当前位置:   article > 正文

随机森林的简单实现_idl实现随机森林

idl实现随机森林

近日听了七月天空周博的课。现在对随机森林进行一下,简单的实现。

随机森林(randomforest)是一种利用多个分类树对数据进行判别与分类的方法,它在对数据进行分类的同时,还可以给出各个变量(基因)的重要性评分,评估各个变量在分类中所起的作用。

随机森林是一个最近比较火的算法,它有很多的优点:

a. 在数据集上表现良好,两个随机性的引入,使得随机森林不容易陷入过拟合

b. 在当前的很多数据集上,相对其他算法有着很大的优势,两个随机性的引入,使得随机森林具有很好的抗噪声能力

c. 它能够处理很高维度(feature很多)的数据,并且不用做特征选择,对数据集的适应能力强:既能处理离散型数据,也能处理连续型数据,数据集无需规范化

d. 可生成一个Proximities=pij)矩阵,用于度量样本之间的相似性: pij=aij/N, aij表示样本ij出现在随机森林中同一个叶子结点的次数,N随机森林中树的颗数

e. 在创建随机森林的时候,对generlization error使用的是无偏估计

f. 训练速度快,可以得到变量重要性排序(两种:基于OOB误分率的增加量和基于分裂时的GINI下降量

g. 在训练过程中,能够检测到feature间的互相影响

h. 容易做成并行化方法

i. 实现比较简单


  1. /*
  2. * RF.h
  3. *
  4. * Created on: Nov 8, 2015
  5. * Author: shenjiyi
  6. */
  7. #ifndef RF_H_
  8. #define RF_H_
  9. #include <iostream>
  10. #include <cstdlib>
  11. #include <cstring>
  12. #include <cstdio>
  13. #include <cmath>
  14. #include <algorithm>
  15. #include <cmath>
  16. #include <map>
  17. #include <set>
  18. #include <vector>
  19. using namespace std;
  20. #define MAX_RANDOM 10000
  21. #define MAXN 10000
  22. struct Data{
  23. vector<int> x;
  24. int y;
  25. };
  26. typedef vector<Data> Datas;
  27. struct SubTreeType {
  28. double v;
  29. Datas right, left;
  30. };
  31. struct TreeNodeType {
  32. double splitValue;
  33. double splitFeature;
  34. int winClass;
  35. Datas samp;
  36. };
  37. typedef vector<TreeNodeType> Tree;
  38. typedef vector<Tree> Forest;
  39. struct Option {
  40. int treeNumber;
  41. int bagNumber;
  42. int depth;
  43. int bestSelect;
  44. };
  45. class RF {
  46. private:
  47. Forest forest;
  48. double calGini(Datas &data, int &winClass) {
  49. map<int, int> mp; mp.clear();
  50. for (auto d : data) {
  51. mp[d.y]++;
  52. }
  53. double sum = 0;
  54. int winNumber = 0;
  55. for (auto k : mp) {
  56. // cout << k.first << "\t" << k.second << endl;
  57. int key = k.first;
  58. sum = sum + mp[key] * mp[key] /(data.size() + 0.0)/ (data.size() + 0.0);
  59. if (mp[key] > winNumber) {
  60. winClass = key;
  61. winNumber = mp[key];
  62. }
  63. }
  64. return 1 - sum;
  65. }
  66. double splitGini(Datas &left, Datas &right) {
  67. int totalNumber = left.size() + right.size();
  68. int idx = -1;
  69. double sum = (left.size() + 0.0 / totalNumber) * calGini(left, idx)
  70. + (right.size() + 0.0 / totalNumber) * calGini(right, idx);
  71. return sum;
  72. }
  73. double randomf(double a,double b){
  74. return (rand()%(int)((b-a)*MAX_RANDOM))
  75. /(double)MAX_RANDOM+a;
  76. }
  77. double randomi(double a, double b) {
  78. return floor(randomf(a, b));
  79. }
  80. double randoms() {
  81. return (rand()%(int)(MAX_RANDOM))
  82. /(double)MAX_RANDOM;
  83. }
  84. void randomSplit(Datas &data, int feature, SubTreeType& subTree) {
  85. int a = randomi(0, data.size());
  86. int b = randomi(0, data.size());
  87. while (a == b) {b = randomi(0, data.size());}
  88. double s = randoms();
  89. double splitValue = s * data[a].x[feature]
  90. + (1 - s) * data[b].x[feature];
  91. // cout << "a = "<<a<<" b = "<<b <<" s= " << s << " splitvalue= " << splitValue << endl;
  92. subTree.v = splitValue;
  93. for (int i = 0; i < data.size(); ++i) {
  94. if (data[i].x[feature] > splitValue) {
  95. subTree.right.push_back(data[i]);
  96. } else {
  97. subTree.left.push_back(data[i]);
  98. }
  99. }
  100. }
  101. void createSingleTree(Datas &data, int depth, int bestSelect, Tree& singleTree) {
  102. int featureNumber = data[0].x.size();
  103. int allNumber = pow(2, depth + 1) - 1;
  104. int nodeNumber = pow(2, depth) - 1;
  105. singleTree.clear();
  106. for (int i = 0; i < allNumber; ++i) {
  107. TreeNodeType tmp;
  108. singleTree.push_back(tmp);
  109. singleTree[i].splitValue = -1;
  110. singleTree[i].winClass = -1;
  111. singleTree[i].splitFeature = -1;
  112. singleTree[i].samp.clear();
  113. }
  114. for (int i = 0; i < data.size(); ++i) {
  115. singleTree[0].samp.push_back(data[i]);
  116. }
  117. for (int i = 0; i < nodeNumber; i++) {
  118. Datas &samples = singleTree[i].samp;
  119. if (samples.size() == 0 || samples.size() == 1) {
  120. continue;
  121. }
  122. int feature = randomi(0, featureNumber);
  123. int idx;
  124. double bestGini = calGini(samples, idx);
  125. // cout << "bestSelect" << bestSelect << "bestGini " << bestGini << endl;
  126. SubTreeType *bestTree = NULL;
  127. SubTreeType subTree;
  128. for (int j = 0; j < bestSelect && bestGini > 0; ++j) {
  129. subTree.left.clear(); subTree.right.clear();
  130. subTree.v = -1;
  131. randomSplit(samples, feature, subTree);
  132. double newGini = splitGini(subTree.left, subTree.right);
  133. if (newGini < bestGini) {
  134. bestGini = newGini;
  135. bestTree = &subTree;
  136. }
  137. }
  138. if (bestTree != NULL) {
  139. singleTree[i].splitValue = bestTree->v;
  140. singleTree[i].splitFeature = feature;
  141. singleTree[i * 2 + 1].samp = move(bestTree->left);
  142. singleTree[i * 2 + 2].samp = move(bestTree->right);
  143. }
  144. }
  145. // cout << "sss" << endl;
  146. for (int i = 0; i < allNumber; ++i) {
  147. if (singleTree[i].splitValue == -1 && singleTree[i].samp.size() > 0) {
  148. int idx = -1;
  149. calGini(singleTree[i].samp, idx);
  150. singleTree[i].winClass = idx;
  151. }
  152. }
  153. }
  154. Datas bagging(Datas &data, int bagNumber) {
  155. Datas bag; bag.clear();
  156. for (int i = 0; i < bagNumber; ++i) {
  157. int n = randomi(0, data.size());
  158. bag.push_back(data[n]);
  159. }
  160. return bag;
  161. }
  162. void createForest(Datas &data, Option option) {
  163. int treeNumber = option.treeNumber;
  164. int bagNumber = option.bagNumber;
  165. int depth = option.depth;
  166. int bestSelect = option.bestSelect;
  167. forest.clear();
  168. Tree tmp;
  169. for (int i = 0; i < treeNumber; ++i) {
  170. Datas subData = bagging(data, bagNumber);
  171. createSingleTree(subData, depth, bestSelect, tmp);
  172. forest.push_back(tmp);
  173. }
  174. }
  175. int predWithTree(Tree &tree, vector<int> &x) {
  176. // cout <<"tree size " << tree.size() << endl;
  177. for (int i = 0;;) {
  178. if (i >= tree.size()) {
  179. return -1;
  180. }
  181. if (tree[i].winClass != -1) {
  182. // cout << "tree winclass " << tree[i].winClass << endl;
  183. return tree[i].winClass;
  184. }
  185. if (x[tree[i].splitFeature] < tree[i].splitValue) {
  186. i = 2 * i + 1;
  187. } else {
  188. i = 2 * i + 2;
  189. }
  190. }
  191. return -1;
  192. }
  193. public:
  194. //Gini
  195. RF() {
  196. forest.clear();
  197. }
  198. int predWithForest(vector<int> x, int &prob) {
  199. map<int, int> mp; mp.clear();
  200. for (int i = 0; i < forest.size(); ++i) {
  201. int pred = predWithTree(forest[i], x);
  202. if (pred != -1) {
  203. mp[pred]++;
  204. }
  205. }
  206. int winClass = -1, winNumber = -1;
  207. for (auto v : mp) {
  208. cout << "first " << v.first << " second " << v.second << endl;
  209. if (v.second > winNumber) {
  210. winClass = v.first;
  211. winNumber = v.second;
  212. prob = winNumber;
  213. }
  214. }
  215. return winClass;
  216. }
  217. void print() {
  218. cout << "forest size " << forest.size() << endl;
  219. for (auto ts : forest) {
  220. for (auto t : ts) {
  221. cout << "[";
  222. cout << "(" << t.splitFeature << "," << t.splitValue << "," << t.winClass << ")" << "|";
  223. for (auto tt:t.samp) {
  224. cout << tt.x[0] <<"," <<tt.y <<" ";
  225. }
  226. cout << "]";
  227. }
  228. cout << endl;
  229. }
  230. }
  231. void Training(Datas &data, int treeNumber, int bagNumber, int depth, int bestSelect) {
  232. Option option;
  233. option.treeNumber = treeNumber;
  234. option.bagNumber = bagNumber;
  235. option.depth = depth;
  236. option.bestSelect = bestSelect;
  237. createForest(data, option);
  238. }
  239. };
  240. #endif /* RF_H_ */


声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/52886
推荐阅读
相关标签
  

闽ICP备14008679号