知識の蒸留 - 原則 + コードの実践 (蒸留 CNN および漸進的蒸留拡散)

1. 蒸留の基本概念

知識の蒸留は、模型压缩迁移学习 で広く使用されています。先駆的な研究は、「ニューラル ネットワークでの知識の蒸留」 であるはずです。この記事における著者の動機は、方法を見つけることです。通常、トレーニングされた は、別の を教えるために使用されます。通常、モデル A はモデル B よりも強力です。モデル A の指導の下では、モデル B は独学よりもよく学習できます。 把多个模型的知识提炼给单个模型Teacher Model AStudent Model B

方法: まず教師ネットワークをトレーニングし、次にこの teacher网络的输出数据的真实标签训练student网络 に使用します。知識の蒸留を使用すると、ネットワークを大規模ネットワークから小規模ネットワークに変換し、大規模ネットワークに近いパフォーマンスを維持できます。また、複数のネットワークの学習した知識を 1 つのネットワークに転送して、単一ネットワークのパフォーマンスを維持することもできます。の結果に似ています。

たとえば、次の画像分類タスクの場合:

ここに画像の説明を挿入します

  • 従来のトレーニング: Teacher ネットワークがない場合、データのみが Student ネットワークを通過します。ソフトマックスの後、確率分布値 q が出力され、q label p による Cross_entropy loss の検出は Hard loss と呼ばれます。これは、この p が実数値のワンホット ベクトルであり、より近い値になることが期待されるためです。 q と p があればあるほど良いです。

  • 知識の蒸留: 教師の協力がある場合、生徒と教師のネットワークから損失が生じます。そして、教師による q' 出力は 带温度的Softmax を通過する必要があります (スムーズにするため、考え方は label smooth と似ています)。 < a i=5> に続いて を実行して損失を計算します。 = + < a i=9>。 L = α ⋅ ハード _ L o s s + ( 1 − α ) ⋅ ソフト t _ L o s s = α ⋅ C E ( p , q ) + ( 1 − α ) ⋅ C E ( q ' ' , q ) L=\alpha\cdot Hard\_Loss+(1-\alpha)\cdot Soft\_Loss=\alpha\cdot CE(p,q) + (1-\alpha)\cdot CE(q' ',q) q''q总lossTeacher q'' 和 Student q 的 lossStudent q 和 label p 的 loss
    L=あるHard_ Loss+(1α)Soft_ Loss=あるCE(p,q)+(1α)CE(q''q)ここに画像の説明を挿入します

  • SoftMax问题
    普通的Cross Entropy Loss是由NLL LossLogSoftmax组成的:

F.cross_entropy(p,target)) = F.nll_loss(torch.log(torch.softmax(p)), target)

このcross_entropy 損失のソフトマックスは、実際にはそれほどソフトではありません。出力確率分布は、 正しいカテゴリの信頼度は高くなりますが、他のカテゴリの確率はほぼ 0 です。この場合、教師ネットワークによって学習されたデータの類似情報 (たとえば、数字 2、3、7 が非常に似ている、この種のソフトラベル情報) を生徒ネットワークに伝えるのは困難です。 したがって、この記事では温度係数 T を持つソフトマックスを提案しています。 ここで q i q_i Softmax-T
ここに画像の説明を挿入します
qi は生徒のネットワーク学習 (ソフト ターゲット) のオブジェクトです。 z i z_i i ニューラル ネットワークのソフトマックス前の出力ロジットです。 T が 1 に設定されている場合、この式はソフトマックスとなり、ロジットに基づいて各カテゴリの確率を出力します。 T が 0 に近い場合、最大値は 1 に近づき、その他の値は 0 に近づきます。これは、ワンホット エンコーディングと似ています。 Tが大きいほど出力結果の分布が平坦になるため、平滑化に相当し、同様の情報を保持する役割を果たします。 T が無限大に等しい場合、それは一様分布です。

比較Softmax (上) と Softmax-T (下) モデル予測結果の確率分布の視覚化:
ここに画像の説明を挿入します

最終的な蒸留損失は、元の CE 損失の Soft LossSoftmaxSoftmax_T に置き換えて KD 損失を取得します。 :
K D _ L o s s = α ⋅ H a rd _ L o s s + ( 1 − α ) ⋅ S o f t _ L o s s = α ⋅ C E ( p , q ) + ( 1 − α ) ⋅ C E ( q '' , q ) KD\_Loss=\alpha\cdot Hard\_Loss+(1-\alpha)\cdot Soft\_Loss=\alpha\cdot CE(p,q) + (1-\alpha)\ cdot CE(q'',q) KD_Loss=あるHard_ Loss+(1α)Soft_ Loss=あるCE(p,q)+(1α)CE(q''q)
p は実際のラベル ラベル、q は生徒の出力、q'' は教師の出力です。

def distillation_loss(y,labels,teacher_scores,temp,alpha):
	soft_loss = nn.KLDivLoss()(F.log_softmax(y/temp, dim=1), F.softmax(teacher_scores/temp,dim=1))
	hard_loss = F.cross_entropy(y,labels)
    return soft_loss *(temp*temp*2.0*alpha) + hard_loss *(1. - alpha)

2. 蒸留 MNIST CNN 分類コードの実践

  1. ライブラリのインポート:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import torchvision
from torchvision import transforms
  1. 教師 CNN モデル (big) と生徒 CNN モデル (small) を定義します:
class TeacherModel(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(TeacherModel, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784, 1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, num_classes)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.relu(self.dropout(self.fc1(x)))
        x = self.relu(self.dropout(self.fc2(x)))
        x = self.fc3(x)
        return x

class StudentModel(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(StudentModel, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784, 20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, num_classes)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.relu(self.dropout(self.fc1(x)))
        x = self.relu(self.dropout(self.fc2(x)))
        x = self.fc3(x)
        return x
  1. 教師モデルのトレーニングの機能:
def teacher(device, train_loader, test_loader):
    print('--------------teachermodel start--------------')
    model = TeacherModel()
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    epochs = 6
    for epoch in range(epochs):
        model.train()

        for data, target in tqdm(train_loader):
            data = data.to(device)
            target = target.to(device)
            preds = model(data)
            loss = criterion(preds, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        model.eval()
        num_correct = 0
        num_samples = 0

        with torch.no_grad():
            for x, y in test_loader:
                x = x.to(device)
                y = y.to(device)
                preds = model(x)
                predictions = preds.max(1).indices
                num_correct += (predictions.eq(y)).sum().item()
                num_samples += predictions.size(0)
            acc = num_correct / num_samples

        model.train()
        print('Epoch:{}\t Acc:{:.4f}'.format(epoch + 1, acc))
    torch.save(model, 'teacher.pkl')
    print('--------------teachermodel end--------------')
  1. 学生モデルの機能 independently
def student(device, train_loader, test_loader):
    print('--------------studentmodel start--------------')

    model = StudentModel()
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    epochs = 3
    for epoch in range(epochs):
        model.train()

        for data, target in tqdm(train_loader):
            data = data.to(device)
            target = target.to(device)
            preds = model(data)
            loss = criterion(preds, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        model.eval()
        num_correct = 0
        num_samples = 0

        with torch.no_grad():
            for x, y in test_loader:
                x = x.to(device)
                y = y.to(device)
                # print(y)
                preds = model(x)
                #             print(preds)
                predictions = preds.max(1).indices
                # print(predictions)
                num_correct += (predictions.eq(y)).sum().item()
                num_samples += predictions.size(0)
            acc = num_correct / num_samples

        model.train()
        print('Epoch:{}\t Acc:{:.4f}'.format(epoch + 1, acc))
    print('--------------studentmodel prediction end--------------')
  1. 教師モデルから生徒モデルへのDistillingの機能:(核心)
def kd(teachermodel, device, train_loader, test_loader):
    print('--------------kdmodel start--------------')

    teachermodel.eval()

    studentmodel = StudentModel()
    studentmodel = studentmodel.to(device)
    studentmodel.train()

    temp = 7    #蒸馏温度
    alpha = 0.3

    hard_loss = nn.CrossEntropyLoss()
    soft_loss = nn.KLDivLoss(reduction='batchmean')

    optimizer = torch.optim.Adam(studentmodel.parameters(), lr=1e-4)

    epochs = 20
    for epoch in range(epochs):
        for data, target in tqdm(train_loader):
            data = data.to(device)
            target = target.to(device)

            with torch.no_grad():
                teacher_preds = teachermodel(data)

            student_preds = studentmodel(data)
            student_loss = hard_loss(student_preds, target) #hard_loss

            distillation_loss = soft_loss(
                F.log_softmax(student_preds / temp, dim=1),
                F.softmax(teacher_preds / temp, dim=1)
            )   #soft_loss

            loss = alpha * student_loss + (1 - alpha) * distillation_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        studentmodel.eval()
        num_correct = 0
        num_samples = 0

        with torch.no_grad():
            for x, y in test_loader:
                x = x.to(device)
                y = y.to(device)
                preds = studentmodel(x)
                predictions = preds.max(1).indices
                num_correct += (predictions.eq(y)).sum().item()
                num_samples += predictions.size(0)
            acc = num_correct / num_samples

        studentmodel.train()
        print('Epoch:{}\t Acc:{:.4f}'.format(epoch + 1, acc))
    print('--------------kdmodel end--------------')
  1. 主な機能 (データのロード、トレイン機能の実装):
if __name__ == '__main__':
    torch.manual_seed(0)

    device = torch.device("cuda" if torch.cuda.is_available else "cpu")
    torch.backends.cudnn.benchmark = True
    #加载数据集
    X_train = torchvision.datasets.MNIST(
        root="dataset/",
        train=True,
        transform=transforms.ToTensor(),
        download=True
    )

    X_test = torchvision.datasets.MNIST(
        root="dataset/",
        train=False,
        transform=transforms.ToTensor(),
        download=True
    )

    train_loader = DataLoader(dataset=X_train, batch_size=32, shuffle=True)
    test_loader = DataLoader(dataset=X_test, batch_size=32, shuffle=False)

    #从头训练教师模型,并预测
    teacher(device, train_loader, test_loader)

   #从头训练学生模型,并预测
    student(device, train_loader, test_loader)

   #知识蒸馏训练学生模型
    model = torch.load('teacher.pkl')
    kd(model, device, train_loader, test_loader)

Teacher MdeolStuent Model without DistillationStuent Model with Distillation の精度を比較した最終トレーニング結果: < /span> ①教師蒸留を使用して訓練された生徒は、単独で訓練された生徒よりも強いです。 ②実際の場面では、生徒自身が教師よりも著しく弱い場合が多く、教師の成績を超えることは困難です。

  • Teacher Mdeol:エポック:3 精度:0.9689,エポック:6 精度:0.9764
  • Stuent Model without Distillation:エポック:3 アクセス:0.8173
  • Stuent Model with Distillation:エポック:3 精度:0.8387、エポック:20 精度:0.9015

3. 漸進蒸留拡散コード生成の実践

段階的蒸留により拡散モデルのサンプリング ステップ数を削減する方法。主な内容は次のとおりです。progressive distillationguided diffusion distillationstep distillationData-free DistillationLatent Consistency Models

このセクションでは主に順蒸留について説明します。この記事で提案する方法はv-parameterization次のような用途で使用されるためです。その後の拡散 Imagen video、Imagen、Stable Diffusion、Dall E など、推論を高速化する作業で広く使用されています。

3.1 逐次蒸留の原理

プログレッシブ蒸留の目標は、通常は反復を繰り返すことにより、多くのステップを含む教師用拡散を、より少ないステップを含むスチューデント拡散に蒸留することです。各反復、Student企图1步学习Teacher模型2步的结果。蒸留を繰り返すたびに、生徒に必要なサンプル ステップの数が半分に減り、現在の生徒が次の教師になります。

ここに画像の説明を挿入します
上の図に示すように、教師拡散 f ( z , η ) f(z,\eta) f(z,η) 4 つの決定論的ステップ、スチューデント拡散を通じてランダム ノイズ ε をサンプル x にマッピングします f ( z , θ ) f(z,\theta) f(z,θ) このマッピング関係は、わずか 1 ステップで学習できます。

段階的蒸留法:

  1. 教師の拡散のトレーニング: 教師モデルは 标准Diffusion模型训练方法 を使用してトレーニングされ、その トレーニング損失関数が定義されます。ノイズとして ε 空間の平均二乗誤差:
    ここに画像の説明を挿入します
    関連する変数の定義:
    ここに画像の説明を挿入します
    : トレーニング () のために x から を直接予測することに加えて、 x と ε() を個別に予測し、それらを次のようにマージします: 、または ) を予測し、次のように計算します。(vx-parameterizationε-parameterization
    ここに画像の説明を挿入します
    v-parameterization
    ここに画像の説明を挿入します

  2. 漸進的蒸留学生拡散: 蒸留前に、教師拡散の重みで学生拡散を初期化します。モデル構造は同じです。 渐进蒸馏Diffusion方法标准Diffusion模型训练方法 の主な違いは、 去噪模型的 Label 值 の決定方法です。

    • 標準の拡散トレーニング: 拡散によってノイズ除去されたラベルはDDIM每个step的预定义好的Noise
    • 漸進蒸留拡散: スチューデント拡散ノイズ除去モデルが予測する必要があるラベルは Teacher模型预测的Noise です。そして、生徒の拡散は 1 ステップ予測ノイズ匹配 教師の拡散の 2 ステップ予測ノイズ、つまり生徒を使用しようとします。拡散は < /span> z t ' ' z_t^ {' 39;} にあります。教師拡散 2 ステップの予測ノイズです ε 空間のラベルt'' (ε-parameterization)。 再利用 z ˉ t ' = z t ' \bar z_t^{'' } = z_t^{''} ˉt'' =t'' x 空間(x-parameterization) に変換することもできます。
      ここに画像の説明を挿入します
      ここに画像の説明を挿入します

总结:传统Diffusion训练渐进蒸馏Diffusion
ここに画像の説明を挿入します

3.2 v パラメータ化

従来の拡散モデルはノイズ予測、つまりε-parameterization-predictionによってノイズ除去されていることは誰もが知っています。 速度予測 v-parameterization-predictionとは何ですか?なぜ速度予測を使用する必要があるのですか?

ノイズ予測に基づく従来の拡散モデルとは異なり、速度予測に基づく拡散モデルの出力は速度 v ^ θ \hat v_{\theta} ^θ 、対応する最適化目的関数は次のとおりです。
ここに画像の説明を挿入します
ここで、v は速度の真の値で、実際のサンプル x とノイズ レベルに応じたノイズ ε から計算できます。< /span>
ここに画像の説明を挿入します

拡散モデルの蒸留では、v-parameterization モデルのパフォーマンスが ε-parameterization よりも優れている傾向があるため、 a> に微調整されます。 ε-parameterizationv-parameterization

ここに画像の説明を挿入します

次に、xvzε (上の図と組み合わせる):
ここに画像の説明を挿入します
ここに画像の説明を挿入します
3 つのパラメータ化を要約します:
ここに画像の説明を挿入します

3.2 漸進的蒸留 cifar コードの実践

参考Colab:diffusion_distillation.ipynb

  1. Downloadコードとライブラリ、およびインポート ライブラリ:
!apt-get -qq install subversion
!svn checkout https://github.com/google-research/google-research/trunk/diffusion_distillation
!pip install -r diffusion_distillation/diffusion_distillation/requirements.txt --quiet

import os
import time
import requests
import functools
import jax
from jax import config
import jax.numpy as jnp
import flax
from matplotlib import pyplot as plt
import numpy as onp
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
from diffusion_distillation import diffusion_distillation
  1. TPU を使用するように 設定JAX: JAX は、CPU、GPU、TPU で実行できる Google のオープンソース numpy であり、機械学習用に設計されています。研究用の高性能自己微分コンピューティング加速フレームワーク。
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'
  resp = requests.post(url)
  time.sleep(5)
  TPU_DRIVER_MODE = 1
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)
  1. Train新しい普及モデル:
# create model
config = diffusion_distillation.config.cifar_base.get_config()
model = diffusion_distillation.model.Model(config)

# init params
state = jax.device_get(model.make_init_state())
state = flax.jax_utils.replicate(state)

# JIT compile training step
train_step = functools.partial(model.step_fn, jax.random.PRNGKey(0), True)
train_step = functools.partial(jax.lax.scan, train_step)  # for substeps
train_step = jax.pmap(train_step, axis_name='batch', donate_argnums=(0,))

# build input pipeline
total_bs = config.train.batch_size
device_bs = total_bs // jax.device_count()
train_ds = model.dataset.get_shuffled_repeated_dataset(
    split='train',
    batch_shape=(
        jax.local_device_count(),  # for pmap
        config.train.substeps,  # for lax.scan over multiple substeps
        device_bs,  # batch size per device
    ),
    local_rng=jax.random.PRNGKey(0),
    augment=True)
train_iter = diffusion_distillation.utils.numpy_iter(train_ds)

# run training
for step in range(10):
  batch = next(train_iter)
  state, metrics = train_step(state, batch)
  metrics = jax.device_get(flax.jax_utils.unreplicate(metrics))
  metrics = jax.tree_map(lambda x: float(x.mean(axis=0)), metrics)
  print(metrics)
  1. Distill訓練された拡散モデル:(核心)
# create model
config = diffusion_distillation.config.cifar_distill.get_config()
model = diffusion_distillation.model.Model(config)

# load the teacher params
model.load_teacher_state(config.distillation.teacher_checkpoint_path)

# init student state
init_params = diffusion_distillation.utils.copy_pytree(model.teacher_state.ema_params)
optim = model.make_optimizer_def().create(init_params)
state = diffusion_distillation.model.TrainState(
    step=model.teacher_state.step,
    optimizer=optim,
    ema_params=diffusion_distillation.utils.copy_pytree(init_params),
    num_sample_steps=model.teacher_state.num_sample_steps//2)
# build input pipeline
total_bs = config.train.batch_size
device_bs = total_bs // jax.device_count()
train_ds = model.dataset.get_shuffled_repeated_dataset(
    split='train',
    batch_shape=(
        jax.local_device_count(),  # for pmap
        config.train.substeps,  # for lax.scan over multiple substeps
        device_bs,  # batch size per device
    ),
    local_rng=jax.random.PRNGKey(0),
    augment=True)
train_iter = diffusion_distillation.utils.numpy_iter(train_ds)

steps_per_distill_iter = 10  # number of distillation steps per iteration of progressive distillation
end_num_steps = 4  # eventual number of sampling steps we want to use
while state.num_sample_steps >= end_num_steps:

  # compile training step
  train_step = functools.partial(model.step_fn, jax.random.PRNGKey(0), True)
  train_step = functools.partial(jax.lax.scan, train_step)  # for substeps
  train_step = jax.pmap(train_step, axis_name='batch', donate_argnums=(0,))

  # train the student against the teacher model
  print('distilling teacher using %d sampling steps into student using %d steps'
        % (model.teacher_state.num_sample_steps, state.num_sample_steps))
  state = flax.jax_utils.replicate(state)
  for step in range(steps_per_distill_iter):
    batch = next(train_iter)
    state, metrics = train_step(state, batch)
    metrics = jax.device_get(flax.jax_utils.unreplicate(metrics))
    metrics = jax.tree_map(lambda x: float(x.mean(axis=0)), metrics)
    print(metrics)

  # student becomes new teacher for next distillation iteration
  model.teacher_state = jax.device_get(
      flax.jax_utils.unreplicate(state).replace(optimizer=None))

  # reset student optimizer for next distillation iteration
  init_params = diffusion_distillation.utils.copy_pytree(model.teacher_state.ema_params)
  optim = model.make_optimizer_def().create(init_params)
  state = diffusion_distillation.model.TrainState(
      step=model.teacher_state.step,
      optimizer=optim,
      ema_params=diffusion_distillation.utils.copy_pytree(init_params),
      num_sample_steps=model.teacher_state.num_sample_steps//2)
  1. 蒸留されたモデルのチェックポイントをロードし、そこからサンプリングします。
# list all available distilled checkpoints
!gsutil ls gs://gresearch/diffusion-distillation

# create imagenet model
config = diffusion_distillation.config.imagenet64_base.get_config()
model = diffusion_distillation.model.Model(config)

# load distilled checkpoint for 8 sampling steps
loaded_params = diffusion_distillation.checkpoints.restore_from_path('gs://gresearch/diffusion-distillation/imagenet_8', target=None)['ema_params']

# fix possible flax version errors
ema_params = jax.device_get(model.make_init_state()).ema_params
loaded_params = flax.core.unfreeze(loaded_params)
loaded_params = jax.tree_map(
    lambda x, y: onp.reshape(x, y.shape) if hasattr(y, 'shape') else x,
    loaded_params,
    flax.core.unfreeze(ema_params))
loaded_params = flax.core.freeze(loaded_params)
del ema_params

# sample from the model
imagenet_classes = {
    
    'malamute': 249, 'siamese': 284, 'great_white': 2,
                    'speedboat': 814, 'reef': 973, 'sports_car': 817,
                    'race_car': 751, 'model_t': 661, 'truck': 867}
labels = imagenet_classes['truck'] * jnp.ones((4,), dtype=jnp.int32)
samples = model.samples_fn(rng=jax.random.PRNGKey(0), labels=labels, params=loaded_params, num_steps=8)
samples = jax.device_get(samples).astype(onp.uint8)

# visualize samples
padded_samples = onp.pad(samples, ((0,0), (1,1), (1,1), (0,0)), mode='constant', constant_values=255)
nrows = int(onp.sqrt(padded_samples.shape[0]))
ncols = padded_samples.shape[0]//nrows
_, height, width, channels = padded_samples.shape
img_grid = padded_samples.reshape(nrows, ncols, height, width, channels).swapaxes(1,2).reshape(height*nrows, width*ncols, channels)
img = plt.imshow(img_grid)
plt.axis('off')

(-0.5、131.5、131.5、-0.5)
ここに画像の説明を挿入します

元の Diffusion と比較して、蒸留された Diffusion は少ないステップで良好な生成品質の FID を取得できることがわかります。 (DDIM 采样器优化的stochastic随机采样器蒸馏)
ここに画像の説明を挿入します

おすすめ

転載: blog.csdn.net/weixin_54338498/article/details/134656534