当前位置:   article > 正文

自定义层 余弦相似度_自定义余弦相似度

自定义余弦相似度
  1. import tensorflow as tf
  2. from keras.layers import *
  3. import keras.backend as K
  4. from tensorflow.python.ops import control_flow_ops
  5. from tensorflow.python.ops import tensor_array_ops
  6. def cosine2(input_t, states):
  7. '''
  8. 余弦相似度
  9. :param q:
  10. :param a:
  11. :return:
  12. '''
  13. state_shape = K.int_shape(states)
  14. shape_s = state_shape[-1]
  15. batch_t = state_shape[0]
  16. pooled_len_1 = tf.sqrt(tf.reduce_sum(tf.multiply(input_t , input_t),1))
  17. pooled_len_1=tf.reshape(pooled_len_1, (batch_t, 1))
  18. pooled_len_1 = tf.tile(pooled_len_1,[1,3])
  19. pooled_len_2 = tf.sqrt(tf.reduce_sum(tf.multiply(states , states),2))
  20. pooled_mul_12 = tf.matmul(states, tf.reshape(input_t, (-1, shape_s, 1)))
  21. pooled_mul_12 = tf.transpose(pooled_mul_12, [2, 0, 1])
  22. score = tf.div(pooled_mul_12,pooled_len_1*pooled_len_2)
  23. score = tf.transpose(score, [1, 0, 2])
  24. return score
  25. def cosine22(input_t, states):
  26. '''
  27. 余弦相似度 numpy
  28. :param q:
  29. :param a:
  30. :return:
  31. '''
  32. state_shape =states.shape
  33. shape_s = state_shape[-1]
  34. # timestep = state_shape[1]
  35. pooled_len_1 = np.sqrt(np.sum(np.multiply(input_t , input_t),1))
  36. pooled_len_1=np.tile(np.reshape(pooled_len_1, (-1, 1)), 3)
  37. pooled_len_2 = np.sqrt(np.sum(np.multiply(states , states),2))
  38. pooled_mul_12 = np.matmul(states, np.reshape(input_t, (-1, shape_s, 1)))
  39. pooled_mul_12=np.transpose(pooled_mul_12, [2, 0, 1])
  40. score = np.divide(pooled_mul_12, pooled_len_1 * pooled_len_2)
  41. # score= np.matmul(states, np.reshape(input_t, (-1, shape_s, 1)))
  42. print(score)
  43. score=np.transpose(score,[1,0,2])
  44. print(score)
  45. return score
  46. inputs=tf.constant([[[1,2,3,4],[1,0,1,5],[0,1,7,1]],
  47. [[1,8,1,0],[4,1,1,2],[0,5,1,2]]],dtype="float32")
  48. state = tf.constant([[[1,0,1,1],[0,1,1,1],[0,2,1,1]],
  49. [[1,0,1,1],[2,1,1,1],[2,0,1,1]]],dtype="float32")
  50. def lala(inputs,state):
  51. input_shape = K.int_shape(inputs)
  52. input_length = input_shape[1]
  53. ndim = len(inputs.get_shape())
  54. axes = [1, 0] + list(range(2, ndim))
  55. inputs = tf.transpose(inputs, (axes))
  56. time_steps = tf.shape(inputs)[0]
  57. print(inputs.shape)
  58. print(state.shape)
  59. output_ta = tensor_array_ops.TensorArray(
  60. dtype=inputs.dtype,
  61. size=time_steps,
  62. tensor_array_name='output_ta')
  63. input_ta = tensor_array_ops.TensorArray(
  64. dtype=inputs.dtype,
  65. size=time_steps,
  66. tensor_array_name='input_ta')
  67. cosin_score = tensor_array_ops.TensorArray(
  68. dtype=inputs.dtype,
  69. size=time_steps,
  70. tensor_array_name='cosin_score')
  71. input_ta = input_ta.unstack(inputs)
  72. time = tf.constant(0, dtype='int32', name='time')
  73. def _step (time, output_ta_t,cosin_score_t):
  74. current_input = input_ta.read(time)
  75. score = cosine2( current_input,state)
  76. cosin_score_t = cosin_score_t.write(time, score)
  77. output=current_input
  78. # output, new_states = self.step_function(current_input,
  79. # tuple(states))
  80. # # 内部的逻辑
  81. # for state, new_state in zip(states, new_states):
  82. # new_state.set_shape(state.get_shape())
  83. output_ta_t = output_ta_t.write(time, output)
  84. return time + 1, output_ta_t,cosin_score_t
  85. last_time,output,cos_score= control_flow_ops.while_loop(
  86. cond=lambda time, *_: time < time_steps,
  87. body=_step,
  88. loop_vars=(time, output_ta,cosin_score),
  89. parallel_iterations=32,
  90. swap_memory=True,
  91. maximum_iterations=input_length)
  92. outputs = output.stack()
  93. scores=cos_score.stack()
  94. last_output = output_ta.read(last_time - 1)
  95. axes = [1, 0] + list(range(2, len(outputs.get_shape())))
  96. outputs = tf.transpose(outputs, axes)
  97. scores=tf.transpose(scores, [2, 1, 0, 3])
  98. return scores[0]
  99. aaa=K.eval(lala(inputs,state))
  100. print(aaa)
  101. print(aaa.shape)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/42293
推荐阅读
相关标签
  

闽ICP备14008679号