Pytorch常用函数(更新中)

一些自己常用的pytorch函数整理

建议直接Ctrl+f搜索

torch.unsqueeze 维度增加

  • torch.unsqueeze(torch.Tensorf,axis)
  • 常用形式torch.unsqueeze(x,0)#最外维度+1
    这个函数功能等价于numpy的expand_dim
  • 在这里插入图片描述

torch.Tensor.mm矩阵相乘

  • 设a=torch.ones(1,2)
  • 设b=torch.rand(2,2)
  • 矩阵点乘直接a*b就行
  • 矩阵乘法就是a.mm(b)

torch.topk(a,k) 获取a中前k个最大的值和下标

  • 对于多维度的,比如二维,结果如下
  • 在这里插入图片描述
  • 返回的是两个 tensor 所以我们可以写如下代码
  • values,indexs=torch.topk(a,k)

torch.tensor几种初始化

  • #标准正太分布均值未0方差为1中随机抽取一组随机数
  • torch.randn(*sizes,dtype)

  • 均匀分布【0,1)中均匀分布中抽取一组随机数
  • torch.rand(*sizes,out=None)

  • 离散正太分布 means均值 std方差
  • torch.normal(means,std,size=(),dtype)

  • 线性间距向量均匀间隔
  • torch.linspace(start,end,steps=100)

torch.view()改变形状

在这里插入图片描述
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_37633207/article/details/108901333