Arborele de decizie în R: Arborele de clasificare cu exemplu
Ce sunt arborii de decizie?
Arbori de decizie sunt algoritmi de învățare automată versatil care pot efectua atât sarcini de clasificare, cât și de regresie. Sunt algoritmi foarte puternici, capabili să se potrivească cu seturi de date complexe. În plus, arborii de decizie sunt componente fundamentale ale pădurilor aleatorii, care sunt printre cei mai puternici algoritmi de învățare automată disponibili astăzi.
Antrenarea și vizualizarea unui arbore de decizie în R
Pentru a construi primul arbore de decizie în exemplul R, vom proceda după cum urmează în acest tutorial pentru arborele de decizie:
- Pasul 1: importați datele
- Pasul 2: Curățați setul de date
- Pasul 3: Creați setul de tren/test
- Pasul 4: Construiți modelul
- Pasul 5: Faceți predicții
- Pasul 6: Măsurați performanța
- Pasul 7: Reglați hiper-parametrii
Pasul 1) Importați datele
Dacă ești curios despre soarta Titanicului, poți urmări acest videoclip pe YouTube. Scopul acestui set de date este de a prezice ce oameni au mai multe șanse de a supraviețui după ciocnirea cu aisbergul. Setul de date conține 13 variabile și 1309 observații. Setul de date este ordonat după variabila X.
set.seed(678) path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv' titanic <-read.csv(path) head(titanic)
ieșire:
## X pclass survived name sex ## 1 1 1 1 Allen, Miss. Elisabeth Walton female ## 2 2 1 1 Allison, Master. Hudson Trevor male ## 3 3 1 0 Allison, Miss. Helen Loraine female ## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male ## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female ## 6 6 1 1 Anderson, Mr. Harry male ## age sibsp parch ticket fare cabin embarked ## 1 29.0000 0 0 24160 211.3375 B5 S ## 2 0.9167 1 2 113781 151.5500 C22 C26 S ## 3 2.0000 1 2 113781 151.5500 C22 C26 S ## 4 30.0000 1 2 113781 151.5500 C22 C26 S ## 5 25.0000 1 2 113781 151.5500 C22 C26 S ## 6 48.0000 0 0 19952 26.5500 E12 S ## home.dest ## 1 St Louis, MO ## 2 Montreal, PQ / Chesterville, ON ## 3 Montreal, PQ / Chesterville, ON ## 4 Montreal, PQ / Chesterville, ON ## 5 Montreal, PQ / Chesterville, ON ## 6 New York, NY
tail(titanic)
ieșire:
## X pclass survived name sex age sibsp ## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0 ## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1 ## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1 ## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0 ## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0 ## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0 ## parch ticket fare cabin embarked home.dest ## 1304 0 2627 14.4583 C ## 1305 0 2665 14.4542 C ## 1306 0 2665 14.4542 C ## 1307 0 2656 7.2250 C ## 1308 0 2670 7.2250 C ## 1309 0 315082 7.8750 S
De la ieșirea cap și coadă, puteți observa că datele nu sunt amestecate. Aceasta este o mare problemă! Când vă veți împărți datele între un set de tren și un set de testare, veți selecta pasagerul din clasa 1 și 2 (Nici un pasager din clasa 3 nu se află în primele 80 la sută din observații), ceea ce înseamnă că algoritmul nu va vedea niciodată caracteristicile pasagerului din clasa 3. Această greșeală va duce la o predicție slabă.
Pentru a depăși această problemă, puteți utiliza funcția sample().
shuffle_index <- sample(1:nrow(titanic)) head(shuffle_index)
Arborele de decizie cod R Explicație
- sample(1:nrow(titanic)): Generați o listă aleatorie de index de la 1 la 1309 (adică numărul maxim de rânduri).
ieșire:
## [1] 288 874 1078 633 887 992
Veți folosi acest index pentru a amesteca setul de date Titanic.
titanic <- titanic[shuffle_index, ] head(titanic)
ieșire:
## X pclass survived ## 288 288 1 0 ## 874 874 3 0 ## 1078 1078 3 1 ## 633 633 3 0 ## 887 887 3 1 ## 992 992 3 1 ## name sex age ## 288 Sutton, Mr. Frederick male 61 ## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42 ## 1078 O'Driscoll, Miss. Bridget female NA ## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39 ## 887 Jermyn, Miss. Annie female NA ## 992 Mamee, Mr. Hanna male NA ## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ ## 874 0 0 348121 7.6500 F G63 S ## 1078 0 0 14311 7.7500 Q ## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN ## 887 0 0 14313 7.7500 Q ## 992 0 0 2677 7.2292 C
Pasul 2) Curățați setul de date
Structura datelor arată că unele variabile au NA. Curățarea datelor trebuie efectuată după cum urmează
- Eliminați variabilele home.dest,cabin, name, X și bilet
- Creați variabile factor pentru pclass și supraviețuit
- Aruncă NA
library(dplyr) # Drop variables clean_titanic <- titanic % > % select(-c(home.dest, cabin, name, X, ticket)) % > % #Convert to factor level mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')), survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > % na.omit() glimpse(clean_titanic)
Explicarea codului
- select(-c(home.dest, cabin, name, X, ticket)): Eliminați variabilele inutile
- pclass = factor(pclass, levels = c(1,2,3), labels= c('Super', 'Middle', 'Lower')): Adăugați o etichetă la variabila pclass. 1 devine superior, 2 devine mijlociu și 3 devine inferior
- factor(a supraviețuit, niveluri = c(0,1), etichete = c('Nu', 'Da')): Adăugați o etichetă la variabila supraviețuită. 1 devine Nu și 2 devine Da
- na.omit(): Eliminați observațiile NA
ieșire:
## Observations: 1,045 ## Variables: 8 ## $ pclass <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U... ## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y... ## $ sex <fctr> male, male, female, female, male, male, female, male... ## $ age <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ... ## $ sibsp <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,... ## $ parch <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,... ## $ fare <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ... ## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...
Pasul 3) Creați setul de tren/test
Înainte de a vă antrena modelul, trebuie să efectuați doi pași:
- Creați un tren și un set de testare: antrenați modelul pe garnitura de tren și testați predicția pe setul de testare (adică date nevăzute)
- Instalați rpart.plot din consolă
Practica obișnuită este împărțirea datelor 80/20, 80% din date servesc la antrenarea modelului și 20% pentru a face predicții. Trebuie să creați două cadre de date separate. Nu doriți să atingeți setul de testare până nu terminați de construit modelul. Puteți crea un nume de funcție create_train_test() care ia trei argumente.
create_train_test(df, size = 0.8, train = TRUE) arguments: -df: Dataset used to train the model. -size: Size of the split. By default, 0.8. Numerical value -train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample < - 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } }
Explicarea codului
- function(data, size=0.8, train = TRUE): Adăugați argumentele în funcție
- n_row = nrow(date): numărați numărul de rânduri din setul de date
- total_row = size*n_row: Returnează al n-lea rând pentru a construi setul de tren
- train_sample <- 1:total_row: Selectați primul rând până la al n-lea rând
- if (tren ==TRUE){ } else { }: Dacă condiția este setată la adevărat, returnați setul de tren, altfel setul de testare.
Puteți testa funcția și verifica dimensiunea.
data_train <- create_train_test(clean_titanic, 0.8, train = TRUE) data_test <- create_train_test(clean_titanic, 0.8, train = FALSE) dim(data_train)
ieșire:
## [1] 836 8
dim(data_test)
ieșire:
## [1] 209 8
Setul de date de tren are 1046 de rânduri, în timp ce setul de date de testare are 262 de rânduri.
Utilizați funcția prop.table() combinată cu table() pentru a verifica dacă procesul de randomizare este corect.
prop.table(table(data_train$survived))
ieșire:
## ## No Yes ## 0.5944976 0.4055024
prop.table(table(data_test$survived))
ieșire:
## ## No Yes ## 0.5789474 0.4210526
În ambele seturi de date, numărul supraviețuitorilor este același, aproximativ 40 la sută.
Instalați rpart.plot
rpart.plot nu este disponibil din bibliotecile conda. Îl poți instala de pe consolă:
install.packages("rpart.plot")
Pasul 4) Construiți modelul
Sunteți gata să construiți modelul. Sintaxa pentru funcția de arbore de decizie Rpart este:
rpart(formula, data=, method='') arguments: - formula: The function to predict - data: Specifies the data frame- method: - "class" for a classification tree - "anova" for a regression tree
Folosești metoda clasei pentru că prezici o clasă.
library(rpart) library(rpart.plot) fit <- rpart(survived~., data = data_train, method = 'class') rpart.plot(fit, extra = 106
Explicarea codului
- rpart(): Funcție pentru a se potrivi modelului. Argumentele sunt:
- supraviețuit ~.: Formula of the Decision Trees
- data = data_train: Set de date
- method = 'class': Potriviți un model binar
- rpart.plot(fit, extra= 106): Trasează arborele. Caracteristicile suplimentare sunt setate la 101 pentru a afișa probabilitatea clasei a 2-a (utilă pentru răspunsuri binare). Vă puteți referi la vinietă pentru mai multe informații despre celelalte opțiuni.
ieșire:
Începeți de la nodul rădăcină (adâncimea 0 peste 3, partea de sus a graficului):
- În partea de sus, este probabilitatea generală de supraviețuire. Arată proporția de pasageri care au supraviețuit accidentului. 41% dintre pasageri au supraviețuit.
- Acest nod întreabă dacă genul pasagerului este bărbat. Dacă da, atunci coborâți la nodul copil stâng al rădăcinii (adâncimea 2). 63 la sută sunt bărbați cu o probabilitate de supraviețuire de 21 la sută.
- În cel de-al doilea nod, întrebați dacă pasagerul de sex masculin are peste 3.5 ani. Dacă da, atunci șansa de supraviețuire este de 19%.
- Continuați să mergeți așa pentru a înțelege ce caracteristici influențează probabilitatea de supraviețuire.
Rețineți că, una dintre numeroasele calități ale arborilor de decizie este că necesită foarte puțină pregătire a datelor. În special, nu necesită scalare sau centrare a caracteristicilor.
În mod implicit, funcția rpart() folosește gini măsură de impuritate pentru a împărți nota. Cu cât coeficientul Gini este mai mare, cu atât mai multe instanțe diferite în nod.
Pasul 5) Faceți o predicție
Puteți prezice setul de date de testare. Pentru a face o predicție, puteți utiliza funcția predict(). Sintaxa de bază a predict pentru arborele de decizie R este:
predict(fitted_model, df, type = 'class') arguments: - fitted_model: This is the object stored after model estimation. - df: Data frame used to make the prediction - type: Type of prediction - 'class': for classification - 'prob': to compute the probability of each class - 'vector': Predict the mean response at the node level
Doriți să preziceți care pasageri au mai multe șanse de a supraviețui după ciocnire din setul de testare. Înseamnă că vei ști printre cei 209 de pasageri care va supraviețui sau nu.
predict_unseen <-predict(fit, data_test, type = 'class')
Explicarea codului
- predict(fit, data_test, type = 'class'): preziceți clasa (0/1) a setului de testare
Testarea pasagerului care nu a reușit și a celor care au făcut-o.
table_mat <- table(data_test$survived, predict_unseen) table_mat
Explicarea codului
- table(data_test$survived, predict_unseen): Creați un tabel pentru a număra câți pasageri sunt clasificați ca supraviețuitori și au murit în comparație cu clasificarea corectă a arborelui de decizie din R
ieșire:
## predict_unseen ## No Yes ## No 106 15 ## Yes 30 58
Modelul a prezis corect 106 pasageri morți, dar a clasificat 15 supraviețuitori ca morți. Prin analogie, modelul a clasificat greșit 30 de pasageri drept supraviețuitori, în timp ce s-au dovedit a fi morți.
Pasul 6) Măsurați performanța
Puteți calcula o măsură de precizie pentru sarcina de clasificare cu matrice de confuzie:
matrice de confuzie este o alegere mai bună pentru a evalua performanța clasificării. Ideea generală este de a număra de câte ori instanțele adevărate sunt clasificate ca fiind false.
Fiecare rând dintr-o matrice de confuzie reprezintă o țintă reală, în timp ce fiecare coloană reprezintă o țintă estimată. Primul rând al acestei matrice ia în considerare pasagerii morți (clasa False): 106 au fost clasificați corect ca morți (Adevărat negativ), în timp ce cel rămas a fost clasificat în mod greșit drept supraviețuitor (Fals pozitiv). Al doilea rând ia în considerare supraviețuitorii, clasa pozitivă a fost 58 (Adevărat pozitiv), in timp ce Adevărat negativ a fost 30.
Puteți calcula test de precizie din matricea de confuzie:
Este proporția dintre adevăratele pozitive și adevăratele negative față de suma matricei. Cu R, puteți codifica după cum urmează:
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
Explicarea codului
- sum(diag(table_mat)): Suma diagonalei
- sum(table_mat): Suma matricei.
Puteți imprima acuratețea setului de testare:
print(paste('Accuracy for test', accuracy_Test))
ieșire:
## [1] "Accuracy for test 0.784688995215311"
Aveți un scor de 78 la sută pentru setul de test. Puteți replica același exercițiu cu setul de date de antrenament.
Pasul 7) Reglați hiper-parametrii
Arborele de decizie din R are diferiți parametri care controlează aspectele potrivirii. În biblioteca arborelui de decizie rpart, puteți controla parametrii folosind funcția rpart.control(). În următorul cod, introduceți parametrii pe care îi veți regla. Vă puteți referi la vinietă pentru alți parametri.
rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30) Arguments: -minsplit: Set the minimum number of observations in the node before the algorithm perform a split -minbucket: Set the minimum number of observations in the final note i.e. the leaf -maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0
Vom proceda astfel:
- Construiți funcția pentru a returna acuratețea
- Reglați adâncimea maximă
- Reglați numărul minim de eșantion pe care trebuie să îl aibă un nod înainte de a se putea diviza
- Reglați numărul minim de eșantion pe care trebuie să îl aibă un nod frunză
Puteți scrie o funcție pentru a afișa acuratețea. Pur și simplu împachetați codul pe care l-ați folosit înainte:
- prezice: predict_unseen <- predict(potrivire, data_test, tip = 'clasa')
- Produceți tabel: table_mat <- table(data_test$survived, predict_unseen)
- Precizie de calcul: accuracy_Test <- sum(diag(table_mat))/sum(table_mat)
accuracy_tune <- function(fit) { predict_unseen <- predict(fit, data_test, type = 'class') table_mat <- table(data_test$survived, predict_unseen) accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test }
Puteți încerca să reglați parametrii și să vedeți dacă puteți îmbunătăți modelul față de valoarea implicită. Ca o reamintire, trebuie să obțineți o precizie mai mare de 0.78
control <- rpart.control(minsplit = 4, minbucket = round(5 / 3), maxdepth = 3, cp = 0) tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control) accuracy_tune(tune_fit)
ieșire:
## [1] 0.7990431
Cu următorul parametru:
minsplit = 4 minbucket= round(5/3) maxdepth = 3cp=0
Obtii o performanta mai mare decat modelul anterior. Felicitari!
Rezumat
Putem rezuma funcțiile în care să antrenăm un algoritm de arbore de decizie R
Bibliotecă | Obiectiv | Funcţie | Clasă | parametrii | Detalii |
---|---|---|---|---|---|
rpart | Arborele de clasificare a trenurilor în R | rpart() | clasă | formula, df, metoda | |
rpart | Arborele de regresie al trenului | rpart() | anova | formula, df, metoda | |
rpart | Puneți copacii | rpart.plot() | model montat | ||
de bază | prezice | prezice() | clasă | model montat, tip | |
de bază | prezice | prezice() | Prob | model montat, tip | |
de bază | prezice | prezice() | vector | model montat, tip | |
rpart | Parametrii de control | rpart.control() | minsplit | Setați numărul minim de observații în nod înainte ca algoritmul să efectueze o împărțire | |
minbucket | Setați numărul minim de observații în nota finală, adică în frunză | ||||
adancime maxima | Setați adâncimea maximă a oricărui nod al arborelui final. Nodul rădăcină este tratat cu o adâncime 0 | ||||
rpart | Model de tren cu parametru de control | rpart() | formulă, df, metodă, control |
Notă: Antrenați modelul pe date de antrenament și testați performanța pe un set de date nevăzut, adică set de testare.