pytorch 张量的操作

Tensor Operation

1. 张量拼接与切分

torch.cat():将张量的维度dim进行拼接,不会扩张张量的维度

                    如dim=0,则两个向量将在第0维进行拼接:(3,4)concat(3,4)-->(6,4)

torch.stack():在新创建的维度dim上进行拼接

                    如dim=0,则(3,4)stack(3,4)-->(2,3,4)

                    如dim=2,则(3,4)stack(3,4)-->(3,4,2)

torch.chunk():将张量维度dim进行平均切分,返回张量列表。

                         若不能整除,最后一份张量小于其他张量

   chunk:要切的份数

torch.split():将张量按维度dim进行切分,返回张量列表

split_size_or_sections:当为int时,表示每一份的长度;当为list时,按list元素切分

2. 张量索引

torch.index_select():在维度dim上,按index索引数据。依index索引数据拼接的张量。

index:是dtype为torch.long的tensor

t = torch.randint(0, 9, size=(3, 3))
idx = torch.tensor([0, 1], dtype=torch.long)
t_select = torch.index_select(t, dim=0, index=idx)

torch.mask_select():按mask中的True进行索引,返回一维张量。

t = torch.randint(0, 9, size=(3, 3))
mask = t.ge(5)  #  >=5 return true; else false
t_select = torch.masked_select(t, mask)

3. 张量变换

torch.reshape():变换张量形状。当张量在内存中是连续时,新张量与input共享数据内存。

torch.transpose():变换张量的两个维度

torch.t():二维张量的转置

torch.squeeze():压缩长度为1的维度。

  dim为None时,移除所有长度为1的轴;若指定维度,当且仅当该轴长度为1时,可以被移除

torch.unqueeze():依据dim扩展维度

Tensor Math Operation

torch.add():逐元素计算input+alpha+other

torch.addcdiv():

                        

torch.addcmul():

                           

                         

发布了55 篇原创文章 · 获赞 22 · 访问量 4万+

猜你喜欢

转载自blog.csdn.net/li_k_y/article/details/103965574