Hướng dẫn về rừng ngẫu nhiên R với ví dụ

Rừng ngẫu nhiên trong R là gì?

Rừng ngẫu nhiên được xây dựng dựa trên một ý tưởng đơn giản: 'sự khôn ngoan của đám đông'. Tổng hợp các kết quả của nhiều yếu tố dự đoán cho kết quả dự đoán tốt hơn so với yếu tố dự đoán riêng lẻ tốt nhất. Một nhóm các yếu tố dự đoán được gọi là toàn thể. Vì vậy, kỹ thuật này được gọi là học hòa tấu.

Trong phần hướng dẫn trước, bạn đã học cách sử dụng Cây quyết định để thực hiện một dự đoán nhị phân. Để cải thiện kỹ thuật của mình, chúng tôi có thể đào tạo một nhóm Trình phân loại cây quyết định, mỗi tập hợp con ngẫu nhiên khác nhau của tập hợp tàu. Để đưa ra dự đoán, chúng ta chỉ cần lấy dự đoán của tất cả các cây riêng lẻ, sau đó dự đoán lớp nhận được nhiều phiếu bầu nhất. Kỹ thuật này được gọi là Rừng ngẫu nhiên.

Bước 1) Nhập dữ liệu

Để đảm bảo bạn có cùng tập dữ liệu như trong hướng dẫn dành cho cây quyết định, bài kiểm tra tàu và bộ bài kiểm tra được lưu trữ trên internet. Bạn có thể nhập chúng mà không thực hiện bất kỳ thay đổi nào.

library(dplyr)
data_train <- read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/train.csv")
glimpse(data_train)
data_test <- read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/test.csv") 
glimpse(data_test)

Bước 2) Huấn luyện mô hình

Một cách để đánh giá hiệu suất của một mô hình là huấn luyện nó trên một số tập dữ liệu nhỏ hơn khác nhau và đánh giá chúng qua tập thử nghiệm nhỏ hơn khác. Đây được gọi là Xác thực chéo F-Fold tính năng này. R có chức năng phân chia ngẫu nhiên số lượng tập dữ liệu có kích thước gần như nhau. Ví dụ: nếu k=9, mô hình được đánh giá trên chín thư mục và được kiểm tra trên bộ kiểm tra còn lại. Quá trình này được lặp lại cho đến khi tất cả các tập con đã được đánh giá. Kỹ thuật này được sử dụng rộng rãi để lựa chọn mô hình, đặc biệt khi mô hình có các tham số cần điều chỉnh.

Bây giờ chúng ta đã có cách đánh giá mô hình của mình, chúng ta cần tìm ra cách chọn các tham số tổng quát hóa dữ liệu tốt nhất.

Rừng ngẫu nhiên chọn một tập hợp con các tính năng ngẫu nhiên và xây dựng nhiều Cây quyết định. Mô hình tính trung bình tất cả các dự đoán của cây Quyết định.

Rừng ngẫu nhiên có một số tham số có thể được thay đổi để cải thiện tính tổng quát của dự đoán. Bạn sẽ sử dụng hàm RandomForest() để huấn luyện mô hình.

Cú pháp của Rừng Randon là

RandomForest(formula, ntree=n, mtry=FALSE, maxnodes = NULL)
Arguments:
- Formula: Formula of the fitted model
- ntree: number of trees in the forest
- mtry: Number of candidates draw to feed the algorithm. By default, it is the square of the number of columns.
- maxnodes: Set the maximum amount of terminal nodes in the forest
- importance=TRUE: Whether independent variables importance in the random forest be assessed

Lưu ý: Rừng ngẫu nhiên có thể được huấn luyện trên nhiều thông số hơn. Bạn có thể tham khảo các họa tiết để xem các thông số khác nhau.

Điều chỉnh một mô hình là công việc rất tẻ nhạt. Có rất nhiều sự kết hợp có thể có giữa các tham số. Bạn không nhất thiết phải có thời gian để thử tất cả chúng. Một giải pháp thay thế tốt là để máy tìm ra sự kết hợp tốt nhất cho bạn. Có hai phương pháp có sẵn:

  • Tìm kiếm ngẫu nhiên
  • Tìm kiếm lưới

Chúng tôi sẽ xác định cả hai phương pháp nhưng trong hướng dẫn, chúng tôi sẽ huấn luyện mô hình bằng cách sử dụng tìm kiếm lưới

Định nghĩa tìm kiếm lưới

Phương pháp tìm kiếm dạng lưới rất đơn giản, mô hình sẽ được đánh giá dựa trên tất cả sự kết hợp mà bạn chuyển vào hàm, sử dụng xác thực chéo.

Ví dụ: bạn muốn thử mô hình với số lượng 10, 20, 30 cây và mỗi cây sẽ được thử nghiệm trên một số mtry bằng 1, 2, 3, 4, 5. Sau đó máy sẽ thử nghiệm 15 mô hình khác nhau:

    .mtry ntrees
 1      1     10
 2      2     10
 3      3     10
 4      4     10
 5      5     10
 6      1     20
 7      2     20
 8      3     20
 9      4     20
 10     5     20
 11     1     30
 12     2     30
 13     3     30
 14     4     30
 15     5     30	

Thuật toán sẽ đánh giá:

RandomForest(formula, ntree=10, mtry=1)
RandomForest(formula, ntree=10, mtry=2)
RandomForest(formula, ntree=10, mtry=3)
RandomForest(formula, ntree=20, mtry=2)
...

Mỗi lần, các thử nghiệm rừng ngẫu nhiên đều được xác thực chéo. Một thiếu sót của tìm kiếm lưới là số lượng thử nghiệm. Nó có thể trở nên rất dễ nổ khi số lượng kết hợp cao. Để khắc phục vấn đề này, bạn có thể sử dụng tìm kiếm ngẫu nhiên

Định nghĩa tìm kiếm ngẫu nhiên

Sự khác biệt lớn giữa tìm kiếm ngẫu nhiên và tìm kiếm lưới là, tìm kiếm ngẫu nhiên sẽ không đánh giá tất cả các kết hợp siêu tham số trong không gian tìm kiếm. Thay vào đó, nó sẽ chọn ngẫu nhiên kết hợp ở mỗi lần lặp. Ưu điểm là nó làm giảm chi phí tính toán.

Đặt tham số điều khiển

Bạn sẽ tiến hành như sau để xây dựng và đánh giá mô hình:

  • Đánh giá mô hình với cài đặt mặc định
  • Tìm số mtry tốt nhất
  • Tìm số lượng maxnode tốt nhất
  • Tìm số ntree tốt nhất
  • Đánh giá mô hình trên tập dữ liệu thử nghiệm

Trước khi bắt đầu khám phá tham số, bạn cần cài đặt hai thư viện.

Bạn có thể nhập chúng cùng với RandomForest

library(randomForest)
library(caret)
library(e1071)

Thiết lập mặc định

Xác thực chéo K-Fold được kiểm soát bởi hàm trainControl()

trainControl(method = "cv", number = n, search ="grid")
arguments
- method = "cv": The method used to resample the dataset. 
- number = n: Number of folders to create
- search = "grid": Use the search grid method. For randomized method, use "grid"
Note: You can refer to the vignette to see the other arguments of the function.

Bạn có thể thử chạy mô hình với các tham số mặc định và xem điểm chính xác.

Lưu ý: Bạn sẽ sử dụng các điều khiển tương tự trong suốt phần hướng dẫn.

# Define the control
trControl <- trainControl(method = "cv",
    number = 10,
    search = "grid")

Bạn sẽ sử dụng thư viện dấu mũ để đánh giá mô hình của mình. Thư viện có một hàm gọi là train() để đánh giá hầu hết tất cả học máy thuật toán. Nói cách khác, bạn có thể dùng hàm này để huấn luyện các thuật toán khác.

Cú pháp cơ bản là:

train(formula, df, method = "rf", metric= "Accuracy", trControl = trainControl(), tuneGrid = NULL)
argument
- `formula`: Define the formula of the algorithm
- `method`: Define which model to train. Note, at the end of the tutorial, there is a list of all the models that can be trained
- `metric` = "Accuracy": Define how to select the optimal model
- `trControl = trainControl()`: Define the control parameters
- `tuneGrid = NULL`: Return a data frame with all the possible combination

Hãy thử xây dựng mô hình với các giá trị mặc định.

set.seed(1234)
# Run the model
rf_default <- train(survived~.,
    data = data_train,
    method = "rf",
    metric = "Accuracy",
    trControl = trControl)
# Print the results
print(rf_default)

Giải thích mã

  • trainControl(method=”cv”, number=10, search=”grid”): Đánh giá mô hình bằng cách tìm kiếm dạng lưới 10 thư mục
  • train(…): Huấn luyện mô hình rừng ngẫu nhiên. Mô hình tốt nhất được chọn với thước đo chính xác.

Đầu ra:

## Random Forest 
## 
## 836 samples
##   7 predictor
##   2 classes: 'No', 'Yes' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 753, 752, 753, 752, 752, 752, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    2    0.7919248  0.5536486
##    6    0.7811245  0.5391611
##   10    0.7572002  0.4939620
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was mtry = 2.

Thuật toán sử dụng 500 cây và thử nghiệm ba giá trị khác nhau của mtry: 2, 6, 10.

Giá trị cuối cùng được sử dụng cho mô hình là mtry = 2 với độ chính xác là 0.78. Hãy cố gắng đạt được điểm cao hơn.

Bước 2) Tìm kiếm mtry tốt nhất

Bạn có thể kiểm tra mô hình với các giá trị mtry từ 1 đến 10

set.seed(1234)
tuneGrid <- expand.grid(.mtry = c(1: 10))
rf_mtry <- train(survived~.,
    data = data_train,
    method = "rf",
    metric = "Accuracy",
    tuneGrid = tuneGrid,
    trControl = trControl,
    importance = TRUE,
    nodesize = 14,
    ntree = 300)
print(rf_mtry)

Giải thích mã

  • tuneGrid <- Expand.grid(.mtry=c(3:10)): Xây dựng một vectơ có giá trị từ 3:10

Giá trị cuối cùng được sử dụng cho mô hình là mtry = 4.

Đầu ra:

## Random Forest 
## 
## 836 samples
##   7 predictor
##   2 classes: 'No', 'Yes' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 753, 752, 753, 752, 752, 752, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    1    0.7572576  0.4647368
##    2    0.7979346  0.5662364
##    3    0.8075158  0.5884815
##    4    0.8110729  0.5970664
##    5    0.8074727  0.5900030
##    6    0.8099111  0.5949342
##    7    0.8050918  0.5866415
##    8    0.8050918  0.5855399
##    9    0.8050631  0.5855035
##   10    0.7978916  0.5707336
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was mtry = 4.

Giá trị tốt nhất của mtry được lưu trữ trong:

rf_mtry$bestTune$mtry

Bạn có thể lưu trữ và sử dụng khi cần điều chỉnh các thông số khác.

max(rf_mtry$results$Accuracy)

Đầu ra:

## [1] 0.8110729
best_mtry <- rf_mtry$bestTune$mtry 
best_mtry

Đầu ra:

## [1] 4

Bước 3) Tìm kiếm các maxnode tốt nhất

Bạn cần tạo một vòng lặp để đánh giá các giá trị khác nhau của maxnodes. Trong đoạn mã sau, bạn sẽ:

  • Tạo một danh sách
  • Tạo một biến có giá trị tốt nhất của tham số mtry; Bắt buộc
  • Tạo vòng lặp
  • Lưu trữ giá trị hiện tại của maxnode
  • Tóm tắt kết quả
store_maxnode <- list()
tuneGrid <- expand.grid(.mtry = best_mtry)
for (maxnodes in c(5: 15)) {
    set.seed(1234)
    rf_maxnode <- train(survived~.,
        data = data_train,
        method = "rf",
        metric = "Accuracy",
        tuneGrid = tuneGrid,
        trControl = trControl,
        importance = TRUE,
        nodesize = 14,
        maxnodes = maxnodes,
        ntree = 300)
    current_iteration <- toString(maxnodes)
    store_maxnode[[current_iteration]] <- rf_maxnode
}
results_mtry <- resamples(store_maxnode)
summary(results_mtry)

Giải thích mã:

  • store_maxnode <- list(): Kết quả của mô hình sẽ được lưu trữ trong danh sách này
  • Expand.grid(.mtry=best_mtry): Sử dụng giá trị tốt nhất của mtry
  • for (maxnodes in c(15:25)) { … }: Tính toán mô hình với các giá trị của maxnodes bắt đầu từ 15 đến 25.
  • maxnodes=maxnodes: Đối với mỗi lần lặp, maxnodes bằng giá trị hiện tại của maxnodes. tức là 15, 16, 17,…
  • key <- toString(maxnodes): Lưu trữ giá trị của maxnode dưới dạng biến chuỗi.
  • store_maxnode[[key]] <- rf_maxnode: Lưu kết quả của mô hình vào danh sách.
  • resamples(store_maxnode): Sắp xếp kết quả của mô hình
  • tóm tắt (kết quả_mtry): In tóm tắt của tất cả sự kết hợp.

Đầu ra:

## 
## Call:
## summary.resamples(object = results_mtry)
## 
## Models: 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 
## Number of resamples: 10 
## 
## Accuracy 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 5  0.6785714 0.7529762 0.7903758 0.7799771 0.8168388 0.8433735    0
## 6  0.6904762 0.7648810 0.7784710 0.7811962 0.8125000 0.8313253    0
## 7  0.6904762 0.7619048 0.7738095 0.7788009 0.8102410 0.8333333    0
## 8  0.6904762 0.7627295 0.7844234 0.7847820 0.8184524 0.8433735    0
## 9  0.7261905 0.7747418 0.8083764 0.7955250 0.8258749 0.8333333    0
## 10 0.6904762 0.7837780 0.7904475 0.7895869 0.8214286 0.8433735    0
## 11 0.7023810 0.7791523 0.8024240 0.7943775 0.8184524 0.8433735    0
## 12 0.7380952 0.7910929 0.8144005 0.8051205 0.8288511 0.8452381    0
## 13 0.7142857 0.8005952 0.8192771 0.8075158 0.8403614 0.8452381    0
## 14 0.7380952 0.7941050 0.8203528 0.8098967 0.8403614 0.8452381    0
## 15 0.7142857 0.8000215 0.8203528 0.8075301 0.8378873 0.8554217    0
## 
## Kappa 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 5  0.3297872 0.4640436 0.5459706 0.5270773 0.6068751 0.6717371    0
## 6  0.3576471 0.4981484 0.5248805 0.5366310 0.6031287 0.6480921    0
## 7  0.3576471 0.4927448 0.5192771 0.5297159 0.5996437 0.6508314    0
## 8  0.3576471 0.4848320 0.5408159 0.5427127 0.6200253 0.6717371    0
## 9  0.4236277 0.5074421 0.5859472 0.5601687 0.6228626 0.6480921    0
## 10 0.3576471 0.5255698 0.5527057 0.5497490 0.6204819 0.6717371    0
## 11 0.3794326 0.5235007 0.5783191 0.5600467 0.6126720 0.6717371    0
## 12 0.4460432 0.5480930 0.5999072 0.5808134 0.6296780 0.6717371    0
## 13 0.4014252 0.5725752 0.6087279 0.5875305 0.6576219 0.6678832    0
## 14 0.4460432 0.5585005 0.6117973 0.5911995 0.6590982 0.6717371    0
## 15 0.4014252 0.5689401 0.6117973 0.5867010 0.6507194 0.6955990    0

Giá trị cuối cùng của maxnode có độ chính xác cao nhất. Bạn có thể thử với giá trị cao hơn để xem liệu bạn có thể đạt điểm cao hơn không.

store_maxnode <- list()
tuneGrid <- expand.grid(.mtry = best_mtry)
for (maxnodes in c(20: 30)) {
    set.seed(1234)
    rf_maxnode <- train(survived~.,
        data = data_train,
        method = "rf",
        metric = "Accuracy",
        tuneGrid = tuneGrid,
        trControl = trControl,
        importance = TRUE,
        nodesize = 14,
        maxnodes = maxnodes,
        ntree = 300)
    key <- toString(maxnodes)
    store_maxnode[[key]] <- rf_maxnode
}
results_node <- resamples(store_maxnode)
summary(results_node)

Đầu ra:

## 
## Call:
## summary.resamples(object = results_node)
## 
## Models: 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30 
## Number of resamples: 10 
## 
## Accuracy 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 20 0.7142857 0.7821644 0.8144005 0.8075301 0.8447719 0.8571429    0
## 21 0.7142857 0.8000215 0.8144005 0.8075014 0.8403614 0.8571429    0
## 22 0.7023810 0.7941050 0.8263769 0.8099254 0.8328313 0.8690476    0
## 23 0.7023810 0.7941050 0.8263769 0.8111302 0.8447719 0.8571429    0
## 24 0.7142857 0.7946429 0.8313253 0.8135112 0.8417599 0.8690476    0
## 25 0.7142857 0.7916667 0.8313253 0.8099398 0.8408635 0.8690476    0
## 26 0.7142857 0.7941050 0.8203528 0.8123207 0.8528758 0.8571429    0
## 27 0.7023810 0.8060456 0.8313253 0.8135112 0.8333333 0.8690476    0
## 28 0.7261905 0.7941050 0.8203528 0.8111015 0.8328313 0.8690476    0
## 29 0.7142857 0.7910929 0.8313253 0.8087063 0.8333333 0.8571429    0
## 30 0.6785714 0.7910929 0.8263769 0.8063253 0.8403614 0.8690476    0
## 
## Kappa 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 20 0.3956835 0.5316120 0.5961830 0.5854366 0.6661120 0.6955990    0
## 21 0.3956835 0.5699332 0.5960343 0.5853247 0.6590982 0.6919315    0
## 22 0.3735084 0.5560661 0.6221836 0.5914492 0.6422128 0.7189781    0
## 23 0.3735084 0.5594228 0.6228827 0.5939786 0.6657372 0.6955990    0
## 24 0.3956835 0.5600352 0.6337821 0.5992188 0.6604703 0.7189781    0
## 25 0.3956835 0.5530760 0.6354875 0.5912239 0.6554912 0.7189781    0
## 26 0.3956835 0.5589331 0.6136074 0.5969142 0.6822128 0.6955990    0
## 27 0.3735084 0.5852459 0.6368425 0.5998148 0.6426088 0.7189781    0
## 28 0.4290780 0.5589331 0.6154905 0.5946859 0.6356141 0.7189781    0
## 29 0.4070588 0.5534173 0.6337821 0.5901173 0.6423101 0.6919315    0
## 30 0.3297872 0.5534173 0.6202632 0.5843432 0.6590982 0.7189781    0

Điểm chính xác cao nhất đạt được với giá trị maxnode bằng 22.

Bước 4) Tìm kiếm ntrees tốt nhất

Bây giờ bạn đã có giá trị tốt nhất của mtry và maxnode, bạn có thể điều chỉnh số lượng cây. Phương pháp này hoàn toàn giống với maxnode.

store_maxtrees <- list()
for (ntree in c(250, 300, 350, 400, 450, 500, 550, 600, 800, 1000, 2000)) {
    set.seed(5678)
    rf_maxtrees <- train(survived~.,
        data = data_train,
        method = "rf",
        metric = "Accuracy",
        tuneGrid = tuneGrid,
        trControl = trControl,
        importance = TRUE,
        nodesize = 14,
        maxnodes = 24,
        ntree = ntree)
    key <- toString(ntree)
    store_maxtrees[[key]] <- rf_maxtrees
}
results_tree <- resamples(store_maxtrees)
summary(results_tree)

Đầu ra:

## 
## Call:
## summary.resamples(object = results_tree)
## 
## Models: 250, 300, 350, 400, 450, 500, 550, 600, 800, 1000, 2000 
## Number of resamples: 10 
## 
## Accuracy 
##           Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 250  0.7380952 0.7976190 0.8083764 0.8087010 0.8292683 0.8674699    0
## 300  0.7500000 0.7886905 0.8024240 0.8027199 0.8203397 0.8452381    0
## 350  0.7500000 0.7886905 0.8024240 0.8027056 0.8277623 0.8452381    0
## 400  0.7500000 0.7886905 0.8083764 0.8051009 0.8292683 0.8452381    0
## 450  0.7500000 0.7886905 0.8024240 0.8039104 0.8292683 0.8452381    0
## 500  0.7619048 0.7886905 0.8024240 0.8062914 0.8292683 0.8571429    0
## 550  0.7619048 0.7886905 0.8083764 0.8099062 0.8323171 0.8571429    0
## 600  0.7619048 0.7886905 0.8083764 0.8099205 0.8323171 0.8674699    0
## 800  0.7619048 0.7976190 0.8083764 0.8110820 0.8292683 0.8674699    0
## 1000 0.7619048 0.7976190 0.8121510 0.8086723 0.8303571 0.8452381    0
## 2000 0.7619048 0.7886905 0.8121510 0.8086723 0.8333333 0.8452381    0
## 
## Kappa 
##           Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 250  0.4061697 0.5667400 0.5836013 0.5856103 0.6335363 0.7196807    0
## 300  0.4302326 0.5449376 0.5780349 0.5723307 0.6130767 0.6710843    0
## 350  0.4302326 0.5449376 0.5780349 0.5723185 0.6291592 0.6710843    0
## 400  0.4302326 0.5482030 0.5836013 0.5774782 0.6335363 0.6710843    0
## 450  0.4302326 0.5449376 0.5780349 0.5750587 0.6335363 0.6710843    0
## 500  0.4601542 0.5449376 0.5780349 0.5804340 0.6335363 0.6949153    0
## 550  0.4601542 0.5482030 0.5857118 0.5884507 0.6396872 0.6949153    0
## 600  0.4601542 0.5482030 0.5857118 0.5884374 0.6396872 0.7196807    0
## 800  0.4601542 0.5667400 0.5836013 0.5910088 0.6335363 0.7196807    0
## 1000 0.4601542 0.5667400 0.5961590 0.5857446 0.6343666 0.6678832    0
## 2000 0.4601542 0.5482030 0.5961590 0.5862151 0.6440678 0.6656337    0

Bạn đã có mô hình cuối cùng. Bạn có thể huấn luyện rừng ngẫu nhiên với các tham số sau:

  • ntree =800: 800 cây sẽ được huấn luyện
  • mtry=4: 4 tính năng được chọn cho mỗi lần lặp
  • maxnodes = 24: Tối đa 24 nút trong các nút đầu cuối (lá)
fit_rf <- train(survived~.,
    data_train,
    method = "rf",
    metric = "Accuracy",
    tuneGrid = tuneGrid,
    trControl = trControl,
    importance = TRUE,
    nodesize = 14,
    ntree = 800,
    maxnodes = 24)

Bước 5) Đánh giá mô hình

Dấu mũ thư viện có chức năng đưa ra dự đoán.

predict(model, newdata= df)
argument
- `model`: Define the model evaluated before. 
- `newdata`: Define the dataset to make prediction
prediction <-predict(fit_rf, data_test)

Bạn có thể sử dụng dự đoán để tính toán ma trận nhầm lẫn và xem điểm chính xác

confusionMatrix(prediction, data_test$survived)

Đầu ra:

## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  No Yes
##        No  110  32
##        Yes  11  56
##                                          
##                Accuracy : 0.7943         
##                  95% CI : (0.733, 0.8469)
##     No Information Rate : 0.5789         
##     P-Value [Acc > NIR] : 3.959e-11      
##                                          
##                   Kappa : 0.5638         
##  Mcnemar's Test P-Value : 0.002289       
##                                          
##             Sensitivity : 0.9091         
##             Specificity : 0.6364         
##          Pos Pred Value : 0.7746         
##          Neg Pred Value : 0.8358         
##              Prevalence : 0.5789         
##          Detection Rate : 0.5263         
##    Detection Prevalence : 0.6794         
##       Balanced Accuracy : 0.7727         
##                                          
##        'Positive' Class : No             
## 

Bạn có độ chính xác là 0.7943 phần trăm, cao hơn giá trị mặc định

Bước 6) Trực quan hóa kết quả

Cuối cùng, bạn có thể xem tầm quan trọng của tính năng bằng hàm varImp(). Có vẻ như những đặc điểm quan trọng nhất là giới tính và tuổi tác. Điều đó không có gì đáng ngạc nhiên vì những đặc điểm quan trọng thường xuất hiện gần gốc cây hơn, trong khi những đặc điểm ít quan trọng hơn thường xuất hiện ở gần lá.

varImpPlot(fit_rf)

Đầu ra:

varImp(fit_rf)
## rf variable importance
## 
##              Importance
## sexmale         100.000
## age              28.014
## pclassMiddle     27.016
## fare             21.557
## pclassUpper      16.324
## sibsp            11.246
## parch             5.522
## embarkedC         4.908
## embarkedQ         1.420
## embarkedS         0.000		

Tổng kết

Chúng ta có thể tóm tắt cách huấn luyện và đánh giá một khu rừng ngẫu nhiên bằng bảng dưới đây:

Thư viện Mục tiêu Chức năng Tham số
ngẫu nhiênForest Tạo một khu rừng ngẫu nhiên Rừng ngẫu nhiên() công thức, ntree=n, mtry=FALSE, maxnodes = NULL
dấu mũ Tạo xác thực chéo thư mục K tàuControl() phương thức = “cv”, số = n, tìm kiếm =”lưới”
dấu mũ Huấn luyện một khu rừng ngẫu nhiên xe lửa() công thức, df, phương thức = “rf”, số liệu= “Độ chính xác”, trControl = trainControl(), tuneGrid = NULL
dấu mũ Dự đoán ngoài mẫu dự đoán mô hình, dữ liệu mới= df
dấu mũ Ma trận nhầm lẫn và thống kê ma trận hỗn loạn() mô hình, kiểm tra y
dấu mũ tầm quan trọng thay đổi cvarImp() kiểu mẫu

Phụ lục

Danh sách mô hình được sử dụng trong dấu mũ

names>(getModelInfo())

Đầu ra:

##   [1] "ada"                 "AdaBag"              "AdaBoost.M1"        ##   [4] "adaboost"            "amdai"               "ANFIS"              ##   [7] "avNNet"              "awnb"                "awtan"              ##  [10] "bag"                 "bagEarth"            "bagEarthGCV"        ##  [13] "bagFDA"              "bagFDAGCV"           "bam"                ##  [16] "bartMachine"         "bayesglm"            "binda"              ##  [19] "blackboost"          "blasso"              "blassoAveraged"     ##  [22] "bridge"              "brnn"                "BstLm"              ##  [25] "bstSm"               "bstTree"             "C5.0"               ##  [28] "C5.0Cost"            "C5.0Rules"           "C5.0Tree"           ##  [31] "cforest"             "chaid"               "CSimca"             ##  [34] "ctree"               "ctree2"              "cubist"             ##  [37] "dda"                 "deepboost"           "DENFIS"             ##  [40] "dnn"                 "dwdLinear"           "dwdPoly"            ##  [43] "dwdRadial"           "earth"               "elm"                ##  [46] "enet"                "evtree"              "extraTrees"         ##  [49] "fda"                 "FH.GBML"             "FIR.DM"             ##  [52] "foba"                "FRBCS.CHI"           "FRBCS.W"            ##  [55] "FS.HGD"              "gam"                 "gamboost"           ##  [58] "gamLoess"            "gamSpline"           "gaussprLinear"      ##  [61] "gaussprPoly"         "gaussprRadial"       "gbm_h3o"            ##  [64] "gbm"                 "gcvEarth"            "GFS.FR.MOGUL"       ##  [67] "GFS.GCCL"            "GFS.LT.RS"           "GFS.THRIFT"         ##  [70] "glm.nb"              "glm"                 "glmboost"           ##  [73] "glmnet_h3o"          "glmnet"              "glmStepAIC"         ##  [76] "gpls"                "hda"                 "hdda"               ##  [79] "hdrda"               "HYFIS"               "icr"                ##  [82] "J48"                 "JRip"                "kernelpls"          ##  [85] "kknn"                "knn"                 "krlsPoly"           ##  [88] "krlsRadial"          "lars"                "lars2"              ##  [91] "lasso"               "lda"                 "lda2"               ##  [94] "leapBackward"        "leapForward"         "leapSeq"            ##  [97] "Linda"               "lm"                  "lmStepAIC"          ## [100] "LMT"                 "loclda"              "logicBag"           ## [103] "LogitBoost"          "logreg"              "lssvmLinear"        ## [106] "lssvmPoly"           "lssvmRadial"         "lvq"                ## [109] "M5"                  "M5Rules"             "manb"               ## [112] "mda"                 "Mlda"                "mlp"                ## [115] "mlpKerasDecay"       "mlpKerasDecayCost"   "mlpKerasDropout"    ## [118] "mlpKerasDropoutCost" "mlpML"               "mlpSGD"             ## [121] "mlpWeightDecay"      "mlpWeightDecayML"    "monmlp"             ## [124] "msaenet"             "multinom"            "mxnet"              ## [127] "mxnetAdam"           "naive_bayes"         "nb"                 ## [130] "nbDiscrete"          "nbSearch"            "neuralnet"          ## [133] "nnet"                "nnls"                "nodeHarvest"        ## [136] "null"                "OneR"                "ordinalNet"         ## [139] "ORFlog"              "ORFpls"              "ORFridge"           ## [142] "ORFsvm"              "ownn"                "pam"                ## [145] "parRF"               "PART"                "partDSA"            ## [148] "pcaNNet"             "pcr"                 "pda"                ## [151] "pda2"                "penalized"           "PenalizedLDA"       ## [154] "plr"                 "pls"                 "plsRglm"            ## [157] "polr"                "ppr"                 "PRIM"               ## [160] "protoclass"          "pythonKnnReg"        "qda"                ## [163] "QdaCov"              "qrf"                 "qrnn"               ## [166] "randomGLM"           "ranger"              "rbf"                ## [169] "rbfDDA"              "Rborist"             "rda"                ## [172] "regLogistic"         "relaxo"              "rf"                 ## [175] "rFerns"              "RFlda"               "rfRules"            ## [178] "ridge"               "rlda"                "rlm"                ## [181] "rmda"                "rocc"                "rotationForest"     ## [184] "rotationForestCp"    "rpart"               "rpart1SE"           ## [187] "rpart2"              "rpartCost"           "rpartScore"         ## [190] "rqlasso"             "rqnc"                "RRF"                ## [193] "RRFglobal"           "rrlda"               "RSimca"             ## [196] "rvmLinear"           "rvmPoly"             "rvmRadial"          ## [199] "SBC"                 "sda"                 "sdwd"               ## [202] "simpls"              "SLAVE"               "slda"               ## [205] "smda"                "snn"                 "sparseLDA"          ## [208] "spikeslab"           "spls"                "stepLDA"            ## [211] "stepQDA"             "superpc"             "svmBoundrangeString"## [214] "svmExpoString"       "svmLinear"           "svmLinear2"         ## [217] "svmLinear3"          "svmLinearWeights"    "svmLinearWeights2"  ## [220] "svmPoly"             "svmRadial"           "svmRadialCost"      ## [223] "svmRadialSigma"      "svmRadialWeights"    "svmSpectrumString"  ## [226] "tan"                 "tanSearch"           "treebag"            ## [229] "vbmpRadial"          "vglmAdjCat"          "vglmContRatio"      ## [232] "vglmCumulative"      "widekernelpls"       "WM"                 ## [235] "wsrf"                "xgbLinear"           "xgbTree"            ## [238] "xyf"