ベクトル検索: 画像検索システムを構築するための ResNet 事前トレーニング モデルに基づく

目次

1 プロジェクトの背景の紹介

2 キーテクノロジーの紹介

2.1 レスネットネットワーク

2.2 Milvus ベクター データベース

3 システムコードの実装

3.1 動作環境の構築

3.2 データセットのダウンロード

3.3 事前トレーニングモデルのダウンロード

3.4 コードの実装

3.4.1 Milvus テーブルとインデックスの作成

 3.4.2 Resnet エンコーディング ネットワークの構築

3.4.3 データのベクトル化とロード

3.4.4 検索ウェブの構築

4 まとめ


1 プロジェクトの背景の紹介

画像による画像検索は、画像をアップロードすることで、他の画像やそれに関連する関連情報を検索して見つけるベクトル検索技術です。画像検索テクノロジーは、より直感的で効率的な情報検索方法を提供します。このテクノロジーには幅広い応用シナリオと価値があり、商品の検索とショッピング、動植物の識別、食品の識別、知識の検索などの分野でよく使用されます。画像検索の技術的なポイントは以下のとおりです。

  • 画像データをベクトルエンコードする方法
  • 大量のベクターデータを保存する方法
  • 大量のベクター データを迅速に取得する方法

Milvus ベクトル データベースと組み合わせた Resnet 事前学習モデルに基づいて、このプロジェクトは果物データセットに地図検索システムを実装します。読者はデータセットを他の分野に拡張し、自分のビジネスに合った地図検索システムを構築できます。

2 キーテクノロジーの紹介

2.1 レスネットネットワーク

ResNet (Residual Network の正式名) は、深層学習の分野で非常に重要な畳み込みニューラル ネットワーク (CNN) アーキテクチャの 1 つです。2015 年に Kaiming He 氏らによって提案され、ImageNet の画像分類コンテストで目覚ましい成果を上げ、当時は分類タスク、ターゲット検出、画像セグメンテーションの分野で 1 位を獲得しました。ResNet の革新的な点は、残留接続 (残留接続) の導入であり、これにより、ネットワークがトレーニング中にディープ ネットワークをより簡単にトレーニングできるようになります。

従来のニューラル ネットワークでは、ネットワーク層の数が増加するにつれて、パフォーマンスが飽和したり、低下したりする可能性があります。これは、勾配の消失や爆発などの問題によりトレーニングが困難になる可能性があるためです。ResNet は、残差ブロックを導入することでこの問題を解決します。各残差ブロックはメインの畳み込み層で構成され、その出力と入力の差は「残差」と呼ばれ、残差が加算されて最終出力が得られます。このようなアーキテクチャにより、ネットワークが非常に深くなった場合でも、情報がネットワーク内でより容易に拡散できるようになります。

ResNet の古典的なネットワーク構造は、ResNet-18、ResNet-34、ResNet-50、ResNet-101、ResNet-152 です。このうち、ResNet-18 と ResNet-34 は同じ基本構造を持ち、比較的浅いネットワークに属します。後の 3 つはより深いネットワークに属し、その中では RestNet50 が最も一般的に使用されます。

 ResNet には次のような利点があります。

  • より深いネットワークのトレーニング: 残留接続の導入により、トレーニング時により容易に収束する傾向がある非常に深いネットワークの構築が可能になります。
  • 勾配の消失と爆発を回避する: 残留接続により、ネットワーク内での勾配の拡散が促進され、勾配の消失と爆発の問題が軽減されます。
  • 特徴学習の向上: 残差ブロックにより、ネットワークは残差を学習できます。つまり、キャプチャしやすいきめの細かい特徴を学習できます。

ResNet の詳細な紹介: ResNet

2.2 Milvus ベクター データベース

Milvus は、高可用性、高性能、簡単な拡張を備えたクラウドネイティブのベクトル データベースであり、大量のベクトル データをリアルタイムに呼び出すために使用されます。

Milvus は FAISS、Annoy、HNSW およびその他のベクトル検索ライブラリに基づいて構築されており、その中心は高密度ベクトル類似性検索の問題を解決することです。ベクトル検索ライブラリに基づいて、Milvus はデータ分割、データ永続化、増分データ取り込み、スカラー ベクトル ハイブリッド クエリ、タイム トラベルなどの機能をサポートし、ベクトル検索シナリオのアプリケーション要件を満たすためにベクトル検索のパフォーマンスを大幅に最適化します。 。一般に、最高の可用性と回復力を得るために、Kubernetes を使用して Milvus をデプロイすることをお勧めします。

Milvus は共有ストレージ アーキテクチャを採用しており、ストレージとコンピューティングは完全に分離されており、コンピューティング ノードは水平方向の拡張をサポートしています。アーキテクチャの観点から見ると、Milvus はデータ フローと制御フローの分離に従い、全体としてアクセス層、コーディネーター サービス、ワーカー ノード、ストレージ層の 4 つの層に分かれています。各レベルは互いに独立しており、独立した拡張と災害復旧が行われます。

 Milvus ベクトル データベースは、ユーザーが大量の非構造化データ (画像/ビデオ/音声/テキスト) の検索を簡単に処理できるようにします。シングルノードの Milvus は数秒以内に 10 億レベルのベクトル検索を完了でき、分散アーキテクチャはユーザーの水平拡張要件にも対応できます。

milvusの特徴をまとめると以下のようになります。

  • 高性能: 高性能で、大規模なデータセットに対してベクトル類似性検索を実行できます。
  • 高可用性と信頼性: Milvus はクラウド上での拡張をサポートしており、その災害復旧機能により高いサービス可用性を確保できます。
  • ハイブリッド クエリ: Milvus は、ハイブリッド クエリを実現するために、ベクトル類似性検索中のスカラー フィールド フィルタリングをサポートします。
  • 開発者に優しい: 多言語とマルチツールをサポートする Milvus エコシステム。

ミルバスの詳細:ミルバス

3 システムコードの実装

3.1 動作環境の構築

conda 環境の準備については、annoconda を参照してください。

git clone https://gitcode.net/ai-medical/image_image_search.git
cd image_image_search

pip install -r requirements.txt

3.2 データセットのダウンロード

ダウンロードリンク:

最初のパケット: package01

2 番目のパッケージ: package01

次の図に示すように、データセット ディレクトリの下には 10 個のフォルダーがあり、フォルダー名は果物の種類であり、各フォルダーにはこの種類の果物の数百から数千の写真が含まれています。

 例として apple フォルダーを取り上げます。内容は次のとおりです。

ダウンロード後、解凍してD:/dataset/fruitディレクトリに保存すると以下のように表示されます

# ll fruit/
总用量 508
drwxr-xr-x 2 root root 36864 8月   2 16:35 apple
drwxr-xr-x 2 root root 24576 8月   2 16:36 apricot
drwxr-xr-x 2 root root 40960 8月   2 16:36 banana
drwxr-xr-x 2 root root 20480 8月   2 16:36 blueberry
drwxr-xr-x 2 root root 45056 8月   2 16:37 cherry
drwxr-xr-x 2 root root 12288 8月   2 16:37 citrus
drwxr-xr-x 2 root root 49152 8月   2 16:38 grape
drwxr-xr-x 2 root root 16384 8月   2 16:38 lemon
drwxr-xr-x 2 root root 36864 8月   2 16:39 litchi
drwxr-xr-x 2 root root 49152 8月   2 16:39 mango

3.3 事前トレーニングモデルのダウンロード

 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',

resnet50 の事前トレーニング モデル: resnet50 をダウンロードし、 D:/models ディレクトリに保存します。

3.4 コードの実装

3.4.1 Milvus テーブルとインデックスの作成

from pymilvus import connections, db

conn = connections.connect(host="192.168.1.156", port=19530)
database = db.create_database("image_vector_db")

db.using_database("image_vector_db")
print(db.list_database())

コレクションを作成する

from pymilvus import CollectionSchema, FieldSchema, DataType
from pymilvus import Collection, db, connections


conn = connections.connect(host="192.168.1.156", port=19530)
db.using_database("image_vector_db")

m_id = FieldSchema(name="m_id", dtype=DataType.INT64, is_primary=True,)
embeding = FieldSchema(name="embeding", dtype=DataType.FLOAT_VECTOR, dim=2048,)
path = FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=256,)
schema = CollectionSchema(
  fields=[m_id, embeding, path],
  description="image to image embeding search",
  enable_dynamic_field=True
)

collection_name = "fruit_vector"
collection = Collection(name=collection_name, schema=schema, using='default', shards_num=2)

インデックスを作成する

from pymilvus import Collection, utility, connections, db

conn = connections.connect(host="192.168.1.156", port=19530)
db.using_database("image_vector_db")

index_params = {
  "metric_type": "L2",
  "index_type": "IVF_FLAT",
  "params": {"nlist": 1024}
}

collection = Collection("fruit_vector")
collection.create_index(
  field_name="embeding",
  index_params=index_params
)

utility.index_building_progress("fruit_vector")

 3.4.2 Resnet エンコーディング ネットワークの構築

Resnet 事前トレーニング モデルを読み込み、完全に接続された層を削除して、Resnet エンコード出力フィーチャの次元が 2048 になるようにします。

from torchvision.models import resnet50
import torch
from torchvision import transforms
from torch import nn


class ResnetEmbeding:
    pretrained_model = 'D:/models/resnet50-0676ba61.pth'

    def __init__(self):
        self.model = resnet50()
        self.model.load_state_dict(torch.load(self.pretrained_model))

        # delete fc layer
        self.model.fc = nn.Sequential()
        self.transform = transforms.Compose([transforms.Resize((224, 224)),
                                             transforms.ToTensor(),
                                             transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                                  std=[0.26862954, 0.26130258, 0.27577711])])

    def embeding(self, image):
        trans_image = self.transform(image)
        trans_image = trans_image.unsqueeze_(0)
        return self.model(trans_image)


resnet_embeding = ResnetEmbeding()

3.4.3 データのベクトル化とロード

from resnet_embeding import resnet_embeding
from milvus_operator import restnet_image, MilvusOperator
from PIL import Image, ImageSequence
import os


def update_image_vector(data_path, operator: MilvusOperator):
    idxs, embedings, paths = [], [], []

    total_count = 0
    for dir_name in os.listdir(data_path):
        sub_dir = os.path.join(data_path, dir_name)
        for file in os.listdir(sub_dir):

            image = Image.open(os.path.join(sub_dir, file)).convert('RGB')
            embeding = resnet_embeding.embeding(image)

            idxs.append(total_count)
            embedings.append(embeding[0].detach().numpy().tolist())
            paths.append(os.path.join(sub_dir, file))
            total_count += 1

            if total_count % 50 == 0:
                data = [idxs, embedings, paths]
                operator.insert_data(data)

                print(f'success insert {operator.coll_name} items:{len(idxs)}')
                idxs, embedings, paths = [], [], []

        if len(idxs):
            data = [idxs, embedings, paths]
            operator.insert_data(data)
            print(f'success insert {operator.coll_name} items:{len(idxs)}')

    print(f'finish update {operator.coll_name} items: {total_count}')


if __name__ == '__main__':
    data_dir = 'D:/dataset/fruit'
    update_image_vector(data_dir, resnet_image)

3.4.4 検索ウェブの構築

import gradio as gr
import torch
import numpy as np
import argparse
from net_helper import net_helper
from PIL import Image
from restnet_embeding import restnet_embeding
from milvus_operator import resnet_image


def image_search(image):
    if image is None:
        return None

    image = image.convert("RGB")

    # resnet编码
    imput_embeding = resnet_embeding.embeding(image)
    imput_embeding = imput_embeding[0].detach().cpu().numpy()

    results = restnet_image.search_data(imput_embeding)
    pil_images = [Image.open(result['path']) for result in results]
    return pil_images


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--share", action="store_true",
                        default=False, help="share gradio app")
    args = parser.parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    app = gr.Blocks(theme='default', title="image",
                    css=".gradio-container, .gradio-container button {background-color: #009FCC} "
                        "footer {visibility: hidden}")
    with app:
        with gr.Tabs():
            with gr.TabItem("image search"):
                with gr.Row():
                    with gr.Column():
                        image = gr.inputs.Image(type="pil", source='upload')
                        btn = gr.Button(label="search")

                    with gr.Column():
                        with gr.Row():
                            output_images = [gr.outputs.Image(type="pil", label=None) for _ in range(16)]

                btn.click(image_search, inputs=[image], outputs=output_images, show_progress=True)

    ip_addr = net_helper.get_host_ip()
    app.queue(concurrency_count=3).launch(show_api=False, share=True, server_name=ip_addr, server_port=9099)

4 まとめ

Resnet 事前学習モデルと milvus ベクトル データベースの 2 つの主要なテクノロジーに基づいて、このプロジェクトは画像検索のための画像検索システムを構築します。構築プロセス中に Resnet ネットワーク モデルが変換され、全結合層が削除され、 Restnet エンコード後 各画像の出力ベクトル次元は 2048 で、これは milvus ベクトル データベースに保存されます。画像検索の効率を確保するために、スクリプトを通じてベクトル インデックスが milvus ベクトル データベースに構築されます。このプロジェクトは参照として使用でき、同様の画像検索プロジェクトの実際の開発に直接使用できます。

プロジェクトの完全なコード アドレス: code

おすすめ

転載: blog.csdn.net/lsb2002/article/details/132456845