pad_sequence在pytorch中的使用

>>> from torch.nn.utils.rnn import pad_sequence
>>> input_x =[[1,2,3],[4,5,6,7,8],[8,9]]
>>> norm_data_pad = pad_sequence([torch.from_numpy(np.array(x)) for x in input_x], batch_first=True).float()
>>> norm_data_pad
tensor([[1., 2., 3., 0., 0.],
       	[4., 5., 6., 7., 8.],
        [8., 9., 0., 0., 0.]])
>>> norm_data_pad.dtype
torch.float32

猜你喜欢

转载自blog.csdn.net/m0_37586991/article/details/89470658
今日推荐