知識の蒸留の中心的な考え方は、大きなモデルの知識を小さなモデルに転送することです。
ここでの知識は通常、モデルが学習したデータ分布です。一般に、大規模なモデルは非常に高い精度を特徴としますが、速度が十分でない場合や展開が容易ではない場合があり、小規模なモデルは通常、展開が簡単で高速ですが、大規模モデルほど正確ではありません。
したがって、大規模モデルが(厳密な意味ではなく、たとえば)グラウンドトゥルースとみなすことができ、大規模モデルと小規模モデルの間の需給ギャップは継続的に縮小します。したがって、大きなモデルを教師として使用し、小さなモデルを生徒として使用することができ、生徒は教師の指導の下で学習します。
トレーニングの観点から、知識の蒸留はオフラインの蒸留とオンラインの蒸留に分けることができ、前者は、訓練を受けた教師ネットワークと学生ネットワークの間の関係を確立して蒸留学習を継続的に行うものです。以下に例を示します。修正したモデルを教師として使用し、改良を加えていないモデルを生徒として使用し、両者の間でオフラインの蒸留関数を作成できます。たとえば、モデル内の特定のレイヤーを教師として使用し、別のレイヤーを生徒として使用する場合、この 2 つは損失関数を確立しながら自ら学習することになります。これがオンライン手法です。
蒸留方法から論理蒸留と特性蒸留に分けることもできます。前者はモデル出力のロジスティック回帰を抽出するもので、後者はモデル内のフィーチャ レイヤーを抽出するものです。たとえば、前者は主に 2 つのモデルの出力ラベルを測定しますが、後者は 2 つのフィーチャ レイヤー間の距離を短縮できます。
オフライン蒸留については、私の他の記事「分類ネットワークの知識の蒸留」を参照してください。
この記事はオンライン蒸留であり、Resnet を例として、主に論理蒸留と特徴蒸留を含むコードを詳細に説明します。コンテンツがたくさんあります。シートベルトを締めて、さあ行きましょう~
この記事を学習するときは、Resnet コードを深く理解している必要があります。この記事を学習できるように、Resnet コードの学習の準備もここで行いました: Resnet コードの学習
目次
ネットワーク定義
元の Resnet コードとは異なり、ここではいくつかの変更が加えられています。
1. Resnet の最初の畳み込み層 conv1 は 7x7 畳み込みですが、ここでは 3x3 に変更されます。
# conv1与原始Resnet不同,原始Resnet为7x7卷积
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1,
bias=False)
2. conv1 とlayer1 の間の最大プーリング層が削除されます。
# 最大池化,不过在forward中没有用到
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
残差ブロックlayer1~layer4のコードは変更されず、_make_layer関数が呼び出され続けます。
3. 元の Resnet コードと比較すると、ここでは適応プーリング層と完全接続層が削除されています。4 つのアテンション レイヤーと 4 つのスケーリング レイヤー (scala) に変更します。
scalaスケーリングレイヤー
scala 層は主に、特徴抽出における特徴層のスケーリングに使用されます。コードは次のとおりです (ここでは例として scala1 のみが使用されています)。scala1 は、定義された 3 つの SepConv 畳み込み層と平均プーリング層で構成されます。[scala2 は 2 SepConv+AvgP、scala3 は 1 SepConv+AvgP、scala4 は 1 AvgP]。
self.scala1 = nn.Sequential(
# 输入通道64*4=256,输出通道128*4=512
SepConv( # 尺寸减半
channel_in=64 * block.expansion,
channel_out=128 * block.expansion
),
# 输入通道128*4=512, 输出通道256*4=1024
SepConv( # 尺寸减半
channel_in=128 * block.expansion,
channel_out=256 * block.expansion
),
# 输入通道256*4=1024,输出通道512*4=2048
SepConv( # 尺寸减半
channel_in=256 * block.expansion,
channel_out=512 * block.expansion
),
# 平均池化
nn.AvgPool2d(4, 4)
)
定義された SepConv 畳み込みコードは次のとおりです。
コンボリューションは、ステップ サイズ 2 の 3x3 グループ コンボリューション、1x1 コンボリューション、BN、ReLu、1 グループ コンボリューションの 3x3 ステップ サイズ、1x1 コンボリューション、BN、ReLu で構成されます。[または、深さ分離可能な 2 つの畳み込みから構成されていると理解できます]
class SepConv(nn.Module):
def __init__(self, channel_in, channel_out, kernel_size=3, stride=2, padding=1, affine=True):
super(SepConv, self).__init__()
self.op = nn.Sequential(
# 分组卷积,这里的分组数=输入通道数,那么每个group=channel_in/channel_in=1个通道,就是每个通道进行一个卷积
nn.Conv2d(channel_in, channel_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=channel_in, bias=False),
nn.Conv2d(channel_in, channel_in, kernel_size=1, padding=0, bias=False),
# affine 设为 True 时,BatchNorm 层才会学习参数 gamma 和 beta,否则不包含这两个变量,变量名是 weight 和 bias。
nn.BatchNorm2d(channel_in, affine=affine),
nn.ReLU(inplace=False),
# 分组卷积
nn.Conv2d(channel_in, channel_in, kernel_size=kernel_size, stride=1, padding=padding, groups=channel_in, bias=False),
nn.Conv2d(channel_in, channel_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(channel_out, affine=affine),
nn.ReLU(inplace=False),
)
def forward(self, x):
'''
x-->conv_3x3_s2(分组卷积)-->conv_1x1-->bn-->relu-->conv_3x3(分组卷积)-->conv_1x1-->bn-->relu-->out
'''
return self.op(x)
SepConv の構造は次のとおりです。SepConv は特徴マップのサイズを半分にし、出力チャネルの数は入力の2 倍になります。
最終的な scala1 構造は次のとおりです (要約すると、各 scala はフィーチャ レイヤーを [batchsize, 2048, 7, 7] の形状にスケーリングします)。
注目層
アテンション層はSepConv層、BN層、ReLu、アップサンプル、シグモイドから構成されます。
注意メカニズム。
コードは以下のように表示されます。
self.attention1 = nn.Sequential(
SepConv( # 尺寸减半
channel_in=64 * block.expansion, # 256
channel_out=64 * block.expansion # 256
),
nn.BatchNorm2d(64 * block.expansion),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear'), # 恢复原来尺寸
nn.Sigmoid()
)
上記はネットワーク内のさまざまなモジュールであり、完全な Resnet コードは次のとおりです。
自己蒸留用のResNetネットワーク
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=100, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
# 空洞卷积定义
self.dilation = 1
# 是否用空洞卷积代替步长,如果不采用空洞卷积,均为False
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups # 分组卷积分组数
self.base_width = width_per_group # 卷积宽度
# conv1与原始Resnet不同,原始Resnet为7x7卷积
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1,
bias=False)
# bn层
self.bn1 = norm_layer(self.inplanes)
# relu激活函数
self.relu = nn.ReLU(inplace=True)
# 最大池化,不过在forward中没有用到
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0]) # 尺寸不变
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0]) # 尺寸减半
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1]) # 尺寸减半
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2]) # 尺寸减半
'''
此处和原Resnet不同,原Resnet这里是自适应平均池化,然后接一个全连接层。
scala层的作用是对特征层的H,W做缩放处理,因为要和深层网络中其他Bottleneck输出特征层之间做loss
'''
self.scala1 = nn.Sequential(
# 输入通道64*4=256,输出通道128*4=512
SepConv( # 尺寸减半
channel_in=64 * block.expansion,
channel_out=128 * block.expansion
),
# 输入通道128*4=512, 输出通道256*4=1024
SepConv( # 尺寸减半
channel_in=128 * block.expansion,
channel_out=256 * block.expansion
),
# 输入通道256*4=1024,输出通道512*4=2048
SepConv( # 尺寸减半
channel_in=256 * block.expansion,
channel_out=512 * block.expansion
),
# 平均池化
nn.AvgPool2d(4, 4)
)
self.scala2 = nn.Sequential(
# 输入通道128*4=512,输出通道1024
SepConv(
channel_in=128 * block.expansion,
channel_out=256 * block.expansion,
),
# 输入通道256*4=1024,输出通道512*4=2048
SepConv(
channel_in=256 * block.expansion,
channel_out=512 * block.expansion,
),
# 平均池化
nn.AvgPool2d(4, 4)
)
self.scala3 = nn.Sequential(
# 输入通道256*4=1024,输出通道512*4=2048
SepConv(
channel_in=256 * block.expansion,
channel_out=512 * block.expansion,
),
# 平均池化
nn.AvgPool2d(4, 4)
)
# 平均池化
self.scala4 = nn.AvgPool2d(4, 4)
self.attention1 = nn.Sequential(
SepConv( # 尺寸减半
channel_in=64 * block.expansion, # 256
channel_out=64 * block.expansion # 256
), # 比输入前大两个像素
nn.BatchNorm2d(64 * block.expansion),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear'), # 恢复原来尺寸
nn.Sigmoid()
)
self.attention2 = nn.Sequential(
SepConv(
channel_in=128 * block.expansion,
channel_out=128 * block.expansion
),
nn.BatchNorm2d(128 * block.expansion),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Sigmoid()
)
self.attention3 = nn.Sequential(
SepConv(
channel_in=256 * block.expansion,
channel_out=256 * block.expansion
),
nn.BatchNorm2d(256 * block.expansion),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Sigmoid()
)
self.fc1 = nn.Linear(512 * block.expansion, num_classes)
self.fc2 = nn.Linear(512 * block.expansion, num_classes)
self.fc3 = nn.Linear(512 * block.expansion, num_classes)
self.fc4 = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
# 残差边采用1x1卷积升维条件,即当步长不为1或者输入通道数不等于输出通道数的时候
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
# layers用来存储每个当前残差层的所有残差块
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
# 仅在第一个bottleneck采用1x1进行升维,其他的bottleneck是直接输入和输出相加
return nn.Sequential(*layers)
def forward(self, x):
# 以x = (1,3,224,224)为例
feature_list = []
x = self.conv1(x) # get 1,64,224,224
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x) # conv2_x 输出256通道 1,256,224,224
fea1 = self.attention1(x) # 输出通道为256 224,224
fea1 = fea1 * x
feature_list.append(fea1)
x = self.layer2(x) # conv3_x 1,512,112,112
fea2 = self.attention2(x) # 512,112,112
fea2 = fea2 * x
feature_list.append(fea2)
x = self.layer3(x) # conv4_x 1,1024,56,56
fea3 = self.attention3(x) # 1024,56,56
fea3 = fea3 * x
feature_list.append(fea3)
x = self.layer4(x) # conv5_x 最深层网络 1,2048,28,28
feature_list.append(x)
# feature_list[0].shape is [1,256 224,224] scala1 shape is [1,2048,7,7] view is [1,7*7*2048]
out1_feature = self.scala1(feature_list[0]).view(x.size(0), -1) # # 得到新的特征图 对应到论文中的Bottleneck1
# feature_list[1].shape is [1,512,112,112], scala2 shape is [1,2048,7,7] view is [1,7*7*2048]
out2_feature = self.scala2(feature_list[1]).view(x.size(0), -1) # 得到新的特征图 对应到论文中的Bottleneck2
# feature_list[2].shape is [1,1024,56,56],scala3 shape is [1,2048,7,7] view is [1,7*7*2048]
out3_feature = self.scala3(feature_list[2]).view(x.size(0), -1) # 得到新的特征图 对应到论文中的Bottleneck3
# feature_list[3].shape is [1,2048,28,28],scala4 shape is [1,2048,7,7], view is [1,2048*7*7]
out4_feature = self.scala4(feature_list[3]).view(x.size(0), -1) # conv5_x 最深层网络
out1 = self.fc1(out1_feature)
out2 = self.fc2(out2_feature)
out3 = self.fc3(out3_feature)
out4 = self.fc4(out4_feature)
# 返回的特征层分别是经过全连接和不仅过全连接的
return [out4, out3, out2, out1], [out4_feature, out3_feature, out2_feature, out1_feature]
上記のコードにより、ネットワーク構成図は以下を参照することができます。
このように説明できます. ここでは例として入力サイズを 224x224x3 とします. Layer1 は残差ブロックです. 最初の Layer1 を除いて, 特徴層 H と W は変更されません. 他の層が出力された後, HそしてWは半分になります。同時に各層の下に注目層が存在します(layer4を除く) Att_featは注目を獲得した後の特徴マップであり、scalaを通してサイズ[batch_szie,2048,7,7]に固定され、最終的に出力されます。 FC。
コードの観点からは、次の 2 つの部分が返されます: 1. FC 層の後、2. FC 層なしの出力
知識蒸留トレーニング
以上でネットワークの定義は完了しました。トレーニング部分のコードの詳細な説明を見てみましょう。
入力は画像であり、ラベルは対応するラベルです。
net は、前に定義した Resnet ネットワークです。出力には 2 つの部分 (前述) があり、出力は FC ですが、outputs_feature は FC ではありません。前者は論理出力、後者は機能出力です。
inputs, labels = data # inputs是图片,labels是对应标签
inputs, labels = inputs.to(device), labels.to(device)
outputs, outputs_feature = net(inputs) # 获得4个分类特征层,outputs是经过fc层的,outputs_feature是仅缩放后的特征层
ここでの教師は最も深いネットワーク、つまり構造図の Layer4 であり、ここで取得される Teacher_feature_size は 2048*7*7 です [例として初期入力サイズ 224x224 を使用します]
layer_list = []
teacher_feature_size = outputs_feature[0].size(1)
次のループは各生徒層を取得するため、インデックスインデックスは 1 から始まります。ここのstudent_feature_sizeも2048*7*7です。
for index in range(1, len(outputs_feature)):
student_feature_size = outputs_feature[index].size(1) # 取浅层的三个特征层(没有经过FC)
layer_list.append(nn.Linear(student_feature_size, teacher_feature_size))
ここでの出力は FC の分類出力であり、outputs[0] は Layer4 です。損失関数はクロスエントロピーです。
# for deepest classifier hard loss
loss += criterion(outputs[0], labels)
Teacher_output は最も深い Layer4 [FC 層の後]、これをロジックの蒸留に使用します。Teacher_feature は最も深い Layer4 [FC なし]、これを機能の蒸留に使用します。
teacher_output = outputs[0].detach() # 取出最深层特征层
teacher_feature = outputs_feature[0].detach() # 取出最深层特征层(没有经过FC)
浅い論理出力をトラバースします。
1. Layer4[教師の論理出力]と各生徒の論理出力を損失関数として使用します。これは、論理蒸留損失、ソフト損失です。
2. 生徒とラベルの喪失をハードロスとして使用します。
3. 各生徒と教師の間の特徴距離 (特徴抽出、ソフト損失) を測定します。
上記の方法により、論理損失 + 生徒自身のハード損失 + 特徴抽出損失の 3 つの損失部分が発生します。
# for shallow classifiers
for index in range(1, len(outputs)):
# logits distillation 对分类输出最soft_loss
# 逻辑蒸馏,将教师网络的输出和每个浅层学生网络之间做逻辑蒸馏,Loss source2
loss += CrossEntropy(outputs[index], teacher_output) * args.loss_coefficient # KL_loss soft loss
# loss source1
loss += criterion(outputs[index], labels) * (1 - args.loss_coefficient) # hard loss 学生自己的
# feature distillation hint蒸馏
# 特征蒸馏,loss source3
if index != 1:
loss += torch.dist(net.adaptation_layers[index-1](outputs_feature[index]), teacher_feature) * \
args.feature_loss_coefficient
# the feature distillation loss will not be applied to the shallowest classifier
コードは次のとおりです。
if __name__ == "__main__":
# 记录最高准确率
best_acc = 0
# 开始训练
for epoch in range(args.epoch):
# [0,0,0,0,0]
correct = [0 for _ in range(5)]
# [0,0,0,0,0]
predicted = [0 for _ in range(5)]
# 学习率衰减
if epoch in [args.epoch // 3, args.epoch * 2 // 3, args.epoch - 10]:
for param_group in optimizer.param_groups:
param_group['lr'] /= 10
# train
net.train()
sum_loss, total = 0.0, 0.0
# 数据集的加载
for i, data in enumerate(trainloader, 0):
length = len(trainloader) # 获取数据集长度
inputs, labels = data # inputs是图片,labels是对应标签
inputs, labels = inputs.to(device), labels.to(device)
outputs, outputs_feature = net(inputs) # 获得4个分类特征层,outputs是经过fc层的,outputs_feature是仅缩放后的特征层
ensemble = sum(outputs[:-1])/len(outputs) # outputs[:-1]取出out4, out3, out2(即不包含最深层)
ensemble.detach_()
if init is False: # hint层
# init the adaptation layers.
# we add feature adaptation layers here to soften the influence from feature distillation loss
# the feature distillation in our conference version : | f1-f2 | ^ 2
# the feature distillation in the final version : |Fully Connected Layer(f1) - f2 | ^ 2
layer_list = []
teacher_feature_size = outputs_feature[0].size(1) # outputs_feature[0]是最深层的预测特征层 outputs_feature[1:]是浅层网络(学生)的特征层
for index in range(1, len(outputs_feature)):
student_feature_size = outputs_feature[index].size(1) # 取浅层的三个特征层(没有经过FC)
layer_list.append(nn.Linear(student_feature_size, teacher_feature_size))
net.adaptation_layers = nn.ModuleList(layer_list)
net.adaptation_layers.cuda()
optimizer = optim.SGD(net.parameters(), lr=args.init_lr, weight_decay=5e-4, momentum=0.9)
# define the optimizer here again so it will optimize the net.adaptation_layers
init = True
# compute loss
loss = torch.FloatTensor([0.]).to(device)
# for deepest classifier hard loss
loss += criterion(outputs[0], labels) # 最深层的特征层(经过FC输出)和labels计算交叉熵 [教师自己的]
teacher_output = outputs[0].detach() # 取出最深层特征层
teacher_feature = outputs_feature[0].detach() # 取出最深层特征层(没有经过FC)
# for shallow classifiers
for index in range(1, len(outputs)):
# logits distillation 对分类输出最soft_loss
# 逻辑蒸馏,将教师网络的输出和每个浅层学生网络之间做逻辑蒸馏,Loss source2
loss += CrossEntropy(outputs[index], teacher_output) * args.loss_coefficient # KL_loss soft loss
# loss source1
loss += criterion(outputs[index], labels) * (1 - args.loss_coefficient) # hard loss 学生自己的
# feature distillation hint蒸馏
# 特征蒸馏,loss source3
if index != 1:
loss += torch.dist(net.adaptation_layers[index-1](outputs_feature[index]), teacher_feature) * \
args.feature_loss_coefficient
# the feature distillation loss will not be applied to the shallowest classifier
sum_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
total += float(labels.size(0))
outputs.append(ensemble)
for classifier_index in range(len(outputs)):
_, predicted[classifier_index] = torch.max(outputs[classifier_index].data, 1)
correct[classifier_index] += float(predicted[classifier_index].eq(labels.data).cpu().sum())
print('[epoch:%d, iter:%d] Loss: %.03f | Acc: 4/4: %.2f%% 3/4: %.2f%% 2/4: %.2f%% 1/4: %.2f%%'
' Ensemble: %.2f%%' % (epoch + 1, (i + 1 + epoch * length), sum_loss / (i + 1),
100 * correct[0] / total, 100 * correct[1] / total,
100 * correct[2] / total, 100 * correct[3] / total,
100 * correct[4] / total))
print("Waiting Test!")
with torch.no_grad():
correct = [0 for _ in range(5)]
predicted = [0 for _ in range(5)]
total = 0.0
for data in testloader:
net.eval()
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs, outputs_feature = net(images)
ensemble = sum(outputs) / len(outputs)
outputs.append(ensemble)
for classifier_index in range(len(outputs)):
_, predicted[classifier_index] = torch.max(outputs[classifier_index].data, 1)
correct[classifier_index] += float(predicted[classifier_index].eq(labels.data).cpu().sum())
total += float(labels.size(0))
print('Test Set AccuracyAcc: 4/4: %.4f%% 3/4: %.4f%% 2/4: %.4f%% 1/4: %.4f%%'
' Ensemble: %.4f%%' % (100 * correct[0] / total, 100 * correct[1] / total,
100 * correct[2] / total, 100 * correct[3] / total,
100 * correct[4] / total))
if correct[4] / total > best_acc:
best_acc = correct[4]/total
print("Best Accuracy Updated: ", best_acc * 100)
torch.save(net.state_dict(), "./checkpoints/"+str(args.model)+".pth")
print("Training Finished, TotalEPOCH=%d, Best Accuracy=%.3f" % (args.epoch, best_acc))
完全なプロジェクトコード
GitHub - YINYIPENG-EN/Resnet_self_distillation_pytorch: Resnet 自己蒸留ネットワーク