赞
踩
定义余弦相似度层,并在batch内进行负采样
- NEG, batch_size = 20, 128
-
- class NegativeCosineLayer():
- """ 自定义batch内负采样并做cosine相似度的层 """
- def __call__(self, inputs):
- def _cosine(x):
- query_encoder, doc_encoder = x
- doc_encoder_fd = doc_encoder
- for i in range(NEG):
- ss = tf.gather(doc_encoder, tf.random.shuffle(tf.range(tf.shape(doc_encoder)[0])))
- doc_encoder_fd = tf.concat([doc_encoder_fd, ss], axis=0)
- query_norm = tf.tile(tf.sqrt(tf.reduce_sum(tf.square(query_encoder), axis=1, keepdims=True)),[NEG + 1, 1])
- doc_norm = tf.sqrt(tf.reduce_sum(tf.square(doc_encoder_fd), axis=1, keepdims=True))
- query_encoder_fd = tf.tile(query_encoder, [NEG + 1, 1])
- prod = tf.reduce_sum(tf.multiply(query_encoder_fd, doc_encoder_fd, name="sim-multiply"), axis=1, keepdims=True)
- norm_prod = tf.multiply(query_norm, doc_norm)
- cos_sim_raw = tf.truediv(prod, norm_prod)
- cos_sim = tf.transpose(tf.reshape(tf.transpose(cos_sim_raw), [NEG + 1, -1])) * 20
-
- prob = tf.nn.softmax(cos_sim, name="sim-softmax")
- hit_prob = tf.slice(prob, [0, 0], [-1, 1], name="sim-slice")
- loss = -tf.reduce_mean(tf.log(hit_prob), name="sim-mean")
- return loss
- output_shape = (1,)
- value = Lambda(_cosine, output_shape=output_shape)([inputs[0], inputs[1]])
- return value

使用方法:
- import tensorflow as tf
- from keras.models import Model
- from keras.layers import Input, Embedding, Dense, Model, LSTM
-
-
- query_max_len = 16
- doc_max_len = 128
- vocab_size = 10000
- embed_dim = 64
-
- query_input = Input(shape=(query_max_len, ), name="query_input")
- doc_input = Input(shape=(doc_max_len, ), name="doc_input")
-
- embedding = Embedding(vocab_size+1, embed_dim)
- query_embed = embedding(query_input)
- doc_embed = embedding(doc_input)
-
- query_encoder = LSTM(128)(query_embed)
- doc_encoder = LSTM(128)(doc_embed)
-
- cos_sim = NegativeCosineLayer()([query_encoder, doc_encoder])
-
- model = Model(inputs=[query_input, doc_input], outputs=cos_sim)
- query_model = Model(inputs=query_input, outputs=query_encoder)
- doc_model = Model(inputs=doc_input, outputs=doc_encoder)
-
- model.compile(optimizer="adam", loss=lambda y_true, y_pred: y_pred)
- model.train(...)

Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。