神经网络知识蒸馏 Knowledge Distillation_哔哩哔哩_bilibili
Intuition——水果蔬菜分类
传统的训练:西红柿【1, 0, 0】,这是西红柿
有老师教:西红柿【1, 0, 0】+【0.7, 0.29, 0.01】,这是西红柿,但它跟柿子长得挺像。
Teacher-Student
蒸馏中常用的 Loss Function in Pytorch
- Softmax:将一个数值序列映射到概率空间
- log_softmax:在softmax的基础上取对数
- NLLLoss:对log_softmax与one-hot进行计算
- CrossEntropy:衡量两个概率分布的差别
Log_softmax + NLLLoss是CrossEntropy的特例:
CrossEntropy用于衡量两个概率分布的差异,若其中一个概率分布为one-hot形式,则可以使用(Log_softmax + NLLLoss)代替交叉熵。
理论分析
知识蒸馏相当于“加入了正则化的损失函数”,因为蒸馏相当于在原始预测输出与真实one-hot标签求损失的基础上加入了一个软损失,而这个软损失就相当于正则化
代码
代码很简单,关键就在于通过求【学生网络对输入的输出、真实的one-hot标签、教师网络的输出、温度系数、权重系统α】
loss=0.7*KL散度(带有温度的学生的输出、带有温度的老师的输出)+0.3*交叉熵(学生的输出、oneHot)
loss = 软损失 + 硬损失
# 关键,定义kd的loss
def distillation(y, labels, teacher_scores, temp, alpha):
return nn.KLDivLoss()(F.log_softmax(y / temp, dim=1), F.softmax(teacher_scores / temp, dim=1)) * (temp * temp * 2.0 * alpha) \
+ F.cross_entropy(y, labels) * (1. - alpha)
-----------------------------------------------------------------------------------
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.utils.data
torch.manual_seed(0)
torch.cuda.manual_seed(0)
## 训练教师网络
class TeacherNet(nn.Module):
def __init__(self):
super(TeacherNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.3)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
output = self.fc2(x)
return output
def train_teacher(model, device, train_loader, optimizer, epoch):
model.train()
trained_samples = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
trained_samples += len(data)
progress = math.ceil(batch_idx / len(train_loader) * 50)
print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
(epoch, trained_samples, len(train_loader.dataset),
'-' * progress + '>', progress * 2), end='')
def test_teacher(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
return test_loss, correct / len(test_loader.dataset)
def teacher_main():
epochs = 10
batch_size = 64
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data/MNIST', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data/MNIST', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=1000, shuffle=True)
model = TeacherNet().to(device)
optimizer = torch.optim.Adadelta(model.parameters())
teacher_history = []
for epoch in range(1, epochs + 1):
train_teacher(model, device, train_loader, optimizer, epoch)
loss, acc = test_teacher(model, device, test_loader)
teacher_history.append((loss, acc))
torch.save(model.state_dict(), "teacher.pt")
return model, teacher_history
# 训练教师网络
teacher_model, teacher_history = teacher_main()
## 让老师教学生网络
# 定义学生网络
class StudentNet(nn.Module):
def __init__(self):
super(StudentNet, self).__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
output = F.relu(self.fc3(x))
return output
# 关键,定义kd的loss
def distillation(y, labels, teacher_scores, temp, alpha):
return nn.KLDivLoss()(F.log_softmax(y / temp, dim=1), F.softmax(teacher_scores / temp, dim=1)) * (
temp * temp * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)
def train_student_kd(model, device, train_loader, optimizer, epoch):
model.train()
trained_samples = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
teacher_output = teacher_model(data)
teacher_output = teacher_output.detach() # 切断老师网络的反向传播,感谢B站“淡淡的落”的提醒
# ★★★★关键在于这里★★★★
loss = distillation(output, target, teacher_output, temp=5.0, alpha=0.7)
loss.backward()
optimizer.step()
trained_samples += len(data)
progress = math.ceil(batch_idx / len(train_loader) * 50)
print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
(epoch, trained_samples, len(train_loader.dataset),
'-' * progress + '>', progress * 2), end='')
def test_student_kd(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
return test_loss, correct / len(test_loader.dataset)
def student_kd_main():
epochs = 10
batch_size = 64
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data/MNIST', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data/MNIST', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=1000, shuffle=True)
model = StudentNet().to(device)
optimizer = torch.optim.Adadelta(model.parameters())
student_history = []
for epoch in range(1, epochs + 1):
train_student_kd(model, device, train_loader, optimizer, epoch)
loss, acc = test_student_kd(model, device, test_loader)
student_history.append((loss, acc))
torch.save(model.state_dict(), "student_kd.pt")
return model, student_history
student_kd_model, student_kd_history = student_kd_main()
## 让学生自己学,不使用KD
def train_student(model, device, train_loader, optimizer, epoch):
model.train()
trained_samples = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
trained_samples += len(data)
progress = math.ceil(batch_idx / len(train_loader) * 50)
print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
(epoch, trained_samples, len(train_loader.dataset),
'-' * progress + '>', progress * 2), end='')
def test_student(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
return test_loss, correct / len(test_loader.dataset)
def student_main():
epochs = 10
batch_size = 64
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data/MNIST', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data/MNIST', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=1000, shuffle=True)
model = StudentNet().to(device)
optimizer = torch.optim.Adadelta(model.parameters())
student_history = []
for epoch in range(1, epochs + 1):
train_student(model, device, train_loader, optimizer, epoch)
loss, acc = test_student(model, device, test_loader)
student_history.append((loss, acc))
torch.save(model.state_dict(), "student.pt")
return model, student_history
student_simple_model, student_simple_history = student_main()
## -------------------通过绘图看一下他们之间的区别---------------------------
import matplotlib.pyplot as plt
epochs = 10
x = list(range(1, epochs+1))
plt.subplot(2, 1, 1)
plt.plot(x, [teacher_history[i][1] for i in range(epochs)], label='teacher')
plt.plot(x, [student_kd_history[i][1] for i in range(epochs)], label='student with KD')
plt.plot(x, [student_simple_history[i][1] for i in range(epochs)], label='student without KD')
plt.title('Test accuracy')
plt.legend()
plt.subplot(2, 1, 2)
plt.plot(x, [teacher_history[i][0] for i in range(epochs)], label='teacher')
plt.plot(x, [student_kd_history[i][0] for i in range(epochs)], label='student with KD')
plt.plot(x, [student_simple_history[i][0] for i in range(epochs)], label='student without KD')
plt.title('Test loss')
plt.legend()