网络模型输入输出可视化代码

最近调整网络模型结构,添加模块的时候遇到输入输出维度的问题,每一次都需要debug来看网络输出的维度,很麻烦,因此去找了一些能将模型每一层输入输出每一层可视化的代码

可视化示例(Darknet53为例)

 需要注意的是,每一个模块之后会有一行是显示该模块的输入和输出,并不单指卷积、激活等操作,如下图,1-3行是每一次计算操作的输入输出及参数数量,而1-3行属于模块BasicConv,所以第4显示的是BasicConv模块整体输入输出和参数数量:

代码主体部分(input_output.py)

# coding:utf8
import torch
from torch.autograd import Variable
from collections import OrderedDict
from torch import nn
import pandas as pd
import numpy as np
from nets.yolo import YoloBody

def get_names_dict(model):
    names = {}
    def _get_names(module, parent_name=''):
        for key, module in module.named_children():
            name = parent_name + '.' + key if parent_name else key
            names[name] = module
            if isinstance(module, torch.nn.Module):
                _get_names(module, parent_name=name)

    _get_names(model)
    return names


def torch_summarize_df(input_size, model, weights=False, input_shape=True, nb_trainable=False):
    def register_hook(module):
        def hook(module, input, output):
            name = ''
            for key, item in names.items():
                if item == module:
                    name = key
            # <class 'torch.nn.modules.conv.Conv2d'>
            class_name = str(module.__class__).split('.')[-1].split("'")[0]
            module_idx = len(summary)

            m_key = module_idx + 1

            summary[m_key] = OrderedDict()
            # summary[m_key]['name'] = name
            #这个name可能会特别长,影响可视化,不介意的可以也打出来 
            summary[m_key]['class_name'] = class_name
            if input_shape:
                summary[m_key]['input_shape'] = (-1,) + tuple(input[0].size())[1:]
            summary[m_key]['output_shape'] = (-1,) + tuple(output[0].size())[:]
            if weights:
                summary[m_key]['weights'] = list(
                    [tuple(p.size()) for p in module.parameters()])

            # summary[m_key]['trainable'] = any([p.requires_grad for p in module.parameters()])
            if nb_trainable:
                params_trainable = sum(
                    [torch.LongTensor(list(p.size())).prod() for p in module.parameters() if p.requires_grad])
                summary[m_key]['nb_trainable'] = params_trainable
            params = sum([torch.LongTensor(list(p.size())).prod() for p in module.parameters()])
            summary[m_key]['nb_params'] = params

        if not isinstance(module, nn.Sequential) and \
                not isinstance(module, nn.ModuleList) and \
                not (module == model):
            hooks.append(module.register_forward_hook(hook))

    # 名称存储在parent中,path+name是唯一的,而不是名称
    names = get_names_dict(model)

    # 检查网络是否有多个输入
    if isinstance(input_size[0], (list, tuple)):
        x = [Variable(torch.rand(1, *in_size)) for in_size in input_size]
    else:
        x = Variable(torch.rand(1, *input_size))

    if next(model.parameters()).is_cuda:
        x = x.cuda()

    # 创建属性
    summary = OrderedDict()
    hooks = []

    # 统计参数信息/ 注册hook
    model.apply(register_hook)

    # 向前传递
    model(x)

    # 移除这些hook
    for h in hooks:
        h.remove()

    # 制作结构
    df_summary = pd.DataFrame.from_dict(summary, orient='index')

    return df_summary

注意 :第36行的位置,我将summary[m_key]['name'] = name注释调了,因为name显示的是每一层在代码中的名称,有时候特别长会影响我查找,如果大家有需要可以将其还原,以下是存在name的情况下:

 使用方式

# 导入项目中模型,在最上面导入就好
from nets.yolo import YoloBody 

device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model= YoloBody([[6, 7, 8], [3, 4, 5], [0, 1, 2]], 20, phi=0).to(device)

# input_size可以不用考虑batch_size这里只看网络结构
# 显示出来的网络结构会在batch_size的位置显示-1
df = torch_summarize_df(input_size=(3, 416, 416), model=model)

# 以下代码是解决print的时候会存在内容过多而省略无法显示的情况
# 如果以下代码还不能完全打印,请根据注释提示将其修改
np.set_printoptions(threshold=np.inf)
pd.set_option('display.width', 500)# 设置字符显示宽度
pd.set_option('display.max_rows', None)# 设置显示最大行
pd.set_option('display.max_columns', None)# 设置显示最大列,None为显示所有列

# 自动创建一个名为input_output.txt
# 'w'是覆盖式写入文件,以防调整需要的信息时省去删除文件内容的操作
# 不想覆盖的可以换成'a'
f = open("input_output.txt", 'w')
print(df,file=f) # 写入文件input_output.txt

上述文件运行后会在项目文件下直接生成一个<input_output.txt>的文件,大家可以自行查看

以上内容参考pytorch实用工具总结 - 知乎进行修改,但是该文章内容我在一开始运行的时候有一定的错误,本文章的内容都是在调整修改后的代码,直接使用即可。

猜你喜欢

转载自blog.csdn.net/weixin_64064486/article/details/124548883