Deep learning model display tool netron
APP URL
netron is a very good deep learning model display tool
netron supports displaying most of the deep learning models, does not support pt or pth files generated by pytorch, but converts these two files to onnx format, netron is supported
Install
I have not tried other systems personally, but there are github addresses, and I have only installed it under ubuntu for the time being, it is very simple
pip install netron
the code
import torch import torch.nn as nn import netron # Define a simple binary classification network class SimpleNet(nn.Module): def __init__(self): super(SimpleNet, self).__init__() self.conv1 = nn.Sequential ( nn.Conv2d(in_channels=3, out_channels=50, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(2) ) self.conv2 = nn.Sequential( nn.Conv2d (in_channels=50, out_channels=200, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(2) ) self.conv3 = nn.Sequential( nn.Conv2d(in_channels=200, out_channels=500, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(2) ) self.conv4 = nn.Sequential( nn.Conv2d(in_channels=500, out_channels=200, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(2) ) self.conv5 = nn.Sequential( nn.Conv2d(in_channels=200, out_channels=50, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(2) ) self.fc = nn.Sequential( nn.Linear(50 * 20 * 11, 5000) ) self.fc1 = nn.Sequential( nn.Linear(5000, 2000) ) self.fc2 = nn.Sequential( nn.Linear(2000, 50) ) self.classifier = nn.Sequential( nn.Linear(50, 2) ) def forward(self, x): x = torch.tensor(x, dtype=torch.float32) x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = self.conv5(x) x = torch.flatten(x, start_dim=1) x = self.fc(x) x = self.fc1(x) x = self.fc2(x) x = self.classifier(x) return x d = torch.rand(1, 3, 640, 360) m = SimpleNet() o = m(d) onnx_path = "dirtyjudgment640320_pattern.onnx" torch.onnx.export(m, d, onnx_path) netron.start(onnx_path)
Show results
Conv duplication between two images
pth to onnx
def convert_model_to_ONNX(input_img_size, input_pth_model, output_ONNX): dummy_input = torch.randn(2, 3, input_img_size[1], input_img_size[0]) model = SimpleNet() #网络结构 state_dict = torch.load(input_pth_model) new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) #model.load_state_dict(state_dict) input_names = ["input_image"] #指定输入输出 output_names = ["output_classification"] torch.onnx.export(model, dummy_input, output_ONNX, verbose=True, input_names=input_names, output_names=output_names)