Drzewo decyzyjne w R: Drzewo klasyfikacyjne z przykładem
Czym są drzewa decyzyjne?
Drzewa decyzyjne są wszechstronnymi algorytmami uczenia maszynowego, które mogą wykonywać zarówno zadania klasyfikacyjne, jak i regresyjne. Są to bardzo potężne algorytmy, zdolne do dopasowania złożonych zestawów danych. Ponadto drzewa decyzyjne są podstawowymi składnikami lasów losowych, które są jednymi z najpotężniejszych algorytmów uczenia maszynowego dostępnych obecnie.
Uczenie i wizualizacja drzew decyzyjnych w R
Aby zbudować pierwsze drzewo decyzyjne w przykładzie R, postępujemy w następujący sposób w tym samouczku dotyczącym drzewa decyzyjnego:
- Krok 1: Zaimportuj dane
- Krok 2: Wyczyść zbiór danych
- Krok 3: Utwórz zestaw pociągowy/testowy
- Krok 4: Zbuduj model
- Krok 5: Przewiduj
- Krok 6: Zmierz wydajność
- Krok 7: Dostosuj hiperparametry
Krok 1) Zaimportuj dane
Jeśli ciekawi Cię los Titanica, możesz obejrzeć ten film na YouTube. Celem tego zbioru danych jest przewidzenie, którzy ludzie mają większe szanse na przeżycie po zderzeniu z górą lodową. Zbiór danych zawiera 13 zmiennych i 1309 obserwacji. Zbiór danych jest uporządkowany według zmiennej X.
set.seed(678) path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv' titanic <-read.csv(path) head(titanic)
Wyjście:
## X pclass survived name sex ## 1 1 1 1 Allen, Miss. Elisabeth Walton female ## 2 2 1 1 Allison, Master. Hudson Trevor male ## 3 3 1 0 Allison, Miss. Helen Loraine female ## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male ## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female ## 6 6 1 1 Anderson, Mr. Harry male ## age sibsp parch ticket fare cabin embarked ## 1 29.0000 0 0 24160 211.3375 B5 S ## 2 0.9167 1 2 113781 151.5500 C22 C26 S ## 3 2.0000 1 2 113781 151.5500 C22 C26 S ## 4 30.0000 1 2 113781 151.5500 C22 C26 S ## 5 25.0000 1 2 113781 151.5500 C22 C26 S ## 6 48.0000 0 0 19952 26.5500 E12 S ## home.dest ## 1 St Louis, MO ## 2 Montreal, PQ / Chesterville, ON ## 3 Montreal, PQ / Chesterville, ON ## 4 Montreal, PQ / Chesterville, ON ## 5 Montreal, PQ / Chesterville, ON ## 6 New York, NY
tail(titanic)
Wyjście:
## X pclass survived name sex age sibsp ## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0 ## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1 ## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1 ## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0 ## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0 ## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0 ## parch ticket fare cabin embarked home.dest ## 1304 0 2627 14.4583 C ## 1305 0 2665 14.4542 C ## 1306 0 2665 14.4542 C ## 1307 0 2656 7.2250 C ## 1308 0 2670 7.2250 C ## 1309 0 315082 7.8750 S
Na podstawie danych wyjściowych typu „head and tail” można zauważyć, że dane nie są tasowane. To duży problem! Kiedy podzielisz dane pomiędzy zestawem pociągowym i zestawem testowym, dokonasz wyboru tylko pasażera z klasy 1 i 2 (żaden pasażer z klasy 3 nie znajduje się w górnych 80 procentach obserwacji), co oznacza, że algorytm nigdy nie dostrzeże cech pasażera klasy 3. Ten błąd będzie skutkować słabą predykcją.
Aby rozwiązać ten problem, możesz użyć funkcji sample().
shuffle_index <- sample(1:nrow(titanic)) head(shuffle_index)
Drzewo decyzyjne Kod R Wyjaśnienie
- sample(1:nrow(titanic)): Wygeneruj losową listę indeksów od 1 do 1309 (tzn. maksymalną liczbę wierszy).
Wyjście:
## [1] 288 874 1078 633 887 992
Użyjesz tego indeksu do przetasowania zbioru danych Titanica.
titanic <- titanic[shuffle_index, ] head(titanic)
Wyjście:
## X pclass survived ## 288 288 1 0 ## 874 874 3 0 ## 1078 1078 3 1 ## 633 633 3 0 ## 887 887 3 1 ## 992 992 3 1 ## name sex age ## 288 Sutton, Mr. Frederick male 61 ## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42 ## 1078 O'Driscoll, Miss. Bridget female NA ## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39 ## 887 Jermyn, Miss. Annie female NA ## 992 Mamee, Mr. Hanna male NA ## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ ## 874 0 0 348121 7.6500 F G63 S ## 1078 0 0 14311 7.7500 Q ## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN ## 887 0 0 14313 7.7500 Q ## 992 0 0 2677 7.2292 C
Krok 2) Wyczyść zbiór danych
Struktura danych pokazuje, że niektóre zmienne mają NA. Czyszczenie danych należy wykonać w następujący sposób
- Upuść zmienne home.dest,cabin, name, X i bilet
- Utwórz zmienne czynnikowe dla pclass i przetrwaj
- Porzuć NA
library(dplyr) # Drop variables clean_titanic <- titanic % > % select(-c(home.dest, cabin, name, X, ticket)) % > % #Convert to factor level mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')), survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > % na.omit() glimpse(clean_titanic)
Objaśnienie kodu
- wybierz(-c(home.dest, kabina, nazwa, X, bilet)): Usuń niepotrzebne zmienne
- pclass = współczynnik(pklasa, poziomy = c(1,2,3), etykiety= c('Upper', 'Middle', 'Lower')): Dodaj etykietę do zmiennej pclass. 1 staje się górnym, 2 staje się środkowym, a 3 staje się niższym
- współczynnik(przetrwał, poziomy = c(0,1), etykiety = c('Nie', 'Tak')): Dodaj etykietę do zmiennej przetrwał. 1 staje się Nie, a 2 staje się Tak
- na.omit(): Usuń obserwacje NA
Wyjście:
## Observations: 1,045 ## Variables: 8 ## $ pclass <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U... ## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y... ## $ sex <fctr> male, male, female, female, male, male, female, male... ## $ age <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ... ## $ sibsp <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,... ## $ parch <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,... ## $ fare <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ... ## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...
Krok 3) Utwórz zestaw pociągowy/testowy
Zanim wytrenujesz swój model, musisz wykonać dwa kroki:
- Utwórz zestaw pociągowy i testowy: Trenujesz model na zestawie pociągowym i testujesz predykcję na zestawie testowym (tj. Niewidoczne dane)
- Zainstaluj rpart.plot z konsoli
Powszechną praktyką jest dzielenie danych w stosunku 80/20, przy czym 80 procent danych służy do uczenia modelu, a 20 procent do tworzenia prognoz. Musisz utworzyć dwie oddzielne ramki danych. Nie chcesz dotykać zestawu testowego, dopóki nie zakończysz budowania modelu. Możesz utworzyć funkcję o nazwie create_train_test(), która przyjmuje trzy argumenty.
create_train_test(df, size = 0.8, train = TRUE) arguments: -df: Dataset used to train the model. -size: Size of the split. By default, 0.8. Numerical value -train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample < - 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } }
Objaśnienie kodu
- funkcja(dane, rozmiar=0.8, pociąg = TRUE): Dodaj argumenty funkcji
- n_row = nrow(data): Zlicz liczbę wierszy w zbiorze danych
- total_row = size*n_row: Zwróć n-ty wiersz, aby skonstruować zestaw
- train_sample <- 1:total_row: Wybierz pierwszy do n-tych wierszy
- if (train ==TRUE){ } else { }: Jeśli warunek ma wartość true, zwróć zestaw pociągowy, w przeciwnym razie zestaw testowy.
Możesz przetestować swoją funkcję i sprawdzić wymiar.
data_train <- create_train_test(clean_titanic, 0.8, train = TRUE) data_test <- create_train_test(clean_titanic, 0.8, train = FALSE) dim(data_train)
Wyjście:
## [1] 836 8
dim(data_test)
Wyjście:
## [1] 209 8
Zbiór danych pociągu ma 1046 wierszy, podczas gdy zbiór danych testowych ma 262 wiersze.
Używasz funkcji prop.table() w połączeniu z table() w celu sprawdzenia, czy proces randomizacji jest prawidłowy.
prop.table(table(data_train$survived))
Wyjście:
## ## No Yes ## 0.5944976 0.4055024
prop.table(table(data_test$survived))
Wyjście:
## ## No Yes ## 0.5789474 0.4210526
W obu zbiorach danych liczba ocalałych jest taka sama i wynosi około 40 procent.
Zainstaluj rpart.plot
Plik rpart.plot nie jest dostępny w bibliotekach Conda. Możesz zainstalować go z konsoli:
install.packages("rpart.plot")
Krok 4) Zbuduj model
Jesteś gotowy do zbudowania modelu. Składnia funkcji drzewa decyzyjnego Rpart jest następująca:
rpart(formula, data=, method='') arguments: - formula: The function to predict - data: Specifies the data frame- method: - "class" for a classification tree - "anova" for a regression tree
Używasz metody klasowej, ponieważ przewidujesz klasę.
library(rpart) library(rpart.plot) fit <- rpart(survived~., data = data_train, method = 'class') rpart.plot(fit, extra = 106
Objaśnienie kodu
- rpart(): Funkcja dopasowująca się do modelu. Argumenty są następujące:
- przetrwał ~.: Formuła drzew decyzyjnych
- dane = data_train: Zbiór danych
- metoda = „klasa”: Dopasuj model binarny
- rpart.plot(fit, extra= 106): Narysuj drzewo. Dodatkowe funkcje są ustawione na 101, aby wyświetlić prawdopodobieństwo drugiej klasy (przydatne w przypadku odpowiedzi binarnych). Możesz odwołać się do winieta aby uzyskać więcej informacji na temat innych opcji.
Wyjście:
Zaczynasz od węzła głównego (głębokość 0 na 3, góra wykresu):
- Na górze jest to ogólne prawdopodobieństwo przeżycia. Pokazuje odsetek pasażerów, którzy przeżyli katastrofę. 41 procent pasażerów przeżyło.
- Węzeł ten pyta, czy pasażer jest płci męskiej. Jeśli tak, zejdź do lewego węzła podrzędnego korzenia (głębokość 2). 63 procent to mężczyźni, a prawdopodobieństwo przeżycia wynosi 21 procent.
- W drugim węźle pytasz, czy pasażer płci męskiej ma więcej niż 3.5 roku. Jeśli tak, to szansa na przeżycie wynosi 19 procent.
- Kontynuuj w ten sposób, aby zrozumieć, jakie cechy wpływają na prawdopodobieństwo przeżycia.
Należy pamiętać, że jedną z wielu zalet drzew decyzyjnych jest to, że wymagają one bardzo niewielkiego przygotowania danych. W szczególności nie wymagają skalowania ani centrowania funkcji.
Domyślnie funkcja rpart() używa Gini miara zanieczyszczeń, aby podzielić notatkę. Im wyższy współczynnik Giniego, tym więcej różnych instancji w węźle.
Krok 5) Przewiduj
Możesz przewidzieć swój testowy zbiór danych. Aby dokonać przewidywania, możesz użyć funkcji przewidywania(). Podstawowa składnia drzewa decyzyjnego przewidywania dla R jest następująca:
predict(fitted_model, df, type = 'class') arguments: - fitted_model: This is the object stored after model estimation. - df: Data frame used to make the prediction - type: Type of prediction - 'class': for classification - 'prob': to compute the probability of each class - 'vector': Predict the mean response at the node level
Chcesz przewidzieć, którzy pasażerowie mają większe szanse na przeżycie po zderzeniu ze zbioru testowego. Oznacza to, że spośród tych 209 pasażerów będziesz wiedział, który z nich przeżyje, czy nie.
predict_unseen <-predict(fit, data_test, type = 'class')
Objaśnienie kodu
- przewidywanie(fit, data_test, type = 'class'): Przewiduj klasę (0/1) zbioru testowego
Testowanie pasażera, któremu się nie udało, i tych, którym się udało.
table_mat <- table(data_test$survived, predict_unseen) table_mat
Objaśnienie kodu
- table(data_test$survived, przewidywanie_unseen): Utwórz tabelę, aby policzyć, ilu pasażerów zostało sklasyfikowanych jako ocaleni i zmarło w porównaniu z poprawną klasyfikacją drzewa decyzyjnego w R
Wyjście:
## predict_unseen ## No Yes ## No 106 15 ## Yes 30 58
Model prawidłowo przewidział śmierć 106 pasażerów, ale sklasyfikował 15 ocalałych jako martwych. Przez analogię model błędnie sklasyfikował 30 pasażerów jako ocalałych, choć okazało się, że nie żyją.
Krok 6) Zmierz wydajność
Możesz obliczyć miarę dokładności dla zadania klasyfikacji za pomocą matryca zamieszania:
matryca zamieszania jest lepszym wyborem do oceny skuteczności klasyfikacji. Ogólna koncepcja polega na tym, aby policzyć, ile razy prawdziwe przypadki zostały sklasyfikowane jako fałszywe.
Każdy wiersz w macierzy zamieszania reprezentuje rzeczywisty cel, podczas gdy każda kolumna reprezentuje przewidywany cel. Pierwszy wiersz tej macierzy uwzględnia martwych pasażerów (klasa Fałsz): 106 zostało poprawnie sklasyfikowanych jako martwych (Prawdziwy negatyw), natomiast pozostałego błędnie zakwalifikowano jako ocalałego (Fałszywie pozytywne). Drugi rząd uwzględnia ocalałych, klasa pozytywna to 58 (Prawdziwie pozytywne), podczas, gdy Prawdziwy negatyw było 30.
Można obliczyć próba dokładności z macierzy zamieszania:
Jest to proporcja prawdziwie dodatniego i prawdziwie ujemnego w sumie macierzy. Za pomocą R możesz kodować w następujący sposób:
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
Objaśnienie kodu
- sum(diag(table_mat)): Suma przekątnej
- sum(table_mat): Suma macierzy.
Możesz wydrukować dokładność zbioru testowego:
print(paste('Accuracy for test', accuracy_Test))
Wyjście:
## [1] "Accuracy for test 0.784688995215311"
W zestawie testowym uzyskałeś wynik 78 procent. Możesz powtórzyć to samo ćwiczenie, korzystając ze zbioru danych szkoleniowych.
Krok 7) Dostosuj hiperparametry
Drzewo decyzyjne w R ma różne parametry, które kontrolują aspekty dopasowania. W bibliotece drzewa decyzyjnego rpart możesz kontrolować parametry za pomocą funkcji rpart.control(). W poniższym kodzie wprowadzasz parametry, które będziesz dostrajać. Możesz odwołać się do winieta dla innych parametrów.
rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30) Arguments: -minsplit: Set the minimum number of observations in the node before the algorithm perform a split -minbucket: Set the minimum number of observations in the final note i.e. the leaf -maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0
Postępujemy następująco:
- Skonstruuj funkcję zwracającą dokładność
- Dostrój maksymalną głębokość
- Dostosuj minimalną liczbę próbek, jaką musi mieć węzeł, zanim będzie mógł się podzielić
- Dostosuj minimalną liczbę próbek, jaką musi mieć węzeł liścia
Można napisać funkcję wyświetlającą dokładność. Po prostu zawiń kod, którego użyłeś wcześniej:
- przewidywać: przewidywać_unseen <- przewidywać(fit, data_test, type = 'class')
- Utwórz tabelę: table_mat <- table(data_test$survived, przewidywanie_unseen)
- Dokładność obliczeń: dokładność_Test <- sum(diag(table_mat))/sum(table_mat)
accuracy_tune <- function(fit) { predict_unseen <- predict(fit, data_test, type = 'class') table_mat <- table(data_test$survived, predict_unseen) accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test }
Możesz spróbować dostroić parametry i sprawdzić, czy możesz ulepszyć model w stosunku do wartości domyślnej. Dla przypomnienia należy uzyskać dokładność większą niż 0.78
control <- rpart.control(minsplit = 4, minbucket = round(5 / 3), maxdepth = 3, cp = 0) tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control) accuracy_tune(tune_fit)
Wyjście:
## [1] 0.7990431
Z następującym parametrem:
minsplit = 4 minbucket= round(5/3) maxdepth = 3cp=0
Otrzymujesz wyższą wydajność niż w poprzednim modelu. Gratulacje!
Podsumowanie
Możemy podsumować funkcje służące do trenowania algorytmu drzewa decyzyjnego w R
Biblioteka | Cel | Funkcjonować | Klasa | Parametry | Szczegóły |
---|---|---|---|---|---|
część | Drzewo klasyfikacji pociągów w R | rpart() | klasa | formuła, df, metoda | |
część | Drzewo regresji pociągu | rpart() | anowa | formuła, df, metoda | |
część | Narysuj drzewa | rpart.plot() | dopasowany model | ||
baza | przewidzieć | przepowiadać, wywróżyć() | klasa | dopasowany model, typ | |
baza | przewidzieć | przepowiadać, wywróżyć() | prawd | dopasowany model, typ | |
baza | przewidzieć | przepowiadać, wywróżyć() | wektor | dopasowany model, typ | |
część | Parametry dotyczące kontroli | rpart.control() | podział minut | Ustaw minimalną liczbę obserwacji w węźle, zanim algorytm wykona podział | |
minbucket | Ustaw minimalną liczbę obserwacji w notatce końcowej, czyli na liściu | ||||
maksymalna głębokość | Ustaw maksymalną głębokość dowolnego węzła końcowego drzewa. Węzeł główny jest traktowany na głębokości 0 | ||||
część | Model pociągu z parametrem kontrolnym | rpart() | formuła, df, metoda, kontrola |
Uwaga: trenuj model na danych szkoleniowych i przetestuj wydajność na niewidocznym zestawie danych, tj. zestawie testowym.