Convert pytorch model to tensorflow model

Preface
Currently, most models are in pytorch format, which is commonly used by tfserving in deployment, so the model format needs to be in the save_model.pb format. This article introduces how to convert pytorch into a tensorflow format model.

Core process: pytorch====>onnx====>tensorflow

1. Convert pytorch to onnx

The first step is to load the model, initialize the model network, load model weights, etc. For example, here I want to load instructor-largethe model file. Instructor-large is a folder, let's take a look at the contents.
Insert image description here
The model network is T5EncorderModel, and the loading code is as follows.

from transformers import T5EncoderModel
import torch
model = T5EncoderModel.from_pretrained("../instructor-large")

After loading successfully, use torch.onnx.exportfunction conversion.

torch.onnx.export(model, torch.zeros(1, 512, dtype=torch.long), "./model.onnx", input_names=["input"], output_names=["output"], opset_version=12)

The first three parameters of the core: the first is the model, the second is the input format of the model, the third is the target path
and one more: opset_version=12 , be sure to add this, otherwise an error will be reported later, I went to the official website to check it because This is an official bug. It cannot be said to be a bug. It is a problem between versions. Opset_version is set to 12.

After executing the above export code, the onnx model will be generated (here, the command is to generate a file model.onnx, and the others are ignored).
Insert image description here
At this point, pytorch has successfully converted into onnx. (Note that the model size cannot exceed 2GB. If it exceeds the conversion, the conversion will not be successful because the official does not support models exceeding 2GB size.)


2. Onnx to tensorflow

The first is to read the onnx model

import onnx
import tf2onnx
from onnx_tf.backend import prepare
# 加载ONNX模型
onnx_model = onnx.load('model.onnx', load_external_data=False)

Then execute the following code

tf_pb = prepare(onnx_model)
tf_pb.export_graph('model_tensorflow')

Note here how if the opset_version is not set to 12 before, the following error will occur. It is still an official problem hahaha.
Insert image description here

Finally, if you successfully execute the above code, a model_tensorflow folder will be generated, which will look like this. And you're done.
Insert image description here
If your conversion is not successful, copy the error message into the browser and search it. Everyone may have different bugs (I feel like my words are getting more and more like GPT, hahaha...)

Guess you like

Origin blog.csdn.net/xdg15294969271/article/details/132184661