Visualización del modelo de red neuronal PyTorch (Netron)

Visualización del modelo de red neuronal PyTorch (Netron)

Netron es una herramienta para visualizar modelos de aprendizaje profundo, que puede ayudarnos a comprender mejor la estructura y los parámetros del modelo.

Se admiten archivos de almacenamiento de modelos en los siguientes formatos:

Formato plantilla (archivo) Abrir sin descargar
ONNX exprimidor abierto
TensorFlow Lite hilo abierto
TensorFlow ajedrecista abierto
Duro red móvil abierto
TorchScript traced_online_pred_layer abierto
Aprendizaje automático básico ejercitar abierto
red oscura Yolo abierto

Enlace GitHub: https://github.com/lutzroeder/netron

Sitio web oficial: https://netron.app


ONNX

(1) En PyTorch, puede usar torch.onnx.exportla función para exportar el modelo al formato ONNX:

import torch
import netron


# 定义 PyTorch 模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.bn = torch.nn.BatchNorm2d(64)
        self.relu = torch.nn.ReLU(inplace=True)
        self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = torch.nn.Linear(64 * 8 * 8, 10)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(-1, 64 * 8 * 8)
        x = self.fc(x)
        return x


# 创建模型实例并加载预训练权重
model = MyModel()

# 设置示例输入
input = torch.randn(1, 3, 32, 32)

# 将模型导出为 ONNX 格式
torch.onnx.export(model, input, './model/Test/onnx_model.onnx')  # 导出后 netron.start(path) 打开

netron.start(2) Abra el archivo de modelo ONNX exportado con el comando de Netron :

import netron

# 打开导出的 ONNX 模型文件
netron.start('./model/Test/onnx_model.onnx')
Serving './model/Test/onnx_model.onnx' at http://localhost:8080

La herramienta Netron se iniciará automáticamente en el navegador y visualizará el archivo del modelo.

Aviso:

Cuando el modelo se exporta al formato ONNX, se generará un archivo .onnxcon , y también se puede visualizar cargándolo en el sitio web oficial de Netron :

En Netron, puede ver la estructura del modelo, los parámetros, la entrada y salida y otra información. La visualización del modelo se puede ajustar haciendo zoom, girando y traduciendo para comprender mejor la estructura y los parámetros del modelo.

antorcha.guardar

Al torch.savevisualizar un modelo guardado con:

# 保存模型
torch.save(model.state_dict(), './model/Test/saved_model.pt')

# 可视化
netron.start('./model/Test/saved_model.pt')

Como se muestra en la figura a continuación, este método no muestra la información detallada del modelo:

Entonces: Netron no admite torch.savearchivos modelo exportados por PyTorch.

antorcha.jit.script

Para referencia: torch.jit.script y torch.jit.trace

Use para torch.jit.scriptconvertir el modelo en un script primero, luego torch.jit.saveguarde el modelo para visualizarlo:

# TorchScript:script
scripted_model = torch.jit.script(model)

# 保存模型
torch.jit.save(scripted_model, './model/Test/scripted_model.pth')

# 可视化
netron.start('./model/Test/scripted_model.pth')

antorcha.jit.trace

Para referencia: torch.jit.script y torch.jit.trace

Úselo para torch.jit.traceconvertir primero el modelo en una herramienta que rastree la ejecución del modelo, luego úselo torch.jit.savepara guardar el modelo y finalmente visualícelo:

# TorchScript:trace
traced_model = torch.jit.trace(model, torch.randn(1, 3, 32, 32))

# 保存模型
torch.jit.save(traced_model, './model/Test/traced_model.pth')

# 可视化
netron.start('./model/Test/traced_model.pth')

Supongo que te gusta

Origin blog.csdn.net/m0_70885101/article/details/131527770
Recomendado
Clasificación