当前位置:   article > 正文

【Python学习系列八】Python实现线性可分SVM(支持向量机)_python支持向量机实现线性分割

python支持向量机实现线性分割

1、运行环境:eclipse+pydev+Anaconda2-4.4.0(python2.7),含numpy、matplotlib(制图)。

2、代码:

  1. # -*- coding: utf-8 -*-
  2. __author__ = 'Jason.F'
  3. from numpy import *
  4. import matplotlib.pyplot as plt
  5. import operator
  6. import time
  7. #导入数据,格式: value1 value2 label
  8. #3.542485 1.977398 -1
  9. #3.018896 2.556416 1
  10. def loadDataSet(fileName):
  11. dataMat = []
  12. labelMat = []
  13. with open(fileName) as fr:
  14. for line in fr.readlines():
  15. lineArr = line.strip().split()
  16. labelMat.append(float(lineArr[2]))
  17. #i=lineArr.__len__()
  18. #for i in range(1,i):
  19. dataMat.append([float(lineArr[0]),float(lineArr[1])])
  20. return dataMat, labelMat
  21. def selectJrand(i, m):
  22. j = i
  23. while (j == i):
  24. j = int(random.uniform(0, m))
  25. return j
  26. def clipAlpha(aj, H, L):
  27. if aj > H:
  28. aj = H
  29. if L > aj:
  30. aj = L
  31. return aj
  32. class optStruct:
  33. def __init__(self, dataMatIn, classLabels, C, toler):
  34. self.X = dataMatIn
  35. self.labelMat = classLabels
  36. self.C = C
  37. self.tol = toler
  38. self.m = shape(dataMatIn)[0]
  39. self.alphas = mat(zeros((self.m, 1)))
  40. self.b = 0
  41. self.eCache = mat(zeros((self.m, 2)))
  42. def calcEk(oS, k):
  43. fXk = float(multiply(oS.alphas, oS.labelMat).T * (oS.X * oS.X[k, :].T)) + oS.b
  44. Ek = fXk - float(oS.labelMat[k])
  45. return Ek
  46. def selectJ(i, oS, Ei):
  47. maxK = -1
  48. maxDeltaE = 0
  49. Ej = 0
  50. oS.eCache[i] = [1, Ei]
  51. validEcacheList = nonzero(oS.eCache[:, 0].A)[0]
  52. if (len(validEcacheList)) > 1:
  53. for k in validEcacheList:
  54. if k == i:
  55. continue
  56. Ek = calcEk(oS, k)
  57. deltaE = abs(Ei - Ek)
  58. if (deltaE > maxDeltaE):
  59. maxK = k
  60. maxDeltaE = deltaE
  61. Ej = Ek
  62. return maxK, Ej
  63. else:
  64. j = selectJrand(i, oS.m)
  65. Ej = calcEk(oS, j)
  66. return j, Ej
  67. def updateEk(oS, k):
  68. Ek = calcEk(oS, k)
  69. oS.eCache[k] = [1, Ek]
  70. def innerL(i, oS):
  71. Ei = calcEk(oS, i)
  72. if ((oS.labelMat[i] * Ei < -oS.tol) and (oS.alphas[i] < oS.C)) or ((oS.labelMat[i] * Ei > oS.tol) and (oS.alphas[i] > 0)):
  73. j, Ej = selectJ(i, oS, Ei)
  74. alphaIold = oS.alphas[i].copy()
  75. alphaJold = oS.alphas[j].copy()
  76. if (oS.labelMat[i] != oS.labelMat[j]):
  77. L = max(0, oS.alphas[j] - oS.alphas[i])
  78. H = min(oS.C, oS.C + oS.alphas[j] - oS.alphas[i])
  79. else:
  80. L = max(0, oS.alphas[j] + oS.alphas[i] - oS.C)
  81. H = min(oS.C, oS.alphas[j] + oS.alphas[i])
  82. if (L == H):
  83. # print("L == H")
  84. return 0
  85. eta = 2.0 * oS.X[i, :] * oS.X[j, :].T - oS.X[i, :] * oS.X[i, :].T - oS.X[j, :] * oS.X[j, :].T
  86. if eta >= 0:
  87. # print("eta >= 0")
  88. return 0
  89. oS.alphas[j] -= oS.labelMat[j] * (Ei - Ej) / eta
  90. oS.alphas[j] = clipAlpha(oS.alphas[j], H, L)
  91. updateEk(oS, j)
  92. if (abs(oS.alphas[j] - alphaJold) < 0.00001):
  93. # print("j not moving enough")
  94. return 0
  95. oS.alphas[i] += oS.labelMat[j] * oS.labelMat[i] * (alphaJold - oS.alphas[j])
  96. updateEk(oS, i)
  97. b1 = oS.b - Ei - oS.labelMat[i] * (oS.alphas[i] - alphaIold) * oS.X[i, :] * oS.X[i, :].T - oS.labelMat[j] * (oS.alphas[j] - alphaJold) * oS.X[i, :] * oS.X[j, :].T
  98. b2 = oS.b - Ei - oS.labelMat[i] * (oS.alphas[i] - alphaIold) * oS.X[i, :] * oS.X[j, :].T - oS.labelMat[j] * (oS.alphas[j] - alphaJold) * oS.X[j, :] * oS.X[j, :].T
  99. if (0 < oS.alphas[i]) and (oS.C > oS.alphas[i]):
  100. oS.b = b1
  101. elif (0 < oS.alphas[j]) and (oS.C > oS.alphas[j]):
  102. oS.b = b2
  103. else:
  104. oS.b = (b1 + b2) / 2.0
  105. return 1
  106. else:
  107. return 0
  108. def smoP(dataMatIn, classLabels, C, toler, maxIter, kTup=('lin', 0)):
  109. """
  110. 输入:数据集, 类别标签, 常数C, 容错率, 最大循环次数
  111. 输出:目标b, 参数alphas
  112. """
  113. oS = optStruct(mat(dataMatIn), mat(classLabels).transpose(), C, toler)
  114. iterr = 0
  115. entireSet = True
  116. alphaPairsChanged = 0
  117. while (iterr < maxIter) and ((alphaPairsChanged > 0) or (entireSet)):
  118. alphaPairsChanged = 0
  119. if entireSet:
  120. for i in range(oS.m):
  121. alphaPairsChanged += innerL(i, oS)
  122. # print("fullSet, iter: %d i:%d, pairs changed %d" % (iterr, i, alphaPairsChanged))
  123. iterr += 1
  124. else:
  125. nonBoundIs = nonzero((oS.alphas.A > 0) * (oS.alphas.A < C))[0]
  126. for i in nonBoundIs:
  127. alphaPairsChanged += innerL(i, oS)#内积
  128. # print("non-bound, iter: %d i:%d, pairs changed %d" % (iterr, i, alphaPairsChanged))
  129. iterr += 1
  130. if entireSet:
  131. entireSet = False
  132. elif (alphaPairsChanged == 0):
  133. entireSet = True
  134. # print("iteration number: %d" % iterr)
  135. return oS.b, oS.alphas
  136. def calcWs(alphas, dataArr, classLabels):
  137. """
  138. 输入:alphas, 数据集, 类别标签
  139. 输出:目标w
  140. """
  141. X = mat(dataArr)
  142. labelMat = mat(classLabels).transpose()
  143. m, n = shape(X)
  144. w = zeros((n, 1))
  145. for i in range(m):
  146. w += multiply(alphas[i] * labelMat[i], X[i, :].T)
  147. return w
  148. def plotFeature(dataMat, labelMat, weights, b):
  149. dataArr = array(dataMat)
  150. n = shape(dataArr)[0]
  151. xcord1 = []; ycord1 = []
  152. xcord2 = []; ycord2 = []
  153. for i in range(n):
  154. if int(labelMat[i]) == 1:
  155. xcord1.append(dataArr[i, 0])
  156. ycord1.append(dataArr[i, 1])
  157. else:
  158. xcord2.append(dataArr[i, 0])
  159. ycord2.append(dataArr[i, 1])
  160. fig = plt.figure()
  161. ax = fig.add_subplot(111)
  162. ax.scatter(xcord1, ycord1, s=30, c='red', marker='s')
  163. ax.scatter(xcord2, ycord2, s=30, c='green')
  164. x = arange(2, 7.0, 0.1)
  165. y = (-b[0, 0] * x) - 10 / linalg.norm(weights)
  166. ax.plot(x, y)
  167. plt.xlabel('X1'); plt.ylabel('X2')
  168. plt.show()
  169. def main():
  170. trainDataSet, trainLabel = loadDataSet('D:\set.txt')
  171. b, alphas = smoP(trainDataSet, trainLabel, 0.6, 0.0001, 40)
  172. ws = calcWs(alphas, trainDataSet, trainLabel)
  173. print("ws = \n", ws)
  174. print("b = \n", b)
  175. plotFeature(trainDataSet, trainLabel, ws, b)
  176. if __name__ == '__main__':
  177. start = time.clock()
  178. main()
  179. end = time.clock()
  180. print('finish all in %s' % str(end - start))

3、set.txt样例数据

3.542485    1.977398    -1
3.018896    2.556416    -1
7.551510    -1.580030   1
2.114999    -0.004466   -1
8.127113    1.274372    1
7.108772    -0.986906   1
8.610639    2.046708    1
2.326297    0.265213    -1
3.634009    1.730537    -1
0.341367    -0.894998   -1
3.125951    0.293251    -1
2.123252    -0.783563   -1
0.887835    -2.797792   -1
7.139979    -2.329896   1
1.696414    -1.212496   -1
8.117032    0.623493    1
8.497162    -0.266649   1
4.658191    3.507396    -1
8.197181    1.545132    1
1.208047    0.213100    -1
1.928486    -0.321870   -1
2.175808    -0.014527   -1
7.886608    0.461755    1
3.223038    -0.552392   -1
3.628502    2.190585    -1
7.407860    -0.121961   1
7.286357    0.251077    1
2.301095    -0.533988   -1
-0.232542   -0.547690   -1
3.457096    -0.082216   -1
3.023938    -0.057392   -1
8.015003    0.885325    1
8.991748    0.923154    1
7.916831    -1.781735   1
7.616862    -0.217958   1
2.450939    0.744967    -1
7.270337    -2.507834   1
1.749721    -0.961902   -1
1.803111    -0.176349   -1
8.804461    3.044301    1
1.231257    -0.568573   -1
2.074915    1.410550    -1
-0.743036   -1.736103   -1
3.536555    3.964960    -1
8.410143    0.025606    1
7.382988    -0.478764   1
6.960661    -0.245353   1
8.234460    0.701868    1
8.168618    -0.903835   1
1.534187    -0.622492   -1
9.229518    2.066088    1
7.886242    0.191813    1
2.893743    -1.643468   -1
1.870457    -1.040420   -1
5.286862    -2.358286   1
6.080573    0.418886    1
2.544314    1.714165    -1
6.016004    -3.753712   1
0.926310    -0.564359   -1
0.870296    -0.109952   -1
2.369345    1.375695    -1
1.363782    -0.254082   -1
7.279460    -0.189572   1
1.896005    0.515080    -1
8.102154    -0.603875   1
2.529893    0.662657    -1
1.963874    -0.365233   -1
8.132048    0.785914    1
8.245938    0.372366    1
6.543888    0.433164    1
-0.236713   -5.766721   -1
8.112593    0.295839    1
9.803425    1.495167    1
1.497407    -0.552916   -1
1.336267    -1.632889   -1
9.205805    -0.586480   1
1.966279    -1.840439   -1
8.398012    1.584918    1
7.239953    -1.764292   1
7.556201    0.241185    1
9.015509    0.345019    1
8.266085    -0.230977   1
8.545620    2.788799    1
9.295969    1.346332    1
2.404234    0.570278    -1
2.037772    0.021919    -1
1.727631    -0.453143   -1
1.979395    -0.050773   -1
8.092288    -1.372433   1
1.667645    0.239204    -1
9.854303    1.365116    1
7.921057    -1.327587   1
8.500757    1.492372    1
1.339746    -0.291183   -1
3.107511    0.758367    -1
2.609525    0.902979    -1
3.263585    1.367898    -1
2.912122    -0.202359   -1
1.731786    0.589096    -1
2.387003    1.573131    -1

4、执行结果:

  1. ('ws = \n', array([[ 0.65307162],
  2. [-0.17196128]]))
  3. ('b = \n', matrix([[-2.89901748]]))
  4. finish all in 19.5581056613


声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/53400
推荐阅读
相关标签
  

闽ICP备14008679号