Beslutsträd i R: Klassificeringsträd med exempel
Vad är beslutsträd?
Beslutsträd är mångsidiga maskininlärningsalgoritmer som kan utföra både klassificerings- och regressionsuppgifter. De är mycket kraftfulla algoritmer som kan passa komplexa datauppsättningar. Dessutom är beslutsträd grundläggande komponenter i slumpmässiga skogar, som är bland de mest potenta maskininlärningsalgoritmerna som finns tillgängliga idag.
Träning och visualisering av beslutsträd i R
För att bygga ditt första beslutsträd i R-exemplet kommer vi att gå tillväga enligt följande i denna handledning för beslutsträd:
- Steg 1: Importera data
- Steg 2: Rengör datasetet
- Steg 3: Skapa tåg/testset
- Steg 4: Bygg modellen
- Steg 5: Gör en förutsägelse
- Steg 6: Mät prestanda
- Steg 7: Justera hyperparametrarna
Steg 1) Importera data
Om du är nyfiken på Titanics öde kan du se den här videon på youtube. Syftet med denna datauppsättning är att förutsäga vilka personer som är mer benägna att överleva efter kollisionen med isberget. Datauppsättningen innehåller 13 variabler och 1309 observationer. Datauppsättningen är ordnad efter variabeln X.
set.seed(678) path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv' titanic <-read.csv(path) head(titanic)
Produktion:
## X pclass survived name sex ## 1 1 1 1 Allen, Miss. Elisabeth Walton female ## 2 2 1 1 Allison, Master. Hudson Trevor male ## 3 3 1 0 Allison, Miss. Helen Loraine female ## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male ## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female ## 6 6 1 1 Anderson, Mr. Harry male ## age sibsp parch ticket fare cabin embarked ## 1 29.0000 0 0 24160 211.3375 B5 S ## 2 0.9167 1 2 113781 151.5500 C22 C26 S ## 3 2.0000 1 2 113781 151.5500 C22 C26 S ## 4 30.0000 1 2 113781 151.5500 C22 C26 S ## 5 25.0000 1 2 113781 151.5500 C22 C26 S ## 6 48.0000 0 0 19952 26.5500 E12 S ## home.dest ## 1 St Louis, MO ## 2 Montreal, PQ / Chesterville, ON ## 3 Montreal, PQ / Chesterville, ON ## 4 Montreal, PQ / Chesterville, ON ## 5 Montreal, PQ / Chesterville, ON ## 6 New York, NY
tail(titanic)
Produktion:
## X pclass survived name sex age sibsp ## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0 ## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1 ## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1 ## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0 ## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0 ## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0 ## parch ticket fare cabin embarked home.dest ## 1304 0 2627 14.4583 C ## 1305 0 2665 14.4542 C ## 1306 0 2665 14.4542 C ## 1307 0 2656 7.2250 C ## 1308 0 2670 7.2250 C ## 1309 0 315082 7.8750 S
Från huvud- och svansutgången kan du märka att data inte blandas. Detta är en stor fråga! När du ska dela upp din data mellan ett tågset och ett testset väljer du endast passageraren från klass 1 och 2 (Ingen passagerare från klass 3 är bland de översta 80 procenten av observationerna), vilket innebär att algoritmen aldrig kommer att se egenskaperna hos passagerare i klass 3. Detta misstag kommer att leda till dålig förutsägelse.
För att lösa det här problemet kan du använda funktionen sample().
shuffle_index <- sample(1:nrow(titanic)) head(shuffle_index)
Beslutsträd R-kod Förklaring
- sample(1:nrow(titanic)): Skapa en slumpmässig lista med index från 1 till 1309 (dvs det maximala antalet rader).
Produktion:
## [1] 288 874 1078 633 887 992
Du kommer att använda detta index för att blanda titanic-datauppsättningen.
titanic <- titanic[shuffle_index, ] head(titanic)
Produktion:
## X pclass survived ## 288 288 1 0 ## 874 874 3 0 ## 1078 1078 3 1 ## 633 633 3 0 ## 887 887 3 1 ## 992 992 3 1 ## name sex age ## 288 Sutton, Mr. Frederick male 61 ## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42 ## 1078 O'Driscoll, Miss. Bridget female NA ## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39 ## 887 Jermyn, Miss. Annie female NA ## 992 Mamee, Mr. Hanna male NA ## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ ## 874 0 0 348121 7.6500 F G63 S ## 1078 0 0 14311 7.7500 Q ## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN ## 887 0 0 14313 7.7500 Q ## 992 0 0 2677 7.2292 C
Steg 2) Rengör datasetet
Datastrukturen visar att vissa variabler har NA. Datarensning ska göras enligt följande
- Släpp variablerna home.dest,cabin, name, X och ticket
- Skapa faktorvariabler för pclass och överlevde
- Släpp NA
library(dplyr) # Drop variables clean_titanic <- titanic % > % select(-c(home.dest, cabin, name, X, ticket)) % > % #Convert to factor level mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')), survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > % na.omit() glimpse(clean_titanic)
Kodförklaring
- select(-c(home.dest, cabin, name, X, ticket)): Släpp onödiga variabler
- pclass = factor(pclass, levels = c(1,2,3), labels= c('Övre', 'Middle', 'Lower')): Lägg till etikett till variabeln pclass. 1 blir Upper, 2 blir Middle och 3 blir lägre
- faktor(överlevde, nivåer = c(0,1), etiketter = c('Nej', 'Ja')): Lägg till etikett till variabeln överlevd. 1 Blir Nej och 2 blir Ja
- na.omit(): Ta bort NA-observationerna
Produktion:
## Observations: 1,045 ## Variables: 8 ## $ pclass <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U... ## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y... ## $ sex <fctr> male, male, female, female, male, male, female, male... ## $ age <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ... ## $ sibsp <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,... ## $ parch <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,... ## $ fare <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ... ## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...
Steg 3) Skapa tåg/testset
Innan du tränar din modell måste du utföra två steg:
- Skapa ett tåg och testset: Du tränar modellen på tågsetet och testar förutsägelsen på testsetet (dvs osynliga data)
- Installera rpart.plot från konsolen
Vanlig praxis är att dela upp data 80/20, 80 procent av data tjänar till att träna modellen och 20 procent för att göra förutsägelser. Du måste skapa två separata dataramar. Du vill inte röra testsetet förrän du är färdig med att bygga din modell. Du kan skapa ett funktionsnamn create_train_test() som tar tre argument.
create_train_test(df, size = 0.8, train = TRUE) arguments: -df: Dataset used to train the model. -size: Size of the split. By default, 0.8. Numerical value -train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample < - 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } }
Kodförklaring
- function(data, storlek=0.8, tåg = TRUE): Lägg till argumenten i funktionen
- n_row = nrow(data): Räkna antalet rader i datamängden
- total_row = size*n_row: Returnera den n:e raden för att konstruera tågsetet
- train_sample <- 1:total_row: Välj den första raden till den n:e raden
- if (tåg ==TRUE){ } annat { }: Om villkoret är sant, returnera tåguppsättningen, annars testuppsättningen.
Du kan testa din funktion och kontrollera dimensionen.
data_train <- create_train_test(clean_titanic, 0.8, train = TRUE) data_test <- create_train_test(clean_titanic, 0.8, train = FALSE) dim(data_train)
Produktion:
## [1] 836 8
dim(data_test)
Produktion:
## [1] 209 8
Tågsuppsättningen har 1046 rader medan testdatauppsättningen har 262 rader.
Du använder funktionen prop.table() i kombination med table() för att verifiera om randomiseringsprocessen är korrekt.
prop.table(table(data_train$survived))
Produktion:
## ## No Yes ## 0.5944976 0.4055024
prop.table(table(data_test$survived))
Produktion:
## ## No Yes ## 0.5789474 0.4210526
I båda datamängderna är antalet överlevande detsamma, cirka 40 procent.
Installera rpart.plot
rpart.plot är inte tillgängligt från condas bibliotek. Du kan installera det från konsolen:
install.packages("rpart.plot")
Steg 4) Bygg modellen
Du är redo att bygga modellen. Syntaxen för Rparts beslutsträdsfunktion är:
rpart(formula, data=, method='') arguments: - formula: The function to predict - data: Specifies the data frame- method: - "class" for a classification tree - "anova" for a regression tree
Du använder klassmetoden eftersom du förutsäger en klass.
library(rpart) library(rpart.plot) fit <- rpart(survived~., data = data_train, method = 'class') rpart.plot(fit, extra = 106
Kodförklaring
- rpart(): Funktion för att passa modellen. Argumenten är:
- överlevde ~.: Beslutsträdens formel
- data = data_train: Dataset
- metod = 'klass': Passar en binär modell
- rpart.plot(fit, extra= 106): Rita trädet. Extrafunktionerna är inställda på 101 för att visa sannolikheten för den andra klassen (användbart för binära svar). Du kan hänvisa till karikatyrerna för mer information om de andra valen.
Produktion:
Du börjar vid rotnoden (djup 0 över 3, toppen av grafen):
- Överst är det den totala sannolikheten för överlevnad. Den visar andelen passagerare som överlevde kraschen. 41 procent av passagerarna överlevde.
- Denna nod frågar om passagerarens kön är man. Om ja, så går du ner till rotens vänstra barnnod (djup 2). 63 procent är män med en överlevnadssannolikhet på 21 procent.
- I den andra noden frågar du om den manliga passageraren är över 3.5 år. Om ja, så är chansen att överleva 19 procent.
- Du fortsätter så för att förstå vilka egenskaper som påverkar sannolikheten för överlevnad.
Observera att en av de många egenskaperna hos beslutsträd är att de kräver väldigt lite dataförberedelse. I synnerhet kräver de inte funktionsskalning eller centrering.
Som standard använder funktionen rpart() Gini föroreningsåtgärd för att dela sedeln. Ju högre Gini-koefficient, desto fler olika instanser inom noden.
Steg 5) Gör en förutsägelse
Du kan förutsäga din testdatauppsättning. För att göra en förutsägelse kan du använda predict()-funktionen. Den grundläggande syntaxen för förutsägelse för R beslutsträd är:
predict(fitted_model, df, type = 'class') arguments: - fitted_model: This is the object stored after model estimation. - df: Data frame used to make the prediction - type: Type of prediction - 'class': for classification - 'prob': to compute the probability of each class - 'vector': Predict the mean response at the node level
Du vill förutsäga vilka passagerare som är mer benägna att överleva efter kollisionen från testsetet. Det betyder att du kommer att veta bland de 209 passagerarna, vilken som kommer att överleva eller inte.
predict_unseen <-predict(fit, data_test, type = 'class')
Kodförklaring
- predict(fit, data_test, type = 'class'): Förutsäg klassen (0/1) för testuppsättningen
Testar passageraren som inte klarade sig och de som gjorde det.
table_mat <- table(data_test$survived, predict_unseen) table_mat
Kodförklaring
- table(data_test$survived, predict_unseen): Skapa en tabell för att räkna hur många passagerare som klassas som överlevande och avlidna jämfört med den korrekta beslutsträdsklassificeringen i R
Produktion:
## predict_unseen ## No Yes ## No 106 15 ## Yes 30 58
Modellen förutspådde korrekt 106 döda passagerare men klassificerade 15 överlevande som döda. I analogi missklassificerade modellen 30 passagerare som överlevande medan de visade sig vara döda.
Steg 6) Mät prestanda
Du kan beräkna ett noggrannhetsmått för klassificeringsuppgiften med förvirringsmatris:
Ocuco-landskapet förvirringsmatris är ett bättre val för att utvärdera klassificeringsprestanda. Den allmänna idén är att räkna antalet gånger som sanna instanser klassificeras är falska.
Varje rad i en förvirringsmatris representerar ett faktiskt mål, medan varje kolumn representerar ett förutsagt mål. Den första raden i denna matris betraktar döda passagerare (False-klassen): 106 klassades korrekt som döda (Riktigt negativt), medan den återstående felaktigt klassificerades som en överlevande (Falskt positivt). Den andra raden betraktar de överlevande, den positiva klassen var 58 (Riktigt positivt), medan Riktigt negativt var 30.
Du kan beräkna noggrannhetstest från förvirringsmatrisen:
Det är andelen sant positivt och sant negativt över summan av matrisen. Med R kan du koda enligt följande:
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
Kodförklaring
- sum(diag(table_mat)): Summan av diagonalen
- sum(table_mat): Summan av matrisen.
Du kan skriva ut testuppsättningens noggrannhet:
print(paste('Accuracy for test', accuracy_Test))
Produktion:
## [1] "Accuracy for test 0.784688995215311"
Du har en poäng på 78 procent för testsetet. Du kan replikera samma övning med träningsdatauppsättningen.
Steg 7) Justera hyperparametrarna
Beslutsträd i R har olika parametrar som styr aspekter av passformen. I rparts beslutsträdsbibliotek kan du styra parametrarna med funktionen rpart.control(). I följande kod introducerar du parametrarna du ska ställa in. Du kan hänvisa till karikatyrerna för andra parametrar.
rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30) Arguments: -minsplit: Set the minimum number of observations in the node before the algorithm perform a split -minbucket: Set the minimum number of observations in the final note i.e. the leaf -maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0
Vi kommer att gå tillväga enligt följande:
- Konstruera funktion för att returnera noggrannhet
- Ställ in det maximala djupet
- Justera det minsta antalet sampel som en nod måste ha innan den kan delas
- Justera det minsta antalet prov en lövnod måste ha
Du kan skriva en funktion för att visa noggrannheten. Du slår helt enkelt in koden du använde tidigare:
- predict: predict_unseen <- predict(fit, data_test, type = 'class')
- Producera tabell: table_mat <- table(data_test$survived, predict_unseen)
- Beräkna noggrannhet: precision_Test <- sum(diag(tabell_mat))/sum(tabell_mat)
accuracy_tune <- function(fit) { predict_unseen <- predict(fit, data_test, type = 'class') table_mat <- table(data_test$survived, predict_unseen) accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test }
Du kan försöka justera parametrarna och se om du kan förbättra modellen över standardvärdet. Som en påminnelse måste du få en noggrannhet högre än 0.78
control <- rpart.control(minsplit = 4, minbucket = round(5 / 3), maxdepth = 3, cp = 0) tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control) accuracy_tune(tune_fit)
Produktion:
## [1] 0.7990431
Med följande parameter:
minsplit = 4 minbucket= round(5/3) maxdepth = 3cp=0
Du får en högre prestanda än den tidigare modellen. Grattis!
Sammanfattning
Vi kan sammanfatta funktionerna att träna en beslutsträdsalgoritm i R
Bibliotek | Mål | Funktion | Klass | Driftparametrar | Detaljer |
---|---|---|---|---|---|
rpart | Tågklassificeringsträd i R | rpart() | klass | formel, df, metod | |
rpart | Träna regressionsträd | rpart() | anova | formel, df, metod | |
rpart | Rita träden | rpart.plot() | monterad modell | ||
bas | förutse | förutspå() | klass | monterad modell, typ | |
bas | förutse | förutspå() | prob | monterad modell, typ | |
bas | förutse | förutspå() | vektor | monterad modell, typ | |
rpart | Kontrollparametrar | rpart.control() | minsplit | Ställ in det minsta antalet observationer i noden innan algoritmen utför en split | |
minbucket | Ställ in minsta antal observationer i slutnoten, dvs bladet | ||||
Max djup | Ställ in maximalt djup för valfri nod i det slutliga trädet. Rotnoden behandlas med ett djup på 0 | ||||
rpart | Tågmodell med kontrollparameter | rpart() | formel, df, metod, kontroll |
Obs: Träna modellen på träningsdata och testa prestandan på en osynlig datauppsättning, dvs testset.