R บทช่วยสอนป่าสุ่มพร้อมตัวอย่าง

Random Forest ใน R คืออะไร?

ป่าสุ่มมีพื้นฐานอยู่บนแนวคิดง่ายๆ นั่นคือ "ภูมิปัญญาของฝูงชน" การรวมผลลัพธ์ของตัวทำนายหลายตัวทำให้สามารถทำนายได้ดีกว่าตัวทำนายแต่ละตัวที่ดีที่สุด กลุ่มตัวทำนายเรียกว่า ทั้งมวล- จึงเรียกเทคนิคนี้ว่า การเรียนรู้ทั้งมวล.

ในบทช่วยสอนก่อนหน้านี้ คุณได้เรียนรู้วิธีใช้งานแล้ว ต้นไม้ตัดสินใจ เพื่อทำการทำนายแบบไบนารี เพื่อปรับปรุงเทคนิคของเรา เราสามารถฝึกอบรมกลุ่ม ตัวแยกประเภทต้นไม้การตัดสินใจโดยแต่ละชุดจะอยู่บนเซตย่อยแบบสุ่มที่แตกต่างกันของชุดรถไฟ ในการทำนาย เราเพียงแค่รับคำทำนายของแผนผังบุคคลทั้งหมด จากนั้นทายชั้นเรียนที่ได้รับคะแนนโหวตมากที่สุด เทคนิคนี้เรียกว่า ป่าสุ่ม.

ขั้นตอน 1) นำเข้าข้อมูล

เพื่อให้แน่ใจว่าคุณมีชุดข้อมูลเดียวกันกับในบทช่วยสอน ต้นไม้ตัดสินใจชุดทดสอบรถไฟและชุดทดสอบจะถูกจัดเก็บไว้บนอินเทอร์เน็ต คุณสามารถนำเข้าได้โดยไม่ต้องทำการเปลี่ยนแปลงใดๆ

library(dplyr)
data_train <- read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/train.csv")
glimpse(data_train)
data_test <- read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/test.csv") 
glimpse(data_test)

ขั้นตอนที่ 2) ฝึกโมเดล

วิธีหนึ่งในการประเมินประสิทธิภาพของโมเดลคือการฝึกโมเดลกับชุดข้อมูลขนาดเล็กต่างๆ จำนวนหนึ่ง และประเมินโมเดลเหล่านั้นเหนือชุดการทดสอบขนาดเล็กอื่นๆ สิ่งนี้เรียกว่า การตรวจสอบข้าม F-fold ลักษณะ R มีฟังก์ชันสุ่มแบ่งชุดข้อมูลที่มีขนาดเกือบเท่ากัน ตัวอย่างเช่น ถ้า k=9 โมเดลจะได้รับการประเมินในโฟลเดอร์ทั้งเก้าและทดสอบกับชุดทดสอบที่เหลือ กระบวนการนี้จะถูกทำซ้ำจนกว่าชุดย่อยทั้งหมดจะได้รับการประเมิน เทคนิคนี้ใช้กันอย่างแพร่หลายในการเลือกแบบจำลอง โดยเฉพาะอย่างยิ่งเมื่อแบบจำลองมีพารามิเตอร์ที่ต้องปรับแต่ง

ตอนนี้เรามีวิธีประเมินโมเดลของเราแล้ว เราต้องหาวิธีเลือกพารามิเตอร์ที่สรุปข้อมูลได้ดีที่สุด

ฟอเรสต์สุ่มเลือกชุดย่อยของคุณสมบัติแบบสุ่มและสร้างแผนผังการตัดสินใจจำนวนมาก แบบจำลองจะเฉลี่ยการคาดการณ์ทั้งหมดของแผนผังการตัดสินใจ

ฟอเรสต์สุ่มมีพารามิเตอร์บางตัวที่สามารถเปลี่ยนแปลงได้เพื่อปรับปรุงลักษณะทั่วไปของการทำนาย คุณจะใช้ฟังก์ชัน RandomForest() เพื่อฝึกโมเดล

ไวยากรณ์สำหรับ Randon Forest คือ

RandomForest(formula, ntree=n, mtry=FALSE, maxnodes = NULL)
Arguments:
- Formula: Formula of the fitted model
- ntree: number of trees in the forest
- mtry: Number of candidates draw to feed the algorithm. By default, it is the square of the number of columns.
- maxnodes: Set the maximum amount of terminal nodes in the forest
- importance=TRUE: Whether independent variables importance in the random forest be assessed

หมายเหตุ: สามารถฝึกฟอเรสต์แบบสุ่มได้โดยใช้พารามิเตอร์เพิ่มเติม คุณสามารถอ้างถึง บทความสั้น เพื่อดูพารามิเตอร์ต่างๆ

การปรับแต่งโมเดลเป็นงานที่น่าเบื่อมาก มีการผสมผสานระหว่างพารามิเตอร์ได้หลายอย่าง คุณไม่จำเป็นต้องมีเวลาลองทั้งหมด ทางเลือกที่ดีคือให้เครื่องค้นหาชุดค่าผสมที่ดีที่สุดสำหรับคุณ มีสองวิธีที่ใช้ได้:

  • สุ่มค้นหา
  • ค้นหากริด

เราจะกำหนดทั้งสองวิธี แต่ในระหว่างการสอน เราจะฝึกโมเดลโดยใช้การค้นหาตาราง

คำนิยามการค้นหากริด

วิธีการค้นหาตารางนั้นง่ายดาย โดยโมเดลจะถูกประเมินจากชุดค่าผสมทั้งหมดที่คุณส่งผ่านในฟังก์ชัน โดยใช้การตรวจสอบข้าม

ตัวอย่างเช่น คุณต้องการทดลองโมเดลด้วยจำนวนต้นไม้ 10, 20, 30 ต้น และต้นไม้แต่ละต้นจะถูกทดสอบในจำนวนเมตรเท่ากับ 1, 2, 3, 4, 5 จากนั้นเครื่องจะทดสอบโมเดลที่แตกต่างกัน 15 แบบ:

    .mtry ntrees
 1      1     10
 2      2     10
 3      3     10
 4      4     10
 5      5     10
 6      1     20
 7      2     20
 8      3     20
 9      4     20
 10     5     20
 11     1     30
 12     2     30
 13     3     30
 14     4     30
 15     5     30	

อัลกอริทึมจะประเมิน:

RandomForest(formula, ntree=10, mtry=1)
RandomForest(formula, ntree=10, mtry=2)
RandomForest(formula, ntree=10, mtry=3)
RandomForest(formula, ntree=20, mtry=2)
...

แต่ละครั้ง ฟอเรสต์สุ่มจะทำการทดลองด้วยการตรวจสอบข้าม ข้อบกพร่องประการหนึ่งของการค้นหาตารางคือจำนวนการทดลอง มันสามารถระเบิดได้ง่ายมากเมื่อมีจำนวนการรวมกันสูง เพื่อแก้ไขปัญหานี้ คุณสามารถใช้การค้นหาแบบสุ่มได้

คำจำกัดความการค้นหาแบบสุ่ม

ความแตกต่างที่สำคัญระหว่างการค้นหาแบบสุ่มและการค้นหาแบบกริดคือ การค้นหาแบบสุ่มจะไม่ประเมินค่าไฮเปอร์พารามิเตอร์ทั้งหมดในพื้นที่การค้นหา แต่จะเลือกค่าผสมแบบสุ่มในทุกการวนซ้ำ ข้อดีคือมีต้นทุนการคำนวณที่ต่ำลง

ตั้งค่าพารามิเตอร์การควบคุม

คุณจะดำเนินการดังต่อไปนี้เพื่อสร้างและประเมินแบบจำลอง:

  • ประเมินโมเดลด้วยการตั้งค่าเริ่มต้น
  • ค้นหาจำนวน mtry ที่ดีที่สุด
  • ค้นหาจำนวน Maxnodes ที่ดีที่สุด
  • ค้นหาจำนวน ntree ที่ดีที่สุด
  • ประเมินแบบจำลองบนชุดข้อมูลทดสอบ

ก่อนที่คุณจะเริ่มสำรวจพารามิเตอร์ คุณต้องติดตั้งไลบรารี 2 ไลบรารีก่อน

  • คาเร็ต: ไลบรารีการเรียนรู้ของเครื่อง R ถ้าคุณมี ติดตั้ง R ด้วยความจำเป็นอย่างยิ่ง มันมีอยู่แล้วในห้องสมุด
    • งู: conda ติดตั้ง -cr r-caret
  • e1071: ไลบรารีการเรียนรู้ของเครื่อง R
    • งู: conda ติดตั้ง -cr r-e1071

คุณสามารถนำเข้ามันพร้อมกับ RandomForest

library(randomForest)
library(caret)
library(e1071)

การตั้งค่าเริ่มต้น

การตรวจสอบข้าม K-fold ถูกควบคุมโดยฟังก์ชัน trainControl()

trainControl(method = "cv", number = n, search ="grid")
arguments
- method = "cv": The method used to resample the dataset. 
- number = n: Number of folders to create
- search = "grid": Use the search grid method. For randomized method, use "grid"
Note: You can refer to the vignette to see the other arguments of the function.

คุณสามารถลองรันโมเดลด้วยพารามิเตอร์เริ่มต้นและดูคะแนนความแม่นยำได้

หมายเหตุ: คุณจะใช้การควบคุมเดียวกันระหว่างการฝึกสอนทั้งหมด

# Define the control
trControl <- trainControl(method = "cv",
    number = 10,
    search = "grid")

คุณจะใช้ไลบรารีคาเร็ตเพื่อประเมินโมเดลของคุณ ไลบรารีมีฟังก์ชันหนึ่งที่เรียกว่า train() เพื่อประเมินเกือบทั้งหมด เรียนรู้เครื่อง อัลกอริทึม พูดอีกอย่างก็คือ คุณสามารถใช้ฟังก์ชันนี้เพื่อฝึกอัลกอริทึมอื่นได้

ไวยากรณ์พื้นฐานคือ:

train(formula, df, method = "rf", metric= "Accuracy", trControl = trainControl(), tuneGrid = NULL)
argument
- `formula`: Define the formula of the algorithm
- `method`: Define which model to train. Note, at the end of the tutorial, there is a list of all the models that can be trained
- `metric` = "Accuracy": Define how to select the optimal model
- `trControl = trainControl()`: Define the control parameters
- `tuneGrid = NULL`: Return a data frame with all the possible combination

เรามาลองสร้างโมเดลด้วยค่าเริ่มต้นกันดีกว่า

set.seed(1234)
# Run the model
rf_default <- train(survived~.,
    data = data_train,
    method = "rf",
    metric = "Accuracy",
    trControl = trControl)
# Print the results
print(rf_default)

คำอธิบายรหัส

  • trainControl(method=”cv”, number=10, search=”grid”): ประเมินโมเดลด้วยการค้นหากริดจาก 10 โฟลเดอร์
  • รถไฟ(…): ฝึกโมเดลฟอเรสต์แบบสุ่ม เลือกรุ่นที่ดีที่สุดด้วยการวัดความแม่นยำ

Output:

## Random Forest 
## 
## 836 samples
##   7 predictor
##   2 classes: 'No', 'Yes' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 753, 752, 753, 752, 752, 752, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    2    0.7919248  0.5536486
##    6    0.7811245  0.5391611
##   10    0.7572002  0.4939620
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was mtry = 2.

อัลกอริทึมใช้ต้นไม้ 500 ต้นและทดสอบค่า mtry ที่แตกต่างกันสามค่า: 2, 6, 10

ค่าสุดท้ายที่ใช้สำหรับแบบจำลองคือ mtry = 2 โดยมีความแม่นยำ 0.78 มาพยายามที่จะได้รับคะแนนที่สูงขึ้น

ขั้นตอนที่ 2) ค้นหา mtry ที่ดีที่สุด

คุณสามารถทดสอบโมเดลด้วยค่า mtry ได้ตั้งแต่ 1 ถึง 10

set.seed(1234)
tuneGrid <- expand.grid(.mtry = c(1: 10))
rf_mtry <- train(survived~.,
    data = data_train,
    method = "rf",
    metric = "Accuracy",
    tuneGrid = tuneGrid,
    trControl = trControl,
    importance = TRUE,
    nodesize = 14,
    ntree = 300)
print(rf_mtry)

คำอธิบายรหัส

  • tuneGrid <- expand.grid(.mtry=c(3:10)): สร้างเวกเตอร์ด้วยค่าตั้งแต่ 3:10

ค่าสุดท้ายที่ใช้สำหรับโมเดลคือ mtry = 4

Output:

## Random Forest 
## 
## 836 samples
##   7 predictor
##   2 classes: 'No', 'Yes' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 753, 752, 753, 752, 752, 752, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    1    0.7572576  0.4647368
##    2    0.7979346  0.5662364
##    3    0.8075158  0.5884815
##    4    0.8110729  0.5970664
##    5    0.8074727  0.5900030
##    6    0.8099111  0.5949342
##    7    0.8050918  0.5866415
##    8    0.8050918  0.5855399
##    9    0.8050631  0.5855035
##   10    0.7978916  0.5707336
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was mtry = 4.

ค่าที่ดีที่สุดของ mtry จะถูกเก็บไว้ใน:

rf_mtry$bestTune$mtry

คุณสามารถจัดเก็บและใช้งานได้เมื่อคุณต้องการปรับแต่งพารามิเตอร์อื่นๆ

max(rf_mtry$results$Accuracy)

Output:

## [1] 0.8110729
best_mtry <- rf_mtry$bestTune$mtry 
best_mtry

Output:

## [1] 4

ขั้นตอนที่ 3) ค้นหา Maxnodes ที่ดีที่สุด

คุณต้องสร้างลูปเพื่อประเมินค่าต่างๆ ของ maxnodes ในโค้ดต่อไปนี้ คุณจะ:

  • สร้างรายการ
  • สร้างตัวแปรที่มีค่าที่ดีที่สุดของพารามิเตอร์ mtry ภาคบังคับ
  • สร้างวง
  • เก็บค่าปัจจุบันของ maxnode
  • สรุปผล
store_maxnode <- list()
tuneGrid <- expand.grid(.mtry = best_mtry)
for (maxnodes in c(5: 15)) {
    set.seed(1234)
    rf_maxnode <- train(survived~.,
        data = data_train,
        method = "rf",
        metric = "Accuracy",
        tuneGrid = tuneGrid,
        trControl = trControl,
        importance = TRUE,
        nodesize = 14,
        maxnodes = maxnodes,
        ntree = 300)
    current_iteration <- toString(maxnodes)
    store_maxnode[[current_iteration]] <- rf_maxnode
}
results_mtry <- resamples(store_maxnode)
summary(results_mtry)

คำอธิบายรหัส:

  • store_maxnode <- list(): ผลลัพธ์ของโมเดลจะถูกเก็บไว้ในรายการนี้
  • expand.grid(.mtry=best_mtry): ใช้ค่าที่ดีที่สุดของ mtry
  • for (maxnodes ใน c(15:25)) { … }: คำนวณโมเดลด้วยค่าของ maxnodes เริ่มต้นจาก 15 ถึง 25
  • maxnodes=maxnodes: สำหรับการวนซ้ำแต่ละครั้ง maxnodes จะเท่ากับค่าปัจจุบันของ maxnodes เช่น 15, 16, 17, …
  • คีย์ <- toString(maxnodes): เก็บเป็นตัวแปรสตริงค่าของ maxnode
  • store_maxnode[[key]] <- rf_maxnode: บันทึกผลลัพธ์ของโมเดลในรายการ
  • resamples(store_maxnode): จัดเรียงผลลัพธ์ของโมเดล
  • summary(results_mtry): พิมพ์สรุปของชุดค่าผสมทั้งหมด

Output:

## 
## Call:
## summary.resamples(object = results_mtry)
## 
## Models: 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 
## Number of resamples: 10 
## 
## Accuracy 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 5  0.6785714 0.7529762 0.7903758 0.7799771 0.8168388 0.8433735    0
## 6  0.6904762 0.7648810 0.7784710 0.7811962 0.8125000 0.8313253    0
## 7  0.6904762 0.7619048 0.7738095 0.7788009 0.8102410 0.8333333    0
## 8  0.6904762 0.7627295 0.7844234 0.7847820 0.8184524 0.8433735    0
## 9  0.7261905 0.7747418 0.8083764 0.7955250 0.8258749 0.8333333    0
## 10 0.6904762 0.7837780 0.7904475 0.7895869 0.8214286 0.8433735    0
## 11 0.7023810 0.7791523 0.8024240 0.7943775 0.8184524 0.8433735    0
## 12 0.7380952 0.7910929 0.8144005 0.8051205 0.8288511 0.8452381    0
## 13 0.7142857 0.8005952 0.8192771 0.8075158 0.8403614 0.8452381    0
## 14 0.7380952 0.7941050 0.8203528 0.8098967 0.8403614 0.8452381    0
## 15 0.7142857 0.8000215 0.8203528 0.8075301 0.8378873 0.8554217    0
## 
## Kappa 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 5  0.3297872 0.4640436 0.5459706 0.5270773 0.6068751 0.6717371    0
## 6  0.3576471 0.4981484 0.5248805 0.5366310 0.6031287 0.6480921    0
## 7  0.3576471 0.4927448 0.5192771 0.5297159 0.5996437 0.6508314    0
## 8  0.3576471 0.4848320 0.5408159 0.5427127 0.6200253 0.6717371    0
## 9  0.4236277 0.5074421 0.5859472 0.5601687 0.6228626 0.6480921    0
## 10 0.3576471 0.5255698 0.5527057 0.5497490 0.6204819 0.6717371    0
## 11 0.3794326 0.5235007 0.5783191 0.5600467 0.6126720 0.6717371    0
## 12 0.4460432 0.5480930 0.5999072 0.5808134 0.6296780 0.6717371    0
## 13 0.4014252 0.5725752 0.6087279 0.5875305 0.6576219 0.6678832    0
## 14 0.4460432 0.5585005 0.6117973 0.5911995 0.6590982 0.6717371    0
## 15 0.4014252 0.5689401 0.6117973 0.5867010 0.6507194 0.6955990    0

ค่าสุดท้ายของ maxnode มีความแม่นยำสูงสุด คุณสามารถลองใช้ค่าที่สูงกว่าเพื่อดูว่าคุณจะได้คะแนนที่สูงขึ้นหรือไม่

store_maxnode <- list()
tuneGrid <- expand.grid(.mtry = best_mtry)
for (maxnodes in c(20: 30)) {
    set.seed(1234)
    rf_maxnode <- train(survived~.,
        data = data_train,
        method = "rf",
        metric = "Accuracy",
        tuneGrid = tuneGrid,
        trControl = trControl,
        importance = TRUE,
        nodesize = 14,
        maxnodes = maxnodes,
        ntree = 300)
    key <- toString(maxnodes)
    store_maxnode[[key]] <- rf_maxnode
}
results_node <- resamples(store_maxnode)
summary(results_node)

Output:

## 
## Call:
## summary.resamples(object = results_node)
## 
## Models: 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30 
## Number of resamples: 10 
## 
## Accuracy 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 20 0.7142857 0.7821644 0.8144005 0.8075301 0.8447719 0.8571429    0
## 21 0.7142857 0.8000215 0.8144005 0.8075014 0.8403614 0.8571429    0
## 22 0.7023810 0.7941050 0.8263769 0.8099254 0.8328313 0.8690476    0
## 23 0.7023810 0.7941050 0.8263769 0.8111302 0.8447719 0.8571429    0
## 24 0.7142857 0.7946429 0.8313253 0.8135112 0.8417599 0.8690476    0
## 25 0.7142857 0.7916667 0.8313253 0.8099398 0.8408635 0.8690476    0
## 26 0.7142857 0.7941050 0.8203528 0.8123207 0.8528758 0.8571429    0
## 27 0.7023810 0.8060456 0.8313253 0.8135112 0.8333333 0.8690476    0
## 28 0.7261905 0.7941050 0.8203528 0.8111015 0.8328313 0.8690476    0
## 29 0.7142857 0.7910929 0.8313253 0.8087063 0.8333333 0.8571429    0
## 30 0.6785714 0.7910929 0.8263769 0.8063253 0.8403614 0.8690476    0
## 
## Kappa 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 20 0.3956835 0.5316120 0.5961830 0.5854366 0.6661120 0.6955990    0
## 21 0.3956835 0.5699332 0.5960343 0.5853247 0.6590982 0.6919315    0
## 22 0.3735084 0.5560661 0.6221836 0.5914492 0.6422128 0.7189781    0
## 23 0.3735084 0.5594228 0.6228827 0.5939786 0.6657372 0.6955990    0
## 24 0.3956835 0.5600352 0.6337821 0.5992188 0.6604703 0.7189781    0
## 25 0.3956835 0.5530760 0.6354875 0.5912239 0.6554912 0.7189781    0
## 26 0.3956835 0.5589331 0.6136074 0.5969142 0.6822128 0.6955990    0
## 27 0.3735084 0.5852459 0.6368425 0.5998148 0.6426088 0.7189781    0
## 28 0.4290780 0.5589331 0.6154905 0.5946859 0.6356141 0.7189781    0
## 29 0.4070588 0.5534173 0.6337821 0.5901173 0.6423101 0.6919315    0
## 30 0.3297872 0.5534173 0.6202632 0.5843432 0.6590982 0.7189781    0

คะแนนความแม่นยำสูงสุดจะได้มาโดยมีค่า maxnode เท่ากับ 22

ขั้นตอนที่ 4) ค้นหา ntrees ที่ดีที่สุด

ตอนนี้คุณมีค่า mtry และ maxnode ที่ดีที่สุดแล้ว คุณสามารถปรับแต่งจำนวนต้นไม้ได้ วิธีการนี้เหมือนกับ maxnode ทุกประการ

store_maxtrees <- list()
for (ntree in c(250, 300, 350, 400, 450, 500, 550, 600, 800, 1000, 2000)) {
    set.seed(5678)
    rf_maxtrees <- train(survived~.,
        data = data_train,
        method = "rf",
        metric = "Accuracy",
        tuneGrid = tuneGrid,
        trControl = trControl,
        importance = TRUE,
        nodesize = 14,
        maxnodes = 24,
        ntree = ntree)
    key <- toString(ntree)
    store_maxtrees[[key]] <- rf_maxtrees
}
results_tree <- resamples(store_maxtrees)
summary(results_tree)

Output:

## 
## Call:
## summary.resamples(object = results_tree)
## 
## Models: 250, 300, 350, 400, 450, 500, 550, 600, 800, 1000, 2000 
## Number of resamples: 10 
## 
## Accuracy 
##           Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 250  0.7380952 0.7976190 0.8083764 0.8087010 0.8292683 0.8674699    0
## 300  0.7500000 0.7886905 0.8024240 0.8027199 0.8203397 0.8452381    0
## 350  0.7500000 0.7886905 0.8024240 0.8027056 0.8277623 0.8452381    0
## 400  0.7500000 0.7886905 0.8083764 0.8051009 0.8292683 0.8452381    0
## 450  0.7500000 0.7886905 0.8024240 0.8039104 0.8292683 0.8452381    0
## 500  0.7619048 0.7886905 0.8024240 0.8062914 0.8292683 0.8571429    0
## 550  0.7619048 0.7886905 0.8083764 0.8099062 0.8323171 0.8571429    0
## 600  0.7619048 0.7886905 0.8083764 0.8099205 0.8323171 0.8674699    0
## 800  0.7619048 0.7976190 0.8083764 0.8110820 0.8292683 0.8674699    0
## 1000 0.7619048 0.7976190 0.8121510 0.8086723 0.8303571 0.8452381    0
## 2000 0.7619048 0.7886905 0.8121510 0.8086723 0.8333333 0.8452381    0
## 
## Kappa 
##           Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 250  0.4061697 0.5667400 0.5836013 0.5856103 0.6335363 0.7196807    0
## 300  0.4302326 0.5449376 0.5780349 0.5723307 0.6130767 0.6710843    0
## 350  0.4302326 0.5449376 0.5780349 0.5723185 0.6291592 0.6710843    0
## 400  0.4302326 0.5482030 0.5836013 0.5774782 0.6335363 0.6710843    0
## 450  0.4302326 0.5449376 0.5780349 0.5750587 0.6335363 0.6710843    0
## 500  0.4601542 0.5449376 0.5780349 0.5804340 0.6335363 0.6949153    0
## 550  0.4601542 0.5482030 0.5857118 0.5884507 0.6396872 0.6949153    0
## 600  0.4601542 0.5482030 0.5857118 0.5884374 0.6396872 0.7196807    0
## 800  0.4601542 0.5667400 0.5836013 0.5910088 0.6335363 0.7196807    0
## 1000 0.4601542 0.5667400 0.5961590 0.5857446 0.6343666 0.6678832    0
## 2000 0.4601542 0.5482030 0.5961590 0.5862151 0.6440678 0.6656337    0

คุณมีโมเดลสุดท้ายของคุณแล้ว คุณสามารถฝึกป่าสุ่มด้วยพารามิเตอร์ต่อไปนี้:

  • ntree =800: อบรมต้นไม้ 800 ต้น
  • mtry=4: เลือกคุณสมบัติ 4 รายการสำหรับการวนซ้ำแต่ละครั้ง
  • maxnodes = 24: สูงสุด 24 โหนดในโหนดเทอร์มินัล (ใบ)
fit_rf <- train(survived~.,
    data_train,
    method = "rf",
    metric = "Accuracy",
    tuneGrid = tuneGrid,
    trControl = trControl,
    importance = TRUE,
    nodesize = 14,
    ntree = 800,
    maxnodes = 24)

ขั้นตอนที่ 5) ประเมินแบบจำลอง

คาเร็ตของห้องสมุดมีฟังก์ชันในการทำนาย

predict(model, newdata= df)
argument
- `model`: Define the model evaluated before. 
- `newdata`: Define the dataset to make prediction
prediction <-predict(fit_rf, data_test)

คุณสามารถใช้การทำนายเพื่อคำนวณเมทริกซ์ความสับสนและดูคะแนนความแม่นยำได้

confusionMatrix(prediction, data_test$survived)

Output:

## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  No Yes
##        No  110  32
##        Yes  11  56
##                                          
##                Accuracy : 0.7943         
##                  95% CI : (0.733, 0.8469)
##     No Information Rate : 0.5789         
##     P-Value [Acc > NIR] : 3.959e-11      
##                                          
##                   Kappa : 0.5638         
##  Mcnemar's Test P-Value : 0.002289       
##                                          
##             Sensitivity : 0.9091         
##             Specificity : 0.6364         
##          Pos Pred Value : 0.7746         
##          Neg Pred Value : 0.8358         
##              Prevalence : 0.5789         
##          Detection Rate : 0.5263         
##    Detection Prevalence : 0.6794         
##       Balanced Accuracy : 0.7727         
##                                          
##        'Positive' Class : No             
## 

คุณมีความแม่นยำร้อยละ 0.7943 ซึ่งสูงกว่าค่าเริ่มต้น

ขั้นตอนที่ 6) เห็นภาพผลลัพธ์

สุดท้ายนี้ คุณสามารถดูความสำคัญของฟีเจอร์ได้ด้วยฟังก์ชัน varImp() ดูเหมือนว่าลักษณะที่สำคัญที่สุดคือเพศและอายุ จึงไม่น่าแปลกใจเพราะลักษณะที่สำคัญมักจะปรากฏใกล้กับโคนต้นไม้ ในขณะที่ลักษณะที่สำคัญน้อยกว่ามักจะปรากฏใกล้กับใบ

varImpPlot(fit_rf)

Output:

varImp(fit_rf)
## rf variable importance
## 
##              Importance
## sexmale         100.000
## age              28.014
## pclassMiddle     27.016
## fare             21.557
## pclassUpper      16.324
## sibsp            11.246
## parch             5.522
## embarkedC         4.908
## embarkedQ         1.420
## embarkedS         0.000		

สรุป

เราสามารถสรุปวิธีการฝึกและประเมินป่าสุ่มได้จากตารางด้านล่างนี้:

ห้องสมุด วัตถุประสงค์ ฟังก์ชัน พารามิเตอร์
ป่าสุ่ม สร้างป่าสุ่ม สุ่มป่า() สูตร ntree=n, mtry=FALSE, maxnodes = NULL
คาเร็ต สร้างการตรวจสอบข้ามโฟลเดอร์ K รถไฟควบคุม() method = “cv”, number = n, ค้นหา =”grid”
คาเร็ต ฝึกฝนป่าสุ่ม รถไฟ() สูตร, df, วิธีการ = “rf”, เมตริก = “ความแม่นยำ”, trControl = trainControl(), tuneGrid = NULL
คาเร็ต ทำนายจากตัวอย่าง คาดการณ์ โมเดล newdata= df
คาเร็ต เมทริกซ์ความสับสนและสถิติ ความสับสนเมทริกซ์() แบบทดสอบและแบบทดสอบ
คาเร็ต ความสำคัญของตัวแปร cvarImp() แบบ

ภาคผนวก

รายการรุ่นที่ใช้ในคาเร็ต

names>(getModelInfo())

Output:

##   [1] "ada"                 "AdaBag"              "AdaBoost.M1"        ##   [4] "adaboost"            "amdai"               "ANFIS"              ##   [7] "avNNet"              "awnb"                "awtan"              ##  [10] "bag"                 "bagEarth"            "bagEarthGCV"        ##  [13] "bagFDA"              "bagFDAGCV"           "bam"                ##  [16] "bartMachine"         "bayesglm"            "binda"              ##  [19] "blackboost"          "blasso"              "blassoAveraged"     ##  [22] "bridge"              "brnn"                "BstLm"              ##  [25] "bstSm"               "bstTree"             "C5.0"               ##  [28] "C5.0Cost"            "C5.0Rules"           "C5.0Tree"           ##  [31] "cforest"             "chaid"               "CSimca"             ##  [34] "ctree"               "ctree2"              "cubist"             ##  [37] "dda"                 "deepboost"           "DENFIS"             ##  [40] "dnn"                 "dwdLinear"           "dwdPoly"            ##  [43] "dwdRadial"           "earth"               "elm"                ##  [46] "enet"                "evtree"              "extraTrees"         ##  [49] "fda"                 "FH.GBML"             "FIR.DM"             ##  [52] "foba"                "FRBCS.CHI"           "FRBCS.W"            ##  [55] "FS.HGD"              "gam"                 "gamboost"           ##  [58] "gamLoess"            "gamSpline"           "gaussprLinear"      ##  [61] "gaussprPoly"         "gaussprRadial"       "gbm_h3o"            ##  [64] "gbm"                 "gcvEarth"            "GFS.FR.MOGUL"       ##  [67] "GFS.GCCL"            "GFS.LT.RS"           "GFS.THRIFT"         ##  [70] "glm.nb"              "glm"                 "glmboost"           ##  [73] "glmnet_h3o"          "glmnet"              "glmStepAIC"         ##  [76] "gpls"                "hda"                 "hdda"               ##  [79] "hdrda"               "HYFIS"               "icr"                ##  [82] "J48"                 "JRip"                "kernelpls"          ##  [85] "kknn"                "knn"                 "krlsPoly"           ##  [88] "krlsRadial"          "lars"                "lars2"              ##  [91] "lasso"               "lda"                 "lda2"               ##  [94] "leapBackward"        "leapForward"         "leapSeq"            ##  [97] "Linda"               "lm"                  "lmStepAIC"          ## [100] "LMT"                 "loclda"              "logicBag"           ## [103] "LogitBoost"          "logreg"              "lssvmLinear"        ## [106] "lssvmPoly"           "lssvmRadial"         "lvq"                ## [109] "M5"                  "M5Rules"             "manb"               ## [112] "mda"                 "Mlda"                "mlp"                ## [115] "mlpKerasDecay"       "mlpKerasDecayCost"   "mlpKerasDropout"    ## [118] "mlpKerasDropoutCost" "mlpML"               "mlpSGD"             ## [121] "mlpWeightDecay"      "mlpWeightDecayML"    "monmlp"             ## [124] "msaenet"             "multinom"            "mxnet"              ## [127] "mxnetAdam"           "naive_bayes"         "nb"                 ## [130] "nbDiscrete"          "nbSearch"            "neuralnet"          ## [133] "nnet"                "nnls"                "nodeHarvest"        ## [136] "null"                "OneR"                "ordinalNet"         ## [139] "ORFlog"              "ORFpls"              "ORFridge"           ## [142] "ORFsvm"              "ownn"                "pam"                ## [145] "parRF"               "PART"                "partDSA"            ## [148] "pcaNNet"             "pcr"                 "pda"                ## [151] "pda2"                "penalized"           "PenalizedLDA"       ## [154] "plr"                 "pls"                 "plsRglm"            ## [157] "polr"                "ppr"                 "PRIM"               ## [160] "protoclass"          "pythonKnnReg"        "qda"                ## [163] "QdaCov"              "qrf"                 "qrnn"               ## [166] "randomGLM"           "ranger"              "rbf"                ## [169] "rbfDDA"              "Rborist"             "rda"                ## [172] "regLogistic"         "relaxo"              "rf"                 ## [175] "rFerns"              "RFlda"               "rfRules"            ## [178] "ridge"               "rlda"                "rlm"                ## [181] "rmda"                "rocc"                "rotationForest"     ## [184] "rotationForestCp"    "rpart"               "rpart1SE"           ## [187] "rpart2"              "rpartCost"           "rpartScore"         ## [190] "rqlasso"             "rqnc"                "RRF"                ## [193] "RRFglobal"           "rrlda"               "RSimca"             ## [196] "rvmLinear"           "rvmPoly"             "rvmRadial"          ## [199] "SBC"                 "sda"                 "sdwd"               ## [202] "simpls"              "SLAVE"               "slda"               ## [205] "smda"                "snn"                 "sparseLDA"          ## [208] "spikeslab"           "spls"                "stepLDA"            ## [211] "stepQDA"             "superpc"             "svmBoundrangeString"## [214] "svmExpoString"       "svmLinear"           "svmLinear2"         ## [217] "svmLinear3"          "svmLinearWeights"    "svmLinearWeights2"  ## [220] "svmPoly"             "svmRadial"           "svmRadialCost"      ## [223] "svmRadialSigma"      "svmRadialWeights"    "svmSpectrumString"  ## [226] "tan"                 "tanSearch"           "treebag"            ## [229] "vbmpRadial"          "vglmAdjCat"          "vglmContRatio"      ## [232] "vglmCumulative"      "widekernelpls"       "WM"                 ## [235] "wsrf"                "xgbLinear"           "xgbTree"            ## [238] "xyf"