PyTorch Neural Network Model Visualization (Netron)

PyTorch Neural Network Model Visualization (Netron)

Netron is a tool for visualizing deep learning models, which can help us better understand the structure and parameters of the model.

Model storage files in the following formats are supported:

Format template (file) Open without downloading
ONNX squeezenet open
TensorFlow Lite yamnet open
TensorFlow chessbot open
Hard mobilenet open
TorchScript traced_online_pred_layer open
Core ML exermote open
Darknet yolo open

GitHub link: https://github.com/lutzroeder/netron

Official website: https://netron.app


ONNX

(1) In PyTorch, you can use torch.onnx.exportthe function to export the model to ONNX format:

import torch
import netron


# 定义 PyTorch 模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.bn = torch.nn.BatchNorm2d(64)
        self.relu = torch.nn.ReLU(inplace=True)
        self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = torch.nn.Linear(64 * 8 * 8, 10)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(-1, 64 * 8 * 8)
        x = self.fc(x)
        return x


# 创建模型实例并加载预训练权重
model = MyModel()

# 设置示例输入
input = torch.randn(1, 3, 32, 32)

# 将模型导出为 ONNX 格式
torch.onnx.export(model, input, './model/Test/onnx_model.onnx')  # 导出后 netron.start(path) 打开

netron.start(2) Open the exported ONNX model file with the command of Netron :

import netron

# 打开导出的 ONNX 模型文件
netron.start('./model/Test/onnx_model.onnx')
Serving './model/Test/onnx_model.onnx' at http://localhost:8080

The Netron tool will automatically launch in the browser and visualize the model file.

Notice:

When the model is exported to the ONNX format, a file .onnxwith , and it can also be visualized by uploading it to the Netron official website :

In Netron, you can view the model's structure, parameters, input and output and other information. The visualization of the model can be adjusted by zooming, rotating, and translating to better understand the structure and parameters of the model.

torch.save

When torch.savevisualizing a saved model with :

# 保存模型
torch.save(model.state_dict(), './model/Test/saved_model.pt')

# 可视化
netron.start('./model/Test/saved_model.pt')

As shown in the figure below, this method does not display the detailed information of the model:

So: Netron does not support torch.savemodel files exported by PyTorch.

torch.jit.script

For reference: torch.jit.script and torch.jit.trace

Use to torch.jit.scriptconvert the model to a script first, then torch.jit.savesave the model with to visualize it:

# TorchScript:script
scripted_model = torch.jit.script(model)

# 保存模型
torch.jit.save(scripted_model, './model/Test/scripted_model.pth')

# 可视化
netron.start('./model/Test/scripted_model.pth')

torch.jit.trace

For reference: torch.jit.script and torch.jit.trace

Use to torch.jit.tracefirst convert the model to a tool that tracks model execution, then use torch.jit.saveto save the model, and finally visualize it:

# TorchScript:trace
traced_model = torch.jit.trace(model, torch.randn(1, 3, 32, 32))

# 保存模型
torch.jit.save(traced_model, './model/Test/traced_model.pth')

# 可视化
netron.start('./model/Test/traced_model.pth')

Guess you like

Origin blog.csdn.net/m0_70885101/article/details/131527770