Tutoriel R Random Forest avec exemple

Quโ€™est-ce que la forรชt alรฉatoire dans R ?

Les forรชts alรฉatoires reposent sur une idรฉe simple : ยซ la sagesse de la foule ยป. Lโ€™agrรฉgation des rรฉsultats de plusieurs prรฉdicteurs donne une meilleure prรฉdiction que le meilleur prรฉdicteur individuel. Un groupe de prรฉdicteurs est appelรฉ un ensemble. Ainsi, cette technique est appelรฉe Apprentissage d'ensemble.

Dans le didacticiel prรฉcรฉdent, vous avez appris ร  utiliser Arbres de dรฉcision faire une prรฉdiction binaire. Pour amรฉliorer notre technique, nous pouvons former un groupe de Classificateurs d'arbre de dรฉcision, chacun sur un sous-ensemble alรฉatoire diffรฉrent de la rame. Pour faire une prรฉdiction, nous obtenons simplement les prรฉdictions de tous les arbres individuels, puis prรฉdisons la classe qui obtient le plus de votes. Cette technique est appelรฉe Forรชt alรฉatoire.

ร‰tape 1) Importez les donnรฉes

Pour vous assurer que vous disposez du mรชme ensemble de donnรฉes que dans le didacticiel pour arbres de dรฉcision, le test du train et l'ensemble de tests sont stockรฉs sur Internet. Vous pouvez les importer sans apporter aucune modification.

library(dplyr)
data_train <- read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/train.csv")
glimpse(data_train)
data_test <- read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/test.csv") 
glimpse(data_test)

ร‰tape 2) Entraรฎner le modรจle

Une faรงon d'รฉvaluer les performances d'un modรจle consiste ร  l'entraรฎner sur un certain nombre d'ensembles de donnรฉes diffรฉrents plus petits et ร  les รฉvaluer sur l'autre ensemble de tests plus petit. C'est ce qu'on appelle le Validation croisรฉe F-fold . R a une fonction pour diviser alรฉatoirement un nombre dโ€™ensembles de donnรฉes presque de la mรชme taille. Par exemple, si k=9, le modรจle est รฉvaluรฉ sur les neuf dossiers et testรฉ sur l'ensemble de test restant. Ce processus est rรฉpรฉtรฉ jusqu'ร  ce que tous les sous-ensembles aient รฉtรฉ รฉvaluรฉs. Cette technique est largement utilisรฉe pour la sรฉlection de modรจles, en particulier lorsque le modรจle comporte des paramรจtres ร  rรฉgler.

Maintenant que nous disposons dโ€™un moyen dโ€™รฉvaluer notre modรจle, nous devons dรฉterminer comment choisir les paramรจtres qui gรฉnรฉralisent le mieux les donnรฉes.

La forรชt alรฉatoire choisit un sous-ensemble alรฉatoire de fonctionnalitรฉs et crรฉe de nombreux arbres de dรฉcision. Le modรจle fait la moyenne de toutes les prรฉdictions des arbres de dรฉcision.

La forรชt alรฉatoire possรจde certains paramรจtres qui peuvent รชtre modifiรฉs pour amรฉliorer la gรฉnรฉralisation de la prรฉdiction. Vous utiliserez la fonction RandomForest() pour entraรฎner le modรจle.

La syntaxe de Randon Forest est

RandomForest(formula, ntree=n, mtry=FALSE, maxnodes = NULL)
Arguments:
- Formula: Formula of the fitted model
- ntree: number of trees in the forest
- mtry: Number of candidates draw to feed the algorithm. By default, it is the square of the number of columns.
- maxnodes: Set the maximum amount of terminal nodes in the forest
- importance=TRUE: Whether independent variables importance in the random forest be assessed

Note: La forรชt alรฉatoire peut รชtre entraรฎnรฉe sur plus de paramรจtres. Vous pouvez vous rรฉfรฉrer au vignette pour voir les diffรฉrents paramรจtres.

Le rรฉglage d'un modรจle est un travail trรจs fastidieux. Il existe de nombreuses combinaisons possibles entre les paramรจtres. On n'a pas forcรฉment le temps de tous les essayer. Une bonne alternative consiste ร  laisser la machine trouver la meilleure combinaison pour vous. Deux mรฉthodes sont disponibles :

  • Recherche alรฉatoire
  • Recherche de grille

Nous dรฉfinirons les deux mรฉthodes mais au cours du didacticiel, nous entraรฎnerons le modรจle ร  l'aide de la recherche par grille.

Dรฉfinition de la recherche en grille

La mรฉthode de recherche par grille est simple, le modรจle sera รฉvaluรฉ sur toute la combinaison que vous transmettez dans la fonction, par validation croisรฉe.

Par exemple, vous souhaitez essayer le modรจle avec 10, 20, 30 arbres et chaque arbre sera testรฉ sur un nombre de mรจtres รฉgal ร  1, 2, 3, 4, 5. Ensuite, la machine testera 15 modรจles diffรฉrents :

    .mtry ntrees
 1      1     10
 2      2     10
 3      3     10
 4      4     10
 5      5     10
 6      1     20
 7      2     20
 8      3     20
 9      4     20
 10     5     20
 11     1     30
 12     2     30
 13     3     30
 14     4     30
 15     5     30	

L'algorithme รฉvaluera :

RandomForest(formula, ntree=10, mtry=1)
RandomForest(formula, ntree=10, mtry=2)
RandomForest(formula, ntree=10, mtry=3)
RandomForest(formula, ntree=20, mtry=2)
...

A chaque fois, la forรชt alรฉatoire expรฉrimente une validation croisรฉe. Lโ€™un des inconvรฉnients de la recherche par grille est le nombre dโ€™expรฉrimentations. Cela peut devenir trรจs facilement explosif lorsque le nombre de combinaisons est รฉlevรฉ. Pour surmonter ce problรจme, vous pouvez utiliser la recherche alรฉatoire

Dรฉfinition de la recherche alรฉatoire

La grande diffรฉrence entre la recherche alรฉatoire et la recherche sur grille est que la recherche alรฉatoire n'รฉvaluera pas toutes les combinaisons d'hyperparamรจtres dans l'espace de recherche. Au lieu de cela, il choisira une combinaison au hasard ร  chaque itรฉration. L'avantage est que cela rรฉduit le coรปt de calcul.

Dรฉfinir le paramรจtre de contrรดle

Vous procรฉderez comme suit pour construire et รฉvaluer le modรจle :

  • ร‰valuer le modรจle avec le paramรจtre par dรฉfaut
  • Trouver le meilleur nombre de mtry
  • Trouver le meilleur nombre de maxnodes
  • Trouver le meilleur nombre d'arbres
  • ร‰valuer le modรจle sur l'ensemble de donnรฉes de test

Avant de commencer l'exploration des paramรจtres, vous devez installer deux bibliothรจques.

  • caret : bibliothรจque d'apprentissage automatique R. Si tu as installer R avec r-essentiel. Il est dรฉjร  dans la bibliothรจque
  • e1071 : bibliothรจque dโ€™apprentissage automatique R.

Vous pouvez les importer avec RandomForest

library(randomForest)
library(caret)
library(e1071)

Paramรจtres par dรฉfaut

La validation croisรฉe K-fold est contrรดlรฉe par la fonction trainControl()

trainControl(method = "cv", number = n, search ="grid")
arguments
- method = "cv": The method used to resample the dataset. 
- number = n: Number of folders to create
- search = "grid": Use the search grid method. For randomized method, use "grid"
Note: You can refer to the vignette to see the other arguments of the function.

Vous pouvez essayer d'exรฉcuter le modรจle avec les paramรจtres par dรฉfaut et voir le score de prรฉcision.

Note: Vous utiliserez les mรชmes commandes pendant tout le tutoriel.

# Define the control
trControl <- trainControl(method = "cv",
    number = 10,
    search = "grid")

Vous utiliserez la bibliothรจque caret pour รฉvaluer votre modรจle. La bibliothรจque a une fonction appelรฉe train() pour รฉvaluer presque tout machine learning algorithme. Autrement dit, vous pouvez utiliser cette fonction pour entraรฎner dโ€™autres algorithmes.

La syntaxe de base est:

train(formula, df, method = "rf", metric= "Accuracy", trControl = trainControl(), tuneGrid = NULL)
argument
- `formula`: Define the formula of the algorithm
- `method`: Define which model to train. Note, at the end of the tutorial, there is a list of all the models that can be trained
- `metric` = "Accuracy": Define how to select the optimal model
- `trControl = trainControl()`: Define the control parameters
- `tuneGrid = NULL`: Return a data frame with all the possible combination

Essayons de construire le modรจle avec les valeurs par dรฉfaut.

set.seed(1234)
# Run the model
rf_default <- train(survived~.,
    data = data_train,
    method = "rf",
    metric = "Accuracy",
    trControl = trControl)
# Print the results
print(rf_default)

Explication du code

  • trainControl(method=โ€cvโ€, number=10, search=โ€gridโ€) : ร‰valuez le modรจle avec une recherche de grille de 10 dossiers
  • train(โ€ฆ) : Entraรฎner un modรจle de forรชt alรฉatoire. Le meilleur modรจle est choisi avec la mesure de prรฉcision.

Sortie :

## Random Forest 
## 
## 836 samples
##   7 predictor
##   2 classes: 'No', 'Yes' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 753, 752, 753, 752, 752, 752, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    2    0.7919248  0.5536486
##    6    0.7811245  0.5391611
##   10    0.7572002  0.4939620
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was mtry = 2.

L'algorithme utilise 500 arbres et a testรฉ trois valeurs diffรฉrentes de mtry : 2, 6, 10.

La valeur finale utilisรฉe pour le modรจle รฉtait mtry = 2 avec une prรฉcision de 0.78. Essayons d'obtenir un score plus รฉlevรฉ.

ร‰tape 2) Recherchez le meilleur essai

Vous pouvez tester le modรจle avec des valeurs de mtry de 1 ร  10

set.seed(1234)
tuneGrid <- expand.grid(.mtry = c(1: 10))
rf_mtry <- train(survived~.,
    data = data_train,
    method = "rf",
    metric = "Accuracy",
    tuneGrid = tuneGrid,
    trControl = trControl,
    importance = TRUE,
    nodesize = 14,
    ntree = 300)
print(rf_mtry)

Explication du code

  • tuneGrid <- expand.grid(.mtry=c(3:10)) : Construisez un vecteur avec une valeur de 3:10

La valeur finale utilisรฉe pour le modรจle รฉtait mtry = 4.

Sortie :

## Random Forest 
## 
## 836 samples
##   7 predictor
##   2 classes: 'No', 'Yes' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 753, 752, 753, 752, 752, 752, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    1    0.7572576  0.4647368
##    2    0.7979346  0.5662364
##    3    0.8075158  0.5884815
##    4    0.8110729  0.5970664
##    5    0.8074727  0.5900030
##    6    0.8099111  0.5949342
##    7    0.8050918  0.5866415
##    8    0.8050918  0.5855399
##    9    0.8050631  0.5855035
##   10    0.7978916  0.5707336
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was mtry = 4.

La meilleure valeur de mtry est stockรฉe dans :

rf_mtry$bestTune$mtry

Vous pouvez le stocker et l'utiliser lorsque vous devez rรฉgler les autres paramรจtres.

max(rf_mtry$results$Accuracy)

Sortie :

## [1] 0.8110729
best_mtry <- rf_mtry$bestTune$mtry 
best_mtry

Sortie :

## [1] 4

ร‰tape 3) Recherchez les meilleurs maxnodes

Vous devez crรฉer une boucle pour รฉvaluer les diffรฉrentes valeurs de maxnodes. Dans le code suivant, vous allez :

  • Crรฉer une liste
  • Crรฉez une variable avec la meilleure valeur du paramรจtre mtry ; Obligatoire
  • Crรฉer la boucle
  • Stocker la valeur actuelle de maxnode
  • Rรฉsumer les rรฉsultats
store_maxnode <- list()
tuneGrid <- expand.grid(.mtry = best_mtry)
for (maxnodes in c(5: 15)) {
    set.seed(1234)
    rf_maxnode <- train(survived~.,
        data = data_train,
        method = "rf",
        metric = "Accuracy",
        tuneGrid = tuneGrid,
        trControl = trControl,
        importance = TRUE,
        nodesize = 14,
        maxnodes = maxnodes,
        ntree = 300)
    current_iteration <- toString(maxnodes)
    store_maxnode[[current_iteration]] <- rf_maxnode
}
results_mtry <- resamples(store_maxnode)
summary(results_mtry)

Explication du code :

  • store_maxnode <- list() : Les rรฉsultats du modรจle seront stockรฉs dans cette liste
  • expand.grid(.mtry=best_mtry) : utiliser la meilleure valeur de mtry
  • for (maxnodes in c(15:25)) { โ€ฆ } : calculez le modรจle avec des valeurs de maxnodes allant de 15 ร  25.
  • maxnodes=maxnodes : Pour chaque itรฉration, maxnodes est รฉgal ร  la valeur actuelle de maxnodes. soit 15, 16, 17,โ€ฆ
  • key <- toString(maxnodes) : stocke sous forme de variable de chaรฎne la valeur de maxnode.
  • store_maxnode[[key]] <- rf_maxnode : Enregistrez le rรฉsultat du modรจle dans la liste.
  • resamples(store_maxnode) : Organiser les rรฉsultats du modรจle
  • summary(results_mtry) : Imprime le rรฉsumรฉ de toutes les combinaisons.

Sortie :

## 
## Call:
## summary.resamples(object = results_mtry)
## 
## Models: 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 
## Number of resamples: 10 
## 
## Accuracy 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 5  0.6785714 0.7529762 0.7903758 0.7799771 0.8168388 0.8433735    0
## 6  0.6904762 0.7648810 0.7784710 0.7811962 0.8125000 0.8313253    0
## 7  0.6904762 0.7619048 0.7738095 0.7788009 0.8102410 0.8333333    0
## 8  0.6904762 0.7627295 0.7844234 0.7847820 0.8184524 0.8433735    0
## 9  0.7261905 0.7747418 0.8083764 0.7955250 0.8258749 0.8333333    0
## 10 0.6904762 0.7837780 0.7904475 0.7895869 0.8214286 0.8433735    0
## 11 0.7023810 0.7791523 0.8024240 0.7943775 0.8184524 0.8433735    0
## 12 0.7380952 0.7910929 0.8144005 0.8051205 0.8288511 0.8452381    0
## 13 0.7142857 0.8005952 0.8192771 0.8075158 0.8403614 0.8452381    0
## 14 0.7380952 0.7941050 0.8203528 0.8098967 0.8403614 0.8452381    0
## 15 0.7142857 0.8000215 0.8203528 0.8075301 0.8378873 0.8554217    0
## 
## Kappa 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 5  0.3297872 0.4640436 0.5459706 0.5270773 0.6068751 0.6717371    0
## 6  0.3576471 0.4981484 0.5248805 0.5366310 0.6031287 0.6480921    0
## 7  0.3576471 0.4927448 0.5192771 0.5297159 0.5996437 0.6508314    0
## 8  0.3576471 0.4848320 0.5408159 0.5427127 0.6200253 0.6717371    0
## 9  0.4236277 0.5074421 0.5859472 0.5601687 0.6228626 0.6480921    0
## 10 0.3576471 0.5255698 0.5527057 0.5497490 0.6204819 0.6717371    0
## 11 0.3794326 0.5235007 0.5783191 0.5600467 0.6126720 0.6717371    0
## 12 0.4460432 0.5480930 0.5999072 0.5808134 0.6296780 0.6717371    0
## 13 0.4014252 0.5725752 0.6087279 0.5875305 0.6576219 0.6678832    0
## 14 0.4460432 0.5585005 0.6117973 0.5911995 0.6590982 0.6717371    0
## 15 0.4014252 0.5689401 0.6117973 0.5867010 0.6507194 0.6955990    0

La derniรจre valeur de maxnode a la plus grande prรฉcision. Vous pouvez essayer avec des valeurs plus รฉlevรฉes pour voir si vous pouvez obtenir un score plus รฉlevรฉ.

store_maxnode <- list()
tuneGrid <- expand.grid(.mtry = best_mtry)
for (maxnodes in c(20: 30)) {
    set.seed(1234)
    rf_maxnode <- train(survived~.,
        data = data_train,
        method = "rf",
        metric = "Accuracy",
        tuneGrid = tuneGrid,
        trControl = trControl,
        importance = TRUE,
        nodesize = 14,
        maxnodes = maxnodes,
        ntree = 300)
    key <- toString(maxnodes)
    store_maxnode[[key]] <- rf_maxnode
}
results_node <- resamples(store_maxnode)
summary(results_node)

Sortie :

## 
## Call:
## summary.resamples(object = results_node)
## 
## Models: 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30 
## Number of resamples: 10 
## 
## Accuracy 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 20 0.7142857 0.7821644 0.8144005 0.8075301 0.8447719 0.8571429    0
## 21 0.7142857 0.8000215 0.8144005 0.8075014 0.8403614 0.8571429    0
## 22 0.7023810 0.7941050 0.8263769 0.8099254 0.8328313 0.8690476    0
## 23 0.7023810 0.7941050 0.8263769 0.8111302 0.8447719 0.8571429    0
## 24 0.7142857 0.7946429 0.8313253 0.8135112 0.8417599 0.8690476    0
## 25 0.7142857 0.7916667 0.8313253 0.8099398 0.8408635 0.8690476    0
## 26 0.7142857 0.7941050 0.8203528 0.8123207 0.8528758 0.8571429    0
## 27 0.7023810 0.8060456 0.8313253 0.8135112 0.8333333 0.8690476    0
## 28 0.7261905 0.7941050 0.8203528 0.8111015 0.8328313 0.8690476    0
## 29 0.7142857 0.7910929 0.8313253 0.8087063 0.8333333 0.8571429    0
## 30 0.6785714 0.7910929 0.8263769 0.8063253 0.8403614 0.8690476    0
## 
## Kappa 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 20 0.3956835 0.5316120 0.5961830 0.5854366 0.6661120 0.6955990    0
## 21 0.3956835 0.5699332 0.5960343 0.5853247 0.6590982 0.6919315    0
## 22 0.3735084 0.5560661 0.6221836 0.5914492 0.6422128 0.7189781    0
## 23 0.3735084 0.5594228 0.6228827 0.5939786 0.6657372 0.6955990    0
## 24 0.3956835 0.5600352 0.6337821 0.5992188 0.6604703 0.7189781    0
## 25 0.3956835 0.5530760 0.6354875 0.5912239 0.6554912 0.7189781    0
## 26 0.3956835 0.5589331 0.6136074 0.5969142 0.6822128 0.6955990    0
## 27 0.3735084 0.5852459 0.6368425 0.5998148 0.6426088 0.7189781    0
## 28 0.4290780 0.5589331 0.6154905 0.5946859 0.6356141 0.7189781    0
## 29 0.4070588 0.5534173 0.6337821 0.5901173 0.6423101 0.6919315    0
## 30 0.3297872 0.5534173 0.6202632 0.5843432 0.6590982 0.7189781    0

Le score de prรฉcision le plus รฉlevรฉ est obtenu avec une valeur de maxnode รฉgale ร  22.

ร‰tape 4) Recherchez les meilleurs ntrees

Maintenant que vous disposez de la meilleure valeur de mtry et maxnode, vous pouvez rรฉgler le nombre d'arbres. La mรฉthode est exactement la mรชme que celle de maxnode.

store_maxtrees <- list()
for (ntree in c(250, 300, 350, 400, 450, 500, 550, 600, 800, 1000, 2000)) {
    set.seed(5678)
    rf_maxtrees <- train(survived~.,
        data = data_train,
        method = "rf",
        metric = "Accuracy",
        tuneGrid = tuneGrid,
        trControl = trControl,
        importance = TRUE,
        nodesize = 14,
        maxnodes = 24,
        ntree = ntree)
    key <- toString(ntree)
    store_maxtrees[[key]] <- rf_maxtrees
}
results_tree <- resamples(store_maxtrees)
summary(results_tree)

Sortie :

## 
## Call:
## summary.resamples(object = results_tree)
## 
## Models: 250, 300, 350, 400, 450, 500, 550, 600, 800, 1000, 2000 
## Number of resamples: 10 
## 
## Accuracy 
##           Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 250  0.7380952 0.7976190 0.8083764 0.8087010 0.8292683 0.8674699    0
## 300  0.7500000 0.7886905 0.8024240 0.8027199 0.8203397 0.8452381    0
## 350  0.7500000 0.7886905 0.8024240 0.8027056 0.8277623 0.8452381    0
## 400  0.7500000 0.7886905 0.8083764 0.8051009 0.8292683 0.8452381    0
## 450  0.7500000 0.7886905 0.8024240 0.8039104 0.8292683 0.8452381    0
## 500  0.7619048 0.7886905 0.8024240 0.8062914 0.8292683 0.8571429    0
## 550  0.7619048 0.7886905 0.8083764 0.8099062 0.8323171 0.8571429    0
## 600  0.7619048 0.7886905 0.8083764 0.8099205 0.8323171 0.8674699    0
## 800  0.7619048 0.7976190 0.8083764 0.8110820 0.8292683 0.8674699    0
## 1000 0.7619048 0.7976190 0.8121510 0.8086723 0.8303571 0.8452381    0
## 2000 0.7619048 0.7886905 0.8121510 0.8086723 0.8333333 0.8452381    0
## 
## Kappa 
##           Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 250  0.4061697 0.5667400 0.5836013 0.5856103 0.6335363 0.7196807    0
## 300  0.4302326 0.5449376 0.5780349 0.5723307 0.6130767 0.6710843    0
## 350  0.4302326 0.5449376 0.5780349 0.5723185 0.6291592 0.6710843    0
## 400  0.4302326 0.5482030 0.5836013 0.5774782 0.6335363 0.6710843    0
## 450  0.4302326 0.5449376 0.5780349 0.5750587 0.6335363 0.6710843    0
## 500  0.4601542 0.5449376 0.5780349 0.5804340 0.6335363 0.6949153    0
## 550  0.4601542 0.5482030 0.5857118 0.5884507 0.6396872 0.6949153    0
## 600  0.4601542 0.5482030 0.5857118 0.5884374 0.6396872 0.7196807    0
## 800  0.4601542 0.5667400 0.5836013 0.5910088 0.6335363 0.7196807    0
## 1000 0.4601542 0.5667400 0.5961590 0.5857446 0.6343666 0.6678832    0
## 2000 0.4601542 0.5482030 0.5961590 0.5862151 0.6440678 0.6656337    0

Vous avez votre modรจle final. Vous pouvez entraรฎner la forรชt alรฉatoire avec les paramรจtres suivants :

  • ntree =800 : 800 arbres seront palissรฉs
  • mtry=4 : 4 fonctionnalitรฉs sont choisies pour chaque itรฉration
  • maxnodes = 24 : 24 nล“uds maximum dans les nล“uds terminaux (feuilles)
fit_rf <- train(survived~.,
    data_train,
    method = "rf",
    metric = "Accuracy",
    tuneGrid = tuneGrid,
    trControl = trControl,
    importance = TRUE,
    nodesize = 14,
    ntree = 800,
    maxnodes = 24)

ร‰tape 5) ร‰valuer le modรจle

Le curseur de bibliothรจque a pour fonction de faire des prรฉdictions.

predict(model, newdata= df)
argument
- `model`: Define the model evaluated before. 
- `newdata`: Define the dataset to make prediction
prediction <-predict(fit_rf, data_test)

Vous pouvez utiliser la prรฉdiction pour calculer la matrice de confusion et voir le score de prรฉcision

confusionMatrix(prediction, data_test$survived)

Sortie :

## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  No Yes
##        No  110  32
##        Yes  11  56
##                                          
##                Accuracy : 0.7943         
##                  95% CI : (0.733, 0.8469)
##     No Information Rate : 0.5789         
##     P-Value [Acc > NIR] : 3.959e-11      
##                                          
##                   Kappa : 0.5638         
##  Mcnemar's Test P-Value : 0.002289       
##                                          
##             Sensitivity : 0.9091         
##             Specificity : 0.6364         
##          Pos Pred Value : 0.7746         
##          Neg Pred Value : 0.8358         
##              Prevalence : 0.5789         
##          Detection Rate : 0.5263         
##    Detection Prevalence : 0.6794         
##       Balanced Accuracy : 0.7727         
##                                          
##        'Positive' Class : No             
## 

Vous disposez d'une prรฉcision de 0.7943 pour cent, ce qui est supรฉrieur ร  la valeur par dรฉfaut

ร‰tape 6) Visualiser le rรฉsultat

Enfin, vous pouvez examiner l'importance des fonctionnalitรฉs avec la fonction varImp(). Il semble que les caractรฉristiques les plus importantes soient le sexe et lโ€™รขge. Cela nโ€™est pas surprenant car les caractรฉristiques importantes apparaรฎtront probablement plus prรจs de la racine de lโ€™arbre, tandis que les caractรฉristiques moins importantes apparaรฎtront souvent plus prรจs des feuilles.

varImpPlot(fit_rf)

Sortie :

varImp(fit_rf)
## rf variable importance
## 
##              Importance
## sexmale         100.000
## age              28.014
## pclassMiddle     27.016
## fare             21.557
## pclassUpper      16.324
## sibsp            11.246
## parch             5.522
## embarkedC         4.908
## embarkedQ         1.420
## embarkedS         0.000		

Rรฉsumรฉ

Nous pouvons rรฉsumer comment entraรฎner et รฉvaluer une forรชt alรฉatoire avec le tableau ci-dessous :

Bibliothรจque Objectif Fonction Paramรจtres
alรฉatoireforรชt Crรฉer une forรชt alรฉatoire Forรชt alรฉatoire() formule, ntree=n, mtry=FALSE, maxnodes = NULL
caret Crรฉer une validation croisรฉe du dossier K trainControl() mรฉthode = ยซ cv ยป, nombre = n, recherche = ยซ grille ยป
caret Entraรฎner une forรชt alรฉatoire former() formule, df, mรฉthode = ยซ rf ยป, metric = ยซ Prรฉcision ยป, trControl = trainControl(), tuneGrid = NULL
caret Prรฉdire hors รฉchantillon prรฉvoir modรจle, newdata= df
caret Matrice de confusion et statistiques confusionMatrix() modรจle, y test
caret importance variable cvarImp() modรจle

Appendice

Liste des modรจles utilisรฉs dans le caret

names>(getModelInfo())

Sortie :

##   [1] "ada"                 "AdaBag"              "AdaBoost.M1"        ##   [4] "adaboost"            "amdai"               "ANFIS"              ##   [7] "avNNet"              "awnb"                "awtan"              ##  [10] "bag"                 "bagEarth"            "bagEarthGCV"        ##  [13] "bagFDA"              "bagFDAGCV"           "bam"                ##  [16] "bartMachine"         "bayesglm"            "binda"              ##  [19] "blackboost"          "blasso"              "blassoAveraged"     ##  [22] "bridge"              "brnn"                "BstLm"              ##  [25] "bstSm"               "bstTree"             "C5.0"               ##  [28] "C5.0Cost"            "C5.0Rules"           "C5.0Tree"           ##  [31] "cforest"             "chaid"               "CSimca"             ##  [34] "ctree"               "ctree2"              "cubist"             ##  [37] "dda"                 "deepboost"           "DENFIS"             ##  [40] "dnn"                 "dwdLinear"           "dwdPoly"            ##  [43] "dwdRadial"           "earth"               "elm"                ##  [46] "enet"                "evtree"              "extraTrees"         ##  [49] "fda"                 "FH.GBML"             "FIR.DM"             ##  [52] "foba"                "FRBCS.CHI"           "FRBCS.W"            ##  [55] "FS.HGD"              "gam"                 "gamboost"           ##  [58] "gamLoess"            "gamSpline"           "gaussprLinear"      ##  [61] "gaussprPoly"         "gaussprRadial"       "gbm_h3o"            ##  [64] "gbm"                 "gcvEarth"            "GFS.FR.MOGUL"       ##  [67] "GFS.GCCL"            "GFS.LT.RS"           "GFS.THRIFT"         ##  [70] "glm.nb"              "glm"                 "glmboost"           ##  [73] "glmnet_h3o"          "glmnet"              "glmStepAIC"         ##  [76] "gpls"                "hda"                 "hdda"               ##  [79] "hdrda"               "HYFIS"               "icr"                ##  [82] "J48"                 "JRip"                "kernelpls"          ##  [85] "kknn"                "knn"                 "krlsPoly"           ##  [88] "krlsRadial"          "lars"                "lars2"              ##  [91] "lasso"               "lda"                 "lda2"               ##  [94] "leapBackward"        "leapForward"         "leapSeq"            ##  [97] "Linda"               "lm"                  "lmStepAIC"          ## [100] "LMT"                 "loclda"              "logicBag"           ## [103] "LogitBoost"          "logreg"              "lssvmLinear"        ## [106] "lssvmPoly"           "lssvmRadial"         "lvq"                ## [109] "M5"                  "M5Rules"             "manb"               ## [112] "mda"                 "Mlda"                "mlp"                ## [115] "mlpKerasDecay"       "mlpKerasDecayCost"   "mlpKerasDropout"    ## [118] "mlpKerasDropoutCost" "mlpML"               "mlpSGD"             ## [121] "mlpWeightDecay"      "mlpWeightDecayML"    "monmlp"             ## [124] "msaenet"             "multinom"            "mxnet"              ## [127] "mxnetAdam"           "naive_bayes"         "nb"                 ## [130] "nbDiscrete"          "nbSearch"            "neuralnet"          ## [133] "nnet"                "nnls"                "nodeHarvest"        ## [136] "null"                "OneR"                "ordinalNet"         ## [139] "ORFlog"              "ORFpls"              "ORFridge"           ## [142] "ORFsvm"              "ownn"                "pam"                ## [145] "parRF"               "PART"                "partDSA"            ## [148] "pcaNNet"             "pcr"                 "pda"                ## [151] "pda2"                "penalized"           "PenalizedLDA"       ## [154] "plr"                 "pls"                 "plsRglm"            ## [157] "polr"                "ppr"                 "PRIM"               ## [160] "protoclass"          "pythonKnnReg"        "qda"                ## [163] "QdaCov"              "qrf"                 "qrnn"               ## [166] "randomGLM"           "ranger"              "rbf"                ## [169] "rbfDDA"              "Rborist"             "rda"                ## [172] "regLogistic"         "relaxo"              "rf"                 ## [175] "rFerns"              "RFlda"               "rfRules"            ## [178] "ridge"               "rlda"                "rlm"                ## [181] "rmda"                "rocc"                "rotationForest"     ## [184] "rotationForestCp"    "rpart"               "rpart1SE"           ## [187] "rpart2"              "rpartCost"           "rpartScore"         ## [190] "rqlasso"             "rqnc"                "RRF"                ## [193] "RRFglobal"           "rrlda"               "RSimca"             ## [196] "rvmLinear"           "rvmPoly"             "rvmRadial"          ## [199] "SBC"                 "sda"                 "sdwd"               ## [202] "simpls"              "SLAVE"               "slda"               ## [205] "smda"                "snn"                 "sparseLDA"          ## [208] "spikeslab"           "spls"                "stepLDA"            ## [211] "stepQDA"             "superpc"             "svmBoundrangeString"## [214] "svmExpoString"       "svmLinear"           "svmLinear2"         ## [217] "svmLinear3"          "svmLinearWeights"    "svmLinearWeights2"  ## [220] "svmPoly"             "svmRadial"           "svmRadialCost"      ## [223] "svmRadialSigma"      "svmRadialWeights"    "svmSpectrumString"  ## [226] "tan"                 "tanSearch"           "treebag"            ## [229] "vbmpRadial"          "vglmAdjCat"          "vglmContRatio"      ## [232] "vglmCumulative"      "widekernelpls"       "WM"                 ## [235] "wsrf"                "xgbLinear"           "xgbTree"            ## [238] "xyf"

Rรฉsumez cet article avec :