当前位置:   article > 正文

nlp——SentenceTransformer使用例子_sentencetransformer加载本地模型

sentencetransformer加载本地模型

Hugging Face官网下载sentence-transformers模型

1、导入所需要的库

  1. from transformers import AutoTokenizer, AutoModel
  2. import numpy as np
  3. import torch
  4. import torch.nn.functional as F

2、加载预训练模型

  1. path = 'D:/Model/sentence-transformers/all-MiniLM-L6-v2'
  2. tokenizer = AutoTokenizer.from_pretrained(path)
  3. model = AutoModel.from_pretrained(path)

 3、定义平均池化

  1. def mean_pooling(model_output, attention_mask):
  2. #First element of model_output contains all token embeddings
  3. token_embeddings = model_output[0]
  4. input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
  5. return torch.sum(token_embeddings * input_mask_expanded, 1) /
  6. torch.clamp(input_mask_expanded.sum(1), min=1e-9)

4、对句子进行嵌入

  1. sentences = ['loved thisand know really bought wanted see pictures myselfIm lucky enough someone could justify buying present',
  2. 'issue pages stickers restuck really used configurations made regular pages rather taking pieces robot back',
  3. 'stickers dont stick well first time placing',
  4. 'Great fun grandson loves robots',
  5. 'would suggest younger kids son 3']
  6. encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
  7. with torch.no_grad():
  8. model_output = model(**encoded_input)
  9. sentence_embeddings1 = mean_pooling(model_output, encoded_input['attention_mask'])
  10. print("Sentence embeddings:")
  11. print(sentence_embeddings1)
  12. # Normalize embeddings
  13. sentence_embeddings2 = F.normalize(sentence_embeddings1, p=2, dim=1)
  14. print("Sentence embeddings:")
  15. print(sentence_embeddings2)

5、运行结果

6、定义句子之间的相似度

  1. def compute_sim_score(v1, v2) :
  2. return v1.dot(v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
'
运行

7、计算句子相似度

  1. #'issue pages stickers restuck really used configurations made regular pages rather taking pieces robot back'
  2. #'stickers dont stick well first time placing'
  3. compute_sim_score(sentence_embeddings1[1], sentence_embeddings1[2])
  4. #result:tensor(0.5126)

8、看一下嵌入的shape

  1. sentence_embeddings1.shape
  2. #torch.Size([5, 384])

展望总结:

接下来试试对真实用户对项目的评论句子做嵌入

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Guff_9hys/article/detail/965623
推荐阅读
相关标签
  

闽ICP备14008679号