GLM in R: gegeneraliseerd lineair model met voorbeeld
Wat is logistieke regressie?
Logistische regressie wordt gebruikt om een klasse, dat wil zeggen een waarschijnlijkheid, te voorspellen. Logistieke regressie kan een binaire uitkomst nauwkeurig voorspellen.
Stel je voor dat je op basis van vele kenmerken wilt voorspellen of een lening wordt geweigerd/geaccepteerd. De logistische regressie heeft de vorm 0/1. y = 0 als een lening wordt afgewezen, y = 1 als deze wordt geaccepteerd.
Een logistisch regressiemodel verschilt op twee manieren van een lineair regressiemodel.
- Allereerst accepteert de logistische regressie alleen dichotome (binaire) invoer als afhankelijke variabele (dwz een vector van 0 en 1).
- In de tweede plaats wordt de uitkomst gemeten met behulp van de volgende waarschijnlijke koppelingsfunctie, genaamd sigmoïde vanwege zijn S-vorm.:
De uitvoer van de functie ligt altijd tussen 0 en 1. Controleer onderstaande afbeelding
De sigmoïdefunctie retourneert waarden van 0 tot 1. Voor de classificatietaak hebben we een discrete uitvoer van 0 of 1 nodig.
Om een continue stroom om te zetten in discrete waarde, kunnen we een beslissingsgrens instellen op 0.5. Alle waarden boven deze drempel worden geclassificeerd als 1
Hoe u een gegeneraliseerd voeringmodel (GLM) kunt maken
Laten we de volwassen dataset om logistieke regressie te illustreren. De “volwassene” is een geweldige dataset voor de classificatietaak. Het doel is om te voorspellen of het jaarinkomen in dollars van een individu hoger zal zijn dan 50.000. De dataset bevat 46,033 observaties en tien kenmerken:
- leeftijd: leeftijd van het individu. Numeriek
- opleiding: opleidingsniveau van het individu. Factor.
- burgerlijke staat: Maristatus van het individu. Factor dwz nooit getrouwd, getrouwd-burgerlijke echtgenoot, …
- geslacht: geslacht van het individu. Factor, dwz mannelijk of vrouwelijk
- inkomen: Target variabel. Inkomen boven of onder 50K. Factor dwz >50K, <=50K
onder anderen
library(dplyr) data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv") glimpse(data_adult)
Output:
Observations: 48,842 Variables: 10 $ x <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,... $ age <int> 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26... $ workclass <fctr> Private, Private, Local-gov, Private, ?, Private,... $ education <fctr> 11th, HS-grad, Assoc-acdm, Some-college, Some-col... $ educational.num <int> 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,... $ marital.status <fctr> Never-married, Married-civ-spouse, Married-civ-sp... $ race <fctr> Black, White, White, Black, White, White, Black, ... $ gender <fctr> Male, Male, Male, Male, Female, Male, Male, Male,... $ hours.per.week <int> 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39... $ income <fctr> <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5...
Wij gaan als volgt te werk:
- Stap 1: Controleer continue variabelen
- Stap 2: Controleer de factorvariabelen
- Stap 3: Functie-engineering
- Stap 4: Samenvattende statistiek
- Stap 5: Trein-/testset
- Stap 6: Bouw het model
- Stap 7: Beoordeel de prestaties van het model
- stap 8: Verbeter het model
Jouw taak is om te voorspellen welke persoon een omzet zal hebben die hoger is dan 50.
In deze zelfstudie wordt elke stap gedetailleerd beschreven voor het uitvoeren van een analyse op een echte dataset.
Stap 1) Controleer continue variabelen
In de eerste stap kunt u de verdeling van de continue variabelen zien.
continuous <-select_if(data_adult, is.numeric) summary(continuous)
Code Uitleg
- continu <- select_if(data_adult, is.numeric): Gebruik de functie select_if() uit de dplyr-bibliotheek om alleen de numerieke kolommen te selecteren
- summary(continuous): Druk de samenvattende statistiek af
Output:
## X age educational.num hours.per.week ## Min. : 1 Min. :17.00 Min. : 1.00 Min. : 1.00 ## 1st Qu.:11509 1st Qu.:28.00 1st Qu.: 9.00 1st Qu.:40.00 ## Median :23017 Median :37.00 Median :10.00 Median :40.00 ## Mean :23017 Mean :38.56 Mean :10.13 Mean :40.95 ## 3rd Qu.:34525 3rd Qu.:47.00 3rd Qu.:13.00 3rd Qu.:45.00 ## Max. :46033 Max. :90.00 Max. :16.00 Max. :99.00
Uit de bovenstaande tabel blijkt dat de gegevens een totaal verschillende schaal hebben en dat het aantal uren per week grote uitschieters kent (kijk bijvoorbeeld naar het laatste kwartiel en de maximumwaarde).
U kunt dit op de volgende twee manieren aanpakken:
- 1: Teken de verdeling van uren per week
- 2: Standaardiseer de continue variabelen
- Teken de verdeling
Laten we de verdeling van uren per week eens nader bekijken
# Histogram with kernel density curve library(ggplot2) ggplot(continuous, aes(x = hours.per.week)) + geom_density(alpha = .2, fill = "#FF6666")
Output:
De variabele heeft veel outliers en een niet goed gedefinieerde distributie. U kunt dit probleem gedeeltelijk aanpakken door de bovenste 0.01 procent van de uren per week te verwijderen.
Basissyntaxis van kwantiel:
quantile(variable, percentile) arguments: -variable: Select the variable in the data frame to compute the percentile -percentile: Can be a single value between 0 and 1 or multiple value. If multiple, use this format: `c(A,B,C, ...) - `A`,`B`,`C` and `...` are all integer from 0 to 1.
We berekenen het bovenste 2 procent-percentiel
top_one_percent <- quantile(data_adult$hours.per.week, .99) top_one_percent
Code Uitleg
- quantile(data_adult$hours.per.week, .99): Bereken de waarde van 99 procent van de werktijd
Output:
## 99% ## 80
98 procent van de bevolking werkt minder dan 80 uur per week.
U kunt de waarnemingen boven deze drempel laten vallen. Je gebruikt het filter uit de dplyr bibliotheek.
data_adult_drop <-data_adult %>% filter(hours.per.week<top_one_percent) dim(data_adult_drop)
Output:
## [1] 45537 10
- Standaardiseer de continue variabelen
U kunt elke kolom standaardiseren om de prestaties te verbeteren, omdat uw gegevens niet dezelfde schaal hebben. U kunt de functie mute_if uit de dplyr-bibliotheek gebruiken. De basissyntaxis is:
mutate_if(df, condition, funs(function)) arguments: -`df`: Data frame used to compute the function - `condition`: Statement used. Do not use parenthesis - funs(function): Return the function to apply. Do not use parenthesis for the function
U kunt de numerieke kolommen als volgt standaardiseren:
data_adult_rescale <- data_adult_drop % > % mutate_if(is.numeric, funs(as.numeric(scale(.)))) head(data_adult_rescale)
Code Uitleg
- mute_if(is.numeric, funs(scale)): De voorwaarde is alleen een numerieke kolom en de functie is schaal
Output:
## X age workclass education educational.num ## 1 -1.732680 -1.02325949 Private 11th -1.22106443 ## 2 -1.732605 -0.03969284 Private HS-grad -0.43998868 ## 3 -1.732530 -0.79628257 Local-gov Assoc-acdm 0.73162494 ## 4 -1.732455 0.41426100 Private Some-college -0.04945081 ## 5 -1.732379 -0.34232873 Private 10th -1.61160231 ## 6 -1.732304 1.85178149 Self-emp-not-inc Prof-school 1.90323857 ## marital.status race gender hours.per.week income ## 1 Never-married Black Male -0.03995944 <=50K ## 2 Married-civ-spouse White Male 0.86863037 <=50K ## 3 Married-civ-spouse White Male -0.03995944 >50K ## 4 Married-civ-spouse Black Male -0.03995944 >50K ## 5 Never-married White Male -0.94854924 <=50K ## 6 Married-civ-spouse White Male -0.76683128 >50K
Stap 2) Controleer de factorvariabelen
Deze stap heeft twee doelstellingen:
- Controleer het niveau in elke categorische kolom
- Definieer nieuwe niveaus
We verdelen deze stap in drie delen:
- Selecteer de categorische kolommen
- Bewaar het staafdiagram van elke kolom in een lijst
- Print de grafieken
We kunnen de factorkolommen selecteren met de onderstaande code:
# Select categorical column factor <- data.frame(select_if(data_adult_rescale, is.factor)) ncol(factor)
Code Uitleg
- data.frame(select_if(data_adult, is.factor)): We slaan de factorkolommen op in factor in een dataframetype. De bibliotheek ggplot2 vereist een dataframe-object.
Output:
## [1] 6
De dataset bevat 6 categorische variabelen
De tweede stap is vaardiger. U wilt voor elke kolom in de dataframefactor een staafdiagram plotten. Het is handiger om het proces te automatiseren, vooral als er veel kolommen zijn.
library(ggplot2) # Create graph for each column graph <- lapply(names(factor), function(x) ggplot(factor, aes(get(x))) + geom_bar() + theme(axis.text.x = element_text(angle = 90)))
Code Uitleg
- lapply(): Gebruik de functie lapply() om een functie door te geven in alle kolommen van de dataset. De uitvoer sla je op in een lijst
- function(x): De functie wordt voor elke x verwerkt. Hier zijn x de kolommen
- ggplot(factor, aes(get(x))) + geom_bar()+ thema(axis.text.x = element_text(angle = 90)): Maak een staafdiagram voor elk x-element. Let op: om x als kolom terug te geven, moet je het in de get() opnemen
De laatste stap is relatief eenvoudig. U wilt de 6 grafieken afdrukken.
# Print the graph graph
Output:
## [[1]]
## ## [[2]]
## ## [[3]]
## ## [[4]]
## ## [[5]]
## ## [[6]]
Let op: Gebruik de knop Volgende om naar de volgende grafiek te navigeren
Stap 3) Functie-engineering
Herschikking onderwijs
Uit de bovenstaande grafiek kun je zien dat de variabele opleiding 16 niveaus heeft. Dit is aanzienlijk, en sommige niveaus hebben een relatief laag aantal waarnemingen. Als u de hoeveelheid informatie die u uit deze variabele kunt halen, wilt verbeteren, kunt u deze naar een hoger niveau herschikken. Je creëert namelijk grotere groepen met een vergelijkbaar opleidingsniveau. Zo zal een laag opleidingsniveau zich vertalen in uitval. Hogere onderwijsniveaus zullen worden gewijzigd in master.
Hier zijn de details:
Oud niveau | Nieuw level |
---|---|
Peuter | afvaller |
10 | Afvaller |
11 | Afvaller |
12 | Afvaller |
1e-4e | Afvaller |
5th-6th | Afvaller |
7th-8th | Afvaller |
9 | Afvaller |
HS-Grad | Hooggrad |
Sommige-college | Maatschappelijke verantwoordelijkheid |
Assoc-acdm | Maatschappelijke verantwoordelijkheid |
Assoc-voc | Maatschappelijke verantwoordelijkheid |
Bachelors | Bachelors |
Masters | Masters |
Prof-school | Masters |
Doctoraat | PhD |
recast_data <- data_adult_rescale % > % select(-X) % > % mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community", ifelse(education == "Bachelors", "Bachelors", ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))
Code Uitleg
- We gebruiken het werkwoord muteren uit de dplyr-bibliotheek. We veranderen de waarden van onderwijs met de uitspraak ifelse
In onderstaande tabel maak je een samenvattende statistiek om te zien hoeveel jaar opleiding (z-waarde) er gemiddeld nodig is om de bachelor, master of PhD te bereiken.
recast_data % > % group_by(education) % > % summarize(average_educ_year = mean(educational.num), count = n()) % > % arrange(average_educ_year)
Output:
## # A tibble: 6 x 3 ## education average_educ_year count ## <fctr> <dbl> <int> ## 1 dropout -1.76147258 5712 ## 2 HighGrad -0.43998868 14803 ## 3 Community 0.09561361 13407 ## 4 Bachelors 1.12216282 7720 ## 5 Master 1.60337381 3338 ## 6 PhD 2.29377644 557
Herschikking Marital-status
Het is ook mogelijk om lagere niveaus voor de burgerlijke staat te creëren. In de volgende code verandert u het niveau als volgt:
Oud niveau | Nieuw level |
---|---|
Nooit getrouwd | Niet getrouwd |
Getrouwde echtgenoot-afwezig | Niet getrouwd |
Getrouwd-AF-echtgenoot | Getrouwd |
Getrouwd-civ-echtgenoot | |
Gescheiden | Gescheiden |
Gescheiden | |
weduwen | Weduwe |
# Change level marry recast_data <- recast_data % > % mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))
U kunt het aantal individuen binnen elke groep controleren.
table(recast_data$marital.status)
Output:
## ## Married Not_married Separated Widow ## 21165 15359 7727 1286
Stap 4) Samenvattende statistiek
Het is tijd om wat statistieken over onze doelvariabelen te bekijken. In de onderstaande grafiek tel je het percentage individuen dat meer dan 50 verdient, gegeven hun geslacht.
# Plot gender income ggplot(recast_data, aes(x = gender, fill = income)) + geom_bar(position = "fill") + theme_classic()
Output:
Controleer vervolgens of de afkomst van het individu van invloed is op zijn verdiensten.
# Plot origin income ggplot(recast_data, aes(x = race, fill = income)) + geom_bar(position = "fill") + theme_classic() + theme(axis.text.x = element_text(angle = 90))
Output:
Het aantal werkuren per geslacht.
# box plot gender working time ggplot(recast_data, aes(x = gender, y = hours.per.week)) + geom_boxplot() + stat_summary(fun.y = mean, geom = "point", size = 3, color = "steelblue") + theme_classic()
Output:
De boxplot bevestigt dat de verdeling van werktijd past bij verschillende groepen. In de boxplot hebben beide geslachten geen homogene observaties.
Per opleidingstype kunt u de dichtheid van de wekelijkse arbeidstijd nagaan. De distributies hebben veel verschillende keuzes. Dit kan waarschijnlijk worden verklaard door het type contract in de VS.
# Plot distribution working time by education ggplot(recast_data, aes(x = hours.per.week)) + geom_density(aes(color = education), alpha = 0.5) + theme_classic()
Code Uitleg
- ggplot(recast_data, aes( x= hours.per.week)): Een dichtheidsplot vereist slechts één variabele
- geom_density(aes(color = education), alpha =0.5): Het geometrische object om de dichtheid te regelen
Output:
Om uw gedachten te bevestigen, kunt u eenrichtingsverkeer uitvoeren ANOVA-test:
anova <- aov(hours.per.week~education, recast_data) summary(anova)
Output:
## Df Sum Sq Mean Sq F value Pr(>F) ## education 5 1552 310.31 321.2 <2e-16 *** ## Residuals 45531 43984 0.97 ## --- ## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
De ANOVA-test bevestigt het verschil in gemiddelde tussen groepen.
Niet-lineariteit
Voordat u het model uitvoert, kunt u zien of het aantal gewerkte uren verband houdt met uw leeftijd.
library(ggplot2) ggplot(recast_data, aes(x = age, y = hours.per.week)) + geom_point(aes(color = income), size = 0.5) + stat_smooth(method = 'lm', formula = y~poly(x, 2), se = TRUE, aes(color = income)) + theme_classic()
Code Uitleg
- ggplot(recast_data, aes(x = leeftijd, y = uren.per.week)): Stel de esthetiek van de grafiek in
- geom_point(aes(color=inkomen), size =0.5): Construeer de puntenplot
- stat_smooth(): Voeg de trendlijn toe met de volgende argumenten:
- method='lm': Teken de aangepaste waarde als de lineaire regressie
- formule = y~poly(x,2): Pas een polynoomregressie toe
- se = TRUE: voeg de standaardfout toe
- aes(kleur=inkomen): Verdeel het model op basis van inkomen
Output:
Kortom, u kunt interactietermen in het model testen om het niet-lineariteitseffect tussen de wekelijkse werktijd en andere kenmerken op te sporen. Het is belangrijk om te detecteren onder welke omstandigheden de werktijd verschilt.
Correlatie
De volgende controle is het visualiseren van de correlatie tussen de variabelen. U converteert het factorniveautype naar numeriek, zodat u een heatmap kunt plotten met de correlatiecoëfficiënt die is berekend met de Spearman-methode.
library(GGally) # Convert data to numeric corr <- data.frame(lapply(recast_data, as.integer)) # Plot the graphggcorr(corr, method = c("pairwise", "spearman"), nbreaks = 6, hjust = 0.8, label = TRUE, label_size = 3, color = "grey50")
Code Uitleg
- data.frame(lapply(recast_data,as.integer)): Converteer gegevens naar numeriek
- ggcorr() plot de heatmap met de volgende argumenten:
- methode: Methode om de correlatie te berekenen
- nbreaks = 6: Aantal pauzes
- hjust = 0.8: Controlepositie van de variabelenaam in de plot
- label = TRUE: Voeg labels toe in het midden van de vensters
- label_size = 3: Maatlabels
- kleur = “grey50”): Kleur van het label
Output:
Stap 5) Trein-/testset
Elke onder toezicht machine learning taak vereisen om de gegevens te splitsen tussen een treinset en een testset. U kunt de “functie” die u in de andere begeleide leertutorials hebt gemaakt, gebruiken om een trein-/testset te maken.
set.seed(1234) create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample <- 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } } data_train <- create_train_test(recast_data, 0.8, train = TRUE) data_test <- create_train_test(recast_data, 0.8, train = FALSE) dim(data_train)
Output:
## [1] 36429 9
dim(data_test)
Output:
## [1] 9108 9
Stap 6) Bouw het model
Om te zien hoe het algoritme presteert, gebruikt u het glm()-pakket. De Gegeneraliseerd lineair model is een verzameling modellen. De basissyntaxis is:
glm(formula, data=data, family=linkfunction() Argument: - formula: Equation used to fit the model- data: dataset used - Family: - binomial: (link = "logit") - gaussian: (link = "identity") - Gamma: (link = "inverse") - inverse.gaussian: (link = "1/mu^2") - poisson: (link = "log") - quasi: (link = "identity", variance = "constant") - quasibinomial: (link = "logit") - quasipoisson: (link = "log")
U bent klaar om het logistieke model te schatten om het inkomensniveau op te splitsen over een reeks kenmerken.
formula <- income~. logit <- glm(formula, data = data_train, family = 'binomial') summary(logit)
Code Uitleg
- formule <- inkomen ~ .: Creëer het model dat past
- logit <- glm(formule, data = data_train, family = 'binomial'): Pas een logistiek model (family = 'binomial') aan met de data_train-gegevens.
- summary(logit): Druk de samenvatting van het model af
Output:
## ## Call: ## glm(formula = formula, family = "binomial", data = data_train) ## ## Deviance Residuals: ## Min 1Q Median 3Q Max ## -2.6456 -0.5858 -0.2609 -0.0651 3.1982 ## ## Coefficients: ## Estimate Std. Error z value Pr(>|z|) ## (Intercept) 0.07882 0.21726 0.363 0.71675 ## age 0.41119 0.01857 22.146 < 2e-16 *** ## workclassLocal-gov -0.64018 0.09396 -6.813 9.54e-12 *** ## workclassPrivate -0.53542 0.07886 -6.789 1.13e-11 *** ## workclassSelf-emp-inc -0.07733 0.10350 -0.747 0.45499 ## workclassSelf-emp-not-inc -1.09052 0.09140 -11.931 < 2e-16 *** ## workclassState-gov -0.80562 0.10617 -7.588 3.25e-14 *** ## workclassWithout-pay -1.09765 0.86787 -1.265 0.20596 ## educationCommunity -0.44436 0.08267 -5.375 7.66e-08 *** ## educationHighGrad -0.67613 0.11827 -5.717 1.08e-08 *** ## educationMaster 0.35651 0.06780 5.258 1.46e-07 *** ## educationPhD 0.46995 0.15772 2.980 0.00289 ** ## educationdropout -1.04974 0.21280 -4.933 8.10e-07 *** ## educational.num 0.56908 0.07063 8.057 7.84e-16 *** ## marital.statusNot_married -2.50346 0.05113 -48.966 < 2e-16 *** ## marital.statusSeparated -2.16177 0.05425 -39.846 < 2e-16 *** ## marital.statusWidow -2.22707 0.12522 -17.785 < 2e-16 *** ## raceAsian-Pac-Islander 0.08359 0.20344 0.411 0.68117 ## raceBlack 0.07188 0.19330 0.372 0.71001 ## raceOther 0.01370 0.27695 0.049 0.96054 ## raceWhite 0.34830 0.18441 1.889 0.05894 . ## genderMale 0.08596 0.04289 2.004 0.04506 * ## hours.per.week 0.41942 0.01748 23.998 < 2e-16 *** ## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 ## ## (Dispersion parameter for binomial family taken to be 1) ## ## Null deviance: 40601 on 36428 degrees of freedom ## Residual deviance: 27041 on 36406 degrees of freedom ## AIC: 27087 ## ## Number of Fisher Scoring iterations: 6
De samenvatting van ons model onthult interessante informatie. De prestaties van een logistische regressie worden geëvalueerd met specifieke belangrijke statistieken.
- AIC (Akaike Information Criteria): Dit is het equivalent van R2 bij logistische regressie. Het meet de fit wanneer er een boete wordt toegepast op het aantal parameters. Kleiner AIC waarden geven aan dat het model dichter bij de waarheid ligt.
- Nulafwijking: Past alleen op het model met het snijpunt. De vrijheidsgraad is n-1. We kunnen het interpreteren als een Chi-kwadraatwaarde (gepaste waarde die afwijkt van het testen van de werkelijke waardehypothese).
- Residuele afwijking: Model met alle variabelen. Het wordt ook geïnterpreteerd als een Chi-kwadraathypothesetest.
- Aantal Fisher Scoring-iteraties: aantal iteraties vóór convergentie.
De uitvoer van de glm()-functie wordt opgeslagen in een lijst. De onderstaande code toont alle items die beschikbaar zijn in de logitvariabele die we hebben geconstrueerd om de logistische regressie te evalueren.
# De lijst is erg lang, druk alleen de eerste drie elementen af
lapply(logit, class)[1:3]
Output:
## $coefficients ## [1] "numeric" ## ## $residuals ## [1] "numeric" ## ## $fitted.values ## [1] "numeric"
Elke waarde kan worden geëxtraheerd met het $-teken, gevolgd door de naam van de statistieken. U hebt het model bijvoorbeeld opgeslagen als logit. Om de AIC-criteria te extraheren, gebruikt u:
logit$aic
Output:
## [1] 27086.65
Stap 7) Beoordeel de prestaties van het model
Verwarring Matrix
Ocuco's Medewerkers verwarring matrix is een betere keuze om de classificatieprestaties te evalueren in vergelijking met de verschillende statistieken die u eerder zag. Het algemene idee is om het aantal keren te tellen dat True-instanties zijn geclassificeerd als False.
Om de verwarringsmatrix te berekenen, heb je eerst een reeks voorspellingen nodig, zodat deze kunnen worden vergeleken met de werkelijke doelstellingen.
predict <- predict(logit, data_test, type = 'response') # confusion matrix table_mat <- table(data_test$income, predict > 0.5) table_mat
Code Uitleg
- voorspellen(logit,data_test, type = 'response'): Bereken de voorspelling op de testset. Stel type = 'respons' in om de responskans te berekenen.
- table(data_test$income, voorspellen > 0.5): Bereken de verwarringsmatrix. voorspellen > 0.5 betekent dat het 1 retourneert als de voorspelde kansen hoger zijn dan 0.5, anders 0.
Output:
## ## FALSE TRUE ## <=50K 6310 495 ## >50K 1074 1229
Elke rij in een verwarringsmatrix vertegenwoordigt een feitelijk doel, terwijl elke kolom een voorspeld doel vertegenwoordigt. De eerste rij van deze matrix beschouwt het inkomen lager dan 50 (de False-klasse): 6241 werden correct geclassificeerd als individuen met een inkomen lager dan 50 (Echt negatief), terwijl de overige ten onrechte werd geclassificeerd als boven de 50k (Vals positief). De tweede rij beschouwt het inkomen boven de 50, de positieve klasse was 1229 (Echt positief), Terwijl de Echt negatief was 1074.
U kunt het model berekenen nauwkeurigheid door het werkelijk positieve + het ware negatieve op te tellen over de totale waarneming
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test
Code Uitleg
- sum(diag(table_mat)): Som van de diagonaal
- sum(table_mat): Som van de matrix.
Output:
## [1] 0.8277339
Het model lijkt met één probleem te kampen: het overschat het aantal fout-negatieven. Dit heet de nauwkeurigheidstestparadox. We stelden dat de nauwkeurigheid de verhouding is tussen correcte voorspellingen en het totale aantal gevallen. We kunnen een relatief hoge nauwkeurigheid hebben, maar een nutteloos model. Het gebeurt als er een dominante klasse is. Als je terugkijkt op de verwarringsmatrix, zie je dat de meeste gevallen als echt negatief zijn geclassificeerd. Stel je nu voor dat het model alle klassen als negatief classificeerde (dwz lager dan 50k). U zou een nauwkeurigheid van 75 procent hebben (6718/6718+2257). Uw model presteert beter, maar heeft moeite om het echte positieve van het echte negatieve te onderscheiden.
In dergelijke situaties verdient het de voorkeur om over een beknoptere metriek te beschikken. We kunnen kijken naar:
- Precisie=TP/(TP+FP)
- Terugroepen=TP/(TP+FN)
Precisie versus terugroepen
precisie kijkt naar de nauwkeurigheid van de positieve voorspelling. Terugroepen is de verhouding van positieve exemplaren die correct worden gedetecteerd door de classificator;
U kunt twee functies construeren om deze twee statistieken te berekenen
- Precisie construeren
precision <- function(matrix) { # True positive tp <- matrix[2, 2] # false positive fp <- matrix[1, 2] return (tp / (tp + fp)) }
Code Uitleg
- mat[1,1]: Geeft de eerste cel van de eerste kolom van het dataframe terug, dwz het echte positieve
- mat[1,2]; Retourneert de eerste cel van de tweede kolom van het dataframe, dat wil zeggen de fout-positieve cel
recall <- function(matrix) { # true positive tp <- matrix[2, 2]# false positive fn <- matrix[2, 1] return (tp / (tp + fn)) }
Code Uitleg
- mat[1,1]: Geeft de eerste cel van de eerste kolom van het dataframe terug, dwz het echte positieve
- mat[2,1]; Retourneer de tweede cel van de eerste kolom van het dataframe, dwz het vals-negatieve getal
U kunt uw functies testen
prec <- precision(table_mat) prec rec <- recall(table_mat) rec
Output:
## [1] 0.712877 ## [2] 0.5336518
Als het model zegt dat het om een persoon boven de 50 gaat, is dat in slechts 54 procent van de gevallen juist, en kan het in 50 procent van de gevallen individuen boven de 72 claimen.
U kunt de score gebaseerd op de precisie en herinnering. De
is een harmonisch gemiddelde van deze twee metrieken, wat betekent dat het meer gewicht toekent aan de lagere waarden.
f1 <- 2 * ((prec * rec) / (prec + rec)) f1
Output:
## [1] 0.6103799
Afweging tussen precisie en terugroepen
Het is onmogelijk om zowel een hoge precisie als een hoge herinnering te hebben.
Als we de precisie verhogen, zal het juiste individu beter worden voorspeld, maar we zouden er veel van missen (lagere herinnering). In sommige situaties geven we de voorkeur aan een hogere precisie dan terugroepen. Er is een concave relatie tussen precisie en herinnering.
- Stel je voor: je moet voorspellen of een patiënt een ziekte heeft. Je wilt zo precies mogelijk zijn.
- Als je potentiële frauduleuze mensen op straat moet opsporen via gezichtsherkenning, kun je beter veel mensen betrappen die als frauduleus worden bestempeld, ook al is de nauwkeurigheid laag. De politie zal de niet-frauduleuze persoon kunnen vrijlaten.
De ROC-curve
Ocuco's Medewerkers Ontvanger Operakenmerkend curve is een ander veelgebruikt hulpmiddel bij binaire classificatie. Het lijkt erg op de precisie/herinneringscurve, maar in plaats van precisie versus herinnering weer te geven, toont de ROC-curve het werkelijk positieve percentage (dat wil zeggen, herinnering) tegenover het fout-positieve percentage. Het percentage fout-positieve resultaten is het aantal negatieve gevallen dat ten onrechte als positief is geclassificeerd. Het is gelijk aan één minus het werkelijk negatieve tarief. Het werkelijk negatieve tarief wordt ook wel genoemd specificiteit. Vandaar de ROC-curvegrafieken gevoeligheid (herinnering) versus 1-specificiteit
Om de ROC-curve uit te zetten, moeten we een bibliotheek genaamd RORC installeren. We kunnen het vinden in de conda bibliotheek. U kunt de code typen:
conda install -cr r-rocr –ja
We kunnen de ROC uitzetten met de functies forecast() en performance().
library(ROCR) ROCRpred <- prediction(predict, data_test$income) ROCRperf <- performance(ROCRpred, 'tpr', 'fpr') plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))
Code Uitleg
- voorspelling(predict, data_test$income): De ROCR-bibliotheek moet een voorspellingsobject maken om de invoergegevens te transformeren
- performance(ROCRpred, 'tpr','fpr'): Retourneert de twee combinaties die in de grafiek moeten worden geproduceerd. Hier worden tpr en fpr geconstrueerd. Om precisie en herinnering samen te plotten, gebruikt u “prec”, “rec”.
Output:
Stap 8) Verbeter het model
Je kunt proberen om niet-lineariteit aan het model toe te voegen met de interactie ertussen
- leeftijd en uren per week
- geslacht en uren per week.
U moet de scoretest gebruiken om beide modellen te vergelijken
formula_2 <- income~age: hours.per.week + gender: hours.per.week + . logit_2 <- glm(formula_2, data = data_train, family = 'binomial') predict_2 <- predict(logit_2, data_test, type = 'response') table_mat_2 <- table(data_test$income, predict_2 > 0.5) precision_2 <- precision(table_mat_2) recall_2 <- recall(table_mat_2) f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2)) f1_2
Output:
## [1] 0.6109181
De score is iets hoger dan de vorige. U kunt aan de gegevens blijven werken en proberen de score te verbeteren.
Samenvatting
We kunnen de functie voor het trainen van een logistieke regressie samenvatten in de onderstaande tabel:
Pakket | Objectief | Functie | Argument |
---|---|---|---|
- | Maak een trein-/testgegevensset | create_train_set() | gegevens, grootte, trein |
glm | Train een gegeneraliseerd lineair model | glm() | formule, gegevens, familie* |
glm | Vat het model samen | samenvatting() | getailleerd model |
baseren | Maak voorspelling | voorspellen() | passend model, dataset, type = 'reactie' |
baseren | Maak een verwarringsmatrix | tafel() | y, voorspellen() |
baseren | Nauwkeurigheidsscore maken | som(diag(tabel())/som(tabel() | |
ROCR | ROC aanmaken: Stap 1 Maak een voorspelling | voorspelling() | voorspellen(), y |
ROCR | Creëer ROC: Stap 2 Creëer prestatie | prestatie() | voorspelling(), 'tpr', 'fpr' |
ROCR | ROC maken: Stap 3 Grafiek tekenen | verhaallijn() | prestatie() |
Andere GLM soort modellen zijn:
– binomiaal: (link = “logit”)
– gaussiaans: (link = “identiteit”)
– Gamma: (link = “invers”)
– inverse.gaussisch: (link = “1/mu^2”)
– poisson: (link = “logboek”)
– quasi: (link = “identiteit”, variantie = “constant”)
– quasibinomiaal: (link = “logit”)
– quasipoisson: (link = “logboek”)