torch.jit.script と torch.jit.trace

torch.jit.script と torch.jit.trace

torch.jit.script および はtorch.jit.trace、モデルをスクリプトに変換したり、モデルの実行を追跡したりするための PyTorch のツールです。

これらは PyTorch の Just-in-Time Compilation モジュールの一部であり、モデルの実行効率を向上させ、モデルのデプロイメントをサポートするために使用されます。

トーチ.jit.スクリプト

torch.jit.scriptモデルをスクリプトに変換する関数です。

PyTorch モデルを入力として受け取り、それを実行可能なスクリプトに変換します。変換されたスクリプトは、通常の Python 関数と同様に呼び出すことも、ディスクに保存して PyTorch 依存関係のない環境で実行することもできます。

この変換の利点は、Python インタープリターのオーバーヘッドが排除されるため、モデル実行時のオーバーヘッドが軽減されることです。

例:

import torch


# 定义模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.fc = torch.nn.Linear(64 * 8 * 8, 10)

    def forward(self, x):
        x = self.conv(x)
        x = torch.nn.functional.relu(x)
        x = x.view(-1, 64 * 8 * 8)
        x = self.fc(x)
        return x


model = MyModel()

# 将模型转换为Torch脚本模块
scripted_model = torch.jit.script(model)

# 调用
output = scripted_model(torch.randn(1, 3, 32, 32))
print(output)

# 保存模型
torch.jit.save(scripted_model, './model/Test/scripted_model.pth')

トーチ.ジット.トレース

torch.jit.traceモデルの実行を追跡する関数です。

モデルとサンプル入力を受け取り、指定された入力でのモデルの実行を記録し、トレースされたモデルを返します。

トレース モデルは、同じ機能を持つスクリプト モデルとして見ることができますが、元のモデルの動的な性質も保持しており、動的なグラフや制御フローなどのより高度な機能を使用することもできます。

例:

import torch


# 定义模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.fc = torch.nn.Linear(64 * 8 * 8, 10)

    def forward(self, x):
        x = self.conv(x)
        x = torch.nn.functional.relu(x)
        x = x.view(-1, 64 * 8 * 8)
        x = self.fc(x)
        return x


model = MyModel()

# 将模型转换为Torch脚本模块
traced_model = torch.jit.trace(model, torch.randn(1, 3, 32, 32))

# 调用
output = traced_model(torch.randn(1, 3, 32, 32))
print(output)

# 保存模型
torch.jit.save(traced_model, './model/Test/traced_model.pth')

知らせ

このメソッドは指定された入力テンソルの実行パスのみをトレースするためtorch.jit.trace、変換されたモジュール オブジェクトを推論に使用する場合、入力テンソルはトレースに使用されたものと同じ次元とデータ型である必要があります。

トーチ.ジット.保存

torch.jit.scriptまたはを使用しtorch.jit.traceて変換されたモジュール オブジェクトは、推論に直接使用することも、torch.jit.save必要に応じてモデルをロードするメソッドを使用してファイルに保存することもできます。

トーチ.jit.load

torch.jit.loadPyTorch モデルは、モデル ファイル パスまたはファイル オブジェクトを入力パラメーターとして受け入れることができる関数を使用してロードできます。具体的な手順は次のとおりです。

  • モデル ファイルをロードします。
import torch

model = torch.jit.load("model.pt")

これにより、model.ptという名前のモデル ファイルがロードされます。

  • モデル ファイルをロードし、デバイスを指定します。
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.load("model.pt", map_location=device)

これにより、model.pt指定されたモデル ファイルがロードされ、利用可能な CUDA デバイスに配置されます。

  • モデル ファイルをロードし、evalスキーマを使用します。
import torch

model = torch.jit.load("model.pt")
model.eval()

model.ptこれにより、指定されたモデル ファイルがロードされ、評価モードに変換されます。

知らせ:

モデルが CUDA などの特定のデバイスを使用している場合は、モデルをロードするときにそのデバイスが使用可能であることを確認する必要があります。デバイスが使用できない場合は、map_locationパラメータを使用してモデルを使用可能なデバイスにマッピングする必要があります。

コード

import torch


# 定义模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.fc = torch.nn.Linear(64 * 8 * 8, 10)

    def forward(self, x):
        x = self.conv(x)
        x = torch.nn.functional.relu(x)
        x = x.view(-1, 64 * 8 * 8)
        x = self.fc(x)
        return x


model = MyModel()
print(model)

# 将模型转换为Torch脚本模块
scripted_model = torch.jit.script(model)
traced_model = torch.jit.trace(model, torch.randn(1, 3, 32, 32))

# 调用
output_scripted = scripted_model(torch.randn(1, 3, 32, 32))
output_traced = traced_model(torch.randn(1, 3, 32, 32))

# 保存模型
torch.jit.save(scripted_model, './model/Test/scripted_model.pth')
torch.jit.save(traced_model, './model/Test/traced_model.pth')

# 加载模型
load_scripted_model = torch.jit.load('./model/Test/scripted_model.pth')
print(load_scripted_model)

load_traced_model = torch.jit.load('./model/Test/traced_model.pth')
print(load_traced_model)
MyModel(
  (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc): Linear(in_features=4096, out_features=10, bias=True)
)
RecursiveScriptModule(
  original_name=MyModel
  (conv): RecursiveScriptModule(original_name=Conv2d)
  (fc): RecursiveScriptModule(original_name=Linear)
)
RecursiveScriptModule(
  original_name=MyModel
  (conv): RecursiveScriptModule(original_name=Conv2d)
  (fc): RecursiveScriptModule(original_name=Linear)
)

例証します:

  • RecursiveScriptModuleツリー構造に似た、再帰的な TorchScript モジュールを表します。

    • モジュールの元の名前は でMyModel、これがモデルのコンテナであることを示しています。
  • コンテナには 2 つのサブモジュールconvと が含まれておりfc、これらはConv2dLinearの再帰スクリプト モジュールです。これは、これら 2 つのサブモジュールも TorchScript モジュールであり、TorchScript で操作できることを意味します。

  • RecursiveScriptModuletorch.jit.scriptPyTorch モデルは、またはを介し​​てtorch.jit.traceTorchScript モジュールに変換できます。変換プロセス中に、各サブモジュールも、親モジュール内にネストされた、対応する TorchScript モジュールに変換されます。

  • この入れ子構造は、深層学習モデルの階層をよく表現できます。

  • RecursiveScriptModuleのモジュール名と元の名前には、original_name属性を通じてアクセスできます。

    • たとえば、MyModelモジュールの元の名前はMyModelconvモジュールの元の名前はConv2dfcモジュールの元の名前は ですLinear

おすすめ

転載: blog.csdn.net/m0_70885101/article/details/131498340
おすすめ