GLM w R: Uogólniony model liniowy z przykładem

Co to jest regresja logistyczna?

Regresja logistyczna służy do przewidywania klasy, tj. prawdopodobieństwa. Regresja logistyczna może dokładnie przewidzieć wynik binarny.

Wyobraź sobie, że chcesz przewidzieć, czy pożyczka zostanie odrzucona/przyjęta na podstawie wielu atrybutów. Regresja logistyczna ma postać 0/1. y = 0 w przypadku odrzucenia pożyczki, y = 1 w przypadku jej przyjęcia.

Model regresji logistycznej różni się od modelu regresji liniowej pod dwoma względami.

  • Po pierwsze, regresja logistyczna akceptuje jedynie dane dychotomiczne (binarne) jako zmienną zależną (tj. wektor 0 i 1).
  • Po drugie, wynik mierzony jest następującą funkcją łączącą probabilistyczną zwaną esicy ze względu na kształt litery S.:

Regresja logistyczna

Wynik funkcji zawsze mieści się w przedziale od 0 do 1. Sprawdź obrazek poniżej

Regresja logistyczna

Funkcja sigmoidalna zwraca wartości od 0 do 1. Do zadania klasyfikacji potrzebujemy wyjścia dyskretnego o wartości 0 lub 1.

Aby przekonwertować ciągły przepływ na wartość dyskretną, możemy ustawić granicę decyzji na 0.5. Wszystkie wartości powyżej tego progu są klasyfikowane jako 1

Regresja logistyczna

Jak utworzyć uogólniony model wykładziny (GLM)

Użyjmy dorosły zestaw danych ilustrujących regresję logistyczną. „Dorosły” jest doskonałym zbiorem danych do zadania klasyfikacji. Celem jest przewidzenie, czy roczny dochód w dolarach danej osoby przekroczy 50.000 46,033. Zbiór danych zawiera XNUMX XNUMX obserwacji i dziesięć cech:

  • wiek: wiek jednostki. Numeryczny
  • edukacja: Poziom wykształcenia jednostki. Czynnik.
  • stan cywilny: Maricałkowity status jednostki. Czynnik tj. osoba nigdy niebędąca w związku małżeńskim, żonaty-małżonek cywilny,…
  • płeć: płeć jednostki. Czynnik, tj. mężczyzna lub kobieta
  • dochód: Target zmienny. Dochód powyżej lub poniżej 50 tys. Współczynnik tj. >50K, <=50K

wśród innych

library(dplyr)
data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv")
glimpse(data_adult)

Wyjście:

Observations: 48,842
Variables: 10
$ x               <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,...
$ age             <int> 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26...
$ workclass       <fctr> Private, Private, Local-gov, Private, ?, Private,...
$ education       <fctr> 11th, HS-grad, Assoc-acdm, Some-college, Some-col...
$ educational.num <int> 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,...
$ marital.status  <fctr> Never-married, Married-civ-spouse, Married-civ-sp...
$ race            <fctr> Black, White, White, Black, White, White, Black, ...
$ gender          <fctr> Male, Male, Male, Male, Female, Male, Male, Male,...
$ hours.per.week  <int> 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39...
$ income          <fctr> <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5...

Postępujemy następująco:

  • Krok 1: Sprawdź zmienne ciągłe
  • Krok 2: Sprawdź zmienne czynnikowe
  • Krok 3: Inżynieria funkcji
  • Krok 4: Statystyka podsumowująca
  • Krok 5: Trenuj/zestaw testowy
  • Krok 6: Zbuduj model
  • Krok 7: Oceń wydajność modelu
  • krok 8: Ulepsz model

Twoim zadaniem jest przewidzenie, która osoba będzie miała przychód wyższy niż 50 tys.

W tym samouczku każdy krok zostanie szczegółowo opisany w celu przeprowadzenia analizy prawdziwego zbioru danych.

Krok 1) Sprawdź zmienne ciągłe

W pierwszym kroku można zobaczyć rozkład zmiennych ciągłych.

continuous <-select_if(data_adult, is.numeric)
summary(continuous)

Objaśnienie kodu

  • ciągłe <-select_if(data_adult, is.numeric): Użyj funkcjiselect_if() z biblioteki dplyr, aby wybrać tylko kolumny liczbowe
  • podsumowanie(ciągłe): Wydrukuj statystykę podsumowującą

Wyjście:

##        X              age        educational.num hours.per.week 
##  Min.   :    1   Min.   :17.00   Min.   : 1.00   Min.   : 1.00  
##  1st Qu.:11509   1st Qu.:28.00   1st Qu.: 9.00   1st Qu.:40.00  
##  Median :23017   Median :37.00   Median :10.00   Median :40.00  
##  Mean   :23017   Mean   :38.56   Mean   :10.13   Mean   :40.95  
##  3rd Qu.:34525   3rd Qu.:47.00   3rd Qu.:13.00   3rd Qu.:45.00  
##  Max.   :46033   Max.   :90.00   Max.   :16.00   Max.   :99.00	

Z powyższej tabeli widać, że dane mają zupełnie inną skalę, a liczba godzin na tydzień wykazuje duże odchylenia (np. spójrz na ostatni kwartyl i wartość maksymalną).

Można sobie z tym poradzić, wykonując dwa kroki:

  • 1: Narysuj rozkład godzin w tygodniu
  • 2: Standaryzacja zmiennych ciągłych
  1. Narysuj rozkład

Przyjrzyjmy się bliżej rozkładowi godzin w tygodniu

# Histogram with kernel density curve
library(ggplot2)
ggplot(continuous, aes(x = hours.per.week)) +
    geom_density(alpha = .2, fill = "#FF6666")

Wyjście:

Sprawdź zmienne ciągłe

Zmienna ma wiele wartości odstających i nie jest dobrze zdefiniowana. Możesz częściowo rozwiązać ten problem, usuwając górne 0.01 procent godzin w tygodniu.

Podstawowa składnia kwantyla:

quantile(variable, percentile)
arguments:
-variable:  Select the variable in the data frame to compute the percentile
-percentile:  Can be a single value between 0 and 1 or multiple value. If multiple, use this format:  `c(A,B,C, ...)
- `A`,`B`,`C` and `...` are all integer from 0 to 1.

Obliczamy górne 2 procent percentyla

top_one_percent <- quantile(data_adult$hours.per.week, .99)
top_one_percent

Objaśnienie kodu

  • quantile(data_adult$hours.per.week, .99): Oblicz wartość 99 procent czasu pracy

Wyjście:

## 99% 
##  80

98 procent populacji pracuje mniej niż 80 godzin tygodniowo.

Możesz pominąć obserwacje powyżej tego progu. Używasz filtra z dplyr biblioteka.

data_adult_drop <-data_adult %>%
filter(hours.per.week<top_one_percent)
dim(data_adult_drop)

Wyjście:

## [1] 45537    10
  1. Standaryzacja zmiennych ciągłych

Możesz ujednolicić każdą kolumnę, aby poprawić wydajność, ponieważ dane nie mają tej samej skali. Możesz użyć funkcji mutate_if z biblioteki dplyr. Podstawowa składnia to:

mutate_if(df, condition, funs(function))
arguments:
-`df`: Data frame used to compute the function
- `condition`: Statement used. Do not use parenthesis
- funs(function):  Return the function to apply. Do not use parenthesis for the function

Możesz ujednolicić kolumny liczbowe w następujący sposób:

data_adult_rescale <- data_adult_drop % > %
	mutate_if(is.numeric, funs(as.numeric(scale(.))))
head(data_adult_rescale)

Objaśnienie kodu

  • mutate_if(is.numeric, funs(scale)): Warunek to tylko kolumna numeryczna, a funkcją jest skala

Wyjście:

##           X         age        workclass    education educational.num
## 1 -1.732680 -1.02325949          Private         11th     -1.22106443
## 2 -1.732605 -0.03969284          Private      HS-grad     -0.43998868
## 3 -1.732530 -0.79628257        Local-gov   Assoc-acdm      0.73162494
## 4 -1.732455  0.41426100          Private Some-college     -0.04945081
## 5 -1.732379 -0.34232873          Private         10th     -1.61160231
## 6 -1.732304  1.85178149 Self-emp-not-inc  Prof-school      1.90323857
##       marital.status  race gender hours.per.week income
## 1      Never-married Black   Male    -0.03995944  <=50K
## 2 Married-civ-spouse White   Male     0.86863037  <=50K
## 3 Married-civ-spouse White   Male    -0.03995944   >50K
## 4 Married-civ-spouse Black   Male    -0.03995944   >50K
## 5      Never-married White   Male    -0.94854924  <=50K
## 6 Married-civ-spouse White   Male    -0.76683128   >50K

Krok 2) Sprawdź zmienne czynnikowe

Ten krok ma dwa cele:

  • Sprawdź poziom w każdej kolumnie kategorycznej
  • Zdefiniuj nowe poziomy

Podzielimy ten krok na trzy części:

  • Wybierz kolumny kategorialne
  • Zapisz wykres słupkowy każdej kolumny na liście
  • Wydrukuj wykresy

Możemy wybrać kolumny współczynników za pomocą poniższego kodu:

# Select categorical column
factor <- data.frame(select_if(data_adult_rescale, is.factor))
	ncol(factor)

Objaśnienie kodu

  • data.frame(select_if(data_adult, is.factor)): Przechowujemy kolumny współczynników we współczynniku w typie ramki danych. Biblioteka ggplot2 wymaga obiektu ramki danych.

Wyjście:

## [1] 6

Zbiór danych zawiera 6 zmiennych kategorycznych

Drugi krok jest bardziej wykwalifikowany. Chcesz wykreślić wykres słupkowy dla każdej kolumny współczynnika ramki danych. Wygodniej jest zautomatyzować proces, szczególnie w sytuacji dużej liczby kolumn.

library(ggplot2)
# Create graph for each column
graph <- lapply(names(factor),
    function(x) 
	ggplot(factor, aes(get(x))) +
		geom_bar() +
		theme(axis.text.x = element_text(angle = 90)))

Objaśnienie kodu

  • lapply(): Użyj funkcji lapply(), aby przekazać funkcję we wszystkich kolumnach zbioru danych. Dane wyjściowe przechowujesz na liście
  • funkcja(x): Funkcja będzie przetwarzana dla każdego x. Tutaj x to kolumny
  • ggplot(factor, aes(get(x))) + geom_bar()+ topic(axis.text.x = element_text(angle = 90)): Utwórz wykres słupkowy dla każdego elementu x. Uwaga, aby zwrócić x jako kolumnę, musisz umieścić ją w funkcji get()

Ostatni krok jest stosunkowo łatwy. Chcesz wydrukować 6 wykresów.

# Print the graph
graph

Wyjście:

## [[1]]

Sprawdź zmienne współczynników

## ## [[2]]

Sprawdź zmienne współczynników

## ## [[3]]

Sprawdź zmienne współczynników

## ## [[4]]

Sprawdź zmienne współczynników

## ## [[5]]

Sprawdź zmienne współczynników

## ## [[6]]

Sprawdź zmienne współczynników

Uwaga: Użyj przycisku Dalej, aby przejść do następnego wykresu

Sprawdź zmienne współczynników

Krok 3) Inżynieria cech

Przekształcenie edukacji

Z powyższego wykresu widać, że zmienna edukacja ma 16 poziomów. Jest to znaczne, a niektóre poziomy mają stosunkowo małą liczbę obserwacji. Jeśli chcesz zwiększyć ilość informacji, które możesz uzyskać z tej zmiennej, możesz przekształcić ją na wyższy poziom. Mianowicie tworzycie większe grupy o podobnym poziomie wykształcenia. Na przykład niski poziom wykształcenia spowoduje porzucenie nauki. Wyższe poziomy edukacji zostaną zmienione na mistrzowskie.

Oto szczegóły:

Stary poziom Nowy poziom
Przedszkole spadkowicz
10 Spadkowicz
11 Spadkowicz
12 Spadkowicz
1-4 Spadkowicz
5th-6th Spadkowicz
7th-8th Spadkowicz
9 Spadkowicz
Stopień HS Wysoki Grad
Uczelnia Społeczność
Assoc-acdm Społeczność
doc Społeczność
Doktorantów Doktorantów
Masters Masters
Szkoła prof Masters
Doktorat Dr
recast_data <- data_adult_rescale % > %
	select(-X) % > %
	mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community",
    ifelse(education == "Bachelors", "Bachelors",
        ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))

Objaśnienie kodu

  • Używamy czasownika mutate z biblioteki dplyr. Stwierdzeniem ifelse zmieniamy wartości edukacji

W poniższej tabeli tworzysz statystyki podsumowujące, aby zobaczyć, ile lat edukacji (wartość z) potrzeba, aby uzyskać tytuł licencjata, magistra lub doktora.

recast_data % > %
	group_by(education) % > %
	summarize(average_educ_year = mean(educational.num),
		count = n()) % > %
	arrange(average_educ_year)

Wyjście:

## # A tibble: 6 x 3
## education average_educ_year count			
##      <fctr>             <dbl> <int>
## 1   dropout       -1.76147258  5712
## 2  HighGrad       -0.43998868 14803
## 3 Community        0.09561361 13407
## 4 Bachelors        1.12216282  7720
## 5    Master        1.60337381  3338
## 6       PhD        2.29377644   557

Przerobić Maristatus tal

Można również utworzyć niższe poziomy dla stanu cywilnego. W poniższym kodzie zmieniasz poziom w następujący sposób:

Stary poziom Nowy poziom
Nigdy nie żonaty Niezamężny
Żonaty-małżonek-nieobecny Niezamężny
Żonaty-AF-małżonek Żonaty
Żonaty-obywatelski współmałżonek
Rozdzielony Rozdzielony
Rozwiedziony
Wdowy Wdowa
# Change level marry
recast_data <- recast_data % > %
	mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))

Możesz sprawdzić liczbę osób w każdej grupie.

table(recast_data$marital.status)

Wyjście:

## ##     Married Not_married   Separated       Widow
##       21165       15359        7727        1286

Krok 4) Statystyka podsumowująca

Czas sprawdzić statystyki dotyczące naszych zmiennych docelowych. Na poniższym wykresie liczysz odsetek osób zarabiających powyżej 50 tys., biorąc pod uwagę płeć.

# Plot gender income
ggplot(recast_data, aes(x = gender, fill = income)) +
    geom_bar(position = "fill") +
    theme_classic()

Wyjście:

Statystyka podsumowująca

Następnie sprawdź, czy pochodzenie danej osoby wpływa na jej zarobki.

# Plot origin income
ggplot(recast_data, aes(x = race, fill = income)) +
    geom_bar(position = "fill") +
    theme_classic() +
    theme(axis.text.x = element_text(angle = 90))

Wyjście:

Statystyka podsumowująca

Liczba godzin pracy według płci.

# box plot gender working time
ggplot(recast_data, aes(x = gender, y = hours.per.week)) +
    geom_boxplot() +
    stat_summary(fun.y = mean,
        geom = "point",
        size = 3,
        color = "steelblue") +
    theme_classic()

Wyjście:

Statystyka podsumowująca

Wykres pudełkowy potwierdza, że ​​rozkład czasu pracy pasuje do różnych grup. Na wykresie pudełkowym obie płcie nie mają jednorodnych obserwacji.

Możesz sprawdzić gęstość tygodniowego czasu pracy według rodzaju wykształcenia. Dystrybucje mają wiele różnych typów. Prawdopodobnie można to wytłumaczyć rodzajem umowy w USA.

# Plot distribution working time by education
ggplot(recast_data, aes(x = hours.per.week)) +
    geom_density(aes(color = education), alpha = 0.5) +
    theme_classic()

Objaśnienie kodu

  • ggplot(recast_data, aes( x= hours.per.week)): Wykres gęstości wymaga tylko jednej zmiennej
  • geom_density(aes(color = edukacja), alfa =0.5): Obiekt geometryczny kontrolujący gęstość

Wyjście:

Statystyka podsumowująca

Aby potwierdzić swoje przemyślenia, możesz wykonać operację w jedną stronę Test ANOVA:

anova <- aov(hours.per.week~education, recast_data)
summary(anova)

Wyjście:

##                Df Sum Sq Mean Sq F value Pr(>F)    
## education       5   1552  310.31   321.2 <2e-16 ***
## Residuals   45531  43984    0.97                   
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Test ANOVA potwierdza różnice w średnich między grupami.

Nieliniowość

Zanim uruchomisz model, możesz sprawdzić, czy liczba przepracowanych godzin jest związana z wiekiem.

library(ggplot2)
ggplot(recast_data, aes(x = age, y = hours.per.week)) +
    geom_point(aes(color = income),
        size = 0.5) +
    stat_smooth(method = 'lm',
        formula = y~poly(x, 2),
        se = TRUE,
        aes(color = income)) +
    theme_classic()

Objaśnienie kodu

  • ggplot(recast_data, aes(x = wiek, y = godziny.na.tydzień)): Ustaw estetykę wykresu
  • geom_point(aes(color= dochód), size =0.5): Skonstruuj wykres punktowy
  • stat_smooth(): Dodaj linię trendu z następującymi argumentami:
    • method='lm': Wykreśl dopasowaną wartość, jeśli regresji liniowej
    • formuła = y~poly(x,2): Dopasuj regresję wielomianową
    • se = TRUE: Dodaj błąd standardowy
    • aes(color= dochód): Rozbij model według dochodu

Wyjście:

Nieliniowość

Krótko mówiąc, możesz przetestować warunki interakcji w modelu, aby wychwycić efekt nieliniowości pomiędzy tygodniowym czasem pracy a innymi cechami. Ważne jest, aby wykryć, w jakich warunkach czas pracy jest różny.

Korelacja

Następną kontrolą jest wizualizacja korelacji pomiędzy zmiennymi. Konwertujesz typ poziomu współczynnika na numeryczny, aby można było wykreślić mapę cieplną zawierającą współczynnik korelacji obliczony metodą Spearmana.

library(GGally)
# Convert data to numeric
corr <- data.frame(lapply(recast_data, as.integer))
# Plot the graphggcorr(corr,
    method = c("pairwise", "spearman"),
    nbreaks = 6,
    hjust = 0.8,
    label = TRUE,
    label_size = 3,
    color = "grey50")

Objaśnienie kodu

  • data.frame(lapply(recast_data,as.integer)): Konwertuj dane na numeryczne
  • ggcorr() tworzy mapę cieplną z następującymi argumentami:
    • metoda: Metoda obliczania korelacji
    • nbreaks = 6: Liczba przerw
    • hjust = 0.8: Pozycja kontrolna nazwy zmiennej na wykresie
    • etykieta = TRUE: Dodaj etykiety na środku okien
    • label_size = 3: Rozmiar etykiet
    • color = „grey50”): Kolor etykiety

Wyjście:

Korelacja

Krok 5) Trenuj/zestaw testowy

Każdy nadzorowany uczenie maszynowe zadanie wymaga podziału danych pomiędzy zestawem pociągowym i zestawem testowym. Możesz użyć „funkcji” utworzonej w innych samouczkach uczenia się nadzorowanego, aby utworzyć zestaw pociągowy/testowy.

set.seed(1234)
create_train_test <- function(data, size = 0.8, train = TRUE) {
    n_row = nrow(data)
    total_row = size * n_row
    train_sample <- 1: total_row
    if (train == TRUE) {
        return (data[train_sample, ])
    } else {
        return (data[-train_sample, ])
    }
}
data_train <- create_train_test(recast_data, 0.8, train = TRUE)
data_test <- create_train_test(recast_data, 0.8, train = FALSE)
dim(data_train)

Wyjście:

## [1] 36429     9
dim(data_test)

Wyjście:

## [1] 9108    9

Krok 6) Zbuduj model

Aby zobaczyć, jak działa algorytm, użyj pakietu glm(). The Uogólniony model liniowy to zbiór modeli. Podstawowa składnia to:

glm(formula, data=data, family=linkfunction()
Argument:
- formula:  Equation used to fit the model- data: dataset used
- Family:     - binomial: (link = "logit")			
- gaussian: (link = "identity")			
- Gamma:    (link = "inverse")			
- inverse.gaussian: (link = "1/mu^2")			
- poisson:  (link = "log")			
- quasi:    (link = "identity", variance = "constant")			
- quasibinomial:    (link = "logit")			
- quasipoisson: (link = "log")	

Jesteś gotowy do oszacowania modelu logistycznego, aby podzielić poziom dochodu pomiędzy zestawem cech.

formula <- income~.
logit <- glm(formula, data = data_train, family = 'binomial')
summary(logit)

Objaśnienie kodu

  • formuła <- dochód ~.: Utwórz model pasujący
  • logit <- glm(formula, data = data_train, rodzina = 'dwumianowy'): Dopasuj model logistyczny (rodzina = 'dwumianowy') za pomocą danych data_train.
  • podsumowanie(logit): Wydrukuj podsumowanie modelu

Wyjście:

## 
## Call:
## glm(formula = formula, family = "binomial", data = data_train)
## ## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.6456  -0.5858  -0.2609  -0.0651   3.1982  
## 
## Coefficients:
##                           Estimate Std. Error z value Pr(>|z|)    
## (Intercept)                0.07882    0.21726   0.363  0.71675    
## age                        0.41119    0.01857  22.146  < 2e-16 ***
## workclassLocal-gov        -0.64018    0.09396  -6.813 9.54e-12 ***
## workclassPrivate          -0.53542    0.07886  -6.789 1.13e-11 ***
## workclassSelf-emp-inc     -0.07733    0.10350  -0.747  0.45499    
## workclassSelf-emp-not-inc -1.09052    0.09140 -11.931  < 2e-16 ***
## workclassState-gov        -0.80562    0.10617  -7.588 3.25e-14 ***
## workclassWithout-pay      -1.09765    0.86787  -1.265  0.20596    
## educationCommunity        -0.44436    0.08267  -5.375 7.66e-08 ***
## educationHighGrad         -0.67613    0.11827  -5.717 1.08e-08 ***
## educationMaster            0.35651    0.06780   5.258 1.46e-07 ***
## educationPhD               0.46995    0.15772   2.980  0.00289 ** 
## educationdropout          -1.04974    0.21280  -4.933 8.10e-07 ***
## educational.num            0.56908    0.07063   8.057 7.84e-16 ***
## marital.statusNot_married -2.50346    0.05113 -48.966  < 2e-16 ***
## marital.statusSeparated   -2.16177    0.05425 -39.846  < 2e-16 ***
## marital.statusWidow       -2.22707    0.12522 -17.785  < 2e-16 ***
## raceAsian-Pac-Islander     0.08359    0.20344   0.411  0.68117    
## raceBlack                  0.07188    0.19330   0.372  0.71001    
## raceOther                  0.01370    0.27695   0.049  0.96054    
## raceWhite                  0.34830    0.18441   1.889  0.05894 .  
## genderMale                 0.08596    0.04289   2.004  0.04506 *  
## hours.per.week             0.41942    0.01748  23.998  < 2e-16 ***
## ---## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## ## (Dispersion parameter for binomial family taken to be 1)
## ##     Null deviance: 40601  on 36428  degrees of freedom
## Residual deviance: 27041  on 36406  degrees of freedom
## AIC: 27087
## 
## Number of Fisher Scoring iterations: 6

Podsumowanie naszego modelu ujawnia ciekawe informacje. Wydajność regresji logistycznej ocenia się za pomocą określonych kluczowych wskaźników.

  • AIC (Kryteria informacyjne Akaike): Jest to odpowiednik R2 w regresji logistycznej. Mierzy dopasowanie, gdy do liczby parametrów zostanie zastosowana kara. Mniejszy AIC wartości wskazują, że model jest bliższy prawdy.
  • Odchylenie zerowe: Pasuje do modelu tylko z wyrazem wolnym. Stopień swobody wynosi n-1. Możemy to zinterpretować jako wartość chi-kwadrat (wartość dopasowana różni się od testowania hipotezy wartości rzeczywistej).
  • Odchylenie resztkowe: Model ze wszystkimi zmiennymi. Jest to również interpretowane jako testowanie hipotezy chi-kwadrat.
  • Liczba iteracji punktacji Fishera: Liczba iteracji przed zbieżnością.

Dane wyjściowe funkcji glm() są przechowywane na liście. Poniższy kod pokazuje wszystkie elementy dostępne w zmiennej logit, którą skonstruowaliśmy w celu oceny regresji logistycznej.

# Lista jest bardzo długa, wydrukuj tylko pierwsze trzy elementy

lapply(logit, class)[1:3]

Wyjście:

## $coefficients
## [1] "numeric"
## 
## $residuals
## [1] "numeric"
## 
## $fitted.values
## [1] "numeric"

Każdą wartość można wyodrębnić za pomocą znaku $, po którym następuje nazwa metryki. Na przykład zapisałeś model jako logit. Aby wyodrębnić kryteria AIC, użyj:

logit$aic

Wyjście:

## [1] 27086.65

Krok 7) Oceń wydajność modelu

Macierz zamieszania

matryca zamieszania jest lepszym wyborem do oceny skuteczności klasyfikacji w porównaniu z różnymi metrykami, które widziałeś wcześniej. Ogólna koncepcja polega na tym, aby policzyć, ile razy prawdziwe przypadki zostały sklasyfikowane jako fałszywe.

Macierz zamieszania

Aby obliczyć macierz zamieszania, należy najpierw dysponować zestawem przewidywań, aby można je było porównać z rzeczywistymi wartościami docelowymi.

predict <- predict(logit, data_test, type = 'response')
# confusion matrix
table_mat <- table(data_test$income, predict > 0.5)
table_mat

Objaśnienie kodu

  • przewidywanie(logit,data_test, type = 'odpowiedź'): Oblicz przewidywanie na zestawie testowym. Ustaw typ = „odpowiedź”, aby obliczyć prawdopodobieństwo odpowiedzi.
  • table(data_test$income, przewidywanie > 0.5): Oblicz macierz zamieszania. przewidywanie > 0.5 oznacza, że ​​zwraca 1, jeśli przewidywane prawdopodobieństwa są większe niż 0.5, w przeciwnym razie 0.

Wyjście:

##        
##         FALSE TRUE
##   <=50K  6310  495
##   >50K   1074 1229	

Każdy wiersz w macierzy zamieszania reprezentuje rzeczywisty cel, podczas gdy każda kolumna reprezentuje przewidywany cel. Pierwszy wiersz tej macierzy uwzględnia dochody mniejsze niż 50 tys. (klasa Fałsz): 6241 zostało poprawnie zakwalifikowanych do osób o dochodach niższych niż 50 tys. (Prawdziwy negatyw), natomiast pozostała została błędnie sklasyfikowana jako powyżej 50 tys. (Fałszywie pozytywne). Drugi rząd uwzględnia dochody powyżej 50 tys., klasa dodatnia to 1229 (Prawdziwie pozytywne), podczas, gdy Prawdziwy negatyw było 1074.

Można obliczyć model precyzja poprzez zsumowanie prawdziwie dodatnich i prawdziwie ujemnych wartości z całej obserwacji

Macierz zamieszania

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
accuracy_Test

Objaśnienie kodu

  • sum(diag(table_mat)): Suma przekątnej
  • sum(table_mat): Suma macierzy.

Wyjście:

## [1] 0.8277339

Model wydaje się mieć jeden problem: przeszacowuje liczbę wyników fałszywie negatywnych. Nazywa się to paradoks testu dokładności. Ustaliliśmy, że trafność to stosunek poprawnych przewidywań do całkowitej liczby przypadków. Możemy mieć stosunkowo wysoką dokładność, ale model jest bezużyteczny. Dzieje się tak, gdy istnieje klasa dominująca. Jeśli spojrzysz wstecz na matrycę zamieszania, zobaczysz, że większość przypadków jest klasyfikowana jako prawdziwie negatywna. Wyobraź sobie teraz, że model sklasyfikował wszystkie klasy jako ujemne (tzn. poniżej 50 tys.). Uzyskałbyś dokładność na poziomie 75 procent (6718/6718+2257). Twój model działa lepiej, ale ma trudności z odróżnieniem prawdziwego pozytywu od prawdziwego negatywu.

W takiej sytuacji lepiej jest mieć bardziej zwięzłe dane. Możemy przyjrzeć się:

  • Precyzja = TP/(TP+FP)
  • Przywołanie = TP/(TP+FN)

Precyzja kontra przypominanie

Detaliczność sprawdza dokładność pozytywnej prognozy. Odwołanie jest stosunkiem pozytywnych przypadków, które zostały poprawnie wykryte przez klasyfikator;

Można skonstruować dwie funkcje, aby obliczyć te dwie metryki

  1. Konstruuj precyzję
precision <- function(matrix) {
	# True positive
    tp <- matrix[2, 2]
	# false positive
    fp <- matrix[1, 2]
    return (tp / (tp + fp))
}

Objaśnienie kodu

  • mat[1,1]: Zwraca pierwszą komórkę pierwszej kolumny ramki danych, tj. wartość dodatnią
  • mata[1,2]; Zwróć pierwszą komórkę drugiej kolumny ramki danych, tj. wynik fałszywie dodatni
recall <- function(matrix) {
# true positive
    tp <- matrix[2, 2]# false positive
    fn <- matrix[2, 1]
    return (tp / (tp + fn))
}

Objaśnienie kodu

  • mat[1,1]: Zwraca pierwszą komórkę pierwszej kolumny ramki danych, tj. wartość dodatnią
  • mata[2,1]; Zwróć drugą komórkę pierwszej kolumny ramki danych, tj. wartość fałszywie ujemną

Możesz przetestować swoje funkcje

prec <- precision(table_mat)
prec
rec <- recall(table_mat)
rec

Wyjście:

## [1] 0.712877
## [2] 0.5336518

Kiedy model podaje, że jest to osoba powyżej 50 tys., ma rację tylko w 54% przypadków, a w 50% przypadków może wskazać osoby powyżej 72 tys.

Możesz stworzyć Precyzja kontra przypominanie Wynik oparty na precyzji i zapamiętywaniu. The Precyzja kontra przypominanie jest średnią harmoniczną tych dwóch wskaźników, co oznacza, że ​​przypisuje większą wagę niższym wartościom.

Precyzja kontra przypominanie

f1 <- 2 * ((prec * rec) / (prec + rec))
f1

Wyjście:

## [1] 0.6103799

Kompromis precyzja kontra przypomnienie

Niemożliwe jest uzyskanie zarówno wysokiej precyzji, jak i wysokiej pamięci.

Jeśli zwiększymy precyzję, łatwiej będzie przewidzieć właściwą osobę, ale wiele z nich przeoczymy (niższa pamięć). W niektórych sytuacjach wolimy większą precyzję niż przypominanie. Istnieje wklęsła zależność pomiędzy precyzją a przypominaniem.

  • Wyobraź sobie, że musisz przewidzieć, czy pacjent ma chorobę. Chcesz być jak najbardziej precyzyjny.
  • Jeśli chcesz wykryć potencjalnych oszustów na ulicy za pomocą rozpoznawania twarzy, lepiej będzie wychwycić wiele osób oznaczonych jako oszuści, nawet jeśli dokładność jest niska. Policja będzie mogła zwolnić osobę, która nie dopuściła się oszustwa.

Krzywa ROC

Odbiornik OperaCharakterystyka krzywa to kolejne popularne narzędzie używane w klasyfikacji binarnej. Jest bardzo podobna do krzywej precyzji/przypomnienia, ale zamiast wykreślać precyzję w funkcji przypominania, krzywa ROC pokazuje prawdziwie dodatni współczynnik (tj. przypominanie) w porównaniu z fałszywie dodatnim współczynnikiem. Odsetek wyników fałszywie pozytywnych to stosunek przypadków negatywnych, które zostały błędnie sklasyfikowane jako pozytywne. Jest ona równa jeden minus rzeczywista stopa ujemna. Prawdziwie ujemna stopa jest również nazywana specyficzność. Stąd wykresy krzywej ROC wrażliwość (przypomnijmy) w porównaniu ze specyficznością 1

Aby wykreślić krzywą ROC, musimy zainstalować bibliotekę o nazwie RORC. Możemy znaleźć w Condzie biblioteka. Możesz wpisać kod:

conda install -cr r-rocr – tak

Możemy wykreślić ROC za pomocą funkcji przewidywania() i wydajności().

library(ROCR)
ROCRpred <- prediction(predict, data_test$income)
ROCRperf <- performance(ROCRpred, 'tpr', 'fpr')
plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))

Objaśnienie kodu

  • przewidywanie(predykt, data_test$income): Biblioteka ROCR musi utworzyć obiekt przewidywania, aby przekształcić dane wejściowe
  • wydajność(ROCRpred, 'tpr','fpr'): Zwróć dwie kombinacje, które chcesz przedstawić na wykresie. Tutaj konstruowane są tpr i fpr. Aby uzyskać precyzję wydruku i przywołanie razem, użyj „prec”, „rec”.

Wyjście:

Krzywa ROC

Krok 8) Ulepsz model

Możesz spróbować dodać nieliniowość do modelu z interakcją pomiędzy

  • wiek i godziny.na.tydzień
  • płeć i godziny.tygodniowo.

Aby porównać oba modele, należy skorzystać z testu punktacji

formula_2 <- income~age: hours.per.week + gender: hours.per.week + .
logit_2 <- glm(formula_2, data = data_train, family = 'binomial')
predict_2 <- predict(logit_2, data_test, type = 'response')
table_mat_2 <- table(data_test$income, predict_2 > 0.5)
precision_2 <- precision(table_mat_2)
recall_2 <- recall(table_mat_2)
f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2))
f1_2

Wyjście:

## [1] 0.6109181

Wynik jest nieco wyższy niż poprzedni. Możesz dalej pracować nad danymi, próbując pobić wynik.

Podsumowanie

Funkcję służącą do trenowania regresji logistycznej możemy podsumować w poniższej tabeli:

Pakiet Cel Funkcjonować Argument
- Utwórz zbiór danych pociągu/testu utwórz_pociąg_set() dane, rozmiar, pociąg
glm Trenuj uogólniony model liniowy glm() formuła, dane, rodzina*
glm Podsumuj model streszczenie() dopasowany model
baza Przewidzieć przepowiadać, wywróżyć() dopasowany model, zbiór danych, typ = „odpowiedź”
baza Utwórz macierz zamieszania tabela() y, przewiduj()
baza Utwórz wynik dokładności suma(diag(tabela())/suma(tabela()
ROCR Utwórz ROC: Krok 1 Utwórz prognozę prognoza() przewidywać(), j
ROCR Utwórz ROC: Krok 2 Stwórz wydajność wydajność() przewidywanie(), „tpr”, „fpr”
ROCR Utwórz ROC: Krok 3 Narysuj wykres wątek() wydajność()

Inny GLM rodzaje modeli to:

– dwumian: (link = „logit”)

– gaussowski: (link = „tożsamość”)

– Gamma: (link = „odwrotność”)

– odwrotność.gaussa: (link = „1/mu^2”)

– poisson: (link = „log”)

– quasi: (link = „tożsamość”, wariancja = „stała”)

– quasibinomial: (link = „logit”)

– quasipoisson: (link = „log”)

Codzienny biuletyn Guru99

Rozpocznij dzień od najnowszych i najważniejszych wiadomości na temat sztucznej inteligencji, dostarczanych już teraz.