Handwriting recognition of knowledge distillation

knowledge distillation

This article is mainly based on the website video (https://www.bilibili.com/video/BV1s7411h7K2?t=906), if there is any error in understanding, please criticize and give pointers

1. First proposed

First proposed: https://arxiv.org/pdf/1503.02531.pdf
The author's motivation is to find a way to extract the knowledge of multiple models to a single model.
Although many classification models now use cross-entropy to measure the predicted value and the real value, the one-hot vector used by the real value can not provide as much information as the probability distribution.
Rationale: Probability distributions are more informative than onehot - dark knowledge.
loss=0.7 KL divergence (softmax_t (teacher output)) + 0.3 cross entropy (oneHot)
The loss of the two distributions can use KL divergence.

2. Brief introduction

Knowledge distillation can transfer the knowledge of one network to another network, and the two networks can be homogeneous or heterogeneous. The method is to train a teacher network first, and then use the output of the teacher network and the real label of the data to train the student network. Knowledge distillation can be used to transform the network from a large network to a small network, and retain the performance close to the large network; it can also transfer the learned knowledge of multiple networks to one network, so that the performance of a single network is close to emsemble the result of.

Example: Fruits and Vegetables
insert image description here

3. Some loss functions in pytorch

insert image description here

log_softmax_p = torch.log(torch.softmax(p))
loss1= F.nll_loss(log_softmax_p ,target)
loss2 = F.cross_entropy(p,target))
loss1 == loss2

insert image description here

4. Core

insert image description here

When there is no teacher network, just pass the image through the student network, after softmax, output the probability distribution value q, and calculate the loss between q and the real value is called Hard loss, because this p is the one-hot vector of the real value, we hope that q The closer to p the better.
When with the help of a teacher, the loss comes from the student and teacher networks. Moreover, the q' output by the teacher must be distilled (to make it smoother) to get q'' and then calculate the loss with q. The total loss is the sum between them.

5. Theoretical analysis

insert image description here
By introducing a soft target: soft-target as part of the total loss, the student network is induced to streamline and low-complexity training to achieve knowledge transfer.
Among them, the design of total loss uses cross entropy loss and kl divergence (soft target, student prediction). As shown in the figure above, if the weighting coefficient of soft target cross entropy is larger, it indicates that migration induction is more dependent on the contribution of the teacher network. This is necessary in the early stages of training. But in the later stage of training, the proportion of soft targets needs to be reduced.

The digit recognition without smoothing in the figure is 9, and the probability of 9 in the corresponding probability is very small. After smoothing is added, the proportion of 9 is relatively large.
insert image description here

6. Example: Knowledge Distillation for Handwriting Recognition

6.1 Teacher model

Simple Convolutional Network

class TeacherNet(nn.Module):
    def __init__(self):
        super(TeacherNet,self).__init__()
        self.conv1 = nn.Conv2d(1,32,3,1)
        self.conv2 = nn.Conv2d(32,64,3,1)
        self.dropout1 = nn.Dropout2d(0.3)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216,128)
        self.fc2 = nn.Linear(128,10)
        
    #此处没有softmax
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x,2)
        x = self.dropout1(x)
        x = torch.flatten(x,1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        output = self.fc2(x)
        return output

6.2 Student model

Simpler Linear Model

class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet,self).__init__()
        self.fc1 = nn.Linear(28*28,128)
        self.fc2 = nn.Linear(128,64)
        self.fc3 = nn.Linear(64,10)
        
    def forward(self,x):
        x = torch.flatten(x,1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        output = F.relu(self.fc3(x))
        return output
    

6.3 Teacher and Student

KD's Loss

def distillation(y,labels,teacher_scores,temp,alpha):
    return nn.KLDivLoss()(F.log_softmax(y/temp, dim=1),
                          F.softmax(teacher_scores/temp,dim=1))*(temp*temp*2.0*alpha)+F.cross_entropy(y,labels)*(1. - alpha)

Teachers teach students on the premise that the teacher network has been trained before teaching students.

def train_student_kd(model,device,train_loader,optimizer,epoch):
    model.train()
    teacher_model.eval()
    trained_samples = 0
    for batch_idx,(data,target) in enumerate(train_loader):
        data,target = data.to(device),target.to(device)
        optimizer.zero_grad()
        output = model(data)
        teacher_output = teacher_model(data)
        teacher_output = teacher_output.detach()
        loss = distillation(output,target,teacher_output,temp=5.0,alpha=0.7)
        loss.backward()
        optimizer.step()
        trained_samples += len(data)
        progress = math.ceil(batch_idx / len(train_loader)*50)
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
#         print("\r Train epoch %d:%d/%d"%(epoch,trained_samples,len(train_loader.dataset)))
        
def test_student_kd(model,device,test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data,target in test_loader:
            data,target = data.to(device),target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item() 
            pred = output.argmax(dim=1,keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print("\r Test average loss:{:.4f},accuracy:{}/{}({:.0f}%)".format(test_loss,correct,len(test_loader.dataset),100.*correct/len(test_loader.dataset)))
    return test_loss, correct/len(test_loader.dataset)

    
def student_kd_main():
    epochs = 10
    batch_size = 64
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data/MNIST',train=True,download=False,
                                                             transform=transforms.Compose([
                                                                 transforms.ToTensor(),
                                                                 transforms.Normalize((0.1307,),(0.3081,))
                                                             ])),
                                              batch_size=batch_size,shuffle=True)
    
    test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data/MNIST',train=False,download=False,
                                                             transform=transforms.Compose([
                                                                 transforms.ToTensor(),
                                                                 transforms.Normalize((0.1307,),(0.3081,))
                                                             ])),
                                              batch_size=batch_size,shuffle=True)
    
    model = StudentNet().to(device)
    teacher_model = TeacherNet()
    teacher_model.load_state_dict(torch.load('teacher.pth')) #加载老师网络参数
    
    optimizer = torch.optim.Adadelta(model.parameters())
    student_history = []
    for epoch in range(1,epochs+1):
        train_student_kd(model,device,train_loader,optimizer,epoch)
        loss , acc = test_student_kd(model,device,test_loader)
        student_history.append((loss,acc))
    torch.save(model.state_dict(),'student.pth')
    return model,student_history

insert image description here

Guess you like

Origin blog.csdn.net/snow_maple521/article/details/116007538