Entscheidungsbaum in R: Klassifizierungsbaum mit Beispiel
Was sind Entscheidungsbรคume?
Entscheidungsbรคume sind vielseitige Machine-Learning-Algorithmen, die sowohl Klassifizierungs- als auch Regressionsaufgaben ausfรผhren kรถnnen. Es sind sehr leistungsstarke Algorithmen, die in der Lage sind, komplexe Datensรคtze anzupassen. Darรผber hinaus sind Entscheidungsbรคume grundlegende Komponenten von Random Forests, die zu den leistungsfรคhigsten Machine-Learning-Algorithmen gehรถren, die heute verfรผgbar sind.
Training und Visualisierung eines Entscheidungsbaums in R
Um Ihren ersten Entscheidungsbaum im R-Beispiel zu erstellen, gehen wir in diesem Entscheidungsbaum-Tutorial wie folgt vor:
- Schritt 1: Importieren Sie die Daten
- Schritt 2: Bereinigen Sie den Datensatz
- Schritt 3: Zug-/Testset erstellen
- Schritt 4: Erstellen Sie das Modell
- Schritt 5: Machen Sie eine Vorhersage
- Schritt 6: Leistung messen
- Schritt 7: Optimieren Sie die Hyperparameter
Schritt 1) โโImportieren Sie die Daten
Wenn Sie neugierig auf das Schicksal der Titanic sind, kรถnnen Sie sich dieses Video ansehen Youtube. Der Zweck dieses Datensatzes besteht darin, vorherzusagen, welche Menschen nach der Kollision mit dem Eisberg mit grรถรerer Wahrscheinlichkeit รผberleben. Der Datensatz enthรคlt 13 Variablen und 1309 Beobachtungen. Der Datensatz wird nach der Variablen X geordnet.
set.seed(678) path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv' titanic <-read.csv(path) head(titanic)
Ausgang:
## X pclass survived name sex ## 1 1 1 1 Allen, Miss. Elisabeth Walton female ## 2 2 1 1 Allison, Master. Hudson Trevor male ## 3 3 1 0 Allison, Miss. Helen Loraine female ## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male ## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female ## 6 6 1 1 Anderson, Mr. Harry male ## age sibsp parch ticket fare cabin embarked ## 1 29.0000 0 0 24160 211.3375 B5 S ## 2 0.9167 1 2 113781 151.5500 C22 C26 S ## 3 2.0000 1 2 113781 151.5500 C22 C26 S ## 4 30.0000 1 2 113781 151.5500 C22 C26 S ## 5 25.0000 1 2 113781 151.5500 C22 C26 S ## 6 48.0000 0 0 19952 26.5500 E12 S ## home.dest ## 1 St Louis, MO ## 2 Montreal, PQ / Chesterville, ON ## 3 Montreal, PQ / Chesterville, ON ## 4 Montreal, PQ / Chesterville, ON ## 5 Montreal, PQ / Chesterville, ON ## 6 New York, NY
tail(titanic)
Ausgang:
## X pclass survived name sex age sibsp ## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0 ## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1 ## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1 ## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0 ## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0 ## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0 ## parch ticket fare cabin embarked home.dest ## 1304 0 2627 14.4583 C ## 1305 0 2665 14.4542 C ## 1306 0 2665 14.4542 C ## 1307 0 2656 7.2250 C ## 1308 0 2670 7.2250 C ## 1309 0 315082 7.8750 S
Anhand der Head- und Tail-Ausgabe kรถnnen Sie erkennen, dass die Daten nicht gemischt werden. Das ist ein groรes Problem! Wenn Sie Ihre Daten zwischen einem Zugsatz und einem Testsatz aufteilen, wรคhlen Sie aus einzige der Passagier der Klassen 1 und 2 (Kein Passagier der Klasse 3 befindet sich in den oberen 80 Prozent der Beobachtungen), was bedeutet, dass der Algorithmus niemals die Merkmale des Passagiers der Klasse 3 sehen wird. Dieser Fehler fรผhrt zu einer schlechten Vorhersage.
Um dieses Problem zu lรถsen, kรถnnen Sie die Funktion sample() verwenden.
shuffle_index <- sample(1:nrow(titanic)) head(shuffle_index)
Entscheidungsbaum-R-Code-Erklรคrung
- sample(1:nrow(titanic)): Generieren Sie eine zufรคllige Indexliste von 1 bis 1309 (dh die maximale Anzahl von Zeilen).
Ausgang:
## [1] 288 874 1078 633 887 992
Sie werden diesen Index verwenden, um den Titanic-Datensatz zu mischen.
titanic <- titanic[shuffle_index, ] head(titanic)
Ausgang:
## X pclass survived ## 288 288 1 0 ## 874 874 3 0 ## 1078 1078 3 1 ## 633 633 3 0 ## 887 887 3 1 ## 992 992 3 1 ## name sex age ## 288 Sutton, Mr. Frederick male 61 ## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42 ## 1078 O'Driscoll, Miss. Bridget female NA ## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39 ## 887 Jermyn, Miss. Annie female NA ## 992 Mamee, Mr. Hanna male NA ## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ ## 874 0 0 348121 7.6500 F G63 S ## 1078 0 0 14311 7.7500 Q ## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN ## 887 0 0 14313 7.7500 Q ## 992 0 0 2677 7.2292 C
Schritt 2) Bereinigen Sie den Datensatz
Die Struktur der Daten zeigt, dass einige Variablen NAs haben. Die Datenbereinigung muss wie folgt durchgefรผhrt werden
- Lรถschen Sie die Variablen home.dest,cabin, name, X und ticket
- Erstellen Sie Faktorvariablen fรผr PCLASS und Surved
- Lass die NA fallen
library(dplyr)
# Drop variables
clean_titanic <- titanic % > %
select(-c(home.dest, cabin, name, X, ticket)) % > %
#Convert to factor level
mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')),
survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > %
na.omit()
glimpse(clean_titanic)
Code Erklรคrung
- select(-c(home.dest, Kabine, Name, X, Ticket)): Lรถschen Sie unnรถtige Variablen
- pclass = Faktor(pclass, Ebenen = c(1,2,3), Labels= c('Upper', 'Middle', 'Lower')): Beschriftung zur Variablen pclass hinzufรผgen. 1 wird zum oberen Wert, 2 wird zum mittleren Wert und 3 wird zum niedrigeren Wert
- Faktor(รผberlebt, Ebenen = c(0,1), Labels = c('Nein', 'Ja')): Fรผge Label zur Variablen รผberlebt hinzu. 1 wird zu Nein und 2 wird zu Ja
- na.omit(): Entfernen Sie die NA-Beobachtungen
Ausgang:
## Observations: 1,045 ## Variables: 8 ## $ pclass <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U... ## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y... ## $ sex <fctr> male, male, female, female, male, male, female, male... ## $ age <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ... ## $ sibsp <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,... ## $ parch <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,... ## $ fare <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ... ## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...
Schritt 3) Erstellen Sie einen Zug-/Testsatz
Bevor Sie Ihr Modell trainieren, mรผssen Sie zwei Schritte ausfรผhren:
- Erstellen Sie einen Zug- und Testsatz: Sie trainieren das Modell auf dem Zugsatz und testen die Vorhersage auf dem Testsatz (d. h. unsichtbaren Daten).
- Installieren Sie rpart.plot รผber die Konsole
Die gรคngige Praxis besteht darin, die Daten im Verhรคltnis 80/20 aufzuteilen, wobei 80 Prozent der Daten zum Trainieren des Modells und 20 Prozent zur Erstellung von Vorhersagen dienen. Sie mรผssen zwei separate Datenrahmen erstellen. Sie mรถchten das Testset erst dann anfassen, wenn Sie mit der Erstellung Ihres Modells fertig sind. Sie kรถnnen einen Funktionsnamen create_train_test() erstellen, der drei Argumente akzeptiert.
create_train_test(df, size = 0.8, train = TRUE) arguments: -df: Dataset used to train the model. -size: Size of the split. By default, 0.8. Numerical value -train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) {
n_row = nrow(data)
total_row = size * n_row
train_sample < - 1: total_row
if (train == TRUE) {
return (data[train_sample, ])
} else {
return (data[-train_sample, ])
}
}
Code Erklรคrung
- function(data, size=0.8, train = TRUE): Fรผgen Sie die Argumente in der Funktion hinzu
- n_row = nrow(data): Zรคhlt die Anzahl der Zeilen im Datensatz
- total_row = size*n_row: Gibt die n-te Zeile zurรผck, um den Zugsatz zu erstellen
- train_sample <- 1:total_row: Wรคhlen Sie die erste bis n-te Zeile aus
- if (train ==TRUE){ } else { }: Wenn die Bedingung auf true gesetzt ist, wird der Zugsatz zurรผckgegeben, andernfalls der Testsatz.
Sie kรถnnen Ihre Funktion testen und die Abmessung รผberprรผfen.
data_train <- create_train_test(clean_titanic, 0.8, train = TRUE) data_test <- create_train_test(clean_titanic, 0.8, train = FALSE) dim(data_train)
Ausgang:
## [1] 836 8
dim(data_test)
Ausgang:
## [1] 209 8
Der Zugdatensatz hat 1046 Zeilen, wรคhrend der Testdatensatz 262 Zeilen hat.
Sie verwenden die Funktion prop.table() in Kombination mit table(), um zu รผberprรผfen, ob der Randomisierungsprozess korrekt ist.
prop.table(table(data_train$survived))
Ausgang:
## ## No Yes ## 0.5944976 0.4055024
prop.table(table(data_test$survived))
Ausgang:
## ## No Yes ## 0.5789474 0.4210526
In beiden Datensรคtzen ist die Anzahl der รberlebenden gleich, etwa 40 Prozent.
Installieren Sie rpart.plot
rpart.plot ist in den Conda-Bibliotheken nicht verfรผgbar. Sie kรถnnen es รผber die Konsole installieren:
install.packages("rpart.plot")
Schritt 4) Erstellen Sie das Modell
Sie sind bereit, das Modell zu erstellen. Die Syntax fรผr die Rpart-Entscheidungsbaumfunktion lautet:
rpart(formula, data=, method='') arguments: - formula: The function to predict - data: Specifies the data frame- method: - "class" for a classification tree - "anova" for a regression tree
Sie verwenden die Klassenmethode, weil Sie eine Klasse vorhersagen.
library(rpart) library(rpart.plot) fit <- rpart(survived~., data = data_train, method = 'class') rpart.plot(fit, extra = 106
Code Erklรคrung
- rpart(): Funktion zur Anpassung an das Modell. Die Argumente sind:
- รผberlebt ~.: Formel der Entscheidungsbรคume
- data = data_train: Datensatz
- method = 'class': Passen Sie ein binรคres Modell an
- rpart.plot(fit, extra= 106): Plotten Sie den Baum. Die zusรคtzlichen Funktionen werden auf 101 gesetzt, um die Wahrscheinlichkeit der 2. Klasse anzuzeigen (nรผtzlich fรผr binรคre Antworten). Sie kรถnnen sich auf die beziehen Vignette Weitere Informationen zu den anderen Optionen finden Sie hier.
Ausgang:
Sie beginnen am Wurzelknoten (Tiefe 0 รผber 3, oben im Diagramm):
- Ganz oben steht die Gesamtรผberlebenswahrscheinlichkeit. Es zeigt den Anteil der Passagiere, die den Unfall รผberlebt haben. 41 Prozent der Passagiere รผberlebten.
- Dieser Knoten fragt, ob das Geschlecht des Passagiers mรคnnlich ist. Wenn ja, gehen Sie zum linken untergeordneten Knoten der Wurzel (Tiefe 2). 63 Prozent sind Mรคnner mit einer รberlebenswahrscheinlichkeit von 21 Prozent.
- Im zweiten Knoten fragen Sie, ob der mรคnnliche Passagier รคlter als 3.5 Jahre ist. Wenn ja, dann liegt die รberlebenschance bei 19 Prozent.
- So machen Sie weiter, um zu verstehen, welche Merkmale die รberlebenswahrscheinlichkeit beeinflussen.
Beachten Sie, dass eine der vielen Eigenschaften von Entscheidungsbรคumen darin besteht, dass sie nur sehr wenig Datenvorbereitung erfordern. Insbesondere ist keine Feature-Skalierung oder -Zentrierung erforderlich.
Standardmรครig verwendet die Funktion rpart() die Gini Verunreinigungsmaรnahme, um die Note zu spalten. Je hรถher der Gini-Koeffizient, desto mehr unterschiedliche Instanzen innerhalb des Knotens.
Schritt 5) Machen Sie eine Vorhersage
Sie kรถnnen Ihren Testdatensatz vorhersagen. Um eine Vorhersage zu treffen, kรถnnen Sie die Funktion Predict() verwenden. Die grundlegende Syntax der Vorhersage fรผr den R-Entscheidungsbaum lautet:
predict(fitted_model, df, type = 'class')
arguments:
- fitted_model: This is the object stored after model estimation.
- df: Data frame used to make the prediction
- type: Type of prediction
- 'class': for classification
- 'prob': to compute the probability of each class
- 'vector': Predict the mean response at the node level
Sie mรถchten anhand des Testsatzes vorhersagen, welche Passagiere nach der Kollision mit grรถรerer Wahrscheinlichkeit รผberleben. Das heiรt, Sie werden wissen, welcher von diesen 209 Passagieren รผberleben wird oder nicht.
predict_unseen <-predict(fit, data_test, type = 'class')
Code Erklรคrung
- Predict(fit, data_test, type = 'class'): Sagen Sie die Klasse (0/1) des Testsatzes voraus
Testen der Passagiere, die es nicht geschafft haben, und derjenigen, die es geschafft haben.
table_mat <- table(data_test$survived, predict_unseen) table_mat
Code Erklรคrung
- table(data_test$survived, Predict_unseen): Erstellen Sie eine Tabelle, um zu zรคhlen, wie viele Passagiere als รberlebende und Verstorbene eingestuft werden, im Vergleich zur korrekten Entscheidungsbaumklassifizierung in R
Ausgang:
## predict_unseen ## No Yes ## No 106 15 ## Yes 30 58
Das Modell hat 106 tote Passagiere korrekt vorhergesagt, 15 รberlebende jedoch als tot eingestuft. Analog dazu klassifizierte das Modell 30 Passagiere fรคlschlicherweise als รberlebende, wรคhrend sich herausstellte, dass sie tot waren.
Schritt 6) Messen Sie die Leistung
Mit dem kรถnnen Sie ein Genauigkeitsmaร fรผr Klassifizierungsaufgaben berechnen Verwirrung Matrix:
Die Verwirrung Matrix ist eine bessere Wahl, um die Klassifizierungsleistung zu bewerten. Die allgemeine Idee besteht darin, zu zรคhlen, wie oft Wahre Instanzen als Falsch klassifiziert werden.
Jede Zeile in einer Verwirrungsmatrix stellt ein tatsรคchliches Ziel dar, wรคhrend jede Spalte ein vorhergesagtes Ziel darstellt. Die erste Zeile dieser Matrix berรผcksichtigt tote Passagiere (die False-Klasse): 106 wurden korrekt als tot klassifiziert (Wahr-negativ), wรคhrend der Rest fรคlschlicherweise als รberlebender eingestuft wurde (Falsch positiv). Die zweite Reihe berรผcksichtigt die รberlebenden, die positive Klasse betrug 58 (Richtig positiv), Wรคhrend die Wahr-negativ war 30.
Sie kรถnnen das berechnen Genauigkeitstest aus der Verwirrungsmatrix:
Es ist das Verhรคltnis von richtig positiv und richtig negativ an der Summe der Matrix. Mit R kรถnnen Sie wie folgt codieren:
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
Code Erklรคrung
- sum(diag(table_mat)): Summe der Diagonalen
- sum(table_mat): Summe der Matrix.
Sie kรถnnen die Genauigkeit des Testsatzes ausdrucken:
print(paste('Accuracy for test', accuracy_Test))
Ausgang:
## [1] "Accuracy for test 0.784688995215311"
Sie haben fรผr den Testsatz eine Punktzahl von 78 Prozent. Sie kรถnnen dieselbe รbung mit dem Trainingsdatensatz replizieren.
Schritt 7) Optimieren Sie die Hyperparameter
Entscheidungsbรคume in R haben verschiedene Parameter, die Aspekte der Anpassung steuern. In der Entscheidungsbaumbibliothek rpart kรถnnen Sie die Parameter mit der Funktion rpart.control() steuern. Im folgenden Code fรผhren Sie die Parameter ein, die Sie anpassen werden. Sie kรถnnen sich auf die Vignette fรผr andere Parameter.
rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30) Arguments: -minsplit: Set the minimum number of observations in the node before the algorithm perform a split -minbucket: Set the minimum number of observations in the final note i.e. the leaf -maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0
Wir werden wie folgt vorgehen:
- Konstruieren Sie eine Funktion, um Genauigkeit zurรผckzugeben
- Stellen Sie die maximale Tiefe ein
- Passen Sie die Mindestanzahl an Samples an, die ein Knoten haben muss, bevor er geteilt werden kann
- Passen Sie die Mindestanzahl an Samples an, die ein Blattknoten haben muss
Sie kรถnnen eine Funktion schreiben, um die Genauigkeit anzuzeigen. Sie packen einfach den Code ein, den Sie zuvor verwendet haben:
- Vorhersage: Predict_unseen <- Predict(fit, data_test, type = 'class')
- Tabelle erstellen: table_mat <- table(data_test$survived, Predict_unseen)
- Genauigkeit berechnen: precision_Test <- sum(diag(table_mat))/sum(table_mat)
accuracy_tune <- function(fit) {
predict_unseen <- predict(fit, data_test, type = 'class')
table_mat <- table(data_test$survived, predict_unseen)
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
accuracy_Test
}
Sie kรถnnen versuchen, die Parameter zu optimieren und zu sehen, ob Sie das Modell gegenรผber dem Standardwert verbessern kรถnnen. Zur Erinnerung: Sie mรผssen eine Genauigkeit von mehr als 0.78 erreichen
control <- rpart.control(minsplit = 4,
minbucket = round(5 / 3),
maxdepth = 3,
cp = 0)
tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control)
accuracy_tune(tune_fit)
Ausgang:
## [1] 0.7990431
Mit folgendem Parameter:
minsplit = 4 minbucket= round(5/3) maxdepth = 3cp=0
Sie erhalten eine hรถhere Leistung als beim Vorgรคngermodell. Herzlichen Glรผckwunsch!
Zusammenfassung
Wir kรถnnen die Funktionen zum Trainieren eines Entscheidungsbaumalgorithmus in R
| Bibliothek | Ziel | Funktion | Klasse | Kenngrรถรen | Details |
|---|---|---|---|---|---|
| Teil | Zugklassifizierungsbaum in R | rpart() | Klasse | Formel, df, Methode | |
| Teil | Regressionsbaum trainieren | rpart() | Anova | Formel, df, Methode | |
| Teil | Plotten Sie die Bรคume | rpart.plot() | angepasstes Modell | ||
| Base | vorhersagen | vorhersagen() | Klasse | Einbaumodell, Typ | |
| Base | vorhersagen | vorhersagen() | prob | Einbaumodell, Typ | |
| Base | vorhersagen | vorhersagen() | Vektor | Einbaumodell, Typ | |
| Teil | Regelparameter | rpart.control() | Minsplit | Legen Sie die Mindestanzahl an Beobachtungen im Knoten fest, bevor der Algorithmus eine Aufteilung durchfรผhrt | |
| Minbucket | Legen Sie die Mindestanzahl an Beobachtungen in der Schlussnotiz, also dem Blatt, fest | ||||
| maximale Tiefe | Legen Sie die maximale Tiefe eines beliebigen Knotens des endgรผltigen Baums fest. Der Wurzelknoten wird mit der Tiefe 0 behandelt | ||||
| Teil | Trainieren Sie das Modell mit Steuerparametern | rpart() | Formel, df, Methode, Kontrolle |
Hinweis: Trainieren Sie das Modell anhand von Trainingsdaten und testen Sie die Leistung anhand eines unsichtbaren Datensatzes, z. B. eines Testsatzes.



