当前位置:   article > 正文

机器学习(二):聚类算法1——K-means算法_kmeans如何更新中心点

kmeans如何更新中心点

Kmeans是一种经典的聚类算法,所谓聚类,是指在没有给出目标的情况下,将样本根据某种关系分为某几类。那在kmeans中,是根据样本点间的距离,将样本n分为k个类。

K-means实现步骤:

1.首先,输入数据N并确定聚类个数K

2.初始化聚类中心 :随机选K个初始中心点。

3.计算所有样本NK个中心点的距离,将其归到距离最近的一簇。

4.针对每一簇,计算该簇内所有样本到中心点距离的均值,最为新的中心点。

5.不断迭代,直到中心点不再改变或误差达到阈值。

还有一个与K-means算法非常类似的算法是K-medoids,步骤也与K-means一致,唯一的区别是k-means的中心是各个样本点的平均,可能是样本点中不存在的点。K-medoids的质心一定是某个样本点的值。

K-meansMATLAB实现:

 1.使用MATLAB自带的函数实现

  1. idx = kmeans(X,k) %将数据x分为k类,返回类标签
  2. idx = kmeans(X,k,Name,Value) %可以指定距离、使用新的初始值重复聚类的次数或使用并行计算。
  3. [idx,C] = kmeans(___) %返回值可以返回中心点的坐标
  4. [idx,C,sumd] = kmeans(___) %返回向量中点到质心距离的簇内总和sumd
  5. [idx,C,sumd,D] = kmeans(___) %返回输入矩阵中每个点到每个质心的距离D

K-medoids自带函数实现

  1. idx = kmedoids(X,k)
  2. idx = kmedoids(X,k,Name,Value)
  3. [idx,C] = kmedoids(___)
  4. [idx,C,sumd] = kmedoids(___)
  5. [idx,C,sumd,D] = kmedoids(___)
  6. [idx,C,sumd,D,midx] = kmedoids(___)
  7. [idx,C,sumd,D,midx,info] = kmedoids(___)

示例

  1. rng('default') % For reproducibility
  2. X = [randn(100,2)*0.75+ones(100,2);
  3. randn(100,2)*0.5-ones(100,2);
  4. randn(100,2)*0.75];
  5. [idx,C] = kmeans(X,3);
  6. figure
  7. gscatter(X(:,1),X(:,2),idx,'bgm')
  8. hold on
  9. plot(C(:,1),C(:,2),'kx')
  10. legend('Cluster 1','Cluster 2','Cluster 3','Cluster Centroid')

2.K-means代码实现

  1. clear all;
  2. clc;
  3. % 第一组数据
  4. mu1=[0 0 ]; %均值(是需要生成的数据的均值)
  5. S1=[.08 0 ;0 .08]; %协方差(需要生成的数据的自相关矩阵(相关系数矩阵))
  6. data1=mvnrnd(mu1,S1,3200); %产生高斯分布数据
  7. %第二组数据
  8. mu2=[1.5 1.5 ];
  9. S2=[.08 0 ;0 .08];
  10. data2=mvnrnd(mu2,S2,3200);
  11. % 第三组数据
  12. mu3=[-1.5 1.5 ];
  13. S3=[.08 0 ;0 .08];
  14. data3=mvnrnd(mu3,S3,3200);
  15. % 显示数据
  16. plot(data1(:,1),data1(:,2),'b.');
  17. hold on;%不覆盖原图,要关闭则使用hold off
  18. plot(data2(:,1),data2(:,2),'r.');
  19. plot(data3(:,1),data3(:,2),'g.');
  20. grid on;%显示表格
  21. % 三类数据合成一个不带标号的数据类
  22. data=[data1;data2;data3];
  23. N=3;%设置聚类数目
  24. [m,n]=size(data);%表示矩阵data大小,m行n列
  25. pattern=zeros(m,n+1);%生成0矩阵
  26. center=zeros(N,n);%初始化聚类中心
  27. pattern(:,1:n)=data(:,:);
  28. for x=1:N
  29. center(x,:)=data( randi(300,1),:);%第一次随机产生聚类中心
  30. end
  31. while 1 %循环迭代每次的聚类簇;
  32. distence=zeros(1,N);%最小距离矩阵
  33. num=zeros(1,N);%聚类簇数矩阵
  34. new_center=zeros(N,n);%聚类中心矩阵
  35. for x=1:m
  36. for y=1:N
  37. distence(y)=norm(data(x,:)-center(y,:));%计算到每个类的距离
  38. end
  39. [~, temp]=min(distence);%求最小的距离
  40. pattern(x,n+1)=temp;%划分所有对象点到最近的聚类中心;标记为1,2,3
  41. end
  42. k=0;
  43. for y=1:N
  44. for x=1:m
  45. if pattern(x,n+1)==y
  46. new_center(y,:)=new_center(y,:)+pattern(x,1:n);
  47. num(y)=num(y)+1;
  48. end
  49. end
  50. new_center(y,:)=new_center(y,:)/num(y);%求均值,即新的聚类中心;
  51. if norm(new_center(y,:)-center(y,:))<0.1%检查集群中心是否已收敛。如果是则终止。
  52. k=k+1;
  53. end
  54. end
  55. if k==N
  56. break;
  57. else
  58. center=new_center;
  59. end
  60. end
  61. [m, n]=size(pattern);
  62. %最后显示聚类后的数据
  63. figure;
  64. hold on;
  65. for i=1:m
  66. if pattern(i,n)==1
  67. plot(pattern(i,1),pattern(i,2),'r.');
  68. plot(center(1,1),center(1,2),'kp');%用小圆圈标记中心点;
  69. elseif pattern(i,n)==2
  70. plot(pattern(i,1),pattern(i,2),'g.');
  71. plot(center(2,1),center(2,2),'kp');
  72. elseif pattern(i,n)==3
  73. plot(pattern(i,1),pattern(i,2),'c.');
  74. plot(center(3,1),center(3,2),'kp');
  75. elseif pattern(i,n)==4
  76. plot(pattern(i,1),pattern(i,2),'y.');
  77. plot(center(4,1),center(4,2),'kp');
  78. else
  79. plot(pattern(i,1),pattern(i,2),'m.');
  80. plot(center(4,1),center(4,2),'kp');
  81. end
  82. end

3.K-means算法Python实现

Python代码来自机器学习(二)——K-均值聚类(K-means)算法 - 1ang - 博客园

  1. #k-means算法的实现
  2. #-*-coding:utf-8 -*-
  3. from numpy import *
  4. from math import sqrt
  5. import sys
  6. sys.path.append("C:/Users/Administrator/Desktop/k-means的python实现")
  7. def loadData(fileName):
  8. data = []
  9. fr = open(fileName)
  10. for line in fr.readlines():
  11. curline = line.strip().split('\t')
  12. frline = map(float,curline)
  13. data.append(frline)
  14. return data
  15. '''
  16. #test
  17. a = mat(loadData("C:/Users/Administrator/Desktop/k-means/testSet.txt"))
  18. print a
  19. '''
  20. #计算欧氏距离
  21. def distElud(vecA,vecB):
  22. return sqrt(sum(power((vecA - vecB),2)))
  23. #初始化聚类中心
  24. def randCent(dataSet,k):
  25. n = shape(dataSet)[1]
  26. center = mat(zeros((k,n)))
  27. for j in range(n):
  28. rangeJ = float(max(dataSet[:,j]) - min(dataSet[:,j]))
  29. center[:,j] = min(dataSet[:,j]) + rangeJ * random.rand(k,1)
  30. return center
  31. '''
  32. #test
  33. a = mat(loadData("C:/Users/Administrator/Desktop/k-means/testSet.txt"))
  34. n = 3
  35. b = randCent(a,3)
  36. print b
  37. '''
  38. def kMeans(dataSet,k,dist = distElud,createCent = randCent):
  39. m = shape(dataSet)[0]
  40. clusterAssment = mat(zeros((m,2)))
  41. center = createCent(dataSet,k)
  42. clusterChanged = True
  43. while clusterChanged:
  44. clusterChanged = False
  45. for i in range(m):
  46. minDist = inf
  47. minIndex = -1
  48. for j in range(k):
  49. distJI = dist(dataSet[i,:],center[j,:])
  50. if distJI < minDist:
  51. minDist = distJI
  52. minIndex = j
  53. if clusterAssment[i,0] != minIndex:#判断是否收敛
  54. clusterChanged = True
  55. clusterAssment[i,:] = minIndex,minDist ** 2
  56. print center
  57. for cent in range(k):#更新聚类中心
  58. dataCent = dataSet[nonzero(clusterAssment[:,0].A == cent)[0]]
  59. center[cent,:] = mean(dataCent,axis = 0)#axis是普通的将每一列相加,而axis=1表示的是将向量的每一行进行相加
  60. return center,clusterAssment
  61. '''
  62. #test
  63. dataSet = mat(loadData("C:/Users/Administrator/Desktop/k-means/testSet.txt"))
  64. k = 4
  65. a = kMeans(dataSet,k)
  66. print a
  67. '''

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/812935
推荐阅读
相关标签
  

闽ICP备14008679号