Дерево рішень у R: Дерево класифікації з прикладом
Що таке дерева рішень?
Дерева рішень є універсальним алгоритмом машинного навчання, який може виконувати як завдання класифікації, так і регресії. Це дуже потужні алгоритми, здатні підбирати складні набори даних. Крім того, дерева рішень є основними компонентами випадкових лісів, які є одними з найпотужніших алгоритмів машинного навчання, доступних сьогодні.
Навчання та візуалізація дерев рішень у R
Щоб побудувати своє перше дерево рішень у прикладі R, ми виконаємо наступні дії в цьому підручнику з дерева рішень:
- Крок 1. Імпортуйте дані
- Крок 2. Очистіть набір даних
- Крок 3: Створіть тренувальний/тестовий набір
- Крок 4: Побудуйте модель
- Крок 5: Зробіть прогноз
- Крок 6. Виміряйте продуктивність
- Крок 7: Налаштуйте гіперпараметри
Крок 1) Імпортуйте дані
Якщо вам цікава доля титаніка, ви можете переглянути це відео на Youtube. Мета цього набору даних — передбачити, які люди мають більше шансів вижити після зіткнення з айсбергом. Набір даних містить 13 змінних і 1309 спостережень. Набір даних упорядковано за змінною X.
set.seed(678) path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv' titanic <-read.csv(path) head(titanic)
вихід:
## X pclass survived name sex ## 1 1 1 1 Allen, Miss. Elisabeth Walton female ## 2 2 1 1 Allison, Master. Hudson Trevor male ## 3 3 1 0 Allison, Miss. Helen Loraine female ## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male ## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female ## 6 6 1 1 Anderson, Mr. Harry male ## age sibsp parch ticket fare cabin embarked ## 1 29.0000 0 0 24160 211.3375 B5 S ## 2 0.9167 1 2 113781 151.5500 C22 C26 S ## 3 2.0000 1 2 113781 151.5500 C22 C26 S ## 4 30.0000 1 2 113781 151.5500 C22 C26 S ## 5 25.0000 1 2 113781 151.5500 C22 C26 S ## 6 48.0000 0 0 19952 26.5500 E12 S ## home.dest ## 1 St Louis, MO ## 2 Montreal, PQ / Chesterville, ON ## 3 Montreal, PQ / Chesterville, ON ## 4 Montreal, PQ / Chesterville, ON ## 5 Montreal, PQ / Chesterville, ON ## 6 New York, NY
tail(titanic)
вихід:
## X pclass survived name sex age sibsp ## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0 ## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1 ## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1 ## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0 ## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0 ## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0 ## parch ticket fare cabin embarked home.dest ## 1304 0 2627 14.4583 C ## 1305 0 2665 14.4542 C ## 1306 0 2665 14.4542 C ## 1307 0 2656 7.2250 C ## 1308 0 2670 7.2250 C ## 1309 0 315082 7.8750 S
З вихідних даних голови та хвоста ви можете помітити, що дані не перетасовані. Це велика проблема! Ви вибираєте, коли ви розділяєте свої дані між набором поїздів і тестовим набором тільки пасажир з класу 1 і 2 (жодного пасажира з класу 3 не входить до 80 відсотків спостережень), що означає, що алгоритм ніколи не побачить характеристики пасажира класу 3. Ця помилка призведе до поганого прогнозу.
Щоб подолати цю проблему, ви можете використовувати функцію sample().
shuffle_index <- sample(1:nrow(titanic)) head(shuffle_index)
Дерево рішень R код Пояснення
- sample(1:nrow(titanic)): генерує випадковий список індексів від 1 до 1309 (тобто максимальну кількість рядків).
вихід:
## [1] 288 874 1078 633 887 992
Ви будете використовувати цей індекс, щоб перетасувати титанічний набір даних.
titanic <- titanic[shuffle_index, ] head(titanic)
вихід:
## X pclass survived ## 288 288 1 0 ## 874 874 3 0 ## 1078 1078 3 1 ## 633 633 3 0 ## 887 887 3 1 ## 992 992 3 1 ## name sex age ## 288 Sutton, Mr. Frederick male 61 ## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42 ## 1078 O'Driscoll, Miss. Bridget female NA ## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39 ## 887 Jermyn, Miss. Annie female NA ## 992 Mamee, Mr. Hanna male NA ## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ ## 874 0 0 348121 7.6500 F G63 S ## 1078 0 0 14311 7.7500 Q ## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN ## 887 0 0 14313 7.7500 Q ## 992 0 0 2677 7.2292 C
Крок 2) Очистіть набір даних
Структура даних показує, що деякі змінні мають NA. Очищення даних необхідно виконати наступним чином
- Витягніть змінні home.dest,cabin, name, X і квиток
- Створіть факторні змінні для pclass і survived
- Відпустіть НС
library(dplyr) # Drop variables clean_titanic <- titanic % > % select(-c(home.dest, cabin, name, X, ticket)) % > % #Convert to factor level mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')), survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > % na.omit() glimpse(clean_titanic)
Пояснення коду
- select(-c(home.dest, cabin, name, X, ticket)): видалити непотрібні змінні
- pclass = factor(pclass, levels = c(1,2,3), labels= c('Upper', 'Middle', 'Lower')): Додайте мітку до змінної pclass. 1 стає верхнім, 2 стає середнім і 3 стає нижнім
- factor(survived, levels = c(0,1), labels = c('No', 'Yes')): Додайте мітку до змінної survived. 1 стає Ні, а 2 стає Так
- na.omit(): Видалити спостереження NA
вихід:
## Observations: 1,045 ## Variables: 8 ## $ pclass <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U... ## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y... ## $ sex <fctr> male, male, female, female, male, male, female, male... ## $ age <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ... ## $ sibsp <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,... ## $ parch <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,... ## $ fare <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ... ## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...
Крок 3) Створіть навчальний/тестовий набір
Перш ніж навчити свою модель, вам потрібно виконати два кроки:
- Створення тренування та тестового набору: Ви навчаєте модель на тренуванні та перевіряєте прогноз на тестовому наборі (тобто невидимі дані).
- Встановіть rpart.plot з консолі
Загальноприйнятою практикою є розділення даних 80/20, 80 відсотків даних служать для навчання моделі, а 20 відсотків — для прогнозування. Вам потрібно створити два окремих кадри даних. Ви не хочете торкатися тестового набору, доки не завершите створення своєї моделі. Ви можете створити назву функції create_train_test(), яка приймає три аргументи.
create_train_test(df, size = 0.8, train = TRUE) arguments: -df: Dataset used to train the model. -size: Size of the split. By default, 0.8. Numerical value -train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample < - 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } }
Пояснення коду
- function(data, size=0.8, train = TRUE): додайте аргументи у функцію
- n_row = nrow(data): підрахувати кількість рядків у наборі даних
- total_row = size*n_row: повертає n-й рядок, щоб побудувати набір поїздів
- train_sample <- 1:total_row: Виберіть перший рядок до n-го рядків
- if (train ==TRUE){ } else { }: якщо умова має значення true, повертає набір поїздів, інакше тестовий набір.
Ви можете перевірити свою функцію та перевірити розмір.
data_train <- create_train_test(clean_titanic, 0.8, train = TRUE) data_test <- create_train_test(clean_titanic, 0.8, train = FALSE) dim(data_train)
вихід:
## [1] 836 8
dim(data_test)
вихід:
## [1] 209 8
Набір даних поїзда містить 1046 рядків, тоді як тестовий набір даних має 262 рядки.
Ви використовуєте функцію prop.table() у поєднанні з table(), щоб перевірити правильність процесу рандомізації.
prop.table(table(data_train$survived))
вихід:
## ## No Yes ## 0.5944976 0.4055024
prop.table(table(data_test$survived))
вихід:
## ## No Yes ## 0.5789474 0.4210526
В обох наборах даних кількість вижили однакова, близько 40 відсотків.
Встановіть rpart.plot
rpart.plot недоступний з бібліотек conda. Ви можете встановити його з консолі:
install.packages("rpart.plot")
Крок 4) Побудуйте модель
Ви готові побудувати модель. Синтаксис функції дерева рішень Rpart:
rpart(formula, data=, method='') arguments: - formula: The function to predict - data: Specifies the data frame- method: - "class" for a classification tree - "anova" for a regression tree
Ви використовуєте метод класу, оскільки ви прогнозуєте клас.
library(rpart) library(rpart.plot) fit <- rpart(survived~., data = data_train, method = 'class') rpart.plot(fit, extra = 106
Пояснення коду
- rpart(): функція для підгонки моделі. Аргументи:
- пережив ~.: Формула дерев рішень
- data = data_train: Набір даних
- method = 'class': підібрати бінарну модель
- rpart.plot(fit, extra= 106): побудувати дерево. Для додаткових функцій встановлено значення 101, щоб відобразити ймовірність 2-го класу (корисно для двійкових відповідей). Ви можете звернутися до віньєтка для отримання додаткової інформації про інші варіанти.
вихід:
Ви починаєте з кореневого вузла (глибина 0 на 3, верхня частина графіка):
- На вершині це загальна ймовірність виживання. Він показує частку пасажирів, які вижили в аварії. 41 відсоток пасажирів вижив.
- Цей вузол запитує, чи є стать пасажира чоловічої статі. Якщо так, то ви спускаєтеся до лівого дочірнього вузла кореня (глибина 2). 63 відсотки — самці з ймовірністю виживання 21 відсоток.
- У другому вузлі ви запитуєте, чи пасажиру чоловічої статі більше 3.5 років. Якщо так, то шанс вижити становить 19 відсотків.
- Ви продовжуєте так, щоб зрозуміти, які особливості впливають на ймовірність виживання.
Зауважте, що однією з багатьох властивостей дерев рішень є те, що вони вимагають дуже мало підготовки даних. Зокрема, вони не вимагають масштабування або центрування функцій.
За замовчуванням функція rpart() використовує Джині міра домішки, щоб розділити ноту. Чим вищий коефіцієнт Джині, тим більше різних екземплярів у вузлі.
Крок 5) Зробіть прогноз
Ви можете передбачити свій тестовий набір даних. Щоб зробити прогноз, ви можете скористатися функцією predict(). Основний синтаксис прогнозу для дерева рішень R:
predict(fitted_model, df, type = 'class') arguments: - fitted_model: This is the object stored after model estimation. - df: Data frame used to make the prediction - type: Type of prediction - 'class': for classification - 'prob': to compute the probability of each class - 'vector': Predict the mean response at the node level
Ви хочете передбачити, які пасажири з більшою ймовірністю виживуть після зіткнення з тестового набору. Тобто серед тих 209 пасажирів ви будете знати, хто виживе чи ні.
predict_unseen <-predict(fit, data_test, type = 'class')
Пояснення коду
- predict(fit, data_test, type = 'class'): передбачити клас (0/1) тестового набору
Тестування пасажира, який не встиг, і тих, хто встиг.
table_mat <- table(data_test$survived, predict_unseen) table_mat
Пояснення коду
- table(data_test$survived, predict_unseen): Створіть таблицю для підрахунку кількості пасажирів, класифікованих як вижили та померлих, у порівнянні з правильною класифікацією дерева рішень у R
вихід:
## predict_unseen ## No Yes ## No 106 15 ## Yes 30 58
Модель правильно передбачила 106 загиблих пасажирів, але класифікувала 15 тих, хто вижив, як мертвих. За аналогією, модель неправильно класифікувала 30 пасажирів як тих, хто вижив, хоча вони виявилися мертвими.
Крок 6) Виміряйте продуктивність
Ви можете обчислити міру точності для завдання класифікації за допомогою матриця плутанини:
Команда матриця плутанини є кращим вибором для оцінки ефективності класифікації. Загальна ідея полягає в тому, щоб підрахувати кількість разів, коли випадки True класифікуються як False.
Кожен рядок у матриці плутанини представляє фактичну ціль, тоді як кожен стовпець представляє прогнозовану ціль. Перший рядок цієї матриці враховує загиблих пасажирів (клас False): 106 було правильно класифіковано як загиблі (Справжній негатив), тоді як решта була помилково класифікована як вижила (Хибно позитивний). Другий рядок враховує тих, хто вижив, позитивний клас склав 58 (Справжній позитив), У той час як Справжній негатив було 30.
Ви можете обчислити тест на точність з матриці плутанини:
Це частка істинно позитивного та істинно негативного значення в сумі матриці. За допомогою R ви можете кодувати наступним чином:
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
Пояснення коду
- sum(diag(table_mat)): сума діагоналі
- sum(table_mat): сума матриці.
Ви можете роздрукувати точність тестового набору:
print(paste('Accuracy for test', accuracy_Test))
вихід:
## [1] "Accuracy for test 0.784688995215311"
У вас є 78 відсотків балів за набір тестів. Ви можете відтворити ту саму вправу за допомогою навчального набору даних.
Крок 7) Налаштуйте гіперпараметри
Дерево рішень у R має різні параметри, які контролюють аспекти відповідності. У бібліотеці дерева рішень rpart ви можете керувати параметрами за допомогою функції rpart.control(). У наступному коді ви вводите параметри, які ви будете налаштовувати. Ви можете звернутися до віньєтка для інших параметрів.
rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30) Arguments: -minsplit: Set the minimum number of observations in the node before the algorithm perform a split -minbucket: Set the minimum number of observations in the final note i.e. the leaf -maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0
Ми будемо діяти наступним чином:
- Побудова функції для повернення точності
- Налаштувати максимальну глибину
- Налаштуйте мінімальну кількість зразків, які повинен мати вузол, перш ніж його можна буде розділити
- Налаштуйте мінімальну кількість зразків, які повинен мати листовий вузол
Ви можете написати функцію для відображення точності. Ви просто обертаєте код, який використовували раніше:
- predict: predict_unseen <- predict(fit, data_test, type = 'class')
- Створення таблиці: table_mat <- table(data_test$survived, predict_unseen)
- Точність обчислення: accuracy_Test <- sum(diag(table_mat))/sum(table_mat)
accuracy_tune <- function(fit) { predict_unseen <- predict(fit, data_test, type = 'class') table_mat <- table(data_test$survived, predict_unseen) accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test }
Ви можете спробувати налаштувати параметри та перевірити, чи зможете ви покращити модель порівняно зі значенням за замовчуванням. Нагадуємо, що вам потрібна точність вище 0.78
control <- rpart.control(minsplit = 4, minbucket = round(5 / 3), maxdepth = 3, cp = 0) tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control) accuracy_tune(tune_fit)
вихід:
## [1] 0.7990431
З таким параметром:
minsplit = 4 minbucket= round(5/3) maxdepth = 3cp=0
Ви отримуєте вищу продуктивність, ніж у попередньої моделі. Вітаю!
Підсумки
Ми можемо узагальнити функції для навчання алгоритму дерева рішень R
Library | Мета | функція | Клас | параметри | ПОДРОБИЦІ |
---|---|---|---|---|---|
rpart | Дерево класифікації поїздів у R | rpart() | клас | формула, df, метод | |
rpart | Дерево тренування регресії | rpart() | anova | формула, df, метод | |
rpart | Накресліть дерева | rpart.plot() | приталена модель | ||
база | передбачати | передбачити() | клас | приталена модель, тип | |
база | передбачати | передбачити() | проб | приталена модель, тип | |
база | передбачати | передбачити() | вектор | приталена модель, тип | |
rpart | Параметри управління | rpart.control() | minsplit | Встановіть мінімальну кількість спостережень у вузлі, перш ніж алгоритм виконає розбиття | |
minbucket | Встановіть мінімальну кількість спостережень у фінальній ноті, тобто аркуші | ||||
максимальна глибина | Встановіть максимальну глибину будь-якого вузла кінцевого дерева. Кореневий вузол обробляється на глибину 0 | ||||
rpart | Модель поїзда з контрольним параметром | rpart() | формула, df, метод, контроль |
Примітка. Навчіть модель на навчальних даних і перевірте продуктивність на невидимому наборі даних, тобто тестовому наборі.