二维矩阵
import torch
if __name__ == '__main__':
a = torch.tensor([[1,1,1],
[2,2,2],
[3,3,3]],dtype=torch.float)
b = torch.tensor([[1,0,1],
[2,0,2],
[3,0,3]],dtype=torch.float)
c = torch.cat( (a, b), 0)#按行拼接
d = torch.cat( (a, b), 1)#按列拼接
print(c)
print(d)
二维矩阵result:
tensor([[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.],
[1., 0., 1.],
[2., 0., 2.],
[3., 0., 3.]])
tensor([[1., 1., 1., 1., 0., 1.],
[2., 2., 2., 2., 0., 2.],
[3., 3., 3., 3., 0., 3.]])
三维矩阵
if __name__ == '__main__':
cls = torch.tensor([[1,1,0],
[1,0,1],
[0,1,1]],dtype=torch.float)
a = torch.tensor([ [ [1,2,3],
[2,5,6] ],
[ [3,0,2],
[5,1,0] ],
[ [3,0,2],
[5,1,0]] ],dtype=torch.float)
cls = cls.unsqueeze(1)#升维[3*3]-->[3*1*3]
new = torch.cat((cls, a), dim=1)#注意dim
print(new)
三维矩阵result:
tensor([[[1., 1., 0.],
[1., 2., 3.],
[2., 5., 6.]],
[[1., 0., 1.],
[3., 0., 2.],
[5., 1., 0.]],
[[0., 1., 1.],
[3., 0., 2.],
[5., 1., 0.]]])