使用Resnet50作为图像特征提取器和faiss进行indexing的电子产品图像检索

使用Resnet50作为图像特征提取器和faiss进行indexing的电子产品图像检索

!wget https://digix-algo-challenge.obs.cn-east-2.myhuaweicloud.com/2020/cv/6rKDTsB6sX8A1O2DA2IAq7TgHPdSPxJF/train_data.zip
!wget https://digix-algo-challenge.obs.cn-east-2.myhuaweicloud.com/2020/cv/6rKDTsB6sX8A1O2DA2IAq7TgHPdSPxJF/test_data_A.zip
!unzip /content/train_data.zip
!unzip /content/test_data_A.zip

一 Resnet50作为图像特征提取器

import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
from torch.optim import lr_scheduler
from PIL import Image
from glob import glob
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import time
import os
import copy
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import cv2
from tqdm import tqdm

TRAIN_DATASET_PATH = '/content/train_data'
IMG_SIZE = (512, 512)
BATCH_SIZE = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, image_fns, label_dict, data_transforms):
        self.image_fns = image_fns
        self.label_dict= label_dict
        self.transforms = data_transforms
    
    def __getitem__(self, index):
        label = self.label_dict[image_fns[index].split('/')[-2]]
        image = Image.open(image_fns[index]).convert("RGB")
        image = self.transforms(image)
        
        return image, label#, image_fns[index]
    
    def __len__(self):
        return len(self.image_fns)
        
image_fns = glob(os.path.join(TRAIN_DATASET_PATH, '*', '*.*'))
label_names = [s.split('/')[-2] for s in image_fns]
unique_labels = list(set(label_names))
unique_labels.sort()
id_labels = {
    
    _id:name for name, _id in enumerate(unique_labels)}

NUM_CLASSES = len(unique_labels)
print("NUM_CLASSES:", NUM_CLASSES)

train_transform = transforms.Compose(
    [transforms.RandomRotation((-15, 15)),
     transforms.Scale(IMG_SIZE[0]),
     transforms.CenterCrop(IMG_SIZE[0]),
     transforms.ColorJitter(brightness=0.1, contrast=0.1,saturation=0.1),
     transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])

val_transform = transforms.Compose(
    [transforms.Scale(IMG_SIZE[0]),
     transforms.CenterCrop(IMG_SIZE[0]),
     transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])


train_fns, val_fns = train_test_split(image_fns, test_size=0.1, shuffle=True)

train_dataset = ImageDataset(train_fns, id_labels, train_transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE,
                                          shuffle=True)
val_dataset = ImageDataset(val_fns, id_labels, val_transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE,
                                          shuffle=True)
datalaoders_dict = {
    
    'train':train_loader, 'val':val_loader}

# from pyretri.models.backbone.backbone_impl.reid_baseline import ft_net_own
# path_file = '/data/nextcloud/dbc2017/files/jupyter/model/resnet50-19c8e357.pth'
# model = ft_net_own(progress=True)

model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True)
model.load_state_dict(torch.load('res50_512_best.pth'))
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, NUM_CLASSES)
print(model.eval())

def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phaseSS
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(best_model_wts,'models/resnet50_512_best.pth')

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = torch.nn.CrossEntropyLoss()
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

model_net = train_model(model, datalaoders_dict, criterion, optimizer, exp_lr_scheduler, num_epochs=20)
# torch.save(model_net.state_dict(),'models/resnet50_512_best.pth')

二 Indexing

import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from glob import glob
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import time
import os
import copy
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import cv2
from tqdm import tqdm

加载预训练模型

IMG_SIZE = (512, 512)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = 3094
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, NUM_CLASSES)
model.load_state_dict(torch.load('/content/res50_512_best.pth'))
model.to(device)
model.eval()

获取AvgPool的值

feat_act = {
    
    }
def get_activation(name):
    def hook(model, input, output):
        feat_act[name] = output.detach()
    return hook
model.avgpool.register_forward_hook(get_activation('avgpool'))

特征提取

class TestImageDataset(torch.utils.data.Dataset):
    def __init__(self, image_fns, data_transforms):
        self.image_fns = image_fns
        self.transforms = data_transforms
    
    def __getitem__(self, index):
        fp = self.image_fns[index]
        fn = os.path.basename(fp)
        image = Image.open(fp).convert("RGB")
        
        # Preprocessing
        image = self.transforms(image)
        
        return image, fn
    
    def __len__(self):
        return len(self.image_fns)

import numpy as np
TEST_BATCH_SZIE = 64
d=2048
TEST_DATASET_PATH = '../dataset/test_data_A/'
gallery_image_fns = glob(os.path.join(TEST_DATASET_PATH, 'gallery', '*.*'))
print(gallery_image_fns[:10])
query_image_fns = glob(os.path.join(TEST_DATASET_PATH, 'query', '*.*'))

test_transform = transforms.Compose([
     transforms.Scale(IMG_SIZE[0]),
     transforms.CenterCrop(IMG_SIZE[0]),
     transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

gallery_dataset = TestImageDataset(gallery_image_fns, test_transform)
gallery_loader = torch.utils.data.DataLoader(gallery_dataset, batch_size=TEST_BATCH_SZIE,
                                          shuffle=True, num_workers=20)

query_dataset = TestImageDataset(query_image_fns, test_transform)
query_loader = torch.utils.data.DataLoader(query_dataset, batch_size=TEST_BATCH_SZIE,
                                          shuffle=True, num_workers=20)

gallery_vector = []
gallery_fns = []
query_vector = []
query_fns = []
with torch.no_grad():
    for img, fn in tqdm(gallery_loader):
        logits = torch.nn.Softmax()(model(img.to(device))).cpu().numpy()
        vectors = feat_act['avgpool'].view((-1, d,)).cpu().numpy()
        for v, n, l in zip(vectors, fn, logits):
            # print(np.sum(l**2))
            if np.sum(l**2) > 0.06:
                gallery_vector.append(v)
                gallery_fns.append(n)
    
    for img, fn in tqdm(query_loader):
        _ = torch.nn.Softmax()(model(img.to(device))).cpu().numpy()
        vectors = feat_act['avgpool'].view((-1, d,)).cpu().numpy()
        for v, n in zip(vectors, fn):
            query_vector.append(v)
            query_fns.append(n)
    
    gallery_vector = np.array(gallery_vector)
    query_vector = np.array(query_vector)

保存特征

# gallery_vector = np.save('gallery_vector.npy',gallery_vector,allow_pickle = True)
# query_vector = np.save('query_vector.npy',query_vector,allow_pickle = True)

加载保存好的特征

gallery_vector = np.load('gallery_vector.npy')
query_vector = np.load('query_vector.npy')

向量匹配

import faiss                   # make faiss available
index = faiss.IndexFlatL2(d)   # build the index
index.add(gallery_vector)                  # add vectors to the index

k = 10                          # we want to see 10 nearest neighbors

D, I = index.search(query_vector, k)     # actual search

result = ''
display_num = 10
for indices, distances, q_fn in zip(I, D, query_fns):
    line = q_fn+',{'
    
    # Visualization
    if display_num > 0:
        q_img = cv2.imread(os.path.join(TEST_DATASET_PATH, 'query', q_fn))[...,::-1].copy()
        g_img = cv2.imread(os.path.join(TEST_DATASET_PATH, 'gallery', gallery_fns[indices[0]]))[...,::-1].copy()
        f = plt.figure()
        f.add_subplot(1,2,1)
        plt.imshow(q_img)
        f.add_subplot(1,2,2)
        plt.imshow(g_img)
        plt.show()
        display_num -= 1
    # Visualization done     
    
    for i, dis in zip(indices, distances):
        #print(_d)
        # if dis-distances[0] > 200:
        #    break
        line+=gallery_fns[i]+','
    line = line[:-1]+'}\n'
    result+=line

在这里插入图片描述

保存到CSV文件

with open('submission.csv', 'a') as f:
    f.truncate(0)
    f.write(result)
    f.close()

猜你喜欢

转载自blog.csdn.net/qq_41375318/article/details/108466311
今日推荐