赞
踩
标签传播算法(LPA)的做法比较简单:
第一步:为所有节点指定一个唯一的标签;
第二步:逐轮刷新所有节点的标签,直到达到收敛要求为止。对于每一轮刷新,节点标签刷新的规则如下:
对于某一个节点,考察其所有邻居节点的标签,并进行统计,将出现个数最多的那个标签赋给当前节点。当个数最多的标签不唯一时,随机选一个。
注:算法中的记号 N_n^k 表示节点 n 的邻居中标签为 k 的所有节点构成的集合。
以上资料来源于:
http://blog.csdn.net/cleverlzc/article/details/39494957
下面我们来简单实现以下这个算法:
数据,自己编的:
1 2,3,4
2 1,3,4,7
3 1,2,4
4 1,2,3
5 6,7,8
6 5,7,8
7 2,5,6,8
8 5,6,7
加载数据用的函数:
- def loadLpaData(filename):
- f = open(filename,'r')
- data = {}
- for i in f.readlines():
- order,ship = i.split()[0],i.split()[1]
- ships = ship.split(',')
- data.setdefault(order,ships)
- f.close()
- return data
获取数目最多的相邻接点,有多个的话随机选一个:
- def getMost(ships):
- import collections
- counter = collections.Counter(ships)
- tmp = sorted(counter.items(),key = lambda x:x[1])
-
- maxc = tmp[-1][1]
- maxset = []
- for i in tmp:
- if i[1] == maxc:maxset.append(i[0])
-
- import random
- random.shuffle(maxset)
- return maxset[0]
更新标签:
- def updateShips(cluster,data):
- for _ in data.keys():
- data[_] = [cluster[i] for i in data[_]]
- def checkStatus(cluster,data):
- flag = 0
- for d in data.keys():
- if cluster[d] != getMost(data[d]):return 0
- return 1
主函数:
- def main(mydata):
- data = mydata.copy()
- cluster = dict([(_,_) for _ in data.keys()])
- while 1:
- if checkStatus(cluster,data):break
- for i in cluster.keys():
- cluster[i] = getMost(data[i])
- updateShips(cluster,data)
- return cluster
- data = loadLpaData('LPAdataset')
- main(data)
别的训练集没有测试。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。