Pytorch常用函数解析(一) Tensor 拼接

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/TH_NUM/article/details/81979346

torch模块下的数学操作符

1 . torch.numel() 返回一个tensor变量内所有元素个数,可以理解为矩阵内元素的个数

2 . torch.squeeze() 对于tensor变量进行维度压缩,去除维数为1的的维度。例如一矩阵维度为A*1*B*C*1*D,通过squeeze()返回向量的维度为A*B*C*D。squeeze(a),表示将a的维数位1的维度删掉,squeeze(a,N)表示,如果第N维维数为1,则压缩去掉,否则a矩阵不变

3 . torch.unsqueeze() 是squeeze()的反向操作,增加一个维度,该维度维数为1,可以指定添加的维度。例如unsqueeze(a,1)表示在1这个维度进行添加

4 . torch.stack(sequence, dim=0, out=None),做tensor的拼接。sequence表示Tensor列表,dim表示拼接的维度,注意这个函数和concatenate是不同的,torch的concatenate函数是torch.cat,是在已有的维度上拼接,而stack是建立一个新的维度,然后再在该纬度上进行拼接。
例子:

import  torch

a=torch.Tensor([[1,2,3],[4,5,6]])
b=torch.Tensor([[7,8,9],[10,11,12]])
d=torch.stack( (a,b) ,dim = 1)

print(d)

输出:

tensor([[[  1.,   2.,   3.],
         [  7.,   8.,   9.]],

        [[  4.,   5.,   6.],
         [ 10.,  11.,  12.]]])

5 . expand_as(a)这是tensor变量的一个内置方法,如果使用b.expand_as(a)就是将b进行扩充,扩充到a的维度,需要说明的是a的低维度需要比b大,例如b的shape是3*1,如果a的shape是3*2不会出错,但是是2*2就会报错了

猜你喜欢

转载自blog.csdn.net/TH_NUM/article/details/81979346