目次
近年、Self-Attention構造に基づくモデル、特にTransformerモデルの開発により、自然言語処理モデルの開発が大きく推進されています。Transformer の計算効率とスケーラビリティにより、100B を超えるパラメータを使用して前例のない規模のモデルをトレーニングすることができました。
ViT は、自然言語処理とコンピューター ビジョンを融合したものです。畳み込み演算に依存しなくても、画像分類タスクで良好な結果を達成できます。
モデル構造
ViT モデルの主な構造は、Transformer モデルの Encoder 部分 (Normalization の位置が標準の Transformer と異なるなど、一部の構造の順序が調整されています) とその構造図 [1] に基づいています。モデルの特徴 ViT モデルは主に
画像分類の分野で使用されます
。したがって、従来の Transformer と比較して、そのモデル構造は次のような特徴があります。
データセットの元の画像が複数のパッチに分割された後、2 次元パッチ (チャネルに関係なく) が 1 次元ベクトルに変換され、カテゴリ ベクトルと位置ベクトルをモデル入力として加えます。
モデル本体の Block 構造は Transformer の Encoder 構造をベースにしていますが、Normalization の位置が調整されており、その中で最も重要な構造はやはり Multi-head Attender 構造です。
モデルは、ブロックの積み重ね後に全結合層に接続され、カテゴリ ベクトルの出力を入力として受け入れ、分類に使用されます。通常、最後に完全に接続されたレイヤーをヘッドと呼び、トランスフォーマー エンコーダー部分がバックボーンとなります。
以下では、コード例を通じて、ViT に基づく ImageNet 分類タスクの実装について詳しく説明します。
MindSpore に興味がある場合は、 Shengsi MindSpore コミュニティをフォローしてください。
1. 環境整備
1.モデルアーツ公式サイトにアクセス
クラウド プラットフォームは、ユーザーがモデルを迅速に作成およびデプロイし、フルサイクル AI ワークフローを管理するのに役立ちます。次のクラウド プラットフォームを選択して Shengsi MindSpore の使用を開始し、インストール コマンドを取得し、MindSpore2.0.0-alpha バージョンをインストールして、ModelArts 公式 Web サイトに入ります。Shengsiチュートリアル_
以下の CodeLab を選択して、すぐに体験してください
環境が構築されるまで待ちます
2. CodeLab を使用して Notebook インスタンスを体験する
NoteBook サンプル コード、Vision Transformer 画像分類を.ipynb
サンプル コードとしてダウンロード
.ipynb
ファイルをアップロードするには、ModelArts Upload Files を選択します。
カーネル環境を選択します
GPU環境に切り替え、初回期間限定無料に切り替える
Shengsi MindSpore 公式ウェブサイトに入り、上のインストールをクリックします
インストールコマンドを取得する
ノートブックに戻り、コードの最初のブロックの前にコマンドを追加します。
conda update -n base -c defaults conda
MindSpore 2.0 GPU バージョンをインストールする
conda install mindspore=2.0.0a0 -c mindspore -c conda-forge
マインドビジョンをインストールする
pip install mindvision
インストールダウンロードダウンロード
pip install download
2. 環境準備とデータ読み込み
実験を開始する前に、Python 環境と MindSpore がローカルにインストールされていることを確認してください。
まず最初に、このケースのデータ セットをダウンロードする必要があります。完全な ImageNet データ セットは http://image-net.org からダウンロードできます。このケースで使用されるデータ セットは、ImageNet から選択されたサブセットです。
最初のコードを実行すると、自動的にダウンロードされて解凍されます。データセットのパスが次の構造になっていることを確認してください。
.dataset/
├── ILSVRC2012_devkit_t12.tar.gz
├── train/
├── infer/
└── val/
from download import download
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"
path = download(dataset_url, path, kind="zip", replace=True)
import os
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms
data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)
trans_train = [
transforms.RandomCropDecodeResize(size=224,
scale=(0.08, 1.0),
ratio=(0.75, 1.333)),
transforms.RandomHorizontalFlip(prob=0.5),
transforms.Normalize(mean=mean, std=std),
transforms.HWC2CHW()
]
dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)
3. モデル分析
以下では、ViT モデルの内部構造をコードを通じて詳細に分析します。
トランスの基礎
Transformer モデルは 2017 年の記事に由来します [2]。この記事で提案したアテンション機構に基づくエンコーダ・デコーダ構造は、自然言語処理の分野で大きな成功を収めています。モデルの構造を次の図に示します。
その主な構造は複数の Encoder および Decoder モジュールで構成されており、Encoder および Decoder の詳細な構造は次の図 [2] に示されています。
エンコーダとデコーダは、マルチヘッド アテンション層、フィード
フォワード層、正規化層、さらには残留接続 (
図の「追加」) などの多くの構造で構成されます。ただし、最も重要な構造はマルチヘッド アテンション
構造です。これはセルフ アテンション メカニズムに基づいており、複数のセルフ アテンションを並列に構成したものです。したがって、セルフアテンションを理解することで、Transformer の核心がわかります。
アテンションモジュール
from mindspore import nn, ops
class Attention(nn.Cell):
def __init__(self,
dim: int,
num_heads: int = 8,
keep_prob: float = 1.0,
attention_keep_prob: float = 1.0):
super(Attention, self).__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = ms.Tensor(head_dim ** -0.5)
self.qkv = nn.Dense(dim, dim * 3)
self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)
self.out = nn.Dense(dim, dim)
self.out_drop = nn.Dropout(p=1.0-keep_prob)
self.attn_matmul_v = ops.BatchMatMul()
self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
self.softmax = nn.Softmax(axis=-1)
def construct(self, x):
"""Attention construct."""
b, n, c = x.shape
qkv = self.qkv(x)
qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
q, k, v = ops.unstack(qkv, axis=0)
attn = self.q_matmul_k(q, k)
attn = ops.mul(attn, self.scale)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
out = self.attn_matmul_v(attn, v)
out = ops.transpose(out, (0, 2, 1, 3))
out = ops.reshape(out, (b, n, c))
out = self.out(out)
out = self.out_drop(out)
return out
トランスエンコーダ
Self-Attention 構造を理解した後、
Feed Forward、Residual Connection、その他の構造を結合することで Transformer の基本構造を形成できます。次のコードは、Feed Forward と Residual
Connection 構造を実装します。
from typing import Optional, Dict
class FeedForward(nn.Cell):
def __init__(self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
activation: nn.Cell = nn.GELU,
keep_prob: float = 1.0):
super(FeedForward, self).__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.dense1 = nn.Dense(in_features, hidden_features)
self.activation = activation()
self.dense2 = nn.Dense(hidden_features, out_features)
self.dropout = nn.Dropout(p=1.0-keep_prob)
def construct(self, x):
"""Feed Forward construct."""
x = self.dense1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.dense2(x)
x = self.dropout(x)
return x
class ResidualCell(nn.Cell):
def __init__(self, cell):
super(ResidualCell, self).__init__()
self.cell = cell
def construct(self, x):
"""ResidualCell construct."""
return self.cell(x) + x
次に、セルフ アテンションを使用して ViT モデルの TransformerEncoder 部分を構築します。これは、次の図 [1] に示すように、Transformer エンコーダ パートを構築するのと同様です。
ビットエンコーダー
ViT モデルの基本構造は、標準の Transformer とは異なります。主に、Normalization の位置が Self-tention と Feed
Forward の前に配置され、Residual Connection、Feed
Forward、Normalization などの他の構造は Transformer と同様に設計されています。Transformer 構造の図から、複数のサブエンコーダーの積層によってモデル エンコーダーの構築が完了することがわかります。ViT モデルでも、この考え方は引き続き踏襲されています。ハイパーパラメーター num_layers を設定することで、積層されるレイヤーの数が決まります。と判断できる。
Residual Connection と Normalization の構造により、モデルの強力なスケーラビリティが保証され (Residual Connection の
役割である深い処理後に情報が劣化しないようにするため)、Normalization と Dropout を適用することで、モデルの一般化能力を強化できます。
モデル。Transformer の構造は、次のソース コードから明確にわかります。TransformerEncoder 構造と多層パーセプトロン (MLP) を組み合わせることで、ViT モデルのバックボーン部分が構成されます。
class TransformerEncoder(nn.Cell):
def __init__(self,
dim: int,
num_layers: int,
num_heads: int,
mlp_dim: int,
keep_prob: float = 1.,
attention_keep_prob: float = 1.0,
drop_path_keep_prob: float = 1.0,
activation: nn.Cell = nn.GELU,
norm: nn.Cell = nn.LayerNorm):
super(TransformerEncoder, self).__init__()
layers = []
for _ in range(num_layers):
normalization1 = norm((dim,))
normalization2 = norm((dim,))
attention = Attention(dim=dim,
num_heads=num_heads,
keep_prob=keep_prob,
attention_keep_prob=attention_keep_prob)
feedforward = FeedForward(in_features=dim,
hidden_features=mlp_dim,
activation=activation,
keep_prob=keep_prob)
layers.append(
nn.SequentialCell([
ResidualCell(nn.SequentialCell([normalization1, attention])),
ResidualCell(nn.SequentialCell([normalization2, feedforward]))
])
)
self.layers = nn.SequentialCell(layers)
def construct(self, x):
"""Transformer construct."""
return self.layers(x)
ViTモデルの入力
従来の Transformer 構造は、主に自然言語の分野で単語ベクトル (Word Embedding または Word Vector) を処理するために使用されます。単語ベクトルと従来の画像データの主な違いは、単語ベクトルは通常 1 次元ベクトルとしてスタックされるのに対し、画像データはスタックされることです。スタッキング、マルチヘッド アテンション メカニズムは、1 次元の単語ベクトルのスタッキングを処理するときに単語ベクトル間の接続、つまりコンテキスト セマンティクスを抽出します。これにより、Transformer は自然言語処理の分野で非常に役立ちます。 、および 2 次元の画像行列は 1 次元の Word ベクトル変換とどのように比較されるのかなど、Transformer が画像処理の分野に参入するための小さな敷居となっています。
ViT モデルでは次のようになります。
入力画像を各チャンネルの 16*16 パッチに分割することで、このステップはコンボリューション演算によって実行されます。もちろん、手動で分割することもできますが、コンボリューション演算でも目的を達成でき、1 回の実行で済みます。 ; たとえば、
224 x 224 の入力画像は最初に畳み込み処理によって 16 x 16 のパッチを取得し、その後、各パッチのサイズは 14 x 14 になります。
次に、各パッチの行列を 1 次元ベクトルに引き伸ばし、近似的なワード ベクトル スタッキングの効果を取得します。前のステップで取得した 14 x 14 パッチは、長さ 196 のベクトルに変換されます。
これは、画像入力ネットワークが通過する最初のステップです。具体的なパッチ埋め込みコードは次のとおりです。
class PatchEmbedding(nn.Cell):
MIN_NUM_PATCHES = 4
def __init__(self,
image_size: int = 224,
patch_size: int = 16,
embed_dim: int = 768,
input_channels: int = 3):
super(PatchEmbedding, self).__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)
def construct(self, x):
"""Path Embedding construct."""
x = self.conv(x)
b, c, h, w = x.shape
x = ops.reshape(x, (b, c, h * w))
x = ops.transpose(x, (0, 2, 1))
return x
入力画像はパッチに分割された後、pos_embeddingとclass_embeddingの2つの処理を経ます。
class_embedding は主にテキスト分類用の BERT モデルのアイデアを利用しており、各単語ベクトルの前 (通常はベクトルの最初の場所) にカテゴリ値を追加し
、前のステップで取得した 196 次元のベクトルに class_embedding を追加すると 197 になります。寸法。追加された class_embedding は学習可能なパラメータであり、ネットワークの継続的な学習の後、最終的な出力カテゴリは出力ベクトルの最初の次元の出力によって最終的に決定されます; 入力が 16 x 16 パッチであるため、出力は分類されます分類用の 16 x 16 class_embeddings として。
pos_embedding は、処理されたパッチ マトリックスに追加される学習可能なパラメーターのセットでもあります。
pos_embedding も学習可能なパラメータであるため、その追加はフル リンク ネットワークと畳み込みのバイアスに似ています。このステップでは、長さの次元が 197 のトレーニング可能なベクトルを作成し、それを class_embedding の後のベクトルに追加します。
実際、pos_embedding には合計 4 つのスキームがあります。しかし、著者の議論によれば、pos_embedding を追加するだけで pos_embedding を追加しないのは大きな影響を及ぼし、pos_embedding が 1 次元であるか 2 次元であるかは分類結果にほとんど影響を与えないため、コードでは 1 次元としています。 pos_embedding も使用しますが、pos_embedding の前に class_embedding が追加されるため、pos_embedding の次元はパッチストレッチ後の次元より 1 高くなります。
一般に、ViT モデルは、コンテキスト セマンティクスを処理する際に Transformer モデルを引き続き利用し、画像を "バリアント ワード ベクトル" に変換してから処理します。この変換の重要性は、複数のパッチの間にスペースがあることです。 、これは一種の「空間セマンティクス」に似ているため、より優れた処理効果が得られます。
ViT全体を構築する
次のコードは、完全な ViT モデルを構築します。
from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameter
def init(init_type, shape, dtype, name, requires_grad):
"""Init."""
initial = initializer(init_type, shape, dtype).init_data()
return Parameter(initial, name=name, requires_grad=requires_grad)
class ViT(nn.Cell):
def __init__(self,
image_size: int = 224,
input_channels: int = 3,
patch_size: int = 16,
embed_dim: int = 768,
num_layers: int = 12,
num_heads: int = 12,
mlp_dim: int = 3072,
keep_prob: float = 1.0,
attention_keep_prob: float = 1.0,
drop_path_keep_prob: float = 1.0,
activation: nn.Cell = nn.GELU,
norm: Optional[nn.Cell] = nn.LayerNorm,
pool: str = 'cls') -> None:
super(ViT, self).__init__()
self.patch_embedding = PatchEmbedding(image_size=image_size,
patch_size=patch_size,
embed_dim=embed_dim,
input_channels=input_channels)
num_patches = self.patch_embedding.num_patches
self.cls_token = init(init_type=Normal(sigma=1.0),
shape=(1, 1, embed_dim),
dtype=ms.float32,
name='cls',
requires_grad=True)
self.pos_embedding = init(init_type=Normal(sigma=1.0),
shape=(1, num_patches + 1, embed_dim),
dtype=ms.float32,
name='pos_embedding',
requires_grad=True)
self.pool = pool
self.pos_dropout = nn.Dropout(p=1.0-keep_prob)
self.norm = norm((embed_dim,))
self.transformer = TransformerEncoder(dim=embed_dim,
num_layers=num_layers,
num_heads=num_heads,
mlp_dim=mlp_dim,
keep_prob=keep_prob,
attention_keep_prob=attention_keep_prob,
drop_path_keep_prob=drop_path_keep_prob,
activation=activation,
norm=norm)
self.dropout = nn.Dropout(p=1.0-keep_prob)
self.dense = nn.Dense(embed_dim, num_classes)
def construct(self, x):
"""ViT construct."""
x = self.patch_embedding(x)
cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
x = ops.concat((cls_tokens, x), axis=1)
x += self.pos_embedding
x = self.pos_dropout(x)
x = self.transformer(x)
x = self.norm(x)
x = x[:, 0]
if self.training:
x = self.dropout(x)
x = self.dense(x)
return x
全体的なフローチャートは次のとおりです。
4. モデルのトレーニングと推論
モデルトレーニング
from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train
# define super parameter
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()
# construct model
network = ViT()
# load ckpt
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
# define learning rate
lr = nn.cosine_decay_lr(min_lr=float(0),
max_lr=0.00005,
total_step=epoch_size * step_size,
step_per_epoch=step_size,
decay_epoch=10)
# define optimizer
network_opt = nn.Adam(network.trainable_params(), lr, momentum)
# define loss function
class CrossEntropySmooth(LossBase):
"""CrossEntropy."""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = ops.OneHot()
self.sparse = sparse
self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss
network_loss = CrossEntropySmooth(sparse=True,
reduction="mean",
smooth_factor=0.1,
num_classes=num_classes)
# set checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)
# initialize model
# "Ascend + mixed precision" can improve performance
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={
"acc"}, amp_level="O2")
else:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={
"acc"}, amp_level="O0")
# train model
model.train(epoch_size,
dataset_train,
callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
dataset_sink_mode=False,)
モデルの検証
dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)
trans_val = [
transforms.Decode(),
transforms.Resize(224 + 32),
transforms.CenterCrop(224),
transforms.Normalize(mean=mean, std=std),
transforms.HWC2CHW()
]
dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)
# construct model
network = ViT()
# load ckpt
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
network_loss = CrossEntropySmooth(sparse=True,
reduction="mean",
smooth_factor=0.1,
num_classes=num_classes)
# define metric
eval_metrics = {
'Top_1_Accuracy': train.Top1CategoricalAccuracy(),
'Top_5_Accuracy': train.Top5CategoricalAccuracy()}
if ascend_target:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")
else:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")
# evaluate model
result = model.eval(dataset_val)
print(result)
モデル推論
dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)
trans_infer = [
transforms.Decode(),
transforms.Resize([224, 224]),
transforms.Normalize(mean=mean, std=std),
transforms.HWC2CHW()
]
dataset_infer = dataset_infer.map(operations=trans_infer,
input_columns=["image"],
num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)
import os
import pathlib
import cv2
import numpy as np
from PIL import Image
from enum import Enum
from scipy import io
class Color(Enum):
"""dedine enum color."""
red = (0, 0, 255)
green = (0, 255, 0)
blue = (255, 0, 0)
cyan = (255, 255, 0)
yellow = (0, 255, 255)
magenta = (255, 0, 255)
white = (255, 255, 255)
black = (0, 0, 0)
def check_file_exist(file_name: str):
"""check_file_exist."""
if not os.path.isfile(file_name):
raise FileNotFoundError(f"File `{
file_name}` does not exist.")
def color_val(color):
"""color_val."""
if isinstance(color, str):
return Color[color].value
if isinstance(color, Color):
return color.value
if isinstance(color, tuple):
assert len(color) == 3
for channel in color:
assert 0 <= channel <= 255
return color
if isinstance(color, int):
assert 0 <= color <= 255
return color, color, color
if isinstance(color, np.ndarray):
assert color.ndim == 1 and color.size == 3
assert np.all((color >= 0) & (color <= 255))
color = color.astype(np.uint8)
return tuple(color)
raise TypeError(f'Invalid type for color: {type(color)}')
def imread(image, mode=None):
"""imread."""
if isinstance(image, pathlib.Path):
image = str(image)
if isinstance(image, np.ndarray):
pass
elif isinstance(image, str):
check_file_exist(image)
image = Image.open(image)
if mode:
image = np.array(image.convert(mode))
else:
raise TypeError("Image must be a `ndarray`, `str` or Path object.")
return image
def imwrite(image, image_path, auto_mkdir=True):
"""imwrite."""
if auto_mkdir:
dir_name = os.path.abspath(os.path.dirname(image_path))
if dir_name != '':
dir_name = os.path.expanduser(dir_name)
os.makedirs(dir_name, mode=777, exist_ok=True)
image = Image.fromarray(image)
image.save(image_path)
def imshow(img, win_name='', wait_time=0):
"""imshow"""
cv2.imshow(win_name, imread(img))
if wait_time == 0: # prevent from hanging if windows was closed
while True:
ret = cv2.waitKey(1)
closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
# if user closed window or if some key pressed
if closed or ret != -1:
break
else:
ret = cv2.waitKey(wait_time)
def show_result(img: str,
result: Dict[int, float],
text_color: str = 'green',
font_scale: float = 0.5,
row_width: int = 20,
show: bool = False,
win_name: str = '',
wait_time: int = 0,
out_file: Optional[str] = None) -> None:
"""Mark the prediction results on the picture."""
img = imread(img, mode="RGB")
img = img.copy()
x, y = 0, row_width
text_color = color_val(text_color)
for k, v in result.items():
if isinstance(v, float):
v = f'{v:.2f}'
label_text = f'{k}: {v}'
cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
font_scale, text_color)
y += row_width
if out_file:
show = False
imwrite(img, out_file)
if show:
imshow(img, win_name, wait_time)
def index2label():
"""Dictionary output for image numbers and categories of the ImageNet dataset."""
metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
meta = io.loadmat(metafile, squeeze_me=True)['synsets']
nums_children = list(zip(*meta))[4]
meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
_, wnids, classes = list(zip(*meta))[:3]
clssname = [tuple(clss.split(', ')) for clss in classes]
wnid2class = {
wnid: clss for wnid, clss in zip(wnids, clssname)}
wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])
mapping = {
}
for index, (_, class_name) in enumerate(wind2class_name):
mapping[index] = class_name[0]
return mapping
# Read data for inference
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
image = image["image"]
image = ms.Tensor(image)
prob = model.predict(image)
label = np.argmax(prob.asnumpy(), axis=1)
mapping = index2label()
output = {
int(label): mapping[int(label)]}
print(output)
show_result(img="./dataset/infer/n01440764/ILSVRC2012_test_00000279.JPEG",
result=output,
out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")
推論プロセスが完了すると、推論フォルダーの下に画像の推論結果が表示され、予測結果がドーベルマンであり、予想された結果と同じであることがわかり、モデルの精度が検証されます。