GLM w R: Uogólniony model liniowy z przykładem
Co to jest regresja logistyczna?
Regresja logistyczna służy do przewidywania klasy, tj. prawdopodobieństwa. Regresja logistyczna może dokładnie przewidzieć wynik binarny.
Wyobraź sobie, że chcesz przewidzieć, czy pożyczka zostanie odrzucona/przyjęta na podstawie wielu atrybutów. Regresja logistyczna ma postać 0/1. y = 0 w przypadku odrzucenia pożyczki, y = 1 w przypadku jej przyjęcia.
Model regresji logistycznej różni się od modelu regresji liniowej pod dwoma względami.
- Po pierwsze, regresja logistyczna akceptuje jedynie dane dychotomiczne (binarne) jako zmienną zależną (tj. wektor 0 i 1).
- Po drugie, wynik mierzony jest następującą funkcją łączącą probabilistyczną zwaną esicy ze względu na kształt litery S.:
Wynik funkcji zawsze mieści się w przedziale od 0 do 1. Sprawdź obrazek poniżej
Funkcja sigmoidalna zwraca wartości od 0 do 1. Do zadania klasyfikacji potrzebujemy wyjścia dyskretnego o wartości 0 lub 1.
Aby przekonwertować ciągły przepływ na wartość dyskretną, możemy ustawić granicę decyzji na 0.5. Wszystkie wartości powyżej tego progu są klasyfikowane jako 1
Jak utworzyć uogólniony model wykładziny (GLM)
Użyjmy dorosły zestaw danych ilustrujących regresję logistyczną. „Dorosły” jest doskonałym zbiorem danych do zadania klasyfikacji. Celem jest przewidzenie, czy roczny dochód w dolarach danej osoby przekroczy 50.000 46,033. Zbiór danych zawiera XNUMX XNUMX obserwacji i dziesięć cech:
- wiek: wiek jednostki. Numeryczny
- edukacja: Poziom wykształcenia jednostki. Czynnik.
- stan cywilny: Maricałkowity status jednostki. Czynnik tj. osoba nigdy niebędąca w związku małżeńskim, żonaty-małżonek cywilny,…
- płeć: płeć jednostki. Czynnik, tj. mężczyzna lub kobieta
- dochód: Target zmienny. Dochód powyżej lub poniżej 50 tys. Współczynnik tj. >50K, <=50K
wśród innych
library(dplyr) data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv") glimpse(data_adult)
Wyjście:
Observations: 48,842 Variables: 10 $ x <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,... $ age <int> 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26... $ workclass <fctr> Private, Private, Local-gov, Private, ?, Private,... $ education <fctr> 11th, HS-grad, Assoc-acdm, Some-college, Some-col... $ educational.num <int> 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,... $ marital.status <fctr> Never-married, Married-civ-spouse, Married-civ-sp... $ race <fctr> Black, White, White, Black, White, White, Black, ... $ gender <fctr> Male, Male, Male, Male, Female, Male, Male, Male,... $ hours.per.week <int> 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39... $ income <fctr> <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5...
Postępujemy następująco:
- Krok 1: Sprawdź zmienne ciągłe
- Krok 2: Sprawdź zmienne czynnikowe
- Krok 3: Inżynieria funkcji
- Krok 4: Statystyka podsumowująca
- Krok 5: Trenuj/zestaw testowy
- Krok 6: Zbuduj model
- Krok 7: Oceń wydajność modelu
- krok 8: Ulepsz model
Twoim zadaniem jest przewidzenie, która osoba będzie miała przychód wyższy niż 50 tys.
W tym samouczku każdy krok zostanie szczegółowo opisany w celu przeprowadzenia analizy prawdziwego zbioru danych.
Krok 1) Sprawdź zmienne ciągłe
W pierwszym kroku można zobaczyć rozkład zmiennych ciągłych.
continuous <-select_if(data_adult, is.numeric) summary(continuous)
Objaśnienie kodu
- ciągłe <-select_if(data_adult, is.numeric): Użyj funkcjiselect_if() z biblioteki dplyr, aby wybrać tylko kolumny liczbowe
- podsumowanie(ciągłe): Wydrukuj statystykę podsumowującą
Wyjście:
## X age educational.num hours.per.week ## Min. : 1 Min. :17.00 Min. : 1.00 Min. : 1.00 ## 1st Qu.:11509 1st Qu.:28.00 1st Qu.: 9.00 1st Qu.:40.00 ## Median :23017 Median :37.00 Median :10.00 Median :40.00 ## Mean :23017 Mean :38.56 Mean :10.13 Mean :40.95 ## 3rd Qu.:34525 3rd Qu.:47.00 3rd Qu.:13.00 3rd Qu.:45.00 ## Max. :46033 Max. :90.00 Max. :16.00 Max. :99.00
Z powyższej tabeli widać, że dane mają zupełnie inną skalę, a liczba godzin na tydzień wykazuje duże odchylenia (np. spójrz na ostatni kwartyl i wartość maksymalną).
Można sobie z tym poradzić, wykonując dwa kroki:
- 1: Narysuj rozkład godzin w tygodniu
- 2: Standaryzacja zmiennych ciągłych
- Narysuj rozkład
Przyjrzyjmy się bliżej rozkładowi godzin w tygodniu
# Histogram with kernel density curve library(ggplot2) ggplot(continuous, aes(x = hours.per.week)) + geom_density(alpha = .2, fill = "#FF6666")
Wyjście:
Zmienna ma wiele wartości odstających i nie jest dobrze zdefiniowana. Możesz częściowo rozwiązać ten problem, usuwając górne 0.01 procent godzin w tygodniu.
Podstawowa składnia kwantyla:
quantile(variable, percentile) arguments: -variable: Select the variable in the data frame to compute the percentile -percentile: Can be a single value between 0 and 1 or multiple value. If multiple, use this format: `c(A,B,C, ...) - `A`,`B`,`C` and `...` are all integer from 0 to 1.
Obliczamy górne 2 procent percentyla
top_one_percent <- quantile(data_adult$hours.per.week, .99) top_one_percent
Objaśnienie kodu
- quantile(data_adult$hours.per.week, .99): Oblicz wartość 99 procent czasu pracy
Wyjście:
## 99% ## 80
98 procent populacji pracuje mniej niż 80 godzin tygodniowo.
Możesz pominąć obserwacje powyżej tego progu. Używasz filtra z dplyr biblioteka.
data_adult_drop <-data_adult %>% filter(hours.per.week<top_one_percent) dim(data_adult_drop)
Wyjście:
## [1] 45537 10
- Standaryzacja zmiennych ciągłych
Możesz ujednolicić każdą kolumnę, aby poprawić wydajność, ponieważ dane nie mają tej samej skali. Możesz użyć funkcji mutate_if z biblioteki dplyr. Podstawowa składnia to:
mutate_if(df, condition, funs(function)) arguments: -`df`: Data frame used to compute the function - `condition`: Statement used. Do not use parenthesis - funs(function): Return the function to apply. Do not use parenthesis for the function
Możesz ujednolicić kolumny liczbowe w następujący sposób:
data_adult_rescale <- data_adult_drop % > % mutate_if(is.numeric, funs(as.numeric(scale(.)))) head(data_adult_rescale)
Objaśnienie kodu
- mutate_if(is.numeric, funs(scale)): Warunek to tylko kolumna numeryczna, a funkcją jest skala
Wyjście:
## X age workclass education educational.num ## 1 -1.732680 -1.02325949 Private 11th -1.22106443 ## 2 -1.732605 -0.03969284 Private HS-grad -0.43998868 ## 3 -1.732530 -0.79628257 Local-gov Assoc-acdm 0.73162494 ## 4 -1.732455 0.41426100 Private Some-college -0.04945081 ## 5 -1.732379 -0.34232873 Private 10th -1.61160231 ## 6 -1.732304 1.85178149 Self-emp-not-inc Prof-school 1.90323857 ## marital.status race gender hours.per.week income ## 1 Never-married Black Male -0.03995944 <=50K ## 2 Married-civ-spouse White Male 0.86863037 <=50K ## 3 Married-civ-spouse White Male -0.03995944 >50K ## 4 Married-civ-spouse Black Male -0.03995944 >50K ## 5 Never-married White Male -0.94854924 <=50K ## 6 Married-civ-spouse White Male -0.76683128 >50K
Krok 2) Sprawdź zmienne czynnikowe
Ten krok ma dwa cele:
- Sprawdź poziom w każdej kolumnie kategorycznej
- Zdefiniuj nowe poziomy
Podzielimy ten krok na trzy części:
- Wybierz kolumny kategorialne
- Zapisz wykres słupkowy każdej kolumny na liście
- Wydrukuj wykresy
Możemy wybrać kolumny współczynników za pomocą poniższego kodu:
# Select categorical column factor <- data.frame(select_if(data_adult_rescale, is.factor)) ncol(factor)
Objaśnienie kodu
- data.frame(select_if(data_adult, is.factor)): Przechowujemy kolumny współczynników we współczynniku w typie ramki danych. Biblioteka ggplot2 wymaga obiektu ramki danych.
Wyjście:
## [1] 6
Zbiór danych zawiera 6 zmiennych kategorycznych
Drugi krok jest bardziej wykwalifikowany. Chcesz wykreślić wykres słupkowy dla każdej kolumny współczynnika ramki danych. Wygodniej jest zautomatyzować proces, szczególnie w sytuacji dużej liczby kolumn.
library(ggplot2) # Create graph for each column graph <- lapply(names(factor), function(x) ggplot(factor, aes(get(x))) + geom_bar() + theme(axis.text.x = element_text(angle = 90)))
Objaśnienie kodu
- lapply(): Użyj funkcji lapply(), aby przekazać funkcję we wszystkich kolumnach zbioru danych. Dane wyjściowe przechowujesz na liście
- funkcja(x): Funkcja będzie przetwarzana dla każdego x. Tutaj x to kolumny
- ggplot(factor, aes(get(x))) + geom_bar()+ topic(axis.text.x = element_text(angle = 90)): Utwórz wykres słupkowy dla każdego elementu x. Uwaga, aby zwrócić x jako kolumnę, musisz umieścić ją w funkcji get()
Ostatni krok jest stosunkowo łatwy. Chcesz wydrukować 6 wykresów.
# Print the graph graph
Wyjście:
## [[1]]
## ## [[2]]
## ## [[3]]
## ## [[4]]
## ## [[5]]
## ## [[6]]
Uwaga: Użyj przycisku Dalej, aby przejść do następnego wykresu
Krok 3) Inżynieria cech
Przekształcenie edukacji
Z powyższego wykresu widać, że zmienna edukacja ma 16 poziomów. Jest to znaczne, a niektóre poziomy mają stosunkowo małą liczbę obserwacji. Jeśli chcesz zwiększyć ilość informacji, które możesz uzyskać z tej zmiennej, możesz przekształcić ją na wyższy poziom. Mianowicie tworzycie większe grupy o podobnym poziomie wykształcenia. Na przykład niski poziom wykształcenia spowoduje porzucenie nauki. Wyższe poziomy edukacji zostaną zmienione na mistrzowskie.
Oto szczegóły:
Stary poziom | Nowy poziom |
---|---|
Przedszkole | spadkowicz |
10 | Spadkowicz |
11 | Spadkowicz |
12 | Spadkowicz |
1-4 | Spadkowicz |
5th-6th | Spadkowicz |
7th-8th | Spadkowicz |
9 | Spadkowicz |
Stopień HS | Wysoki Grad |
Uczelnia | Społeczność |
Assoc-acdm | Społeczność |
doc | Społeczność |
Doktorantów | Doktorantów |
Masters | Masters |
Szkoła prof | Masters |
Doktorat | Dr |
recast_data <- data_adult_rescale % > % select(-X) % > % mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community", ifelse(education == "Bachelors", "Bachelors", ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))
Objaśnienie kodu
- Używamy czasownika mutate z biblioteki dplyr. Stwierdzeniem ifelse zmieniamy wartości edukacji
W poniższej tabeli tworzysz statystyki podsumowujące, aby zobaczyć, ile lat edukacji (wartość z) potrzeba, aby uzyskać tytuł licencjata, magistra lub doktora.
recast_data % > % group_by(education) % > % summarize(average_educ_year = mean(educational.num), count = n()) % > % arrange(average_educ_year)
Wyjście:
## # A tibble: 6 x 3 ## education average_educ_year count ## <fctr> <dbl> <int> ## 1 dropout -1.76147258 5712 ## 2 HighGrad -0.43998868 14803 ## 3 Community 0.09561361 13407 ## 4 Bachelors 1.12216282 7720 ## 5 Master 1.60337381 3338 ## 6 PhD 2.29377644 557
Przerobić Maristatus tal
Można również utworzyć niższe poziomy dla stanu cywilnego. W poniższym kodzie zmieniasz poziom w następujący sposób:
Stary poziom | Nowy poziom |
---|---|
Nigdy nie żonaty | Niezamężny |
Żonaty-małżonek-nieobecny | Niezamężny |
Żonaty-AF-małżonek | Żonaty |
Żonaty-obywatelski współmałżonek | |
Rozdzielony | Rozdzielony |
Rozwiedziony | |
Wdowy | Wdowa |
# Change level marry recast_data <- recast_data % > % mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))
Możesz sprawdzić liczbę osób w każdej grupie.
table(recast_data$marital.status)
Wyjście:
## ## Married Not_married Separated Widow ## 21165 15359 7727 1286
Krok 4) Statystyka podsumowująca
Czas sprawdzić statystyki dotyczące naszych zmiennych docelowych. Na poniższym wykresie liczysz odsetek osób zarabiających powyżej 50 tys., biorąc pod uwagę płeć.
# Plot gender income ggplot(recast_data, aes(x = gender, fill = income)) + geom_bar(position = "fill") + theme_classic()
Wyjście:
Następnie sprawdź, czy pochodzenie danej osoby wpływa na jej zarobki.
# Plot origin income ggplot(recast_data, aes(x = race, fill = income)) + geom_bar(position = "fill") + theme_classic() + theme(axis.text.x = element_text(angle = 90))
Wyjście:
Liczba godzin pracy według płci.
# box plot gender working time ggplot(recast_data, aes(x = gender, y = hours.per.week)) + geom_boxplot() + stat_summary(fun.y = mean, geom = "point", size = 3, color = "steelblue") + theme_classic()
Wyjście:
Wykres pudełkowy potwierdza, że rozkład czasu pracy pasuje do różnych grup. Na wykresie pudełkowym obie płcie nie mają jednorodnych obserwacji.
Możesz sprawdzić gęstość tygodniowego czasu pracy według rodzaju wykształcenia. Dystrybucje mają wiele różnych typów. Prawdopodobnie można to wytłumaczyć rodzajem umowy w USA.
# Plot distribution working time by education ggplot(recast_data, aes(x = hours.per.week)) + geom_density(aes(color = education), alpha = 0.5) + theme_classic()
Objaśnienie kodu
- ggplot(recast_data, aes( x= hours.per.week)): Wykres gęstości wymaga tylko jednej zmiennej
- geom_density(aes(color = edukacja), alfa =0.5): Obiekt geometryczny kontrolujący gęstość
Wyjście:
Aby potwierdzić swoje przemyślenia, możesz wykonać operację w jedną stronę Test ANOVA:
anova <- aov(hours.per.week~education, recast_data) summary(anova)
Wyjście:
## Df Sum Sq Mean Sq F value Pr(>F) ## education 5 1552 310.31 321.2 <2e-16 *** ## Residuals 45531 43984 0.97 ## --- ## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Test ANOVA potwierdza różnice w średnich między grupami.
Nieliniowość
Zanim uruchomisz model, możesz sprawdzić, czy liczba przepracowanych godzin jest związana z wiekiem.
library(ggplot2) ggplot(recast_data, aes(x = age, y = hours.per.week)) + geom_point(aes(color = income), size = 0.5) + stat_smooth(method = 'lm', formula = y~poly(x, 2), se = TRUE, aes(color = income)) + theme_classic()
Objaśnienie kodu
- ggplot(recast_data, aes(x = wiek, y = godziny.na.tydzień)): Ustaw estetykę wykresu
- geom_point(aes(color= dochód), size =0.5): Skonstruuj wykres punktowy
- stat_smooth(): Dodaj linię trendu z następującymi argumentami:
- method='lm': Wykreśl dopasowaną wartość, jeśli regresji liniowej
- formuła = y~poly(x,2): Dopasuj regresję wielomianową
- se = TRUE: Dodaj błąd standardowy
- aes(color= dochód): Rozbij model według dochodu
Wyjście:
Krótko mówiąc, możesz przetestować warunki interakcji w modelu, aby wychwycić efekt nieliniowości pomiędzy tygodniowym czasem pracy a innymi cechami. Ważne jest, aby wykryć, w jakich warunkach czas pracy jest różny.
Korelacja
Następną kontrolą jest wizualizacja korelacji pomiędzy zmiennymi. Konwertujesz typ poziomu współczynnika na numeryczny, aby można było wykreślić mapę cieplną zawierającą współczynnik korelacji obliczony metodą Spearmana.
library(GGally) # Convert data to numeric corr <- data.frame(lapply(recast_data, as.integer)) # Plot the graphggcorr(corr, method = c("pairwise", "spearman"), nbreaks = 6, hjust = 0.8, label = TRUE, label_size = 3, color = "grey50")
Objaśnienie kodu
- data.frame(lapply(recast_data,as.integer)): Konwertuj dane na numeryczne
- ggcorr() tworzy mapę cieplną z następującymi argumentami:
- metoda: Metoda obliczania korelacji
- nbreaks = 6: Liczba przerw
- hjust = 0.8: Pozycja kontrolna nazwy zmiennej na wykresie
- etykieta = TRUE: Dodaj etykiety na środku okien
- label_size = 3: Rozmiar etykiet
- color = „grey50”): Kolor etykiety
Wyjście:
Krok 5) Trenuj/zestaw testowy
Każdy nadzorowany uczenie maszynowe zadanie wymaga podziału danych pomiędzy zestawem pociągowym i zestawem testowym. Możesz użyć „funkcji” utworzonej w innych samouczkach uczenia się nadzorowanego, aby utworzyć zestaw pociągowy/testowy.
set.seed(1234) create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample <- 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } } data_train <- create_train_test(recast_data, 0.8, train = TRUE) data_test <- create_train_test(recast_data, 0.8, train = FALSE) dim(data_train)
Wyjście:
## [1] 36429 9
dim(data_test)
Wyjście:
## [1] 9108 9
Krok 6) Zbuduj model
Aby zobaczyć, jak działa algorytm, użyj pakietu glm(). The Uogólniony model liniowy to zbiór modeli. Podstawowa składnia to:
glm(formula, data=data, family=linkfunction() Argument: - formula: Equation used to fit the model- data: dataset used - Family: - binomial: (link = "logit") - gaussian: (link = "identity") - Gamma: (link = "inverse") - inverse.gaussian: (link = "1/mu^2") - poisson: (link = "log") - quasi: (link = "identity", variance = "constant") - quasibinomial: (link = "logit") - quasipoisson: (link = "log")
Jesteś gotowy do oszacowania modelu logistycznego, aby podzielić poziom dochodu pomiędzy zestawem cech.
formula <- income~. logit <- glm(formula, data = data_train, family = 'binomial') summary(logit)
Objaśnienie kodu
- formuła <- dochód ~.: Utwórz model pasujący
- logit <- glm(formula, data = data_train, rodzina = 'dwumianowy'): Dopasuj model logistyczny (rodzina = 'dwumianowy') za pomocą danych data_train.
- podsumowanie(logit): Wydrukuj podsumowanie modelu
Wyjście:
## ## Call: ## glm(formula = formula, family = "binomial", data = data_train) ## ## Deviance Residuals: ## Min 1Q Median 3Q Max ## -2.6456 -0.5858 -0.2609 -0.0651 3.1982 ## ## Coefficients: ## Estimate Std. Error z value Pr(>|z|) ## (Intercept) 0.07882 0.21726 0.363 0.71675 ## age 0.41119 0.01857 22.146 < 2e-16 *** ## workclassLocal-gov -0.64018 0.09396 -6.813 9.54e-12 *** ## workclassPrivate -0.53542 0.07886 -6.789 1.13e-11 *** ## workclassSelf-emp-inc -0.07733 0.10350 -0.747 0.45499 ## workclassSelf-emp-not-inc -1.09052 0.09140 -11.931 < 2e-16 *** ## workclassState-gov -0.80562 0.10617 -7.588 3.25e-14 *** ## workclassWithout-pay -1.09765 0.86787 -1.265 0.20596 ## educationCommunity -0.44436 0.08267 -5.375 7.66e-08 *** ## educationHighGrad -0.67613 0.11827 -5.717 1.08e-08 *** ## educationMaster 0.35651 0.06780 5.258 1.46e-07 *** ## educationPhD 0.46995 0.15772 2.980 0.00289 ** ## educationdropout -1.04974 0.21280 -4.933 8.10e-07 *** ## educational.num 0.56908 0.07063 8.057 7.84e-16 *** ## marital.statusNot_married -2.50346 0.05113 -48.966 < 2e-16 *** ## marital.statusSeparated -2.16177 0.05425 -39.846 < 2e-16 *** ## marital.statusWidow -2.22707 0.12522 -17.785 < 2e-16 *** ## raceAsian-Pac-Islander 0.08359 0.20344 0.411 0.68117 ## raceBlack 0.07188 0.19330 0.372 0.71001 ## raceOther 0.01370 0.27695 0.049 0.96054 ## raceWhite 0.34830 0.18441 1.889 0.05894 . ## genderMale 0.08596 0.04289 2.004 0.04506 * ## hours.per.week 0.41942 0.01748 23.998 < 2e-16 *** ## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 ## ## (Dispersion parameter for binomial family taken to be 1) ## ## Null deviance: 40601 on 36428 degrees of freedom ## Residual deviance: 27041 on 36406 degrees of freedom ## AIC: 27087 ## ## Number of Fisher Scoring iterations: 6
Podsumowanie naszego modelu ujawnia ciekawe informacje. Wydajność regresji logistycznej ocenia się za pomocą określonych kluczowych wskaźników.
- AIC (Kryteria informacyjne Akaike): Jest to odpowiednik R2 w regresji logistycznej. Mierzy dopasowanie, gdy do liczby parametrów zostanie zastosowana kara. Mniejszy AIC wartości wskazują, że model jest bliższy prawdy.
- Odchylenie zerowe: Pasuje do modelu tylko z wyrazem wolnym. Stopień swobody wynosi n-1. Możemy to zinterpretować jako wartość chi-kwadrat (wartość dopasowana różni się od testowania hipotezy wartości rzeczywistej).
- Odchylenie resztkowe: Model ze wszystkimi zmiennymi. Jest to również interpretowane jako testowanie hipotezy chi-kwadrat.
- Liczba iteracji punktacji Fishera: Liczba iteracji przed zbieżnością.
Dane wyjściowe funkcji glm() są przechowywane na liście. Poniższy kod pokazuje wszystkie elementy dostępne w zmiennej logit, którą skonstruowaliśmy w celu oceny regresji logistycznej.
# Lista jest bardzo długa, wydrukuj tylko pierwsze trzy elementy
lapply(logit, class)[1:3]
Wyjście:
## $coefficients ## [1] "numeric" ## ## $residuals ## [1] "numeric" ## ## $fitted.values ## [1] "numeric"
Każdą wartość można wyodrębnić za pomocą znaku $, po którym następuje nazwa metryki. Na przykład zapisałeś model jako logit. Aby wyodrębnić kryteria AIC, użyj:
logit$aic
Wyjście:
## [1] 27086.65
Krok 7) Oceń wydajność modelu
Macierz zamieszania
matryca zamieszania jest lepszym wyborem do oceny skuteczności klasyfikacji w porównaniu z różnymi metrykami, które widziałeś wcześniej. Ogólna koncepcja polega na tym, aby policzyć, ile razy prawdziwe przypadki zostały sklasyfikowane jako fałszywe.
Aby obliczyć macierz zamieszania, należy najpierw dysponować zestawem przewidywań, aby można je było porównać z rzeczywistymi wartościami docelowymi.
predict <- predict(logit, data_test, type = 'response') # confusion matrix table_mat <- table(data_test$income, predict > 0.5) table_mat
Objaśnienie kodu
- przewidywanie(logit,data_test, type = 'odpowiedź'): Oblicz przewidywanie na zestawie testowym. Ustaw typ = „odpowiedź”, aby obliczyć prawdopodobieństwo odpowiedzi.
- table(data_test$income, przewidywanie > 0.5): Oblicz macierz zamieszania. przewidywanie > 0.5 oznacza, że zwraca 1, jeśli przewidywane prawdopodobieństwa są większe niż 0.5, w przeciwnym razie 0.
Wyjście:
## ## FALSE TRUE ## <=50K 6310 495 ## >50K 1074 1229
Każdy wiersz w macierzy zamieszania reprezentuje rzeczywisty cel, podczas gdy każda kolumna reprezentuje przewidywany cel. Pierwszy wiersz tej macierzy uwzględnia dochody mniejsze niż 50 tys. (klasa Fałsz): 6241 zostało poprawnie zakwalifikowanych do osób o dochodach niższych niż 50 tys. (Prawdziwy negatyw), natomiast pozostała została błędnie sklasyfikowana jako powyżej 50 tys. (Fałszywie pozytywne). Drugi rząd uwzględnia dochody powyżej 50 tys., klasa dodatnia to 1229 (Prawdziwie pozytywne), podczas, gdy Prawdziwy negatyw było 1074.
Można obliczyć model precyzja poprzez zsumowanie prawdziwie dodatnich i prawdziwie ujemnych wartości z całej obserwacji
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test
Objaśnienie kodu
- sum(diag(table_mat)): Suma przekątnej
- sum(table_mat): Suma macierzy.
Wyjście:
## [1] 0.8277339
Model wydaje się mieć jeden problem: przeszacowuje liczbę wyników fałszywie negatywnych. Nazywa się to paradoks testu dokładności. Ustaliliśmy, że trafność to stosunek poprawnych przewidywań do całkowitej liczby przypadków. Możemy mieć stosunkowo wysoką dokładność, ale model jest bezużyteczny. Dzieje się tak, gdy istnieje klasa dominująca. Jeśli spojrzysz wstecz na matrycę zamieszania, zobaczysz, że większość przypadków jest klasyfikowana jako prawdziwie negatywna. Wyobraź sobie teraz, że model sklasyfikował wszystkie klasy jako ujemne (tzn. poniżej 50 tys.). Uzyskałbyś dokładność na poziomie 75 procent (6718/6718+2257). Twój model działa lepiej, ale ma trudności z odróżnieniem prawdziwego pozytywu od prawdziwego negatywu.
W takiej sytuacji lepiej jest mieć bardziej zwięzłe dane. Możemy przyjrzeć się:
- Precyzja = TP/(TP+FP)
- Przywołanie = TP/(TP+FN)
Precyzja kontra przypominanie
Detaliczność sprawdza dokładność pozytywnej prognozy. Odwołanie jest stosunkiem pozytywnych przypadków, które zostały poprawnie wykryte przez klasyfikator;
Można skonstruować dwie funkcje, aby obliczyć te dwie metryki
- Konstruuj precyzję
precision <- function(matrix) { # True positive tp <- matrix[2, 2] # false positive fp <- matrix[1, 2] return (tp / (tp + fp)) }
Objaśnienie kodu
- mat[1,1]: Zwraca pierwszą komórkę pierwszej kolumny ramki danych, tj. wartość dodatnią
- mata[1,2]; Zwróć pierwszą komórkę drugiej kolumny ramki danych, tj. wynik fałszywie dodatni
recall <- function(matrix) { # true positive tp <- matrix[2, 2]# false positive fn <- matrix[2, 1] return (tp / (tp + fn)) }
Objaśnienie kodu
- mat[1,1]: Zwraca pierwszą komórkę pierwszej kolumny ramki danych, tj. wartość dodatnią
- mata[2,1]; Zwróć drugą komórkę pierwszej kolumny ramki danych, tj. wartość fałszywie ujemną
Możesz przetestować swoje funkcje
prec <- precision(table_mat) prec rec <- recall(table_mat) rec
Wyjście:
## [1] 0.712877 ## [2] 0.5336518
Kiedy model podaje, że jest to osoba powyżej 50 tys., ma rację tylko w 54% przypadków, a w 50% przypadków może wskazać osoby powyżej 72 tys.
Możesz stworzyć Wynik oparty na precyzji i zapamiętywaniu. The
jest średnią harmoniczną tych dwóch wskaźników, co oznacza, że przypisuje większą wagę niższym wartościom.
f1 <- 2 * ((prec * rec) / (prec + rec)) f1
Wyjście:
## [1] 0.6103799
Kompromis precyzja kontra przypomnienie
Niemożliwe jest uzyskanie zarówno wysokiej precyzji, jak i wysokiej pamięci.
Jeśli zwiększymy precyzję, łatwiej będzie przewidzieć właściwą osobę, ale wiele z nich przeoczymy (niższa pamięć). W niektórych sytuacjach wolimy większą precyzję niż przypominanie. Istnieje wklęsła zależność pomiędzy precyzją a przypominaniem.
- Wyobraź sobie, że musisz przewidzieć, czy pacjent ma chorobę. Chcesz być jak najbardziej precyzyjny.
- Jeśli chcesz wykryć potencjalnych oszustów na ulicy za pomocą rozpoznawania twarzy, lepiej będzie wychwycić wiele osób oznaczonych jako oszuści, nawet jeśli dokładność jest niska. Policja będzie mogła zwolnić osobę, która nie dopuściła się oszustwa.
Krzywa ROC
Odbiornik OperaCharakterystyka krzywa to kolejne popularne narzędzie używane w klasyfikacji binarnej. Jest bardzo podobna do krzywej precyzji/przypomnienia, ale zamiast wykreślać precyzję w funkcji przypominania, krzywa ROC pokazuje prawdziwie dodatni współczynnik (tj. przypominanie) w porównaniu z fałszywie dodatnim współczynnikiem. Odsetek wyników fałszywie pozytywnych to stosunek przypadków negatywnych, które zostały błędnie sklasyfikowane jako pozytywne. Jest ona równa jeden minus rzeczywista stopa ujemna. Prawdziwie ujemna stopa jest również nazywana specyficzność. Stąd wykresy krzywej ROC wrażliwość (przypomnijmy) w porównaniu ze specyficznością 1
Aby wykreślić krzywą ROC, musimy zainstalować bibliotekę o nazwie RORC. Możemy znaleźć w Condzie biblioteka. Możesz wpisać kod:
conda install -cr r-rocr – tak
Możemy wykreślić ROC za pomocą funkcji przewidywania() i wydajności().
library(ROCR) ROCRpred <- prediction(predict, data_test$income) ROCRperf <- performance(ROCRpred, 'tpr', 'fpr') plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))
Objaśnienie kodu
- przewidywanie(predykt, data_test$income): Biblioteka ROCR musi utworzyć obiekt przewidywania, aby przekształcić dane wejściowe
- wydajność(ROCRpred, 'tpr','fpr'): Zwróć dwie kombinacje, które chcesz przedstawić na wykresie. Tutaj konstruowane są tpr i fpr. Aby uzyskać precyzję wydruku i przywołanie razem, użyj „prec”, „rec”.
Wyjście:
Krok 8) Ulepsz model
Możesz spróbować dodać nieliniowość do modelu z interakcją pomiędzy
- wiek i godziny.na.tydzień
- płeć i godziny.tygodniowo.
Aby porównać oba modele, należy skorzystać z testu punktacji
formula_2 <- income~age: hours.per.week + gender: hours.per.week + . logit_2 <- glm(formula_2, data = data_train, family = 'binomial') predict_2 <- predict(logit_2, data_test, type = 'response') table_mat_2 <- table(data_test$income, predict_2 > 0.5) precision_2 <- precision(table_mat_2) recall_2 <- recall(table_mat_2) f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2)) f1_2
Wyjście:
## [1] 0.6109181
Wynik jest nieco wyższy niż poprzedni. Możesz dalej pracować nad danymi, próbując pobić wynik.
Podsumowanie
Funkcję służącą do trenowania regresji logistycznej możemy podsumować w poniższej tabeli:
Pakiet | Cel | Funkcjonować | Argument |
---|---|---|---|
- | Utwórz zbiór danych pociągu/testu | utwórz_pociąg_set() | dane, rozmiar, pociąg |
glm | Trenuj uogólniony model liniowy | glm() | formuła, dane, rodzina* |
glm | Podsumuj model | streszczenie() | dopasowany model |
baza | Przewidzieć | przepowiadać, wywróżyć() | dopasowany model, zbiór danych, typ = „odpowiedź” |
baza | Utwórz macierz zamieszania | tabela() | y, przewiduj() |
baza | Utwórz wynik dokładności | suma(diag(tabela())/suma(tabela() | |
ROCR | Utwórz ROC: Krok 1 Utwórz prognozę | prognoza() | przewidywać(), j |
ROCR | Utwórz ROC: Krok 2 Stwórz wydajność | wydajność() | przewidywanie(), „tpr”, „fpr” |
ROCR | Utwórz ROC: Krok 3 Narysuj wykres | wątek() | wydajność() |
Inny GLM rodzaje modeli to:
– dwumian: (link = „logit”)
– gaussowski: (link = „tożsamość”)
– Gamma: (link = „odwrotność”)
– odwrotność.gaussa: (link = „1/mu^2”)
– poisson: (link = „log”)
– quasi: (link = „tożsamość”, wariancja = „stała”)
– quasibinomial: (link = „logit”)
– quasipoisson: (link = „log”)