GLM em R: modelo linear generalizado com exemplo
O que é regressão logística?
A regressão logística é usada para prever uma classe, ou seja, uma probabilidade. A regressão logística pode prever um resultado binário com precisão.
Imagine que você deseja prever se um empréstimo será negado/aceito com base em vários atributos. A regressão logística tem o formato 0/1. y = 0 se um empréstimo for rejeitado, y = 1 se for aceito.
Um modelo de regressão logística difere do modelo de regressão linear de duas maneiras.
- Em primeiro lugar, a regressão logística aceita apenas dados dicotômicos (binários) como variável dependente (ou seja, um vetor de 0 e 1).
- Em segundo lugar, o resultado é medido pela seguinte função de ligação probabilística chamada sigmóide devido ao seu formato em S.:
A saída da função está sempre entre 0 e 1. Verifique a imagem abaixo
A função sigmóide retorna valores de 0 a 1. Para a tarefa de classificação, precisamos de uma saída discreta de 0 ou 1.
Para converter um fluxo contínuo em valor discreto, podemos definir um limite de decisão em 0.5. Todos os valores acima deste limite são classificados como 1
Como criar um modelo de liner generalizado (GLM)
Vamos usar o adulto conjunto de dados para ilustrar a regressão logística. O “adulto” é um ótimo conjunto de dados para a tarefa de classificação. O objetivo é prever se a renda anual em dólares de um indivíduo ultrapassará 50.000. O conjunto de dados contém 46,033 observações e dez recursos:
- idade: idade do indivíduo. Numérico
- educação: Nível educacional do indivíduo. Fator.
- estado civil: Mariestatuto social do indivíduo. Fator, ou seja, nunca casado, cônjuge civil casado, ...
- gênero: Gênero do indivíduo. Fator, ou seja, Masculino ou Feminino
- renda: Target variável. Renda acima ou abaixo de 50K. Fator, ou seja, >50K, <=50K
entre outros
library(dplyr) data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv") glimpse(data_adult)
Saída:
Observations: 48,842 Variables: 10 $ x <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,... $ age <int> 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26... $ workclass <fctr> Private, Private, Local-gov, Private, ?, Private,... $ education <fctr> 11th, HS-grad, Assoc-acdm, Some-college, Some-col... $ educational.num <int> 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,... $ marital.status <fctr> Never-married, Married-civ-spouse, Married-civ-sp... $ race <fctr> Black, White, White, Black, White, White, Black, ... $ gender <fctr> Male, Male, Male, Male, Female, Male, Male, Male,... $ hours.per.week <int> 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39... $ income <fctr> <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5...
Procederemos da seguinte forma:
- Passo 1: Verifique variáveis contínuas
- Passo 2: Verifique as variáveis fatoriais
- Etapa 3: engenharia de recursos
- Etapa 4: estatística resumida
- Etapa 5: conjunto de treinamento/teste
- Etapa 6: construir o modelo
- Etapa 7: Avalie o desempenho do modelo
- etapa 8: Melhore o modelo
Sua tarefa é prever qual indivíduo terá uma receita superior a 50 mil.
Neste tutorial, cada etapa será detalhada para realizar uma análise em um conjunto de dados real.
Passo 1) Verifique variáveis contínuas
Na primeira etapa, você pode ver a distribuição das variáveis contínuas.
continuous <-select_if(data_adult, is.numeric) summary(continuous)
Explicação do código
- contínuo <- select_if(data_adult, is.numeric): Use a função select_if() da biblioteca dplyr para selecionar apenas as colunas numéricas
- resumo (contínuo): Imprime a estatística do resumo
Saída:
## X age educational.num hours.per.week ## Min. : 1 Min. :17.00 Min. : 1.00 Min. : 1.00 ## 1st Qu.:11509 1st Qu.:28.00 1st Qu.: 9.00 1st Qu.:40.00 ## Median :23017 Median :37.00 Median :10.00 Median :40.00 ## Mean :23017 Mean :38.56 Mean :10.13 Mean :40.95 ## 3rd Qu.:34525 3rd Qu.:47.00 3rd Qu.:13.00 3rd Qu.:45.00 ## Max. :46033 Max. :90.00 Max. :16.00 Max. :99.00
Na tabela acima, você pode ver que os dados têm escalas totalmente diferentes e horas por semana têm grandes discrepâncias (ou seja, observe o último quartil e o valor máximo).
Você pode lidar com isso seguindo duas etapas:
- 1: Trace a distribuição de horas por semana
- 2: Padronize as variáveis contínuas
- Trace a distribuição
Vejamos mais de perto a distribuição de horas por semana
# Histogram with kernel density curve library(ggplot2) ggplot(continuous, aes(x = hours.per.week)) + geom_density(alpha = .2, fill = "#FF6666")
Saída:
A variável tem muitos valores discrepantes e distribuição não bem definida. Você pode resolver parcialmente esse problema excluindo 0.01% das principais horas da semana.
Sintaxe básica do quantil:
quantile(variable, percentile) arguments: -variable: Select the variable in the data frame to compute the percentile -percentile: Can be a single value between 0 and 1 or multiple value. If multiple, use this format: `c(A,B,C, ...) - `A`,`B`,`C` and `...` are all integer from 0 to 1.
Calculamos o percentil dos 2 por cento superiores
top_one_percent <- quantile(data_adult$hours.per.week, .99) top_one_percent
Explicação do código
- quantile(data_adult$hours.per.week, .99): Calcule o valor de 99 por cento do tempo de trabalho
Saída:
## 99% ## 80
98 por cento da população trabalha menos de 80 horas por semana.
Você pode deixar as observações acima desse limite. Você usa o filtro do dplyr biblioteca.
data_adult_drop <-data_adult %>% filter(hours.per.week<top_one_percent) dim(data_adult_drop)
Saída:
## [1] 45537 10
- Padronize as variáveis contínuas
Você pode padronizar cada coluna para melhorar o desempenho porque seus dados não têm a mesma escala. Você pode usar a função mutate_if da biblioteca dplyr. A sintaxe básica é:
mutate_if(df, condition, funs(function)) arguments: -`df`: Data frame used to compute the function - `condition`: Statement used. Do not use parenthesis - funs(function): Return the function to apply. Do not use parenthesis for the function
Você pode padronizar as colunas numéricas da seguinte forma:
data_adult_rescale <- data_adult_drop % > % mutate_if(is.numeric, funs(as.numeric(scale(.)))) head(data_adult_rescale)
Explicação do código
- mutate_if(is.numeric, funs(scale)): A condição é apenas coluna numérica e a função é escala
Saída:
## X age workclass education educational.num ## 1 -1.732680 -1.02325949 Private 11th -1.22106443 ## 2 -1.732605 -0.03969284 Private HS-grad -0.43998868 ## 3 -1.732530 -0.79628257 Local-gov Assoc-acdm 0.73162494 ## 4 -1.732455 0.41426100 Private Some-college -0.04945081 ## 5 -1.732379 -0.34232873 Private 10th -1.61160231 ## 6 -1.732304 1.85178149 Self-emp-not-inc Prof-school 1.90323857 ## marital.status race gender hours.per.week income ## 1 Never-married Black Male -0.03995944 <=50K ## 2 Married-civ-spouse White Male 0.86863037 <=50K ## 3 Married-civ-spouse White Male -0.03995944 >50K ## 4 Married-civ-spouse Black Male -0.03995944 >50K ## 5 Never-married White Male -0.94854924 <=50K ## 6 Married-civ-spouse White Male -0.76683128 >50K
Passo 2) Verifique as variáveis fatoriais
Esta etapa tem dois objetivos:
- Verifique o nível em cada coluna categórica
- Defina novos níveis
Dividiremos esta etapa em três partes:
- Selecione as colunas categóricas
- Armazene o gráfico de barras de cada coluna em uma lista
- Imprima os gráficos
Podemos selecionar as colunas de fator com o código abaixo:
# Select categorical column factor <- data.frame(select_if(data_adult_rescale, is.factor)) ncol(factor)
Explicação do código
- data.frame(select_if(data_adult, is.factor)): Armazenamos as colunas de fator em factor em um tipo de quadro de dados. A biblioteca ggplot2 requer um objeto de quadro de dados.
Saída:
## [1] 6
O conjunto de dados contém 6 variáveis categóricas
A segunda etapa é mais qualificada. Você deseja traçar um gráfico de barras para cada coluna no fator do quadro de dados. É mais conveniente automatizar o processo, principalmente em situações em que há muitas colunas.
library(ggplot2) # Create graph for each column graph <- lapply(names(factor), function(x) ggplot(factor, aes(get(x))) + geom_bar() + theme(axis.text.x = element_text(angle = 90)))
Explicação do código
- lapply(): Use a função lapply() para passar uma função em todas as colunas do conjunto de dados. Você armazena a saída em uma lista
- function(x): A função será processada para cada x. Aqui x são as colunas
- ggplot(factor, aes(get(x))) + geom_bar()+ theme(axis.text.x = element_text(angle = 90)): Crie um gráfico de barras para cada elemento x. Observe que para retornar x como uma coluna, você precisa incluí-lo dentro de get()
A última etapa é relativamente fácil. Você deseja imprimir os 6 gráficos.
# Print the graph graph
Saída:
## [[1]]
## ## [[2]]
## ## [[3]]
## ## [[4]]
## ## [[5]]
## ## [[6]]
Nota: Use o botão seguinte para navegar para o próximo gráfico
Etapa 3) Engenharia de recursos
Reformular a educação
No gráfico acima você pode perceber que a variável escolaridade possui 16 níveis. Isto é substancial e alguns níveis têm um número relativamente baixo de observações. Se quiser melhorar a quantidade de informações que pode obter dessa variável, você pode reformulá-la para um nível superior. Ou seja, você cria grupos maiores com nível de escolaridade semelhante. Por exemplo, o baixo nível de escolaridade será convertido em abandono escolar. Níveis mais elevados de educação serão alterados para mestrado.
Aqui está o detalhe:
Nível antigo | Novo nível |
---|---|
Pré escola | Dropout |
sec 10 | Cair fora |
sec 11 | Cair fora |
sec 12 | Cair fora |
1º a 4º | Cair fora |
5th-6th | Cair fora |
7th-8th | Cair fora |
sec 9 | Cair fora |
HS-Graduação | Alta graduação |
Alguma faculdade | Comunidade |
Associado-acdm | Comunidade |
Associado | Comunidade |
Bacharelado | Bacharelado |
mestres | mestres |
Escola profissional | mestres |
Doutorado | PhD |
recast_data <- data_adult_rescale % > % select(-X) % > % mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community", ifelse(education == "Bachelors", "Bachelors", ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))
Explicação do código
- Usamos o verbo mutate da biblioteca dplyr. Mudamos os valores da educação com a afirmação ifelse
Na tabela abaixo, você cria uma estatística resumida para ver, em média, quantos anos de estudo (valor z) são necessários para se chegar ao bacharelado, mestrado ou doutorado.
recast_data % > % group_by(education) % > % summarize(average_educ_year = mean(educational.num), count = n()) % > % arrange(average_educ_year)
Saída:
## # A tibble: 6 x 3 ## education average_educ_year count ## <fctr> <dbl> <int> ## 1 dropout -1.76147258 5712 ## 2 HighGrad -0.43998868 14803 ## 3 Community 0.09561361 13407 ## 4 Bachelors 1.12216282 7720 ## 5 Master 1.60337381 3338 ## 6 PhD 2.29377644 557
Reformulação Marital-status
Também é possível criar níveis inferiores para o estado civil. No código a seguir, você altera o nível da seguinte maneira:
Nível antigo | Novo nível |
---|---|
Nunca casado | Solteiro |
Casado-cônjuge-ausente | Solteiro |
Casado-AF-cônjuge | Casado |
Cônjuge casada | |
Separado | Separado |
Divorciado | |
Viúvas | Viúva |
# Change level marry recast_data <- recast_data % > % mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))
Você pode verificar o número de indivíduos dentro de cada grupo.
table(recast_data$marital.status)
Saída:
## ## Married Not_married Separated Widow ## 21165 15359 7727 1286
Etapa 4) Estatística resumida
É hora de verificar algumas estatísticas sobre nossas variáveis-alvo. No gráfico abaixo, você conta a porcentagem de indivíduos que ganham mais de 50 mil de acordo com seu gênero.
# Plot gender income ggplot(recast_data, aes(x = gender, fill = income)) + geom_bar(position = "fill") + theme_classic()
Saída:
A seguir, verifique se a origem do indivíduo afeta seus rendimentos.
# Plot origin income ggplot(recast_data, aes(x = race, fill = income)) + geom_bar(position = "fill") + theme_classic() + theme(axis.text.x = element_text(angle = 90))
Saída:
O número de horas de trabalho por gênero.
# box plot gender working time ggplot(recast_data, aes(x = gender, y = hours.per.week)) + geom_boxplot() + stat_summary(fun.y = mean, geom = "point", size = 3, color = "steelblue") + theme_classic()
Saída:
O box plot confirma que a distribuição do tempo de trabalho se ajusta a diferentes grupos. No box plot, ambos os sexos não possuem observações homogêneas.
Você pode verificar a densidade do tempo de trabalho semanal por tipo de ensino. As distribuições têm muitas escolhas distintas. Provavelmente pode ser explicado pelo tipo de contrato nos EUA.
# Plot distribution working time by education ggplot(recast_data, aes(x = hours.per.week)) + geom_density(aes(color = education), alpha = 0.5) + theme_classic()
Explicação do código
- ggplot(recast_data, aes( x= hours.per.week)): Um gráfico de densidade requer apenas uma variável
- geom_density(aes(color = education), alpha =0.5): O objeto geométrico para controlar a densidade
Saída:
Para confirmar seus pensamentos, você pode realizar um teste unilateral Teste ANOVA:
anova <- aov(hours.per.week~education, recast_data) summary(anova)
Saída:
## Df Sum Sq Mean Sq F value Pr(>F) ## education 5 1552 310.31 321.2 <2e-16 *** ## Residuals 45531 43984 0.97 ## --- ## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
O teste ANOVA confirma a diferença de média entre os grupos.
Não-linearidade
Antes de executar o modelo, você pode ver se o número de horas trabalhadas está relacionado à idade.
library(ggplot2) ggplot(recast_data, aes(x = age, y = hours.per.week)) + geom_point(aes(color = income), size = 0.5) + stat_smooth(method = 'lm', formula = y~poly(x, 2), se = TRUE, aes(color = income)) + theme_classic()
Explicação do código
- ggplot(recast_data, aes(x = age, y = hours.per.week)): Defina a estética do gráfico
- geom_point(aes(color= income), size =0.5): Construa o gráfico de pontos
- stat_smooth(): Adicione a linha de tendência com os seguintes argumentos:
- method='lm': Plote o valor ajustado se o regressão linear
- fórmula = y~poly(x,2): Ajustar uma regressão polinomial
- se = TRUE: Adicione o erro padrão
- aes(cor=renda): Divida o modelo por renda
Saída:
Resumindo, você pode testar os termos de interação no modelo para captar o efeito de não linearidade entre o horário de trabalho semanal e outros recursos. É importante detectar em que condições o tempo de trabalho difere.
Correlação
A próxima verificação é visualizar a correlação entre as variáveis. Você converte o tipo de nível de fator em numérico para poder traçar um mapa de calor contendo o coeficiente de correlação calculado com o método de Spearman.
library(GGally) # Convert data to numeric corr <- data.frame(lapply(recast_data, as.integer)) # Plot the graphggcorr(corr, method = c("pairwise", "spearman"), nbreaks = 6, hjust = 0.8, label = TRUE, label_size = 3, color = "grey50")
Explicação do código
- data.frame(lapply(recast_data,as.integer)): Converte dados em numéricos
- ggcorr() plote o mapa de calor com os seguintes argumentos:
- método: Método para calcular a correlação
- nbreaks = 6: Número de pausas
- hjust = 0.8: posição de controle do nome da variável no gráfico
- label = TRUE: Adicione rótulos no centro das janelas
- label_size = 3: rótulos de tamanho
- color = “grey50”): Cor da etiqueta
Saída:
Etapa 5) Conjunto de treinamento/teste
Qualquer supervisionado aprendizado de máquina A tarefa exige dividir os dados entre um conjunto de treinamento e um conjunto de teste. Você pode usar a “função” criada em outros tutoriais de aprendizagem supervisionada para criar um conjunto de treinamento/teste.
set.seed(1234) create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample <- 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } } data_train <- create_train_test(recast_data, 0.8, train = TRUE) data_test <- create_train_test(recast_data, 0.8, train = FALSE) dim(data_train)
Saída:
## [1] 36429 9
dim(data_test)
Saída:
## [1] 9108 9
Etapa 6) Construa o modelo
Para ver o desempenho do algoritmo, você usa o pacote glm(). O Modelo Linear Generalizado é uma coleção de modelos. A sintaxe básica é:
glm(formula, data=data, family=linkfunction() Argument: - formula: Equation used to fit the model- data: dataset used - Family: - binomial: (link = "logit") - gaussian: (link = "identity") - Gamma: (link = "inverse") - inverse.gaussian: (link = "1/mu^2") - poisson: (link = "log") - quasi: (link = "identity", variance = "constant") - quasibinomial: (link = "logit") - quasipoisson: (link = "log")
Você está pronto para estimar o modelo logístico para dividir o nível de renda entre um conjunto de recursos.
formula <- income~. logit <- glm(formula, data = data_train, family = 'binomial') summary(logit)
Explicação do código
- fórmula <- renda ~ .: Crie o modelo para caber
- logit <- glm(formula, data = data_train, family = 'binomial'): Ajusta um modelo logístico (family = 'binomial') com os dados data_train.
- summary(logit): Imprime o resumo do modelo
Saída:
## ## Call: ## glm(formula = formula, family = "binomial", data = data_train) ## ## Deviance Residuals: ## Min 1Q Median 3Q Max ## -2.6456 -0.5858 -0.2609 -0.0651 3.1982 ## ## Coefficients: ## Estimate Std. Error z value Pr(>|z|) ## (Intercept) 0.07882 0.21726 0.363 0.71675 ## age 0.41119 0.01857 22.146 < 2e-16 *** ## workclassLocal-gov -0.64018 0.09396 -6.813 9.54e-12 *** ## workclassPrivate -0.53542 0.07886 -6.789 1.13e-11 *** ## workclassSelf-emp-inc -0.07733 0.10350 -0.747 0.45499 ## workclassSelf-emp-not-inc -1.09052 0.09140 -11.931 < 2e-16 *** ## workclassState-gov -0.80562 0.10617 -7.588 3.25e-14 *** ## workclassWithout-pay -1.09765 0.86787 -1.265 0.20596 ## educationCommunity -0.44436 0.08267 -5.375 7.66e-08 *** ## educationHighGrad -0.67613 0.11827 -5.717 1.08e-08 *** ## educationMaster 0.35651 0.06780 5.258 1.46e-07 *** ## educationPhD 0.46995 0.15772 2.980 0.00289 ** ## educationdropout -1.04974 0.21280 -4.933 8.10e-07 *** ## educational.num 0.56908 0.07063 8.057 7.84e-16 *** ## marital.statusNot_married -2.50346 0.05113 -48.966 < 2e-16 *** ## marital.statusSeparated -2.16177 0.05425 -39.846 < 2e-16 *** ## marital.statusWidow -2.22707 0.12522 -17.785 < 2e-16 *** ## raceAsian-Pac-Islander 0.08359 0.20344 0.411 0.68117 ## raceBlack 0.07188 0.19330 0.372 0.71001 ## raceOther 0.01370 0.27695 0.049 0.96054 ## raceWhite 0.34830 0.18441 1.889 0.05894 . ## genderMale 0.08596 0.04289 2.004 0.04506 * ## hours.per.week 0.41942 0.01748 23.998 < 2e-16 *** ## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 ## ## (Dispersion parameter for binomial family taken to be 1) ## ## Null deviance: 40601 on 36428 degrees of freedom ## Residual deviance: 27041 on 36406 degrees of freedom ## AIC: 27087 ## ## Number of Fisher Scoring iterations: 6
O resumo do nosso modelo revela informações interessantes. O desempenho de uma regressão logística é avaliado com métricas-chave específicas.
- AIC (Critérios de Informação Akaike): Isto é equivalente a R2 na regressão logística. Mede o ajuste quando uma penalidade é aplicada ao número de parâmetros. Menor AIC valores indicam que o modelo está mais próximo da verdade.
- Desvio nulo: ajusta o modelo apenas com o intercepto. O grau de liberdade é n-1. Podemos interpretá-lo como um valor qui-quadrado (valor ajustado diferente do teste de hipótese de valor real).
- Desvio Residual: Modelo com todas as variáveis. Também é interpretado como um teste de hipótese do qui-quadrado.
- Número de iterações do Fisher Scoring: Número de iterações antes da convergência.
A saída da função glm() é armazenada em uma lista. O código abaixo mostra todos os itens disponíveis na variável logit que construímos para avaliar a regressão logística.
# A lista é muito longa, imprima apenas os três primeiros elementos
lapply(logit, class)[1:3]
Saída:
## $coefficients ## [1] "numeric" ## ## $residuals ## [1] "numeric" ## ## $fitted.values ## [1] "numeric"
Cada valor pode ser extraído com o sinal $ seguido do nome das métricas. Por exemplo, você armazenou o modelo como logit. Para extrair os critérios AIC, você usa:
logit$aic
Saída:
## [1] 27086.65
Etapa 7) Avalie o desempenho do modelo
Matriz de Confusão
A matriz de confusão é a melhor opção para avaliar o desempenho da classificação em comparação com as diferentes métricas que você viu antes. A ideia geral é contar o número de vezes que instâncias Verdadeiras são classificadas como Falsas.
Para calcular a matriz de confusão, primeiro você precisa ter um conjunto de previsões para que possam ser comparadas com os alvos reais.
predict <- predict(logit, data_test, type = 'response') # confusion matrix table_mat <- table(data_test$income, predict > 0.5) table_mat
Explicação do código
- predizer(logit,data_test, type = 'response'): Calcula a previsão no conjunto de teste. Defina type = 'response' para calcular a probabilidade de resposta.
- table(data_test$income, prever > 0.5): Calcula a matriz de confusão. prever > 0.5 significa que retornará 1 se as probabilidades previstas estiverem acima de 0.5, caso contrário, 0.
Saída:
## ## FALSE TRUE ## <=50K 6310 495 ## >50K 1074 1229
Cada linha em uma matriz de confusão representa um alvo real, enquanto cada coluna representa um alvo previsto. A primeira linha desta matriz considera a renda inferior a 50k (a classe Falsa): 6241 foram corretamente classificados como indivíduos com renda inferior a 50k (Verdadeiro negativo), enquanto o restante foi erroneamente classificado como acima de 50k (Falso positivo). A segunda linha considera a renda acima de 50k, a classe positiva foi 1229 (Verdadeiro-positivo), Enquanto que o Verdadeiro negativo foi 1074.
Você pode calcular o modelo precisão somando o verdadeiro positivo + verdadeiro negativo sobre a observação total
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test
Explicação do código
- sum(diag(table_mat)): Soma da diagonal
- sum(table_mat): Soma da matriz.
Saída:
## [1] 0.8277339
O modelo parece sofrer de um problema: superestima o número de falsos negativos. Isso é chamado de paradoxo do teste de precisão. Afirmamos que a precisão é a razão entre as previsões corretas e o número total de casos. Podemos ter uma precisão relativamente alta, mas um modelo inútil. Isso acontece quando há uma classe dominante. Se você olhar novamente para a matriz de confusão, verá que a maioria dos casos são classificados como verdadeiros negativos. Imagine agora, o modelo classificou todas as classes como negativas (ou seja, inferiores a 50k). Você teria uma precisão de 75 por cento (6718/6718+2257). Seu modelo tem melhor desempenho, mas tem dificuldade para distinguir o verdadeiro positivo do verdadeiro negativo.
Nessa situação, é preferível ter uma métrica mais concisa. Podemos olhar para:
- Precisão=TP/(TP+FP)
- Rechamada=TP/(TP+FN)
Precisão vs recall
Precisão analisa a precisão da previsão positiva. Recordar é a proporção de instâncias positivas que são detectadas corretamente pelo classificador;
Você pode construir duas funções para calcular essas duas métricas
- Precisão de construção
precision <- function(matrix) { # True positive tp <- matrix[2, 2] # false positive fp <- matrix[1, 2] return (tp / (tp + fp)) }
Explicação do código
- mat[1,1]: Retorna a primeira célula da primeira coluna do quadro de dados, ou seja, o verdadeiro positivo
- tapete[1,2]; Retorna a primeira célula da segunda coluna do data frame, ou seja, o falso positivo
recall <- function(matrix) { # true positive tp <- matrix[2, 2]# false positive fn <- matrix[2, 1] return (tp / (tp + fn)) }
Explicação do código
- mat[1,1]: Retorna a primeira célula da primeira coluna do quadro de dados, ou seja, o verdadeiro positivo
- tapete[2,1]; Retorna a segunda célula da primeira coluna do data frame, ou seja, o falso negativo
Você pode testar suas funções
prec <- precision(table_mat) prec rec <- recall(table_mat) rec
Saída:
## [1] 0.712877 ## [2] 0.5336518
Quando o modelo diz que é um indivíduo acima de 50 mil, está correto em apenas 54% dos casos e pode reivindicar indivíduos acima de 50 mil em 72% dos casos.
Você pode criar o pontuação com base na precisão e recall. O
é uma média harmônica dessas duas métricas, o que significa que dá mais peso aos valores mais baixos.
f1 <- 2 * ((prec * rec) / (prec + rec)) f1
Saída:
## [1] 0.6103799
Troca entre precisão e recall
É impossível ter alta precisão e alto recall.
Se aumentarmos a precisão, o indivíduo correto será melhor previsto, mas perderemos muitos deles (menor recordação). Em algumas situações, preferimos maior precisão do que recall. Existe uma relação côncava entre precisão e recall.
- Imagine, você precisa prever se um paciente tem alguma doença. Você quer ser o mais preciso possível.
- Se você precisar detectar possíveis pessoas fraudulentas nas ruas por meio do reconhecimento facial, seria melhor capturar muitas pessoas rotuladas como fraudulentas, mesmo que a precisão seja baixa. A polícia poderá libertar o indivíduo não fraudulento.
A curva ROC
A recebedor Operacaracterística curva é outra ferramenta comum usada com classificação binária. É muito semelhante à curva de precisão/recall, mas em vez de representar graficamente precisão versus recall, a curva ROC mostra a taxa de verdadeiros positivos (ou seja, recall) contra a taxa de falsos positivos. A taxa de falsos positivos é a proporção de instâncias negativas que são classificadas incorretamente como positivas. É igual a um menos a taxa verdadeiramente negativa. A taxa verdadeiramente negativa também é chamada especificidade. Daí os gráficos da curva ROC sensibilidade (recall) versus 1 especificidade
Para traçar a curva ROC, precisamos instalar uma biblioteca chamada RORC. Podemos encontrar no conda biblioteca. Você pode digitar o código:
conda instalar -cr r-rocr –sim
Podemos traçar o ROC com as funções de previsão() e desempenho().
library(ROCR) ROCRpred <- prediction(predict, data_test$income) ROCRperf <- performance(ROCRpred, 'tpr', 'fpr') plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))
Explicação do código
- previsão (prever, data_test$income): A biblioteca ROCR precisa criar um objeto de previsão para transformar os dados de entrada
- performance(ROCRpred, 'tpr','fpr'): Retorna as duas combinações a serem produzidas no gráfico. Aqui, tpr e fpr são construídos. Para plotar precisão e recall juntos, use “prec”, “rec”.
Saída:
Passo 8) Melhorar o modelo
Você pode tentar adicionar não linearidade ao modelo com a interação entre
- idade e horas.por.semana
- gênero e horas.por.semana.
Você precisa usar o teste de pontuação para comparar os dois modelos
formula_2 <- income~age: hours.per.week + gender: hours.per.week + . logit_2 <- glm(formula_2, data = data_train, family = 'binomial') predict_2 <- predict(logit_2, data_test, type = 'response') table_mat_2 <- table(data_test$income, predict_2 > 0.5) precision_2 <- precision(table_mat_2) recall_2 <- recall(table_mat_2) f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2)) f1_2
Saída:
## [1] 0.6109181
A pontuação é um pouco superior à anterior. Você pode continuar trabalhando nos dados e tentar bater a pontuação.
Resumo
Podemos resumir a função para treinar uma regressão logística na tabela abaixo:
Pacote | Objetivo | função | Argumento |
---|---|---|---|
- | Criar conjunto de dados de treinamento/teste | create_train_set() | dados, tamanho, trem |
glm | Treine um modelo linear generalizado | glm() | fórmula, dados, família* |
glm | Resuma o modelo | resumo() | modelo ajustado |
base | Fazer previsão | prever() | modelo ajustado, conjunto de dados, tipo = 'resposta' |
base | Crie uma matriz de confusão | mesa() | sim, prever() |
base | Criar pontuação de precisão | soma(diag(tabela())/soma(tabela() | |
ROCR | Criar ROC: Etapa 1 Criar previsão | predição() | prever(), y |
ROCR | Criar ROC: Etapa 2 Criar desempenho | desempenho() | previsão(), 'tpr', 'fpr' |
ROCR | Criar ROC: Gráfico de plotagem da etapa 3 | enredo() | desempenho() |
A outra GLM tipo de modelos são:
– binômio: (link = “logit”)
– gaussiano: (link = “identidade”)
– Gama: (link = “inverso”)
– inverso.gaussiano: (link = “1/mu^2”)
– poisson: (link = “log”)
– quase: (link = “identidade”, variância = “constante”)
– quasebinomial: (link = “logit”)
– quasepoisson: (link = “log”)