pytorch:load_state_dict

       在训练比较大、耗时较久的网络时,如果突然停电、断网或者一些意外情况发生导致训练中断,那么已经训练好的内容可能全部丢失,这时我们就需要在训练过程中把一些时间点的checkpoint保存下来,及时训练意外中断,那么我们也可以在之后把这些checkpoint下载下来,重新开始训练。
(谁能想到我刚刚码好这段话就停电了呢????)
在这里插入图片描述



以下内容大部分和

cifar-10+resnet.
一样,重点在load_state_dict的,可以直接跳转:

戳这里↓

import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
trans = transforms.Compose((transforms.Resize(32),transforms.ToTensor()))
cifar_train = datasets.CIFAR10('cifar',train = True,transform=trans)

cifar_train_batch = DataLoader(cifar_train,batch_size=30,shuffle=True)
cifar_test = datasets.CIFAR10('cifar',train = False,transform=trans)

cifar_test_batch = DataLoader(cifar_test,batch_size=30,shuffle=True)
#搭建resnet
class resblock(nn.Module):
    def __init__(self,ch_in,ch_out,stride):
        super(resblock,self).__init__()
        self.conv_1 = nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)
        self.bn_1 = nn.BatchNorm2d(ch_out)
        self.conv_2 = nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)
        self.bn_2 = nn.BatchNorm2d(ch_out)
        self.ch_in,self.ch_out,self.stride = ch_in,ch_out,stride
        self.ch_trans = nn.Sequential()
        if ch_in != ch_out:
            self.ch_trans = nn.Sequential(nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),nn.BatchNorm2d(self.ch_out))
        #ch_trans表示通道数转变。因为要做short_cut,所以x_pro和x_ch的size应该完全一致
        
    def  forward(self,x):
        x_pro = F.relu(self.bn_1(self.conv_1(x)))
        x_pro = self.bn_2(self.conv_2(x_pro))
        
        #short_cut:
        x_ch = self.ch_trans(x)
        out = x_pro + x_ch
        return out     
class resnet(nn.Module):
    def __init__(self):
        super(resnet,self).__init__()
        self.conv_1 = nn.Sequential(
        nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1),
        nn.BatchNorm2d(64))
        self.block1 = resblock(64,128,2) #长宽减半 32/2=16
        self.block2 = resblock(128,256,2) #长宽再减半 16/2=8
        self.block3 = resblock(256,512,1)
        self.block4 = resblock(512,512,1)
        self.outlayer = nn.Linear(512,10) #512*8*8=32768
        
    def forward(self,x):
        x = F.relu(self.conv_1(x))
        x = self.block1(x) 
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = F.adaptive_avg_pool2d(x,[1,1])
        x = x.reshape(x.size(0),-1)
        result = self.outlayer(x)
        return result
        
device = torch.device('cuda')
net = resnet()
net = net.to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(net.parameters(),lr=1e-3)

load_state_dict

#开始训练
for epoch in range(5):
    for batchidx,(x,label) in enumerate(cifar_train_batch):
        x,label = x.to(device),label.to(device) #x.size (bcs,3,32,32) label.size (bcs)
        logits = net.forward(x)
        loss = loss_fn(logits,label) #logits.size:bcs*10,label.size:bcs
        
        #开始反向传播:
        optimizer.zero_grad()
        loss.backward() #计算gradient
        optimizer.step() #更新参数
        if (batchidx+1)%400 == 0:
            print('这是本次迭代的第{}个batch'.format(batchidx+1))  #本例中一共有50000张照片,每个batch有30张照片,所以一个epoch有1667个batch
            '''
            就是这里!!!每400个batch就存一次checkpoint,
            存到指定的文件,这里我设的是一个TXT文件
            '''
            torch.save(net.state_dict(),"./resnet_ckp.txt")
    
    print('这是第{}迭代,loss是{}'.format(epoch+1,loss.item()))
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第1迭代,loss是0.7839802503585815
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第2迭代,loss是1.0195786952972412
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第3迭代,loss是0.5244616866111755
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第4迭代,loss是0.6468905210494995
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第5迭代,loss是0.8967750668525696
net.eval()
with torch.no_grad():
    correct_num = 0
    total_num = 0
    batch_num = 0
    for x,label in cifar_test_batch: #x的size是30*3*32*32(30是batch_size,3是通道数),label的size是30.
                                     #cifar_test中一共有10000张照片,所以一共有334个batch,因此要循环334次
        x,label = x.to(device),label.to(device) 
        logits = net.forward(x)
        pred = logits.argmax(dim=1)
        correct_num += torch.eq(pred,label).float().sum().item()
        total_num += x.size(0)
        batch_num += 1
        if batch_num%50 == 0:
            print('这是测试集上的第{}个batch'.format(batch_num)) #一共有10000/30≈334个batch
            
    acc = correct_num/total_num  #最终的total_num是10000
    print('测试集上的准确率为:',acc)
这是测试集上的第50个batch
这是测试集上的第100个batch
这是测试集上的第150个batch
这是测试集上的第200个batch
这是测试集上的第250个batch
这是测试集上的第300个batch
测试集上的准确率为: 0.7628
  • 我们刚刚训练resnet的时候,只设置了5个epoch,
  • 如果我们想在已经训练了5个epochs的基础上继续训练,那么我们就可以把之前存下来的checkpoint加载下来继续开始训练~
continue_net = resnet().to(device)
para_in_last_net = torch.load('./resnet_ckp.txt') #把之前网络的参数下载到para_in_last_net中

continue_net.load_state_dict(para_in_last_net) #把para_in_last_net加载到continue_net中
#其实这两步可以合到一起写 continue_net.load_state_dict(torch.load('./resnet_ckp.txt'))
#然后我们再训练2个epoch(这次我们就不保存checkpoint了)
for epoch in range(2):
    for batchidx,(x,label) in enumerate(cifar_train_batch):
        x,label = x.to(device),label.to(device) #x.size (bcs,3,32,32) label.size (bcs)
        logits = continue_net.forward(x)
        loss = loss_fn(logits,label) #logits.size:bcs*10,label.size:bcs
        
        #开始反向传播:
        optimizer.zero_grad()
        loss.backward() #计算gradient
        optimizer.step() #更新参数
        if (batchidx+1)%400 == 0:
            print('这是本次迭代的第{}个batch'.format(batchidx+1))  #本例中一共有50000张照片,每个batch有30张照片,所以一个epoch有1667个batch
#             torch.save(net.state_dict(),"./resnet_ckp.txt")
    
    print('这是第{}迭代,loss是{}'.format(epoch+1,loss.item()))
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第1迭代,loss是0.809099018573761
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第2迭代,loss是0.43857383728027344
continue_net.eval()
with torch.no_grad():
    correct_num = 0
    total_num = 0
    batch_num = 0
    for x,label in cifar_test_batch: #x的size是30*3*32*32(30是batch_size,3是通道数),label的size是30.
                                     #cifar_test中一共有10000张照片,所以一共有334个batch,因此要循环334次
        x,label = x.to(device),label.to(device) 
        logits = continue_net.forward(x)
        pred = logits.argmax(dim=1)
        correct_num += torch.eq(pred,label).float().sum().item()
        total_num += x.size(0)
        batch_num += 1
        if batch_num%50 == 0:
            print('这是测试集上的第{}个batch'.format(batch_num)) #一共有10000/30≈334个batch
            
    acc = correct_num/total_num  #最终的total_num是10000
    print('测试集上的准确率为:',acc)
这是测试集上的第50个batch
这是测试集上的第100个batch
这是测试集上的第150个batch
这是测试集上的第200个batch
这是测试集上的第250个batch
这是测试集上的第300个batch
测试集上的准确率为: 0.7754
  • 可以看到,又训练了两个epoch后,准确率提升了一点点(之前5个epoch是0.7682)

发布了43 篇原创文章 · 获赞 1 · 访问量 741

猜你喜欢

转载自blog.csdn.net/weixin_41391619/article/details/104939067
今日推荐