Ce este regresia logistică?
Regresia logistică este utilizată pentru a prezice o clasă, adică o probabilitate. Regresia logistică poate prezice cu exactitate un rezultat binar.
Imaginați-vă că doriți să preziceți dacă un împrumut este refuzat / acceptat pe baza multor atribute. Regresia logistică are forma 0/1. y = 0 dacă un împrumut este respins, y = 1 dacă este acceptat.
Un model de regresie logistică diferă de modelul de regresie liniară în două moduri.
- În primul rând, regresia logistică acceptă doar intrarea dihotomică (binară) ca variabilă dependentă (adică un vector de 0 și 1).
- În al doilea rând, rezultatul este măsurat prin următoarea funcție de legătură probabilistică numită sigmoid datorită formei sale S:
Ieșirea funcției este întotdeauna între 0 și 1. Verificați imaginea de mai jos
Funcția sigmoidă returnează valori de la 0 la 1. Pentru sarcina de clasificare, avem nevoie de o ieșire discretă de 0 sau 1.
Pentru a converti un flux continuu în valoare discretă, putem seta o decizie legată la 0,5. Toate valorile peste acest prag sunt clasificate ca 1
În acest tutorial, veți învăța
- Ce este regresia logistică?
- Cum se creează un model de linie generalizată (GLM)
- Pasul 1) Verificați variabilele continue
- Pasul 2) Verificați variabilele factorului
- Pasul 3) Ingineria caracteristicilor
- Pasul 4) Statistică sumară
- Pasul 5) Set tren / test
- Pasul 6) Construiți modelul
- Pasul 7) Evaluați performanța modelului
Cum se creează un model de linie generalizată (GLM)
Să folosim setul de date pentru adulți pentru a ilustra regresia logistică. „Adultul” este un set de date excelent pentru sarcina de clasificare. Obiectivul este de a prezice dacă venitul anual în dolari al unei persoane va depăși 50.000. Setul de date conține 46.033 de observații și zece caracteristici:
- vârstă: vârsta individului. Numeric
- educație: Nivelul educațional al individului. Factor.
- marital.status: Starea civilă a individului. Factor, adică Necăsătorit, Căsătorit-civ-soț, ...
- sex: Sexul individului. Factor, adică bărbat sau femeie
- venit: variabilă țintă. Venituri peste sau sub 50K. Factor adică> 50K, <= 50K
printre altii
library(dplyr)data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv")glimpse(data_adult)
Ieșire:
Observations: 48,842Variables: 10$ x1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,… $ age 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26… $ workclass Private, Private, Local-gov, Private, ?, Private,… $ education 11th, HS-grad, Assoc-acdm, Some-college, Some-col… $ educational.num 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,… $ marital.status Never-married, Married-civ-spouse, Married-civ-sp… $ race Black, White, White, Black, White, White, Black,… $ gender Male, Male, Male, Male, Female, Male, Male, Male,… $ hours.per.week 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39… $ income <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5…
Vom proceda după cum urmează:
- Pasul 1: Verificați variabilele continue
- Pasul 2: Verificați variabilele factorului
- Pasul 3: Ingineria caracteristicilor
- Pasul 4: statistică sumară
- Pasul 5: Antrenează / testează setul
- Pasul 6: Construiți modelul
- Pasul 7: evaluați performanța modelului
- pasul 8: Îmbunătățiți modelul
Sarcina dvs. este de a prezice care persoană va avea un venit mai mare de 50K.
În acest tutorial, fiecare pas va fi detaliat pentru a efectua o analiză pe un set de date real.
Pasul 1) Verificați variabilele continue
În primul pas, puteți vedea distribuția variabilelor continue.
continuous <-select_if(data_adult, is.numeric)summary(continuous)
Explicarea codului
- continuu <- select_if (data_adult, is.numeric): Utilizați funcția select_if () din biblioteca dplyr pentru a selecta doar coloanele numerice
- rezumat (continuu): tipăriți statistica rezumatului
Ieșire:
## X age educational.num hours.per.week## Min. : 1 Min. :17.00 Min. : 1.00 Min. : 1.00## 1st Qu.:11509 1st Qu.:28.00 1st Qu.: 9.00 1st Qu.:40.00## Median :23017 Median :37.00 Median :10.00 Median :40.00## Mean :23017 Mean :38.56 Mean :10.13 Mean :40.95## 3rd Qu.:34525 3rd Qu.:47.00 3rd Qu.:13.00 3rd Qu.:45.00## Max. :46033 Max. :90.00 Max. :16.00 Max. :99.00
Din tabelul de mai sus, puteți vedea că datele au scări și ore total diferite.
Puteți face față urmând doi pași:
- 1: Trageți distribuția orelor.pe săptămână
- 2: Standardizați variabilele continue
- Complotați distribuția
Să ne uităm mai atent la distribuția orelor.per.saptamână
# Histogram with kernel density curvelibrary(ggplot2)ggplot(continuous, aes(x = hours.per.week)) +geom_density(alpha = .2, fill = "#FF6666")
Ieșire:
Variabila are o mulțime de valori aberante și o distribuție nu este bine definită. Puteți rezolva parțial această problemă ștergând primele 0,01% din ore pe săptămână.
Sintaxa de bază a cuantilului:
quantile(variable, percentile)arguments:-variable: Select the variable in the data frame to compute the percentile-percentile: Can be a single value between 0 and 1 or multiple value. If multiple, use this format: `c(A,B,C,… )- `A`,`B`,`C` and `… ` are all integer from 0 to 1.
Calculăm prima percentilă de 2%
top_one_percent <- quantile(data_adult$hours.per.week, .99)top_one_percent
Explicarea codului
- quantile (data_adult $ hours.per.week, .99): calculați valoarea celor 99% din timpul de lucru
Ieșire:
## 99%## 80
98 la sută din populație lucrează sub 80 de ore pe săptămână.
Puteți renunța la observații peste acest prag. Folosiți filtrul din biblioteca dplyr.
data_adult_drop <-data_adult %>%filter(hours.per.weekIeșire:
## [1] 45537 10
- Standardizați variabilele continue
Puteți standardiza fiecare coloană pentru a îmbunătăți performanța, deoarece datele dvs. nu au aceeași scară. Puteți utiliza funcția mutate_if din biblioteca dplyr. Sintaxa de bază este:
mutate_if(df, condition, funs(function))arguments:-`df`: Data frame used to compute the function- `condition`: Statement used. Do not use parenthesis- funs(function): Return the function to apply. Do not use parenthesis for the functionPuteți standardiza coloanele numerice după cum urmează:
data_adult_rescale <- data_adult_drop % > %mutate_if(is.numeric, funs(as.numeric(scale(.))))head(data_adult_rescale)Explicarea codului
- mutate_if (is.numeric, funs (scale)): Condiția este doar coloană numerică și funcția este scale
Ieșire:
## X age workclass education educational.num## 1 -1.732680 -1.02325949 Private 11th -1.22106443## 2 -1.732605 -0.03969284 Private HS-grad -0.43998868## 3 -1.732530 -0.79628257 Local-gov Assoc-acdm 0.73162494## 4 -1.732455 0.41426100 Private Some-college -0.04945081## 5 -1.732379 -0.34232873 Private 10th -1.61160231## 6 -1.732304 1.85178149 Self-emp-not-inc Prof-school 1.90323857## marital.status race gender hours.per.week income## 1 Never-married Black Male -0.03995944 <=50K## 2 Married-civ-spouse White Male 0.86863037 <=50K## 3 Married-civ-spouse White Male -0.03995944 >50K## 4 Married-civ-spouse Black Male -0.03995944 >50K## 5 Never-married White Male -0.94854924 <=50K## 6 Married-civ-spouse White Male -0.76683128 >50KPasul 2) Verificați variabilele factorului
Acest pas are două obiective:
- Verificați nivelul din fiecare coloană categorică
- Definiți noi niveluri
Vom împărți acest pas în trei părți:
- Selectați coloanele categorice
- Stocați diagrama cu bare a fiecărei coloane într-o listă
- Imprimați graficele
Putem selecta coloanele factorului cu codul de mai jos:
# Select categorical columnfactor <- data.frame(select_if(data_adult_rescale, is.factor))ncol(factor)Explicarea codului
- data.frame (select_if (data_adult, is.factor)): stocăm coloanele factor în factor într-un tip de cadru de date. Biblioteca ggplot2 necesită un obiect cadru de date.
Ieșire:
## [1] 6Setul de date conține 6 variabile categorice
Al doilea pas este mai priceput. Doriți să reprezentați o diagramă cu bare pentru fiecare coloană în factorul cadru de date. Este mai convenabil să automatizați procesul, mai ales în situația în care există o mulțime de coloane.
library(ggplot2)# Create graph for each columngraph <- lapply(names(factor),function(x)ggplot(factor, aes(get(x))) +geom_bar() +theme(axis.text.x = element_text(angle = 90)))Explicarea codului
- lapply (): Utilizați funcția lapply () pentru a transmite o funcție în toate coloanele setului de date. Stocați ieșirea într-o listă
- funcție (x): funcția va fi procesată pentru fiecare x. Aici x sunt coloanele
- ggplot (factor, aes (get (x))) + geom_bar () + theme (axis.text.x = element_text (angle = 90)): Creați o diagramă de bare pentru fiecare element x. Notă, pentru a returna x ca coloană, trebuie să îl includeți în get ()
Ultimul pas este relativ ușor. Doriți să imprimați cele 6 grafice.
# Print the graphgraphIeșire:
## [[1]]## ## [[2]]## ## [[3]]## ## [[4]]## ## [[5]]## ## [[6]]Notă: Utilizați butonul următor pentru a naviga la următorul grafic
Pasul 3) Ingineria caracteristicilor
Educație reformată
Din graficul de mai sus, puteți vedea că variabila educație are 16 niveluri. Acest lucru este substanțial, iar unele niveluri au un număr relativ mic de observații. Dacă doriți să îmbunătățiți cantitatea de informații pe care o puteți obține de la această variabilă, o puteți reforma la nivel superior. Și anume, creați grupuri mai mari cu un nivel de educație similar. De exemplu, nivelul scăzut de educație va fi convertit în abandon. Nivelurile superioare de educație vor fi schimbate în master.
Iată detaliul:
Vechi nivel
Nivel nou
Preşcolar
renunța
Al 10-lea
Renunța
11
Renunța
Al 12-lea
Renunța
1-4
Renunța
5-6
Renunța
7-8
Renunța
9
Renunța
HS-Grad
HighGrad
Unele colegii
Comunitate
Assoc-acdm
Comunitate
Conf. Voc
Comunitate
Bachelor
Bachelor
studii de masterat
studii de masterat
Prof-scoala
studii de masterat
Doctorat
Doctorat
recast_data <- data_adult_rescale % > %select(-X) % > %mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community",ifelse(education == "Bachelors", "Bachelors",ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))Explicarea codului
- Folosim verbul mutare din biblioteca dplyr. Schimbăm valorile educației cu afirmația ifelse
În tabelul de mai jos, creați o statistică sumară pentru a vedea, în medie, câți ani de educație (valoarea z) este nevoie pentru a ajunge la licență, master sau doctorat.
recast_data % > %group_by(education) % > %summarize(average_educ_year = mean(educational.num),count = n()) % > %arrange(average_educ_year)Ieșire:
## # A tibble: 6 x 3## education average_educ_year count#### 1 dropout -1.76147258 5712## 2 HighGrad -0.43998868 14803## 3 Community 0.09561361 13407## 4 Bachelors 1.12216282 7720## 5 Master 1.60337381 3338## 6 PhD 2.29377644 557 Reformarea stării civile
De asemenea, este posibil să se creeze niveluri mai scăzute pentru starea civilă. În următorul cod schimbați nivelul după cum urmează:
Vechi nivel
Nivel nou
Niciodata casatorit
Necasatorit
Căsătorit-soț-absent
Necasatorit
Căsătorit-AF-soț
Căsătorit
Căsătorit-civ-soț
Separat
Separat
Divorţat
Văduve
Văduvă
# Change level marryrecast_data <- recast_data % > %mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))Puteți verifica numărul de persoane din fiecare grup.table(recast_data$marital.status)Ieșire:
## ## Married Not_married Separated Widow## 21165 15359 7727 1286Pasul 4) Statistică sumară
Este timpul să verificăm câteva statistici despre variabilele noastre țintă. În graficul de mai jos, contorizați procentul de persoane care câștigă mai mult de 50.000, având în vedere sexul lor.
# Plot gender incomeggplot(recast_data, aes(x = gender, fill = income)) +geom_bar(position = "fill") +theme_classic()Ieșire:
Apoi, verificați dacă originea persoanei afectează câștigurile lor.
# Plot origin incomeggplot(recast_data, aes(x = race, fill = income)) +geom_bar(position = "fill") +theme_classic() +theme(axis.text.x = element_text(angle = 90))Ieșire:
Numărul de ore de lucru după sex.
# box plot gender working timeggplot(recast_data, aes(x = gender, y = hours.per.week)) +geom_boxplot() +stat_summary(fun.y = mean,geom = "point",size = 3,color = "steelblue") +theme_classic()Ieșire:
Graficul cutiei confirmă faptul că distribuția timpului de lucru se potrivește diferitelor grupuri. În graficul cutiei, ambele sexe nu au observații omogene.
Puteți verifica densitatea timpului de lucru săptămânal în funcție de tipul de educație. Distribuțiile au multe opțiuni distincte. Poate fi explicat probabil prin tipul de contract din SUA.
# Plot distribution working time by educationggplot(recast_data, aes(x = hours.per.week)) +geom_density(aes(color = education), alpha = 0.5) +theme_classic()Explicarea codului
- ggplot (recast_data, aes (x = hours.per.week)): un grafic de densitate necesită o singură variabilă
- geom_density (aes (culoare = educație), alfa = 0,5): Obiectul geometric pentru a controla densitatea
Ieșire:
Pentru a vă confirma gândurile, puteți efectua un test ANOVA unidirecțional:
anova <- aov(hours.per.week~education, recast_data)summary(anova)Ieșire:
## Df Sum Sq Mean Sq F value Pr(>F)## education 5 1552 310.31 321.2 <2e-16 ***## Residuals 45531 43984 0.97## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1Testul ANOVA confirmă diferența de medie între grupuri.
Non-liniaritatea
Înainte de a rula modelul, puteți vedea dacă numărul de ore lucrate este legat de vârstă.
library(ggplot2)ggplot(recast_data, aes(x = age, y = hours.per.week)) +geom_point(aes(color = income),size = 0.5) +stat_smooth(method = 'lm',formula = y~poly(x, 2),se = TRUE,aes(color = income)) +theme_classic()Explicarea codului
- ggplot (recast_data, aes (x = age, y = hours.per.week)): Setați estetica graficului
- geom_point (aes (culoare = venit), dimensiune = 0,5): construiți graficul de puncte
- stat_smooth (): Adăugați linia de tendință cu următoarele argumente:
- method = 'lm': Se trasează valoarea ajustată dacă regresia liniară
- formula = y ~ poli (x, 2): Se potrivește o regresie polinomială
- se = ADEVĂRAT: Adăugați eroarea standard
- aes (culoare = venit): rupeți modelul după venituri
Ieșire:
Pe scurt, puteți testa termenii de interacțiune în model pentru a prelua efectul de non-liniaritate între timpul de lucru săptămânal și alte caracteristici. Este important să detectăm în ce condiție timpul de lucru diferă.
Corelație
Următoarea verificare este de a vizualiza corelația dintre variabile. Puteți converti tipul nivelului factorului în numeric, astfel încât să puteți trasa o hartă de căldură care conține coeficientul de corelație calculat cu metoda Spearman.
library(GGally)# Convert data to numericcorr <- data.frame(lapply(recast_data, as.integer))# Plot the graphggcorr(corr,method = c("pairwise", "spearman"),nbreaks = 6,hjust = 0.8,label = TRUE,label_size = 3,color = "grey50")Explicarea codului
- data.frame (lapply (recast_data, as.integer)): convertiți datele în numerice
- ggcorr () trasează harta de căldură cu următoarele argumente:
- metoda: Metoda de calcul a corelației
- nbreaks = 6: Numărul de pauze
- hjust = 0.8: Poziția de control a numelui variabilei în complot
- etichetă = ADEVĂRAT: Adăugați etichete în centrul ferestrelor
- label_size = 3: dimensiunea etichetelor
- color = "grey50"): Culoarea etichetei
Ieșire:
Pasul 5) Set tren / test
Orice sarcină de învățare automată supravegheată necesită împărțirea datelor între un set de trenuri și un set de testare. Puteți utiliza „funcția” pe care ați creat-o în celelalte tutoriale de învățare supravegheate pentru a crea un set de trenuri / teste.
set.seed(1234)create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample <- 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}data_train <- create_train_test(recast_data, 0.8, train = TRUE)data_test <- create_train_test(recast_data, 0.8, train = FALSE)dim(data_train)Ieșire:
## [1] 36429 9dim(data_test)Ieșire:
## [1] 9108 9Pasul 6) Construiți modelul
Pentru a vedea cum funcționează algoritmul, utilizați pachetul glm (). Modelul liniar generalizat este o colecție de modele. Sintaxa de bază este:
glm(formula, data=data, family=linkfunction()Argument:- formula: Equation used to fit the model- data: dataset used- Family: - binomial: (link = "logit")- gaussian: (link = "identity")- Gamma: (link = "inverse")- inverse.gaussian: (link = "1/mu^2")- poisson: (link = "log")- quasi: (link = "identity", variance = "constant")- quasibinomial: (link = "logit")- quasipoisson: (link = "log")Sunteți gata să estimați modelul logistic pentru a împărți nivelul veniturilor între un set de caracteristici.
formula <- income~.logit <- glm(formula, data = data_train, family = 'binomial')summary(logit)Explicarea codului
- formula <- venit ~.: Creați modelul pentru a se potrivi
- logit <- glm (formula, date = data_train, family = 'binomial'): Se potrivește un model logistic (family = 'binomial') cu datele data_train.
- rezumat (logit): tipăriți rezumatul modelului
Ieșire:
#### Call:## glm(formula = formula, family = "binomial", data = data_train)## ## Deviance Residuals:## Min 1Q Median 3Q Max## -2.6456 -0.5858 -0.2609 -0.0651 3.1982#### Coefficients:## Estimate Std. Error z value Pr(>|z|)## (Intercept) 0.07882 0.21726 0.363 0.71675## age 0.41119 0.01857 22.146 < 2e-16 ***## workclassLocal-gov -0.64018 0.09396 -6.813 9.54e-12 ***## workclassPrivate -0.53542 0.07886 -6.789 1.13e-11 ***## workclassSelf-emp-inc -0.07733 0.10350 -0.747 0.45499## workclassSelf-emp-not-inc -1.09052 0.09140 -11.931 < 2e-16 ***## workclassState-gov -0.80562 0.10617 -7.588 3.25e-14 ***## workclassWithout-pay -1.09765 0.86787 -1.265 0.20596## educationCommunity -0.44436 0.08267 -5.375 7.66e-08 ***## educationHighGrad -0.67613 0.11827 -5.717 1.08e-08 ***## educationMaster 0.35651 0.06780 5.258 1.46e-07 ***## educationPhD 0.46995 0.15772 2.980 0.00289 **## educationdropout -1.04974 0.21280 -4.933 8.10e-07 ***## educational.num 0.56908 0.07063 8.057 7.84e-16 ***## marital.statusNot_married -2.50346 0.05113 -48.966 < 2e-16 ***## marital.statusSeparated -2.16177 0.05425 -39.846 < 2e-16 ***## marital.statusWidow -2.22707 0.12522 -17.785 < 2e-16 ***## raceAsian-Pac-Islander 0.08359 0.20344 0.411 0.68117## raceBlack 0.07188 0.19330 0.372 0.71001## raceOther 0.01370 0.27695 0.049 0.96054## raceWhite 0.34830 0.18441 1.889 0.05894 .## genderMale 0.08596 0.04289 2.004 0.04506 *## hours.per.week 0.41942 0.01748 23.998 < 2e-16 ***## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1## ## (Dispersion parameter for binomial family taken to be 1)## ## Null deviance: 40601 on 36428 degrees of freedom## Residual deviance: 27041 on 36406 degrees of freedom## AIC: 27087#### Number of Fisher Scoring iterations: 6Rezumatul modelului nostru dezvăluie informații interesante. Performanța unei regresii logistice este evaluată cu valori cheie specifice.
- AIC (Akaike Information Criteria): Acesta este echivalentul lui R2 în regresia logistică. Măsoară potrivirea atunci când se aplică o penalizare la numărul de parametri. Valorile AIC mai mici indică faptul că modelul este mai aproape de adevăr.
- Devianță nulă: se potrivește modelului numai cu interceptarea. Gradul de libertate este n-1. O putem interpreta ca o valoare Chi-pătrat (valoare potrivită diferită de testarea ipotezei valorii reale).
- Devianță reziduală: Model cu toate variabilele. De asemenea, este interpretat ca o testare a ipotezei Chi-pătrat.
- Numărul de iterații Fisher Scoring: Numărul de iterații înainte de convergență.
Ieșirea funcției glm () este stocată într-o listă. Codul de mai jos arată toate elementele disponibile în variabila logit pe care am construit-o pentru a evalua regresia logistică.
# Lista este foarte lungă, tipăriți doar primele trei elemente
lapply(logit, class)[1:3]Ieșire:
## $coefficients## [1] "numeric"#### $residuals## [1] "numeric"#### $fitted.values## [1] "numeric"Fiecare valoare poate fi extrasă cu semnul $ urmat de numele valorilor. De exemplu, ați stocat modelul ca logit. Pentru a extrage criteriile AIC, utilizați:
logit$aicIeșire:
## [1] 27086.65Pasul 7) Evaluați performanța modelului
Matricea confuziei
Matricea de confuzie este o alegere mai bună pentru a evalua performanța de clasificare în comparație cu diferite valori ați mai văzut înainte. Ideea generală este să numărăm de câte ori sunt clasificate instanțele adevărate ca fiind false.
Pentru a calcula matricea de confuzie, trebuie mai întâi să aveți un set de predicții, astfel încât acestea să poată fi comparate cu țintele reale.
predict <- predict(logit, data_test, type = 'response')# confusion matrixtable_mat <- table(data_test$income, predict > 0.5)table_matExplicarea codului
- predict (logit, data_test, type = 'response'): calculați predicția pe setul de testare. Setați tipul = „răspuns” pentru a calcula probabilitatea de răspuns.
- tabel (date_test $ venit, prezice> 0,5): calculați matricea de confuzie. predict> 0,5 înseamnă că returnează 1 dacă probabilitățile prezise sunt peste 0,5, altfel 0.
Ieșire:
#### FALSE TRUE## <=50K 6310 495## >50K 1074 1229Fiecare rând dintr-o matrice de confuzie reprezintă o țintă reală, în timp ce fiecare coloană reprezintă o țintă prevăzută. Primul rând al acestei matrice consideră venitul mai mic de 50k (clasa False): 6241 au fost corect clasificați ca indivizi cu venituri mai mici de 50k ( Adevărat negativ ), în timp ce cel rămas a fost clasificat în mod greșit ca fiind peste 50k ( False pozitive ). Al doilea rând consideră venitul peste 50k, clasa pozitivă a fost 1229 ( Adevărat pozitiv ), în timp ce Adevărul negativ a fost 1074.
Puteți calcula precizia modelului însumând adevăratul pozitiv + adevăratul negativ peste observația totală
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_TestExplicarea codului
- sum (diag (table_mat)): Suma diagonalei
- sum (table_mat): Suma matricei.
Ieșire:
## [1] 0.8277339Modelul pare să sufere de o problemă, supraestimează numărul de negative negative. Aceasta se numește paradoxul testului de precizie . Am afirmat că acuratețea este raportul dintre predicțiile corecte și numărul total de cazuri. Putem avea o precizie relativ mare, dar un model inutil. Se întâmplă atunci când există o clasă dominantă. Dacă vă uitați înapoi la matricea de confuzie, puteți vedea că majoritatea cazurilor sunt clasificate drept adevărate negative. Imaginați-vă acum, modelul a clasificat toate clasele ca fiind negative (adică mai mici de 50k). Ați avea o precizie de 75% (6718/6718 + 2257). Modelul dvs. funcționează mai bine, dar se luptă să distingă adevăratul pozitiv de adevăratul negativ.
Într-o astfel de situație, este de preferat să aveți o metrică mai concisă. Ne putem uita la:
- Precizie = TP / (TP + FP)
- Recall = TP / (TP + FN)
Precizie vs Recall
Precizia privește acuratețea predicției pozitive. Recall este raportul instanțelor pozitive care sunt detectate corect de clasificator;
Puteți construi două funcții pentru a calcula aceste două valori
- Construiți precizie
precision <- function(matrix) {# True positivetp <- matrix[2, 2]# false positivefp <- matrix[1, 2]return (tp / (tp + fp))}Explicarea codului
- mat [1,1]: Returnează prima celulă din prima coloană a cadrului de date, adică adevăratul pozitiv
- mat [1,2]; Returnează prima celulă din a doua coloană a cadrului de date, adică falsul pozitiv
recall <- function(matrix) {# true positivetp <- matrix[2, 2]# false positivefn <- matrix[2, 1]return (tp / (tp + fn))}Explicarea codului
- mat [1,1]: Returnează prima celulă din prima coloană a cadrului de date, adică adevăratul pozitiv
- mat [2,1]; Returnează a doua celulă a primei coloane a cadrului de date, adică falsul negativ
Vă puteți testa funcțiile
prec <- precision(table_mat)precrec <- recall(table_mat)recIeșire:
## [1] 0.712877## [2] 0.5336518Atunci când modelul spune că este o persoană de peste 50.000 de ori, este corectă doar în 54 la sută din cazuri și poate revendica persoane de peste 50.000 în 72 la sută din caz.
Puteți crea o medie armonică a acestor două valori, ceea ce înseamnă că dă mai multă greutate valorilor inferioare.
f1 <- 2 * ((prec * rec) / (prec + rec))f1Ieșire:
## [1] 0.6103799Compensare Precision vs Recall
Este imposibil să aveți atât o precizie ridicată, cât și o rechemare ridicată.
Dacă creștem precizia, individul corect va fi mai bine prezis, dar ne-ar fi dor de multe dintre ele (rechemare mai mică). În unele situații, preferăm o precizie mai mare decât amintirea. Există o relație concavă între precizie și rechemare.
- Imaginați-vă, trebuie să preziceți dacă un pacient are o boală. Vrei să fii cât mai precis posibil.
- Dacă trebuie să detectați potențiali oameni frauduloși pe stradă prin recunoașterea facială, ar fi mai bine să prindeți mulți oameni etichetați ca fiind frauduloși, chiar dacă precizia este scăzută. Poliția va putea elibera persoana care nu este frauduloasă.
Curba ROC
Caracteristica Receiver Operating Curba este un alt instrument comun folosit cu clasificarea binară. Este foarte asemănătoare cu curba de precizie / rechemare, dar în loc să traseze precizia versus rechemarea, curba ROC arată adevărata rată pozitivă (adică rechemarea) față de rata fals pozitivă. Rata fals pozitivă este raportul instanțelor negative care sunt incorect clasificate ca pozitive. Este egal cu un minus adevărata rată negativă. Adevărata rată negativă se mai numește și specificitate . Prin urmare, curba ROC trasează sensibilitatea (rechemarea) versus specificitatea 1
Pentru a trasa curba ROC, trebuie să instalăm o bibliotecă numită RORC. Putem găsi în biblioteca conda. Puteți introduce codul:
conda install -cr r-rocr - da
Putem parcela ROC cu funcțiile de predicție () și performanță ().
library(ROCR)ROCRpred <- prediction(predict, data_test$income)ROCRperf <- performance(ROCRpred, 'tpr', 'fpr')plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))Explicarea codului
- predicție (predict, data_test $ venit): biblioteca ROCR trebuie să creeze un obiect de predicție pentru a transforma datele de intrare
- performanță (ROCRpred, 'tpr', 'fpr'): returnează cele două combinații pentru a le produce în grafic. Aici se construiesc tpr și fpr. Complotați cu precizie și reamintiți împreună, folosiți „prec”, „rec”.
Ieșire:
Pasul 8) Îmbunătățiți modelul
Puteți încerca să adăugați non-liniaritate modelului cu interacțiunea dintre
- vârsta și orele.pe săptămână
- sex și ore.per.saptamana.
Trebuie să utilizați testul de scor pentru a compara ambele modele
formula_2 <- income~age: hours.per.week + gender: hours.per.week + .logit_2 <- glm(formula_2, data = data_train, family = 'binomial')predict_2 <- predict(logit_2, data_test, type = 'response')table_mat_2 <- table(data_test$income, predict_2 > 0.5)precision_2 <- precision(table_mat_2)recall_2 <- recall(table_mat_2)f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2))f1_2Ieșire:
## [1] 0.6109181Scorul este puțin mai mare decât cel precedent. Puteți continua să lucrați la date și să încercați să bateți scorul.
rezumat
Putem rezuma funcția de a antrena o regresie logistică în tabelul de mai jos:
Pachet
Obiectiv
funcţie
argument
-
Creați setul de date tren / test
create_train_set ()
date, dimensiune, tren
glm
Instruiți un model liniar generalizat
glm ()
formula, date, familie *
glm
Rezumați modelul
rezumat()
model montat
baza
Faceți predicție
prezice()
model adaptat, set de date, tip = 'răspuns'
baza
Creați o matrice de confuzie
masa()
y, predict ()
baza
Creează un scor de precizie
sum (diag (table ()) / sum (table ()
ROCR
Creați ROC: Pasul 1 Creați predicție
predicție ()
predict (), y
ROCR
Creați ROC: Pasul 2 Creați performanță
performanţă()
predicție (), „tpr”, „fpr”
ROCR
Creați ROC: Pasul 3 Trasați graficul
complot ()
performanţă()
Celelalte tipuri de modele GLM sunt:
- binom: (link = "logit")
- gaussian: (link = "identitate")
- Gamma: (link = "invers")
- invers.gaussian: (link = "1 / mu 2")
- poisson: (link = "jurnal")
- cvasi: (link = "identitate", varianță = "constantă")
- cvasibinomial: (link = "logit")
- quasipoisson: (link = "jurnal")