GLM ใน R: โมเดลเชิงเส้นทั่วไปพร้อมตัวอย่าง
การถดถอยโลจิสติกคืออะไร?
การถดถอยโลจิสติกใช้ในการทำนายคลาส เช่น ความน่าจะเป็น การถดถอยโลจิสติกสามารถทำนายผลลัพธ์ไบนารี่ได้อย่างแม่นยำ
ลองจินตนาการว่าคุณต้องการคาดการณ์ว่าเงินกู้จะถูกปฏิเสธ/ยอมรับโดยพิจารณาจากคุณลักษณะหลายประการหรือไม่ การถดถอยโลจิสติกอยู่ในรูปแบบ 0/1 y = 0 หากเงินกู้ถูกปฏิเสธ y = 1 หากได้รับการยอมรับ
แบบจำลองการถดถอยโลจิสติกแตกต่างจากแบบจำลองการถดถอยเชิงเส้นในสองวิธี
- ประการแรก การถดถอยโลจิสติกยอมรับเฉพาะอินพุตแบบไดโคโตมัส (ไบนารี่) เป็นตัวแปรตาม (เช่น เวกเตอร์ 0 และ 1)
- ประการที่สอง ผลลัพธ์จะถูกวัดโดยฟังก์ชันลิงก์ความน่าจะเป็นต่อไปนี้เรียกว่า ซิกมอยด์ เนื่องจากมีลักษณะเป็นรูปตัว S:
ผลลัพธ์ของฟังก์ชันจะอยู่ระหว่าง 0 ถึง 1 เสมอ ตรวจสอบรูปภาพด้านล่าง
ฟังก์ชัน sigmoid ส่งคืนค่าจาก 0 ถึง 1 สำหรับงานการจัดหมวดหมู่ เราจำเป็นต้องมีเอาต์พุตแบบไม่ต่อเนื่องเป็น 0 หรือ 1
ในการแปลงการไหลต่อเนื่องเป็นค่าที่ไม่ต่อเนื่อง เราสามารถกำหนดขอบเขตการตัดสินใจไว้ที่ 0.5 ค่าทั้งหมดที่อยู่เหนือเกณฑ์นี้จัดเป็น 1
วิธีสร้าง Generalized Liner Model (GLM)
มาใช้ไฟล์ ผู้ใหญ่ ชุดข้อมูลเพื่อแสดงการถดถอยลอจิสติก “ผู้ใหญ่” เป็นชุดข้อมูลที่ยอดเยี่ยมสำหรับงานจำแนกประเภท วัตถุประสงค์คือเพื่อคาดการณ์ว่ารายได้ต่อปีในสกุลเงินดอลลาร์ของบุคคลจะเกิน 50.000 หรือไม่ ชุดข้อมูลประกอบด้วยข้อสังเกต 46,033 รายการและคุณลักษณะ XNUMX ประการ:
- อายุ: อายุของบุคคล ตัวเลข
- การศึกษา: ระดับการศึกษาของแต่ละบุคคล ปัจจัย.
- สถานะการสมรส: Mariสถานะของบุคคล ปัจจัย ได้แก่ ไม่เคยแต่งงาน, แต่งงานแล้ว-คู่สมรส, …
- เพศ: เพศของบุคคล ปัจจัย เช่น ชายหรือหญิง
- รายได้: Target ตัวแปร. รายได้สูงกว่าหรือต่ำกว่า 50K ปัจจัย เช่น >50K, <=50K
ท่ามกลางคนอื่น ๆ
library(dplyr) data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv") glimpse(data_adult)
Output:
Observations: 48,842 Variables: 10 $ x <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,... $ age <int> 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26... $ workclass <fctr> Private, Private, Local-gov, Private, ?, Private,... $ education <fctr> 11th, HS-grad, Assoc-acdm, Some-college, Some-col... $ educational.num <int> 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,... $ marital.status <fctr> Never-married, Married-civ-spouse, Married-civ-sp... $ race <fctr> Black, White, White, Black, White, White, Black, ... $ gender <fctr> Male, Male, Male, Male, Female, Male, Male, Male,... $ hours.per.week <int> 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39... $ income <fctr> <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5...
เราจะดำเนินการดังนี้:
- ขั้นตอนที่ 1: ตรวจสอบตัวแปรต่อเนื่อง
- ขั้นตอนที่ 2: ตรวจสอบตัวแปรปัจจัย
- ขั้นตอนที่ 3: วิศวกรรมคุณลักษณะ
- ขั้นตอนที่ 4: สถิติสรุป
- ขั้นตอนที่ 5: ฝึก/ชุดทดสอบ
- ขั้นตอนที่ 6: สร้างแบบจำลอง
- ขั้นตอนที่ 7: ประเมินประสิทธิภาพของแบบจำลอง
- ขั้นตอนที่ 8: ปรับปรุงโมเดล
งานของคุณคือคาดการณ์ว่าบุคคลใดจะมีรายได้สูงกว่า 50K
ในบทช่วยสอนนี้ แต่ละขั้นตอนจะมีรายละเอียดเพื่อทำการวิเคราะห์ชุดข้อมูลจริง
ขั้นตอนที่ 1) ตรวจสอบตัวแปรต่อเนื่อง
ในขั้นตอนแรก คุณจะเห็นการกระจายตัวของตัวแปรต่อเนื่อง
continuous <-select_if(data_adult, is.numeric) summary(continuous)
คำอธิบายรหัส
- ต่อเนื่อง <- select_if(data_adult, is.numeric): ใช้ฟังก์ชัน select_if() จากไลบรารี dplyr เพื่อเลือกเฉพาะคอลัมน์ตัวเลข
- summary(continuous): พิมพ์สถิติสรุป
Output:
## X age educational.num hours.per.week ## Min. : 1 Min. :17.00 Min. : 1.00 Min. : 1.00 ## 1st Qu.:11509 1st Qu.:28.00 1st Qu.: 9.00 1st Qu.:40.00 ## Median :23017 Median :37.00 Median :10.00 Median :40.00 ## Mean :23017 Mean :38.56 Mean :10.13 Mean :40.95 ## 3rd Qu.:34525 3rd Qu.:47.00 3rd Qu.:13.00 3rd Qu.:45.00 ## Max. :46033 Max. :90.00 Max. :16.00 Max. :99.00
จากตารางด้านบน คุณจะเห็นได้ว่าข้อมูลมีมาตราส่วนที่ต่างกันโดยสิ้นเชิง และชั่วโมงต่อสัปดาห์มีค่าผิดปกติที่มาก (เช่น ดูที่ควอร์ไทล์สุดท้ายและค่าสูงสุด)
คุณสามารถจัดการกับมันได้โดยปฏิบัติตามสองขั้นตอน:
- 1: พล็อตการกระจายตัวของชั่วโมงต่อสัปดาห์
- 2: สร้างมาตรฐานให้กับตัวแปรต่อเนื่อง
- พล็อตการกระจาย
มาดูการกระจายของชั่วโมงต่อสัปดาห์กันอย่างใกล้ชิด
# Histogram with kernel density curve library(ggplot2) ggplot(continuous, aes(x = hours.per.week)) + geom_density(alpha = .2, fill = "#FF6666")
Output:
ตัวแปรมีค่าผิดปกติจำนวนมากและไม่มีการกระจายที่ชัดเจน คุณสามารถแก้ปัญหานี้บางส่วนได้โดยการลบ 0.01 เปอร์เซ็นต์ชั่วโมงสูงสุดในแต่ละสัปดาห์
ไวยากรณ์พื้นฐานของควอนไทล์:
quantile(variable, percentile) arguments: -variable: Select the variable in the data frame to compute the percentile -percentile: Can be a single value between 0 and 1 or multiple value. If multiple, use this format: `c(A,B,C, ...) - `A`,`B`,`C` and `...` are all integer from 0 to 1.
เราคำนวณเปอร์เซ็นต์ไทล์ 2 อันดับแรก
top_one_percent <- quantile(data_adult$hours.per.week, .99) top_one_percent
คำอธิบายรหัส
- quantile(data_adult$hours.per.week, .99): คำนวณค่า 99 เปอร์เซ็นต์ของเวลาทำงาน
Output:
## 99% ## 80
98 เปอร์เซ็นต์ของประชากรทำงานน้อยกว่า 80 ชั่วโมงต่อสัปดาห์
คุณสามารถยกเลิกการสังเกตเหนือเกณฑ์นี้ได้ คุณใช้ตัวกรองจาก ดีพลีร์ ห้องสมุด.
data_adult_drop <-data_adult %>% filter(hours.per.week<top_one_percent) dim(data_adult_drop)
Output:
## [1] 45537 10
- สร้างมาตรฐานให้กับตัวแปรต่อเนื่อง
คุณสามารถกำหนดมาตรฐานแต่ละคอลัมน์เพื่อปรับปรุงประสิทธิภาพได้เนื่องจากข้อมูลของคุณไม่มีขนาดเท่ากัน คุณสามารถใช้ฟังก์ชัน mutate_if จากไลบรารี dplyr ไวยากรณ์พื้นฐานคือ:
mutate_if(df, condition, funs(function)) arguments: -`df`: Data frame used to compute the function - `condition`: Statement used. Do not use parenthesis - funs(function): Return the function to apply. Do not use parenthesis for the function
คุณสามารถกำหนดมาตรฐานคอลัมน์ตัวเลขได้ดังนี้:
data_adult_rescale <- data_adult_drop % > % mutate_if(is.numeric, funs(as.numeric(scale(.)))) head(data_adult_rescale)
คำอธิบายรหัส
- mutate_if(is.numeric, funs(scale)): เงื่อนไขเป็นเพียงคอลัมน์ตัวเลขและฟังก์ชันคือมาตราส่วน
Output:
## X age workclass education educational.num ## 1 -1.732680 -1.02325949 Private 11th -1.22106443 ## 2 -1.732605 -0.03969284 Private HS-grad -0.43998868 ## 3 -1.732530 -0.79628257 Local-gov Assoc-acdm 0.73162494 ## 4 -1.732455 0.41426100 Private Some-college -0.04945081 ## 5 -1.732379 -0.34232873 Private 10th -1.61160231 ## 6 -1.732304 1.85178149 Self-emp-not-inc Prof-school 1.90323857 ## marital.status race gender hours.per.week income ## 1 Never-married Black Male -0.03995944 <=50K ## 2 Married-civ-spouse White Male 0.86863037 <=50K ## 3 Married-civ-spouse White Male -0.03995944 >50K ## 4 Married-civ-spouse Black Male -0.03995944 >50K ## 5 Never-married White Male -0.94854924 <=50K ## 6 Married-civ-spouse White Male -0.76683128 >50K
ขั้นตอนที่ 2) ตรวจสอบตัวแปรปัจจัย
ขั้นตอนนี้มีวัตถุประสงค์สองประการ:
- ตรวจสอบระดับในแต่ละคอลัมน์หมวดหมู่
- กำหนดระดับใหม่
เราจะแบ่งขั้นตอนนี้ออกเป็นสามส่วน:
- เลือกคอลัมน์หมวดหมู่
- เก็บแผนภูมิแท่งของแต่ละคอลัมน์ในรายการ
- พิมพ์กราฟ
เราสามารถเลือกคอลัมน์ปัจจัยด้วยรหัสด้านล่าง:
# Select categorical column factor <- data.frame(select_if(data_adult_rescale, is.factor)) ncol(factor)
คำอธิบายรหัส
- data.frame(select_if(data_adult, is.factor)): เราจัดเก็บคอลัมน์ตัวประกอบเป็นตัวประกอบในประเภทเฟรมข้อมูล ไลบรารี ggplot2 ต้องการวัตถุกรอบข้อมูล
Output:
## [1] 6
ชุดข้อมูลประกอบด้วยตัวแปรหมวดหมู่ 6 ตัว
ขั้นตอนที่สองมีความชำนาญมากขึ้น คุณต้องการสร้างแผนภูมิแท่งสำหรับแต่ละคอลัมน์ในตัวประกอบกรอบข้อมูล การทำให้กระบวนการเป็นแบบอัตโนมัติสะดวกกว่า โดยเฉพาะอย่างยิ่งในสถานการณ์ที่มีคอลัมน์จำนวนมาก
library(ggplot2) # Create graph for each column graph <- lapply(names(factor), function(x) ggplot(factor, aes(get(x))) + geom_bar() + theme(axis.text.x = element_text(angle = 90)))
คำอธิบายรหัส
- lapply() : ใช้ฟังก์ชัน lapply() เพื่อส่งผ่านฟังก์ชันในทุกคอลัมน์ของชุดข้อมูล คุณเก็บผลลัพธ์ไว้ในรายการ
- function(x): ฟังก์ชันจะถูกประมวลผลสำหรับแต่ละ x โดยที่ x คือคอลัมน์
- ggplot(factor, aes(get(x))) + geom_bar()+ theme(axis.text.x = element_text(angle = 90)): สร้างแผนภูมิแท่งถ่านสำหรับแต่ละองค์ประกอบ x หมายเหตุ หากต้องการคืนค่า x เป็นคอลัมน์ คุณต้องรวมไว้ใน get()
ขั้นตอนสุดท้ายค่อนข้างง่าย คุณต้องการพิมพ์ 6 กราฟ
# Print the graph graph
Output:
## [[1]]
## ## [[2]]
## ## [[3]]
## ## [[4]]
## ## [[5]]
## ## [[6]]
หมายเหตุ: ใช้ปุ่มถัดไปเพื่อนำทางไปยังกราฟถัดไป
ขั้นตอนที่ 3) คุณสมบัติทางวิศวกรรม
หล่อหลอมการศึกษา
จากกราฟด้านบนจะเห็นว่าการศึกษาตัวแปรมี 16 ระดับ นี่เป็นเรื่องสำคัญ และบางระดับก็มีจำนวนการสังเกตค่อนข้างน้อย หากคุณต้องการปรับปรุงปริมาณข้อมูลที่คุณจะได้รับจากตัวแปรนี้ คุณสามารถเขียนใหม่ให้อยู่ในระดับที่สูงขึ้นได้ กล่าวคือ คุณสร้างกลุ่มขนาดใหญ่ขึ้นโดยมีระดับการศึกษาใกล้เคียงกัน ตัวอย่างเช่น การศึกษาในระดับต่ำจะถูกเปลี่ยนให้อยู่ในภาวะออกจากกลางคัน ระดับการศึกษาที่สูงขึ้นจะเปลี่ยนเป็นปริญญาโท
นี่คือรายละเอียด:
ระดับเก่า | ระดับใหม่ |
---|---|
เด็กก่อนวัยเรียน | การออกกลางคัน |
10th | การออกกลางคัน |
11th | การออกกลางคัน |
12th | การออกกลางคัน |
1st-4th | การออกกลางคัน |
5th-6th | การออกกลางคัน |
7th-8th | การออกกลางคัน |
9th | การออกกลางคัน |
HS-Grad | มัธยมศึกษาตอนปลาย |
วิทยาลัยบางแห่ง | สังคม |
รศ.อ | สังคม |
รศ | สังคม |
ปริญญาตรี | ปริญญาตรี |
ปริญญาโท | ปริญญาโท |
ศ.-โรงเรียน | ปริญญาโท |
ปริญญาเอก | ปริญญาเอก |
recast_data <- data_adult_rescale % > % select(-X) % > % mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community", ifelse(education == "Bachelors", "Bachelors", ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))
คำอธิบายรหัส
- เราใช้คำกริยากลายพันธุ์จากไลบรารี dplyr เราเปลี่ยนคุณค่าของการศึกษาด้วยคำว่า ifelse
ในตารางด้านล่าง คุณสร้างสถิติสรุปเพื่อดูโดยเฉลี่ยว่าต้องใช้เวลากี่ปีในการศึกษา (ค่า z) เพื่อเข้าเรียนในระดับปริญญาตรี ปริญญาโท หรือปริญญาเอก
recast_data % > % group_by(education) % > % summarize(average_educ_year = mean(educational.num), count = n()) % > % arrange(average_educ_year)
Output:
## # A tibble: 6 x 3 ## education average_educ_year count ## <fctr> <dbl> <int> ## 1 dropout -1.76147258 5712 ## 2 HighGrad -0.43998868 14803 ## 3 Community 0.09561361 13407 ## 4 Bachelors 1.12216282 7720 ## 5 Master 1.60337381 3338 ## 6 PhD 2.29377644 557
แต่งใหม่ Mariสถานะทัล
นอกจากนี้ยังสามารถสร้างระดับที่ต่ำกว่าสำหรับสถานะการสมรสได้อีกด้วย ในโค้ดต่อไปนี้ คุณสามารถเปลี่ยนระดับได้ดังนี้:
ระดับเก่า | ระดับใหม่ |
---|---|
ไม่เคยแต่งงาน | ยังไม่แต่งงาน |
แต่งงาน-คู่สมรส-ไม่อยู่ | ยังไม่แต่งงาน |
แต่งงาน-AF-คู่สมรส | แต่งงาน |
สมรส-พลเมือง-คู่สมรส | |
แยก | แยก |
หย่า | |
แม่ม่าย | แม่ม่าย |
# Change level marry recast_data <- recast_data % > % mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))
คุณสามารถตรวจสอบจำนวนบุคคลในแต่ละกลุ่มได้
table(recast_data$marital.status)
Output:
## ## Married Not_married Separated Widow ## 21165 15359 7727 1286
ขั้นตอนที่ 4) สถิติสรุป
ถึงเวลาตรวจสอบสถิติเกี่ยวกับตัวแปรเป้าหมายของเราแล้ว ในกราฟด้านล่าง คุณนับเปอร์เซ็นต์ของบุคคลที่มีรายได้มากกว่า 50 ตามเพศ
# Plot gender income ggplot(recast_data, aes(x = gender, fill = income)) + geom_bar(position = "fill") + theme_classic()
Output:
จากนั้น ตรวจสอบว่าที่มาของบุคคลส่งผลต่อรายได้หรือไม่
# Plot origin income ggplot(recast_data, aes(x = race, fill = income)) + geom_bar(position = "fill") + theme_classic() + theme(axis.text.x = element_text(angle = 90))
Output:
จำนวนชั่วโมงการทำงานแยกตามเพศ
# box plot gender working time ggplot(recast_data, aes(x = gender, y = hours.per.week)) + geom_boxplot() + stat_summary(fun.y = mean, geom = "point", size = 3, color = "steelblue") + theme_classic()
Output:
แผนภาพกล่องยืนยันว่าการกระจายเวลาทำงานเหมาะกับกลุ่มต่างๆ ในแผนภาพกล่อง ทั้งสองเพศไม่มีการสังเกตที่เป็นเนื้อเดียวกัน
คุณสามารถตรวจสอบความหนาแน่นของเวลาทำงานรายสัปดาห์ตามประเภทการศึกษา การแจกแจงมีตัวเลือกที่แตกต่างกันมากมาย อาจอธิบายได้จากประเภทของสัญญาในสหรัฐอเมริกา
# Plot distribution working time by education ggplot(recast_data, aes(x = hours.per.week)) + geom_density(aes(color = education), alpha = 0.5) + theme_classic()
คำอธิบายรหัส
- ggplot(recast_data, aes( x= hours.per.week)): กราฟความหนาแน่นต้องการตัวแปรเพียงตัวเดียว
- geom_density(aes(color = educational), alpha =0.5): วัตถุทางเรขาคณิตที่ใช้ควบคุมความหนาแน่น
Output:
เพื่อยืนยันความคิดของคุณ คุณสามารถดำเนินการทางเดียวได้ การทดสอบ ANOVA:
anova <- aov(hours.per.week~education, recast_data) summary(anova)
Output:
## Df Sum Sq Mean Sq F value Pr(>F) ## education 5 1552 310.31 321.2 <2e-16 *** ## Residuals 45531 43984 0.97 ## --- ## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
การทดสอบ ANOVA ยืนยันความแตกต่างในค่าเฉลี่ยระหว่างกลุ่ม
แบบไม่เชิงเส้น
ก่อนที่คุณจะรันโมเดล คุณสามารถดูว่าจำนวนชั่วโมงที่ทำงานเกี่ยวข้องกับอายุหรือไม่
library(ggplot2) ggplot(recast_data, aes(x = age, y = hours.per.week)) + geom_point(aes(color = income), size = 0.5) + stat_smooth(method = 'lm', formula = y~poly(x, 2), se = TRUE, aes(color = income)) + theme_classic()
คำอธิบายรหัส
- ggplot(recast_data, aes(x = age, y = hours.per.week)): กำหนดความสวยงามของกราฟ
- geom_point(aes(color= Income), size =0.5): สร้างจุดพล็อต
- stat_smooth(): เพิ่มเส้นแนวโน้มด้วยอาร์กิวเมนต์ต่อไปนี้:
- method='lm': พล็อตค่าที่ติดตั้งถ้า การถดถอยเชิงเส้น
- สูตร = y~poly(x,2): ปรับการถดถอยพหุนามให้พอดี
- se = TRUE: เพิ่มข้อผิดพลาดมาตรฐาน
- aes(color=income): แบ่งโมเดลตามรายได้
Output:
โดยสรุป คุณสามารถทดสอบเงื่อนไขการโต้ตอบในแบบจำลองเพื่อรับผลกระทบที่ไม่เป็นเชิงเส้นระหว่างเวลาทำงานรายสัปดาห์และคุณสมบัติอื่นๆ สิ่งสำคัญคือต้องตรวจสอบว่าเวลาทำงานแตกต่างกันภายใต้สภาวะใด
ความสัมพันธ์
การตรวจสอบครั้งต่อไปคือการแสดงภาพความสัมพันธ์ระหว่างตัวแปรต่างๆ คุณแปลงประเภทระดับปัจจัยเป็นตัวเลข เพื่อให้คุณสามารถพล็อตแผนที่ความร้อนที่มีค่าสัมประสิทธิ์สหสัมพันธ์ที่คำนวณด้วยวิธี Spearman
library(GGally) # Convert data to numeric corr <- data.frame(lapply(recast_data, as.integer)) # Plot the graphggcorr(corr, method = c("pairwise", "spearman"), nbreaks = 6, hjust = 0.8, label = TRUE, label_size = 3, color = "grey50")
คำอธิบายรหัส
- data.frame(lapply(recast_data,as.integer)): แปลงข้อมูลเป็นตัวเลข
- ggcorr() พล็อตแผนที่ความร้อนด้วยอาร์กิวเมนต์ต่อไปนี้:
- วิธีการ: วิธีการคำนวณความสัมพันธ์
- nbreaks = 6: จำนวนการพัก
- hjust = 0.8: ตำแหน่งควบคุมของชื่อตัวแปรในพล็อต
- ป้าย = TRUE: เพิ่มป้ายชื่อที่กึ่งกลางของหน้าต่าง
- label_size = 3: ป้ายขนาด
- color = “grey50”): สีของฉลาก
Output:
ขั้นตอนที่ 5) ฝึก/ชุดทดสอบ
กำกับดูแลแต่อย่างใด เรียนรู้เครื่อง งานจำเป็นต้องแบ่งข้อมูลระหว่างชุดรถไฟและชุดทดสอบ คุณสามารถใช้ "ฟังก์ชัน" ที่คุณสร้างขึ้นในบทช่วยสอนการเรียนรู้ภายใต้การดูแลอื่นๆ เพื่อสร้างชุดฝึก/ชุดทดสอบ
set.seed(1234) create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample <- 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } } data_train <- create_train_test(recast_data, 0.8, train = TRUE) data_test <- create_train_test(recast_data, 0.8, train = FALSE) dim(data_train)
Output:
## [1] 36429 9
dim(data_test)
Output:
## [1] 9108 9
ขั้นตอนที่ 6) สร้างแบบจำลอง
หากต้องการดูว่าอัลกอริทึมทำงานอย่างไร คุณใช้แพ็คเกจ glm() ที่ โมเดลเชิงเส้นทั่วไป เป็นการรวบรวมโมเดล ไวยากรณ์พื้นฐานคือ:
glm(formula, data=data, family=linkfunction() Argument: - formula: Equation used to fit the model- data: dataset used - Family: - binomial: (link = "logit") - gaussian: (link = "identity") - Gamma: (link = "inverse") - inverse.gaussian: (link = "1/mu^2") - poisson: (link = "log") - quasi: (link = "identity", variance = "constant") - quasibinomial: (link = "logit") - quasipoisson: (link = "log")
คุณพร้อมที่จะประมาณแบบจำลองลอจิสติกส์เพื่อแบ่งระดับรายได้ระหว่างชุดคุณลักษณะต่างๆ
formula <- income~. logit <- glm(formula, data = data_train, family = 'binomial') summary(logit)
คำอธิบายรหัส
- สูตร <- รายได้ ~ .: สร้างโมเดลให้ลงตัว
- logit <- glm(formula, data = data_train, family = 'binomial'): ปรับโมเดลลอจิสติกส์ (family = 'binomial') ด้วยข้อมูล data_train
- summary(logit): พิมพ์ข้อมูลสรุปของโมเดล
Output:
## ## Call: ## glm(formula = formula, family = "binomial", data = data_train) ## ## Deviance Residuals: ## Min 1Q Median 3Q Max ## -2.6456 -0.5858 -0.2609 -0.0651 3.1982 ## ## Coefficients: ## Estimate Std. Error z value Pr(>|z|) ## (Intercept) 0.07882 0.21726 0.363 0.71675 ## age 0.41119 0.01857 22.146 < 2e-16 *** ## workclassLocal-gov -0.64018 0.09396 -6.813 9.54e-12 *** ## workclassPrivate -0.53542 0.07886 -6.789 1.13e-11 *** ## workclassSelf-emp-inc -0.07733 0.10350 -0.747 0.45499 ## workclassSelf-emp-not-inc -1.09052 0.09140 -11.931 < 2e-16 *** ## workclassState-gov -0.80562 0.10617 -7.588 3.25e-14 *** ## workclassWithout-pay -1.09765 0.86787 -1.265 0.20596 ## educationCommunity -0.44436 0.08267 -5.375 7.66e-08 *** ## educationHighGrad -0.67613 0.11827 -5.717 1.08e-08 *** ## educationMaster 0.35651 0.06780 5.258 1.46e-07 *** ## educationPhD 0.46995 0.15772 2.980 0.00289 ** ## educationdropout -1.04974 0.21280 -4.933 8.10e-07 *** ## educational.num 0.56908 0.07063 8.057 7.84e-16 *** ## marital.statusNot_married -2.50346 0.05113 -48.966 < 2e-16 *** ## marital.statusSeparated -2.16177 0.05425 -39.846 < 2e-16 *** ## marital.statusWidow -2.22707 0.12522 -17.785 < 2e-16 *** ## raceAsian-Pac-Islander 0.08359 0.20344 0.411 0.68117 ## raceBlack 0.07188 0.19330 0.372 0.71001 ## raceOther 0.01370 0.27695 0.049 0.96054 ## raceWhite 0.34830 0.18441 1.889 0.05894 . ## genderMale 0.08596 0.04289 2.004 0.04506 * ## hours.per.week 0.41942 0.01748 23.998 < 2e-16 *** ## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 ## ## (Dispersion parameter for binomial family taken to be 1) ## ## Null deviance: 40601 on 36428 degrees of freedom ## Residual deviance: 27041 on 36406 degrees of freedom ## AIC: 27087 ## ## Number of Fisher Scoring iterations: 6
การสรุปแบบจำลองของเราเผยให้เห็นข้อมูลที่น่าสนใจ ประสิทธิภาพของการถดถอยโลจิสติกได้รับการประเมินด้วยตัวชี้วัดหลักเฉพาะ
- AIC (เกณฑ์ข้อมูล Akaike): ซึ่งเทียบเท่ากับ R2 ในการถดถอยโลจิสติก โดยจะวัดความพอดีเมื่อมีการใช้การลงโทษกับจำนวนพารามิเตอร์ เล็กลง AIC ค่าบ่งชี้ว่าแบบจำลองนั้นใกล้เคียงกับความจริงมากขึ้น
- การเบี่ยงเบนแบบ Null: เหมาะกับโมเดลที่มีจุดตัดกันเท่านั้น ระดับความเป็นอิสระคือ n-1 เราสามารถตีความได้ว่าเป็นค่าไคสแควร์ (ค่าที่เหมาะสมแตกต่างจากการทดสอบสมมติฐานค่าจริง)
- ความเบี่ยงเบนตกค้าง: แบบจำลองที่มีตัวแปรทั้งหมด นอกจากนี้ยังตีความว่าเป็นการทดสอบสมมติฐานไคสแควร์ด้วย
- จำนวนการวนซ้ำของ Fisher Scoring: จำนวนการวนซ้ำก่อนที่จะมาบรรจบกัน
ผลลัพธ์ของฟังก์ชัน glm() จะถูกจัดเก็บไว้ในรายการ โค้ดด้านล่างแสดงรายการทั้งหมดที่มีอยู่ในตัวแปร logit ที่เราสร้างขึ้นเพื่อประเมินการถดถอยโลจิสติก
#รายการยาวมาก พิมพ์แค่ 3 องค์ประกอบแรกเท่านั้น
lapply(logit, class)[1:3]
Output:
## $coefficients ## [1] "numeric" ## ## $residuals ## [1] "numeric" ## ## $fitted.values ## [1] "numeric"
แต่ละค่าสามารถแยกออกมาได้ด้วยเครื่องหมาย $ ตามด้วยชื่อของเมตริก ตัวอย่างเช่น คุณจัดเก็บโมเดลเป็นบันทึก หากต้องการแยกเกณฑ์ AIC คุณใช้:
logit$aic
Output:
## [1] 27086.65
ขั้นตอนที่ 7) ประเมินประสิทธิภาพของแบบจำลอง
เมทริกซ์ความสับสน
เทศกาล เมทริกซ์ความสับสน เป็นตัวเลือกที่ดีกว่าในการประเมินประสิทธิภาพการจัดหมวดหมู่โดยเปรียบเทียบกับเมตริกต่างๆ ที่คุณเห็นก่อนหน้านี้ แนวคิดทั่วไปคือการนับจำนวนครั้งที่อินสแตนซ์ True ถูกจัดว่าเป็นเท็จ
ในการคำนวณเมทริกซ์ความสับสน คุณต้องมีชุดการคาดการณ์ก่อนจึงจะสามารถเปรียบเทียบกับเป้าหมายจริงได้
predict <- predict(logit, data_test, type = 'response') # confusion matrix table_mat <- table(data_test$income, predict > 0.5) table_mat
คำอธิบายรหัส
- Predict(logit,data_test, type = 'response'): คำนวณการทำนายบนชุดทดสอบ ตั้งค่า type = 'response' เพื่อคำนวณความน่าจะเป็นในการตอบสนอง
- table(data_test$income, ทำนาย > 0.5): คำนวณเมทริกซ์ความสับสน ทำนาย > 0.5 หมายความว่าจะส่งคืนค่า 1 หากความน่าจะเป็นที่คาดการณ์ไว้สูงกว่า 0.5 มิฉะนั้นจะเป็น 0
Output:
## ## FALSE TRUE ## <=50K 6310 495 ## >50K 1074 1229
แต่ละแถวในเมทริกซ์ความสับสนแสดงถึงเป้าหมายที่แท้จริง ในขณะที่แต่ละคอลัมน์แสดงถึงเป้าหมายที่คาดการณ์ไว้ แถวแรกของเมทริกซ์นี้ถือว่ารายได้ต่ำกว่า 50 (คลาสเท็จ): 6241 ถูกจำแนกอย่างถูกต้องว่าเป็นบุคคลที่มีรายได้ต่ำกว่า 50 (ลบจริง) ในขณะที่ส่วนที่เหลือจัดประเภทผิดว่าเกิน 50 (บวกเท็จ- แถวที่สองพิจารณารายได้ที่สูงกว่า 50 ชั้นที่เป็นบวกคือ 1229 (บวกจริง), ในขณะที่ ลบจริง คือ 1074
คุณสามารถคำนวณแบบจำลองได้ ความถูกต้อง โดยการรวมค่าบวกจริง + ค่าลบจริงเข้ากับการสังเกตทั้งหมด
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test
คำอธิบายรหัส
- sum(diag(table_mat)): ผลรวมของเส้นทแยงมุม
- sum(table_mat): ผลรวมของเมทริกซ์
Output:
## [1] 0.8277339
ดูเหมือนว่าแบบจำลองจะประสบปัญหาเดียว โดยประเมินค่าจำนวนผลลบลวงสูงเกินไป สิ่งนี้เรียกว่า ความขัดแย้งในการทดสอบความแม่นยำ- เราระบุว่าความแม่นยำคืออัตราส่วนของการคาดการณ์ที่ถูกต้องต่อจำนวนเคสทั้งหมด เราสามารถมีความแม่นยำค่อนข้างสูงแต่เป็นโมเดลที่ไม่มีประโยชน์ มันเกิดขึ้นเมื่อมีชนชั้นที่โดดเด่น หากคุณมองย้อนกลับไปที่เมทริกซ์ความสับสน คุณจะเห็นว่ากรณีส่วนใหญ่จัดอยู่ในประเภทเชิงลบที่แท้จริง ลองนึกภาพตอนนี้ โมเดลจำแนกคลาสทั้งหมดเป็นเชิงลบ (เช่น ต่ำกว่า 50) คุณจะมีความแม่นยำ 75 เปอร์เซ็นต์ (6718/6718+2257) แบบจำลองของคุณทำงานได้ดีขึ้นแต่ยังมีความยากลำบากในการแยกแยะระหว่างผลบวกที่แท้จริงกับผลลบที่แท้จริง
ในสถานการณ์เช่นนี้ ควรมีการวัดที่กระชับกว่านี้จะดีกว่า เราสามารถดูได้ที่:
- ความแม่นยำ=TP/(TP+FP)
- เรียกคืน=TP/(TP+FN)
ความแม่นยำเทียบกับการเรียกคืน
ความแม่นยำ ดูความแม่นยำของการทำนายเชิงบวก จำ คืออัตราส่วนของอินสแตนซ์เชิงบวกที่ตัวแยกประเภทตรวจพบอย่างถูกต้อง
คุณสามารถสร้างฟังก์ชันสองฟังก์ชันเพื่อคำนวณเมตริกทั้งสองนี้ได้
- สร้างความแม่นยำ
precision <- function(matrix) { # True positive tp <- matrix[2, 2] # false positive fp <- matrix[1, 2] return (tp / (tp + fp)) }
คำอธิบายรหัส
- mat[1,1]: ส่งกลับเซลล์แรกของคอลัมน์แรกของกรอบข้อมูล นั่นคือค่าบวกที่แท้จริง
- เสื่อ[1,2]; ส่งกลับเซลล์แรกของคอลัมน์ที่สองของกรอบข้อมูล เช่น ผลบวกลวง
recall <- function(matrix) { # true positive tp <- matrix[2, 2]# false positive fn <- matrix[2, 1] return (tp / (tp + fn)) }
คำอธิบายรหัส
- mat[1,1]: ส่งกลับเซลล์แรกของคอลัมน์แรกของกรอบข้อมูล นั่นคือค่าบวกที่แท้จริง
- เสื่อ[2,1]; ส่งกลับเซลล์ที่สองของคอลัมน์แรกของกรอบข้อมูล เช่น ผลลบลวง
คุณสามารถทดสอบการทำงานของคุณได้
prec <- precision(table_mat) prec rec <- recall(table_mat) rec
Output:
## [1] 0.712877 ## [2] 0.5336518
เมื่อแบบจำลองระบุว่าเป็นบุคคลที่อายุมากกว่า 50 ปี ถือว่าถูกต้องเพียง 54 เปอร์เซ็นต์ของคดี และสามารถอ้างสิทธิ์ในบุคคลที่สูงกว่า 50 ใน 72 เปอร์เซ็นต์ของคดีได้
คุณสามารถสร้าง คะแนนขึ้นอยู่กับความแม่นยำและการจดจำ ที่
คือค่าเฉลี่ยฮาร์มอนิกของทั้งสองหน่วยเมตริก ซึ่งหมายความว่าค่าที่ต่ำกว่าจะให้น้ำหนักมากกว่า
f1 <- 2 * ((prec * rec) / (prec + rec)) f1
Output:
## [1] 0.6103799
ความแม่นยำเทียบกับการแลกเปลี่ยนการเรียกคืน
เป็นไปไม่ได้ที่จะมีทั้งความแม่นยำและการเรียกคืนสูง
หากเราเพิ่มความแม่นยำ บุคคลที่ถูกต้องจะถูกคาดการณ์ได้ดีขึ้น แต่เราจะพลาดจำนวนมาก (การเรียกคืนน้อยกว่า) ในบางสถานการณ์ เราต้องการความแม่นยำมากกว่าการเรียกคืน มีความสัมพันธ์แบบเว้าระหว่างความแม่นยำและการจดจำ
- ลองนึกภาพ คุณต้องคาดการณ์ว่าผู้ป่วยมีโรคหรือไม่ คุณต้องการที่จะแม่นยำที่สุด
- หากคุณต้องการตรวจจับผู้ที่อาจฉ้อโกงบนท้องถนนผ่านการจดจำใบหน้า คงจะดีกว่าถ้าจับคนจำนวนมากที่ถูกระบุว่าฉ้อโกงแม้ว่าจะมีความแม่นยำต่ำก็ตาม ตำรวจจะสามารถปล่อยตัวผู้ไม่ฉ้อโกงได้
เส้นโค้ง ROC
เทศกาล ผู้รับ Operaลักษณะเฉพาะ curve เป็นอีกหนึ่งเครื่องมือทั่วไปที่ใช้กับการจำแนกไบนารี มันคล้ายกับเส้นโค้งความแม่นยำ/การเรียกคืนมาก แต่แทนที่จะวางแผนความแม่นยำกับการเรียกคืน เส้นโค้ง ROC จะแสดงอัตราบวกที่แท้จริง (เช่น การเรียกคืน) เทียบกับอัตราบวกลวง อัตราผลบวกลวงคืออัตราส่วนของอินสแตนซ์เชิงลบที่จัดประเภทไม่ถูกต้องว่าเป็นบวก มันเท่ากับ 1 ลบอัตราติดลบจริง อัตราลบที่แท้จริงเรียกอีกอย่างว่า ความจำเพาะ- ดังนั้นกราฟ ROC ความไว (การเรียกคืน) กับ 1 ความจำเพาะ
ในการพล็อตเส้นโค้ง ROC เราจำเป็นต้องติดตั้งไลบรารีชื่อ RORC เราสามารถพบได้ในคอนดา ห้องสมุด- คุณสามารถพิมพ์รหัส:
conda ติดตั้ง -cr r-rocr – ใช่
เราสามารถพล็อต ROC ด้วยฟังก์ชันการทำนาย () และประสิทธิภาพ ()
library(ROCR) ROCRpred <- prediction(predict, data_test$income) ROCRperf <- performance(ROCRpred, 'tpr', 'fpr') plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))
คำอธิบายรหัส
- การทำนาย(ทำนาย, data_test$income): ไลบรารี ROCR จำเป็นต้องสร้างออบเจ็กต์การทำนายเพื่อแปลงข้อมูลอินพุต
- ประสิทธิภาพ (ROCRpred, 'tpr', 'fpr'): ส่งคืนชุดค่าผสมทั้งสองเพื่อสร้างในกราฟ ที่นี่ tpr และ fpr ถูกสร้างขึ้น Tot พล็อตความแม่นยำและการเรียกคืนร่วมกัน ใช้ "prec", "rec"
Output:
ขั้นตอน 8) ปรับปรุงรูปแบบ
คุณสามารถลองเพิ่มความไม่เชิงเส้นให้กับโมเดลด้วยการโต้ตอบระหว่าง
- อายุและชั่วโมงต่อสัปดาห์
- เพศและชั่วโมงต่อสัปดาห์
คุณต้องใช้การทดสอบคะแนนเพื่อเปรียบเทียบทั้งสองรุ่น
formula_2 <- income~age: hours.per.week + gender: hours.per.week + . logit_2 <- glm(formula_2, data = data_train, family = 'binomial') predict_2 <- predict(logit_2, data_test, type = 'response') table_mat_2 <- table(data_test$income, predict_2 > 0.5) precision_2 <- precision(table_mat_2) recall_2 <- recall(table_mat_2) f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2)) f1_2
Output:
## [1] 0.6109181
คะแนนสูงกว่าครั้งก่อนเล็กน้อย คุณสามารถทำงานกับข้อมูลต่อไปเพื่อพยายามเอาชนะคะแนน
สรุป
เราสามารถสรุปฟังก์ชันในการฝึกการถดถอยแบบโลจิสติกได้ในตารางด้านล่างนี้:
แพ็คเกจ | วัตถุประสงค์ | ฟังก์ชัน | ข้อโต้แย้ง |
---|---|---|---|
- | สร้างชุดข้อมูลฝึก/ทดสอบ | create_train_set() | ข้อมูล ขนาด รถไฟ |
กล | ฝึกโมเดลเชิงเส้นทั่วไป | จีแอลเอ็ม() | สูตร ข้อมูล ครอบครัว* |
กล | สรุปแบบจำลอง | สรุป() | รุ่นที่ติดตั้ง |
ฐาน | ทำนายกัน | ทำนาย() | รุ่นที่ติดตั้ง, ชุดข้อมูล, ประเภท = 'การตอบสนอง' |
ฐาน | สร้างเมทริกซ์ความสับสน | โต๊ะ() | ใช่ ทำนาย() |
ฐาน | สร้างคะแนนความแม่นยำ | ผลรวม (diag (ตาราง ()) / ผลรวม (ตาราง () | |
โรซีอาร์ | สร้าง ROC : ขั้นตอนที่ 1 สร้างการทำนาย | การทำนาย() | ทำนาย (), y |
โรซีอาร์ | สร้าง ROC : ขั้นตอนที่ 2 สร้างประสิทธิภาพ | ผลงาน() | การทำนาย (), 'tpr', 'fpr' |
โรซีอาร์ | สร้าง ROC : ขั้นตอนที่ 3 พล็อตกราฟ | พล็อต() | ผลงาน() |
อื่น ๆ GLM ประเภทของรุ่นได้แก่:
– ทวินาม: (ลิงก์ = “logit”)
– เกาส์เซียน: (ลิงก์ = “ตัวตน”)
– แกมมา: (ลิงก์ = “ผกผัน”)
– inverse.gaussian: (ลิงก์ = “1/mu^2”)
– ปัวซง: (ลิงก์ = “บันทึก”)
– กึ่ง: (link = “identity”, variance = “constant”)
– กึ่งทวินาม: (link = “logit”)
– เสมือน: (link = “log”)