Päätöspuu R:ssä: Luokittelupuu esimerkin kanssa
Mitä päätöspuut ovat?
Päätöspuut ovat monipuolisia koneoppimisalgoritmeja, jotka voivat suorittaa sekä luokitus- että regressiotehtäviä. Ne ovat erittäin tehokkaita algoritmeja, jotka pystyvät sovittamaan monimutkaisia tietojoukkoja. Lisäksi päätöspuut ovat satunnaisten metsien peruskomponentteja, jotka ovat tehokkaimpia saatavilla olevia koneoppimisalgoritmeja.
Päätöspuiden koulutus ja visualisointi R:ssä
Luodaksemme ensimmäisen päätöspuusi R-esimerkissä toimimme seuraavasti tässä päätöspuun opetusohjelmassa:
- Vaihe 1: Tuo tiedot
- Vaihe 2: Puhdista tietojoukko
- Vaihe 3: Luo juna-/testisarja
- Vaihe 4: Rakenna malli
- Vaihe 5: Tee ennuste
- Vaihe 6: Mittaa suorituskykyä
- Vaihe 7: Säädä hyperparametrit
Vaihe 1) Tuo tiedot
Jos olet utelias titanicin kohtalosta, voit katsoa tämän videon Youtube. Tämän tietojoukon tarkoituksena on ennustaa, ketkä ihmiset selviävät todennäköisemmin jäävuoren törmäyksen jälkeen. Aineisto sisältää 13 muuttujaa ja 1309 havaintoa. Tietojoukko on järjestetty muuttujan X mukaan.
set.seed(678) path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv' titanic <-read.csv(path) head(titanic)
lähtö:
## X pclass survived name sex ## 1 1 1 1 Allen, Miss. Elisabeth Walton female ## 2 2 1 1 Allison, Master. Hudson Trevor male ## 3 3 1 0 Allison, Miss. Helen Loraine female ## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male ## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female ## 6 6 1 1 Anderson, Mr. Harry male ## age sibsp parch ticket fare cabin embarked ## 1 29.0000 0 0 24160 211.3375 B5 S ## 2 0.9167 1 2 113781 151.5500 C22 C26 S ## 3 2.0000 1 2 113781 151.5500 C22 C26 S ## 4 30.0000 1 2 113781 151.5500 C22 C26 S ## 5 25.0000 1 2 113781 151.5500 C22 C26 S ## 6 48.0000 0 0 19952 26.5500 E12 S ## home.dest ## 1 St Louis, MO ## 2 Montreal, PQ / Chesterville, ON ## 3 Montreal, PQ / Chesterville, ON ## 4 Montreal, PQ / Chesterville, ON ## 5 Montreal, PQ / Chesterville, ON ## 6 New York, NY
tail(titanic)
lähtö:
## X pclass survived name sex age sibsp ## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0 ## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1 ## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1 ## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0 ## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0 ## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0 ## parch ticket fare cabin embarked home.dest ## 1304 0 2627 14.4583 C ## 1305 0 2665 14.4542 C ## 1306 0 2665 14.4542 C ## 1307 0 2656 7.2250 C ## 1308 0 2670 7.2250 C ## 1309 0 315082 7.8750 S
Pään ja hännän ulostulosta voit huomata, että tietoja ei ole sekoitettu. Tämä on iso ongelma! Kun jaat tietosi junasarjan ja testijoukon kesken, valitset vain luokan 1 ja 2 matkustaja (Yhtään luokan 3 matkustajaa ei ole havaintojen ylimmässä 80 prosentissa), mikä tarkoittaa, että algoritmi ei koskaan näe luokan 3 matkustajan ominaisuuksia. Tämä virhe johtaa huonoon ennusteeseen.
Voit ratkaista tämän ongelman käyttämällä funktiota sample().
shuffle_index <- sample(1:nrow(titanic)) head(shuffle_index)
Päätöspuu R-koodi Selitys
- sample(1:nrow(titanic)): Luo satunnainen indeksiluettelo väliltä 1–1309 (eli rivien enimmäismäärä).
lähtö:
## [1] 288 874 1078 633 887 992
Käytät tätä hakemistoa titanic-tietojoukon sekoittamiseen.
titanic <- titanic[shuffle_index, ] head(titanic)
lähtö:
## X pclass survived ## 288 288 1 0 ## 874 874 3 0 ## 1078 1078 3 1 ## 633 633 3 0 ## 887 887 3 1 ## 992 992 3 1 ## name sex age ## 288 Sutton, Mr. Frederick male 61 ## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42 ## 1078 O'Driscoll, Miss. Bridget female NA ## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39 ## 887 Jermyn, Miss. Annie female NA ## 992 Mamee, Mr. Hanna male NA ## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ ## 874 0 0 348121 7.6500 F G63 S ## 1078 0 0 14311 7.7500 Q ## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN ## 887 0 0 14313 7.7500 Q ## 992 0 0 2677 7.2292 C
Vaihe 2) Puhdista tietojoukko
Tietojen rakenne osoittaa, että joillakin muuttujilla on NA:t. Tietojen puhdistus tehdään seuraavasti
- Pudota muuttujat home.dest,hytti, nimi, X ja lippu
- Luo tekijämuuttujat pclassille ja selviytyville
- Pudota NA
library(dplyr) # Drop variables clean_titanic <- titanic % > % select(-c(home.dest, cabin, name, X, ticket)) % > % #Convert to factor level mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')), survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > % na.omit() glimpse(clean_titanic)
Koodin selitys
- select(-c(koti.kohde, hytti, nimi, X, lippu)): Pudota tarpeettomat muuttujat
- pclass = factor(pclass, level = c(1,2,3), labels= c('Ylä', 'Keski, 'Alempi')): Lisää nimiö muuttujaan pclass. 1 muuttuu ylemmäksi, 2 muuttuu keskimmäiseksi ja 3 alemmaksi
- factor(survived, level = c(0,1), labels = c('Ei', 'Kyllä')): Lisää nimike muuttujaan elossa. 1 muuttuu ei ja 2 tulee kyllä
- na.omit(): Poista NA-havainnot
lähtö:
## Observations: 1,045 ## Variables: 8 ## $ pclass <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U... ## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y... ## $ sex <fctr> male, male, female, female, male, male, female, male... ## $ age <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ... ## $ sibsp <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,... ## $ parch <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,... ## $ fare <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ... ## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...
Vaihe 3) Luo juna-/testisarja
Ennen kuin koulutat malliasi, sinun on suoritettava kaksi vaihetta:
- Luo juna- ja testijoukko: Harjoittelet mallia junasarjassa ja testaat ennustetta testijoukolla (eli näkymätön data)
- Asenna rpart.plot konsolista
Yleinen käytäntö on jakaa data 80/20:een, 80 prosenttia tiedoista on mallin harjoittelua ja 20 prosenttia ennusteita. Sinun on luotava kaksi erillistä tietokehystä. Et halua koskea testisarjaan ennen kuin olet valmis rakentamaan mallisi. Voit luoda funktion nimen create_train_test(), joka sisältää kolme argumenttia.
create_train_test(df, size = 0.8, train = TRUE) arguments: -df: Dataset used to train the model. -size: Size of the split. By default, 0.8. Numerical value -train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample < - 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } }
Koodin selitys
- function(data, koko=0.8, juna = TOSI): Lisää funktion argumentit
- n_row = nrow(data): Laske tietojoukon rivien määrä
- rivi_koko = koko*n_rivi: Palauta n:s rivi junajoukon muodostamiseksi
- train_sample <- 1:total_row: Valitse ensimmäinen rivi n:nnelle riville
- if (train ==TRUE){ } else { }: Jos ehto on tosi, palauta junajoukko, muuten testijoukko.
Voit testata toimintaasi ja tarkistaa mitat.
data_train <- create_train_test(clean_titanic, 0.8, train = TRUE) data_test <- create_train_test(clean_titanic, 0.8, train = FALSE) dim(data_train)
lähtö:
## [1] 836 8
dim(data_test)
lähtö:
## [1] 209 8
Junatietojoukossa on 1046 riviä, kun taas testitietojoukossa on 262 riviä.
Käytät funktiota prop.table() yhdessä table()-funktion kanssa varmistaaksesi, onko satunnaistusprosessi oikea.
prop.table(table(data_train$survived))
lähtö:
## ## No Yes ## 0.5944976 0.4055024
prop.table(table(data_test$survived))
lähtö:
## ## No Yes ## 0.5789474 0.4210526
Kummassakin aineistossa eloonjääneiden määrä on sama, noin 40 prosenttia.
Asenna rpart.plot
rpart.plot ei ole saatavilla conda-kirjastoista. Voit asentaa sen konsolista:
install.packages("rpart.plot")
Vaihe 4) Rakenna malli
Olet valmis rakentamaan mallin. Rpart-päätöspuufunktion syntaksi on:
rpart(formula, data=, method='') arguments: - formula: The function to predict - data: Specifies the data frame- method: - "class" for a classification tree - "anova" for a regression tree
Käytät luokkamenetelmää, koska ennustat luokan.
library(rpart) library(rpart.plot) fit <- rpart(survived~., data = data_train, method = 'class') rpart.plot(fit, extra = 106
Koodin selitys
- rpart(): Malliin sopiva toiminto. Argumentit ovat:
- selvisi ~.: Päätöspuiden kaava
- data = data_train: Tietojoukko
- method = 'luokka': Sovita binäärimalli
- rpart.plot(fit, extra= 106): Piirrä puu. Lisäominaisuudet on asetettu arvoon 101 näyttämään 2. luokan todennäköisyys (hyödyllinen binäärivasteille). Voit viitata vinjetti saadaksesi lisätietoja muista vaihtoehdoista.
lähtö:
Aloitat juurisolmusta (syvyys 0 yli 3, kaavion yläosa):
- Huipulla se on kokonaistodennäköisyys selviytyä. Se näyttää onnettomuudesta selvinneiden matkustajien osuuden. 41 prosenttia matkustajista selvisi.
- Tämä solmu kysyy, onko matkustajan sukupuoli mies. Jos kyllä, siirryt alas juuren vasempaan lapsisolmuun (syvyys 2). 63 prosenttia on miehiä, joiden eloonjäämistodennäköisyys on 21 prosenttia.
- Toisessa solmussa kysyt, onko miesmatkustaja yli 3.5 vuotta vanha. Jos kyllä, niin selviytymismahdollisuus on 19 prosenttia.
- Jatkat samalla tavalla ymmärtääksesi, mitkä ominaisuudet vaikuttavat selviytymisen todennäköisyyteen.
Huomaa, että yksi päätöspuiden monista ominaisuuksista on, että ne vaativat hyvin vähän tietojen valmistelua. Erityisesti ne eivät vaadi ominaisuuksien skaalausta tai keskitystä.
Oletusarvoisesti rpart()-funktio käyttää Gini epäpuhtausmitta setelin jakamiseksi. Mitä korkeampi Gini-kerroin, sitä enemmän eri esiintymiä solmussa.
Vaihe 5) Tee ennuste
Voit ennustaa testitietojoukon. Voit tehdä ennusteen ennustaa()-funktiolla. R-päätöspuun ennustamisen perussyntaksi on:
predict(fitted_model, df, type = 'class') arguments: - fitted_model: This is the object stored after model estimation. - df: Data frame used to make the prediction - type: Type of prediction - 'class': for classification - 'prob': to compute the probability of each class - 'vector': Predict the mean response at the node level
Haluat ennustaa testisarjasta, ketkä matkustajat selviävät todennäköisemmin hengissä törmäyksen jälkeen. Se tarkoittaa, että tiedät näiden 209 matkustajan joukosta, kumpi selviää vai ei.
predict_unseen <-predict(fit, data_test, type = 'class')
Koodin selitys
- ennusta(fit, data_test, type = 'luokka'): Ennusta testijoukon luokka (0/1)
Testaa matkustajaa, joka ei päässyt perille, ja niitä, jotka pääsivät.
table_mat <- table(data_test$survived, predict_unseen) table_mat
Koodin selitys
- table(data_test$survived, ennustaa_unseen): Luo taulukko laskeaksesi kuinka monta matkustajaa on luokiteltu eloonjääneiksi ja kuolleita verrattuna oikeaan R:n päätöspuuluokitukseen
lähtö:
## predict_unseen ## No Yes ## No 106 15 ## Yes 30 58
Malli ennusti oikein 106 kuollutta matkustajaa, mutta 15 eloonjäänyttä luokiteltiin kuolleiksi. Analogisesti malli luokitteli väärin 30 matkustajaa eloonjääneiksi, vaikka he osoittautuivat kuolleiksi.
Vaihe 6) Mittaa suorituskykyä
Voit laskea tarkkuusmitan luokittelutehtävälle -sovelluksella sekaannusmatriisi:
- sekaannusmatriisi on parempi valinta arvioida luokituksen suorituskykyä. Yleisenä ajatuksena on laskea, kuinka monta kertaa todelliset esiintymät luokitellaan vääriksi.
Jokainen sekoitusmatriisin rivi edustaa todellista kohdetta, kun taas jokainen sarake edustaa ennustettua kohdetta. Tämän matriisin ensimmäisellä rivillä tarkastellaan kuolleita matkustajia (false-luokka): 106 luokiteltiin oikein kuolleiksi (Tosi negatiivinen), kun taas loput luokiteltiin virheellisesti eloonjääneeksi (Väärä positiivinen). Toisella rivillä on selviytyjät, positiivinen luokka oli 58 (Tosi positiivista), samalla kun Tosi negatiivinen oli 30.
Voit laskea tarkkuustesti hämmennysmatriisista:
Se on todellisen positiivisen ja tosi negatiivisen osuus matriisin summasta. R:llä voit koodata seuraavasti:
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
Koodin selitys
- summa(diag(taulukko_matto)): Diagonaalin summa
- sum(table_mat): Matriisin summa.
Voit tulostaa testisarjan tarkkuuden:
print(paste('Accuracy for test', accuracy_Test))
lähtö:
## [1] "Accuracy for test 0.784688995215311"
Testisarjan pistemäärä on 78 prosenttia. Voit toistaa saman harjoituksen harjoitustietojoukolla.
Vaihe 7) Viritä hyperparametrit
R:n päätöspuulla on useita parametreja, jotka ohjaavat sovituksen näkökohtia. rpart-päätöspuukirjastossa voit hallita parametreja rpart.control()-funktiolla. Seuraavassa koodissa esittelet viritettävät parametrit. Voit viitata vinjetti muille parametreille.
rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30) Arguments: -minsplit: Set the minimum number of observations in the node before the algorithm perform a split -minbucket: Set the minimum number of observations in the final note i.e. the leaf -maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0
Jatketaan seuraavasti:
- Rakenna funktio palauttaaksesi tarkkuuden
- Säädä suurin syvyys
- Säädä näytteen vähimmäismäärä, joka solmulla on oltava, ennen kuin se voi jakaa
- Säädä näytteiden vähimmäismäärä, joka lehtisolmussa on oltava
Voit kirjoittaa tarkkuuden näyttävän funktion. Käärit vain aiemmin käyttämäsi koodin:
- ennustaa: ennustaa_näkymätön <- ennusta(fit, data_test, type = 'class')
- Tuota taulukko: table_mat <- table(data_test$survived, ennustaa_näkemätön)
- Laskennan tarkkuus: tarkkuus_testi <- summa(diag(table_mat))/sum(table_mat)
accuracy_tune <- function(fit) { predict_unseen <- predict(fit, data_test, type = 'class') table_mat <- table(data_test$survived, predict_unseen) accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test }
Voit yrittää virittää parametreja ja katsoa, voitko parantaa mallia oletusarvoon verrattuna. Muistutuksena, sinun on saatava suurempi tarkkuus kuin 0.78
control <- rpart.control(minsplit = 4, minbucket = round(5 / 3), maxdepth = 3, cp = 0) tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control) accuracy_tune(tune_fit)
lähtö:
## [1] 0.7990431
Seuraavalla parametrilla:
minsplit = 4 minbucket= round(5/3) maxdepth = 3cp=0
Saat paremman suorituskyvyn kuin edellinen malli. Onnittelut!
Yhteenveto
Voimme tehdä yhteenvedon funktioista, joilla päätöspuualgoritmi koulutetaan R
Kirjasto | Tavoite | Toiminto | luokka | parametrit | Lisätiedot |
---|---|---|---|---|---|
rpart | Junien luokituspuu kirjassa R | rpart() | luokka | kaava, df, menetelmä | |
rpart | Harjoittele regressiopuuta | rpart() | anova | kaava, df, menetelmä | |
rpart | Piirrä puut | rpart.plot() | asennettu malli | ||
pohja | ennustaa | ennustaa() | luokka | asennettu malli, tyyppi | |
pohja | ennustaa | ennustaa() | ongelma | asennettu malli, tyyppi | |
pohja | ennustaa | ennustaa() | vektori | asennettu malli, tyyppi | |
rpart | Ohjausparametrit | rpart.control() | minsplit | Aseta solmun havaintojen vähimmäismäärä ennen kuin algoritmi suorittaa jaon | |
minbucket | Aseta havaintojen vähimmäismäärä viimeiseen nuottiin eli lehtiin | ||||
suurin syvyys | Aseta lopullisen puun minkä tahansa solmun suurin syvyys. Juurisolmua käsitellään syvyydessä 0 | ||||
rpart | Junamalli ohjausparametrilla | rpart() | kaava, df, menetelmä, ohjaus |
Huomautus: Harjoittele malli harjoitusdatalla ja testaa suorituskykyä näkymättömällä datajoukolla, eli testijoukolla.