記事ディレクトリ
まとめ
論文翻訳: https://blog.csdn.net/m0_47867638/article/details/131057152
公式ソースコード: https://github.com/huawei-noah/VanillaNet
VanillaNet は、2023 年に Huawei によってリリースされた最小限の CNN ネットワークです。最も一般的な CNN ネットワークを使用しますが、非常に良好な結果が得られます。
この記事では、VanillaNet を使用して植物分類タスクを完了し、モデルでは VanillaNet10 を使用して VanillaNet の使用方法を示します。以下に示すように、事前トレーニングされたモデルがないため、VanillaNet10 はこのデータセットで 87% の ACC を達成しました。
この記事を通じて、次のことを学ぶことができます。
- 変換、CutOut、MixUp、CutMix およびその他の拡張方法の拡張を含む、データ拡張を使用するにはどうすればよいですか?
- VanillaNet モデルをトレーニング用に実装するにはどうすればよいですか?
- pytorch 独自の混合精度を使用するにはどうすればよいですか?
- グラデーションクリッピングを使用してグラデーションの爆発を防ぐにはどうすればよいですか?
- DP マルチグラフィックス カード トレーニングの使用方法は?
- 損失とacc曲線を描くにはどうすればよいですか?
- val 評価レポートを生成するにはどうすればよいですか?
- テスト スイートをテストするテスト スクリプトを作成するにはどうすればよいですか?
- コサインアニーリング戦略を使用して学習率を調整するにはどうすればよいですか?
- AverageMeter クラスを使用して ACC や損失などのカスタム変数をカウントするにはどうすればよいですか?
- ACC1 と ACC5 を理解して数えるにはどうすればよいですか?
- EMAの使い方は?
- Grad-CAMを使用してヒートマップ可視化を実現するにはどうすればよいですか?
基礎が弱く、上記の機能を理解するのが難しい場合は、私のコラム「古典的バックボーンネットワークの集中講義と実戦」を
読んで、誰でも受け入れやすいようにゼロからステップバイステップで説明します。
インストールパッケージ
ティムをインストールする
pip を使用するだけです:
pip install timm
ミックスアップ強化と EMA は timm を使用します
grad-cam をインストールする
pip install grad-cam
データ拡張のカットアウトとミックスアップ
パフォーマンスを向上させるために、Cutout と Mixup という 2 つの拡張メソッドをコードに追加しました。これらの機能強化を両方実装するには、torchtoolbox をインストールする必要があります。インストールコマンド:
pip install torchtoolbox
変換におけるカットアウトの実装。
from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
Cutout(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
パッケージをインポートする必要があります: from timm.data.mixup import Mixup,
Mixup と SoftTargetCrossEntropy を定義する
mixup_fn = Mixup(
mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
prob=0.1, switch_prob=0.5, mode='batch',
label_smoothing=0.1, num_classes=12)
criterion_train = SoftTargetCrossEntropy()
パラメータの詳細な説明:
mixup_alpha (float): ミックスアップのアルファ値。 > 0 の場合、ミックスアップがアクティブです。
Cutmix_alpha (float): Cutmix アルファ値、> 0 の場合、cutmix はアクティブです。
Cutmix_minmax (List[float]): Cutmix の最小/最大画像比。Cutmix がアクティブです。None でない場合は、これとアルファを使用します。
Cutmix_minmax が設定されている場合、cutmix_alpha のデフォルトは 1.0 です。
prob (float): バッチまたは要素ごとにミックスアップまたはカットミックスを適用する確率。
switch_prob (float): カットミックスとミックスアップが両方ともアクティブな場合に切り替わる確率。
mode (str): mixup/cutmix パラメータ (各 'batch'、'pair' (要素のペア)、'elem' (要素) の適用方法。
correct_lam (bool):cutmix bbox が画像の境界線によってクリップされる場合に適用されます。ラムダ補正
label_smoothing (float): 混合ターゲット テンソルにラベル スムージングを適用します。
num_classes (int): ターゲットのクラスの数。
EMA
EMA (指数移動平均) は指数移動平均です。深層学習では、履歴パラメータのコピーを保存し、一定のトレーニング期間が経過した後、その履歴パラメータを使用して、現在学習されているパラメータを平滑化します。具体的な実装は以下の通りです。
import logging
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn
_logger = logging.getLogger(__name__)
class ModelEma:
def __init__(self, model, decay=0.9999, device='', resume=''):
# make a copy of the model for accumulating moving average of weights
self.ema = deepcopy(model)
self.ema.eval()
self.decay = decay
self.device = device # perform ema on different device from model if set
if device:
self.ema.to(device=device)
self.ema_has_module = hasattr(self.ema, 'module')
if resume:
self._load_checkpoint(resume)
for p in self.ema.parameters():
p.requires_grad_(False)
def _load_checkpoint(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
assert isinstance(checkpoint, dict)
if 'state_dict_ema' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict_ema'].items():
# ema model may have been wrapped by DataParallel, and need module prefix
if self.ema_has_module:
name = 'module.' + k if not k.startswith('module') else k
else:
name = k
new_state_dict[name] = v
self.ema.load_state_dict(new_state_dict)
_logger.info("Loaded state_dict_ema")
else:
_logger.warning("Failed to find state_dict_ema, starting from loaded model weights")
def update(self, model):
# correct a mismatch in state dict keys
needs_module = hasattr(model, 'module') and not self.ema_has_module
with torch.no_grad():
msd = model.state_dict()
for k, ema_v in self.ema.state_dict().items():
if needs_module:
k = 'module.' + k
model_v = msd[k].detach()
if self.device:
model_v = model_v.to(device=self.device)
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
モデルに追加されました。
#初始化
if use_ema:
model_ema = ModelEma(
model_ft,
decay=model_ema_decay,
device='cpu',
resume=resume)
# 训练过程中,更新完参数后,同步update shadow weights
def train():
optimizer.step()
if model_ema is not None:
model_ema.update(model)
# 将model_ema传入验证函数中
val(model_ema.ema, DEVICE, test_loader)
事前トレーニングのないモデルの場合、EMA がスコアを獲得できない可能性が高くなります。誰もがこれに注意する必要があります。
プロジェクト構造
VanillaNet_Demo
├─data1
│ ├─Black-grass
│ ├─Charlock
│ ├─Cleavers
│ ├─Common Chickweed
│ ├─Common wheat
│ ├─Fat Hen
│ ├─Loose Silky-bent
│ ├─Maize
│ ├─Scentless Mayweed
│ ├─Shepherds Purse
│ ├─Small-flowered Cranesbill
│ └─Sugar beet
├─models
│ └─vanillanet.py
├─mean_std.py
├─makedata.py
├─train.py
├─cam_image.py
└─test.py
モデル: ソースの公式コード。反対側のコードに適応的な変更が加えられています。事前トレーニングをロードしてモデルを呼び出すためのロジックを追加しました。
means_std.py: 平均値と標準偏差の値を計算します。
makedata.py: データセットを生成します。
ema.py: EMA スクリプト
train.py: SeaFormer モデルのトレーニング
cam_image.py: ヒート マップの視覚化
平均値と標準偏差を計算する
モデルをより速く収束させるには、mean と std の値を計算し、新しい means_std.py を作成してコードを挿入する必要があります。
from torchvision.datasets import ImageFolder
import torch
from torchvision import transforms
def get_mean_and_std(train_data):
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=1, shuffle=False, num_workers=0,
pin_memory=True)
mean = torch.zeros(3)
std = torch.zeros(3)
for X, _ in train_loader:
for d in range(3):
mean[d] += X[:, d, :, :].mean()
std[d] += X[:, d, :, :].std()
mean.div_(len(train_data))
std.div_(len(train_data))
return list(mean.numpy()), list(std.numpy())
if __name__ == '__main__':
train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())
print(get_mean_and_std(train_dataset))
データセット構造:
操作結果:
([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])
後で使用できるようにこの結果を記録してください。
データセットを生成する
整理した画像分類のデータセット構造はこんな感じ
data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet
pytorch と keras のデフォルトの読み込み方法は ImageNet データセット形式で、形式は次のとおりです。
├─data
│ ├─val
│ │ ├─Black-grass
│ │ ├─Charlock
│ │ ├─Cleavers
│ │ ├─Common Chickweed
│ │ ├─Common wheat
│ │ ├─Fat Hen
│ │ ├─Loose Silky-bent
│ │ ├─Maize
│ │ ├─Scentless Mayweed
│ │ ├─Shepherds Purse
│ │ ├─Small-flowered Cranesbill
│ │ └─Sugar beet
│ └─train
│ ├─Black-grass
│ ├─Charlock
│ ├─Cleavers
│ ├─Common Chickweed
│ ├─Common wheat
│ ├─Fat Hen
│ ├─Loose Silky-bent
│ ├─Maize
│ ├─Scentless Mayweed
│ ├─Shepherds Purse
│ ├─Small-flowered Cranesbill
│ └─Sugar beet
形式変換スクリプト makedata.py を追加し、コードを挿入します。
import glob
import os
import shutil
image_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):
print('true')
#os.rmdir(file_dir)
shutil.rmtree(file_dir)#删除再建立
os.makedirs(file_dir)
else:
os.makedirs(file_dir)
from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:
file_class=file.replace("\\","/").split('/')[-2]
file_name=file.replace("\\","/").split('/')[-1]
file_class=os.path.join(train_root,file_class)
if not os.path.isdir(file_class):
os.makedirs(file_class)
shutil.copy(file, file_class + '/' + file_name)
for file in val_files:
file_class=file.replace("\\","/").split('/')[-2]
file_name=file.replace("\\","/").split('/')[-1]
file_class=os.path.join(val_root,file_class)
if not os.path.isdir(file_class):
os.makedirs(file_class)
shutil.copy(file, file_class + '/' + file_name)
上記の内容を完了したら、トレーニングとテストを開始できます。