当前位置:   article > 正文

pytorch 余弦相似度矩阵cos_similar,批量运算_批量计算余弦相似度 python

批量计算余弦相似度 python
  1. import torch
  2. from torch import Tensor
  3. # params: 张量p和q,保证最后一个维度为特征维度,倒数第二哥维度为所求的相似维度。
  4. # return: sim_matrix[i][j]代表p的第i个特征和q的第j个特征相似度计算值。
  5. def cos_similar(p: Tensor, q: Tensor):
  6. sim_matrix = p.matmul(q.transpose(-2, -1))
  7. a = torch.norm(p, p=2, dim=-1)
  8. b = torch.norm(q, p=2, dim=-1)
  9. sim_matrix /= a.unsqueeze(-1)
  10. sim_matrix /= b.unsqueeze(-2)
  11. return sim_matrix
  12. if __name__ == '__main__':
  13. a = torch.rand((64, 23, 50, 200))
  14. b = torch.rand((64, 23, 60, 200))
  15. print(cos_similar(a, b).size())
声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号