知識蒸留 DEiT アルゴリズムの戦闘: RegNet 蒸留 DEiT モデルの使用

まとめ

論文翻訳: [パート 58] DEiT: アテンション トレーニング データによる効率的な画像変換と蒸留
DEiT は、蒸留トークンを導入することで蒸留を実現します。蒸留には 2 つの方法があります。

  • 1. 蒸留トークンを教師タグとして使用します。2 つのトークンは、注意を通じてトランスフォーマー内で相互作用します。蒸留を達成します。使用方法のリファレンス:
    DEiT の実践: DEiT を使用して画像分類タスクを実装する (1)
  • 2. 畳み込みニューラル ネットワークを通じてトークンを抽出し、変換器に畳み込みニューラル ネットワークから誘導バイアスなどのいくつかの畳み込み特徴を学習させます。この点については著者も懐疑的である。

この記事は、畳み込みニューラル ネットワークを使用して DEiT を抽出する 2 番目のポイントから始まります。
説明ビデオ: https://www.zhihu.com/zvideo/1588881049425276928

最終結論

結論から先に話しましょう!Teacher ネットワークは RegNet の regnetx_160 ネットワークを使用し、Student ネットワークは DEiT の deit_tiny_distilled_pa​​tch16_224 モデルを使用します。次のように

通信網 エポック ACC
DET 100 94%
RegNet 100 96%
DEiT+ハード 100 95%

ここに画像の説明を挿入

プロジェクト構造

DeiT_dist_demo
├─data
│  ├─train
│  │  ├─Black-grass
│  │  ├─Charlock
│  │  ├─Cleavers
│  │  ├─Common Chickweed
│  │  ├─Common wheat
│  │  ├─Fat Hen
│  │  ├─Loose Silky-bent
│  │  ├─Maize
│  │  ├─Scentless Mayweed
│  │  ├─Shepherds Purse
│  │  ├─Small-flowered Cranesbill
│  │  └─Sugar beet
│  └─val
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet
├─models
│  └─models.py
├─losses.py
├─teacher_train.py
├─student_train.py
├─train_kd.py
└─test.py

data: データセット。train と val に分かれています。
モデル: モデル ファイルを保存します。
loss.py: 損失ファイル、外部蒸留損失を計算します。
Teacher_train.py: Teacher モデルをトレーニングする
Student_train.py: Student モデルをトレーニングする
train_kd.py: 蒸留モデルをトレーニングする
test: テスト結果。

モデルと損失

モデル models.py と損失スクリプト loss.py は、公式モデルのリンク: https://github.com/facebookresearch/deit から取得する必要があります。

model.py コード

# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import torch
import torch.nn as nn
from functools import partial
from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_

__all__ = [
    'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
    'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
    'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
    'deit_base_distilled_patch16_384',
]

class DistilledVisionTransformer(VisionTransformer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        num_patches = self.patch_embed.num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
        self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
        trunc_normal_(self.dist_token, std=.02)
        trunc_normal_(self.pos_embed, std=.02)
        self.head_dist.apply(self._init_weights)
    def forward_features(self, x):
        # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
        # with slight modifications to add the dist_token
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        dist_token = self.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x[:, 0], x[:, 1]

    def forward(self, x):
        x, x_dist = self.forward_features(x)
        x = self.head(x)
        x_dist = self.head_dist(x_dist)
        if self.training:
            return x, x_dist
        else:
            # during inference, return the average of both classifier predictions
            return (x + x_dist) / 2


@register_model
def deit_tiny_patch16_224(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def deit_small_patch16_224(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def deit_base_patch16_224(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
    model = DistilledVisionTransformer(
        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    print(model.default_cfg)
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
    model = DistilledVisionTransformer(
        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
    model = DistilledVisionTransformer(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def deit_base_patch16_384(pretrained=False, **kwargs):
    model = VisionTransformer(
        img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
    model = DistilledVisionTransformer(
        img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model

loss.py コード

# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
"""
Implements the knowledge distillation loss
"""
import torch
from torch.nn import functional as F


class DistillationLoss(torch.nn.Module):
    """
    This module wraps a standard criterion and adds an extra knowledge distillation loss by
    taking a teacher model prediction and using it as additional supervision.
    """
    def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
                 distillation_type: str, alpha: float, tau: float):
        super().__init__()
        self.base_criterion = base_criterion
        self.teacher_model = teacher_model
        assert distillation_type in ['none', 'soft', 'hard']
        self.distillation_type = distillation_type
        self.alpha = alpha
        self.tau = tau

    def forward(self, inputs, outputs, labels):
        """
        Args:
            inputs: The original inputs that are feed to the teacher model
            outputs: the outputs of the model to be trained. It is expected to be
                either a Tensor, or a Tuple[Tensor, Tensor], with the original output
                in the first position and the distillation predictions as the second output
            labels: the labels for the base criterion
        """
        outputs_kd = None
        if not isinstance(outputs, torch.Tensor):
            # assume that the model outputs a tuple of [outputs, outputs_kd]
            outputs, outputs_kd = outputs
        base_loss = self.base_criterion(outputs, labels)
        if self.distillation_type == 'none':
            return base_loss

        if outputs_kd is None:
            raise ValueError("When knowledge distillation is enabled, the model is "
                             "expected to return a Tuple[Tensor, Tensor] with the output of the "
                             "class_token and the dist_token")
        # don't backprop throught the teacher
        with torch.no_grad():
            teacher_outputs = self.teacher_model(inputs)

        if self.distillation_type == 'soft':
            T = self.tau
            # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
            # with slight modifications
            distillation_loss = F.kl_div(
                F.log_softmax(outputs_kd / T, dim=1),
                #We provide the teacher's targets in log probability because we use log_target=True 
                #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719)
                #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both.
                F.log_softmax(teacher_outputs / T, dim=1),
                reduction='sum',
                log_target=True
            ) * (T * T) / outputs_kd.numel()
            #We divide by outputs_kd.numel() to have the legacy PyTorch behavior. 
            #But we also experiments output_kd.size(0) 
            #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details
        elif self.distillation_type == 'hard':
            distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))

        loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
        return loss

Teacher モデルをトレーニングする

教師は regnetx_160 を選択します。このモデルの事前学習モデルは比較的大きいため、直接ダウンロードできない場合は、特定の鉱山ダウンロードなどのダウンロード ツールを使用できます。

ステップ

新しい Teacher_train.py を作成し、コードを挿入します。

必要なライブラリをインポートする

import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torchvision import datasets
from torch.autograd import Variable
from timm.models import regnetx_160

import json
import os
# 定义训练过程

トレーニングおよび検証関数を定義する

# 设置随机因子
def seed_everything(seed=42):
    os.environ['PYHTONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

# 训练函数
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    sum_loss = 0
    total_num = len(train_loader.dataset)
    print(total_num, len(train_loader))
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data).to(device), Variable(target).to(device)
        out = model(data)
        loss = criterion(out, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print_loss = loss.data.item()
        sum_loss += print_loss
        if (batch_idx + 1) % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                       100. * (batch_idx + 1) / len(train_loader), loss.item()))
    ave_loss = sum_loss / len(train_loader)
    print('epoch:{},loss:{}'.format(epoch, ave_loss))

Best_ACC=0
# 验证过程
@torch.no_grad()
def val(model, device, test_loader):
    global Best_ACC
    model.eval()
    test_loss = 0
    correct = 0
    total_num = len(test_loader.dataset)
    print(total_num, len(test_loader))
    with torch.no_grad():
        for data, target in test_loader:
            data, target = Variable(data).to(device), Variable(target).to(device)
            out = model(data)
            loss = criterion(out, target)
            _, pred = torch.max(out.data, 1)
            correct += torch.sum(pred == target)
            print_loss = loss.data.item()
            test_loss += print_loss
        correct = correct.data.item()
        acc = correct / total_num
        avgloss = test_loss / len(test_loader)
        if acc > Best_ACC:
            torch.save(model, file_dir + '/' + 'best.pth')
            Best_ACC = acc
        print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            avgloss, correct, len(test_loader.dataset), 100 * acc))
        return acc

グローバルパラメータを定義する

if __name__ == '__main__':
    # 创建保存模型的文件夹
    file_dir = 'TeacherModel'
    if os.path.exists(file_dir):
        print('true')

        os.makedirs(file_dir, exist_ok=True)
    else:
        os.makedirs(file_dir)

    # 设置全局参数
    modellr = 1e-4
    BATCH_SIZE = 16
    EPOCHS = 100
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    SEED=42
    seed_everything(SEED)

画像の前処理と強化

 # 数据预处理7
    transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])

    ])
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
    ])


読み取りデータ

デフォルトでデータを読み取るには pytorch を使用します。

   # 读取数据
    dataset_train = datasets.ImageFolder('data/train', transform=transform)
    dataset_test = datasets.ImageFolder("data/val", transform=transform_test)
    with open('class.txt', 'w') as file:
        file.write(str(dataset_train.class_to_idx))
    with open('class.json', 'w', encoding='utf-8') as file:
        file.write(json.dumps(dataset_train.class_to_idx))
    # 导入数据
    train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

モデルと損失を設定する

    # 实例化模型并且移动到GPU
    criterion = nn.CrossEntropyLoss()
    model_ft = regnetx_160(pretrained=True)
    model_ft.reset_classifier(num_classes=12)
    model_ft.to(DEVICE)
    # 选择简单暴力的Adam优化器,学习率调低
    optimizer = optim.Adam(model_ft.parameters(), lr=modellr)
    cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)
    # 训练
    val_acc_list= {
    
    }
    for epoch in range(1, EPOCHS + 1):
        train(model_ft, DEVICE, train_loader, optimizer, epoch)
        cosine_schedule.step()
        acc=val(model_ft, DEVICE, test_loader)
        val_acc_list[epoch]=acc
        with open('result.json', 'w', encoding='utf-8') as file:
            file.write(json.dumps(val_acc_list))
    torch.save(model_ft, 'TeacherModel/model_final.pth')

上記のコードを完了したら、教師ネットワークのトレーニングを開始できます。

学生ネットワーク

学生ネットワークは deit_tiny_distilled_pa​​tch16_224 を使用します。これは比較的小規模なネットワークであり、モデルのサイズは 20M です。100 エポックのトレーニング。

ステップ

新しい Student_train.py を作成し、コードを挿入します。

必要なライブラリをインポートする

import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torchvision import datasets
from torch.autograd import Variable
from models.models import deit_tiny_distilled_patch16_224

import json
import os

トレーニングおよび検証関数を定義する

# 设置随机因子
def seed_everything(seed=42):
    os.environ['PYHTONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
# 定义训练过程
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    sum_loss = 0
    total_num = len(train_loader.dataset)
    print(total_num, len(train_loader))
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data).to(device), Variable(target).to(device)
        out = model(data)[0]
        loss = criterion(out, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print_loss = loss.data.item()
        sum_loss += print_loss
        if (batch_idx + 1) % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                       100. * (batch_idx + 1) / len(train_loader), loss.item()))
    ave_loss = sum_loss / len(train_loader)
    print('epoch:{},loss:{}'.format(epoch, ave_loss))

Best_ACC=0
# 验证过程
@torch.no_grad()
def val(model, device, test_loader):
    global Best_ACC
    model.eval()
    test_loss = 0
    correct = 0
    total_num = len(test_loader.dataset)
    print(total_num, len(test_loader))
    with torch.no_grad():
        for data, target in test_loader:
            data, target = Variable(data).to(device), Variable(target).to(device)
            out = model(data)
            loss = criterion(out, target)
            _, pred = torch.max(out.data, 1)
            correct += torch.sum(pred == target)
            print_loss = loss.data.item()
            test_loss += print_loss
        correct = correct.data.item()
        acc = correct / total_num
        avgloss = test_loss / len(test_loader)
        if acc > Best_ACC:
            torch.save(model, file_dir + '/' + 'best.pth')
            Best_ACC = acc
        print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            avgloss, correct, len(test_loader.dataset), 100 * acc))
        return acc

ここで、使用する公式モデルにより、通常のトレーニングを行う場合、戻り値が 2 つ、つまり x と x_dist になることに注意してください。
ここに画像の説明を挿入
損失の計算には、前の値、つまり次の値のみが必要です。

 out = model(data)[0]

検証する場合は値を返します。したがって、上記の操作を行う必要はありません。

 out = model(data)

グローバルパラメータを定義する

if __name__ == '__main__':
    # 创建保存模型的文件夹
    file_dir = 'StudentModel'
    if os.path.exists(file_dir):
        print('true')
        os.makedirs(file_dir, exist_ok=True)
    else:
        os.makedirs(file_dir)
    # 设置全局参数
    modellr = 1e-4
    BATCH_SIZE = 16
    EPOCHS = 100
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    SEED=42
    seed_everything(SEED)

画像の前処理と強化

 # 数据预处理7
    transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
    ])
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
    ])

読み取りデータ

デフォルトでデータを読み取るには pytorch を使用します。

    # 读取数据
    dataset_train = datasets.ImageFolder('data/train', transform=transform)
    dataset_test = datasets.ImageFolder("data/val", transform=transform_test)
    with open('class.txt', 'w') as file:
        file.write(str(dataset_train.class_to_idx))
    with open('class.json', 'w', encoding='utf-8') as file:
        file.write(json.dumps(dataset_train.class_to_idx))
    # 导入数据
    train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

モデルと損失を設定する

	 # 实例化模型并且移动到GPU
    criterion = nn.CrossEntropyLoss()
    model_ft = deit_tiny_distilled_patch16_224(pretrained=True)
    num_ftrs = model_ft.head.in_features
    model_ft.head = nn.Linear(num_ftrs, 12)
    num_ftrs_dist = model_ft.head_dist.in_features
    model_ft.head_dist = nn.Linear(num_ftrs_dist, 12)
    print(model_ft)
    model_ft.to(DEVICE)
    # 选择简单暴力的Adam优化器,学习率调低
    optimizer = optim.Adam(model_ft.parameters(), lr=modellr)
    cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)
    # 训练
    val_acc_list= {
    
    }
    for epoch in range(1, EPOCHS + 1):
        train(model_ft, DEVICE, train_loader, optimizer, epoch)
        cosine_schedule.step()
        acc=val(model_ft, DEVICE, test_loader)
        val_acc_list[epoch]=acc
        with open('result_student.json', 'w', encoding='utf-8') as file:
            file.write(json.dumps(val_acc_list))
    torch.save(model_ft, 'StudentModel/model_final.pth')

上記のコードを完了したら、Student ネットワークのトレーニングを開始できます。

蒸留された学生ネットワーク

学生ネットワークは deit_tiny_distilled_pa​​tch16_224 を引き続き使用し、教師ネットワークを使用して学生ネットワークを蒸留し、100 エポックの間トレーニングします。

ステップ

新しい train_kd.py.py を作成し、コードを挿入します。

必要なライブラリをインポートする

import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from timm.loss import LabelSmoothingCrossEntropy
from torchvision import datasets
from models.models import deit_tiny_distilled_patch16_224
import json
import os
from losses import DistillationLoss

トレーニングおよび検証関数を定義する

# 设置随机因子
def seed_everything(seed=42):
    os.environ['PYHTONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
# 定义训练过程
def train(s_net,t_net, device,criterionKD,train_loader, optimizer, epoch):
    s_net.train()
    sum_loss = 0
    total_num = len(train_loader.dataset)
    print(total_num, len(train_loader))
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        out_s = s_net(data)
        loss = criterionKD(data,out_s, target)
        loss.backward()
        optimizer.step()
        print_loss = loss.data.item()
        sum_loss += print_loss
        if (batch_idx + 1) % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                       100. * (batch_idx + 1) / len(train_loader), loss.item()))
    ave_loss = sum_loss / len(train_loader)
    print('epoch:{},loss:{}'.format(epoch, ave_loss))

Best_ACC=0
# 验证过程
@torch.no_grad()
def val(model, device,criterionCls, test_loader):
    global Best_ACC
    model.eval()
    test_loss = 0
    correct = 0
    total_num = len(test_loader.dataset)
    print(total_num, len(test_loader))
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            out_s = model(data)
            loss = criterionCls(out_s, target)
            _, pred = torch.max(out_s.data, 1)
            correct += torch.sum(pred == target)
            print_loss = loss.data.item()
            test_loss += print_loss
        correct = correct.data.item()
        acc = correct / total_num
        avgloss = test_loss / len(test_loader)
        if acc > Best_ACC:
            torch.save(model, file_dir + '/' + 'best.pth')
            Best_ACC = acc
        print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            avgloss, correct, len(test_loader.dataset), 100 * acc))
        return acc

グローバルパラメータを定義する

if __name__ == '__main__':
    # 创建保存模型的文件夹
    file_dir = 'KDModel'
    if os.path.exists(file_dir):
        print('true')
        os.makedirs(file_dir, exist_ok=True)
    else:
        os.makedirs(file_dir)

    # 设置全局参数
    modellr = 1e-4
    BATCH_SIZE = 4
    EPOCHS = 100
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    SEED=42
    seed_everything(SEED)
    distillation_type='hard'  #['none', 'soft', 'hard']
    distillation_alpha=0.5
    distillation_tau=1.0

distillation_type: 蒸留のタイプ。この記事ではハードを選択します。
distillation_alpha: α係数、蒸留損失の重み係数。
distillation_tau: T、蒸留温度の意味。

画像の前処理と強化

 # 数据预处理7
    transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])

    ])
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
    ])


読み取りデータ

デフォルトでデータを読み取るには pytorch を使用します。

    # 读取数据
    dataset_train = datasets.ImageFolder('data/train', transform=transform)
    dataset_test = datasets.ImageFolder("data/val", transform=transform_test)
    with open('class.txt', 'w') as file:
        file.write(str(dataset_train.class_to_idx))
    with open('class.json', 'w', encoding='utf-8') as file:
        file.write(json.dumps(dataset_train.class_to_idx))
    # 导入数据
    train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

モデルと損失を設定する

 	model_ft = deit_tiny_distilled_patch16_224(pretrained=True)
    num_ftrs = model_ft.head.in_features
    model_ft.head = nn.Linear(num_ftrs, 12)
    num_ftrs_dist = model_ft.head_dist.in_features
    model_ft.head_dist = nn.Linear(num_ftrs_dist, 12)
    print(model_ft)
    model_ft.to(DEVICE)
    # 选择简单暴力的Adam优化器,学习率调低
    optimizer = optim.Adam(model_ft.parameters(), lr=modellr)
    cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)
    teacher_model=torch.load('TeacherModel/best.pth')
    teacher_model.eval()
    # 实例化模型并且移动到GPU
    criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
    criterionKD = DistillationLoss(
        criterion, teacher_model, distillation_type, distillation_alpha, distillation_tau
    )
    criterionCls = nn.CrossEntropyLoss()
    # 训练
    val_acc_list= {
    
    }
    for epoch in range(1, EPOCHS + 1):
        train(model_ft,teacher_model, DEVICE,criterionKD, train_loader, optimizer, epoch)
        cosine_schedule.step()
        acc=val(model_ft,DEVICE,criterionCls , test_loader)
        val_acc_list[epoch]=acc
        with open('result_kd.json', 'w', encoding='utf-8') as file:
            file.write(json.dumps(val_acc_list))
    torch.save(model_ft, 'KDModel/model_final.pth')

上記のコードを完了したら、蒸留モードを開始できます。

結果の比較

保存した結果を読み込み、acc カーブを描画します。

import numpy as np
from matplotlib import pyplot as plt
import json
teacher_file='result.json'
student_file='result_student.json'
student_kd_file='result_kd.json'
def read_json(file):
    with open(file, 'r', encoding='utf8') as fp:
        json_data = json.load(fp)
        print(json_data)
    return json_data

teacher_data=read_json(teacher_file)
student_data=read_json(student_file)
student_kd_data=read_json(student_kd_file)


x =[int(x) for x in  list(dict(teacher_data).keys())]
print(x)

plt.plot(x, list(teacher_data.values()), label='teacher')
plt.plot(x,list(student_data.values()), label='student without IRG')
plt.plot(x, list(student_kd_data.values()), label='student with IRG')

plt.title('Test accuracy')
plt.legend()

plt.show()

ここに画像の説明を挿入

要約する

この記事では、外部モデル抽出アルゴリズムを使用して DeiT モデルを抽出する方法に焦点を当てます。役に立ったと思われる場合は、ブックマーク、いいね、転送を歓迎します。ご質問がある場合は、メッセージを残して議論することもできます。
この実際の戦闘で使用されるコードとデータ セットについては、以下で詳しく説明します。

https://download.csdn.net/download/うおおおおおおwwwwww/87323531

おすすめ

転載: blog.csdn.net/m0_47867638/article/details/131180435