神经网络学习小记录33——pytorch中squeeze()和unsqueeze()函数的简单介绍

神经网络学习小记录33——pytorch中squeeze和unsqueeze函数的简单介绍

学习前言

经常看到在tf中看到squeeze,学会pytorch,结果刚入门就发现了这个函数,我决定弄懂它,顺便写篇文章水一下。
在这里插入图片描述

1、unsqueeze

其实unsqueeze的作用和np.expand_dims的作用非常类似,都是为矩阵增加一个维度,unsqueeze是为了pytorch中的tensor增加一个维度。

函数声明为:

torch.unsqueeze(dim)

其中dim表示需要在哪一维增加一个维度,dim必须被指定。

试验示例:

import torch
before_unsqueeze = torch.arange(12).reshape([3,4])
print(before_unsqueeze.data)
after_unsqueeze = before_unsqueeze.unsqueeze(1)
print(after_unsqueeze.data)

结果:

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
tensor([[[ 0,  1,  2,  3]],

        [[ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11]]])

2、squeeze

其实squeeze的作用和tf.squeeze的作用非常类似,二者都是将被操作目标中维度为1的部分去除。

函数声明为:

torch.squeeze(dim=None)

其中dim表示需要在哪一维去掉一个维度,如果不指定则自动寻找,如果指定则当指定的维度为1时去掉,如果不为1则不改变。

试验示例:

import torch
before_squeeze = torch.arange(12).reshape([1,3,4])
print(before_squeeze.data)
# 指定维度
after_squeeze = before_squeeze.squeeze(1)
print(after_squeeze.data)
# 自动去除
after_squeeze = before_squeeze.squeeze()
print(after_squeeze.data)

结果:

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]]])
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]]])
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
发布了167 篇原创文章 · 获赞 112 · 访问量 24万+

猜你喜欢

转载自blog.csdn.net/weixin_44791964/article/details/103657577