欢迎光临散文网 会员登陆 & 注册

R语言基于树的方法:决策树,随机森林,Bagging,增强树

2021-03-05 09:58 作者:拓端tecdat  | 我要投稿

原文链接:http://tecdat.cn/?p=9859

 

概观

本文是有关  基于树的  回归和分类方法的。

树方法简单易懂,但对于解释却非常有用,但就预测准确性而言,它们通常无法与最佳监督学习方法竞争。因此,我们还介绍了Bagging(自助法),随机森林和增强树。这些示例中的每一个都涉及产生多个树,然后将其合并以产生单个共识预测。我们看到,合并大量的树可以大大提高预测准确性,但代价是损失解释能力。

决策树可以应用于回归和分类问题。我们将首先考虑回归。

决策树基础:回归

我们从一个简单的例子开始:

我们预测棒球运动员的  Salary 。

结果将是一系列分裂规则。第一个分支会将数据分割  Years < 4.5 为左侧的分支,其余的为右侧。如果我们对此模型进行编码,我们会发现关系最终变得稍微复杂一些。

  1. library(tree)

  2. library(ISLR)

  3. attach(Hitters)

  4. # 删除NA数据

  5. Hitters<- na.omit(Hitters)

  6. # log转换Salary使其更正态分布

  7. hist(Hitters$Salary)

  1. Hitters$Salary <- log(Hitters$Salary)

  2. hist(Hitters$Salary)

summary(tree.fit)

  1. ##

  2. ## Regression tree:

  3. ## tree(formula = Salary ~ Hits + Years, data = Hitters)

  4. ## Number of terminal nodes:  8

  5. ## Residual mean deviance:  0.271 = 69.1 / 255

  6. ## Distribution of residuals:

  7. ##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max.

  8. ## -2.2400 -0.2980 -0.0365  0.0000  0.3230  2.1500

现在,我们讨论通过对特征空间进行分层来构建预测树。通常,有两个步骤。

  1. 找到最能分隔因变量的变量/拆分,从而产生最低的RSS。

  2. 将数据分为两个在第一个标识的节点上的叶子。

  3. 在每片叶子中,找到分隔结果的最佳变量/分割。

目标是找到最小化RSS的区域数。但是,考虑将每个可能的分区划分为J个区域在计算上是不可行的  。为此,我们采取了  自上而下的贪婪  的方法。它是自顶向下的,因为我们从所有观测值都属于一个区域的点开始。贪婪是因为在树构建过程的每个步骤中,都会在该特定步骤中选择最佳拆分,而不是向前看会在将来的某个步骤中生成更好树的拆分。

一旦创建了所有区域,我们将使用每个区域中训练观察的平均值预测给定测试观察的因变量。

剪枝

尽管上面的模型可以对训练数据产生良好的预测,但是基本的树方法可能会过度拟合数据,从而导致测试性能不佳。这是因为生成的树往往过于复杂。具有较少拆分的较小树通常以较小的偏差为代价,从而导致方差较低,易于解释且测试错误较低。实现此目的的一种可能方法是仅在每次拆分导致的RSS减少量超过某个(高)阈值时,才构建一棵树。

因此,更好的策略是生成一棵树,然后  修剪  回去以获得更好的子树。

成本复杂度剪枝算法-也称为最弱链接修剪为我们提供了解决此问题的方法。而不是考虑每个可能的子树,我们考虑由非负调整参数索引的树序列  alpha

 


  1. trees <- tree(Salary~., train)

  2. plot(trees)

  3. text(trees, pretty=0)


  1. plot(cv.trees)

似乎第7棵树的偏差最小。然后我们可以剪枝树。但是,这并不能真正剪枝模型,因此我们可以选择较小的树来改善偏差状态。这大约是在第四个分支。

  1. prune.trees <- prune.tree(trees, best=4)

  2. plot(prune.trees)

  3. text(prune.trees, pretty=0)

使用剪枝的树对测试集进行预测。

mean((yhat - test$Salary)^2)## [1] 0.3531

分类树

分类树与回归树非常相似,不同之处在于分类树用于预测定性而不是定量。

为了增长分类树,我们使用相同的递归二进制拆分,但是现在RSS不能用作拆分标准。替代方法是使用  分类错误率。虽然很直观,但事实证明,此方法对于树木生长不够敏感。

实际上,另外两种方法是可取的,尽管它们在数值上非常相似:

Gini index_是K个  类之间总方差的度量  。

如果给定类别中的训练观测值的比例都接近零或一,则__cross-entropy_的值将接近零。

修剪树时,首选这两种方法,但如果以最终修剪模型的预测精度为目标,则规则分类错误率是优选的。

为了证明这一点,我们将使用  Heart 数据集。这些数据包含AHD 303名胸痛患者的二进制结果变量  。结果被编码为  Yes 或  No 存在心脏病。

  1. dim(Heart)

  2. [1] 303 15

 

到目前为止,这是一棵非常复杂的树。让我们确定是否可以通过使用分类评分方法的交叉验证来使用修剪后的版本改善拟合度。

cv.trees

  1. ## $size

  2. ## [1] 16  9  5  3  2  1

  3. ##

  4. ## $dev

  5. ## [1] 44 45 42 41 41 81

  6. ##

  7. ## $k

  8. ## [1] -Inf  0.0  1.0  2.5  5.0 37.0

  9. ##

  10. ## $method

  11. ## [1] "misclass"

  12. ##

  13. ## attr(,"class")

  14. ## [1] "prune"         "tree.sequence"

看起来4棵分裂树的偏差最小。让我们看看这棵树是什么样子。同样,我们使用  prune.misclass 分类设置。

  1. prune.trees <- prune.misclass(trees, best=4)

  2. plot(prune.trees)

  3. text(prune.trees, pretty=0)

  1. ## Confusion Matrix and Statistics

  2. ##

  3. ##           Reference

  4. ## Prediction No Yes

  5. ##        No  72  24

  6. ##        Yes 10  45

  7. ##

  8. ##                Accuracy : 0.775

  9. ##                  95% CI : (0.7, 0.839)

  10. ##     No Information Rate : 0.543

  11. ##     P-Value [Acc > NIR] : 2.86e-09

  12. ##

  13. ##                   Kappa : 0.539

  14. ##  Mcnemar's Test P-Value : 0.0258

  15. ##

  16. ##             Sensitivity : 0.878

  17. ##             Specificity : 0.652

  18. ##          Pos Pred Value : 0.750

  19. ##          Neg Pred Value : 0.818

  20. ##              Prevalence : 0.543

  21. ##          Detection Rate : 0.477

  22. ##    Detection Prevalence : 0.636

  23. ##       Balanced Accuracy : 0.765

  24. ##

  25. ##        'Positive' Class : No

  26. ##

在这里,我们获得了约76%的精度。

那么为什么要进行拆分呢?拆分导致节点纯度提高  ,这可能会在使用测试数据时有更好的预测。

树与线性模型

最好的模型始终取决于当前的问题。如果可以通过线性模型近似该关系,则线性回归将很可能占主导地位。相反,如果我们在特征和y之间具有复杂的,高度非线性的关系,则决策树可能会胜过传统方法。

优点/缺点

优点

  • 树比线性回归更容易解释。

  • 更能反映了人类的决策。

  • 易于以图形方式显示。

  • 可以处理没有伪变量的定性预测变量。

缺点

  • 树木通常不具有与传统方法相同的预测准确性,但是,诸如  Bagging,随机森林和增强等方法  可以提高性能。

其他例子

 

树结构中实际使用的变量:“价格”、“ CompPrice”、“年龄”、“收入”、“ ShelveLoc”、“广告”,终端节点数:19,残差平均偏差:0.414 = 92/222,错误分类错误率:0.0996 = 24/241

在这里,我们看到训练误差约为9%。我们  plot() 用来显示树结构和  text() 显示节点标签。

  1. plot(sales.tree)

  2. text(sales.tree, pretty=0)

让我们看看完整的树如何处理测试数据。

  1. ## Confusion Matrix and Statistics

  2. ##

  3. ##           Reference

  4. ## Prediction High Low

  5. ##       High   56  12

  6. ##       Low    23  68

  7. ##

  8. ##                Accuracy : 0.78

  9. ##                  95% CI : (0.707, 0.842)

  10. ##     No Information Rate : 0.503

  11. ##     P-Value [Acc > NIR] : 6.28e-13

  12. ##

  13. ##                   Kappa : 0.559

  14. ##  Mcnemar's Test P-Value : 0.091

  15. ##

  16. ##             Sensitivity : 0.709

  17. ##             Specificity : 0.850

  18. ##          Pos Pred Value : 0.824

  19. ##          Neg Pred Value : 0.747

  20. ##              Prevalence : 0.497

  21. ##          Detection Rate : 0.352

  22. ##    Detection Prevalence : 0.428

  23. ##       Balanced Accuracy : 0.779

  24. ##

  25. ##        'Positive' Class : High

  26. ##

约74%的测试错误率相当不错,但是我们可以通过交叉验证来改善它。

在这里,我们看到最低的错误分类错误是模型4的。现在我们可以将树修剪为4模型。

  1. ## Confusion Matrix and Statistics

  2. ##

  3. ##           Reference

  4. ## Prediction High Low

  5. ##       High   52  20

  6. ##       Low    27  60

  7. ##

  8. ##                Accuracy : 0.704

  9. ##                  95% CI : (0.627, 0.774)

  10. ##     No Information Rate : 0.503

  11. ##     P-Value [Acc > NIR] : 2.02e-07

  12. ##

  13. ##                   Kappa : 0.408

  14. ##  Mcnemar's Test P-Value : 0.381

  15. ##

  16. ##             Sensitivity : 0.658

  17. ##             Specificity : 0.750

  18. ##          Pos Pred Value : 0.722

  19. ##          Neg Pred Value : 0.690

  20. ##              Prevalence : 0.497

  21. ##          Detection Rate : 0.327

  22. ##    Detection Prevalence : 0.453

  23. ##       Balanced Accuracy : 0.704

  24. ##

  25. ##        'Positive' Class : High

  26. ##

这并不能真正改善我们的分类,但是我们大大简化了模型。

  1. ## CART

  2. ##

  3. ## 241 samples

  4. ##  10 predictors

  5. ##   2 classes: 'High', 'Low'

  6. ##

  7. ## No pre-processing

  8. ## Resampling: Cross-Validated (10 fold)

  9. ##

  10. ## Summary of sample sizes: 217, 217, 216, 217, 217, 217, ...

  11. ##

  12. ## Resampling results across tuning parameters:

  13. ##

  14. ##   cp    ROC  Sens  Spec  ROC SD  Sens SD  Spec SD

  15. ##   0.06  0.7  0.7   0.7   0.1     0.2      0.1

  16. ##   0.1   0.6  0.7   0.6   0.2     0.2      0.2

  17. ##   0.4   0.5  0.3   0.8   0.09    0.3      0.3

  18. ##

  19. ## ROC was used to select the optimal model using  the largest value.

  20. ## The final value used for the model was cp = 0.06.

  1. ## Confusion Matrix and Statistics

  2. ##

  3. ##           Reference

  4. ## Prediction High Low

  5. ##       High   56  21

  6. ##       Low    23  59

  7. ##

  8. ##                Accuracy : 0.723

  9. ##                  95% CI : (0.647, 0.791)

  10. ##     No Information Rate : 0.503

  11. ##     P-Value [Acc > NIR] : 1.3e-08

  12. ##

  13. ##                   Kappa : 0.446

  14. ##  Mcnemar's Test P-Value : 0.88

  15. ##

  16. ##             Sensitivity : 0.709

  17. ##             Specificity : 0.738

  18. ##          Pos Pred Value : 0.727

  19. ##          Neg Pred Value : 0.720

  20. ##              Prevalence : 0.497

  21. ##          Detection Rate : 0.352

  22. ##    Detection Prevalence : 0.484

  23. ##       Balanced Accuracy : 0.723

  24. ##

  25. ##        'Positive' Class : High

  26. ##

选择了更简单的树,预测精度有所降低。

最受欢迎的见解

1.从决策树模型看员工为什么离职

2.R语言基于树的方法:决策树,随机森林

3.python中使用scikit-learn和pandas决策树

4.机器学习:在SAS中运行随机森林数据分析报告

5.R语言用随机森林和文本挖掘提高航空公司客户满意度

6.机器学习助推快时尚精准销售时间序列

7.用机器学习识别不断变化的股市状况——隐马尔可夫模型的应用

8.python机器学习:推荐系统实现(以矩阵分解来协同过滤)

9.python中用pytorch机器学习分类预测银行客户流失


R语言基于树的方法:决策树,随机森林,Bagging,增强树的评论 (共 条)

分享到微博请遵守国家法律