python/pytorch计算tensor的余弦相似度

一、相似度和点积

很多场景里,需要比较两个tensor的相似度(NLP或者CV里都有可能),这种相似度的计算一般用余弦相似度来计算,也就是常说的向量点积(dot-product),比如Transformer里self-attention的相关操作,用点积来计算Q和K的“相似度”

二、Pytorch的简单实现

很好的是,torch里有现成的函数cosine_similarity,不需要像网上那种要自己定义一个复杂的类来实现。

torch.cosine_similarity(input1,input2,dim=1, eps=1e-8)

  • input1和input2都需要是两个torch.Tensor类型的变量
  • dim指定在某个维度上进行计算相似度,default=1,即可以不输入
  • eps是避免出现除数为0的一个极小值,一般不输入

例:

通过transform的编码器对两张图进行编码,得到了两个shape为[1,1,768]的tensor:img1和img2

import torch
# img1.shape = [1,1,768] = img2.shape
cos_sim = torch.cosine_similarity(img1, img2, dim=2)
# tensor([[0.9457]], device='cuda:0')
print(cos_sim)

可以看到这两张图的相似度是0.9457

如果是批量化计算,得到一组cos,怎么方便计算平均余弦相似度呢?

参考做法:

import torch
# img1.shape = [1,1,768] = img2.shape


cos_list = []
for i in range(n):
    cos_sim = torch.cosine_similarity(...)
    cos_list.append(cos_sim)
#此时cos_list为list,但是里面都是一个个tensor 不方便计算
# cos_list.shape = [9,1,1]
# 可以用下面的方法 先建一个新维度 然后在这个维度上mean
mean_cos=torch.stack(cos_list,dim=0).mean(dim=0)
# tensor([[0.9599]], device='cuda:0')
print(mean_cos)

扫描二维码关注公众号,回复: 14867620 查看本文章

猜你喜欢

转载自blog.csdn.net/jiangqixing0728/article/details/129185185
今日推荐