R Random Forest Tutorial esimerkin kanssa

Mikä on Random Forest R:ssä?

Satunnaiset metsät perustuvat yksinkertaiseen ajatukseen: "joukon viisaudelle". Useiden ennustajien tulosten aggregaatti antaa paremman ennusteen kuin paras yksittäinen ennustaja. Ennustajien ryhmää kutsutaan an kokonaisuus. Siksi tätä tekniikkaa kutsutaan Yhtye-oppiminen.

Aiemmassa opetusohjelmassa opit käyttämään Päättävät puut tehdä binäärinen ennustus. Tekniikkamme parantamiseksi voimme kouluttaa ryhmän Päätöspuun luokittimet, jokainen junajoukon eri satunnaisessa osajoukossa. Ennusteen tekemiseksi hankimme vain kaikkien yksilöpuiden ennusteet ja ennustamme sitten eniten ääniä saaneen luokan. Tätä tekniikkaa kutsutaan Satunnainen metsä.

Vaihe 1) Tuo tiedot

Varmistaaksesi, että sinulla on sama tietojoukko kuin opetusohjelmassa päätöksentekopuut, junatesti ja testisarja on tallennettu Internetiin. Voit tuoda ne ilman muutoksia.

library(dplyr)
data_train <- read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/train.csv")
glimpse(data_train)
data_test <- read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/test.csv") 
glimpse(data_test)

Vaihe 2) Kouluta malli

Yksi tapa arvioida mallin suorituskykyä on harjoitella sitä useilla erilaisilla pienemmillä tietojoukoilla ja arvioida ne toiseen pienempään testijoukkoon verrattuna. Tätä kutsutaan F-kertainen ristiinvalidointi ominaisuus. R on toiminto, joka jakaa satunnaisesti useita lähes samankokoisia tietojoukkoja. Jos esimerkiksi k=9, malli arvioidaan yhdeksän kansion yli ja testataan jäljellä olevalla testijoukolla. Tätä prosessia toistetaan, kunnes kaikki osajoukot on arvioitu. Tätä tekniikkaa käytetään laajalti mallien valinnassa, varsinkin kun mallilla on viritettävät parametrit.

Nyt kun meillä on tapa arvioida malliamme, meidän on selvitettävä, kuinka valita parametrit, jotka yleistivät parhaiten dataa.

Satunnainen metsä valitsee satunnaisen osajoukon ominaisuuksia ja rakentaa useita päätöspuita. Malli laskee keskiarvon kaikista päätöspuiden ennusteista.

Satunnaisessa metsässä on joitain parametreja, joita voidaan muuttaa ennusteen yleistämisen parantamiseksi. Käytät funktiota RandomForest() mallin kouluttamiseen.

Randon Forestin syntaksi on

RandomForest(formula, ntree=n, mtry=FALSE, maxnodes = NULL)
Arguments:
- Formula: Formula of the fitted model
- ntree: number of trees in the forest
- mtry: Number of candidates draw to feed the algorithm. By default, it is the square of the number of columns.
- maxnodes: Set the maximum amount of terminal nodes in the forest
- importance=TRUE: Whether independent variables importance in the random forest be assessed

Huomautuksia: Satunnaista metsää voidaan kouluttaa useammille parametreille. Voit viitata vinjetti nähdäksesi eri parametrit.

Mallin viritys on erittäin työlästä työtä. Parametrien välillä on monia yhdistelmiä. Sinulla ei välttämättä ole aikaa kokeilla kaikkia. Hyvä vaihtoehto on antaa koneen löytää sinulle paras yhdistelmä. Käytettävissä on kaksi tapaa:

  • Satunnainen haku
  • Ruudukkohaku

Määrittelemme molemmat menetelmät, mutta harjoituksen aikana harjoittelemme mallia ruudukkohaun avulla

Ruudukkohaun määritelmä

Ruudukkohakumenetelmä on yksinkertainen, mallia arvioidaan ristiinvalidoinnin avulla kaikkien funktiossa välitettyjen yhdistelmän osalta.

Haluat esimerkiksi kokeilla mallia, jossa on 10, 20, 30 puuta, ja jokainen puu testataan usealla mtry:llä, joka on 1, 2, 3, 4, 5. Sitten kone testaa 15 eri mallia:

    .mtry ntrees
 1      1     10
 2      2     10
 3      3     10
 4      4     10
 5      5     10
 6      1     20
 7      2     20
 8      3     20
 9      4     20
 10     5     20
 11     1     30
 12     2     30
 13     3     30
 14     4     30
 15     5     30	

Algoritmi arvioi:

RandomForest(formula, ntree=10, mtry=1)
RandomForest(formula, ntree=10, mtry=2)
RandomForest(formula, ntree=10, mtry=3)
RandomForest(formula, ntree=20, mtry=2)
...

Joka kerta satunnainen metsä kokeilee ristiinvalidointia. Yksi ruudukkohaun puute on kokeilujen määrä. Siitä voi tulla erittäin helposti räjähdysherkkä, kun yhdistelmien lukumäärä on suuri. Voit ratkaista tämän ongelman käyttämällä satunnaista hakua

Satunnaishaun määritelmä

Suuri ero satunnaishaun ja ruudukkohaun välillä on, että satunnaishaku ei arvioi kaikkia hyperparametrien yhdistelmää hakutilassa. Sen sijaan se valitsee satunnaisesti yhdistelmän joka iteraatiossa. Sen etuna on alhaisemmat laskentakustannukset.

Aseta ohjausparametri

Rakennat ja arvioit mallin seuraavasti:

  • Arvioi malli oletusasetuksilla
  • Etsi paras määrä mtry
  • Etsi paras määrä maxnodeja
  • Etsi paras määrä puita
  • Arvioi malli testitietojoukosta

Ennen kuin aloitat parametrien tutkimisen, sinun on asennettava kaksi kirjastoa.

  • caret: R koneoppimiskirjasto. Jos sinulla on asenna R r-essentialin kanssa. Se on jo kirjastossa
  • e1071: R-koneoppimiskirjasto.

Voit tuoda ne yhdessä RandomForestin kanssa

library(randomForest)
library(caret)
library(e1071)

Oletusasetus

K-kertaista ristiinvalidointia ohjaa trainControl()-funktio

trainControl(method = "cv", number = n, search ="grid")
arguments
- method = "cv": The method used to resample the dataset. 
- number = n: Number of folders to create
- search = "grid": Use the search grid method. For randomized method, use "grid"
Note: You can refer to the vignette to see the other arguments of the function.

Voit yrittää ajaa mallia oletusparametreilla ja nähdä tarkkuuspisteet.

Huomautuksia: Käytät samoja säätimiä koko opetusohjelman ajan.

# Define the control
trControl <- trainControl(method = "cv",
    number = 10,
    search = "grid")

Arvioi mallisi Caret-kirjaston avulla. Kirjastossa on yksi funktio nimeltä train() arvioimaan lähes kaikki koneoppiminen algoritmi. Sano toisin, voit käyttää tätä toimintoa muiden algoritmien opettamiseen.

Perussyntaksi on:

train(formula, df, method = "rf", metric= "Accuracy", trControl = trainControl(), tuneGrid = NULL)
argument
- `formula`: Define the formula of the algorithm
- `method`: Define which model to train. Note, at the end of the tutorial, there is a list of all the models that can be trained
- `metric` = "Accuracy": Define how to select the optimal model
- `trControl = trainControl()`: Define the control parameters
- `tuneGrid = NULL`: Return a data frame with all the possible combination

Kokeillaan rakentaa malli oletusarvoilla.

set.seed(1234)
# Run the model
rf_default <- train(survived~.,
    data = data_train,
    method = "rf",
    metric = "Accuracy",
    trControl = trControl)
# Print the results
print(rf_default)

Koodin selitys

  • trainControl(method=”cv”, numero=10, search=”grid”): Arvioi malli 10 kansion ruudukkohaulla
  • juna(…): Harjoittele satunnainen metsämalli. Paras malli valitaan tarkkuusmittauksella.

lähtö:

## Random Forest 
## 
## 836 samples
##   7 predictor
##   2 classes: 'No', 'Yes' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 753, 752, 753, 752, 752, 752, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    2    0.7919248  0.5536486
##    6    0.7811245  0.5391611
##   10    0.7572002  0.4939620
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was mtry = 2.

Algoritmi käyttää 500 puuta ja testasi kolmea erilaista mtry-arvoa: 2, 6, 10.

Mallissa käytetty lopullinen arvo oli mtry = 2, jonka tarkkuus oli 0.78. Yritetään saada korkeampi pistemäärä.

Vaihe 2) Etsi paras mtry

Voit testata mallia mtry-arvoilla 1-10

set.seed(1234)
tuneGrid <- expand.grid(.mtry = c(1: 10))
rf_mtry <- train(survived~.,
    data = data_train,
    method = "rf",
    metric = "Accuracy",
    tuneGrid = tuneGrid,
    trControl = trControl,
    importance = TRUE,
    nodesize = 14,
    ntree = 300)
print(rf_mtry)

Koodin selitys

  • tuneGrid <- expand.grid(.mtry=c(3:10)): Rakenna vektori, jonka arvo on 3:10

Mallissa käytetty lopullinen arvo oli mtry = 4.

lähtö:

## Random Forest 
## 
## 836 samples
##   7 predictor
##   2 classes: 'No', 'Yes' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 753, 752, 753, 752, 752, 752, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    1    0.7572576  0.4647368
##    2    0.7979346  0.5662364
##    3    0.8075158  0.5884815
##    4    0.8110729  0.5970664
##    5    0.8074727  0.5900030
##    6    0.8099111  0.5949342
##    7    0.8050918  0.5866415
##    8    0.8050918  0.5855399
##    9    0.8050631  0.5855035
##   10    0.7978916  0.5707336
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was mtry = 4.

Mtry:n paras arvo on tallennettu:

rf_mtry$bestTune$mtry

Voit tallentaa sen ja käyttää sitä, kun haluat virittää muita parametreja.

max(rf_mtry$results$Accuracy)

lähtö:

## [1] 0.8110729
best_mtry <- rf_mtry$bestTune$mtry 
best_mtry

lähtö:

## [1] 4

Vaihe 3) Etsi parhaat maxsolmut

Sinun on luotava silmukka arvioidaksesi maxnode-arvoja. Seuraavassa koodissa saat:

  • Luo luettelo
  • Luo muuttuja, jolla on paras parametrin arvo mtry; Pakollinen
  • Luo silmukka
  • Tallenna maxnoden nykyinen arvo
  • Tee yhteenveto tuloksista
store_maxnode <- list()
tuneGrid <- expand.grid(.mtry = best_mtry)
for (maxnodes in c(5: 15)) {
    set.seed(1234)
    rf_maxnode <- train(survived~.,
        data = data_train,
        method = "rf",
        metric = "Accuracy",
        tuneGrid = tuneGrid,
        trControl = trControl,
        importance = TRUE,
        nodesize = 14,
        maxnodes = maxnodes,
        ntree = 300)
    current_iteration <- toString(maxnodes)
    store_maxnode[[current_iteration]] <- rf_maxnode
}
results_mtry <- resamples(store_maxnode)
summary(results_mtry)

Koodin selitys:

  • store_maxnode <- list(): Mallin tulokset tallennetaan tähän luetteloon
  • expand.grid(.mtry=best_mtry): Käytä mtry:n parasta arvoa
  • for (maxnodes in c(15:25)) { … }: Laske malli maksimisolmujen arvoilla alkaen 15-25.
  • maxnodes=maxnodes: Jokaisessa iteraatiossa maxnodes on yhtä suuri kuin maxnodesin nykyinen arvo. eli 15, 16, 17...
  • avain <- toString(maxnodes): Tallenna merkkijonomuuttujaksi maxnode-arvo.
  • store_maxnode[[avain]] <- rf_maxnode: Tallenna mallin tulos luetteloon.
  • resamples(store_maxnode): Järjestä mallin tulokset
  • summary(results_mtry): Tulosta yhteenveto kaikesta yhdistelmästä.

lähtö:

## 
## Call:
## summary.resamples(object = results_mtry)
## 
## Models: 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 
## Number of resamples: 10 
## 
## Accuracy 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 5  0.6785714 0.7529762 0.7903758 0.7799771 0.8168388 0.8433735    0
## 6  0.6904762 0.7648810 0.7784710 0.7811962 0.8125000 0.8313253    0
## 7  0.6904762 0.7619048 0.7738095 0.7788009 0.8102410 0.8333333    0
## 8  0.6904762 0.7627295 0.7844234 0.7847820 0.8184524 0.8433735    0
## 9  0.7261905 0.7747418 0.8083764 0.7955250 0.8258749 0.8333333    0
## 10 0.6904762 0.7837780 0.7904475 0.7895869 0.8214286 0.8433735    0
## 11 0.7023810 0.7791523 0.8024240 0.7943775 0.8184524 0.8433735    0
## 12 0.7380952 0.7910929 0.8144005 0.8051205 0.8288511 0.8452381    0
## 13 0.7142857 0.8005952 0.8192771 0.8075158 0.8403614 0.8452381    0
## 14 0.7380952 0.7941050 0.8203528 0.8098967 0.8403614 0.8452381    0
## 15 0.7142857 0.8000215 0.8203528 0.8075301 0.8378873 0.8554217    0
## 
## Kappa 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 5  0.3297872 0.4640436 0.5459706 0.5270773 0.6068751 0.6717371    0
## 6  0.3576471 0.4981484 0.5248805 0.5366310 0.6031287 0.6480921    0
## 7  0.3576471 0.4927448 0.5192771 0.5297159 0.5996437 0.6508314    0
## 8  0.3576471 0.4848320 0.5408159 0.5427127 0.6200253 0.6717371    0
## 9  0.4236277 0.5074421 0.5859472 0.5601687 0.6228626 0.6480921    0
## 10 0.3576471 0.5255698 0.5527057 0.5497490 0.6204819 0.6717371    0
## 11 0.3794326 0.5235007 0.5783191 0.5600467 0.6126720 0.6717371    0
## 12 0.4460432 0.5480930 0.5999072 0.5808134 0.6296780 0.6717371    0
## 13 0.4014252 0.5725752 0.6087279 0.5875305 0.6576219 0.6678832    0
## 14 0.4460432 0.5585005 0.6117973 0.5911995 0.6590982 0.6717371    0
## 15 0.4014252 0.5689401 0.6117973 0.5867010 0.6507194 0.6955990    0

Maxnoden viimeisellä arvolla on suurin tarkkuus. Voit kokeilla suurempia arvoja nähdäksesi, voitko saada korkeamman pistemäärän.

store_maxnode <- list()
tuneGrid <- expand.grid(.mtry = best_mtry)
for (maxnodes in c(20: 30)) {
    set.seed(1234)
    rf_maxnode <- train(survived~.,
        data = data_train,
        method = "rf",
        metric = "Accuracy",
        tuneGrid = tuneGrid,
        trControl = trControl,
        importance = TRUE,
        nodesize = 14,
        maxnodes = maxnodes,
        ntree = 300)
    key <- toString(maxnodes)
    store_maxnode[[key]] <- rf_maxnode
}
results_node <- resamples(store_maxnode)
summary(results_node)

lähtö:

## 
## Call:
## summary.resamples(object = results_node)
## 
## Models: 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30 
## Number of resamples: 10 
## 
## Accuracy 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 20 0.7142857 0.7821644 0.8144005 0.8075301 0.8447719 0.8571429    0
## 21 0.7142857 0.8000215 0.8144005 0.8075014 0.8403614 0.8571429    0
## 22 0.7023810 0.7941050 0.8263769 0.8099254 0.8328313 0.8690476    0
## 23 0.7023810 0.7941050 0.8263769 0.8111302 0.8447719 0.8571429    0
## 24 0.7142857 0.7946429 0.8313253 0.8135112 0.8417599 0.8690476    0
## 25 0.7142857 0.7916667 0.8313253 0.8099398 0.8408635 0.8690476    0
## 26 0.7142857 0.7941050 0.8203528 0.8123207 0.8528758 0.8571429    0
## 27 0.7023810 0.8060456 0.8313253 0.8135112 0.8333333 0.8690476    0
## 28 0.7261905 0.7941050 0.8203528 0.8111015 0.8328313 0.8690476    0
## 29 0.7142857 0.7910929 0.8313253 0.8087063 0.8333333 0.8571429    0
## 30 0.6785714 0.7910929 0.8263769 0.8063253 0.8403614 0.8690476    0
## 
## Kappa 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 20 0.3956835 0.5316120 0.5961830 0.5854366 0.6661120 0.6955990    0
## 21 0.3956835 0.5699332 0.5960343 0.5853247 0.6590982 0.6919315    0
## 22 0.3735084 0.5560661 0.6221836 0.5914492 0.6422128 0.7189781    0
## 23 0.3735084 0.5594228 0.6228827 0.5939786 0.6657372 0.6955990    0
## 24 0.3956835 0.5600352 0.6337821 0.5992188 0.6604703 0.7189781    0
## 25 0.3956835 0.5530760 0.6354875 0.5912239 0.6554912 0.7189781    0
## 26 0.3956835 0.5589331 0.6136074 0.5969142 0.6822128 0.6955990    0
## 27 0.3735084 0.5852459 0.6368425 0.5998148 0.6426088 0.7189781    0
## 28 0.4290780 0.5589331 0.6154905 0.5946859 0.6356141 0.7189781    0
## 29 0.4070588 0.5534173 0.6337821 0.5901173 0.6423101 0.6919315    0
## 30 0.3297872 0.5534173 0.6202632 0.5843432 0.6590982 0.7189781    0

Korkein tarkkuuspistemäärä saadaan, kun maxnode-arvo on 22.

Vaihe 4) Etsi parhaat puut

Nyt kun sinulla on paras arvo mtry:lle ja maxnodelle, voit säätää puiden määrää. Menetelmä on täsmälleen sama kuin maxnode.

store_maxtrees <- list()
for (ntree in c(250, 300, 350, 400, 450, 500, 550, 600, 800, 1000, 2000)) {
    set.seed(5678)
    rf_maxtrees <- train(survived~.,
        data = data_train,
        method = "rf",
        metric = "Accuracy",
        tuneGrid = tuneGrid,
        trControl = trControl,
        importance = TRUE,
        nodesize = 14,
        maxnodes = 24,
        ntree = ntree)
    key <- toString(ntree)
    store_maxtrees[[key]] <- rf_maxtrees
}
results_tree <- resamples(store_maxtrees)
summary(results_tree)

lähtö:

## 
## Call:
## summary.resamples(object = results_tree)
## 
## Models: 250, 300, 350, 400, 450, 500, 550, 600, 800, 1000, 2000 
## Number of resamples: 10 
## 
## Accuracy 
##           Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 250  0.7380952 0.7976190 0.8083764 0.8087010 0.8292683 0.8674699    0
## 300  0.7500000 0.7886905 0.8024240 0.8027199 0.8203397 0.8452381    0
## 350  0.7500000 0.7886905 0.8024240 0.8027056 0.8277623 0.8452381    0
## 400  0.7500000 0.7886905 0.8083764 0.8051009 0.8292683 0.8452381    0
## 450  0.7500000 0.7886905 0.8024240 0.8039104 0.8292683 0.8452381    0
## 500  0.7619048 0.7886905 0.8024240 0.8062914 0.8292683 0.8571429    0
## 550  0.7619048 0.7886905 0.8083764 0.8099062 0.8323171 0.8571429    0
## 600  0.7619048 0.7886905 0.8083764 0.8099205 0.8323171 0.8674699    0
## 800  0.7619048 0.7976190 0.8083764 0.8110820 0.8292683 0.8674699    0
## 1000 0.7619048 0.7976190 0.8121510 0.8086723 0.8303571 0.8452381    0
## 2000 0.7619048 0.7886905 0.8121510 0.8086723 0.8333333 0.8452381    0
## 
## Kappa 
##           Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 250  0.4061697 0.5667400 0.5836013 0.5856103 0.6335363 0.7196807    0
## 300  0.4302326 0.5449376 0.5780349 0.5723307 0.6130767 0.6710843    0
## 350  0.4302326 0.5449376 0.5780349 0.5723185 0.6291592 0.6710843    0
## 400  0.4302326 0.5482030 0.5836013 0.5774782 0.6335363 0.6710843    0
## 450  0.4302326 0.5449376 0.5780349 0.5750587 0.6335363 0.6710843    0
## 500  0.4601542 0.5449376 0.5780349 0.5804340 0.6335363 0.6949153    0
## 550  0.4601542 0.5482030 0.5857118 0.5884507 0.6396872 0.6949153    0
## 600  0.4601542 0.5482030 0.5857118 0.5884374 0.6396872 0.7196807    0
## 800  0.4601542 0.5667400 0.5836013 0.5910088 0.6335363 0.7196807    0
## 1000 0.4601542 0.5667400 0.5961590 0.5857446 0.6343666 0.6678832    0
## 2000 0.4601542 0.5482030 0.5961590 0.5862151 0.6440678 0.6656337    0

Sinulla on lopullinen mallisi. Voit harjoitella satunnaista metsää seuraavilla parametreilla:

  • ntree = 800: 800 puuta koulutetaan
  • mtry=4: Jokaiselle iteraatiolle valitaan 4 ominaisuutta
  • maxnodes = 24: enintään 24 solmua päätesolmuissa (lehdet)
fit_rf <- train(survived~.,
    data_train,
    method = "rf",
    metric = "Accuracy",
    tuneGrid = tuneGrid,
    trControl = trControl,
    importance = TRUE,
    nodesize = 14,
    ntree = 800,
    maxnodes = 24)

Vaihe 5) Arvioi malli

Kirjastokartalla on ennustetoiminto.

predict(model, newdata= df)
argument
- `model`: Define the model evaluated before. 
- `newdata`: Define the dataset to make prediction
prediction <-predict(fit_rf, data_test)

Voit käyttää ennustetta laskeaksesi sekavuusmatriisin ja nähdäksesi tarkkuuspisteet

confusionMatrix(prediction, data_test$survived)

lähtö:

## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  No Yes
##        No  110  32
##        Yes  11  56
##                                          
##                Accuracy : 0.7943         
##                  95% CI : (0.733, 0.8469)
##     No Information Rate : 0.5789         
##     P-Value [Acc > NIR] : 3.959e-11      
##                                          
##                   Kappa : 0.5638         
##  Mcnemar's Test P-Value : 0.002289       
##                                          
##             Sensitivity : 0.9091         
##             Specificity : 0.6364         
##          Pos Pred Value : 0.7746         
##          Neg Pred Value : 0.8358         
##              Prevalence : 0.5789         
##          Detection Rate : 0.5263         
##    Detection Prevalence : 0.6794         
##       Balanced Accuracy : 0.7727         
##                                          
##        'Positive' Class : No             
## 

Tarkkuus on 0.7943 prosenttia, mikä on suurempi kuin oletusarvo

Vaihe 6) Visualisoi tulos

Lopuksi voit tarkastella ominaisuuden tärkeyttä funktiolla varImp(). Näyttää siltä, ​​​​että tärkeimmät ominaisuudet ovat sukupuoli ja ikä. Tämä ei ole yllättävää, koska tärkeät ominaisuudet näkyvät todennäköisesti lähempänä puun juuria, kun taas vähemmän tärkeät ominaisuudet näkyvät usein suljettuina lehdissä.

varImpPlot(fit_rf)

lähtö:

varImp(fit_rf)
## rf variable importance
## 
##              Importance
## sexmale         100.000
## age              28.014
## pclassMiddle     27.016
## fare             21.557
## pclassUpper      16.324
## sibsp            11.246
## parch             5.522
## embarkedC         4.908
## embarkedQ         1.420
## embarkedS         0.000		

Yhteenveto

Voimme tiivistää satunnaisen metsän kouluttamisesta ja arvioinnista alla olevan taulukon avulla:

Kirjasto Tavoite Toiminto Parametri
satunnainen metsä Luo satunnainen metsä RandomForest() kaava, ntree=n, mtry=FALSE, maxnodes = NULL
caret Luo K-kansion ristiintarkistus trainControl() menetelmä = "cv", numero = n, haku = "ruudukko"
caret Harjoittele satunnaista metsää kouluttaa() kaava, df, metodi = "rf", metriikka = "Tarkkuus", trControl = trainControl(), tuneGrid = NULL
caret Ennusta pois näytteestä ennustaa malli, newdata= df
caret Sekaannusmatriisi ja tilastot confusionMatrix() malli, y testi
caret vaihteleva merkitys cvarImp() malli

Liite

Luettelo caretissa käytetyistä malleista

names>(getModelInfo())

lähtö:

##   [1] "ada"                 "AdaBag"              "AdaBoost.M1"        ##   [4] "adaboost"            "amdai"               "ANFIS"              ##   [7] "avNNet"              "awnb"                "awtan"              ##  [10] "bag"                 "bagEarth"            "bagEarthGCV"        ##  [13] "bagFDA"              "bagFDAGCV"           "bam"                ##  [16] "bartMachine"         "bayesglm"            "binda"              ##  [19] "blackboost"          "blasso"              "blassoAveraged"     ##  [22] "bridge"              "brnn"                "BstLm"              ##  [25] "bstSm"               "bstTree"             "C5.0"               ##  [28] "C5.0Cost"            "C5.0Rules"           "C5.0Tree"           ##  [31] "cforest"             "chaid"               "CSimca"             ##  [34] "ctree"               "ctree2"              "cubist"             ##  [37] "dda"                 "deepboost"           "DENFIS"             ##  [40] "dnn"                 "dwdLinear"           "dwdPoly"            ##  [43] "dwdRadial"           "earth"               "elm"                ##  [46] "enet"                "evtree"              "extraTrees"         ##  [49] "fda"                 "FH.GBML"             "FIR.DM"             ##  [52] "foba"                "FRBCS.CHI"           "FRBCS.W"            ##  [55] "FS.HGD"              "gam"                 "gamboost"           ##  [58] "gamLoess"            "gamSpline"           "gaussprLinear"      ##  [61] "gaussprPoly"         "gaussprRadial"       "gbm_h3o"            ##  [64] "gbm"                 "gcvEarth"            "GFS.FR.MOGUL"       ##  [67] "GFS.GCCL"            "GFS.LT.RS"           "GFS.THRIFT"         ##  [70] "glm.nb"              "glm"                 "glmboost"           ##  [73] "glmnet_h3o"          "glmnet"              "glmStepAIC"         ##  [76] "gpls"                "hda"                 "hdda"               ##  [79] "hdrda"               "HYFIS"               "icr"                ##  [82] "J48"                 "JRip"                "kernelpls"          ##  [85] "kknn"                "knn"                 "krlsPoly"           ##  [88] "krlsRadial"          "lars"                "lars2"              ##  [91] "lasso"               "lda"                 "lda2"               ##  [94] "leapBackward"        "leapForward"         "leapSeq"            ##  [97] "Linda"               "lm"                  "lmStepAIC"          ## [100] "LMT"                 "loclda"              "logicBag"           ## [103] "LogitBoost"          "logreg"              "lssvmLinear"        ## [106] "lssvmPoly"           "lssvmRadial"         "lvq"                ## [109] "M5"                  "M5Rules"             "manb"               ## [112] "mda"                 "Mlda"                "mlp"                ## [115] "mlpKerasDecay"       "mlpKerasDecayCost"   "mlpKerasDropout"    ## [118] "mlpKerasDropoutCost" "mlpML"               "mlpSGD"             ## [121] "mlpWeightDecay"      "mlpWeightDecayML"    "monmlp"             ## [124] "msaenet"             "multinom"            "mxnet"              ## [127] "mxnetAdam"           "naive_bayes"         "nb"                 ## [130] "nbDiscrete"          "nbSearch"            "neuralnet"          ## [133] "nnet"                "nnls"                "nodeHarvest"        ## [136] "null"                "OneR"                "ordinalNet"         ## [139] "ORFlog"              "ORFpls"              "ORFridge"           ## [142] "ORFsvm"              "ownn"                "pam"                ## [145] "parRF"               "PART"                "partDSA"            ## [148] "pcaNNet"             "pcr"                 "pda"                ## [151] "pda2"                "penalized"           "PenalizedLDA"       ## [154] "plr"                 "pls"                 "plsRglm"            ## [157] "polr"                "ppr"                 "PRIM"               ## [160] "protoclass"          "pythonKnnReg"        "qda"                ## [163] "QdaCov"              "qrf"                 "qrnn"               ## [166] "randomGLM"           "ranger"              "rbf"                ## [169] "rbfDDA"              "Rborist"             "rda"                ## [172] "regLogistic"         "relaxo"              "rf"                 ## [175] "rFerns"              "RFlda"               "rfRules"            ## [178] "ridge"               "rlda"                "rlm"                ## [181] "rmda"                "rocc"                "rotationForest"     ## [184] "rotationForestCp"    "rpart"               "rpart1SE"           ## [187] "rpart2"              "rpartCost"           "rpartScore"         ## [190] "rqlasso"             "rqnc"                "RRF"                ## [193] "RRFglobal"           "rrlda"               "RSimca"             ## [196] "rvmLinear"           "rvmPoly"             "rvmRadial"          ## [199] "SBC"                 "sda"                 "sdwd"               ## [202] "simpls"              "SLAVE"               "slda"               ## [205] "smda"                "snn"                 "sparseLDA"          ## [208] "spikeslab"           "spls"                "stepLDA"            ## [211] "stepQDA"             "superpc"             "svmBoundrangeString"## [214] "svmExpoString"       "svmLinear"           "svmLinear2"         ## [217] "svmLinear3"          "svmLinearWeights"    "svmLinearWeights2"  ## [220] "svmPoly"             "svmRadial"           "svmRadialCost"      ## [223] "svmRadialSigma"      "svmRadialWeights"    "svmSpectrumString"  ## [226] "tan"                 "tanSearch"           "treebag"            ## [229] "vbmpRadial"          "vglmAdjCat"          "vglmContRatio"      ## [232] "vglmCumulative"      "widekernelpls"       "WM"                 ## [235] "wsrf"                "xgbLinear"           "xgbTree"            ## [238] "xyf"