Entscheidungsbaum in R: Klassifizierungsbaum mit Beispiel
Was sind Entscheidungsbäume?
Entscheidungsbäume sind vielseitige Machine-Learning-Algorithmen, die sowohl Klassifizierungs- als auch Regressionsaufgaben ausführen können. Es sind sehr leistungsstarke Algorithmen, die in der Lage sind, komplexe Datensätze anzupassen. Darüber hinaus sind Entscheidungsbäume grundlegende Komponenten von Random Forests, die zu den leistungsfähigsten Machine-Learning-Algorithmen gehören, die heute verfügbar sind.
Training und Visualisierung eines Entscheidungsbaums in R
Um Ihren ersten Entscheidungsbaum im R-Beispiel zu erstellen, gehen wir in diesem Entscheidungsbaum-Tutorial wie folgt vor:
- Schritt 1: Importieren Sie die Daten
- Schritt 2: Bereinigen Sie den Datensatz
- Schritt 3: Zug-/Testset erstellen
- Schritt 4: Erstellen Sie das Modell
- Schritt 5: Machen Sie eine Vorhersage
- Schritt 6: Leistung messen
- Schritt 7: Optimieren Sie die Hyperparameter
Schritt 1) Importieren Sie die Daten
Wenn Sie neugierig auf das Schicksal der Titanic sind, können Sie sich dieses Video ansehen Youtube. Der Zweck dieses Datensatzes besteht darin, vorherzusagen, welche Menschen nach der Kollision mit dem Eisberg mit größerer Wahrscheinlichkeit überleben. Der Datensatz enthält 13 Variablen und 1309 Beobachtungen. Der Datensatz wird nach der Variablen X geordnet.
set.seed(678) path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv' titanic <-read.csv(path) head(titanic)
Ausgang:
## X pclass survived name sex ## 1 1 1 1 Allen, Miss. Elisabeth Walton female ## 2 2 1 1 Allison, Master. Hudson Trevor male ## 3 3 1 0 Allison, Miss. Helen Loraine female ## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male ## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female ## 6 6 1 1 Anderson, Mr. Harry male ## age sibsp parch ticket fare cabin embarked ## 1 29.0000 0 0 24160 211.3375 B5 S ## 2 0.9167 1 2 113781 151.5500 C22 C26 S ## 3 2.0000 1 2 113781 151.5500 C22 C26 S ## 4 30.0000 1 2 113781 151.5500 C22 C26 S ## 5 25.0000 1 2 113781 151.5500 C22 C26 S ## 6 48.0000 0 0 19952 26.5500 E12 S ## home.dest ## 1 St Louis, MO ## 2 Montreal, PQ / Chesterville, ON ## 3 Montreal, PQ / Chesterville, ON ## 4 Montreal, PQ / Chesterville, ON ## 5 Montreal, PQ / Chesterville, ON ## 6 New York, NY
tail(titanic)
Ausgang:
## X pclass survived name sex age sibsp ## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0 ## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1 ## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1 ## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0 ## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0 ## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0 ## parch ticket fare cabin embarked home.dest ## 1304 0 2627 14.4583 C ## 1305 0 2665 14.4542 C ## 1306 0 2665 14.4542 C ## 1307 0 2656 7.2250 C ## 1308 0 2670 7.2250 C ## 1309 0 315082 7.8750 S
Anhand der Head- und Tail-Ausgabe können Sie erkennen, dass die Daten nicht gemischt werden. Das ist ein großes Problem! Wenn Sie Ihre Daten zwischen einem Zugsatz und einem Testsatz aufteilen, wählen Sie aus einzige der Passagier der Klassen 1 und 2 (Kein Passagier der Klasse 3 befindet sich in den oberen 80 Prozent der Beobachtungen), was bedeutet, dass der Algorithmus niemals die Merkmale des Passagiers der Klasse 3 sehen wird. Dieser Fehler führt zu einer schlechten Vorhersage.
Um dieses Problem zu lösen, können Sie die Funktion sample() verwenden.
shuffle_index <- sample(1:nrow(titanic)) head(shuffle_index)
Entscheidungsbaum-R-Code-Erklärung
- sample(1:nrow(titanic)): Generieren Sie eine zufällige Indexliste von 1 bis 1309 (dh die maximale Anzahl von Zeilen).
Ausgang:
## [1] 288 874 1078 633 887 992
Sie werden diesen Index verwenden, um den Titanic-Datensatz zu mischen.
titanic <- titanic[shuffle_index, ] head(titanic)
Ausgang:
## X pclass survived ## 288 288 1 0 ## 874 874 3 0 ## 1078 1078 3 1 ## 633 633 3 0 ## 887 887 3 1 ## 992 992 3 1 ## name sex age ## 288 Sutton, Mr. Frederick male 61 ## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42 ## 1078 O'Driscoll, Miss. Bridget female NA ## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39 ## 887 Jermyn, Miss. Annie female NA ## 992 Mamee, Mr. Hanna male NA ## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ ## 874 0 0 348121 7.6500 F G63 S ## 1078 0 0 14311 7.7500 Q ## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN ## 887 0 0 14313 7.7500 Q ## 992 0 0 2677 7.2292 C
Schritt 2) Bereinigen Sie den Datensatz
Die Struktur der Daten zeigt, dass einige Variablen NAs haben. Die Datenbereinigung muss wie folgt durchgeführt werden
- Löschen Sie die Variablen home.dest,cabin, name, X und ticket
- Erstellen Sie Faktorvariablen für PCLASS und Surved
- Lass die NA fallen
library(dplyr) # Drop variables clean_titanic <- titanic % > % select(-c(home.dest, cabin, name, X, ticket)) % > % #Convert to factor level mutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')), survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > % na.omit() glimpse(clean_titanic)
Code Erklärung
- select(-c(home.dest, Kabine, Name, X, Ticket)): Löschen Sie unnötige Variablen
- pclass = Faktor(pclass, Ebenen = c(1,2,3), Labels= c('Upper', 'Middle', 'Lower')): Beschriftung zur Variablen pclass hinzufügen. 1 wird zum oberen Wert, 2 wird zum mittleren Wert und 3 wird zum niedrigeren Wert
- Faktor(überlebt, Ebenen = c(0,1), Labels = c('Nein', 'Ja')): Füge Label zur Variablen überlebt hinzu. 1 wird zu Nein und 2 wird zu Ja
- na.omit(): Entfernen Sie die NA-Beobachtungen
Ausgang:
## Observations: 1,045 ## Variables: 8 ## $ pclass <fctr> Upper, Lower, Lower, Upper, Middle, Upper, Middle, U... ## $ survived <fctr> No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y... ## $ sex <fctr> male, male, female, female, male, male, female, male... ## $ age <dbl> 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0, ... ## $ sibsp <int> 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,... ## $ parch <int> 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,... ## $ fare <dbl> 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542, ... ## $ embarked <fctr> S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C...
Schritt 3) Erstellen Sie einen Zug-/Testsatz
Bevor Sie Ihr Modell trainieren, müssen Sie zwei Schritte ausführen:
- Erstellen Sie einen Zug- und Testsatz: Sie trainieren das Modell auf dem Zugsatz und testen die Vorhersage auf dem Testsatz (d. h. unsichtbaren Daten).
- Installieren Sie rpart.plot über die Konsole
Die gängige Praxis besteht darin, die Daten im Verhältnis 80/20 aufzuteilen, wobei 80 Prozent der Daten zum Trainieren des Modells und 20 Prozent zur Erstellung von Vorhersagen dienen. Sie müssen zwei separate Datenrahmen erstellen. Sie möchten das Testset erst dann anfassen, wenn Sie mit der Erstellung Ihres Modells fertig sind. Sie können einen Funktionsnamen create_train_test() erstellen, der drei Argumente akzeptiert.
create_train_test(df, size = 0.8, train = TRUE) arguments: -df: Dataset used to train the model. -size: Size of the split. By default, 0.8. Numerical value -train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample < - 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } }
Code Erklärung
- function(data, size=0.8, train = TRUE): Fügen Sie die Argumente in der Funktion hinzu
- n_row = nrow(data): Zählt die Anzahl der Zeilen im Datensatz
- total_row = size*n_row: Gibt die n-te Zeile zurück, um den Zugsatz zu erstellen
- train_sample <- 1:total_row: Wählen Sie die erste bis n-te Zeile aus
- if (train ==TRUE){ } else { }: Wenn die Bedingung auf true gesetzt ist, wird der Zugsatz zurückgegeben, andernfalls der Testsatz.
Sie können Ihre Funktion testen und die Abmessung überprüfen.
data_train <- create_train_test(clean_titanic, 0.8, train = TRUE) data_test <- create_train_test(clean_titanic, 0.8, train = FALSE) dim(data_train)
Ausgang:
## [1] 836 8
dim(data_test)
Ausgang:
## [1] 209 8
Der Zugdatensatz hat 1046 Zeilen, während der Testdatensatz 262 Zeilen hat.
Sie verwenden die Funktion prop.table() in Kombination mit table(), um zu überprüfen, ob der Randomisierungsprozess korrekt ist.
prop.table(table(data_train$survived))
Ausgang:
## ## No Yes ## 0.5944976 0.4055024
prop.table(table(data_test$survived))
Ausgang:
## ## No Yes ## 0.5789474 0.4210526
In beiden Datensätzen ist die Anzahl der Überlebenden gleich, etwa 40 Prozent.
Installieren Sie rpart.plot
rpart.plot ist in den Conda-Bibliotheken nicht verfügbar. Sie können es über die Konsole installieren:
install.packages("rpart.plot")
Schritt 4) Erstellen Sie das Modell
Sie sind bereit, das Modell zu erstellen. Die Syntax für die Rpart-Entscheidungsbaumfunktion lautet:
rpart(formula, data=, method='') arguments: - formula: The function to predict - data: Specifies the data frame- method: - "class" for a classification tree - "anova" for a regression tree
Sie verwenden die Klassenmethode, weil Sie eine Klasse vorhersagen.
library(rpart) library(rpart.plot) fit <- rpart(survived~., data = data_train, method = 'class') rpart.plot(fit, extra = 106
Code Erklärung
- rpart(): Funktion zur Anpassung an das Modell. Die Argumente sind:
- überlebt ~.: Formel der Entscheidungsbäume
- data = data_train: Datensatz
- method = 'class': Passen Sie ein binäres Modell an
- rpart.plot(fit, extra= 106): Plotten Sie den Baum. Die zusätzlichen Funktionen werden auf 101 gesetzt, um die Wahrscheinlichkeit der 2. Klasse anzuzeigen (nützlich für binäre Antworten). Sie können sich auf die beziehen Vignette Weitere Informationen zu den anderen Optionen finden Sie hier.
Ausgang:
Sie beginnen am Wurzelknoten (Tiefe 0 über 3, oben im Diagramm):
- Ganz oben steht die Gesamtüberlebenswahrscheinlichkeit. Es zeigt den Anteil der Passagiere, die den Unfall überlebt haben. 41 Prozent der Passagiere überlebten.
- Dieser Knoten fragt, ob das Geschlecht des Passagiers männlich ist. Wenn ja, gehen Sie zum linken untergeordneten Knoten der Wurzel (Tiefe 2). 63 Prozent sind Männer mit einer Überlebenswahrscheinlichkeit von 21 Prozent.
- Im zweiten Knoten fragen Sie, ob der männliche Passagier älter als 3.5 Jahre ist. Wenn ja, dann liegt die Überlebenschance bei 19 Prozent.
- So machen Sie weiter, um zu verstehen, welche Merkmale die Überlebenswahrscheinlichkeit beeinflussen.
Beachten Sie, dass eine der vielen Eigenschaften von Entscheidungsbäumen darin besteht, dass sie nur sehr wenig Datenvorbereitung erfordern. Insbesondere ist keine Feature-Skalierung oder -Zentrierung erforderlich.
Standardmäßig verwendet die Funktion rpart() die Gini Verunreinigungsmaßnahme, um die Note zu spalten. Je höher der Gini-Koeffizient, desto mehr unterschiedliche Instanzen innerhalb des Knotens.
Schritt 5) Machen Sie eine Vorhersage
Sie können Ihren Testdatensatz vorhersagen. Um eine Vorhersage zu treffen, können Sie die Funktion Predict() verwenden. Die grundlegende Syntax der Vorhersage für den R-Entscheidungsbaum lautet:
predict(fitted_model, df, type = 'class') arguments: - fitted_model: This is the object stored after model estimation. - df: Data frame used to make the prediction - type: Type of prediction - 'class': for classification - 'prob': to compute the probability of each class - 'vector': Predict the mean response at the node level
Sie möchten anhand des Testsatzes vorhersagen, welche Passagiere nach der Kollision mit größerer Wahrscheinlichkeit überleben. Das heißt, Sie werden wissen, welcher von diesen 209 Passagieren überleben wird oder nicht.
predict_unseen <-predict(fit, data_test, type = 'class')
Code Erklärung
- Predict(fit, data_test, type = 'class'): Sagen Sie die Klasse (0/1) des Testsatzes voraus
Testen der Passagiere, die es nicht geschafft haben, und derjenigen, die es geschafft haben.
table_mat <- table(data_test$survived, predict_unseen) table_mat
Code Erklärung
- table(data_test$survived, Predict_unseen): Erstellen Sie eine Tabelle, um zu zählen, wie viele Passagiere als Überlebende und Verstorbene eingestuft werden, im Vergleich zur korrekten Entscheidungsbaumklassifizierung in R
Ausgang:
## predict_unseen ## No Yes ## No 106 15 ## Yes 30 58
Das Modell hat 106 tote Passagiere korrekt vorhergesagt, 15 Überlebende jedoch als tot eingestuft. Analog dazu klassifizierte das Modell 30 Passagiere fälschlicherweise als Überlebende, während sich herausstellte, dass sie tot waren.
Schritt 6) Messen Sie die Leistung
Mit dem können Sie ein Genauigkeitsmaß für Klassifizierungsaufgaben berechnen Verwirrung Matrix:
Das Verwirrung Matrix ist eine bessere Wahl, um die Klassifizierungsleistung zu bewerten. Die allgemeine Idee besteht darin, zu zählen, wie oft Wahre Instanzen als Falsch klassifiziert werden.
Jede Zeile in einer Verwirrungsmatrix stellt ein tatsächliches Ziel dar, während jede Spalte ein vorhergesagtes Ziel darstellt. Die erste Zeile dieser Matrix berücksichtigt tote Passagiere (die False-Klasse): 106 wurden korrekt als tot klassifiziert (Wahr-negativ), während der Rest fälschlicherweise als Überlebender eingestuft wurde (Falsch positiv). Die zweite Reihe berücksichtigt die Überlebenden, die positive Klasse betrug 58 (Richtig positiv), Während die Wahr-negativ war 30.
Sie können das berechnen Genauigkeitstest aus der Verwirrungsmatrix:
Es ist das Verhältnis von richtig positiv und richtig negativ an der Summe der Matrix. Mit R können Sie wie folgt codieren:
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
Code Erklärung
- sum(diag(table_mat)): Summe der Diagonalen
- sum(table_mat): Summe der Matrix.
Sie können die Genauigkeit des Testsatzes ausdrucken:
print(paste('Accuracy for test', accuracy_Test))
Ausgang:
## [1] "Accuracy for test 0.784688995215311"
Sie haben für den Testsatz eine Punktzahl von 78 Prozent. Sie können dieselbe Übung mit dem Trainingsdatensatz replizieren.
Schritt 7) Optimieren Sie die Hyperparameter
Entscheidungsbäume in R haben verschiedene Parameter, die Aspekte der Anpassung steuern. In der Entscheidungsbaumbibliothek rpart können Sie die Parameter mit der Funktion rpart.control() steuern. Im folgenden Code führen Sie die Parameter ein, die Sie anpassen werden. Sie können sich auf die Vignette für andere Parameter.
rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30) Arguments: -minsplit: Set the minimum number of observations in the node before the algorithm perform a split -minbucket: Set the minimum number of observations in the final note i.e. the leaf -maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0
Wir werden wie folgt vorgehen:
- Konstruieren Sie eine Funktion, um Genauigkeit zurückzugeben
- Stellen Sie die maximale Tiefe ein
- Passen Sie die Mindestanzahl an Samples an, die ein Knoten haben muss, bevor er geteilt werden kann
- Passen Sie die Mindestanzahl an Samples an, die ein Blattknoten haben muss
Sie können eine Funktion schreiben, um die Genauigkeit anzuzeigen. Sie packen einfach den Code ein, den Sie zuvor verwendet haben:
- Vorhersage: Predict_unseen <- Predict(fit, data_test, type = 'class')
- Tabelle erstellen: table_mat <- table(data_test$survived, Predict_unseen)
- Genauigkeit berechnen: precision_Test <- sum(diag(table_mat))/sum(table_mat)
accuracy_tune <- function(fit) { predict_unseen <- predict(fit, data_test, type = 'class') table_mat <- table(data_test$survived, predict_unseen) accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test }
Sie können versuchen, die Parameter zu optimieren und zu sehen, ob Sie das Modell gegenüber dem Standardwert verbessern können. Zur Erinnerung: Sie müssen eine Genauigkeit von mehr als 0.78 erreichen
control <- rpart.control(minsplit = 4, minbucket = round(5 / 3), maxdepth = 3, cp = 0) tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control) accuracy_tune(tune_fit)
Ausgang:
## [1] 0.7990431
Mit folgendem Parameter:
minsplit = 4 minbucket= round(5/3) maxdepth = 3cp=0
Sie erhalten eine höhere Leistung als beim Vorgängermodell. Herzlichen Glückwunsch!
Zusammenfassung
Wir können die Funktionen zum Trainieren eines Entscheidungsbaumalgorithmus in R
Bibliothek | Ziel | Funktion | Klasse | Parameter | Details |
---|---|---|---|---|---|
Teil | Zugklassifizierungsbaum in R | rpart() | Klasse | Formel, df, Methode | |
Teil | Regressionsbaum trainieren | rpart() | Anova | Formel, df, Methode | |
Teil | Plotten Sie die Bäume | rpart.plot() | angepasstes Modell | ||
Base | vorhersagen | vorhersagen() | Klasse | Einbaumodell, Typ | |
Base | vorhersagen | vorhersagen() | prob | Einbaumodell, Typ | |
Base | vorhersagen | vorhersagen() | Vektor | Einbaumodell, Typ | |
Teil | Regelparameter | rpart.control() | Minsplit | Legen Sie die Mindestanzahl an Beobachtungen im Knoten fest, bevor der Algorithmus eine Aufteilung durchführt | |
Minbucket | Legen Sie die Mindestanzahl an Beobachtungen in der Schlussnotiz, also dem Blatt, fest | ||||
maximale Tiefe | Legen Sie die maximale Tiefe eines beliebigen Knotens des endgültigen Baums fest. Der Wurzelknoten wird mit der Tiefe 0 behandelt | ||||
Teil | Trainieren Sie das Modell mit Steuerparametern | rpart() | Formel, df, Methode, Kontrolle |
Hinweis: Trainieren Sie das Modell anhand von Trainingsdaten und testen Sie die Leistung anhand eines unsichtbaren Datensatzes, z. B. eines Testsatzes.