Pytorch中RNN LSTM的input(重点理解batch_size/time_steps)

原文链接:Pytorch中如何理解RNN LSTM的input(重点理解seq_len/time_steps) - 阿矛布朗斯洛特的文章 - 知乎

在建立时序模型时,若使用keras,我们在Input的时候就会在shape内设置好sequence_length(后面均用seq_len表示),接着便可以在自定义的data_generator内进行个性化的使用。这个值同时也就是time_steps,它代表了RNN内部的cell的数量,有点懵的朋友可以再去看看RNN的相关内容:

CSDN-专业IT技术社区-登录​blog.csdn.net

所以设定好这个值是很重要的事情,它和batch_size,feature_dimensions(在词向量的时候就是embedding_size了)构成了我们Input的三大维度,无论是keras/tensorflow,亦或是Pytorch,本质上都是这样。

牵涉到这个问题是听说Pytorch自由度更高,最近在做实验的时候开始尝试用Pytorch了,写完代码跑通后,过了段时间才意识到,好像没有用到seq_len这个参数,果然是Keras用多了的后遗症?(果然是博主比较蠢!)检查了一下才发现,DataLoader生成数据的时候,默认生成为(batch_size, 1, feature_dims)。(这里无视了batch_size和seq_len的顺序,在建立模型的时候,比如nn.LSTM有个batch_first的参数,它决定了谁前谁后,但这不是我们这里讨论的重点)。

所以我们的seq_len/time_steps被默认成了1,这是在使用Pytorch的时候容易发生的问题,由于Keras先天的接口设置在Input时就让我们无脑设置seq_len,这反而不会成为我们在使用Keras时发生的问题,而Pytorch没有让我们在哪里设置这个参数,所以一不小心可能就忽视了。

好了,接下来就来找找问题怎么出现的,又怎么解决。果然问题还是出现在了DataLoader,在__getitem__(self, index)这里,决定了我们如何取出数据,在这里我发现我自己还是一条一条取的。

    def __getitem__(self, idx):
        return self.input[idx], self.target[idx]

完全没有意识到Torch需要在这里进行seq_len的修饰,接下来该怎么解决呢,首先看看我们希望的“取数据方式”。

假如我们有id = 1,2,3,4,5,6,7,8,9,10一共10个sample。

假设我们设定seq_len是3。

那现在数据的形式应该为1-2-3,2-3-4,3-4-5,4-5-6,5-6-7,6-7-8,7-8-9,8-9-10,9-10-0,10-0-0(最后两个数据不完整,进行补零)的10个数据。这是我们真正有了seq_len这个参数,带有“循环”这个概念,要放进RNN等序列模型中进行处理的数据。所以之前说seq_len被我默认弄成了1,那就是把1,2,3,4,5,6,7,8,9,10这样形式的10个数据分别放进了模型训练,自然在DataLoader里取数据的size就成了(batch_size, 1, feature_dims),而我们现在取数据才会是(batch_size, 3, feature_dims)。

假设我们设定batch_size为2。

那我们取出第一个batch为1-2-3,2-3-4。这个batch的size就是(2,3,feature_dims)了。我们把这个玩意儿喂进模型。

接下来第二个batch为3-4-5,4-5-6。

第三个batch为5-6-7,6-7-8。

第四个batch为7-8-9,8-9-10。

第五个batch为9-10-0,10-0-0。我们的数据一共生成了5个batch。

可以看到,num_batch = num_samples / batch_size(这里没有进行向上或向下取整是因为在某些地方可以设置是否需要那些不完整的被进行补零的batch),seq_len仍然不会影响最后生成的batch的数量,只有batch_size和num_samples会对batch的数量进行影响。

可能忽略了feature_dims仅凭借id来代表数据难以理解,那换种方式看看,假如feature_dims为6:

data_ = [[1, 10, 11, 15, 9, 100],
         [2, 11, 12, 16, 9, 100],
         [3, 12, 13, 17, 9, 100],
         [4, 13, 14, 18, 9, 100],
         [5, 14, 15, 19, 9, 100],
         [6, 15, 16, 10, 9, 100],
         [7, 15, 16, 10, 9, 100],
         [8, 15, 16, 10, 9, 100],
         [9, 15, 16, 10, 9, 100],
         [10, 15, 16, 10, 9, 100]]

仍然设置seq_len为3,batch_size为2。

这时我们的第一个batch为

tensor([[[  1.,  10.,  11.,  15.,   9., 100.],
         [  2.,  11.,  12.,  16.,   9., 100.],
         [  3.,  12.,  13.,  17.,   9., 100.]],

        [[  2.,  11.,  12.,  16.,   9., 100.],
         [  3.,  12.,  13.,  17.,   9., 100.],
         [  4.,  13.,  14.,  18.,   9., 100.]]])

这就是刚刚的1-2-3,2-3-4嘛。

而最后一个batch为

tensor([[[  9.,  15.,  16.,  10.,   9., 100.],
         [ 10.,  15.,  16.,  10.,   9., 100.],
         [  0.,   0.,   0.,   0.,   0.,   0.]],

        [[ 10.,  15.,  16.,  10.,   9., 100.],
         [  0.,   0.,   0.,   0.,   0.,   0.],
         [  0.,   0.,   0.,   0.,   0.,   0.]]])

最后放上Demo,由于每个人的数据甚至loss等等都不一样,不过大家应该能够从Demo中得到一些如何针对自己的Project进行修改的点子。

# -*- coding: utf-8 -*-

import torch
import torch.utils.data as Data
import torch.nn as nn
import torchvision.transforms as transforms
import numpy as np
###   Demo dataset

data_ = [[1, 10, 11, 15, 9, 100],
         [2, 11, 12, 16, 9, 100],
         [3, 12, 13, 17, 9, 100],
         [4, 13, 14, 18, 9, 100],
         [5, 14, 15, 19, 9, 100],
         [6, 15, 16, 10, 9, 100],
         [7, 15, 16, 10, 9, 100],
         [8, 15, 16, 10, 9, 100],
         [9, 15, 16, 10, 9, 100],
         [10, 15, 16, 10, 9, 100]]


###   Demo Dataset class

class DemoDatasetLSTM(Data.Dataset):

    """
        Support class for the loading and batching of sequences of samples

        Args:
            dataset (Tensor): Tensor containing all the samples
            sequence_length (int): length of the analyzed sequence by the LSTM
            transforms (object torchvision.transform): Pytorch's transforms used to process the data
    """

    ##  Constructor
    def __init__(self, dataset, sequence_length=1, transforms=None):
        self.dataset = dataset
        self.seq_len = sequence_length
        self.transforms = transforms

    ##  Override total dataset's length getter
    def __len__(self):
        return self.dataset.__len__()

    ##  Override single items' getter
    def __getitem__(self, idx):
        if idx + self.seq_len > self.__len__():
            if self.transforms is not None:
                item = torch.zeros(self.seq_len, self.dataset[0].__len__())
                item[:self.__len__()-idx] = self.transforms(self.dataset[idx:])
                return item, item
            else:
                item = []
                item[:self.__len__()-idx] = self.dataset[idx:]
                return item, item
        else:
            if self.transforms is not None:
                return self.transforms(self.dataset[idx:idx+self.seq_len]), self.transforms(self.dataset[idx:idx+self.seq_len])
            else:
                return self.dataset[idx:idx+self.seq_len], self.dataset[idx:idx+self.seq_len]


###   Helper for transforming the data from a list to Tensor

def listToTensor(list):
    tensor = torch.empty(list.__len__(), list[0].__len__())
    for i in range(list.__len__()):
        tensor[i, :] = torch.FloatTensor(list[i])
    return tensor

###   Dataloader instantiation

# Parameters
seq_len = 3
batch_size = 2
data_transform = transforms.Lambda(lambda x: listToTensor(x))

dataset = DemoDatasetLSTM(data_, seq_len, transforms=data_transform)
data_loader = Data.DataLoader(dataset, batch_size, shuffle=False)

for data in data_loader:
    x, _ = data
    print(x)
    print('\n')

猜你喜欢

转载自blog.csdn.net/ch206265/article/details/106979744