Дерево рішень у R: Дерево класифікації з прикладом

Що таке дерева рішень?

Дерева рішень є універсальним алгоритмом машинного навчання, який може виконувати як завдання класифікації, так і регресії. Це дуже потужні алгоритми, здатні підбирати складні набори даних. Крім того, дерева рішень є основними компонентами випадкових лісів, які є одними з найпотужніших алгоритмів машинного навчання, доступних сьогодні.

Навчання та візуалізація дерев рішень у R

Щоб побудувати своє перше дерево рішень у прикладі R, ми виконаємо наступні дії в цьому підручнику з дерева рішень:

  • Крок 1. Імпортуйте дані
  • Крок 2. Очистіть набір даних
  • Крок 3: Створіть тренувальний/тестовий набір
  • Крок 4: Побудуйте модель
  • Крок 5: Зробіть прогноз
  • Крок 6. Виміряйте продуктивність
  • Крок 7: Налаштуйте гіперпараметри

Крок 1) Імпортуйте дані

Якщо вам цікава доля титаніка, ви можете переглянути це відео на Youtube. Мета цього набору даних — передбачити, які люди мають більше шансів вижити після зіткнення з айсбергом. Набір даних містить 13 змінних і 1309 спостережень. Набір даних упорядковано за змінною X.

set.seed(678)
path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv'
titanic <-read.csv(path)
head(titanic)

вихід:

##   X pclass survived                                            name    sex
## 1 1      1        1                   Allen, Miss. Elisabeth Walton female
## 2 2      1        1                  Allison, Master. Hudson Trevor   male
## 3 3      1        0                    Allison, Miss. Helen Loraine female
## 4 4      1        0            Allison, Mr. Hudson Joshua Creighton   male
## 5 5      1        0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female
## 6 6      1        1                             Anderson, Mr. Harry   male
##       age sibsp parch ticket     fare   cabin embarked
## 1 29.0000     0     0  24160 211.3375      B5        S
## 2  0.9167     1     2 113781 151.5500 C22 C26        S
## 3  2.0000     1     2 113781 151.5500 C22 C26        S
## 4 30.0000     1     2 113781 151.5500 C22 C26        S
## 5 25.0000     1     2 113781 151.5500 C22 C26        S
## 6 48.0000     0     0  19952  26.5500     E12        S
##                         home.dest
## 1                    St Louis, MO
## 2 Montreal, PQ / Chesterville, ON
## 3 Montreal, PQ / Chesterville, ON
## 4 Montreal, PQ / Chesterville, ON
## 5 Montreal, PQ / Chesterville, ON
## 6                    New York, NY
tail(titanic)

вихід:

##         X pclass survived                      name    sex  age sibsp
## 1304 1304      3        0     Yousseff, Mr. Gerious   male   NA     0
## 1305 1305      3        0      Zabour, Miss. Hileni female 14.5     1
## 1306 1306      3        0     Zabour, Miss. Thamine female   NA     1
## 1307 1307      3        0 Zakarian, Mr. Mapriededer   male 26.5     0
## 1308 1308      3        0       Zakarian, Mr. Ortin   male 27.0     0
## 1309 1309      3        0        Zimmerman, Mr. Leo   male 29.0     0
##      parch ticket    fare cabin embarked home.dest
## 1304     0   2627 14.4583              C          
## 1305     0   2665 14.4542              C          
## 1306     0   2665 14.4542              C          
## 1307     0   2656  7.2250              C          
## 1308     0   2670  7.2250              C          
## 1309     0 315082  7.8750              S

З вихідних даних голови та хвоста ви можете помітити, що дані не перетасовані. Це велика проблема! Ви вибираєте, коли ви розділяєте свої дані між набором поїздів і тестовим набором тільки пасажир з класу 1 і 2 (жодного пасажира з класу 3 не входить до 80 відсотків спостережень), що означає, що алгоритм ніколи не побачить характеристики пасажира класу 3. Ця помилка призведе до поганого прогнозу.

Щоб подолати цю проблему, ви можете використовувати функцію sample().

shuffle_index <- sample(1:nrow(titanic))
head(shuffle_index)

Дерево рішень R код Пояснення

  • sample(1:nrow(titanic)): генерує випадковий список індексів від 1 до 1309 (тобто максимальну кількість рядків).

вихід:

## [1]  288  874 1078  633  887  992

Ви будете використовувати цей індекс, щоб перетасувати титанічний набір даних.

titanic <- titanic[shuffle_index, ]
head(titanic)

вихід:

##         X pclass survived
## 288   288      1        0
## 874   874      3        0
## 1078 1078      3        1
## 633   633      3        0
## 887   887      3        1
## 992   992      3        1
##                                                           name    sex age
## 288                                      Sutton, Mr. Frederick   male  61
## 874                   Humblen, Mr. Adolf Mathias Nicolai Olsen   male  42
## 1078                                 O'Driscoll, Miss. Bridget female  NA
## 633  Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female  39
## 887                                        Jermyn, Miss. Annie female  NA
## 992                                           Mamee, Mr. Hanna   male  NA
##      sibsp parch ticket    fare cabin embarked           home.dest## 288      0     0  36963 32.3208   D50        S     Haddenfield, NJ
## 874      0     0 348121  7.6500 F G63        S                    
## 1078     0     0  14311  7.7500              Q                    
## 633      1     5 347082 31.2750              S Sweden Winnipeg, MN
## 887      0     0  14313  7.7500              Q                    
## 992      0     0   2677  7.2292              C	

Крок 2) Очистіть набір даних

Структура даних показує, що деякі змінні мають NA. Очищення даних необхідно виконати наступним чином

  • Витягніть змінні home.dest,cabin, name, X і квиток
  • Створіть факторні змінні для pclass і survived
  • Відпустіть НС
library(dplyr)
# Drop variables
clean_titanic <- titanic % > %
select(-c(home.dest, cabin, name, X, ticket)) % > % 
#Convert to factor level
	mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')),
	survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > %
na.omit()
glimpse(clean_titanic)

Пояснення коду

  • select(-c(home.dest, cabin, name, X, ticket)): видалити непотрібні змінні
  • pclass = factor(pclass, levels = c(1,2,3), labels= c('Upper', 'Middle', 'Lower')): Додайте мітку до змінної pclass. 1 стає верхнім, 2 стає середнім і 3 стає нижнім
  • factor(survived, levels = c(0,1), labels = c('No', 'Yes')): Додайте мітку до змінної survived. 1 стає Ні, а 2 стає Так
  • na.omit(): Видалити спостереження NA

вихід:

## Observations: 1,045
## Variables: 8
## $ pclass   <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U...
## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y...
## $ sex      <fctr> male, male, female, female, male, male, female, male...
## $ age      <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ...
## $ sibsp    <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,...
## $ parch    <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,...
## $ fare     <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ...
## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...		

Крок 3) Створіть навчальний/тестовий набір

Перш ніж навчити свою модель, вам потрібно виконати два кроки:

  • Створення тренування та тестового набору: Ви навчаєте модель на тренуванні та перевіряєте прогноз на тестовому наборі (тобто невидимі дані).
  • Встановіть rpart.plot з консолі

Загальноприйнятою практикою є розділення даних 80/20, 80 відсотків даних служать для навчання моделі, а 20 відсотків — для прогнозування. Вам потрібно створити два окремих кадри даних. Ви не хочете торкатися тестового набору, доки не завершите створення своєї моделі. Ви можете створити назву функції create_train_test(), яка приймає три аргументи.

create_train_test(df, size = 0.8, train = TRUE)
arguments:
-df: Dataset used to train the model.
-size: Size of the split. By default, 0.8. Numerical value
-train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) {
    n_row = nrow(data)
    total_row = size * n_row
    train_sample < - 1: total_row
    if (train == TRUE) {
        return (data[train_sample, ])
    } else {
        return (data[-train_sample, ])
    }
}

Пояснення коду

  • function(data, size=0.8, train = TRUE): додайте аргументи у функцію
  • n_row = nrow(data): підрахувати кількість рядків у наборі даних
  • total_row = size*n_row: повертає n-й рядок, щоб побудувати набір поїздів
  • train_sample <- 1:total_row: Виберіть перший рядок до n-го рядків
  • if (train ==TRUE){ } else { }: якщо умова має значення true, повертає набір поїздів, інакше тестовий набір.

Ви можете перевірити свою функцію та перевірити розмір.

data_train <- create_train_test(clean_titanic, 0.8, train = TRUE)
data_test <- create_train_test(clean_titanic, 0.8, train = FALSE)
dim(data_train)

вихід:

## [1] 836   8
dim(data_test)

вихід:

## [1] 209   8

Набір даних поїзда містить 1046 рядків, тоді як тестовий набір даних має 262 рядки.

Ви використовуєте функцію prop.table() у поєднанні з table(), щоб перевірити правильність процесу рандомізації.

prop.table(table(data_train$survived))

вихід:

##
##        No       Yes 
## 0.5944976 0.4055024
prop.table(table(data_test$survived))

вихід:

## 
##        No       Yes 
## 0.5789474 0.4210526

В обох наборах даних кількість вижили однакова, близько 40 відсотків.

Встановіть rpart.plot

rpart.plot недоступний з бібліотек conda. Ви можете встановити його з консолі:

install.packages("rpart.plot")

Крок 4) Побудуйте модель

Ви готові побудувати модель. Синтаксис функції дерева рішень Rpart:

rpart(formula, data=, method='')
arguments:			
- formula: The function to predict
- data: Specifies the data frame- method: 			
- "class" for a classification tree 			
- "anova" for a regression tree	

Ви використовуєте метод класу, оскільки ви прогнозуєте клас.

library(rpart)
library(rpart.plot)
fit <- rpart(survived~., data = data_train, method = 'class')
rpart.plot(fit, extra = 106

Пояснення коду

  • rpart(): функція для підгонки моделі. Аргументи:
    • пережив ~.: Формула дерев рішень
    • data = data_train: Набір даних
    • method = 'class': підібрати бінарну модель
  • rpart.plot(fit, extra= 106): побудувати дерево. Для додаткових функцій встановлено значення 101, щоб відобразити ймовірність 2-го класу (корисно для двійкових відповідей). Ви можете звернутися до віньєтка для отримання додаткової інформації про інші варіанти.

вихід:

Побудуйте модель дерев рішень у R

Ви починаєте з кореневого вузла (глибина 0 на 3, верхня частина графіка):

  1. На вершині це загальна ймовірність виживання. Він показує частку пасажирів, які вижили в аварії. 41 відсоток пасажирів вижив.
  2. Цей вузол запитує, чи є стать пасажира чоловічої статі. Якщо так, то ви спускаєтеся до лівого дочірнього вузла кореня (глибина 2). 63 відсотки — самці з ймовірністю виживання 21 відсоток.
  3. У другому вузлі ви запитуєте, чи пасажиру чоловічої статі більше 3.5 років. Якщо так, то шанс вижити становить 19 відсотків.
  4. Ви продовжуєте так, щоб зрозуміти, які особливості впливають на ймовірність виживання.

Зауважте, що однією з багатьох властивостей дерев рішень є те, що вони вимагають дуже мало підготовки даних. Зокрема, вони не вимагають масштабування або центрування функцій.

За замовчуванням функція rpart() використовує Джині міра домішки, щоб розділити ноту. Чим вищий коефіцієнт Джині, тим більше різних екземплярів у вузлі.

Крок 5) Зробіть прогноз

Ви можете передбачити свій тестовий набір даних. Щоб зробити прогноз, ви можете скористатися функцією predict(). Основний синтаксис прогнозу для дерева рішень R:

predict(fitted_model, df, type = 'class')
arguments:
- fitted_model: This is the object stored after model estimation. 
- df: Data frame used to make the prediction
- type: Type of prediction			
    - 'class': for classification			
    - 'prob': to compute the probability of each class			
    - 'vector': Predict the mean response at the node level	

Ви хочете передбачити, які пасажири з більшою ймовірністю виживуть після зіткнення з тестового набору. Тобто серед тих 209 пасажирів ви будете знати, хто виживе чи ні.

predict_unseen <-predict(fit, data_test, type = 'class')

Пояснення коду

  • predict(fit, data_test, type = 'class'): передбачити клас (0/1) тестового набору

Тестування пасажира, який не встиг, і тих, хто встиг.

table_mat <- table(data_test$survived, predict_unseen)
table_mat

Пояснення коду

  • table(data_test$survived, predict_unseen): Створіть таблицю для підрахунку кількості пасажирів, класифікованих як вижили та померлих, у порівнянні з правильною класифікацією дерева рішень у R

вихід:

##      predict_unseen
##        No Yes
##   No  106  15
##   Yes  30  58

Модель правильно передбачила 106 загиблих пасажирів, але класифікувала 15 тих, хто вижив, як мертвих. За аналогією, модель неправильно класифікувала 30 пасажирів як тих, хто вижив, хоча вони виявилися мертвими.

Крок 6) Виміряйте продуктивність

Ви можете обчислити міру точності для завдання класифікації за допомогою матриця плутанини:

Команда матриця плутанини є кращим вибором для оцінки ефективності класифікації. Загальна ідея полягає в тому, щоб підрахувати кількість разів, коли випадки True класифікуються як False.

Вимірювання продуктивності дерев рішень у R

Кожен рядок у матриці плутанини представляє фактичну ціль, тоді як кожен стовпець представляє прогнозовану ціль. Перший рядок цієї матриці враховує загиблих пасажирів (клас False): 106 було правильно класифіковано як загиблі (Справжній негатив), тоді як решта була помилково класифікована як вижила (Хибно позитивний). Другий рядок враховує тих, хто вижив, позитивний клас склав 58 (Справжній позитив), У той час як Справжній негатив було 30.

Ви можете обчислити тест на точність з матриці плутанини:

Вимірювання продуктивності дерев рішень у R

Це частка істинно позитивного та істинно негативного значення в сумі матриці. За допомогою R ви можете кодувати наступним чином:

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)

Пояснення коду

  • sum(diag(table_mat)): сума діагоналі
  • sum(table_mat): сума матриці.

Ви можете роздрукувати точність тестового набору:

print(paste('Accuracy for test', accuracy_Test))

вихід:

## [1] "Accuracy for test 0.784688995215311"

У вас є 78 відсотків балів за набір тестів. Ви можете відтворити ту саму вправу за допомогою навчального набору даних.

Крок 7) Налаштуйте гіперпараметри

Дерево рішень у R має різні параметри, які контролюють аспекти відповідності. У бібліотеці дерева рішень rpart ви можете керувати параметрами за допомогою функції rpart.control(). У наступному коді ви вводите параметри, які ви будете налаштовувати. Ви можете звернутися до віньєтка для інших параметрів.

rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30)
Arguments:
-minsplit: Set the minimum number of observations in the node before the algorithm perform a split
-minbucket:  Set the minimum number of observations in the final note i.e. the leaf
-maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0

Ми будемо діяти наступним чином:

  • Побудова функції для повернення точності
  • Налаштувати максимальну глибину
  • Налаштуйте мінімальну кількість зразків, які повинен мати вузол, перш ніж його можна буде розділити
  • Налаштуйте мінімальну кількість зразків, які повинен мати листовий вузол

Ви можете написати функцію для відображення точності. Ви просто обертаєте код, який використовували раніше:

  1. predict: predict_unseen <- predict(fit, data_test, type = 'class')
  2. Створення таблиці: table_mat <- table(data_test$survived, predict_unseen)
  3. Точність обчислення: accuracy_Test <- sum(diag(table_mat))/sum(table_mat)
accuracy_tune <- function(fit) {
    predict_unseen <- predict(fit, data_test, type = 'class')
    table_mat <- table(data_test$survived, predict_unseen)
    accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
    accuracy_Test
}

Ви можете спробувати налаштувати параметри та перевірити, чи зможете ви покращити модель порівняно зі значенням за замовчуванням. Нагадуємо, що вам потрібна точність вище 0.78

control <- rpart.control(minsplit = 4,
    minbucket = round(5 / 3),
    maxdepth = 3,
    cp = 0)
tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control)
accuracy_tune(tune_fit)

вихід:

## [1] 0.7990431

З таким параметром:

minsplit = 4
minbucket= round(5/3)
maxdepth = 3cp=0

Ви отримуєте вищу продуктивність, ніж у попередньої моделі. Вітаю!

Підсумки

Ми можемо узагальнити функції для навчання алгоритму дерева рішень R

Library Мета функція Клас параметри ПОДРОБИЦІ
rpart Дерево класифікації поїздів у R rpart() клас формула, df, метод
rpart Дерево тренування регресії rpart() anova формула, df, метод
rpart Накресліть дерева rpart.plot() приталена модель
база передбачати передбачити() клас приталена модель, тип
база передбачати передбачити() проб приталена модель, тип
база передбачати передбачити() вектор приталена модель, тип
rpart Параметри управління rpart.control() minsplit Встановіть мінімальну кількість спостережень у вузлі, перш ніж алгоритм виконає розбиття
minbucket Встановіть мінімальну кількість спостережень у фінальній ноті, тобто аркуші
максимальна глибина Встановіть максимальну глибину будь-якого вузла кінцевого дерева. Кореневий вузол обробляється на глибину 0
rpart Модель поїзда з контрольним параметром rpart() формула, df, метод, контроль

Примітка. Навчіть модель на навчальних даних і перевірте продуктивність на невидимому наборі даних, тобто тестовому наборі.