درخت تصمیم و جنگل تصادفی در R — راهنمای کاربردی
«درخت تصمیم» (Decision tree) یک راهکار بسیار قدرتمند بصری برای تحلیل یک سری از خروجیهای پیشبینی شده برای یک مدل مشخص است. همچنین، از این الگوریتم اغلب به عنوان مکمل (و یا حتی جایگزین) تحلیل «رگرسیون» (Regression) در تعیین اینکه چگونه یک سری از «متغیرهای توصیفی» (Explanatory Variables) یک متغیر وابسته را تحت تاثیر قرار میدهند استفاده میشود. در مثال بیان شده در این مطلب، تاثیر متغیرهای توصیفی «سن» (age)، «جنسیت» (gender)، «مایل» (miles)، «اعتبار» (debt) (منظور اعتبار کارت بانکی یا همان کارت اعتباری فرد است) و «درآمد» (income) بر «متغیر وابسته» (dependent variable) «قیمت خودرو» (car sales) با استفاده از درخت تصمیم و «جنگل تصادفی» (Random Forest) تحلیل خواهد شد. مجموعه داده مورد استفاده در این مطلب، از مسیر (+) قابل دانلود است.
مساله دستهبندی و درخت تصمیم
ابتدا، مجموعه داده بارگذاری و متغیر پاسخ ساخته میشود (که برای درخت تصمیم مورد استفاده قرار میگیرد زیرا نیاز به تبدیل فروش از متغیر عددی به طبقهای وجود دارد):
1#Set Directory and define response variable
2setwd("C:/Users/michaeljgrogan/Documents/a_documents/computing/data science/datasets")
3fullData <- read.csv("cars.csv")
4attach(fullData)
5
6fullData$response[CarSales > 24000] <- ">24000"
7fullData$response[CarSales > 1000 & CarSales <= 24000] <- ">1000 & <24000"
8fullData$response[CarSales <= 1000] <- "<1000"
9fullData$response<-as.factor(fullData$response)
10str(fullData)
سپس، دادههای آموزش و آزمون ساخته میشود (دادههایی که برای ساخت مدل مورد استفاده قرار خواهد گرفت و سپس دادههایی که برای آزمودن مدل استفاده میشوند).
1#Create training and test data
2inputData <- fullData[1:770, ] # training data
3testData <- fullData[771:963, ] # test data
سپس، درخت دستهبندی ساخته میشود.
1#Classification Tree
2library(rpart)
3formula=response~Age+Gender+Miles+Debt+Income
4dtree=rpart(formula,data=inputData,method="class",control=rpart.control(minsplit=30,cp=0.001))
5plot(dtree)
6text(dtree)
7summary(dtree)
8printcp(dtree)
9plotcp(dtree)
10printcp(dtree)
توجه به این نکته لازم است که cp value چیزی است که سایز درخت مطلوب را تعیین میکند (در ادامه مشاهده میشود که خطای نسبی X-val هنگامی که سایز درخت برابر با ۴ است کمینه میشود). بنابراین، درخت تصمیم با استفاده از متغیر dtree و در نظر گرفتن این متغیر ساخته میشود.
1> summary(dtree)
2
3Call:
4rpart(formula = formula, data = inputData, method = "class",
5 control = rpart.control(minsplit = 30, cp = 0.001))
6 n= 770
7
8 CP nsplit
91 0.496598639 0
102 0.013605442 1
113 0.008503401 6
124 0.001000000 10
13 rel error xerror
141 1.0000000 1.0000000
152 0.5034014 0.5170068
163 0.4353741 0.5646259
174 0.4013605 0.5442177
18 xstd
191 0.07418908
202 0.05630200
213 0.05854027
224 0.05759793
هرس کردن درخت
سپس، درخت تصمیم «هرس» (pruned) میشود و طی آن «گرههای» (nodes) نامناسب از درخت برای جلوگیری از «بیشبرازش» (Overfitting) دادهها حذف میشوند.
1> #Prune the Tree and Plot
2pdtree<- prune(dtree, cp=dtree$cptable[which.min(dtree$cptable[,"xerror"]),"CP"])
3plot(pdtree, uniform=TRUE,
4 main="Pruned Classification Tree For Sales")
5text(pdtree, use.n=TRUE, all=TRUE, cex=.8)
مدل اکنون با استفاده از دادههای تست مورد آزمون قرار گرفته است و میتوان مشاهده کرد که درصد دستهبندی غلط برابر با ٪۱۶.۷۵ است.
واضح است که هر چه این میزان کمتر باشد بهتر است، بنابراین این امر از آن حکایت دارد که در حال حاضر مدل در پیشبینی «دادههای واقعی» صحیحتر عمل میکند.
1> #Model Testing
2> out <- predict(pdtree)
3> table(out[1:193],testData$response)
4> response_predicted <- colnames(out)[max.col(out, ties.method = c("first"))] # predicted
5> response_input <- as.character (testData$response) # actuals
6> mean (response_input != response_predicted) # misclassification %
7[1] 0.2844156
حل مساله رگرسیون با درخت تصمیم
هنگامی که متغیرهای وابسته به جای طبقهای، عددی هستند، از درخت رگرسیون به صورت زیر استفاده میشود.
1> #Regression Tree
2fitreg <- rpart(CarSales~Age+Gender+Miles+Debt+Income,
3 method="anova", data=inputData)
4
5printcp(fitreg)
6plotcp(fitreg)
7summary(fitreg)
8par(mfrow=c(1,2))
9rsq.rpart(fitreg) # cross-validation results
1> #Regression Tree
2> fitreg <- rpart(CarSales~Age+Gender+Miles+Debt+Income,
3+ method="anova", data=inputData)
4>
5> printcp(fitreg)
6
7Regression tree:
8rpart(formula = CarSales ~ Age + Gender + Miles + Debt + Income,
9 data = inputData, method = "anova")
10
11Variables actually used in tree construction:
12[1] Age Debt Income
13
14Root node error: 6.283e+10/770 = 81597576
15
16n= 770
17
18 CP nsplit rel error
191 0.698021 0 1.00000
202 0.094038 1 0.30198
213 0.028161 2 0.20794
224 0.023332 4 0.15162
235 0.010000 5 0.12829
24 xerror xstd
251 1.00162 0.033055
262 0.30373 0.016490
273 0.21261 0.012890
284 0.18149 0.013298
295 0.14781 0.013068
30
31> plotcp(fitreg)
32> summary(fitreg)
33
34Call:
35rpart(formula = CarSales ~ Age + Gender + Miles + Debt + Income,
36 data = inputData, method = "anova")
37 n= 770
38
39 CP nsplit rel error
401 0.69802077 0 1.0000000
412 0.09403824 1 0.3019792
423 0.02816107 2 0.2079410
434 0.02333197 4 0.1516189
445 0.01000000 5 0.1282869
45 xerror xstd
461 1.0016159 0.03305536
472 0.3037301 0.01649002
483 0.2126110 0.01289041
494 0.1814939 0.01329778
505 0.1478078 0.01306756
51
52Variable importance
53 Debt Miles Income Age
54 53 23 20 4
سپس، با استفاده از کد زیر درخت رگرسیون هرس میشود.
1> #Prune the Tree
2pfitreg<- prune(fitreg, cp=fitreg$cptable[which.min(fitreg$cptable[,"xerror"]),"CP"]) # from cptable
3plot(pfitreg, uniform=TRUE,
4 main="Pruned Regression Tree for Sales")
5text(pfitreg, use.n=TRUE, all=TRUE, https://blog.faradars.org/wp-admin/post-new.php#cex=.8)
جنگل تصادفی
اگر درختهای تصمیم بسیاری وجود داشته باشند که هدف «برازش» (Fit) آنها بدون وقوع «بیشبرازش» (Overfitting) باشد، یک راهکار برای حل مساله استفاده از «جنگل تصادفی» (Random Forest) است.
یک جنگل تصادفی امکان تعیین مهمترین پیشبینها را در میان متغیرهای توصیفی با تولید درختهای تصمیم زیاد و رتبهدهی به متغیرها بر اساس اهمیت آنها فراهم میکند.
1> library(randomForest)
2> fitregforest <- randomForest(CarSales~Age+Gender+Miles+Debt+Income,data=inputData)
3> print(fitregforest) # view results
4
5Call:
6 randomForest(formula = CarSales ~ Age + Gender + Miles + Debt + Income, data = inputData)
7 Type of random forest: regression
8 Number of trees: 500
9No. of variables tried at each split: 1
10
11 Mean of squared residuals: 10341022
12 % Var explained: 87.33
13> importance(fitregforest) # importance of each predictor
14 IncNodePurity
15Age 5920357954
16Gender 187391341
17Miles 10811341575
18Debt 21813952812
19Income 12694331712
در شکل بالا می توان مشاهده کرد که اعتبار (کارت اعتباری) به عنوان مهمترین عامل علامتگذاری شده، در واقع مشتریانی که سطح اعتبار بالاتری دارند احتمال بیشتری دارد که پول بیشتری نیز برای خرید خودرو صرف کنند.
میتوان مشاهده کرد که ٪۸۷.۳۳ از تغییرات به وسیله «جنگل تصادفی» توصیف شده و خطا در تقریبا ۱۰۰ درخت کمینه شده است.
اگر نوشته بالا برای شما مفید بوده، آموزشهای زیر نیز به شما پیشنهاد میشوند:
- مجموعه آموزشهای برنامه نویسی پایتون (Python)
- مجموعه آموزشهای آمار، احتمالات و دادهکاوی
- مجموعه آموزشهای یادگیری ماشین و بازشناسی الگو
- مجموعه آموزشهای شبکههای عصبی مصنوعی
- مجموعه آموزشهای هوش محاسباتی
- آموزش برنامهنویسی R و نرمافزار R Studio
- مجموعه آموزشهای برنامه نویسی متلب (MATLAB)
^^