Pytorch implements simple image search (image retrieval) based on VGG cosine similarity

code show as below:

from PIL import Image
from torchvision import transforms
import os
import torch
import torchvision
import torch.nn.functional as F

class VGGSim(torch.nn.Module):
    def __init__(self):
        super(VGGSim, self).__init__()
        blocks = []
        blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
        for bl in blocks:
            for p in bl:
                p.requires_grad = False
        self.blocks = torch.nn.ModuleList(blocks)
        self.transform = torch.nn.functional.interpolate
        self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
        self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))

    def forward(self, input, target):
        if input.shape[1] != 3:
            input = input.repeat(1, 3, 1, 1)
            target = target.repeat(1, 3, 1, 1)
        input = (input-self.mean) / self.std
        target = (target-self.mean) / self.std
        x = input
        y = target

        res = []
        for block in self.blocks:
            x = block(x)
            y = block(y)
            x_flat = torch.flatten(x, start_dim=1)
            y_flat = torch.flatten(y, start_dim=1)
            similarity = torch.nn.functional.cosine_similarity(x_flat, y_flat)
            res.append(similarity.cpu().item())
        # 仅利用VGG最后一层的全局(分类)特征计算余弦相似度
        # return res[-1]
        # 或者,利用VGG各Block的特征计算余弦相似度
        return sum(res)

def load_image(path):
    image = Image.open(path).convert('RGB')
    image = transforms.Resize([224,224])(image)
    image = transforms.ToTensor()(image)
    image = image.unsqueeze(0)
    return image.cuda()

query_image_path = "query.jpeg"  # 想要查找的图像
query_image = load_image(query_image_path) 
target_image_dir = "cat_images/" # 待搜索的相册
target_images = [os.path.join(target_image_dir, name) for name in os.listdir(target_image_dir)]
vgg_sim = VGGSim().cuda()
scores = []
for path in target_images:
    target_image = load_image(path)
    score = vgg_sim(query_image, target_image)
    scores.append([path, score])
scores.sort(key=lambda x: -x[1])
for i in range(5):
    print("Top", (i + 1), "similiar =>", scores[i][0].split("/")[-1])

The core idea of ​​the above code is similar to Perceptual Loss, which uses VGG to extract multi-level features of images to compare the similarity between two images. The difference is that MAE is generally used in Perceptual Loss, and MSE compares the distance of features, while the code here uses cosine similarity.

An example is as follows, given a raccoon image (query) as follows:
insert image description here
We hope to find images of other raccoons in the album: in the
insert image description here
above data set, the numbers 01 to 10 are cow cats, and the numbers 11 to 20 are cihua cats . Run the code, the result is as follows:

Top 1 similiar => 04.jpeg
Top 2 similiar => 20.jpeg
Top 3 similiar => 14.jpeg
Top 4 similiar => 12.jpeg
Top 5 similiar => 15.jpeg

It can be seen that the search is basically correct, 20, 14, 12, and 15 are all cihuamao. The reason why 04 gets the highest similarity is that it is very similar to the pose of the query, and the environment is also similar (floor), which is also similar to the two images on another level.

Guess you like

Origin blog.csdn.net/qq_40714949/article/details/132212179