torch.nn.Conv2d

class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

参数:Conv2d(

           输入channels,

           输出channels,

           kernel_size,)

备注:在定义了相应的卷积层后,会在model的state_dict中自动生成相关参数的state描述.

     每定义一个卷积层,会在model的state_dict中自动生成两个tensor参数:

         1)conv2d.weight     shape=[输出channels,输入channels,kernel_size,kernel_size]

         2)conv2d.bias   shape=[输出channels]

举例:

#定义一个卷积层 conv1
conv1 = torch.nn.Conv2d(5,10,3)
# 则在model的state_dict中会自动保存以下参数
conv1.weight    (torch.Size=[10, 5, 3, 3])
conv1.bias      (torch.Size()=[10])

示例2:

扫描二维码关注公众号,回复: 3659434 查看本文章
#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim



# define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass,self).__init__()
        self.conv1 = nn.Conv2d(3,6,5)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6,16,5)
        self.fc1 = nn.Linear(16*5*5,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    def forward(self,x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1,16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# initial model
model = TheModelClass()

#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor,'\t',model.state_dict()[param_tensor].size())

print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
    print(var_name,'\t',optimizer.state_dict()[var_name])

print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)

输出:

model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])

optimizer's state_dict
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [139985995931368, 139984655959456, 139984655959600, 139984655959672, 139984655959744, 139984655959816, 139984655959888, 139984655959960, 139984655960032, 139984655960104]}]

print particular param

 torch.Size([6, 3, 5, 5])

 Parameter containing:
tensor([[[[-0.0080, -0.0503,  0.0092, -0.1068, -0.0789],
          [-0.1028, -0.0067, -0.1015, -0.0660,  0.1107],
          [ 0.0733,  0.0195, -0.0236,  0.0244,  0.0168],
          [-0.0310, -0.0915,  0.0267, -0.0465, -0.0112],
          [-0.0876, -0.0579, -0.0689, -0.0397, -0.1020]],

         [[ 0.0148, -0.0605, -0.0428, -0.0280, -0.0038],
          [-0.0452,  0.0938,  0.0793, -0.0857,  0.0700],
          [-0.0463, -0.0326, -0.0130,  0.0460,  0.0138],
          [ 0.1144,  0.0173, -0.0178, -0.0745,  0.0625],
          [ 0.0713,  0.0400, -0.0596, -0.0878, -0.0773]],

         [[ 0.0782,  0.0849, -0.0777,  0.0770, -0.0115],
          [-0.0918, -0.0262,  0.0067,  0.0481,  0.0812],
          [ 0.0411, -0.1067,  0.0187,  0.0250,  0.0964],
          [ 0.0076,  0.0715, -0.0559,  0.0888, -0.0787],
          [-0.0894,  0.0258,  0.1001, -0.0621, -0.0245]]],


        [[[-0.0464, -0.0124, -0.0204, -0.0179,  0.0263],
          [ 0.1148,  0.0955, -0.0630,  0.0382, -0.0889],
          [ 0.1114,  0.0027, -0.0478, -0.0857, -0.0735],
          [ 0.0446,  0.0893, -0.0671,  0.0066, -0.0356],
          [-0.1027,  0.0593, -0.0410, -0.0647,  0.0377]],

         [[-0.0145,  0.0259, -0.0488, -0.1128, -0.0441],
          [-0.0269, -0.0213,  0.0958, -0.0159, -0.1011],
          [ 0.0614, -0.0445, -0.0642, -0.0092,  0.0317],
          [ 0.0399, -0.0608, -0.0156,  0.1112,  0.0865],
          [ 0.0679, -0.0030,  0.0948,  0.0804, -0.0644]],

         [[ 0.0625,  0.0002, -0.0690,  0.0803, -0.0091],
          [ 0.0073,  0.1063,  0.0663,  0.0094, -0.0997],
          [-0.0938,  0.0973, -0.0571, -0.0281, -0.0008],
          [ 0.0502, -0.0266, -0.0459, -0.0831,  0.0589],
          [ 0.1062,  0.0144,  0.0318,  0.0814,  0.0641]]],


        [[[ 0.0706,  0.0121, -0.0918,  0.0571, -0.0780],
          [ 0.0068,  0.0786, -0.0118,  0.0070,  0.0367],
          [-0.0983, -0.0742,  0.0878,  0.1115, -0.0342],
          [ 0.0682, -0.1151,  0.0689, -0.1039, -0.0854],
          [-0.0185,  0.0474, -0.0282, -0.0707, -0.0105]],

         [[-0.0562,  0.0887,  0.0002,  0.0974,  0.1088],
          [-0.0568,  0.0291,  0.0522, -0.0791, -0.0136],
          [ 0.0480,  0.0764,  0.1015,  0.0315, -0.0715],
          [ 0.0078,  0.1052,  0.0647, -0.0707, -0.0269],
          [-0.0742,  0.1057,  0.0410,  0.0867, -0.0098]],

         [[-0.0847,  0.0005,  0.0210,  0.1104, -0.0865],
          [ 0.0424, -0.0321, -0.0856,  0.0761, -0.1053],
          [-0.0995,  0.0792,  0.0428,  0.0239,  0.0532],
          [-0.0705,  0.0683, -0.0691,  0.0287, -0.0657],
          [-0.0518, -0.0395,  0.0270,  0.0997, -0.0581]]],


        [[[ 0.0071,  0.1119,  0.0198,  0.0697,  0.0853],
          [-0.0718, -0.0216, -0.0026,  0.0939,  0.0791],
          [ 0.0584, -0.0262,  0.0226,  0.0166, -0.0898],
          [ 0.1004, -0.0992,  0.0630,  0.0591,  0.0152],
          [-0.0731, -0.0343,  0.0821,  0.0518, -0.0257]],

         [[-0.0847,  0.1124, -0.0815, -0.0989,  0.0975],
          [ 0.0750, -0.0998, -0.0341,  0.0603,  0.0299],
          [ 0.0504, -0.0782, -0.0870,  0.0940, -0.0717],
          [-0.0387,  0.1046, -0.0216,  0.0870, -0.0550],
          [-0.0772,  0.0888,  0.0341,  0.0018,  0.0923]],

         [[-0.0257, -0.0024, -0.0461,  0.0309, -0.0204],
          [ 0.0782, -0.1152, -0.1073, -0.0128, -0.1088],
          [ 0.0238,  0.0951, -0.1048,  0.1055,  0.1090],
          [ 0.0984, -0.0634,  0.0864,  0.1067, -0.1024],
          [-0.0499,  0.1054,  0.0025, -0.0640, -0.0089]]],


        [[[-0.0263,  0.0849, -0.0872, -0.0457, -0.1010],
          [-0.0327,  0.0176, -0.0301,  0.0329,  0.0561],
          [-0.0325,  0.0409, -0.0862,  0.0603, -0.0904],
          [-0.0352,  0.0723,  0.0955, -0.0478, -0.1055],
          [-0.0711, -0.0076, -0.0725, -0.0856,  0.0413]],

         [[ 0.0999, -0.0613, -0.0390, -0.1126,  0.0182],
          [ 0.0302,  0.0699,  0.0263,  0.0594,  0.0965],
          [-0.0062,  0.0779,  0.0010,  0.0617,  0.0596],
          [ 0.0058, -0.0344,  0.0266, -0.0754, -0.0667],
          [ 0.0120,  0.1121, -0.0693,  0.0516,  0.0863]],

         [[-0.0897, -0.0838, -0.0126,  0.0938,  0.0570],
          [ 0.0729,  0.0482,  0.0066,  0.0559, -0.0951],
          [ 0.0750,  0.0592,  0.0550,  0.0671,  0.0661],
          [-0.1132, -0.0496, -0.0931,  0.0659, -0.0453],
          [ 0.0177,  0.0018,  0.0622,  0.0571,  0.1092]]],


        [[[ 0.0697,  0.0629,  0.0071,  0.0266,  0.0199],
          [-0.1087,  0.1084,  0.0488, -0.0162,  0.1147],
          [-0.0944, -0.1005, -0.0494,  0.0163, -0.0477],
          [ 0.0199, -0.0245,  0.0768, -0.0319, -0.0087],
          [ 0.0823,  0.1125, -0.0000, -0.0238, -0.0647]],

         [[ 0.0107, -0.0313, -0.0060,  0.0010,  0.0102],
          [-0.0748,  0.0240, -0.0658, -0.0524,  0.0908],
          [-0.0921, -0.1004, -0.0492,  0.0021,  0.0020],
          [-0.1136,  0.0122,  0.0324,  0.0125,  0.0843],
          [-0.0888,  0.0573,  0.0286,  0.0672,  0.0266]],

         [[-0.0215, -0.0275, -0.0994,  0.1052,  0.1087],
          [ 0.0008, -0.1082, -0.0890,  0.0155,  0.0612],
          [ 0.0211,  0.0042, -0.0483,  0.0919, -0.1100],
          [-0.0703, -0.0263, -0.0256, -0.0122, -0.0594],
          [-0.0150, -0.0508, -0.0393, -0.1073,  0.0849]]]],
       requires_grad=True)

参考:https://pytorch.org/docs/stable/nn.html#conv2d

        https://pytorch.org/tutorials/beginner/saving_loading_models.html

 

猜你喜欢

转载自blog.csdn.net/Strive_For_Future/article/details/83240232