torch 打印网络参数、结构

要打印网络结构,可以使用print或print(model)语句,其中model是定义的神经网络模型对象。这将输出整个网络的结构信息,包括每个层的名称、输入和输出尺寸以及参数量等。

要打印网络参数,可以使用以下代码:

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data)

该代码遍历了模型中所有需要梯度更新的参数,并打印出参数名称和对应的数值。
如果只想打印网络结构的摘要信息,可以使用以下代码:

from torchsummary import summary
summary(model, input_size=(input_channels, input_height, input_width))

上述代码使用了第三方库torchsummary,它提供了一种方便的方式来打印网络结构的摘要信息,包括每个层的名称、形状和参数数量等。其中input_size指定了输入张量的形状。
要查看定义的神经网络结构,可以使用以下代码:

import torch
from your_module import YourNetwork
model = YourNetwork()  # 实例化网络
# 打印网络结构
print(model)

在上述代码中,YourNetwork是定义的神经网络类。首先需要导入该类,然后实例化一个对象,并将其赋值给model变量。最后,通过调用print(model)语句来输出网络结构信息。
注意:如果模型包含多个子模块(例如,使用nn.Sequential组合多个层),则可以使用以下代码来打印每个子模块的结构信息:

for name, module in model.named_children():
    print(name)
    print(module)

以上代码遍历了所有子模块,并打印出每个子模块的名称和对应的结构信息。

猜你喜欢

转载自blog.csdn.net/AdamCY888/article/details/134828966
今日推荐