pytorch implements model distillation

#首先导入模块、准备数据
import torch
from torch.utils.data import DataLoader
import torch.utils.data as Data
import torchvision.transforms as transforms
import numpy as np
import os
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

#根据自己情况加载自己的数据集
import torch
from torch.autograd import Variable
#对数据归一化处理
maxT = trainx.max()
print(maxT)
minT = trainx.min()
print(minT)
trainx = (trainx-minT)/(maxT-minT)
testx = (testx-minT)/(maxT-minT)

trainx = torch.tensor(trainx)
trainy = torch.tensor(trainy)
testx = torch.tensor(testx)
testy = torch.tensor(testy)
trainx =  Variable(torch.unsqueeze(trainx, dim=1).float(), requires_grad=False)
trainy = Variable(torch.unsqueeze(trainy, dim=1).float(), requires_grad=False)
testx = Variable(torch.unsqueeze(testx, dim=1).float(), requires_grad=False)
testy = Variable(torch.unsqueeze(testy, dim=1).float(), requires_grad=False)
print(trainx.shape)
print(testx.shape)


Establish the student model and initialize it according to your own needs. If the effect of the studentnet model is too poor, consider deepening it.

class studentNet(nn.Module):
    def __init__(self):
        super(anNet,self).__init__()
        self.conv1 = nn.Conv2d(1,6,3)
        self.pool1 = nn.MaxPool2d(2,1)
        self.fc3 = nn.Linear(6*25*25,2)
    def forward(self,x):
        x = self.conv1(x)
        x = self.pool1(F.relu(x))
        x = x.view(x.size()[0],-1)
        x = self.fc3(x)
        return x
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                torch.nn.init.normal_(m.weight.data, 0, 0.01)
                m.bias.data.zero_()
#建立teacher模型 
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        """定义BasicBlock残差块类
        
        参数:
            inplanes (int): 输入的Feature Map的通道数
            planes (int): 第一个卷积层输出的Feature Map的通道数
            stride (int, optional): 第一个卷积层的步长
            downsample (nn.Sequential, optional): 旁路下采样的操作
        注意:
            残差块输出的Feature Map的通道数是planes*expansion
        """
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class TeacherNet(nn.Module):

    def __init__(self, block, layers, num_classes=2, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 1
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
      #  self.conv1 = nn.Conv2d(1, 3, kernel_size=7, stride=2, padding=3,
       #                        bias=False)
      #  self.bn1 = norm_layer(3)
      #  self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)

        # 网络的第一层加入注意力机制
        self.ca = ChannelAttention(self.inplanes)
        self.sa = SpatialAttention()

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       )
        #self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
         #                              )
        #self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
           #                            )
        # 网络的卷积层的最后一层加入注意力机制
        self.ca1 = ChannelAttention(self.inplanes)
        self.sa1 = SpatialAttention()

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(128 * block.expansion, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)


    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.ca(x) * x
        x = self.sa(x) * x

        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        #x = self.layer3(x)
        #x = self.layer4(x)

        x = self.ca1(x) * x
        x = self.sa1(x) * x


        x = self.avgpool(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = F.log_softmax(x, dim=1)

        return x


#加载之前训练好的teacherNet参数
teach_model = TeacherNet(BasicBlock, [2, 2, 2, 2], **kwargs).to(device)
teach_model.load_state_dict((torch.load('F:/模型/teacherNet.pkl')).state_dict(),strict=False)
#建立student模型  以及初始化超参数

model = studentNet()

criterion = nn.CrossEntropyLoss()
criterion2 = nn.KLDivLoss()

optimizer = optim.Adam(model.parameters(),lr = 0.0001)

correct_ratio = []
alpha = 0.5
#进行训练

for epoch in range(200):
    loss_sigma = 0.0
    correct = 0.0
    total = 0.0
    for i, data in enumerate(trainload):
        inputs, labels = data
        #inputs = inputs.cuda()
        #labels = labels.cuda()
        labels = labels.squeeze().long()
        optimizer.zero_grad()
        
        outputs = model(inputs.float())
        loss1 = criterion(outputs, labels)
        
        teacher_outputs = teach_model(inputs.float())
        T = 2
        outputs_S = F.log_softmax(outputs/T,dim=1)
        outputs_T = F.softmax(teacher_outputs/T,dim=1)
        loss2 = criterion2(outputs_S,outputs_T)*T*T
        
        loss = loss1*(1-alpha) + loss2*alpha

#        loss = loss1
        loss.backward()
        optimizer.step()
        
        _, predicted = torch.max(outputs.data, dim = 1)
        total += labels.size(0)
        correct += (predicted.cpu()==labels.cpu()).squeeze().sum().numpy()
        loss_sigma += loss.item()
        if i% 100 == 0:
            loss_avg = loss_sigma/10
            loss_sigma = 0.0
            print('loss_avg:{:.2}   Acc:{:.2%}'.format(loss_avg, correct/total))
            print("Train Epoch: {} [{}/{} ({:0f}%)]\tLoss: {:.6f}".format(
                epoch, i * len(data), len(trainload.dataset), 
                100. * i / len(trainload), loss.item()
            ))

Guess you like

Origin blog.csdn.net/qq_43360777/article/details/106383210