R 中的决策树:分类树示例

什么是决策树?

决策树 是多功能的机器学习算法,可以执行分类和回归任务。它们是非常强大的算法,能够拟合复杂的数据集。此外,决策树是随机森林的基本组成部分,是当今最强大的机器学习算法之一。

在 R 中训练和可视化决策树

为了在 R 示例中构建您的第一棵决策树,我们将按照本决策树教程中的说明进行操作:

  • 步骤 1:导入数据
  • 第 2 步:清理数据集
  • 步骤 3:创建训练/测试集
  • 步骤 4:建立模型
  • 步骤 5:做出预测
  • 第 6 步:衡量绩效
  • 步骤 7:调整超参数

步骤1)导入数据

如果你对泰坦尼克号的命运感到好奇,你可以观看此视频 Youtube。该数据集的目的是预测哪些人与冰山相撞后更有可能幸存。数据集包含 13 个变量和 1309 个观测值。数据集按变量 X 排序。

set.seed(678)
path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv'
titanic <-read.csv(path)
head(titanic)

输出:

##   X pclass survived                                            name    sex
## 1 1      1        1                   Allen, Miss. Elisabeth Walton female
## 2 2      1        1                  Allison, Master. Hudson Trevor   male
## 3 3      1        0                    Allison, Miss. Helen Loraine female
## 4 4      1        0            Allison, Mr. Hudson Joshua Creighton   male
## 5 5      1        0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female
## 6 6      1        1                             Anderson, Mr. Harry   male
##       age sibsp parch ticket     fare   cabin embarked
## 1 29.0000     0     0  24160 211.3375      B5        S
## 2  0.9167     1     2 113781 151.5500 C22 C26        S
## 3  2.0000     1     2 113781 151.5500 C22 C26        S
## 4 30.0000     1     2 113781 151.5500 C22 C26        S
## 5 25.0000     1     2 113781 151.5500 C22 C26        S
## 6 48.0000     0     0  19952  26.5500     E12        S
##                         home.dest
## 1                    St Louis, MO
## 2 Montreal, PQ / Chesterville, ON
## 3 Montreal, PQ / Chesterville, ON
## 4 Montreal, PQ / Chesterville, ON
## 5 Montreal, PQ / Chesterville, ON
## 6                    New York, NY
tail(titanic)

输出:

##         X pclass survived                      name    sex  age sibsp
## 1304 1304      3        0     Yousseff, Mr. Gerious   male   NA     0
## 1305 1305      3        0      Zabour, Miss. Hileni female 14.5     1
## 1306 1306      3        0     Zabour, Miss. Thamine female   NA     1
## 1307 1307      3        0 Zakarian, Mr. Mapriededer   male 26.5     0
## 1308 1308      3        0       Zakarian, Mr. Ortin   male 27.0     0
## 1309 1309      3        0        Zimmerman, Mr. Leo   male 29.0     0
##      parch ticket    fare cabin embarked home.dest
## 1304     0   2627 14.4583              C          
## 1305     0   2665 14.4542              C          
## 1306     0   2665 14.4542              C          
## 1307     0   2656  7.2250              C          
## 1308     0   2670  7.2250              C          
## 1309     0 315082  7.8750              S

从头部和尾部输出中,您可以注意到数据没有被打乱。这是一个大问题!当您将数据拆分为训练集和测试集时,您将选择 仅由 1 类和 2 类乘客(3 类乘客中没有一个进入观测值前 80%),这意味着算法永远不会看到 3 类乘客的特征。这个错误会导致预测不佳。

为了解决这个问题,您可以使用函数sample()。

shuffle_index <- sample(1:nrow(titanic))
head(shuffle_index)

决策树 R 代码解释

  • sample(1:nrow(titanic)):生成从1到1309(即最大行数)的索引随机列表。

输出:

## [1]  288  874 1078  633  887  992

您将使用此索引来重新排列泰坦尼克号数据集。

titanic <- titanic[shuffle_index, ]
head(titanic)

输出:

##         X pclass survived
## 288   288      1        0
## 874   874      3        0
## 1078 1078      3        1
## 633   633      3        0
## 887   887      3        1
## 992   992      3        1
##                                                           name    sex age
## 288                                      Sutton, Mr. Frederick   male  61
## 874                   Humblen, Mr. Adolf Mathias Nicolai Olsen   male  42
## 1078                                 O'Driscoll, Miss. Bridget female  NA
## 633  Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female  39
## 887                                        Jermyn, Miss. Annie female  NA
## 992                                           Mamee, Mr. Hanna   male  NA
##      sibsp parch ticket    fare cabin embarked           home.dest## 288      0     0  36963 32.3208   D50        S     Haddenfield, NJ
## 874      0     0 348121  7.6500 F G63        S                    
## 1078     0     0  14311  7.7500              Q                    
## 633      1     5 347082 31.2750              S Sweden Winnipeg, MN
## 887      0     0  14313  7.7500              Q                    
## 992      0     0   2677  7.2292              C	

步骤2)清理数据集

数据结构显示一些变量有 NA。数据清理工作如下

  • 删除变量 home.dest、cabin、name、X 和 ticket
  • 为 pclass 和 survivors 创建因子变量
  • 放弃 NA
library(dplyr)
# Drop variables
clean_titanic <- titanic % > %
select(-c(home.dest, cabin, name, X, ticket)) % > % 
#Convert to factor level
	mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')),
	survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > %
na.omit()
glimpse(clean_titanic)

代码说明

  • select(-c(home.dest, cabin, name, X, ticket)):删除不必要的变量
  • pclass = factor(pclass, levels = c(1,2,3), labels= c('Upper', 'Middle', 'Lower')): 为变量 pclass 添加标签。1 变为 Upper,2 变为 MIddle,3 变为 lower
  • factor(survived, levels = c(0,1), labels = c('No', 'Yes')):给变量survived添加标签。1变为No,2变为Yes
  • na.omit():删除 NA 观测值

输出:

## Observations: 1,045
## Variables: 8
## $ pclass   <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U...
## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y...
## $ sex      <fctr> male, male, female, female, male, male, female, male...
## $ age      <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ...
## $ sibsp    <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,...
## $ parch    <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,...
## $ fare     <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ...
## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...		

步骤3)创建训练/测试集

在训练模型之前,您需要执行两个步骤:

  • 创建训练和测试集:在训练集上训练模型,并在测试集(即未见数据)上测试预测
  • 从控制台安装 rpart.plot

常见的做法是将数据分成 80/20,80% 的数据用于训练模型,20% 的数据用于进行预测。您需要创建两个单独的数据框。在完成模型构建之前,您不想接触测试集。您可以创建一个名为 create_train_test() 的函数,它接受三个参数。

create_train_test(df, size = 0.8, train = TRUE)
arguments:
-df: Dataset used to train the model.
-size: Size of the split. By default, 0.8. Numerical value
-train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) {
    n_row = nrow(data)
    total_row = size * n_row
    train_sample < - 1: total_row
    if (train == TRUE) {
        return (data[train_sample, ])
    } else {
        return (data[-train_sample, ])
    }
}

代码说明

  • function(data, size=0.8, train = TRUE): 在函数中添加参数
  • n_row = nrow(data): 计算数据集中的行数
  • total_row = size*n_row:返回第 n 行以构建训练集
  • train_sample <- 1:total_row: 选择第一行到第 n 行
  • if (train ==TRUE){ } else { }:如果条件设置为真,则返回训练集,否则返回测试集。

您可以测试您的功能并检查尺寸。

data_train <- create_train_test(clean_titanic, 0.8, train = TRUE)
data_test <- create_train_test(clean_titanic, 0.8, train = FALSE)
dim(data_train)

输出:

## [1] 836   8
dim(data_test)

输出:

## [1] 209   8

训练数据集有 1046 行,而测试数据集有 262 行。

您可以使用函数 prop.table() 结合 table() 来验证随机化过程是否正确。

prop.table(table(data_train$survived))

输出:

##
##        No       Yes 
## 0.5944976 0.4055024
prop.table(table(data_test$survived))

输出:

## 
##        No       Yes 
## 0.5789474 0.4210526

在两个数据集中,幸存者的数量相同,约为 40%。

安装 rpart.plot

conda 库中没有 rpart.plot。您可以从控制台安装它:

install.packages("rpart.plot")

步骤4)建立模型

您已准备好构建模型。Rpart 决策树函数的语法是:

rpart(formula, data=, method='')
arguments:			
- formula: The function to predict
- data: Specifies the data frame- method: 			
- "class" for a classification tree 			
- "anova" for a regression tree	

您使用类方法是因为您要预测一个类。

library(rpart)
library(rpart.plot)
fit <- rpart(survived~., data = data_train, method = 'class')
rpart.plot(fit, extra = 106

代码说明

  • rpart():用于拟合模型的函数。参数为:
    • 幸存下来〜。:决策树的公式
    • 数据=data_train:数据集
    • method = ‘class’:拟合二元模型
  • rpart.plot(fit, extra= 106):绘制树。额外特征设置为 101,以显示第二类的概率(适用于二元响应)。您可以参考 小插图 了解有关其他选择的更多信息。

输出:

在 R 中构建决策树模型

从根节点(深度 0 除以 3,图的顶部)开始:

  1. 最上方是总体生存概率。它显示了在事故中幸存的乘客比例。41% 的乘客幸存。
  2. 此节点询问乘客的性别是否为男性。如果是,则向下转到根的左子节点(深度 2)。63% 为男性,存活概率为 21%。
  3. 在第二个节点中,您询问男性乘客是否年满 3.5 岁。如果是,则生存几率为 19%。
  4. 你继续这样做,就能了解哪些特征会影响生存的可能性。

请注意,决策树的众多特性之一是它们几乎不需要数据准备。特别是,它们不需要特征缩放或居中。

默认情况下,rpart() 函数使用 基尼 不纯度度量来分割节点。基尼系数越高,节点内的不同实例越多。

步骤 5)做出预测

您可以预测测试数据集。要进行预测,可以使用 predict() 函数。R 决策树预测的基本语法是:

predict(fitted_model, df, type = 'class')
arguments:
- fitted_model: This is the object stored after model estimation. 
- df: Data frame used to make the prediction
- type: Type of prediction			
    - 'class': for classification			
    - 'prob': to compute the probability of each class			
    - 'vector': Predict the mean response at the node level	

你想从测试集中预测哪些乘客在碰撞后更有可能幸存。这意味着,你将知道在这 209 名乘客中,哪些人会幸存,哪些人会死。

predict_unseen <-predict(fit, data_test, type = 'class')

代码说明

  • predict(fit, data_test, type = 'class'): 预测测试集的类(0/1)

对未能成功的乘客和成功完成任务的乘客进行测试。

table_mat <- table(data_test$survived, predict_unseen)
table_mat

代码说明

  • table(data_test$survived, predict_unseen):创建一个表来计算有多少乘客被归类为幸存者并去世,并与 R 中正确的决策树分类进行比较

输出:

##      predict_unseen
##        No Yes
##   No  106  15
##   Yes  30  58

该模型正确预测了 106 名乘客死亡,但将 15 名幸存者归类为死亡。类似地,该模型错误地将 30 名乘客归类为幸存者,而他们最终都被认定为死亡。

步骤 6)衡量绩效

您可以使用以下方法计算分类任务的准确度度量 混淆矩阵:

这个 混淆矩阵 是评估分类性能的更好选择。一般的想法是计算 True 实例被分类为 False 的次数。

测量 R 中决策树的性能

混淆矩阵中的每一行代表一个实际目标,而每一列代表一个预测目标。该矩阵的第一行考虑了死亡乘客(错误类别):106 人被正确归类为死亡(真阴性),而剩下的一个被错误地归类为幸存者(假阳性)。第二行考虑幸存者,阳性类别为 58(真阳性),而 真阴性 是30。

您可以计算 准确度测试 来自混淆矩阵:

测量 R 中决策树的性能

它是真阳性和真阴性占矩阵总和的比例。使用 R,您可以按如下方式编写代码:

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)

代码说明

  • sum(diag(table_mat)):对角线的总和
  • sum(table_mat):矩阵的总和。

您可以打印测试集的准确率:

print(paste('Accuracy for test', accuracy_Test))

输出:

## [1] "Accuracy for test 0.784688995215311"

您在测试集中的得分为 78%。您可以使用训练数据集重复相同的练习。

步骤 7)调整超参数

R 中的决策树有各种参数来控制拟合的各个方面。在 rpart 决策树库中,您可以使用 rpart.control() 函数控制参数。在下面的代码中,您将介绍要调整的参数。您可以参考 小插图 其他参数。

rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30)
Arguments:
-minsplit: Set the minimum number of observations in the node before the algorithm perform a split
-minbucket:  Set the minimum number of observations in the final note i.e. the leaf
-maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0

我们将按照以下步骤进行:

  • 构造函数来返回准确率
  • 调整最大深度
  • 调整节点分裂前必须拥有的最小样本数量
  • 调整叶节点必须具有的最小样本数量

你可以编写一个函数来显示准确率。你只需包装一下之前使用的代码:

  1. 预测:predict_unseen <- 预测(fit,data_test,type ='class')
  2. 生成表:table_mat <- table(data_test$survived, predict_unseen)
  3. 计算准确度:accuracy_Test <- sum(diag(table_mat))/sum(table_mat)
accuracy_tune <- function(fit) {
    predict_unseen <- predict(fit, data_test, type = 'class')
    table_mat <- table(data_test$survived, predict_unseen)
    accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
    accuracy_Test
}

您可以尝试调整参数,看看是否可以将模型改进到默认值以上。提醒一下,您需要获得高于 0.78 的准确率

control <- rpart.control(minsplit = 4,
    minbucket = round(5 / 3),
    maxdepth = 3,
    cp = 0)
tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control)
accuracy_tune(tune_fit)

输出:

## [1] 0.7990431

使用以下参数:

minsplit = 4
minbucket= round(5/3)
maxdepth = 3cp=0

您获得的性能比之前的型号更高。恭喜!

结语

我们可以总结训练决策树算法的函数 R

自学资料库 目的 功能 增益级 参数 信息
部分 在 R 中训练分类树 rpart() 公式、df、方法
部分 训练回归树 rpart() 方差分析 公式、df、方法
部分 绘制树木 rpart.plot() 拟合模型
基地 预测 预测() 拟合模型,类型
基地 预测 预测() 概率 拟合模型,类型
基地 预测 预测() 向量 拟合模型,类型
部分 控制参数 rpart.控制() 最小分割 在算法进行拆分之前,设置节点中的最小观察数
最小桶 设置最终注释中的最小观察次数,即叶子
最大深度 设置最终树中任意节点的最大深度。根节点的深度为 0
部分 使用控制参数训练模型 rpart() 公式、df、方法、控制

注意:在训练数据上训练模型,并在看不见的数据集(即测试集)上测试性能。