記事ディレクトリ
1. 蒸留の基本概念
知識の蒸留は、模型压缩
と 迁移学习
で広く使用されています。先駆的な研究は、「ニューラル ネットワークでの知識の蒸留」 であるはずです。この記事における著者の動機は、方法を見つけることです。通常、トレーニングされた は、別の を教えるために使用されます。通常、モデル A はモデル B よりも強力です。モデル A の指導の下では、モデル B は独学よりもよく学習できます。 把多个模型的知识提炼给单个模型
Teacher Model A
Student Model B
方法: まず教師ネットワークをトレーニングし、次にこの teacher网络的输出
と 数据的真实标签
を 训练student网络
に使用します。知識の蒸留を使用すると、ネットワークを大規模ネットワークから小規模ネットワークに変換し、大規模ネットワークに近いパフォーマンスを維持できます。また、複数のネットワークの学習した知識を 1 つのネットワークに転送して、単一ネットワークのパフォーマンスを維持することもできます。の結果に似ています。
たとえば、次の画像分類タスクの場合:
-
従来のトレーニング: Teacher ネットワークがない場合、データのみが Student ネットワークを通過します。ソフトマックスの後、確率分布値 q が出力され、
q
label p
による Cross_entropy loss の検出はHard loss
と呼ばれます。これは、この p が実数値のワンホット ベクトルであり、より近い値になることが期待されるためです。 q と p があればあるほど良いです。 -
知識の蒸留: 教師の協力がある場合、生徒と教師のネットワークから損失が生じます。そして、教師による
q'
出力は带温度的Softmax
を通過する必要があります (スムーズにするため、考え方はlabel smooth
と似ています)。 < a i=5> に続いて を実行して損失を計算します。 = + < a i=9>。 L = α ⋅ ハード _ L o s s + ( 1 − α ) ⋅ ソフト t _ L o s s = α ⋅ C E ( p , q ) + ( 1 − α ) ⋅ C E ( q ' ' , q ) L=\alpha\cdot Hard\_Loss+(1-\alpha)\cdot Soft\_Loss=\alpha\cdot CE(p,q) + (1-\alpha)\cdot CE(q' ',q)q''
q
总loss
Teacher q'' 和 Student q 的 loss
Student q 和 label p 的 loss
L=ある⋅Hard_ Loss+(1−α)⋅Soft_ Loss=ある⋅CE(p,q)+(1−α)⋅CE(q''、q) -
SoftMax问题:
普通的Cross Entropy Loss
是由NLL Loss
、Log
、Softmax
组成的:
F.cross_entropy(p,target)) = F.nll_loss(torch.log(torch.softmax(p)), target)
このcross_entropy 損失のソフトマックスは、実際にはそれほどソフトではありません。出力確率分布は、 正しいカテゴリの信頼度は高くなりますが、他のカテゴリの確率はほぼ 0 です。この場合、教師ネットワークによって学習されたデータの類似情報 (たとえば、数字 2、3、7 が非常に似ている、この種のソフトラベル情報) を生徒ネットワークに伝えるのは困難です。 したがって、この記事では温度係数 T を持つソフトマックスを提案しています。 ここで q i q_i Softmax-T
qi は生徒のネットワーク学習 (ソフト ターゲット) のオブジェクトです。 z i z_i とi ニューラル ネットワークのソフトマックス前の出力ロジットです。 T が 1 に設定されている場合、この式はソフトマックスとなり、ロジットに基づいて各カテゴリの確率を出力します。 T が 0 に近い場合、最大値は 1 に近づき、その他の値は 0 に近づきます。これは、ワンホット エンコーディングと似ています。 Tが大きいほど出力結果の分布が平坦になるため、平滑化に相当し、同様の情報を保持する役割を果たします。 T が無限大に等しい場合、それは一様分布です。
比較Softmax
(上) と Softmax-T
(下) モデル予測結果の確率分布の視覚化:
最終的な蒸留損失は、元の CE 損失の Soft Loss
の Softmax
を Softmax_T
に置き換えて KD 損失を取得します。 :
K D _ L o s s = α ⋅ H a rd _ L o s s + ( 1 − α ) ⋅ S o f t _ L o s s = α ⋅ C E ( p , q ) + ( 1 − α ) ⋅ C E ( q '' , q ) KD\_Loss=\alpha\cdot Hard\_Loss+(1-\alpha)\cdot Soft\_Loss=\alpha\cdot CE(p,q) + (1-\alpha)\ cdot CE(q'',q) KD_Loss=ある⋅Hard_ Loss+(1−α)⋅Soft_ Loss=ある⋅CE(p,q)+(1−α)⋅CE(q''、q)
p は実際のラベル ラベル、q は生徒の出力、q'' は教師の出力です。
def distillation_loss(y,labels,teacher_scores,temp,alpha):
soft_loss = nn.KLDivLoss()(F.log_softmax(y/temp, dim=1), F.softmax(teacher_scores/temp,dim=1))
hard_loss = F.cross_entropy(y,labels)
return soft_loss *(temp*temp*2.0*alpha) + hard_loss *(1. - alpha)
2. 蒸留 MNIST CNN 分類コードの実践
- ライブラリのインポート:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import torchvision
from torchvision import transforms
- 教師 CNN モデル (
big
) と生徒 CNN モデル (small
) を定義します:
class TeacherModel(nn.Module):
def __init__(self, in_channels=1, num_classes=10):
super(TeacherModel, self).__init__()
self.relu = nn.ReLU()
self.fc1 = nn.Linear(784, 1200)
self.fc2 = nn.Linear(1200, 1200)
self.fc3 = nn.Linear(1200, num_classes)
self.dropout = nn.Dropout(p=0.5)
def forward(self, x):
x = x.view(-1, 784)
x = self.relu(self.dropout(self.fc1(x)))
x = self.relu(self.dropout(self.fc2(x)))
x = self.fc3(x)
return x
class StudentModel(nn.Module):
def __init__(self, in_channels=1, num_classes=10):
super(StudentModel, self).__init__()
self.relu = nn.ReLU()
self.fc1 = nn.Linear(784, 20)
self.fc2 = nn.Linear(20, 20)
self.fc3 = nn.Linear(20, num_classes)
self.dropout = nn.Dropout(p=0.5)
def forward(self, x):
x = x.view(-1, 784)
x = self.relu(self.dropout(self.fc1(x)))
x = self.relu(self.dropout(self.fc2(x)))
x = self.fc3(x)
return x
- 教師モデルのトレーニングの機能:
def teacher(device, train_loader, test_loader):
print('--------------teachermodel start--------------')
model = TeacherModel()
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 6
for epoch in range(epochs):
model.train()
for data, target in tqdm(train_loader):
data = data.to(device)
target = target.to(device)
preds = model(data)
loss = criterion(preds, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.eval()
num_correct = 0
num_samples = 0
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
preds = model(x)
predictions = preds.max(1).indices
num_correct += (predictions.eq(y)).sum().item()
num_samples += predictions.size(0)
acc = num_correct / num_samples
model.train()
print('Epoch:{}\t Acc:{:.4f}'.format(epoch + 1, acc))
torch.save(model, 'teacher.pkl')
print('--------------teachermodel end--------------')
- 学生モデルの機能
independently
:
def student(device, train_loader, test_loader):
print('--------------studentmodel start--------------')
model = StudentModel()
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 3
for epoch in range(epochs):
model.train()
for data, target in tqdm(train_loader):
data = data.to(device)
target = target.to(device)
preds = model(data)
loss = criterion(preds, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.eval()
num_correct = 0
num_samples = 0
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
# print(y)
preds = model(x)
# print(preds)
predictions = preds.max(1).indices
# print(predictions)
num_correct += (predictions.eq(y)).sum().item()
num_samples += predictions.size(0)
acc = num_correct / num_samples
model.train()
print('Epoch:{}\t Acc:{:.4f}'.format(epoch + 1, acc))
print('--------------studentmodel prediction end--------------')
- 教師モデルから生徒モデルへの
Distilling
の機能:(核心)
def kd(teachermodel, device, train_loader, test_loader):
print('--------------kdmodel start--------------')
teachermodel.eval()
studentmodel = StudentModel()
studentmodel = studentmodel.to(device)
studentmodel.train()
temp = 7 #蒸馏温度
alpha = 0.3
hard_loss = nn.CrossEntropyLoss()
soft_loss = nn.KLDivLoss(reduction='batchmean')
optimizer = torch.optim.Adam(studentmodel.parameters(), lr=1e-4)
epochs = 20
for epoch in range(epochs):
for data, target in tqdm(train_loader):
data = data.to(device)
target = target.to(device)
with torch.no_grad():
teacher_preds = teachermodel(data)
student_preds = studentmodel(data)
student_loss = hard_loss(student_preds, target) #hard_loss
distillation_loss = soft_loss(
F.log_softmax(student_preds / temp, dim=1),
F.softmax(teacher_preds / temp, dim=1)
) #soft_loss
loss = alpha * student_loss + (1 - alpha) * distillation_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
studentmodel.eval()
num_correct = 0
num_samples = 0
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
preds = studentmodel(x)
predictions = preds.max(1).indices
num_correct += (predictions.eq(y)).sum().item()
num_samples += predictions.size(0)
acc = num_correct / num_samples
studentmodel.train()
print('Epoch:{}\t Acc:{:.4f}'.format(epoch + 1, acc))
print('--------------kdmodel end--------------')
- 主な機能 (データのロード、トレイン機能の実装):
if __name__ == '__main__':
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
torch.backends.cudnn.benchmark = True
#加载数据集
X_train = torchvision.datasets.MNIST(
root="dataset/",
train=True,
transform=transforms.ToTensor(),
download=True
)
X_test = torchvision.datasets.MNIST(
root="dataset/",
train=False,
transform=transforms.ToTensor(),
download=True
)
train_loader = DataLoader(dataset=X_train, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=X_test, batch_size=32, shuffle=False)
#从头训练教师模型,并预测
teacher(device, train_loader, test_loader)
#从头训练学生模型,并预测
student(device, train_loader, test_loader)
#知识蒸馏训练学生模型
model = torch.load('teacher.pkl')
kd(model, device, train_loader, test_loader)
Teacher Mdeol
、Stuent Model without Distillation
、Stuent Model with Distillation
の精度を比較した最終トレーニング結果: < /span> ①教師蒸留を使用して訓練された生徒は、単独で訓練された生徒よりも強いです。 ②実際の場面では、生徒自身が教師よりも著しく弱い場合が多く、教師の成績を超えることは困難です。
Teacher Mdeol
:エポック:3 精度:0.9689,エポック:6 精度:0.9764Stuent Model without Distillation
:エポック:3 アクセス:0.8173Stuent Model with Distillation
:エポック:3 精度:0.8387、エポック:20 精度:0.9015
3. 漸進蒸留拡散コード生成の実践
段階的蒸留により拡散モデルのサンプリング ステップ数を削減する方法。主な内容は次のとおりです。progressive distillation
、guided diffusion distillation
、step distillation
、Data-free Distillation
、Latent Consistency Models
。
このセクションでは主に順蒸留について説明します。この記事で提案する方法はv-parameterization
次のような用途で使用されるためです。その後の拡散 Imagen video、Imagen、Stable Diffusion、Dall E など、推論を高速化する作業で広く使用されています。
3.1 逐次蒸留の原理
プログレッシブ蒸留の目標は、通常は反復を繰り返すことにより、多くのステップを含む教師用拡散を、より少ないステップを含むスチューデント拡散に蒸留することです。各反復、Student企图1步学习Teacher模型2步的结果
。蒸留を繰り返すたびに、生徒に必要なサンプル ステップの数が半分に減り、現在の生徒が次の教師になります。
上の図に示すように、教師拡散 f ( z , η ) f(z,\eta) f(z,η) 4 つの決定論的ステップ、スチューデント拡散を通じてランダム ノイズ ε をサンプル x にマッピングします f ( z , θ ) f(z,\theta) f(z,θ) このマッピング関係は、わずか 1 ステップで学習できます。
段階的蒸留法:
-
教師の拡散のトレーニング: 教師モデルは
标准Diffusion模型训练方法
を使用してトレーニングされ、その トレーニング損失関数が定義されます。ノイズとして ε 空間の平均二乗誤差:
関連する変数の定義:
注 : トレーニング () のために x から を直接予測することに加えて、 x と ε() を個別に予測し、それらを次のようにマージします: 、または ) を予測し、次のように計算します。(vx-parameterization
ε-parameterization
v-parameterization
-
漸進的蒸留学生拡散: 蒸留前に、教師拡散の重みで学生拡散を初期化します。モデル構造は同じです。
渐进蒸馏Diffusion方法
と标准Diffusion模型训练方法
の主な違いは、去噪模型的 Label 值
の決定方法です。- 標準の拡散トレーニング: 拡散によってノイズ除去されたラベルは
DDIM每个step的预定义好的Noise
、 - 中漸進蒸留拡散: スチューデント拡散ノイズ除去モデルが予測する必要があるラベルは
Teacher模型预测的Noise
です。そして、生徒の拡散は 1 ステップ予測ノイズ匹配
教師の拡散の 2 ステップ予測ノイズ、つまり生徒を使用しようとします。拡散は < /span> z t ' ' z_t^ {' 39;} にあります。教師拡散 2 ステップの予測ノイズです ε 空間のラベルとt'' (ε-parameterization
)。 再利用 z ˉ t ' = z t ' \bar z_t^{'' } = z_t^{''} とˉt'' =とt'' は x 空間(x-parameterization
) に変換することもできます。
- 標準の拡散トレーニング: 拡散によってノイズ除去されたラベルは
总结:传统Diffusion训练
和 渐进蒸馏Diffusion
3.2 v パラメータ化
従来の拡散モデルはノイズ予測、つまりε-parameterization-prediction
によってノイズ除去されていることは誰もが知っています。 速度予測 v-parameterization-prediction
とは何ですか?なぜ速度予測を使用する必要があるのですか?
ノイズ予測に基づく従来の拡散モデルとは異なり、速度予測に基づく拡散モデルの出力は速度 v ^ θ \hat v_{\theta} で^θ 、対応する最適化目的関数は次のとおりです。
ここで、v は速度の真の値で、実際のサンプル x とノイズ レベルに応じたノイズ ε から計算できます。< /span>
拡散モデルの蒸留では、v-parameterization
モデルのパフォーマンスが ε-parameterization
よりも優れている傾向があるため、 a> に微調整されます。 ε-parameterization
は v-parameterization
次に、x
、v
、z
、ε
(上の図と組み合わせる):
3 つのパラメータ化を要約します:
3.2 漸進的蒸留 cifar コードの実践
参考Colab:diffusion_distillation.ipynb
Download
コードとライブラリ、およびインポート ライブラリ:
!apt-get -qq install subversion
!svn checkout https://github.com/google-research/google-research/trunk/diffusion_distillation
!pip install -r diffusion_distillation/diffusion_distillation/requirements.txt --quiet
import os
import time
import requests
import functools
import jax
from jax import config
import jax.numpy as jnp
import flax
from matplotlib import pyplot as plt
import numpy as onp
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
from diffusion_distillation import diffusion_distillation
- TPU を使用するように 設定
JAX
:JAX
は、CPU、GPU、TPU で実行できる Google のオープンソース numpy であり、機械学習用に設計されています。研究用の高性能自己微分コンピューティング加速フレームワーク。
if 'TPU_DRIVER_MODE' not in globals():
url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'
resp = requests.post(url)
time.sleep(5)
TPU_DRIVER_MODE = 1
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)
Train
新しい普及モデル:
# create model
config = diffusion_distillation.config.cifar_base.get_config()
model = diffusion_distillation.model.Model(config)
# init params
state = jax.device_get(model.make_init_state())
state = flax.jax_utils.replicate(state)
# JIT compile training step
train_step = functools.partial(model.step_fn, jax.random.PRNGKey(0), True)
train_step = functools.partial(jax.lax.scan, train_step) # for substeps
train_step = jax.pmap(train_step, axis_name='batch', donate_argnums=(0,))
# build input pipeline
total_bs = config.train.batch_size
device_bs = total_bs // jax.device_count()
train_ds = model.dataset.get_shuffled_repeated_dataset(
split='train',
batch_shape=(
jax.local_device_count(), # for pmap
config.train.substeps, # for lax.scan over multiple substeps
device_bs, # batch size per device
),
local_rng=jax.random.PRNGKey(0),
augment=True)
train_iter = diffusion_distillation.utils.numpy_iter(train_ds)
# run training
for step in range(10):
batch = next(train_iter)
state, metrics = train_step(state, batch)
metrics = jax.device_get(flax.jax_utils.unreplicate(metrics))
metrics = jax.tree_map(lambda x: float(x.mean(axis=0)), metrics)
print(metrics)
Distill
訓練された拡散モデル:(核心)
# create model
config = diffusion_distillation.config.cifar_distill.get_config()
model = diffusion_distillation.model.Model(config)
# load the teacher params
model.load_teacher_state(config.distillation.teacher_checkpoint_path)
# init student state
init_params = diffusion_distillation.utils.copy_pytree(model.teacher_state.ema_params)
optim = model.make_optimizer_def().create(init_params)
state = diffusion_distillation.model.TrainState(
step=model.teacher_state.step,
optimizer=optim,
ema_params=diffusion_distillation.utils.copy_pytree(init_params),
num_sample_steps=model.teacher_state.num_sample_steps//2)
# build input pipeline
total_bs = config.train.batch_size
device_bs = total_bs // jax.device_count()
train_ds = model.dataset.get_shuffled_repeated_dataset(
split='train',
batch_shape=(
jax.local_device_count(), # for pmap
config.train.substeps, # for lax.scan over multiple substeps
device_bs, # batch size per device
),
local_rng=jax.random.PRNGKey(0),
augment=True)
train_iter = diffusion_distillation.utils.numpy_iter(train_ds)
steps_per_distill_iter = 10 # number of distillation steps per iteration of progressive distillation
end_num_steps = 4 # eventual number of sampling steps we want to use
while state.num_sample_steps >= end_num_steps:
# compile training step
train_step = functools.partial(model.step_fn, jax.random.PRNGKey(0), True)
train_step = functools.partial(jax.lax.scan, train_step) # for substeps
train_step = jax.pmap(train_step, axis_name='batch', donate_argnums=(0,))
# train the student against the teacher model
print('distilling teacher using %d sampling steps into student using %d steps'
% (model.teacher_state.num_sample_steps, state.num_sample_steps))
state = flax.jax_utils.replicate(state)
for step in range(steps_per_distill_iter):
batch = next(train_iter)
state, metrics = train_step(state, batch)
metrics = jax.device_get(flax.jax_utils.unreplicate(metrics))
metrics = jax.tree_map(lambda x: float(x.mean(axis=0)), metrics)
print(metrics)
# student becomes new teacher for next distillation iteration
model.teacher_state = jax.device_get(
flax.jax_utils.unreplicate(state).replace(optimizer=None))
# reset student optimizer for next distillation iteration
init_params = diffusion_distillation.utils.copy_pytree(model.teacher_state.ema_params)
optim = model.make_optimizer_def().create(init_params)
state = diffusion_distillation.model.TrainState(
step=model.teacher_state.step,
optimizer=optim,
ema_params=diffusion_distillation.utils.copy_pytree(init_params),
num_sample_steps=model.teacher_state.num_sample_steps//2)
- 蒸留されたモデルのチェックポイントをロードし、そこからサンプリングします。
# list all available distilled checkpoints
!gsutil ls gs://gresearch/diffusion-distillation
# create imagenet model
config = diffusion_distillation.config.imagenet64_base.get_config()
model = diffusion_distillation.model.Model(config)
# load distilled checkpoint for 8 sampling steps
loaded_params = diffusion_distillation.checkpoints.restore_from_path('gs://gresearch/diffusion-distillation/imagenet_8', target=None)['ema_params']
# fix possible flax version errors
ema_params = jax.device_get(model.make_init_state()).ema_params
loaded_params = flax.core.unfreeze(loaded_params)
loaded_params = jax.tree_map(
lambda x, y: onp.reshape(x, y.shape) if hasattr(y, 'shape') else x,
loaded_params,
flax.core.unfreeze(ema_params))
loaded_params = flax.core.freeze(loaded_params)
del ema_params
# sample from the model
imagenet_classes = {
'malamute': 249, 'siamese': 284, 'great_white': 2,
'speedboat': 814, 'reef': 973, 'sports_car': 817,
'race_car': 751, 'model_t': 661, 'truck': 867}
labels = imagenet_classes['truck'] * jnp.ones((4,), dtype=jnp.int32)
samples = model.samples_fn(rng=jax.random.PRNGKey(0), labels=labels, params=loaded_params, num_steps=8)
samples = jax.device_get(samples).astype(onp.uint8)
# visualize samples
padded_samples = onp.pad(samples, ((0,0), (1,1), (1,1), (0,0)), mode='constant', constant_values=255)
nrows = int(onp.sqrt(padded_samples.shape[0]))
ncols = padded_samples.shape[0]//nrows
_, height, width, channels = padded_samples.shape
img_grid = padded_samples.reshape(nrows, ncols, height, width, channels).swapaxes(1,2).reshape(height*nrows, width*ncols, channels)
img = plt.imshow(img_grid)
plt.axis('off')
(-0.5、131.5、131.5、-0.5)
元の Diffusion と比較して、蒸留された Diffusion は少ないステップで良好な生成品質の FID を取得できることがわかります。 (DDIM 采样器
対 优化的stochastic随机采样器
対 蒸馏
)