torch.nn.Module.buffers(recurse=True)

参考链接: torch.nn.Module.buffers(recurse=True)
参考链接: Pytorch模型中的parameter与buffer
参考链接: Pytorch模型中的parameter与buffer
参考链接: What is the difference between register_buffer and register_parameter of nn.Module
参考链接: Pytorch模型中的parameter与buffer(torch.nn.Module的成员)

在这里插入图片描述

原文及翻译:

buffers(recurse=True)
方法: buffers(recurse=True)
    Returns an iterator over module buffers.
    返回一个迭代器,带迭代器可以遍历模块的缓冲buffer.
    Parameters 参数
        recurse (bool)if True, then yields buffers of this module and all submodules. 
        Otherwise, yields only buffers that are direct members of this module.
		recurse (布尔类型) – 如果该参数是True,表示递归地迭代返回,即迭代返回该模块的缓冲及其所有
		子模块的缓冲.否则,只返回作为该模块直接成员的缓冲.
    Yields  迭代返回
        torch.Tensor – module buffer
        torch.Tensor类型 – 该模块的缓冲
    
    Example:  例子:
    >>> for buf in model.buffers():
    >>>     print(type(buf.data), buf.size())
    <class 'torch.FloatTensor'> (20L,)
    <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)

总结,缓冲buffer必须要登记注册才会有效,如果仅仅将张量赋值给Module模块的属性,不会被自动转为缓冲buffer.因而也无法被state_dict()、buffers()、named_buffers()访问到.此外state_dict()可以遍历缓冲buffer和参数Parameter.
总结,缓冲buffer和参数Parameter的区别是前者不需要训练优化,而后者需要训练优化.在创建方法上也有区别,前者必须要将一个张量使用方法register_buffer()来登记注册,后者比较灵活,可以直接赋值给模块的属性,也可以使用方法register_parameter()来登记注册.

代码实验:

import torch 
import torch.nn as nn
torch.manual_seed(seed=20200910)
class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.conv1=torch.nn.Sequential(  # 输入torch.Size([64, 1, 28, 28])
                torch.nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),
                torch.nn.ReLU(),  # 输出torch.Size([64, 64, 28, 28])
        )
        self.attribute_buffer_in = torch.randn(3,5)
        register_buffer_in_temp = torch.randn(4,6)
        self.register_buffer('register_buffer_in', register_buffer_in_temp)

    def forward(self,x): 
        pass

print('cuda(GPU)是否可用:',torch.cuda.is_available())
print('torch的版本:',torch.__version__)
model = Model() #.cuda()



print('初始化之后模型修改之前'.center(100,"-"))
print('调用named_buffers()'.center(100,"-"))
for name, buf in model.named_buffers():
    print(name,'-->',buf.shape)

print('调用named_parameters()'.center(100,"-"))
for name, param in model.named_parameters():
    print(name,'-->',param.shape)

print('调用buffers()'.center(100,"-"))
for buf in model.buffers():
    print(buf.shape)

print('调用parameters()'.center(100,"-"))
for param in model.parameters():
    print(param.shape)

print('调用state_dict()'.center(100,"-"))
for k, v in model.state_dict().items():
    print(k, '-->', v.shape)



model.attribute_buffer_out = torch.randn(10,10)
register_buffer_out_temp = torch.randn(15,15)
model.register_buffer('register_buffer_out', register_buffer_out_temp)
print('模型初始化以及修改之后'.center(100,"-"))
print('调用named_buffers()'.center(100,"-"))
for name, buf in model.named_buffers():
    print(name,'-->',buf.shape)

print('调用named_parameters()'.center(100,"-"))
for name, param in model.named_parameters():
    print(name,'-->',param.shape)

print('调用buffers()'.center(100,"-"))
for buf in model.buffers():
    print(buf.shape)

print('调用parameters()'.center(100,"-"))
for param in model.parameters():
    print(param.shape)

print('调用state_dict()'.center(100,"-"))
for k, v in model.state_dict().items():
    print(k, '-->', v.shape)

控制台输出结果:

Windows PowerShell
版权所有 (C) Microsoft Corporation。保留所有权利。

尝试新的跨平台 PowerShell https://aka.ms/pscore6

加载个人及系统配置文件用了 840 毫秒。
(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>  & 'D:\Anaconda3\envs\ssd4pytorch1_2_0\python.exe' 'c:\Users\chenxuqi\.vscode\extensions\ms-python.python-2020.12.424452561\pythonFiles\lib\python\debugpy\launcher' '63490' '--' 'c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\test2.py'
cuda(GPU)是否可用: True
torch的版本: 1.2.0+cu92
--------------------------------------------初始化之后模型修改之前---------------------------------------------
-----------------------------------------调用named_buffers()------------------------------------------
register_buffer_in --> torch.Size([4, 6])
----------------------------------------调用named_parameters()----------------------------------------
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
--------------------------------------------调用buffers()---------------------------------------------
torch.Size([4, 6])
-------------------------------------------调用parameters()-------------------------------------------
torch.Size([64, 1, 3, 3])
torch.Size([64])
-------------------------------------------调用state_dict()-------------------------------------------
register_buffer_in --> torch.Size([4, 6])
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
--------------------------------------------模型初始化以及修改之后---------------------------------------------
-----------------------------------------调用named_buffers()------------------------------------------
register_buffer_in --> torch.Size([4, 6])
register_buffer_out --> torch.Size([15, 15])
----------------------------------------调用named_parameters()----------------------------------------
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
--------------------------------------------调用buffers()---------------------------------------------
torch.Size([4, 6])
torch.Size([15, 15])
-------------------------------------------调用parameters()-------------------------------------------
torch.Size([64, 1, 3, 3])
torch.Size([64])
-------------------------------------------调用state_dict()-------------------------------------------
register_buffer_in --> torch.Size([4, 6])
register_buffer_out --> torch.Size([15, 15])
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> 

猜你喜欢

转载自blog.csdn.net/m0_46653437/article/details/112760304
今日推荐