当前位置:   article > 正文

机器学习---决策树和随机森林代码_决策森林代码

决策森林代码

1、决策树代码

  1. 1.object ClassificationDecisionTree {
  2. 2.
  3. 3. def main(args: Array[String]): Unit = {
  4. 4. val conf = new SparkConf()
  5. 5. conf.setAppName("analysItem")
  6. 6. conf.setMaster("local[3]")
  7. 7. val sc = new SparkContext(conf)
  8. 8. val data = MLUtils.loadLibSVMFile(sc, "汽车数据样本.txt")
  9. 9. // Split the data into training and test sets (30% held out for testing)
  10. 10. val splits = data.randomSplit(Array(0.7, 0.3))
  11. 11. val (trainingData, testData) = (splits(0), splits(1))
  12. 12. //指明类别
  13. 13. val numClasses=2
  14. 14. //指定离散变量,未指明的都当作连续变量处理
  15. 15. //1,2,3,4维度进来就变成了0,1,2,3
  16. 16. //这里天气维度有3类,但是要指明4,这里是个坑,后面以此类推
  17. 17. val categoricalFeaturesInfo=Map[Int,Int](0->4,1->4,2->3,3->3)
  18. 18. //设定评判标准 "gini"/"entropy"
  19. 19. val impurity="entropy"
  20. 20. //树的最大深度,太深运算量大也没有必要 剪枝 防止模型的过拟合!!!
  21. 21. val maxDepth=3
  22. 22. //设置离散化程度,连续数据需要离散化,分成32个区间,默认其实就是32,分割的区间保证数量差不多 这个参数也可以进行剪枝
  23. 23. val maxBins=32
  24. 24. //生成模型
  25. 25. val model =DecisionTree.trainClassifier(trainingData,numClasses,categoricalFeaturesInfo,impurity,maxDepth,maxBins)
  26. 26. //测试
  27. 27. val labelAndPreds = testData.map { point =>
  28. 28. val prediction = model.predict(point.features)
  29. 29. (point.label, prediction)
  30. 30. }
  31. 31. val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count()
  32. 32. println("Test Error = " + testErr)
  33. 33. println("Learned classification tree model:\n" + model.toDebugString)
  34. 34.
  35. 35. }
  36. 36.}

2、随机森林代码

  1. 1.object ClassificationRandomForest {
  2. 2. def main(args: Array[String]): Unit = {
  3. 3. val conf = new SparkConf()
  4. 4. conf.setAppName("analysItem")
  5. 5. conf.setMaster("local[3]")
  6. 6. val sc = new SparkContext(conf)
  7. 7. //读取数据
  8. 8. val data = MLUtils.loadLibSVMFile(sc,"汽车数据样本.txt")
  9. 9. //将样本按73的比例分成
  10. 10. val splits = data.randomSplit(Array(0.7, 0.3))
  11. 11. val (trainingData, testData) = (splits(0), splits(1))
  12. 12. //分类数
  13. 13. val numClasses = 2
  14. 14. // categoricalFeaturesInfo 为空,意味着所有的特征为连续型变量
  15. 15. val categoricalFeaturesInfo =Map[Int, Int](0->4,1->4,2->3,3->3)
  16. 16. //树的个数
  17. 17. val numTrees = 3
  18. 18. //特征子集采样策略,auto 表示算法自主选取
  19. 19. //"auto"根据特征数量在4个中进行选择
  20. 20. // 1,all 全部特征 2,sqrt 把特征数量开根号后随机选择的 3,log2 取对数个 4,onethird 三分之一
  21. 21. val featureSubsetStrategy = "auto"
  22. 22. //纯度计算 "gini"/"entropy"
  23. 23. val impurity = "entropy"
  24. 24. //树的最大层次
  25. 25. val maxDepth = 3
  26. 26. //特征最大装箱数,即连续数据离散化的区间
  27. 27. val maxBins = 32
  28. 28. //训练随机森林分类器,trainClassifier 返回的是 RandomForestModel 对象
  29. 29. val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
  30. 30. numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
  31. 31. //打印模型
  32. 32. println(model.toDebugString)
  33. 33. //保存模型
  34. 34. //model.save(sc,"汽车保险")
  35. 35. //在测试集上进行测试
  36. 36. val count = testData.map { point =>
  37. 37. val prediction = model.predict(point.features)
  38. 38. // Math.abs(prediction-point.label)
  39. 39. (prediction,point.label)
  40. 40. }.filter(r => r._1 != r._2).count()
  41. 41. println("Test Error = " + count.toDouble/testData.count().toDouble)
  42. 42. println()
  43. 43. }
  44. 44.}

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/神奇cpp/article/detail/945019
推荐阅读
相关标签
  

闽ICP备14008679号