GLM în R: Model liniar generalizat cu exemplu

Ce este regresia logistică?

Regresia logistică este folosită pentru a prezice o clasă, adică o probabilitate. Regresia logistică poate prezice cu acuratețe un rezultat binar.

Imaginați-vă că doriți să preziceți dacă un împrumut este refuzat/acceptat pe baza mai multor atribute. Regresia logistică este de forma 0/1. y = 0 dacă un împrumut este respins, y = 1 dacă este acceptat.

Un model de regresie logistică diferă de modelul de regresie liniară în două moduri.

  • În primul rând, regresia logistică acceptă doar intrare dihotomică (binară) ca variabilă dependentă (adică, un vector de 0 și 1).
  • În al doilea rând, rezultatul este măsurat prin următoarea funcție de legătură probabilistică numită sigmoid datorită formei sale de S.:

Regresie logistică

Ieșirea funcției este întotdeauna între 0 și 1. Verificați imaginea de mai jos

Regresie logistică

Funcția sigmoidă returnează valori de la 0 la 1. Pentru sarcina de clasificare, avem nevoie de o ieșire discretă de 0 sau 1.

Pentru a converti un flux continuu în valoare discretă, putem seta o limită de decizie la 0.5. Toate valorile peste acest prag sunt clasificate ca 1

Regresie logistică

Cum se creează modelul generalizat de căptușeală (GLM)

Să folosim adult set de date pentru a ilustra regresia logistică. „Adultul” este un set de date grozav pentru sarcina de clasificare. Obiectivul este de a prezice dacă venitul anual în dolari al unei persoane va depăși 50.000. Setul de date conține 46,033 de observații și zece caracteristici:

  • vârsta: vârsta individului. Numeric
  • educație: Nivelul educațional al individului. Factor.
  • starea civilă: Maristatutul individual al individului. Factorul, adică niciodată căsătorit, căsătorit-civ-soț,...
  • gen: Genul individului. Factorul, adică bărbat sau femeie
  • sursa de venit: Target variabilă. Venituri peste sau sub 50K. Factor, adică >50K, <=50K

printre altii

library(dplyr)
data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv")
glimpse(data_adult)

ieșire:

Observations: 48,842
Variables: 10
$ x               <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,...
$ age             <int> 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26...
$ workclass       <fctr> Private, Private, Local-gov, Private, ?, Private,...
$ education       <fctr> 11th, HS-grad, Assoc-acdm, Some-college, Some-col...
$ educational.num <int> 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,...
$ marital.status  <fctr> Never-married, Married-civ-spouse, Married-civ-sp...
$ race            <fctr> Black, White, White, Black, White, White, Black, ...
$ gender          <fctr> Male, Male, Male, Male, Female, Male, Male, Male,...
$ hours.per.week  <int> 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39...
$ income          <fctr> <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5...

Vom proceda astfel:

  • Pasul 1: Verificați variabilele continue
  • Pasul 2: Verificați variabilele factor
  • Pasul 3: Ingineria caracteristicilor
  • Pasul 4: Rezumat statistic
  • Pasul 5: Setul de tren/test
  • Pasul 6: Construiți modelul
  • Pasul 7: Evaluați performanța modelului
  • Pasul 8: Îmbunătățiți modelul

Sarcina ta este să prezici care individ va avea un venit mai mare de 50.

În acest tutorial, fiecare pas va fi detaliat pentru a efectua o analiză pe un set de date real.

Pasul 1) Verificați variabilele continue

În primul pas, puteți vedea distribuția variabilelor continue.

continuous <-select_if(data_adult, is.numeric)
summary(continuous)

Explicarea codului

  • continuu <- select_if(data_adult, is.numeric): Utilizați funcția select_if() din biblioteca dplyr pentru a selecta numai coloanele numerice
  • summary(continuous): Imprimați statistica rezumat

ieșire:

##        X              age        educational.num hours.per.week 
##  Min.   :    1   Min.   :17.00   Min.   : 1.00   Min.   : 1.00  
##  1st Qu.:11509   1st Qu.:28.00   1st Qu.: 9.00   1st Qu.:40.00  
##  Median :23017   Median :37.00   Median :10.00   Median :40.00  
##  Mean   :23017   Mean   :38.56   Mean   :10.13   Mean   :40.95  
##  3rd Qu.:34525   3rd Qu.:47.00   3rd Qu.:13.00   3rd Qu.:45.00  
##  Max.   :46033   Max.   :90.00   Max.   :16.00   Max.   :99.00	

Din tabelul de mai sus, puteți vedea că datele au scări total diferite și ore.pe.săptămâni au valori aberante mari (adică uitați-vă la ultimul quartil și la valoarea maximă).

Poți să o rezolvi în doi pași:

  • 1: Grafică distribuția orelor.pe.săptămână
  • 2: Standardizați variabilele continue
  1. Trasează distribuția

Să ne uităm mai atent la distribuția orelor.pe.săptămână

# Histogram with kernel density curve
library(ggplot2)
ggplot(continuous, aes(x = hours.per.week)) +
    geom_density(alpha = .2, fill = "#FF6666")

ieșire:

Verificați variabilele continue

Variabila are o mulțime de valori aberante și o distribuție nu este bine definită. Puteți rezolva parțial această problemă ștergând primele 0.01% din orele pe săptămână.

Sintaxa de bază a cuantilei:

quantile(variable, percentile)
arguments:
-variable:  Select the variable in the data frame to compute the percentile
-percentile:  Can be a single value between 0 and 1 or multiple value. If multiple, use this format:  `c(A,B,C, ...)
- `A`,`B`,`C` and `...` are all integer from 0 to 1.

Calculăm prima percentila de 2 procente

top_one_percent <- quantile(data_adult$hours.per.week, .99)
top_one_percent

Explicarea codului

  • quantile(data_adult$hours.per.week, .99): Calculați valoarea celor 99% din timpul de lucru

ieșire:

## 99% 
##  80

98% din populație lucrează sub 80 de ore pe săptămână.

Puteți renunța la observațiile peste acest prag. Folosești filtrul de la dplyr bibliotecă.

data_adult_drop <-data_adult %>%
filter(hours.per.week<top_one_percent)
dim(data_adult_drop)

ieșire:

## [1] 45537    10
  1. Standardizați variabilele continue

Puteți standardiza fiecare coloană pentru a îmbunătăți performanța deoarece datele dvs. nu au aceeași scară. Puteți folosi funcția mutate_if din biblioteca dplyr. Sintaxa de bază este:

mutate_if(df, condition, funs(function))
arguments:
-`df`: Data frame used to compute the function
- `condition`: Statement used. Do not use parenthesis
- funs(function):  Return the function to apply. Do not use parenthesis for the function

Puteți standardiza coloanele numerice după cum urmează:

data_adult_rescale <- data_adult_drop % > %
	mutate_if(is.numeric, funs(as.numeric(scale(.))))
head(data_adult_rescale)

Explicarea codului

  • mutate_if(is.numeric, funs(scale)): condiția este doar o coloană numerică, iar funcția este scară

ieșire:

##           X         age        workclass    education educational.num
## 1 -1.732680 -1.02325949          Private         11th     -1.22106443
## 2 -1.732605 -0.03969284          Private      HS-grad     -0.43998868
## 3 -1.732530 -0.79628257        Local-gov   Assoc-acdm      0.73162494
## 4 -1.732455  0.41426100          Private Some-college     -0.04945081
## 5 -1.732379 -0.34232873          Private         10th     -1.61160231
## 6 -1.732304  1.85178149 Self-emp-not-inc  Prof-school      1.90323857
##       marital.status  race gender hours.per.week income
## 1      Never-married Black   Male    -0.03995944  <=50K
## 2 Married-civ-spouse White   Male     0.86863037  <=50K
## 3 Married-civ-spouse White   Male    -0.03995944   >50K
## 4 Married-civ-spouse Black   Male    -0.03995944   >50K
## 5      Never-married White   Male    -0.94854924  <=50K
## 6 Married-civ-spouse White   Male    -0.76683128   >50K

Pasul 2) Verificați variabilele factor

Acest pas are două obiective:

  • Verificați nivelul din fiecare coloană categorială
  • Definiți noi niveluri

Vom împărți acest pas în trei părți:

  • Selectați coloanele categorice
  • Stocați diagrama cu bare a fiecărei coloane într-o listă
  • Tipăriți graficele

Putem selecta coloanele de factori cu codul de mai jos:

# Select categorical column
factor <- data.frame(select_if(data_adult_rescale, is.factor))
	ncol(factor)

Explicarea codului

  • data.frame(select_if(data_adult, is.factor)): Stocăm coloanele factor în factor într-un tip de cadru de date. Biblioteca ggplot2 necesită un obiect cadru de date.

ieșire:

## [1] 6

Setul de date conține 6 variabile categoriale

Al doilea pas este mai priceput. Doriți să trasați o diagramă cu bare pentru fiecare coloană din factorul cadru de date. Este mai convenabil să automatizezi procesul, mai ales în situația în care există o mulțime de coloane.

library(ggplot2)
# Create graph for each column
graph <- lapply(names(factor),
    function(x) 
	ggplot(factor, aes(get(x))) +
		geom_bar() +
		theme(axis.text.x = element_text(angle = 90)))

Explicarea codului

  • lapply(): Utilizați funcția lapply() pentru a trece o funcție în toate coloanele setului de date. Salvați rezultatul într-o listă
  • function(x): Funcția va fi procesată pentru fiecare x. Aici x sunt coloanele
  • ggplot(factor, aes(get(x))) + geom_bar()+ theme(axis.text.x = element_text(angle = 90)): Creați o diagramă cu bare pentru fiecare element x. Rețineți, pentru a returna x ca coloană, trebuie să-l includeți în get()

Ultimul pas este relativ ușor. Doriți să imprimați cele 6 grafice.

# Print the graph
graph

ieșire:

## [[1]]

Verificați variabilele factor

## ## [[2]]

Verificați variabilele factor

## ## [[3]]

Verificați variabilele factor

## ## [[4]]

Verificați variabilele factor

## ## [[5]]

Verificați variabilele factor

## ## [[6]]

Verificați variabilele factor

Notă: Folosiți butonul următor pentru a naviga la următorul grafic

Verificați variabilele factor

Pasul 3) Ingineria caracteristicilor

Reformarea educației

Din graficul de mai sus, puteți vedea că variabila educație are 16 niveluri. Acest lucru este substanțial, iar unele niveluri au un număr relativ scăzut de observații. Dacă doriți să îmbunătățiți cantitatea de informații pe care o puteți obține din această variabilă, o puteți transforma la un nivel superior. Și anume, creezi grupuri mai mari cu nivel similar de educație. De exemplu, nivelul scăzut de educație va fi transformat în abandon școlar. Nivelurile superioare de educație vor fi schimbate în master.

Iată detaliul:

Nivel vechi Nivel nou
Preşcolar renunța
10 Renunța
11 Renunța
12 Renunța
1st-4th Renunța
5th-6th Renunța
7th-8th Renunța
9 Renunța
HS-Grad HighGrad
O facultate Comunitate
Conf.-acdm Comunitate
Conf.-voc Comunitate
burlaci burlaci
masterat masterat
Prof-scoala masterat
Doctorat Dr.
recast_data <- data_adult_rescale % > %
	select(-X) % > %
	mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community",
    ifelse(education == "Bachelors", "Bachelors",
        ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))

Explicarea codului

  • Folosim verbul mutare din biblioteca dplyr. Schimbăm valorile educației cu afirmația ifelse

În tabelul de mai jos, creați o statistică rezumativă pentru a vedea, în medie, câți ani de educație (valoarea z) sunt necesari pentru a ajunge la licență, master sau doctorat.

recast_data % > %
	group_by(education) % > %
	summarize(average_educ_year = mean(educational.num),
		count = n()) % > %
	arrange(average_educ_year)

ieșire:

## # A tibble: 6 x 3
## education average_educ_year count			
##      <fctr>             <dbl> <int>
## 1   dropout       -1.76147258  5712
## 2  HighGrad       -0.43998868 14803
## 3 Community        0.09561361 13407
## 4 Bachelors        1.12216282  7720
## 5    Master        1.60337381  3338
## 6       PhD        2.29377644   557

Reformare Marital-status

De asemenea, este posibil să se creeze niveluri inferioare pentru starea civilă. În următorul cod modificați nivelul după cum urmează:

Nivel vechi Nivel nou
Niciodata casatorit Necasatorit
Căsătorit-soț-absent Necasatorit
Căsătorit-AF-soț Căsătorit
Căsătorit-civ-soț
Separat Separat
Divorţat
văduvele Văduvă
# Change level marry
recast_data <- recast_data % > %
	mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))

Puteți verifica numărul de persoane din fiecare grup.

table(recast_data$marital.status)

ieșire:

## ##     Married Not_married   Separated       Widow
##       21165       15359        7727        1286

Pasul 4) Statistică rezumată

Este timpul să verificăm câteva statistici despre variabilele noastre țintă. În graficul de mai jos, numărați procentul de persoane care câștigă mai mult de 50, având în vedere sexul lor.

# Plot gender income
ggplot(recast_data, aes(x = gender, fill = income)) +
    geom_bar(position = "fill") +
    theme_classic()

ieșire:

Statistica rezumată

Apoi, verificați dacă originea individului îi afectează câștigurile.

# Plot origin income
ggplot(recast_data, aes(x = race, fill = income)) +
    geom_bar(position = "fill") +
    theme_classic() +
    theme(axis.text.x = element_text(angle = 90))

ieșire:

Statistica rezumată

Numărul de ore de muncă în funcție de sex.

# box plot gender working time
ggplot(recast_data, aes(x = gender, y = hours.per.week)) +
    geom_boxplot() +
    stat_summary(fun.y = mean,
        geom = "point",
        size = 3,
        color = "steelblue") +
    theme_classic()

ieșire:

Statistica rezumată

Box plot confirmă faptul că distribuția timpului de lucru se potrivește diferitelor grupuri. În box plot, ambele sexe nu au observații omogene.

Puteți verifica densitatea timpului de lucru săptămânal în funcție de tipul de studii. Distribuțiile au multe alegeri distincte. Poate fi explicat prin tipul de contract din SUA.

# Plot distribution working time by education
ggplot(recast_data, aes(x = hours.per.week)) +
    geom_density(aes(color = education), alpha = 0.5) +
    theme_classic()

Explicarea codului

  • ggplot(recast_data, aes(x= hours.per.week)): Un grafic de densitate necesită doar o variabilă
  • geom_density(aes(culoare = educație), alpha =0.5): obiectul geometric pentru a controla densitatea

ieșire:

Statistica rezumată

Pentru a vă confirma gândurile, puteți efectua un singur sens test ANOVA:

anova <- aov(hours.per.week~education, recast_data)
summary(anova)

ieșire:

##                Df Sum Sq Mean Sq F value Pr(>F)    
## education       5   1552  310.31   321.2 <2e-16 ***
## Residuals   45531  43984    0.97                   
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Testul ANOVA confirmă diferența de medie între grupuri.

Non-liniaritatea

Înainte de a rula modelul, puteți vedea dacă numărul de ore lucrate este legat de vârstă.

library(ggplot2)
ggplot(recast_data, aes(x = age, y = hours.per.week)) +
    geom_point(aes(color = income),
        size = 0.5) +
    stat_smooth(method = 'lm',
        formula = y~poly(x, 2),
        se = TRUE,
        aes(color = income)) +
    theme_classic()

Explicarea codului

  • ggplot(recast_data, aes(x = age, y = hours.per.week)): setați estetica graficului
  • geom_point(aes(culoare=venit), dimensiune =0.5): Construiți graficul cu puncte
  • stat_smooth(): Adăugați linia de tendință cu următoarele argumente:
    • method='lm': Trasează valoarea ajustată dacă regresie liniara
    • formula = y~poly(x,2): Potriviți o regresie polinomială
    • se = TRUE: Adăugați eroarea standard
    • aes(culoare= venit): sparge modelul după venit

ieșire:

Non-liniaritatea

Pe scurt, puteți testa termenii de interacțiune din model pentru a detecta efectul de neliniaritate dintre timpul de lucru săptămânal și alte caracteristici. Este important să se detecteze în ce condiții diferă timpul de lucru.

Corelație

Următoarea verificare este de a vizualiza corelația dintre variabile. Convertiți tipul de nivel de factor în numeric, astfel încât să puteți reprezenta o hartă termică care conține coeficientul de corelație calculat cu metoda Spearman.

library(GGally)
# Convert data to numeric
corr <- data.frame(lapply(recast_data, as.integer))
# Plot the graphggcorr(corr,
    method = c("pairwise", "spearman"),
    nbreaks = 6,
    hjust = 0.8,
    label = TRUE,
    label_size = 3,
    color = "grey50")

Explicarea codului

  • data.frame(lapply(recast_data,as.integer)): Convertiți datele în numere
  • ggcorr() trasează harta termică cu următoarele argumente:
    • metoda: Metodă de calcul a corelației
    • nbreaks = 6: Număr de pauze
    • hjust = 0.8: Poziția de control a numelui variabilei în grafic
    • label = TRUE: Adăugați etichete în centrul ferestrelor
    • label_size = 3: etichete de dimensiune
    • culoare = „grey50”): culoarea etichetei

ieșire:

Corelație

Pasul 5) Set de tren/test

Orice supravegheat masina de învățare sarcina necesită împărțirea datelor între un set de tren și un set de testare. Puteți utiliza „funcția” pe care ați creat-o în celelalte tutoriale de învățare supravegheată pentru a crea un set de tren/test.

set.seed(1234)
create_train_test <- function(data, size = 0.8, train = TRUE) {
    n_row = nrow(data)
    total_row = size * n_row
    train_sample <- 1: total_row
    if (train == TRUE) {
        return (data[train_sample, ])
    } else {
        return (data[-train_sample, ])
    }
}
data_train <- create_train_test(recast_data, 0.8, train = TRUE)
data_test <- create_train_test(recast_data, 0.8, train = FALSE)
dim(data_train)

ieșire:

## [1] 36429     9
dim(data_test)

ieșire:

## [1] 9108    9

Pasul 6) Construiți modelul

Pentru a vedea cum funcționează algoritmul, utilizați pachetul glm(). The Model liniar generalizat este o colecție de modele. Sintaxa de bază este:

glm(formula, data=data, family=linkfunction()
Argument:
- formula:  Equation used to fit the model- data: dataset used
- Family:     - binomial: (link = "logit")			
- gaussian: (link = "identity")			
- Gamma:    (link = "inverse")			
- inverse.gaussian: (link = "1/mu^2")			
- poisson:  (link = "log")			
- quasi:    (link = "identity", variance = "constant")			
- quasibinomial:    (link = "logit")			
- quasipoisson: (link = "log")	

Sunteți gata să estimați modelul logistic pentru a împărți nivelul veniturilor între un set de caracteristici.

formula <- income~.
logit <- glm(formula, data = data_train, family = 'binomial')
summary(logit)

Explicarea codului

  • formula <- venit ~ .: Creați modelul care să se potrivească
  • logit <- glm(formula, data = data_train, family = 'binom'): Potriviți un model logistic (family = 'binom') cu datele data_train.
  • summary(logit): Imprimați rezumatul modelului

ieșire:

## 
## Call:
## glm(formula = formula, family = "binomial", data = data_train)
## ## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.6456  -0.5858  -0.2609  -0.0651   3.1982  
## 
## Coefficients:
##                           Estimate Std. Error z value Pr(>|z|)    
## (Intercept)                0.07882    0.21726   0.363  0.71675    
## age                        0.41119    0.01857  22.146  < 2e-16 ***
## workclassLocal-gov        -0.64018    0.09396  -6.813 9.54e-12 ***
## workclassPrivate          -0.53542    0.07886  -6.789 1.13e-11 ***
## workclassSelf-emp-inc     -0.07733    0.10350  -0.747  0.45499    
## workclassSelf-emp-not-inc -1.09052    0.09140 -11.931  < 2e-16 ***
## workclassState-gov        -0.80562    0.10617  -7.588 3.25e-14 ***
## workclassWithout-pay      -1.09765    0.86787  -1.265  0.20596    
## educationCommunity        -0.44436    0.08267  -5.375 7.66e-08 ***
## educationHighGrad         -0.67613    0.11827  -5.717 1.08e-08 ***
## educationMaster            0.35651    0.06780   5.258 1.46e-07 ***
## educationPhD               0.46995    0.15772   2.980  0.00289 ** 
## educationdropout          -1.04974    0.21280  -4.933 8.10e-07 ***
## educational.num            0.56908    0.07063   8.057 7.84e-16 ***
## marital.statusNot_married -2.50346    0.05113 -48.966  < 2e-16 ***
## marital.statusSeparated   -2.16177    0.05425 -39.846  < 2e-16 ***
## marital.statusWidow       -2.22707    0.12522 -17.785  < 2e-16 ***
## raceAsian-Pac-Islander     0.08359    0.20344   0.411  0.68117    
## raceBlack                  0.07188    0.19330   0.372  0.71001    
## raceOther                  0.01370    0.27695   0.049  0.96054    
## raceWhite                  0.34830    0.18441   1.889  0.05894 .  
## genderMale                 0.08596    0.04289   2.004  0.04506 *  
## hours.per.week             0.41942    0.01748  23.998  < 2e-16 ***
## ---## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## ## (Dispersion parameter for binomial family taken to be 1)
## ##     Null deviance: 40601  on 36428  degrees of freedom
## Residual deviance: 27041  on 36406  degrees of freedom
## AIC: 27087
## 
## Number of Fisher Scoring iterations: 6

Rezumatul modelului nostru dezvăluie informații interesante. Performanța unei regresii logistice este evaluată cu valori cheie specifice.

  • AIC (Akaike Information Criteria): Acesta este echivalentul R2 în regresie logistică. Măsoară potrivirea atunci când se aplică o penalizare numărului de parametri. Mai mic AIC valorile indică faptul că modelul este mai aproape de adevăr.
  • Devianța nulă: se potrivește modelului numai cu interceptarea. Gradul de libertate este n-1. O putem interpreta ca o valoare Chi-pătrat (valoare ajustată diferită de testarea ipotezei valorii reale).
  • Devianța reziduală: Model cu toate variabilele. De asemenea, este interpretată ca o testare a ipotezei Chi-pătrat.
  • Numărul de iterații Fisher Scoring: numărul de iterații înainte de convergență.

Ieșirea funcției glm() este stocată într-o listă. Codul de mai jos arată toate elementele disponibile în variabila logit pe care am construit-o pentru a evalua regresia logistică.

# Lista este foarte lungă, imprimați doar primele trei elemente

lapply(logit, class)[1:3]

ieșire:

## $coefficients
## [1] "numeric"
## 
## $residuals
## [1] "numeric"
## 
## $fitted.values
## [1] "numeric"

Fiecare valoare poate fi extrasă cu semnul $ urmat de numele valorilor. De exemplu, ați stocat modelul ca logit. Pentru a extrage criteriile AIC, utilizați:

logit$aic

ieșire:

## [1] 27086.65

Pasul 7) Evaluați performanța modelului

Matricea confuziei

matrice de confuzie este o alegere mai bună pentru a evalua performanța clasificării în comparație cu diferitele valori pe care le-ați văzut anterior. Ideea generală este de a număra de câte ori instanțele adevărate sunt clasificate ca fiind false.

Matricea confuziei

Pentru a calcula matricea de confuzie, trebuie mai întâi să aveți un set de predicții, astfel încât acestea să poată fi comparate cu țintele reale.

predict <- predict(logit, data_test, type = 'response')
# confusion matrix
table_mat <- table(data_test$income, predict > 0.5)
table_mat

Explicarea codului

  • predict(logit,data_test, type = 'răspuns'): Calculați predicția pe setul de testare. Setați tip = „răspuns” pentru a calcula probabilitatea de răspuns.
  • table(data_test$income, predict > 0.5): Calculați matricea de confuzie. prezice > 0.5 înseamnă că returnează 1 dacă probabilitățile prezise sunt peste 0.5, altfel 0.

ieșire:

##        
##         FALSE TRUE
##   <=50K  6310  495
##   >50K   1074 1229	

Fiecare rând dintr-o matrice de confuzie reprezintă o țintă reală, în timp ce fiecare coloană reprezintă o țintă estimată. Primul rând al acestei matrice consideră venitul mai mic de 50k (clasa Fals): 6241 au fost clasificați corect ca persoane cu venituri mai mici de 50k (Adevărat negativ), în timp ce cel rămas a fost clasificat greșit ca peste 50k (Fals pozitiv). Al doilea rând ia în considerare venitul peste 50k, clasa pozitivă a fost 1229 (Adevărat pozitiv), in timp ce Adevărat negativ a fost 1074.

Puteți calcula modelul precizie prin însumarea adevăratului pozitiv + adevăratul negativ peste observația totală

Matricea confuziei

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
accuracy_Test

Explicarea codului

  • sum(diag(table_mat)): Suma diagonalei
  • sum(table_mat): Suma matricei.

ieșire:

## [1] 0.8277339

Modelul pare să sufere de o problemă, supraestimează numărul de fals negative. Aceasta se numește paradoxul testului de precizie. Am afirmat că acuratețea este raportul dintre predicțiile corecte și numărul total de cazuri. Putem avea o precizie relativ mare, dar un model inutil. Se întâmplă când există o clasă dominantă. Dacă priviți înapoi la matricea de confuzie, puteți vedea că majoritatea cazurilor sunt clasificate drept negative adevărate. Imaginați-vă acum, modelul a clasificat toate clasele drept negative (adică mai mici de 50k). Ai avea o precizie de 75 la sută (6718/6718+2257). Modelul tău are performanțe mai bune, dar se luptă să distingă adevăratul pozitiv de adevăratul negativ.

Într-o astfel de situație, este de preferat să aveți o metrică mai concisă. Ne putem uita la:

  • Precizie=TP/(TP+FP)
  • Recall=TP/(TP+FN)

Precizie vs rechemare

Precizie se uită la acuratețea predicției pozitive. Rechemare este raportul instanțelor pozitive care sunt detectate corect de către clasificator;

Puteți construi două funcții pentru a calcula aceste două valori

  1. Precizie de construcție
precision <- function(matrix) {
	# True positive
    tp <- matrix[2, 2]
	# false positive
    fp <- matrix[1, 2]
    return (tp / (tp + fp))
}

Explicarea codului

  • mat[1,1]: Returnează prima celulă a primei coloane a cadrului de date, adică adevăratul pozitiv
  • mat[1,2]; Returnează prima celulă din a doua coloană a cadrului de date, adică fals pozitiv
recall <- function(matrix) {
# true positive
    tp <- matrix[2, 2]# false positive
    fn <- matrix[2, 1]
    return (tp / (tp + fn))
}

Explicarea codului

  • mat[1,1]: Returnează prima celulă a primei coloane a cadrului de date, adică adevăratul pozitiv
  • mat[2,1]; Returnați a doua celulă a primei coloane a cadrului de date, adică negativul fals

Vă puteți testa funcțiile

prec <- precision(table_mat)
prec
rec <- recall(table_mat)
rec

ieșire:

## [1] 0.712877
## [2] 0.5336518

Când modelul spune că este un individ de peste 50k, este corect doar în 54% din cazuri și poate revendica persoane cu peste 50k în 72% din cazuri.

Puteți crea fișierul Precizie vs rechemare scor bazat pe precizie și reamintire. The Precizie vs rechemare este o medie armonică a acestor două metrici, ceea ce înseamnă că acordă mai multă pondere valorilor mai mici.

Precizie vs rechemare

f1 <- 2 * ((prec * rec) / (prec + rec))
f1

ieșire:

## [1] 0.6103799

Precizie vs rechemare compromis

Este imposibil să aveți atât o precizie ridicată, cât și o reamintire ridicată.

Dacă creștem precizia, individul corect va fi mai bine prezis, dar am rata foarte multe dintre ele (reamintire mai mică). În unele situații, preferăm o precizie mai mare decât reamintirea. Există o relație concavă între precizie și reamintire.

  • Imaginați-vă, trebuie să preziceți dacă un pacient are o boală. Vrei să fii cât mai precis posibil.
  • Dacă trebuie să detectați potențialele persoane frauduloase de pe stradă prin recunoașterea facială, ar fi mai bine să prindeți multe persoane etichetate drept frauduloase, chiar dacă precizia este scăzută. Poliția va putea elibera persoana nefrauduloasă.

Curba ROC

Receptor Operating Caracteristică curba este un alt instrument comun utilizat cu clasificarea binară. Este foarte asemănătoare cu curba de precizie/rechemare, dar în loc de a reprezenta un grafic precizie versus reamintire, curba ROC arată rata pozitivă adevărată (adică, reamintirea) față de rata fals pozitivă. Rata fals pozitive este raportul dintre cazurile negative care sunt clasificate incorect drept pozitive. Este egal cu unu minus rata negativă adevărată. Rata adevărată negativă se mai numește specificitate. De aici curba ROC grafică sensibilitate (reamintire) versus 1-specificitate

Pentru a trasa curba ROC, trebuie să instalăm o bibliotecă numită RORC. Putem găsi în conda bibliotecă. Puteți introduce codul:

conda install -cr r-rocr –da

Putem reprezenta ROC cu funcțiile prediction() și performance().

library(ROCR)
ROCRpred <- prediction(predict, data_test$income)
ROCRperf <- performance(ROCRpred, 'tpr', 'fpr')
plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))

Explicarea codului

  • prediction(predict, data_test$income): biblioteca ROCR trebuie să creeze un obiect de predicție pentru a transforma datele de intrare
  • performanță(ROCRpred, 'tpr','fpr'): returnează cele două combinații de produs în grafic. Aici, tpr și fpr sunt construite. Pentru a reprezenta precizia și a reaminti împreună, folosiți „prec”, „rec”.

ieșire:

Curba ROC

Pas 8) Îmbunătățiți modelul

Puteți încerca să adăugați non-liniaritate modelului cu interacțiunea dintre

  • vârsta și ore.pe.săptămână
  • sex și ore.pe.săptămână.

Trebuie să utilizați testul de scor pentru a compara ambele modele

formula_2 <- income~age: hours.per.week + gender: hours.per.week + .
logit_2 <- glm(formula_2, data = data_train, family = 'binomial')
predict_2 <- predict(logit_2, data_test, type = 'response')
table_mat_2 <- table(data_test$income, predict_2 > 0.5)
precision_2 <- precision(table_mat_2)
recall_2 <- recall(table_mat_2)
f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2))
f1_2

ieșire:

## [1] 0.6109181

Scorul este puțin mai mare decât cel anterior. Puteți continua să lucrați la date și să încercați să bateți scorul.

Rezumat

Putem rezuma funcția de antrenare a unei regresii logistice în tabelul de mai jos:

Pachet Obiectiv Funcţie Argument
- Creați un set de date tren/test create_train_set() date, dimensiune, tren
glm Antrenează un model liniar generalizat glm() formula, date, familie*
glm Rezumă modelul rezumat() model montat
de bază Faceți predicții prezice() model adaptat, set de date, tip = „răspuns”
de bază Creați o matrice de confuzie masa() y, prezice()
de bază Creați un scor de precizie suma(diag(tabel())/sum(tabel()
ROCR Creați ROC: Pasul 1 Creați predicție predicție () prezice(), y
ROCR Creați ROC: Pasul 2 Creați performanță performanţă() predicție(), „tpr”, „fpr”
ROCR Creați ROC: Pasul 3 Trasați graficul complot() performanţă()

Celălalt GLM tipuri de modele sunt:

– binom: (link = „logit”)

– gaussian: (link = „identitate”)

– Gamma: (link = „invers”)

– invers.gaussian: (link = „1/mu^2”)

– poisson: (link = „jurnal”)

– cvasi: (link = „identitate”, varianță = „constant”)

– cvasibinom: (link = „logit”)

– cvasipoisson: (link = „log”)