GLM v R: Zobecněný lineární model s příkladem
Co je logistická regrese?
Logistická regrese se používá k predikci třídy, tj. pravděpodobnosti. Logistická regrese dokáže přesně předpovědět binární výsledek.
Představte si, že chcete předpovědět, zda bude půjčka zamítnuta/přijata na základě mnoha atributů. Logistická regrese má tvar 0/1. y = 0, pokud je půjčka zamítnuta, y = 1, pokud je přijata.
Logistický regresní model se liší od lineárního regresního modelu dvěma způsoby.
- Za prvé, logistická regrese akceptuje pouze dichotomický (binární) vstup jako závisle proměnnou (tj. vektor 0 a 1).
- Za druhé, výsledek je měřen pomocí následující funkce pravděpodobnostního odkazu sigmatu díky svému tvaru S.:
Výstup funkce je vždy mezi 0 a 1. Zkontrolujte obrázek níže
Funkce sigmoid vrací hodnoty od 0 do 1. Pro klasifikační úlohu potřebujeme diskrétní výstup 0 nebo 1.
Pro převod spojitého toku na diskrétní hodnotu můžeme nastavit rozhodovací hranici na 0.5. Všechny hodnoty nad touto hranicí jsou klasifikovány jako 1
Jak vytvořit Generalized Liner Model (GLM)
Pojďme použít dospělý soubor dat pro ilustraci logistické regrese. „Dospělý“ je skvělá datová sada pro klasifikační úkol. Cílem je předpovědět, zda roční příjem jednotlivce v dolarech překročí 50.000 46,033. Soubor dat obsahuje XNUMX XNUMX pozorování a deset funkcí:
- věk: věk jedince. Numerický
- vzdělání: Vzdělanostní úroveň jedince. Faktor.
- rodinný.stav: Maristav jednotlivce. Faktor, tj. nikdy nevdaná, vdaná-občanská manželka, …
- pohlaví: Pohlaví jednotlivce. Faktor, tedy Muž nebo Žena
- příjem: Target variabilní. Příjem nad nebo pod 50 tis. Faktor, tj. >50K, <=50K
mimo jiné
library(dplyr) data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv") glimpse(data_adult)
Výstup:
Observations: 48,842 Variables: 10 $ x <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,... $ age <int> 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26... $ workclass <fctr> Private, Private, Local-gov, Private, ?, Private,... $ education <fctr> 11th, HS-grad, Assoc-acdm, Some-college, Some-col... $ educational.num <int> 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,... $ marital.status <fctr> Never-married, Married-civ-spouse, Married-civ-sp... $ race <fctr> Black, White, White, Black, White, White, Black, ... $ gender <fctr> Male, Male, Male, Male, Female, Male, Male, Male,... $ hours.per.week <int> 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39... $ income <fctr> <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5...
Budeme postupovat následovně:
- Krok 1: Zkontrolujte spojité proměnné
- Krok 2: Zkontrolujte proměnné faktoru
- Krok 3: Funkce
- Krok 4: Souhrnná statistika
- Krok 5: Trénink/testovací sada
- Krok 6: Sestavte model
- Krok 7: Posuďte výkon modelu
- Krok 8: Vylepšete model
Vaším úkolem je předpovědět, která osoba bude mít tržby vyšší než 50 tisíc.
V tomto tutoriálu bude každý krok podrobně popsán k provedení analýzy skutečné datové sady.
Krok 1) Zkontrolujte spojité proměnné
V prvním kroku můžete vidět rozložení spojitých proměnných.
continuous <-select_if(data_adult, is.numeric) summary(continuous)
Vysvětlení kódu
- kontinuální <- select_if(data_adult, is.numeric): Pomocí funkce select_if() z knihovny dplyr vyberte pouze číselné sloupce
- souhrn (nepřetržitý): Tisk souhrnné statistiky
Výstup:
## X age educational.num hours.per.week ## Min. : 1 Min. :17.00 Min. : 1.00 Min. : 1.00 ## 1st Qu.:11509 1st Qu.:28.00 1st Qu.: 9.00 1st Qu.:40.00 ## Median :23017 Median :37.00 Median :10.00 Median :40.00 ## Mean :23017 Mean :38.56 Mean :10.13 Mean :40.95 ## 3rd Qu.:34525 3rd Qu.:47.00 3rd Qu.:13.00 3rd Qu.:45.00 ## Max. :46033 Max. :90.00 Max. :16.00 Max. :99.00
Z výše uvedené tabulky můžete vidět, že data mají úplně jiná měřítka a hodiny za týden mají velké odlehlé hodnoty (např. podívejte se na poslední kvartil a maximální hodnotu).
Můžete si s tím poradit ve dvou krocích:
- 1: Znázorněte rozložení hodin.za.týden
- 2: Standardizujte spojité proměnné
- Zakreslete distribuci
Podívejme se blíže na rozložení hodin.za.týden
# Histogram with kernel density curve library(ggplot2) ggplot(continuous, aes(x = hours.per.week)) + geom_density(alpha = .2, fill = "#FF6666")
Výstup:
Proměnná má mnoho odlehlých hodnot a není dobře definovaná distribuce. Tento problém můžete částečně vyřešit odstraněním horních 0.01 procenta hodin týdně.
Základní syntaxe kvantilu:
quantile(variable, percentile) arguments: -variable: Select the variable in the data frame to compute the percentile -percentile: Can be a single value between 0 and 1 or multiple value. If multiple, use this format: `c(A,B,C, ...) - `A`,`B`,`C` and `...` are all integer from 0 to 1.
Vypočítáme horní 2 procenta percentilu
top_one_percent <- quantile(data_adult$hours.per.week, .99) top_one_percent
Vysvětlení kódu
- quantile(data_adult$hours.per.week, 99): Vypočítejte hodnotu 99 procent pracovní doby
Výstup:
## 99% ## 80
98 procent populace pracuje pod 80 hodin týdně.
Pozorování můžete pustit nad tuto hranici. Používáte filtr z dplyr knihovna.
data_adult_drop <-data_adult %>% filter(hours.per.week<top_one_percent) dim(data_adult_drop)
Výstup:
## [1] 45537 10
- Standardizujte spojité proměnné
Každý sloupec můžete standardizovat, abyste zlepšili výkon, protože vaše data nemají stejné měřítko. Můžete použít funkci mutate_if z knihovny dplyr. Základní syntaxe je:
mutate_if(df, condition, funs(function)) arguments: -`df`: Data frame used to compute the function - `condition`: Statement used. Do not use parenthesis - funs(function): Return the function to apply. Do not use parenthesis for the function
Číselné sloupce můžete standardizovat následovně:
data_adult_rescale <- data_adult_drop % > % mutate_if(is.numeric, funs(as.numeric(scale(.)))) head(data_adult_rescale)
Vysvětlení kódu
- mutate_if(is.numeric, funs(scale)): Podmínka je pouze číselný sloupec a funkce je měřítko
Výstup:
## X age workclass education educational.num ## 1 -1.732680 -1.02325949 Private 11th -1.22106443 ## 2 -1.732605 -0.03969284 Private HS-grad -0.43998868 ## 3 -1.732530 -0.79628257 Local-gov Assoc-acdm 0.73162494 ## 4 -1.732455 0.41426100 Private Some-college -0.04945081 ## 5 -1.732379 -0.34232873 Private 10th -1.61160231 ## 6 -1.732304 1.85178149 Self-emp-not-inc Prof-school 1.90323857 ## marital.status race gender hours.per.week income ## 1 Never-married Black Male -0.03995944 <=50K ## 2 Married-civ-spouse White Male 0.86863037 <=50K ## 3 Married-civ-spouse White Male -0.03995944 >50K ## 4 Married-civ-spouse Black Male -0.03995944 >50K ## 5 Never-married White Male -0.94854924 <=50K ## 6 Married-civ-spouse White Male -0.76683128 >50K
Krok 2) Zkontrolujte proměnné faktoru
Tento krok má dva cíle:
- Zkontrolujte úroveň v každém kategorickém sloupci
- Definujte nové úrovně
Tento krok rozdělíme na tři části:
- Vyberte kategorické sloupce
- Uložte sloupcový graf každého sloupce do seznamu
- Vytiskněte grafy
Sloupce faktoru můžeme vybrat pomocí kódu níže:
# Select categorical column factor <- data.frame(select_if(data_adult_rescale, is.factor)) ncol(factor)
Vysvětlení kódu
- data.frame(select_if(data_adult, is.factor)): Sloupce faktoru ukládáme faktorem v typu datového rámce. Knihovna ggplot2 vyžaduje objekt datového rámce.
Výstup:
## [1] 6
Dataset obsahuje 6 kategoriálních proměnných
Druhý krok je zručnější. Chcete vykreslit pruhový graf pro každý sloupec ve faktoru datového rámce. Je pohodlnější proces automatizovat, zvláště v situaci, kdy je mnoho sloupců.
library(ggplot2) # Create graph for each column graph <- lapply(names(factor), function(x) ggplot(factor, aes(get(x))) + geom_bar() + theme(axis.text.x = element_text(angle = 90)))
Vysvětlení kódu
- lapply(): Použijte funkci lapply() k předání funkce ve všech sloupcích datové sady. Výstup uložíte do seznamu
- function(x): Funkce bude zpracována pro každé x. Zde x jsou sloupce
- ggplot(factor, aes(get(x))) + geom_bar()+ theme(axis.text.x = element_text(angle = 90)): Vytvořte pruhový znakový graf pro každý x prvek. Všimněte si, že chcete-li vrátit x jako sloupec, musíte jej zahrnout do get()
Poslední krok je poměrně snadný. Chcete vytisknout 6 grafů.
# Print the graph graph
Výstup:
## [[1]]
## ## [[2]]
## ## [[3]]
## ## [[4]]
## ## [[5]]
## ## [[6]]
Poznámka: Pomocí tlačítka Další přejděte na další graf
Krok 3) Inženýrství funkcí
Přepracování vzdělávání
Z výše uvedeného grafu vidíte, že proměnná vzdělání má 16 úrovní. To je podstatné a některé úrovně mají relativně nízký počet pozorování. Pokud chcete zlepšit množství informací, které můžete z této proměnné získat, můžete ji přetvořit na vyšší úroveň. Totiž vytváříte větší skupiny s podobnou úrovní vzdělání. Nízká úroveň vzdělání bude například převedena na předčasný odchod. Vyšší stupně vzdělání se změní na mistrovské.
Zde je detail:
Stará úroveň | Nová úroveň |
---|---|
Předškolní | výpadek |
10. | Výpadek |
11. | Výpadek |
12. | Výpadek |
1.-4 | Výpadek |
5th-6th | Výpadek |
7th-8th | Výpadek |
9. | Výpadek |
HS-Grad | HighGrad |
Některé vysoké školy | Naše projekty |
Assoc-acdm | Naše projekty |
Assoc-voc | Naše projekty |
Bachelors | Bachelors |
Masters | Masters |
Prof-škola | Masters |
Doktorát | PhD |
recast_data <- data_adult_rescale % > % select(-X) % > % mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community", ifelse(education == "Bachelors", "Bachelors", ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))
Vysvětlení kódu
- Používáme sloveso mutate z knihovny dplyr. Hodnoty vzdělání měníme výrokem ifelse
V níže uvedené tabulce vytvoříte souhrnnou statistiku, abyste v průměru viděli, kolik let vzdělání (hodnota z) je potřeba k dosažení bakalářského, magisterského nebo doktorského studia.
recast_data % > % group_by(education) % > % summarize(average_educ_year = mean(educational.num), count = n()) % > % arrange(average_educ_year)
Výstup:
## # A tibble: 6 x 3 ## education average_educ_year count ## <fctr> <dbl> <int> ## 1 dropout -1.76147258 5712 ## 2 HighGrad -0.43998868 14803 ## 3 Community 0.09561361 13407 ## 4 Bachelors 1.12216282 7720 ## 5 Master 1.60337381 3338 ## 6 PhD 2.29377644 557
Přepracovat Marital-stav
Je také možné vytvořit nižší úrovně pro rodinný stav. V následujícím kódu změníte úroveň následovně:
Stará úroveň | Nová úroveň |
---|---|
Se nikdy neoženil | Není vdaná |
Ženatý-manžel-nepřítomný | Není vdaná |
Ženatý-AF-manžel | Ženatý |
Ženatý občan | |
Oddělil | Oddělil |
Rozvedený | |
Vdovy | Vdova |
# Change level marry recast_data <- recast_data % > % mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))
Můžete zkontrolovat počet jednotlivců v každé skupině.
table(recast_data$marital.status)
Výstup:
## ## Married Not_married Separated Widow ## 21165 15359 7727 1286
Krok 4) Souhrnná statistika
Je čas zkontrolovat některé statistiky o našich cílových proměnných. V níže uvedeném grafu spočítáte procento jednotlivců, kteří vydělávají více než 50 tisíc vzhledem k jejich pohlaví.
# Plot gender income ggplot(recast_data, aes(x = gender, fill = income)) + geom_bar(position = "fill") + theme_classic()
Výstup:
Dále zkontrolujte, zda původ jednotlivce ovlivňuje jeho výdělky.
# Plot origin income ggplot(recast_data, aes(x = race, fill = income)) + geom_bar(position = "fill") + theme_classic() + theme(axis.text.x = element_text(angle = 90))
Výstup:
Počet hodin práce podle pohlaví.
# box plot gender working time ggplot(recast_data, aes(x = gender, y = hours.per.week)) + geom_boxplot() + stat_summary(fun.y = mean, geom = "point", size = 3, color = "steelblue") + theme_classic()
Výstup:
Krabicový graf potvrzuje, že rozložení pracovní doby vyhovuje různým skupinám. V krabicovém grafu nemají obě pohlaví homogenní pozorování.
Hustotu týdenní pracovní doby si můžete ověřit podle typu vzdělání. Distribuce mají mnoho odlišných výběrů. Pravděpodobně se to dá vysvětlit typem smlouvy v USA.
# Plot distribution working time by education ggplot(recast_data, aes(x = hours.per.week)) + geom_density(aes(color = education), alpha = 0.5) + theme_classic()
Vysvětlení kódu
- ggplot(recast_data, aes(x= hours.per.week)): Graf hustoty vyžaduje pouze jednu proměnnou
- geom_density(aes(barva = vzdělání), alfa =0.5): Geometrický objekt pro řízení hustoty
Výstup:
Chcete-li potvrdit své myšlenky, můžete provést jednosměrný postup Test ANOVA:
anova <- aov(hours.per.week~education, recast_data) summary(anova)
Výstup:
## Df Sum Sq Mean Sq F value Pr(>F) ## education 5 1552 310.31 321.2 <2e-16 *** ## Residuals 45531 43984 0.97 ## --- ## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Test ANOVA potvrzuje rozdíl v průměru mezi skupinami.
Nelinearita
Před spuštěním modelu můžete zjistit, zda počet odpracovaných hodin souvisí s věkem.
library(ggplot2) ggplot(recast_data, aes(x = age, y = hours.per.week)) + geom_point(aes(color = income), size = 0.5) + stat_smooth(method = 'lm', formula = y~poly(x, 2), se = TRUE, aes(color = income)) + theme_classic()
Vysvětlení kódu
- ggplot(recast_data, aes(x = věk, y = hodiny.za.týden)): Nastavení estetiky grafu
- geom_point(aes(barva= příjem), velikost =0.5): Vytvořte bodový graf
- stat_smooth(): Přidejte trendovou linii s následujícími argumenty:
- method='lm': Vykreslete přizpůsobenou hodnotu, pokud je lineární regrese
- vzorec = y~poly(x,2): Fit polynomiální regrese
- se = TRUE: Přidá standardní chybu
- aes(barva= příjem): Rozdělte model podle příjmu
Výstup:
Stručně řečeno, můžete otestovat podmínky interakce v modelu, abyste zjistili nelineární efekt mezi týdenní pracovní dobou a dalšími funkcemi. Je důležité zjistit, za jakých podmínek se liší pracovní doba.
Korelace
Další kontrolou je vizualizace korelace mezi proměnnými. Typ úrovně faktoru převedete na numerický, abyste mohli vykreslit tepelnou mapu obsahující koeficient korelace vypočítaný Spearmanovou metodou.
library(GGally) # Convert data to numeric corr <- data.frame(lapply(recast_data, as.integer)) # Plot the graphggcorr(corr, method = c("pairwise", "spearman"), nbreaks = 6, hjust = 0.8, label = TRUE, label_size = 3, color = "grey50")
Vysvětlení kódu
- data.frame(lapply(recast_data,as.integer)): Převede data na číselná
- ggcorr() vykresluje tepelnou mapu s následujícími argumenty:
- metoda: Metoda pro výpočet korelace
- nbreaks = 6: Počet přerušení
- hjust = 0.8: Kontrolní pozice názvu proměnné v grafu
- label = TRUE: Přidat štítky do středu oken
- label_size = 3: Velikost štítků
- barva = “šedá50”): Barva štítku
Výstup:
Krok 5) Trénink/testovací sada
Jakýkoli pod dohledem strojové učení úkol vyžaduje rozdělit data mezi vlakovou soupravu a testovací soupravu. K vytvoření vlakové/testovací sady můžete použít „funkci“, kterou jste vytvořili v jiných výukových programech pod dohledem.
set.seed(1234) create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample <- 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } } data_train <- create_train_test(recast_data, 0.8, train = TRUE) data_test <- create_train_test(recast_data, 0.8, train = FALSE) dim(data_train)
Výstup:
## [1] 36429 9
dim(data_test)
Výstup:
## [1] 9108 9
Krok 6) Sestavte model
Chcete-li zjistit, jak algoritmus funguje, použijte balíček glm(). The Zobecněný lineární model je sbírka modelů. Základní syntaxe je:
glm(formula, data=data, family=linkfunction() Argument: - formula: Equation used to fit the model- data: dataset used - Family: - binomial: (link = "logit") - gaussian: (link = "identity") - Gamma: (link = "inverse") - inverse.gaussian: (link = "1/mu^2") - poisson: (link = "log") - quasi: (link = "identity", variance = "constant") - quasibinomial: (link = "logit") - quasipoisson: (link = "log")
Jste připraveni odhadnout logistický model pro rozdělení úrovně příjmu mezi sadu funkcí.
formula <- income~. logit <- glm(formula, data = data_train, family = 'binomial') summary(logit)
Vysvětlení kódu
- vzorec <- příjem ~ .: Vytvořte model tak, aby vyhovoval
- logit <- glm(vzorec, data = data_train, rodina = 'binomický'): Proložte logistický model (rodina = 'binomický') pomocí dat data_train.
- Summary(logit): Vytiskne souhrn modelu
Výstup:
## ## Call: ## glm(formula = formula, family = "binomial", data = data_train) ## ## Deviance Residuals: ## Min 1Q Median 3Q Max ## -2.6456 -0.5858 -0.2609 -0.0651 3.1982 ## ## Coefficients: ## Estimate Std. Error z value Pr(>|z|) ## (Intercept) 0.07882 0.21726 0.363 0.71675 ## age 0.41119 0.01857 22.146 < 2e-16 *** ## workclassLocal-gov -0.64018 0.09396 -6.813 9.54e-12 *** ## workclassPrivate -0.53542 0.07886 -6.789 1.13e-11 *** ## workclassSelf-emp-inc -0.07733 0.10350 -0.747 0.45499 ## workclassSelf-emp-not-inc -1.09052 0.09140 -11.931 < 2e-16 *** ## workclassState-gov -0.80562 0.10617 -7.588 3.25e-14 *** ## workclassWithout-pay -1.09765 0.86787 -1.265 0.20596 ## educationCommunity -0.44436 0.08267 -5.375 7.66e-08 *** ## educationHighGrad -0.67613 0.11827 -5.717 1.08e-08 *** ## educationMaster 0.35651 0.06780 5.258 1.46e-07 *** ## educationPhD 0.46995 0.15772 2.980 0.00289 ** ## educationdropout -1.04974 0.21280 -4.933 8.10e-07 *** ## educational.num 0.56908 0.07063 8.057 7.84e-16 *** ## marital.statusNot_married -2.50346 0.05113 -48.966 < 2e-16 *** ## marital.statusSeparated -2.16177 0.05425 -39.846 < 2e-16 *** ## marital.statusWidow -2.22707 0.12522 -17.785 < 2e-16 *** ## raceAsian-Pac-Islander 0.08359 0.20344 0.411 0.68117 ## raceBlack 0.07188 0.19330 0.372 0.71001 ## raceOther 0.01370 0.27695 0.049 0.96054 ## raceWhite 0.34830 0.18441 1.889 0.05894 . ## genderMale 0.08596 0.04289 2.004 0.04506 * ## hours.per.week 0.41942 0.01748 23.998 < 2e-16 *** ## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 ## ## (Dispersion parameter for binomial family taken to be 1) ## ## Null deviance: 40601 on 36428 degrees of freedom ## Residual deviance: 27041 on 36406 degrees of freedom ## AIC: 27087 ## ## Number of Fisher Scoring iterations: 6
Shrnutí našeho modelu odhaluje zajímavé informace. Výkon logistické regrese se hodnotí pomocí specifických klíčových metrik.
- AIC (Akaike Information Criteria): Toto je ekvivalent R2 v logistické regresi. Měří přizpůsobení, když je na počet parametrů aplikována penalizace. Menší AIC hodnoty ukazují, že model je blíže pravdě.
- Nulová odchylka: Vyhovuje modelu pouze s průsečíkem. Stupeň volnosti je n-1. Můžeme ji interpretovat jako hodnotu Chí-kvadrát (proložená hodnota odlišná od testování hypotézy skutečné hodnoty).
- Reziduální odchylka: Model se všemi proměnnými. Je také interpretován jako testování hypotézy chí-kvadrát.
- Počet iterací Fisher Scoring: Počet iterací před konvergováním.
Výstup funkce glm() je uložen v seznamu. Níže uvedený kód zobrazuje všechny položky dostupné v proměnné logit, kterou jsme vytvořili pro vyhodnocení logistické regrese.
# Seznam je velmi dlouhý, vytiskněte pouze první tři prvky
lapply(logit, class)[1:3]
Výstup:
## $coefficients ## [1] "numeric" ## ## $residuals ## [1] "numeric" ## ## $fitted.values ## [1] "numeric"
Každou hodnotu lze extrahovat znakem $ následovaným názvem metrik. Například jste uložili model jako logit. Chcete-li extrahovat kritéria AIC, použijte:
logit$aic
Výstup:
## [1] 27086.65
Krok 7) Posuďte výkon modelu
Matice zmatků
Jedno matoucí matice je lepší volbou pro hodnocení výkonu klasifikace ve srovnání s různými metrikami, které jste viděli dříve. Obecnou myšlenkou je spočítat, kolikrát jsou skutečné instance klasifikovány jako nepravdivé.
Chcete-li vypočítat matici zmatků, musíte mít nejprve sadu předpovědí, aby je bylo možné porovnat se skutečnými cíli.
predict <- predict(logit, data_test, type = 'response') # confusion matrix table_mat <- table(data_test$income, predict > 0.5) table_mat
Vysvětlení kódu
- forecast(logit,data_test, type = 'response'): Vypočítejte předpověď na testovací sadě. Nastavte typ = 'response' pro výpočet pravděpodobnosti odpovědi.
- table(data_test$income, forecast > 0.5): Vypočtěte matici zmatků. predikovat > 0.5 znamená, že vrátí 1, pokud jsou předpokládané pravděpodobnosti vyšší než 0.5, jinak 0.
Výstup:
## ## FALSE TRUE ## <=50K 6310 495 ## >50K 1074 1229
Každý řádek v matečné matici představuje skutečný cíl, zatímco každý sloupec představuje předpokládaný cíl. První řádek této matice bere v úvahu příjem nižší než 50 6241 (třída False): 50 bylo správně klasifikováno jako jednotlivci s příjmem nižším než XNUMX XNUMX (Pravda negativní), zatímco zbývající byl chybně klasifikován jako nad 50 XNUMX (Falešně pozitivní). Druhý řádek bere v úvahu příjem nad 50 1229, kladná třída byla XNUMX (Pravda pozitivní), zatímco Pravda negativní byl 1074.
Můžete vypočítat model přesnost sečtením skutečných kladných + skutečných záporných hodnot z celkového pozorování
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test
Vysvětlení kódu
- sum(diag(table_mat)): Součet úhlopříčky
- sum(table_mat): Součet matice.
Výstup:
## [1] 0.8277339
Zdá se, že model trpí jedním problémem, nadhodnocuje počet falešně negativních výsledků. Toto se nazývá paradox testu přesnosti. Uvedli jsme, že přesnost je poměr správných předpovědí k celkovému počtu případů. Můžeme mít relativně vysokou přesnost, ale nepoužitelný model. Stává se to, když existuje dominantní třída. Pokud se podíváte zpět na matici zmatků, můžete vidět, že většina případů je klasifikována jako skutečně negativní. Představte si nyní, že model klasifikoval všechny třídy jako negativní (tj. nižší než 50k). Měli byste přesnost 75 procent (6718/6718+2257). Váš model funguje lépe, ale má potíže s rozlišením skutečného pozitivního od skutečného negativního.
V takové situaci je vhodnější mít stručnější metriku. Můžeme se podívat na:
- Přesnost=TP/(TP+FP)
- Vyvolat=TP/(TP+FN)
Přesnost vs
Přesnost dívá se na přesnost pozitivní předpovědi. Odvolání je poměr pozitivních instancí, které jsou správně detekovány klasifikátorem;
Pro výpočet těchto dvou metrik můžete sestavit dvě funkce
- Konstrukční přesnost
precision <- function(matrix) { # True positive tp <- matrix[2, 2] # false positive fp <- matrix[1, 2] return (tp / (tp + fp)) }
Vysvětlení kódu
- mat[1,1]: Vrátí první buňku prvního sloupce datového rámce, tj. true positive
- mat[1,2]; Vraťte první buňku druhého sloupce datového rámce, tj. falešně pozitivní
recall <- function(matrix) { # true positive tp <- matrix[2, 2]# false positive fn <- matrix[2, 1] return (tp / (tp + fn)) }
Vysvětlení kódu
- mat[1,1]: Vrátí první buňku prvního sloupce datového rámce, tj. true positive
- mat[2,1]; Vraťte druhou buňku prvního sloupce datového rámce, tj. falešně negativní
Můžete otestovat své funkce
prec <- precision(table_mat) prec rec <- recall(table_mat) rec
Výstup:
## [1] 0.712877 ## [2] 0.5336518
Když model říká, že jde o osobu nad 50 tisíc, je to správné pouze v 54 procentech případů a v 50 procentech případů může tvrdit osoby nad 72 tisíc.
Můžete vytvořit skóre na základě přesnosti a zapamatovatelnosti. The
je harmonický průměr těchto dvou metrik, což znamená, že dává větší váhu nižším hodnotám.
f1 <- 2 * ((prec * rec) / (prec + rec)) f1
Výstup:
## [1] 0.6103799
Přesnost vs Recall kompromis
Je nemožné mít současně vysokou přesnost a vysokou vybavitelnost.
Pokud zvýšíme přesnost, bude se lépe předpovídat správný jedinec, ale mnoho z nich bychom vynechali (nižší zapamatovatelnost). V některých situacích dáváme přednost vyšší přesnosti než vyvolání. Mezi přesností a vyvoláním je konkávní vztah.
- Představte si, že potřebujete předpovědět, zda má pacient nějakou nemoc. Chcete být co nejpřesnější.
- Pokud potřebujete odhalit potenciální podvodné lidi na ulici pomocí rozpoznávání obličeje, bylo by lepší chytit mnoho lidí označených jako podvodníci, i když je přesnost nízká. Policie bude moci nepodvedenou osobu propustit.
ROC křivka
Jedno Přijímač Operating Charakteristika křivka je dalším běžným nástrojem používaným s binární klasifikací. Je velmi podobná křivce přesnosti/vybavení, ale místo vynesení přesnosti versus vybavování ukazuje křivka ROC skutečnou pozitivní četnost (tj. vybavování) proti četnosti falešně pozitivních. Míra falešně pozitivních výsledků je poměr negativních instancí, které jsou nesprávně klasifikovány jako pozitivní. Rovná se jedné mínus skutečná záporná sazba. Skutečná záporná sazba se také nazývá specifičnost. Proto se vykresluje ROC křivka citlivost (odvolání) versus 1-specifičnost
Pro vykreslení ROC křivky musíme nainstalovat knihovnu s názvem RORC. Můžeme najít v conda knihovna. Můžete zadat kód:
conda install -cr r-rocr –ano
ROC můžeme vykreslit pomocí funkcí forecast() a performance().
library(ROCR) ROCRpred <- prediction(predict, data_test$income) ROCRperf <- performance(ROCRpred, 'tpr', 'fpr') plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))
Vysvětlení kódu
- predikce(predikce, data_test$income): Knihovna ROCR potřebuje vytvořit objekt predikce pro transformaci vstupních dat
- performance(ROCRpred, 'tpr', 'fpr'): Vraťte dvě kombinace, které se mají vytvořit v grafu. Zde jsou konstruovány tpr a fpr. Pro přesnost vykreslování a vyvolání dohromady použijte „prec“, „rec“.
Výstup:
Krok 8) Vylepšit model
Můžete zkusit přidat do modelu nelinearitu s interakcí mezi nimi
- věk a hodiny.týdně
- pohlaví a hodiny.týdně.
K porovnání obou modelů musíte použít test skóre
formula_2 <- income~age: hours.per.week + gender: hours.per.week + . logit_2 <- glm(formula_2, data = data_train, family = 'binomial') predict_2 <- predict(logit_2, data_test, type = 'response') table_mat_2 <- table(data_test$income, predict_2 > 0.5) precision_2 <- precision(table_mat_2) recall_2 <- recall(table_mat_2) f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2)) f1_2
Výstup:
## [1] 0.6109181
Skóre je o něco vyšší než předchozí. Můžete pokračovat v práci na datech a snažit se překonat skóre.
Shrnutí
Funkci pro trénování logistické regrese můžeme shrnout v tabulce níže:
Balíček | Objektivní | funkce | Argument |
---|---|---|---|
- | Vytvořte vlakovou/testovací datovou sadu | create_train_set() | údaje, velikost, vlak |
glm | Trénujte zobecněný lineární model | glm() | vzorec, data, rodina* |
glm | Shrňte model | souhrn() | namontovaný model |
základna | Proveďte předpověď | předpovědět() | přizpůsobený model, datová sada, typ = 'odpověď' |
základna | Vytvořte matici zmatků | stůl() | y, předpovědět() |
základna | Vytvořte skóre přesnosti | součet(diag(tabulka())/součet(tabulka() | |
ROCR | Vytvoření ROC: Krok 1 Vytvořte předpověď | předpověď() | předpovídat(), y |
ROCR | Vytvořte ROC: Krok 2 Vytvořte výkon | výkon() | predikce(), 'tpr', 'fpr' |
ROCR | Vytvořte ROC: Krok 3 Vykreslete graf | spiknutí() | výkon() |
Ostatní GLM typy modelů jsou:
– binomický: (odkaz = „logit“)
– gaussovský: (odkaz = „identita“)
– Gamma: (odkaz = „inverzní“)
– inverse.gaussian: (odkaz = „1/mu^2“)
– poisson: (odkaz = „log“)
– kvazi: (odkaz = „identita“, rozptyl = „konstanta“)
– kvazibinomický: (link = “logit”)
– kvazipoisson: (odkaz = „log“)