通俗易懂的知识蒸馏 Knowledge Distillation(下)——代码实践(附详细注释)

第一步:导入所需要的包

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.utils.data

torch.manual_seed(0)		# 为CPU设置种子
torch.cuda.manual_seed(0)	# 为GPU设置种子

第二步:定义教师模型

教师模型网络结构(此处仅举一个例子):卷积层-卷积层-dropout-dropout-全连接层-全连接层

class TeacherNet(nn.Module):
    def __init__(self):
        super(TeacherNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)		# 卷积层
        self.conv2 = nn.Conv2d(32, 64, 3, 1)	# 卷积层
        self.dropout1 = nn.Dropout2d(0.3)		# dropout
        self.dropout2 = nn.Dropout2d(0.5)		# dropout
        self.fc1 = nn.Linear(9216, 128)			# 全连接层
        self.fc2 = nn.Linear(128, 10)			# 全连接层

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)			# 激活函数
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        output = self.fc2(x)
        return output

第三步:定义训练教师模型方法

正常的定义一个神经网络模型

def train_teacher(model, device, train_loader, optimizer, epoch):
    model.train()
    trained_samples = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)	# 将数据转移到CPU/GPU
        optimizer.zero_grad()	# 优化器将梯度全部置为0
        output = model(data)	# 数据经过模型向前传播
        loss = F.cross_entropy(output, target)  # 计算损失函数
        loss.backward()			# 反向传播
        optimizer.step()		# 更新梯度

        trained_samples += len(data)
        progress = math.ceil(batch_idx / len(train_loader) * 50) # 计算训练进度
        print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
              (epoch, trained_samples, len(train_loader.dataset),
               '-' * progress + '>', progress * 2), end='')

第四步:定义教师模型测试方法

正常的定义一个神经网络模型

def test_teacher(model, device, test_loader):
    model.eval()  # 设置为评估模式
    test_loss = 0
    correct = 0
    with torch.no_grad():  # 不计算梯度,减少计算量
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # 将数据转移到CPU/GPU
            output = model(data)  # 经过模型正向传播得到结果
            test_loss += F.cross_entropy(output, target, reduction='sum').item()  # 计算总的损失函数
            pred = output.argmax(dim=1, keepdim=True)  # 获取最大对数概率索引
            correct += pred.eq(target.view_as(pred)).sum().item()  # pred.eq(target.view_as(pred)) 会返回一个布尔张量,其中每个元素表示预测值是否等于目标值。然后,.sum().item() 会将所有为 True 的元素相加,从而得到正确分类的数量。 

    test_loss /= len(test_loader.dataset)  # 计算损失函数

    print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss, correct / len(test_loader.dataset)

第五步:定义教师模型主函数

整体也和正常模型一样,但是这里使用了teacher_history去保留需要知识蒸馏的数据。

def teacher_main():
    epochs = 10
    batch_size = 64
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 使用的设备类型
    
	# 导入训练集
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))  # 数据正则化
                       ])),
        batch_size=batch_size, shuffle=True)
    
    # 导入测试集
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))	# 数据正则化
        ])),
        batch_size=1000, shuffle=True)

    model = TeacherNet().to(device)  # 传输经过教师模型网络
    optimizer = torch.optim.Adadelta(model.parameters())  # 使用Adadelta优化器
    
    teacher_history = []  # 记录教师得到结果的历史
    for epoch in range(1, epochs + 1):
        train_teacher(model, device, train_loader, optimizer, epoch)  # 开始训练模型
        loss, acc = test_teacher(model, device, test_loader)  # 计算损失函数和准确率
        teacher_history.append((loss, acc))  # 记录教师模型得到的历史数据

    torch.save(model.state_dict(), "teacher.pt")  # 保存到权重文件
    return model, teacher_history

第六步:开始训练教师模型

# 训练教师网络
teacher_model, teacher_history = teacher_main()

第七步:定义学生模型网络结构

学生模型的网络结构定义时一般要比教师模型简单一些,这样才能达到知识蒸馏轻量化的目的

class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)	# 全连接层
        self.fc2 = nn.Linear(128, 64)		# 全连接层
        self.fc3 = nn.Linear(64, 10)		# 全连接层

    def forward(self, x):
        x = torch.flatten(x, 1)		# 将输入张量沿着第二维度平
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        output = F.relu(self.fc3(x))
        return output

第八步:定义知识蒸馏方法

这里定义知识蒸馏主要是实现其损失函数。

def distillation(y, labels, teacher_scores, temp, alpha):
    return nn.KLDivLoss()(F.log_softmax(y / temp, dim=1), F.softmax(teacher_scores / temp, dim=1)) * (
            temp * temp * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)

我们这里写一下这个公式:
K L D i v L o s s ( l o g ( s o f t m a x ( y t e m p ) ) , s o f t m a x ( t e a c h e r   s c o r e s t e m p ) ) ∗ 2 α t e m p 2 C r o s s E n t r o p y ( y , l a b e l s ) ( 1 − α ) KLDivLoss(log(softmax(\frac{y}{temp})),softmax(\frac{teacher~scores}{temp}))*2\alpha temp^2\\CrossEntropy(y,labels)(1-\alpha) KLDivLoss(log(softmax(tempy)),softmax(tempteacher scores))2αtemp2CrossEntropy(y,labels)(1α)
其中 α \alpha α 1 − α 1-\alpha 1α为系数, t e m p 2 temp^2 temp2用于调节量纲。

第九步:定义学生模型训练和测试方法

学生模型训练部分和教师模型训练部分基本一样,除了两个部分。

第一个部分是,需要重点关注teacher_output = teacher_output.detach()切断教师模型反向传播这一行。

第二个部分是,这里训练使用的损失函数是上面定义的知识蒸馏的损失函数

def train_student_kd(model, device, train_loader, optimizer, epoch):
    model.train()
    trained_samples = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)  # 学生模型前向传播
        teacher_output = teacher_model(data)  # 教师模型前向传播
        teacher_output = teacher_output.detach()  # 切断老师网络的反向传播
        loss = distillation(output, target, teacher_output, temp=5.0, alpha=0.7)  # 计算总损失函数,这里使用的是知识蒸馏的损失函数
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数

        trained_samples += len(data)
        progress = math.ceil(batch_idx / len(train_loader) * 50)
        print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
              (epoch, trained_samples, len(train_loader.dataset),
               '-' * progress + '>', progress * 2), end='')

def test_student_kd(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()  # 计算总的损失函数
            pred = output.argmax(dim=1, keepdim=True)  # 获取最大对数概率索引
            correct += pred.eq(target.view_as(pred)).sum().item()  # 计算准确率
    test_loss /= len(test_loader.dataset)

    print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss, correct / len(test_loader.dataset)

第十步:定义学生模型主函数

def student_kd_main():
    epochs = 10
    batch_size = 64
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	# 加载训练集
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True)
    # 加载测试集
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=1000, shuffle=True)
	# 加载学生模型
    model = StudentNet().to(device)
    optimizer = torch.optim.Adadelta(model.parameters())
    
    student_history = []  # 记录学生训练的模型
    for epoch in range(1, epochs + 1):
        train_student_kd(model, device, train_loader, optimizer, epoch)
        loss, acc = test_student_kd(model, device, test_loader)
        student_history.append((loss, acc))

    torch.save(model.state_dict(), "student_kd.pt")
    return model, student_history
student_kd_model, student_kd_history = student_kd_main()

知识蒸馏步骤总结

(1)先训练教师模型,定义教师模型的训练方法和测试方法

(2)定义知识蒸馏损失函数

(3)再训练学生模型,定义学生模型的训练方法和测试方法

(4)训练学生模型的时候需要将教师模型得到数据输出经过知识蒸馏作为输入,并且要阻断教师模型的反向传播,并利用知识蒸馏损失函数进行反向传播

(5)训练结束后得到学生模型

猜你喜欢

转载自blog.csdn.net/m0_61787307/article/details/131554720
今日推荐