Rozhodovací strom v R: Klasifikační strom s příkladem

Co jsou rozhodovací stromy?

Rozhodovací stromy jsou všestranný algoritmus strojového učení, který může provádět klasifikační i regresní úkoly. Jsou to velmi výkonné algoritmy, které jsou schopné přizpůsobit složité datové sady. Kromě toho jsou rozhodovací stromy základními součástmi náhodných lesů, které patří mezi nejúčinnější dnes dostupné algoritmy strojového učení.

Školení a vizualizace rozhodovacích stromů v R

Chcete-li vytvořit svůj první rozhodovací strom v příkladu R, budeme v tomto tutoriálu rozhodovacího stromu postupovat následovně:

  • Krok 1: Importujte data
  • Krok 2: Vyčistěte datovou sadu
  • Krok 3: Vytvořte vlak/testovací sadu
  • Krok 4: Sestavte model
  • Krok 5: Proveďte předpověď
  • Krok 6: Změřte výkon
  • Krok 7: Vylaďte hyperparametry

Krok 1) Importujte data

Pokud vás zajímá osud Titanicu, můžete se podívat na toto video Youtube. Účelem tohoto souboru dat je předpovědět, kteří lidé s větší pravděpodobností přežijí po srážce s ledovcem. Soubor dat obsahuje 13 proměnných a 1309 pozorování. Soubor dat je uspořádán podle proměnné X.

set.seed(678)
path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv'
titanic <-read.csv(path)
head(titanic)

Výstup:

##   X pclass survived                                            name    sex
## 1 1      1        1                   Allen, Miss. Elisabeth Walton female
## 2 2      1        1                  Allison, Master. Hudson Trevor   male
## 3 3      1        0                    Allison, Miss. Helen Loraine female
## 4 4      1        0            Allison, Mr. Hudson Joshua Creighton   male
## 5 5      1        0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female
## 6 6      1        1                             Anderson, Mr. Harry   male
##       age sibsp parch ticket     fare   cabin embarked
## 1 29.0000     0     0  24160 211.3375      B5        S
## 2  0.9167     1     2 113781 151.5500 C22 C26        S
## 3  2.0000     1     2 113781 151.5500 C22 C26        S
## 4 30.0000     1     2 113781 151.5500 C22 C26        S
## 5 25.0000     1     2 113781 151.5500 C22 C26        S
## 6 48.0000     0     0  19952  26.5500     E12        S
##                         home.dest
## 1                    St Louis, MO
## 2 Montreal, PQ / Chesterville, ON
## 3 Montreal, PQ / Chesterville, ON
## 4 Montreal, PQ / Chesterville, ON
## 5 Montreal, PQ / Chesterville, ON
## 6                    New York, NY
tail(titanic)

Výstup:

##         X pclass survived                      name    sex  age sibsp
## 1304 1304      3        0     Yousseff, Mr. Gerious   male   NA     0
## 1305 1305      3        0      Zabour, Miss. Hileni female 14.5     1
## 1306 1306      3        0     Zabour, Miss. Thamine female   NA     1
## 1307 1307      3        0 Zakarian, Mr. Mapriededer   male 26.5     0
## 1308 1308      3        0       Zakarian, Mr. Ortin   male 27.0     0
## 1309 1309      3        0        Zimmerman, Mr. Leo   male 29.0     0
##      parch ticket    fare cabin embarked home.dest
## 1304     0   2627 14.4583              C          
## 1305     0   2665 14.4542              C          
## 1306     0   2665 14.4542              C          
## 1307     0   2656  7.2250              C          
## 1308     0   2670  7.2250              C          
## 1309     0 315082  7.8750              S

Z výstupu hlavy a ocasu si můžete všimnout, že data nejsou zamíchána. To je velký problém! Když rozdělíte data mezi vlakovou soupravu a testovací soupravu, vyberete 👔 cestující z třídy 1 a 2 (žádný cestující z třídy 3 není v horních 80 procentech pozorování), což znamená, že algoritmus nikdy neuvidí vlastnosti cestujícího třídy 3. Tato chyba povede ke špatné predikci.

Chcete-li tento problém vyřešit, můžete použít funkci sample().

shuffle_index <- sample(1:nrow(titanic))
head(shuffle_index)

Rozhodovací strom R kód Vysvětlení

  • sample(1:nrow(titanic)): Vygeneruje náhodný seznam indexů od 1 do 1309 (tj. maximální počet řádků).

Výstup:

## [1]  288  874 1078  633  887  992

Tento index použijete k promíchání titanické datové sady.

titanic <- titanic[shuffle_index, ]
head(titanic)

Výstup:

##         X pclass survived
## 288   288      1        0
## 874   874      3        0
## 1078 1078      3        1
## 633   633      3        0
## 887   887      3        1
## 992   992      3        1
##                                                           name    sex age
## 288                                      Sutton, Mr. Frederick   male  61
## 874                   Humblen, Mr. Adolf Mathias Nicolai Olsen   male  42
## 1078                                 O'Driscoll, Miss. Bridget female  NA
## 633  Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female  39
## 887                                        Jermyn, Miss. Annie female  NA
## 992                                           Mamee, Mr. Hanna   male  NA
##      sibsp parch ticket    fare cabin embarked           home.dest## 288      0     0  36963 32.3208   D50        S     Haddenfield, NJ
## 874      0     0 348121  7.6500 F G63        S                    
## 1078     0     0  14311  7.7500              Q                    
## 633      1     5 347082 31.2750              S Sweden Winnipeg, MN
## 887      0     0  14313  7.7500              Q                    
## 992      0     0   2677  7.2292              C	

Krok 2) Vyčistěte datovou sadu

Struktura dat ukazuje, že některé proměnné mají NA. Čištění dat proveďte následovně

  • Vypusťte proměnné home.dest,cabin, name, X a ticket
  • Vytvořte proměnné faktoru pro pclass a přežil
  • Pusťte NA
library(dplyr)
# Drop variables
clean_titanic <- titanic % > %
select(-c(home.dest, cabin, name, X, ticket)) % > % 
#Convert to factor level
	mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')),
	survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > %
na.omit()
glimpse(clean_titanic)

Vysvětlení kódu

  • select(-c(domov.cíl, kajuta, jméno, X, letenka)): Vynechte nepotřebné proměnné
  • pclass = factor(pclass, levely = c(1,2,3), labels= c('Upper', 'Middle', 'Lower')): Přidejte popisek k proměnné pclass. 1 se změní na horní, 2 na střední a 3 na nižší
  • factor(survived, levels = c(0,1), labels = c('No', 'Yes')): Přidejte popisek k proměnné, která přežila. 1 se změní na Ne a 2 se změní na Ano
  • na.omit(): Odstraňte pozorování NA

Výstup:

## Observations: 1,045
## Variables: 8
## $ pclass   <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U...
## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y...
## $ sex      <fctr> male, male, female, female, male, male, female, male...
## $ age      <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ...
## $ sibsp    <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,...
## $ parch    <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,...
## $ fare     <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ...
## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...		

Krok 3) Vytvořte vlak/testovací sadu

Před trénováním modelu musíte provést dva kroky:

  • Vytvořte vlak a testovací soupravu: Natrénujete model na soupravě vlaku a otestujete předpověď na testovací soupravě (tj. neviditelná data)
  • Nainstalujte rpart.plot z konzoly

Běžnou praxí je rozdělit data 80/20, 80 procent dat slouží k trénování modelu a 20 procent k předpovědím. Musíte vytvořit dva samostatné datové rámce. Nechcete se dotknout testovací sady, dokud nedokončíte stavbu modelu. Můžete vytvořit název funkce create_train_test(), který přebírá tři argumenty.

create_train_test(df, size = 0.8, train = TRUE)
arguments:
-df: Dataset used to train the model.
-size: Size of the split. By default, 0.8. Numerical value
-train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) {
    n_row = nrow(data)
    total_row = size * n_row
    train_sample < - 1: total_row
    if (train == TRUE) {
        return (data[train_sample, ])
    } else {
        return (data[-train_sample, ])
    }
}

Vysvětlení kódu

  • function(data, size=0.8, train = TRUE): Přidejte argumenty do funkce
  • n_row = nrow(data): Počet řádků v datové sadě
  • total_row = size*n_row: Vraťte n-tý řádek pro sestavení vlakové soupravy
  • train_sample <- 1:total_row: Vyberte první řádek až n-tý řádek
  • if (train ==TRUE){ } else { }: Pokud je podmínka nastavena na true, vrátí vlakovou sadu, jinak testovací sadu.

Můžete otestovat svou funkci a zkontrolovat rozměr.

data_train <- create_train_test(clean_titanic, 0.8, train = TRUE)
data_test <- create_train_test(clean_titanic, 0.8, train = FALSE)
dim(data_train)

Výstup:

## [1] 836   8
dim(data_test)

Výstup:

## [1] 209   8

Datový soubor vlaku má 1046 řádků, zatímco testovací datový soubor má 262 řádků.

Pomocí funkce prop.table() v kombinaci s table() ověříte, zda je proces randomizace správný.

prop.table(table(data_train$survived))

Výstup:

##
##        No       Yes 
## 0.5944976 0.4055024
prop.table(table(data_test$survived))

Výstup:

## 
##        No       Yes 
## 0.5789474 0.4210526

V obou souborech dat je počet přeživších stejný, asi 40 procent.

Nainstalujte rpart.plot

rpart.plot není dostupný z knihoven conda. Můžete jej nainstalovat z konzole:

install.packages("rpart.plot")

Krok 4) Sestavte model

Jste připraveni postavit model. Syntaxe funkce rozhodovacího stromu Rpart je:

rpart(formula, data=, method='')
arguments:			
- formula: The function to predict
- data: Specifies the data frame- method: 			
- "class" for a classification tree 			
- "anova" for a regression tree	

Používáte metodu třídy, protože předpovídáte třídu.

library(rpart)
library(rpart.plot)
fit <- rpart(survived~., data = data_train, method = 'class')
rpart.plot(fit, extra = 106

Vysvětlení kódu

  • rpart(): Funkce přizpůsobená modelu. Argumenty jsou:
    • přežil ~.: Vzorec rozhodovacích stromů
    • data = data_train: Dataset
    • method = 'class': Fit binární model
  • rpart.plot(fit, extra= 106): Vykreslete strom. Další funkce jsou nastaveny na 101 pro zobrazení pravděpodobnosti 2. třídy (užitečné pro binární odezvy). Můžete odkazovat na viněta pro více informací o dalších možnostech.

Výstup:

Sestavte model rozhodovacích stromů v R

Začnete u kořenového uzlu (hloubka 0 nad 3, horní část grafu):

  1. Na vrcholu je celková pravděpodobnost přežití. Ukazuje podíl cestujících, kteří nehodu přežili. 41 procent cestujících přežilo.
  2. Tento uzel se ptá, zda je pohlaví cestujícího muž. Pokud ano, přejděte dolů do levého podřízeného uzlu kořene (hloubka 2). 63 procent jsou muži s pravděpodobností přežití 21 procent.
  3. Ve druhém uzlu se zeptáte, zda je cestujícímu muž starší 3.5 roku. Pokud ano, pak je šance na přežití 19 procent.
  4. Pokračujte v tom, abyste pochopili, jaké vlastnosti ovlivňují pravděpodobnost přežití.

Všimněte si, že jednou z mnoha vlastností rozhodovacích stromů je, že vyžadují velmi malou přípravu dat. Zejména nevyžadují změnu měřítka prvků nebo centrování.

Ve výchozím nastavení funkce rpart() používá Gini míra nečistot k rozdělení poznámky. Čím vyšší je Gini koeficient, tím více různých instancí v uzlu.

Krok 5) Proveďte předpověď

Můžete předvídat testovací datovou sadu. Chcete-li provést předpověď, můžete použít funkci forecast(). Základní syntaxe predikce pro rozhodovací strom R je:

predict(fitted_model, df, type = 'class')
arguments:
- fitted_model: This is the object stored after model estimation. 
- df: Data frame used to make the prediction
- type: Type of prediction			
    - 'class': for classification			
    - 'prob': to compute the probability of each class			
    - 'vector': Predict the mean response at the node level	

Chcete předpovědět, kteří cestující s větší pravděpodobností přežijí po srážce z testovací soupravy. To znamená, že mezi těmi 209 cestujícími budete vědět, který z nich přežije nebo ne.

predict_unseen <-predict(fit, data_test, type = 'class')

Vysvětlení kódu

  • predikovat(fit, data_test, typ = 'třída'): Předpovídá třídu (0/1) testovací sady

Testování cestujících, kteří to nestihli, a těch, kteří ano.

table_mat <- table(data_test$survived, predict_unseen)
table_mat

Vysvětlení kódu

  • table(data_test$survived, forecast_unseen): Vytvořte tabulku, která spočítá, kolik cestujících je klasifikováno jako přeživší a zemřelo v porovnání se správnou klasifikací rozhodovacího stromu v R

Výstup:

##      predict_unseen
##        No Yes
##   No  106  15
##   Yes  30  58

Model správně předpověděl 106 mrtvých cestujících, ale 15 přeživších klasifikoval jako mrtvé. Analogicky model nesprávně klasifikoval 30 cestujících jako přeživší, zatímco se ukázalo, že jsou mrtví.

Krok 6) Změřte výkon

Můžete vypočítat míru přesnosti pro klasifikační úlohu pomocí matoucí matice:

Jedno matoucí matice je lepší volbou pro hodnocení výkonu klasifikace. Obecnou myšlenkou je spočítat, kolikrát jsou skutečné instance klasifikovány jako nepravdivé.

Měření výkonu rozhodovacích stromů v R

Každý řádek v matečné matici představuje skutečný cíl, zatímco každý sloupec představuje předpokládaný cíl. První řádek této matice uvažuje mrtvé cestující (třída False): 106 bylo správně klasifikováno jako mrtvé (Pravda negativní), zatímco zbývající byl chybně klasifikován jako přeživší (Falešně pozitivní). Druhá řada bere v úvahu přeživší, pozitivní třída byla 58 (Pravda pozitivní), zatímco Pravda negativní byl 30.

Můžete vypočítat test přesnosti z matoucí matice:

Měření výkonu rozhodovacích stromů v R

Je to podíl skutečných kladných a záporných hodnot na součtu matice. S R můžete kódovat následovně:

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)

Vysvětlení kódu

  • sum(diag(table_mat)): Součet úhlopříčky
  • sum(table_mat): Součet matice.

Přesnost testovací sady si můžete vytisknout:

print(paste('Accuracy for test', accuracy_Test))

Výstup:

## [1] "Accuracy for test 0.784688995215311"

Za testovací sadu máte skóre 78 procent. Stejné cvičení můžete replikovat pomocí trénovací datové sady.

Krok 7) Vylaďte hyperparametry

Rozhodovací strom v R má různé parametry, které řídí aspekty přizpůsobení. V knihovně rozhodovacího stromu rpart můžete ovládat parametry pomocí funkce rpart.control(). V následujícím kódu uvedete parametry, které budete ladit. Můžete odkazovat na viněta pro další parametry.

rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30)
Arguments:
-minsplit: Set the minimum number of observations in the node before the algorithm perform a split
-minbucket:  Set the minimum number of observations in the final note i.e. the leaf
-maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0

Budeme postupovat následovně:

  • Vytvořte funkci pro návrat přesnosti
  • Nalaďte maximální hloubku
  • Vylaďte minimální počet vzorků, které musí mít uzel, než se může rozdělit
  • Vylaďte minimální počet vzorků, které musí mít listový uzel

Můžete napsat funkci pro zobrazení přesnosti. Jednoduše zabalíte kód, který jste použili dříve:

  1. předpovídat: předpovídat_nezobrazeno <- předvídat(přizpůsobit, test_dat, typ = 'třída')
  2. Vytvořit tabulku: table_mat <- table(data_test$survived, forecast_unseen)
  3. Přesnost výpočtu: přesnost_Test <- součet(diag(podložka_tabulky))/součet(podložka_tabulky)
accuracy_tune <- function(fit) {
    predict_unseen <- predict(fit, data_test, type = 'class')
    table_mat <- table(data_test$survived, predict_unseen)
    accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
    accuracy_Test
}

Můžete zkusit vyladit parametry a zjistit, zda můžete vylepšit model nad výchozí hodnotu. Připomínáme, že musíte získat přesnost vyšší než 0.78

control <- rpart.control(minsplit = 4,
    minbucket = round(5 / 3),
    maxdepth = 3,
    cp = 0)
tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control)
accuracy_tune(tune_fit)

Výstup:

## [1] 0.7990431

S následujícím parametrem:

minsplit = 4
minbucket= round(5/3)
maxdepth = 3cp=0

Získáte vyšší výkon než předchozí model. gratuluji!

Shrnutí

Můžeme shrnout funkce pro trénování algoritmu rozhodovacího stromu R

Knihovna Objektivní funkce Třída parametry Detaily
rpart Strom klasifikace vlaků v R rpart() třída vzorec, df, metoda
rpart Vlak regresní strom rpart() anova vzorec, df, metoda
rpart Vykreslete stromy rpart.plot() namontovaný model
základna předpovědět předpovědět() třída osazený model, typ
základna předpovědět předpovědět() prob osazený model, typ
základna předpovědět předpovědět() vektor osazený model, typ
rpart Kontrolní parametry rpart.control() minsplit Než algoritmus provede rozdělení, nastavte minimální počet pozorování v uzlu
minbucket Nastavte minimální počet pozorování v závěrečné notě, tj. listu
maxdepth Nastavte maximální hloubku libovolného uzlu konečného stromu. Kořenový uzel je ošetřen hloubkou 0
rpart Model vlaku s řídicím parametrem rpart() vzorec, df, metoda, kontrola

Poznámka: Trénujte model na trénovacích datech a otestujte výkon na neviditelné datové sadě, tj. testovací sadě.