PyTorch model construction

Several models constructed recording method:

Inherited Moduleclass structure model

ModuleIs the base class for all neural network module, which is the model we need to get through inheritance, usually we need to override Modulethe class of __init__functions and forwardfunction.

Examples

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Use Modulesubclasses

In Pytorch in the realization inherited from Modulecan easily construct the model class, there are Sequential, ModuleList, ModuleDictand so on

  • useSequential

    When the former model to compute each series is calculated as a simple layer, Sequentialthe class can be defined in a more simple manner model. This is the Sequentialobject of the class: it can receive ordered dictionary (OrderedDict) a series of sub-modules or sub-modules individually added as a parameter to Modulethe example, and the former model is calculated by one of these examples in order to add to the calculation.

    Here and achieve a Sequentialsimilar function MySequentialclass

    class MySequential(nn.Module):
    from collections import OrderedDict
    def __init__(self, *args):
        super(MySequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict): # 如果传入的是一个OrderedDict
            for key, module in args[0].items():
                self.add_module(key, module)  # add_module方法会将module添加进self._modules(一个OrderedDict)
        else:  # 传入的是一些Module
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)
    def forward(self, input):
        # self._modules返回一个 OrderedDict,保证会按照成员添加时的顺序遍历成员
        for module in self._modules.values():
            input = module(input)
        return input
    
  • useModuleList

    A sub-module in the list ( list) among
    ModuleListcan perform the same routine as the Python List append(), extend()operation, except that there are some ModuleListparameters of all modules will be automatically added to throughout the network

    Examples

    net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
    net.append(nn.Linear(256, 10)) # # 类似List的append操作
    print(net[-1])  # 类似List的索引访问
    print(net)

    Although Sequentialand ModuleListare a listing of network configuration, but there are two differences: ModuleListonly a list of stores various modules, there is no connection between these modules no order (so do not guarantee the adjacent input and output dimension matching layer), and no implement forwardfunction (need to implement). SequentialModule within required in sequential order to ensure that the size of the input and output matches adjacent layers, the internal forwardfunction has been achieved.

    ModuleListThe appearance can make the network more flexibility to define front propagation:

    class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
    
    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x
  • Use ModuleDict
    ModuleDictreceiving submodule a dictionary as input, and then add the access operation may be performed analogously to the dictionary:

    net = nn.ModuleDict({
    'linear': nn.Linear(784, 256),
    'act': nn.ReLU(),
    })
    net['output'] = nn.Linear(256, 10) # 添加
    print(net['linear']) # 访问
    print(net.output)
    print(net)
    # net(torch.zeros(1, 784)) # 会报NotImplementedError

    And ModuleList, the use of ModuleDictthe same need to define whenforward

reference:

  1. Depth hands-on science learning PyTorch version
  2. PyTorch official documents

Guess you like

Origin www.cnblogs.com/patrolli/p/11896776.html