pytorch中查看可训练参数

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/jeryjeryjery/article/details/83057199

  pytorch中我们有时候可能需要设定某些变量是参与训练的,这时候就需要查看哪些是可训练参数,以确定这些设置是成功的。

  pytorch中model.parameters()函数定义如下:

    def parameters(self):
        r"""Returns an iterator over module parameters.

        This is typically passed to an optimizer.

        Yields:
            Parameter: module parameter

        Example::

            >>> for param in model.parameters():
            >>>     print(type(param.data), param.size())
            <class 'torch.FloatTensor'> (20L,)
            <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)

        """
        for name, param in self.named_parameters():
            yield param

所以,我们可以遍历named_parameters()中的所有的参数,只打印那些param.requires_grad=True的变量。具体实现代码如下所示:

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

这样打印出的结果就是模型中所有的可训练参数列表!

猜你喜欢

转载自blog.csdn.net/jeryjeryjery/article/details/83057199