当前位置:   article > 正文

受限的玻尔兹曼机_受限玻尔兹曼机 matlab

受限玻尔兹曼机 matlab

将matlab代码http://code.google.com/p/matrbm/中rbmBB改写成Python代码,如下,参考文献为:

1 A Tutorial on Stochastic Approximation Algorithms for Training Restricted Boltzmann Machines and Deep Belief Nets

2 Inductive Principles for Learning Restricted Boltzmann Machines

3 Training products of experts by minimizing contrastive divergence

4 受限波尔兹曼机简介


  1. import matplotlib.pylab as plt
  2. import numpy as np
  3. import random
  4. from scipy.linalg import norm
  5. import PIL.Image
  6. class Rbm:
  7. def __init__(self,n_visul, n_hidden, max_epoch = 50, batch_size = 110, penalty = 2e-4, anneal = False, w = None, v_bias = None, h_bias = None):
  8. self.n_visible = n_visul
  9. self.n_hidden = n_hidden
  10. self.max_epoch = max_epoch
  11. self.batch_size = batch_size
  12. self.penalty = penalty
  13. self.anneal = anneal
  14. if w is None:
  15. self.w = np.random.random((self.n_visible, self.n_hidden)) * 0.1
  16. if v_bias is None:
  17. self.v_bias = np.zeros((1, self.n_visible))
  18. if h_bias is None:
  19. self.h_bias = np.zeros((1, self.n_hidden))
  20. def sigmod(self, z):
  21. return 1.0 / (1.0 + np.exp( -z ))
  22. def forward(self, vis):
  23. #if(len(vis.shape) == 1):
  24. #vis = np.array([vis])
  25. #vis = vis.transpose()
  26. #if(vis.shape[1] != self.w.shape[0]):
  27. vis = vis.transpose()
  28. pre_sigmod_input = np.dot(vis, self.w) + self.h_bias
  29. return self.sigmod(pre_sigmod_input)
  30. def backward(self, vis):
  31. #if(len(vis.shape) == 1):
  32. #vis = np.array([vis])
  33. #vis = vis.transpose()
  34. #if(vis.shape[0] != self.w.shape[1]):
  35. back_sigmod_input = np.dot(vis, self.w.transpose()) + self.v_bias
  36. return self.sigmod(back_sigmod_input)
  37. def batch(self):
  38. eta = 0.1
  39. momentum = 0.5
  40. d, N = self.x.shape
  41. num_batchs = int(round(N / self.batch_size)) + 1
  42. groups = np.ravel(np.repeat([range(0, num_batchs)], self.batch_size, axis = 0))
  43. groups = groups[0 : N]
  44. perm = range(0, N)
  45. random.shuffle(perm)
  46. groups = groups[perm]
  47. batch_data = []
  48. for i in range(0, num_batchs):
  49. index = groups == i
  50. batch_data.append(self.x[:, index])
  51. return batch_data
  52. def rbmBB(self, x):
  53. self.x = x
  54. eta = 0.1
  55. momentum = 0.5
  56. W = self.w
  57. b = self.h_bias
  58. c = self.v_bias
  59. Wavg = W
  60. bavg = b
  61. cavg = c
  62. Winc = np.zeros((self.n_visible, self.n_hidden))
  63. binc = np.zeros(self.n_hidden)
  64. cinc = np.zeros(self.n_visible)
  65. avgstart = self.max_epoch - 5;
  66. batch_data = self.batch()
  67. num_batch = len(batch_data)
  68. oldpenalty= self.penalty
  69. t = 1
  70. errors = []
  71. for epoch in range(0, self.max_epoch):
  72. err_sum = 0.0
  73. if(self.anneal):
  74. penalty = oldpenalty - 0.9 * epoch / self.max_epoch * oldpenalty
  75. for batch in range(0, num_batch):
  76. num_dims, num_cases = batch_data[batch].shape
  77. data = batch_data[batch]
  78. #forward
  79. ph = self.forward(data)
  80. ph_states = np.zeros((num_cases, self.n_hidden))
  81. ph_states[ph > np.random.random((num_cases, self.n_hidden))] = 1
  82. #backward
  83. nh_states = ph_states
  84. neg_data = self.backward(nh_states)
  85. neg_data_states = np.zeros((num_cases, num_dims))
  86. neg_data_states[neg_data > np.random.random((num_cases, num_dims))] = 1
  87. #forward one more time
  88. neg_data_states = neg_data_states.transpose()
  89. nh = self.forward(neg_data_states)
  90. nh_states = np.zeros((num_cases, self.n_hidden))
  91. nh_states[nh > np.random.random((num_cases, self.n_hidden))] = 1
  92. #update weight and biases
  93. dW = np.dot(data, ph) - np.dot(neg_data_states, nh)
  94. dc = np.sum(data, axis = 1) - np.sum(neg_data_states, axis = 1)
  95. db = np.sum(ph, axis = 0) - np.sum(nh, axis = 0)
  96. Winc = momentum * Winc + eta * (dW / num_cases - self.penalty * W)
  97. binc = momentum * binc + eta * (db / num_cases);
  98. cinc = momentum * cinc + eta * (dc / num_cases);
  99. W = W + Winc
  100. b = b + binc
  101. c = c + cinc
  102. self.w = W
  103. self.h_bais = b
  104. self.v_bias = c
  105. if(epoch > avgstart):
  106. Wavg -= (1.0 / t) * (Wavg - W)
  107. cavg -= (1.0 / t) * (cavg - c)
  108. bavg -= (1.0 / t) * (bavg - b)
  109. t += 1
  110. else:
  111. Wavg = W
  112. bavg = b
  113. cavg = c
  114. #accumulate reconstruction error
  115. err = norm(data - neg_data.transpose())
  116. err_sum += err
  117. print epoch, err_sum
  118. errors.append(err_sum)
  119. self.errors = errors
  120. self.hiden_value = self.forward(self.x)
  121. h_row, h_col = self.hiden_value.shape
  122. hiden_states = np.zeros((h_row, h_col))
  123. hiden_states[self.hiden_value > np.random.random((h_row, h_col))] = 1
  124. self.rebuild_value = self.backward(hiden_states)
  125. self.w = Wavg
  126. self.h_bais = b
  127. self.v_bias = c
  128. def visualize(self, X):
  129. D, N = X.shape
  130. s = int(np.sqrt(D))
  131. if s == int(np.floor(s)):
  132. num = int(np.ceil(np.sqrt(N)))
  133. a = np.zeros((num*s + num + 1, num * s + num + 1)) - 1.0
  134. x = 0
  135. y = 0
  136. for i in range(0, N):
  137. z = X[:,i]
  138. z = z.reshape(s,s,order='F')
  139. z = z.transpose()
  140. a[x*s+1+x - 1:x*s+s+x , y*s+1+y - 1:y*s+s+y ] = z
  141. x = x + 1
  142. if(x >= num):
  143. x = 0
  144. y = y + 1
  145. d = True
  146. else:
  147. a = X
  148. return a
  149. def readData(path):
  150. data = []
  151. for line in open(path, 'r'):
  152. ele = line.split(' ')
  153. tmp = []
  154. for e in ele:
  155. if e != '':
  156. tmp.append(float(e.strip(' ')))
  157. data.append(tmp)
  158. return data
  159. if __name__ == '__main__':
  160. data = readData('data.txt')
  161. data = np.array(data)
  162. data = data.transpose()
  163. rbm = Rbm(784, 100,max_epoch = 50)
  164. rbm.rbmBB(data)
  165. a = rbm.visualize(data)
  166. fig = plt.figure(1)
  167. ax = fig.add_subplot(111)
  168. ax.imshow(a)
  169. plt.title('original data')
  170. rebuild_value = rbm.rebuild_value.transpose()
  171. b = rbm.visualize(rebuild_value)
  172. fig = plt.figure(2)
  173. ax = fig.add_subplot(111)
  174. ax.imshow(b)
  175. plt.title('rebuild data')
  176. hidden_value = rbm.hiden_value.transpose()
  177. c = rbm.visualize(hidden_value)
  178. fig = plt.figure(3)
  179. ax = fig.add_subplot(111)
  180. ax.imshow(c)
  181. plt.title('hidden data')
  182. w_value = rbm.w
  183. d = rbm.visualize(w_value)
  184. fig = plt.figure(4)
  185. ax = fig.add_subplot(111)
  186. ax.imshow(d)
  187. plt.title('weight value(w)')
  188. plt.show()
程序中数据下载地址为: http://download.csdn.net/detail/zc02051126/5845977
产生的结果图片如下:



数据解释:

程序中变量data存储了数据,data的维数为784x5000,每一列代表一幅手写数字的图像数据,每一列中包括了784个像素,把这784个像素转化成28X28的矩阵数据,显示出来即可看出对应的数字,以第0列的数据为例,其手写数字为


实现的代码为:

  1. c = data[:,0]
  2. d = c.reshape(28,28,order='F')
  3. d = d.transpose()
  4. plt.imshow(d)
  5. plt.show()

完整资源在这里 http://download.csdn.net/detail/zc02051126/8286677




声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/码创造者/article/detail/944290
推荐阅读
相关标签
  

闽ICP备14008679号