The difference between torch.functional and torch.nn.functional, torch.nn and torch.nn.functional

The difference between torch.functional and torch.nn.functional, torch.nn and torch.nn.functional

torch.functional 和torch.nn.functional

The functions in torch.nn have a corresponding function in torch.nn.funtional.

The difference between the two is:

What is implemented in torch.nn are classes one by one, which are defined by class xx(), while the functions in nn.functional are pure functions and are defined by def xx().

What's the use of doing this?

Generally, the function defined with def xx() is an operation formula, and it can also be regarded as a simple tool that can input data according to your needs and produce the output you need.
But in deep learning, the network you build will have a lot of weights. When training the network, these weights are constantly updated. It is impossible for you to use a lot of parameters every time you train. It would be very troublesome to save this weight information and wait until the next time the function is called before passing it in and then changing it. Therefore, we will use a class method to ensure that we can still use the calculation steps we set before when the parameters change.
Things like ReLu layers and pooling layers only use their functions and are not stored in weight information, but convolutional layers and fully connected layers contain weight information. If we use torch.nn.functional to define all layers, then we need to manually write all the weights and biases in the convolutional layer and fully connected layer, which is very inconvenient.

For example:

torch.nn.Conv2d() and torch.nn.functional.conv2d()
We see their source code:

torch.nn.Conv2d
class Conv2d(_ConvNd):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1,
                 bias=True, padding_mode='zeros'):
                 
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        
        super(Conv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _pair(0), groups, bias, padding_mode)

        def forward(self, input):
            return self.conv2d_forward(input, self.weight)    
torch.nn.functional.conv2d()
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
                groups=1):
                
    if input is not None and input.dim() != 4:
        raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format(input.dim()))
        
    f = _ConvNd(_pair(stride), _pair(padding), _pair(dilation), False,
                        _pair(0), groups, torch.backends.cudnn.benchmark,
            torch.backends.cudnn.deterministic,torch.backends.cudnn.enabled)  
    return f(input, weight, bias)

Comparing the above code, torch.nn.Conv2d is a class, and torch.nn.functiona.conv2d() is a function, and the forward() function in torch.nn.Conv2d is composed of torch.nn.functiona.conv2d( ) implemented (there is a __call__ method in the Module class that implements the forward call)

おすすめ

転載: blog.csdn.net/weixin_46088099/article/details/125927918