GLM en R: modelo lineal generalizado con ejemplo

¿Qué es la regresión logística?

La regresión logística se utiliza para predecir una clase, es decir, una probabilidad. La regresión logística puede predecir con precisión un resultado binario.

Imagine que desea predecir si se deniega o acepta un préstamo en función de muchos atributos. La regresión logística es de la forma 0/1. y = 0 si se rechaza un préstamo, y = 1 si se acepta.

Un modelo de regresión logística se diferencia del modelo de regresión lineal en dos formas.

  • En primer lugar, la regresión logística acepta sólo datos dicotómicos (binarios) como variable dependiente (es decir, un vector de 0 y 1).
  • En segundo lugar, el resultado se mide mediante la siguiente función de enlace probabilístico denominada sigmoideo por su forma de S.:

Regresión logística

La salida de la función siempre está entre 0 y 1. Consulte la imagen a continuación

Regresión logística

La función sigmoidea devuelve valores de 0 a 1. Para la tarea de clasificación, necesitamos una salida discreta de 0 o 1.

Para convertir un flujo continuo en un valor discreto, podemos establecer un límite de decisión en 0.5. Todos los valores por encima de este umbral se clasifican como 1

Regresión logística

Cómo crear un modelo de revestimiento generalizado (GLM)

Usemos el adulto conjunto de datos para ilustrar la regresión logística. El "adulto" es un gran conjunto de datos para la tarea de clasificación. El objetivo es predecir si los ingresos anuales en dólares de un individuo superarán los 50.000. El conjunto de datos contiene 46,033 observaciones y diez características:

  • edad: edad del individuo. Numérico
  • educación: Nivel educativo del individuo. Factor.
  • Estado civil: Mariestatus tal del individuo. Factor, es decir, nunca casado, casado-cónyuge civil,…
  • género: Género del individuo. Factor, es decir, masculino o femenino
  • ingreso: Target variable. Ingresos superiores o inferiores a 50K. Factor, es decir >50K, <=50K

Entre otros

library(dplyr)
data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv")
glimpse(data_adult)

Salida:

Observations: 48,842
Variables: 10
$ x               <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,...
$ age             <int> 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26...
$ workclass       <fctr> Private, Private, Local-gov, Private, ?, Private,...
$ education       <fctr> 11th, HS-grad, Assoc-acdm, Some-college, Some-col...
$ educational.num <int> 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,...
$ marital.status  <fctr> Never-married, Married-civ-spouse, Married-civ-sp...
$ race            <fctr> Black, White, White, Black, White, White, Black, ...
$ gender          <fctr> Male, Male, Male, Male, Female, Male, Male, Male,...
$ hours.per.week  <int> 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39...
$ income          <fctr> <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5...

Procederemos de la siguiente manera:

  • Paso 1: Verifique las variables continuas
  • Paso 2: verificar las variables de los factores
  • Paso 3: ingeniería de funciones
  • Paso 4: resumen estadístico
  • Paso 5: Conjunto de entrenamiento/prueba
  • Paso 6: construye el modelo
  • Paso 7: evaluar el rendimiento del modelo
  • Paso 8: Mejorar el modelo

Su tarea es predecir qué individuo tendrá unos ingresos superiores a 50.

En este tutorial, se detallará cada paso para realizar un análisis en un conjunto de datos real.

Paso 1) Verificar variables continuas

En el primer paso, puedes ver la distribución de las variables continuas.

continuous <-select_if(data_adult, is.numeric)
summary(continuous)

Explicación del código

  • continuo <- select_if(data_adult, is.numeric): use la función select_if() de la biblioteca dplyr para seleccionar solo las columnas numéricas
  • resumen (continuo): imprime la estadística resumida

Salida:

##        X              age        educational.num hours.per.week 
##  Min.   :    1   Min.   :17.00   Min.   : 1.00   Min.   : 1.00  
##  1st Qu.:11509   1st Qu.:28.00   1st Qu.: 9.00   1st Qu.:40.00  
##  Median :23017   Median :37.00   Median :10.00   Median :40.00  
##  Mean   :23017   Mean   :38.56   Mean   :10.13   Mean   :40.95  
##  3rd Qu.:34525   3rd Qu.:47.00   3rd Qu.:13.00   3rd Qu.:45.00  
##  Max.   :46033   Max.   :90.00   Max.   :16.00   Max.   :99.00	

De la tabla anterior, se puede ver que los datos tienen escalas totalmente diferentes y hours.per.weeks tiene valores atípicos grandes (es decir, observe el último cuartil y el valor máximo).

Puedes solucionarlo siguiendo dos pasos:

  • 1: Grafica la distribución de horas por semana
  • 2: estandarizar las variables continuas
  1. Trazar la distribución

Veamos más de cerca la distribución de horas por semana.

# Histogram with kernel density curve
library(ggplot2)
ggplot(continuous, aes(x = hours.per.week)) +
    geom_density(alpha = .2, fill = "#FF6666")

Salida:

Verificar variables continuas

La variable tiene muchos valores atípicos y una distribución no bien definida. Puedes solucionar parcialmente este problema eliminando el 0.01 por ciento superior de las horas por semana.

Sintaxis básica de cuantil:

quantile(variable, percentile)
arguments:
-variable:  Select the variable in the data frame to compute the percentile
-percentile:  Can be a single value between 0 and 1 or multiple value. If multiple, use this format:  `c(A,B,C, ...)
- `A`,`B`,`C` and `...` are all integer from 0 to 1.

Calculamos el percentil del 2 por ciento superior

top_one_percent <- quantile(data_adult$hours.per.week, .99)
top_one_percent

Explicación del código

  • quantile(data_adult$hours.per.week, .99): Calcula el valor del 99 por ciento del tiempo de trabajo

Salida:

## 99% 
##  80

El 98 por ciento de la población trabaja menos de 80 horas semanales.

Puede dejar las observaciones por encima de este umbral. Usas el filtro del dplyr biblioteca.

data_adult_drop <-data_adult %>%
filter(hours.per.week<top_one_percent)
dim(data_adult_drop)

Salida:

## [1] 45537    10
  1. Estandarizar las variables continuas.

Puedes estandarizar cada columna para mejorar el rendimiento porque tus datos no tienen la misma escala. Puede utilizar la función mutate_if de la biblioteca dplyr. La sintaxis básica es:

mutate_if(df, condition, funs(function))
arguments:
-`df`: Data frame used to compute the function
- `condition`: Statement used. Do not use parenthesis
- funs(function):  Return the function to apply. Do not use parenthesis for the function

Puede estandarizar las columnas numéricas de la siguiente manera:

data_adult_rescale <- data_adult_drop % > %
	mutate_if(is.numeric, funs(as.numeric(scale(.))))
head(data_adult_rescale)

Explicación del código

  • mutate_if(is.numeric, funs(scale)): la condición es solo una columna numérica y la función es escala

Salida:

##           X         age        workclass    education educational.num
## 1 -1.732680 -1.02325949          Private         11th     -1.22106443
## 2 -1.732605 -0.03969284          Private      HS-grad     -0.43998868
## 3 -1.732530 -0.79628257        Local-gov   Assoc-acdm      0.73162494
## 4 -1.732455  0.41426100          Private Some-college     -0.04945081
## 5 -1.732379 -0.34232873          Private         10th     -1.61160231
## 6 -1.732304  1.85178149 Self-emp-not-inc  Prof-school      1.90323857
##       marital.status  race gender hours.per.week income
## 1      Never-married Black   Male    -0.03995944  <=50K
## 2 Married-civ-spouse White   Male     0.86863037  <=50K
## 3 Married-civ-spouse White   Male    -0.03995944   >50K
## 4 Married-civ-spouse Black   Male    -0.03995944   >50K
## 5      Never-married White   Male    -0.94854924  <=50K
## 6 Married-civ-spouse White   Male    -0.76683128   >50K

Paso 2) Verificar las variables de los factores

Este paso tiene dos objetivos:

  • Consulta el nivel en cada columna categórica.
  • Definir nuevos niveles

Dividiremos este paso en tres partes:

  • Seleccione las columnas categóricas
  • Almacene el gráfico de barras de cada columna en una lista
  • imprimir los gráficos

Podemos seleccionar las columnas de factores con el siguiente código:

# Select categorical column
factor <- data.frame(select_if(data_adult_rescale, is.factor))
	ncol(factor)

Explicación del código

  • data.frame(select_if(data_adult, is.factor)): almacenamos las columnas de factores en factor en un tipo de marco de datos. La biblioteca ggplot2 requiere un objeto de marco de datos.

Salida:

## [1] 6

El conjunto de datos contiene 6 variables categóricas.

El segundo paso es más hábil. Quiere trazar un gráfico de barras para cada columna del factor del marco de datos. Es más conveniente automatizar el proceso, especialmente en situaciones en las que hay muchas columnas.

library(ggplot2)
# Create graph for each column
graph <- lapply(names(factor),
    function(x) 
	ggplot(factor, aes(get(x))) +
		geom_bar() +
		theme(axis.text.x = element_text(angle = 90)))

Explicación del código

  • lapply(): Utilice la función lapply() para pasar una función en todas las columnas del conjunto de datos. Almacenas la salida en una lista.
  • función(x): La función se procesará para cada x. Aquí x son las columnas.
  • ggplot(factor, aes(get(x))) + geom_bar()+ theme(axis.text.x = element_text(angle = 90)): crea un gráfico de caracteres de barras para cada elemento x. Tenga en cuenta que para devolver x como columna, debe incluirla dentro de get()

El último paso es relativamente fácil. Quieres imprimir los 6 gráficos.

# Print the graph
graph

Salida:

## [[1]]

Variables de factores de verificación

## ## [[2]]

Variables de factores de verificación

## ## [[3]]

Variables de factores de verificación

## ## [[4]]

Variables de factores de verificación

## ## [[5]]

Variables de factores de verificación

## ## [[6]]

Variables de factores de verificación

Nota: Utilice el botón siguiente para navegar al siguiente gráfico.

Variables de factores de verificación

Paso 3) Ingeniería de funciones

Educación refundida

En el gráfico anterior se puede ver que la variable educación tiene 16 niveles. Esto es sustancial y algunos niveles tienen un número relativamente bajo de observaciones. Si desea mejorar la cantidad de información que puede obtener de esta variable, puede reformularla a un nivel superior. Es decir, creas grupos más grandes con un nivel de educación similar. Por ejemplo, un bajo nivel de educación se convertirá en abandono escolar. Los niveles superiores de educación se cambiarán a maestría.

Aquí está el detalle:

Antiguo nivel Nuevo nivel
Preescolar Dropout
10 Punteras
11 Punteras
12 Punteras
1st-4th Punteras
5mo-6vo Punteras
7mo-8vo Punteras
9 Punteras
Graduado HS alto grado
Alguna educación superior Comunidad
asoc-acdm Comunidad
Voc-Asociado Comunidad
Solteros Solteros
Masters Masters
escuela profesional Masters
Doctorado Doctorado
recast_data <- data_adult_rescale % > %
	select(-X) % > %
	mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community",
    ifelse(education == "Bachelors", "Bachelors",
        ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))

Explicación del código

  • Usamos el verbo mutar de la biblioteca dplyr. Cambiamos los valores de la educación con la afirmación ifelse

En la siguiente tabla, crea una estadística resumida para ver, en promedio, cuántos años de educación (valor z) se necesitan para alcanzar la licenciatura, la maestría o el doctorado.

recast_data % > %
	group_by(education) % > %
	summarize(average_educ_year = mean(educational.num),
		count = n()) % > %
	arrange(average_educ_year)

Salida:

## # A tibble: 6 x 3
## education average_educ_year count			
##      <fctr>             <dbl> <int>
## 1   dropout       -1.76147258  5712
## 2  HighGrad       -0.43998868 14803
## 3 Community        0.09561361 13407
## 4 Bachelors        1.12216282  7720
## 5    Master        1.60337381  3338
## 6       PhD        2.29377644   557

Nuevo reparto de papeles Mariestado-tal

También es posible crear niveles inferiores para el estado civil. En el código siguiente, cambia el nivel de la siguiente manera:

Antiguo nivel Nuevo nivel
Nunca casado No casado
cónyuge-casado-ausente No casado
Cónyuge-AF-casado Casado
Cónyuge-civil-casado
Separado Separado
Divorciado
Viudas Viuda
# Change level marry
recast_data <- recast_data % > %
	mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))

Puede consultar el número de personas dentro de cada grupo.

table(recast_data$marital.status)

Salida:

## ##     Married Not_married   Separated       Widow
##       21165       15359        7727        1286

Paso 4) Estadística resumida

Es hora de comprobar algunas estadísticas sobre nuestras variables objetivo. En el siguiente gráfico, se cuenta el porcentaje de personas que ganan más de 50 XNUMX según su género.

# Plot gender income
ggplot(recast_data, aes(x = gender, fill = income)) +
    geom_bar(position = "fill") +
    theme_classic()

Salida:

Estadística resumida

A continuación, verifique si el origen del individuo afecta sus ingresos.

# Plot origin income
ggplot(recast_data, aes(x = race, fill = income)) +
    geom_bar(position = "fill") +
    theme_classic() +
    theme(axis.text.x = element_text(angle = 90))

Salida:

Estadística resumida

Número de horas trabajadas por género.

# box plot gender working time
ggplot(recast_data, aes(x = gender, y = hours.per.week)) +
    geom_boxplot() +
    stat_summary(fun.y = mean,
        geom = "point",
        size = 3,
        color = "steelblue") +
    theme_classic()

Salida:

Estadística resumida

El diagrama de caja confirma que la distribución del tiempo de trabajo se ajusta a diferentes grupos. En el diagrama de caja, ambos géneros no tienen observaciones homogéneas.

Puedes consultar la densidad del tiempo de trabajo semanal por tipo de educación. Las distribuciones tienen muchas selecciones distintas. Probablemente pueda explicarse por el tipo de contrato en Estados Unidos.

# Plot distribution working time by education
ggplot(recast_data, aes(x = hours.per.week)) +
    geom_density(aes(color = education), alpha = 0.5) +
    theme_classic()

Explicación del código

  • ggplot(recast_data, aes( x= hours.per.week)): Un gráfico de densidad solo requiere una variable
  • geom_density(aes(color = educación), alfa =0.5): El objeto geométrico para controlar la densidad

Salida:

Estadística resumida

Para confirmar sus pensamientos, puede realizar una prueba unidireccional. Prueba ANOVA:

anova <- aov(hours.per.week~education, recast_data)
summary(anova)

Salida:

##                Df Sum Sq Mean Sq F value Pr(>F)    
## education       5   1552  310.31   321.2 <2e-16 ***
## Residuals   45531  43984    0.97                   
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

La prueba ANOVA confirma la diferencia de promedios entre los grupos.

No linealidad

Antes de ejecutar el modelo, puedes ver si el número de horas trabajadas está relacionado con la edad.

library(ggplot2)
ggplot(recast_data, aes(x = age, y = hours.per.week)) +
    geom_point(aes(color = income),
        size = 0.5) +
    stat_smooth(method = 'lm',
        formula = y~poly(x, 2),
        se = TRUE,
        aes(color = income)) +
    theme_classic()

Explicación del código

  • ggplot(recast_data, aes(x = age, y = hours.per.week)): establece la estética del gráfico
  • geom_point(aes(color= ingresos), tamaño =0.5): construye el diagrama de puntos
  • stat_smooth(): agrega la línea de tendencia con los siguientes argumentos:
    • método = 'lm': Traza el valor ajustado si el regresión lineal
    • fórmula = y~poly(x,2): Ajustar una regresión polinómica
    • se = VERDADERO: Agrega el error estándar
    • aes(color= ingresos): divide el modelo por ingresos

Salida:

No linealidad

En pocas palabras, puede probar los términos de interacción en el modelo para detectar el efecto de no linealidad entre el tiempo de trabajo semanal y otras características. Es importante detectar en qué condiciones difiere el tiempo de trabajo.

La correlación

La siguiente comprobación es visualizar la correlación entre las variables. El tipo de nivel de factor se convierte en numérico para poder trazar un mapa de calor que contenga el coeficiente de correlación calculado con el método de Spearman.

library(GGally)
# Convert data to numeric
corr <- data.frame(lapply(recast_data, as.integer))
# Plot the graphggcorr(corr,
    method = c("pairwise", "spearman"),
    nbreaks = 6,
    hjust = 0.8,
    label = TRUE,
    label_size = 3,
    color = "grey50")

Explicación del código

  • data.frame(lapply(recast_data,as.integer)): convierte datos a numéricos
  • ggcorr() traza el mapa de calor con los siguientes argumentos:
    • método: Método para calcular la correlación
    • nbreaks = 6: Número de descansos
    • hjust = 0.8: posición de control del nombre de la variable en el gráfico
    • etiqueta = VERDADERO: Agrega etiquetas en el centro de las ventanas
    • label_size = 3: etiquetas de tamaño
    • color = “grey50”): Color de la etiqueta

Salida:

La correlación

Paso 5) Conjunto de entrenamiento/prueba

Cualquier supervisado máquina de aprendizaje La tarea requiere dividir los datos entre un conjunto de trenes y un conjunto de prueba. Puede utilizar la "función" que creó en los otros tutoriales de aprendizaje supervisado para crear un conjunto de entrenamiento/prueba.

set.seed(1234)
create_train_test <- function(data, size = 0.8, train = TRUE) {
    n_row = nrow(data)
    total_row = size * n_row
    train_sample <- 1: total_row
    if (train == TRUE) {
        return (data[train_sample, ])
    } else {
        return (data[-train_sample, ])
    }
}
data_train <- create_train_test(recast_data, 0.8, train = TRUE)
data_test <- create_train_test(recast_data, 0.8, train = FALSE)
dim(data_train)

Salida:

## [1] 36429     9
dim(data_test)

Salida:

## [1] 9108    9

Paso 6) Construye el modelo

Para ver cómo funciona el algoritmo, utilice el paquete glm(). El Modelo lineal generalizado es una colección de modelos. La sintaxis básica es:

glm(formula, data=data, family=linkfunction()
Argument:
- formula:  Equation used to fit the model- data: dataset used
- Family:     - binomial: (link = "logit")			
- gaussian: (link = "identity")			
- Gamma:    (link = "inverse")			
- inverse.gaussian: (link = "1/mu^2")			
- poisson:  (link = "log")			
- quasi:    (link = "identity", variance = "constant")			
- quasibinomial:    (link = "logit")			
- quasipoisson: (link = "log")	

Está listo para estimar el modelo logístico para dividir el nivel de ingresos entre un conjunto de características.

formula <- income~.
logit <- glm(formula, data = data_train, family = 'binomial')
summary(logit)

Explicación del código

  • fórmula <- ingresos ~.: Crea el modelo para que se ajuste
  • logit <- glm(fórmula, datos = data_train, familia = 'binomial'): ajusta un modelo logístico (familia = 'binomial') con los datos de data_train.
  • resumen (logit): imprime el resumen del modelo

Salida:

## 
## Call:
## glm(formula = formula, family = "binomial", data = data_train)
## ## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.6456  -0.5858  -0.2609  -0.0651   3.1982  
## 
## Coefficients:
##                           Estimate Std. Error z value Pr(>|z|)    
## (Intercept)                0.07882    0.21726   0.363  0.71675    
## age                        0.41119    0.01857  22.146  < 2e-16 ***
## workclassLocal-gov        -0.64018    0.09396  -6.813 9.54e-12 ***
## workclassPrivate          -0.53542    0.07886  -6.789 1.13e-11 ***
## workclassSelf-emp-inc     -0.07733    0.10350  -0.747  0.45499    
## workclassSelf-emp-not-inc -1.09052    0.09140 -11.931  < 2e-16 ***
## workclassState-gov        -0.80562    0.10617  -7.588 3.25e-14 ***
## workclassWithout-pay      -1.09765    0.86787  -1.265  0.20596    
## educationCommunity        -0.44436    0.08267  -5.375 7.66e-08 ***
## educationHighGrad         -0.67613    0.11827  -5.717 1.08e-08 ***
## educationMaster            0.35651    0.06780   5.258 1.46e-07 ***
## educationPhD               0.46995    0.15772   2.980  0.00289 ** 
## educationdropout          -1.04974    0.21280  -4.933 8.10e-07 ***
## educational.num            0.56908    0.07063   8.057 7.84e-16 ***
## marital.statusNot_married -2.50346    0.05113 -48.966  < 2e-16 ***
## marital.statusSeparated   -2.16177    0.05425 -39.846  < 2e-16 ***
## marital.statusWidow       -2.22707    0.12522 -17.785  < 2e-16 ***
## raceAsian-Pac-Islander     0.08359    0.20344   0.411  0.68117    
## raceBlack                  0.07188    0.19330   0.372  0.71001    
## raceOther                  0.01370    0.27695   0.049  0.96054    
## raceWhite                  0.34830    0.18441   1.889  0.05894 .  
## genderMale                 0.08596    0.04289   2.004  0.04506 *  
## hours.per.week             0.41942    0.01748  23.998  < 2e-16 ***
## ---## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## ## (Dispersion parameter for binomial family taken to be 1)
## ##     Null deviance: 40601  on 36428  degrees of freedom
## Residual deviance: 27041  on 36406  degrees of freedom
## AIC: 27087
## 
## Number of Fisher Scoring iterations: 6

El resumen de nuestro modelo revela información interesante. El rendimiento de una regresión logística se evalúa con métricas clave específicas.

  • AIC (Criterios de información de Akaike): Es el equivalente a R2 en regresión logística. Mide el ajuste cuando se aplica una penalización al número de parámetros. Menor AIC Los valores indican que el modelo está más cerca de la verdad.
  • Desviación nula: se ajusta al modelo solo con la intersección. El grado de libertad es n-1. Podemos interpretarlo como un valor de Chi-cuadrado (valor ajustado diferente del valor real de la prueba de hipótesis).
  • Desviación Residual: Modelo con todas las variables. También se interpreta como una prueba de hipótesis de Chi-cuadrado.
  • Número de iteraciones de Fisher Scoring: número de iteraciones antes de converger.

La salida de la función glm() se almacena en una lista. El siguiente código muestra todos los elementos disponibles en la variable logit que construimos para evaluar la regresión logística.

# La lista es muy larga, imprime solo los primeros tres elementos

lapply(logit, class)[1:3]

Salida:

## $coefficients
## [1] "numeric"
## 
## $residuals
## [1] "numeric"
## 
## $fitted.values
## [1] "numeric"

Cada valor se puede extraer con el signo $ seguido del nombre de las métricas. Por ejemplo, almacenó el modelo como logit. Para extraer los criterios AIC, utiliza:

logit$aic

Salida:

## [1] 27086.65

Paso 7) Evaluar el desempeño del modelo.

Matriz de confusión

La directiva matriz de confusión es una mejor opción para evaluar el rendimiento de la clasificación en comparación con las diferentes métricas que vio antes. La idea general es contar el número de veces que las instancias Verdaderas se clasifican como Falsas.

Matriz de confusión

Para calcular la matriz de confusión, primero es necesario tener un conjunto de predicciones para poder compararlas con los objetivos reales.

predict <- predict(logit, data_test, type = 'response')
# confusion matrix
table_mat <- table(data_test$income, predict > 0.5)
table_mat

Explicación del código

  • predict(logit,data_test, type = 'response'): calcula la predicción en el conjunto de prueba. Establezca tipo = 'respuesta' para calcular la probabilidad de respuesta.
  • tabla(data_test$ingresos, predecir > ​​0.5): Calcula la matriz de confusión. predecir > ​​0.5 significa que devuelve 1 si las probabilidades predichas son superiores a 0.5; en caso contrario, 0.

Salida:

##        
##         FALSE TRUE
##   <=50K  6310  495
##   >50K   1074 1229	

Cada fila de una matriz de confusión representa un objetivo real, mientras que cada columna representa un objetivo previsto. La primera fila de esta matriz considera los ingresos inferiores a 50k (la clase Falsa): 6241 fueron clasificados correctamente como individuos con ingresos inferiores a 50k (Verdadero-negativo), mientras que el restante fue clasificado erróneamente como superior a 50k (Falso positivo). La segunda fila considera los ingresos superiores a 50k, la clase positiva fue 1229 (Verdadero positivo), Mientras que el Verdadero-negativo fue 1074.

Puedes calcular el modelo. la exactitud sumando el verdadero positivo + el verdadero negativo sobre la observación total

Matriz de confusión

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
accuracy_Test

Explicación del código

  • sum(diag(table_mat)): Suma de la diagonal
  • sum(table_mat): Suma de la matriz.

Salida:

## [1] 0.8277339

El modelo parece sufrir un problema: sobreestima el número de falsos negativos. Esto se llama el paradoja de la prueba de precisión. Dijimos que la precisión es la relación entre las predicciones correctas y el número total de casos. Podemos tener una precisión relativamente alta pero un modelo inútil. Ocurre cuando hay una clase dominante. Si vuelve a mirar la matriz de confusión, podrá ver que la mayoría de los casos se clasifican como verdaderos negativos. Imagínese ahora, el modelo clasificó todas las clases como negativas (es decir, inferiores a 50k). Tendría una precisión del 75 por ciento (6718/6718+2257). Su modelo funciona mejor pero tiene dificultades para distinguir lo verdadero positivo de lo verdadero negativo.

En tal situación, es preferible tener una métrica más concisa. Podemos mirar:

  • Precisión=TP/(TP+FP)
  • Recuperar=TP/(TP+FN)

Precisión vs recuperación

Precisión analiza la precisión de la predicción positiva. Recordar es la proporción de instancias positivas que el clasificador detecta correctamente;

Puedes construir dos funciones para calcular estas dos métricas.

  1. Precisión de construcción
precision <- function(matrix) {
	# True positive
    tp <- matrix[2, 2]
	# false positive
    fp <- matrix[1, 2]
    return (tp / (tp + fp))
}

Explicación del código

  • mat[1,1]: Devuelve la primera celda de la primera columna del marco de datos, es decir, el verdadero positivo
  • estera[1,2]; Devuelve la primera celda de la segunda columna del marco de datos, es decir, el falso positivo.
recall <- function(matrix) {
# true positive
    tp <- matrix[2, 2]# false positive
    fn <- matrix[2, 1]
    return (tp / (tp + fn))
}

Explicación del código

  • mat[1,1]: Devuelve la primera celda de la primera columna del marco de datos, es decir, el verdadero positivo
  • estera[2,1]; Devuelve la segunda celda de la primera columna del marco de datos, es decir, el falso negativo.

Puedes probar tus funciones.

prec <- precision(table_mat)
prec
rec <- recall(table_mat)
rec

Salida:

## [1] 0.712877
## [2] 0.5336518

Cuando el modelo dice que es un individuo por encima de 50k, es correcto sólo en el 54 por ciento de los casos, y puede reclamar individuos por encima de 50k en el 72 por ciento de los casos.

Puede crear el Precisión vs recuperación Puntuación basada en la precisión y el recuerdo. El Precisión vs recuperación es una media armónica de estas dos métricas, lo que significa que da más peso a los valores más bajos.

Precisión vs recuperación

f1 <- 2 * ((prec * rec) / (prec + rec))
f1

Salida:

## [1] 0.6103799

Compensación entre precisión y recuperación

Es imposible tener al mismo tiempo una alta precisión y una alta recuperación.

Si aumentamos la precisión, se predecirá mejor el individuo correcto, pero perderíamos muchos de ellos (menor recuerdo). En algunas situaciones, preferimos una mayor precisión que la recuperación. Existe una relación cóncava entre precisión y recuperación.

  • Imagínese, necesita predecir si un paciente tiene una enfermedad. Quieres ser lo más preciso posible.
  • Si necesitas detectar posibles personas fraudulentas en la calle mediante el reconocimiento facial, sería mejor detectar a muchas personas etiquetadas como fraudulentas aunque la precisión sea baja. La policía podrá liberar al individuo no fraudulento.

La curva ROC

La directiva Receptor Operacaracterística de tintura La curva es otra herramienta común utilizada con la clasificación binaria. Es muy similar a la curva de precisión/recuperación, pero en lugar de representar gráficamente la precisión versus la recuperación, la curva ROC muestra la tasa de verdaderos positivos (es decir, recuperación) frente a la tasa de falsos positivos. La tasa de falsos positivos es la proporción de casos negativos que se clasifican incorrectamente como positivos. Es igual a uno menos la tasa negativa verdadera. La tasa negativa verdadera también se llama especificidad. Por lo tanto, los gráficos de la curva ROC sensibilidad (recuerdo) versus 1-especificidad

Para trazar la curva ROC, necesitamos instalar una biblioteca llamada RORC. Podemos encontrar en la conda. bibliotecas. Puedes escribir el código:

instalación de conda -c r r-rocr –sí

Podemos trazar la República de China con las funciones de predicción() y rendimiento().

library(ROCR)
ROCRpred <- prediction(predict, data_test$income)
ROCRperf <- performance(ROCRpred, 'tpr', 'fpr')
plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))

Explicación del código

  • predicción (predecir, prueba_datos $ ingresos): la biblioteca ROCR necesita crear un objeto de predicción para transformar los datos de entrada
  • rendimiento (ROCRpred, 'tpr', 'fpr'): devuelve las dos combinaciones para producir en el gráfico. Aquí se construyen tpr y fpr. Para trazar la precisión y recuperar juntos, use "prec", "rec".

Salida:

La curva ROC

Paso 8) Mejorar el modelo

Puedes intentar agregar no linealidad al modelo con la interacción entre

  • Edad y horas por semana
  • género y horas por semana.

Debe utilizar la prueba de puntuación para comparar ambos modelos.

formula_2 <- income~age: hours.per.week + gender: hours.per.week + .
logit_2 <- glm(formula_2, data = data_train, family = 'binomial')
predict_2 <- predict(logit_2, data_test, type = 'response')
table_mat_2 <- table(data_test$income, predict_2 > 0.5)
precision_2 <- precision(table_mat_2)
recall_2 <- recall(table_mat_2)
f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2))
f1_2

Salida:

## [1] 0.6109181

La puntuación es ligeramente superior a la anterior. Puedes seguir trabajando en los datos e intentar superar la puntuación.

Resum

Podemos resumir la función para entrenar una regresión logística en la siguiente tabla:

PREMIUM Objetivo Función Argumento
Crear conjunto de datos de entrenamiento/prueba crear_tren_set() datos, tamaño, tren
glm Entrenar un modelo lineal generalizado glm() fórmula, datos, familia*
glm Resumir el modelo resumen() modelo ajustado
bases Hacer predicción predecir() modelo ajustado, conjunto de datos, tipo = 'respuesta'
bases Crea una matriz de confusión mesa() y, predecir()
bases Crear puntuación de precisión suma(diag(tabla())/suma(tabla()
ROCR Crear ROC: Paso 1 Crear predicción predicción() predecir(), y
ROCR Crear ROC: Paso 2 Crear rendimiento actuación() predicción(), 'tpr', 'fpr'
ROCR Crear ROC: Paso 3 Trazar gráfico trama() actuación()

El otro GLM tipo de modelos son:

– binomial: (enlace = “logit”)

– gaussiano: (enlace = “identidad”)

– Gamma: (enlace = “inverso”)

– inverso.gaussiano: (enlace = “1/mu^2”)

– poisson: (enlace = “registro”)

– cuasi: (enlace = “identidad”, varianza = “constante”)

– cuasibinomial: (enlace = “logit”)

– quasipoisson: (enlace = “registro”)