R Random Forest Tutorial mit Beispiel
Was ist Random Forest in R?
Random Forests basieren auf einer einfachen Idee: „der Weisheit der Menge“. Die Aggregation der Ergebnisse mehrerer Prädiktoren ergibt eine bessere Vorhersage als der beste einzelne Prädiktor. Eine Gruppe von Prädiktoren wird als bezeichnet zusammen zu geniessen. Daher wird diese Technik genannt Ensemble-Lernen.
In einem früheren Tutorial haben Sie die Verwendung gelernt Entscheidungsbäume um eine binäre Vorhersage zu treffen. Um unsere Technik zu verbessern, können wir eine Gruppe trainieren Entscheidungsbaum-Klassifikatoren, jeweils auf einer anderen zufälligen Teilmenge des Zugsatzes. Um eine Vorhersage zu treffen, erhalten wir einfach die Vorhersagen aller einzelnen Bäume und sagen dann die Klasse voraus, die die meisten Stimmen erhält. Diese Technik heißt Zufälliger Wald.
Schritt 1) Importieren Sie die Daten
Um sicherzustellen, dass Sie über denselben Datensatz wie im Tutorial verfügen Entscheidungsbäume, der Zugtest und das Testset werden im Internet gespeichert. Sie können sie importieren, ohne Änderungen vorzunehmen.
library(dplyr) data_train <- read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/train.csv") glimpse(data_train) data_test <- read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/test.csv") glimpse(data_test)
Schritt 2) Trainieren Sie das Modell
Eine Möglichkeit, die Leistung eines Modells zu bewerten, besteht darin, es anhand einer Reihe verschiedener kleinerer Datensätze zu trainieren und diese im Vergleich zu anderen kleineren Testsätzen auszuwerten. Dies nennt man F-fache Kreuzvalidierung -Funktion R verfügt über eine Funktion zum zufälligen Aufteilen einer Anzahl von Datensätzen nahezu gleicher Größe. Wenn beispielsweise k=9, wird das Modell über die neun Ordner ausgewertet und mit dem verbleibenden Testsatz getestet. Dieser Vorgang wird wiederholt, bis alle Teilmengen ausgewertet wurden. Diese Technik wird häufig zur Modellauswahl verwendet, insbesondere wenn das Modell Parameter zum Anpassen aufweist.
Nachdem wir nun die Möglichkeit haben, unser Modell auszuwerten, müssen wir herausfinden, wie wir die Parameter auswählen, die die Daten am besten verallgemeinern.
Random Forest wählt eine zufällige Teilmenge von Features aus und erstellt viele Entscheidungsbäume. Das Modell mittelt alle Vorhersagen der Entscheidungsbäume.
Random Forest verfügt über einige Parameter, die geändert werden können, um die Verallgemeinerung der Vorhersage zu verbessern. Sie werden die Funktion RandomForest() verwenden, um das Modell zu trainieren.
Die Syntax für Randon Forest lautet
RandomForest(formula, ntree=n, mtry=FALSE, maxnodes = NULL) Arguments: - Formula: Formula of the fitted model - ntree: number of trees in the forest - mtry: Number of candidates draw to feed the algorithm. By default, it is the square of the number of columns. - maxnodes: Set the maximum amount of terminal nodes in the forest - importance=TRUE: Whether independent variables importance in the random forest be assessed
Note: Random Forest kann auf mehr Parameter trainiert werden. Sie können sich auf die beziehen Vignette um die verschiedenen Parameter zu sehen.
Das Tuning eines Modells ist eine sehr mühsame Arbeit. Es sind viele Kombinationen zwischen den Parametern möglich. Sie haben nicht unbedingt die Zeit, sie alle auszuprobieren. Eine gute Alternative besteht darin, die Maschine die beste Kombination für Sie finden zu lassen. Es stehen zwei Methoden zur Verfügung:
- Zufällige Suche
- Rastersuche
Wir werden beide Methoden definieren, aber während des Tutorials werden wir das Modell mithilfe der Rastersuche trainieren
Definition der Rastersuche
Die Rastersuchmethode ist einfach: Das Modell wird mithilfe einer Kreuzvalidierung über alle Kombinationen ausgewertet, die Sie in der Funktion übergeben.
Sie möchten das Modell beispielsweise mit einer Anzahl von 10, 20, 30 Bäumen ausprobieren und jeder Baum wird über eine Anzahl von Metern getestet, die 1, 2, 3, 4, 5 entspricht. Dann testet die Maschine 15 verschiedene Modelle:
.mtry ntrees 1 1 10 2 2 10 3 3 10 4 4 10 5 5 10 6 1 20 7 2 20 8 3 20 9 4 20 10 5 20 11 1 30 12 2 30 13 3 30 14 4 30 15 5 30
Der Algorithmus wertet Folgendes aus:
RandomForest(formula, ntree=10, mtry=1) RandomForest(formula, ntree=10, mtry=2) RandomForest(formula, ntree=10, mtry=3) RandomForest(formula, ntree=20, mtry=2) ...
Jedes Mal führt die Zufallsstruktur eine Kreuzvalidierung durch. Ein Manko der Rastersuche ist die Anzahl der Experimente. Es kann sehr leicht explosiv werden, wenn die Anzahl der Kombinationen hoch ist. Um dieses Problem zu lösen, können Sie die Zufallssuche verwenden
Definition der Zufallssuche
Der große Unterschied zwischen der Zufallssuche und der Rastersuche besteht darin, dass die Zufallssuche nicht alle Kombinationen von Hyperparametern im Suchraum auswertet. Stattdessen wird bei jeder Iteration eine zufällige Kombination ausgewählt. Der Vorteil liegt darin, dass der Rechenaufwand geringer ist.
Stellen Sie den Steuerparameter ein
Um das Modell aufzubauen und auszuwerten, gehen Sie wie folgt vor:
- Bewerten Sie das Modell mit der Standardeinstellung
- Finden Sie die beste Anzahl an mtry
- Finden Sie die beste Anzahl von Maxnodes
- Finden Sie die beste Anzahl an Bäumen
- Bewerten Sie das Modell anhand des Testdatensatzes
Bevor Sie mit der Erkundung der Parameter beginnen, müssen Sie zwei Bibliotheken installieren.
- Caret: R-Bibliothek für maschinelles Lernen. Wenn Sie haben R . installieren mit r-essentiell. Es befindet sich bereits in der Bibliothek
- Anaconda: conda install -cr r-caret
- e1071: R-Bibliothek für maschinelles Lernen.
- Anaconda: conda install -cr r-e1071
Sie können sie zusammen mit RandomForest importieren
library(randomForest) library(caret) library(e1071)
Voreinstellung
Die K-fache Kreuzvalidierung wird durch die Funktion trainControl() gesteuert
trainControl(method = "cv", number = n, search ="grid") arguments - method = "cv": The method used to resample the dataset. - number = n: Number of folders to create - search = "grid": Use the search grid method. For randomized method, use "grid" Note: You can refer to the vignette to see the other arguments of the function.
Sie können versuchen, das Modell mit den Standardparametern auszuführen und die Genauigkeitsbewertung anzuzeigen.
Note: Sie werden während des gesamten Tutorials dieselben Steuerelemente verwenden.
# Define the control trControl <- trainControl(method = "cv", number = 10, search = "grid")
Sie verwenden die Caret-Bibliothek, um Ihr Modell zu bewerten. Die Bibliothek verfügt über eine Funktion namens train(), um fast alle auszuwerten Maschinelles Lernen Algorithmus. Anders gesagt, Sie können diese Funktion verwenden, um andere Algorithmen zu trainieren.
Die grundlegende Syntax lautet:
train(formula, df, method = "rf", metric= "Accuracy", trControl = trainControl(), tuneGrid = NULL) argument - `formula`: Define the formula of the algorithm - `method`: Define which model to train. Note, at the end of the tutorial, there is a list of all the models that can be trained - `metric` = "Accuracy": Define how to select the optimal model - `trControl = trainControl()`: Define the control parameters - `tuneGrid = NULL`: Return a data frame with all the possible combination
Versuchen wir, das Modell mit den Standardwerten zu erstellen.
set.seed(1234) # Run the model rf_default <- train(survived~., data = data_train, method = "rf", metric = "Accuracy", trControl = trControl) # Print the results print(rf_default)
Code Erklärung
- trainControl(method=“cv“, number=10, search=“grid“): Bewerten Sie das Modell mit einer Rastersuche in 10 Ordnern
- train(…): Trainieren Sie ein Random-Forest-Modell. Das beste Modell wird mithilfe des Genauigkeitsmaßes ausgewählt.
Ausgang:
## Random Forest ## ## 836 samples ## 7 predictor ## 2 classes: 'No', 'Yes' ## ## No pre-processing ## Resampling: Cross-Validated (10 fold) ## Summary of sample sizes: 753, 752, 753, 752, 752, 752, ... ## Resampling results across tuning parameters: ## ## mtry Accuracy Kappa ## 2 0.7919248 0.5536486 ## 6 0.7811245 0.5391611 ## 10 0.7572002 0.4939620 ## ## Accuracy was used to select the optimal model using the largest value. ## The final value used for the model was mtry = 2.
Der Algorithmus verwendet 500 Bäume und testete drei verschiedene mtry-Werte: 2, 6, 10.
Der für das Modell verwendete Endwert war mtry = 2 mit einer Genauigkeit von 0.78. Versuchen wir, eine höhere Punktzahl zu erreichen.
Schritt 2) Suchen Sie nach dem besten Mtry
Sie können das Modell mit mtry-Werten von 1 bis 10 testen
set.seed(1234) tuneGrid <- expand.grid(.mtry = c(1: 10)) rf_mtry <- train(survived~., data = data_train, method = "rf", metric = "Accuracy", tuneGrid = tuneGrid, trControl = trControl, importance = TRUE, nodesize = 14, ntree = 300) print(rf_mtry)
Code Erklärung
- tuneGrid <- expand.grid(.mtry=c(3:10)): Konstruieren Sie einen Vektor mit einem Wert von 3:10
Der für das Modell verwendete Endwert war mtry = 4.
Ausgang:
## Random Forest ## ## 836 samples ## 7 predictor ## 2 classes: 'No', 'Yes' ## ## No pre-processing ## Resampling: Cross-Validated (10 fold) ## Summary of sample sizes: 753, 752, 753, 752, 752, 752, ... ## Resampling results across tuning parameters: ## ## mtry Accuracy Kappa ## 1 0.7572576 0.4647368 ## 2 0.7979346 0.5662364 ## 3 0.8075158 0.5884815 ## 4 0.8110729 0.5970664 ## 5 0.8074727 0.5900030 ## 6 0.8099111 0.5949342 ## 7 0.8050918 0.5866415 ## 8 0.8050918 0.5855399 ## 9 0.8050631 0.5855035 ## 10 0.7978916 0.5707336 ## ## Accuracy was used to select the optimal model using the largest value. ## The final value used for the model was mtry = 4.
Der beste Wert von mtry wird gespeichert in:
rf_mtry$bestTune$mtry
Sie können es speichern und verwenden, wenn Sie die anderen Parameter optimieren müssen.
max(rf_mtry$results$Accuracy)
Ausgang:
## [1] 0.8110729
best_mtry <- rf_mtry$bestTune$mtry best_mtry
Ausgang:
## [1] 4
Schritt 3) Suchen Sie nach den besten Maxnodes
Sie müssen eine Schleife erstellen, um die verschiedenen Werte von maxnodes auszuwerten. Im folgenden Code werden Sie Folgendes tun:
- Erstelle eine Liste
- Erstellen Sie eine Variable mit dem besten Wert des Parameters mtry; Obligatorisch
- Erstellen Sie die Schleife
- Speichern Sie den aktuellen Wert von maxnode
- Fassen Sie die Ergebnisse zusammen
store_maxnode <- list() tuneGrid <- expand.grid(.mtry = best_mtry) for (maxnodes in c(5: 15)) { set.seed(1234) rf_maxnode <- train(survived~., data = data_train, method = "rf", metric = "Accuracy", tuneGrid = tuneGrid, trControl = trControl, importance = TRUE, nodesize = 14, maxnodes = maxnodes, ntree = 300) current_iteration <- toString(maxnodes) store_maxnode[[current_iteration]] <- rf_maxnode } results_mtry <- resamples(store_maxnode) summary(results_mtry)
Code-Erklärung:
- store_maxnode <- list(): Die Ergebnisse des Modells werden in dieser Liste gespeichert
- expand.grid(.mtry=best_mtry): Verwenden Sie den besten Wert von mtry
- for (maxnodes in c(15:25)) { … }: Berechnen Sie das Modell mit Werten von maxnodes beginnend bei 15 bis 25.
- maxnodes=maxnodes: Für jede Iteration ist maxnodes gleich dem aktuellen Wert von maxnodes. also 15, 16, 17, …
- key <- toString(maxnodes): Speichern Sie den Wert von maxnode als String-Variable.
- store_maxnode[[key]] <- rf_maxnode: Speichern Sie das Ergebnis des Modells in der Liste.
- resamples(store_maxnode): Ordnen Sie die Ergebnisse des Modells an
- summary(results_mtry): Drucken Sie die Zusammenfassung aller Kombinationen.
Ausgang:
## ## Call: ## summary.resamples(object = results_mtry) ## ## Models: 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 ## Number of resamples: 10 ## ## Accuracy ## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's ## 5 0.6785714 0.7529762 0.7903758 0.7799771 0.8168388 0.8433735 0 ## 6 0.6904762 0.7648810 0.7784710 0.7811962 0.8125000 0.8313253 0 ## 7 0.6904762 0.7619048 0.7738095 0.7788009 0.8102410 0.8333333 0 ## 8 0.6904762 0.7627295 0.7844234 0.7847820 0.8184524 0.8433735 0 ## 9 0.7261905 0.7747418 0.8083764 0.7955250 0.8258749 0.8333333 0 ## 10 0.6904762 0.7837780 0.7904475 0.7895869 0.8214286 0.8433735 0 ## 11 0.7023810 0.7791523 0.8024240 0.7943775 0.8184524 0.8433735 0 ## 12 0.7380952 0.7910929 0.8144005 0.8051205 0.8288511 0.8452381 0 ## 13 0.7142857 0.8005952 0.8192771 0.8075158 0.8403614 0.8452381 0 ## 14 0.7380952 0.7941050 0.8203528 0.8098967 0.8403614 0.8452381 0 ## 15 0.7142857 0.8000215 0.8203528 0.8075301 0.8378873 0.8554217 0 ## ## Kappa ## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's ## 5 0.3297872 0.4640436 0.5459706 0.5270773 0.6068751 0.6717371 0 ## 6 0.3576471 0.4981484 0.5248805 0.5366310 0.6031287 0.6480921 0 ## 7 0.3576471 0.4927448 0.5192771 0.5297159 0.5996437 0.6508314 0 ## 8 0.3576471 0.4848320 0.5408159 0.5427127 0.6200253 0.6717371 0 ## 9 0.4236277 0.5074421 0.5859472 0.5601687 0.6228626 0.6480921 0 ## 10 0.3576471 0.5255698 0.5527057 0.5497490 0.6204819 0.6717371 0 ## 11 0.3794326 0.5235007 0.5783191 0.5600467 0.6126720 0.6717371 0 ## 12 0.4460432 0.5480930 0.5999072 0.5808134 0.6296780 0.6717371 0 ## 13 0.4014252 0.5725752 0.6087279 0.5875305 0.6576219 0.6678832 0 ## 14 0.4460432 0.5585005 0.6117973 0.5911995 0.6590982 0.6717371 0 ## 15 0.4014252 0.5689401 0.6117973 0.5867010 0.6507194 0.6955990 0
Der letzte Wert von maxnode hat die höchste Genauigkeit. Sie können es mit höheren Werten versuchen, um zu sehen, ob Sie eine höhere Punktzahl erzielen können.
store_maxnode <- list() tuneGrid <- expand.grid(.mtry = best_mtry) for (maxnodes in c(20: 30)) { set.seed(1234) rf_maxnode <- train(survived~., data = data_train, method = "rf", metric = "Accuracy", tuneGrid = tuneGrid, trControl = trControl, importance = TRUE, nodesize = 14, maxnodes = maxnodes, ntree = 300) key <- toString(maxnodes) store_maxnode[[key]] <- rf_maxnode } results_node <- resamples(store_maxnode) summary(results_node)
Ausgang:
## ## Call: ## summary.resamples(object = results_node) ## ## Models: 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30 ## Number of resamples: 10 ## ## Accuracy ## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's ## 20 0.7142857 0.7821644 0.8144005 0.8075301 0.8447719 0.8571429 0 ## 21 0.7142857 0.8000215 0.8144005 0.8075014 0.8403614 0.8571429 0 ## 22 0.7023810 0.7941050 0.8263769 0.8099254 0.8328313 0.8690476 0 ## 23 0.7023810 0.7941050 0.8263769 0.8111302 0.8447719 0.8571429 0 ## 24 0.7142857 0.7946429 0.8313253 0.8135112 0.8417599 0.8690476 0 ## 25 0.7142857 0.7916667 0.8313253 0.8099398 0.8408635 0.8690476 0 ## 26 0.7142857 0.7941050 0.8203528 0.8123207 0.8528758 0.8571429 0 ## 27 0.7023810 0.8060456 0.8313253 0.8135112 0.8333333 0.8690476 0 ## 28 0.7261905 0.7941050 0.8203528 0.8111015 0.8328313 0.8690476 0 ## 29 0.7142857 0.7910929 0.8313253 0.8087063 0.8333333 0.8571429 0 ## 30 0.6785714 0.7910929 0.8263769 0.8063253 0.8403614 0.8690476 0 ## ## Kappa ## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's ## 20 0.3956835 0.5316120 0.5961830 0.5854366 0.6661120 0.6955990 0 ## 21 0.3956835 0.5699332 0.5960343 0.5853247 0.6590982 0.6919315 0 ## 22 0.3735084 0.5560661 0.6221836 0.5914492 0.6422128 0.7189781 0 ## 23 0.3735084 0.5594228 0.6228827 0.5939786 0.6657372 0.6955990 0 ## 24 0.3956835 0.5600352 0.6337821 0.5992188 0.6604703 0.7189781 0 ## 25 0.3956835 0.5530760 0.6354875 0.5912239 0.6554912 0.7189781 0 ## 26 0.3956835 0.5589331 0.6136074 0.5969142 0.6822128 0.6955990 0 ## 27 0.3735084 0.5852459 0.6368425 0.5998148 0.6426088 0.7189781 0 ## 28 0.4290780 0.5589331 0.6154905 0.5946859 0.6356141 0.7189781 0 ## 29 0.4070588 0.5534173 0.6337821 0.5901173 0.6423101 0.6919315 0 ## 30 0.3297872 0.5534173 0.6202632 0.5843432 0.6590982 0.7189781 0
Die höchste Genauigkeitsbewertung wird mit einem Wert von maxnode von 22 erreicht.
Schritt 4) Suchen Sie nach den besten Bäumen
Da Sie nun den besten Wert für mtry und maxnode haben, können Sie die Anzahl der Bäume anpassen. Die Methode ist genau die gleiche wie bei maxnode.
store_maxtrees <- list() for (ntree in c(250, 300, 350, 400, 450, 500, 550, 600, 800, 1000, 2000)) { set.seed(5678) rf_maxtrees <- train(survived~., data = data_train, method = "rf", metric = "Accuracy", tuneGrid = tuneGrid, trControl = trControl, importance = TRUE, nodesize = 14, maxnodes = 24, ntree = ntree) key <- toString(ntree) store_maxtrees[[key]] <- rf_maxtrees } results_tree <- resamples(store_maxtrees) summary(results_tree)
Ausgang:
## ## Call: ## summary.resamples(object = results_tree) ## ## Models: 250, 300, 350, 400, 450, 500, 550, 600, 800, 1000, 2000 ## Number of resamples: 10 ## ## Accuracy ## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's ## 250 0.7380952 0.7976190 0.8083764 0.8087010 0.8292683 0.8674699 0 ## 300 0.7500000 0.7886905 0.8024240 0.8027199 0.8203397 0.8452381 0 ## 350 0.7500000 0.7886905 0.8024240 0.8027056 0.8277623 0.8452381 0 ## 400 0.7500000 0.7886905 0.8083764 0.8051009 0.8292683 0.8452381 0 ## 450 0.7500000 0.7886905 0.8024240 0.8039104 0.8292683 0.8452381 0 ## 500 0.7619048 0.7886905 0.8024240 0.8062914 0.8292683 0.8571429 0 ## 550 0.7619048 0.7886905 0.8083764 0.8099062 0.8323171 0.8571429 0 ## 600 0.7619048 0.7886905 0.8083764 0.8099205 0.8323171 0.8674699 0 ## 800 0.7619048 0.7976190 0.8083764 0.8110820 0.8292683 0.8674699 0 ## 1000 0.7619048 0.7976190 0.8121510 0.8086723 0.8303571 0.8452381 0 ## 2000 0.7619048 0.7886905 0.8121510 0.8086723 0.8333333 0.8452381 0 ## ## Kappa ## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's ## 250 0.4061697 0.5667400 0.5836013 0.5856103 0.6335363 0.7196807 0 ## 300 0.4302326 0.5449376 0.5780349 0.5723307 0.6130767 0.6710843 0 ## 350 0.4302326 0.5449376 0.5780349 0.5723185 0.6291592 0.6710843 0 ## 400 0.4302326 0.5482030 0.5836013 0.5774782 0.6335363 0.6710843 0 ## 450 0.4302326 0.5449376 0.5780349 0.5750587 0.6335363 0.6710843 0 ## 500 0.4601542 0.5449376 0.5780349 0.5804340 0.6335363 0.6949153 0 ## 550 0.4601542 0.5482030 0.5857118 0.5884507 0.6396872 0.6949153 0 ## 600 0.4601542 0.5482030 0.5857118 0.5884374 0.6396872 0.7196807 0 ## 800 0.4601542 0.5667400 0.5836013 0.5910088 0.6335363 0.7196807 0 ## 1000 0.4601542 0.5667400 0.5961590 0.5857446 0.6343666 0.6678832 0 ## 2000 0.4601542 0.5482030 0.5961590 0.5862151 0.6440678 0.6656337 0
Sie haben Ihr endgültiges Modell. Sie können den Random Forest mit den folgenden Parametern trainieren:
- ntree =800: 800 Bäume werden trainiert
- mtry=4: Für jede Iteration werden 4 Features ausgewählt
- maxnodes = 24: Maximal 24 Knoten in den Endknoten (Blättern)
fit_rf <- train(survived~., data_train, method = "rf", metric = "Accuracy", tuneGrid = tuneGrid, trControl = trControl, importance = TRUE, nodesize = 14, ntree = 800, maxnodes = 24)
Schritt 5) Bewerten Sie das Modell
Das Bibliotheks-Caret hat die Funktion, Vorhersagen zu treffen.
predict(model, newdata= df) argument - `model`: Define the model evaluated before. - `newdata`: Define the dataset to make prediction
prediction <-predict(fit_rf, data_test)
Sie können die Vorhersage verwenden, um die Verwirrungsmatrix zu berechnen und den Genauigkeitswert anzuzeigen
confusionMatrix(prediction, data_test$survived)
Ausgang:
## Confusion Matrix and Statistics ## ## Reference ## Prediction No Yes ## No 110 32 ## Yes 11 56 ## ## Accuracy : 0.7943 ## 95% CI : (0.733, 0.8469) ## No Information Rate : 0.5789 ## P-Value [Acc > NIR] : 3.959e-11 ## ## Kappa : 0.5638 ## Mcnemar's Test P-Value : 0.002289 ## ## Sensitivity : 0.9091 ## Specificity : 0.6364 ## Pos Pred Value : 0.7746 ## Neg Pred Value : 0.8358 ## Prevalence : 0.5789 ## Detection Rate : 0.5263 ## Detection Prevalence : 0.6794 ## Balanced Accuracy : 0.7727 ## ## 'Positive' Class : No ##
Sie haben eine Genauigkeit von 0.7943 Prozent, was über dem Standardwert liegt
Schritt 6) Visualisieren Sie das Ergebnis
Abschließend können Sie die Feature-Wichtigkeit mit der Funktion varImp() betrachten. Es scheint, dass die wichtigsten Merkmale das Geschlecht und das Alter sind. Das ist nicht verwunderlich, da die wichtigen Merkmale wahrscheinlich näher an der Wurzel des Baumes erscheinen, während weniger wichtige Merkmale oft in der Nähe der Blätter erscheinen.
varImpPlot(fit_rf)
Ausgang:
varImp(fit_rf) ## rf variable importance ## ## Importance ## sexmale 100.000 ## age 28.014 ## pclassMiddle 27.016 ## fare 21.557 ## pclassUpper 16.324 ## sibsp 11.246 ## parch 5.522 ## embarkedC 4.908 ## embarkedQ 1.420 ## embarkedS 0.000
Zusammenfassung
Mit der folgenden Tabelle können wir zusammenfassen, wie man einen Random Forest trainiert und auswertet:
Bibliothek | Ziel | Funktion | Parameter |
---|---|---|---|
zufälligerWald | Erstellen Sie eine zufällige Gesamtstruktur | RandomForest() | Formel, ntree=n, mtry=FALSE, maxnodes = NULL |
Caret | Erstellen Sie eine Kreuzvalidierung für den K-Ordner | trainControl() | Methode = „Lebenslauf“, Zahl = n, Suche = „Gitter“ |
Caret | Trainiere einen zufälligen Wald | Zug() | Formel, df, Methode = „rf“, Metrik = „Accuracy“, trControl = trainControl(), tuneGrid = NULL |
Caret | Vorhersagen aus der Stichprobe | vorhersagen | Modell, newdata= df |
Caret | Verwirrungsmatrix und Statistik | Verwirrung Matrix() | Modell, y-Test |
Caret | variable Bedeutung | cvarImp() | Modell |
Anhang
Liste der im Caret verwendeten Modelle
names>(getModelInfo())
Ausgang:
## [1] "ada" "AdaBag" "AdaBoost.M1" ## [4] "adaboost" "amdai" "ANFIS" ## [7] "avNNet" "awnb" "awtan" ## [10] "bag" "bagEarth" "bagEarthGCV" ## [13] "bagFDA" "bagFDAGCV" "bam" ## [16] "bartMachine" "bayesglm" "binda" ## [19] "blackboost" "blasso" "blassoAveraged" ## [22] "bridge" "brnn" "BstLm" ## [25] "bstSm" "bstTree" "C5.0" ## [28] "C5.0Cost" "C5.0Rules" "C5.0Tree" ## [31] "cforest" "chaid" "CSimca" ## [34] "ctree" "ctree2" "cubist" ## [37] "dda" "deepboost" "DENFIS" ## [40] "dnn" "dwdLinear" "dwdPoly" ## [43] "dwdRadial" "earth" "elm" ## [46] "enet" "evtree" "extraTrees" ## [49] "fda" "FH.GBML" "FIR.DM" ## [52] "foba" "FRBCS.CHI" "FRBCS.W" ## [55] "FS.HGD" "gam" "gamboost" ## [58] "gamLoess" "gamSpline" "gaussprLinear" ## [61] "gaussprPoly" "gaussprRadial" "gbm_h3o" ## [64] "gbm" "gcvEarth" "GFS.FR.MOGUL" ## [67] "GFS.GCCL" "GFS.LT.RS" "GFS.THRIFT" ## [70] "glm.nb" "glm" "glmboost" ## [73] "glmnet_h3o" "glmnet" "glmStepAIC" ## [76] "gpls" "hda" "hdda" ## [79] "hdrda" "HYFIS" "icr" ## [82] "J48" "JRip" "kernelpls" ## [85] "kknn" "knn" "krlsPoly" ## [88] "krlsRadial" "lars" "lars2" ## [91] "lasso" "lda" "lda2" ## [94] "leapBackward" "leapForward" "leapSeq" ## [97] "Linda" "lm" "lmStepAIC" ## [100] "LMT" "loclda" "logicBag" ## [103] "LogitBoost" "logreg" "lssvmLinear" ## [106] "lssvmPoly" "lssvmRadial" "lvq" ## [109] "M5" "M5Rules" "manb" ## [112] "mda" "Mlda" "mlp" ## [115] "mlpKerasDecay" "mlpKerasDecayCost" "mlpKerasDropout" ## [118] "mlpKerasDropoutCost" "mlpML" "mlpSGD" ## [121] "mlpWeightDecay" "mlpWeightDecayML" "monmlp" ## [124] "msaenet" "multinom" "mxnet" ## [127] "mxnetAdam" "naive_bayes" "nb" ## [130] "nbDiscrete" "nbSearch" "neuralnet" ## [133] "nnet" "nnls" "nodeHarvest" ## [136] "null" "OneR" "ordinalNet" ## [139] "ORFlog" "ORFpls" "ORFridge" ## [142] "ORFsvm" "ownn" "pam" ## [145] "parRF" "PART" "partDSA" ## [148] "pcaNNet" "pcr" "pda" ## [151] "pda2" "penalized" "PenalizedLDA" ## [154] "plr" "pls" "plsRglm" ## [157] "polr" "ppr" "PRIM" ## [160] "protoclass" "pythonKnnReg" "qda" ## [163] "QdaCov" "qrf" "qrnn" ## [166] "randomGLM" "ranger" "rbf" ## [169] "rbfDDA" "Rborist" "rda" ## [172] "regLogistic" "relaxo" "rf" ## [175] "rFerns" "RFlda" "rfRules" ## [178] "ridge" "rlda" "rlm" ## [181] "rmda" "rocc" "rotationForest" ## [184] "rotationForestCp" "rpart" "rpart1SE" ## [187] "rpart2" "rpartCost" "rpartScore" ## [190] "rqlasso" "rqnc" "RRF" ## [193] "RRFglobal" "rrlda" "RSimca" ## [196] "rvmLinear" "rvmPoly" "rvmRadial" ## [199] "SBC" "sda" "sdwd" ## [202] "simpls" "SLAVE" "slda" ## [205] "smda" "snn" "sparseLDA" ## [208] "spikeslab" "spls" "stepLDA" ## [211] "stepQDA" "superpc" "svmBoundrangeString"## [214] "svmExpoString" "svmLinear" "svmLinear2" ## [217] "svmLinear3" "svmLinearWeights" "svmLinearWeights2" ## [220] "svmPoly" "svmRadial" "svmRadialCost" ## [223] "svmRadialSigma" "svmRadialWeights" "svmSpectrumString" ## [226] "tan" "tanSearch" "treebag" ## [229] "vbmpRadial" "vglmAdjCat" "vglmContRatio" ## [232] "vglmCumulative" "widekernelpls" "WM" ## [235] "wsrf" "xgbLinear" "xgbTree" ## [238] "xyf"