VanillaNet 戦闘: VanillaNet を使用して画像分類を実現する (1)

まとめ

論文翻訳: https://blog.csdn.net/m0_47867638/article/details/131057152
公式ソースコード: https://github.com/huawei-noah/VanillaNet

VanillaNet は、2023 年に Huawei によってリリースされた最小限の CNN ネットワークです。最も一般的な CNN ネットワークを使用しますが、非常に良好な結果が得られます。
ここに画像の説明を挿入

この記事では、VanillaNet を使用して植物分類タスクを完了し、モデルでは VanillaNet10 を使用して VanillaNet の使用方法を示します。以下に示すように、事前トレーニングされたモデルがないため、VanillaNet10 はこのデータセットで 87% の ACC を達成しました。

画像の説明を追加してください
画像の説明を追加してください

この記事を通じて、次のことを学ぶことができます。

  1. 変換、CutOut、MixUp、CutMix およびその他の拡張方法の拡張を含む、データ拡張を使用するにはどうすればよいですか?
  2. VanillaNet モデルをトレーニング用に実装するにはどうすればよいですか?
  3. pytorch 独自の混合精度を使用するにはどうすればよいですか?
  4. グラデーションクリッピングを使用してグラデーションの爆発を防ぐにはどうすればよいですか?
  5. DP マルチグラフィックス カード トレーニングの使用方法は?
  6. 損失とacc曲線を描くにはどうすればよいですか?
  7. val 評価レポートを生成するにはどうすればよいですか?
  8. テスト スイートをテストするテスト スクリプトを作成するにはどうすればよいですか?
  9. コサインアニーリング戦略を使用して学習率を調整するにはどうすればよいですか?
  10. AverageMeter クラスを使用して ACC や損失などのカスタム変数をカウントするにはどうすればよいですか?
  11. ACC1 と ACC5 を理解して数えるにはどうすればよいですか?
  12. EMAの使い方は?
  13. Grad-CAMを使用してヒートマップ可視化を実現するにはどうすればよいですか?

基礎が弱く、上記の機能を理解するのが難しい場合は、私のコラム「古典的バックボーンネットワークの集中講義と実戦」を
読んで、誰でも受け入れやすいようにゼロからステップバイステップで説明します。

インストールパッケージ

ティムをインストールする

pip を使用するだけです:

pip install timm

ミックスアップ強化と EMA は timm を使用します

grad-cam をインストールする

pip install grad-cam

データ拡張のカットアウトとミックスアップ

パフォーマンスを向上させるために、Cutout と Mixup という 2 つの拡張メソッドをコードに追加しました。これらの機能強化を両方実装するには、torchtoolbox をインストールする必要があります。インストールコマンド:

pip install torchtoolbox

変換におけるカットアウトの実装。

from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    Cutout(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

])

パッケージをインポートする必要があります: from timm.data.mixup import Mixup,

Mixup と SoftTargetCrossEntropy を定義する

  mixup_fn = Mixup(
    mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
    prob=0.1, switch_prob=0.5, mode='batch',
    label_smoothing=0.1, num_classes=12)
 criterion_train = SoftTargetCrossEntropy()

パラメータの詳細な説明:

mixup_alpha (float): ミックスアップのアルファ値。 > 0 の場合、ミックスアップがアクティブです。

Cutmix_alpha (float): Cutmix アルファ値、> 0 の場合、cutmix はアクティブです。

Cutmix_minmax (List[float]): Cutmix の最小/最大画像比。Cutmix がアクティブです。None でない場合は、これとアルファを使用します。

Cutmix_minmax が設定されている場合、cutmix_alpha のデフォルトは 1.0 です。

prob (float): バッチまたは要素ごとにミックスアップまたはカットミックスを適用する確率。

switch_prob (float): カットミックスとミックスアップが両方ともアクティブな場合に切り替わる確率。

mode (str): mixup/cutmix パラメータ (各 'batch'、'pair' (要素のペア)、'elem' (要素) の適用方法。

correct_lam (bool):cutmix bbox が画像の境界線によってクリップされる場合に適用されます。ラムダ補正

label_smoothing (float): 混合ターゲット テンソルにラベル スムージングを適用します。

num_classes (int): ターゲットのクラスの数。

EMA

EMA (指数移動平均) は指数移動平均です。深層学習では、履歴パラメータのコピーを保存し、一定のトレーニング期間が経過した後、その履歴パラメータを使用して、現在学習されているパラメータを平滑化します。具体的な実装は以下の通りです。


import logging
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn

_logger = logging.getLogger(__name__)

class ModelEma:
    def __init__(self, model, decay=0.9999, device='', resume=''):
        # make a copy of the model for accumulating moving average of weights
        self.ema = deepcopy(model)
        self.ema.eval()
        self.decay = decay
        self.device = device  # perform ema on different device from model if set
        if device:
            self.ema.to(device=device)
        self.ema_has_module = hasattr(self.ema, 'module')
        if resume:
            self._load_checkpoint(resume)
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def _load_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        assert isinstance(checkpoint, dict)
        if 'state_dict_ema' in checkpoint:
            new_state_dict = OrderedDict()
            for k, v in checkpoint['state_dict_ema'].items():
                # ema model may have been wrapped by DataParallel, and need module prefix
                if self.ema_has_module:
                    name = 'module.' + k if not k.startswith('module') else k
                else:
                    name = k
                new_state_dict[name] = v
            self.ema.load_state_dict(new_state_dict)
            _logger.info("Loaded state_dict_ema")
        else:
            _logger.warning("Failed to find state_dict_ema, starting from loaded model weights")

    def update(self, model):
        # correct a mismatch in state dict keys
        needs_module = hasattr(model, 'module') and not self.ema_has_module
        with torch.no_grad():
            msd = model.state_dict()
            for k, ema_v in self.ema.state_dict().items():
                if needs_module:
                    k = 'module.' + k
                model_v = msd[k].detach()
                if self.device:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

モデルに追加されました。

#初始化
if use_ema:
     model_ema = ModelEma(
            model_ft,
            decay=model_ema_decay,
            device='cpu',
            resume=resume)

# 训练过程中,更新完参数后,同步update shadow weights
def train():
    optimizer.step()
    if model_ema is not None:
        model_ema.update(model)


# 将model_ema传入验证函数中
val(model_ema.ema, DEVICE, test_loader)

事前トレーニングのないモデルの場合、EMA がスコアを獲得できない可能性が高くなります。誰もがこれに注意する必要があります。

プロジェクト構造

VanillaNet_Demo
├─data1
│  ├─Black-grass
│  ├─Charlock
│  ├─Cleavers
│  ├─Common Chickweed
│  ├─Common wheat
│  ├─Fat Hen
│  ├─Loose Silky-bent
│  ├─Maize
│  ├─Scentless Mayweed
│  ├─Shepherds Purse
│  ├─Small-flowered Cranesbill
│  └─Sugar beet
├─models
│  └─vanillanet.py
├─mean_std.py
├─makedata.py
├─train.py
├─cam_image.py
└─test.py

モデル: ソースの公式コード。反対側のコードに適応的な変更が加えられています。事前トレーニングをロードしてモデルを呼び出すためのロジックを追加しました。
means_std.py: 平均値と標準偏差の値を計算します。
makedata.py: データセットを生成します。
ema.py: EMA スクリプト
train.py: SeaFormer モデルのトレーニング
cam_image.py: ヒート マップの視覚化

平均値と標準偏差を計算する

モデルをより速く収束させるには、mean と std の値を計算し、新しい means_std.py を作成してコードを挿入する必要があります。

from torchvision.datasets import ImageFolder
import torch
from torchvision import transforms

def get_mean_and_std(train_data):
    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=1, shuffle=False, num_workers=0,
        pin_memory=True)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    for X, _ in train_loader:
        for d in range(3):
            mean[d] += X[:, d, :, :].mean()
            std[d] += X[:, d, :, :].std()
    mean.div_(len(train_data))
    std.div_(len(train_data))
    return list(mean.numpy()), list(std.numpy())

if __name__ == '__main__':
    train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())
    print(get_mean_and_std(train_dataset))

データセット構造:

画像-20220221153058619

操作結果:

([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])

後で使用できるようにこの結果を記録してください。

データセットを生成する

整理した画像分類のデータセット構造はこんな感じ

data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet

pytorch と keras のデフォルトの読み込み方法は ImageNet データセット形式で、形式は次のとおりです。

├─data
│  ├─val
│  │   ├─Black-grass
│  │   ├─Charlock
│  │   ├─Cleavers
│  │   ├─Common Chickweed
│  │   ├─Common wheat
│  │   ├─Fat Hen
│  │   ├─Loose Silky-bent
│  │   ├─Maize
│  │   ├─Scentless Mayweed
│  │   ├─Shepherds Purse
│  │   ├─Small-flowered Cranesbill
│  │   └─Sugar beet
│  └─train
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet

形式変換スクリプト makedata.py を追加し、コードを挿入します。

import glob
import os
import shutil

image_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):
    print('true')
    #os.rmdir(file_dir)
    shutil.rmtree(file_dir)#删除再建立
    os.makedirs(file_dir)
else:
    os.makedirs(file_dir)

from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:
    file_class=file.replace("\\","/").split('/')[-2]
    file_name=file.replace("\\","/").split('/')[-1]
    file_class=os.path.join(train_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

for file in val_files:
    file_class=file.replace("\\","/").split('/')[-2]
    file_name=file.replace("\\","/").split('/')[-1]
    file_class=os.path.join(val_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

上記の内容を完了したら、トレーニングとテストを開始できます。

おすすめ

転載: blog.csdn.net/m0_47867638/article/details/131216849