Árbol de decisión en R: Árbol de clasificación con ejemplo
¿Qué son los árboles de decisión?
Árboles de decisión Son algoritmos de aprendizaje automático versátiles que pueden realizar tareas tanto de clasificación como de regresión. Son algoritmos muy potentes, capaces de ajustar conjuntos de datos complejos. Además, los árboles de decisión son componentes fundamentales de los bosques aleatorios, que se encuentran entre los algoritmos de aprendizaje automático más potentes disponibles en la actualidad.
Entrenamiento y visualización de árboles de decisión en R
Para construir su primer árbol de decisión en el ejemplo de R, procederemos de la siguiente manera en este tutorial de Árbol de decisión:
- Paso 1: importar los datos
- Paso 2: limpiar el conjunto de datos
- Paso 3: Crear tren/conjunto de prueba
- Paso 4: construye el modelo
- Paso 5: haz una predicción
- Paso 6: medir el rendimiento
- Paso 7: ajuste los hiperparámetros
Paso 1) Importar los datos
Si tienes curiosidad sobre el destino del Titanic, puedes ver este vídeo en YouTube. El propósito de este conjunto de datos es predecir qué personas tienen más probabilidades de sobrevivir después de la colisión con el iceberg. El conjunto de datos contiene 13 variables y 1309 observaciones. El conjunto de datos está ordenado por la variable X.
set.seed(678) path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv' titanic <-read.csv(path) head(titanic)
Salida:
## X pclass survived name sex ## 1 1 1 1 Allen, Miss. Elisabeth Walton female ## 2 2 1 1 Allison, Master. Hudson Trevor male ## 3 3 1 0 Allison, Miss. Helen Loraine female ## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male ## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female ## 6 6 1 1 Anderson, Mr. Harry male ## age sibsp parch ticket fare cabin embarked ## 1 29.0000 0 0 24160 211.3375 B5 S ## 2 0.9167 1 2 113781 151.5500 C22 C26 S ## 3 2.0000 1 2 113781 151.5500 C22 C26 S ## 4 30.0000 1 2 113781 151.5500 C22 C26 S ## 5 25.0000 1 2 113781 151.5500 C22 C26 S ## 6 48.0000 0 0 19952 26.5500 E12 S ## home.dest ## 1 St Louis, MO ## 2 Montreal, PQ / Chesterville, ON ## 3 Montreal, PQ / Chesterville, ON ## 4 Montreal, PQ / Chesterville, ON ## 5 Montreal, PQ / Chesterville, ON ## 6 New York, NY
tail(titanic)
Salida:
## X pclass survived name sex age sibsp ## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0 ## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1 ## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1 ## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0 ## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0 ## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0 ## parch ticket fare cabin embarked home.dest ## 1304 0 2627 14.4583 C ## 1305 0 2665 14.4542 C ## 1306 0 2665 14.4542 C ## 1307 0 2656 7.2250 C ## 1308 0 2670 7.2250 C ## 1309 0 315082 7.8750 S
Desde la salida inicial y final, puede notar que los datos no están mezclados. ¡Este es un gran problema! Cuando divida sus datos entre un conjunto de tren y un conjunto de prueba, seleccionará only el pasajero de las clases 1 y 2 (ningún pasajero de la clase 3 está en el 80 por ciento superior de las observaciones), lo que significa que el algoritmo nunca verá las características del pasajero de la clase 3. Este error conducirá a una mala predicción.
Para superar este problema, puede utilizar la función sample().
shuffle_index <- sample(1:nrow(titanic)) head(shuffle_index)
Árbol de decisión Código R Explicación
- muestra (1: nrow (titanic)): genera una lista aleatoria de índices del 1 al 1309 (es decir, el número máximo de filas).
Salida:
## [1] 288 874 1078 633 887 992
Utilizará este índice para mezclar el conjunto de datos del Titanic.
titanic <- titanic[shuffle_index, ] head(titanic)
Salida:
## X pclass survived ## 288 288 1 0 ## 874 874 3 0 ## 1078 1078 3 1 ## 633 633 3 0 ## 887 887 3 1 ## 992 992 3 1 ## name sex age ## 288 Sutton, Mr. Frederick male 61 ## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42 ## 1078 O'Driscoll, Miss. Bridget female NA ## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39 ## 887 Jermyn, Miss. Annie female NA ## 992 Mamee, Mr. Hanna male NA ## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ ## 874 0 0 348121 7.6500 F G63 S ## 1078 0 0 14311 7.7500 Q ## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN ## 887 0 0 14313 7.7500 Q ## 992 0 0 2677 7.2292 C
Paso 2) Limpiar el conjunto de datos
La estructura de los datos muestra que algunas variables tienen NA. La limpieza de datos se realizará de la siguiente manera
- Eliminar variables home.dest,cabin, nombre, X y ticket
- Crear variables de factor para pclass y sobrevivió
- Deja la NA
library(dplyr) # Drop variables clean_titanic <- titanic % > % select(-c(home.dest, cabin, name, X, ticket)) % > % #Convert to factor level mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')), survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > % na.omit() glimpse(clean_titanic)
Explicación del código
- select(-c(home.dest, cabina, nombre, X, boleto)): Elimina variables innecesarias
- pclass = factor(pclass, niveles = c(1,2,3), etiquetas = c('Upper', 'Middle', 'Lower')): Agrega etiqueta a la variable pclass. 1 se vuelve superior, 2 se vuelve medio y 3 se vuelve inferior
- factor (sobrevivido, niveles = c (0,1), etiquetas = c ('No', 'Sí')): agregue una etiqueta a la variable sobrevivida. 1 se convierte en No y 2 se convierte en Sí
- na.omit(): Elimina las observaciones de NA
Salida:
## Observations: 1,045 ## Variables: 8 ## $ pclass <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U... ## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y... ## $ sex <fctr> male, male, female, female, male, male, female, male... ## $ age <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ... ## $ sibsp <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,... ## $ parch <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,... ## $ fare <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ... ## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...
Paso 3) Crear tren/conjunto de prueba
Antes de entrenar su modelo, debe realizar dos pasos:
- Cree un tren y un conjunto de prueba: entrene el modelo en el conjunto de trenes y pruebe la predicción en el conjunto de prueba (es decir, datos no vistos).
- Instale rpart.plot desde la consola
La práctica común es dividir los datos 80/20, el 80 por ciento de los datos sirve para entrenar el modelo y el 20 por ciento para hacer predicciones. Necesita crear dos marcos de datos separados. No querrás tocar el conjunto de prueba hasta que termines de construir tu modelo. Puede crear un nombre de función create_train_test() que tome tres argumentos.
create_train_test(df, size = 0.8, train = TRUE) arguments: -df: Dataset used to train the model. -size: Size of the split. By default, 0.8. Numerical value -train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample < - 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } }
Explicación del código
- función (datos, tamaño = 0.8, tren = VERDADERO): agrega los argumentos en la función
- n_row = nrow(data): cuenta el número de filas en el conjunto de datos
- total_row = size*n_row: Devuelve la enésima fila para construir el conjunto de trenes
- train_sample <- 1:total_row: seleccione la primera fila hasta la enésima fila
- if (train ==TRUE){ } else { }: si la condición se establece en verdadera, devuelve el conjunto de tren; de lo contrario, el conjunto de prueba.
Puede probar su función y verificar la dimensión.
data_train <- create_train_test(clean_titanic, 0.8, train = TRUE) data_test <- create_train_test(clean_titanic, 0.8, train = FALSE) dim(data_train)
Salida:
## [1] 836 8
dim(data_test)
Salida:
## [1] 209 8
El conjunto de datos del tren tiene 1046 filas, mientras que el conjunto de datos de prueba tiene 262 filas.
Utiliza la función prop.table() combinada con table() para verificar si el proceso de aleatorización es correcto.
prop.table(table(data_train$survived))
Salida:
## ## No Yes ## 0.5944976 0.4055024
prop.table(table(data_test$survived))
Salida:
## ## No Yes ## 0.5789474 0.4210526
En ambos conjuntos de datos, la cantidad de supervivientes es la misma, alrededor del 40 por ciento.
Instalar rpart.plot
rpart.plot no está disponible en las bibliotecas conda. Puedes instalarlo desde la consola:
install.packages("rpart.plot")
Paso 4) Construye el modelo
Estás listo para construir el modelo. La sintaxis de la función del árbol de decisión Rpart es:
rpart(formula, data=, method='') arguments: - formula: The function to predict - data: Specifies the data frame- method: - "class" for a classification tree - "anova" for a regression tree
Utiliza el método de clase porque predice una clase.
library(rpart) library(rpart.plot) fit <- rpart(survived~., data = data_train, method = 'class') rpart.plot(fit, extra = 106
Explicación del código
- rpart(): Función para ajustar el modelo. Los argumentos son:
- sobrevivió ~.: Fórmula de los árboles de decisión
- datos = data_train: conjunto de datos
- método = 'clase': Ajustar un modelo binario
- rpart.plot(fit, extra= 106): traza el árbol. Las características adicionales se establecen en 101 para mostrar la probabilidad de la segunda clase (útil para respuestas binarias). Puedes consultar el viñeta para obtener más información sobre las otras opciones.
Salida:
Comienza en el nodo raíz (profundidad 0 sobre 3, la parte superior del gráfico):
- En la parte superior está la probabilidad global de supervivencia. Muestra la proporción de pasajeros que sobrevivieron al accidente. El 41 por ciento de los pasajeros sobrevivió.
- Este nodo pregunta si el sexo del pasajero es masculino. En caso afirmativo, baje al nodo secundario izquierdo de la raíz (profundidad 2). El 63 por ciento son hombres con una probabilidad de supervivencia del 21 por ciento.
- En el segundo nodo se pregunta si el pasajero varón tiene más de 3.5 años. En caso afirmativo, la probabilidad de supervivencia es del 19 por ciento.
- Continúe así para comprender qué características afectan la probabilidad de supervivencia.
Tenga en cuenta que una de las muchas cualidades de los árboles de decisión es que requieren muy poca preparación de datos. En particular, no requieren escalado ni centrado de funciones.
Por defecto, la función rpart() utiliza el Gini Medida de impureza para dividir el billete. Cuanto mayor sea el coeficiente de Gini, más instancias diferentes habrá dentro del nodo.
Paso 5) Haz una predicción
Puede predecir su conjunto de datos de prueba. Para hacer una predicción, puede utilizar la función predict(). La sintaxis básica de predicción para el árbol de decisión R es:
predict(fitted_model, df, type = 'class') arguments: - fitted_model: This is the object stored after model estimation. - df: Data frame used to make the prediction - type: Type of prediction - 'class': for classification - 'prob': to compute the probability of each class - 'vector': Predict the mean response at the node level
Quiere predecir qué pasajeros tienen más probabilidades de sobrevivir después de la colisión a partir del conjunto de prueba. Es decir, sabrás entre esos 209 pasajeros cuál sobrevivirá o no.
predict_unseen <-predict(fit, data_test, type = 'class')
Explicación del código
- predecir (ajuste, prueba_datos, tipo = 'clase'): predice la clase (0/1) del conjunto de prueba
Probando al pasajero que no lo logró y a los que sí lo lograron.
table_mat <- table(data_test$survived, predict_unseen) table_mat
Explicación del código
- table(data_test$survived, predict_unseen): cree una tabla para contar cuántos pasajeros se clasifican como sobrevivientes y fallecieron en comparación con la clasificación correcta del árbol de decisión en R
Salida:
## predict_unseen ## No Yes ## No 106 15 ## Yes 30 58
El modelo predijo correctamente 106 pasajeros muertos, pero clasificó a 15 supervivientes como muertos. Por analogía, el modelo clasificó erróneamente a 30 pasajeros como supervivientes cuando resultaron estar muertos.
Paso 6) Medir el desempeño
Puede calcular una medida de precisión para la tarea de clasificación con el matriz de confusión:
Los matriz de confusión es una mejor opción para evaluar el rendimiento de la clasificación. La idea general es contar el número de veces que las instancias Verdaderas se clasifican como Falsas.
Cada fila de una matriz de confusión representa un objetivo real, mientras que cada columna representa un objetivo previsto. La primera fila de esta matriz considera pasajeros fallecidos (la clase Falsa): 106 fueron clasificados correctamente como muertos (Verdadero-negativo), mientras que el restante fue clasificado erróneamente como superviviente (Falso positivo). La segunda fila considera a los supervivientes, la clase positiva fueron 58 (Verdadero positivo), Mientras que el Verdadero-negativo fue 30.
Puede calcular el prueba de precisión de la matriz de confusión:
Es la proporción de verdaderos positivos y verdaderos negativos sobre la suma de la matriz. Con R, puedes codificar de la siguiente manera:
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
Explicación del código
- sum(diag(table_mat)): Suma de la diagonal
- sum(table_mat): Suma de la matriz.
Puede imprimir la precisión del conjunto de prueba:
print(paste('Accuracy for test', accuracy_Test))
Salida:
## [1] "Accuracy for test 0.784688995215311"
Tiene una puntuación del 78 por ciento en el conjunto de pruebas. Puedes replicar el mismo ejercicio con el conjunto de datos de entrenamiento.
Paso 7) Ajusta los hiperparámetros
El árbol de decisión en R tiene varios parámetros que controlan aspectos del ajuste. En la biblioteca de árboles de decisión de rpart, puede controlar los parámetros mediante la función rpart.control(). En el siguiente código, introduce los parámetros que ajustará. Puede consultar la viñeta para otros parámetros.
rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30) Arguments: -minsplit: Set the minimum number of observations in the node before the algorithm perform a split -minbucket: Set the minimum number of observations in the final note i.e. the leaf -maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0
Procederemos de la siguiente manera:
- Construir función para devolver precisión
- Ajusta la profundidad máxima
- Ajuste la cantidad mínima de muestra que debe tener un nodo antes de poder dividirse
- Ajuste el número mínimo de muestras que debe tener un nodo hoja
Puede escribir una función para mostrar la precisión. Simplemente envuelve el código que usaste antes:
- predecir: predecir_unseen <- predecir (ajuste, prueba_datos, tipo = 'clase')
- Producir tabla: table_mat <- table(data_test$survived, predict_unseen)
- Precisión de cálculo: exactitud_Test <- sum(diag(table_mat))/sum(table_mat)
accuracy_tune <- function(fit) { predict_unseen <- predict(fit, data_test, type = 'class') table_mat <- table(data_test$survived, predict_unseen) accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test }
Puede intentar ajustar los parámetros y ver si puede mejorar el modelo con respecto al valor predeterminado. Como recordatorio, debe obtener una precisión superior a 0.78.
control <- rpart.control(minsplit = 4, minbucket = round(5 / 3), maxdepth = 3, cp = 0) tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control) accuracy_tune(tune_fit)
Salida:
## [1] 0.7990431
Con el siguiente parámetro:
minsplit = 4 minbucket= round(5/3) maxdepth = 3cp=0
Obtienes un rendimiento superior al del modelo anterior. ¡Felicitaciones!
Resum
Podemos resumir las funciones para entrenar un algoritmo de árbol de decisión en R
Biblioteca | Objetivo | Función | Clase | parámetros | Detalles |
---|---|---|---|---|---|
parte | Árbol de clasificación de trenes en R | parte() | clase | fórmula, df, método | |
parte | Tren de árbol de regresión | parte() | anova | fórmula, df, método | |
parte | Trazar los árboles | rpart.plot() | modelo ajustado | ||
bases | predecir | predecir() | clase | modelo equipado, tipo | |
bases | predecir | predecir() | problema | modelo equipado, tipo | |
bases | predecir | predecir() | vector | modelo equipado, tipo | |
parte | Parámetros de control | rpart.control() | división mínima | Establezca el número mínimo de observaciones en el nodo antes de que el algoritmo realice una división | |
minbucket | Establezca el número mínimo de observaciones en la nota final, es decir, la hoja. | ||||
máxima profundidad | Establece la profundidad máxima de cualquier nodo del árbol final. El nodo raíz se trata con una profundidad 0. | ||||
parte | Modelo de tren con parámetro de control. | parte() | fórmula, df, método, control |
Nota: Entrene el modelo con datos de entrenamiento y pruebe el rendimiento en un conjunto de datos invisible, es decir, un conjunto de prueba.