当前位置:   article > 正文

AI-机器学习-自学笔记(七)支持向量机(SVG)算法_svg算法

svg算法

       支持向量机(Support Vector Machine, SVM)是一类按监督学习(supervised learning)方式对数据进行二元分类的广义线性分类器(generalized linear classifier),其决策边界是对学习样本求解的最大边距超平面(maximum-margin hyperplane)

       在二维空间上,两类点被一条直线完全分开叫做线性可分。从二维扩展到多维空间中时,将两类N维空间完全分开的N-1维面就成了一个超平面。

 这些靠近超平面最近的一些点,就称为支持向量

对于非线性问题,运用核函数将数据映射到高维空间后应用线性SVM可获得解决。

SVM在scikit- learn 中的实现类是 SVC 类,我们通过一个简单的例子来演示一下:

  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. from sklearn import svm
  4. def loadDataSet(fileName):
  5. """
  6. Args:
  7. fileName 文件名
  8. Returns:
  9. dataMat 数据矩阵
  10. labelMat 类标签
  11. """
  12. dataMat = []
  13. labelMat = []
  14. fr = open(fileName)
  15. for line in fr.readlines():
  16. lineArr = line.strip().split(',')
  17. dataMat.append([float(lineArr[0]), float(lineArr[1])])
  18. labelMat.append(float(lineArr[2]))
  19. return dataMat, labelMat
  20. X, Y = loadDataSet('./data/datalog2.txt')
  21. X = np.mat(X)
  22. print("X=", X[:5])
  23. print("Y=", Y[:5])
  24. clf = svm.SVC(C=8,kernel='linear',gamma=10,probability=True)
  25. #SVC(C=5, cache_size=200, class_weight=None, coef0=0.0,
  26. #, decision_function_shape='ovr', degree=3, gamma=10, kernel='linear',
  27. #, max_iter=-1, probability=False, random_state=None, shrinking=True,
  28. #, tol=0.001, verbose=False)
  29. clf.fit(X, Y)
  30. # 获取分割超平面
  31. w = clf.coef_[0]
  32. # 斜率
  33. a = -w[0] / w[1]
  34. # 从-2到10,顺序间隔采样50个样本,默认是num=50
  35. xx = np.linspace(-2, 10) # , num=50)
  36. # 二维的直线方程
  37. yy = a * xx - (clf.intercept_[0]) / w[1]
  38. print("yy=", yy)
  39. print("support_vectors_=", clf.support_vectors_)
  40. b = clf.support_vectors_[0]
  41. yy_down = a * xx + (b[1] - a * b[0])
  42. b = clf.support_vectors_[-1]
  43. yy_up = a * xx + (b[1] - a * b[0])
  44. plt.plot(xx, yy, 'k-')
  45. plt.plot(xx, yy_down, 'k--')
  46. plt.plot(xx, yy_up, 'k--')
  47. plt.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1], s=80, facecolors='none')
  48. plt.scatter(X[:, 0].flat, X[:, 1].flat, c=Y, cmap=plt.cm.Paired)
  49. plt.axis('tight')
  50. plt.show()

运行后得到下图

 我们再用scikit-learn中自带的手写数字数据集进行实验

  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import scipy,cv2,imageio
  4. from sklearn import svm
  5. from sklearn.datasets import load_digits
  6. from sklearn.model_selection import train_test_split
  7. from fractions import Fraction
  8. from skimage.transform import resize
  9. #读取sklearn.datasets自带的手写数字数据集
  10. datas = load_digits()
  11. #print(datas.data[1])
  12. #前63个值为特征,赋值给x,最后一个值是分类,赋值给y
  13. x = datas.data[:, :-1]
  14. y = datas.target
  15. x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=666)
  16. #调用svm.SVC方法进行训练
  17. clf = svm.SVC(C=8,kernel='linear',gamma=10,probability=True)
  18. #SVC(C=5, cache_size=200, class_weight=None, coef0=0.0,
  19. #, decision_function_shape='ovr', degree=3, gamma=10, kernel='linear',
  20. #, max_iter=-1, probability=False, random_state=None, shrinking=True,
  21. #, tol=0.001, verbose=False)
  22. clf.fit(x, y)
  23. #print(clf.predict(x[0:15]))#必须以区间取值的方式,4:5 其实就是取 4 这个值
  24. #训练集准确率
  25. print("Train :", clf.score(x_train, y_train))
  26. #测试集准确率
  27. print("Test :", clf.score(x_test, y_test))
  28. #以下为实现用训练好的模型识别自己手写的图片
  29. #图片处理函数,主要是把图片压缩为8*8的格式(和数据集一致),包括变灰度、黑白反转
  30. def image2Digit(image):
  31. # 调整为8*8大小
  32. #im_resized = scipy.misc.imresize(image, (8,8))#scipy.misc.imresize这个函数现在不能用了
  33. #print(image.shape)
  34. im_resized=cv2.resize(image,(8, 8))
  35. #print('im_resized:')
  36. #print(im_resized.shape)
  37. im_resized2=im_resized.astype(np.float32) #这里是个坑,CV2默认数据格式是float64的,np默认格式是float32的,这里要把数据格式转一下,否则后面会报错
  38. #print('im_resized2:')
  39. #print(im_resized2)
  40. # RGB(三维)转为灰度图(一维)
  41. im_gray = cv2.cvtColor(im_resized2, cv2.COLOR_BGR2GRAY)
  42. #print('im_gray')
  43. #print(im_gray.shape)
  44. # 调整为0-16之间(digits训练数据的特征规格)像素值——16/255
  45. im_hex = Fraction(16,255) * im_gray
  46. #print('im_hex')
  47. #print(im_hex)
  48. # 将图片数据反相(digits训练数据的特征规格——黑底白字)
  49. im_reverse = 16 - im_hex
  50. return im_reverse.astype(np.int)
  51. #图片文件路径
  52. fp='data/numbers/test1.png'
  53. # 读取单张自定义手写数字的图片
  54. #image = scipy.misc.imread(fp) #新版本scipy不支持imread,可以用imageio.imread代替
  55. image = imageio.imread(fp)
  56. # 调用上面的函数,将图片转为digits训练数据的规格——即数据的表征方式要统一
  57. im_reverse = image2Digit(image)
  58. # 显示图片转换后的像素值
  59. print(im_reverse)
  60. # 8*8转为1*64(预测方法的参数要求)
  61. reshaped = im_reverse.reshape(1,64)
  62. # 预测
  63. result = clf.predict(reshaped[:, :-1])
  64. print('识别到的数字为:{}'.format(result[0]))

打印结果如下:

  1. PS C:\coding\machinelearning>SVM-手写数字数据集实验.py
  2. Train : 1.0
  3. Test : 1.0
  4. [[ 0 0 0 0 0 0 0 0]
  5. [ 0 0 16 16 16 16 15 0]
  6. [ 0 0 16 16 9 9 16 0]
  7. [ 0 0 0 0 0 16 16 0]
  8. [ 0 0 0 1 14 16 0 0]
  9. [ 0 0 16 16 16 8 0 0]
  10. [ 0 0 1 15 16 16 16 16]
  11. [ 0 0 0 0 0 0 0 2]]
  12. 识别到的数字为:2
  13. PS C:\coding\machinelearning>

从图形也能看出来,这是个数字2

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/377442
推荐阅读
相关标签
  

闽ICP备14008679号