PyTorch经验分享:新手如何搭建PyTorch程序

为什么是PyTorch

在2017年的10月份,笔者还在研二的时候,曾发布过一篇有关TensorFlow编程经验的博客。当时,笔者刚接触深度学习框架TensorFlow不足半年,相较于笔者最开始接触的深度学习框架Caffe,TensorFlow更易于上手,更方便随心所欲地实现各种网络结构,确实是一个非常优良的适合算法研究的深度学习框架。但是,在笔者对TensorFlow一年多的使用,写过一些程序后,对TensorFlow的部分局限性感受深刻,比如:

  1. TensorFlow的数据流图是静态的,因此在TensorFlow框架编程的过程中,是无法直接对Tensor进行条件判断,只能通过会话层tf.Session()将Tensor的值run出来再进行判断。因此,在TensorFlow中,进行条件分支判断就非常难,比如采用tf.cond或者tf.py_func
  2. 上述的静态图机制,也使得TensorFlow程序非常难调试。
  3. 笔者不得不吐槽TensorFlow的接口,在不断更新换代升级的过程中变化太快也太大(比如TensorFlow 2.0与TensorFlow 1.0)。这就造成了许多老版本程序,无法在新版本的TensorFlow框架上运行。

当然,纵使TensorFlow具有上述缺点,依然是非常优秀的深度学习框架,非常多的开源代码仍旧选择TensorFlow作为实现平台。

可是,该如何克服TensorFlow的上述缺点呢?答案是选择PyTorch。时间来到2018-2019年,深度学习框架PyTorch的热度越来越高,因为PyTorch具有非常多的优点。首先,TensorFlow的优点PyTorch也同样拥有,比如:

  1. 使用Python接口,使得算法实现轻松容易。
  2. 在构造普通模块的过程中,用户无需关心训练过程中的梯度反传,这为算法研究与实现带来了极大的方便。
  3. 接口简单,模块集成度高,支持用户较快地实现算法思想。

除此之外,PyTorch还具有许多其他的优点:

  1. 采用动态图机制。在PyTorch中,能够非常方便地将GPU中的Tensor与Python中的Numpy Array进行相互转换,读出Tensor的值。这就使得用户能非常方便地实现分支程序,而且极大地增加了调试的方便性。
  2. PyTorch官方文档比较全面且详细,不同版本的程序接口变化不大。为用户提供了极大的便利。

基于以上的优点,笔者也对PyTorch进行了了解并上手。在本篇博客中,笔者将与大家分享搭建PyTorch程序的经验,下面开始干货。

如何搭建PyTorch程序

在计算机视觉相关的深度学习任务中,简单的PyTorch程序主要包含以下这4个模块。

-网络模型定义文件,比如名称为network.py。
-数据读取文件,比如名称为dataset.py。
-训练接口程序,比如名称为train.py。
-测试接口程序,比如名称为evaluate.py。

比如,随便打开一个笔者的PyTorch程序,如下图所示:
在这里插入图片描述
在上图中,就包含之前提到的network,dataset,train和evaluate。当然也包含一些笔者自己定制的模块,比如cal_metrics用来计算实验指标,cfg用来设置某些超参数,image_io_and_process用来对图像进行预处理和后处理,loss_functions用来定义与实现某些损失,utils用来承载一些其他操作,比如读写txt文档。

网络模型的定义

在PyTorch中,定义一个网络模型,需要定义一个Python类,并继承torch.nn.Module这个类。在这个类中,通常是通过重写构造函数__init__来定义网络中使用的层,在定义层或者模块的时候,经常会使用到torch.nn这个库。然后再重写父类的forward函数,在其中使用__init__函数中定义的层实现网络的前传过程。比如下面的例子:

class My_Network(nn.Module):
    def __init__(self):
        super(network, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) #3×3卷积
        self.bn1 = nn.BatchNorm2d(64) #BatchNorm
        self.relu = nn.ReLU() #激活
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) #最大池化
        self.fc = nn.Linear(64 * 14 * 14, 512) #全连接,将[n, (64*14*14)]变成[n, 512]

        for m in self.modules(): #初始化
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x): #前传,输入x尺寸为[n, c, 28, 28]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = x.view(x.size(0), -1) #将四维Tensor变成两维,作为全连接的输入
        x = self.fc(x)

        return x

在构造函数__init__里面,定义好了网络所使用的各个模块。这个网络对输入完依次进行了卷积,批归一化,激活,最大池化,全连接的功能。并且用一个for循环对可训练参数进行初始化。初始化的过程一般都是在网络定义的构造函数中完成。在初始化的过程中可见,在构造函数中定义的所有模块,都被存放在self.modules()这个列表中
然后,我们重写函数forward,就能实现对输入的x进行前传了。这样,就定义好了一个完整的简单的网络。

数据读取接口

在PyTorch中,与其他框架具有鲜明差异的一点,就是数据读取接口非常规范。在PyTorch中,通过继承torch.utils.data.Dataset这个类实现自己的数据集读取。在继承时,主要需要重写三个函数:__init__函数,__getitem__函数和__len__函数。其中,__init__函数主要进行一些初始化,比如读取一下所需的记录数据的txt或者xml文件,传入一些预处理参数等等。而__getitem__函数主要规定了,在每个batch训练给网络喂数据的时候,应该采用怎样的方式读取数据,以及做怎样的预处理。至于__len__函数的主要功能,就是告诉PyTorch,数据集中有多少数据。下面,就是一个数据读取接口的简单实现:

class My_Dataset(torch.utils.data.Dataset):

    def __init__(self, data_txt_path):

        data_list = read_txt(data_txt_path) #读取一下记录数据与标签的txt
        self.data_list = np.random.permutation(data_list) #打乱一下txt

        self.transform = T.ToTensor() #预处理,将数据转化成[n, c, h, w]的形式,并归一化到[0,1]

    def __getitem__(self, index):
        sample = self.data_list[index] #读取txt中的一行,是一个数据对(image_path, label)
        image_path = sample.split(' ')[0]
        label = int(sample.split(' ')[-1])
        resized_image = read_image(image_path) #读取一下图像
        image = self.transform(resized_image) #转化成Tensor
        return image.float(), label 

    def __len__(self):
        return len(self.data_list)

如上所示,在定义数据接口时,通过重写__init__,__getitem__和__len__三个函数,实现读取数据的功能,该数据接口可用于简单的图像分类。然后在训练程序中,只需要使用torch.utils.data.DataLoader开始在每一次训练中对数据集进行读取与遍历就行

train_data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None)

其中,除去dataset,用的比较多的是batch_size,shuffle,num_workers和pin_memory这几个参数。batch_size指定了一个批次的数据容量,shuffle指定了是否打乱数据,num_workers指定了用多少个线程对数据进行读取,而pin_memory则指定了是否将数据放入显存。

训练程序

在搭建好网络,完善好数据读取接口后,我们就可以进行训练程序的搭建了。在进行训练程序的搭建时,与其他深度学习框架类似,都是先搭建图(模型),然后在图上进行训练。比如,一个简单的训练接口示意就如下所示:

from __future__ import print_function
import torch.nn
from torch.nn import DataParallel
import torch.optim
import torch.utils.data
import argparse
#import 所需的各种模块

parser = argparse.ArgumentParser(description="")
#添加各种自定义参数
args = parser.parse_args()


def main():

    device = torch.device("cuda")
    train_dataset = My_Dataset(txt_path) 
    train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                                    batch_size,
                                                    shuffle,
                                                    num_workers)
    model = My_Network()
    #若有预训练参数,可以载入预训练参数
    model.load_state_dict(torch.load(ckpt_path), strict=False)
    
    #定义损失,比如criterion = torch.nn.CrossEntropyLoss()
    criterion = My_Loss
    
    model.to(device)
    #可以将模型训练放在多张GPU上并行
    model = DataParallel(model)
    
    optimizer = torch.optim.SGD(model.parameters(), lr, weight_decay)
    
    model.train() #将模型置为训练模式
    
    for i in range(epoch):
        for ii, data in enumerate(train_data_loader):
            data_input, label = data
            data_input = data_input.to(device)
            label = label.to(device).long()
            output = model(data_input)
            loss = criterion(output, label)
            
            print("loss =: ", loss)
            
            optimizer.zero_grad() #首先梯度置零
            loss.backward() #然后求梯度
            optimizer.step() #通过梯度更新参数

            iters = i * len(train_data_loader) + ii
            
            if iters % save_interval == 0:
                torch.save(model.state_dict(), save_path)

if __name__ == "__main__":
    main()

如上所示,在进行训练时,首先进行四个步骤,即:

  1. 引用之前定义的数据读取接口
  2. 引用model,即网络,需要的话就载入预训练参数。
  3. 定义loss
  4. 定义优化器
    然后就可以开始在for循环中进行可训练参数更新了。

最后,在PyTorch训练程序中,还需要注意三点。
第一点是保存模型和加载预训练参数
在保存模型时,通常是保存模型的已训练参数(也可以保存整个模型),通过torch.save函数完成;在加载预训练参数时,通过model的父类,即torch.nn.Module的load_state_dict完成,里面的“strict”参数表示是否需要将预训练模型里面的参数与model里面的完全对齐。

第二点是使用GPU训练
如果需要使用GPU训练时,需要使用到to(device)函数,意思就是将数据或者模型放到GPU上面。

第三点是参数更新过程
在进行参数更新时,一共分成三步:

  1. 将可训练参数梯度置零 optimizer.zero_grad()
  2. 根据损失值求梯度 loss.backward()
  3. 更新可训练参数 optimizer.step()

第四点是将model置为训练模式
对应代码中的

model.train()

训练模式会对某些定义的网络模块有影响,比如使用dropout层,在训练时会被激活。

测试程序

在模型训练完毕之后,需要对模型进行测试。在测试时,与其他深度学习框架类似,还是通过先对模型进行搭建,载入已训练参数,将测试数据前传得到结果。比如,一个简单的测试接口示意就如下所示:

from __future__ import print_function
import torch.nn
from torch.nn import DataParallel
import torch.optim
import torch.utils.data
import argparse
#import 所需的各种模块

parser = argparse.ArgumentParser(description="")
#添加各种自定义参数
args = parser.parse_args()

def main():
    device = torch.device("cuda")

    model = My_Network()
    model = DataParallel(model)

    model.load_state_dict(torch.load(snapshot_path)) #载入参数
    model.to(device)

    model.eval() #将模型置为测试模式

    for eval_data_path in eval_data_list:
        eval_data = read_image(eval_data_path)
        data = torch.from_numpy(eval_data) #将data转化为Tensor
        data = data.to(device)
        output = model(data) #前传得到结果

        #自定义操作,比如计算精度

if __name__ == "__main__":
    main()

可以看到,测试程序比训练程序简单很多,主要先进行两个步骤:

  1. 定义网络
  2. 载入已训练参数

然后就可以读取测试样本进行前传了。需要注意的是,在进行模型的测试时,需要将model置为测试模式,对应代码中的

model.eval()

经过以上的示例,可见PyTorch程序比较规范,简洁与清晰。并且在PyTorch中还使用到了非常多的面向对象的规范,有许多操作都是通过继承Python类进行实现的

写在后面

到这里,本篇博客就接近尾声了。在本篇博客中,笔者只是简单地展示了PyTorch程序搭建的架构,这也仅仅是笔者的个人习惯,旨在展示PyTorch程序的执行过程。各位读者朋友可以根据自己的喜好进行自定制的PyTorch代码搭建。

在学习PyTorch的过程中,推荐大家多阅读pyTorch官方文档。也可以多阅读github上的优秀开源项目,比如HRNetmmdetection

欢迎阅读笔者后续博客,各位读者朋友的支持与鼓励是我最大的动力!

written by jiong
不忘初心,牢记使命!

猜你喜欢

转载自blog.csdn.net/jiongnima/article/details/103324839