GLM in R: modello lineare generalizzato con esempio

Cos'è la regressione logistica?

La regressione logistica viene utilizzata per prevedere una classe, ovvero una probabilità. La regressione logistica può prevedere accuratamente un risultato binario.

Immagina di voler prevedere se un prestito verrà rifiutato/accettato in base a molti attributi. La regressione logistica è della forma 0/1. y = 0 se il prestito viene rifiutato, y = 1 se accettato.

Un modello di regressione logistica differisce dal modello di regressione lineare in due modi.

  • Innanzitutto, la regressione logistica accetta solo input dicotomici (binari) come variabile dipendente (cioè un vettore di 0 e 1).
  • In secondo luogo, il risultato è misurato dalla seguente funzione di collegamento probabilistico chiamata sigma grazie alla sua forma a S.:

Regressione logistica

L'output della funzione è sempre compreso tra 0 e 1. Controlla l'immagine qui sotto

Regressione logistica

La funzione sigmoide restituisce valori da 0 a 1. Per l'attività di classificazione, abbiamo bisogno di un output discreto di 0 o 1.

Per convertire un flusso continuo in un valore discreto, possiamo impostare un limite decisionale a 0.5. Tutti i valori superiori a questa soglia sono classificati come 1

Regressione logistica

Come creare un modello di rivestimento generalizzato (GLM)

Usiamo il file adulto set di dati per illustrare la regressione logistica. L'"adulto" è un ottimo set di dati per l'attività di classificazione. L'obiettivo è prevedere se il reddito annuo in dollari di un individuo supererà i 50.000. Il set di dati contiene 46,033 osservazioni e dieci caratteristiche:

  • età: età dell'individuo. Numerico
  • istruzione: livello di istruzione dell'individuo. Fattore.
  • stato civile: Maristato mentale dell'individuo. Fattore cioè Mai sposato, Sposato-coniuge civile,...
  • genere: genere dell'individuo. Fattore, cioè Maschio o Femmina
  • reddito: Target variabile. Reddito superiore o inferiore a 50. Fattore ovvero >50, <=50

fra gli altri

library(dplyr)
data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv")
glimpse(data_adult)

Produzione:

Observations: 48,842
Variables: 10
$ x               <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,...
$ age             <int> 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26...
$ workclass       <fctr> Private, Private, Local-gov, Private, ?, Private,...
$ education       <fctr> 11th, HS-grad, Assoc-acdm, Some-college, Some-col...
$ educational.num <int> 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,...
$ marital.status  <fctr> Never-married, Married-civ-spouse, Married-civ-sp...
$ race            <fctr> Black, White, White, Black, White, White, Black, ...
$ gender          <fctr> Male, Male, Male, Male, Female, Male, Male, Male,...
$ hours.per.week  <int> 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39...
$ income          <fctr> <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5...

Procederemo come segue:

  • Passaggio 1: verificare le variabili continue
  • Passaggio 2: verificare le variabili fattore
  • Passaggio 3: ingegneria delle funzionalità
  • Passaggio 4: statistica riepilogativa
  • Passaggio 5: set di allenamento/prova
  • Passaggio 6: costruisci il modello
  • Passaggio 7: valutare le prestazioni del modello
  • passaggio 8: migliorare il modello

Il tuo compito è prevedere quale individuo avrà entrate superiori a 50.

In questo tutorial, ogni passaggio sarà dettagliato per eseguire un'analisi su un set di dati reale.

Passaggio 1) Controllare le variabili continue

Nel primo passaggio puoi vedere la distribuzione delle variabili continue.

continuous <-select_if(data_adult, is.numeric)
summary(continuous)

Spiegazione del codice

  • continue <- select_if(data_adult, is.numeric): utilizza la funzione select_if() dalla libreria dplyr per selezionare solo le colonne numeriche
  • summary(continuous): stampa la statistica riepilogativa

Produzione:

##        X              age        educational.num hours.per.week 
##  Min.   :    1   Min.   :17.00   Min.   : 1.00   Min.   : 1.00  
##  1st Qu.:11509   1st Qu.:28.00   1st Qu.: 9.00   1st Qu.:40.00  
##  Median :23017   Median :37.00   Median :10.00   Median :40.00  
##  Mean   :23017   Mean   :38.56   Mean   :10.13   Mean   :40.95  
##  3rd Qu.:34525   3rd Qu.:47.00   3rd Qu.:13.00   3rd Qu.:45.00  
##  Max.   :46033   Max.   :90.00   Max.   :16.00   Max.   :99.00	

Dalla tabella sopra, puoi vedere che i dati hanno scale completamente diverse e ore.per.settimane presenta ampi valori anomali (ad esempio, guarda l'ultimo quartile e il valore massimo).

Puoi affrontarlo seguendo due passaggi:

  • 1: Traccia la distribuzione delle ore.per.settimana
  • 2: Standardizzare le variabili continue
  1. Traccia la distribuzione

Diamo un'occhiata più da vicino alla distribuzione delle ore.per.settimana

# Histogram with kernel density curve
library(ggplot2)
ggplot(continuous, aes(x = hours.per.week)) +
    geom_density(alpha = .2, fill = "#FF6666")

Produzione:

Controlla variabili continue

La variabile ha molti valori anomali e una distribuzione non ben definita. Puoi affrontare parzialmente questo problema eliminando lo 0.01% delle ore settimanali più ricche.

Sintassi di base del quantile:

quantile(variable, percentile)
arguments:
-variable:  Select the variable in the data frame to compute the percentile
-percentile:  Can be a single value between 0 and 1 or multiple value. If multiple, use this format:  `c(A,B,C, ...)
- `A`,`B`,`C` and `...` are all integer from 0 to 1.

Calcoliamo il percentile del 2% più alto

top_one_percent <- quantile(data_adult$hours.per.week, .99)
top_one_percent

Spiegazione del codice

  • quantile(data_adult$hours.per.week, .99): calcola il valore del 99% dell'orario di lavoro

Produzione:

## 99% 
##  80

Il 98% della popolazione lavora meno di 80 ore settimanali.

È possibile eliminare le osservazioni al di sopra di questa soglia. Utilizzi il filtro da dplyr biblioteca.

data_adult_drop <-data_adult %>%
filter(hours.per.week<top_one_percent)
dim(data_adult_drop)

Produzione:

## [1] 45537    10
  1. Standardizzare le variabili continue

Puoi standardizzare ciascuna colonna per migliorare le prestazioni perché i tuoi dati non hanno la stessa scala. Puoi usare la funzione mutate_if dalla libreria dplyr. La sintassi di base è:

mutate_if(df, condition, funs(function))
arguments:
-`df`: Data frame used to compute the function
- `condition`: Statement used. Do not use parenthesis
- funs(function):  Return the function to apply. Do not use parenthesis for the function

È possibile standardizzare le colonne numeriche come segue:

data_adult_rescale <- data_adult_drop % > %
	mutate_if(is.numeric, funs(as.numeric(scale(.))))
head(data_adult_rescale)

Spiegazione del codice

  • mutate_if(is.numeric, funs(scale)): la condizione è solo una colonna numerica e la funzione è scala

Produzione:

##           X         age        workclass    education educational.num
## 1 -1.732680 -1.02325949          Private         11th     -1.22106443
## 2 -1.732605 -0.03969284          Private      HS-grad     -0.43998868
## 3 -1.732530 -0.79628257        Local-gov   Assoc-acdm      0.73162494
## 4 -1.732455  0.41426100          Private Some-college     -0.04945081
## 5 -1.732379 -0.34232873          Private         10th     -1.61160231
## 6 -1.732304  1.85178149 Self-emp-not-inc  Prof-school      1.90323857
##       marital.status  race gender hours.per.week income
## 1      Never-married Black   Male    -0.03995944  <=50K
## 2 Married-civ-spouse White   Male     0.86863037  <=50K
## 3 Married-civ-spouse White   Male    -0.03995944   >50K
## 4 Married-civ-spouse Black   Male    -0.03995944   >50K
## 5      Never-married White   Male    -0.94854924  <=50K
## 6 Married-civ-spouse White   Male    -0.76683128   >50K

Passaggio 2) Controllare le variabili fattore

Questo passaggio ha due obiettivi:

  • Controlla il livello in ciascuna colonna categoriale
  • Definire nuovi livelli

Divideremo questo passaggio in tre parti:

  • Seleziona le colonne categoriali
  • Memorizza il grafico a barre di ciascuna colonna in un elenco
  • Stampa i grafici

Possiamo selezionare le colonne dei fattori con il codice seguente:

# Select categorical column
factor <- data.frame(select_if(data_adult_rescale, is.factor))
	ncol(factor)

Spiegazione del codice

  • data.frame(select_if(data_adult, is.factor)): memorizziamo le colonne dei fattori in factor in un tipo di frame di dati. La libreria ggplot2 richiede un oggetto frame di dati.

Produzione:

## [1] 6

Il set di dati contiene 6 variabili categoriali

Il secondo passaggio è più abile. Vuoi tracciare un grafico a barre per ogni colonna nel fattore frame dati. È più conveniente automatizzare il processo, soprattutto in situazioni in cui sono presenti molte colonne.

library(ggplot2)
# Create graph for each column
graph <- lapply(names(factor),
    function(x) 
	ggplot(factor, aes(get(x))) +
		geom_bar() +
		theme(axis.text.x = element_text(angle = 90)))

Spiegazione del codice

  • lapply(): utilizza la funzione lapply() per passare una funzione in tutte le colonne del set di dati. Memorizzi l'output in un elenco
  • function(x): la funzione verrà elaborata per ogni x. Qui x sono le colonne
  • ggplot(factor, aes(get(x))) + geom_bar()+ theme(axis.text.x = element_text(angle = 90)): crea un grafico a barre per ogni elemento x. Nota, per restituire x come colonna, devi includerla all'interno di get()

L'ultimo passaggio è relativamente semplice. Vuoi stampare i 6 grafici.

# Print the graph
graph

Produzione:

## [[1]]

Controllare le variabili fattore

## ## [[2]]

Controllare le variabili fattore

## ## [[3]]

Controllare le variabili fattore

## ## [[4]]

Controllare le variabili fattore

## ## [[5]]

Controllare le variabili fattore

## ## [[6]]

Controllare le variabili fattore

Nota: utilizzare il pulsante successivo per passare al grafico successivo

Controllare le variabili fattore

Passaggio 3) Ingegneria delle funzionalità

Riformulare l'istruzione

Dal grafico sopra puoi vedere che la variabile istruzione ha 16 livelli. Questo è sostanziale e alcuni livelli hanno un numero relativamente basso di osservazioni. Se desideri migliorare la quantità di informazioni che puoi ottenere da questa variabile, puoi riformularla a un livello superiore. Vale a dire, crei gruppi più grandi con un livello di istruzione simile. Ad esempio, un basso livello di istruzione si tradurrà in abbandono scolastico. I livelli di istruzione più elevati verranno modificati in master.

Ecco il dettaglio:

Vecchio livello Nuovo livello
Asilo Nido Dropout
10° dropout
11° dropout
12° dropout
1 ° -4 ° dropout
5th-6th dropout
7th-8th dropout
dropout
Laurea HS Grado elevato
Qualche college Comunità
Assoc-acdm Comunità
Assoc-voc Comunità
Diploma di laurea Diploma di laurea
Masters Masters
Prof-scuola Masters
Dottorato PhD
recast_data <- data_adult_rescale % > %
	select(-X) % > %
	mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community",
    ifelse(education == "Bachelors", "Bachelors",
        ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))

Spiegazione del codice

  • Usiamo il verbo mutare dalla libreria dplyr. Cambiamo i valori dell'educazione con l'affermazione ifelse

Nella tabella seguente, crei una statistica riepilogativa per vedere, in media, quanti anni di istruzione (valore z) sono necessari per raggiungere la laurea, il master o il dottorato.

recast_data % > %
	group_by(education) % > %
	summarize(average_educ_year = mean(educational.num),
		count = n()) % > %
	arrange(average_educ_year)

Produzione:

## # A tibble: 6 x 3
## education average_educ_year count			
##      <fctr>             <dbl> <int>
## 1   dropout       -1.76147258  5712
## 2  HighGrad       -0.43998868 14803
## 3 Community        0.09561361 13407
## 4 Bachelors        1.12216282  7720
## 5    Master        1.60337381  3338
## 6       PhD        2.29377644   557

Rifusione Marital-stato

È anche possibile creare livelli inferiori per lo stato civile. Nel codice seguente si modifica il livello come segue:

Vecchio livello Nuovo livello
Mai sposato Non sposato
Sposato-coniuge-assente Non sposato
Coniuge sposato Sposato
Sposato-civ-coniuge
Separato Separato
Divorziato
Widows Vedova
# Change level marry
recast_data <- recast_data % > %
	mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))

È possibile verificare il numero di individui all'interno di ciascun gruppo.

table(recast_data$marital.status)

Produzione:

## ##     Married Not_married   Separated       Widow
##       21165       15359        7727        1286

Passaggio 4) Statistica riepilogativa

È tempo di controllare alcune statistiche sulle nostre variabili target. Nel grafico sottostante si conta la percentuale di individui che guadagnano più di 50 in base al sesso.

# Plot gender income
ggplot(recast_data, aes(x = gender, fill = income)) +
    geom_bar(position = "fill") +
    theme_classic()

Produzione:

Statistica riassuntiva

Successivamente, controlla se l'origine dell'individuo influisce sul suo guadagno.

# Plot origin income
ggplot(recast_data, aes(x = race, fill = income)) +
    geom_bar(position = "fill") +
    theme_classic() +
    theme(axis.text.x = element_text(angle = 90))

Produzione:

Statistica riassuntiva

Il numero di ore lavorate per genere.

# box plot gender working time
ggplot(recast_data, aes(x = gender, y = hours.per.week)) +
    geom_boxplot() +
    stat_summary(fun.y = mean,
        geom = "point",
        size = 3,
        color = "steelblue") +
    theme_classic()

Produzione:

Statistica riassuntiva

Il box plot conferma che la distribuzione del tempo di lavoro si adatta a gruppi diversi. Nel box plot, entrambi i sessi non hanno osservazioni omogenee.

Puoi verificare la densità dell'orario di lavoro settimanale per tipo di istruzione. Le distribuzioni hanno molte scelte distinte. Ciò può probabilmente essere spiegato dal tipo di contratto negli Stati Uniti.

# Plot distribution working time by education
ggplot(recast_data, aes(x = hours.per.week)) +
    geom_density(aes(color = education), alpha = 0.5) +
    theme_classic()

Spiegazione del codice

  • ggplot(recast_data, aes( x= ore.per.settimana)): un grafico di densità richiede solo una variabile
  • geom_density(aes(color = education), alpha =0.5): l'oggetto geometrico per controllare la densità

Produzione:

Statistica riassuntiva

Per confermare i tuoi pensieri, puoi eseguire un viaggio di sola andata Prova ANOVA:

anova <- aov(hours.per.week~education, recast_data)
summary(anova)

Produzione:

##                Df Sum Sq Mean Sq F value Pr(>F)    
## education       5   1552  310.31   321.2 <2e-16 ***
## Residuals   45531  43984    0.97                   
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Il test ANOVA conferma la differenza nella media tra i gruppi.

Non linearità

Prima di eseguire il modello, puoi verificare se il numero di ore lavorate è correlato all'età.

library(ggplot2)
ggplot(recast_data, aes(x = age, y = hours.per.week)) +
    geom_point(aes(color = income),
        size = 0.5) +
    stat_smooth(method = 'lm',
        formula = y~poly(x, 2),
        se = TRUE,
        aes(color = income)) +
    theme_classic()

Spiegazione del codice

  • ggplot(recast_data, aes(x = age, y = hour.per.week)): imposta l'estetica del grafico
  • geom_point(aes(color= reddito), size =0.5): costruisce il dot plot
  • stat_smooth(): Aggiungi la linea di tendenza con i seguenti argomenti:
    • metodo='lm': traccia il valore adattato se il regressione lineare
    • formula = y~poly(x,2): adatta una regressione polinomiale
    • se = TRUE: aggiungi l'errore standard
    • aes(color=reddito): spezza il modello in base al reddito

Produzione:

Non linearità

In poche parole, puoi testare i termini di interazione nel modello per rilevare l'effetto di non linearità tra l'orario di lavoro settimanale e altre funzionalità. È importante rilevare in quali condizioni l'orario di lavoro differisce.

Correlazione

Il controllo successivo consiste nel visualizzare la correlazione tra le variabili. Convertire il tipo di livello fattore in numerico in modo da poter tracciare una mappa termica contenente il coefficiente di correlazione calcolato con il metodo Spearman.

library(GGally)
# Convert data to numeric
corr <- data.frame(lapply(recast_data, as.integer))
# Plot the graphggcorr(corr,
    method = c("pairwise", "spearman"),
    nbreaks = 6,
    hjust = 0.8,
    label = TRUE,
    label_size = 3,
    color = "grey50")

Spiegazione del codice

  • data.frame(lapply(recast_data,as.integer)): converte i dati in numerici
  • ggcorr() traccia la mappa di calore con i seguenti argomenti:
    • metodo: metodo per calcolare la correlazione
    • nbreaks = 6: Numero di interruzioni
    • hjust = 0.8: controlla la posizione del nome della variabile nel grafico
    • label = TRUE: Aggiungi etichette al centro delle finestre
    • label_size = 3: etichette delle dimensioni
    • color = “grey50”): Colore dell'etichetta

Produzione:

Correlazione

Passaggio 5) Addestramento/test del set

Qualsiasi supervisionato machine learning l'attività richiede di dividere i dati tra un convoglio e un set di prova. Puoi utilizzare la "funzione" che hai creato negli altri tutorial di apprendimento supervisionato per creare un set di training/test.

set.seed(1234)
create_train_test <- function(data, size = 0.8, train = TRUE) {
    n_row = nrow(data)
    total_row = size * n_row
    train_sample <- 1: total_row
    if (train == TRUE) {
        return (data[train_sample, ])
    } else {
        return (data[-train_sample, ])
    }
}
data_train <- create_train_test(recast_data, 0.8, train = TRUE)
data_test <- create_train_test(recast_data, 0.8, train = FALSE)
dim(data_train)

Produzione:

## [1] 36429     9
dim(data_test)

Produzione:

## [1] 9108    9

Passaggio 6) Costruisci il modello

Per vedere come si comporta l'algoritmo, usi il pacchetto glm(). IL Modello lineare generalizzato è una raccolta di modelli. La sintassi di base è:

glm(formula, data=data, family=linkfunction()
Argument:
- formula:  Equation used to fit the model- data: dataset used
- Family:     - binomial: (link = "logit")			
- gaussian: (link = "identity")			
- Gamma:    (link = "inverse")			
- inverse.gaussian: (link = "1/mu^2")			
- poisson:  (link = "log")			
- quasi:    (link = "identity", variance = "constant")			
- quasibinomial:    (link = "logit")			
- quasipoisson: (link = "log")	

Sei pronto per stimare il modello logistico per suddividere il livello di reddito tra una serie di funzionalità.

formula <- income~.
logit <- glm(formula, data = data_train, family = 'binomial')
summary(logit)

Spiegazione del codice

  • formula <- reddito ~ .: crea il modello adatto
  • logit <- glm(formula, data = data_train, family = 'binomial'): adatta un modello logistico (family = 'binomial') con i dati data_train.
  • summary(logit): stampa il riepilogo del modello

Produzione:

## 
## Call:
## glm(formula = formula, family = "binomial", data = data_train)
## ## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.6456  -0.5858  -0.2609  -0.0651   3.1982  
## 
## Coefficients:
##                           Estimate Std. Error z value Pr(>|z|)    
## (Intercept)                0.07882    0.21726   0.363  0.71675    
## age                        0.41119    0.01857  22.146  < 2e-16 ***
## workclassLocal-gov        -0.64018    0.09396  -6.813 9.54e-12 ***
## workclassPrivate          -0.53542    0.07886  -6.789 1.13e-11 ***
## workclassSelf-emp-inc     -0.07733    0.10350  -0.747  0.45499    
## workclassSelf-emp-not-inc -1.09052    0.09140 -11.931  < 2e-16 ***
## workclassState-gov        -0.80562    0.10617  -7.588 3.25e-14 ***
## workclassWithout-pay      -1.09765    0.86787  -1.265  0.20596    
## educationCommunity        -0.44436    0.08267  -5.375 7.66e-08 ***
## educationHighGrad         -0.67613    0.11827  -5.717 1.08e-08 ***
## educationMaster            0.35651    0.06780   5.258 1.46e-07 ***
## educationPhD               0.46995    0.15772   2.980  0.00289 ** 
## educationdropout          -1.04974    0.21280  -4.933 8.10e-07 ***
## educational.num            0.56908    0.07063   8.057 7.84e-16 ***
## marital.statusNot_married -2.50346    0.05113 -48.966  < 2e-16 ***
## marital.statusSeparated   -2.16177    0.05425 -39.846  < 2e-16 ***
## marital.statusWidow       -2.22707    0.12522 -17.785  < 2e-16 ***
## raceAsian-Pac-Islander     0.08359    0.20344   0.411  0.68117    
## raceBlack                  0.07188    0.19330   0.372  0.71001    
## raceOther                  0.01370    0.27695   0.049  0.96054    
## raceWhite                  0.34830    0.18441   1.889  0.05894 .  
## genderMale                 0.08596    0.04289   2.004  0.04506 *  
## hours.per.week             0.41942    0.01748  23.998  < 2e-16 ***
## ---## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## ## (Dispersion parameter for binomial family taken to be 1)
## ##     Null deviance: 40601  on 36428  degrees of freedom
## Residual deviance: 27041  on 36406  degrees of freedom
## AIC: 27087
## 
## Number of Fisher Scoring iterations: 6

La sintesi del nostro modello rivela informazioni interessanti. Le prestazioni di una regressione logistica vengono valutate con metriche chiave specifiche.

  • AIC (Akaike Information Criteria): questo è l'equivalente di R2 nella regressione logistica. Misura l'adattamento quando viene applicata una penalità al numero di parametri. Più piccola AIC i valori indicano che il modello è più vicino alla verità.
  • Devianza nulla: adatta il modello solo con l'intercetta. Il grado di libertà è n-1. Possiamo interpretarlo come un valore Chi-quadrato (valore adattato diverso dal test dell'ipotesi del valore effettivo).
  • Devianza residua: modello con tutte le variabili. Viene anche interpretato come un test di ipotesi del Chi-quadrato.
  • Numero di iterazioni del punteggio Fisher: numero di iterazioni prima della convergenza.

L'output della funzione glm() viene memorizzato in un elenco. Il codice seguente mostra tutti gli elementi disponibili nella variabile logit che abbiamo costruito per valutare la regressione logistica.

# La lista è molto lunga, stampa solo i primi tre elementi

lapply(logit, class)[1:3]

Produzione:

## $coefficients
## [1] "numeric"
## 
## $residuals
## [1] "numeric"
## 
## $fitted.values
## [1] "numeric"

Ogni valore può essere estratto con il segno $ seguito dal nome della metrica. Ad esempio, hai archiviato il modello come logit. Per estrarre i criteri AIC, si utilizza:

logit$aic

Produzione:

## [1] 27086.65

Passaggio 7) Valutare le prestazioni del modello

Matrice di confusione

matrice di confusione è una scelta migliore per valutare le prestazioni della classificazione rispetto alle diverse metriche visualizzate in precedenza. L'idea generale è contare il numero di volte in cui le istanze Vere vengono classificate come False.

Matrice di confusione

Per calcolare la matrice di confusione, è necessario innanzitutto disporre di una serie di previsioni in modo che possano essere confrontate con gli obiettivi effettivi.

predict <- predict(logit, data_test, type = 'response')
# confusion matrix
table_mat <- table(data_test$income, predict > 0.5)
table_mat

Spiegazione del codice

  • predit(logit,data_test, type = 'response'): calcola la previsione sul set di test. Imposta type = 'response' per calcolare la probabilità di risposta.
  • table(data_test$income, suggest > 0.5): calcola la matrice di confusione. predire > 0.5 significa che restituisce 1 se le probabilità previste sono superiori a 0.5, altrimenti 0.

Produzione:

##        
##         FALSE TRUE
##   <=50K  6310  495
##   >50K   1074 1229	

Ogni riga in una matrice di confusione rappresenta un obiettivo effettivo, mentre ogni colonna rappresenta un obiettivo previsto. La prima riga di questa matrice considera i redditi inferiori a 50mila (classe Falsi): 6241 sono stati correttamente classificati come individui con reddito inferiore a 50mila (Vero negativo), mentre il restante è stato erroneamente classificato superiore a 50mila (Falso positivo). La seconda riga considera i redditi superiori a 50k, le classi positive erano 1229 (Vero positivo), Mentre la Vero negativo era 1074.

Puoi calcolare il modello precisione sommando il vero positivo + il vero negativo sull'osservazione totale

Matrice di confusione

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
accuracy_Test

Spiegazione del codice

  • sum(diag(table_mat)): somma della diagonale
  • sum(table_mat): somma della matrice.

Produzione:

## [1] 0.8277339

Il modello sembra soffrire di un problema: sovrastima il numero di falsi negativi. Questo è chiamato il paradosso del test di precisione. Abbiamo affermato che l’accuratezza è il rapporto tra le previsioni corrette e il numero totale di casi. Possiamo avere una precisione relativamente elevata ma un modello inutile. Succede quando c'è una classe dominante. Se guardi indietro alla matrice di confusione, puoi vedere che la maggior parte dei casi sono classificati come veri negativi. Immaginiamo ora che il modello classifichi tutte le classi come negative (cioè inferiori a 50). Avresti una precisione del 75% (6718/6718+2257). Il tuo modello funziona meglio ma fatica a distinguere il vero positivo dal vero negativo.

In tale situazione, è preferibile avere una metrica più concisa. Possiamo guardare:

  • Precisione=TP/(TP+FP)
  • Richiamo=TP/(TP+FN)

Precisione vs richiamo

Precisione esamina l'accuratezza della previsione positiva. Richiamo è il rapporto di istanze positive rilevate correttamente dal classificatore;

È possibile costruire due funzioni per calcolare queste due metriche

  1. Precisione costruttiva
precision <- function(matrix) {
	# True positive
    tp <- matrix[2, 2]
	# false positive
    fp <- matrix[1, 2]
    return (tp / (tp + fp))
}

Spiegazione del codice

  • mat[1,1]: Restituisce la prima cella della prima colonna del frame di dati, ovvero il vero positivo
  • mat[1,2]; Restituisce la prima cella della seconda colonna del frame di dati, ovvero il falso positivo
recall <- function(matrix) {
# true positive
    tp <- matrix[2, 2]# false positive
    fn <- matrix[2, 1]
    return (tp / (tp + fn))
}

Spiegazione del codice

  • mat[1,1]: Restituisce la prima cella della prima colonna del frame di dati, ovvero il vero positivo
  • mat[2,1]; Restituisce la seconda cella della prima colonna del frame di dati, ovvero il falso negativo

Puoi testare le tue funzioni

prec <- precision(table_mat)
prec
rec <- recall(table_mat)
rec

Produzione:

## [1] 0.712877
## [2] 0.5336518

Quando il modello dice che si tratta di un individuo superiore a 50, è corretto solo nel 54% dei casi e può rivendicare individui superiori a 50 nel 72% dei casi.

Puoi creare il file Precisione vs richiamo punteggio basato sulla precisione e sul ricordo. IL Precisione vs richiamo è una media armonica di questi due parametri, il che significa che dà più peso ai valori più bassi.

Precisione vs richiamo

f1 <- 2 * ((prec * rec) / (prec + rec))
f1

Produzione:

## [1] 0.6103799

Compromesso tra precisione e richiamo

È impossibile avere sia un'alta precisione che un alto richiamo.

Se aumentiamo la precisione, l'individuo corretto sarà previsto meglio, ma ne perderemmo molti (richiamo inferiore). In alcune situazioni, preferiamo una precisione maggiore rispetto al richiamo. Esiste una relazione concava tra precisione e richiamo.

  • Immagina di dover prevedere se un paziente ha una malattia. Vuoi essere il più preciso possibile.
  • Se è necessario individuare potenziali truffatori per strada attraverso il riconoscimento facciale, sarebbe meglio catturare molte persone etichettate come fraudolente anche se la precisione è bassa. La polizia potrà rilasciare la persona non fraudolenta.

La curva ROC

Ricevitore Operacaratteristica La curva è un altro strumento comune utilizzato con la classificazione binaria. È molto simile alla curva precisione/richiamo, ma invece di tracciare la precisione rispetto al richiamo, la curva ROC mostra il tasso di veri positivi (cioè il richiamo) rispetto al tasso di falsi positivi. Il tasso di falsi positivi è il rapporto tra i casi negativi erroneamente classificati come positivi. È uguale a uno meno il vero tasso negativo. Viene anche chiamato il vero tasso negativo specificità. Quindi viene tracciata la curva ROC sensibilità (richiamo) rispetto a 1-specificità

Per tracciare la curva ROC, dobbiamo installare una libreria chiamata RORC. Possiamo trovare nella conda biblioteca. Puoi digitare il codice:

conda install -cr r-rocr –yes

Possiamo tracciare il ROC con le funzioni prediction() e performance().

library(ROCR)
ROCRpred <- prediction(predict, data_test$income)
ROCRperf <- performance(ROCRpred, 'tpr', 'fpr')
plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))

Spiegazione del codice

  • predizione(predict, data_test$income): la libreria ROCR deve creare un oggetto di previsione per trasformare i dati di input
  • performance(ROCRpred, 'tpr','fpr'): restituisce le due combinazioni da produrre nel grafico. Qui vengono costruiti tpr e fpr. Per tracciare precisione e richiamo insieme, usa “prec”, “rec”.

Produzione:

La curva ROC

Passo 8) Migliora il modello

Puoi provare ad aggiungere non linearità al modello con l'interazione tra

  • età e ore.a.settimana
  • sesso e ore.per.settimana.

È necessario utilizzare il test del punteggio per confrontare entrambi i modelli

formula_2 <- income~age: hours.per.week + gender: hours.per.week + .
logit_2 <- glm(formula_2, data = data_train, family = 'binomial')
predict_2 <- predict(logit_2, data_test, type = 'response')
table_mat_2 <- table(data_test$income, predict_2 > 0.5)
precision_2 <- precision(table_mat_2)
recall_2 <- recall(table_mat_2)
f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2))
f1_2

Produzione:

## [1] 0.6109181

Il punteggio è leggermente più alto del precedente. Puoi continuare a lavorare sui dati per provare a battere il punteggio.

Sintesi

Possiamo riassumere la funzione per addestrare una regressione logistica nella tabella seguente:

CONFEZIONE Obiettivo Funzione Argomento
- Creare un set di dati di training/test create_train_set() dati, dimensioni, treno
glm Addestrare un modello lineare generalizzato glm() formula, dati, famiglia*
glm Riassumere il modello riepilogo() modello montato
base Fai una previsione prevedere () modello adattato, set di dati, tipo = 'risposta'
base Crea una matrice di confusione tavolo() sì, predire()
base Creare un punteggio di precisione somma(diag(tabella())/somma(tabella()
ROCR Crea ROC: passaggio 1 Crea previsione predizione() predire(), sì
ROCR Crea ROC: passo 2 Crea performance prestazione() previsione(), 'tpr', 'fpr'
ROCR Creare ROC: Passaggio 3 Tracciare il grafico complotto() prestazione()

L'altra GLM tipo di modelli sono:

– binomio: (link = “logit”)

– gaussiano: (link = “identità”)

– Gamma: (link = “inverso”)

– inverse.gaussiana: (link = “1/mu^2”)

– poisson: (link = “log”)

– quasi: (link = “identità”, varianza = “costante”)

– quasibinomiale: (link = “logit”)

– quasipoisson: (link = “log”)