pytorch 深度学习之余弦相似度

用处

此方法特别重要,经常可以用来修改论文,提出创新点.

定理

余弦相似度是通过计算两个向量之间的夹角余弦值来衡量它们的相似性。给定两个非零向量 x 和 y,它们之间的余弦相似度可以使用以下公式计算:

cosine_similarity(x, y) = (x · y) / (||x|| * ||y||)

其中,

  • (x · y) 表示向量 x 和 y 的点积(内积),是两个向量对应元素相乘再求和的结果。
  • ||x|| 表示向量 x 的范数,通常使用 L2 范数表示,即向量 x 的所有元素平方和的平方根。
  • ||y|| 表示向量 y 的范数,也是使用 L2 范数进行计算。

使用上述公式,我们可以将两个向量的点积除以它们的范数的乘积,得到余弦相似度的标量结果,取值范围在 -1 到 1 之间。越接近 1 表示两个向量越相似,越接近 -1 表示两个向量越不相似,0 表示两个向量正交(无关)。

代码

  • 代码1:

如果您想在指定的维度(channels, height, width)上计算范数并保持计算过程中的维度,可以进行如下修改:

import torch.nn.functional as F

def cosine_similarity(tensor_1, tensor_2):
    normalized_tensor_1 = F.normalize(tensor_1, p=2, dim=(1, 2, 3))
    normalized_tensor_2 = F.normalize(tensor_2, p=2, dim=(1, 2, 3))
    cosine_sim = torch.sum(normalized_tensor_1 * normalized_tensor_2, dim=(1, 2, 3), keepdim=True)
    return cosine_sim

在这里,我们使用 dim=(1, 2, 3) 将计算范数的维度指定为 (channels, height, width),并使用 keepdim=True 保持了计算过程中的维度。

这样,函数将在指定的维度上进行范数计算,并返回一个与输入张量形状相同的张量,其中的每个元素是沿着指定维度(channels, height, width)计算得到的余弦相似度值,并保持了指定维度的维度大小。

  • 代码2:

如果你希望使用 torch.norm() 函数计算张量的范数,可以对上述代码进行如下修改:

import torch

def cosine_similarity(tensor_1, tensor_2):
    normalized_tensor_1 = tensor_1 / torch.norm(tensor_1, p=2, dim=(1, 2, 3), keepdim=True)
    normalized_tensor_2 = tensor_2 / torch.norm(tensor_2, p=2, dim=(1, 2, 3), keepdim=True)
    cosine_sim = torch.sum(normalized_tensor_1 * normalized_tensor_2, dim=(1, 2, 3), keepdim=True)
    return cosine_sim

在这个修改后的代码中,我们使用了 torch.norm() 函数计算指定维度上的张量范数,并将其作为分母来归一化输入张量。参数 p=2 表示使用 L2 范数计算。

然后,我们使用 torch.sum() 函数在指定的维度上求和,并保持计算过程中的维度,得到余弦相似度的向量。

请确保已经导入了 torch 模块。

F.normalize() 和 F.norm() 的区别

F.normalize()F.norm() 是两个不同的函数,它们在功能和使用方式上有所不同。

  1. F.normalize() 函数是用来对张量进行归一化处理的。它接受一个输入张量和一个参数 p,并根据指定的范数类型对输入张量进行归一化。常见的范数类型包括 L1 范数、L2 范数等。归一化后的张量将具有单位长度,方便进行一些距离度量或相似度计算的操作。

  2. F.norm() 函数是用来计算张量的范数的。它接受一个输入张量和一个参数 p,并返回指定范数类型的计算结果。常见的范数类型包括 L1 范数、L2 范数等。F.norm() 函数返回的是一个标量结果,而不是对输入张量进行归一化处理。

总结:
F.normalize() 函数用于对张量进行归一化处理,返回归一化后的张量;
F.norm() 函数用于计算张量的范数,返回范数的标量结果。

猜你喜欢

转载自blog.csdn.net/weixin_41504611/article/details/134429857