torch 导出onnx

#export函数中,input_names需要是list
#如果输入是两个tenor,dummy_input需要是tuple,不能是list
import torch 
import torch.nn.functional as F
class TinyModel(torch.nn.Module):
    def __init__(self):
        super(TinyModel,self).__init__()

        self.linear1=torch.nn.Linear(4,5)
        

    def forward(self,x:torch.Tensor,y:torch.Tensor):
        x=self.linear1(x)
        y=self.linear1(y)
        
        return x + y
model=TinyModel()
input1=torch.rand([3,4])
input2=torch.rand([3,4])

dummy_input=(input1,input2)
output=['output1']
input=['input1','input2']
torch.onnx.export(model,dummy_input,"two_input.onnx",opset_version=11,verbose=True,input_names=input,output_names=output)

猜你喜欢

转载自blog.csdn.net/weixin_44594953/article/details/130322247