torch.jit.script と torch.jit.trace
torch.jit.script
および はtorch.jit.trace
、モデルをスクリプトに変換したり、モデルの実行を追跡したりするための PyTorch のツールです。
これらは PyTorch の Just-in-Time Compilation モジュールの一部であり、モデルの実行効率を向上させ、モデルのデプロイメントをサポートするために使用されます。
トーチ.jit.スクリプト
torch.jit.script
モデルをスクリプトに変換する関数です。
PyTorch モデルを入力として受け取り、それを実行可能なスクリプトに変換します。変換されたスクリプトは、通常の Python 関数と同様に呼び出すことも、ディスクに保存して PyTorch 依存関係のない環境で実行することもできます。
この変換の利点は、Python インタープリターのオーバーヘッドが排除されるため、モデル実行時のオーバーヘッドが軽減されることです。
例:
import torch
# 定义模型
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.fc = torch.nn.Linear(64 * 8 * 8, 10)
def forward(self, x):
x = self.conv(x)
x = torch.nn.functional.relu(x)
x = x.view(-1, 64 * 8 * 8)
x = self.fc(x)
return x
model = MyModel()
# 将模型转换为Torch脚本模块
scripted_model = torch.jit.script(model)
# 调用
output = scripted_model(torch.randn(1, 3, 32, 32))
print(output)
# 保存模型
torch.jit.save(scripted_model, './model/Test/scripted_model.pth')
トーチ.ジット.トレース
torch.jit.trace
モデルの実行を追跡する関数です。
モデルとサンプル入力を受け取り、指定された入力でのモデルの実行を記録し、トレースされたモデルを返します。
トレース モデルは、同じ機能を持つスクリプト モデルとして見ることができますが、元のモデルの動的な性質も保持しており、動的なグラフや制御フローなどのより高度な機能を使用することもできます。
例:
import torch
# 定义模型
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.fc = torch.nn.Linear(64 * 8 * 8, 10)
def forward(self, x):
x = self.conv(x)
x = torch.nn.functional.relu(x)
x = x.view(-1, 64 * 8 * 8)
x = self.fc(x)
return x
model = MyModel()
# 将模型转换为Torch脚本模块
traced_model = torch.jit.trace(model, torch.randn(1, 3, 32, 32))
# 调用
output = traced_model(torch.randn(1, 3, 32, 32))
print(output)
# 保存模型
torch.jit.save(traced_model, './model/Test/traced_model.pth')
知らせ
このメソッドは指定された入力テンソルの実行パスのみをトレースするため
torch.jit.trace
、変換されたモジュール オブジェクトを推論に使用する場合、入力テンソルはトレースに使用されたものと同じ次元とデータ型である必要があります。
トーチ.ジット.保存
torch.jit.script
またはを使用しtorch.jit.trace
て変換されたモジュール オブジェクトは、推論に直接使用することも、torch.jit.save
必要に応じてモデルをロードするメソッドを使用してファイルに保存することもできます。
トーチ.jit.load
torch.jit.load
PyTorch モデルは、モデル ファイル パスまたはファイル オブジェクトを入力パラメーターとして受け入れることができる関数を使用してロードできます。具体的な手順は次のとおりです。
- モデル ファイルをロードします。
import torch
model = torch.jit.load("model.pt")
これにより、
model.pt
という名前のモデル ファイルがロードされます。
- モデル ファイルをロードし、デバイスを指定します。
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.load("model.pt", map_location=device)
これにより、
model.pt
指定されたモデル ファイルがロードされ、利用可能な CUDA デバイスに配置されます。
- モデル ファイルをロードし、
eval
スキーマを使用します。
import torch
model = torch.jit.load("model.pt")
model.eval()
model.pt
これにより、指定されたモデル ファイルがロードされ、評価モードに変換されます。
知らせ:
モデルが CUDA などの特定のデバイスを使用している場合は、モデルをロードするときにそのデバイスが使用可能であることを確認する必要があります。デバイスが使用できない場合は、map_location
パラメータを使用してモデルを使用可能なデバイスにマッピングする必要があります。
コード
import torch
# 定义模型
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.fc = torch.nn.Linear(64 * 8 * 8, 10)
def forward(self, x):
x = self.conv(x)
x = torch.nn.functional.relu(x)
x = x.view(-1, 64 * 8 * 8)
x = self.fc(x)
return x
model = MyModel()
print(model)
# 将模型转换为Torch脚本模块
scripted_model = torch.jit.script(model)
traced_model = torch.jit.trace(model, torch.randn(1, 3, 32, 32))
# 调用
output_scripted = scripted_model(torch.randn(1, 3, 32, 32))
output_traced = traced_model(torch.randn(1, 3, 32, 32))
# 保存模型
torch.jit.save(scripted_model, './model/Test/scripted_model.pth')
torch.jit.save(traced_model, './model/Test/traced_model.pth')
# 加载模型
load_scripted_model = torch.jit.load('./model/Test/scripted_model.pth')
print(load_scripted_model)
load_traced_model = torch.jit.load('./model/Test/traced_model.pth')
print(load_traced_model)
MyModel(
(conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc): Linear(in_features=4096, out_features=10, bias=True)
)
RecursiveScriptModule(
original_name=MyModel
(conv): RecursiveScriptModule(original_name=Conv2d)
(fc): RecursiveScriptModule(original_name=Linear)
)
RecursiveScriptModule(
original_name=MyModel
(conv): RecursiveScriptModule(original_name=Conv2d)
(fc): RecursiveScriptModule(original_name=Linear)
)
例証します:
-
RecursiveScriptModule
ツリー構造に似た、再帰的な TorchScript モジュールを表します。- モジュールの元の名前は で
MyModel
、これがモデルのコンテナであることを示しています。
- モジュールの元の名前は で
-
コンテナには 2 つのサブモジュール
conv
と が含まれておりfc
、これらはConv2d
とLinear
の再帰スクリプト モジュールです。これは、これら 2 つのサブモジュールも TorchScript モジュールであり、TorchScript で操作できることを意味します。 -
RecursiveScriptModule
torch.jit.script
PyTorch モデルは、またはを介してtorch.jit.trace
TorchScript モジュールに変換できます。変換プロセス中に、各サブモジュールも、親モジュール内にネストされた、対応する TorchScript モジュールに変換されます。 -
この入れ子構造は、深層学習モデルの階層をよく表現できます。
-
RecursiveScriptModule
のモジュール名と元の名前には、original_name
属性を通じてアクセスできます。- たとえば、
MyModel
モジュールの元の名前はMyModel
、conv
モジュールの元の名前はConv2d
、fc
モジュールの元の名前は ですLinear
。
- たとえば、