GLM i R: Generalisert lineær modell med eksempel

Hva er logistisk regresjon?

Logistisk regresjon brukes til å forutsi en klasse, dvs. en sannsynlighet. Logistisk regresjon kan forutsi et binært utfall nøyaktig.

Tenk deg at du vil forutsi om et lån blir nektet/akseptert basert på mange attributter. Den logistiske regresjonen er av formen 0/1. y = 0 hvis et lån blir avvist, y = 1 hvis det godtas.

En logistisk regresjonsmodell skiller seg fra lineær regresjonsmodell på to måter.

  • Først og fremst aksepterer den logistiske regresjonen bare dikotom (binær) input som en avhengig variabel (dvs. en vektor på 0 og 1).
  • For det andre måles utfallet av følgende sannsynlighetskoblingsfunksjon kalt sigmoid på grunn av sin S-form.:

Logistisk regresjon

Utgangen til funksjonen er alltid mellom 0 og 1. Sjekk bildet nedenfor

Logistisk regresjon

Sigmoid-funksjonen returnerer verdier fra 0 til 1. For klassifiseringsoppgaven trenger vi en diskret utgang på 0 eller 1.

For å konvertere en kontinuerlig flyt til diskret verdi, kan vi sette en beslutningsgrense til 0.5. Alle verdier over denne terskelen er klassifisert som 1

Logistisk regresjon

Hvordan lage Generalized Liner Model (GLM)

La oss bruke voksen datasett for å illustrere logistisk regresjon. "Voksen" er et flott datasett for klassifiseringsoppgaven. Målet er å forutsi om en persons årlige inntekt i dollar vil overstige 50.000. Datasettet inneholder 46,033 XNUMX observasjoner og ti funksjoner:

  • alder: individets alder. Numerisk
  • utdanning: Utdanningsnivået til den enkelte. Faktor.
  • marital.status: Maripersonens individuelle status. Faktor dvs. aldri gift, gift borgerlig ektefelle, …
  • kjønn: Individets kjønn. Faktor, dvs. mann eller kvinne
  • inntekt: Target variabel. Inntekt over eller under 50K. Faktor dvs. >50K, <=50K

blant andre

library(dplyr)
data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv")
glimpse(data_adult)

Utgang:

Observations: 48,842
Variables: 10
$ x               <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,...
$ age             <int> 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26...
$ workclass       <fctr> Private, Private, Local-gov, Private, ?, Private,...
$ education       <fctr> 11th, HS-grad, Assoc-acdm, Some-college, Some-col...
$ educational.num <int> 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,...
$ marital.status  <fctr> Never-married, Married-civ-spouse, Married-civ-sp...
$ race            <fctr> Black, White, White, Black, White, White, Black, ...
$ gender          <fctr> Male, Male, Male, Male, Female, Male, Male, Male,...
$ hours.per.week  <int> 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39...
$ income          <fctr> <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5...

Vi vil gå frem som følger:

  • Trinn 1: Sjekk kontinuerlige variabler
  • Trinn 2: Sjekk faktorvariabler
  • Trinn 3: Funksjonsteknikk
  • Trinn 4: Sammendragsstatistikk
  • Trinn 5: Tren/testsett
  • Trinn 6: Bygg modellen
  • Trinn 7: Vurder ytelsen til modellen
  • trinn 8: Forbedre modellen

Din oppgave er å forutsi hvilken person som vil ha en inntekt høyere enn 50K.

I denne opplæringen vil hvert trinn bli detaljert for å utføre en analyse på et ekte datasett.

Trinn 1) Sjekk kontinuerlige variabler

I det første trinnet kan du se fordelingen av de kontinuerlige variablene.

continuous <-select_if(data_adult, is.numeric)
summary(continuous)

Kode Forklaring

  • kontinuerlig <- select_if(data_adult, is.numeric): Bruk funksjonen select_if() fra dplyr-biblioteket for å velge bare de numeriske kolonnene
  • summary (kontinuerlig): Skriv ut sammendragsstatistikken

Utgang:

##        X              age        educational.num hours.per.week 
##  Min.   :    1   Min.   :17.00   Min.   : 1.00   Min.   : 1.00  
##  1st Qu.:11509   1st Qu.:28.00   1st Qu.: 9.00   1st Qu.:40.00  
##  Median :23017   Median :37.00   Median :10.00   Median :40.00  
##  Mean   :23017   Mean   :38.56   Mean   :10.13   Mean   :40.95  
##  3rd Qu.:34525   3rd Qu.:47.00   3rd Qu.:13.00   3rd Qu.:45.00  
##  Max.   :46033   Max.   :90.00   Max.   :16.00   Max.   :99.00	

Fra tabellen ovenfor kan du se at dataene har helt forskjellige skalaer og timer.per.uker har store uteliggere (dvs. se på siste kvartil og maksimal verdi).

Du kan håndtere det ved å følge to trinn:

  • 1: Plott fordelingen av timer.per.uke
  • 2: Standardiser de kontinuerlige variablene
  1. Tegn fordelingen

La oss se nærmere på fordelingen av timer.per.uke

# Histogram with kernel density curve
library(ggplot2)
ggplot(continuous, aes(x = hours.per.week)) +
    geom_density(alpha = .2, fill = "#FF6666")

Utgang:

Sjekk kontinuerlige variabler

Variabelen har mange uteliggere og ikke veldefinert fordeling. Du kan delvis takle dette problemet ved å slette de øverste 0.01 prosentene av timene per uke.

Grunnleggende syntaks for kvantil:

quantile(variable, percentile)
arguments:
-variable:  Select the variable in the data frame to compute the percentile
-percentile:  Can be a single value between 0 and 1 or multiple value. If multiple, use this format:  `c(A,B,C, ...)
- `A`,`B`,`C` and `...` are all integer from 0 to 1.

Vi beregner den øverste 2 prosent persentilen

top_one_percent <- quantile(data_adult$hours.per.week, .99)
top_one_percent

Kode Forklaring

  • quantile(data_adult$hours.per.week, .99): Beregn verdien av 99 prosent av arbeidstiden

Utgang:

## 99% 
##  80

98 prosent av befolkningen jobber under 80 timer per uke.

Du kan slippe observasjonene over denne terskelen. Du bruker filteret fra dplyr bibliotek.

data_adult_drop <-data_adult %>%
filter(hours.per.week<top_one_percent)
dim(data_adult_drop)

Utgang:

## [1] 45537    10
  1. Standardiser de kontinuerlige variablene

Du kan standardisere hver kolonne for å forbedre ytelsen fordi dataene dine ikke har samme skala. Du kan bruke funksjonen mutate_if fra dplyr-biblioteket. Den grunnleggende syntaksen er:

mutate_if(df, condition, funs(function))
arguments:
-`df`: Data frame used to compute the function
- `condition`: Statement used. Do not use parenthesis
- funs(function):  Return the function to apply. Do not use parenthesis for the function

Du kan standardisere de numeriske kolonnene som følger:

data_adult_rescale <- data_adult_drop % > %
	mutate_if(is.numeric, funs(as.numeric(scale(.))))
head(data_adult_rescale)

Kode Forklaring

  • mutate_if(is.numeric, funs(scale)): Betingelsen er kun numerisk kolonne og funksjonen er skala

Utgang:

##           X         age        workclass    education educational.num
## 1 -1.732680 -1.02325949          Private         11th     -1.22106443
## 2 -1.732605 -0.03969284          Private      HS-grad     -0.43998868
## 3 -1.732530 -0.79628257        Local-gov   Assoc-acdm      0.73162494
## 4 -1.732455  0.41426100          Private Some-college     -0.04945081
## 5 -1.732379 -0.34232873          Private         10th     -1.61160231
## 6 -1.732304  1.85178149 Self-emp-not-inc  Prof-school      1.90323857
##       marital.status  race gender hours.per.week income
## 1      Never-married Black   Male    -0.03995944  <=50K
## 2 Married-civ-spouse White   Male     0.86863037  <=50K
## 3 Married-civ-spouse White   Male    -0.03995944   >50K
## 4 Married-civ-spouse Black   Male    -0.03995944   >50K
## 5      Never-married White   Male    -0.94854924  <=50K
## 6 Married-civ-spouse White   Male    -0.76683128   >50K

Trinn 2) Sjekk faktorvariabler

Dette trinnet har to mål:

  • Sjekk nivået i hver kategorisk kolonne
  • Definer nye nivåer

Vi deler dette trinnet inn i tre deler:

  • Velg de kategoriske kolonnene
  • Lagre søylediagrammet for hver kolonne i en liste
  • Skriv ut grafene

Vi kan velge faktorkolonnene med koden nedenfor:

# Select categorical column
factor <- data.frame(select_if(data_adult_rescale, is.factor))
	ncol(factor)

Kode Forklaring

  • data.frame(select_if(data_adult, is.factor)): Vi lagrer faktorkolonnene i faktor i en datarammetype. Biblioteket ggplot2 krever et datarammeobjekt.

Utgang:

## [1] 6

Datasettet inneholder 6 kategoriske variabler

Det andre trinnet er mer dyktig. Du vil plotte et stolpediagram for hver kolonne i datarammefaktoren. Det er mer praktisk å automatisere prosessen, spesielt i situasjoner hvor det er mange kolonner.

library(ggplot2)
# Create graph for each column
graph <- lapply(names(factor),
    function(x) 
	ggplot(factor, aes(get(x))) +
		geom_bar() +
		theme(axis.text.x = element_text(angle = 90)))

Kode Forklaring

  • lapply(): Bruk funksjonen lapply() for å sende en funksjon i alle kolonnene i datasettet. Du lagrer utdataene i en liste
  • funksjon(x): Funksjonen vil bli behandlet for hver x. Her er x kolonnene
  • ggplot(faktor, aes(get(x))) + geom_bar()+ theme(axis.text.x = element_text(angle = 90)): Lag et søylediagram for hvert x-element. Merk at for å returnere x som en kolonne, må du inkludere den i get()

Det siste trinnet er relativt enkelt. Du vil skrive ut de 6 grafene.

# Print the graph
graph

Utgang:

## [[1]]

Sjekk faktorvariabler

## ## [[2]]

Sjekk faktorvariabler

## ## [[3]]

Sjekk faktorvariabler

## ## [[4]]

Sjekk faktorvariabler

## ## [[5]]

Sjekk faktorvariabler

## ## [[6]]

Sjekk faktorvariabler

Merk: Bruk neste-knappen for å navigere til neste graf

Sjekk faktorvariabler

Trinn 3) Funksjonsteknikk

Omarbeidet utdanning

Fra grafen over kan du se at variabelen utdanning har 16 nivåer. Dette er betydelig, og enkelte nivåer har et relativt lavt antall observasjoner. Hvis du ønsker å forbedre mengden informasjon du kan få fra denne variabelen, kan du omforme den til et høyere nivå. Man lager nemlig større grupper med tilsvarende utdanningsnivå. For eksempel vil lavt utdanningsnivå konverteres til frafall. Høyere utdanningsnivå endres til master.

Her er detaljen:

Gammelt nivå Nytt nivå
Barnehage frafall
10. dropout
11. dropout
12. dropout
1.-4 dropout
5th-6th dropout
7th-8th dropout
9. dropout
HS-Grad HighGrad
En eller annen høyskole Samfunn
Assoc-acdm Samfunn
Assoc-voc Samfunn
ungkarer ungkarer
mestere mestere
Prof-skole mestere
Doktorgrad PhD
recast_data <- data_adult_rescale % > %
	select(-X) % > %
	mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community",
    ifelse(education == "Bachelors", "Bachelors",
        ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))

Kode Forklaring

  • Vi bruker verbet mutere fra dplyr library. Vi endrer verdiene til utdanning med utsagnet ifelse

I tabellen nedenfor lager du en oppsummerende statistikk for å se i gjennomsnitt hvor mange års utdanning (z-verdi) det tar å nå Bachelor, Master eller PhD.

recast_data % > %
	group_by(education) % > %
	summarize(average_educ_year = mean(educational.num),
		count = n()) % > %
	arrange(average_educ_year)

Utgang:

## # A tibble: 6 x 3
## education average_educ_year count			
##      <fctr>             <dbl> <int>
## 1   dropout       -1.76147258  5712
## 2  HighGrad       -0.43998868 14803
## 3 Community        0.09561361 13407
## 4 Bachelors        1.12216282  7720
## 5    Master        1.60337381  3338
## 6       PhD        2.29377644   557

Omarbeide Marital-status

Det er også mulig å lage lavere nivåer for sivilstanden. I følgende kode endrer du nivået som følger:

Gammelt nivå Nytt nivå
Aldri gift Ikke-gift
Gift-ektefelle-fraværende Ikke-gift
Gift-AF-ektefelle Gift
Gift-civ-ektefelle
Separert Separert
Skilt
enker Enke
# Change level marry
recast_data <- recast_data % > %
	mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))

Du kan sjekke antall individer innenfor hver gruppe.

table(recast_data$marital.status)

Utgang:

## ##     Married Not_married   Separated       Widow
##       21165       15359        7727        1286

Trinn 4) Oppsummeringsstatistikk

Det er på tide å sjekke litt statistikk om målvariablene våre. I grafen nedenfor teller du prosentandelen av individer som tjener mer enn 50 XNUMX gitt kjønn.

# Plot gender income
ggplot(recast_data, aes(x = gender, fill = income)) +
    geom_bar(position = "fill") +
    theme_classic()

Utgang:

Sammendragsstatistikk

Deretter må du sjekke om opprinnelsen til individet påvirker inntjeningen.

# Plot origin income
ggplot(recast_data, aes(x = race, fill = income)) +
    geom_bar(position = "fill") +
    theme_classic() +
    theme(axis.text.x = element_text(angle = 90))

Utgang:

Sammendragsstatistikk

Antall timer arbeid etter kjønn.

# box plot gender working time
ggplot(recast_data, aes(x = gender, y = hours.per.week)) +
    geom_boxplot() +
    stat_summary(fun.y = mean,
        geom = "point",
        size = 3,
        color = "steelblue") +
    theme_classic()

Utgang:

Sammendragsstatistikk

Boksplottet bekrefter at fordelingen av arbeidstid passer til ulike grupper. I boksplotten har ikke begge kjønn homogene observasjoner.

Du kan sjekke tettheten av den ukentlige arbeidstiden etter type utdanning. Distribusjonene har mange forskjellige valg. Det kan trolig forklares med typen kontrakt i USA.

# Plot distribution working time by education
ggplot(recast_data, aes(x = hours.per.week)) +
    geom_density(aes(color = education), alpha = 0.5) +
    theme_classic()

Kode Forklaring

  • ggplot(recast_data, aes( x= hours.per.week)): Et tetthetsplott krever bare én variabel
  • geom_density(aes(color = education), alpha =0.5): Det geometriske objektet for å kontrollere tettheten

Utgang:

Sammendragsstatistikk

For å bekrefte tankene dine, kan du utføre en enveis ANOVA test:

anova <- aov(hours.per.week~education, recast_data)
summary(anova)

Utgang:

##                Df Sum Sq Mean Sq F value Pr(>F)    
## education       5   1552  310.31   321.2 <2e-16 ***
## Residuals   45531  43984    0.97                   
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

ANOVA-testen bekrefter forskjellen i gjennomsnitt mellom gruppene.

Ikke-linearitet

Før du kjører modellen kan du se om antall arbeidstimer er relatert til alder.

library(ggplot2)
ggplot(recast_data, aes(x = age, y = hours.per.week)) +
    geom_point(aes(color = income),
        size = 0.5) +
    stat_smooth(method = 'lm',
        formula = y~poly(x, 2),
        se = TRUE,
        aes(color = income)) +
    theme_classic()

Kode Forklaring

  • ggplot(recast_data, aes(x = alder, y = timer.per.week)): Angi estetikken til grafen
  • geom_point(aes(farge=inntekt), størrelse =0.5): Konstruer prikkplottet
  • stat_smooth(): Legg til trendlinjen med følgende argumenter:
    • method='lm': Plott den tilpassede verdien hvis lineær regresjon
    • formel = y~poly(x,2): Tilpass en polynomregresjon
    • se = TRUE: Legg til standardfeilen
    • aes(farge= inntekt): Bryt modellen etter inntekt

Utgang:

Ikke-linearitet

I et nøtteskall kan du teste interaksjonsbegreper i modellen for å fange opp ikke-linearitetseffekten mellom den ukentlige arbeidstiden og andre funksjoner. Det er viktig å oppdage under hvilke forhold arbeidstiden er forskjellig.

Korrelasjon

Neste kontroll er å visualisere korrelasjonen mellom variablene. Du konverterer faktornivåtypen til numerisk slik at du kan plotte et varmekart som inneholder korrelasjonskoeffisienten beregnet med Spearman-metoden.

library(GGally)
# Convert data to numeric
corr <- data.frame(lapply(recast_data, as.integer))
# Plot the graphggcorr(corr,
    method = c("pairwise", "spearman"),
    nbreaks = 6,
    hjust = 0.8,
    label = TRUE,
    label_size = 3,
    color = "grey50")

Kode Forklaring

  • data.frame(lapply(recast_data,as.integer)): Konverter data til numeriske
  • ggcorr() plott varmekartet med følgende argumenter:
    • metode: Metode for å beregne korrelasjonen
    • nbreaks = 6: Antall pauser
    • hjust = 0.8: Kontrollposisjon for variabelnavnet i plottet
    • label = TRUE: Legg til etiketter i midten av vinduene
    • label_size = 3: Størrelsesetiketter
    • farge = "grå50"): Farge på etiketten

Utgang:

Korrelasjon

Trinn 5) Tren/prøvesett

Eventuelle overvåket maskinlæring oppgave krever å dele dataene mellom et togsett og et testsett. Du kan bruke "funksjonen" du opprettet i de andre veiledede læringsopplæringene for å lage et tog-/testsett.

set.seed(1234)
create_train_test <- function(data, size = 0.8, train = TRUE) {
    n_row = nrow(data)
    total_row = size * n_row
    train_sample <- 1: total_row
    if (train == TRUE) {
        return (data[train_sample, ])
    } else {
        return (data[-train_sample, ])
    }
}
data_train <- create_train_test(recast_data, 0.8, train = TRUE)
data_test <- create_train_test(recast_data, 0.8, train = FALSE)
dim(data_train)

Utgang:

## [1] 36429     9
dim(data_test)

Utgang:

## [1] 9108    9

Trinn 6) Bygg modellen

For å se hvordan algoritmen fungerer, bruker du glm()-pakken. De Generalisert lineær modell er en samling av modeller. Den grunnleggende syntaksen er:

glm(formula, data=data, family=linkfunction()
Argument:
- formula:  Equation used to fit the model- data: dataset used
- Family:     - binomial: (link = "logit")			
- gaussian: (link = "identity")			
- Gamma:    (link = "inverse")			
- inverse.gaussian: (link = "1/mu^2")			
- poisson:  (link = "log")			
- quasi:    (link = "identity", variance = "constant")			
- quasibinomial:    (link = "logit")			
- quasipoisson: (link = "log")	

Du er klar til å estimere den logistiske modellen for å dele inntektsnivået mellom et sett med funksjoner.

formula <- income~.
logit <- glm(formula, data = data_train, family = 'binomial')
summary(logit)

Kode Forklaring

  • formel <- inntekt ~ .: Lag modellen som passer
  • logit <- glm(formel, data = datatog, familie = 'binomial'): Tilpass en logistisk modell (familie = 'binomial') med datatogdataene.
  • summary(logit): Skriv ut sammendraget av modellen

Utgang:

## 
## Call:
## glm(formula = formula, family = "binomial", data = data_train)
## ## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.6456  -0.5858  -0.2609  -0.0651   3.1982  
## 
## Coefficients:
##                           Estimate Std. Error z value Pr(>|z|)    
## (Intercept)                0.07882    0.21726   0.363  0.71675    
## age                        0.41119    0.01857  22.146  < 2e-16 ***
## workclassLocal-gov        -0.64018    0.09396  -6.813 9.54e-12 ***
## workclassPrivate          -0.53542    0.07886  -6.789 1.13e-11 ***
## workclassSelf-emp-inc     -0.07733    0.10350  -0.747  0.45499    
## workclassSelf-emp-not-inc -1.09052    0.09140 -11.931  < 2e-16 ***
## workclassState-gov        -0.80562    0.10617  -7.588 3.25e-14 ***
## workclassWithout-pay      -1.09765    0.86787  -1.265  0.20596    
## educationCommunity        -0.44436    0.08267  -5.375 7.66e-08 ***
## educationHighGrad         -0.67613    0.11827  -5.717 1.08e-08 ***
## educationMaster            0.35651    0.06780   5.258 1.46e-07 ***
## educationPhD               0.46995    0.15772   2.980  0.00289 ** 
## educationdropout          -1.04974    0.21280  -4.933 8.10e-07 ***
## educational.num            0.56908    0.07063   8.057 7.84e-16 ***
## marital.statusNot_married -2.50346    0.05113 -48.966  < 2e-16 ***
## marital.statusSeparated   -2.16177    0.05425 -39.846  < 2e-16 ***
## marital.statusWidow       -2.22707    0.12522 -17.785  < 2e-16 ***
## raceAsian-Pac-Islander     0.08359    0.20344   0.411  0.68117    
## raceBlack                  0.07188    0.19330   0.372  0.71001    
## raceOther                  0.01370    0.27695   0.049  0.96054    
## raceWhite                  0.34830    0.18441   1.889  0.05894 .  
## genderMale                 0.08596    0.04289   2.004  0.04506 *  
## hours.per.week             0.41942    0.01748  23.998  < 2e-16 ***
## ---## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## ## (Dispersion parameter for binomial family taken to be 1)
## ##     Null deviance: 40601  on 36428  degrees of freedom
## Residual deviance: 27041  on 36406  degrees of freedom
## AIC: 27087
## 
## Number of Fisher Scoring iterations: 6

Sammendraget av modellen vår avslører interessant informasjon. Ytelsen til en logistisk regresjon blir evaluert med spesifikke nøkkeltall.

  • AIC (Akaike Information Criteria): Dette tilsvarer R2 i logistisk regresjon. Den måler passformen når en straff pålegges antall parametere. Mindre AIC verdier indikerer at modellen er nærmere sannheten.
  • Nullavvik: Passer kun modellen med avskjæring. Frihetsgraden er n-1. Vi kan tolke det som en chi-kvadratverdi (tilpasset verdi forskjellig fra testingen av faktisk verdihypotese).
  • Residual Deviance: Modell med alle variablene. Det tolkes også som en chi-kvadrat hypotesetesting.
  • Antall Fisher Scoring-iterasjoner: Antall iterasjoner før konvergering.

Utdataene fra glm()-funksjonen er lagret i en liste. Koden nedenfor viser alle elementene som er tilgjengelige i logit-variabelen vi konstruerte for å evaluere den logistiske regresjonen.

# Listen er veldig lang, skriv ut kun de tre første elementene

lapply(logit, class)[1:3]

Utgang:

## $coefficients
## [1] "numeric"
## 
## $residuals
## [1] "numeric"
## 
## $fitted.values
## [1] "numeric"

Hver verdi kan trekkes ut med $-tegnet etterfulgt av navnet på beregningene. For eksempel lagret du modellen som logit. For å trekke ut AIC-kriteriene bruker du:

logit$aic

Utgang:

## [1] 27086.65

Trinn 7) Vurder ytelsen til modellen

Forvirringsmatrise

Ocuco forvirringsmatrise er et bedre valg for å evaluere klassifiseringsytelsen sammenlignet med de forskjellige beregningene du så før. Den generelle ideen er å telle antall ganger sanne forekomster klassifiseres er falske.

Forvirringsmatrise

For å beregne forvirringsmatrisen må du først ha et sett med spådommer slik at de kan sammenlignes med de faktiske målene.

predict <- predict(logit, data_test, type = 'response')
# confusion matrix
table_mat <- table(data_test$income, predict > 0.5)
table_mat

Kode Forklaring

  • predict(logit,data_test, type = 'response'): Beregn prediksjonen på testsettet. Sett type = 'respons' for å beregne responssannsynligheten.
  • table(data_test$income, predict > 0.5): Beregn forvirringsmatrisen. forutsi > 0.5 betyr at den returnerer 1 hvis de forutsagte sannsynlighetene er over 0.5, ellers 0.

Utgang:

##        
##         FALSE TRUE
##   <=50K  6310  495
##   >50K   1074 1229	

Hver rad i en forvirringsmatrise representerer et faktisk mål, mens hver kolonne representerer et forutsagt mål. Den første raden i denne matrisen vurderer inntekten lavere enn 50k (falsk-klassen): 6241 ble korrekt klassifisert som individer med inntekt lavere enn 50k (Ekte negativt), mens den gjenværende feilaktig ble klassifisert som over 50k (Falsk positiv). Den andre raden vurderer inntekten over 50k, den positive klassen var 1229 (Riktig positiv), mens Ekte negativt var 1074.

Du kan beregne modellen nøyaktighet ved å summere det sanne positive + det sanne negative over den totale observasjonen

Forvirringsmatrise

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
accuracy_Test

Kode Forklaring

  • sum(diag(table_mat)): Summen av diagonalen
  • sum(tabellmatte): Summen av matrisen.

Utgang:

## [1] 0.8277339

Modellen ser ut til å lide av ett problem, den overvurderer antallet falske negativer. Dette kalles paradoks for nøyaktighetstest. Vi uttalte at nøyaktigheten er forholdet mellom korrekte spådommer og det totale antallet tilfeller. Vi kan ha relativt høy nøyaktighet, men en ubrukelig modell. Det skjer når det er en dominerende klasse. Hvis du ser tilbake på forvirringsmatrisen, kan du se at de fleste tilfellene er klassifisert som ekte negative. Tenk deg nå, modellen klassifiserte alle klassene som negative (dvs. lavere enn 50k). Du vil ha en nøyaktighet på 75 prosent (6718/6718+2257). Modellen din presterer bedre, men sliter med å skille det sanne positive med det sanne negative.

I en slik situasjon er det å foretrekke å ha en mer kortfattet beregning. Vi kan se på:

  • Presisjon=TP/(TP+FP)
  • Recall=TP/(TP+FN)

Presisjon vs tilbakekalling

Precision ser på nøyaktigheten til den positive prediksjonen. Husker er forholdet mellom positive forekomster som er korrekt oppdaget av klassifisereren;

Du kan konstruere to funksjoner for å beregne disse to beregningene

  1. Konstruer presisjon
precision <- function(matrix) {
	# True positive
    tp <- matrix[2, 2]
	# false positive
    fp <- matrix[1, 2]
    return (tp / (tp + fp))
}

Kode Forklaring

  • mat[1,1]: Returner den første cellen i den første kolonnen i datarammen, dvs. den sanne positive
  • matte[1,2]; Returner den første cellen i den andre kolonnen i datarammen, dvs. den falske positive
recall <- function(matrix) {
# true positive
    tp <- matrix[2, 2]# false positive
    fn <- matrix[2, 1]
    return (tp / (tp + fn))
}

Kode Forklaring

  • mat[1,1]: Returner den første cellen i den første kolonnen i datarammen, dvs. den sanne positive
  • matte[2,1]; Returner den andre cellen i den første kolonnen i datarammen, dvs. den falske negative

Du kan teste funksjonene dine

prec <- precision(table_mat)
prec
rec <- recall(table_mat)
rec

Utgang:

## [1] 0.712877
## [2] 0.5336518

Når modellen sier at det er et individ over 50k, er det riktig i bare 54 prosent av tilfellene, og kan kreve individer over 50k i 72 prosent av tilfellet.

Du kan opprette Presisjon vs tilbakekalling score basert på presisjon og tilbakekalling. De Presisjon vs tilbakekalling er et harmonisk gjennomsnitt av disse to beregningene, noe som betyr at det gir større vekt til de lavere verdiene.

Presisjon vs tilbakekalling

f1 <- 2 * ((prec * rec) / (prec + rec))
f1

Utgang:

## [1] 0.6103799

Avveining mellom presisjon og tilbakekalling

Det er umulig å ha både høy presisjon og høy tilbakekalling.

Hvis vi øker presisjonen, vil det riktige individet bli bedre forutsagt, men vi ville savnet mange av dem (lavere tilbakekalling). I noen situasjoner foretrekker vi høyere presisjon enn tilbakekalling. Det er et konkavt forhold mellom presisjon og gjenkalling.

  • Tenk deg, du må forutsi om en pasient har en sykdom. Du ønsker å være så presis som mulig.
  • Hvis du trenger å oppdage potensielle bedragerske personer på gaten ved hjelp av ansiktsgjenkjenning, ville det være bedre å fange mange personer stemplet som uredelige selv om presisjonen er lav. Politiet vil kunne løslate den ikke-bedrageriske personen.

ROC-kurven

Ocuco Mottaker Operating Karakteristisk kurve er et annet vanlig verktøy som brukes med binær klassifisering. Den er veldig lik presisjon/gjenkallingskurven, men i stedet for å plotte presisjon versus gjenkalling, viser ROC-kurven den sanne positive raten (dvs. gjenkalling) mot den falske positive raten. Den falske positive frekvensen er forholdet mellom negative forekomster som feilaktig klassifiseres som positive. Det er lik én minus den sanne negative kursen. Den sanne negative satsen kalles også spesifisitet. Derfor plotter ROC-kurven følsomhet (gjenkalling) versus 1-spesifisitet

For å plotte ROC-kurven må vi installere et bibliotek kalt RORC. Vi kan finne i conda bibliotek. Du kan skrive inn koden:

conda installer -cr r-rocr -ja

Vi kan plotte ROC med funksjonene prediksjon() og ytelse().

library(ROCR)
ROCRpred <- prediction(predict, data_test$income)
ROCRperf <- performance(ROCRpred, 'tpr', 'fpr')
plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))

Kode Forklaring

  • prediksjon(prediksjon, data_test$inntekt): ROCR-biblioteket må lage et prediksjonsobjekt for å transformere inndataene
  • ytelse(ROCRpred, 'tpr','fpr'): Returner de to kombinasjonene som skal produseres i grafen. Her er tpr og fpr konstruert. For å plotte presisjon og huske sammen, bruk "prec", "rec".

Utgang:

ROC-kurven

Trinn 8) Forbedre modellen

Du kan prøve å legge til ikke-linearitet til modellen med interaksjonen mellom

  • alder og timer.per.uke
  • kjønn og timer.per.uke.

Du må bruke poengtesten for å sammenligne begge modellene

formula_2 <- income~age: hours.per.week + gender: hours.per.week + .
logit_2 <- glm(formula_2, data = data_train, family = 'binomial')
predict_2 <- predict(logit_2, data_test, type = 'response')
table_mat_2 <- table(data_test$income, predict_2 > 0.5)
precision_2 <- precision(table_mat_2)
recall_2 <- recall(table_mat_2)
f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2))
f1_2

Utgang:

## [1] 0.6109181

Poengsummen er litt høyere enn den forrige. Du kan fortsette å jobbe med dataene og prøve å slå poengsummen.

Sammendrag

Vi kan oppsummere funksjonen for å trene en logistisk regresjon i tabellen nedenfor:

Pakke Målet Funksjon Argument
- Lag tog/testdatasett create_train_set() data, størrelse, tog
glm Tren en generalisert lineær modell glm() formel, data, familie*
glm Oppsummer modellen sammendrag() montert modell
basen Gjør spådommer spå() tilpasset modell, datasett, type = 'respons'
basen Lag en forvirringsmatrise bord() y, forutsi()
basen Lag nøyaktighetsscore sum(diag(tabell())/sum(tabell()
ROCR Lag ROC: Trinn 1 Lag prediksjon prediksjon() forutsi(), y
ROCR Lag ROC: Trinn 2 Opprett ytelse ytelse() prediksjon(), 'tpr', 'fpr'
ROCR Lag ROC: Trinn 3 Plott graf plott() ytelse()

Den andre GLM type modeller er:

– binomial: (link = “logit”)

– gaussisk: (lenke = "identitet")

– Gamma: (lenke = “invers”)

– inverse.gaussian: (link = “1/mu^2”)

– gift: (lenke = "logg")

– kvasi: (lenke = "identitet", varians = "konstant")

– kvasibinomial: (lenke = “logit”)

– kvasipoisson: (lenke = "logg")