当前位置:   article > 正文

统计学习方法 第五章 课后习题_根据表中训练数据集构造决策树

根据表中训练数据集构造决策树

5.1 根据表5.1所给的训练数据集,利用信息增益比(C4.5算法)生成决策树

DTree.py实现了ID3、C4.5做树,而CART只是实现计算Gini肯尼指数,没实现做树,步骤是一样的,我进度拖得比较慢,就没有实现,基本上可以套用。你们也可直接计算下,我计算过一遍然后在直接写代码

  1. # -*- coding: utf-8 -*-
  2. import math
  3. C45_Flag = True#算法标志
  4. ID3_Flag = False
  5. class DtreeStruct:
  6. def __init__(self,next_nodelist=None,Ai=None,Aivalue=None,ck=None,value=None):
  7. self.next_nodelist = next_nodelist
  8. self.Ai = Ai
  9. self.Aival= Aivalue
  10. self.value = value
  11. self.ck = ck
  12. def Print(self):
  13. def Fprint(node):
  14. if node.next_nodelist == None:
  15. print("叶节点:",node.Ai,node.Aival,node.value,node.ck)
  16. return
  17. print("节点",node.next_nodelist,node.Ai,node.Aival,node.value,node.ck)
  18. for node in node.next_nodelist:
  19. Fprint(node)
  20. print("根节点",self.next_nodelist,self.Ai,self.Aival,self.value,self.ck)
  21. for node in self.next_nodelist:
  22. Fprint(node)
  23. class DTree:
  24. def __init__(self,datasets,labels):
  25. self.tree = None
  26. self.datasets = datasets
  27. self.labels = labels
  28. self.GetAandC()
  29. def GetAandC(self):
  30. self.C = {}
  31. self.A = {}
  32. self.C[self.labels[-1]] = set([ line[-1] for line in self.datasets])
  33. for i in range(len(self.labels)-1):
  34. self.A[self.labels[i]] = set([ line[i] for line in self.datasets])
  35. def ID3CreateTree(self,epsilon,ICflag):
  36. #经验熵
  37. def emp_entropy(data,label):
  38. dic = {}
  39. datalen = len(data)
  40. indx = self.labels.index(label)
  41. for line in data:
  42. if line[indx] not in dic:
  43. dic[line[indx]] = 0
  44. #该特征下的取值分类个数(基本上为‘类别’)
  45. dic[line[indx]] += 1
  46. return -sum([(dic[p]/datalen)*math.log(dic[p]/datalen,2) for p in dic.keys()])
  47. #经验条件熵
  48. def emp_cdtl_entropy(data,Ai):
  49. dic = {}
  50. data_dic = {}
  51. c = list(self.C.keys())[0]
  52. indx = self.labels.index(Ai)
  53. datalen = len(data)
  54. for line in data:
  55. if line[indx] not in dic:
  56. dic[line[indx]] = 0
  57. data_dic[line[indx]] = []
  58. #以特征取值分类
  59. dic[line[indx]] += 1
  60. #数据子集
  61. data_dic[line[indx]].append(line)
  62. #经验条件熵公式
  63. return sum([ (dic[k]/datalen)*(emp_entropy(data_dic[k],c)) for k in dic.keys()])
  64. def infgain(dataset,C,A):
  65. #经验熵
  66. D = emp_entropy(dataset,C)
  67. #所有特征的信息增益值
  68. return [D-emp_cdtl_entropy(dataset, Ai) for Ai in A]
  69. def infgainrate(dataset,C,A):
  70. #经验熵
  71. D = emp_entropy(dataset,C)
  72. #所有特征的信息增益比值
  73. return [(D-emp_cdtl_entropy(dataset, Ai))/emp_entropy(dataset, Ai) for Ai in A]
  74. def getCk(Data,C_key):
  75. cat = [p[-1] for p in Data]
  76. maxck = None
  77. for Cv in self.C[C_key]:
  78. if cat.count(Cv) > cat.count(maxck):
  79. maxck = Cv
  80. return maxck
  81. def getMaxAg(ifglist):
  82. Max = 0.0
  83. Ag = 0
  84. for i,ifg in enumerate(ifglist):
  85. if ifg > Max:
  86. Max = ifg
  87. Ag = i
  88. return Max,Ag
  89. #ID3 C4.5 算法
  90. def id3orC45(node,Data,C,A,ICflag):
  91. #1.剩下的是不是同一类的
  92. if len(set([p[-1] for p in Data])) == 1:
  93. node.value = Data
  94. node.ck = Data[0][-1]
  95. return
  96. #2.A=None时
  97. elif A == None or A == []:
  98. C_key = list(self.C.keys())[0]
  99. node.ck = getCk(Data,C_key)
  100. node.value = Data
  101. return
  102. if ICflag == ID3_Flag:
  103. #3.计算信息增益
  104. ifglist = infgain(Data, C, A)
  105. Max,Ag = getMaxAg(ifglist)
  106. #4.是否小于阈值
  107. if Max < epsilon:
  108. node.ck = getCk(Data,C_key)
  109. node.value = Data
  110. return
  111. else:
  112. #5.切分数据集D
  113. spdict = {}
  114. for line in Data:
  115. if line[Ag] not in spdict:
  116. spdict[line[Ag]] = []
  117. spdict[line[Ag]].append(line)
  118. node.next_nodelist = []
  119. #6.A-Ag 以Di递归
  120. ag = A[Ag]
  121. print(ag)
  122. A.pop(Ag)
  123. for k in spdict.keys():
  124. nodei = DtreeStruct(Ai = ag,Aivalue=k)
  125. node.next_nodelist.append(nodei)
  126. id3orC45(nodei, spdict[k], C, A, ID3_Flag)
  127. elif ICflag == C45_Flag:
  128. #3.计算信息增益比
  129. ifglist = infgainrate(Data, C, A)
  130. Max,Ag = getMaxAg(ifglist)
  131. #4.是否小于阈值
  132. if Max < epsilon:
  133. node.ck = getCk(Data,C_key)
  134. node.value = Data
  135. return
  136. else:
  137. #5.切分数据集D
  138. spdict = {}
  139. for line in Data:
  140. if line[Ag] not in spdict:
  141. spdict[line[Ag]] = []
  142. spdict[line[Ag]].append(line)
  143. node.next_nodelist = []
  144. ag = A[Ag]
  145. print(ag)
  146. A.pop(Ag)
  147. for k in spdict.keys():
  148. nodei = DtreeStruct(Ai = ag,Aivalue=k)
  149. node.next_nodelist.append(nodei)
  150. id3orC45(nodei, spdict[k], C, A, C45_Flag)
  151. C = list(self.C.keys())[0]
  152. A = list(self.A.keys())
  153. self.tree = DtreeStruct()
  154. id3orC45(self.tree, self.datasets, C, A, ICflag)
  155. #CRAT算法
  156. def CART(self):
  157. def Gini(data):
  158. datalen = len(data)
  159. dic = {}
  160. indx = -1
  161. for line in data:
  162. if line[indx] not in dic:
  163. dic[line[indx]] = 0
  164. #以类别分类
  165. dic[line[indx]] += 1
  166. val = 1-sum([(i/datalen)**2 for i in dic.values()])
  167. return val
  168. def ContialGini(data,Ai):
  169. val = []
  170. mydata = None
  171. otherdata = None
  172. indx = self.labels.index(Ai)
  173. datalen = len(data)
  174. for aival in self.A[Ai]:
  175. mydata = []
  176. otherdata = []
  177. for line in data:
  178. if line[indx] == aival:
  179. mydata.append(line)
  180. else:
  181. otherdata.append(line)
  182. mycont = len(mydata)
  183. ohtercont =len(otherdata)
  184. val.append(mycont/datalen*Gini(mydata)+ohtercont/datalen*Gini(otherdata))
  185. return val
  186. def AGini(data,A):
  187. return [ContialGini(data, Ai) for Ai in A]
  188. print(AGini(self.datasets, self.A))

main.py

  1. # -*- coding: utf-8 -*-
  2. import pandas as pd
  3. import DTree
  4. def create_data():
  5. datasets = [['青年', '否', '否', '一般', '否'],
  6. ['青年', '否', '否', '好', '否'],
  7. ['青年', '是', '否', '好', '是'],
  8. ['青年', '是', '是', '一般', '是'],
  9. ['青年', '否', '否', '一般', '否'],
  10. ['中年', '否', '否', '一般', '否'],
  11. ['中年', '否', '否', '好', '否'],
  12. ['中年', '是', '是', '好', '是'],
  13. ['中年', '否', '是', '非常好', '是'],
  14. ['中年', '否', '是', '非常好', '是'],
  15. ['老年', '否', '是', '非常好', '是'],
  16. ['老年', '否', '是', '好', '是'],
  17. ['老年', '是', '否', '好', '是'],
  18. ['老年', '是', '否', '非常好', '是'],
  19. ['老年', '否', '否', '一般', '否'],
  20. ]
  21. labels = [u'年龄', u'有工作', u'有自己的房子', u'信贷情况', u'类别']
  22. # 返回数据集和每个维度的名称
  23. return datasets, labels
  24. def main():
  25. datasets, labels = create_data()
  26. dtree = DTree.DTree(datasets,labels)
  27. dtree.ID3CreateTree(0.1,DTree.ID3_Flag)
  28. dtree.tree.Print( )
  29. pass
  30. if __name__ == '__main__':
  31. main()

5.2 试用平方误差准则生成一个二叉回归树

这道题我使用的是书上说的最小二乘法,直接上代码(程序有点混乱,精神不太好):

  1. # -*- coding: utf-8 -*-
  2. class Ctree:
  3. def __init__(self,spvalue=None):
  4. self.spvalue = spvalue
  5. self.value = None
  6. self.lnode = None
  7. self.rnode = None
  8. class CART:
  9. def __init__(self,data):
  10. self.data = sorted(data)
  11. self.tree = None
  12. def SqartLost(self):
  13. def sumspart(data,c):
  14. sum= 0.0
  15. for d in data:
  16. sum+= (d - c)**2
  17. return sum
  18. def splittosumsq(data,c):
  19. datal = []
  20. datah = []
  21. for d in data:
  22. if d <= c:
  23. datal.append(d)
  24. else:
  25. datah.append(d)
  26. return sumspart(datal, c)+sumspart(datah, c)
  27. def createTree(node,data):
  28. if len(data) == 1:
  29. node.value = data
  30. return
  31. minsp = float("inf")
  32. indx = 0
  33. for i,d in enumerate(data[1:-2]):
  34. sp = splittosumsq(data, d)
  35. if minsp > sp:
  36. minsp = sp
  37. indx = i
  38. node.value = data
  39. node.lnode = Ctree(data[indx])
  40. createTree(node.lnode, data[0:indx+1])
  41. node.rnode = Ctree(data[indx])
  42. createTree(node.rnode, data[indx+1:])
  43. self.tree = Ctree()
  44. createTree(self.tree, self.data)
  45. def Print(self):
  46. def Fprint(node):
  47. if node.lnode == None:
  48. print("叶节点:",node.value)
  49. return
  50. print("节点:",node.spvalue,node.value)
  51. Fprint(node.lnode)
  52. Fprint(node.rnode)
  53. print("根节点")
  54. Fprint(self.tree)
  55. def main():
  56. data = [4.50,4.75,4.91,5.34,5.80,7.05,7.90,8.23,8.70,9.00]
  57. cart = CART(data)
  58. cart.SqartLost()
  59. cart.Print()
  60. pass
  61. if __name__ == '__main__':
  62. main()

证明题就不证了

5.3 证明CART剪枝算法中,当αα确定时,存在唯一的最小子树TαTα使损失函数Cα(T)Cα(T)最小。

5.4 证明CART剪枝算法中求出的子树序列{T0,T1,…,Tn}{T0,T1,…,Tn}分别是区间α∈[αi,αi+1)α∈[αi,αi+1)的最优子树TαTα,这里i=0,1,…,n,0=α0<α1<⋯<αn<+∞.
 

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/331320
推荐阅读
相关标签
  

闽ICP备14008679号