版权声明:王家林大咖2018年新书《SPARK大数据商业实战三部曲》清华大学出版,清华大学出版社官方旗舰店(天猫)https://qhdx.tmall.com/?spm=a220o.1000855.1997427721.d4918089.4b2a2e5dT6bUsM https://blog.csdn.net/duan_zhihua/article/details/82596292
Pytorch 学习(6):Pytorch中的torch.nn Convolution Layers 卷积层参数初始化
class Conv1d(_ConvNd):
......
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
kernel_size = _single(kernel_size)
stride = _single(stride)
padding = _single(padding)
dilation = _single(dilation)
super(Conv1d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _single(0), groups, bias)
def forward(self, input):
return F.conv1d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
参数初始化调用 _ntuple方法:
import collections
from itertools import repeat
def _ntuple(n):
def parse(x):
if isinstance(x, collections.Iterable):
return x
return tuple(repeat(x, n))
return parse
_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
_quadruple = _ntuple(4)
_ntuple是函数式编程高阶函数,_single = _ntuple(1)将n=1参数传入parse函数,返回parse函数,然后在_single(kernel_size)传入kernel_size参数,调用parse(kernel_size)方法,执行repeat(x, n)方法。
做个小测试:
import collections
from itertools import repeat
def _ntuple(n):
def parse(x):
if isinstance(x, collections.Iterable):
return x
return tuple(repeat(x, n))
return parse
_single = _ntuple(1)
print(_single(0))
_pair = _ntuple(2)
print(_pair(0))
#kernel_size=5
kernel_size=(3, 5)
kernel_size = _pair(kernel_size)
print(kernel_size)
_triple = _ntuple(3)
kernel_size=(3, 5, 2)
kernel_size = _triple(kernel_size)
print(kernel_size)
运行结果如下:
(0,)
(0, 0)
(3, 5)
(3, 5, 2)