Supongamos que tengo la lista de memoria list_of_tensors = [tensor1, tensor2, tensor3, tensor4]
. Cada elemento es un tensor pytorch de forma (1, 1, 84, 84)
.
Quiero que concatenar lista de tensores para obtener un tensor de la forma (4, 1, 84, 84)
. torch.cat(TT, dim=0)
seguramente podría permitir que haga eso. TT
debe ser una tupla de tensor, por lo que torch.cat(*list_of_tensors, dim=0)
o torch.cat((*list_of_tensors), dim=0)
no funcionará.
¿Cómo puedo usar list_of_tensors
y torch.cat(???, dim=0)
para crear un nuevo tensor de la forma(4, 1, 84, 84)
Se puede utilizar la pila , y retirar el excedente con dimensión de aterrizaje
c = (torch.stack(list_of_tensors,dim=1)).squeeze(0)
ahora c.shape es (4, 1, 84, 84)
Se puede encontrar una explicación aquí: https://discuss.pytorch.org/t/how-to-turn-a-list-of-tensor-to-tensor/8868/6