【Pytorch中的常用操作】

Pytorch中的常用函数与操作

PYtorch中的各种函数就像英语单词一样,见得多用得多就慢慢掌握了,这里以DQN代码为例,记录我经常用的和碰见的函数方法。(还会有一些python操作)

.detach().cpu()

out = model(inputs)
ls.append(out.detach().cpu().numpy())

detach阻断反向传播的,经过detach()方法后,变量仍然在GPU上,再利用.cpu()将数据移至CPU中进行后续操作,如tensor变量转numpy。

关于梯度的一些问题12

np.array与np.ndarray的区别

import numpy as np

# numpy.array() 和 numpy.ndarray()的区别?
mat1 = np.array([[1,2,3],[4,5,6]])
print("mat1 data:{}".format(mat1))
print("mat1 type:{}".format(type(mat1)))
print("mat1 dtype:{}".format(mat1.dtype))

mat2 = np.ndarray(shape=(2,3), dtype=np.int32)
print("mat2 data:{}".format(mat2))
print("mat2 type:{}".format(type(mat2)))
print("mat2 dtype:{}".format(mat2.dtype))

>>>output
mat1 data:[[1 2 3]
 [4 5 6]]
mat1 type:<class 'numpy.ndarray'>
mat1 dtype:int32
mat2 data:[[ -153199152         440           0]
 [          0      131074 -2147483648]]
mat2 type:<class 'numpy.ndarray'>
mat2 dtype:int32

ndarray是一个类,其默认构造函数是ndarray()。
array是一个函数,便于创建一个ndarray对象。
np.ndarray()构造函数相对更low-level一些,使用默认构造函数创建的ndarray对象的数组元素是随机值,而numpy提供了一系列的创建ndarray对象的函数,array()就是其中的一种;通常使用这些上层一点的函数来构造ndarray对象会更方便一些3

torch.load

torch.save(net.state_dict(), 'test.pth')  # save的是net的state_dict
net.load_state_dict(torch.load('test.pth'))  # 加载的也是state_dict

因为state_dict本质上Python字典对象,所以可以很好地进行保存、更新、修改和恢复操作(python字典结构的特性),从而为PyTorch模型和优化器增加了大量的模块化。

.modules()和.children()

model.modules()能够迭代地遍历模型的所有子层,而model.children()只会遍历模型的子层。4
注意这两实现的时候用的是set,相同的网络只输出一次,但是使用nn.Sequential就不会有这种困扰

    def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True):
        r"""Returns an iterator over all modules in the network, yielding
        both the name of the module as well as the module itself.

        Args:
            memo: a memo to store the set of modules already added to the result
            prefix: a prefix that will be added to the name of the module
            remove_duplicate: whether to remove the duplicated module instances in the result
                or not

        Yields:
            (string, Module): Tuple of name and module

        Note:
            Duplicate modules are returned only once. In the following
            example, ``l`` will be returned only once.

        Example::

            >>> l = nn.Linear(2, 2)
            >>> net = nn.Sequential(l, l)
            >>> for idx, m in enumerate(net.named_modules()):
                    print(idx, '->', m)

            0 -> ('', Sequential(
              (0): Linear(in_features=2, out_features=2, bias=True)
              (1): Linear(in_features=2, out_features=2, bias=True)
            ))
            1 -> ('0', Linear(in_features=2, out_features=2, bias=True))

        """

        if memo is None:
            memo = set()
        if self not in memo:
            if remove_duplicate:
                memo.add(self)
            yield prefix, self
            for name, module in self._modules.items():
                if module is None:
                    continue
                submodule_prefix = prefix + ('.' if prefix else '') + name
                for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
                    yield m

关键词yield的解释:5

题外话:foo函数的命名由来

在计算机程序设计与计算机技术的相关文档中,术语foobar是一个常见的无名氏化名,常被作为“伪变量”使用。 从技术上讲,“foobar”很可能在1960年代至1970年代初通过迪吉多的系统手册传播开来。另一种说法是,“foobar”可能来源于电子学中反转的foo信号;这是因为如果一个数字信号是低电平有效,那么在信号标记上方一般会标有一根水平横线,而横线的英文即为“bar”。在《新黑客辞典》中,还提到“foo”可能早于“FUBAR”出现。http://zh.wikipedia.org/zh-cn/

permute和reshape/view的区别

permute作用为调换Tensor的维度,参数为调换的维度。例如对于一个二维Tensor来说,调用tensor.permute(1,0)意为将1轴(列轴)与0轴(行轴)调换,相当于进行转置。使用view或者reshape,得到的tensor并不是转置的效果,而是相当于将原tensor的元素按行取出,然后按行放入到新形状的tensor中。6

In [20]: a              
Out[20]:                
tensor([[0, 1, 2],      
        [3, 4, 5]])     
                        
In [21]: a.permute(1,0) 
Out[21]:                
tensor([[0, 3],         
        [1, 4],         
        [2, 5]])        


In [22]: a.reshape(3,2) 
Out[22]:                
tensor([[0, 1],         
        [2, 3],         
        [4, 5]])        
                        
In [23]: a.view(3,2)    
Out[23]:                
tensor([[0, 1],         
        [2, 3],         
        [4, 5]])        

可以理解为,对于一个高维的Tensor执行permute,我们没有改变数据的相对位置,而只是旋转了一下这个(超)立方体。或者也可以说,改变了我们对这个(超)立方体的“观察角度”而已。

.parameters()

#网络参数数量
def get_parameter_number(net):
    total_num = sum(p.numel() for p in net.parameters())
    trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print('Total:{}, Trainable:{}'.format(total_num, trainable_num))

.diag()

import torch
 
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(a)
# output:
# tensor([[1, 2, 3],
#        [4, 5, 6],
#        [7, 8, 9]])
 
print(torch.diag(a))
# output:
# tensor([1, 5, 9])
 
print(torch.diag(a, 1))
# output:
# tensor([2, 6])
 
print(torch.diag(a, -1))
# output:
# tensor([4, 8])
 
print(torch.diag(a, 2))
# output:
# tensor([3])
 
print(torch.diag(a, -2))
# output:
# tensor([7])

  1. PyTorch张量的关闭自动梯度的三种方式 ↩︎

  2. requires_grad,grad_fn,grad的含义及使用 ↩︎

  3. np.array与np.ndarray的区别 ↩︎

  4. PyTorch中的model.modules(), model.children(), model.named_children(), model.parameters(), model.nam… ↩︎

  5. python中yield的用法详解——最简单,最清晰的解释 ↩︎

  6. PyTorch的permute和reshape/view的区别 ↩︎

猜你喜欢

转载自blog.csdn.net/a_piece_of_ppx/article/details/125179510