How to export ONNX model from Pytorch and use it to achieve image super resolution

foreword

        In this tutorial, we will show how to convert a model defined in PyTorch to ONNX format and then run it using ONNX runtime.

        The ONNX runtime is a performance-focused engine for ONNX models that enables efficient inference across multiple platforms and hardware (Windows, Linux, and Mac), as well as CPUs and GPUs. The ONNX runtime has been shown to significantly improve the performance of several models, as described here.

        For this tutorial, you need ONNX and the ONNX runtime installed. pip install onnx onnxruntime Binary builds of ONNX and the ONNX runtime are available using  Obtain. The ONNX runtime recommends using the latest stable runtime of PyTorch.

How to create, export and use models

# Some standard imports
import io
import numpy as np

from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx

        Here we take image super-fraction as an example to explain. Super-resolution is a method of increasing the resolution of images, videos, widely used in image processing or video editing. In this tutorial, we will use a small super-resolution model.

        First, let's create a model in PyTorch  SuperResolution . The model uses efficient sub-pixel convolutional layers as described in "Real-time single-image and video super-resolution using efficient sub-pixel convolutional neural networks" - Shi et al. Upscale the resolution of images by upscale factors. The model expects  YCbCr the Y component of the image as input and outputs the upscaled Y component at super-resolution.

create model

        The model comes directly from PyTorch's example without modification:

# Super Resolution model definition in PyTorch
import torch.nn as nn
import torch.nn.init as init


class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)

export model 

        Normally, you would train this model now; however, for this tutorial, we will download some pretrained weights. Note that this model is not fully trained, so it doesn't perform very well, it's here for demonstration purposes only.

torch_model.eval() It is important to call or  torch_model.train(False) to convert the model to inference mode         before exporting the model  . This is required because operators like dropout or batchnorm behave differently in inference and training modes.

# Load pretrained model weights
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 1    # just a random number

# Initialize model with the pretrained weights
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
    map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))

# set the model to inference mode
torch_model.eval()

        Exporting models in PyTorch works through tracing or scripting. This tutorial will use a model exported with tracing as an example. To export the model, we call  torch.onnx.export() the function. This will execute the model, recording a trace of the operators used to compute the output. To  export run the model, we need to provide an input tensor  x . The values ​​in it can be random as long as it's the correct type and size. Note that the input size for all input dimensions will be fixed in the exported ONNX graph unless specified as a dynamic axis. In this example, we export the model with an input of batch_size 1, then   specify the first dimension as dynamic in the parameters torch.onnx.export() of  . dynamic_axesTherefore, the exported model will accept an input of size [batch_size, 1, 224, 224], where batch_size can be variable.

        To learn more details about PyTorch's export interface, check out the torch.onnx documentation.

# Input to the model
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
torch_out = torch_model(x)

# Export the model
torch.onnx.export(torch_model,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "super_resolution.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})

        We also compute  torch_out , this is the output after the model, we will use it to verify that the exported model computes the same values ​​when ONNX runs.

        But before validating the output of the model using the ONNX runtime, we will check the ONNX model using the ONNX API. First,  onnx.load("super_resolution.onnx") the saved model is loaded and  onnx.ModelProto the structure is output (a top-level file/container format for bundling ML models). For more information, visit the onnx.proto documentation.). Then,  onnx.checker.check_model(onnx_model) the structure of the model is validated and confirmed to have a valid architecture. Verify the validity of the ONNX graph by checking the version of the model, the structure of the graph, and the nodes and their inputs and outputs.

use model

import onnx

onnx_model = onnx.load("super_resolution.onnx")
onnx.checker.check_model(onnx_model)

        Now, let's compute the output using the ONNX runtime's Python API. This part can usually be done in a separate process or on another computer, but we'll continue the same process so we can verify that the ONNX runtime and PyTorch are computing the same values ​​for the network.

        In order to run the model with the ONNX runtime, we need to create an inference session for the model with the selected configuration parameters (here we use the default configuration). After creating the session, we evaluate the model using the run() API. The output of this call is a list containing the outputs of the models computed by the ONNX runtime.

import onnxruntime

ort_session = onnxruntime.InferenceSession("super_resolution.onnx")

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

        We should see that the output of the PyTorch and ONNX runtime runs numerically matches the given precision (  rtol=1e-03 and  atol=1e-05 ). As a side note, if they don't match, there is a problem in the ONNX exporter.

        So far, we have exported a model from PyTorch and shown how to load it and run it at ONNX runtime using virtual tensors as input.

sample code

        In this tutorial, we will use a widely used famous cat image as shown below

        First, let's load the image, preprocess it using the standard PIL python library. Note that this preprocessing is standard practice for processing data to train/test neural networks.

        We first resize the image to fit the model input size (224x224). Then we split the image into Y, Cb and Cr components. These components represent the grayscale image (Y) and the blue-difference (Cb) and red-difference (Cr) chrominance components. The Y component is more sensitive to the human eye, and we are interested in this component we are going to transform. After extracting the Y component, we convert it into a tensor, which will be the input to the model.

from PIL import Image
import torchvision.transforms as transforms

img = Image.open("./_static/img/cat.jpg")

resize = transforms.Resize([224, 224])
img = resize(img)

img_ycbcr = img.convert('YCbCr')
img_y, img_cb, img_cr = img_ycbcr.split()

to_tensor = transforms.ToTensor()
img_y = to_tensor(img_y)
img_y.unsqueeze_(0)

        Now, as a next step, let's take a tensor representing a greyscale resized cat image and run the super-resolution model in the ONNX runtime as described earlier.

ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img_y)}
ort_outs = ort_session.run(None, ort_inputs)
img_out_y = ort_outs[0]

        At this point, the output of the model is a tensor. Now, we will process the output of the model to construct back the final output image from the output tensor, and save the image. The post-processing steps of the PyTorch implementation of the super-resolution model are adopted here.

img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L')

# get the output image follow post-processing step from PyTorch implementation
final_img = Image.merge(
    "YCbCr", [
        img_out_y,
        img_cb.resize(img_out_y.size, Image.BICUBIC),
        img_cr.resize(img_out_y.size, Image.BICUBIC),
    ]).convert("RGB")

# Save the image, we will compare this with the output image from mobile device
final_img.save("./_static/img/cat_superres_with_ort.jpg")

Image effect after super-resolution: 

Summarize

ONNX runtime is a cross-platform engine, you can run it on multiple platforms as well as CPU and GPU.

The ONNX runtime can also be deployed to the cloud for model inference using server machine learning services.

Guess you like

Origin blog.csdn.net/qq_39312146/article/details/132074762