R 中的 GLM:广义线性模型及示例

什么是逻辑回归?

逻辑回归用于预测类别,即概率。逻辑回归可以准确预测二元结果。

假设您想根据许多属性预测一笔贷款是否被拒绝/接受。逻辑回归的形式为 0/1。如果贷款被拒绝,则 y = 0;如果贷款被接受,则 y = 1。

逻辑回归模型与线性回归模型有两点不同。

  • 首先,逻辑回归只接受二分(二进制)输入作为因变量(即 0 和 1 的向量)。
  • 其次,结果通过以下概率链接函数来衡量,称为 乙状结肠 由于其呈S形。:

Logistic回归

该函数的输出始终在 0 和 1 之间。查看下图

Logistic回归

S 型函数返回从 0 到 1 的值。对于分类任务,我们需要 0 或 1 的离散输出。

为了将连续流转换为离散值,我们可以将决策边界设置为 0.5。高于此阈值的所有值均归类为 1

Logistic回归

如何创建广义线性模型 (GLM)

让我们使用 成年人 数据集用于说明 Logistic 回归。“成年人”是分类任务的绝佳数据集。目标是预测个人的年收入是否会超过 50.000 美元。数据集包含 46,033 个观测值和十个特征:

  • age:个人的年龄。数字
  • 教育:个人的教育水平。因素。
  • 婚姻状况: Mari个人的状况。因素包括:未婚、已婚配偶……
  • 性别:个人的性别。因素,即男性或女性
  • 收入: Target 变量。收入高于或低于 50K。因素即 >50K,<=50K

在其他人中

library(dplyr)
data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv")
glimpse(data_adult)

输出:

Observations: 48,842
Variables: 10
$ x               <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,...
$ age             <int> 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26...
$ workclass       <fctr> Private, Private, Local-gov, Private, ?, Private,...
$ education       <fctr> 11th, HS-grad, Assoc-acdm, Some-college, Some-col...
$ educational.num <int> 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,...
$ marital.status  <fctr> Never-married, Married-civ-spouse, Married-civ-sp...
$ race            <fctr> Black, White, White, Black, White, White, Black, ...
$ gender          <fctr> Male, Male, Male, Male, Female, Male, Male, Male,...
$ hours.per.week  <int> 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39...
$ income          <fctr> <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5...

我们将按照以下步骤进行:

  • 步骤 1:检查连续变量
  • 第 2 步:检查因子变量
  • 步骤 3:特征工程
  • 步骤 4:汇总统计
  • 步骤5:训练/测试集
  • 步骤 6:建立模型
  • 步骤 7:评估模型的性能
  • 步骤 8:改进模型

您的任务是预测哪个人的收入将高于 50K。

在本教程中,将详细介绍对真实数据集进行分析的每个步骤。

步骤1)检查连续变量

在第一步中,您可以看到连续变量的分布。

continuous <-select_if(data_adult, is.numeric)
summary(continuous)

代码说明

  • 连续<- select_if(data_adult,is.numeric):使用 dplyr 库中的函数 select_if()仅选择数字列
  • summary(continuous): 打印摘要统计信息

输出:

##        X              age        educational.num hours.per.week 
##  Min.   :    1   Min.   :17.00   Min.   : 1.00   Min.   : 1.00  
##  1st Qu.:11509   1st Qu.:28.00   1st Qu.: 9.00   1st Qu.:40.00  
##  Median :23017   Median :37.00   Median :10.00   Median :40.00  
##  Mean   :23017   Mean   :38.56   Mean   :10.13   Mean   :40.95  
##  3rd Qu.:34525   3rd Qu.:47.00   3rd Qu.:13.00   3rd Qu.:45.00  
##  Max.   :46033   Max.   :90.00   Max.   :16.00   Max.   :99.00	

从上表中,您可以看到数据具有完全不同的尺度,并且每周小时数具有较大的异常值(即查看最后一个四分位数和最大值)。

你可以按照两个步骤来处理:

  • 1:绘制每周工作时间分布图
  • 2:标准化连续变量
  1. 绘制分布

让我们仔细看看每周工作时间的分布

# Histogram with kernel density curve
library(ggplot2)
ggplot(continuous, aes(x = hours.per.week)) +
    geom_density(alpha = .2, fill = "#FF6666")

输出:

检查连续变量

该变量有很多异常值,并且分布不明确。您可以通过删除每周前 0.01% 的小时数来部分解决此问题。

分位数的基本语法:

quantile(variable, percentile)
arguments:
-variable:  Select the variable in the data frame to compute the percentile
-percentile:  Can be a single value between 0 and 1 or multiple value. If multiple, use this format:  `c(A,B,C, ...)
- `A`,`B`,`C` and `...` are all integer from 0 to 1.

我们计算前 2% 的百分位数

top_one_percent <- quantile(data_adult$hours.per.week, .99)
top_one_percent

代码说明

  • quantile(data_adult$hours.per.week, .99):计算 99% 工作时间的值

输出:

## 99% 
##  80

98%的人口每周工作时间不足80小时。

您可以将观测值降低到高于此阈值的水平。您可以使用 dplyr 图书馆。

data_adult_drop <-data_adult %>%
filter(hours.per.week<top_one_percent)
dim(data_adult_drop)

输出:

## [1] 45537    10
  1. 标准化连续变量

由于数据的比例不同,您可以标准化每列以提高性能。您可以使用 dplyr 库中的函数 mutate_if。基本语法是:

mutate_if(df, condition, funs(function))
arguments:
-`df`: Data frame used to compute the function
- `condition`: Statement used. Do not use parenthesis
- funs(function):  Return the function to apply. Do not use parenthesis for the function

您可以按如下方式标准化数字列:

data_adult_rescale <- data_adult_drop % > %
	mutate_if(is.numeric, funs(as.numeric(scale(.))))
head(data_adult_rescale)

代码说明

  • mutate_if(is.numeric, funs(scale)):条件仅为数字列,函数为scale

输出:

##           X         age        workclass    education educational.num
## 1 -1.732680 -1.02325949          Private         11th     -1.22106443
## 2 -1.732605 -0.03969284          Private      HS-grad     -0.43998868
## 3 -1.732530 -0.79628257        Local-gov   Assoc-acdm      0.73162494
## 4 -1.732455  0.41426100          Private Some-college     -0.04945081
## 5 -1.732379 -0.34232873          Private         10th     -1.61160231
## 6 -1.732304  1.85178149 Self-emp-not-inc  Prof-school      1.90323857
##       marital.status  race gender hours.per.week income
## 1      Never-married Black   Male    -0.03995944  <=50K
## 2 Married-civ-spouse White   Male     0.86863037  <=50K
## 3 Married-civ-spouse White   Male    -0.03995944   >50K
## 4 Married-civ-spouse Black   Male    -0.03995944   >50K
## 5      Never-married White   Male    -0.94854924  <=50K
## 6 Married-civ-spouse White   Male    -0.76683128   >50K

步骤2)检查因子变量

此步骤有两个目标:

  • 检查每个分类列中的级别
  • 定义新的级别

我们将这一步分为三个部分:

  • 选择分类列
  • 将每列的条形图存储在列表中
  • 打印图表

我们可以使用以下代码来选择因子列:

# Select categorical column
factor <- data.frame(select_if(data_adult_rescale, is.factor))
	ncol(factor)

代码说明

  • data.frame(select_if(data_adult, is.factor)):我们将因子中的因子列存储在数据框类型中。ggplot2 库需要数据框对象。

输出:

## [1] 6

数据集包含 6 个分类变量

第二步比较有技巧。你需要为数据框因子中的每一列绘制一个条形图。自动化这个过程更方便,特别是在列很多的情况下。

library(ggplot2)
# Create graph for each column
graph <- lapply(names(factor),
    function(x) 
	ggplot(factor, aes(get(x))) +
		geom_bar() +
		theme(axis.text.x = element_text(angle = 90)))

代码说明

  • lapply():使用函数 lapply() 将函数传递到数据集的所有列中。将输出存储在列表中
  • function(x):该函数将针对每个 x 进行处理。这里 x 是列
  • ggplot(factor, aes(get(x))) + geom_bar()+ theme(axis.text.x = element_text(angle = 90)):为每个 x 元素创建一个条形图。注意,要将 x 作为列返回,您需要将其包含在 get() 中

最后一步相对简单。您要打印 6 张图表。

# Print the graph
graph

输出:

## [[1]]

检查因素变量

## ## [[2]]

检查因素变量

## ## [[3]]

检查因素变量

## ## [[4]]

检查因素变量

## ## [[5]]

检查因素变量

## ## [[6]]

检查因素变量

注意:使用下一个按钮导航到下一个图表

检查因素变量

步骤3)特征工程

重塑教育

从上图可以看出,变量 education 有 16 个级别。这是相当可观的,而且有些级别的观察值数量相对较少。如果你想提高从这个变量中获取的信息量,你可以将其重铸到更高的级别。也就是说,你可以创建具有相似教育水平的更大组。例如,教育水平低的将转换为 dropout。教育水平较高的将改为 master。

详细信息如下:

旧级别 新水平
学龄前 辍学
10日 退出
11日 退出
12日 退出
1-4日 退出
5th-6th 退出
7th-8th 退出
9日 退出
高中毕业 高年级
一些学院 社区
协会会员 社区
副教授 社区
学士 学士
硕士 硕士
专业学校 硕士
博士学位 博士
recast_data <- data_adult_rescale % > %
	select(-X) % > %
	mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community",
    ifelse(education == "Bachelors", "Bachelors",
        ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))

代码说明

  • 我们使用 dplyr 库中的动词 mutate。我们使用 ifelse 语句更改 education 的值

在下表中,您可以创建一个汇总统计数据,以查看平均需要多少年教育(z 值)才能获得学士、硕士或博士学位。

recast_data % > %
	group_by(education) % > %
	summarize(average_educ_year = mean(educational.num),
		count = n()) % > %
	arrange(average_educ_year)

输出:

## # A tibble: 6 x 3
## education average_educ_year count			
##      <fctr>             <dbl> <int>
## 1   dropout       -1.76147258  5712
## 2  HighGrad       -0.43998868 14803
## 3 Community        0.09561361 13407
## 4 Bachelors        1.12216282  7720
## 5    Master        1.60337381  3338
## 6       PhD        2.29377644   557

重铸 Mari地位

也可以为婚姻状况创建较低级别。在下面的代码中,您可以按如下方式更改级别:

旧级别 新水平
未婚 未婚
已婚-配偶缺席 未婚
已婚-AF-配偶 已婚
已婚公民配偶
分离 分离
离婚
寡妇 寡妇
# Change level marry
recast_data <- recast_data % > %
	mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))

您可以检查每个组内的人数。

table(recast_data$marital.status)

输出:

## ##     Married Not_married   Separated       Widow
##       21165       15359        7727        1286

步骤 4)汇总统计

现在该检查一下目标变量的统计数据了。在下图中,您可以计算出收入超过 50 美元的个人的百分比(按性别计算)。

# Plot gender income
ggplot(recast_data, aes(x = gender, fill = income)) +
    geom_bar(position = "fill") +
    theme_classic()

输出:

统计摘要

接下来,检查个人的出身是否影响他们的收入。

# Plot origin income
ggplot(recast_data, aes(x = race, fill = income)) +
    geom_bar(position = "fill") +
    theme_classic() +
    theme(axis.text.x = element_text(angle = 90))

输出:

统计摘要

按性别划分的工作小时数。

# box plot gender working time
ggplot(recast_data, aes(x = gender, y = hours.per.week)) +
    geom_boxplot() +
    stat_summary(fun.y = mean,
        geom = "point",
        size = 3,
        color = "steelblue") +
    theme_classic()

输出:

统计摘要

箱线图证实了工作时间分布适合不同的群体。在箱线图中,两种性别的观察结果并不均匀。

你可以按教育类型查看每周工作时间的密度。分布有很多不同的选择。这可能可以通过美国的合同类型来解释。

# Plot distribution working time by education
ggplot(recast_data, aes(x = hours.per.week)) +
    geom_density(aes(color = education), alpha = 0.5) +
    theme_classic()

代码说明

  • ggplot(recast_data, aes( x= hours.per.week)):密度图只需要一个变量
  • geom_density(aes(color = education), alpha =0.5):控制密度的几何对象

输出:

统计摘要

为了确认你的想法,你可以进行单向 方差分析检验:

anova <- aov(hours.per.week~education, recast_data)
summary(anova)

输出:

##                Df Sum Sq Mean Sq F value Pr(>F)    
## education       5   1552  310.31   321.2 <2e-16 ***
## Residuals   45531  43984    0.97                   
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

方差分析检验证实了组间平均值的差异。

非线性

在运行模型之前,您可以查看工作时间是否与年龄有关。

library(ggplot2)
ggplot(recast_data, aes(x = age, y = hours.per.week)) +
    geom_point(aes(color = income),
        size = 0.5) +
    stat_smooth(method = 'lm',
        formula = y~poly(x, 2),
        se = TRUE,
        aes(color = income)) +
    theme_classic()

代码说明

  • ggplot(recast_data, aes(x = age, y = hours.per.week)):设置图形的美观性
  • geom_point(aes(color=income),size=0.5):构建点图
  • stat_smooth():使用以下参数添加趋势线:
    • method='lm': 如果 线性回归
    • 公式 = y~poly(x,2):拟合多项式回归
    • se = TRUE: 添加标准错误
    • aes(color=income): 按收入划分模型

输出:

非线性

简而言之,您可以测试模型中的交互项,以找出每周工作时间与其他特征之间的非线性效应。检测工作时间在哪种条件下有所不同非常重要。

相关性

下一个检查是可视化变量之间的相关性。将因子级别类型转换为数字,以便绘制包含使用 Spearman 方法计算的相关系数的热图。

library(GGally)
# Convert data to numeric
corr <- data.frame(lapply(recast_data, as.integer))
# Plot the graphggcorr(corr,
    method = c("pairwise", "spearman"),
    nbreaks = 6,
    hjust = 0.8,
    label = TRUE,
    label_size = 3,
    color = "grey50")

代码说明

  • data.frame(lapply(recast_data,as.integer)): 将数据转换为数字
  • ggcorr() 使用以下参数绘制热图:
    • 方法:计算相关性的方法
    • nbreaks = 6:中断次数
    • hjust = 0.8:控制图中变量名称的位置
    • label = TRUE: 在窗口中心添加标签
    • label_size = 3:标签尺寸
    • color = “grey50”:标签的颜色

输出:

相关性

步骤5)训练/测试集

任何受监督 机器学习 任务需要将数据分为训练集和测试集。您可以使用在其他监督学习教程中创建的“函数”来创建训练/测试集。

set.seed(1234)
create_train_test <- function(data, size = 0.8, train = TRUE) {
    n_row = nrow(data)
    total_row = size * n_row
    train_sample <- 1: total_row
    if (train == TRUE) {
        return (data[train_sample, ])
    } else {
        return (data[-train_sample, ])
    }
}
data_train <- create_train_test(recast_data, 0.8, train = TRUE)
data_test <- create_train_test(recast_data, 0.8, train = FALSE)
dim(data_train)

输出:

## [1] 36429     9
dim(data_test)

输出:

## [1] 9108    9

步骤6)建立模型

要查看算法的执行情况,可以使用 glm() 包。 广义线性模型 是模型的集合。基本语法是:

glm(formula, data=data, family=linkfunction()
Argument:
- formula:  Equation used to fit the model- data: dataset used
- Family:     - binomial: (link = "logit")			
- gaussian: (link = "identity")			
- Gamma:    (link = "inverse")			
- inverse.gaussian: (link = "1/mu^2")			
- poisson:  (link = "log")			
- quasi:    (link = "identity", variance = "constant")			
- quasibinomial:    (link = "logit")			
- quasipoisson: (link = "log")	

您已准备好估算逻辑模型,以在一组特征之间划分收入水平。

formula <- income~.
logit <- glm(formula, data = data_train, family = 'binomial')
summary(logit)

代码说明

  • 公式 <- 收入 ~ .: 创建适合的模型
  • logit <- glm(formula, data = data_train, family = 'binomial'): 使用 data_train 数据拟合逻辑模型(family = 'binomial')。
  • summary(logit):打印模型摘要

输出:

## 
## Call:
## glm(formula = formula, family = "binomial", data = data_train)
## ## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.6456  -0.5858  -0.2609  -0.0651   3.1982  
## 
## Coefficients:
##                           Estimate Std. Error z value Pr(>|z|)    
## (Intercept)                0.07882    0.21726   0.363  0.71675    
## age                        0.41119    0.01857  22.146  < 2e-16 ***
## workclassLocal-gov        -0.64018    0.09396  -6.813 9.54e-12 ***
## workclassPrivate          -0.53542    0.07886  -6.789 1.13e-11 ***
## workclassSelf-emp-inc     -0.07733    0.10350  -0.747  0.45499    
## workclassSelf-emp-not-inc -1.09052    0.09140 -11.931  < 2e-16 ***
## workclassState-gov        -0.80562    0.10617  -7.588 3.25e-14 ***
## workclassWithout-pay      -1.09765    0.86787  -1.265  0.20596    
## educationCommunity        -0.44436    0.08267  -5.375 7.66e-08 ***
## educationHighGrad         -0.67613    0.11827  -5.717 1.08e-08 ***
## educationMaster            0.35651    0.06780   5.258 1.46e-07 ***
## educationPhD               0.46995    0.15772   2.980  0.00289 ** 
## educationdropout          -1.04974    0.21280  -4.933 8.10e-07 ***
## educational.num            0.56908    0.07063   8.057 7.84e-16 ***
## marital.statusNot_married -2.50346    0.05113 -48.966  < 2e-16 ***
## marital.statusSeparated   -2.16177    0.05425 -39.846  < 2e-16 ***
## marital.statusWidow       -2.22707    0.12522 -17.785  < 2e-16 ***
## raceAsian-Pac-Islander     0.08359    0.20344   0.411  0.68117    
## raceBlack                  0.07188    0.19330   0.372  0.71001    
## raceOther                  0.01370    0.27695   0.049  0.96054    
## raceWhite                  0.34830    0.18441   1.889  0.05894 .  
## genderMale                 0.08596    0.04289   2.004  0.04506 *  
## hours.per.week             0.41942    0.01748  23.998  < 2e-16 ***
## ---## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## ## (Dispersion parameter for binomial family taken to be 1)
## ##     Null deviance: 40601  on 36428  degrees of freedom
## Residual deviance: 27041  on 36406  degrees of freedom
## AIC: 27087
## 
## Number of Fisher Scoring iterations: 6

我们的模型摘要揭示了一些有趣的信息。逻辑回归的性能可以用特定的关键指标来评估。

  • AIC(赤池信息准则):这相当于 R2 在逻辑回归中。它测量当对参数数量施加惩罚时的拟合度。较小的 AIC 值表明模型更接近真实情况。
  • 零偏差:仅用截距拟合模型。自由度为 n-1。我们可以将其解释为卡方值(不同于实际值假设检验的拟合值)。
  • 残差偏差:包含所有变量的模型。它也被解释为卡方假设检验。
  • 费舍尔评分迭代次数:收敛前的迭代次数。

glm() 函数的输出存储在列表中。下面的代码显示了我们构建的用于评估逻辑回归的 logit 变量中可用的所有项目。

# 列表很长,仅打印前三个元素

lapply(logit, class)[1:3]

输出:

## $coefficients
## [1] "numeric"
## 
## $residuals
## [1] "numeric"
## 
## $fitted.values
## [1] "numeric"

可以使用 $ 符号和指标名称提取每个值。例如,您将模型存储为 logit。要提取 AIC 标准,请使用:

logit$aic

输出:

## [1] 27086.65

步骤7)评估模型的性能

混淆矩阵

- 混淆矩阵 与您之前看到的不同指标相比,是评估分类性能的更好选择。一般的想法是计算 True 实例被分类为 False 的次数。

混淆矩阵

要计算混淆矩阵,首先需要有一组预测,以便可以将它们与实际目标进行比较。

predict <- predict(logit, data_test, type = 'response')
# confusion matrix
table_mat <- table(data_test$income, predict > 0.5)
table_mat

代码说明

  • predict(logit,data_test, type = 'response'):计算测试集的预测。设置 type = 'response' 来计算响应概率。
  • table(data_test$income, predict > 0.5):计算混淆矩阵。predict > 0.5 表示如果预测概率高于 1 则返回 0.5,否则返回 0。

输出:

##        
##         FALSE TRUE
##   <=50K  6310  495
##   >50K   1074 1229	

混淆矩阵中的每一行代表一个实际目标,而每一列代表一个预测目标。该矩阵的第一行考虑收入低于 50k(错误类别):6241 人被正确归类为收入低于 50k 的个人(真阴性),而剩下的一个被错误地归类为 50k 以上(假阳性第二行考虑收入在50万以上,正类为1229(真阳性),而 真阴性 是1074。

您可以计算模型 将真阳性 + 真阴性加到总观察值上

混淆矩阵

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
accuracy_Test

代码说明

  • sum(diag(table_mat)):对角线的总和
  • sum(table_mat):矩阵的总和。

输出:

## [1] 0.8277339

该模型似乎存在一个问题,即它高估了假阴性的数量。这被称为 准确度测试悖论。我们说过,准确率是正确预测与案例总数的比率。我们可以拥有相对较高的准确率,但模型却毫无用处。当存在主导类别时,就会发生这种情况。如果你回顾混淆矩阵,你会看到大多数案例都被归类为真阴性。现在想象一下,该模型将所有类别归类为阴性(即低于 50k)。你的准确率将达到 75%(6718/6718+2257)。你的模型表现更好,但很难区分真阳性和真阴性。

在这种情况下,最好有一个更简洁的指标。我们可以看看:

  • 准确率=TP/(TP+FP)
  • 召回率=TP/(TP+FN)

准确率与召回率

平台精度 查看正面预测的准确性。 记得 是分类器正确检测到的正实例的比例;

您可以构建两个函数来计算这两个指标

  1. 构造精度
precision <- function(matrix) {
	# True positive
    tp <- matrix[2, 2]
	# false positive
    fp <- matrix[1, 2]
    return (tp / (tp + fp))
}

代码说明

  • mat[1,1]:返回数据框第一列第一个单元格,即真阳性
  • mat[1,2];返回数据框第二列第一个单元格,即假阳性
recall <- function(matrix) {
# true positive
    tp <- matrix[2, 2]# false positive
    fn <- matrix[2, 1]
    return (tp / (tp + fn))
}

代码说明

  • mat[1,1]:返回数据框第一列第一个单元格,即真阳性
  • mat[2,1];返回数据框第一列第二个单元格,即假阴性

您可以测试您的功能

prec <- precision(table_mat)
prec
rec <- recall(table_mat)
rec

输出:

## [1] 0.712877
## [2] 0.5336518

当模型预测某个人的收入超过 50 美元时,只有 54% 的概率是正确的,而 50% 的概率可以断定某个人的收入超过 72 美元。

您可以创建 准确率与召回率 根据准确率和召回率进行评分。 准确率与召回率 是这两个指标的调和平均值,这意味着较低的值具有更大的权重。

准确率与召回率

f1 <- 2 * ((prec * rec) / (prec + rec))
f1

输出:

## [1] 0.6103799

准确率与召回率的权衡

同时具有较高的准确率和较高的召回率是不可能的。

如果我们提高精度,正确的个体将会被更好地预测,但我们会错过很多(召回率较低)。在某些情况下,我们更喜欢更高的精度而不是召回率。精度和召回率之间存在凹关系。

  • 想象一下,你需要预测一个病人是否患有疾病。你希望预测结果尽可能准确。
  • 如果需要通过人脸识别来侦查街上潜在的诈骗分子,那么最好能抓到很多被标记为诈骗分子的人,尽管准确率不高,警察就能释放那些没有诈骗行为的人。

ROC曲线

- 接收器 Opera特点 ROC 曲线是二元分类中另一种常用工具。它与精确度/召回率曲线非常相似,但 ROC 曲线不是绘制精确度与召回率,而是显示真实阳性率(即召回率)与假阳性率。假阳性率是被错误分类为阳性的阴性实例的比例。它等于一减去真实阴性率。真实阴性率也称为 特异性。因此 ROC 曲线图 灵敏度 (回忆)与 1-特异性

要绘制 ROC 曲线,我们需要安装一个名为 RORC 的库。我们可以在 conda 中找到 图书馆。您可以输入代码:

conda 安装-cr r-rocr –yes

我们可以使用 prediction() 和 performance() 函数绘制 ROC。

library(ROCR)
ROCRpred <- prediction(predict, data_test$income)
ROCRperf <- performance(ROCRpred, 'tpr', 'fpr')
plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))

代码说明

  • 预测(predict,data_test$income):ROCR 库需要创建一个预测对象来转换输入数据
  • performance(ROCRpred, 'tpr','fpr'):返回在图中产生的两个组合。这里构造了 tpr 和 fpr。要将精度和召回率一起绘制,请使用“prec”、“rec”。

输出:

ROC 曲线

步骤8) 改进模型

您可以尝试在模型中添加非线性,其相互作用如下:

  • 年龄和每周工作时间
  • 性别和每周工作时间。

您需要使用分数检验来比较两个模型

formula_2 <- income~age: hours.per.week + gender: hours.per.week + .
logit_2 <- glm(formula_2, data = data_train, family = 'binomial')
predict_2 <- predict(logit_2, data_test, type = 'response')
table_mat_2 <- table(data_test$income, predict_2 > 0.5)
precision_2 <- precision(table_mat_2)
recall_2 <- recall(table_mat_2)
f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2))
f1_2

输出:

## [1] 0.6109181

这个分数比之前的分数略高。您可以继续努力提高数据,争取打破这个分数。

总结

我们可以在下表中总结训练逻辑回归的函数:

小包装 目的 功能 争论
创建训练/测试数据集 创建训练集() 数据、大小、训练
m 训练广义线性模型 GLM() 公式、数据、系列*
m 总结模型 概括() 拟合模型
基地 做出预测 预测() 拟合模型,数据集,类型 = '响应'
基地 创建混淆矩阵 桌子() y,预测()
基地 创建准确度分数 总和(诊断(表())/总和(表()
大中华区 创建 ROC:步骤 1 创建预测 预言() 预测(),y
大中华区 创建 ROC:第 2 步 创建绩效 表现() 预测(),'tpr','fpr'
大中华区 创建 ROC:步骤 3 绘制图表 阴谋() 表现()

GLM 模型类型有:

– 二项式:(链接 = “logit”)

– 高斯:(链接 = “身份”)

– Gamma:(链接 = “逆”)

– inverse.gaussian:(链接 = “1/mu^2”)

– 泊松:(链接 = “log”)

– 准:(链接 = “身份”,方差 = “常数”)

– 准二项式:(链接 = “logit”)

– 拟泊松:(链接 = “log”)