GLM în R: Model liniar generalizat cu exemplu
Ce este regresia logistică?
Regresia logistică este folosită pentru a prezice o clasă, adică o probabilitate. Regresia logistică poate prezice cu acuratețe un rezultat binar.
Imaginați-vă că doriți să preziceți dacă un împrumut este refuzat/acceptat pe baza mai multor atribute. Regresia logistică este de forma 0/1. y = 0 dacă un împrumut este respins, y = 1 dacă este acceptat.
Un model de regresie logistică diferă de modelul de regresie liniară în două moduri.
- În primul rând, regresia logistică acceptă doar intrare dihotomică (binară) ca variabilă dependentă (adică, un vector de 0 și 1).
- În al doilea rând, rezultatul este măsurat prin următoarea funcție de legătură probabilistică numită sigmoid datorită formei sale de S.:
Ieșirea funcției este întotdeauna între 0 și 1. Verificați imaginea de mai jos
Funcția sigmoidă returnează valori de la 0 la 1. Pentru sarcina de clasificare, avem nevoie de o ieșire discretă de 0 sau 1.
Pentru a converti un flux continuu în valoare discretă, putem seta o limită de decizie la 0.5. Toate valorile peste acest prag sunt clasificate ca 1
Cum se creează modelul generalizat de căptușeală (GLM)
Să folosim adult set de date pentru a ilustra regresia logistică. „Adultul” este un set de date grozav pentru sarcina de clasificare. Obiectivul este de a prezice dacă venitul anual în dolari al unei persoane va depăși 50.000. Setul de date conține 46,033 de observații și zece caracteristici:
- vârsta: vârsta individului. Numeric
- educație: Nivelul educațional al individului. Factor.
- starea civilă: Maristatutul individual al individului. Factorul, adică niciodată căsătorit, căsătorit-civ-soț,...
- gen: Genul individului. Factorul, adică bărbat sau femeie
- sursa de venit: Target variabilă. Venituri peste sau sub 50K. Factor, adică >50K, <=50K
printre altii
library(dplyr) data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv") glimpse(data_adult)
ieșire:
Observations: 48,842 Variables: 10 $ x <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,... $ age <int> 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26... $ workclass <fctr> Private, Private, Local-gov, Private, ?, Private,... $ education <fctr> 11th, HS-grad, Assoc-acdm, Some-college, Some-col... $ educational.num <int> 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,... $ marital.status <fctr> Never-married, Married-civ-spouse, Married-civ-sp... $ race <fctr> Black, White, White, Black, White, White, Black, ... $ gender <fctr> Male, Male, Male, Male, Female, Male, Male, Male,... $ hours.per.week <int> 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39... $ income <fctr> <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5...
Vom proceda astfel:
- Pasul 1: Verificați variabilele continue
- Pasul 2: Verificați variabilele factor
- Pasul 3: Ingineria caracteristicilor
- Pasul 4: Rezumat statistic
- Pasul 5: Setul de tren/test
- Pasul 6: Construiți modelul
- Pasul 7: Evaluați performanța modelului
- Pasul 8: Îmbunătățiți modelul
Sarcina ta este să prezici care individ va avea un venit mai mare de 50.
În acest tutorial, fiecare pas va fi detaliat pentru a efectua o analiză pe un set de date real.
Pasul 1) Verificați variabilele continue
În primul pas, puteți vedea distribuția variabilelor continue.
continuous <-select_if(data_adult, is.numeric) summary(continuous)
Explicarea codului
- continuu <- select_if(data_adult, is.numeric): Utilizați funcția select_if() din biblioteca dplyr pentru a selecta numai coloanele numerice
- summary(continuous): Imprimați statistica rezumat
ieșire:
## X age educational.num hours.per.week ## Min. : 1 Min. :17.00 Min. : 1.00 Min. : 1.00 ## 1st Qu.:11509 1st Qu.:28.00 1st Qu.: 9.00 1st Qu.:40.00 ## Median :23017 Median :37.00 Median :10.00 Median :40.00 ## Mean :23017 Mean :38.56 Mean :10.13 Mean :40.95 ## 3rd Qu.:34525 3rd Qu.:47.00 3rd Qu.:13.00 3rd Qu.:45.00 ## Max. :46033 Max. :90.00 Max. :16.00 Max. :99.00
Din tabelul de mai sus, puteți vedea că datele au scări total diferite și ore.pe.săptămâni au valori aberante mari (adică uitați-vă la ultimul quartil și la valoarea maximă).
Poți să o rezolvi în doi pași:
- 1: Grafică distribuția orelor.pe.săptămână
- 2: Standardizați variabilele continue
- Trasează distribuția
Să ne uităm mai atent la distribuția orelor.pe.săptămână
# Histogram with kernel density curve library(ggplot2) ggplot(continuous, aes(x = hours.per.week)) + geom_density(alpha = .2, fill = "#FF6666")
ieșire:
Variabila are o mulțime de valori aberante și o distribuție nu este bine definită. Puteți rezolva parțial această problemă ștergând primele 0.01% din orele pe săptămână.
Sintaxa de bază a cuantilei:
quantile(variable, percentile) arguments: -variable: Select the variable in the data frame to compute the percentile -percentile: Can be a single value between 0 and 1 or multiple value. If multiple, use this format: `c(A,B,C, ...) - `A`,`B`,`C` and `...` are all integer from 0 to 1.
Calculăm prima percentila de 2 procente
top_one_percent <- quantile(data_adult$hours.per.week, .99) top_one_percent
Explicarea codului
- quantile(data_adult$hours.per.week, .99): Calculați valoarea celor 99% din timpul de lucru
ieșire:
## 99% ## 80
98% din populație lucrează sub 80 de ore pe săptămână.
Puteți renunța la observațiile peste acest prag. Folosești filtrul de la dplyr bibliotecă.
data_adult_drop <-data_adult %>% filter(hours.per.week<top_one_percent) dim(data_adult_drop)
ieșire:
## [1] 45537 10
- Standardizați variabilele continue
Puteți standardiza fiecare coloană pentru a îmbunătăți performanța deoarece datele dvs. nu au aceeași scară. Puteți folosi funcția mutate_if din biblioteca dplyr. Sintaxa de bază este:
mutate_if(df, condition, funs(function)) arguments: -`df`: Data frame used to compute the function - `condition`: Statement used. Do not use parenthesis - funs(function): Return the function to apply. Do not use parenthesis for the function
Puteți standardiza coloanele numerice după cum urmează:
data_adult_rescale <- data_adult_drop % > % mutate_if(is.numeric, funs(as.numeric(scale(.)))) head(data_adult_rescale)
Explicarea codului
- mutate_if(is.numeric, funs(scale)): condiția este doar o coloană numerică, iar funcția este scară
ieșire:
## X age workclass education educational.num ## 1 -1.732680 -1.02325949 Private 11th -1.22106443 ## 2 -1.732605 -0.03969284 Private HS-grad -0.43998868 ## 3 -1.732530 -0.79628257 Local-gov Assoc-acdm 0.73162494 ## 4 -1.732455 0.41426100 Private Some-college -0.04945081 ## 5 -1.732379 -0.34232873 Private 10th -1.61160231 ## 6 -1.732304 1.85178149 Self-emp-not-inc Prof-school 1.90323857 ## marital.status race gender hours.per.week income ## 1 Never-married Black Male -0.03995944 <=50K ## 2 Married-civ-spouse White Male 0.86863037 <=50K ## 3 Married-civ-spouse White Male -0.03995944 >50K ## 4 Married-civ-spouse Black Male -0.03995944 >50K ## 5 Never-married White Male -0.94854924 <=50K ## 6 Married-civ-spouse White Male -0.76683128 >50K
Pasul 2) Verificați variabilele factor
Acest pas are două obiective:
- Verificați nivelul din fiecare coloană categorială
- Definiți noi niveluri
Vom împărți acest pas în trei părți:
- Selectați coloanele categorice
- Stocați diagrama cu bare a fiecărei coloane într-o listă
- Tipăriți graficele
Putem selecta coloanele de factori cu codul de mai jos:
# Select categorical column factor <- data.frame(select_if(data_adult_rescale, is.factor)) ncol(factor)
Explicarea codului
- data.frame(select_if(data_adult, is.factor)): Stocăm coloanele factor în factor într-un tip de cadru de date. Biblioteca ggplot2 necesită un obiect cadru de date.
ieșire:
## [1] 6
Setul de date conține 6 variabile categoriale
Al doilea pas este mai priceput. Doriți să trasați o diagramă cu bare pentru fiecare coloană din factorul cadru de date. Este mai convenabil să automatizezi procesul, mai ales în situația în care există o mulțime de coloane.
library(ggplot2) # Create graph for each column graph <- lapply(names(factor), function(x) ggplot(factor, aes(get(x))) + geom_bar() + theme(axis.text.x = element_text(angle = 90)))
Explicarea codului
- lapply(): Utilizați funcția lapply() pentru a trece o funcție în toate coloanele setului de date. Salvați rezultatul într-o listă
- function(x): Funcția va fi procesată pentru fiecare x. Aici x sunt coloanele
- ggplot(factor, aes(get(x))) + geom_bar()+ theme(axis.text.x = element_text(angle = 90)): Creați o diagramă cu bare pentru fiecare element x. Rețineți, pentru a returna x ca coloană, trebuie să-l includeți în get()
Ultimul pas este relativ ușor. Doriți să imprimați cele 6 grafice.
# Print the graph graph
ieșire:
## [[1]]
## ## [[2]]
## ## [[3]]
## ## [[4]]
## ## [[5]]
## ## [[6]]
Notă: Folosiți butonul următor pentru a naviga la următorul grafic
Pasul 3) Ingineria caracteristicilor
Reformarea educației
Din graficul de mai sus, puteți vedea că variabila educație are 16 niveluri. Acest lucru este substanțial, iar unele niveluri au un număr relativ scăzut de observații. Dacă doriți să îmbunătățiți cantitatea de informații pe care o puteți obține din această variabilă, o puteți transforma la un nivel superior. Și anume, creezi grupuri mai mari cu nivel similar de educație. De exemplu, nivelul scăzut de educație va fi transformat în abandon școlar. Nivelurile superioare de educație vor fi schimbate în master.
Iată detaliul:
Nivel vechi | Nivel nou |
---|---|
Preşcolar | renunța |
10 | Renunța |
11 | Renunța |
12 | Renunța |
1st-4th | Renunța |
5th-6th | Renunța |
7th-8th | Renunța |
9 | Renunța |
HS-Grad | HighGrad |
O facultate | Comunitate |
Conf.-acdm | Comunitate |
Conf.-voc | Comunitate |
burlaci | burlaci |
masterat | masterat |
Prof-scoala | masterat |
Doctorat | Dr. |
recast_data <- data_adult_rescale % > % select(-X) % > % mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community", ifelse(education == "Bachelors", "Bachelors", ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))
Explicarea codului
- Folosim verbul mutare din biblioteca dplyr. Schimbăm valorile educației cu afirmația ifelse
În tabelul de mai jos, creați o statistică rezumativă pentru a vedea, în medie, câți ani de educație (valoarea z) sunt necesari pentru a ajunge la licență, master sau doctorat.
recast_data % > % group_by(education) % > % summarize(average_educ_year = mean(educational.num), count = n()) % > % arrange(average_educ_year)
ieșire:
## # A tibble: 6 x 3 ## education average_educ_year count ## <fctr> <dbl> <int> ## 1 dropout -1.76147258 5712 ## 2 HighGrad -0.43998868 14803 ## 3 Community 0.09561361 13407 ## 4 Bachelors 1.12216282 7720 ## 5 Master 1.60337381 3338 ## 6 PhD 2.29377644 557
Reformare Marital-status
De asemenea, este posibil să se creeze niveluri inferioare pentru starea civilă. În următorul cod modificați nivelul după cum urmează:
Nivel vechi | Nivel nou |
---|---|
Niciodata casatorit | Necasatorit |
Căsătorit-soț-absent | Necasatorit |
Căsătorit-AF-soț | Căsătorit |
Căsătorit-civ-soț | |
Separat | Separat |
Divorţat | |
văduvele | Văduvă |
# Change level marry recast_data <- recast_data % > % mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))
Puteți verifica numărul de persoane din fiecare grup.
table(recast_data$marital.status)
ieșire:
## ## Married Not_married Separated Widow ## 21165 15359 7727 1286
Pasul 4) Statistică rezumată
Este timpul să verificăm câteva statistici despre variabilele noastre țintă. În graficul de mai jos, numărați procentul de persoane care câștigă mai mult de 50, având în vedere sexul lor.
# Plot gender income ggplot(recast_data, aes(x = gender, fill = income)) + geom_bar(position = "fill") + theme_classic()
ieșire:
Apoi, verificați dacă originea individului îi afectează câștigurile.
# Plot origin income ggplot(recast_data, aes(x = race, fill = income)) + geom_bar(position = "fill") + theme_classic() + theme(axis.text.x = element_text(angle = 90))
ieșire:
Numărul de ore de muncă în funcție de sex.
# box plot gender working time ggplot(recast_data, aes(x = gender, y = hours.per.week)) + geom_boxplot() + stat_summary(fun.y = mean, geom = "point", size = 3, color = "steelblue") + theme_classic()
ieșire:
Box plot confirmă faptul că distribuția timpului de lucru se potrivește diferitelor grupuri. În box plot, ambele sexe nu au observații omogene.
Puteți verifica densitatea timpului de lucru săptămânal în funcție de tipul de studii. Distribuțiile au multe alegeri distincte. Poate fi explicat prin tipul de contract din SUA.
# Plot distribution working time by education ggplot(recast_data, aes(x = hours.per.week)) + geom_density(aes(color = education), alpha = 0.5) + theme_classic()
Explicarea codului
- ggplot(recast_data, aes(x= hours.per.week)): Un grafic de densitate necesită doar o variabilă
- geom_density(aes(culoare = educație), alpha =0.5): obiectul geometric pentru a controla densitatea
ieșire:
Pentru a vă confirma gândurile, puteți efectua un singur sens test ANOVA:
anova <- aov(hours.per.week~education, recast_data) summary(anova)
ieșire:
## Df Sum Sq Mean Sq F value Pr(>F) ## education 5 1552 310.31 321.2 <2e-16 *** ## Residuals 45531 43984 0.97 ## --- ## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Testul ANOVA confirmă diferența de medie între grupuri.
Non-liniaritatea
Înainte de a rula modelul, puteți vedea dacă numărul de ore lucrate este legat de vârstă.
library(ggplot2) ggplot(recast_data, aes(x = age, y = hours.per.week)) + geom_point(aes(color = income), size = 0.5) + stat_smooth(method = 'lm', formula = y~poly(x, 2), se = TRUE, aes(color = income)) + theme_classic()
Explicarea codului
- ggplot(recast_data, aes(x = age, y = hours.per.week)): setați estetica graficului
- geom_point(aes(culoare=venit), dimensiune =0.5): Construiți graficul cu puncte
- stat_smooth(): Adăugați linia de tendință cu următoarele argumente:
- method='lm': Trasează valoarea ajustată dacă regresie liniara
- formula = y~poly(x,2): Potriviți o regresie polinomială
- se = TRUE: Adăugați eroarea standard
- aes(culoare= venit): sparge modelul după venit
ieșire:
Pe scurt, puteți testa termenii de interacțiune din model pentru a detecta efectul de neliniaritate dintre timpul de lucru săptămânal și alte caracteristici. Este important să se detecteze în ce condiții diferă timpul de lucru.
Corelație
Următoarea verificare este de a vizualiza corelația dintre variabile. Convertiți tipul de nivel de factor în numeric, astfel încât să puteți reprezenta o hartă termică care conține coeficientul de corelație calculat cu metoda Spearman.
library(GGally) # Convert data to numeric corr <- data.frame(lapply(recast_data, as.integer)) # Plot the graphggcorr(corr, method = c("pairwise", "spearman"), nbreaks = 6, hjust = 0.8, label = TRUE, label_size = 3, color = "grey50")
Explicarea codului
- data.frame(lapply(recast_data,as.integer)): Convertiți datele în numere
- ggcorr() trasează harta termică cu următoarele argumente:
- metoda: Metodă de calcul a corelației
- nbreaks = 6: Număr de pauze
- hjust = 0.8: Poziția de control a numelui variabilei în grafic
- label = TRUE: Adăugați etichete în centrul ferestrelor
- label_size = 3: etichete de dimensiune
- culoare = „grey50”): culoarea etichetei
ieșire:
Pasul 5) Set de tren/test
Orice supravegheat masina de învățare sarcina necesită împărțirea datelor între un set de tren și un set de testare. Puteți utiliza „funcția” pe care ați creat-o în celelalte tutoriale de învățare supravegheată pentru a crea un set de tren/test.
set.seed(1234) create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample <- 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } } data_train <- create_train_test(recast_data, 0.8, train = TRUE) data_test <- create_train_test(recast_data, 0.8, train = FALSE) dim(data_train)
ieșire:
## [1] 36429 9
dim(data_test)
ieșire:
## [1] 9108 9
Pasul 6) Construiți modelul
Pentru a vedea cum funcționează algoritmul, utilizați pachetul glm(). The Model liniar generalizat este o colecție de modele. Sintaxa de bază este:
glm(formula, data=data, family=linkfunction() Argument: - formula: Equation used to fit the model- data: dataset used - Family: - binomial: (link = "logit") - gaussian: (link = "identity") - Gamma: (link = "inverse") - inverse.gaussian: (link = "1/mu^2") - poisson: (link = "log") - quasi: (link = "identity", variance = "constant") - quasibinomial: (link = "logit") - quasipoisson: (link = "log")
Sunteți gata să estimați modelul logistic pentru a împărți nivelul veniturilor între un set de caracteristici.
formula <- income~. logit <- glm(formula, data = data_train, family = 'binomial') summary(logit)
Explicarea codului
- formula <- venit ~ .: Creați modelul care să se potrivească
- logit <- glm(formula, data = data_train, family = 'binom'): Potriviți un model logistic (family = 'binom') cu datele data_train.
- summary(logit): Imprimați rezumatul modelului
ieșire:
## ## Call: ## glm(formula = formula, family = "binomial", data = data_train) ## ## Deviance Residuals: ## Min 1Q Median 3Q Max ## -2.6456 -0.5858 -0.2609 -0.0651 3.1982 ## ## Coefficients: ## Estimate Std. Error z value Pr(>|z|) ## (Intercept) 0.07882 0.21726 0.363 0.71675 ## age 0.41119 0.01857 22.146 < 2e-16 *** ## workclassLocal-gov -0.64018 0.09396 -6.813 9.54e-12 *** ## workclassPrivate -0.53542 0.07886 -6.789 1.13e-11 *** ## workclassSelf-emp-inc -0.07733 0.10350 -0.747 0.45499 ## workclassSelf-emp-not-inc -1.09052 0.09140 -11.931 < 2e-16 *** ## workclassState-gov -0.80562 0.10617 -7.588 3.25e-14 *** ## workclassWithout-pay -1.09765 0.86787 -1.265 0.20596 ## educationCommunity -0.44436 0.08267 -5.375 7.66e-08 *** ## educationHighGrad -0.67613 0.11827 -5.717 1.08e-08 *** ## educationMaster 0.35651 0.06780 5.258 1.46e-07 *** ## educationPhD 0.46995 0.15772 2.980 0.00289 ** ## educationdropout -1.04974 0.21280 -4.933 8.10e-07 *** ## educational.num 0.56908 0.07063 8.057 7.84e-16 *** ## marital.statusNot_married -2.50346 0.05113 -48.966 < 2e-16 *** ## marital.statusSeparated -2.16177 0.05425 -39.846 < 2e-16 *** ## marital.statusWidow -2.22707 0.12522 -17.785 < 2e-16 *** ## raceAsian-Pac-Islander 0.08359 0.20344 0.411 0.68117 ## raceBlack 0.07188 0.19330 0.372 0.71001 ## raceOther 0.01370 0.27695 0.049 0.96054 ## raceWhite 0.34830 0.18441 1.889 0.05894 . ## genderMale 0.08596 0.04289 2.004 0.04506 * ## hours.per.week 0.41942 0.01748 23.998 < 2e-16 *** ## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 ## ## (Dispersion parameter for binomial family taken to be 1) ## ## Null deviance: 40601 on 36428 degrees of freedom ## Residual deviance: 27041 on 36406 degrees of freedom ## AIC: 27087 ## ## Number of Fisher Scoring iterations: 6
Rezumatul modelului nostru dezvăluie informații interesante. Performanța unei regresii logistice este evaluată cu valori cheie specifice.
- AIC (Akaike Information Criteria): Acesta este echivalentul R2 în regresie logistică. Măsoară potrivirea atunci când se aplică o penalizare numărului de parametri. Mai mic AIC valorile indică faptul că modelul este mai aproape de adevăr.
- Devianța nulă: se potrivește modelului numai cu interceptarea. Gradul de libertate este n-1. O putem interpreta ca o valoare Chi-pătrat (valoare ajustată diferită de testarea ipotezei valorii reale).
- Devianța reziduală: Model cu toate variabilele. De asemenea, este interpretată ca o testare a ipotezei Chi-pătrat.
- Numărul de iterații Fisher Scoring: numărul de iterații înainte de convergență.
Ieșirea funcției glm() este stocată într-o listă. Codul de mai jos arată toate elementele disponibile în variabila logit pe care am construit-o pentru a evalua regresia logistică.
# Lista este foarte lungă, imprimați doar primele trei elemente
lapply(logit, class)[1:3]
ieșire:
## $coefficients ## [1] "numeric" ## ## $residuals ## [1] "numeric" ## ## $fitted.values ## [1] "numeric"
Fiecare valoare poate fi extrasă cu semnul $ urmat de numele valorilor. De exemplu, ați stocat modelul ca logit. Pentru a extrage criteriile AIC, utilizați:
logit$aic
ieșire:
## [1] 27086.65
Pasul 7) Evaluați performanța modelului
Matricea confuziei
matrice de confuzie este o alegere mai bună pentru a evalua performanța clasificării în comparație cu diferitele valori pe care le-ați văzut anterior. Ideea generală este de a număra de câte ori instanțele adevărate sunt clasificate ca fiind false.
Pentru a calcula matricea de confuzie, trebuie mai întâi să aveți un set de predicții, astfel încât acestea să poată fi comparate cu țintele reale.
predict <- predict(logit, data_test, type = 'response') # confusion matrix table_mat <- table(data_test$income, predict > 0.5) table_mat
Explicarea codului
- predict(logit,data_test, type = 'răspuns'): Calculați predicția pe setul de testare. Setați tip = „răspuns” pentru a calcula probabilitatea de răspuns.
- table(data_test$income, predict > 0.5): Calculați matricea de confuzie. prezice > 0.5 înseamnă că returnează 1 dacă probabilitățile prezise sunt peste 0.5, altfel 0.
ieșire:
## ## FALSE TRUE ## <=50K 6310 495 ## >50K 1074 1229
Fiecare rând dintr-o matrice de confuzie reprezintă o țintă reală, în timp ce fiecare coloană reprezintă o țintă estimată. Primul rând al acestei matrice consideră venitul mai mic de 50k (clasa Fals): 6241 au fost clasificați corect ca persoane cu venituri mai mici de 50k (Adevărat negativ), în timp ce cel rămas a fost clasificat greșit ca peste 50k (Fals pozitiv). Al doilea rând ia în considerare venitul peste 50k, clasa pozitivă a fost 1229 (Adevărat pozitiv), in timp ce Adevărat negativ a fost 1074.
Puteți calcula modelul precizie prin însumarea adevăratului pozitiv + adevăratul negativ peste observația totală
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test
Explicarea codului
- sum(diag(table_mat)): Suma diagonalei
- sum(table_mat): Suma matricei.
ieșire:
## [1] 0.8277339
Modelul pare să sufere de o problemă, supraestimează numărul de fals negative. Aceasta se numește paradoxul testului de precizie. Am afirmat că acuratețea este raportul dintre predicțiile corecte și numărul total de cazuri. Putem avea o precizie relativ mare, dar un model inutil. Se întâmplă când există o clasă dominantă. Dacă priviți înapoi la matricea de confuzie, puteți vedea că majoritatea cazurilor sunt clasificate drept negative adevărate. Imaginați-vă acum, modelul a clasificat toate clasele drept negative (adică mai mici de 50k). Ai avea o precizie de 75 la sută (6718/6718+2257). Modelul tău are performanțe mai bune, dar se luptă să distingă adevăratul pozitiv de adevăratul negativ.
Într-o astfel de situație, este de preferat să aveți o metrică mai concisă. Ne putem uita la:
- Precizie=TP/(TP+FP)
- Recall=TP/(TP+FN)
Precizie vs rechemare
Precizie se uită la acuratețea predicției pozitive. Rechemare este raportul instanțelor pozitive care sunt detectate corect de către clasificator;
Puteți construi două funcții pentru a calcula aceste două valori
- Precizie de construcție
precision <- function(matrix) { # True positive tp <- matrix[2, 2] # false positive fp <- matrix[1, 2] return (tp / (tp + fp)) }
Explicarea codului
- mat[1,1]: Returnează prima celulă a primei coloane a cadrului de date, adică adevăratul pozitiv
- mat[1,2]; Returnează prima celulă din a doua coloană a cadrului de date, adică fals pozitiv
recall <- function(matrix) { # true positive tp <- matrix[2, 2]# false positive fn <- matrix[2, 1] return (tp / (tp + fn)) }
Explicarea codului
- mat[1,1]: Returnează prima celulă a primei coloane a cadrului de date, adică adevăratul pozitiv
- mat[2,1]; Returnați a doua celulă a primei coloane a cadrului de date, adică negativul fals
Vă puteți testa funcțiile
prec <- precision(table_mat) prec rec <- recall(table_mat) rec
ieșire:
## [1] 0.712877 ## [2] 0.5336518
Când modelul spune că este un individ de peste 50k, este corect doar în 54% din cazuri și poate revendica persoane cu peste 50k în 72% din cazuri.
Puteți crea fișierul scor bazat pe precizie și reamintire. The
este o medie armonică a acestor două metrici, ceea ce înseamnă că acordă mai multă pondere valorilor mai mici.
f1 <- 2 * ((prec * rec) / (prec + rec)) f1
ieșire:
## [1] 0.6103799
Precizie vs rechemare compromis
Este imposibil să aveți atât o precizie ridicată, cât și o reamintire ridicată.
Dacă creștem precizia, individul corect va fi mai bine prezis, dar am rata foarte multe dintre ele (reamintire mai mică). În unele situații, preferăm o precizie mai mare decât reamintirea. Există o relație concavă între precizie și reamintire.
- Imaginați-vă, trebuie să preziceți dacă un pacient are o boală. Vrei să fii cât mai precis posibil.
- Dacă trebuie să detectați potențialele persoane frauduloase de pe stradă prin recunoașterea facială, ar fi mai bine să prindeți multe persoane etichetate drept frauduloase, chiar dacă precizia este scăzută. Poliția va putea elibera persoana nefrauduloasă.
Curba ROC
Receptor Operating Caracteristică curba este un alt instrument comun utilizat cu clasificarea binară. Este foarte asemănătoare cu curba de precizie/rechemare, dar în loc de a reprezenta un grafic precizie versus reamintire, curba ROC arată rata pozitivă adevărată (adică, reamintirea) față de rata fals pozitivă. Rata fals pozitive este raportul dintre cazurile negative care sunt clasificate incorect drept pozitive. Este egal cu unu minus rata negativă adevărată. Rata adevărată negativă se mai numește specificitate. De aici curba ROC grafică sensibilitate (reamintire) versus 1-specificitate
Pentru a trasa curba ROC, trebuie să instalăm o bibliotecă numită RORC. Putem găsi în conda bibliotecă. Puteți introduce codul:
conda install -cr r-rocr –da
Putem reprezenta ROC cu funcțiile prediction() și performance().
library(ROCR) ROCRpred <- prediction(predict, data_test$income) ROCRperf <- performance(ROCRpred, 'tpr', 'fpr') plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))
Explicarea codului
- prediction(predict, data_test$income): biblioteca ROCR trebuie să creeze un obiect de predicție pentru a transforma datele de intrare
- performanță(ROCRpred, 'tpr','fpr'): returnează cele două combinații de produs în grafic. Aici, tpr și fpr sunt construite. Pentru a reprezenta precizia și a reaminti împreună, folosiți „prec”, „rec”.
ieșire:
Pas 8) Îmbunătățiți modelul
Puteți încerca să adăugați non-liniaritate modelului cu interacțiunea dintre
- vârsta și ore.pe.săptămână
- sex și ore.pe.săptămână.
Trebuie să utilizați testul de scor pentru a compara ambele modele
formula_2 <- income~age: hours.per.week + gender: hours.per.week + . logit_2 <- glm(formula_2, data = data_train, family = 'binomial') predict_2 <- predict(logit_2, data_test, type = 'response') table_mat_2 <- table(data_test$income, predict_2 > 0.5) precision_2 <- precision(table_mat_2) recall_2 <- recall(table_mat_2) f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2)) f1_2
ieșire:
## [1] 0.6109181
Scorul este puțin mai mare decât cel anterior. Puteți continua să lucrați la date și să încercați să bateți scorul.
Rezumat
Putem rezuma funcția de antrenare a unei regresii logistice în tabelul de mai jos:
Pachet | Obiectiv | Funcţie | Argument |
---|---|---|---|
- | Creați un set de date tren/test | create_train_set() | date, dimensiune, tren |
glm | Antrenează un model liniar generalizat | glm() | formula, date, familie* |
glm | Rezumă modelul | rezumat() | model montat |
de bază | Faceți predicții | prezice() | model adaptat, set de date, tip = „răspuns” |
de bază | Creați o matrice de confuzie | masa() | y, prezice() |
de bază | Creați un scor de precizie | suma(diag(tabel())/sum(tabel() | |
ROCR | Creați ROC: Pasul 1 Creați predicție | predicție () | prezice(), y |
ROCR | Creați ROC: Pasul 2 Creați performanță | performanţă() | predicție(), „tpr”, „fpr” |
ROCR | Creați ROC: Pasul 3 Trasați graficul | complot() | performanţă() |
Celălalt GLM tipuri de modele sunt:
– binom: (link = „logit”)
– gaussian: (link = „identitate”)
– Gamma: (link = „invers”)
– invers.gaussian: (link = „1/mu^2”)
– poisson: (link = „jurnal”)
– cvasi: (link = „identitate”, varianță = „constant”)
– cvasibinom: (link = „logit”)
– cvasipoisson: (link = „log”)