Ce sunt copacii de decizie?
Arborii de decizie sunt algoritmul de învățare automată versatil care poate îndeplini atât sarcini de clasificare, cât și sarcini de regresie. Sunt algoritmi foarte puternici, capabili să potrivească seturi de date complexe. În plus, arborii de decizie sunt componente fundamentale ale pădurilor aleatorii, care se numără printre cei mai puternici algoritmi de învățare automată disponibili astăzi.
Instruirea și vizualizarea copacilor de decizie
Pentru a construi primul dvs. arbore de decizie în exemplul R, vom proceda după cum urmează în acest tutorial Arborele de decizie:
- Pasul 1: Importați datele
- Pasul 2: Curățați setul de date
- Pasul 3: Creați trenul / setul de testare
- Pasul 4: Construiți modelul
- Pasul 5: Faceți predicții
- Pasul 6: Măsurați performanța
- Pasul 7: Acordați hiper-parametrii
Pasul 1) Importați datele
Dacă sunteți curioși de soarta titanului, puteți viziona acest videoclip pe Youtube. Scopul acestui set de date este de a prezice care sunt persoanele care sunt mai susceptibile să supraviețuiască după coliziunea cu aisbergul. Setul de date conține 13 variabile și 1309 de observații. Setul de date este ordonat de variabila X.
set.seed(678)path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv'titanic <-read.csv(path)head(titanic)
Ieșire:
## X pclass survived name sex## 1 1 1 1 Allen, Miss. Elisabeth Walton female## 2 2 1 1 Allison, Master. Hudson Trevor male## 3 3 1 0 Allison, Miss. Helen Loraine female## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female## 6 6 1 1 Anderson, Mr. Harry male## age sibsp parch ticket fare cabin embarked## 1 29.0000 0 0 24160 211.3375 B5 S## 2 0.9167 1 2 113781 151.5500 C22 C26 S## 3 2.0000 1 2 113781 151.5500 C22 C26 S## 4 30.0000 1 2 113781 151.5500 C22 C26 S## 5 25.0000 1 2 113781 151.5500 C22 C26 S## 6 48.0000 0 0 19952 26.5500 E12 S## home.dest## 1 St Louis, MO## 2 Montreal, PQ / Chesterville, ON## 3 Montreal, PQ / Chesterville, ON## 4 Montreal, PQ / Chesterville, ON## 5 Montreal, PQ / Chesterville, ON## 6 New York, NY
tail(titanic)
Ieșire:
## X pclass survived name sex age sibsp## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0## parch ticket fare cabin embarked home.dest## 1304 0 2627 14.4583 C## 1305 0 2665 14.4542 C## 1306 0 2665 14.4542 C## 1307 0 2656 7.2250 C## 1308 0 2670 7.2250 C## 1309 0 315082 7.8750 S
Din ieșirea capului și a cozii, puteți observa că datele nu sunt amestecate. Aceasta este o problemă mare! Când vă veți împărți datele între un set de trenuri și un set de testare, veți selecta doar pasagerul din clasa 1 și 2 (Niciun pasager din clasa 3 nu se află în top 80 la sută din observații), ceea ce înseamnă că algoritmul nu va vedea niciodată caracteristicile pasagerului din clasa 3. Această greșeală va duce la predicții slabe.
Pentru a depăși această problemă, puteți utiliza funcția sample ().
shuffle_index <- sample(1:nrow(titanic))head(shuffle_index)
Arborele decizional Cod R Explicație
- eșantion (1: nrow (titanic)): Generați o listă aleatorie a indexului de la 1 la 1309 (adică numărul maxim de rânduri).
Ieșire:
## [1] 288 874 1078 633 887 992
Veți utiliza acest index pentru a amesteca setul de date titanic.
titanic <- titanic[shuffle_index, ]head(titanic)
Ieșire:
## X pclass survived## 288 288 1 0## 874 874 3 0## 1078 1078 3 1## 633 633 3 0## 887 887 3 1## 992 992 3 1## name sex age## 288 Sutton, Mr. Frederick male 61## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42## 1078 O'Driscoll, Miss. Bridget female NA## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39## 887 Jermyn, Miss. Annie female NA## 992 Mamee, Mr. Hanna male NA## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ## 874 0 0 348121 7.6500 F G63 S## 1078 0 0 14311 7.7500 Q## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN## 887 0 0 14313 7.7500 Q## 992 0 0 2677 7.2292 C
Pasul 2) Curățați setul de date
Structura datelor arată că unele variabile au NA. Curățarea datelor trebuie efectuată după cum urmează
- Plasați variabilele home.dest, cabină, nume, X și bilet
- Creați variabile de factor pentru pclass și a supraviețuit
- Aruncați NA
library(dplyr)# Drop variablesclean_titanic <- titanic % > %select(-c(home.dest, cabin, name, X, ticket)) % > %#Convert to factor levelmutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')),survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > %na.omit()glimpse(clean_titanic)
Explicarea codului
- selectați (-c (home.dest, cabină, nume, X, bilet)): Eliminați variabilele inutile
- pclass = factor (pclass, nivele = c (1,2,3), etichete = c („Upper”, „Middle”, „Lower”)): Adăugați etichetă variabilei pclass. 1 devine superior, 2 devine MIddle și 3 devine inferior
- factor (supraviețuit, niveluri = c (0,1), etichete = c („Nu”, „Da”)): Adăugați etichetă variabilei supraviețuite. 1 Devine Nu și 2 devine Da
- na.omit (): Eliminați observațiile NA
Ieșire:
## Observations: 1,045## Variables: 8## $ pclassUpper, Lower, Lower, Upper, Middle, Upper, Middle, U… ## $ survived No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y… ## $ sex male, male, female, female, male, male, female, male… ## $ age 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0,… ## $ sibsp 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,… ## $ parch 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,… ## $ fare 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542,… ## $ embarked S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C…
Pasul 3) Creați setul de tren / test
Înainte de a vă antrena modelul, trebuie să efectuați doi pași:
- Creați un set de trenuri și teste: instruiți modelul pe setul de trenuri și testați predicția pe setul de testare (adică date nevăzute)
- Instalați rpart.plot din consolă
Practica obișnuită este de a împărți datele 80/20, 80 la sută din date servesc la instruirea modelului și 20 la sută pentru a face predicții. Trebuie să creați două cadre de date separate. Nu doriți să atingeți setul de testare până nu terminați construirea modelului. Puteți crea un nume de funcție create_train_test () care ia trei argumente.
create_train_test(df, size = 0.8, train = TRUE)arguments:-df: Dataset used to train the model.-size: Size of the split. By default, 0.8. Numerical value-train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample < - 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}
Explicarea codului
- funcție (date, dimensiune = 0,8, tren = ADEVĂRAT): Adăugați argumentele în funcție
- n_row = nrow (data): Numărați numărul de rânduri din setul de date
- total_row = size * n_row: Returnează al n-lea rând pentru a construi setul de trenuri
- train_sample <- 1: total_row: Selectați primul rând până la al n-lea rând
- if (train == TRUE) {} else {}: Dacă condiția se setează la true, întoarceți setul de tren, altfel setul de testare.
Vă puteți testa funcția și puteți verifica dimensiunea.
data_train <- create_train_test(clean_titanic, 0.8, train = TRUE)data_test <- create_train_test(clean_titanic, 0.8, train = FALSE)dim(data_train)
Ieșire:
## [1] 836 8
dim(data_test)
Ieșire:
## [1] 209 8
Setul de date de tren are 1046 de rânduri, în timp ce setul de date de testare are 262 de rânduri.
Utilizați funcția prop.table () combinată cu table () pentru a verifica dacă procesul de randomizare este corect.
prop.table(table(data_train$survived))
Ieșire:
#### No Yes## 0.5944976 0.4055024
prop.table(table(data_test$survived))
Ieșire:
#### No Yes## 0.5789474 0.4210526
În ambele seturi de date, cantitatea de supraviețuitori este aceeași, aproximativ 40%.
Instalați rpart.plot
rpart.plot nu este disponibil din bibliotecile conda. Puteți să-l instalați de pe consolă:
install.packages("rpart.plot")
Pasul 4) Construiți modelul
Sunteți gata să construiți modelul. Sintaxa pentru funcția arborelui de decizie Rpart este:
rpart(formula, data=, method='')arguments:- formula: The function to predict- data: Specifies the data frame- method:- "class" for a classification tree- "anova" for a regression tree
Folosești metoda clasei pentru că prezici o clasă.
library(rpart)library(rpart.plot)fit <- rpart(survived~., data = data_train, method = 'class')rpart.plot(fit, extra = 106
Explicarea codului
- rpart (): Funcție pentru a se potrivi modelului. Argumentele sunt:
- a supraviețuit ~ .: Formula copacilor de decizie
- date = data_train: Dataset
- method = 'class': se potrivește unui model binar
- rpart.plot (fit, extra = 106): Plotează arborele. Funcțiile suplimentare sunt setate la 101 pentru a afișa probabilitatea clasei a 2-a (utilă pentru răspunsurile binare). Puteți consulta vigneta pentru mai multe informații despre celelalte opțiuni.
Ieșire:
Începeți de la nodul rădăcină (adâncimea 0 peste 3, partea de sus a graficului):
- În partea de sus, este probabilitatea generală de supraviețuire. Arată proporția de pasageri care a supraviețuit accidentului. 41% dintre pasageri au supraviețuit.
- Acest nod întreabă dacă sexul pasagerului este bărbat. Dacă da, atunci coborâți la nodul copil stânga al rădăcinii (adâncimea 2). 63 la sută sunt bărbați cu o probabilitate de supraviețuire de 21 la sută.
- În al doilea nod, întrebați dacă pasagerul bărbat are peste 3,5 ani. Dacă da, atunci șansa de supraviețuire este de 19%.
- Continuați să mergeți așa pentru a înțelege ce caracteristici influențează probabilitatea de supraviețuire.
Rețineți că, una dintre numeroasele calități ale copacilor de decizie este că acestea necesită foarte puțină pregătire a datelor. În special, nu necesită scalare sau centrare a caracteristicilor.
În mod implicit, funcția rpart () folosește măsura de impuritate Gini pentru a împărți nota. Cu cât este mai mare coeficientul Gini, cu atât sunt mai diferite instanțe în cadrul nodului.
Pasul 5) Faceți o predicție
Puteți prezice setul de date de testare. Pentru a face o predicție, puteți utiliza funcția predict (). Sintaxa de bază a predicției pentru arborele de decizie R este:
predict(fitted_model, df, type = 'class')arguments:- fitted_model: This is the object stored after model estimation.- df: Data frame used to make the prediction- type: Type of prediction- 'class': for classification- 'prob': to compute the probability of each class- 'vector': Predict the mean response at the node level
Vrei să prezici care pasageri sunt mai predispuși să supraviețuiască după coliziunea din setul de testare. Înseamnă că veți ști printre cei 209 de pasageri care dintre ei vor supraviețui sau nu.
predict_unseen <-predict(fit, data_test, type = 'class')
Explicarea codului
- predict (fit, data_test, type = 'class'): prezice clasa (0/1) a setului de testare
Testarea pasagerului care nu a reușit și a celor care au reușit.
table_mat <- table(data_test$survived, predict_unseen)table_mat
Explicarea codului
- table (data_test $ supraviețuit, predict_unseen): Creați un tabel pentru a număra câți pasageri sunt clasificați ca supraviețuitori și au murit, comparați cu clasificarea corectă a arborelui decizional din R
Ieșire:
## predict_unseen## No Yes## No 106 15## Yes 30 58
Modelul a prezis corect 106 pasageri morți, dar a clasificat 15 supraviețuitori ca morți. Prin analogie, modelul a clasificat greșit 30 de pasageri ca supraviețuitori în timp ce s-au dovedit a fi morți.
Pasul 6) Măsurați performanța
Puteți calcula o măsură de precizie pentru sarcina de clasificare cu matricea de confuzie :
Matricea de confuzie este o alegere mai bună pentru a evalua performanța de clasificare. Ideea generală este să numărăm de câte ori sunt clasificate instanțele adevărate ca fiind false.
Fiecare rând dintr-o matrice de confuzie reprezintă o țintă reală, în timp ce fiecare coloană reprezintă o țintă prevăzută. Primul rând al acestei matrice consideră pasagerii morți (clasa False): 106 au fost corect clasificați ca morți ( Adevărat negativ ), în timp ce cel rămas a fost clasificat în mod greșit ca supraviețuitor ( Fals pozitiv ). Al doilea rând îi consideră pe supraviețuitori, clasa pozitivă a fost 58 ( Adevărat pozitiv ), în timp ce Adevărul negativ a fost 30.
Puteți calcula testul de precizie din matricea de confuzie:
Este proporția adevărat pozitiv și negativ adevărat peste suma matricei. Cu R, puteți codifica după cum urmează:
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
Explicarea codului
- sum (diag (table_mat)): Suma diagonalei
- sum (table_mat): Suma matricei.
Puteți imprima acuratețea setului de testare:
print(paste('Accuracy for test', accuracy_Test))
Ieșire:
## [1] "Accuracy for test 0.784688995215311"
Aveți un scor de 78% pentru setul de testare. Puteți replica același exercițiu cu setul de date de antrenament.
Pasul 7) Reglați hiper-parametrii
Arborele decizional din R are diverși parametri care controlează aspectele potrivirii. În biblioteca arborelui de decizie rpart, puteți controla parametrii utilizând funcția rpart.control (). În următorul cod, introduceți parametrii pe care îi veți regla. Puteți consulta vigneta pentru alți parametri.
rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30)Arguments:-minsplit: Set the minimum number of observations in the node before the algorithm perform a split-minbucket: Set the minimum number of observations in the final note i.e. the leaf-maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0
Vom proceda după cum urmează:
- Funcția de construcție pentru a restabili precizia
- Reglați adâncimea maximă
- Reglați numărul minim de eșantioane pe care trebuie să le aibă un nod înainte de a se putea împărți
- Reglați numărul minim de eșantioane pe care trebuie să le aibă un nod frunză
Puteți scrie o funcție pentru a afișa acuratețea. Pur și simplu înfășurați codul pe care l-ați folosit înainte:
- predict: predict_unseen <- predict (fit, data_test, type = 'class')
- Produceți tabelul: table_mat <- table (data_test $ supraviețuit, predict_unseen)
- Calculați acuratețea: accurate_Test <- sum (diag (table_mat)) / sum (table_mat)
accuracy_tune <- function(fit) {predict_unseen <- predict(fit, data_test, type = 'class')table_mat <- table(data_test$survived, predict_unseen)accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_Test}
Puteți încerca să reglați parametrii și să vedeți dacă puteți îmbunătăți modelul peste valoarea implicită. Ca memento, trebuie să obțineți o precizie mai mare de 0,78
control <- rpart.control(minsplit = 4,minbucket = round(5 / 3),maxdepth = 3,cp = 0)tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control)accuracy_tune(tune_fit)
Ieșire:
## [1] 0.7990431
Cu următorul parametru:
minsplit = 4minbucket= round(5/3)maxdepth = 3cp=0
Obțineți o performanță mai mare decât modelul anterior. Felicitări!
rezumat
Putem rezuma funcțiile de instruire a algoritmului arborelui decizional în R
Bibliotecă |
Obiectiv |
funcţie |
clasă |
parametrii |
Detalii |
---|---|---|---|---|---|
rpart |
Arborele de clasificare a trenului în R |
rpart () |
clasă |
formula, df, metodă | |
rpart |
Arborele de regresie al trenului |
rpart () |
anova |
formula, df, metodă | |
rpart |
Complotează copacii |
rpart.plot () |
model montat | ||
baza |
prezice |
prezice() |
clasă |
model montat, tip | |
baza |
prezice |
prezice() |
prob |
model montat, tip | |
baza |
prezice |
prezice() |
vector |
model montat, tip | |
rpart |
Parametrii de control |
rpart.control () |
minsplit |
Setați numărul minim de observații în nod înainte ca algoritmul să efectueze o împărțire |
|
minbucket |
Setați numărul minim de observații în nota finală, adică frunza |
||||
adancime maxima |
Setați adâncimea maximă a oricărui nod al arborelui final. Nodul rădăcină este tratat la o adâncime 0 |
||||
rpart |
Model de tren cu parametru de control |
rpart () |
formula, df, metodă, control |
Notă: Instruiți modelul pe date de antrenament și testați performanța pe un set de date nevăzut, adică set de testare.