Easy-to-understand knowledge distillation Knowledge Distillation (below) - code practice (with detailed notes)

Step 1: Import the required packages

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.utils.data

torch.manual_seed(0)		# 为CPU设置种子
torch.cuda.manual_seed(0)	# 为GPU设置种子

Step 2: Define the teacher model

Teacher model network structure (only one example here): convolutional layer-convolutional layer-dropout-dropout-fully connected layer-fully connected layer

class TeacherNet(nn.Module):
    def __init__(self):
        super(TeacherNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)		# 卷积层
        self.conv2 = nn.Conv2d(32, 64, 3, 1)	# 卷积层
        self.dropout1 = nn.Dropout2d(0.3)		# dropout
        self.dropout2 = nn.Dropout2d(0.5)		# dropout
        self.fc1 = nn.Linear(9216, 128)			# 全连接层
        self.fc2 = nn.Linear(128, 10)			# 全连接层

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)			# 激活函数
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        output = self.fc2(x)
        return output

Step 3: Define the method to train the teacher model

Normal definition of a neural network model

def train_teacher(model, device, train_loader, optimizer, epoch):
    model.train()
    trained_samples = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)	# 将数据转移到CPU/GPU
        optimizer.zero_grad()	# 优化器将梯度全部置为0
        output = model(data)	# 数据经过模型向前传播
        loss = F.cross_entropy(output, target)  # 计算损失函数
        loss.backward()			# 反向传播
        optimizer.step()		# 更新梯度

        trained_samples += len(data)
        progress = math.ceil(batch_idx / len(train_loader) * 50) # 计算训练进度
        print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
              (epoch, trained_samples, len(train_loader.dataset),
               '-' * progress + '>', progress * 2), end='')

Step 4: Define the Teacher Model Test Method

Normal definition of a neural network model

def test_teacher(model, device, test_loader):
    model.eval()  # 设置为评估模式
    test_loss = 0
    correct = 0
    with torch.no_grad():  # 不计算梯度,减少计算量
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # 将数据转移到CPU/GPU
            output = model(data)  # 经过模型正向传播得到结果
            test_loss += F.cross_entropy(output, target, reduction='sum').item()  # 计算总的损失函数
            pred = output.argmax(dim=1, keepdim=True)  # 获取最大对数概率索引
            correct += pred.eq(target.view_as(pred)).sum().item()  # pred.eq(target.view_as(pred)) 会返回一个布尔张量,其中每个元素表示预测值是否等于目标值。然后,.sum().item() 会将所有为 True 的元素相加,从而得到正确分类的数量。 

    test_loss /= len(test_loader.dataset)  # 计算损失函数

    print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss, correct / len(test_loader.dataset)

Step 5: Define the teacher model main function

The whole is the same as the normal model, but here it is used teacher_historyto retain data that requires knowledge distillation.

def teacher_main():
    epochs = 10
    batch_size = 64
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 使用的设备类型
    
	# 导入训练集
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))  # 数据正则化
                       ])),
        batch_size=batch_size, shuffle=True)
    
    # 导入测试集
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))	# 数据正则化
        ])),
        batch_size=1000, shuffle=True)

    model = TeacherNet().to(device)  # 传输经过教师模型网络
    optimizer = torch.optim.Adadelta(model.parameters())  # 使用Adadelta优化器
    
    teacher_history = []  # 记录教师得到结果的历史
    for epoch in range(1, epochs + 1):
        train_teacher(model, device, train_loader, optimizer, epoch)  # 开始训练模型
        loss, acc = test_teacher(model, device, test_loader)  # 计算损失函数和准确率
        teacher_history.append((loss, acc))  # 记录教师模型得到的历史数据

    torch.save(model.state_dict(), "teacher.pt")  # 保存到权重文件
    return model, teacher_history

Step 6: Start training the teacher model

# 训练教师网络
teacher_model, teacher_history = teacher_main()

Step 7: Define the student model network structure

The definition of the network structure of the student model is generally simpler than that of the teacher model, so as to achieve the purpose of knowledge distillation and lightweight

class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)	# 全连接层
        self.fc2 = nn.Linear(128, 64)		# 全连接层
        self.fc3 = nn.Linear(64, 10)		# 全连接层

    def forward(self, x):
        x = torch.flatten(x, 1)		# 将输入张量沿着第二维度平
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        output = F.relu(self.fc3(x))
        return output

Step 8: Define the Knowledge Distillation Method

The definition of knowledge distillation here is mainly to realize its loss function.

def distillation(y, labels, teacher_scores, temp, alpha):
    return nn.KLDivLoss()(F.log_softmax(y / temp, dim=1), F.softmax(teacher_scores / temp, dim=1)) * (
            temp * temp * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)

Let's write this formula here:
KLD iv L oss ( log ( softmax ( ytemp ) ) , softmax ( teacher scorestemp ) ) ∗ 2 α temp 2 C ross Entropy ( y , labels ) ( 1 − α ) KLDivLoss(log(softmax (\frac{y}{temp})),softmax(\frac{teacher~scores}{temp}))*2\alpha temp^2\\CrossEntropy(y,labels)(1-\alpha)KLDivLoss(log(softmax(tempy)),softmax(tempteacher scores))2αtemp2CrossEntropy(y,labels)(1α )
whereα \alphaα and1 − α 1-\alpha1α is the coefficient,temp 2 temp^2temp2 is used to adjust the dimension.

Step 9: Define student model training and testing methods

The student model training part is basically the same as the teacher model training part, except for two parts.

The first part is that we need to focus on teacher_output = teacher_output.detach()cutting off the line of teacher model backpropagation.

The second part is that the loss function used for training here is the loss function of knowledge distillation defined above

def train_student_kd(model, device, train_loader, optimizer, epoch):
    model.train()
    trained_samples = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)  # 学生模型前向传播
        teacher_output = teacher_model(data)  # 教师模型前向传播
        teacher_output = teacher_output.detach()  # 切断老师网络的反向传播
        loss = distillation(output, target, teacher_output, temp=5.0, alpha=0.7)  # 计算总损失函数,这里使用的是知识蒸馏的损失函数
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数

        trained_samples += len(data)
        progress = math.ceil(batch_idx / len(train_loader) * 50)
        print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
              (epoch, trained_samples, len(train_loader.dataset),
               '-' * progress + '>', progress * 2), end='')

def test_student_kd(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()  # 计算总的损失函数
            pred = output.argmax(dim=1, keepdim=True)  # 获取最大对数概率索引
            correct += pred.eq(target.view_as(pred)).sum().item()  # 计算准确率
    test_loss /= len(test_loader.dataset)

    print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss, correct / len(test_loader.dataset)

Step 10: Define the main function of the student model

def student_kd_main():
    epochs = 10
    batch_size = 64
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	# 加载训练集
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True)
    # 加载测试集
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=1000, shuffle=True)
	# 加载学生模型
    model = StudentNet().to(device)
    optimizer = torch.optim.Adadelta(model.parameters())
    
    student_history = []  # 记录学生训练的模型
    for epoch in range(1, epochs + 1):
        train_student_kd(model, device, train_loader, optimizer, epoch)
        loss, acc = test_student_kd(model, device, test_loader)
        student_history.append((loss, acc))

    torch.save(model.state_dict(), "student_kd.pt")
    return model, student_history
student_kd_model, student_kd_history = student_kd_main()

Summary of knowledge distillation steps

(1) Train the teacher model first, define the training method and test method of the teacher model

(2) Define the knowledge distillation loss function

(3) Retrain the student model, define the training method and testing method of the student model

(4) When training the student model, it is necessary to use the data output obtained by the teacher model as input through knowledge distillation, and to block the backpropagation of the teacher model, and use the knowledge distillation loss function for backpropagation

(5) Obtain the student model after training

Guess you like

Origin blog.csdn.net/m0_61787307/article/details/131554720