龙良曲pytorch学习笔记_10

LeNet5网络和CIFAR10数据集

main函数--dataloader--train--test

 1 import torch
 2 from torch.utils.data import DataLoader
 3 from torchvision import datasets
 4 from torchvision import transforms
 5 from torch import nn,optim
 6 from lenet5 import LeNet5
 7 
 8 def main():
 9     batch_size = 32
10     cifar_train = datasets.CIFAR10('cifar',train = True,transform = transforms.Compose([
11         transforms.Resize((32,32)),
12         transforms.ToTensor()
13     ]),download = True)
14     
15     # 可以同时加载多张图片
16     cifar_train = DataLoader(cifar_train,batch_size = batch_size,shuffle = True)
17     
18     cifar_test = datasets.CIFAR10('cifar',train = False,transform = transforms.Compose([
19         transforms.Resize((32,32)),
20         transforms.ToTensor()
21     ]),download = True)
22     
23     # 可以同时加载多张图片
24     cifar_test = DataLoader(cifar_test,batch_size = batch_size,shuffle = True)
25 
26     # 数据加载成功后可以检验shape
27     x,label = iter(cifar_train).next()
28     print('x:',x.shape,'label:',label.shape)
29 
30     device = torch.device('cuda')
31     model = LeNet5().to(device)
32     criteon = nn.CrossEntropyLoss().to(device)
33     optimizer = optim.Adam(model.parameters(),lr=1e-3)
34     
35     print(model)
36     
37     for epoch in range(1000):
38         
39         model.train()
40         for batchidx,(x,label) in enumerate(cifar_train):
41             # x: [b,3,32,32], label: [b]
42             x,label = x.to(device),label.to(device)
43             
44             logits = model(x)
45             # logits:[b,10]
46             # label:[b]
47             loss = criteon(logits,label)
48             
49             # backprop
50             optimizer.zero_grad()
51             loss.backwark()
52             optimizer.step()
53             
54         #
55         print(epoch,loss.item())
56         
57         model.eval()
58         # 不需要做梯度相关计算
59         with torch.nn_grad():
60             # test
61             total_correct = 0
62             total_num = 0
63             for x,label in cifar_test:
64                 x,label = x.to(device),label.to(device)
65                 # logits:[b,10]
66                 logits = model(x)
67                 pred = logits.argmax(dim=1)
68                 # 获取一个batch的在累加
69                 total_correct = += torch.eq(pred,label).float().sum().item()
70                 # x.size(0)就是batch_size
71                 total_num += x.size(0)
72                 
73             acc = total_correct / total_num
74             print(epoch,acc)
75             
76 if __name__ == '__main__'
77     main()

LeNet网络--tmp测试

 1 import torch
 2 from torch import nn
 3 from torch.nn import functional as F
 4 
 5 class LeNet5(nn.Module):
 6     """
 7     for cifar10 dataset.
 8     """
 9     def __init__(self):
10         super(LeNet5,self).__init__()
11         
12         self.conv_unit = nn.Sequential(
13             # x:[b,3,32,32] --> [b,6,]
14             # input_channel,output_channel,kernel_size,stride,padding
15             nn.Conv2d(3,6,kernel_size = 5,stride = 1,padding = 0),
16             nn.AvgPool2d(kernel_size = 2,stride = 2,padding = 0),
17             # 
18             nn.Conv2d(6,16,kernel_size = 5,stride = 1,padding = 0),
19             nn.AvgPool2d(kernel_size = 2,stride = 2,padding = 0),
20         )
21         # Flatten
22         # fc_unit
23         self.fc_unit = nn.Sequential(
24             # 由下面的测试得出来的
25             nn.Linear(16*5*5,120),
26             # 全连接层会出现梯度离散现象,加一个relu
27             nn.ReLU(),
28             nn.Linear(120,84),
29             nn.ReLU(),
30             nn.Linear(84,10),
31         )
32         '''
33         tmp = torch.randn(2,3,32,32)
34         out = self.conv_unit(tmp)
35         # 测试一下输出的维度,用于全连接层
36         # [2,16,5,5]
37         print('conv_out:',out.shape)
38         '''
39         
40         # use Cross Entropy Loss
41         # 放到类外,不用引入y参数
42         # self.criteon = nn.CrossEntropyLoss()
43         
44         # 从左往右走的,backward会自动根据这个走
45     def forward(self,x):
46         # 取得x的shape,然后0号为batch_size
47         batch_size = x.size(0)
48         # [b,3,32,32] --> [b,16,5,5]
49         x = self.conv_unit(x)
50         # [b,16,5,5] --> [b,16*5*5]
51         x = x.view(batch_size,16*5*5)
52         # [b,16*5*5] --> [b,10]
53         logits = self.fc_unit(x)
54         return logits
55         # [b,10] crossEntropy会包含,不用写
56         # pred = F.softmax(logits,dim = 1)
57         # loss = self.criteon(logits,y)
58         
59         
60 def main():
61 
62     net = LeNet5()
63     tmp = torch.randn(2,3,32,32)
64     out = net(tmp)
65     print('lenet_out:',out.shape)
66     
67     
68 if __name__ == '__main__'
69     main()

猜你喜欢

转载自www.cnblogs.com/fxw-learning/p/12317504.html