当前位置:   article > 正文

Keras-DSSM之in-batch余弦相似度负采样层_dssm keras

dssm keras

定义余弦相似度层,并在batch内进行负采样

  1. NEG, batch_size = 20, 128
  2. class NegativeCosineLayer():
  3. """ 自定义batch内负采样并做cosine相似度的层 """
  4. def __call__(self, inputs):
  5. def _cosine(x):
  6. query_encoder, doc_encoder = x
  7. doc_encoder_fd = doc_encoder
  8. for i in range(NEG):
  9. ss = tf.gather(doc_encoder, tf.random.shuffle(tf.range(tf.shape(doc_encoder)[0])))
  10. doc_encoder_fd = tf.concat([doc_encoder_fd, ss], axis=0)
  11. query_norm = tf.tile(tf.sqrt(tf.reduce_sum(tf.square(query_encoder), axis=1, keepdims=True)),[NEG + 1, 1])
  12. doc_norm = tf.sqrt(tf.reduce_sum(tf.square(doc_encoder_fd), axis=1, keepdims=True))
  13. query_encoder_fd = tf.tile(query_encoder, [NEG + 1, 1])
  14. prod = tf.reduce_sum(tf.multiply(query_encoder_fd, doc_encoder_fd, name="sim-multiply"), axis=1, keepdims=True)
  15. norm_prod = tf.multiply(query_norm, doc_norm)
  16. cos_sim_raw = tf.truediv(prod, norm_prod)
  17. cos_sim = tf.transpose(tf.reshape(tf.transpose(cos_sim_raw), [NEG + 1, -1])) * 20
  18. prob = tf.nn.softmax(cos_sim, name="sim-softmax")
  19. hit_prob = tf.slice(prob, [0, 0], [-1, 1], name="sim-slice")
  20. loss = -tf.reduce_mean(tf.log(hit_prob), name="sim-mean")
  21. return loss
  22. output_shape = (1,)
  23. value = Lambda(_cosine, output_shape=output_shape)([inputs[0], inputs[1]])
  24. return value

使用方法:

  1. import tensorflow as tf
  2. from keras.models import Model
  3. from keras.layers import Input, Embedding, Dense, Model, LSTM
  4. query_max_len = 16
  5. doc_max_len = 128
  6. vocab_size = 10000
  7. embed_dim = 64
  8. query_input = Input(shape=(query_max_len, ), name="query_input")
  9. doc_input = Input(shape=(doc_max_len, ), name="doc_input")
  10. embedding = Embedding(vocab_size+1, embed_dim)
  11. query_embed = embedding(query_input)
  12. doc_embed = embedding(doc_input)
  13. query_encoder = LSTM(128)(query_embed)
  14. doc_encoder = LSTM(128)(doc_embed)
  15. cos_sim = NegativeCosineLayer()([query_encoder, doc_encoder])
  16. model = Model(inputs=[query_input, doc_input], outputs=cos_sim)
  17. query_model = Model(inputs=query_input, outputs=query_encoder)
  18. doc_model = Model(inputs=doc_input, outputs=doc_encoder)
  19. model.compile(optimizer="adam", loss=lambda y_true, y_pred: y_pred)
  20. model.train(...)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/42274
推荐阅读
相关标签
  

闽ICP备14008679号