如何将多个torch.Tensor存储下来并在新的纬度上进行合并?

output_list = []                        # 存放输出的list

for idlayer in range(self.num_layers):  # 每一次循环
    current_input = input[idlayer,...]
    current_output,current_hidden_state = self.cell(current_input,current_hidden_state)
    # 按照输出间隔将输出存储下来
    if idlayer % self.out_stride == self.out_stride - 1:
        output_list.append(current_output)

output = torch.stack(output_list, dim=0)    # 将输出合并成一个tensor

先用list.append存储下来
再用torch.stack接受list,并定义需要拼接的新维度即可。
current_output的纬度为[10,4,64,64]
n个current_output拼接的output的纬度为[n,10,4,64,64]

猜你喜欢

转载自blog.csdn.net/weixin_43905212/article/details/107484876
今日推荐