Rozhodovací strom v R: Klasifikační strom s příkladem
Co jsou rozhodovací stromy?
Rozhodovací stromy jsou všestranný algoritmus strojového učení, který může provádět klasifikační i regresní úkoly. Jsou to velmi výkonné algoritmy, které jsou schopné přizpůsobit složité datové sady. Kromě toho jsou rozhodovací stromy základními součástmi náhodných lesů, které patří mezi nejúčinnější dnes dostupné algoritmy strojového učení.
Školení a vizualizace rozhodovacích stromů v R
Chcete-li vytvořit svůj první rozhodovací strom v příkladu R, budeme v tomto tutoriálu rozhodovacího stromu postupovat následovně:
- Krok 1: Importujte data
- Krok 2: Vyčistěte datovou sadu
- Krok 3: Vytvořte vlak/testovací sadu
- Krok 4: Sestavte model
- Krok 5: Proveďte předpověď
- Krok 6: Změřte výkon
- Krok 7: Vylaďte hyperparametry
Krok 1) Importujte data
Pokud vás zajímá osud Titanicu, můžete se podívat na toto video Youtube. Účelem tohoto souboru dat je předpovědět, kteří lidé s větší pravděpodobností přežijí po srážce s ledovcem. Soubor dat obsahuje 13 proměnných a 1309 pozorování. Soubor dat je uspořádán podle proměnné X.
set.seed(678) path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv' titanic <-read.csv(path) head(titanic)
Výstup:
## X pclass survived name sex ## 1 1 1 1 Allen, Miss. Elisabeth Walton female ## 2 2 1 1 Allison, Master. Hudson Trevor male ## 3 3 1 0 Allison, Miss. Helen Loraine female ## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male ## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female ## 6 6 1 1 Anderson, Mr. Harry male ## age sibsp parch ticket fare cabin embarked ## 1 29.0000 0 0 24160 211.3375 B5 S ## 2 0.9167 1 2 113781 151.5500 C22 C26 S ## 3 2.0000 1 2 113781 151.5500 C22 C26 S ## 4 30.0000 1 2 113781 151.5500 C22 C26 S ## 5 25.0000 1 2 113781 151.5500 C22 C26 S ## 6 48.0000 0 0 19952 26.5500 E12 S ## home.dest ## 1 St Louis, MO ## 2 Montreal, PQ / Chesterville, ON ## 3 Montreal, PQ / Chesterville, ON ## 4 Montreal, PQ / Chesterville, ON ## 5 Montreal, PQ / Chesterville, ON ## 6 New York, NY
tail(titanic)
Výstup:
## X pclass survived name sex age sibsp ## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0 ## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1 ## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1 ## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0 ## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0 ## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0 ## parch ticket fare cabin embarked home.dest ## 1304 0 2627 14.4583 C ## 1305 0 2665 14.4542 C ## 1306 0 2665 14.4542 C ## 1307 0 2656 7.2250 C ## 1308 0 2670 7.2250 C ## 1309 0 315082 7.8750 S
Z výstupu hlavy a ocasu si můžete všimnout, že data nejsou zamíchána. To je velký problém! Když rozdělíte data mezi vlakovou soupravu a testovací soupravu, vyberete 👔 cestující z třídy 1 a 2 (žádný cestující z třídy 3 není v horních 80 procentech pozorování), což znamená, že algoritmus nikdy neuvidí vlastnosti cestujícího třídy 3. Tato chyba povede ke špatné predikci.
Chcete-li tento problém vyřešit, můžete použít funkci sample().
shuffle_index <- sample(1:nrow(titanic)) head(shuffle_index)
Rozhodovací strom R kód Vysvětlení
- sample(1:nrow(titanic)): Vygeneruje náhodný seznam indexů od 1 do 1309 (tj. maximální počet řádků).
Výstup:
## [1] 288 874 1078 633 887 992
Tento index použijete k promíchání titanické datové sady.
titanic <- titanic[shuffle_index, ] head(titanic)
Výstup:
## X pclass survived ## 288 288 1 0 ## 874 874 3 0 ## 1078 1078 3 1 ## 633 633 3 0 ## 887 887 3 1 ## 992 992 3 1 ## name sex age ## 288 Sutton, Mr. Frederick male 61 ## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42 ## 1078 O'Driscoll, Miss. Bridget female NA ## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39 ## 887 Jermyn, Miss. Annie female NA ## 992 Mamee, Mr. Hanna male NA ## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ ## 874 0 0 348121 7.6500 F G63 S ## 1078 0 0 14311 7.7500 Q ## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN ## 887 0 0 14313 7.7500 Q ## 992 0 0 2677 7.2292 C
Krok 2) Vyčistěte datovou sadu
Struktura dat ukazuje, že některé proměnné mají NA. Čištění dat proveďte následovně
- Vypusťte proměnné home.dest,cabin, name, X a ticket
- Vytvořte proměnné faktoru pro pclass a přežil
- Pusťte NA
library(dplyr) # Drop variables clean_titanic <- titanic % > % select(-c(home.dest, cabin, name, X, ticket)) % > % #Convert to factor level mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')), survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > % na.omit() glimpse(clean_titanic)
Vysvětlení kódu
- select(-c(domov.cíl, kajuta, jméno, X, letenka)): Vynechte nepotřebné proměnné
- pclass = factor(pclass, levely = c(1,2,3), labels= c('Upper', 'Middle', 'Lower')): Přidejte popisek k proměnné pclass. 1 se změní na horní, 2 na střední a 3 na nižší
- factor(survived, levels = c(0,1), labels = c('No', 'Yes')): Přidejte popisek k proměnné, která přežila. 1 se změní na Ne a 2 se změní na Ano
- na.omit(): Odstraňte pozorování NA
Výstup:
## Observations: 1,045 ## Variables: 8 ## $ pclass <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U... ## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y... ## $ sex <fctr> male, male, female, female, male, male, female, male... ## $ age <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ... ## $ sibsp <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,... ## $ parch <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,... ## $ fare <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ... ## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...
Krok 3) Vytvořte vlak/testovací sadu
Před trénováním modelu musíte provést dva kroky:
- Vytvořte vlak a testovací soupravu: Natrénujete model na soupravě vlaku a otestujete předpověď na testovací soupravě (tj. neviditelná data)
- Nainstalujte rpart.plot z konzoly
Běžnou praxí je rozdělit data 80/20, 80 procent dat slouží k trénování modelu a 20 procent k předpovědím. Musíte vytvořit dva samostatné datové rámce. Nechcete se dotknout testovací sady, dokud nedokončíte stavbu modelu. Můžete vytvořit název funkce create_train_test(), který přebírá tři argumenty.
create_train_test(df, size = 0.8, train = TRUE) arguments: -df: Dataset used to train the model. -size: Size of the split. By default, 0.8. Numerical value -train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample < - 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } }
Vysvětlení kódu
- function(data, size=0.8, train = TRUE): Přidejte argumenty do funkce
- n_row = nrow(data): Počet řádků v datové sadě
- total_row = size*n_row: Vraťte n-tý řádek pro sestavení vlakové soupravy
- train_sample <- 1:total_row: Vyberte první řádek až n-tý řádek
- if (train ==TRUE){ } else { }: Pokud je podmínka nastavena na true, vrátí vlakovou sadu, jinak testovací sadu.
Můžete otestovat svou funkci a zkontrolovat rozměr.
data_train <- create_train_test(clean_titanic, 0.8, train = TRUE) data_test <- create_train_test(clean_titanic, 0.8, train = FALSE) dim(data_train)
Výstup:
## [1] 836 8
dim(data_test)
Výstup:
## [1] 209 8
Datový soubor vlaku má 1046 řádků, zatímco testovací datový soubor má 262 řádků.
Pomocí funkce prop.table() v kombinaci s table() ověříte, zda je proces randomizace správný.
prop.table(table(data_train$survived))
Výstup:
## ## No Yes ## 0.5944976 0.4055024
prop.table(table(data_test$survived))
Výstup:
## ## No Yes ## 0.5789474 0.4210526
V obou souborech dat je počet přeživších stejný, asi 40 procent.
Nainstalujte rpart.plot
rpart.plot není dostupný z knihoven conda. Můžete jej nainstalovat z konzole:
install.packages("rpart.plot")
Krok 4) Sestavte model
Jste připraveni postavit model. Syntaxe funkce rozhodovacího stromu Rpart je:
rpart(formula, data=, method='') arguments: - formula: The function to predict - data: Specifies the data frame- method: - "class" for a classification tree - "anova" for a regression tree
Používáte metodu třídy, protože předpovídáte třídu.
library(rpart) library(rpart.plot) fit <- rpart(survived~., data = data_train, method = 'class') rpart.plot(fit, extra = 106
Vysvětlení kódu
- rpart(): Funkce přizpůsobená modelu. Argumenty jsou:
- přežil ~.: Vzorec rozhodovacích stromů
- data = data_train: Dataset
- method = 'class': Fit binární model
- rpart.plot(fit, extra= 106): Vykreslete strom. Další funkce jsou nastaveny na 101 pro zobrazení pravděpodobnosti 2. třídy (užitečné pro binární odezvy). Můžete odkazovat na viněta pro více informací o dalších možnostech.
Výstup:
Začnete u kořenového uzlu (hloubka 0 nad 3, horní část grafu):
- Na vrcholu je celková pravděpodobnost přežití. Ukazuje podíl cestujících, kteří nehodu přežili. 41 procent cestujících přežilo.
- Tento uzel se ptá, zda je pohlaví cestujícího muž. Pokud ano, přejděte dolů do levého podřízeného uzlu kořene (hloubka 2). 63 procent jsou muži s pravděpodobností přežití 21 procent.
- Ve druhém uzlu se zeptáte, zda je cestujícímu muž starší 3.5 roku. Pokud ano, pak je šance na přežití 19 procent.
- Pokračujte v tom, abyste pochopili, jaké vlastnosti ovlivňují pravděpodobnost přežití.
Všimněte si, že jednou z mnoha vlastností rozhodovacích stromů je, že vyžadují velmi malou přípravu dat. Zejména nevyžadují změnu měřítka prvků nebo centrování.
Ve výchozím nastavení funkce rpart() používá Gini míra nečistot k rozdělení poznámky. Čím vyšší je Gini koeficient, tím více různých instancí v uzlu.
Krok 5) Proveďte předpověď
Můžete předvídat testovací datovou sadu. Chcete-li provést předpověď, můžete použít funkci forecast(). Základní syntaxe predikce pro rozhodovací strom R je:
predict(fitted_model, df, type = 'class') arguments: - fitted_model: This is the object stored after model estimation. - df: Data frame used to make the prediction - type: Type of prediction - 'class': for classification - 'prob': to compute the probability of each class - 'vector': Predict the mean response at the node level
Chcete předpovědět, kteří cestující s větší pravděpodobností přežijí po srážce z testovací soupravy. To znamená, že mezi těmi 209 cestujícími budete vědět, který z nich přežije nebo ne.
predict_unseen <-predict(fit, data_test, type = 'class')
Vysvětlení kódu
- predikovat(fit, data_test, typ = 'třída'): Předpovídá třídu (0/1) testovací sady
Testování cestujících, kteří to nestihli, a těch, kteří ano.
table_mat <- table(data_test$survived, predict_unseen) table_mat
Vysvětlení kódu
- table(data_test$survived, forecast_unseen): Vytvořte tabulku, která spočítá, kolik cestujících je klasifikováno jako přeživší a zemřelo v porovnání se správnou klasifikací rozhodovacího stromu v R
Výstup:
## predict_unseen ## No Yes ## No 106 15 ## Yes 30 58
Model správně předpověděl 106 mrtvých cestujících, ale 15 přeživších klasifikoval jako mrtvé. Analogicky model nesprávně klasifikoval 30 cestujících jako přeživší, zatímco se ukázalo, že jsou mrtví.
Krok 6) Změřte výkon
Můžete vypočítat míru přesnosti pro klasifikační úlohu pomocí matoucí matice:
Jedno matoucí matice je lepší volbou pro hodnocení výkonu klasifikace. Obecnou myšlenkou je spočítat, kolikrát jsou skutečné instance klasifikovány jako nepravdivé.
Každý řádek v matečné matici představuje skutečný cíl, zatímco každý sloupec představuje předpokládaný cíl. První řádek této matice uvažuje mrtvé cestující (třída False): 106 bylo správně klasifikováno jako mrtvé (Pravda negativní), zatímco zbývající byl chybně klasifikován jako přeživší (Falešně pozitivní). Druhá řada bere v úvahu přeživší, pozitivní třída byla 58 (Pravda pozitivní), zatímco Pravda negativní byl 30.
Můžete vypočítat test přesnosti z matoucí matice:
Je to podíl skutečných kladných a záporných hodnot na součtu matice. S R můžete kódovat následovně:
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
Vysvětlení kódu
- sum(diag(table_mat)): Součet úhlopříčky
- sum(table_mat): Součet matice.
Přesnost testovací sady si můžete vytisknout:
print(paste('Accuracy for test', accuracy_Test))
Výstup:
## [1] "Accuracy for test 0.784688995215311"
Za testovací sadu máte skóre 78 procent. Stejné cvičení můžete replikovat pomocí trénovací datové sady.
Krok 7) Vylaďte hyperparametry
Rozhodovací strom v R má různé parametry, které řídí aspekty přizpůsobení. V knihovně rozhodovacího stromu rpart můžete ovládat parametry pomocí funkce rpart.control(). V následujícím kódu uvedete parametry, které budete ladit. Můžete odkazovat na viněta pro další parametry.
rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30) Arguments: -minsplit: Set the minimum number of observations in the node before the algorithm perform a split -minbucket: Set the minimum number of observations in the final note i.e. the leaf -maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0
Budeme postupovat následovně:
- Vytvořte funkci pro návrat přesnosti
- Nalaďte maximální hloubku
- Vylaďte minimální počet vzorků, které musí mít uzel, než se může rozdělit
- Vylaďte minimální počet vzorků, které musí mít listový uzel
Můžete napsat funkci pro zobrazení přesnosti. Jednoduše zabalíte kód, který jste použili dříve:
- předpovídat: předpovídat_nezobrazeno <- předvídat(přizpůsobit, test_dat, typ = 'třída')
- Vytvořit tabulku: table_mat <- table(data_test$survived, forecast_unseen)
- Přesnost výpočtu: přesnost_Test <- součet(diag(podložka_tabulky))/součet(podložka_tabulky)
accuracy_tune <- function(fit) { predict_unseen <- predict(fit, data_test, type = 'class') table_mat <- table(data_test$survived, predict_unseen) accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test }
Můžete zkusit vyladit parametry a zjistit, zda můžete vylepšit model nad výchozí hodnotu. Připomínáme, že musíte získat přesnost vyšší než 0.78
control <- rpart.control(minsplit = 4, minbucket = round(5 / 3), maxdepth = 3, cp = 0) tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control) accuracy_tune(tune_fit)
Výstup:
## [1] 0.7990431
S následujícím parametrem:
minsplit = 4 minbucket= round(5/3) maxdepth = 3cp=0
Získáte vyšší výkon než předchozí model. gratuluji!
Shrnutí
Můžeme shrnout funkce pro trénování algoritmu rozhodovacího stromu R
Knihovna | Objektivní | funkce | Třída | parametry | Detaily |
---|---|---|---|---|---|
rpart | Strom klasifikace vlaků v R | rpart() | třída | vzorec, df, metoda | |
rpart | Vlak regresní strom | rpart() | anova | vzorec, df, metoda | |
rpart | Vykreslete stromy | rpart.plot() | namontovaný model | ||
základna | předpovědět | předpovědět() | třída | osazený model, typ | |
základna | předpovědět | předpovědět() | prob | osazený model, typ | |
základna | předpovědět | předpovědět() | vektor | osazený model, typ | |
rpart | Kontrolní parametry | rpart.control() | minsplit | Než algoritmus provede rozdělení, nastavte minimální počet pozorování v uzlu | |
minbucket | Nastavte minimální počet pozorování v závěrečné notě, tj. listu | ||||
maxdepth | Nastavte maximální hloubku libovolného uzlu konečného stromu. Kořenový uzel je ošetřen hloubkou 0 | ||||
rpart | Model vlaku s řídicím parametrem | rpart() | vzorec, df, metoda, kontrola |
Poznámka: Trénujte model na trénovacích datech a otestujte výkon na neviditelné datové sadě, tj. testovací sadě.