PyTorch:stack + reshape与cat之间的异同

value_query = torch.stack([prev_query, now_query], 1)
.reshape(bs*2, num_query, -1).permute(1, 0, 2)

value_query1 = torch.stack([prev_query, now_query], 2)
.reshape(num_query, bs*2, -1)

#而非
value_query1 = torch.stack([prev_query, now_query], 1)
.reshape(num_query, bs*2, -1)

猜你喜欢

转载自blog.csdn.net/DragonGirI/article/details/126980761