pytorch加载模型中的部分参数

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets,transforms
from torch.optim import lr_scheduler
import torch.optim as optim

class VGG(nn.Module):
    def __init__(self):
        super(VGG,self).__init__()

        self.conv1 = nn.Conv2d(3,64,3,padding=(1,1))
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool1 = nn.MaxPool2d((2,2))

        self.conv2 = nn.Conv2d(64,128,3,padding=(1,1))
        # self.bn2 = nn.BatchNorm2d(128)
        self.maxpool2 = nn.MaxPool2d((2,2))

        self.conv3 = nn.Conv2d(128,256,3,padding=(1,1))
        # self.bn3 = nn.BatchNorm2d(256)
        self.maxpool3 = nn.MaxPool2d((2,2))

        self.fc1 = nn.Linear(256*16*8,4096)
        self.fc2 = nn.Linear(4096,1000)
        self.fc3 = nn.Linear(1000,10)

    def forward(self,x):
        in_size = x.size(0)

        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.maxpool1(out)

        out = self.conv2(out)
        out = F.relu(out)
        out = self.maxpool2(out)

        out = self.conv3(out)
        out = F.relu(out)
        out = self.maxpool3(out)

        out = out.view(out.size(0),-1)

        out = self.fc1(out)
        out = F.relu(out)

        out = self.fc2(out)
        out = F.relu(out)
        out = self.fc3(out)

        return out

transform_train_list = transforms.Compose([
    transforms.Resize( (256,128),interpolation=3 ),
    transforms.RandomCrop((128,64)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder('./train',transform_train_list)
dataloaders = torch.utils.data.DataLoader(train_dataset,batch_size=2,num_workers=0)


pre_model_dict = torch.load("first.pth")
print(pre_model_dict['conv1.bias'])
for k,v in pre_model_dict.items():
    print(k)

model1 = VGG()
model1_dict = model1.state_dict()
model1_dict['conv1.bias'] = pre_model_dict['conv1.bias']
model1.load_state_dict(model1_dict)
print(model1_dict['conv1.bias'])
for k,v in model1_dict.items():
    print(k)
print()

##
发布了36 篇原创文章 · 获赞 11 · 访问量 6550

猜你喜欢

转载自blog.csdn.net/t20134297/article/details/103534368