Beslutningstræ i R: Klassifikationstræ med eksempel
Hvad er beslutningstræer?
Beslutningstræer er alsidig Machine Learning-algoritme, der kan udføre både klassifikations- og regressionsopgaver. De er meget kraftfulde algoritmer, der er i stand til at tilpasse komplekse datasæt. Desuden er beslutningstræer grundlæggende komponenter i tilfældige skove, som er blandt de mest potente Machine Learning-algoritmer, der findes i dag.
Træning og visualisering af beslutningstræer i R
For at bygge dit første beslutningstræ i R-eksemplet, vil vi fortsætte som følger i denne beslutningstræ-vejledning:
- Trin 1: Importer dataene
- Trin 2: Rens datasættet
- Trin 3: Opret tog/testsæt
- Trin 4: Byg modellen
- Trin 5: Lav forudsigelse
- Trin 6: Mål ydeevne
- Trin 7: Indstil hyper-parametrene
Trin 1) Importer dataene
Hvis du er nysgerrig efter titanics skæbne, kan du se denne video på Youtube. Formålet med dette datasæt er at forudsige, hvilke mennesker der er mere tilbøjelige til at overleve efter kollisionen med isbjerget. Datasættet indeholder 13 variabler og 1309 observationer. Datasættet er ordnet efter variablen X.
set.seed(678) path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv' titanic <-read.csv(path) head(titanic)
Output:
## X pclass survived name sex ## 1 1 1 1 Allen, Miss. Elisabeth Walton female ## 2 2 1 1 Allison, Master. Hudson Trevor male ## 3 3 1 0 Allison, Miss. Helen Loraine female ## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male ## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female ## 6 6 1 1 Anderson, Mr. Harry male ## age sibsp parch ticket fare cabin embarked ## 1 29.0000 0 0 24160 211.3375 B5 S ## 2 0.9167 1 2 113781 151.5500 C22 C26 S ## 3 2.0000 1 2 113781 151.5500 C22 C26 S ## 4 30.0000 1 2 113781 151.5500 C22 C26 S ## 5 25.0000 1 2 113781 151.5500 C22 C26 S ## 6 48.0000 0 0 19952 26.5500 E12 S ## home.dest ## 1 St Louis, MO ## 2 Montreal, PQ / Chesterville, ON ## 3 Montreal, PQ / Chesterville, ON ## 4 Montreal, PQ / Chesterville, ON ## 5 Montreal, PQ / Chesterville, ON ## 6 New York, NY
tail(titanic)
Output:
## X pclass survived name sex age sibsp ## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0 ## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1 ## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1 ## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0 ## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0 ## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0 ## parch ticket fare cabin embarked home.dest ## 1304 0 2627 14.4583 C ## 1305 0 2665 14.4542 C ## 1306 0 2665 14.4542 C ## 1307 0 2656 7.2250 C ## 1308 0 2670 7.2250 C ## 1309 0 315082 7.8750 S
Fra hoved- og haleudgangen kan du bemærke, at dataene ikke blandes. Dette er et stort problem! Når du vil dele dine data mellem et togsæt og et testsæt, skal du vælge kun passageren fra klasse 1 og 2 (Ingen passager fra klasse 3 er i top 80 procent af observationerne), hvilket betyder, at algoritmen aldrig vil se funktionerne for passagerer i klasse 3. Denne fejl vil føre til dårlig forudsigelse.
For at løse dette problem kan du bruge funktionen sample().
shuffle_index <- sample(1:nrow(titanic)) head(shuffle_index)
Beslutningstræ R-kode Forklaring
- sample(1:nrow(titanic)): Generer en tilfældig liste med indeks fra 1 til 1309 (dvs. det maksimale antal rækker).
Output:
## [1] 288 874 1078 633 887 992
Du skal bruge dette indeks til at blande det titaniske datasæt.
titanic <- titanic[shuffle_index, ] head(titanic)
Output:
## X pclass survived ## 288 288 1 0 ## 874 874 3 0 ## 1078 1078 3 1 ## 633 633 3 0 ## 887 887 3 1 ## 992 992 3 1 ## name sex age ## 288 Sutton, Mr. Frederick male 61 ## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42 ## 1078 O'Driscoll, Miss. Bridget female NA ## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39 ## 887 Jermyn, Miss. Annie female NA ## 992 Mamee, Mr. Hanna male NA ## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ ## 874 0 0 348121 7.6500 F G63 S ## 1078 0 0 14311 7.7500 Q ## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN ## 887 0 0 14313 7.7500 Q ## 992 0 0 2677 7.2292 C
Trin 2) Rens datasættet
Strukturen af dataene viser, at nogle variable har NA'er. Oprydning af data skal udføres som følger
- Drop variabler home.dest,cabin, name, X og ticket
- Opret faktorvariabler for pclass og overlevede
- Drop NA
library(dplyr) # Drop variables clean_titanic <- titanic % > % select(-c(home.dest, cabin, name, X, ticket)) % > % #Convert to factor level mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')), survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > % na.omit() glimpse(clean_titanic)
Kode Forklaring
- select(-c(home.dest, cabin, name, X, ticket)): Slip unødvendige variabler
- pclass = factor(pclass, levels = c(1,2,3), labels= c('Øvre', 'Middle', 'Lower')): Tilføj etiket til variablen pclass. 1 bliver til Øvre, 2 bliver til Midt og 3 bliver til lavere
- faktor(overlevet, niveauer = c(0,1), etiketter = c('Nej', 'Ja')): Tilføj etiket til variablen overlevede. 1 bliver til nej og 2 bliver til ja
- na.omit(): Fjern NA-observationerne
Output:
## Observations: 1,045 ## Variables: 8 ## $ pclass <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U... ## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y... ## $ sex <fctr> male, male, female, female, male, male, female, male... ## $ age <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ... ## $ sibsp <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,... ## $ parch <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,... ## $ fare <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ... ## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...
Trin 3) Opret tog/testsæt
Før du træner din model, skal du udføre to trin:
- Opret et tog og testsæt: Du træner modellen på togsættet og tester forudsigelsen på testsættet (dvs. usete data)
- Installer rpart.plot fra konsollen
Den almindelige praksis er at dele dataene 80/20, 80 procent af dataene tjener til at træne modellen og 20 procent til at lave forudsigelser. Du skal oprette to separate datarammer. Du ønsker ikke at røre ved testsættet, før du er færdig med at bygge din model. Du kan oprette et funktionsnavn create_train_test(), der tager tre argumenter.
create_train_test(df, size = 0.8, train = TRUE) arguments: -df: Dataset used to train the model. -size: Size of the split. By default, 0.8. Numerical value -train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample < - 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } }
Kode Forklaring
- function(data, størrelse=0.8, tog = TRUE): Tilføj argumenterne i funktionen
- n_row = nrow(data): Tæl antallet af rækker i datasættet
- total_row = size*n_row: Returner den n'te række for at konstruere togsættet
- train_sample <- 1:total_row: Vælg den første række til den n'te række
- if (tog ==TRUE){ } else { }: Hvis betingelsen sættes til sand, returneres togsættet, ellers testsættet.
Du kan teste din funktion og kontrollere dimensionen.
data_train <- create_train_test(clean_titanic, 0.8, train = TRUE) data_test <- create_train_test(clean_titanic, 0.8, train = FALSE) dim(data_train)
Output:
## [1] 836 8
dim(data_test)
Output:
## [1] 209 8
Togdatasættet har 1046 rækker, mens testdatasættet har 262 rækker.
Du bruger funktionen prop.table() kombineret med table() for at kontrollere, om randomiseringsprocessen er korrekt.
prop.table(table(data_train$survived))
Output:
## ## No Yes ## 0.5944976 0.4055024
prop.table(table(data_test$survived))
Output:
## ## No Yes ## 0.5789474 0.4210526
I begge datasæt er mængden af overlevende den samme, omkring 40 procent.
Installer rpart.plot
rpart.plot er ikke tilgængelig fra conda-biblioteker. Du kan installere det fra konsollen:
install.packages("rpart.plot")
Trin 4) Byg modellen
Du er klar til at bygge modellen. Syntaksen for Rpart-beslutningstræfunktionen er:
rpart(formula, data=, method='') arguments: - formula: The function to predict - data: Specifies the data frame- method: - "class" for a classification tree - "anova" for a regression tree
Du bruger klassemetoden, fordi du forudsiger en klasse.
library(rpart) library(rpart.plot) fit <- rpart(survived~., data = data_train, method = 'class') rpart.plot(fit, extra = 106
Kode Forklaring
- rpart(): Funktion til at passe til modellen. Argumenterne er:
- overlevede ~.: Beslutningstræernes formel
- data = data_tog: Datasæt
- metode = 'klasse': Tilpas en binær model
- rpart.plot(fit, extra= 106): Plot træet. De ekstra funktioner er sat til 101 for at vise sandsynligheden for 2. klasse (nyttigt til binære svar). Du kan henvise til tegnefilm for mere information om de andre valg.
Output:
Du starter ved rodknuden (dybde 0 over 3, toppen af grafen):
- Øverst er det den samlede sandsynlighed for overlevelse. Det viser andelen af passagerer, der overlevede styrtet. 41 procent af passagererne overlevede.
- Denne node spørger, om passagerens køn er mand. Hvis ja, så går du ned til rodens venstre barneknude (dybde 2). 63 procent er mænd med en overlevelsessandsynlighed på 21 procent.
- I den anden knude spørger du, om den mandlige passager er over 3.5 år. Hvis ja, så er chancen for overlevelse 19 procent.
- Du fortsætter sådan for at forstå, hvilke funktioner der påvirker sandsynligheden for overlevelse.
Bemærk, at en af de mange kvaliteter ved beslutningstræer er, at de kræver meget lidt dataforberedelse. De kræver især ikke funktionsskalering eller centrering.
Som standard bruger rpart()-funktionen Gini urenhedsmål for at opdele sedlen. Jo højere Gini-koefficienten er, jo flere forskellige forekomster i noden.
Trin 5) Lav en forudsigelse
Du kan forudsige dit testdatasæt. For at lave en forudsigelse kan du bruge funktionen forudsige(). Den grundlæggende syntaks for forudsigelse for R beslutningstræ er:
predict(fitted_model, df, type = 'class') arguments: - fitted_model: This is the object stored after model estimation. - df: Data frame used to make the prediction - type: Type of prediction - 'class': for classification - 'prob': to compute the probability of each class - 'vector': Predict the mean response at the node level
Du ønsker at forudsige, hvilke passagerer der har størst sandsynlighed for at overleve efter kollisionen fra testsættet. Det betyder, at du vil vide blandt de 209 passagerer, hvilken der vil overleve eller ej.
predict_unseen <-predict(fit, data_test, type = 'class')
Kode Forklaring
- forudsige(tilpas, data_test, type = 'klasse'): Forudsige klassen (0/1) af testsættet
Test af passageren, der ikke nåede det, og dem, der gjorde det.
table_mat <- table(data_test$survived, predict_unseen) table_mat
Kode Forklaring
- table(data_test$survived, predict_unseen): Opret en tabel for at tælle, hvor mange passagerer der er klassificeret som overlevende og døde sammenlignet med den korrekte beslutningstræklassificering i R
Output:
## predict_unseen ## No Yes ## No 106 15 ## Yes 30 58
Modellen forudsagde korrekt 106 døde passagerer, men klassificerede 15 overlevende som døde. I analogi misklassificerede modellen 30 passagerer som overlevende, mens de viste sig at være døde.
Trin 6) Mål ydeevne
Du kan beregne et nøjagtighedsmål for klassificeringsopgave med forvirringsmatrix:
forvirringsmatrix er et bedre valg til at evaluere klassificeringens ydeevne. Den generelle idé er at tælle antallet af gange, hvor sande tilfælde klassificeres er falske.
Hver række i en forvirringsmatrix repræsenterer et faktisk mål, mens hver kolonne repræsenterer et forudsagt mål. Den første række af denne matrix betragter døde passagerer (den falske klasse): 106 blev korrekt klassificeret som døde (Sand negativ), mens den resterende fejlagtigt blev klassificeret som en overlevende (Falsk positiv). Den anden række betragter de overlevende, den positive klasse var 58 (Rigtig positiv), Mens Sand negativ var 30.
Du kan beregne nøjagtighedstest fra forvirringsmatricen:
Det er andelen af sand positiv og sand negativ over summen af matricen. Med R kan du kode som følger:
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
Kode Forklaring
- sum(diag(table_mat)): Summen af diagonalen
- sum(tabel_mat): Summen af matricen.
Du kan udskrive nøjagtigheden af testsættet:
print(paste('Accuracy for test', accuracy_Test))
Output:
## [1] "Accuracy for test 0.784688995215311"
Du har en score på 78 procent for testsættet. Du kan replikere den samme øvelse med træningsdatasættet.
Trin 7) Indstil hyper-parametrene
Beslutningstræ i R har forskellige parametre, der styrer aspekter af pasformen. I rpart-beslutningstræbiblioteket kan du styre parametrene ved hjælp af funktionen rpart.control(). I den følgende kode introducerer du de parametre, du vil tune. Du kan henvise til tegnefilm for andre parametre.
rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30) Arguments: -minsplit: Set the minimum number of observations in the node before the algorithm perform a split -minbucket: Set the minimum number of observations in the final note i.e. the leaf -maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0
Vi fortsætter som følger:
- Konstruer funktion for at returnere nøjagtighed
- Indstil den maksimale dybde
- Indstil det mindste antal sample, en node skal have, før den kan opdeles
- Indstil det mindste antal prøver, som en bladknude skal have
Du kan skrive en funktion for at vise nøjagtigheden. Du pakker blot den kode, du brugte før:
- forudsige: forudsige_usynlig <- forudsige(tilpas, data_test, type = 'klasse')
- Fremstil tabel: table_mat <- table(data_test$survived, predict_unseen)
- Beregn nøjagtighed: nøjagtighed_Test <- sum(diag(tabel_mat))/sum(tabel_mat)
accuracy_tune <- function(fit) { predict_unseen <- predict(fit, data_test, type = 'class') table_mat <- table(data_test$survived, predict_unseen) accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test }
Du kan prøve at justere parametrene og se, om du kan forbedre modellen i forhold til standardværdien. Som en påmindelse skal du have en nøjagtighed højere end 0.78
control <- rpart.control(minsplit = 4, minbucket = round(5 / 3), maxdepth = 3, cp = 0) tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control) accuracy_tune(tune_fit)
Output:
## [1] 0.7990431
Med følgende parameter:
minsplit = 4 minbucket= round(5/3) maxdepth = 3cp=0
Du får en højere ydeevne end den tidligere model. Tillykke!
Resumé
Vi kan opsummere funktionerne til at træne en beslutningstræalgoritme i R
Bibliotek | Objektiv | Funktion | Klasse | parametre | Detaljer |
---|---|---|---|---|---|
rpart | Togklassifikationstræ i R | rpart() | klasse | formel, df, metode | |
rpart | Træn regressionstræ | rpart() | ANOVA | formel, df, metode | |
rpart | Tegn træerne | rpart.plot() | monteret model | ||
bund | forudsige | forudsige() | klasse | monteret model, type | |
bund | forudsige | forudsige() | sandsynlighed | monteret model, type | |
bund | forudsige | forudsige() | vektor | monteret model, type | |
rpart | Kontrolparametre | rpart.control() | minsplit | Indstil det mindste antal observationer i noden, før algoritmen udfører en opdeling | |
minbucket | Indstil minimumsantallet af observationer i den sidste note, dvs. bladet | ||||
maxdepth | Indstil den maksimale dybde for enhver knude i det endelige træ. Rodknuden behandles med en dybde på 0 | ||||
rpart | Togmodel med kontrolparameter | rpart() | formel, df, metode, kontrol |
Bemærk: Træn modellen på et træningsdata og test ydeevnen på et uset datasæt, dvs. testsæt.