pytorch学习笔记(三)数据的拼接、分割与运算

一、前言

        前文简单的介绍了tensor的索引、切片等操作。本文主要介绍数据的拼接与分割以及数学运算。

二、数据的拼接

[In] a = torch.rand(4,32,8)
[In] b = torch.rand(5,32,8)

torch.cat()   #需要合并的维度值可以不同,其他维度必须完全相同
[In] torch.cat([a,b],dim=0).shape
[Out] torch.Size([9,32,8])         

torch.stack()    #创建一个新维度,两组数据的其他维度需要相同
[In] a = torch.rand(4,3,16,32)
[In] b = torch.rand(4,3,16,32)
[In] torch.stack([a,b],dim=2).shape
[Out] torch.Size([4,3,2,16,32])

torch.split() #与cat相反,用于分割维度
[In] a = torch.rand(2,32,8)
[In] c,d = a.split([1,1],sim=0)
[In] c.shape , d.shape
[Out] torch.Size([1,32,8]),torch.Size([1,32,8])

三、数据运算

#加法
torch.add() or +
#数据维度相同,或者某一维度为1
[In] a = torch.rand(3,4)
[In] b = torch.rand(4)
[In] a+b
[In] torch.add(a,b)
#注意:torch.sum()与torch.add()的区别,前者用于压缩数据的某一维度
[In] torch.sum(a,dim=0) #dim不赋值则对数据中所有元素求和

#减法
torch.sub() or -

#乘法
点乘:* or torch.mul()
矩阵乘法: @ = torch.matmul() or (torch.mm()只能二维矩阵)

#除法
除法:/ or torch.div()  #注意:torch.div的输入应为浮点型

#乘方 
torch.pow() or **2

#开方
torch.sqrt()

#开方取倒数
torch.rsqrt()

#指数e
torch.exp()

#对数 ln
torch.log()

torch.floor()  #向下取整
torch.ceil()   #向上取整
torch.trunc()  #取整数部分
torch.frac()   #取小数部分

torch.any()    #数据中任一元素为True,则返回True
torch.all()    #数据中所有元素为True,则返回True

torch.max()    #取某一维度最大值或者整个数据中的最大值以及其索引
torch.maximum() #比较两个相同形状张量的元素大小,取较大值,输出与输入相同形状
torch.min()
torch.minimum()



四、总结

        整个数据的运算有很多函数,有时候用到时忘记咋用了,或者长啥样,为了防止到处搜,这次直接放到我的博客中记录下来。

猜你喜欢

转载自blog.csdn.net/qq_40691868/article/details/122870383
今日推荐