Cây quyết định trong R: Cây phân loại với ví dụ

Cây quyết định là gì?

Cây quyết định là thuật toán Machine Learning đa năng có thể thực hiện cả nhiệm vụ phân loại và hồi quy. Chúng là thuật toán rất mạnh mẽ, có khả năng phù hợp với các tập dữ liệu phức tạp. Bên cạnh đó, cây quyết định là thành phần cơ bản của rừng ngẫu nhiên, một trong những thuật toán Machine Learning mạnh mẽ nhất hiện nay.

Đào tạo và trực quan hóa cây quyết định trong R

Để xây dựng cây quyết định đầu tiên của bạn trong ví dụ R, chúng tôi sẽ tiến hành như sau trong hướng dẫn về Cây quyết định này:

  • Bước 1: Nhập dữ liệu
  • Bước 2: Làm sạch tập dữ liệu
  • Bước 3: Tạo tập huấn luyện/kiểm tra
  • Bước 4: Xây dựng mô hình
  • Bước 5: Đưa ra dự đoán
  • Bước 6: Đo lường hiệu suất
  • Bước 7: Điều chỉnh các siêu tham số

Bước 1) Nhập dữ liệu

Nếu tò mò về số phận của tàu Titanic, bạn có thể xem video này trên Youtube. Mục đích của bộ dữ liệu này là dự đoán những người nào có nhiều khả năng sống sót hơn sau vụ va chạm với tảng băng trôi. Bộ dữ liệu chứa 13 biến và 1309 quan sát. Tập dữ liệu được sắp xếp theo biến X.

set.seed(678)
path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv'
titanic <-read.csv(path)
head(titanic)

Đầu ra:

##   X pclass survived                                            name    sex
## 1 1      1        1                   Allen, Miss. Elisabeth Walton female
## 2 2      1        1                  Allison, Master. Hudson Trevor   male
## 3 3      1        0                    Allison, Miss. Helen Loraine female
## 4 4      1        0            Allison, Mr. Hudson Joshua Creighton   male
## 5 5      1        0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female
## 6 6      1        1                             Anderson, Mr. Harry   male
##       age sibsp parch ticket     fare   cabin embarked
## 1 29.0000     0     0  24160 211.3375      B5        S
## 2  0.9167     1     2 113781 151.5500 C22 C26        S
## 3  2.0000     1     2 113781 151.5500 C22 C26        S
## 4 30.0000     1     2 113781 151.5500 C22 C26        S
## 5 25.0000     1     2 113781 151.5500 C22 C26        S
## 6 48.0000     0     0  19952  26.5500     E12        S
##                         home.dest
## 1                    St Louis, MO
## 2 Montreal, PQ / Chesterville, ON
## 3 Montreal, PQ / Chesterville, ON
## 4 Montreal, PQ / Chesterville, ON
## 5 Montreal, PQ / Chesterville, ON
## 6                    New York, NY
tail(titanic)

Đầu ra:

##         X pclass survived                      name    sex  age sibsp
## 1304 1304      3        0     Yousseff, Mr. Gerious   male   NA     0
## 1305 1305      3        0      Zabour, Miss. Hileni female 14.5     1
## 1306 1306      3        0     Zabour, Miss. Thamine female   NA     1
## 1307 1307      3        0 Zakarian, Mr. Mapriededer   male 26.5     0
## 1308 1308      3        0       Zakarian, Mr. Ortin   male 27.0     0
## 1309 1309      3        0        Zimmerman, Mr. Leo   male 29.0     0
##      parch ticket    fare cabin embarked home.dest
## 1304     0   2627 14.4583              C          
## 1305     0   2665 14.4542              C          
## 1306     0   2665 14.4542              C          
## 1307     0   2656  7.2250              C          
## 1308     0   2670  7.2250              C          
## 1309     0 315082  7.8750              S

Từ đầu ra đầu và đuôi, bạn có thể nhận thấy dữ liệu không bị xáo trộn. Đây là một vấn đề lớn! Khi bạn phân chia dữ liệu của mình giữa tập huấn luyện và tập kiểm tra, bạn sẽ chọn có thể hành khách hạng 1 và 2 (Không có hành khách hạng 3 nào nằm trong 80% quan sát hàng đầu), nghĩa là thuật toán sẽ không bao giờ nhìn thấy các đặc điểm của hành khách hạng 3. Sai lầm này sẽ dẫn đến dự đoán kém.

Để khắc phục vấn đề này, bạn có thể sử dụng hàm sample().

shuffle_index <- sample(1:nrow(titanic))
head(shuffle_index)

Cây quyết định Mã R Giải thích

  • sample(1:nrow(titanic)): Tạo danh sách chỉ mục ngẫu nhiên từ 1 đến 1309 (tức là số hàng tối đa).

Đầu ra:

## [1]  288  874 1078  633  887  992

Bạn sẽ sử dụng chỉ mục này để xáo trộn tập dữ liệu titanic.

titanic <- titanic[shuffle_index, ]
head(titanic)

Đầu ra:

##         X pclass survived
## 288   288      1        0
## 874   874      3        0
## 1078 1078      3        1
## 633   633      3        0
## 887   887      3        1
## 992   992      3        1
##                                                           name    sex age
## 288                                      Sutton, Mr. Frederick   male  61
## 874                   Humblen, Mr. Adolf Mathias Nicolai Olsen   male  42
## 1078                                 O'Driscoll, Miss. Bridget female  NA
## 633  Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female  39
## 887                                        Jermyn, Miss. Annie female  NA
## 992                                           Mamee, Mr. Hanna   male  NA
##      sibsp parch ticket    fare cabin embarked           home.dest## 288      0     0  36963 32.3208   D50        S     Haddenfield, NJ
## 874      0     0 348121  7.6500 F G63        S                    
## 1078     0     0  14311  7.7500              Q                    
## 633      1     5 347082 31.2750              S Sweden Winnipeg, MN
## 887      0     0  14313  7.7500              Q                    
## 992      0     0   2677  7.2292              C	

Bước 2) Làm sạch tập dữ liệu

Cấu trúc dữ liệu cho thấy một số biến có NA. Việc dọn dẹp dữ liệu được thực hiện như sau

  • Thả các biến home.dest,cabin,name,X và ticket
  • Tạo các biến nhân tố cho pclass và sống sót
  • Bỏ NA
library(dplyr)
# Drop variables
clean_titanic <- titanic % > %
select(-c(home.dest, cabin, name, X, ticket)) % > % 
#Convert to factor level
	mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')),
	survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > %
na.omit()
glimpse(clean_titanic)

Giải thích mã

  • select(-c(home.dest, cabin, name, X, ticket)): Bỏ các biến không cần thiết
  • pclass =factor(pclass,levels = c(1,2,3), labels= c('Upper', 'Middle', 'Lower')): Thêm nhãn vào biến pclass. 1 trở thành Thượng, 2 trở thành Trung và 3 trở thành Thấp
  • yếu tố(sống sót, mức độ = c(0,1), nhãn = c('Không', 'Có')): Thêm nhãn vào biến sống sót. 1 trở thành Không và 2 trở thành Có
  • na.omit(): Loại bỏ các quan sát NA

Đầu ra:

## Observations: 1,045
## Variables: 8
## $ pclass   <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U...
## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y...
## $ sex      <fctr> male, male, female, female, male, male, female, male...
## $ age      <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ...
## $ sibsp    <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,...
## $ parch    <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,...
## $ fare     <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ...
## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...		

Bước 3) Tạo tập huấn luyện/kiểm tra

Trước khi huấn luyện mô hình của mình, bạn cần thực hiện hai bước:

  • Tạo tập huấn luyện và kiểm tra: Bạn huấn luyện mô hình trên tập huấn luyện và kiểm tra dự đoán trên tập kiểm tra (tức là dữ liệu chưa nhìn thấy)
  • Cài đặt rpart.plot từ bảng điều khiển

Thực tiễn phổ biến là chia dữ liệu 80/20, 80% dữ liệu dùng để huấn luyện mô hình và 20% để đưa ra dự đoán. Bạn cần tạo hai khung dữ liệu riêng biệt. Bạn không muốn chạm vào tập kiểm tra cho đến khi hoàn thành việc xây dựng mô hình của mình. Bạn có thể tạo tên hàm create_train_test() có ba đối số.

create_train_test(df, size = 0.8, train = TRUE)
arguments:
-df: Dataset used to train the model.
-size: Size of the split. By default, 0.8. Numerical value
-train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) {
    n_row = nrow(data)
    total_row = size * n_row
    train_sample < - 1: total_row
    if (train == TRUE) {
        return (data[train_sample, ])
    } else {
        return (data[-train_sample, ])
    }
}

Giải thích mã

  • function(data, size=0.8, train = TRUE): Thêm các đối số trong hàm
  • n_row = nrow(data): Đếm số hàng trong tập dữ liệu
  • Total_row = size*n_row: Trả về hàng thứ n để xây dựng tập tàu
  • train_sample <- 1:total_row: Chọn hàng đầu tiên đến hàng thứ n
  • if (train ==TRUE){ } else { }: Nếu điều kiện được đặt thành đúng, trả về tập tàu, nếu không thì trả về tập kiểm tra.

Bạn có thể kiểm tra chức năng của mình và kiểm tra kích thước.

data_train <- create_train_test(clean_titanic, 0.8, train = TRUE)
data_test <- create_train_test(clean_titanic, 0.8, train = FALSE)
dim(data_train)

Đầu ra:

## [1] 836   8
dim(data_test)

Đầu ra:

## [1] 209   8

Tập dữ liệu tàu có 1046 hàng trong khi tập dữ liệu thử nghiệm có 262 hàng.

Bạn sử dụng hàm prop.table() kết hợp với table() để xác minh xem quá trình ngẫu nhiên hóa có đúng hay không.

prop.table(table(data_train$survived))

Đầu ra:

##
##        No       Yes 
## 0.5944976 0.4055024
prop.table(table(data_test$survived))

Đầu ra:

## 
##        No       Yes 
## 0.5789474 0.4210526

Trong cả hai tập dữ liệu, số lượng người sống sót là như nhau, khoảng 40%.

Cài đặt rpart.plot

rpart.plot không có sẵn trong thư viện conda. Bạn có thể cài đặt nó từ bảng điều khiển:

install.packages("rpart.plot")

Bước 4) Xây dựng mô hình

Bạn đã sẵn sàng để xây dựng mô hình. Cú pháp của hàm cây quyết định Rpart là:

rpart(formula, data=, method='')
arguments:			
- formula: The function to predict
- data: Specifies the data frame- method: 			
- "class" for a classification tree 			
- "anova" for a regression tree	

Bạn sử dụng phương thức lớp vì bạn dự đoán một lớp.

library(rpart)
library(rpart.plot)
fit <- rpart(survived~., data = data_train, method = 'class')
rpart.plot(fit, extra = 106

Giải thích mã

  • rpart(): Hàm phù hợp với mô hình. Các đối số là:
    • sống sót ~.: Công thức của cây quyết định
    • data = data_train: Tập dữ liệu
    • Method = 'class': Phù hợp với mô hình nhị phân
  • rpart.plot(fit, extra= 106): Vẽ sơ đồ cây. Các tính năng bổ sung được đặt thành 101 để hiển thị xác suất của loại thứ 2 (hữu ích cho các phản hồi nhị phân). Bạn có thể tham khảo các họa tiết để biết thêm thông tin về các lựa chọn khác.

Đầu ra:

Xây dựng mô hình cây quyết định trong R

Bạn bắt đầu tại nút gốc (độ sâu 0 trên 3, đỉnh biểu đồ):

  1. Ở trên cùng, đó là xác suất sống sót tổng thể. Nó cho thấy tỷ lệ hành khách sống sót sau vụ tai nạn. 41% hành khách sống sót.
  2. Nút này hỏi xem giới tính của hành khách có phải là nam hay không. Nếu có thì bạn đi xuống nút con bên trái của gốc (độ sâu 2). 63% là nam giới với khả năng sống sót là 21%.
  3. Ở nút thứ hai, bạn hỏi xem hành khách nam có trên 3.5 tuổi không. Nếu có thì cơ hội sống sót là 19%.
  4. Bạn cứ tiếp tục như vậy để hiểu những đặc điểm nào ảnh hưởng đến khả năng sống sót.

Lưu ý rằng, một trong nhiều ưu điểm của Cây Quyết định là chúng yêu cầu rất ít việc chuẩn bị dữ liệu. Đặc biệt, chúng không yêu cầu chia tỷ lệ hoặc căn giữa tính năng.

Theo mặc định, hàm rpart() sử dụng Gini thước đo tạp chất để phân chia nốt. Hệ số Gini càng cao thì càng có nhiều trường hợp khác nhau trong nút.

Bước 5) Đưa ra dự đoán

Bạn có thể dự đoán tập dữ liệu thử nghiệm của mình. Để đưa ra dự đoán, bạn có thể sử dụng hàm Predict(). Cú pháp cơ bản của dự đoán cho cây quyết định R là:

predict(fitted_model, df, type = 'class')
arguments:
- fitted_model: This is the object stored after model estimation. 
- df: Data frame used to make the prediction
- type: Type of prediction			
    - 'class': for classification			
    - 'prob': to compute the probability of each class			
    - 'vector': Predict the mean response at the node level	

Bạn muốn dự đoán hành khách nào có nhiều khả năng sống sót hơn sau vụ va chạm từ bộ thử nghiệm. Nghĩa là, bạn sẽ biết trong số 209 hành khách đó, ai sẽ sống sót hay không.

predict_unseen <-predict(fit, data_test, type = 'class')

Giải thích mã

  • dự đoán(fit, data_test, type='class'): Dự đoán lớp (0/1) của tập kiểm tra

Kiểm tra hành khách không đến được và những người đã đến được.

table_mat <- table(data_test$survived, predict_unseen)
table_mat

Giải thích mã

  • table(data_test$survived, Predict_unseen): Tạo bảng đếm xem có bao nhiêu hành khách được phân loại là người sống sót và đã qua đời so với phân loại cây quyết định đúng trong R

Đầu ra:

##      predict_unseen
##        No Yes
##   No  106  15
##   Yes  30  58

Mô hình đã dự đoán chính xác 106 hành khách thiệt mạng nhưng lại phân loại 15 người sống sót là đã chết. Bằng cách tương tự, mô hình đã phân loại sai 30 hành khách là những người sống sót trong khi hóa ra họ đã chết.

Bước 6) Đo lường hiệu suất

Bạn có thể tính toán thước đo độ chính xác cho nhiệm vụ phân loại bằng ma trận hỗn loạn:

ma trận hỗn loạn là sự lựa chọn tốt hơn để đánh giá hiệu suất phân loại. Ý tưởng chung là đếm số lần các trường hợp Đúng được phân loại là Sai.

Đo lường hiệu suất của cây quyết định trong R

Mỗi hàng trong ma trận nhầm lẫn đại diện cho một mục tiêu thực tế, trong khi mỗi cột đại diện cho một mục tiêu được dự đoán. Hàng đầu tiên của ma trận này coi hành khách đã chết (loại Sai): 106 được phân loại chính xác là đã chết (Âm tính thật), trong khi người còn lại bị phân loại sai là người sống sót (Dương tính giả). Hàng thứ hai xem xét những người sống sót, lớp tích cực là 58 (Đúng tích cực), trong khi Âm tính thật là 30.

Bạn có thể tính toán kiểm tra độ chính xác từ ma trận nhầm lẫn:

Đo lường hiệu suất của cây quyết định trong R

Đó là tỷ lệ giữa dương thực và âm thực trên tổng của ma trận. Với R, bạn có thể viết mã như sau:

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)

Giải thích mã

  • sum(diag(table_mat)): Tổng của đường chéo
  • sum(table_mat): Tổng của ma trận.

Bạn có thể in độ chính xác của bộ kiểm tra:

print(paste('Accuracy for test', accuracy_Test))

Đầu ra:

## [1] "Accuracy for test 0.784688995215311"

Bạn có số điểm 78 phần trăm cho bộ bài kiểm tra. Bạn có thể sao chép bài tập tương tự với tập dữ liệu huấn luyện.

Bước 7) Điều chỉnh các siêu tham số

Cây quyết định trong R có nhiều tham số khác nhau để kiểm soát các khía cạnh của sự phù hợp. Trong thư viện cây quyết định rpart, bạn có thể kiểm soát các tham số bằng hàm rpart.control(). Trong đoạn mã sau, bạn giới thiệu các tham số bạn sẽ điều chỉnh. Bạn có thể tham khảo họa tiết cho các thông số khác.

rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30)
Arguments:
-minsplit: Set the minimum number of observations in the node before the algorithm perform a split
-minbucket:  Set the minimum number of observations in the final note i.e. the leaf
-maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0

Chúng ta sẽ tiến hành như sau:

  • Xây dựng hàm để trả về độ chính xác
  • Điều chỉnh độ sâu tối đa
  • Điều chỉnh số lượng mẫu tối thiểu mà một nút phải có trước khi có thể phân chia
  • Điều chỉnh số lượng mẫu tối thiểu mà một nút lá phải có

Bạn có thể viết một hàm để hiển thị độ chính xác. Bạn chỉ cần bọc mã bạn đã sử dụng trước đó:

  1. dự đoán: dự đoán_unseen <- dự đoán (phù hợp, data_test, type = 'class')
  2. Tạo bảng: table_mat <- table(data_test$survived, Predict_unseen)
  3. Tính toán độ chính xác: độ chính xác_Test <- sum(diag(table_mat))/sum(table_mat)
accuracy_tune <- function(fit) {
    predict_unseen <- predict(fit, data_test, type = 'class')
    table_mat <- table(data_test$survived, predict_unseen)
    accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
    accuracy_Test
}

Bạn có thể thử điều chỉnh các tham số và xem liệu bạn có thể cải thiện mô hình so với giá trị mặc định hay không. Xin nhắc lại, bạn cần có độ chính xác cao hơn 0.78

control <- rpart.control(minsplit = 4,
    minbucket = round(5 / 3),
    maxdepth = 3,
    cp = 0)
tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control)
accuracy_tune(tune_fit)

Đầu ra:

## [1] 0.7990431

Với tham số sau:

minsplit = 4
minbucket= round(5/3)
maxdepth = 3cp=0

Bạn sẽ có được hiệu suất cao hơn so với mô hình trước đó. Xin chúc mừng!

Tổng kết

Chúng ta có thể tóm tắt các hàm để huấn luyện thuật toán cây quyết định trong R

Thư viện Mục tiêu Chức năng Lớp Thông số Chi Tiết
phần Cây phân loại tàu trong R rpart() tốt nghiệp lớp XNUMX công thức, df, phương pháp
phần Đào tạo cây hồi quy rpart() anova công thức, df, phương pháp
phần Vẽ cây rpart.plot() mô hình vừa vặn
cơ sở dự đoán dự đoán () tốt nghiệp lớp XNUMX mô hình, loại trang bị
cơ sở dự đoán dự đoán () thăm dò mô hình, loại trang bị
cơ sở dự đoán dự đoán () vector mô hình, loại trang bị
phần thông số điều khiển rpart.control() chia nhỏ Đặt số lượng quan sát tối thiểu trong nút trước khi thuật toán thực hiện phân tách
thùng nhỏ Đặt số lượng quan sát tối thiểu trong nốt cuối cùng tức là lá
độ sâu tối đa Đặt độ sâu tối đa của bất kỳ nút nào của cây cuối cùng. Nút gốc được xử lý ở độ sâu 0
phần Mô hình tàu với tham số điều khiển rpart() công thức, df, phương pháp, điều khiển

Lưu ý: Huấn luyện mô hình trên dữ liệu huấn luyện và kiểm tra hiệu suất trên tập dữ liệu không nhìn thấy, tức là tập kiểm tra.