torch.jit.script 与 torch.jit.trace

torch.jit.script 与 torch.jit.trace

torch.jit.script and torch.jit.traceare tools in PyTorch for converting models into scripts or tracking model execution.

They are part of PyTorch's Just-in-Time Compilation module, which is used to improve the execution efficiency of models and support the deployment of models.

torch.jit.script

torch.jit.scriptis the function that converts the model into a script.

It takes a PyTorch model as input and converts it into a runnable script. Converted scripts can be called like normal Python functions, or saved to disk and executed in an environment without PyTorch dependencies.

The benefit of this transformation is that it reduces the overhead during model execution because it eliminates the overhead of the Python interpreter.

Example:

import torch


# 定义模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.fc = torch.nn.Linear(64 * 8 * 8, 10)

    def forward(self, x):
        x = self.conv(x)
        x = torch.nn.functional.relu(x)
        x = x.view(-1, 64 * 8 * 8)
        x = self.fc(x)
        return x


model = MyModel()

# 将模型转换为Torch脚本模块
scripted_model = torch.jit.script(model)

# 调用
output = scripted_model(torch.randn(1, 3, 32, 32))
print(output)

# 保存模型
torch.jit.save(scripted_model, './model/Test/scripted_model.pth')

torch.jit.trace

torch.jit.traceis the function that tracks model execution.

It takes a model and an example input, records the execution of the model on the given input, and returns a traced model.

The trace model can be seen as a script model with the same functionality, but it also retains the dynamic nature of the original model and can use more advanced features such as dynamic graphs and control flow.

Example:

import torch


# 定义模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.fc = torch.nn.Linear(64 * 8 * 8, 10)

    def forward(self, x):
        x = self.conv(x)
        x = torch.nn.functional.relu(x)
        x = x.view(-1, 64 * 8 * 8)
        x = self.fc(x)
        return x


model = MyModel()

# 将模型转换为Torch脚本模块
traced_model = torch.jit.trace(model, torch.randn(1, 3, 32, 32))

# 调用
output = traced_model(torch.randn(1, 3, 32, 32))
print(output)

# 保存模型
torch.jit.save(traced_model, './model/Test/traced_model.pth')

Notice

Since torch.jit.tracethe method only traces the execution path for a given input tensor, when using the transformed module object for inference, the input tensor must be of the same dimension and data type as was used for tracing.

torch.jit.save

Converted module objects using torch.jit.scriptor can be used directly for inference, or can be saved to a file using the method to load the model when needed.torch.jit.tracetorch.jit.save

torch.jit.load

A PyTorch model can be loaded using torch.jit.loadthe function, which can accept a model file path or a file object as an input parameter. Specific steps are as follows:

  • Load the model file:
import torch

model = torch.jit.load("model.pt")

This will load model.ptthe model file named .

  • Load the model file and specify the device:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.load("model.pt", map_location=device)

This will load model.ptthe model file named and place it on an available CUDA device.

  • Load the model file and use evalthe schema:
import torch

model = torch.jit.load("model.pt")
model.eval()

This will load model.ptthe model file named and convert it to evaluation mode.

Notice:

If the model uses a specific device, such as CUDA, then you need to make sure that device is available when loading the model. If a device is not available, map_locationthe parameter needs to be used to map the model to an available device.

Code

import torch


# 定义模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.fc = torch.nn.Linear(64 * 8 * 8, 10)

    def forward(self, x):
        x = self.conv(x)
        x = torch.nn.functional.relu(x)
        x = x.view(-1, 64 * 8 * 8)
        x = self.fc(x)
        return x


model = MyModel()
print(model)

# 将模型转换为Torch脚本模块
scripted_model = torch.jit.script(model)
traced_model = torch.jit.trace(model, torch.randn(1, 3, 32, 32))

# 调用
output_scripted = scripted_model(torch.randn(1, 3, 32, 32))
output_traced = traced_model(torch.randn(1, 3, 32, 32))

# 保存模型
torch.jit.save(scripted_model, './model/Test/scripted_model.pth')
torch.jit.save(traced_model, './model/Test/traced_model.pth')

# 加载模型
load_scripted_model = torch.jit.load('./model/Test/scripted_model.pth')
print(load_scripted_model)

load_traced_model = torch.jit.load('./model/Test/traced_model.pth')
print(load_traced_model)
MyModel(
  (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc): Linear(in_features=4096, out_features=10, bias=True)
)
RecursiveScriptModule(
  original_name=MyModel
  (conv): RecursiveScriptModule(original_name=Conv2d)
  (fc): RecursiveScriptModule(original_name=Linear)
)
RecursiveScriptModule(
  original_name=MyModel
  (conv): RecursiveScriptModule(original_name=Conv2d)
  (fc): RecursiveScriptModule(original_name=Linear)
)

illustrate:

  • RecursiveScriptModuleRepresents a recursive TorchScript module, similar to a tree structure.

    • The original name of the module is MyModel, indicating that this is a container for models.
  • The container contains two submodules convand fc, which are recursive script modules of Conv2dand . LinearIt means that these two submodules are also TorchScript modules and can be operated in TorchScript.

  • RecursiveScriptModuleA PyTorch model can be converted to a TorchScript module via torch.jit.scriptor . torch.jit.traceDuring the conversion process, each submodule is also converted to a corresponding TorchScript module, nested within the parent module.

  • This nested structure can well represent the hierarchy of deep learning models.

  • RecursiveScriptModuleThe module name and original name in can original_namebe accessed through the attribute.

    • For example, MyModelthe original name of the module is MyModel, convthe original name of the module is Conv2d, and fcthe original name of the module is Linear.

Guess you like

Origin blog.csdn.net/m0_70885101/article/details/131498340