距離行列への座標

#!/usr/bin/env python
# _*_ coding:utf-8 _*_
import torch


def by_l2_square(a, b):
    a = a.unsqueeze(-2)
    b = b.unsqueeze(-3)
    C = torch.sum((a - b) ** 2, dim=-1)
    return C


def by_p_norm(a, b, p):
    a = a.unsqueeze(-2)
    b = b.unsqueeze(-3)
    C = torch.linalg.vector_norm(a - b, ord=p, dim=-1)
    return C


def origin(a, b):
    B, n, _ = a.shape
    _, m, _ = b.shape

    C = torch.empty(B, n, m)
    for batch in range(B):
        for i in range(n):
            for j in range(m):
                C[batch, i, j] = torch.sum((a[batch, i] - b[batch, j])**2)
    return C


if __name__ == '__main__':
    n = 5
    m = 4
    a = torch.rand(1, n, 2)
    b = torch.rand(1, m, 2)

    ans1 = by_l2_square(a, b)
    ans2 = by_p_norm(a, b, 2)**2
    ans3 = origin(a, b)

    print(torch.allclose(ans1, ans3))
    print(torch.allclose(ans1, ans2))
    print(torch.allclose(ans2, ans3))



おすすめ

転載: blog.csdn.net/qq_39942341/article/details/131783575