Cây quyết định trong R: Cây phân loại với ví dụ
Cây quyết định là gì?
Cây quyết định là thuật toán Machine Learning đa năng có thể thực hiện cả nhiệm vụ phân loại và hồi quy. Chúng là thuật toán rất mạnh mẽ, có khả năng phù hợp với các tập dữ liệu phức tạp. Bên cạnh đó, cây quyết định là thành phần cơ bản của rừng ngẫu nhiên, một trong những thuật toán Machine Learning mạnh mẽ nhất hiện nay.
Đào tạo và trực quan hóa cây quyết định trong R
Để xây dựng cây quyết định đầu tiên của bạn trong ví dụ R, chúng tôi sẽ tiến hành như sau trong hướng dẫn về Cây quyết định này:
- Bước 1: Nhập dữ liệu
- Bước 2: Làm sạch tập dữ liệu
- Bước 3: Tạo tập huấn luyện/kiểm tra
- Bước 4: Xây dựng mô hình
- Bước 5: Đưa ra dự đoán
- Bước 6: Đo lường hiệu suất
- Bước 7: Điều chỉnh các siêu tham số
Bước 1) Nhập dữ liệu
Nếu tò mò về số phận của tàu Titanic, bạn có thể xem video này trên Youtube. Mục đích của bộ dữ liệu này là dự đoán những người nào có nhiều khả năng sống sót hơn sau vụ va chạm với tảng băng trôi. Bộ dữ liệu chứa 13 biến và 1309 quan sát. Tập dữ liệu được sắp xếp theo biến X.
set.seed(678) path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv' titanic <-read.csv(path) head(titanic)
Đầu ra:
## X pclass survived name sex ## 1 1 1 1 Allen, Miss. Elisabeth Walton female ## 2 2 1 1 Allison, Master. Hudson Trevor male ## 3 3 1 0 Allison, Miss. Helen Loraine female ## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male ## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female ## 6 6 1 1 Anderson, Mr. Harry male ## age sibsp parch ticket fare cabin embarked ## 1 29.0000 0 0 24160 211.3375 B5 S ## 2 0.9167 1 2 113781 151.5500 C22 C26 S ## 3 2.0000 1 2 113781 151.5500 C22 C26 S ## 4 30.0000 1 2 113781 151.5500 C22 C26 S ## 5 25.0000 1 2 113781 151.5500 C22 C26 S ## 6 48.0000 0 0 19952 26.5500 E12 S ## home.dest ## 1 St Louis, MO ## 2 Montreal, PQ / Chesterville, ON ## 3 Montreal, PQ / Chesterville, ON ## 4 Montreal, PQ / Chesterville, ON ## 5 Montreal, PQ / Chesterville, ON ## 6 New York, NY
tail(titanic)
Đầu ra:
## X pclass survived name sex age sibsp ## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0 ## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1 ## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1 ## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0 ## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0 ## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0 ## parch ticket fare cabin embarked home.dest ## 1304 0 2627 14.4583 C ## 1305 0 2665 14.4542 C ## 1306 0 2665 14.4542 C ## 1307 0 2656 7.2250 C ## 1308 0 2670 7.2250 C ## 1309 0 315082 7.8750 S
Từ đầu ra đầu và đuôi, bạn có thể nhận thấy dữ liệu không bị xáo trộn. Đây là một vấn đề lớn! Khi bạn phân chia dữ liệu của mình giữa tập huấn luyện và tập kiểm tra, bạn sẽ chọn có thể hành khách hạng 1 và 2 (Không có hành khách hạng 3 nào nằm trong 80% quan sát hàng đầu), nghĩa là thuật toán sẽ không bao giờ nhìn thấy các đặc điểm của hành khách hạng 3. Sai lầm này sẽ dẫn đến dự đoán kém.
Để khắc phục vấn đề này, bạn có thể sử dụng hàm sample().
shuffle_index <- sample(1:nrow(titanic)) head(shuffle_index)
Cây quyết định Mã R Giải thích
- sample(1:nrow(titanic)): Tạo danh sách chỉ mục ngẫu nhiên từ 1 đến 1309 (tức là số hàng tối đa).
Đầu ra:
## [1] 288 874 1078 633 887 992
Bạn sẽ sử dụng chỉ mục này để xáo trộn tập dữ liệu titanic.
titanic <- titanic[shuffle_index, ] head(titanic)
Đầu ra:
## X pclass survived ## 288 288 1 0 ## 874 874 3 0 ## 1078 1078 3 1 ## 633 633 3 0 ## 887 887 3 1 ## 992 992 3 1 ## name sex age ## 288 Sutton, Mr. Frederick male 61 ## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42 ## 1078 O'Driscoll, Miss. Bridget female NA ## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39 ## 887 Jermyn, Miss. Annie female NA ## 992 Mamee, Mr. Hanna male NA ## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ ## 874 0 0 348121 7.6500 F G63 S ## 1078 0 0 14311 7.7500 Q ## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN ## 887 0 0 14313 7.7500 Q ## 992 0 0 2677 7.2292 C
Bước 2) Làm sạch tập dữ liệu
Cấu trúc dữ liệu cho thấy một số biến có NA. Việc dọn dẹp dữ liệu được thực hiện như sau
- Thả các biến home.dest,cabin,name,X và ticket
- Tạo các biến nhân tố cho pclass và sống sót
- Bỏ NA
library(dplyr) # Drop variables clean_titanic <- titanic % > % select(-c(home.dest, cabin, name, X, ticket)) % > % #Convert to factor level mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')), survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > % na.omit() glimpse(clean_titanic)
Giải thích mã
- select(-c(home.dest, cabin, name, X, ticket)): Bỏ các biến không cần thiết
- pclass =factor(pclass,levels = c(1,2,3), labels= c('Upper', 'Middle', 'Lower')): Thêm nhãn vào biến pclass. 1 trở thành Thượng, 2 trở thành Trung và 3 trở thành Thấp
- yếu tố(sống sót, mức độ = c(0,1), nhãn = c('Không', 'Có')): Thêm nhãn vào biến sống sót. 1 trở thành Không và 2 trở thành Có
- na.omit(): Loại bỏ các quan sát NA
Đầu ra:
## Observations: 1,045 ## Variables: 8 ## $ pclass <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U... ## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y... ## $ sex <fctr> male, male, female, female, male, male, female, male... ## $ age <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ... ## $ sibsp <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,... ## $ parch <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,... ## $ fare <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ... ## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...
Bước 3) Tạo tập huấn luyện/kiểm tra
Trước khi huấn luyện mô hình của mình, bạn cần thực hiện hai bước:
- Tạo tập huấn luyện và kiểm tra: Bạn huấn luyện mô hình trên tập huấn luyện và kiểm tra dự đoán trên tập kiểm tra (tức là dữ liệu chưa nhìn thấy)
- Cài đặt rpart.plot từ bảng điều khiển
Thực tiễn phổ biến là chia dữ liệu 80/20, 80% dữ liệu dùng để huấn luyện mô hình và 20% để đưa ra dự đoán. Bạn cần tạo hai khung dữ liệu riêng biệt. Bạn không muốn chạm vào tập kiểm tra cho đến khi hoàn thành việc xây dựng mô hình của mình. Bạn có thể tạo tên hàm create_train_test() có ba đối số.
create_train_test(df, size = 0.8, train = TRUE) arguments: -df: Dataset used to train the model. -size: Size of the split. By default, 0.8. Numerical value -train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample < - 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } }
Giải thích mã
- function(data, size=0.8, train = TRUE): Thêm các đối số trong hàm
- n_row = nrow(data): Đếm số hàng trong tập dữ liệu
- Total_row = size*n_row: Trả về hàng thứ n để xây dựng tập tàu
- train_sample <- 1:total_row: Chọn hàng đầu tiên đến hàng thứ n
- if (train ==TRUE){ } else { }: Nếu điều kiện được đặt thành đúng, trả về tập tàu, nếu không thì trả về tập kiểm tra.
Bạn có thể kiểm tra chức năng của mình và kiểm tra kích thước.
data_train <- create_train_test(clean_titanic, 0.8, train = TRUE) data_test <- create_train_test(clean_titanic, 0.8, train = FALSE) dim(data_train)
Đầu ra:
## [1] 836 8
dim(data_test)
Đầu ra:
## [1] 209 8
Tập dữ liệu tàu có 1046 hàng trong khi tập dữ liệu thử nghiệm có 262 hàng.
Bạn sử dụng hàm prop.table() kết hợp với table() để xác minh xem quá trình ngẫu nhiên hóa có đúng hay không.
prop.table(table(data_train$survived))
Đầu ra:
## ## No Yes ## 0.5944976 0.4055024
prop.table(table(data_test$survived))
Đầu ra:
## ## No Yes ## 0.5789474 0.4210526
Trong cả hai tập dữ liệu, số lượng người sống sót là như nhau, khoảng 40%.
Cài đặt rpart.plot
rpart.plot không có sẵn trong thư viện conda. Bạn có thể cài đặt nó từ bảng điều khiển:
install.packages("rpart.plot")
Bước 4) Xây dựng mô hình
Bạn đã sẵn sàng để xây dựng mô hình. Cú pháp của hàm cây quyết định Rpart là:
rpart(formula, data=, method='') arguments: - formula: The function to predict - data: Specifies the data frame- method: - "class" for a classification tree - "anova" for a regression tree
Bạn sử dụng phương thức lớp vì bạn dự đoán một lớp.
library(rpart) library(rpart.plot) fit <- rpart(survived~., data = data_train, method = 'class') rpart.plot(fit, extra = 106
Giải thích mã
- rpart(): Hàm phù hợp với mô hình. Các đối số là:
- sống sót ~.: Công thức của cây quyết định
- data = data_train: Tập dữ liệu
- Method = 'class': Phù hợp với mô hình nhị phân
- rpart.plot(fit, extra= 106): Vẽ sơ đồ cây. Các tính năng bổ sung được đặt thành 101 để hiển thị xác suất của loại thứ 2 (hữu ích cho các phản hồi nhị phân). Bạn có thể tham khảo các họa tiết để biết thêm thông tin về các lựa chọn khác.
Đầu ra:
Bạn bắt đầu tại nút gốc (độ sâu 0 trên 3, đỉnh biểu đồ):
- Ở trên cùng, đó là xác suất sống sót tổng thể. Nó cho thấy tỷ lệ hành khách sống sót sau vụ tai nạn. 41% hành khách sống sót.
- Nút này hỏi xem giới tính của hành khách có phải là nam hay không. Nếu có thì bạn đi xuống nút con bên trái của gốc (độ sâu 2). 63% là nam giới với khả năng sống sót là 21%.
- Ở nút thứ hai, bạn hỏi xem hành khách nam có trên 3.5 tuổi không. Nếu có thì cơ hội sống sót là 19%.
- Bạn cứ tiếp tục như vậy để hiểu những đặc điểm nào ảnh hưởng đến khả năng sống sót.
Lưu ý rằng, một trong nhiều ưu điểm của Cây Quyết định là chúng yêu cầu rất ít việc chuẩn bị dữ liệu. Đặc biệt, chúng không yêu cầu chia tỷ lệ hoặc căn giữa tính năng.
Theo mặc định, hàm rpart() sử dụng Gini thước đo tạp chất để phân chia nốt. Hệ số Gini càng cao thì càng có nhiều trường hợp khác nhau trong nút.
Bước 5) Đưa ra dự đoán
Bạn có thể dự đoán tập dữ liệu thử nghiệm của mình. Để đưa ra dự đoán, bạn có thể sử dụng hàm Predict(). Cú pháp cơ bản của dự đoán cho cây quyết định R là:
predict(fitted_model, df, type = 'class') arguments: - fitted_model: This is the object stored after model estimation. - df: Data frame used to make the prediction - type: Type of prediction - 'class': for classification - 'prob': to compute the probability of each class - 'vector': Predict the mean response at the node level
Bạn muốn dự đoán hành khách nào có nhiều khả năng sống sót hơn sau vụ va chạm từ bộ thử nghiệm. Nghĩa là, bạn sẽ biết trong số 209 hành khách đó, ai sẽ sống sót hay không.
predict_unseen <-predict(fit, data_test, type = 'class')
Giải thích mã
- dự đoán(fit, data_test, type='class'): Dự đoán lớp (0/1) của tập kiểm tra
Kiểm tra hành khách không đến được và những người đã đến được.
table_mat <- table(data_test$survived, predict_unseen) table_mat
Giải thích mã
- table(data_test$survived, Predict_unseen): Tạo bảng đếm xem có bao nhiêu hành khách được phân loại là người sống sót và đã qua đời so với phân loại cây quyết định đúng trong R
Đầu ra:
## predict_unseen ## No Yes ## No 106 15 ## Yes 30 58
Mô hình đã dự đoán chính xác 106 hành khách thiệt mạng nhưng lại phân loại 15 người sống sót là đã chết. Bằng cách tương tự, mô hình đã phân loại sai 30 hành khách là những người sống sót trong khi hóa ra họ đã chết.
Bước 6) Đo lường hiệu suất
Bạn có thể tính toán thước đo độ chính xác cho nhiệm vụ phân loại bằng ma trận hỗn loạn:
ma trận hỗn loạn là sự lựa chọn tốt hơn để đánh giá hiệu suất phân loại. Ý tưởng chung là đếm số lần các trường hợp Đúng được phân loại là Sai.
Mỗi hàng trong ma trận nhầm lẫn đại diện cho một mục tiêu thực tế, trong khi mỗi cột đại diện cho một mục tiêu được dự đoán. Hàng đầu tiên của ma trận này coi hành khách đã chết (loại Sai): 106 được phân loại chính xác là đã chết (Âm tính thật), trong khi người còn lại bị phân loại sai là người sống sót (Dương tính giả). Hàng thứ hai xem xét những người sống sót, lớp tích cực là 58 (Đúng tích cực), trong khi Âm tính thật là 30.
Bạn có thể tính toán kiểm tra độ chính xác từ ma trận nhầm lẫn:
Đó là tỷ lệ giữa dương thực và âm thực trên tổng của ma trận. Với R, bạn có thể viết mã như sau:
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
Giải thích mã
- sum(diag(table_mat)): Tổng của đường chéo
- sum(table_mat): Tổng của ma trận.
Bạn có thể in độ chính xác của bộ kiểm tra:
print(paste('Accuracy for test', accuracy_Test))
Đầu ra:
## [1] "Accuracy for test 0.784688995215311"
Bạn có số điểm 78 phần trăm cho bộ bài kiểm tra. Bạn có thể sao chép bài tập tương tự với tập dữ liệu huấn luyện.
Bước 7) Điều chỉnh các siêu tham số
Cây quyết định trong R có nhiều tham số khác nhau để kiểm soát các khía cạnh của sự phù hợp. Trong thư viện cây quyết định rpart, bạn có thể kiểm soát các tham số bằng hàm rpart.control(). Trong đoạn mã sau, bạn giới thiệu các tham số bạn sẽ điều chỉnh. Bạn có thể tham khảo họa tiết cho các thông số khác.
rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30) Arguments: -minsplit: Set the minimum number of observations in the node before the algorithm perform a split -minbucket: Set the minimum number of observations in the final note i.e. the leaf -maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0
Chúng ta sẽ tiến hành như sau:
- Xây dựng hàm để trả về độ chính xác
- Điều chỉnh độ sâu tối đa
- Điều chỉnh số lượng mẫu tối thiểu mà một nút phải có trước khi có thể phân chia
- Điều chỉnh số lượng mẫu tối thiểu mà một nút lá phải có
Bạn có thể viết một hàm để hiển thị độ chính xác. Bạn chỉ cần bọc mã bạn đã sử dụng trước đó:
- dự đoán: dự đoán_unseen <- dự đoán (phù hợp, data_test, type = 'class')
- Tạo bảng: table_mat <- table(data_test$survived, Predict_unseen)
- Tính toán độ chính xác: độ chính xác_Test <- sum(diag(table_mat))/sum(table_mat)
accuracy_tune <- function(fit) { predict_unseen <- predict(fit, data_test, type = 'class') table_mat <- table(data_test$survived, predict_unseen) accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test }
Bạn có thể thử điều chỉnh các tham số và xem liệu bạn có thể cải thiện mô hình so với giá trị mặc định hay không. Xin nhắc lại, bạn cần có độ chính xác cao hơn 0.78
control <- rpart.control(minsplit = 4, minbucket = round(5 / 3), maxdepth = 3, cp = 0) tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control) accuracy_tune(tune_fit)
Đầu ra:
## [1] 0.7990431
Với tham số sau:
minsplit = 4 minbucket= round(5/3) maxdepth = 3cp=0
Bạn sẽ có được hiệu suất cao hơn so với mô hình trước đó. Xin chúc mừng!
Tổng kết
Chúng ta có thể tóm tắt các hàm để huấn luyện thuật toán cây quyết định trong R
Thư viện | Mục tiêu | Chức năng | Lớp | Thông số | Chi Tiết |
---|---|---|---|---|---|
phần | Cây phân loại tàu trong R | rpart() | tốt nghiệp lớp XNUMX | công thức, df, phương pháp | |
phần | Đào tạo cây hồi quy | rpart() | anova | công thức, df, phương pháp | |
phần | Vẽ cây | rpart.plot() | mô hình vừa vặn | ||
cơ sở | dự đoán | dự đoán () | tốt nghiệp lớp XNUMX | mô hình, loại trang bị | |
cơ sở | dự đoán | dự đoán () | thăm dò | mô hình, loại trang bị | |
cơ sở | dự đoán | dự đoán () | vector | mô hình, loại trang bị | |
phần | thông số điều khiển | rpart.control() | chia nhỏ | Đặt số lượng quan sát tối thiểu trong nút trước khi thuật toán thực hiện phân tách | |
thùng nhỏ | Đặt số lượng quan sát tối thiểu trong nốt cuối cùng tức là lá | ||||
độ sâu tối đa | Đặt độ sâu tối đa của bất kỳ nút nào của cây cuối cùng. Nút gốc được xử lý ở độ sâu 0 | ||||
phần | Mô hình tàu với tham số điều khiển | rpart() | công thức, df, phương pháp, điều khiển |
Lưu ý: Huấn luyện mô hình trên dữ liệu huấn luyện và kiểm tra hiệu suất trên tập dữ liệu không nhìn thấy, tức là tập kiểm tra.