references:
pytorch search for pictures with pictures__-week-_'s blog-CSDN blog_two
methods of picture search algorithm pytorch pytorch loads its own picture data set__-week-_'s blog-CSDN blog_pytorch read image dataset
1. Modification of the network
1) Keep the vgg16 extraction feature network, remove the fully connected layer and avgpool layer, and change the number of channels to 1 for the last convolutional layer
net = models.vgg16(pretrained=True)
net.classifier = nn.Sequential()
net.features[28] = nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1)
net.avgpool = nn.Sequential()
2. Dataset loading
1) Define a function to convert to txt
def mak_txt(root, file_name):
path = os.path.join(root, file_name)
data = os.listdir(path)
f = open(path + '\\' + 'f.txt', 'w')
for line in data:
if line == 'f.txt':
continue
f.write(line + '\n')
f.close()
2) Call the mak_txt function to convert txt
image_packages = r'D:\AI\images_retreve\image_packages'
inputs_images = r'D:\AI\images_retrev\inputs_images'
path = r'D:\AI\images_retreve'
mak_txt(path, 'image_packages')
mak_txt(path, 'inputs_images')
3) Image preprocessing
# 进行图片预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()
])
4) define mydataset
class MyDataset(Dataset):
def __init__(self, img_path, transform=None):
super(MyDataset, self).__init__()
self.img_path = img_path
self.txt_root = img_path + r'\f.txt'
f = open(self.txt_root, 'r')
data = f.readlines()
imgs = []
for line in data:
line.strip()
word = line.split()
imgs.append(os.path.join(self.img_path, word[0]))
self.img = imgs
self.transform = transform
def __len__(self):
return len(self.img)
def __getitem__(self, item):
img = self.img[item]
img = Image.open(img).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img
5) Load the dataset
dataset_inputs = MyDataset(inputs_images, transform=transform)
dataset_packages = MyDataset(image_packages, transform=transform)
data_loader_inputs = DataLoader(dataset=dataset_inputs, batch_size=1, shuffle=False)
data_loader_packages = DataLoader(dataset=dataset_packages, batch_size=100, shuffle=False)
3. Calculate similarity
1) Start inputting data
for i, data in enumerate(data_loader_inputs):
output_inputs = net(data)
for i, data in enumerate(data_loader_packages):
output_packages = net(data)
print(output_inputs.shape)
print(output_packages.shape)
2) Call the Euclidean distance method in the F library
dist2 = F.pairwise_distance(output_inputs, output_packages, p=2)
print(dist2.shape)
4. Output the three most similar pictures
1) Output the indexes of the three most similar pictures
max_list = []
for i in range(3):
max_n = torch.argmin(dist2)
max_list.append(int(max_n))
dist2[max_n] = 9999999.9
print(max_list)
2) Find the original picture according to the index
path_dir = image_packages + r'\f.txt'
f = open(path_dir, 'r')
data = f.readlines()
data_img = []
for i in range(3):
img_path = os.path.join(image_packages, data[max_list[i]])
data_img.append(img_path)
3) Create a canvas and display the picture on the canvas
fig = plt.figure(figsize=(10, 10))
for i in range(1, 4):
ax = fig.add_subplot(3, 1, i) # 创建一个3行1列的画布, 遍历依次为第1个、第2个画布、第3个画布
img = Image.open(data_img[i - 1].strip())
ax.imshow(img)
pass
plt.show()