GLM dans R : modèle linéaire généralisé avec exemple
Qu’est-ce que la régression logistique ?
La régression logistique est utilisée pour prédire une classe, c'est-à-dire une probabilité. La régression logistique peut prédire avec précision un résultat binaire.
Imaginez que vous souhaitiez prédire si un prêt est refusé/accepté en fonction de nombreux attributs. La régression logistique est de la forme 0/1. y = 0 si un prêt est rejeté, y = 1 s'il est accepté.
Un modèle de régression logistique diffère du modèle de régression linéaire de deux manières.
- Tout d'abord, la régression logistique n'accepte que les entrées dichotomiques (binaires) comme variable dépendante (c'est-à-dire un vecteur de 0 et 1).
- Deuxièmement, le résultat est mesuré par la fonction de lien probabiliste suivante appelée sigmoïde grâce à sa forme en S. :
La sortie de la fonction est toujours comprise entre 0 et 1. Vérifiez l'image ci-dessous
La fonction sigmoïde renvoie des valeurs de 0 à 1. Pour la tâche de classification, nous avons besoin d'une sortie discrète de 0 ou 1.
Pour convertir un flux continu en valeur discrète, nous pouvons fixer une limite de décision à 0.5. Toutes les valeurs supérieures à ce seuil sont classées 1
Comment créer un modèle de revêtement généralisé (GLM)
Utilisons le adulte ensemble de données pour illustrer la régression logistique. L’« adulte » constitue un excellent ensemble de données pour la tâche de classification. L'objectif est de prédire si le revenu annuel en dollars d'un individu dépassera 50.000. L'ensemble de données contient 46,033 observations et dix caractéristiques :
- âge : âge de l’individu. Numérique
- éducation : Niveau d’éducation de l’individu. Facteur.
- état matrimonial : Maristatut réel de l’individu. Facteur c'est-à-dire jamais marié, marié-conjoint civil,…
- genre : Sexe de l’individu. Facteur, c'est-à-dire masculin ou féminin
- le revenu: Target variable. Revenu supérieur ou inférieur à 50K. Facteur c'est-à-dire >50K, <=50K
entre autres
library(dplyr) data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv") glimpse(data_adult)
Sortie :
Observations: 48,842 Variables: 10 $ x <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,... $ age <int> 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26... $ workclass <fctr> Private, Private, Local-gov, Private, ?, Private,... $ education <fctr> 11th, HS-grad, Assoc-acdm, Some-college, Some-col... $ educational.num <int> 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,... $ marital.status <fctr> Never-married, Married-civ-spouse, Married-civ-sp... $ race <fctr> Black, White, White, Black, White, White, Black, ... $ gender <fctr> Male, Male, Male, Male, Female, Male, Male, Male,... $ hours.per.week <int> 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39... $ income <fctr> <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5...
Nous procéderons de la manière suivante :
- Étape 1 : Vérifiez les variables continues
- Étape 2 : Vérifier les variables factorielles
- Étape 3 : Ingénierie des fonctionnalités
- Étape 4 : Statistique récapitulative
- Étape 5 : Former/tester l’ensemble
- Étape 6 : Construire le modèle
- Étape 7 : Évaluer les performances du modèle
- étape 8 : Améliorer le modèle
Votre tâche consiste à prédire quelle personne aura un revenu supérieur à 50 $.
Dans ce tutoriel, chaque étape sera détaillée pour effectuer une analyse sur un jeu de données réel.
Étape 1) Vérifiez les variables continues
Dans la première étape, vous pouvez voir la distribution des variables continues.
continuous <-select_if(data_adult, is.numeric) summary(continuous)
Explication du code
- continu <- select_if(data_adult, is.numeric) : utilisez la fonction select_if() de la bibliothèque dplyr pour sélectionner uniquement les colonnes numériques
- résumé (continu) : Imprimer la statistique récapitulative
Sortie :
## X age educational.num hours.per.week ## Min. : 1 Min. :17.00 Min. : 1.00 Min. : 1.00 ## 1st Qu.:11509 1st Qu.:28.00 1st Qu.: 9.00 1st Qu.:40.00 ## Median :23017 Median :37.00 Median :10.00 Median :40.00 ## Mean :23017 Mean :38.56 Mean :10.13 Mean :40.95 ## 3rd Qu.:34525 3rd Qu.:47.00 3rd Qu.:13.00 3rd Qu.:45.00 ## Max. :46033 Max. :90.00 Max. :16.00 Max. :99.00
Dans le tableau ci-dessus, vous pouvez voir que les données ont des échelles totalement différentes et que les heures par semaine présentent de grandes valeurs aberrantes (c'est-à-dire regardez le dernier quartile et la valeur maximale).
Vous pouvez y faire face en suivant deux étapes :
- 1 : Tracer la répartition des heures.par.semaine
- 2 : Standardiser les variables continues
- Tracer la distribution
Regardons de plus près la répartition des heures par semaine
# Histogram with kernel density curve library(ggplot2) ggplot(continuous, aes(x = hours.per.week)) + geom_density(alpha = .2, fill = "#FF6666")
Sortie :
La variable présente de nombreuses valeurs aberrantes et une distribution mal définie. Vous pouvez résoudre partiellement ce problème en supprimant les 0.01 pour cent des heures les plus élevées par semaine.
Syntaxe de base du quantile :
quantile(variable, percentile) arguments: -variable: Select the variable in the data frame to compute the percentile -percentile: Can be a single value between 0 and 1 or multiple value. If multiple, use this format: `c(A,B,C, ...) - `A`,`B`,`C` and `...` are all integer from 0 to 1.
Nous calculons le centile supérieur de 2 pour cent
top_one_percent <- quantile(data_adult$hours.per.week, .99) top_one_percent
Explication du code
- quantile (data_adult$hours.per.week, .99) : calcule la valeur de 99 % du temps de travail
Sortie :
## 99% ## 80
98 pour cent de la population travaille moins de 80 heures par semaine.
Vous pouvez déposer les observations au-dessus de ce seuil. Vous utilisez le filtre du déplyr bibliothèque.
data_adult_drop <-data_adult %>% filter(hours.per.week<top_one_percent) dim(data_adult_drop)
Sortie :
## [1] 45537 10
- Standardiser les variables continues
Vous pouvez standardiser chaque colonne pour améliorer les performances car vos données n'ont pas la même échelle. Vous pouvez utiliser la fonction mutate_if de la bibliothèque dplyr. La syntaxe de base est la suivante :
mutate_if(df, condition, funs(function)) arguments: -`df`: Data frame used to compute the function - `condition`: Statement used. Do not use parenthesis - funs(function): Return the function to apply. Do not use parenthesis for the function
Vous pouvez standardiser les colonnes numériques comme suit :
data_adult_rescale <- data_adult_drop % > % mutate_if(is.numeric, funs(as.numeric(scale(.)))) head(data_adult_rescale)
Explication du code
- mutate_if(is.numeric, funs(scale)) : La condition est uniquement une colonne numérique et la fonction est une échelle
Sortie :
## X age workclass education educational.num ## 1 -1.732680 -1.02325949 Private 11th -1.22106443 ## 2 -1.732605 -0.03969284 Private HS-grad -0.43998868 ## 3 -1.732530 -0.79628257 Local-gov Assoc-acdm 0.73162494 ## 4 -1.732455 0.41426100 Private Some-college -0.04945081 ## 5 -1.732379 -0.34232873 Private 10th -1.61160231 ## 6 -1.732304 1.85178149 Self-emp-not-inc Prof-school 1.90323857 ## marital.status race gender hours.per.week income ## 1 Never-married Black Male -0.03995944 <=50K ## 2 Married-civ-spouse White Male 0.86863037 <=50K ## 3 Married-civ-spouse White Male -0.03995944 >50K ## 4 Married-civ-spouse Black Male -0.03995944 >50K ## 5 Never-married White Male -0.94854924 <=50K ## 6 Married-civ-spouse White Male -0.76683128 >50K
Étape 2) Vérifiez les variables factorielles
Cette étape a deux objectifs :
- Vérifiez le niveau dans chaque colonne catégorielle
- Définir de nouveaux niveaux
Nous diviserons cette étape en trois parties :
- Sélectionnez les colonnes catégorielles
- Stocker le graphique à barres de chaque colonne dans une liste
- Imprimer les graphiques
Nous pouvons sélectionner les colonnes de facteurs avec le code ci-dessous :
# Select categorical column factor <- data.frame(select_if(data_adult_rescale, is.factor)) ncol(factor)
Explication du code
- data.frame(select_if(data_adult, is.factor)) : Nous stockons les colonnes de facteurs dans factor dans un type de trame de données. La bibliothèque ggplot2 nécessite un objet data frame.
Sortie :
## [1] 6
L'ensemble de données contient 6 variables catégorielles
La deuxième étape est plus qualifiée. Vous souhaitez tracer un graphique à barres pour chaque colonne du facteur de bloc de données. Il est plus pratique d’automatiser le processus, surtout dans les cas où il y a beaucoup de colonnes.
library(ggplot2) # Create graph for each column graph <- lapply(names(factor), function(x) ggplot(factor, aes(get(x))) + geom_bar() + theme(axis.text.x = element_text(angle = 90)))
Explication du code
- lapply() : Utilisez la fonction lapply() pour passer une fonction dans toutes les colonnes de l'ensemble de données. Vous stockez la sortie dans une liste
- function(x) : La fonction sera traitée pour chaque x. Ici x sont les colonnes
- ggplot(factor, aes(get(x))) + geom_bar()+ theme(axis.text.x = element_text(angle = 90)) : Créez un graphique à barres pour chaque élément x. Notez que pour renvoyer x sous forme de colonne, vous devez l'inclure dans get()
La dernière étape est relativement simple. Vous souhaitez imprimer les 6 graphiques.
# Print the graph graph
Sortie :
## [[1]]
## ## [[2]]
## ## [[3]]
## ## [[4]]
## ## [[5]]
## ## [[6]]
Remarque : utilisez le bouton suivant pour accéder au graphique suivant.
Étape 3) Ingénierie des fonctionnalités
Refondre l'éducation
Sur le graphique ci-dessus, vous pouvez voir que la variable éducation comporte 16 niveaux. C'est important et certains niveaux ont un nombre d'observations relativement faible. Si vous souhaitez améliorer la quantité d'informations que vous pouvez obtenir de cette variable, vous pouvez la redéfinir à un niveau supérieur. Autrement dit, vous créez des groupes plus grands avec un niveau d’éducation similaire. Par exemple, un faible niveau d’éducation se transformera en abandon scolaire. Les niveaux d'enseignement supérieurs seront remplacés par le master.
Voici le détail :
Ancien niveau | Nouveau niveau |
---|---|
Maternelle | Dropout |
10 | marginal |
11 | marginal |
12 | marginal |
1er-4e | marginal |
5th-6th | marginal |
7th-8th | marginal |
9 | marginal |
HS-Grad | Diplômé supérieur |
Un collège | Communauté |
Assoc-acdm | Communauté |
Assoc-voc | Communauté |
Les bacheliers | Les bacheliers |
Maîtrise | Maîtrise |
École professionnelle | Maîtrise |
Doctorat | PhD |
recast_data <- data_adult_rescale % > % select(-X) % > % mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community", ifelse(education == "Bachelors", "Bachelors", ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))
Explication du code
- Nous utilisons le verbe muter de la bibliothèque dplyr. Nous changeons les valeurs de l'éducation avec la déclaration ifelse
Dans le tableau ci-dessous, vous créez une statistique récapitulative pour voir, en moyenne, combien d'années d'études (valeur z) il faut pour atteindre le baccalauréat, le master ou le doctorat.
recast_data % > % group_by(education) % > % summarize(average_educ_year = mean(educational.num), count = n()) % > % arrange(average_educ_year)
Sortie :
## # A tibble: 6 x 3 ## education average_educ_year count ## <fctr> <dbl> <int> ## 1 dropout -1.76147258 5712 ## 2 HighGrad -0.43998868 14803 ## 3 Community 0.09561361 13407 ## 4 Bachelors 1.12216282 7720 ## 5 Master 1.60337381 3338 ## 6 PhD 2.29377644 557
Refonte Maristatut-tal
Il est également possible de créer des niveaux inférieurs pour l'état civil. Dans le code suivant, vous modifiez le niveau comme suit :
Ancien niveau | Nouveau niveau |
---|---|
Jamais marié | Pas marié |
Marié-conjoint-absent | Pas marié |
Marié-AF-conjoint | Marié |
Marié-civ-conjoint | |
Séparé | Séparé |
Divorcé | |
Veuves | Veuve |
# Change level marry recast_data <- recast_data % > % mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))
Vous pouvez vérifier le nombre d'individus au sein de chaque groupe.
table(recast_data$marital.status)
Sortie :
## ## Married Not_married Separated Widow ## 21165 15359 7727 1286
Étape 4) Statistique récapitulative
Il est temps de vérifier quelques statistiques sur nos variables cibles. Dans le graphique ci-dessous, vous comptez le pourcentage d'individus gagnant plus de 50 $ en fonction de leur sexe.
# Plot gender income ggplot(recast_data, aes(x = gender, fill = income)) + geom_bar(position = "fill") + theme_classic()
Sortie :
Ensuite, vérifiez si l’origine de l’individu affecte ses revenus.
# Plot origin income ggplot(recast_data, aes(x = race, fill = income)) + geom_bar(position = "fill") + theme_classic() + theme(axis.text.x = element_text(angle = 90))
Sortie :
Le nombre d’heures travaillées par sexe.
# box plot gender working time ggplot(recast_data, aes(x = gender, y = hours.per.week)) + geom_boxplot() + stat_summary(fun.y = mean, geom = "point", size = 3, color = "steelblue") + theme_classic()
Sortie :
Le diagramme en boîte confirme que la répartition du temps de travail correspond à différents groupes. Dans le box plot, les deux sexes n’ont pas d’observations homogènes.
Vous pouvez vérifier la densité du temps de travail hebdomadaire par type d'enseignement. Les distributions ont de nombreux choix distincts. Cela peut probablement s’expliquer par le type de contrat aux États-Unis.
# Plot distribution working time by education ggplot(recast_data, aes(x = hours.per.week)) + geom_density(aes(color = education), alpha = 0.5) + theme_classic()
Explication du code
- ggplot(recast_data, aes( x= hours.per.week)) : un tracé de densité ne nécessite qu'une seule variable
- geom_density(aes(color = education), alpha =0.5) : L'objet géométrique pour contrôler la densité
Sortie :
Pour confirmer vos pensées, vous pouvez effectuer un aller simple Test ANOVA:
anova <- aov(hours.per.week~education, recast_data) summary(anova)
Sortie :
## Df Sum Sq Mean Sq F value Pr(>F) ## education 5 1552 310.31 321.2 <2e-16 *** ## Residuals 45531 43984 0.97 ## --- ## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Le test ANOVA confirme la différence de moyenne entre les groupes.
Non-linéarité
Avant d'exécuter le modèle, vous pouvez voir si le nombre d'heures travaillées est lié à l'âge.
library(ggplot2) ggplot(recast_data, aes(x = age, y = hours.per.week)) + geom_point(aes(color = income), size = 0.5) + stat_smooth(method = 'lm', formula = y~poly(x, 2), se = TRUE, aes(color = income)) + theme_classic()
Explication du code
- ggplot(recast_data, aes(x = age, y = hours.per.week)) : Définir l'esthétique du graphique
- geom_point(aes(color= revenue), size =0.5) : Construire le diagramme de points
- stat_smooth() : ajoutez la ligne de tendance avec les arguments suivants :
- method='lm' : tracez la valeur ajustée si le régression linéaire
- formule = y~poly(x,2) : Ajuster une régression polynomiale
- se = TRUE : Ajouter l'erreur standard
- aes(color= revenue) : décomposer le modèle par revenu
Sortie :
En un mot, vous pouvez tester les termes d'interaction dans le modèle pour détecter l'effet de non-linéarité entre le temps de travail hebdomadaire et d'autres caractéristiques. Il est important de détecter dans quelles conditions le temps de travail diffère.
Corrélation
La vérification suivante consiste à visualiser la corrélation entre les variables. Vous convertissez le type de niveau de facteur en numérique afin de pouvoir tracer une carte thermique contenant le coefficient de corrélation calculé avec la méthode Spearman.
library(GGally) # Convert data to numeric corr <- data.frame(lapply(recast_data, as.integer)) # Plot the graphggcorr(corr, method = c("pairwise", "spearman"), nbreaks = 6, hjust = 0.8, label = TRUE, label_size = 3, color = "grey50")
Explication du code
- data.frame (lapply (recast_data, as.integer)): Convertir les données en numérique
- ggcorr() trace la carte thermique avec les arguments suivants :
- method : Méthode pour calculer la corrélation
- nbreaks = 6 : Nombre de pauses
- hjust = 0.8 : Position de contrôle du nom de la variable dans le tracé
- label = TRUE : Ajoute des étiquettes au centre des fenêtres
- label_size = 3 : étiquettes de taille
- color = « grey50 » : Couleur de l’étiquette
Sortie :
Étape 5) Entraîner/tester l’ensemble
Tout surveillé machine learning La tâche nécessite de diviser les données entre une rame et une rame de test. Vous pouvez utiliser la « fonction » que vous avez créée dans les autres tutoriels d’apprentissage supervisé pour créer un ensemble d’entraînement/test.
set.seed(1234) create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample <- 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } } data_train <- create_train_test(recast_data, 0.8, train = TRUE) data_test <- create_train_test(recast_data, 0.8, train = FALSE) dim(data_train)
Sortie :
## [1] 36429 9
dim(data_test)
Sortie :
## [1] 9108 9
Étape 6) Construire le modèle
Pour voir comment l'algorithme fonctionne, vous utilisez le package glm(). Le Modèle linéaire généralisé est une collection de modèles. La syntaxe de base est la suivante :
glm(formula, data=data, family=linkfunction() Argument: - formula: Equation used to fit the model- data: dataset used - Family: - binomial: (link = "logit") - gaussian: (link = "identity") - Gamma: (link = "inverse") - inverse.gaussian: (link = "1/mu^2") - poisson: (link = "log") - quasi: (link = "identity", variance = "constant") - quasibinomial: (link = "logit") - quasipoisson: (link = "log")
Vous êtes prêt à estimer le modèle logistique pour répartir le niveau de revenu entre un ensemble de fonctionnalités.
formula <- income~. logit <- glm(formula, data = data_train, family = 'binomial') summary(logit)
Explication du code
- formule <- revenu ~ .: Créer le modèle pour s'adapter
- logit <- glm(formula, data = data_train, family = 'binomial') : Ajustez un modèle logistique (family = 'binomial') avec les données data_train.
- summary(logit) : Imprimer le résumé du modèle
Sortie :
## ## Call: ## glm(formula = formula, family = "binomial", data = data_train) ## ## Deviance Residuals: ## Min 1Q Median 3Q Max ## -2.6456 -0.5858 -0.2609 -0.0651 3.1982 ## ## Coefficients: ## Estimate Std. Error z value Pr(>|z|) ## (Intercept) 0.07882 0.21726 0.363 0.71675 ## age 0.41119 0.01857 22.146 < 2e-16 *** ## workclassLocal-gov -0.64018 0.09396 -6.813 9.54e-12 *** ## workclassPrivate -0.53542 0.07886 -6.789 1.13e-11 *** ## workclassSelf-emp-inc -0.07733 0.10350 -0.747 0.45499 ## workclassSelf-emp-not-inc -1.09052 0.09140 -11.931 < 2e-16 *** ## workclassState-gov -0.80562 0.10617 -7.588 3.25e-14 *** ## workclassWithout-pay -1.09765 0.86787 -1.265 0.20596 ## educationCommunity -0.44436 0.08267 -5.375 7.66e-08 *** ## educationHighGrad -0.67613 0.11827 -5.717 1.08e-08 *** ## educationMaster 0.35651 0.06780 5.258 1.46e-07 *** ## educationPhD 0.46995 0.15772 2.980 0.00289 ** ## educationdropout -1.04974 0.21280 -4.933 8.10e-07 *** ## educational.num 0.56908 0.07063 8.057 7.84e-16 *** ## marital.statusNot_married -2.50346 0.05113 -48.966 < 2e-16 *** ## marital.statusSeparated -2.16177 0.05425 -39.846 < 2e-16 *** ## marital.statusWidow -2.22707 0.12522 -17.785 < 2e-16 *** ## raceAsian-Pac-Islander 0.08359 0.20344 0.411 0.68117 ## raceBlack 0.07188 0.19330 0.372 0.71001 ## raceOther 0.01370 0.27695 0.049 0.96054 ## raceWhite 0.34830 0.18441 1.889 0.05894 . ## genderMale 0.08596 0.04289 2.004 0.04506 * ## hours.per.week 0.41942 0.01748 23.998 < 2e-16 *** ## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 ## ## (Dispersion parameter for binomial family taken to be 1) ## ## Null deviance: 40601 on 36428 degrees of freedom ## Residual deviance: 27041 on 36406 degrees of freedom ## AIC: 27087 ## ## Number of Fisher Scoring iterations: 6
Le résumé de notre modèle révèle des informations intéressantes. Les performances d'une régression logistique sont évaluées avec des mesures clés spécifiques.
- AIC (Akaike Information Criteria) : C'est l'équivalent de R2 en régression logistique. Il mesure l’ajustement lorsqu’une pénalité est appliquée au nombre de paramètres. Plus petit AIC les valeurs indiquent que le modèle est plus proche de la vérité.
- Déviance nulle : s'adapte au modèle uniquement avec l'interception. Le degré de liberté est n-1. Nous pouvons l'interpréter comme une valeur du Chi carré (valeur ajustée différente du test d'hypothèse de valeur réelle).
- Déviance résiduelle : Modèle avec toutes les variables. Il est également interprété comme un test d’hypothèse du chi carré.
- Nombre d'itérations de Fisher Scoring : nombre d'itérations avant la convergence.
La sortie de la fonction glm() est stockée dans une liste. Le code ci-dessous montre tous les éléments disponibles dans la variable logit que nous avons construite pour évaluer la régression logistique.
# La liste est très longue, n'imprimez que les trois premiers éléments
lapply(logit, class)[1:3]
Sortie :
## $coefficients ## [1] "numeric" ## ## $residuals ## [1] "numeric" ## ## $fitted.values ## [1] "numeric"
Chaque valeur peut être extraite avec le signe $ suivi du nom de la métrique. Par exemple, vous avez stocké le modèle sous forme de logit. Pour extraire les critères AIC, vous utilisez :
logit$aic
Sortie :
## [1] 27086.65
Étape 7) Évaluer les performances du modèle
Matrice de confusion
Les matrice de confusion est un meilleur choix pour évaluer les performances de classification par rapport aux différentes mesures que vous avez vues auparavant. L'idée générale est de compter le nombre de fois où les instances vraies sont classées comme étant fausses.
Pour calculer la matrice de confusion, vous devez d’abord disposer d’un ensemble de prédictions afin qu’elles puissent être comparées aux cibles réelles.
predict <- predict(logit, data_test, type = 'response') # confusion matrix table_mat <- table(data_test$income, predict > 0.5) table_mat
Explication du code
- Predict(logit,data_test, type = 'response') : calcule la prédiction sur l'ensemble de test. Définissez type = 'response' pour calculer la probabilité de réponse.
- table(data_test$ Income, Predict > 0.5) : Calculez la matrice de confusion. prédire > 0.5 signifie qu'il renvoie 1 si les probabilités prédites sont supérieures à 0.5, sinon 0.
Sortie :
## ## FALSE TRUE ## <=50K 6310 495 ## >50K 1074 1229
Chaque ligne d'une matrice de confusion représente une cible réelle, tandis que chaque colonne représente une cible prévue. La première ligne de cette matrice considère les revenus inférieurs à 50 6241 (la classe Faux) : 50 ont été correctement classés comme individus ayant un revenu inférieur à (Vrai négatif), tandis que le reste a été classé à tort comme supérieur à 50 (Faux positif). La deuxième ligne considère les revenus supérieurs à 50 1229, la classe positive étant de (Vrai positif), tandis que le Vrai négatif était 1074.
Vous pouvez calculer le modèle précision en additionnant les vrais positifs + les vrais négatifs sur l'observation totale
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test
Explication du code
- sum(diag(table_mat)) : Somme de la diagonale
- sum(table_mat) : Somme de la matrice.
Sortie :
## [1] 0.8277339
Le modèle semble souffrir d’un problème : il surestime le nombre de faux négatifs. C'est ce qu'on appelle le paradoxe du test de précision. Nous avons déclaré que l'exactitude est le rapport entre les prédictions correctes et le nombre total de cas. Nous pouvons avoir une précision relativement élevée mais un modèle inutile. Cela se produit lorsqu’il existe une classe dominante. Si vous regardez la matrice de confusion, vous pouvez voir que la plupart des cas sont classés comme vrais négatifs. Imaginez maintenant, le modèle classe toutes les classes comme négatives (c'est-à-dire inférieures à 50 75). Vous auriez une précision de 6718 pour cent (6718/2257+). Votre modèle fonctionne mieux mais a du mal à distinguer le vrai positif du vrai négatif.
Dans une telle situation, il est préférable d’avoir une métrique plus concise. Nous pouvons regarder :
- Précision=TP/(TP+FP)
- Rappel=TP/(TP+FN)
Précision vs rappel
La précision examine l’exactitude de la prédiction positive. Rappeler est le ratio d'instances positives correctement détectées par le classificateur ;
Vous pouvez construire deux fonctions pour calculer ces deux métriques
- Construire avec précision
precision <- function(matrix) { # True positive tp <- matrix[2, 2] # false positive fp <- matrix[1, 2] return (tp / (tp + fp)) }
Explication du code
- mat[1,1] : Renvoie la première cellule de la première colonne de la trame de données, c'est à dire le vrai positif
- tapis[1,2]; Renvoie la première cellule de la deuxième colonne de la trame de données, c'est à dire le faux positif
recall <- function(matrix) { # true positive tp <- matrix[2, 2]# false positive fn <- matrix[2, 1] return (tp / (tp + fn)) }
Explication du code
- mat[1,1] : Renvoie la première cellule de la première colonne de la trame de données, c'est à dire le vrai positif
- tapis[2,1]; Renvoie la deuxième cellule de la première colonne de la trame de données, c'est à dire le faux négatif
Vous pouvez tester vos fonctions
prec <- precision(table_mat) prec rec <- recall(table_mat) rec
Sortie :
## [1] 0.712877 ## [2] 0.5336518
Lorsque le modèle indique qu'il s'agit d'un individu de plus de 50 54 ans, il est correct dans seulement 50 % des cas et peut revendiquer des individus de plus de 72 ans dans % des cas.
Vous pouvez créer le score basé sur la précision et le rappel. Le
est une moyenne harmonique de ces deux mesures, ce qui signifie qu'elle donne plus de poids aux valeurs les plus faibles.
f1 <- 2 * ((prec * rec) / (prec + rec)) f1
Sortie :
## [1] 0.6103799
Compromis entre précision et rappel
Il est impossible d’avoir à la fois une haute précision et un haut rappel.
Si nous augmentons la précision, l’individu correct sera mieux prédit, mais nous en manquerons beaucoup (rappel plus faible). Dans certaines situations, nous préférons une précision supérieure au rappel. Il existe une relation concave entre précision et rappel.
- Imaginez, vous devez prédire si un patient souffre d'une maladie. Vous voulez être aussi précis que possible.
- Si vous devez détecter des personnes frauduleuses potentielles dans la rue grâce à la reconnaissance faciale, il serait préférable d'attraper de nombreuses personnes étiquetées comme frauduleuses, même si la précision est faible. La police pourra libérer l'individu non frauduleux.
La courbe ROC
Les Cible OperaCaractéristique de réglage La courbe est un autre outil couramment utilisé avec la classification binaire. Elle est très similaire à la courbe précision/rappel, mais au lieu de tracer la précision par rapport au rappel, la courbe ROC montre le taux de vrais positifs (c'est-à-dire le rappel) par rapport au taux de faux positifs. Le taux de faux positifs est le rapport des instances négatives qui sont incorrectement classées comme positives. Il est égal à un moins le taux véritablement négatif. Le vrai taux négatif est également appelé spécificité. D'où la courbe ROC Un niveau de sensibilité élevée (rappel) versus 1-spécificité
Pour tracer la courbe ROC, nous devons installer une bibliothèque appelée RORC. On peut trouver dans le conda bibliothèque. Vous pouvez taper le code :
conda install -cr r-rocr –oui
Nous pouvons tracer le ROC avec les fonctions prédiction() et performance().
library(ROCR) ROCRpred <- prediction(predict, data_test$income) ROCRperf <- performance(ROCRpred, 'tpr', 'fpr') plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))
Explication du code
- prédiction(predict, data_test$ Income) : la bibliothèque ROCR doit créer un objet de prédiction pour transformer les données d'entrée
- performance(ROCRpred, 'tpr','fpr') : Renvoie les deux combinaisons à produire dans le graphique. Ici, tpr et fpr sont construits. Pour tracer la précision et rappeler ensemble, utilisez « prec », « rec ».
Sortie :
Étape 8) Améliorer le modèle
Vous pouvez essayer d'ajouter de la non-linéarité au modèle avec l'interaction entre
- âge et heures par semaine
- sexe et heures par semaine.
Vous devez utiliser le test de score pour comparer les deux modèles
formula_2 <- income~age: hours.per.week + gender: hours.per.week + . logit_2 <- glm(formula_2, data = data_train, family = 'binomial') predict_2 <- predict(logit_2, data_test, type = 'response') table_mat_2 <- table(data_test$income, predict_2 > 0.5) precision_2 <- precision(table_mat_2) recall_2 <- recall(table_mat_2) f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2)) f1_2
Sortie :
## [1] 0.6109181
Le score est légèrement supérieur au précédent. Vous pouvez continuer à travailler sur les données et essayer de battre le score.
Résumé
Nous pouvons résumer la fonction pour entraîner une régression logistique dans le tableau ci-dessous :
Forfait | Objectif | Fonction | Argument |
---|---|---|---|
- | Créer un ensemble de données d'entraînement/test | créer_train_set() | données, taille, train |
gm | Former un modèle linéaire généralisé | glm() | formule, données, famille* |
gm | Résumer le modèle | résumé() | modèle ajusté |
base | Faire des prédictions | prédire() | modèle ajusté, ensemble de données, type = « réponse » |
base | Créer une matrice de confusion | tableau() | oui, prédire() |
base | Créer un score de précision | somme(diag(table())/somme(table() | |
ROCR | Créer ROC : Étape 1 Créer une prédiction | prédiction() | prédire(), oui |
ROCR | Créer ROC : Étape 2 Créer de la performance | performance() | prédiction(), 'tpr', 'fpr' |
ROCR | Créer un ROC : Étape 3 : Tracer le graphique | terrain() | performance() |
L'autre GLM les types de modèles sont :
– binôme : (lien = « logit »)
– gaussien : (lien = « identité »)
– Gamma : (lien = « inverse »)
– inverse.gaussien : (lien = « 1/mu^2 »)
– poisson : (lien = « log »)
– quasi : (lien = « identité », variance = « constante »)
– quasibinomial : (lien = « logit »)
– quasipoisson : (lien = « log »)