当前位置:   article > 正文

R语言机器学习篇——决策树_r语言决策树

r语言决策树
参考书籍:陈强.机器学习及R应用.北京:高等教育出版社,2020

决策树"算法是一种非参数方法,它本质上也是一种“近邻”方法,因此本章分别介绍运用于回归问题以及分类问题的决策树算法。

一 回归树

在本例中,使用Boston房价数据,该数据集包含1970年波士顿506个社区有关房价的14个变量,响应变量为社区房价中位数medv,下面使用rpart包估计决策树,该算法与CART非常接近。部分数据集及R代码如下所示。

  1. #估计回归树
  2. library(rpart)
  3. library(MASS)
  4. dim(Boston) #506个观测,14个变量
  5. #[1] 506 14
  6. set.seed(1)
  7. train<-sample(506,354) #随机选取354个观测值(70%)作为训练集,其余作为测试集
  8. set.seed(123)
  9. fit<-rpart(medv~.,data = Boston,subset = train)
  10. fit #结果展示如下,默认进行10折交叉验证
  11. #n= 354
  12. node), split, n, deviance, yval
  13. * denotes terminal node
  14. 1) root 354 32268.9600 22.95085
  15. 2) rm< 6.945 296 10831.8100 19.82230
  16. 4) lstat>=14.405 119 2214.7900 14.84202
  17. 8) crim>=5.76921 56 636.1371 12.04286 *
  18. 9) crim< 5.76921 63 749.8527 17.33016 *
  19. 5) lstat< 14.405 177 3681.0470 23.17062
  20. 10) rm< 6.543 138 1690.0800 21.85580 *
  21. 11) rm>=6.543 39 908.2292 27.82308 *
  22. 3) rm>=6.945 58 3754.2630 38.91724
  23. 6) rm< 7.445 33 749.6655 33.12727 *
  24. 7) rm>=7.445 25 438.0200 46.56000 *

通过结果,显示出共有11节点,而后跟”*“号则为终节点,每行输出结果的内容依次为:node(节点)、split(到节点的分裂条件),n(该节点的样本数),deviance(该节点的偏离度,对回归问题就是残差平方和),yval(该节点的预测值,即y的平均值)。但是这样的结果没有这么直观,因此可用图像的形式加以表述。

  1. #决策树图像
  2. op<-par(no.readonly = TRUE)
  3. par(mar=c(1,1,1,1)) #设置图像的英分单位
  4. plot(fit,margin = 0.1) #参数margin表示在决策树的边框留下0.1的空间
  5. text(fit) #在图像中加入文字信息
  6. par(op)

该图像的右边表示“是”,左边表示“否",如根节点的分裂条件为房间数“rm<6.945”。如果不满足此条件’则为“大宅”向右,满足条件“rm>=6.945”的大宅又可进一步细分为“rm>=7.445“的“豪宅”以及满足“rm<7.445“的—般大宅。该图还显示终节点“豪宅’,的预测均价为46.56而终节点“—般大宅”的预测均价为33.13.

数据集总有13个特征变量,但此决策树仅用了3个变量,对于树规模的合理确定,可使决策树拥有更好的泛化预测能力,可用交叉验证来确定。

  1. #确定决策树规模
  2. plotcp(fit)

在交叉验证误差图中,下方横轴为复杂性参数cp,控制对模型复杂度的惩罚力度,上方横轴为决策树规模,即终节点的数目,纵轴为交叉验证误差。

在此图中显示终节点数目为6时,交叉验证误差最低,如果使用“一个标准误(1SE)”的规则,则应选择终节点数目为5,图中虚线表示离最优cp值一个标准差的位置,此图的具体信息还可通过cptable来查看,过程如下:

  1. fit$cptable
  2. fit$cptable
  3. # CP nsplit rel error xerror xstd
  4. 1 0.54798440 0 1.0000000 1.0079393 0.09789436
  5. 2 0.15296356 1 0.4520156 0.4815081 0.04954874
  6. 3 0.07953702 2 0.2990520 0.3307847 0.03894189
  7. 4 0.03355353 3 0.2195150 0.2525102 0.03664166
  8. 5 0.02568412 4 0.1859615 0.2267945 0.03582720
  9. 6 0.01000000 5 0.1602774 0.1959728 0.03280781

其中第1列为复杂性参数CP:第2列nsplit为分裂数,也就是终节点数目减1;第3列rel error为训练集的相对误差;第4列xerror为交叉验证误差;第5列xstd为交叉验证误差的标准误。

为得到修枝后的最优模型,需提取能使交叉验证误差xerror最小化的最优复杂性参数cp,如下所示:

  1. #决策树修枝
  2. min_cp<-fit$cptable[which.min(fit$cptable[,"xerror"]),"CP"] #选出xerror最小cp值
  3. min_cp
  4. #[1] 0.01
  5. fit_best<-prune(fit,cp=min_cp) #使用修枝函数,得到最终决策树
  6. library(rpart.plot)
  7. prp(fit_best,type=2) #画出修枝后的最优决策树图像,type可设为1-5

下面,对测试集进行预测,并计算测试误差:

  1. tree.pred<-predict(fit_best,newdata = Boston[train,])
  2. y.test<-Boston[-train,"medv"]
  3. mean((tree.pred-y.test)^2)
  4. #[1] 36.2319
  5. plot(tree.pred,y.test,main = "Boston Housing")
  6. abline(0,1)

通过结果可知测试集的均方误为36.23,因此画出测试集效应变量的实际值与预测值的散点图:

(abline(0,1)表示直线“y=0+1*x”)

如图所见,此回归树模型只有6个预测值,而实际值变化较大。因此尝试用”一个标准差“的规则来预测,在修枝时,可通过设定参数cp=0.03来实现

  1. #1SE规则
  2. fit_1se<-prune(fit,cp=0.03)
  3. tree.pred.1se<-predict(fit_1se,newdata = Boston[-train,])
  4. mean((tree.pred.1se-y.test)^2)
  5. #[1] 38.17137

然而,测试集的均方误差反而上升至38.17,最后作为对比,考察线性回归(ols)模型的测试误差:

  1. #ols
  2. ols.fit<-lm(medv~.,Boston,subset = train)
  3. ols.pred<-predict(ols.fit,newdata = Boston[-train,])
  4. mean((ols.pred-y.test)^2)
  5. #[1] 27.31196

结果显示,ols回归的测试集均方误差仅为27.31,明显低于回归树的均方误差,进一步直观地展示ols的预测效果

  1. plot(ols.pred,y.test,main = "OLS Prediction")
  2. abline(0,1)

从图可见,与决策树的6个预测值相比,OLS的预测值更为多样化,故图中的散点更为紧密地围绕在45度线周围。虽然在此例中,决策树的预测效果不及线性回归,但基于决策树的随机森林却明显优于OLS。

二 分类树

在本例中,使用葡萄牙银行市场营销的数据集来演示分类树的R操作,部分数据集及操作过程如下所示:

  1. bank <- read.csv("bank-additional.csv",header = TRUE,sep=";")
  2. str(bank,vec.len=1)
  3. #'data.frame': 4119 obs. of 21 variables:
  4. $ age : int 30 39 ...
  5. $ job : chr "blue-collar" ...
  6. $ marital : chr "married" ...
  7. $ education : chr "basic.9y" ...
  8. $ default : chr "no" ...
  9. $ housing : chr "yes" ...
  10. $ loan : chr "no" ...
  11. $ contact : chr "cellular" ...
  12. $ month : chr "may" ...
  13. $ day_of_week : chr "fri" ...
  14. $ duration : int 487 346 ...
  15. $ campaign : int 2 4 ...
  16. $ pdays : int 999 999 ...
  17. $ previous : int 0 0 ...
  18. $ poutcome : chr "nonexistent" ...
  19. $ emp.var.rate : num -1.8 1.1 ...
  20. $ cons.price.idx: num 92.9 ...
  21. $ cons.conf.idx : num -46.2 -36.4 ...
  22. $ euribor3m : num 1.31 ...
  23. $ nr.employed : num 5099 ...
  24. $ y : chr "no" ...

函数str()的参数“vec.1en=1”限制作为示例的观测值个数,默认值为4。结果显示,此数据框包含4119个观测值与21个变量,其中,响应变量Y为因子(取值为”yes“或"no“)’表示在接到银行的直销电话后,客户是否会购买“银行定期存款"产品。特征变量包括客户的个人特征,比如年龄、职业

类型、婚姻状况、教育程度;经济状况,比如是否有信用违约、是否有房贷,是否有个人贷款,工作单位人数;营销状态 ,比如自上次联络以来的天数等。

特别地’特征变量duration表示自上次去电后过了多少秒,显然,在去电前,这个变量对于预测客户购买意愿毫无意义,故从数据框中去掉:

  1. bank$duration<-NULL
  2. prop.table(table(bank$y)) #考虑样本中有购买金融产品意愿的比例
  3. # no yes
  4. 0.8905074 0.1094926

结果显示,只有10.9%的客户有购买银行定期存款的意愿。下面,我们把样本随机分为两组保留1000个观测值作为测试集’而以其余3119个观测值为训练集,并使用rpart()函数估计分类树,画出交叉验证图,结果如下所示:

  1. set.seed(1)
  2. train <- sample(4119,3119) #随机选择3119组观测作为训练集
  3. library(rpart)
  4. set.seed(123)
  5. fit <- rpart(y~.,data=bank,subset=train)
  6. plotcp(fit)

使用rpart()函数无须区别分类树或回归树,因为它会根据响应变量的类型自动识别。也可用参数“method=class”指定分类树;或用参数“method=ANOVA”指定回归树。

由图可见,当分类树规模,也就是终节点数目为3时,交叉验证误差达到最小值,进一步可显示交叉验证的细节:

  1. #交叉验证细节
  2. fit$cptable
  3. # CP nsplit rel error xerror xstd
  4. 1 0.06051873 0 1.0000000 1.0000000 0.05060858
  5. 2 0.01056676 2 0.8789625 0.8904899 0.04808341
  6. 3 0.01008646 5 0.8472622 0.9769452 0.05009393
  7. 4 0.01000000 7 0.8270893 0.9769452 0.05009393
  8. min_cp <- fit$cptable[which.min(fit$cptable[,"xerror"]),"CP"]
  9. min_cp
  10. #[1] 0.01056676

结果显示,当复杂性参数CP等于0.01056676,分裂数nsplit为2(故终节点数为3)时,交叉验证误差xerror达到最小。因此估计修枝后的最优模型,并画此分类树的结构图:

  1. fit_best <- prune(fit, cp = min_cp) #最优决策树
  2. op <- par(no.readonly = TRUE)
  3. par(mar=c(1,1,1,1))
  4. plot(fit_best,uniform=TRUE,margin=0.1)
  5. text(fit_best,cex=1.5)
  6. par(op)

参数“uniform=TRUE”表示使不同节点垂直下降的高度保持一致(默认与基尼指数的下

降幅度成正比)。

图中的决策树只有3个终节点,这意味着只要给这一类客户致电即可,即工作单位人数小于5088人,而且自上次营销致电已过去12.5天的客户。

下面,在测试集中进行预测,并展示混淆矩阵:

  1. #测试集预测
  2. tree.pred <- predict(fit_best,bank[-train,],type="class")
  3. y.test <- bank[-train,"y"]
  4. (table <- table(tree.pred,y.test))
  5. # y.test
  6. tree.pred no yes
  7. no 890 87
  8. yes 6 17
  9. (accuracy <- sum(diag(table))/sum(table)) #准确率
  10. #[1] 0.907
  11. (sensitivity <- table[2,2]/(table[1,2]+table[2,2])) #灵敏度
  12. #[1] 0.1634615

结果显示,虽然预测准确率(accuracy)高达90.7%;但算法的灵敏度(sensitivity)仅有16.3%,即只能成功识别16.3%有购买意向的客户。因为无购买意向的客户占比达到89.1%,故只要猜想所有客户都不够买,即可达到89.1%的准确率。

在做以上预测时可默认以“概率大于0.5”作为预测标准。为提高算法的灵敏度,以识别更多有购买意向的潜在客户’可降低此概率门槛值,比如将“概率大于0.1”即视为有购买意向。为此’输人以下命令:

  1. tree.prob <- predict(fit_best,bank[-train,],type="prob")
  2. tree.pred <- tree.prob[,2] >= 0.1
  3. (table <- table(tree.pred,y.test))
  4. # y.test
  5. tree.pred no yes
  6. FALSE 826 59
  7. TRUE 70 45
  8. (accuracy <- sum(diag(table))/sum(table))
  9. #[1] 0.871
  10. (sensitivity <- table[2,2]/(table[1,2]+table[2,2]))
  11. #[1] 0.4326923

函数rpart()默认使用基尼系数估计分类树,下面尝试使用信息(parms=list(split="information"))作为分裂准则,结果显示所得的混淆矩阵及预测率都与基尼指数结果完全相同。

  1. #信息熵准则
  2. set.seed(123)
  3. fit <- rpart(y~.,data=bank,subset=train,parms=list(split="information"))
  4. min_cp <- fit$cptable[which.min(fit$cptable[,"xerror"]),"CP"]
  5. fit_best <- prune(fit, cp = min_cp)
  6. tree.pred <- predict(fit_best,bank[-train,],type="class")
  7. (table <- table(tree.pred,y.test))
  8. # y.test
  9. tree.pred no yes
  10. no 890 87
  11. yes 6 17
  12. (accuracy <- sum(diag(table))/sum(table))
  13. #[1] 0.907

为了避免过拟合,函数rpart()还设置了默从的参数值“minsplit=20”,表示如果节点样本数少于20即不再分裂;以及“minbucket=5”表示终节点至少应包含5个观测值。下面尝试去掉这两个限制,再次进行预测:

  1. #去限制
  2. set.seed(123)
  3. fit <- rpart(y~.,data=bank,subset=train,control=rpart.control(minsplit = 0,minbucket = 0))
  4. min_cp <- fit$cptable[which.min(fit$cptable[,"xerror"]),"CP"]
  5. fit_best <- prune(fit, cp = min_cp)
  6. tree.pred <- predict(fit_best,bank[-train,],type="class")
  7. (table <- table(tree.pred,y.test))
  8. # y.test
  9. tree.pred no yes
  10. no 890 87
  11. yes 6 17
  12. (accuracy <- sum(diag(table))/sum(table))
  13. #[1] 0.907

其中,函数rpart()的参数“control=rpart.control(minsplit = 0,minbucket = 0))”,表示既不限制节点分裂的前提条件,也不限制终节点的规模。结果显示,变动这两个参数,对于混淆矩阵与预测准确率并无影响。可能的原因是,虽然未限制节点分裂条件与终节点规模,或许导致过拟合(决策树过于枝繁叶茂),但经过交叉验证进行修枝后,模型的复杂性依然能得到合理控制。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/53085
推荐阅读
相关标签
  

闽ICP备14008679号