R 中的决策树:分类树示例
什么是决策树?
决策树 是多功能的机器学习算法,可以执行分类和回归任务。它们是非常强大的算法,能够拟合复杂的数据集。此外,决策树是随机森林的基本组成部分,是当今最强大的机器学习算法之一。
在 R 中训练和可视化决策树
为了在 R 示例中构建您的第一棵决策树,我们将按照本决策树教程中的说明进行操作:
- 步骤 1:导入数据
- 第 2 步:清理数据集
- 步骤 3:创建训练/测试集
- 步骤 4:建立模型
- 步骤 5:做出预测
- 第 6 步:衡量绩效
- 步骤 7:调整超参数
步骤1)导入数据
如果你对泰坦尼克号的命运感到好奇,你可以观看此视频 Youtube。该数据集的目的是预测哪些人与冰山相撞后更有可能幸存。数据集包含 13 个变量和 1309 个观测值。数据集按变量 X 排序。
set.seed(678) path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv' titanic <-read.csv(path) head(titanic)
输出:
## X pclass survived name sex ## 1 1 1 1 Allen, Miss. Elisabeth Walton female ## 2 2 1 1 Allison, Master. Hudson Trevor male ## 3 3 1 0 Allison, Miss. Helen Loraine female ## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male ## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female ## 6 6 1 1 Anderson, Mr. Harry male ## age sibsp parch ticket fare cabin embarked ## 1 29.0000 0 0 24160 211.3375 B5 S ## 2 0.9167 1 2 113781 151.5500 C22 C26 S ## 3 2.0000 1 2 113781 151.5500 C22 C26 S ## 4 30.0000 1 2 113781 151.5500 C22 C26 S ## 5 25.0000 1 2 113781 151.5500 C22 C26 S ## 6 48.0000 0 0 19952 26.5500 E12 S ## home.dest ## 1 St Louis, MO ## 2 Montreal, PQ / Chesterville, ON ## 3 Montreal, PQ / Chesterville, ON ## 4 Montreal, PQ / Chesterville, ON ## 5 Montreal, PQ / Chesterville, ON ## 6 New York, NY
tail(titanic)
输出:
## X pclass survived name sex age sibsp ## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0 ## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1 ## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1 ## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0 ## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0 ## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0 ## parch ticket fare cabin embarked home.dest ## 1304 0 2627 14.4583 C ## 1305 0 2665 14.4542 C ## 1306 0 2665 14.4542 C ## 1307 0 2656 7.2250 C ## 1308 0 2670 7.2250 C ## 1309 0 315082 7.8750 S
从头部和尾部输出中,您可以注意到数据没有被打乱。这是一个大问题!当您将数据拆分为训练集和测试集时,您将选择 仅由 1 类和 2 类乘客(3 类乘客中没有一个进入观测值前 80%),这意味着算法永远不会看到 3 类乘客的特征。这个错误会导致预测不佳。
为了解决这个问题,您可以使用函数sample()。
shuffle_index <- sample(1:nrow(titanic)) head(shuffle_index)
决策树 R 代码解释
- sample(1:nrow(titanic)):生成从1到1309(即最大行数)的索引随机列表。
输出:
## [1] 288 874 1078 633 887 992
您将使用此索引来重新排列泰坦尼克号数据集。
titanic <- titanic[shuffle_index, ] head(titanic)
输出:
## X pclass survived ## 288 288 1 0 ## 874 874 3 0 ## 1078 1078 3 1 ## 633 633 3 0 ## 887 887 3 1 ## 992 992 3 1 ## name sex age ## 288 Sutton, Mr. Frederick male 61 ## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42 ## 1078 O'Driscoll, Miss. Bridget female NA ## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39 ## 887 Jermyn, Miss. Annie female NA ## 992 Mamee, Mr. Hanna male NA ## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ ## 874 0 0 348121 7.6500 F G63 S ## 1078 0 0 14311 7.7500 Q ## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN ## 887 0 0 14313 7.7500 Q ## 992 0 0 2677 7.2292 C
步骤2)清理数据集
数据结构显示一些变量有 NA。数据清理工作如下
- 删除变量 home.dest、cabin、name、X 和 ticket
- 为 pclass 和 survivors 创建因子变量
- 放弃 NA
library(dplyr) # Drop variables clean_titanic <- titanic % > % select(-c(home.dest, cabin, name, X, ticket)) % > % #Convert to factor level mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')), survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > % na.omit() glimpse(clean_titanic)
代码说明
- select(-c(home.dest, cabin, name, X, ticket)):删除不必要的变量
- pclass = factor(pclass, levels = c(1,2,3), labels= c('Upper', 'Middle', 'Lower')): 为变量 pclass 添加标签。1 变为 Upper,2 变为 MIddle,3 变为 lower
- factor(survived, levels = c(0,1), labels = c('No', 'Yes')):给变量survived添加标签。1变为No,2变为Yes
- na.omit():删除 NA 观测值
输出:
## Observations: 1,045 ## Variables: 8 ## $ pclass <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U... ## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y... ## $ sex <fctr> male, male, female, female, male, male, female, male... ## $ age <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ... ## $ sibsp <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,... ## $ parch <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,... ## $ fare <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ... ## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...
步骤3)创建训练/测试集
在训练模型之前,您需要执行两个步骤:
- 创建训练和测试集:在训练集上训练模型,并在测试集(即未见数据)上测试预测
- 从控制台安装 rpart.plot
常见的做法是将数据分成 80/20,80% 的数据用于训练模型,20% 的数据用于进行预测。您需要创建两个单独的数据框。在完成模型构建之前,您不想接触测试集。您可以创建一个名为 create_train_test() 的函数,它接受三个参数。
create_train_test(df, size = 0.8, train = TRUE) arguments: -df: Dataset used to train the model. -size: Size of the split. By default, 0.8. Numerical value -train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample < - 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } }
代码说明
- function(data, size=0.8, train = TRUE): 在函数中添加参数
- n_row = nrow(data): 计算数据集中的行数
- total_row = size*n_row:返回第 n 行以构建训练集
- train_sample <- 1:total_row: 选择第一行到第 n 行
- if (train ==TRUE){ } else { }:如果条件设置为真,则返回训练集,否则返回测试集。
您可以测试您的功能并检查尺寸。
data_train <- create_train_test(clean_titanic, 0.8, train = TRUE) data_test <- create_train_test(clean_titanic, 0.8, train = FALSE) dim(data_train)
输出:
## [1] 836 8
dim(data_test)
输出:
## [1] 209 8
训练数据集有 1046 行,而测试数据集有 262 行。
您可以使用函数 prop.table() 结合 table() 来验证随机化过程是否正确。
prop.table(table(data_train$survived))
输出:
## ## No Yes ## 0.5944976 0.4055024
prop.table(table(data_test$survived))
输出:
## ## No Yes ## 0.5789474 0.4210526
在两个数据集中,幸存者的数量相同,约为 40%。
安装 rpart.plot
conda 库中没有 rpart.plot。您可以从控制台安装它:
install.packages("rpart.plot")
步骤4)建立模型
您已准备好构建模型。Rpart 决策树函数的语法是:
rpart(formula, data=, method='') arguments: - formula: The function to predict - data: Specifies the data frame- method: - "class" for a classification tree - "anova" for a regression tree
您使用类方法是因为您要预测一个类。
library(rpart) library(rpart.plot) fit <- rpart(survived~., data = data_train, method = 'class') rpart.plot(fit, extra = 106
代码说明
- rpart():用于拟合模型的函数。参数为:
- 幸存下来〜。:决策树的公式
- 数据=data_train:数据集
- method = ‘class’:拟合二元模型
- rpart.plot(fit, extra= 106):绘制树。额外特征设置为 101,以显示第二类的概率(适用于二元响应)。您可以参考 小插图 了解有关其他选择的更多信息。
输出:
从根节点(深度 0 除以 3,图的顶部)开始:
- 最上方是总体生存概率。它显示了在事故中幸存的乘客比例。41% 的乘客幸存。
- 此节点询问乘客的性别是否为男性。如果是,则向下转到根的左子节点(深度 2)。63% 为男性,存活概率为 21%。
- 在第二个节点中,您询问男性乘客是否年满 3.5 岁。如果是,则生存几率为 19%。
- 你继续这样做,就能了解哪些特征会影响生存的可能性。
请注意,决策树的众多特性之一是它们几乎不需要数据准备。特别是,它们不需要特征缩放或居中。
默认情况下,rpart() 函数使用 基尼 不纯度度量来分割节点。基尼系数越高,节点内的不同实例越多。
步骤 5)做出预测
您可以预测测试数据集。要进行预测,可以使用 predict() 函数。R 决策树预测的基本语法是:
predict(fitted_model, df, type = 'class') arguments: - fitted_model: This is the object stored after model estimation. - df: Data frame used to make the prediction - type: Type of prediction - 'class': for classification - 'prob': to compute the probability of each class - 'vector': Predict the mean response at the node level
你想从测试集中预测哪些乘客在碰撞后更有可能幸存。这意味着,你将知道在这 209 名乘客中,哪些人会幸存,哪些人会死。
predict_unseen <-predict(fit, data_test, type = 'class')
代码说明
- predict(fit, data_test, type = 'class'): 预测测试集的类(0/1)
对未能成功的乘客和成功完成任务的乘客进行测试。
table_mat <- table(data_test$survived, predict_unseen) table_mat
代码说明
- table(data_test$survived, predict_unseen):创建一个表来计算有多少乘客被归类为幸存者并去世,并与 R 中正确的决策树分类进行比较
输出:
## predict_unseen ## No Yes ## No 106 15 ## Yes 30 58
该模型正确预测了 106 名乘客死亡,但将 15 名幸存者归类为死亡。类似地,该模型错误地将 30 名乘客归类为幸存者,而他们最终都被认定为死亡。
步骤 6)衡量绩效
您可以使用以下方法计算分类任务的准确度度量 混淆矩阵:
这个 混淆矩阵 是评估分类性能的更好选择。一般的想法是计算 True 实例被分类为 False 的次数。
混淆矩阵中的每一行代表一个实际目标,而每一列代表一个预测目标。该矩阵的第一行考虑了死亡乘客(错误类别):106 人被正确归类为死亡(真阴性),而剩下的一个被错误地归类为幸存者(假阳性)。第二行考虑幸存者,阳性类别为 58(真阳性),而 真阴性 是30。
您可以计算 准确度测试 来自混淆矩阵:
它是真阳性和真阴性占矩阵总和的比例。使用 R,您可以按如下方式编写代码:
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
代码说明
- sum(diag(table_mat)):对角线的总和
- sum(table_mat):矩阵的总和。
您可以打印测试集的准确率:
print(paste('Accuracy for test', accuracy_Test))
输出:
## [1] "Accuracy for test 0.784688995215311"
您在测试集中的得分为 78%。您可以使用训练数据集重复相同的练习。
步骤 7)调整超参数
R 中的决策树有各种参数来控制拟合的各个方面。在 rpart 决策树库中,您可以使用 rpart.control() 函数控制参数。在下面的代码中,您将介绍要调整的参数。您可以参考 小插图 其他参数。
rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30) Arguments: -minsplit: Set the minimum number of observations in the node before the algorithm perform a split -minbucket: Set the minimum number of observations in the final note i.e. the leaf -maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0
我们将按照以下步骤进行:
- 构造函数来返回准确率
- 调整最大深度
- 调整节点分裂前必须拥有的最小样本数量
- 调整叶节点必须具有的最小样本数量
你可以编写一个函数来显示准确率。你只需包装一下之前使用的代码:
- 预测:predict_unseen <- 预测(fit,data_test,type ='class')
- 生成表:table_mat <- table(data_test$survived, predict_unseen)
- 计算准确度:accuracy_Test <- sum(diag(table_mat))/sum(table_mat)
accuracy_tune <- function(fit) { predict_unseen <- predict(fit, data_test, type = 'class') table_mat <- table(data_test$survived, predict_unseen) accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test }
您可以尝试调整参数,看看是否可以将模型改进到默认值以上。提醒一下,您需要获得高于 0.78 的准确率
control <- rpart.control(minsplit = 4, minbucket = round(5 / 3), maxdepth = 3, cp = 0) tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control) accuracy_tune(tune_fit)
输出:
## [1] 0.7990431
使用以下参数:
minsplit = 4 minbucket= round(5/3) maxdepth = 3cp=0
您获得的性能比之前的型号更高。恭喜!
结语
我们可以总结训练决策树算法的函数 R
自学资料库 | 目的 | 功能 | 增益级 | 参数 | 信息 |
---|---|---|---|---|---|
部分 | 在 R 中训练分类树 | rpart() | 程 | 公式、df、方法 | |
部分 | 训练回归树 | rpart() | 方差分析 | 公式、df、方法 | |
部分 | 绘制树木 | rpart.plot() | 拟合模型 | ||
基地 | 预测 | 预测() | 程 | 拟合模型,类型 | |
基地 | 预测 | 预测() | 概率 | 拟合模型,类型 | |
基地 | 预测 | 预测() | 向量 | 拟合模型,类型 | |
部分 | 控制参数 | rpart.控制() | 最小分割 | 在算法进行拆分之前,设置节点中的最小观察数 | |
最小桶 | 设置最终注释中的最小观察次数,即叶子 | ||||
最大深度 | 设置最终树中任意节点的最大深度。根节点的深度为 0 | ||||
部分 | 使用控制参数训练模型 | rpart() | 公式、df、方法、控制 |
注意:在训练数据上训练模型,并在看不见的数据集(即测试集)上测试性能。