Entscheidungsbaum in R: Klassifizierungsbaum mit Beispiel

Was sind Entscheidungsbรคume?

Entscheidungsbรคume sind vielseitige Machine-Learning-Algorithmen, die sowohl Klassifizierungs- als auch Regressionsaufgaben ausfรผhren kรถnnen. Es sind sehr leistungsstarke Algorithmen, die in der Lage sind, komplexe Datensรคtze anzupassen. Darรผber hinaus sind Entscheidungsbรคume grundlegende Komponenten von Random Forests, die zu den leistungsfรคhigsten Machine-Learning-Algorithmen gehรถren, die heute verfรผgbar sind.

Training und Visualisierung eines Entscheidungsbaums in R

Um Ihren ersten Entscheidungsbaum im R-Beispiel zu erstellen, gehen wir in diesem Entscheidungsbaum-Tutorial wie folgt vor:

  • Schritt 1: Importieren Sie die Daten
  • Schritt 2: Bereinigen Sie den Datensatz
  • Schritt 3: Zug-/Testset erstellen
  • Schritt 4: Erstellen Sie das Modell
  • Schritt 5: Machen Sie eine Vorhersage
  • Schritt 6: Leistung messen
  • Schritt 7: Optimieren Sie die Hyperparameter

Schritt 1) โ€‹โ€‹Importieren Sie die Daten

Wenn Sie neugierig auf das Schicksal der Titanic sind, kรถnnen Sie sich dieses Video ansehen Youtube. Der Zweck dieses Datensatzes besteht darin, vorherzusagen, welche Menschen nach der Kollision mit dem Eisberg mit grรถรŸerer Wahrscheinlichkeit รผberleben. Der Datensatz enthรคlt 13 Variablen und 1309 Beobachtungen. Der Datensatz wird nach der Variablen X geordnet.

set.seed(678)
path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv'
titanic <-read.csv(path)
head(titanic)

Ausgang:

##   X pclass survived                                            name    sex
## 1 1      1        1                   Allen, Miss. Elisabeth Walton female
## 2 2      1        1                  Allison, Master. Hudson Trevor   male
## 3 3      1        0                    Allison, Miss. Helen Loraine female
## 4 4      1        0            Allison, Mr. Hudson Joshua Creighton   male
## 5 5      1        0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female
## 6 6      1        1                             Anderson, Mr. Harry   male
##       age sibsp parch ticket     fare   cabin embarked
## 1 29.0000     0     0  24160 211.3375      B5        S
## 2  0.9167     1     2 113781 151.5500 C22 C26        S
## 3  2.0000     1     2 113781 151.5500 C22 C26        S
## 4 30.0000     1     2 113781 151.5500 C22 C26        S
## 5 25.0000     1     2 113781 151.5500 C22 C26        S
## 6 48.0000     0     0  19952  26.5500     E12        S
##                         home.dest
## 1                    St Louis, MO
## 2 Montreal, PQ / Chesterville, ON
## 3 Montreal, PQ / Chesterville, ON
## 4 Montreal, PQ / Chesterville, ON
## 5 Montreal, PQ / Chesterville, ON
## 6                    New York, NY
tail(titanic)

Ausgang:

##         X pclass survived                      name    sex  age sibsp
## 1304 1304      3        0     Yousseff, Mr. Gerious   male   NA     0
## 1305 1305      3        0      Zabour, Miss. Hileni female 14.5     1
## 1306 1306      3        0     Zabour, Miss. Thamine female   NA     1
## 1307 1307      3        0 Zakarian, Mr. Mapriededer   male 26.5     0
## 1308 1308      3        0       Zakarian, Mr. Ortin   male 27.0     0
## 1309 1309      3        0        Zimmerman, Mr. Leo   male 29.0     0
##      parch ticket    fare cabin embarked home.dest
## 1304     0   2627 14.4583              C          
## 1305     0   2665 14.4542              C          
## 1306     0   2665 14.4542              C          
## 1307     0   2656  7.2250              C          
## 1308     0   2670  7.2250              C          
## 1309     0 315082  7.8750              S

Anhand der Head- und Tail-Ausgabe kรถnnen Sie erkennen, dass die Daten nicht gemischt werden. Das ist ein groรŸes Problem! Wenn Sie Ihre Daten zwischen einem Zugsatz und einem Testsatz aufteilen, wรคhlen Sie aus einzige der Passagier der Klassen 1 und 2 (Kein Passagier der Klasse 3 befindet sich in den oberen 80 Prozent der Beobachtungen), was bedeutet, dass der Algorithmus niemals die Merkmale des Passagiers der Klasse 3 sehen wird. Dieser Fehler fรผhrt zu einer schlechten Vorhersage.

Um dieses Problem zu lรถsen, kรถnnen Sie die Funktion sample() verwenden.

shuffle_index <- sample(1:nrow(titanic))
head(shuffle_index)

Entscheidungsbaum-R-Code-Erklรคrung

  • sample(1:nrow(titanic)): Generieren Sie eine zufรคllige Indexliste von 1 bis 1309 (dh die maximale Anzahl von Zeilen).

Ausgang:

## [1]  288  874 1078  633  887  992

Sie werden diesen Index verwenden, um den Titanic-Datensatz zu mischen.

titanic <- titanic[shuffle_index, ]
head(titanic)

Ausgang:

##         X pclass survived
## 288   288      1        0
## 874   874      3        0
## 1078 1078      3        1
## 633   633      3        0
## 887   887      3        1
## 992   992      3        1
##                                                           name    sex age
## 288                                      Sutton, Mr. Frederick   male  61
## 874                   Humblen, Mr. Adolf Mathias Nicolai Olsen   male  42
## 1078                                 O'Driscoll, Miss. Bridget female  NA
## 633  Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female  39
## 887                                        Jermyn, Miss. Annie female  NA
## 992                                           Mamee, Mr. Hanna   male  NA
##      sibsp parch ticket    fare cabin embarked           home.dest## 288      0     0  36963 32.3208   D50        S     Haddenfield, NJ
## 874      0     0 348121  7.6500 F G63        S                    
## 1078     0     0  14311  7.7500              Q                    
## 633      1     5 347082 31.2750              S Sweden Winnipeg, MN
## 887      0     0  14313  7.7500              Q                    
## 992      0     0   2677  7.2292              C	

Schritt 2) Bereinigen Sie den Datensatz

Die Struktur der Daten zeigt, dass einige Variablen NAs haben. Die Datenbereinigung muss wie folgt durchgefรผhrt werden

  • Lรถschen Sie die Variablen home.dest,cabin, name, X und ticket
  • Erstellen Sie Faktorvariablen fรผr PCLASS und Surved
  • Lass die NA fallen
library(dplyr)
# Drop variables
clean_titanic <- titanic % > %
select(-c(home.dest, cabin, name, X, ticket)) % > % 
#Convert to factor level
	mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')),
	survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > %
na.omit()
glimpse(clean_titanic)

Code Erklรคrung

  • select(-c(home.dest, Kabine, Name, X, Ticket)): Lรถschen Sie unnรถtige Variablen
  • pclass = Faktor(pclass, Ebenen = c(1,2,3), Labels= c('Upper', 'Middle', 'Lower')): Beschriftung zur Variablen pclass hinzufรผgen. 1 wird zum oberen Wert, 2 wird zum mittleren Wert und 3 wird zum niedrigeren Wert
  • Faktor(รผberlebt, Ebenen = c(0,1), Labels = c('Nein', 'Ja')): Fรผge Label zur Variablen รผberlebt hinzu. 1 wird zu Nein und 2 wird zu Ja
  • na.omit(): Entfernen Sie die NA-Beobachtungen

Ausgang:

## Observations: 1,045
## Variables: 8
## $ pclass   <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U...
## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y...
## $ sex      <fctr> male, male, female, female, male, male, female, male...
## $ age      <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ...
## $ sibsp    <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,...
## $ parch    <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,...
## $ fare     <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ...
## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...		

Schritt 3) Erstellen Sie einen Zug-/Testsatz

Bevor Sie Ihr Modell trainieren, mรผssen Sie zwei Schritte ausfรผhren:

  • Erstellen Sie einen Zug- und Testsatz: Sie trainieren das Modell auf dem Zugsatz und testen die Vorhersage auf dem Testsatz (d. h. unsichtbaren Daten).
  • Installieren Sie rpart.plot รผber die Konsole

Die gรคngige Praxis besteht darin, die Daten im Verhรคltnis 80/20 aufzuteilen, wobei 80 Prozent der Daten zum Trainieren des Modells und 20 Prozent zur Erstellung von Vorhersagen dienen. Sie mรผssen zwei separate Datenrahmen erstellen. Sie mรถchten das Testset erst dann anfassen, wenn Sie mit der Erstellung Ihres Modells fertig sind. Sie kรถnnen einen Funktionsnamen create_train_test() erstellen, der drei Argumente akzeptiert.

create_train_test(df, size = 0.8, train = TRUE)
arguments:
-df: Dataset used to train the model.
-size: Size of the split. By default, 0.8. Numerical value
-train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) {
    n_row = nrow(data)
    total_row = size * n_row
    train_sample < - 1: total_row
    if (train == TRUE) {
        return (data[train_sample, ])
    } else {
        return (data[-train_sample, ])
    }
}

Code Erklรคrung

  • function(data, size=0.8, train = TRUE): Fรผgen Sie die Argumente in der Funktion hinzu
  • n_row = nrow(data): Zรคhlt die Anzahl der Zeilen im Datensatz
  • total_row = size*n_row: Gibt die n-te Zeile zurรผck, um den Zugsatz zu erstellen
  • train_sample <- 1:total_row: Wรคhlen Sie die erste bis n-te Zeile aus
  • if (train ==TRUE){ } else { }: Wenn die Bedingung auf true gesetzt ist, wird der Zugsatz zurรผckgegeben, andernfalls der Testsatz.

Sie kรถnnen Ihre Funktion testen und die Abmessung รผberprรผfen.

data_train <- create_train_test(clean_titanic, 0.8, train = TRUE)
data_test <- create_train_test(clean_titanic, 0.8, train = FALSE)
dim(data_train)

Ausgang:

## [1] 836   8
dim(data_test)

Ausgang:

## [1] 209   8

Der Zugdatensatz hat 1046 Zeilen, wรคhrend der Testdatensatz 262 Zeilen hat.

Sie verwenden die Funktion prop.table() in Kombination mit table(), um zu รผberprรผfen, ob der Randomisierungsprozess korrekt ist.

prop.table(table(data_train$survived))

Ausgang:

##
##        No       Yes 
## 0.5944976 0.4055024
prop.table(table(data_test$survived))

Ausgang:

## 
##        No       Yes 
## 0.5789474 0.4210526

In beiden Datensรคtzen ist die Anzahl der รœberlebenden gleich, etwa 40 Prozent.

Installieren Sie rpart.plot

rpart.plot ist in den Conda-Bibliotheken nicht verfรผgbar. Sie kรถnnen es รผber die Konsole installieren:

install.packages("rpart.plot")

Schritt 4) Erstellen Sie das Modell

Sie sind bereit, das Modell zu erstellen. Die Syntax fรผr die Rpart-Entscheidungsbaumfunktion lautet:

rpart(formula, data=, method='')
arguments:			
- formula: The function to predict
- data: Specifies the data frame- method: 			
- "class" for a classification tree 			
- "anova" for a regression tree	

Sie verwenden die Klassenmethode, weil Sie eine Klasse vorhersagen.

library(rpart)
library(rpart.plot)
fit <- rpart(survived~., data = data_train, method = 'class')
rpart.plot(fit, extra = 106

Code Erklรคrung

  • rpart(): Funktion zur Anpassung an das Modell. Die Argumente sind:
    • รผberlebt ~.: Formel der Entscheidungsbรคume
    • data = data_train: Datensatz
    • method = 'class': Passen Sie ein binรคres Modell an
  • rpart.plot(fit, extra= 106): Plotten Sie den Baum. Die zusรคtzlichen Funktionen werden auf 101 gesetzt, um die Wahrscheinlichkeit der 2. Klasse anzuzeigen (nรผtzlich fรผr binรคre Antworten). Sie kรถnnen sich auf die beziehen Vignette Weitere Informationen zu den anderen Optionen finden Sie hier.

Ausgang:

 Erstellen Sie ein Modell von Entscheidungsbรคumen in R

Sie beginnen am Wurzelknoten (Tiefe 0 รผber 3, oben im Diagramm):

  1. Ganz oben steht die Gesamtรผberlebenswahrscheinlichkeit. Es zeigt den Anteil der Passagiere, die den Unfall รผberlebt haben. 41 Prozent der Passagiere รผberlebten.
  2. Dieser Knoten fragt, ob das Geschlecht des Passagiers mรคnnlich ist. Wenn ja, gehen Sie zum linken untergeordneten Knoten der Wurzel (Tiefe 2). 63 Prozent sind Mรคnner mit einer รœberlebenswahrscheinlichkeit von 21 Prozent.
  3. Im zweiten Knoten fragen Sie, ob der mรคnnliche Passagier รคlter als 3.5 Jahre ist. Wenn ja, dann liegt die รœberlebenschance bei 19 Prozent.
  4. So machen Sie weiter, um zu verstehen, welche Merkmale die รœberlebenswahrscheinlichkeit beeinflussen.

Beachten Sie, dass eine der vielen Eigenschaften von Entscheidungsbรคumen darin besteht, dass sie nur sehr wenig Datenvorbereitung erfordern. Insbesondere ist keine Feature-Skalierung oder -Zentrierung erforderlich.

StandardmรครŸig verwendet die Funktion rpart() die Gini VerunreinigungsmaรŸnahme, um die Note zu spalten. Je hรถher der Gini-Koeffizient, desto mehr unterschiedliche Instanzen innerhalb des Knotens.

Schritt 5) Machen Sie eine Vorhersage

Sie kรถnnen Ihren Testdatensatz vorhersagen. Um eine Vorhersage zu treffen, kรถnnen Sie die Funktion Predict() verwenden. Die grundlegende Syntax der Vorhersage fรผr den R-Entscheidungsbaum lautet:

predict(fitted_model, df, type = 'class')
arguments:
- fitted_model: This is the object stored after model estimation. 
- df: Data frame used to make the prediction
- type: Type of prediction			
    - 'class': for classification			
    - 'prob': to compute the probability of each class			
    - 'vector': Predict the mean response at the node level	

Sie mรถchten anhand des Testsatzes vorhersagen, welche Passagiere nach der Kollision mit grรถรŸerer Wahrscheinlichkeit รผberleben. Das heiรŸt, Sie werden wissen, welcher von diesen 209 Passagieren รผberleben wird oder nicht.

predict_unseen <-predict(fit, data_test, type = 'class')

Code Erklรคrung

  • Predict(fit, data_test, type = 'class'): Sagen Sie die Klasse (0/1) des Testsatzes voraus

Testen der Passagiere, die es nicht geschafft haben, und derjenigen, die es geschafft haben.

table_mat <- table(data_test$survived, predict_unseen)
table_mat

Code Erklรคrung

  • table(data_test$survived, Predict_unseen): Erstellen Sie eine Tabelle, um zu zรคhlen, wie viele Passagiere als รœberlebende und Verstorbene eingestuft werden, im Vergleich zur korrekten Entscheidungsbaumklassifizierung in R

Ausgang:

##      predict_unseen
##        No Yes
##   No  106  15
##   Yes  30  58

Das Modell hat 106 tote Passagiere korrekt vorhergesagt, 15 รœberlebende jedoch als tot eingestuft. Analog dazu klassifizierte das Modell 30 Passagiere fรคlschlicherweise als รœberlebende, wรคhrend sich herausstellte, dass sie tot waren.

Schritt 6) Messen Sie die Leistung

Mit dem kรถnnen Sie ein GenauigkeitsmaรŸ fรผr Klassifizierungsaufgaben berechnen Verwirrung Matrix:

Die Verwirrung Matrix ist eine bessere Wahl, um die Klassifizierungsleistung zu bewerten. Die allgemeine Idee besteht darin, zu zรคhlen, wie oft Wahre Instanzen als Falsch klassifiziert werden.

Messen Sie die Leistung von Entscheidungsbรคumen in R

Jede Zeile in einer Verwirrungsmatrix stellt ein tatsรคchliches Ziel dar, wรคhrend jede Spalte ein vorhergesagtes Ziel darstellt. Die erste Zeile dieser Matrix berรผcksichtigt tote Passagiere (die False-Klasse): 106 wurden korrekt als tot klassifiziert (Wahr-negativ), wรคhrend der Rest fรคlschlicherweise als รœberlebender eingestuft wurde (Falsch positiv). Die zweite Reihe berรผcksichtigt die รœberlebenden, die positive Klasse betrug 58 (Richtig positiv), Wรคhrend die Wahr-negativ war 30.

Sie kรถnnen das berechnen Genauigkeitstest aus der Verwirrungsmatrix:

Messen Sie die Leistung von Entscheidungsbรคumen in R

Es ist das Verhรคltnis von richtig positiv und richtig negativ an der Summe der Matrix. Mit R kรถnnen Sie wie folgt codieren:

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)

Code Erklรคrung

  • sum(diag(table_mat)): Summe der Diagonalen
  • sum(table_mat): Summe der Matrix.

Sie kรถnnen die Genauigkeit des Testsatzes ausdrucken:

print(paste('Accuracy for test', accuracy_Test))

Ausgang:

## [1] "Accuracy for test 0.784688995215311"

Sie haben fรผr den Testsatz eine Punktzahl von 78 Prozent. Sie kรถnnen dieselbe รœbung mit dem Trainingsdatensatz replizieren.

Schritt 7) Optimieren Sie die Hyperparameter

Entscheidungsbรคume in R haben verschiedene Parameter, die Aspekte der Anpassung steuern. In der Entscheidungsbaumbibliothek rpart kรถnnen Sie die Parameter mit der Funktion rpart.control() steuern. Im folgenden Code fรผhren Sie die Parameter ein, die Sie anpassen werden. Sie kรถnnen sich auf die Vignette fรผr andere Parameter.

rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30)
Arguments:
-minsplit: Set the minimum number of observations in the node before the algorithm perform a split
-minbucket:  Set the minimum number of observations in the final note i.e. the leaf
-maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0

Wir werden wie folgt vorgehen:

  • Konstruieren Sie eine Funktion, um Genauigkeit zurรผckzugeben
  • Stellen Sie die maximale Tiefe ein
  • Passen Sie die Mindestanzahl an Samples an, die ein Knoten haben muss, bevor er geteilt werden kann
  • Passen Sie die Mindestanzahl an Samples an, die ein Blattknoten haben muss

Sie kรถnnen eine Funktion schreiben, um die Genauigkeit anzuzeigen. Sie packen einfach den Code ein, den Sie zuvor verwendet haben:

  1. Vorhersage: Predict_unseen <- Predict(fit, data_test, type = 'class')
  2. Tabelle erstellen: table_mat <- table(data_test$survived, Predict_unseen)
  3. Genauigkeit berechnen: precision_Test <- sum(diag(table_mat))/sum(table_mat)
accuracy_tune <- function(fit) {
    predict_unseen <- predict(fit, data_test, type = 'class')
    table_mat <- table(data_test$survived, predict_unseen)
    accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
    accuracy_Test
}

Sie kรถnnen versuchen, die Parameter zu optimieren und zu sehen, ob Sie das Modell gegenรผber dem Standardwert verbessern kรถnnen. Zur Erinnerung: Sie mรผssen eine Genauigkeit von mehr als 0.78 erreichen

control <- rpart.control(minsplit = 4,
    minbucket = round(5 / 3),
    maxdepth = 3,
    cp = 0)
tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control)
accuracy_tune(tune_fit)

Ausgang:

## [1] 0.7990431

Mit folgendem Parameter:

minsplit = 4
minbucket= round(5/3)
maxdepth = 3cp=0

Sie erhalten eine hรถhere Leistung als beim Vorgรคngermodell. Herzlichen Glรผckwunsch!

Zusammenfassung

Wir kรถnnen die Funktionen zum Trainieren eines Entscheidungsbaumalgorithmus in R

Bibliothek Ziel Funktion Klasse KenngrรถรŸen Details
Teil Zugklassifizierungsbaum in R rpart() Klasse Formel, df, Methode
Teil Regressionsbaum trainieren rpart() Anova Formel, df, Methode
Teil Plotten Sie die Bรคume rpart.plot() angepasstes Modell
Base vorhersagen vorhersagen() Klasse Einbaumodell, Typ
Base vorhersagen vorhersagen() prob Einbaumodell, Typ
Base vorhersagen vorhersagen() Vektor Einbaumodell, Typ
Teil Regelparameter rpart.control() Minsplit Legen Sie die Mindestanzahl an Beobachtungen im Knoten fest, bevor der Algorithmus eine Aufteilung durchfรผhrt
Minbucket Legen Sie die Mindestanzahl an Beobachtungen in der Schlussnotiz, also dem Blatt, fest
maximale Tiefe Legen Sie die maximale Tiefe eines beliebigen Knotens des endgรผltigen Baums fest. Der Wurzelknoten wird mit der Tiefe 0 behandelt
Teil Trainieren Sie das Modell mit Steuerparametern rpart() Formel, df, Methode, Kontrolle

Hinweis: Trainieren Sie das Modell anhand von Trainingsdaten und testen Sie die Leistung anhand eines unsichtbaren Datensatzes, z. B. eines Testsatzes.

Fassen Sie diesen Beitrag mit folgenden Worten zusammen: