torch.functional 和torch.nn.functional,torch.nn和torch.nn.functional的区别

torch.functional 和torch.nn.functional,torch.nn和torch.nn.functional的区别

torch.functional 和torch.nn.functional

torch.nn中的函数在torch.nn.funtional中都有一个与之对应的函数。

二者的区别在于:

torch. nn 中实现的都是一个个的类,是用class xx()定义的,而 nn.functional中的函数,就是是纯函数,由def xx( )定义。

这样弄有什么用呢?

一般用def xx()定义的函数是就是一个运算公式,也可以看作一个简单的工具,可以根据你的需要输入数据产生你需要的输出。
但是在深度学习中,你构建的网络会有很多权重,在训练网络的时候这些权重是在不断更新的,不可能每训练一次你就用一大堆的参数来保存这些权重信息然后等到下次调用函数就再传进去再更改,这样就会非常麻烦,所以就会采用类的方式,以确保能在参数发生变化时仍能使用我们之前定好的运算步骤。
像ReLu层,池化层,这类只是运用它的函数功能,并不要保存在权重信息中,但是像卷积层、全连接层这些就含有权重信息。如果所有的层我们都用torch.nn.functional来定义,那么我们需要将卷积层和全连接层中的weights、bias全部手动来写,这样是非常不方便的。

例如:

torch.nn.Conv2d()和torch.nn.functional.conv2d()
我们看到它们的源码:

torch.nn.Conv2d
class Conv2d(_ConvNd):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1,
                 bias=True, padding_mode='zeros'):
                 
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        
        super(Conv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _pair(0), groups, bias, padding_mode)

        def forward(self, input):
            return self.conv2d_forward(input, self.weight)    
torch.nn.functional.conv2d()
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
                groups=1):
                
    if input is not None and input.dim() != 4:
        raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format(input.dim()))
        
    f = _ConvNd(_pair(stride), _pair(padding), _pair(dilation), False,
                        _pair(0), groups, torch.backends.cudnn.benchmark,
            torch.backends.cudnn.deterministic,torch.backends.cudnn.enabled)  
    return f(input, weight, bias)

对比上面的代码,torch.nn.Conv2d是一个类,而torch.nn.functiona.conv2d()是一个函数,并且torch.nn.Conv2d中的forward()函数是由torch.nn.functiona.conv2d()实现的(在Module类中有一个__call__的方法实现了forward的调用)

猜你喜欢

转载自blog.csdn.net/weixin_46088099/article/details/125927918