GLM em R: modelo linear generalizado com exemplo

O que é regressão logística?

A regressão logística é usada para prever uma classe, ou seja, uma probabilidade. A regressão logística pode prever um resultado binário com precisão.

Imagine que você deseja prever se um empréstimo será negado/aceito com base em vários atributos. A regressão logística tem o formato 0/1. y = 0 se um empréstimo for rejeitado, y = 1 se for aceito.

Um modelo de regressão logística difere do modelo de regressão linear de duas maneiras.

  • Em primeiro lugar, a regressão logística aceita apenas dados dicotômicos (binários) como variável dependente (ou seja, um vetor de 0 e 1).
  • Em segundo lugar, o resultado é medido pela seguinte função de ligação probabilística chamada sigmóide devido ao seu formato em S.:

Regressão Logística

A saída da função está sempre entre 0 e 1. Verifique a imagem abaixo

Regressão Logística

A função sigmóide retorna valores de 0 a 1. Para a tarefa de classificação, precisamos de uma saída discreta de 0 ou 1.

Para converter um fluxo contínuo em valor discreto, podemos definir um limite de decisão em 0.5. Todos os valores acima deste limite são classificados como 1

Regressão Logística

Como criar um modelo de liner generalizado (GLM)

Vamos usar o adulto conjunto de dados para ilustrar a regressão logística. O “adulto” é um ótimo conjunto de dados para a tarefa de classificação. O objetivo é prever se a renda anual em dólares de um indivíduo ultrapassará 50.000. O conjunto de dados contém 46,033 observações e dez recursos:

  • idade: idade do indivíduo. Numérico
  • educação: Nível educacional do indivíduo. Fator.
  • estado civil: Mariestatuto social do indivíduo. Fator, ou seja, nunca casado, cônjuge civil casado, ...
  • gênero: Gênero do indivíduo. Fator, ou seja, Masculino ou Feminino
  • renda: Target variável. Renda acima ou abaixo de 50K. Fator, ou seja, >50K, <=50K

entre outros

library(dplyr)
data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv")
glimpse(data_adult)

Saída:

Observations: 48,842
Variables: 10
$ x               <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,...
$ age             <int> 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26...
$ workclass       <fctr> Private, Private, Local-gov, Private, ?, Private,...
$ education       <fctr> 11th, HS-grad, Assoc-acdm, Some-college, Some-col...
$ educational.num <int> 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,...
$ marital.status  <fctr> Never-married, Married-civ-spouse, Married-civ-sp...
$ race            <fctr> Black, White, White, Black, White, White, Black, ...
$ gender          <fctr> Male, Male, Male, Male, Female, Male, Male, Male,...
$ hours.per.week  <int> 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39...
$ income          <fctr> <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5...

Procederemos da seguinte forma:

  • Passo 1: Verifique variáveis ​​contínuas
  • Passo 2: Verifique as variáveis ​​fatoriais
  • Etapa 3: engenharia de recursos
  • Etapa 4: estatística resumida
  • Etapa 5: conjunto de treinamento/teste
  • Etapa 6: construir o modelo
  • Etapa 7: Avalie o desempenho do modelo
  • etapa 8: Melhore o modelo

Sua tarefa é prever qual indivíduo terá uma receita superior a 50 mil.

Neste tutorial, cada etapa será detalhada para realizar uma análise em um conjunto de dados real.

Passo 1) Verifique variáveis ​​contínuas

Na primeira etapa, você pode ver a distribuição das variáveis ​​contínuas.

continuous <-select_if(data_adult, is.numeric)
summary(continuous)

Explicação do código

  • contínuo <- select_if(data_adult, is.numeric): Use a função select_if() da biblioteca dplyr para selecionar apenas as colunas numéricas
  • resumo (contínuo): Imprime a estatística do resumo

Saída:

##        X              age        educational.num hours.per.week 
##  Min.   :    1   Min.   :17.00   Min.   : 1.00   Min.   : 1.00  
##  1st Qu.:11509   1st Qu.:28.00   1st Qu.: 9.00   1st Qu.:40.00  
##  Median :23017   Median :37.00   Median :10.00   Median :40.00  
##  Mean   :23017   Mean   :38.56   Mean   :10.13   Mean   :40.95  
##  3rd Qu.:34525   3rd Qu.:47.00   3rd Qu.:13.00   3rd Qu.:45.00  
##  Max.   :46033   Max.   :90.00   Max.   :16.00   Max.   :99.00	

Na tabela acima, você pode ver que os dados têm escalas totalmente diferentes e horas por semana têm grandes discrepâncias (ou seja, observe o último quartil e o valor máximo).

Você pode lidar com isso seguindo duas etapas:

  • 1: Trace a distribuição de horas por semana
  • 2: Padronize as variáveis ​​contínuas
  1. Trace a distribuição

Vejamos mais de perto a distribuição de horas por semana

# Histogram with kernel density curve
library(ggplot2)
ggplot(continuous, aes(x = hours.per.week)) +
    geom_density(alpha = .2, fill = "#FF6666")

Saída:

Verifique variáveis ​​contínuas

A variável tem muitos valores discrepantes e distribuição não bem definida. Você pode resolver parcialmente esse problema excluindo 0.01% das principais horas da semana.

Sintaxe básica do quantil:

quantile(variable, percentile)
arguments:
-variable:  Select the variable in the data frame to compute the percentile
-percentile:  Can be a single value between 0 and 1 or multiple value. If multiple, use this format:  `c(A,B,C, ...)
- `A`,`B`,`C` and `...` are all integer from 0 to 1.

Calculamos o percentil dos 2 por cento superiores

top_one_percent <- quantile(data_adult$hours.per.week, .99)
top_one_percent

Explicação do código

  • quantile(data_adult$hours.per.week, .99): Calcule o valor de 99 por cento do tempo de trabalho

Saída:

## 99% 
##  80

98 por cento da população trabalha menos de 80 horas por semana.

Você pode deixar as observações acima desse limite. Você usa o filtro do dplyr biblioteca.

data_adult_drop <-data_adult %>%
filter(hours.per.week<top_one_percent)
dim(data_adult_drop)

Saída:

## [1] 45537    10
  1. Padronize as variáveis ​​contínuas

Você pode padronizar cada coluna para melhorar o desempenho porque seus dados não têm a mesma escala. Você pode usar a função mutate_if da biblioteca dplyr. A sintaxe básica é:

mutate_if(df, condition, funs(function))
arguments:
-`df`: Data frame used to compute the function
- `condition`: Statement used. Do not use parenthesis
- funs(function):  Return the function to apply. Do not use parenthesis for the function

Você pode padronizar as colunas numéricas da seguinte forma:

data_adult_rescale <- data_adult_drop % > %
	mutate_if(is.numeric, funs(as.numeric(scale(.))))
head(data_adult_rescale)

Explicação do código

  • mutate_if(is.numeric, funs(scale)): A condição é apenas coluna numérica e a função é escala

Saída:

##           X         age        workclass    education educational.num
## 1 -1.732680 -1.02325949          Private         11th     -1.22106443
## 2 -1.732605 -0.03969284          Private      HS-grad     -0.43998868
## 3 -1.732530 -0.79628257        Local-gov   Assoc-acdm      0.73162494
## 4 -1.732455  0.41426100          Private Some-college     -0.04945081
## 5 -1.732379 -0.34232873          Private         10th     -1.61160231
## 6 -1.732304  1.85178149 Self-emp-not-inc  Prof-school      1.90323857
##       marital.status  race gender hours.per.week income
## 1      Never-married Black   Male    -0.03995944  <=50K
## 2 Married-civ-spouse White   Male     0.86863037  <=50K
## 3 Married-civ-spouse White   Male    -0.03995944   >50K
## 4 Married-civ-spouse Black   Male    -0.03995944   >50K
## 5      Never-married White   Male    -0.94854924  <=50K
## 6 Married-civ-spouse White   Male    -0.76683128   >50K

Passo 2) Verifique as variáveis ​​fatoriais

Esta etapa tem dois objetivos:

  • Verifique o nível em cada coluna categórica
  • Defina novos níveis

Dividiremos esta etapa em três partes:

  • Selecione as colunas categóricas
  • Armazene o gráfico de barras de cada coluna em uma lista
  • Imprima os gráficos

Podemos selecionar as colunas de fator com o código abaixo:

# Select categorical column
factor <- data.frame(select_if(data_adult_rescale, is.factor))
	ncol(factor)

Explicação do código

  • data.frame(select_if(data_adult, is.factor)): Armazenamos as colunas de fator em factor em um tipo de quadro de dados. A biblioteca ggplot2 requer um objeto de quadro de dados.

Saída:

## [1] 6

O conjunto de dados contém 6 variáveis ​​​​categóricas

A segunda etapa é mais qualificada. Você deseja traçar um gráfico de barras para cada coluna no fator do quadro de dados. É mais conveniente automatizar o processo, principalmente em situações em que há muitas colunas.

library(ggplot2)
# Create graph for each column
graph <- lapply(names(factor),
    function(x) 
	ggplot(factor, aes(get(x))) +
		geom_bar() +
		theme(axis.text.x = element_text(angle = 90)))

Explicação do código

  • lapply(): Use a função lapply() para passar uma função em todas as colunas do conjunto de dados. Você armazena a saída em uma lista
  • function(x): A função será processada para cada x. Aqui x são as colunas
  • ggplot(factor, aes(get(x))) + geom_bar()+ theme(axis.text.x = element_text(angle = 90)): Crie um gráfico de barras para cada elemento x. Observe que para retornar x como uma coluna, você precisa incluí-lo dentro de get()

A última etapa é relativamente fácil. Você deseja imprimir os 6 gráficos.

# Print the graph
graph

Saída:

## [[1]]

Verifique variáveis ​​de fator

## ## [[2]]

Verifique variáveis ​​de fator

## ## [[3]]

Verifique variáveis ​​de fator

## ## [[4]]

Verifique variáveis ​​de fator

## ## [[5]]

Verifique variáveis ​​de fator

## ## [[6]]

Verifique variáveis ​​de fator

Nota: Use o botão seguinte para navegar para o próximo gráfico

Verifique variáveis ​​de fator

Etapa 3) Engenharia de recursos

Reformular a educação

No gráfico acima você pode perceber que a variável escolaridade possui 16 níveis. Isto é substancial e alguns níveis têm um número relativamente baixo de observações. Se quiser melhorar a quantidade de informações que pode obter dessa variável, você pode reformulá-la para um nível superior. Ou seja, você cria grupos maiores com nível de escolaridade semelhante. Por exemplo, o baixo nível de escolaridade será convertido em abandono escolar. Níveis mais elevados de educação serão alterados para mestrado.

Aqui está o detalhe:

Nível antigo Novo nível
Pré escola Dropout
sec 10 Cair fora
sec 11 Cair fora
sec 12 Cair fora
1º a 4º Cair fora
5th-6th Cair fora
7th-8th Cair fora
sec 9 Cair fora
HS-Graduação Alta graduação
Alguma faculdade Comunidade
Associado-acdm Comunidade
Associado Comunidade
Bacharelado Bacharelado
mestres mestres
Escola profissional mestres
Doutorado PhD
recast_data <- data_adult_rescale % > %
	select(-X) % > %
	mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community",
    ifelse(education == "Bachelors", "Bachelors",
        ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))

Explicação do código

  • Usamos o verbo mutate da biblioteca dplyr. Mudamos os valores da educação com a afirmação ifelse

Na tabela abaixo, você cria uma estatística resumida para ver, em média, quantos anos de estudo (valor z) são necessários para se chegar ao bacharelado, mestrado ou doutorado.

recast_data % > %
	group_by(education) % > %
	summarize(average_educ_year = mean(educational.num),
		count = n()) % > %
	arrange(average_educ_year)

Saída:

## # A tibble: 6 x 3
## education average_educ_year count			
##      <fctr>             <dbl> <int>
## 1   dropout       -1.76147258  5712
## 2  HighGrad       -0.43998868 14803
## 3 Community        0.09561361 13407
## 4 Bachelors        1.12216282  7720
## 5    Master        1.60337381  3338
## 6       PhD        2.29377644   557

Reformulação Marital-status

Também é possível criar níveis inferiores para o estado civil. No código a seguir, você altera o nível da seguinte maneira:

Nível antigo Novo nível
Nunca casado Solteiro
Casado-cônjuge-ausente Solteiro
Casado-AF-cônjuge Casado
Cônjuge casada
Separado Separado
Divorciado
Viúvas Viúva
# Change level marry
recast_data <- recast_data % > %
	mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))

Você pode verificar o número de indivíduos dentro de cada grupo.

table(recast_data$marital.status)

Saída:

## ##     Married Not_married   Separated       Widow
##       21165       15359        7727        1286

Etapa 4) Estatística resumida

É hora de verificar algumas estatísticas sobre nossas variáveis-alvo. No gráfico abaixo, você conta a porcentagem de indivíduos que ganham mais de 50 mil de acordo com seu gênero.

# Plot gender income
ggplot(recast_data, aes(x = gender, fill = income)) +
    geom_bar(position = "fill") +
    theme_classic()

Saída:

Estatística resumida

A seguir, verifique se a origem do indivíduo afeta seus rendimentos.

# Plot origin income
ggplot(recast_data, aes(x = race, fill = income)) +
    geom_bar(position = "fill") +
    theme_classic() +
    theme(axis.text.x = element_text(angle = 90))

Saída:

Estatística resumida

O número de horas de trabalho por gênero.

# box plot gender working time
ggplot(recast_data, aes(x = gender, y = hours.per.week)) +
    geom_boxplot() +
    stat_summary(fun.y = mean,
        geom = "point",
        size = 3,
        color = "steelblue") +
    theme_classic()

Saída:

Estatística resumida

O box plot confirma que a distribuição do tempo de trabalho se ajusta a diferentes grupos. No box plot, ambos os sexos não possuem observações homogêneas.

Você pode verificar a densidade do tempo de trabalho semanal por tipo de ensino. As distribuições têm muitas escolhas distintas. Provavelmente pode ser explicado pelo tipo de contrato nos EUA.

# Plot distribution working time by education
ggplot(recast_data, aes(x = hours.per.week)) +
    geom_density(aes(color = education), alpha = 0.5) +
    theme_classic()

Explicação do código

  • ggplot(recast_data, aes( x= hours.per.week)): Um gráfico de densidade requer apenas uma variável
  • geom_density(aes(color = education), alpha =0.5): O objeto geométrico para controlar a densidade

Saída:

Estatística resumida

Para confirmar seus pensamentos, você pode realizar um teste unilateral Teste ANOVA:

anova <- aov(hours.per.week~education, recast_data)
summary(anova)

Saída:

##                Df Sum Sq Mean Sq F value Pr(>F)    
## education       5   1552  310.31   321.2 <2e-16 ***
## Residuals   45531  43984    0.97                   
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

O teste ANOVA confirma a diferença de média entre os grupos.

Não-linearidade

Antes de executar o modelo, você pode ver se o número de horas trabalhadas está relacionado à idade.

library(ggplot2)
ggplot(recast_data, aes(x = age, y = hours.per.week)) +
    geom_point(aes(color = income),
        size = 0.5) +
    stat_smooth(method = 'lm',
        formula = y~poly(x, 2),
        se = TRUE,
        aes(color = income)) +
    theme_classic()

Explicação do código

  • ggplot(recast_data, aes(x = age, y = hours.per.week)): Defina a estética do gráfico
  • geom_point(aes(color= income), size =0.5): Construa o gráfico de pontos
  • stat_smooth(): Adicione a linha de tendência com os seguintes argumentos:
    • method='lm': Plote o valor ajustado se o regressão linear
    • fórmula = y~poly(x,2): Ajustar uma regressão polinomial
    • se = TRUE: Adicione o erro padrão
    • aes(cor=renda): Divida o modelo por renda

Saída:

Não-linearidade

Resumindo, você pode testar os termos de interação no modelo para captar o efeito de não linearidade entre o horário de trabalho semanal e outros recursos. É importante detectar em que condições o tempo de trabalho difere.

Correlação

A próxima verificação é visualizar a correlação entre as variáveis. Você converte o tipo de nível de fator em numérico para poder traçar um mapa de calor contendo o coeficiente de correlação calculado com o método de Spearman.

library(GGally)
# Convert data to numeric
corr <- data.frame(lapply(recast_data, as.integer))
# Plot the graphggcorr(corr,
    method = c("pairwise", "spearman"),
    nbreaks = 6,
    hjust = 0.8,
    label = TRUE,
    label_size = 3,
    color = "grey50")

Explicação do código

  • data.frame(lapply(recast_data,as.integer)): Converte dados em numéricos
  • ggcorr() plote o mapa de calor com os seguintes argumentos:
    • método: Método para calcular a correlação
    • nbreaks = 6: Número de pausas
    • hjust = 0.8: posição de controle do nome da variável no gráfico
    • label = TRUE: Adicione rótulos no centro das janelas
    • label_size = 3: rótulos de tamanho
    • color = “grey50”): Cor da etiqueta

Saída:

Correlação

Etapa 5) Conjunto de treinamento/teste

Qualquer supervisionado aprendizado de máquina A tarefa exige dividir os dados entre um conjunto de treinamento e um conjunto de teste. Você pode usar a “função” criada em outros tutoriais de aprendizagem supervisionada para criar um conjunto de treinamento/teste.

set.seed(1234)
create_train_test <- function(data, size = 0.8, train = TRUE) {
    n_row = nrow(data)
    total_row = size * n_row
    train_sample <- 1: total_row
    if (train == TRUE) {
        return (data[train_sample, ])
    } else {
        return (data[-train_sample, ])
    }
}
data_train <- create_train_test(recast_data, 0.8, train = TRUE)
data_test <- create_train_test(recast_data, 0.8, train = FALSE)
dim(data_train)

Saída:

## [1] 36429     9
dim(data_test)

Saída:

## [1] 9108    9

Etapa 6) Construa o modelo

Para ver o desempenho do algoritmo, você usa o pacote glm(). O Modelo Linear Generalizado é uma coleção de modelos. A sintaxe básica é:

glm(formula, data=data, family=linkfunction()
Argument:
- formula:  Equation used to fit the model- data: dataset used
- Family:     - binomial: (link = "logit")			
- gaussian: (link = "identity")			
- Gamma:    (link = "inverse")			
- inverse.gaussian: (link = "1/mu^2")			
- poisson:  (link = "log")			
- quasi:    (link = "identity", variance = "constant")			
- quasibinomial:    (link = "logit")			
- quasipoisson: (link = "log")	

Você está pronto para estimar o modelo logístico para dividir o nível de renda entre um conjunto de recursos.

formula <- income~.
logit <- glm(formula, data = data_train, family = 'binomial')
summary(logit)

Explicação do código

  • fórmula <- renda ~ .: Crie o modelo para caber
  • logit <- glm(formula, data = data_train, family = 'binomial'): Ajusta um modelo logístico (family = 'binomial') com os dados data_train.
  • summary(logit): Imprime o resumo do modelo

Saída:

## 
## Call:
## glm(formula = formula, family = "binomial", data = data_train)
## ## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.6456  -0.5858  -0.2609  -0.0651   3.1982  
## 
## Coefficients:
##                           Estimate Std. Error z value Pr(>|z|)    
## (Intercept)                0.07882    0.21726   0.363  0.71675    
## age                        0.41119    0.01857  22.146  < 2e-16 ***
## workclassLocal-gov        -0.64018    0.09396  -6.813 9.54e-12 ***
## workclassPrivate          -0.53542    0.07886  -6.789 1.13e-11 ***
## workclassSelf-emp-inc     -0.07733    0.10350  -0.747  0.45499    
## workclassSelf-emp-not-inc -1.09052    0.09140 -11.931  < 2e-16 ***
## workclassState-gov        -0.80562    0.10617  -7.588 3.25e-14 ***
## workclassWithout-pay      -1.09765    0.86787  -1.265  0.20596    
## educationCommunity        -0.44436    0.08267  -5.375 7.66e-08 ***
## educationHighGrad         -0.67613    0.11827  -5.717 1.08e-08 ***
## educationMaster            0.35651    0.06780   5.258 1.46e-07 ***
## educationPhD               0.46995    0.15772   2.980  0.00289 ** 
## educationdropout          -1.04974    0.21280  -4.933 8.10e-07 ***
## educational.num            0.56908    0.07063   8.057 7.84e-16 ***
## marital.statusNot_married -2.50346    0.05113 -48.966  < 2e-16 ***
## marital.statusSeparated   -2.16177    0.05425 -39.846  < 2e-16 ***
## marital.statusWidow       -2.22707    0.12522 -17.785  < 2e-16 ***
## raceAsian-Pac-Islander     0.08359    0.20344   0.411  0.68117    
## raceBlack                  0.07188    0.19330   0.372  0.71001    
## raceOther                  0.01370    0.27695   0.049  0.96054    
## raceWhite                  0.34830    0.18441   1.889  0.05894 .  
## genderMale                 0.08596    0.04289   2.004  0.04506 *  
## hours.per.week             0.41942    0.01748  23.998  < 2e-16 ***
## ---## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## ## (Dispersion parameter for binomial family taken to be 1)
## ##     Null deviance: 40601  on 36428  degrees of freedom
## Residual deviance: 27041  on 36406  degrees of freedom
## AIC: 27087
## 
## Number of Fisher Scoring iterations: 6

O resumo do nosso modelo revela informações interessantes. O desempenho de uma regressão logística é avaliado com métricas-chave específicas.

  • AIC (Critérios de Informação Akaike): Isto é equivalente a R2 na regressão logística. Mede o ajuste quando uma penalidade é aplicada ao número de parâmetros. Menor AIC valores indicam que o modelo está mais próximo da verdade.
  • Desvio nulo: ajusta o modelo apenas com o intercepto. O grau de liberdade é n-1. Podemos interpretá-lo como um valor qui-quadrado (valor ajustado diferente do teste de hipótese de valor real).
  • Desvio Residual: Modelo com todas as variáveis. Também é interpretado como um teste de hipótese do qui-quadrado.
  • Número de iterações do Fisher Scoring: Número de iterações antes da convergência.

A saída da função glm() é armazenada em uma lista. O código abaixo mostra todos os itens disponíveis na variável logit que construímos para avaliar a regressão logística.

# A lista é muito longa, imprima apenas os três primeiros elementos

lapply(logit, class)[1:3]

Saída:

## $coefficients
## [1] "numeric"
## 
## $residuals
## [1] "numeric"
## 
## $fitted.values
## [1] "numeric"

Cada valor pode ser extraído com o sinal $ seguido do nome das métricas. Por exemplo, você armazenou o modelo como logit. Para extrair os critérios AIC, você usa:

logit$aic

Saída:

## [1] 27086.65

Etapa 7) Avalie o desempenho do modelo

Matriz de Confusão

A matriz de confusão é a melhor opção para avaliar o desempenho da classificação em comparação com as diferentes métricas que você viu antes. A ideia geral é contar o número de vezes que instâncias Verdadeiras são classificadas como Falsas.

Matriz de Confusão

Para calcular a matriz de confusão, primeiro você precisa ter um conjunto de previsões para que possam ser comparadas com os alvos reais.

predict <- predict(logit, data_test, type = 'response')
# confusion matrix
table_mat <- table(data_test$income, predict > 0.5)
table_mat

Explicação do código

  • predizer(logit,data_test, type = 'response'): Calcula a previsão no conjunto de teste. Defina type = 'response' para calcular a probabilidade de resposta.
  • table(data_test$income, prever > 0.5): Calcula a matriz de confusão. prever > 0.5 significa que retornará 1 se as probabilidades previstas estiverem acima de 0.5, caso contrário, 0.

Saída:

##        
##         FALSE TRUE
##   <=50K  6310  495
##   >50K   1074 1229	

Cada linha em uma matriz de confusão representa um alvo real, enquanto cada coluna representa um alvo previsto. A primeira linha desta matriz considera a renda inferior a 50k (a classe Falsa): 6241 foram corretamente classificados como indivíduos com renda inferior a 50k (Verdadeiro negativo), enquanto o restante foi erroneamente classificado como acima de 50k (Falso positivo). A segunda linha considera a renda acima de 50k, a classe positiva foi 1229 (Verdadeiro-positivo), Enquanto que o Verdadeiro negativo foi 1074.

Você pode calcular o modelo precisão somando o verdadeiro positivo + verdadeiro negativo sobre a observação total

Matriz de Confusão

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
accuracy_Test

Explicação do código

  • sum(diag(table_mat)): Soma da diagonal
  • sum(table_mat): Soma da matriz.

Saída:

## [1] 0.8277339

O modelo parece sofrer de um problema: superestima o número de falsos negativos. Isso é chamado de paradoxo do teste de precisão. Afirmamos que a precisão é a razão entre as previsões corretas e o número total de casos. Podemos ter uma precisão relativamente alta, mas um modelo inútil. Isso acontece quando há uma classe dominante. Se você olhar novamente para a matriz de confusão, verá que a maioria dos casos são classificados como verdadeiros negativos. Imagine agora, o modelo classificou todas as classes como negativas (ou seja, inferiores a 50k). Você teria uma precisão de 75 por cento (6718/6718+2257). Seu modelo tem melhor desempenho, mas tem dificuldade para distinguir o verdadeiro positivo do verdadeiro negativo.

Nessa situação, é preferível ter uma métrica mais concisa. Podemos olhar para:

  • Precisão=TP/(TP+FP)
  • Rechamada=TP/(TP+FN)

Precisão vs recall

Precisão analisa a precisão da previsão positiva. Recordar é a proporção de instâncias positivas que são detectadas corretamente pelo classificador;

Você pode construir duas funções para calcular essas duas métricas

  1. Precisão de construção
precision <- function(matrix) {
	# True positive
    tp <- matrix[2, 2]
	# false positive
    fp <- matrix[1, 2]
    return (tp / (tp + fp))
}

Explicação do código

  • mat[1,1]: Retorna a primeira célula da primeira coluna do quadro de dados, ou seja, o verdadeiro positivo
  • tapete[1,2]; Retorna a primeira célula da segunda coluna do data frame, ou seja, o falso positivo
recall <- function(matrix) {
# true positive
    tp <- matrix[2, 2]# false positive
    fn <- matrix[2, 1]
    return (tp / (tp + fn))
}

Explicação do código

  • mat[1,1]: Retorna a primeira célula da primeira coluna do quadro de dados, ou seja, o verdadeiro positivo
  • tapete[2,1]; Retorna a segunda célula da primeira coluna do data frame, ou seja, o falso negativo

Você pode testar suas funções

prec <- precision(table_mat)
prec
rec <- recall(table_mat)
rec

Saída:

## [1] 0.712877
## [2] 0.5336518

Quando o modelo diz que é um indivíduo acima de 50 mil, está correto em apenas 54% dos casos e pode reivindicar indivíduos acima de 50 mil em 72% dos casos.

Você pode criar o Precisão vs recall pontuação com base na precisão e recall. O Precisão vs recall é uma média harmônica dessas duas métricas, o que significa que dá mais peso aos valores mais baixos.

Precisão vs recall

f1 <- 2 * ((prec * rec) / (prec + rec))
f1

Saída:

## [1] 0.6103799

Troca entre precisão e recall

É impossível ter alta precisão e alto recall.

Se aumentarmos a precisão, o indivíduo correto será melhor previsto, mas perderemos muitos deles (menor recordação). Em algumas situações, preferimos maior precisão do que recall. Existe uma relação côncava entre precisão e recall.

  • Imagine, você precisa prever se um paciente tem alguma doença. Você quer ser o mais preciso possível.
  • Se você precisar detectar possíveis pessoas fraudulentas nas ruas por meio do reconhecimento facial, seria melhor capturar muitas pessoas rotuladas como fraudulentas, mesmo que a precisão seja baixa. A polícia poderá libertar o indivíduo não fraudulento.

A curva ROC

A recebedor Operacaracterística curva é outra ferramenta comum usada com classificação binária. É muito semelhante à curva de precisão/recall, mas em vez de representar graficamente precisão versus recall, a curva ROC mostra a taxa de verdadeiros positivos (ou seja, recall) contra a taxa de falsos positivos. A taxa de falsos positivos é a proporção de instâncias negativas que são classificadas incorretamente como positivas. É igual a um menos a taxa verdadeiramente negativa. A taxa verdadeiramente negativa também é chamada especificidade. Daí os gráficos da curva ROC sensibilidade (recall) versus 1 especificidade

Para traçar a curva ROC, precisamos instalar uma biblioteca chamada RORC. Podemos encontrar no conda biblioteca. Você pode digitar o código:

conda instalar -cr r-rocr –sim

Podemos traçar o ROC com as funções de previsão() e desempenho().

library(ROCR)
ROCRpred <- prediction(predict, data_test$income)
ROCRperf <- performance(ROCRpred, 'tpr', 'fpr')
plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))

Explicação do código

  • previsão (prever, data_test$income): A biblioteca ROCR precisa criar um objeto de previsão para transformar os dados de entrada
  • performance(ROCRpred, 'tpr','fpr'): Retorna as duas combinações a serem produzidas no gráfico. Aqui, tpr e fpr são construídos. Para plotar precisão e recall juntos, use “prec”, “rec”.

Saída:

A curva ROC

Passo 8) Melhorar o modelo

Você pode tentar adicionar não linearidade ao modelo com a interação entre

  • idade e horas.por.semana
  • gênero e horas.por.semana.

Você precisa usar o teste de pontuação para comparar os dois modelos

formula_2 <- income~age: hours.per.week + gender: hours.per.week + .
logit_2 <- glm(formula_2, data = data_train, family = 'binomial')
predict_2 <- predict(logit_2, data_test, type = 'response')
table_mat_2 <- table(data_test$income, predict_2 > 0.5)
precision_2 <- precision(table_mat_2)
recall_2 <- recall(table_mat_2)
f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2))
f1_2

Saída:

## [1] 0.6109181

A pontuação é um pouco superior à anterior. Você pode continuar trabalhando nos dados e tentar bater a pontuação.

Resumo

Podemos resumir a função para treinar uma regressão logística na tabela abaixo:

Pacote Objetivo função Argumento
- Criar conjunto de dados de treinamento/teste create_train_set() dados, tamanho, trem
glm Treine um modelo linear generalizado glm() fórmula, dados, família*
glm Resuma o modelo resumo() modelo ajustado
base Fazer previsão prever() modelo ajustado, conjunto de dados, tipo = 'resposta'
base Crie uma matriz de confusão mesa() sim, prever()
base Criar pontuação de precisão soma(diag(tabela())/soma(tabela()
ROCR Criar ROC: Etapa 1 Criar previsão predição() prever(), y
ROCR Criar ROC: Etapa 2 Criar desempenho desempenho() previsão(), 'tpr', 'fpr'
ROCR Criar ROC: Gráfico de plotagem da etapa 3 enredo() desempenho()

A outra GLM tipo de modelos são:

– binômio: (link = “logit”)

– gaussiano: (link = “identidade”)

– Gama: (link = “inverso”)

– inverso.gaussiano: (link = “1/mu^2”)

– poisson: (link = “log”)

– quase: (link = “identidade”, variância = “constante”)

– quasebinomial: (link = “logit”)

– quasepoisson: (link = “log”)