Article Directory
torch.jit.script 与 torch.jit.trace
torch.jit.script
and torch.jit.trace
are tools in PyTorch for converting models into scripts or tracking model execution.
They are part of PyTorch's Just-in-Time Compilation module, which is used to improve the execution efficiency of models and support the deployment of models.
torch.jit.script
torch.jit.script
is the function that converts the model into a script.
It takes a PyTorch model as input and converts it into a runnable script. Converted scripts can be called like normal Python functions, or saved to disk and executed in an environment without PyTorch dependencies.
The benefit of this transformation is that it reduces the overhead during model execution because it eliminates the overhead of the Python interpreter.
Example:
import torch
# 定义模型
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.fc = torch.nn.Linear(64 * 8 * 8, 10)
def forward(self, x):
x = self.conv(x)
x = torch.nn.functional.relu(x)
x = x.view(-1, 64 * 8 * 8)
x = self.fc(x)
return x
model = MyModel()
# 将模型转换为Torch脚本模块
scripted_model = torch.jit.script(model)
# 调用
output = scripted_model(torch.randn(1, 3, 32, 32))
print(output)
# 保存模型
torch.jit.save(scripted_model, './model/Test/scripted_model.pth')
torch.jit.trace
torch.jit.trace
is the function that tracks model execution.
It takes a model and an example input, records the execution of the model on the given input, and returns a traced model.
The trace model can be seen as a script model with the same functionality, but it also retains the dynamic nature of the original model and can use more advanced features such as dynamic graphs and control flow.
Example:
import torch
# 定义模型
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.fc = torch.nn.Linear(64 * 8 * 8, 10)
def forward(self, x):
x = self.conv(x)
x = torch.nn.functional.relu(x)
x = x.view(-1, 64 * 8 * 8)
x = self.fc(x)
return x
model = MyModel()
# 将模型转换为Torch脚本模块
traced_model = torch.jit.trace(model, torch.randn(1, 3, 32, 32))
# 调用
output = traced_model(torch.randn(1, 3, 32, 32))
print(output)
# 保存模型
torch.jit.save(traced_model, './model/Test/traced_model.pth')
Notice
Since
torch.jit.trace
the method only traces the execution path for a given input tensor, when using the transformed module object for inference, the input tensor must be of the same dimension and data type as was used for tracing.
torch.jit.save
Converted module objects using torch.jit.script
or can be used directly for inference, or can be saved to a file using the method to load the model when needed.torch.jit.trace
torch.jit.save
torch.jit.load
A PyTorch model can be loaded using torch.jit.load
the function, which can accept a model file path or a file object as an input parameter. Specific steps are as follows:
- Load the model file:
import torch
model = torch.jit.load("model.pt")
This will load
model.pt
the model file named .
- Load the model file and specify the device:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.load("model.pt", map_location=device)
This will load
model.pt
the model file named and place it on an available CUDA device.
- Load the model file and use
eval
the schema:
import torch
model = torch.jit.load("model.pt")
model.eval()
This will load
model.pt
the model file named and convert it to evaluation mode.
Notice:
If the model uses a specific device, such as CUDA, then you need to make sure that device is available when loading the model. If a device is not available, map_location
the parameter needs to be used to map the model to an available device.
Code
import torch
# 定义模型
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.fc = torch.nn.Linear(64 * 8 * 8, 10)
def forward(self, x):
x = self.conv(x)
x = torch.nn.functional.relu(x)
x = x.view(-1, 64 * 8 * 8)
x = self.fc(x)
return x
model = MyModel()
print(model)
# 将模型转换为Torch脚本模块
scripted_model = torch.jit.script(model)
traced_model = torch.jit.trace(model, torch.randn(1, 3, 32, 32))
# 调用
output_scripted = scripted_model(torch.randn(1, 3, 32, 32))
output_traced = traced_model(torch.randn(1, 3, 32, 32))
# 保存模型
torch.jit.save(scripted_model, './model/Test/scripted_model.pth')
torch.jit.save(traced_model, './model/Test/traced_model.pth')
# 加载模型
load_scripted_model = torch.jit.load('./model/Test/scripted_model.pth')
print(load_scripted_model)
load_traced_model = torch.jit.load('./model/Test/traced_model.pth')
print(load_traced_model)
MyModel(
(conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc): Linear(in_features=4096, out_features=10, bias=True)
)
RecursiveScriptModule(
original_name=MyModel
(conv): RecursiveScriptModule(original_name=Conv2d)
(fc): RecursiveScriptModule(original_name=Linear)
)
RecursiveScriptModule(
original_name=MyModel
(conv): RecursiveScriptModule(original_name=Conv2d)
(fc): RecursiveScriptModule(original_name=Linear)
)
illustrate:
-
RecursiveScriptModule
Represents a recursive TorchScript module, similar to a tree structure.- The original name of the module is
MyModel
, indicating that this is a container for models.
- The original name of the module is
-
The container contains two submodules
conv
andfc
, which are recursive script modules ofConv2d
and .Linear
It means that these two submodules are also TorchScript modules and can be operated in TorchScript. -
RecursiveScriptModule
A PyTorch model can be converted to a TorchScript module viatorch.jit.script
or .torch.jit.trace
During the conversion process, each submodule is also converted to a corresponding TorchScript module, nested within the parent module. -
This nested structure can well represent the hierarchy of deep learning models.
-
RecursiveScriptModule
The module name and original name in canoriginal_name
be accessed through the attribute.- For example,
MyModel
the original name of the module isMyModel
,conv
the original name of the module isConv2d
, andfc
the original name of the module isLinear
.
- For example,