pytorch's processing of variable-length sequences

The main parts used are: Dateset, Dateloader, MSELoss, PackedSequence, pack_padded_sequence, pad_packed_sequence

The model contains LSTM modules.

Refer to the following two blog posts for a summary.

http://www.cnblogs.com/lindaxin/p/8052043.html#commentform

https://blog.csdn.net/lssc4205/article/details/79474735

 

When using Dateset to build a dataset, in the __getitem__ function

 

def __getitem__(self, index):
    '''
get original data
The code to get the raw data is omitted here
input_data,output_data
Data shape is seq_length * feature_dim
    ''' 
#The current seq_length is less than the longest data length in all data, then fill 0 to the same length. 
    ori_length = input_data.shape[0]
     if ori_length < self.max_len:
        npi = np.zeros(self.input_feature_dim, dtype=np.float32)
        npi = np.tile(npi, (self.max_len - ori_length,1))
        input_data = np.row_stack((input_data, npi))
        npo = np.zeros(self.output_feature_dim, dtype=np.float32)
        npo = np.tile (npo, (self.max_len - ori_length, 1 ))
        output_data = np.row_stack((output_data, npo))
    return input_data, output_data, ori_length, input_data_path

 In the model, in the implementation of forward, it is necessary to use pack_padded_sequence before LSTM and pad_packed_sequence after LSTM, and operations such as order restoration are also involved in the middle.

def forward(self, input_x, length_list, hidden= None):
     if hidden is None:
         #The batch_size in the configuration is not used here, but the batch_size in the input_x is taken directly to prevent the batch_size of the last_batch from being the one in the configuration, causing a bug 
        h_0 = input_x.data.new(self.directional* self.layer_num, input_x.shape[0], self.hidden_dim).fill_(0).float()
        c_0 = input_x.data.new(self.directional*self.layer_num, input_x.shape[0], self.hidden_dim).fill_(0).float()
    else:
        h_0, c_0 = hidden
    '''
Omit other parts of the model and go directly to the operations before and after LSTM
    '''
    _, idx_sort = torch.sort(length_list, dim=0, descending=True)
    _, idx_unsort = otrch.sort(idx_sort, dim=0)

    input_x = input_x.index_select(0, Variable(idx_sort))
    length_list = list(length_list[idx_sort])
    pack = nn_utils.rnn.pack_padded_sequence(input_x, length_list, batch_first=self.batch_first)
    output, hidden = self.BiLSTM(pack, (h0, c0))
    un_padded = nn_utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first)
    un_padded = un_padded[0].index_select(0, Variable(idx_unsort))
 #Un_padded has been restored at this time, and the 0-fill is completed, and the sequence length of the 0-fill at this time is the longest length of the current batch, and Not the global longest length in Dateset! 
# So in the main train function, the seq of the label should also be processed
return un_padded

 In the main train, the label needs to be truncated accordingly. When calculating the loss, the reduce parameter of MSELoss should be set to false, let the loss function return a loss matrix, and then construct a mask matrix mask. Multiply and sum to get true loss

def train(**kwargs):
  train_data = my_dataset()
  train_dataloader = DataLoader(train_data, opt.batch_size, shuffle=True, num_workers=opt.num_workers)
  model = getattr(models, opt.model)(batchsize=opt.batch_size)
  criterion = torch.nn.MSELoss(reduce=False)
  lr = opt.lf
  optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=opt.weight_decay)
  for epoch in range(opt.start_epoch, opt.max_epoch):
    for ii, (data, label, length_list,_) in tqdm(enumerate(train_dataloader)):
      cur_batch_max_len = length_list.max()
      data = Variable(data)
      target = Variable(label)

      optimizer.zero_grad()
      score = model(data, length_list)
      loss_mat = criterion(score, target)
      list_int = list(length_list)
      mask_mat = Variable(t.ones(len(list_int),cur_batch_max_len,opt.output_feature_dim))
      num_element = 0
      for idx_sample in range(len(list_int)):
        num_element += list_int[idx_sample] * opt.output_feature_dim
        if list_int[idx_sample] != cur_batch_max_len:
          mask_mat[idx_sample, list[idx_sample]:] = 0.0

      loss = (loss_mat * mask_mat).sum() / num_element
      loss.backward()
      optimizer.step()


 

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325327696&siteId=291194637
Recommended