Дърво на решенията в R: Класификационно дърво с пример
Какво представляват дърветата на решенията?
Дървета за вземане на решения са многофункционален алгоритъм за машинно обучение, който може да изпълнява както задачи за класификация, така и регресия. Те са много мощни алгоритми, способни да монтират сложни масиви от данни. Освен това дърветата на решенията са основни компоненти на произволни гори, които са сред най-мощните алгоритми за машинно обучение, налични днес.
Обучение и визуализиране на дървета на решения в R
За да изградите вашето първо дърво на решения в R пример, ще продължим както следва в този урок за дърво на решенията:
- Стъпка 1: Импортирайте данните
- Стъпка 2: Почистете набора от данни
- Стъпка 3: Създайте комплект за обучение/тест
- Стъпка 4: Изградете модела
- Стъпка 5: Направете прогноза
- Стъпка 6: Измерете ефективността
- Стъпка 7: Настройте хиперпараметрите
Стъпка 1) Импортирайте данните
Ако сте любопитни за съдбата на титаник, можете да гледате това видео на Youtube. Целта на този набор от данни е да се предвиди кои хора са по-склонни да оцелеят след сблъсъка с айсберга. Наборът от данни съдържа 13 променливи и 1309 наблюдения. Наборът от данни е подреден по променливата X.
set.seed(678) path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv' titanic <-read.csv(path) head(titanic)
Изход:
## X pclass survived name sex ## 1 1 1 1 Allen, Miss. Elisabeth Walton female ## 2 2 1 1 Allison, Master. Hudson Trevor male ## 3 3 1 0 Allison, Miss. Helen Loraine female ## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male ## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female ## 6 6 1 1 Anderson, Mr. Harry male ## age sibsp parch ticket fare cabin embarked ## 1 29.0000 0 0 24160 211.3375 B5 S ## 2 0.9167 1 2 113781 151.5500 C22 C26 S ## 3 2.0000 1 2 113781 151.5500 C22 C26 S ## 4 30.0000 1 2 113781 151.5500 C22 C26 S ## 5 25.0000 1 2 113781 151.5500 C22 C26 S ## 6 48.0000 0 0 19952 26.5500 E12 S ## home.dest ## 1 St Louis, MO ## 2 Montreal, PQ / Chesterville, ON ## 3 Montreal, PQ / Chesterville, ON ## 4 Montreal, PQ / Chesterville, ON ## 5 Montreal, PQ / Chesterville, ON ## 6 New York, NY
tail(titanic)
Изход:
## X pclass survived name sex age sibsp ## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0 ## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1 ## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1 ## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0 ## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0 ## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0 ## parch ticket fare cabin embarked home.dest ## 1304 0 2627 14.4583 C ## 1305 0 2665 14.4542 C ## 1306 0 2665 14.4542 C ## 1307 0 2656 7.2250 C ## 1308 0 2670 7.2250 C ## 1309 0 315082 7.8750 S
От изхода на главата и опашката можете да забележите, че данните не са разбъркани. Това е голям проблем! Когато разделяте вашите данни между набор от влакове и набор от тестове, вие ще изберете само за лична употреба пътникът от клас 1 и 2 (Няма пътник от клас 3 в първите 80 процента от наблюденията), което означава, че алгоритъмът никога няма да види характеристиките на пътник от клас 3. Тази грешка ще доведе до лошо прогнозиране.
За да преодолеете този проблем, можете да използвате функцията sample().
shuffle_index <- sample(1:nrow(titanic)) head(shuffle_index)
Дърво на решения R код Обяснение
- sample(1:nrow(titanic)): Генериране на произволен списък с индекси от 1 до 1309 (т.е. максималния брой редове).
Изход:
## [1] 288 874 1078 633 887 992
Ще използвате този индекс, за да разбъркате титаничния набор от данни.
titanic <- titanic[shuffle_index, ] head(titanic)
Изход:
## X pclass survived ## 288 288 1 0 ## 874 874 3 0 ## 1078 1078 3 1 ## 633 633 3 0 ## 887 887 3 1 ## 992 992 3 1 ## name sex age ## 288 Sutton, Mr. Frederick male 61 ## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42 ## 1078 O'Driscoll, Miss. Bridget female NA ## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39 ## 887 Jermyn, Miss. Annie female NA ## 992 Mamee, Mr. Hanna male NA ## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ ## 874 0 0 348121 7.6500 F G63 S ## 1078 0 0 14311 7.7500 Q ## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN ## 887 0 0 14313 7.7500 Q ## 992 0 0 2677 7.2292 C
Стъпка 2) Почистете набора от данни
Структурата на данните показва, че някои променливи имат NA. Почистването на данните трябва да се извърши по следния начин
- Пуснете променливи home.dest,cabin, name, X и билет
- Създайте факторни променливи за pclass и survived
- Пуснете NA
library(dplyr) # Drop variables clean_titanic <- titanic % > % select(-c(home.dest, cabin, name, X, ticket)) % > % #Convert to factor level mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')), survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > % na.omit() glimpse(clean_titanic)
Обяснение на кода
- select(-c(home.dest, cabin, name, X, ticket)): Премахнете ненужните променливи
- pclass = фактор(pclass, нива = c(1,2,3), labels= c('Upper', 'Middle', 'Lower')): Добавяне на етикет към променливата pclass. 1 става Горна, 2 става Средна и 3 става долна
- фактор (оцелял, нива = c(0,1), етикети = c('Не', 'Да')): Добавете етикет към променливата оцелял. 1 става Не и 2 става Да
- na.omit(): Премахване на наблюденията на NA
Изход:
## Observations: 1,045 ## Variables: 8 ## $ pclass <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U... ## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y... ## $ sex <fctr> male, male, female, female, male, male, female, male... ## $ age <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ... ## $ sibsp <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,... ## $ parch <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,... ## $ fare <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ... ## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...
Стъпка 3) Създайте набор от влак/тест
Преди да обучите вашия модел, трябва да извършите две стъпки:
- Създаване на влак и набор от тестове: Вие обучавате модела на набора от влакове и тествате прогнозата върху тестовия набор (т.е. невидяни данни)
- Инсталирайте rpart.plot от конзолата
Обичайната практика е данните да се разделят 80/20, 80 процента от данните служат за обучение на модела, а 20 процента за правене на прогнози. Трябва да създадете две отделни рамки с данни. Не искате да докосвате тестовия набор, докато не завършите изграждането на своя модел. Можете да създадете име на функция create_train_test(), което приема три аргумента.
create_train_test(df, size = 0.8, train = TRUE) arguments: -df: Dataset used to train the model. -size: Size of the split. By default, 0.8. Numerical value -train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample < - 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } }
Обяснение на кода
- функция (данни, размер = 0.8, влак = TRUE): Добавяне на аргументите във функцията
- n_row = nrow(данни): Преброяване на броя редове в набора от данни
- total_row = size*n_row: Върнете n-тия ред, за да конструирате влака
- train_sample <- 1:total_row: Изберете първия ред до n-тия ред
- if (train ==TRUE){ } else { }: Ако условието е вярно, връща набора от влакове, в противен случай тестовия набор.
Можете да тествате функцията си и да проверите размерите.
data_train <- create_train_test(clean_titanic, 0.8, train = TRUE) data_test <- create_train_test(clean_titanic, 0.8, train = FALSE) dim(data_train)
Изход:
## [1] 836 8
dim(data_test)
Изход:
## [1] 209 8
Наборът от данни за влака има 1046 реда, докато тестовият набор от данни има 262 реда.
Използвате функцията prop.table() в комбинация с table(), за да проверите дали процесът на рандомизиране е правилен.
prop.table(table(data_train$survived))
Изход:
## ## No Yes ## 0.5944976 0.4055024
prop.table(table(data_test$survived))
Изход:
## ## No Yes ## 0.5789474 0.4210526
И в двата набора от данни броят на оцелелите е еднакъв, около 40 процента.
Инсталирайте rpart.plot
rpart.plot не се предлага от библиотеките на conda. Можете да го инсталирате от конзолата:
install.packages("rpart.plot")
Стъпка 4) Изградете модела
Готови сте да изградите модела. Синтаксисът за функцията на дървото на решенията Rpart е:
rpart(formula, data=, method='') arguments: - formula: The function to predict - data: Specifies the data frame- method: - "class" for a classification tree - "anova" for a regression tree
Използвате метода клас, защото предвиждате клас.
library(rpart) library(rpart.plot) fit <- rpart(survived~., data = data_train, method = 'class') rpart.plot(fit, extra = 106
Обяснение на кода
- rpart(): Функция за пасване на модела. Аргументите са:
- оцелял ~.: Формула на дърветата на решенията
- данни = data_train: Набор от данни
- method = 'class': Напасване на двоичен модел
- rpart.plot(fit, extra= 106): Начертайте дървото. Допълнителните функции са настроени на 101 за показване на вероятността от 2-ри клас (полезно за двоични отговори). Можете да се обърнете към винетка за повече информация относно другите възможности за избор.
Изход:
Започвате от основния възел (дълбочина 0 върху 3, горната част на графиката):
- Най-отгоре това е общата вероятност за оцеляване. Той показва дела на пътниците, оцелели след катастрофата. 41% от пътниците са оцелели.
- Този възел пита дали полът на пътника е мъж. Ако да, тогава слизате до левия дъщерен възел на корена (дълбочина 2). 63 процента са мъже с вероятност за оцеляване 21 процента.
- Във втория възел питате дали мъжкият пътник е на възраст над 3.5 години. Ако да, тогава шансът за оцеляване е 19 процента.
- Продължавате така, за да разберете кои характеристики влияят върху вероятността за оцеляване.
Имайте предвид, че едно от многото качества на дърветата на решенията е, че те изискват много малко подготовка на данни. По-специално, те не изискват мащабиране или центриране на функции.
По подразбиране функцията rpart() използва Джини мярка за примеси за разделяне на нотата. Колкото по-висок е коефициентът на Джини, толкова повече различни екземпляри в рамките на възела.
Стъпка 5) Направете прогноза
Можете да предвидите набора от тестови данни. За да направите прогноза, можете да използвате функцията predict(). Основният синтаксис на предсказване за дърво на решения R е:
predict(fitted_model, df, type = 'class') arguments: - fitted_model: This is the object stored after model estimation. - df: Data frame used to make the prediction - type: Type of prediction - 'class': for classification - 'prob': to compute the probability of each class - 'vector': Predict the mean response at the node level
Искате да предскажете кои пътници е по-вероятно да оцелеят след сблъсъка от тестовия набор. Това означава, че сред тези 209 пътници ще разберете кой ще оцелее или не.
predict_unseen <-predict(fit, data_test, type = 'class')
Обяснение на кода
- predict(fit, data_test, type = 'class'): Предсказване на класа (0/1) на тестовия набор
Тестване на пътниците, които не са успели, и тези, които са успели.
table_mat <- table(data_test$survived, predict_unseen) table_mat
Обяснение на кода
- table(data_test$survived, predict_unseen): Създайте таблица, за да преброите колко пътници са класифицирани като оцелели и починали в сравнение с правилната класификация на дървото на решенията в R
Изход:
## predict_unseen ## No Yes ## No 106 15 ## Yes 30 58
Моделът правилно прогнозира 106 мъртви пътници, но класифицира 15 оцелели като мъртви. По аналогия, моделът погрешно класифицира 30 пътници като оцелели, докато те се оказват мъртви.
Стъпка 6) Измерете ефективността
Можете да изчислите мярка за точност за задача за класификация с матрица на объркване:
- матрица на объркване е по-добър избор за оценка на ефективността на класификацията. Общата идея е да се преброи колко пъти True екземплярите са класифицирани като False.
Всеки ред в матрицата на объркването представлява действителна цел, докато всяка колона представлява предвидена цел. Първият ред на тази матрица разглежда мъртвите пътници (класът False): 106 са правилно класифицирани като мъртви (Истински отрицателен), докато останалият е погрешно класифициран като оцелял (Фалшиво позитивен). Вторият ред разглежда оцелелите, положителният клас са 58 (Истинско положително), докато Истински отрицателен беше 30.
Можете да изчислите тест за точност от матрицата на объркването:
Това е съотношението на истинските положителни и истинските отрицателни спрямо сумата на матрицата. С R можете да кодирате както следва:
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
Обяснение на кода
- sum(diag(table_mat)): Сума от диагонала
- sum(table_mat): Сума на матрицата.
Можете да отпечатате точността на тестовия комплект:
print(paste('Accuracy for test', accuracy_Test))
Изход:
## [1] "Accuracy for test 0.784688995215311"
Имате резултат от 78 процента за набора от тестове. Можете да повторите същото упражнение с набора от данни за обучение.
Стъпка 7) Настройте хиперпараметрите
Дървото на решенията в R има различни параметри, които контролират аспектите на съответствието. В библиотеката на дървото на решенията на rpart можете да контролирате параметрите с помощта на функцията rpart.control(). В следния код въвеждате параметрите, които ще настроите. Можете да се обърнете към винетка за други параметри.
rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30) Arguments: -minsplit: Set the minimum number of observations in the node before the algorithm perform a split -minbucket: Set the minimum number of observations in the final note i.e. the leaf -maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0
Ще продължим както следва:
- Конструктивна функция за връщане на точност
- Настройте максималната дълбочина
- Настройте минималния брой проби, които трябва да има възел, преди да може да се раздели
- Настройте минималния брой проби, които листовият възел трябва да има
Можете да напишете функция за показване на точността. Просто обвивате кода, който сте използвали преди:
- прогнозиране: predict_unseen <- predict(fit, data_test, type = 'class')
- Създаване на таблица: table_mat <- table(data_test$survived, predict_unseen)
- Изчисляване на точност: accuracy_Test <- sum(diag(table_mat))/sum(table_mat)
accuracy_tune <- function(fit) { predict_unseen <- predict(fit, data_test, type = 'class') table_mat <- table(data_test$survived, predict_unseen) accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test }
Можете да опитате да настроите параметрите и да видите дали можете да подобрите модела над стойността по подразбиране. Напомняме ви, че трябва да получите точност, по-висока от 0.78
control <- rpart.control(minsplit = 4, minbucket = round(5 / 3), maxdepth = 3, cp = 0) tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control) accuracy_tune(tune_fit)
Изход:
## [1] 0.7990431
Със следния параметър:
minsplit = 4 minbucket= round(5/3) maxdepth = 3cp=0
Получавате по-висока производителност от предишния модел. честито!
Oбобщение
Можем да обобщим функциите, в които да обучим алгоритъм за дърво на решенията R
Библиотека | Цел | функция | клас | параметри | Детайли |
---|---|---|---|---|---|
rpart | Дърво за класификация на влаковете в R | rpart() | клас | формула, df, метод | |
rpart | Обучете регресионно дърво | rpart() | анова | формула, df, метод | |
rpart | Начертайте дърветата | rpart.plot() | втален модел | ||
база | предскаже | прогнозирам () | клас | втален модел, вид | |
база | предскаже | прогнозирам () | проб | втален модел, вид | |
база | предскаже | прогнозирам () | вектор | втален модел, вид | |
rpart | Контролни параметри | rpart.control() | минсплит | Задайте минималния брой наблюдения във възела, преди алгоритъмът да извърши разделяне | |
minbucket | Задайте минималния брой наблюдения в крайната бележка, т.е. листа | ||||
максимална дълбочина | Задайте максималната дълбочина на всеки възел на крайното дърво. Коренният възел се третира с дълбочина 0 | ||||
rpart | Модел на влака с контролен параметър | rpart() | формула, df, метод, контрол |
Забележка: Обучете модела на обучителни данни и тествайте ефективността на невидим набор от данни, т.е. тестов набор.