Find the three most similar images to a given image

references:

pytorch search for pictures with pictures__-week-_'s blog-CSDN blog_two
methods of picture search algorithm pytorch pytorch loads its own picture data set__-week-_'s blog-CSDN blog_pytorch read image dataset

pytorch adds, deletes, changes, and modifies the structure of the pre-training model__-周-_的博客-CSDN Blog_pytorch modifies the network structure

1. Modification of the network

1) Keep the vgg16 extraction feature network, remove the fully connected layer and avgpool layer, and change the number of channels to 1 for the last convolutional layer

net = models.vgg16(pretrained=True)
net.classifier = nn.Sequential()
net.features[28] = nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1)
net.avgpool = nn.Sequential()

2. Dataset loading

1) Define a function to convert to txt

def mak_txt(root, file_name):
    path = os.path.join(root, file_name)
    data = os.listdir(path)
    f = open(path + '\\' + 'f.txt', 'w')
    for line in data:
        if line == 'f.txt':
            continue
        f.write(line + '\n')
    f.close()

2) Call the mak_txt function to convert txt

image_packages = r'D:\AI\images_retreve\image_packages'
inputs_images = r'D:\AI\images_retrev\inputs_images'
path = r'D:\AI\images_retreve'
mak_txt(path, 'image_packages')
mak_txt(path, 'inputs_images')

3) Image preprocessing


# 进行图片预处理
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])

4) define mydataset

class MyDataset(Dataset):
    def __init__(self, img_path, transform=None):
        super(MyDataset, self).__init__()
        self.img_path = img_path
        self.txt_root = img_path + r'\f.txt'
        f = open(self.txt_root, 'r')
        data = f.readlines()

        imgs = []

        for line in data:
            line.strip()
            word = line.split()
            imgs.append(os.path.join(self.img_path, word[0]))

        self.img = imgs
        self.transform = transform

    def __len__(self):
        return len(self.img)

    def __getitem__(self, item):
        img = self.img[item]

        img = Image.open(img).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img

5) Load the dataset

dataset_inputs = MyDataset(inputs_images, transform=transform)
dataset_packages = MyDataset(image_packages, transform=transform)

data_loader_inputs = DataLoader(dataset=dataset_inputs, batch_size=1, shuffle=False)
data_loader_packages = DataLoader(dataset=dataset_packages, batch_size=100, shuffle=False)

3. Calculate similarity

1) Start inputting data

for i, data in enumerate(data_loader_inputs):
    output_inputs = net(data)

for i, data in enumerate(data_loader_packages):
    output_packages = net(data)

print(output_inputs.shape)
print(output_packages.shape)
2) Call the Euclidean distance method in the F library
dist2 = F.pairwise_distance(output_inputs, output_packages, p=2)
print(dist2.shape)

4. Output the three most similar pictures

1) Output the indexes of the three most similar pictures
max_list = []
for i in range(3):
    max_n = torch.argmin(dist2)
    max_list.append(int(max_n))
    dist2[max_n] = 9999999.9
print(max_list)

2) Find the original picture according to the index

path_dir = image_packages + r'\f.txt'
f = open(path_dir, 'r')
data = f.readlines()
data_img = []
for i in range(3):
    img_path = os.path.join(image_packages, data[max_list[i]])
    data_img.append(img_path)

3) Create a canvas and display the picture on the canvas

fig = plt.figure(figsize=(10, 10))
for i in range(1, 4):
    ax = fig.add_subplot(3, 1, i)  # 创建一个3行1列的画布, 遍历依次为第1个、第2个画布、第3个画布
    img = Image.open(data_img[i - 1].strip())
    ax.imshow(img)
    pass
plt.show()

Guess you like

Origin blog.csdn.net/weixin_52950958/article/details/125781318