Arbre de décision dans R : arbre de classification avec exemple

Que sont les arbres de décision ?

Arbres de décision sont des algorithmes d'apprentissage automatique polyvalents qui peuvent effectuer à la fois des tâches de classification et de régression. Ce sont des algorithmes très puissants, capables d’ajuster des ensembles de données complexes. En outre, les arbres de décision sont des composants fondamentaux des forêts aléatoires, qui comptent parmi les algorithmes d’apprentissage automatique les plus puissants disponibles aujourd’hui.

Formation et visualisation d'un arbre de décision dans R

Pour construire votre premier arbre de décision dans l'exemple R, nous procéderons comme suit dans ce didacticiel d'arbre de décision :

  • Étape 1 : Importer les données
  • Étape 2 : Nettoyer l'ensemble de données
  • Étape 3 : Créer un ensemble d'entraînement/de test
  • Étape 4 : Construire le modèle
  • Étape 5 : Faire une prédiction
  • Étape 6 : Mesurer les performances
  • Étape 7 : Ajustez les hyper-paramètres

Étape 1) Importez les données

Si vous êtes curieux de connaître le sort du Titanic, vous pouvez regarder cette vidéo sur Youtube. Le but de cet ensemble de données est de prédire quelles personnes ont le plus de chances de survivre après une collision avec un iceberg. L'ensemble de données contient 13 variables et 1309 observations. L'ensemble de données est ordonné par la variable X.

set.seed(678)
path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv'
titanic <-read.csv(path)
head(titanic)

Sortie :

##   X pclass survived                                            name    sex
## 1 1      1        1                   Allen, Miss. Elisabeth Walton female
## 2 2      1        1                  Allison, Master. Hudson Trevor   male
## 3 3      1        0                    Allison, Miss. Helen Loraine female
## 4 4      1        0            Allison, Mr. Hudson Joshua Creighton   male
## 5 5      1        0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female
## 6 6      1        1                             Anderson, Mr. Harry   male
##       age sibsp parch ticket     fare   cabin embarked
## 1 29.0000     0     0  24160 211.3375      B5        S
## 2  0.9167     1     2 113781 151.5500 C22 C26        S
## 3  2.0000     1     2 113781 151.5500 C22 C26        S
## 4 30.0000     1     2 113781 151.5500 C22 C26        S
## 5 25.0000     1     2 113781 151.5500 C22 C26        S
## 6 48.0000     0     0  19952  26.5500     E12        S
##                         home.dest
## 1                    St Louis, MO
## 2 Montreal, PQ / Chesterville, ON
## 3 Montreal, PQ / Chesterville, ON
## 4 Montreal, PQ / Chesterville, ON
## 5 Montreal, PQ / Chesterville, ON
## 6                    New York, NY
tail(titanic)

Sortie :

##         X pclass survived                      name    sex  age sibsp
## 1304 1304      3        0     Yousseff, Mr. Gerious   male   NA     0
## 1305 1305      3        0      Zabour, Miss. Hileni female 14.5     1
## 1306 1306      3        0     Zabour, Miss. Thamine female   NA     1
## 1307 1307      3        0 Zakarian, Mr. Mapriededer   male 26.5     0
## 1308 1308      3        0       Zakarian, Mr. Ortin   male 27.0     0
## 1309 1309      3        0        Zimmerman, Mr. Leo   male 29.0     0
##      parch ticket    fare cabin embarked home.dest
## 1304     0   2627 14.4583              C          
## 1305     0   2665 14.4542              C          
## 1306     0   2665 14.4542              C          
## 1307     0   2656  7.2250              C          
## 1308     0   2670  7.2250              C          
## 1309     0 315082  7.8750              S

À partir des sorties head et tail, vous pouvez remarquer que les données ne sont pas mélangées. C'est un gros problème! Lorsque vous diviserez vos données entre une rame et une rame de test, vous sélectionnerez uniquement le passager des classes 1 et 2 (aucun passager de la classe 3 ne figure dans les 80 % des observations), ce qui signifie que l'algorithme ne verra jamais les caractéristiques du passager de la classe 3. Cette erreur entraînera une mauvaise prédiction.

Pour résoudre ce problème, vous pouvez utiliser la fonction sample().

shuffle_index <- sample(1:nrow(titanic))
head(shuffle_index)

Arbre de décision Code R Explication

  • sample(1:nrow(titanic)) : génère une liste aléatoire d'index de 1 à 1309 (c'est-à-dire le nombre maximum de lignes).

Sortie :

## [1]  288  874 1078  633  887  992

Vous utiliserez cet index pour mélanger l'ensemble de données titanesque.

titanic <- titanic[shuffle_index, ]
head(titanic)

Sortie :

##         X pclass survived
## 288   288      1        0
## 874   874      3        0
## 1078 1078      3        1
## 633   633      3        0
## 887   887      3        1
## 992   992      3        1
##                                                           name    sex age
## 288                                      Sutton, Mr. Frederick   male  61
## 874                   Humblen, Mr. Adolf Mathias Nicolai Olsen   male  42
## 1078                                 O'Driscoll, Miss. Bridget female  NA
## 633  Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female  39
## 887                                        Jermyn, Miss. Annie female  NA
## 992                                           Mamee, Mr. Hanna   male  NA
##      sibsp parch ticket    fare cabin embarked           home.dest## 288      0     0  36963 32.3208   D50        S     Haddenfield, NJ
## 874      0     0 348121  7.6500 F G63        S                    
## 1078     0     0  14311  7.7500              Q                    
## 633      1     5 347082 31.2750              S Sweden Winnipeg, MN
## 887      0     0  14313  7.7500              Q                    
## 992      0     0   2677  7.2292              C	

Étape 2) Nettoyer l'ensemble de données

La structure des données montre que certaines variables ont des NA. Le nettoyage des données doit être effectué comme suit

  • Supprimez les variables home.dest, cabine, nom, X et ticket
  • Créer des variables factorielles pour pclass et survécu
  • Abandonnez le NA
library(dplyr)
# Drop variables
clean_titanic <- titanic % > %
select(-c(home.dest, cabin, name, X, ticket)) % > % 
#Convert to factor level
	mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')),
	survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > %
na.omit()
glimpse(clean_titanic)

Explication du code

  • select(-c(home.dest, cabin, name, X, ticket)) : supprimez les variables inutiles
  • pclass = factor(pclass,levels = c(1,2,3), labels= c('Upper', 'Middle', 'Lower')) : Ajoutez une étiquette à la variable pclass. 1 devient supérieur, 2 devient MOYEN et 3 devient inférieur
  • factor(survived,levels = c(0,1), labels = c('No', 'Yes')) : Ajoutez une étiquette à la variable survécue. 1 devient non et 2 devient oui
  • na.omit() : Supprime les observations NA

Sortie :

## Observations: 1,045
## Variables: 8
## $ pclass   <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U...
## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y...
## $ sex      <fctr> male, male, female, female, male, male, female, male...
## $ age      <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ...
## $ sibsp    <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,...
## $ parch    <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,...
## $ fare     <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ...
## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...		

Étape 3) Créer un ensemble d'entraînement/test

Avant d'entraîner votre modèle, vous devez effectuer deux étapes :

  • Créez un ensemble de train et de test : vous entraînez le modèle sur la rame et testez la prédiction sur l'ensemble de test (c'est-à-dire des données invisibles)
  • Installez rpart.plot depuis la console

La pratique courante consiste à diviser les données à 80/20, 80 % des données servant à entraîner le modèle et 20 % à faire des prédictions. Vous devez créer deux blocs de données distincts. Vous ne voulez pas toucher à l'ensemble de test tant que vous n'avez pas fini de créer votre modèle. Vous pouvez créer un nom de fonction create_train_test() qui prend trois arguments.

create_train_test(df, size = 0.8, train = TRUE)
arguments:
-df: Dataset used to train the model.
-size: Size of the split. By default, 0.8. Numerical value
-train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) {
    n_row = nrow(data)
    total_row = size * n_row
    train_sample < - 1: total_row
    if (train == TRUE) {
        return (data[train_sample, ])
    } else {
        return (data[-train_sample, ])
    }
}

Explication du code

  • function(data, size=0.8, train = TRUE) : ​​Ajoutez les arguments dans la fonction
  • n_row = nrow(data) : compte le nombre de lignes dans l'ensemble de données
  • total_row = size*n_row : renvoie la nième ligne pour construire la rame
  • train_sample <- 1:total_row : sélectionnez la première ligne jusqu'à la nième ligne
  • if (train ==TRUE){ } else { } : si la condition est définie sur true, renvoie la rame, sinon l'ensemble de test.

Vous pouvez tester votre fonction et vérifier la dimension.

data_train <- create_train_test(clean_titanic, 0.8, train = TRUE)
data_test <- create_train_test(clean_titanic, 0.8, train = FALSE)
dim(data_train)

Sortie :

## [1] 836   8
dim(data_test)

Sortie :

## [1] 209   8

L'ensemble de données de train comporte 1046 262 lignes tandis que l'ensemble de données de test comporte lignes.

Vous utilisez la fonction prop.table() combinée avec table() pour vérifier si le processus de randomisation est correct.

prop.table(table(data_train$survived))

Sortie :

##
##        No       Yes 
## 0.5944976 0.4055024
prop.table(table(data_test$survived))

Sortie :

## 
##        No       Yes 
## 0.5789474 0.4210526

Dans les deux ensembles de données, le nombre de survivants est le même, environ 40 pour cent.

Installer rpart.plot

rpart.plot n'est pas disponible dans les bibliothèques conda. Vous pouvez l'installer depuis la console :

install.packages("rpart.plot")

Étape 4) Construire le modèle

Vous êtes prêt à construire le modèle. La syntaxe de la fonction d'arbre de décision Rpart est la suivante :

rpart(formula, data=, method='')
arguments:			
- formula: The function to predict
- data: Specifies the data frame- method: 			
- "class" for a classification tree 			
- "anova" for a regression tree	

Vous utilisez la méthode de classe parce que vous prédisez une classe.

library(rpart)
library(rpart.plot)
fit <- rpart(survived~., data = data_train, method = 'class')
rpart.plot(fit, extra = 106

Explication du code

  • rpart() : Fonction pour ajuster le modèle. Les arguments sont :
    • survécu ~.: Formule des arbres de décision
    • data = data_train : ensemble de données
    • method = 'class' : Ajuster un modèle binaire
  • rpart.plot(fit, extra= 106) : tracez l'arborescence. Les fonctionnalités supplémentaires sont définies sur 101 pour afficher la probabilité de la 2ème classe (utile pour les réponses binaires). Vous pouvez vous référer au vignette pour plus d’informations sur les autres choix.

Sortie :

Construire un modèle d'arbres de décision dans R

Vous commencez au nœud racine (profondeur 0 sur 3, haut du graphique) :

  1. Au sommet, c’est la probabilité globale de survie. Il montre la proportion de passagers qui ont survécu à l'accident. 41 pour cent des passagers ont survécu.
  2. Ce nœud demande si le sexe du passager est un homme. Si oui, alors vous descendez jusqu'au nœud enfant gauche de la racine (profondeur 2). 63 pour cent sont des hommes avec une probabilité de survie de 21 pour cent.
  3. Dans le deuxième nœud, vous demandez si le passager masculin a plus de 3.5 ans. Si oui, les chances de survie sont de 19 pour cent.
  4. Vous continuez ainsi pour comprendre quelles caractéristiques ont un impact sur les chances de survie.

Notez que l’une des nombreuses qualités des arbres de décision est qu’ils nécessitent très peu de préparation des données. En particulier, ils ne nécessitent pas de mise à l’échelle ou de centrage des fonctionnalités.

Par défaut, la fonction rpart() utilise le Gini mesure d'impureté pour diviser la note. Plus le coefficient de Gini est élevé, plus il y a d'instances différentes au sein du nœud.

Étape 5) Faites une prédiction

Vous pouvez prédire votre ensemble de données de test. Pour faire une prédiction, vous pouvez utiliser la fonction prédire(). La syntaxe de base de la prévision pour l'arbre de décision R est :

predict(fitted_model, df, type = 'class')
arguments:
- fitted_model: This is the object stored after model estimation. 
- df: Data frame used to make the prediction
- type: Type of prediction			
    - 'class': for classification			
    - 'prob': to compute the probability of each class			
    - 'vector': Predict the mean response at the node level	

Vous souhaitez prédire quels passagers sont les plus susceptibles de survivre après la collision à partir de l'ensemble de test. Cela signifie que vous saurez parmi ces 209 passagers lequel survivra ou non.

predict_unseen <-predict(fit, data_test, type = 'class')

Explication du code

  • predict(fit, data_test, type = 'class') : prédire la classe (0/1) de l'ensemble de test

Tester le passager qui n’a pas réussi et ceux qui l’ont fait.

table_mat <- table(data_test$survived, predict_unseen)
table_mat

Explication du code

  • table(data_test$survived, prédict_unseen) : créez un tableau pour compter le nombre de passagers classés comme survivants et décédés par rapport à la classification correcte de l'arbre de décision dans R

Sortie :

##      predict_unseen
##        No Yes
##   No  106  15
##   Yes  30  58

Le modèle a correctement prédit 106 passagers morts, mais a classé 15 survivants comme morts. Par analogie, le modèle a classé à tort 30 passagers comme survivants alors qu’ils se sont révélés morts.

Étape 6) Mesurer les performances

Vous pouvez calculer une mesure de précision pour une tâche de classification avec l'outil matrice de confusion:

La série matrice de confusion est un meilleur choix pour évaluer les performances de classification. L'idée générale est de compter le nombre de fois où les instances vraies sont classées comme étant fausses.

Mesurer les performances des arbres de décision dans R

Chaque ligne d'une matrice de confusion représente une cible réelle, tandis que chaque colonne représente une cible prévue. La première ligne de cette matrice considère les passagers morts (la classe Faux) : 106 ont été correctement classés comme morts (Vrai négatif), tandis que le dernier a été classé à tort comme survivant (Faux positif). La deuxième ligne considère les survivants, la classe positive était de 58 (Vrai positif), tandis que le Vrai négatif était 30.

Vous pouvez calculer le test de précision de la matrice de confusion :

Mesurer les performances des arbres de décision dans R

C'est la proportion de vrais positifs et de vrais négatifs sur la somme de la matrice. Avec R, vous pouvez coder comme suit :

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)

Explication du code

  • sum(diag(table_mat)) : Somme de la diagonale
  • sum(table_mat) : Somme de la matrice.

Vous pouvez imprimer la précision de l'ensemble de test :

print(paste('Accuracy for test', accuracy_Test))

Sortie :

## [1] "Accuracy for test 0.784688995215311"

Vous avez un score de 78 pour cent pour l’ensemble de tests. Vous pouvez reproduire le même exercice avec l'ensemble de données d'entraînement.

Étape 7) Ajustez les hyper-paramètres

L'arbre de décision dans R comporte divers paramètres qui contrôlent les aspects de l'ajustement. Dans la bibliothèque d'arbres de décision rpart, vous pouvez contrôler les paramètres à l'aide de la fonction rpart.control(). Dans le code suivant, vous introduisez les paramètres que vous allez régler. Vous pouvez vous référer au vignette pour les autres paramètres.

rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30)
Arguments:
-minsplit: Set the minimum number of observations in the node before the algorithm perform a split
-minbucket:  Set the minimum number of observations in the final note i.e. the leaf
-maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0

Nous procéderons de la manière suivante :

  • Construire une fonction pour renvoyer la précision
  • Ajustez la profondeur maximale
  • Ajustez le nombre minimum d'échantillons qu'un nœud doit avoir avant de pouvoir se diviser
  • Ajustez le nombre minimum d’échantillons qu’un nœud feuille doit avoir

Vous pouvez écrire une fonction pour afficher la précision. Vous enveloppez simplement le code que vous avez utilisé auparavant :

  1. prédire : prédire_unseen <- prédire (fit, data_test, type = 'class')
  2. Produire la table : table_mat <- table(data_test$survived, prédict_unseen)
  3. Précision du calcul : précision_Test <- sum(diag(table_mat))/sum(table_mat)
accuracy_tune <- function(fit) {
    predict_unseen <- predict(fit, data_test, type = 'class')
    table_mat <- table(data_test$survived, predict_unseen)
    accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
    accuracy_Test
}

Vous pouvez essayer d'ajuster les paramètres et voir si vous pouvez améliorer le modèle par rapport à la valeur par défaut. Pour rappel, il faut obtenir une précision supérieure à 0.78

control <- rpart.control(minsplit = 4,
    minbucket = round(5 / 3),
    maxdepth = 3,
    cp = 0)
tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control)
accuracy_tune(tune_fit)

Sortie :

## [1] 0.7990431

Avec le paramètre suivant :

minsplit = 4
minbucket= round(5/3)
maxdepth = 3cp=0

Vous obtenez des performances supérieures à celles du modèle précédent. Félicitation !

Résumé

Nous pouvons résumer les fonctions pour entraîner un algorithme d'arbre de décision dans R

Bibliothèque Objectif Fonction Classe Paramètres Détails
partie Arbre de classification des trains dans R rpart() classe formule, df, méthode
partie Arbre de régression de train rpart() anova formule, df, méthode
partie Tracer les arbres rpart.plot() modèle ajusté
base prévoir prédire() classe modèle ajusté, type
base prévoir prédire() problèmes modèle ajusté, type
base prévoir prédire() vecteur modèle ajusté, type
partie Paramètres de contrôle rpart.control() minsplit Définissez le nombre minimum d'observations dans le nœud avant que l'algorithme effectue une division
minbucket Définissez le nombre minimum d'observations dans la note finale, c'est-à-dire la feuille
profondeur max Définissez la profondeur maximale de n’importe quel nœud de l’arborescence finale. Le nœud racine est traité avec une profondeur 0
partie Modèle de train avec paramètre de contrôle rpart() formule, df, méthode, contrôle

Remarque : entraînez le modèle sur des données d'entraînement et testez les performances sur un ensemble de données invisible, c'est-à-dire un ensemble de test.