Döntési fa az R-ben: Osztályozási fa példával
Mik azok a döntési fák?
Döntési fák sokoldalú gépi tanulási algoritmus, amely osztályozási és regressziós feladatokat is végrehajthat. Nagyon hatékony algoritmusok, amelyek képesek összetett adatkészletek illesztésére. Emellett a döntési fák a véletlenszerű erdők alapvető összetevői, amelyek a ma elérhető leghatékonyabb gépi tanulási algoritmusok közé tartoznak.
Döntési fák képzése és megjelenítése R-ben
Az első döntési fa felépítéséhez az R példában a következőképpen járunk el ebben a döntési fa oktatóanyagában:
- 1. lépés: Importálja az adatokat
- 2. lépés: Tisztítsa meg az adatkészletet
- 3. lépés: Hozzon létre vonat/tesztkészletet
- 4. lépés: Készítse el a modellt
- 5. lépés: Készítsen előrejelzést
- 6. lépés: Mérje meg a teljesítményt
- 7. lépés: Hangolja be a hiperparamétereket
1. lépés) Importálja az adatokat
Ha kíváncsi vagy a Titanic sorsára, megtekintheti ezt a videót itt Youtube. Ennek az adathalmaznak az a célja, hogy megjósolja, mely emberek maradnak nagyobb valószínűséggel a jéghegygel való ütközés után. Az adatkészlet 13 változót és 1309 megfigyelést tartalmaz. Az adatkészletet az X változó rendezi.
set.seed(678) path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv' titanic <-read.csv(path) head(titanic)
output:
## X pclass survived name sex ## 1 1 1 1 Allen, Miss. Elisabeth Walton female ## 2 2 1 1 Allison, Master. Hudson Trevor male ## 3 3 1 0 Allison, Miss. Helen Loraine female ## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male ## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female ## 6 6 1 1 Anderson, Mr. Harry male ## age sibsp parch ticket fare cabin embarked ## 1 29.0000 0 0 24160 211.3375 B5 S ## 2 0.9167 1 2 113781 151.5500 C22 C26 S ## 3 2.0000 1 2 113781 151.5500 C22 C26 S ## 4 30.0000 1 2 113781 151.5500 C22 C26 S ## 5 25.0000 1 2 113781 151.5500 C22 C26 S ## 6 48.0000 0 0 19952 26.5500 E12 S ## home.dest ## 1 St Louis, MO ## 2 Montreal, PQ / Chesterville, ON ## 3 Montreal, PQ / Chesterville, ON ## 4 Montreal, PQ / Chesterville, ON ## 5 Montreal, PQ / Chesterville, ON ## 6 New York, NY
tail(titanic)
output:
## X pclass survived name sex age sibsp ## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0 ## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1 ## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1 ## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0 ## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0 ## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0 ## parch ticket fare cabin embarked home.dest ## 1304 0 2627 14.4583 C ## 1305 0 2665 14.4542 C ## 1306 0 2665 14.4542 C ## 1307 0 2656 7.2250 C ## 1308 0 2670 7.2250 C ## 1309 0 315082 7.8750 S
A fej és a farok kimenetén láthatja, hogy az adatok nincsenek megkeverve. Ez nagy kérdés! Amikor felosztja adatait egy vonatszerelvény és egy tesztkészlet között, akkor kiválasztja csak az 1. és 2. osztály utasa (a megfigyelések felső 3 százalékában egy 80. osztályú utas sem szerepel), ami azt jelenti, hogy az algoritmus soha nem fogja látni a 3. osztályú utas jellemzőit. Ez a hiba rossz előrejelzéshez vezet.
A probléma megoldásához használhatja a sample() függvényt.
shuffle_index <- sample(1:nrow(titanic)) head(shuffle_index)
Döntési fa R kód Magyarázat
- sample(1:nrow(titanic)): Véletlenszerű indexlistát generál 1-től 1309-ig (vagyis a sorok maximális számát).
output:
## [1] 288 874 1078 633 887 992
Ezt az indexet fogja használni a Titanic adatkészlet megkeverésére.
titanic <- titanic[shuffle_index, ] head(titanic)
output:
## X pclass survived ## 288 288 1 0 ## 874 874 3 0 ## 1078 1078 3 1 ## 633 633 3 0 ## 887 887 3 1 ## 992 992 3 1 ## name sex age ## 288 Sutton, Mr. Frederick male 61 ## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42 ## 1078 O'Driscoll, Miss. Bridget female NA ## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39 ## 887 Jermyn, Miss. Annie female NA ## 992 Mamee, Mr. Hanna male NA ## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ ## 874 0 0 348121 7.6500 F G63 S ## 1078 0 0 14311 7.7500 Q ## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN ## 887 0 0 14313 7.7500 Q ## 992 0 0 2677 7.2292 C
2. lépés) Tisztítsa meg az adatkészletet
Az adatok szerkezete azt mutatja, hogy néhány változónak NA-ja van. Az adatok tisztítását az alábbiak szerint kell elvégezni
- Dobd el a home.dest,cabin, name, X and ticket változókat
- Hozzon létre faktorváltozókat a pclass és a túlélők számára
- Dobd el az NA-t
library(dplyr) # Drop variables clean_titanic <- titanic % > % select(-c(home.dest, cabin, name, X, ticket)) % > % #Convert to factor level mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')), survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > % na.omit() glimpse(clean_titanic)
Kód Magyarázat
- select(-c(home.dest, cabin, name, X, ticket)): A szükségtelen változók eldobása
- pclass = factor(pclass, levelek = c(1,2,3), labels= c('Felső', 'Közép', 'Alsó')): Címke hozzáadása a pclass változóhoz. Az 1-ből Felső, a 2-ből a középső, a 3-ból pedig alacsonyabb lesz
- factor(survived, level = c(0,1), labels = c('Nem', 'Igen')): Címke hozzáadása a fennmaradt változóhoz. 1 nem lesz, a 2 pedig igen
- na.omit(): Az NA megfigyelések eltávolítása
output:
## Observations: 1,045 ## Variables: 8 ## $ pclass <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U... ## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y... ## $ sex <fctr> male, male, female, female, male, male, female, male... ## $ age <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ... ## $ sibsp <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,... ## $ parch <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,... ## $ fare <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ... ## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...
3. lépés) Hozzon létre vonat/tesztkészletet
Mielőtt betanítaná modelljét, két lépést kell végrehajtania:
- Hozzon létre egy vonat- és tesztkészletet: Betanítja a modellt a vonatkészleten, és teszteli az előrejelzést a tesztkészleten (azaz nem látott adatok)
- Telepítse az rpart.plot fájlt a konzolról
Az általános gyakorlat szerint az adatokat 80/20 arányban osztják fel, az adatok 80 százaléka a modell betanítását, 20 százaléka pedig előrejelzéseket szolgál. Két külön adatkeretet kell létrehoznia. Addig ne nyúljon a tesztkészlethez, amíg be nem fejezi a modell elkészítését. Létrehozhat egy create_train_test() függvénynevet, amely három argumentumot tartalmaz.
create_train_test(df, size = 0.8, train = TRUE) arguments: -df: Dataset used to train the model. -size: Size of the split. By default, 0.8. Numerical value -train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample < - 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } }
Kód Magyarázat
- function(data, size=0.8, train = TRUE): Adja hozzá az argumentumokat a függvényhez
- n_row = nrow(data): Az adatkészletben lévő sorok számlálása
- összesen_sor = méret*n_sor: Visszaadja az n-edik sort a vonatkészlet összeállításához
- train_sample <- 1:total_row: Válassza ki az első sort az n-edik sorig
- if (train ==TRUE){ } else { }: Ha a feltétel igazra van állítva, akkor a vonatkészletet adja vissza, különben a teszthalmazt.
Tesztelheti működését és ellenőrizheti a méretet.
data_train <- create_train_test(clean_titanic, 0.8, train = TRUE) data_test <- create_train_test(clean_titanic, 0.8, train = FALSE) dim(data_train)
output:
## [1] 836 8
dim(data_test)
output:
## [1] 209 8
A vonatadatkészlet 1046, míg a tesztadatkészlet 262 sorból áll.
A prop.table() függvényt a table()-vel kombinálva ellenőrizheti, hogy a véletlenszerűsítési folyamat helyes-e.
prop.table(table(data_train$survived))
output:
## ## No Yes ## 0.5944976 0.4055024
prop.table(table(data_test$survived))
output:
## ## No Yes ## 0.5789474 0.4210526
Mindkét adatkészletben a túlélők száma azonos, körülbelül 40 százalék.
Telepítse az rpart.plot
Az rpart.plot nem érhető el a conda könyvtárakból. Telepítheti a konzolról:
install.packages("rpart.plot")
4. lépés) Építse meg a modellt
Készen áll a modell megépítésére. Az Rpart döntési fa függvény szintaxisa a következő:
rpart(formula, data=, method='') arguments: - formula: The function to predict - data: Specifies the data frame- method: - "class" for a classification tree - "anova" for a regression tree
Azért használja az osztály metódust, mert megjósol egy osztályt.
library(rpart) library(rpart.plot) fit <- rpart(survived~., data = data_train, method = 'class') rpart.plot(fit, extra = 106
Kód Magyarázat
- rpart(): A modellhez illeszkedő függvény. Az érvek a következők:
- fennmaradt ~.: A határozati fák képlete
- data = data_train: Adatkészlet
- method = 'osztály': Bináris modell illesztése
- rpart.plot(fit, extra= 106): Ábrázolja a fát. Az extra szolgáltatások 101-re vannak állítva a 2. osztály valószínűségének megjelenítéséhez (bináris válaszoknál hasznos). Hivatkozhat a címke további információkért a többi választási lehetőségről.
output:
A gyökércsomópontnál kezdi (mélység 0 a 3 felett, a grafikon teteje):
- A tetején a túlélés általános valószínűsége. Megmutatja a balesetet túlélő utasok arányát. Az utasok 41 százaléka életben maradt.
- Ez a csomópont megkérdezi, hogy az utas neme férfi-e. Ha igen, akkor menj le a gyökér bal gyermekcsomópontjához (2. mélység). 63 százalékuk férfi, akiknek a túlélési valószínűsége 21 százalék.
- A második csomópontban azt kérdezi, hogy a férfi utas 3.5 évesnél idősebb-e. Ha igen, akkor a túlélés esélye 19 százalék.
- Így folytatod, hogy megértsd, milyen jellemzők befolyásolják a túlélés valószínűségét.
Vegye figyelembe, hogy a döntési fák számos tulajdonsága közül az egyik az, hogy nagyon kevés adat-előkészítést igényelnek. Különösen nem igényelnek funkcióméretezést vagy központosítást.
Alapértelmezés szerint az rpart() függvény a Gini szennyeződés mértéke a bankjegy felosztásához. Minél magasabb a Gini-együttható, annál több a különböző példány a csomóponton belül.
5. lépés) Készítsen előrejelzést
Megjósolhatja a tesztadatkészletet. Előrejelzés készítéséhez használhatja a predikció() függvényt. Az R döntési fa előrejelzésének alapvető szintaxisa a következő:
predict(fitted_model, df, type = 'class') arguments: - fitted_model: This is the object stored after model estimation. - df: Data frame used to make the prediction - type: Type of prediction - 'class': for classification - 'prob': to compute the probability of each class - 'vector': Predict the mean response at the node level
A tesztkészletből szeretné megjósolni, hogy mely utasok élik túl nagyobb valószínűséggel az ütközést. Ez azt jelenti, hogy a 209 utas közül tudni fogja, melyik éli túl vagy sem.
predict_unseen <-predict(fit, data_test, type = 'class')
Kód Magyarázat
- ennusta(fit, data_test, type = 'osztály'): A tesztkészlet osztályának (0/1) előrejelzése
Az utasok tesztelése, akiknek nem sikerült, és akiknek sikerült.
table_mat <- table(data_test$survived, predict_unseen) table_mat
Kód Magyarázat
- table(data_test$survived,prediction_unseen): Hozzon létre egy táblázatot, amely megszámolja, hány utas minősül túlélőnek és halt el, összehasonlítva a helyes döntési fa besorolásával az R-ben
output:
## predict_unseen ## No Yes ## No 106 15 ## Yes 30 58
A modell helyesen 106 halott utast jósolt meg, de 15 túlélőt halottnak minősített. Hasonlatosan, a modell 30 utast rosszul minősített túlélőnek, miközben kiderült, hogy meghaltak.
6. lépés) Mérje meg a teljesítményt
Kiszámíthatja az osztályozási feladat pontossági mértékét a zavart mátrix:
A zavart mátrix jobb választás az osztályozási teljesítmény értékelésére. Az általános ötlet az, hogy megszámoljuk, hogy az Igaz példányok hányszor hamisnak minősülnek.
A zavaros mátrix minden sora egy tényleges célt, míg minden oszlop egy előre jelzett célt jelent. Ennek a mátrixnak az első sora a halott utasokat veszi figyelembe (a False osztály): 106-ot helyesen halottaknak minősítettek (Igaz negatív), míg a maradékot tévesen túlélőnek minősítették (Álpozitív). A második sor a túlélőket tartalmazza, a pozitív osztály 58 (Igaz pozitív), amíg a Igaz negatív 30 volt.
Ki tudja számolni a pontossági teszt a zavaros mátrixból:
Ez az igazi pozitív és a valódi negatív aránya a mátrix összegéhez képest. Az R-vel a következőképpen kódolhat:
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
Kód Magyarázat
- sum(diag(table_mat)): Az átló összege
- sum(table_mat): A mátrix összege.
A tesztkészlet pontosságát kinyomtathatja:
print(paste('Accuracy for test', accuracy_Test))
output:
## [1] "Accuracy for test 0.784688995215311"
Ön 78 százalékos pontszámot kapott a tesztkészletre. Ugyanazt a gyakorlatot megismételheti az edzési adatkészlettel.
7. lépés) Hangolja be a hiperparamétereket
Az R döntési fájának különféle paraméterei vannak, amelyek szabályozzák az illeszkedést. Az rpart döntési fa könyvtárban a paramétereket az rpart.control() függvénnyel vezérelheti. A következő kódban bemutatja a beállítani kívánt paramétereket. Hivatkozhat a címke egyéb paraméterekhez.
rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30) Arguments: -minsplit: Set the minimum number of observations in the node before the algorithm perform a split -minbucket: Set the minimum number of observations in the final note i.e. the leaf -maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0
A következőképpen járunk el:
- Konstruálja a függvényt a pontosság visszaadásához
- Hangolja be a maximális mélységet
- Hangolja be azt a minimális számú mintát, amelyre egy csomópontnak rendelkeznie kell, mielőtt feloszthatja
- Hangolja be a levélcsomóponthoz szükséges minimális számú mintát
Írhat függvényt a pontosság megjelenítéséhez. Egyszerűen becsomagolja a korábban használt kódot:
- megjósolni: predikció_unseen <- ennusta(fit, data_test, type = 'osztály')
- Előállítási tábla: table_mat <- table(data_test$survived, ennusta_láthatatlan)
- Számítási pontosság: accuracy_Test <- sum(diag(table_mat))/sum(table_mat)
accuracy_tune <- function(fit) { predict_unseen <- predict(fit, data_test, type = 'class') table_mat <- table(data_test$survived, predict_unseen) accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test }
Megpróbálhatja hangolni a paramétereket, és megnézheti, hogy javíthatja-e a modellt az alapértelmezett értékhez képest. Emlékeztetőül: 0.78-nál nagyobb pontosságot kell elérnie
control <- rpart.control(minsplit = 4, minbucket = round(5 / 3), maxdepth = 3, cp = 0) tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control) accuracy_tune(tune_fit)
output:
## [1] 0.7990431
A következő paraméterrel:
minsplit = 4 minbucket= round(5/3) maxdepth = 3cp=0
Az előző modellnél nagyobb teljesítményt kap. Gratulálunk!
Összegzésként
Összefoglalhatjuk azokat a függvényeket, amelyekben egy döntési fa algoritmust betaníthatunk R
könyvtár | Objektív | Funkció | Osztály | paraméterek | Részletek |
---|---|---|---|---|---|
rpart | Vonatosztályozási fa R-ben | rpart() | osztály | képlet, df, módszer | |
rpart | Vonat regressziós fa | rpart() | anova | képlet, df, módszer | |
rpart | Telepítse a fákat | rpart.plot() | felszerelt modell | ||
bázis | előre | megjósolni () | osztály | felszerelt modell, típus | |
bázis | előre | megjósolni () | prob | felszerelt modell, típus | |
bázis | előre | megjósolni () | vektor | felszerelt modell, típus | |
rpart | Ellenőrzési paraméterek | rpart.control() | minsplit | Állítsa be a megfigyelések minimális számát a csomópontban, mielőtt az algoritmus felosztást hajtana végre | |
minbucket | Állítsa be a megfigyelések minimális számát az utolsó megjegyzésben, azaz a levélben | ||||
maximális mélység | Állítsa be a végső fa bármely csomópontjának maximális mélységét. A gyökércsomópontot 0 mélységgel kezeljük | ||||
rpart | Vonatmodell vezérlőparaméterrel | rpart() | képlet, df, módszer, vezérlés |
Megjegyzés: Tanítsa meg a modellt egy betanítási adatokon, és tesztelje a teljesítményt egy nem látott adatkészleten, azaz tesztkészleten.