Drzewo decyzyjne w R: Drzewo klasyfikacyjne z przykładem

Czym są drzewa decyzyjne?

Drzewa decyzyjne są wszechstronnymi algorytmami uczenia maszynowego, które mogą wykonywać zarówno zadania klasyfikacyjne, jak i regresyjne. Są to bardzo potężne algorytmy, zdolne do dopasowania złożonych zestawów danych. Ponadto drzewa decyzyjne są podstawowymi składnikami lasów losowych, które są jednymi z najpotężniejszych algorytmów uczenia maszynowego dostępnych obecnie.

Uczenie i wizualizacja drzew decyzyjnych w R

Aby zbudować pierwsze drzewo decyzyjne w przykładzie R, postępujemy w następujący sposób w tym samouczku dotyczącym drzewa decyzyjnego:

  • Krok 1: Zaimportuj dane
  • Krok 2: Wyczyść zbiór danych
  • Krok 3: Utwórz zestaw pociągowy/testowy
  • Krok 4: Zbuduj model
  • Krok 5: Przewiduj
  • Krok 6: Zmierz wydajność
  • Krok 7: Dostosuj hiperparametry

Krok 1) Zaimportuj dane

Jeśli ciekawi Cię los Titanica, możesz obejrzeć ten film na YouTube. Celem tego zbioru danych jest przewidzenie, którzy ludzie mają większe szanse na przeżycie po zderzeniu z górą lodową. Zbiór danych zawiera 13 zmiennych i 1309 obserwacji. Zbiór danych jest uporządkowany według zmiennej X.

set.seed(678)
path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv'
titanic <-read.csv(path)
head(titanic)

Wyjście:

##   X pclass survived                                            name    sex
## 1 1      1        1                   Allen, Miss. Elisabeth Walton female
## 2 2      1        1                  Allison, Master. Hudson Trevor   male
## 3 3      1        0                    Allison, Miss. Helen Loraine female
## 4 4      1        0            Allison, Mr. Hudson Joshua Creighton   male
## 5 5      1        0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female
## 6 6      1        1                             Anderson, Mr. Harry   male
##       age sibsp parch ticket     fare   cabin embarked
## 1 29.0000     0     0  24160 211.3375      B5        S
## 2  0.9167     1     2 113781 151.5500 C22 C26        S
## 3  2.0000     1     2 113781 151.5500 C22 C26        S
## 4 30.0000     1     2 113781 151.5500 C22 C26        S
## 5 25.0000     1     2 113781 151.5500 C22 C26        S
## 6 48.0000     0     0  19952  26.5500     E12        S
##                         home.dest
## 1                    St Louis, MO
## 2 Montreal, PQ / Chesterville, ON
## 3 Montreal, PQ / Chesterville, ON
## 4 Montreal, PQ / Chesterville, ON
## 5 Montreal, PQ / Chesterville, ON
## 6                    New York, NY
tail(titanic)

Wyjście:

##         X pclass survived                      name    sex  age sibsp
## 1304 1304      3        0     Yousseff, Mr. Gerious   male   NA     0
## 1305 1305      3        0      Zabour, Miss. Hileni female 14.5     1
## 1306 1306      3        0     Zabour, Miss. Thamine female   NA     1
## 1307 1307      3        0 Zakarian, Mr. Mapriededer   male 26.5     0
## 1308 1308      3        0       Zakarian, Mr. Ortin   male 27.0     0
## 1309 1309      3        0        Zimmerman, Mr. Leo   male 29.0     0
##      parch ticket    fare cabin embarked home.dest
## 1304     0   2627 14.4583              C          
## 1305     0   2665 14.4542              C          
## 1306     0   2665 14.4542              C          
## 1307     0   2656  7.2250              C          
## 1308     0   2670  7.2250              C          
## 1309     0 315082  7.8750              S

Na podstawie danych wyjściowych typu „head and tail” można zauważyć, że dane nie są tasowane. To duży problem! Kiedy podzielisz dane pomiędzy zestawem pociągowym i zestawem testowym, dokonasz wyboru tylko pasażera z klasy 1 i 2 (żaden pasażer z klasy 3 nie znajduje się w górnych 80 procentach obserwacji), co oznacza, że ​​algorytm nigdy nie dostrzeże cech pasażera klasy 3. Ten błąd będzie skutkować słabą predykcją.

Aby rozwiązać ten problem, możesz użyć funkcji sample().

shuffle_index <- sample(1:nrow(titanic))
head(shuffle_index)

Drzewo decyzyjne Kod R Wyjaśnienie

  • sample(1:nrow(titanic)): Wygeneruj losową listę indeksów od 1 do 1309 (tzn. maksymalną liczbę wierszy).

Wyjście:

## [1]  288  874 1078  633  887  992

Użyjesz tego indeksu do przetasowania zbioru danych Titanica.

titanic <- titanic[shuffle_index, ]
head(titanic)

Wyjście:

##         X pclass survived
## 288   288      1        0
## 874   874      3        0
## 1078 1078      3        1
## 633   633      3        0
## 887   887      3        1
## 992   992      3        1
##                                                           name    sex age
## 288                                      Sutton, Mr. Frederick   male  61
## 874                   Humblen, Mr. Adolf Mathias Nicolai Olsen   male  42
## 1078                                 O'Driscoll, Miss. Bridget female  NA
## 633  Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female  39
## 887                                        Jermyn, Miss. Annie female  NA
## 992                                           Mamee, Mr. Hanna   male  NA
##      sibsp parch ticket    fare cabin embarked           home.dest## 288      0     0  36963 32.3208   D50        S     Haddenfield, NJ
## 874      0     0 348121  7.6500 F G63        S                    
## 1078     0     0  14311  7.7500              Q                    
## 633      1     5 347082 31.2750              S Sweden Winnipeg, MN
## 887      0     0  14313  7.7500              Q                    
## 992      0     0   2677  7.2292              C	

Krok 2) Wyczyść zbiór danych

Struktura danych pokazuje, że niektóre zmienne mają NA. Czyszczenie danych należy wykonać w następujący sposób

  • Upuść zmienne home.dest,cabin, name, X i bilet
  • Utwórz zmienne czynnikowe dla pclass i przetrwaj
  • Porzuć NA
library(dplyr)
# Drop variables
clean_titanic <- titanic % > %
select(-c(home.dest, cabin, name, X, ticket)) % > % 
#Convert to factor level
	mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')),
	survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > %
na.omit()
glimpse(clean_titanic)

Objaśnienie kodu

  • wybierz(-c(home.dest, kabina, nazwa, X, bilet)): Usuń niepotrzebne zmienne
  • pclass = współczynnik(pklasa, poziomy = c(1,2,3), etykiety= c('Upper', 'Middle', 'Lower')): Dodaj etykietę do zmiennej pclass. 1 staje się górnym, 2 staje się środkowym, a 3 staje się niższym
  • współczynnik(przetrwał, poziomy = c(0,1), etykiety = c('Nie', 'Tak')): Dodaj etykietę do zmiennej przetrwał. 1 staje się Nie, a 2 staje się Tak
  • na.omit(): Usuń obserwacje NA

Wyjście:

## Observations: 1,045
## Variables: 8
## $ pclass   <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U...
## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y...
## $ sex      <fctr> male, male, female, female, male, male, female, male...
## $ age      <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ...
## $ sibsp    <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,...
## $ parch    <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,...
## $ fare     <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ...
## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...		

Krok 3) Utwórz zestaw pociągowy/testowy

Zanim wytrenujesz swój model, musisz wykonać dwa kroki:

  • Utwórz zestaw pociągowy i testowy: Trenujesz model na zestawie pociągowym i testujesz predykcję na zestawie testowym (tj. Niewidoczne dane)
  • Zainstaluj rpart.plot z konsoli

Powszechną praktyką jest dzielenie danych w stosunku 80/20, przy czym 80 procent danych służy do uczenia modelu, a 20 procent do tworzenia prognoz. Musisz utworzyć dwie oddzielne ramki danych. Nie chcesz dotykać zestawu testowego, dopóki nie zakończysz budowania modelu. Możesz utworzyć funkcję o nazwie create_train_test(), która przyjmuje trzy argumenty.

create_train_test(df, size = 0.8, train = TRUE)
arguments:
-df: Dataset used to train the model.
-size: Size of the split. By default, 0.8. Numerical value
-train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) {
    n_row = nrow(data)
    total_row = size * n_row
    train_sample < - 1: total_row
    if (train == TRUE) {
        return (data[train_sample, ])
    } else {
        return (data[-train_sample, ])
    }
}

Objaśnienie kodu

  • funkcja(dane, rozmiar=0.8, pociąg = TRUE): Dodaj argumenty funkcji
  • n_row = nrow(data): Zlicz liczbę wierszy w zbiorze danych
  • total_row = size*n_row: Zwróć n-ty wiersz, aby skonstruować zestaw
  • train_sample <- 1:total_row: Wybierz pierwszy do n-tych wierszy
  • if (train ==TRUE){ } else { }: Jeśli warunek ma wartość true, zwróć zestaw pociągowy, w przeciwnym razie zestaw testowy.

Możesz przetestować swoją funkcję i sprawdzić wymiar.

data_train <- create_train_test(clean_titanic, 0.8, train = TRUE)
data_test <- create_train_test(clean_titanic, 0.8, train = FALSE)
dim(data_train)

Wyjście:

## [1] 836   8
dim(data_test)

Wyjście:

## [1] 209   8

Zbiór danych pociągu ma 1046 wierszy, podczas gdy zbiór danych testowych ma 262 wiersze.

Używasz funkcji prop.table() w połączeniu z table() w celu sprawdzenia, czy proces randomizacji jest prawidłowy.

prop.table(table(data_train$survived))

Wyjście:

##
##        No       Yes 
## 0.5944976 0.4055024
prop.table(table(data_test$survived))

Wyjście:

## 
##        No       Yes 
## 0.5789474 0.4210526

W obu zbiorach danych liczba ocalałych jest taka sama i wynosi około 40 procent.

Zainstaluj rpart.plot

Plik rpart.plot nie jest dostępny w bibliotekach Conda. Możesz zainstalować go z konsoli:

install.packages("rpart.plot")

Krok 4) Zbuduj model

Jesteś gotowy do zbudowania modelu. Składnia funkcji drzewa decyzyjnego Rpart jest następująca:

rpart(formula, data=, method='')
arguments:			
- formula: The function to predict
- data: Specifies the data frame- method: 			
- "class" for a classification tree 			
- "anova" for a regression tree	

Używasz metody klasowej, ponieważ przewidujesz klasę.

library(rpart)
library(rpart.plot)
fit <- rpart(survived~., data = data_train, method = 'class')
rpart.plot(fit, extra = 106

Objaśnienie kodu

  • rpart(): Funkcja dopasowująca się do modelu. Argumenty są następujące:
    • przetrwał ~.: Formuła drzew decyzyjnych
    • dane = data_train: Zbiór danych
    • metoda = „klasa”: Dopasuj model binarny
  • rpart.plot(fit, extra= 106): Narysuj drzewo. Dodatkowe funkcje są ustawione na 101, aby wyświetlić prawdopodobieństwo drugiej klasy (przydatne w przypadku odpowiedzi binarnych). Możesz odwołać się do winieta aby uzyskać więcej informacji na temat innych opcji.

Wyjście:

Zbuduj model drzew decyzyjnych w R

Zaczynasz od węzła głównego (głębokość 0 na 3, góra wykresu):

  1. Na górze jest to ogólne prawdopodobieństwo przeżycia. Pokazuje odsetek pasażerów, którzy przeżyli katastrofę. 41 procent pasażerów przeżyło.
  2. Węzeł ten pyta, czy pasażer jest płci męskiej. Jeśli tak, zejdź do lewego węzła podrzędnego korzenia (głębokość 2). 63 procent to mężczyźni, a prawdopodobieństwo przeżycia wynosi 21 procent.
  3. W drugim węźle pytasz, czy pasażer płci męskiej ma więcej niż 3.5 roku. Jeśli tak, to szansa na przeżycie wynosi 19 procent.
  4. Kontynuuj w ten sposób, aby zrozumieć, jakie cechy wpływają na prawdopodobieństwo przeżycia.

Należy pamiętać, że jedną z wielu zalet drzew decyzyjnych jest to, że wymagają one bardzo niewielkiego przygotowania danych. W szczególności nie wymagają skalowania ani centrowania funkcji.

Domyślnie funkcja rpart() używa Gini miara zanieczyszczeń, aby podzielić notatkę. Im wyższy współczynnik Giniego, tym więcej różnych instancji w węźle.

Krok 5) Przewiduj

Możesz przewidzieć swój testowy zbiór danych. Aby dokonać przewidywania, możesz użyć funkcji przewidywania(). Podstawowa składnia drzewa decyzyjnego przewidywania dla R jest następująca:

predict(fitted_model, df, type = 'class')
arguments:
- fitted_model: This is the object stored after model estimation. 
- df: Data frame used to make the prediction
- type: Type of prediction			
    - 'class': for classification			
    - 'prob': to compute the probability of each class			
    - 'vector': Predict the mean response at the node level	

Chcesz przewidzieć, którzy pasażerowie mają większe szanse na przeżycie po zderzeniu ze zbioru testowego. Oznacza to, że spośród tych 209 pasażerów będziesz wiedział, który z nich przeżyje, czy nie.

predict_unseen <-predict(fit, data_test, type = 'class')

Objaśnienie kodu

  • przewidywanie(fit, data_test, type = 'class'): Przewiduj klasę (0/1) zbioru testowego

Testowanie pasażera, któremu się nie udało, i tych, którym się udało.

table_mat <- table(data_test$survived, predict_unseen)
table_mat

Objaśnienie kodu

  • table(data_test$survived, przewidywanie_unseen): Utwórz tabelę, aby policzyć, ilu pasażerów zostało sklasyfikowanych jako ocaleni i zmarło w porównaniu z poprawną klasyfikacją drzewa decyzyjnego w R

Wyjście:

##      predict_unseen
##        No Yes
##   No  106  15
##   Yes  30  58

Model prawidłowo przewidział śmierć 106 pasażerów, ale sklasyfikował 15 ocalałych jako martwych. Przez analogię model błędnie sklasyfikował 30 pasażerów jako ocalałych, choć okazało się, że nie żyją.

Krok 6) Zmierz wydajność

Możesz obliczyć miarę dokładności dla zadania klasyfikacji za pomocą matryca zamieszania:

matryca zamieszania jest lepszym wyborem do oceny skuteczności klasyfikacji. Ogólna koncepcja polega na tym, aby policzyć, ile razy prawdziwe przypadki zostały sklasyfikowane jako fałszywe.

Zmierz wydajność drzew decyzyjnych w R

Każdy wiersz w macierzy zamieszania reprezentuje rzeczywisty cel, podczas gdy każda kolumna reprezentuje przewidywany cel. Pierwszy wiersz tej macierzy uwzględnia martwych pasażerów (klasa Fałsz): 106 zostało poprawnie sklasyfikowanych jako martwych (Prawdziwy negatyw), natomiast pozostałego błędnie zakwalifikowano jako ocalałego (Fałszywie pozytywne). Drugi rząd uwzględnia ocalałych, klasa pozytywna to 58 (Prawdziwie pozytywne), podczas, gdy Prawdziwy negatyw było 30.

Można obliczyć próba dokładności z macierzy zamieszania:

Zmierz wydajność drzew decyzyjnych w R

Jest to proporcja prawdziwie dodatniego i prawdziwie ujemnego w sumie macierzy. Za pomocą R możesz kodować w następujący sposób:

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)

Objaśnienie kodu

  • sum(diag(table_mat)): Suma przekątnej
  • sum(table_mat): Suma macierzy.

Możesz wydrukować dokładność zbioru testowego:

print(paste('Accuracy for test', accuracy_Test))

Wyjście:

## [1] "Accuracy for test 0.784688995215311"

W zestawie testowym uzyskałeś wynik 78 procent. Możesz powtórzyć to samo ćwiczenie, korzystając ze zbioru danych szkoleniowych.

Krok 7) Dostosuj hiperparametry

Drzewo decyzyjne w R ma różne parametry, które kontrolują aspekty dopasowania. W bibliotece drzewa decyzyjnego rpart możesz kontrolować parametry za pomocą funkcji rpart.control(). W poniższym kodzie wprowadzasz parametry, które będziesz dostrajać. Możesz odwołać się do winieta dla innych parametrów.

rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30)
Arguments:
-minsplit: Set the minimum number of observations in the node before the algorithm perform a split
-minbucket:  Set the minimum number of observations in the final note i.e. the leaf
-maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0

Postępujemy następująco:

  • Skonstruuj funkcję zwracającą dokładność
  • Dostrój maksymalną głębokość
  • Dostosuj minimalną liczbę próbek, jaką musi mieć węzeł, zanim będzie mógł się podzielić
  • Dostosuj minimalną liczbę próbek, jaką musi mieć węzeł liścia

Można napisać funkcję wyświetlającą dokładność. Po prostu zawiń kod, którego użyłeś wcześniej:

  1. przewidywać: przewidywać_unseen <- przewidywać(fit, data_test, type = 'class')
  2. Utwórz tabelę: table_mat <- table(data_test$survived, przewidywanie_unseen)
  3. Dokładność obliczeń: dokładność_Test <- sum(diag(table_mat))/sum(table_mat)
accuracy_tune <- function(fit) {
    predict_unseen <- predict(fit, data_test, type = 'class')
    table_mat <- table(data_test$survived, predict_unseen)
    accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
    accuracy_Test
}

Możesz spróbować dostroić parametry i sprawdzić, czy możesz ulepszyć model w stosunku do wartości domyślnej. Dla przypomnienia należy uzyskać dokładność większą niż 0.78

control <- rpart.control(minsplit = 4,
    minbucket = round(5 / 3),
    maxdepth = 3,
    cp = 0)
tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control)
accuracy_tune(tune_fit)

Wyjście:

## [1] 0.7990431

Z następującym parametrem:

minsplit = 4
minbucket= round(5/3)
maxdepth = 3cp=0

Otrzymujesz wyższą wydajność niż w poprzednim modelu. Gratulacje!

Podsumowanie

Możemy podsumować funkcje służące do trenowania algorytmu drzewa decyzyjnego w R

Biblioteka Cel Funkcjonować Klasa Parametry Szczegóły
część Drzewo klasyfikacji pociągów w R rpart() klasa formuła, df, metoda
część Drzewo regresji pociągu rpart() anowa formuła, df, metoda
część Narysuj drzewa rpart.plot() dopasowany model
baza przewidzieć przepowiadać, wywróżyć() klasa dopasowany model, typ
baza przewidzieć przepowiadać, wywróżyć() prawd dopasowany model, typ
baza przewidzieć przepowiadać, wywróżyć() wektor dopasowany model, typ
część Parametry dotyczące kontroli rpart.control() podział minut Ustaw minimalną liczbę obserwacji w węźle, zanim algorytm wykona podział
minbucket Ustaw minimalną liczbę obserwacji w notatce końcowej, czyli na liściu
maksymalna głębokość Ustaw maksymalną głębokość dowolnego węzła końcowego drzewa. Węzeł główny jest traktowany na głębokości 0
część Model pociągu z parametrem kontrolnym rpart() formuła, df, metoda, kontrola

Uwaga: trenuj model na danych szkoleniowych i przetestuj wydajność na niewidocznym zestawie danych, tj. zestawie testowym.