当前位置:   article > 正文

python实现线性可分支持向量机模型_随机生成两个样本点,每一类不少于100的数据集

随机生成两个样本点,每一类不少于100的数据集

内容:

        随机生成两类且维数为2的100个样本的数据集(注:每类均为100个样本) ,使用2/3数据训练支持向量机,剩余1/3数据进行测试,计算正确率。

代码:

        实在不想写了就来这看看吧~~~

  1. import math # 数学
  2. import random # 随机
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. def zhichi_w(zhichi, xy, a): # 计算更新 w
  6. w = [0, 0]
  7. if len(zhichi) == 0: # 初始化的0
  8. return w
  9. for i in zhichi:
  10. w[0] += a[i] * xy[0][i] * xy[2][i] # 更新w
  11. w[1] += a[i] * xy[1][i] * xy[2][i]
  12. return w
  13. def zhichi_b(zhichi, xy, a): # 计算更新 b
  14. b = 0
  15. if len(zhichi) == 0: # 初始化的0
  16. return 0
  17. for s in zhichi: # 对任意的支持向量有 ysf(xs)=1 所有支持向量求解平均值
  18. sum = 0
  19. for i in zhichi:
  20. sum += a[i] * xy[2][i] * (xy[0][i] * xy[0][s] + xy[1][i] * xy[1][s])
  21. b += 1 / xy[2][s] - sum
  22. return b / len(zhichi)
  23. def SMO(xy):
  24. # 初始化=========================================
  25. fx = [] # 储存所有的fx
  26. yfx = [] # 储存所有yfx-1的值
  27. Ek = [] # Ek,记录fx-y用于启发式搜索
  28. E_ = -1 # 贮存最大偏差,减少计算
  29. a1 = 0 # SMO a1
  30. a2 = 0 # SMO a2
  31. # 初始化结束======================================
  32. a = [0.0] * len(xy[0]) # 拉格朗日乘子
  33. zhichi = set() # 支持向量下标
  34. loop = 1 # 循环标记(符合KKT)
  35. w = [0, 0] # 初始化 w
  36. b = 0 # 初始化 b
  37. while loop:
  38. loop += 1
  39. if loop == 50:
  40. print("达到早停标准")
  41. print("循环了:", loop, "次")
  42. loop=0
  43. break
  44. # 初始化=========================================
  45. fx = [] # 储存所有的fx
  46. yfx = [] # 储存所有yfx-1的值
  47. Ek = [] # Ek,记录fx-y用于启发式搜索
  48. E_ = -1 # 贮存最大偏差,减少计算
  49. a1 = 0 # SMO a1
  50. a2 = 0 # SMO a2
  51. # 初始化结束======================================
  52. # 寻找a1,a2======================================
  53. for i in range(len(xy[0])): # 计算所有的 fx yfx-1 Ek
  54. fx.append(w[0] * xy[0][i] + w[1] * xy[1][i] + b) # 计算 fx=wx+b
  55. yfx.append(xy[2][i] * fx[i] - 1) # 计算 yfx-1
  56. Ek.append(fx[i] - xy[2][i]) # 计算 fx-y
  57. if i in zhichi: # 之前看过的不看了,防止重复找某个a
  58. continue
  59. if yfx[i] <= yfx[a1]:
  60. a1 = i # 得到偏离最大位置的下标(数值最小的)
  61. if yfx[a1] >= 0: # 最小的也满足KKT
  62. print("循环了:", loop, "次")
  63. loop = 0 # 循环标记(符合KKT)置零(没有用到)
  64. break
  65. for i in range(len(xy[0])): # 遍历找间隔最大的a2
  66. if i == a1: # 如果是a1,跳过
  67. continue
  68. Ei = abs(Ek[i] - Ek[a1]) # |Eki-Eka1|
  69. if Ei < E_: # 找偏差
  70. E_ = Ei # 储存偏差的值
  71. a2 = i # 储存偏差的下标
  72. # 寻找a1,a2结束===================================
  73. zhichi.add(a1) # a1录入支持向量
  74. zhichi.add(a2) # a2录入支持向量
  75. # 分析约束条件=====================================
  76. # c=a1*y1+a2*y2
  77. c = a[a1] * xy[2][a1] + a[a2] * xy[2][a2] # 求出c
  78. # n=K11+k22-2*k12
  79. n = xy[0][a1] ** 2 + xy[1][a1] ** 2 + xy[0][a2] ** 2 + xy[1][a2] ** 2 - 2 * (
  80. xy[0][a1] * xy[0][a2] + xy[1][a1] * xy[1][a2])
  81. # 确定a1的可行域=====================================
  82. if xy[2][a1] == xy[2][a2]:
  83. L = max(0.0, a[a1] + a[a2] - 0.5) # 下界
  84. H = min(0.5, a[a1] + a[a2]) # 上界
  85. else:
  86. L = max(0.0, a[a1] - a[a2]) # 下界
  87. H = min(0.5, 0.5 + a[a1] - a[a2]) # 上界
  88. if n > 0:
  89. a1_New = a[a1] - xy[2][a1] * (Ek[a1] - Ek[a2]) / n # a1_New = a1_old-y1(e1-e2)/n
  90. # print("x1=",xy[0][a1],"y1=",xy[1][a1],"z1=",xy[2][a1],"x2=",xy[0][a2],"y2=",xy[1][a2],"z2=",xy[2][a2],"a1_New=",a1_New)
  91. # 越界裁剪============================================================
  92. if a1_New >= H:
  93. a1_New = H
  94. elif a1_New <= L:
  95. a1_New = L
  96. else:
  97. a1_New = min(H, L)
  98. # 参数更新=======================================
  99. a[a2] = a[a2] + xy[2][a1] * xy[2][a2] * (a[a1] - a1_New) # a2更新
  100. a[a1] = a1_New # a1更新
  101. w = zhichi_w(zhichi, xy, a) # 更新w
  102. b = zhichi_b(zhichi, xy, a) # 更新b
  103. # print("W=", w, "b=", b, "zhichi=", zhichi, "a1=", a[a1], "a2=", a[a2])
  104. # 标记支持向量======================================
  105. for i in zhichi:
  106. if a[i] == 0: # 选了,但值仍为0
  107. loop = loop + 1
  108. e = 'silver'
  109. else:
  110. if xy[2][i] == 1:
  111. e = 'b'
  112. else:
  113. e = 'r'
  114. plt.scatter(x1[0][i], x1[1][i], c='none', s=100, linewidths=1, edgecolor=e)
  115. print("支持向量数为:", len(zhichi), "\na为零支持向量:", loop)
  116. print("有用向量数:", len(zhichi) - loop)
  117. # 返回数据 w b =======================================
  118. return [w, b]
  119. def panduan(xyz, w_b):
  120. c = 0
  121. for i in range(len(xyz[0])):
  122. if (xyz[0][i] * w_b[0][0] + xyz[1][i] * w_b[0][1] + w_b[1]) * xyz[2][i] < 0:
  123. c = c + 1
  124. return (1 - c / len(xyz[0])) * 100
  125. # 生成数据集=============================================
  126. x = [] # 数据集x属性
  127. y = [] # 数据集y属性
  128. x.extend(np.random.normal(loc=40.0, scale=10, size=100)) # 向x中放100个均值为30,方差为10,正态分布的随机数
  129. x.extend(np.random.normal(loc=80.0, scale=10, size=100)) # 向x中放100个均值为80,方差为10,正态分布的随机数
  130. y.extend(np.random.normal(loc=80.0, scale=10, size=100)) # 向y中放100个均值为80,方差为10,正态分布的随机数
  131. y.extend(np.random.normal(loc=40.0, scale=10, size=100)) # 向y中放100个均值为30,方差为10,正态分布的随机数
  132. c = [1] * 100 # c标签第一类为 1 x均值:30 y均值:80
  133. c.extend([-1] * 100) # # c标签第二类为 -1 x均值:80 y均值:30
  134. # 生成训练集与测试集=======================================
  135. lt = list(range(200)) # 得到一个顺序序列
  136. random.shuffle(lt) # 打乱序列
  137. x1 = [[], [], []] # 初始化x1
  138. x2 = [[], [], []] # 初始化x2
  139. for i in lt[0:150]: # 截取部分做训练集
  140. x1[0].append(x[i]) # 加上数据集x属性
  141. x1[1].append(y[i]) # 加上数据集y属性
  142. x1[2].append(c[i]) # 加上数据集c标签
  143. for i in lt[150:200]: # 截取部分做测试集
  144. x2[0].append(x[i]) # 加上数据集x属性
  145. x2[1].append(y[i]) # 加上数据集y属性
  146. x2[2].append(c[i]) # 加上数据集c标签
  147. # 计算 w b============================================
  148. plt.figure(1) # 第一张画布
  149. wb = SMO(x1)
  150. print("w为:", wb[0], " b为:", wb[1])
  151. # 计算正确率===========================================
  152. print("训练集上的正确率为:", panduan(x1, wb), "%")
  153. print("测试集上的正确率为:", panduan(x2, wb), "%")
  154. # 绘图 ===============================================
  155. # 红色和蓝色是训练集,圈着的是曾经选中的值,灰色的是选中但更新为0
  156. # 黄色和蓝色是测试集
  157. for i in range(len(x1[2])): # 对训练集‘上色’
  158. if x1[2][i] == 1:
  159. x1[2][i] = 'r' # 训练集 1 红色
  160. else:
  161. x1[2][i] = 'b' # 训练集 -1 蓝色
  162. for i in range(len(x2[2])): # 对测试集‘上色’
  163. if x2[2][i] == 1:
  164. x2[2][i] = 'y' # 测试集 1 黄色
  165. else:
  166. x2[2][i] = 'g' # 测试集 -1 绿色
  167. plt.scatter(x1[0], x1[1], c=x1[2], alpha=0.8) # 绘点训练集
  168. plt.scatter(x2[0], x2[1], c=x2[2], alpha=0.8) # 绘点测试集
  169. plt.xlabel('x') # x轴标签
  170. plt.ylabel('y') # y轴标签
  171. plt.title('y=wx+b') # 标题
  172. plt.xlim((0, 120))
  173. plt.ylim((0, 120))
  174. xl = np.arange(-10, 120, 0.1) # 绘制分类线
  175. yl = (-wb[0][0] * xl - wb[1]) / wb[0][1]
  176. plt.plot(xl, yl, 'k')
  177. plt.show() # 显示

ps:

第78行:

曾经试过这么写

if Ei > E_:  # 找偏差最大的
    E_ = Ei  # 储存偏差最大的值
    a2 = i  # 储存偏差最大的下标

但是不如找最小的效果好,虽然循环次数减少,但是误差可能会更高。

示例:

 

 

结果实例:

 

 

 

重要的地方:

e=f(x)-y

n=x1^Tx1+x2^Tx2+2(x1^Tx2)

a1_{New} = a1_{old}-y1(e1-e2)/n

a2_{New} = a2_{old} + y_1y_2(a1_{old} - a1_{New})

bug繁多,请多指正。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/53395
推荐阅读
相关标签
  

闽ICP备14008679号