A simple image retrieval system using vgg as backbone

        Image retrieval (Content-based Image Retrieval, CBIR for short) is a search for images based on image semantic information, such as color, texture, layout, CNN-based high-level semantics and other feature retrieval technologies. The technique can be divided into instance and category retrieval tasks. The former is to give a query image of an object/scene/building type, and query the images containing the same object/scene/building that were taken from different angles, illuminated or with occlusion; the latter is to retrieve images of the same category. picture. The current needs are more suitable for instance image retrieval.

         CBIR research officially began in the early 1990s. Researchers indexed images based on visual features such as texture and color. During this period, a large number of excellent algorithms and image retrieval systems were proposed. Not stated one by one. The time is extended to after 2000, as shown in Figure 2.1, which shows the milestone moments in the instance retrieval task over the years, and the time when the SIFT feature and CNN feature algorithms were proposed is highlighted in the figure. The year 2000 can be considered the end of most traditional methods, when Smeulders et al. wrote the review "The Early End". Three years later (2003), the bag-of-words model (BoW) entered the image retrieval community's vision, and in 2004 it was applied to image classification tasks combined with the SIFT method. In the past 10 years, the community has witnessed the superiority of the BoW model, which has brought various improvements to image retrieval tasks. In 2012, Krizhevsky et al. used AlexNet to achieve the highest recognition accuracy in the world at ILSRVC 2012. Since then, the focus of research has begun to shift to methods based on deep learning, especially convolutional neural networks (CNN).

         Here is a simple image retrieval code:

        cbir.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time        :2022/7/6 11:26
# @Author      :weiz
# @ProjectName :CBIR-master
# @File        :cbir.py
# @Description :content-based image retrieval
# Copyright (C) 2021-2025 Jiangxi Institute Of Intelligent Industry Technology Innovation
import os
import cv2
from six.moves import cPickle
from scipy import spatial

from feature_extraction import *
from vgg import *


class CBIR(object):
    database_feature_path = "./database_feature"

    def read_images(self, image_folder_path):
        self.image_info_list = []
        for root, _, image_name_list in os.walk(image_folder_path, topdown=False):
            label_name = root.split('/')[-1].split('\\')[-1]
            for image_name in image_name_list:
                if image_name.split('.')[-1] in ["png", "jpg", "jpeg"]:
                    image_path = os.path.join(root, image_name)
                    self.image_info_list.append([image_path, label_name])
                else:
                    print("{} is not a picture".format(os.path.join(root, image_name)))

    def load_database(self, image_folder_path=None, is_save=False):
        self.database = []
        database_feature_path = CBIR.database_feature_path + '_' + self.feature_extraction_object.get_name()
        if os.path.exists(database_feature_path) and not image_folder_path:
            self.database = cPickle.load(open(database_feature_path, "rb", True))
        else:
            if image_folder_path:
                self.read_images(image_folder_path)
            for image_info in self.image_info_list:  # [[图片路径, 类别]...]
                image = cv2.imread(image_info[0])
                feature = self.feature_extraction_object(image)
                self.database.append({
                    'image_path': image_info[0],
                    'label': image_info[1],
                    'feature': feature
                })
            if is_save:
                cPickle.dump(self.database, open(database_feature_path, "wb", True))

        return self.database

    def __init__(self, feature_extraction_object, image_folder_path):
        self.feature_extraction_object = feature_extraction_object
        self.read_images(image_folder_path)
        # print(self.image_info_list)
        self.load_database(is_save=True)
        # print(self.database)

    def query(self, image, query_depth=3, is_show=False):
        feature_1 = self.feature_extraction_object(image)

        query = []
        for idx, value_2 in enumerate(self.database):
            image_path_2, label_2, feature_2 = value_2["image_path"], value_2["label"], value_2["feature"]
            query.append({
                "distance": self.distance(feature_1, feature_2, "d3"),
                "label": label_2,
                "image_path": image_path_2
            })
        # 如果候选深度足够,取前query_depth个
        query = sorted(query, key=lambda x: x["distance"])
        if query_depth and query_depth <= len(query):
            query = query[:query_depth]

        if is_show:
            cv2.imshow("src", image)
            for idx, q in enumerate(query):
                img = cv2.imread(q["image_path"])
                cv2.imshow("top {}".format(idx + 1), img)
            cv2.waitKey(0)
            cv2.destroyAllWindows()
        return query

    def evaluate(self, database=None, query_depth=3):
        if not database:
            database = self.database

        label_list = []
        for tmp in database:
            label_list.append(tmp["label"])
        label_list = list(set(label_list))
        results = {c: [] for c in label_list}
        for value_1 in database:
            image_path_1, label_1, feature_1 = value_1["image_path"], value_1["label"], value_1["feature"]
            query = []
            for idx, value_2 in enumerate(database):
                image_path_2, label_2, feature_2 = value_2["image_path"], value_2["label"], value_2["feature"]
                if image_path_1 == image_path_2:  # 同一图片不参与评估
                    continue
                query.append({
                    "distance": self.distance(feature_1, feature_2, "d3"),
                    "label": label_2
                })
            # 如果候选深度足够,取前query_depth个
            query = sorted(query, key=lambda x: x["distance"])
            # print(query)
            if query_depth and query_depth <= len(query):
                query = query[:query_depth]

            # 计算有多少被hit
            hit = 0
            precision = []
            for idx, q in enumerate(query):
                if q["label"] == label_1:
                    hit += 1
                    precision.append((hit / (idx + 1.)))
                # else:
                #     print("原始目标, path:{}  label:{}".format(value_1["image_path"], value_1["label"]))
                #     print("预测目标, distance:{}  label:{}".format(q["distance"], q["label"]))
            if hit == 0:
                results[label_1].append(0.)
            else:
                results[label_1].append(np.mean(precision))

        mAPs = []
        for label, Ps in results.items():
            AP = np.mean(Ps)
            print("Class {}, AP {}".format(label, AP))
            mAPs.append(AP)
        print("MAP", np.mean(mAPs))
        return results

    def distance(self, value_1, value_2, d_type="d1"):
        assert value_1.shape == value_2.shape
        if d_type == 'd1':  # 曼哈顿距离
            return np.sum(np.absolute(value_1 - value_2))
        elif d_type == 'd2':  # 欧几里得距离
            return np.sqrt(np.sum((value_1 - value_2) ** 2))
        elif d_type == 'd3':  # 余弦相似度
            return spatial.distance.cosine(value_1, value_2)


test_image_path = "./database/cup2"  # ./test_image  ./database/glasses
if __name__ == "__main__":
    vgg_model = VGGNet(requires_grad=False, net_type="vgg16", show_params=False)
    vgg_model.eval()

    if torch.cuda.is_available():
        vgg_model = vgg_model.cuda()

    cbir = CBIR(vgg_model, "./database")
    # cbir.evaluate(query_depth=3)
    test_image_list = os.listdir(test_image_path)
    for image_name in test_image_list[:3]:
        print(image_name)
        image_path = os.path.join(test_image_path, image_name)
        test_image = cv2.imread(image_path)
        print(cbir.query(test_image, is_show=True, query_depth=2))

        feature_extraction.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time        :2022/7/11 11:20
# @Author      :weiz
# @ProjectName :CBIR-master
# @File        :feature_extraction.py
# @Description :
# Copyright (C) 2021-2025 Jiangxi Institute Of Intelligent Industry Technology Innovation
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models.vgg import VGG
import numpy as np


class VGGNet(VGG):
    cfg = {
        'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
        'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
        'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
        'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512,'M'],
    }

    ranges = {
        'vgg11': ((0, 3), (3, 6), (6, 11), (11, 16), (16, 21)),
        'vgg13': ((0, 5), (5, 10), (10, 15), (15, 20), (20, 25)),
        'vgg16': ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)),
        'vgg19': ((0, 5), (5, 10), (10, 19), (19, 28), (28, 37))
    }

    means = np.array([103.939, 116.779, 123.68]) / 255.  # mean of three channels in the order of BGR

    def net_layers(self, net_type, batch_norm=False):
        """
        构建网络层
        """
        layers = []
        in_channels = 3
        for value in VGGNet.cfg[net_type]:
            if 'M' == value:
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, value, kernel_size=3, padding=1)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(value), nn.ReLU(inplace=True)]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = value
        return nn.Sequential(*layers)

    def __init__(self, pretrained=True, net_type='vgg16', requires_grad=False, remove_fc=False, show_params=False):
        """
        初始化
        """
        super().__init__(self.net_layers(net_type))
        self.ranges = VGGNet.ranges[net_type]
        self.fc_ranges = ((0, 2), (2, 5), (5, 7))
        self.net_type = net_type
        # print(self.features.state_dict())

        if pretrained:
            exec("self.load_state_dict(models.%s(pretrained=True).state_dict())" % net_type)

        if not requires_grad:
            for param in super().parameters():
                param.requires_grad = False

        if remove_fc:  # 不需要全连接层,删除
            self.classifier = None
            self.avgpool = None

        if show_params:
            for name, param in self.named_parameters():
                print(name, param.size())

    def forward(self, image):
        """
        image格式需要BGR格式
        """
        # image = image[:, :, ::-1]
        image = np.transpose(image, (2, 0, 1)) / 255.
        image[0] -= VGGNet.means[0]  # reduce B's mean
        image[1] -= VGGNet.means[1]  # reduce G's mean
        image[2] -= VGGNet.means[2]  # reduce R's mean
        image = np.expand_dims(image, axis=0)

        if torch.cuda.is_available():
            inputs = torch.autograd.Variable(torch.from_numpy(image).cuda().float())
        else:
            inputs = torch.autograd.Variable(torch.from_numpy(image).float())

        # print(inputs.shape)
        # print(self.features)
        x = self.features(inputs)
        avg_pool = torch.nn.AvgPool2d((x.size(-2), x.size(-1)), stride=(x.size(-2), x.size(-1)),
                                      padding=0, ceil_mode=False, count_include_pad=True)
        feature = avg_pool(x)                  # avg.size = N * 512 * 1 * 1
        feature = feature.view(feature.size(0), -1)    # avg.size = N * 512

        feature = np.sum(feature.data.cpu().numpy(), axis=0)
        feature /= np.sum(feature)  # normalize

        return feature

    def get_name(self):
        return self.net_type

    def __coll__(self, x):
        return self.forward(x)

        The database format is as follows.

       data

Guess you like

Origin blog.csdn.net/qq_31112205/article/details/126180926