根据keypoint生成heatmap

这里keypoint是有类别的,生成的heatmap是每个通道对应每个类别的heatmap

第一种会比较慢,第二种会比较快

第一种

def generate_heatmap(heatmap_size, sigma, class_num, keypoints, normalization):   
    """
    generate gaussian heatmap

    :param heatmap_size: (h, w)
    :param sigma: radius
    :param class_num: num of classes
    :param keypoints: [(x, y, class_id)...]
    :param normalization: divide by the max

    :return gaussian heatmap (c, h, w)
    """

    h, w = heatmap_size
    heatmap = np.zeros((class_num, h, w))
    if keypoints is None:
        return heatmap
    for x, y, c in keypoints:
        if x < 0 or y < 0 or x >= w or y >= h:
            continue
        heatmap[int(c) - 1] += np.exp(-((np.arange(w)[None, :] - x) ** 2 + (np.arange(h)[:, None] - y) ** 2) / (2 * sigma ** 2))
    if normalization:
        heatmap /= heatmap.max(axis=(-1, -2), keepdims=True)
    return heatmap

调用

heatmap = generate_heatmap((256, 256), 7, 2, [
    (100, 50, 1),
    (150, 150, 1),
    (200, 50, 2),
    (50, 200, 2),
    (180, 180, 1)
], True)

最后产生的效果

在这里插入图片描述
在这里插入图片描述

第二种

def generate_gaussian(sigma, radius=None):
    """

    :param sigma: sigma
    :param radius: radius

    :return: generate function
    """    
    if radius is None:
        size = int(6 * sigma + 3)
        radius = size // 2
    else:
        size = 2 * radius + 1
    x = np.arange(size)
    y = x[:, None]
    x0, y0 = size // 2, size // 2
    gaussian_kernel = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2.0 * sigma ** 2))

    def generate(heatmap_size, class_num, keypoints, normalization):
        """
        generate heatmap

        :param heatmap_size: (h, w)
        :param class_num: class num
        :param keypoints: [(x, y, c)]
        :param normalization: do normalization

        :return: gaussian heatmap(c, h, w)
        """
        h, w = heatmap_size
        result = np.zeros((class_num, h, w))
        if keypoints is None:
            return result
        
        for x, y, class_id in keypoints:
            if y < 0 or x < 0 or y >= h or x >= w:
                continue
            x = int(x)
            y = int(y)
            class_id = int(class_id)
            ul = int(y - radius), int(x - radius)
            br = int(y + radius), int(x + radius)
            
            a, b = max(0, -ul[0]), min(br[0], h) - ul[0]
            c, d = max(0, -ul[1]), min(br[1], w) - ul[1]

            aa, bb = max(0, ul[0]), min(br[0], h)
            cc, dd = max(0, ul[1]), min(br[1], w)

            result[class_id - 1, aa:bb, cc:dd] = gaussian_kernel[a:b, c:d]
        if normalization:
            result /= result.max(axis=(-1, -2), keepdims=True) + 1e-6
            result[result > 1] = 0

        return result

    return generate

调用

h, w = 256, 256
sigma = 1
radius = 7
size = 2 * radius + 1
keypoints = [
    (100, 50, 1),
    (150, 150, 1),
    (200, 50, 2),
    (50, 200, 2),
    (180, 180, 1)
]
gen = generate_gaussian(sigma, radius)
heatmap = gen((h, w), 2, keypoints, True)

在这里插入图片描述
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_39942341/article/details/132907458