GLM in R: gegeneraliseerd lineair model met voorbeeld

Wat is logistieke regressie?

Logistische regressie wordt gebruikt om een ​​klasse, dat wil zeggen een waarschijnlijkheid, te voorspellen. Logistieke regressie kan een binaire uitkomst nauwkeurig voorspellen.

Stel je voor dat je op basis van vele kenmerken wilt voorspellen of een lening wordt geweigerd/geaccepteerd. De logistische regressie heeft de vorm 0/1. y = 0 als een lening wordt afgewezen, y = 1 als deze wordt geaccepteerd.

Een logistisch regressiemodel verschilt op twee manieren van een lineair regressiemodel.

  • Allereerst accepteert de logistische regressie alleen dichotome (binaire) invoer als afhankelijke variabele (dwz een vector van 0 en 1).
  • In de tweede plaats wordt de uitkomst gemeten met behulp van de volgende waarschijnlijke koppelingsfunctie, genaamd sigmoïde vanwege zijn S-vorm.:

Logistische regressie

De uitvoer van de functie ligt altijd tussen 0 en 1. Controleer onderstaande afbeelding

Logistische regressie

De sigmoïdefunctie retourneert waarden van 0 tot 1. Voor de classificatietaak hebben we een discrete uitvoer van 0 of 1 nodig.

Om een ​​continue stroom om te zetten in discrete waarde, kunnen we een beslissingsgrens instellen op 0.5. Alle waarden boven deze drempel worden geclassificeerd als 1

Logistische regressie

Hoe u een gegeneraliseerd voeringmodel (GLM) kunt maken

Laten we de volwassen dataset om logistieke regressie te illustreren. De “volwassene” is een geweldige dataset voor de classificatietaak. Het doel is om te voorspellen of het jaarinkomen in dollars van een individu hoger zal zijn dan 50.000. De dataset bevat 46,033 observaties en tien kenmerken:

  • leeftijd: leeftijd van het individu. Numeriek
  • opleiding: opleidingsniveau van het individu. Factor.
  • burgerlijke staat: Maristatus van het individu. Factor dwz nooit getrouwd, getrouwd-burgerlijke echtgenoot, …
  • geslacht: geslacht van het individu. Factor, dwz mannelijk of vrouwelijk
  • inkomen: Target variabel. Inkomen boven of onder 50K. Factor dwz >50K, <=50K

onder anderen

library(dplyr)
data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv")
glimpse(data_adult)

Output:

Observations: 48,842
Variables: 10
$ x               <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,...
$ age             <int> 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26...
$ workclass       <fctr> Private, Private, Local-gov, Private, ?, Private,...
$ education       <fctr> 11th, HS-grad, Assoc-acdm, Some-college, Some-col...
$ educational.num <int> 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,...
$ marital.status  <fctr> Never-married, Married-civ-spouse, Married-civ-sp...
$ race            <fctr> Black, White, White, Black, White, White, Black, ...
$ gender          <fctr> Male, Male, Male, Male, Female, Male, Male, Male,...
$ hours.per.week  <int> 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39...
$ income          <fctr> <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5...

Wij gaan als volgt te werk:

  • Stap 1: Controleer continue variabelen
  • Stap 2: Controleer de factorvariabelen
  • Stap 3: Functie-engineering
  • Stap 4: Samenvattende statistiek
  • Stap 5: Trein-/testset
  • Stap 6: Bouw het model
  • Stap 7: Beoordeel de prestaties van het model
  • stap 8: Verbeter het model

Jouw taak is om te voorspellen welke persoon een omzet zal hebben die hoger is dan 50.

In deze zelfstudie wordt elke stap gedetailleerd beschreven voor het uitvoeren van een analyse op een echte dataset.

Stap 1) Controleer continue variabelen

In de eerste stap kunt u de verdeling van de continue variabelen zien.

continuous <-select_if(data_adult, is.numeric)
summary(continuous)

Code Uitleg

  • continu <- select_if(data_adult, is.numeric): Gebruik de functie select_if() uit de dplyr-bibliotheek om alleen de numerieke kolommen te selecteren
  • summary(continuous): Druk de samenvattende statistiek af

Output:

##        X              age        educational.num hours.per.week 
##  Min.   :    1   Min.   :17.00   Min.   : 1.00   Min.   : 1.00  
##  1st Qu.:11509   1st Qu.:28.00   1st Qu.: 9.00   1st Qu.:40.00  
##  Median :23017   Median :37.00   Median :10.00   Median :40.00  
##  Mean   :23017   Mean   :38.56   Mean   :10.13   Mean   :40.95  
##  3rd Qu.:34525   3rd Qu.:47.00   3rd Qu.:13.00   3rd Qu.:45.00  
##  Max.   :46033   Max.   :90.00   Max.   :16.00   Max.   :99.00	

Uit de bovenstaande tabel blijkt dat de gegevens een totaal verschillende schaal hebben en dat het aantal uren per week grote uitschieters kent (kijk bijvoorbeeld naar het laatste kwartiel en de maximumwaarde).

U kunt dit op de volgende twee manieren aanpakken:

  • 1: Teken de verdeling van uren per week
  • 2: Standaardiseer de continue variabelen
  1. Teken de verdeling

Laten we de verdeling van uren per week eens nader bekijken

# Histogram with kernel density curve
library(ggplot2)
ggplot(continuous, aes(x = hours.per.week)) +
    geom_density(alpha = .2, fill = "#FF6666")

Output:

Controleer continue variabelen

De variabele heeft veel outliers en een niet goed gedefinieerde distributie. U kunt dit probleem gedeeltelijk aanpakken door de bovenste 0.01 procent van de uren per week te verwijderen.

Basissyntaxis van kwantiel:

quantile(variable, percentile)
arguments:
-variable:  Select the variable in the data frame to compute the percentile
-percentile:  Can be a single value between 0 and 1 or multiple value. If multiple, use this format:  `c(A,B,C, ...)
- `A`,`B`,`C` and `...` are all integer from 0 to 1.

We berekenen het bovenste 2 procent-percentiel

top_one_percent <- quantile(data_adult$hours.per.week, .99)
top_one_percent

Code Uitleg

  • quantile(data_adult$hours.per.week, .99): Bereken de waarde van 99 procent van de werktijd

Output:

## 99% 
##  80

98 procent van de bevolking werkt minder dan 80 uur per week.

U kunt de waarnemingen boven deze drempel laten vallen. Je gebruikt het filter uit de dplyr bibliotheek.

data_adult_drop <-data_adult %>%
filter(hours.per.week<top_one_percent)
dim(data_adult_drop)

Output:

## [1] 45537    10
  1. Standaardiseer de continue variabelen

U kunt elke kolom standaardiseren om de prestaties te verbeteren, omdat uw gegevens niet dezelfde schaal hebben. U kunt de functie mute_if uit de dplyr-bibliotheek gebruiken. De basissyntaxis is:

mutate_if(df, condition, funs(function))
arguments:
-`df`: Data frame used to compute the function
- `condition`: Statement used. Do not use parenthesis
- funs(function):  Return the function to apply. Do not use parenthesis for the function

U kunt de numerieke kolommen als volgt standaardiseren:

data_adult_rescale <- data_adult_drop % > %
	mutate_if(is.numeric, funs(as.numeric(scale(.))))
head(data_adult_rescale)

Code Uitleg

  • mute_if(is.numeric, funs(scale)): De voorwaarde is alleen een numerieke kolom en de functie is schaal

Output:

##           X         age        workclass    education educational.num
## 1 -1.732680 -1.02325949          Private         11th     -1.22106443
## 2 -1.732605 -0.03969284          Private      HS-grad     -0.43998868
## 3 -1.732530 -0.79628257        Local-gov   Assoc-acdm      0.73162494
## 4 -1.732455  0.41426100          Private Some-college     -0.04945081
## 5 -1.732379 -0.34232873          Private         10th     -1.61160231
## 6 -1.732304  1.85178149 Self-emp-not-inc  Prof-school      1.90323857
##       marital.status  race gender hours.per.week income
## 1      Never-married Black   Male    -0.03995944  <=50K
## 2 Married-civ-spouse White   Male     0.86863037  <=50K
## 3 Married-civ-spouse White   Male    -0.03995944   >50K
## 4 Married-civ-spouse Black   Male    -0.03995944   >50K
## 5      Never-married White   Male    -0.94854924  <=50K
## 6 Married-civ-spouse White   Male    -0.76683128   >50K

Stap 2) Controleer de factorvariabelen

Deze stap heeft twee doelstellingen:

  • Controleer het niveau in elke categorische kolom
  • Definieer nieuwe niveaus

We verdelen deze stap in drie delen:

  • Selecteer de categorische kolommen
  • Bewaar het staafdiagram van elke kolom in een lijst
  • Print de grafieken

We kunnen de factorkolommen selecteren met de onderstaande code:

# Select categorical column
factor <- data.frame(select_if(data_adult_rescale, is.factor))
	ncol(factor)

Code Uitleg

  • data.frame(select_if(data_adult, is.factor)): We slaan de factorkolommen op in factor in een dataframetype. De bibliotheek ggplot2 vereist een dataframe-object.

Output:

## [1] 6

De dataset bevat 6 categorische variabelen

De tweede stap is vaardiger. U wilt voor elke kolom in de dataframefactor een staafdiagram plotten. Het is handiger om het proces te automatiseren, vooral als er veel kolommen zijn.

library(ggplot2)
# Create graph for each column
graph <- lapply(names(factor),
    function(x) 
	ggplot(factor, aes(get(x))) +
		geom_bar() +
		theme(axis.text.x = element_text(angle = 90)))

Code Uitleg

  • lapply(): Gebruik de functie lapply() om een ​​functie door te geven in alle kolommen van de dataset. De uitvoer sla je op in een lijst
  • function(x): De functie wordt voor elke x verwerkt. Hier zijn x de kolommen
  • ggplot(factor, aes(get(x))) + geom_bar()+ thema(axis.text.x = element_text(angle = 90)): Maak een staafdiagram voor elk x-element. Let op: om x als kolom terug te geven, moet je het in de get() opnemen

De laatste stap is relatief eenvoudig. U wilt de 6 grafieken afdrukken.

# Print the graph
graph

Output:

## [[1]]

Controleer factorvariabelen

## ## [[2]]

Controleer factorvariabelen

## ## [[3]]

Controleer factorvariabelen

## ## [[4]]

Controleer factorvariabelen

## ## [[5]]

Controleer factorvariabelen

## ## [[6]]

Controleer factorvariabelen

Let op: Gebruik de knop Volgende om naar de volgende grafiek te navigeren

Controleer factorvariabelen

Stap 3) Functie-engineering

Herschikking onderwijs

Uit de bovenstaande grafiek kun je zien dat de variabele opleiding 16 niveaus heeft. Dit is aanzienlijk, en sommige niveaus hebben een relatief laag aantal waarnemingen. Als u de hoeveelheid informatie die u uit deze variabele kunt halen, wilt verbeteren, kunt u deze naar een hoger niveau herschikken. Je creëert namelijk grotere groepen met een vergelijkbaar opleidingsniveau. Zo zal een laag opleidingsniveau zich vertalen in uitval. Hogere onderwijsniveaus zullen worden gewijzigd in master.

Hier zijn de details:

Oud niveau Nieuw level
Peuter afvaller
10 Afvaller
11 Afvaller
12 Afvaller
1e-4e Afvaller
5th-6th Afvaller
7th-8th Afvaller
9 Afvaller
HS-Grad Hooggrad
Sommige-college Maatschappelijke verantwoordelijkheid
Assoc-acdm Maatschappelijke verantwoordelijkheid
Assoc-voc Maatschappelijke verantwoordelijkheid
Bachelors Bachelors
Masters Masters
Prof-school Masters
Doctoraat PhD
recast_data <- data_adult_rescale % > %
	select(-X) % > %
	mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community",
    ifelse(education == "Bachelors", "Bachelors",
        ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))

Code Uitleg

  • We gebruiken het werkwoord muteren uit de dplyr-bibliotheek. We veranderen de waarden van onderwijs met de uitspraak ifelse

In onderstaande tabel maak je een samenvattende statistiek om te zien hoeveel jaar opleiding (z-waarde) er gemiddeld nodig is om de bachelor, master of PhD te bereiken.

recast_data % > %
	group_by(education) % > %
	summarize(average_educ_year = mean(educational.num),
		count = n()) % > %
	arrange(average_educ_year)

Output:

## # A tibble: 6 x 3
## education average_educ_year count			
##      <fctr>             <dbl> <int>
## 1   dropout       -1.76147258  5712
## 2  HighGrad       -0.43998868 14803
## 3 Community        0.09561361 13407
## 4 Bachelors        1.12216282  7720
## 5    Master        1.60337381  3338
## 6       PhD        2.29377644   557

Herschikking Marital-status

Het is ook mogelijk om lagere niveaus voor de burgerlijke staat te creëren. In de volgende code verandert u het niveau als volgt:

Oud niveau Nieuw level
Nooit getrouwd Niet getrouwd
Getrouwde echtgenoot-afwezig Niet getrouwd
Getrouwd-AF-echtgenoot Getrouwd
Getrouwd-civ-echtgenoot
Gescheiden Gescheiden
Gescheiden
weduwen Weduwe
# Change level marry
recast_data <- recast_data % > %
	mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))

U kunt het aantal individuen binnen elke groep controleren.

table(recast_data$marital.status)

Output:

## ##     Married Not_married   Separated       Widow
##       21165       15359        7727        1286

Stap 4) Samenvattende statistiek

Het is tijd om wat statistieken over onze doelvariabelen te bekijken. In de onderstaande grafiek tel je het percentage individuen dat meer dan 50 verdient, gegeven hun geslacht.

# Plot gender income
ggplot(recast_data, aes(x = gender, fill = income)) +
    geom_bar(position = "fill") +
    theme_classic()

Output:

Samenvattende Statistiek

Controleer vervolgens of de afkomst van het individu van invloed is op zijn verdiensten.

# Plot origin income
ggplot(recast_data, aes(x = race, fill = income)) +
    geom_bar(position = "fill") +
    theme_classic() +
    theme(axis.text.x = element_text(angle = 90))

Output:

Samenvattende Statistiek

Het aantal werkuren per geslacht.

# box plot gender working time
ggplot(recast_data, aes(x = gender, y = hours.per.week)) +
    geom_boxplot() +
    stat_summary(fun.y = mean,
        geom = "point",
        size = 3,
        color = "steelblue") +
    theme_classic()

Output:

Samenvattende Statistiek

De boxplot bevestigt dat de verdeling van werktijd past bij verschillende groepen. In de boxplot hebben beide geslachten geen homogene observaties.

Per opleidingstype kunt u de dichtheid van de wekelijkse arbeidstijd nagaan. De distributies hebben veel verschillende keuzes. Dit kan waarschijnlijk worden verklaard door het type contract in de VS.

# Plot distribution working time by education
ggplot(recast_data, aes(x = hours.per.week)) +
    geom_density(aes(color = education), alpha = 0.5) +
    theme_classic()

Code Uitleg

  • ggplot(recast_data, aes( x= hours.per.week)): Een dichtheidsplot vereist slechts één variabele
  • geom_density(aes(color = education), alpha =0.5): Het geometrische object om de dichtheid te regelen

Output:

Samenvattende Statistiek

Om uw gedachten te bevestigen, kunt u eenrichtingsverkeer uitvoeren ANOVA-test:

anova <- aov(hours.per.week~education, recast_data)
summary(anova)

Output:

##                Df Sum Sq Mean Sq F value Pr(>F)    
## education       5   1552  310.31   321.2 <2e-16 ***
## Residuals   45531  43984    0.97                   
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

De ANOVA-test bevestigt het verschil in gemiddelde tussen groepen.

Niet-lineariteit

Voordat u het model uitvoert, kunt u zien of het aantal gewerkte uren verband houdt met uw leeftijd.

library(ggplot2)
ggplot(recast_data, aes(x = age, y = hours.per.week)) +
    geom_point(aes(color = income),
        size = 0.5) +
    stat_smooth(method = 'lm',
        formula = y~poly(x, 2),
        se = TRUE,
        aes(color = income)) +
    theme_classic()

Code Uitleg

  • ggplot(recast_data, aes(x = leeftijd, y = uren.per.week)): Stel de esthetiek van de grafiek in
  • geom_point(aes(color=inkomen), size =0.5): Construeer de puntenplot
  • stat_smooth(): Voeg de trendlijn toe met de volgende argumenten:
    • method='lm': Teken de aangepaste waarde als de lineaire regressie
    • formule = y~poly(x,2): Pas een polynoomregressie toe
    • se = TRUE: voeg de standaardfout toe
    • aes(kleur=inkomen): Verdeel het model op basis van inkomen

Output:

Niet-lineariteit

Kortom, u kunt interactietermen in het model testen om het niet-lineariteitseffect tussen de wekelijkse werktijd en andere kenmerken op te sporen. Het is belangrijk om te detecteren onder welke omstandigheden de werktijd verschilt.

Correlatie

De volgende controle is het visualiseren van de correlatie tussen de variabelen. U converteert het factorniveautype naar numeriek, zodat u een heatmap kunt plotten met de correlatiecoëfficiënt die is berekend met de Spearman-methode.

library(GGally)
# Convert data to numeric
corr <- data.frame(lapply(recast_data, as.integer))
# Plot the graphggcorr(corr,
    method = c("pairwise", "spearman"),
    nbreaks = 6,
    hjust = 0.8,
    label = TRUE,
    label_size = 3,
    color = "grey50")

Code Uitleg

  • data.frame(lapply(recast_data,as.integer)): Converteer gegevens naar numeriek
  • ggcorr() plot de heatmap met de volgende argumenten:
    • methode: Methode om de correlatie te berekenen
    • nbreaks = 6: Aantal pauzes
    • hjust = 0.8: Controlepositie van de variabelenaam in de plot
    • label = TRUE: Voeg labels toe in het midden van de vensters
    • label_size = 3: Maatlabels
    • kleur = “grey50”): Kleur van het label

Output:

Correlatie

Stap 5) Trein-/testset

Elke onder toezicht machine learning taak vereisen om de gegevens te splitsen tussen een treinset en een testset. U kunt de “functie” die u in de andere begeleide leertutorials hebt gemaakt, gebruiken om een ​​trein-/testset te maken.

set.seed(1234)
create_train_test <- function(data, size = 0.8, train = TRUE) {
    n_row = nrow(data)
    total_row = size * n_row
    train_sample <- 1: total_row
    if (train == TRUE) {
        return (data[train_sample, ])
    } else {
        return (data[-train_sample, ])
    }
}
data_train <- create_train_test(recast_data, 0.8, train = TRUE)
data_test <- create_train_test(recast_data, 0.8, train = FALSE)
dim(data_train)

Output:

## [1] 36429     9
dim(data_test)

Output:

## [1] 9108    9

Stap 6) Bouw het model

Om te zien hoe het algoritme presteert, gebruikt u het glm()-pakket. De Gegeneraliseerd lineair model is een verzameling modellen. De basissyntaxis is:

glm(formula, data=data, family=linkfunction()
Argument:
- formula:  Equation used to fit the model- data: dataset used
- Family:     - binomial: (link = "logit")			
- gaussian: (link = "identity")			
- Gamma:    (link = "inverse")			
- inverse.gaussian: (link = "1/mu^2")			
- poisson:  (link = "log")			
- quasi:    (link = "identity", variance = "constant")			
- quasibinomial:    (link = "logit")			
- quasipoisson: (link = "log")	

U bent klaar om het logistieke model te schatten om het inkomensniveau op te splitsen over een reeks kenmerken.

formula <- income~.
logit <- glm(formula, data = data_train, family = 'binomial')
summary(logit)

Code Uitleg

  • formule <- inkomen ~ .: Creëer het model dat past
  • logit <- glm(formule, data = data_train, family = 'binomial'): Pas een logistiek model (family = 'binomial') aan met de data_train-gegevens.
  • summary(logit): Druk de samenvatting van het model af

Output:

## 
## Call:
## glm(formula = formula, family = "binomial", data = data_train)
## ## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.6456  -0.5858  -0.2609  -0.0651   3.1982  
## 
## Coefficients:
##                           Estimate Std. Error z value Pr(>|z|)    
## (Intercept)                0.07882    0.21726   0.363  0.71675    
## age                        0.41119    0.01857  22.146  < 2e-16 ***
## workclassLocal-gov        -0.64018    0.09396  -6.813 9.54e-12 ***
## workclassPrivate          -0.53542    0.07886  -6.789 1.13e-11 ***
## workclassSelf-emp-inc     -0.07733    0.10350  -0.747  0.45499    
## workclassSelf-emp-not-inc -1.09052    0.09140 -11.931  < 2e-16 ***
## workclassState-gov        -0.80562    0.10617  -7.588 3.25e-14 ***
## workclassWithout-pay      -1.09765    0.86787  -1.265  0.20596    
## educationCommunity        -0.44436    0.08267  -5.375 7.66e-08 ***
## educationHighGrad         -0.67613    0.11827  -5.717 1.08e-08 ***
## educationMaster            0.35651    0.06780   5.258 1.46e-07 ***
## educationPhD               0.46995    0.15772   2.980  0.00289 ** 
## educationdropout          -1.04974    0.21280  -4.933 8.10e-07 ***
## educational.num            0.56908    0.07063   8.057 7.84e-16 ***
## marital.statusNot_married -2.50346    0.05113 -48.966  < 2e-16 ***
## marital.statusSeparated   -2.16177    0.05425 -39.846  < 2e-16 ***
## marital.statusWidow       -2.22707    0.12522 -17.785  < 2e-16 ***
## raceAsian-Pac-Islander     0.08359    0.20344   0.411  0.68117    
## raceBlack                  0.07188    0.19330   0.372  0.71001    
## raceOther                  0.01370    0.27695   0.049  0.96054    
## raceWhite                  0.34830    0.18441   1.889  0.05894 .  
## genderMale                 0.08596    0.04289   2.004  0.04506 *  
## hours.per.week             0.41942    0.01748  23.998  < 2e-16 ***
## ---## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## ## (Dispersion parameter for binomial family taken to be 1)
## ##     Null deviance: 40601  on 36428  degrees of freedom
## Residual deviance: 27041  on 36406  degrees of freedom
## AIC: 27087
## 
## Number of Fisher Scoring iterations: 6

De samenvatting van ons model onthult interessante informatie. De prestaties van een logistische regressie worden geëvalueerd met specifieke belangrijke statistieken.

  • AIC (Akaike Information Criteria): Dit is het equivalent van R2 bij logistische regressie. Het meet de fit wanneer er een boete wordt toegepast op het aantal parameters. Kleiner AIC waarden geven aan dat het model dichter bij de waarheid ligt.
  • Nulafwijking: Past alleen op het model met het snijpunt. De vrijheidsgraad is n-1. We kunnen het interpreteren als een Chi-kwadraatwaarde (gepaste waarde die afwijkt van het testen van de werkelijke waardehypothese).
  • Residuele afwijking: Model met alle variabelen. Het wordt ook geïnterpreteerd als een Chi-kwadraathypothesetest.
  • Aantal Fisher Scoring-iteraties: aantal iteraties vóór convergentie.

De uitvoer van de glm()-functie wordt opgeslagen in een lijst. De onderstaande code toont alle items die beschikbaar zijn in de logitvariabele die we hebben geconstrueerd om de logistische regressie te evalueren.

# De lijst is erg lang, druk alleen de eerste drie elementen af

lapply(logit, class)[1:3]

Output:

## $coefficients
## [1] "numeric"
## 
## $residuals
## [1] "numeric"
## 
## $fitted.values
## [1] "numeric"

Elke waarde kan worden geëxtraheerd met het $-teken, gevolgd door de naam van de statistieken. U hebt het model bijvoorbeeld opgeslagen als logit. Om de AIC-criteria te extraheren, gebruikt u:

logit$aic

Output:

## [1] 27086.65

Stap 7) Beoordeel de prestaties van het model

Verwarring Matrix

Ocuco's Medewerkers verwarring matrix is een betere keuze om de classificatieprestaties te evalueren in vergelijking met de verschillende statistieken die u eerder zag. Het algemene idee is om het aantal keren te tellen dat True-instanties zijn geclassificeerd als False.

Verwarring Matrix

Om de verwarringsmatrix te berekenen, heb je eerst een reeks voorspellingen nodig, zodat deze kunnen worden vergeleken met de werkelijke doelstellingen.

predict <- predict(logit, data_test, type = 'response')
# confusion matrix
table_mat <- table(data_test$income, predict > 0.5)
table_mat

Code Uitleg

  • voorspellen(logit,data_test, type = 'response'): Bereken de voorspelling op de testset. Stel type = 'respons' in om de responskans te berekenen.
  • table(data_test$income, voorspellen > 0.5): Bereken de verwarringsmatrix. voorspellen > 0.5 betekent dat het 1 retourneert als de voorspelde kansen hoger zijn dan 0.5, anders 0.

Output:

##        
##         FALSE TRUE
##   <=50K  6310  495
##   >50K   1074 1229	

Elke rij in een verwarringsmatrix vertegenwoordigt een feitelijk doel, terwijl elke kolom een ​​voorspeld doel vertegenwoordigt. De eerste rij van deze matrix beschouwt het inkomen lager dan 50 (de False-klasse): 6241 werden correct geclassificeerd als individuen met een inkomen lager dan 50 (Echt negatief), terwijl de overige ten onrechte werd geclassificeerd als boven de 50k (Vals positief). De tweede rij beschouwt het inkomen boven de 50, de positieve klasse was 1229 (Echt positief), Terwijl de Echt negatief was 1074.

U kunt het model berekenen nauwkeurigheid door het werkelijk positieve + het ware negatieve op te tellen over de totale waarneming

Verwarring Matrix

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
accuracy_Test

Code Uitleg

  • sum(diag(table_mat)): Som van de diagonaal
  • sum(table_mat): Som van de matrix.

Output:

## [1] 0.8277339

Het model lijkt met één probleem te kampen: het overschat het aantal fout-negatieven. Dit heet de nauwkeurigheidstestparadox. We stelden dat de nauwkeurigheid de verhouding is tussen correcte voorspellingen en het totale aantal gevallen. We kunnen een relatief hoge nauwkeurigheid hebben, maar een nutteloos model. Het gebeurt als er een dominante klasse is. Als je terugkijkt op de verwarringsmatrix, zie je dat de meeste gevallen als echt negatief zijn geclassificeerd. Stel je nu voor dat het model alle klassen als negatief classificeerde (dwz lager dan 50k). U zou een nauwkeurigheid van 75 procent hebben (6718/6718+2257). Uw model presteert beter, maar heeft moeite om het echte positieve van het echte negatieve te onderscheiden.

In dergelijke situaties verdient het de voorkeur om over een beknoptere metriek te beschikken. We kunnen kijken naar:

  • Precisie=TP/(TP+FP)
  • Terugroepen=TP/(TP+FN)

Precisie versus terugroepen

precisie kijkt naar de nauwkeurigheid van de positieve voorspelling. Terugroepen is de verhouding van positieve exemplaren die correct worden gedetecteerd door de classificator;

U kunt twee functies construeren om deze twee statistieken te berekenen

  1. Precisie construeren
precision <- function(matrix) {
	# True positive
    tp <- matrix[2, 2]
	# false positive
    fp <- matrix[1, 2]
    return (tp / (tp + fp))
}

Code Uitleg

  • mat[1,1]: Geeft de eerste cel van de eerste kolom van het dataframe terug, dwz het echte positieve
  • mat[1,2]; Retourneert de eerste cel van de tweede kolom van het dataframe, dat wil zeggen de fout-positieve cel
recall <- function(matrix) {
# true positive
    tp <- matrix[2, 2]# false positive
    fn <- matrix[2, 1]
    return (tp / (tp + fn))
}

Code Uitleg

  • mat[1,1]: Geeft de eerste cel van de eerste kolom van het dataframe terug, dwz het echte positieve
  • mat[2,1]; Retourneer de tweede cel van de eerste kolom van het dataframe, dwz het vals-negatieve getal

U kunt uw functies testen

prec <- precision(table_mat)
prec
rec <- recall(table_mat)
rec

Output:

## [1] 0.712877
## [2] 0.5336518

Als het model zegt dat het om een ​​persoon boven de 50 gaat, is dat in slechts 54 procent van de gevallen juist, en kan het in 50 procent van de gevallen individuen boven de 72 claimen.

U kunt de Precisie versus terugroepen score gebaseerd op de precisie en herinnering. De Precisie versus terugroepen is een harmonisch gemiddelde van deze twee metrieken, wat betekent dat het meer gewicht toekent aan de lagere waarden.

Precisie versus terugroepen

f1 <- 2 * ((prec * rec) / (prec + rec))
f1

Output:

## [1] 0.6103799

Afweging tussen precisie en terugroepen

Het is onmogelijk om zowel een hoge precisie als een hoge herinnering te hebben.

Als we de precisie verhogen, zal het juiste individu beter worden voorspeld, maar we zouden er veel van missen (lagere herinnering). In sommige situaties geven we de voorkeur aan een hogere precisie dan terugroepen. Er is een concave relatie tussen precisie en herinnering.

  • Stel je voor: je moet voorspellen of een patiënt een ziekte heeft. Je wilt zo precies mogelijk zijn.
  • Als je potentiële frauduleuze mensen op straat moet opsporen via gezichtsherkenning, kun je beter veel mensen betrappen die als frauduleus worden bestempeld, ook al is de nauwkeurigheid laag. De politie zal de niet-frauduleuze persoon kunnen vrijlaten.

De ROC-curve

Ocuco's Medewerkers Ontvanger Operakenmerkend curve is een ander veelgebruikt hulpmiddel bij binaire classificatie. Het lijkt erg op de precisie/herinneringscurve, maar in plaats van precisie versus herinnering weer te geven, toont de ROC-curve het werkelijk positieve percentage (dat wil zeggen, herinnering) tegenover het fout-positieve percentage. Het percentage fout-positieve resultaten is het aantal negatieve gevallen dat ten onrechte als positief is geclassificeerd. Het is gelijk aan één minus het werkelijk negatieve tarief. Het werkelijk negatieve tarief wordt ook wel genoemd specificiteit. Vandaar de ROC-curvegrafieken gevoeligheid (herinnering) versus 1-specificiteit

Om de ROC-curve uit te zetten, moeten we een bibliotheek genaamd RORC installeren. We kunnen het vinden in de conda bibliotheek. U kunt de code typen:

conda install -cr r-rocr –ja

We kunnen de ROC uitzetten met de functies forecast() en performance().

library(ROCR)
ROCRpred <- prediction(predict, data_test$income)
ROCRperf <- performance(ROCRpred, 'tpr', 'fpr')
plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))

Code Uitleg

  • voorspelling(predict, data_test$income): De ROCR-bibliotheek moet een voorspellingsobject maken om de invoergegevens te transformeren
  • performance(ROCRpred, 'tpr','fpr'): Retourneert de twee combinaties die in de grafiek moeten worden geproduceerd. Hier worden tpr en fpr geconstrueerd. Om precisie en herinnering samen te plotten, gebruikt u “prec”, “rec”.

Output:

De ROC-curve

Stap 8) Verbeter het model

Je kunt proberen om niet-lineariteit aan het model toe te voegen met de interactie ertussen

  • leeftijd en uren per week
  • geslacht en uren per week.

U moet de scoretest gebruiken om beide modellen te vergelijken

formula_2 <- income~age: hours.per.week + gender: hours.per.week + .
logit_2 <- glm(formula_2, data = data_train, family = 'binomial')
predict_2 <- predict(logit_2, data_test, type = 'response')
table_mat_2 <- table(data_test$income, predict_2 > 0.5)
precision_2 <- precision(table_mat_2)
recall_2 <- recall(table_mat_2)
f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2))
f1_2

Output:

## [1] 0.6109181

De score is iets hoger dan de vorige. U kunt aan de gegevens blijven werken en proberen de score te verbeteren.

Samenvatting

We kunnen de functie voor het trainen van een logistieke regressie samenvatten in de onderstaande tabel:

Pakket Objectief Functie Argument
- Maak een trein-/testgegevensset create_train_set() gegevens, grootte, trein
glm Train een gegeneraliseerd lineair model glm() formule, gegevens, familie*
glm Vat het model samen samenvatting() getailleerd model
baseren Maak voorspelling voorspellen() passend model, dataset, type = 'reactie'
baseren Maak een verwarringsmatrix tafel() y, voorspellen()
baseren Nauwkeurigheidsscore maken som(diag(tabel())/som(tabel()
ROCR ROC aanmaken: Stap 1 Maak een voorspelling voorspelling() voorspellen(), y
ROCR Creëer ROC: Stap 2 Creëer prestatie prestatie() voorspelling(), 'tpr', 'fpr'
ROCR ROC maken: Stap 3 Grafiek tekenen verhaallijn() prestatie()

Andere GLM soort modellen zijn:

– binomiaal: (link = “logit”)

– gaussiaans: (link = “identiteit”)

– Gamma: (link = “invers”)

– inverse.gaussisch: (link = “1/mu^2”)

– poisson: (link = “logboek”)

– quasi: (link = “identiteit”, variantie = “constant”)

– quasibinomiaal: (link = “logit”)

– quasipoisson: (link = “logboek”)