RS 1: Notes of PyTorch Tutorials —— Transfer Learning

  1 """
  2 迁移学习:利用一个已经在其他训练集训练好的模型的权重或特征层来对目标训练集进行训练
  3 """
  4 # __future__模块用于把下一个版本的一些新特性导入当前版本,使得当前版本兼容这些新特性
  5 # import print_function:在python2版本中输出不需要加括号,但是在python3版本中需要
  6 # import division:在python2的代码中可以直接使用python3的除法
  7 from __future__ import print_function, division
  8 
  9 # 导入torch包
 10 import torch
 11 
 12 # 导入torch.nn包,其中主要包含了用来搭建各个层的模块(Modules),如全连接、二维卷积、池化等;
 13 # torch.nn包中还包含了一系列Loss函数,如CrossEntopyLoss、MSELoss等;
 14 # torch.nn.functional子包中包含了常用的激活函数,如relu、leaky_relu、prelu、sigmoid等。
 15 import torch.nn as nn 
 16 
 17 # 导入torch.optim包,用于定义损失函数和优化方法
 18 import torch.optim as optim
 19 
 20 # torch.optim.lr_scheduler提供了基于多种epoch数目调整学习率的方法,用于定义学习率的变化策略
 21 from torch.optim import lr_scheduler 
 22 
 23 import numpy as np 
 24 
 25 # torchvision包用于生成图片、视频数据集、一些流行的模型类和一些预训练模型
 26 # torchvision由四个部分组成:torchvision.datasets,torchvision.models,torchvision.transforms,torchvision.utils
 27 # torchvision.datasets包含多个数据集,其中ImageFolder是一种通用的data loader
 28 # torchvision.models提供了多个模型结构,如AlexNet、VGG、ResNet等
 29 # torchvision.transforms提供常见的图像变换(预处理)操作,这些变换可以用torchvision.transforms.Compose连接在一起
 30 # torchvision.utils包含了两个图像处理的工具:torchvision.utils.make_grid()和torchvision.utils.save_image()
 31 import torchvision
 32 from torchvision import datasets, models, transforms
 33 
 34 # 用plt进行画图
 35 import matplotlib.pyplot as plt 
 36 
 37 # 导入时间模块
 38 import time
 39 
 40 # 通过os模块调用系统命令
 41 import os
 42 
 43 # copy模块实现复制操作
 44 # python3中可以直接使用copy()方法,但deepcopy()仍然需要导入copy模块
 45 # copy()方法共享内存地址,一个子列表数据的变动会导致其他列表数据的变动
 46 # deepcopy()复制后分开存储两份数据,它们之间是否有变动互不影响
 47 import copy 
 48 
 49 
 50 # 打开交互模式(使matplotlib的显示模式转换为交互模式)
 51 plt.ion()
 52 
 53 
 54 data_transforms = {
 55     'train': transforms.Compose([
 56         transforms.RandomResizedCrop(224),  # 先将给定的PIL.Image随机切割,然后再resize成224*224
 57         transforms.RandomHorizontalFlip(),  # 以0.5的概率将给定的PIL.Image水平翻转
 58         # transforms.RandomVerticalFlip(),  # 以0.5的概率将给定的PIL.Image垂直翻转
 59         transforms.ToTensor(),              # 把一个取值范围是[0, 255]的PIL.Image或shape为(H, W, C)的numpy.ndarray转换成形状为[C, H, W],取值范围是[0, 1.0]的torch.FloatTensor
 60                                             # transforms.RandomResizedCrop和transforms.RandomHorizontalFlip()的输入对象都是PIL.Image,而transforms.Normalize()的作用对象需要
 61                                             # 是tensor,因此在transforms.Normalize()之前需要用transforms.ToTensor()将PIL.Image转换成tensor
 62         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
 63         # transforms.Normalize(mean, std):给定均值(R, G, B)和方差(R, G, B),将输入的Tensor归一化,即:Normalized_image = (image-mean)/std
 64     ]),
 65     'val': transforms.Compose([
 66         transforms.Resize(256),     # 将图片的大小进行缩放,以统一图片格式。如果输入int,则表示将输入图像的短边resize到这个int数,在保持图像的长宽比不变的前提下将对应的长边进行调整;
 67                                     # 如果输入是个(h, w)的序列(h和w都是int),则会直接将输入图像resize到这个(h, w)尺寸,相当于force resize,因此一般最后图像会被拉伸或压缩(长宽比发生变化)
 68         transforms.CenterCrop(224), # 以输入PIL.Image的中心点为中心,将其进行切割,得到224*224的图片
 69         transforms.ToTensor(),      # transforms.Resize和transforms.CenterCrop的输入对象都是PIL.Image,而transforms.Normalize()的作用对象需要
 70                                     # 是tensor,因此在transforms.Normalize()之前需要用transforms.ToTensor()将PIL.Image转换成tensor
 71         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
 72     ]),
 73 }
 74 
 75 
 76 # 数据导入
 77 # data_dir指定图像所在的文件夹,data_dir目录下一般包括两个文件夹:train和val,每个文件夹下面又包含N个子文件夹(N是分类的类别数目),且每个子文件夹里存放的就是这个类别的图像,
 78 # 此时torchvision.datasets.ImageFolder会返回一个列表(如下面代码中的image_datasets['train']或image_datasets['val']),列表中的每个值都是一个tuple,每个tuple都包含图像和标签信息。
 79 data_dir = 'data/hymenoptera_data'
 80 
 81 # image_datasets是对数据集进行处理
 82 image_datasets = {x: datasets.ImageFolder( # torchvision.datasets.ImageFolder(root="root folder path", [transforms, target_transfroms])接口实现数据导入
 83                                            # 需要提供图像所在的文件夹及数据的变换方式
 84     os.path.join(data_dir, x),             # os.path.join('path1', 'path2',...)用于将多个路径组合后返回:path1/path2/...
 85                                            # 这里返回的路径有两个:data/hymenoptera_data/train 和 data/hymenoptera_data/val
 86     data_transforms[x])                    # data_transforms是一个字典,进行图像预处理
 87     for x in ['train', 'val']}             # image_datasets最终返回的值是一个list,而list是不能作为模型输入的,因此在python中需要用另一个类来封装list,即下面代码中
 88                                            # 的torch.utils.data.DataLoader类
 89                                            # image_datasets = {'train': Dataset ImageFolder
 90                                            # Number of datapoints: 244
 91                                            # Root Location: data/hymenoptera_data/train
 92                                            # Transforms (if any): Compose(
 93                                            #                          RandomResizedCrop(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=PIL.Image.BILINEAR)
 94                                            #                          RandomHorizontalFlip(p=0.5)
 95                                            #                          ToTensor()
 96                                            #                          Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 97                                            #                      )
 98                                            # Target Transforms (if any): None, 'val': Dataset ImageFolder
 99                                            # Number of datapoints: 153
100                                            # Root Location: data/hymenoptera_data/val
101                                            # Transforms (if any): Compose(
102                                            #                          Resize(size=256, interpolation=PIL.Image.BILINEAR)
103                                            #                          CenterCrop(size=(224, 224))
104                                            #                          ToTensor()
105                                            #                          Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
106                                            #                      )
107                                            # Target Transforms (if any): None}
108 
109 # dataloaders实现对数据集的加载
110 dataloaders = {x: torch.utils.data.DataLoader( # torch.utils.data.DataLoader类将list型的输入数据封装成tensor数据格式,以供模型使用。它会将图像和标签分别封装成一个Tensor,因此
111                                                # 如果图像数据不是按照一个类别一个文件夹的方式存放时,需要自己定义一个类来读取数据,这个自定义的类必须继承自torch.utils.data.Dataset
112                                                # 这个基类,然后同样使用torch.utils.data.DataLoader封装成tensor
113                                                # DataLoader是数据加载器,它组合了数据集和采样器
114     image_datasets[x],                         # 要加载的数据集
115     batch_size=4,                              # 每个batch加载4个样本(默认是1)
116     shuffle=True,                              # 将shuffle设置为True时会在每个epoch时重新打乱数据(默认是False)
117     num_workers=4)                             # 指定用4个子进程加载数据,0表示数据将在主进程中加载(默认是0)
118     for x in ['train', 'val']}                 # dataloaders = {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f37e270cda0>, 
119                                                # 'val': <torch.utils.data.dataloader.DataLoader object at 0x7f37e270ce80>}
120                                                # 因此dataloaders是一个字典,dataloaders['train']存的是训练数据,for循环的目的是从dataloaders['train']中读取batch_size个
121                                                # 数据,这些数据中同时包含了图像和它对应的标签
122 
123 dataset_sizes = {x: len(image_datasets[x])     # dataset_sizes = {train的图片数与val的图片数},即dataset_sizes = {'train': 244, 'val': 153}
124                 for x in ['train', 'val']}
125 
126 class_names = image_datasets['train'].classes  # class_names中保存训练集train中所有类别的名称,即class_names = ['ants', 'bees']
127 
128 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 通过torch.device指定将tensor放到GPU还是CPU上运行
129 
130 
131 
132 # Tensor数据可视化:visualize a few images
133 def imshow(inp, title=None):
134     inp = inp.numpy().transpose((1, 2, 0)) # 经make_grid函数处理之后,得到的out依旧是Tensor型的数据,而且它的三个维度依次是[C, W, H],但在imshow的时候要将其转换成numpy的ndarray,
135     mean = np.array([0.485, 0.456, 0.406]) # 并且数据的维度特征必须为[W, H, C],
136     std = np.array([0.229, 0.224, 0.225])  # 除此之外,
137     inp = std * inp + mean                 # 还要乘以方差并加上均值(即反规范化,因为之前预处理的时候是减去均值除以方差)
138     inp = np.clip(inp, 0, 1)               # numpy.clip(self, min=None, max=None, out=None)函数的作用是:将self中小于min的数全部替换为min,大于max的数全部替换为max,
139                                            # 在[min, max]之间的数保持不变,并返回修改后的和self形状一样的array。
140     plt.imshow(inp)
141     if title is not None:
142         plt.title(title)
143     plt.pause(10) # plt.pause()控制图片显示时间的长短,pause的时间越长,图片停留在屏幕上的时间就越久
144 
145 inputs, classes = next(iter(dataloaders['train'])) # 取出一个batch的训练数据:inputs.shape=[4, 3, 224, 224],classes.shape=[4]
146                                                    # classes的值只能取0和1,因为class_names只包含两个元素:class_names = ['ants', 'bees'],
147                                                    # 由此可见,在dataloaders中,类别的名称是以二值的形式进行存储的
148                                                    # iter(object[, sentinel])函数用来生成迭代器(意思就是它可以遍历object中的每个元素),其中的object是支持迭代的集合对象
149                                                    # 此处的dataloaders是一个dict:
150                                                    # dataloaders = {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f37e270cda0>, 
151                                                    # 'val': <torch.utils.data.dataloader.DataLoader object at 0x7f37e270ce80>}
152                                                    # dataloaders['train'] = <torch.utils.data.dataloader.DataLoader object at 0x7feaa158aef0>,
153                                                    # 是Tensor类型的数据对象,而且包含两组Tensor:一组是图像,另一组是标签。它们是以batch的形式存放的(batch_size=4),
154                                                    # (可以将next函数理解为指针)因此,每次程序运行时,next函数都会指向下一个batch的位置,并将它们的值分别返回给inputs和classes
155                                                    # next函数所指向的对象必须是iterator型的可迭代对象
156                                                    # next(iterator[, default])函数用来返回迭代器的下一个项目,其中iterator是可迭代对象(可理解为可遍历对象)
157 out = torchvision.utils.make_grid(inputs) # torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False)
158                                           # 用于将若干幅图像拼成一幅图像
159                                           # 输入的tensor可以是形状为(B*C*H*W)的4D mini-batch Tensor,也可以是形状都相同的图片列表
160                                           # nrow控制每行显示多少个小图片,padding控制两个小图片之间间隔的距离
161                                           # out.shape=[3, 228, 906],可见,与inputs.shape=[4, 3, 224, 224]相比,它是将4张图片的宽度拼起来形成1张图片(996=224*4 + 100)
162 imshow(out, title=[class_names[x] for x in classes]) # e.g. classes = tensor([0, 1, 0, 1]),而class_names = ['ants', 'bees']
163 
164 
165 
166 
167 """
168  定义一个训练函数以实现以下功能:1.可以对学习率进行调控;2.寻找并保存最佳的模型
169 """
170 def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
171     since = time.time() # 计时开始(计时是在训练之前开始的)
172 
173     best_model_wts = copy.deepcopy(model.state_dict()) # 先深拷贝一份当前模型的参数,后面迭代过程中若遇到更优模型则替换
174     best_acc = 0.0 # 初始准确率
175 
176     # 轮训练 --> 批训练
177     for epoch in range(num_epochs):
178         print('Epoch {}/{}'.format(epoch, num_epochs - 1)) # epoch表示当前正在进行第几轮,num_epochs表示总的训练轮数
179         print('-' * 10) # 连续打印10个 - 符号,即----------
180 
181         for phase in ['train', 'val']: # 每个epoch中都有训练部分和验证部分
182             if phase == 'train':
183                 scheduler.step() # 在训练开始时先更新学习率(因为前面已经定义了学习率的变化策略,所以在每个epoch开始时都要进行更新)
184                 model.train()    # phase == 'train'时,将模型设置为训练模式(即将model中所有层的training都设置为True),
185             else:
186                 model.eval()     # 否则,将模型设置为验证模式(即将model中所有层的training都设置为False)
187 
188             running_loss = 0.0   # running_loss用来记录每轮训练的总loss
189             running_corrects = 0 # running_corrects用来记录每轮训练后分类正确的图片的个数
190 
191             for inputs, labels in dataloaders[phase]: # dataloaders[phase]中以每4个一批的方式存放了所有的图片数据,每次循环inputs和labels均以4个一批的方式取dataloaders[phase]
192                                                       # 中的数据,当循环结束时,dataloaders[phase]中的数据恰好以4个一批的方式被全部取出
193                                                       # 用逗号连接可以同时遍历inputs和labels两个变量
194                                                       # 若遍历的变量位于多个不同的数组中,则要用zip()函数将这些不同的地方封装成一个tuple,例如
195                                                       # for i, j, k, l, m in zip(I, J, K, L, M):
196                 inputs = inputs.to(device) # 将数据放到相应的设备上
197                                            # inputs的形状为[4, 3, 224, 224],是一个batch的数据,相当于每轮只会从dataloaders中挑出一个batch的数据进行训练,直至遍历完所有batch
198                 labels = labels.to(device) # labels的形状为[4],是形如tensor([0, 1, 1, 0])的数据
199 
200                 optimizer.zero_grad() # 在反向传播之前需要先将优化器中的梯度值清零,因为在默认情况下反向传播的梯度值会进行累加
201 
202                 with torch.set_grad_enabled(phase == 'train'): # torch.set_grad_enabled(bool),当bool=True时,with语句块中涉及到的所有变量的requires_grad属性都将被设置为True,
203                                                                # 否则,requires_grad都将被设置为False。这一设置不会影响到with语句块之外的变量
204                                                                # with语句适用于事先需要设置且事后需要进行清理的任务
205                                                                # with语句可以保证不管处理过程中是否发生错误或者异常都会执行规定的__exit__操作,释放被访问的资源
206                     outputs = model(inputs)           # 正向传播
207                     _, preds = torch.max(outputs, 1) # 返回第1维(即每一行)中元素的最大值,且返回其索引(该最大值元素在这一行的列索引)
208                                                      # 因此preds中保存的是每一行中最大值元素的列索引,其输出形式如 tensor([1, 1, 1, 1])
209                                                      # torch.max()[0]  只返回最大值的每个数
210                                                      # torch.max()[1]  只返回最大值的每个索引
211                     loss = criterion(outputs, labels) # 将输出的outputs和原来导入的labels作为loss函数的输入即可得到损失
212                                                       # criterion被定义为交叉熵损失函数
213 
214                     if phase == 'train': # 计算得到loss后就要回传损失,因为这是在训练过程中才会有的操作(测试的时候只有forward过程),所以要加上条件判断 if phase == 'train'
215                         loss.backward()  # 反向传播,计算损失函数对于网络参数的梯度值
216                         optimizer.step() # 反向传播过程中要计算梯度,然后根据这些梯度更新参数,这一过程可以通过optimizer.step()实现
217                                          # optimizer.step()后,就可以从optimizer.param_groups[0]['params']中看到各个层的梯度和权值信息
218                 
219                 running_loss += loss.item() * inputs.size(0) # PyTorch 0.4.0的loss是一个零维的张量,使用loss.item()可以从张量中获取Python数字
220                                                              # 如果在累加损失时未将loss转换为Python数字,则可能会出现程序内存使用量增加的情况,因为
221                                                              # 对张量进行累加的同时也会累加它的梯度历史,可能会产生很大的autograd图,耗费内存和计算资源
222                                                              # loss的输出形如tensor(0.8567, grad_fn=<NllLossBackward>)
223                                                              # loss.item()的输出形如0.8567252159118652
224                                                              # inputs.size(0)输出为4
225                 running_corrects += torch.sum(preds == labels.data) # tensor.data返回和原tensor相同的tensor数据,但不会加入到原tensor的计算历史中
226                                                                     # preds的输出形如tensor([1, 1, 1, 1])
227                                                                     # labels的输出形如tensor([0, 1, 0, 1])
228                                                                     # labels.data的输出也形如tensor([0, 1, 0, 1])
229                                                                     # torch.sum函数可以统计preds和labels.data中对应位置相同的元素的对数,并返回一个tensor型的数据
230                                                                     # 如:torch.sum(tensor([0,1,1,0]) == tensor([1,1,1,0]))=tensor(3)  
231                                                                     #    torch.sum(tensor([0,1,1,0]) == tensor([1,0,0,1]))=tensor(0)    
232                 # 每次循环,running_loss和running_corrects都会将相应批次的总loss和总正确数目累加至上一次的running_loss和running_correct中,
233             # 当整个一轮的训练结束时,running_loss记录的就是该轮的总loss,running_corrects记录的就是该轮的总正确数目  
234             epoch_loss = running_loss / dataset_sizes[phase] # dataset_sizes['train']是训练集中的图片数目,dataset_sizes['val']是测试集中的图片数目,
235                                                              # 此时running_loss记录的就是该轮的总loss,running_corrects记录的就是该轮的总正确数目,
236                                                              # 因此epoch_loss记录的是该轮的平均loss,
237             epoch_acc = running_corrects.double() / dataset_sizes[phase] # epoch_acc记录的是该轮的正确率
238 
239             print('{} Loss: {:.4f} Acc: {:.4f}'.format( # 在每一个大循环(即每一轮训练)内,会执行两个小循环(第一个循环是训练过程,第二个循环是测试过程),每一个小循环都会输出一批数据,
240                 phase, epoch_loss, epoch_acc            # 第一个小循环输出的是训练过程的loss和准确率,第二个小循环输出的是测试过程的loss和准确率
241             ))
242 
243             if phase == 'val' and epoch_acc > best_acc: # 验证时,若遇到更好的模型则予以保留(epoch_acc是每轮的正确率)
244                 best_acc = epoch_acc
245                 best_model_wts = copy.deepcopy(model.state_dict())
246             # 每一个epoch的每一个小循环至此结束
247 
248         print() # 每一轮的验证和测试都结束之后,输出一个换行符,使每两轮训练的输出结果之间间隔一个空行
249         # 每一个epoch至此结束
250     
251     time_elapsed = time.time() - since # 所有轮的训练都完成后,记录本次训练的总时间,单位为秒(当前时间没有意义,有意义的是时间差)
252                                        # time.time()函数返回程序对应位置处当前时间的时间戳(1970纪元后经过的浮点秒数)
253     print('Training complete in {:.0f}m {:.0f}s'.format( # 在python3中print是一个函数,通过格式化函数format()来控制输出格式,此时冒号相当于原来的百分号
254                                                          # .0f表示小数点后保留0位(即只取整数位)
255         time_elapsed // 60, time_elapsed % 60            # 单斜杠/表示浮点数除法,只要/两边有一个数是浮点数,结果就是浮点数;
256     ))                                                   # 而双斜杠//表示整数除法,它返回一个不大于对应浮点数除法运算结果的最大整数
257     print('Best val Acc: {:4f}'.format(best_acc)) # 整个训练过程采用的是边训练边测试的策略,best_acc记录的是在测试集上性能最好的那轮训练的结果
258 
259     model.load_state_dict(best_model_wts) # 将最优的模型参数加载下来
260     return model # 调用train_model最终返回的是一个model
261 
262 
263 
264 """
265 加载一个经预训练的模型,并重新设定最后一个全连接层
266 """
267 model_ft = models.resnet18(pretrained=False) # 导入ResNet18网络,pretrained=True表示通过网络的方式在线加载模型参数,使用pretrained=False可以通过离线的方式加载模型参数
268 model_ft.load_state_dict(torch.load('resnet18.pth')) # 加载已经在ImageNet上训练好的模型参数
269                                             # torchvision.models模块中已经内置了一些网络结构,如VGG,ResNet,DenseNet等
270 num_ftrs = model_ft.fc.in_features          # 获取全连接层的输入channel个数,这里为512
271                                             # 因为预训练网络一般是在1000类的ImageNet数据集上进行的,所以如果要迁移到自己数据集的2分类,需要将最后的全连接层替换为所需要的输出,即
272 model_ft.fc = nn.Linear(num_ftrs, 2)        # 用前面获取的channel个数和要分类的类别数(此处是2)替换原来模型中的全连接层(即将最后一个全连接层由(512, 1000)改为(512, 2))
273 
274 model_ft = model_ft.to(device)              # 将模型放到GPU或CPU上
275 
276 criterion = nn.CrossEntropyLoss()           # 定义损失函数为交叉熵损失函数,
277                                             # torch.nn模块同时定义网络所有层的损失函数(如卷积层、池化层、损失层等)
278 
279 optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) # 定义优化函数,通过torch.optim模块实现
280                                                                         # 此处用的是带动量项的SGD,即Adam优化方式
281                                                                         # 这个类的输入包括三项需要优化的参数:model.parameters()、学习率、和Adam相关的momentum参数
282 
283 exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) # 定义学习率的变化策略
284                                                                              # 使用torch.optim.lr_sheduler模块的StepLR类,表示每隔step_size个epoch,就将学习率降为原来的gamma倍
285 
286 
287 
288 """
289 训练
290 """
291 model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25) # 训练完成后会return一个model
292 
293 
294 
295 """
296 定义函数,实现模型预测结果的可视化
297 """
298 def visualize_model(model, num_images=6):
299     was_training = model.training # 检查是否是训练模式
300                                   # model.training是一个bool型的数据,若model.training = True,则表示该model所有层的training都被设置为True,表示处在训练模式
301                                   # 反之,若model.training = False,则表示该model所有层的training都被设置为False,表示处在测试模式
302     model.eval() # 将模式设置为测试模式
303     images_so_far = 0 # images_so_far确定了每个子图放置的位置
304     fig = plt.figure() # 创建自定义图像
305 
306     with torch.no_grad(): # torch.no_grad()的作用是停止梯度计算,在此模式下,每一步的计算结果中requires_grad都是False,即使input设置为requires_grad=True
307         for i, (inputs, labels) in enumerate(dataloaders['val']): # 函数原型是enumerate(sequence, [start=0]),其功能是对一个可遍历的数据对象(如列表、元组或字符串),
308                                                                   # 将该数据对象组合为一个索引序列,同时列出数据下标和相应的数据
309                                                                   # 因此i中存的是索引,(inputs, labels)中存的是对应的数据
310             inputs = inputs.to(device)
311             labels = labels.to(device)
312 
313             outputs = model(inputs)
314             _, preds = torch.max(outputs, 1) # preds中存的是索引
315 
316             for j in range(inputs.size()[0]): # inputs是一个批次的数据,即4*3*224*224
317                                               # inputs.size() = torch.Size([4, 3, 224, 224])
318                                               # inputs.size()[0] = 4
319                 images_so_far += 1 # 第1个子图放在第一个位置上
320                 ax = plt.subplot(num_images//2, 2, images_so_far) # 3行2列
321                 ax.axis('off')
322                 ax.set_title('predicted: {}'.format(class_names[preds[j]])) # class_names = ['ants', 'bees'],将标题设置为图像对应的名称
323                 imshow(inputs.cpu().data[j]) # imshow()函数接收一个tensor
324 
325                 if images_so_far == num_images:
326                     model.train(mode=was_training)
327                     return
328         model.train(mode=was_training) # 在测试之后将模型恢复成之前的形式
329 
330 
331 
332 
333 """
334 可视化结果
335 """
336 visualize_model(model_ft)

注:本博客只是为源代码提供了大量的注释,对源码基本没有改动。有些注释可能未必准确,后续还需做进一步的改动。

代码中所需要的数据源及源码地址均参见https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

猜你喜欢

转载自www.cnblogs.com/tbgatgb/p/10735258.html