pytorch torch.nn函数学习记录

1.torch.nn.Embedding

官网介绍

torch.nn.Embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, max_norm: Optional[float] = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, sparse: bool = False, _weight: Optional[torch.Tensor] = None)
输入:(*)任意形状的LongTensor包含要提取的索引
输出:( *,H)*是输入大小,H是embedding_dim

参数说明:

num_embeddings (int) – size of the dictionary of embeddings
嵌入s的字典的大小
embedding_dim (int) – the size of each embedding vector
每个嵌入向量的大小
padding_idx (int, optional) – If given, pads the output with the embedding vector at padding_idx (initialized to zeros) whenever it encounters the index.
如果给定,则在 遇到索引时,将输出嵌入矢量padding_idx(初始化为零)
max_norm (float, optional) – If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm.
如果给定,则将范数大于的每个嵌入向量max_norm 重新归一化为norm max_norm
norm_type (float, optional) – The p of the p-norm to compute for the max_norm option. Default 2.
为该max_norm选项计算的p范数的p 。默认2
scale_grad_by_freq (boolean, optional) – If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default False.
如果给定,将按小批量中单词频率的倒数来缩放梯度。默认False。
sparse (bool, optional) – If True, gradient w.r.t. weight matrix will be a sparse tensor. See Notes for more details regarding sparse gradients.
如果为True,则梯度wrtweight矩阵将为稀疏张量。有关稀疏渐变的更多详细信息,请参见注释。

embed = torch.nn.Embedding(6,5,padding_idx = 3) # 6个词,每个词5维,索引为3的置为0
x = torch.LongTensor([[1,2,3],[1,3,5]])
print(embed(x))
tensor([[[-1.3751, -3.0215, -1.3973, -0.3610, 1.6760],
[ 0.1496, 0.3810, -1.4765, 0.7070, 0.0221],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], 索引为3
[[-1.3751, -3.0215, -1.3973, -0.3610, 1.6760],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], 索引为3
[ 1.2162, -0.5878, 0.2110, -0.3564, -1.6092]]],
grad_fn=<EmbeddingBackward>)
print(embed(x).size())
torch.Size([2, 3, 5])

实际任务中:
总结一下一般任务中的流程:

首先将单词转成字典的形式,英语中通常用空格分隔,所以可以直接建立词典索引结构。类似于:
dic={
‘i’:1,
‘am’:2,
‘a’:3,
‘student’:4,
‘like’:5,
‘apple’:6
}
这样的形式。如果是中文的话,首先进行分词操作。

然后在以句子为list,为每个句子建立索引结构,list[[sentence1],[sentence2]]。以上面字典的索引来说,最终建立的就是[[1,2,3,4],[1,5,6]]。这样长短不一的句子。

接下来要进行padding的操作。由于tensor结构中都是等长的,所以要对上面那样的句子做padding操作后再利用nn.Embedding来进行词的初始化。padding后的可能是这样的结构[[1,2,3,4],[1,5,6,0]]。其中0作为填充。(注意:由于在NMT任务中肯定存在着填充问题,所以在embedding时一定存在着第三个参数,让某些索引为下的值为0,代表无实际意义的填充)

2. torch.nn.Parameters

torch官网介绍

torch.nn.Parameter
A kind of Tensor that is to be considered a module parameter.
Parameters are Tensor subclasses, that have a very special property when used with Module s - when they’re assigned as Module attributes they are automatically added to the list of its parameters, and will appear e.g. in parameters() iterator. Assigning a Tensor doesn’t have such effect. This is because one might want to cache some temporary state, like last hidden state of the RNN, in the model. If there was no such class as Parameter, these temporaries would get registered too.

Parameters
data(Tensor)-parameter tensor
requires_grad

简单理解就是,torch.nn.Parameter继承自torch.Tensor的子类,其主要作用是作为nn.Module中的可训练参数使用。它与torch.Tensor的区别就是nn.Parameter会自动被认为是module的可训练参数,即加入到parameter()这个迭代器中去。
注意到,nn.Parameter的对象的requires_grad属性的默认值是True,即是可被训练的,这与torth.Tensor对象的默认值相反。
在nn.Module类中,pytorch也是使用nn.Parameter来对每一个module的参数进行初始化的。

更多资料:
https://www.jianshu.com/p/d8b77cc02410
https://blog.csdn.net/qq_28753373/article/details/104179354

3.torch.nn.Identity()

torch官网介绍

torch.nn.Identity
A placeholder identity operator that is argument-insensitive.

Parameters

  • args - any argument(unused)
  • kwargs - any keyword argument (unused)

这个函数建立一个输入模块,什么都不做,通常用在神经网络的输入层。
多个输入可以在神经网络搭建中起到很好的作用,相当于一个容器,把输入都保留下来了。
比如LSTM,因为LSTM是循环网络,需要保存上一次的信息,nn.Identity()能够很好的保留信息。

举例

>>> input = torch.randn(128,20)
>>> m = nn.Identity(54,unused_argument=0.1, unused_argument2=False)
>>> output = m(input)
>>> output == input
tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])

猜你喜欢

转载自blog.csdn.net/eight_Jessen/article/details/112432032