PyTorch 使用上の注意事項 (ユースケース コード + 詳細な注意事項付き)

カスタム データセット - データセットの簡単な使用

# 对 Dataset 类用例介绍
from torch.utils.data import Dataset
# utils为torch中常用的工具区,从其中的data区,导入与文件读写操作相关的Dataset
from PIL import Image  # 用于图片操作,也可以使用cv2
import os  # python 当中关于系统的一个库,用于获取所有图片的地址


class MyData(Dataset):  # 创建一个class继承Dataset列表

    def __init__(self, root_dir, label_dir):  # 根据类创建特例、实例的时候运行的函数。为整个class提供一些全局变量。
        self.root_dir = root_dir  # 不同函数之间,变量不互通,通过self相当于创造了全局变量
        self.label_dir = label_dir
        self.path = os.path.join(root_dir, label_dir)  # Windows 和 Linux 操作系统对于路径的解析是不一样的,使用join函数避免错误
        self.img_path = os.listdir(self.path)  # dir 即为文件夹,通过列表的方式读取文件夹中的文件。

    def __getitem__(self, idx):
        img_name = self.img_path[idx]  # 图片名称
        img_item_path = os.path.join(self.path, img_name)  # 图片相对路径
        img = Image.open(img_item_path)  # 获取图片文件
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path)  # 整个数据集的长度


root_dir = "dataset/train" 
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"
ants_dataset = MyData(root_dir, ants_label_dir)  # 获取ants数据集
bees_dataset = MyData(root_dir, bees_label_dir)  # 获取bees数据集

# 可以在 Python控制台 直接调用 ants_dataset[0] ,看输出结果,输出即为 __getitem__ return的结果
img, label = ants_dataset[1]
img.show()  # 对ants的第2张图片进行展示

train_dataset = ants_dataset + bees_dataset  # 将两个数据集合在一起

予防:

  • dirはフォルダという意味です
  • Ctrl + Shift + C でファイル パスをすばやくコピーできます
  • Ctrl + Alt + Shift + C でファイル パスをすばやくコピーできます
  • 呼び出すクラスのオブジェクトにパラメータを直接入力した場合、__getitem__関数のreturn結果が出力されます。

TensorBoard の使用

TensorBoard はデータ (関数、画像) を視覚化するためのツールであり、モデルのトレーニングに大きな役割を果たします。一般的に使用されるツールは次のとおりです:writer.add_image関数とwriter.add_scalar関数

まず必要なパッケージをインポートし、SummaryWriterインスタンス

from torch.utils.tensorboard import SummaryWriter
# 从 torch 的 utils 工具箱中导入 tensorboard,然后导入 SummaryWriter这个类
import numpy as np
from PIL import Image

writer = SummaryWriter("logs")  # 创建一个类的实例,将事件文件存储到 logs 文件夹下

Writer.add_image 関数

処理された画像を TensorBoard インスタンスに追加するために使用されます

使用説明書:

  1. 最初のパラメータは title、2 番目のパラメータは img、3 番目のパラメータは step です。
  2. 画像とステップを変更することで、1つのタイトルに複数の画像を記録できます。
  3. 画像形式は numpy.array と torch.Tensor をサポートしていますが、PIL 形式はサポートしていません。
image_path = 'dataset/train/ants_image/0013035.jpg'  # 记录图片地址,此处为相对地址,根据自己的数据集地址做修改即可。

img_PIL = Image.open(image_path)  # 通过PIL的Image导入图片,格式为PIL
# 此处可以使用 print(type(img_PIL)) 来读取其格式

img_array = np.array(img_PIL)  
# 将PIL格式的文件转化为 numpy.array 格式,因为 add_image 函数不支持PIL格式。可以直接使用opencv导入图片,格式为 torch.Tensor 符合 add_image 的要求。
# img_array 的 shape 为 HWC,而阅读 add_image 函数说明,其默认支持的为CHW格式,故需要在调用函数时添加 dataformats="HWC" 参数
writer.add_image("test", img_array, 1, dataformats="HWC")  

Writer.add_scalar 関数は次を使用します。

TensorBoard インスタンスに関数を追加するために使用されます。

使用説明書:

  1. 最初のパラメータはラベル名です。
  2. 2 番目のパラメーターは y 軸に対応します。
  3. 3 番目のパラメーターは x 軸に対応します。
for i in range(100):  # 添加 y = x 
    writer.add_scalar("y = x", i, i)  

for i in range(100):  # 添加 y = 2x 
    writer.add_scalar("y = 2x", 2 * i, i)

最後に、TensorBoard のデフォルトの close オペレーションを呼び出します。

writer.close()  

予防:

  • ヒント - エラー レポートを上手に使用する: SummaryWriter にインポートされたヘッダー ファイルは長すぎるため、覚えるのが不便です。使用する必要がある場合は SummaryWriter 関数を直接呼び出し、エラー プロンプトで导入 torch.utils.tensorboard.SummaryWriterを選択
  • 描画されたイメージを開く:ターミナルで実行して描画されたイメージtensorboard --logdir=logsを開きます。logdir は SummaryWriter によって作成されたクラスのインスタンスのフォルダーを意味します。
  • ポート番号を変更する:を追加する--port=···ことで、tensorboard --logdir=logs --port=6007
  • from PIL import ImagePython に付属する画像処理クラス。画像を開くために使用できます。

トーチビジョン.トランスフォームの使用

変換は、画像に対して何らかの変換を実行するために使用されます。一般的に使用されるツール、、、、、、ToTensorですNormalizeResizeComposeRandomCrop

まず、必要なパッケージをインポートし、SummaryWriterインスタンス

from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

writer = SummaryWriter("logs")

ToTensor形式の変換

画像形式をテンソル形式に変換する

img = Image.open("dataset/train/bees_image/39672681_1302d204d1.jpg")
trans_toTensor = transforms.ToTensor()  # 创建 ToTensor 实例对象
img_tensor = trans_toTensor(img)
writer.add_image("img_tensor", img_tensor)

ノーマライズ

画像を正規化する

使用説明書:

  1. パラメータ 1 は平均mean (sequence)です
  2. パラメータ 2 は標準偏差std (sequence)です。
print(img_tensor[0][0][0])  # 打印标准化前的图片信息
trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # 创建 Normalize 实例
img_norm = trans_norm(img_tensor)  # 对图片进行标准化
print(img_norm[0][0][0])  # 打印标准化后的图片信息
writer.add_image("img_norm", img_norm)

サイズ変更 サイズ調整

画像サイズを調整する

使用説明書:

  1. インスタンス作成時、パラメータを同時にWHに入力する場合はシーケンスである必要があり、値が1つだけ入力された場合は比例(短い方=値)でスケーリングされます。
  2. PyTorch の Resize コマンドの新しいバージョンは PIL およびテンソル形式と互換性があり、出力は入力形式と同じです。
print(img.size)  # 打印 resize 前图片的尺寸信息
trans_resize = transforms.Resize((300, 250))  # 创建 Resize实例
img_resize = trans_resize(img_tensor)  # 对图片进行 Resize 处理
print(img_resize.size)  # 打印 resize 后图片的尺寸信息
writer.add_image("Resize", img_resize)  # writer.add_image 仅支持numpy array, torch tensor 格式图片

併用を構成する

複数のtransforms機能を組み合わせる

使用説明書:

  1. Compose() のパラメータはリストである必要があります。Compose([transforms パラメータ 1, transforms パラメータ 2,...])
  2. 前のパラメータの出力形式は、後のパラメータの入力形式と一致する必要があります。
print(img.size)
trans_resize_2 = transforms.Resize(50)
trans_compose = transforms.Compose([trans_resize_2, trans_toTensor])  # 创建 Compose 实例,将 resize 和 toTensor 组合使用

img_resize_2 = trans_compose(img)
print(img_resize_2.size)
writer.add_image("Compose - Resize - 2", img_resize_2)

RandomCrop ランダムクロップ

指定されたサイズパラメータに従って、指定された画像をランダムにトリミングします。

使用説明書:

  1. WH を同時に入力する場合は、次のようなシーケンスである必要があります。transforms.RandomCrop((500, 400))
  2. 値を 1 つだけ入力すると、正方形が切り取られます。
trans_random = transforms.RandomCrop(250)  # 创建 RandomCrop 实例
for i in range(10):
    img_random = trans_random(img_tensor)  # 获取随机裁剪的图片
    writer.add_image("img_random", img_random, i)

最後に、TensorBoard のデフォルトの close オペレーションを呼び出します。

writer.close()
  • 懸念点:入力と出力のタイプ、公式ドキュメント、メソッドに必要なパラメーター
  • この関数は使用されません。Ctrlキーを押したままマウスを関数上に移動すると、対応するドキュメントをすぐに開くことができます。
  • 戻り値が不明: print、print(type())、debug
  • クラスdef __call__()内の関数は、「.」を追加せずにクラスのオブジェクトから直接呼び出すことができます。

データローダーの使用

データセット操作の読み込み用

# 头文件
from torch.utils.data import DataLoader

# 加载数据集(在数据集导入之后)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False, num_workers=0, drop_last=False)

パラメータの説明:

  • データセットはデータセットを表します
  • batch_size は、毎回ロードされる画像の数を示します。
  • shuffle は、データセットのロード後にシャッフル操作を実行するかどうかを示します。
  • num_workers はサブプロセスの数を示し、0 はメインプログラムでのみ実行されます (Windows で 0 でない場合、「BrokenPipeError」が表示される可能性があります)。
  • drop_last は、batch_size 未満の最後に残っている画像アイテムをロードするかどうかを示します。

テンソル次元の一般的な理解

最初の値は、対応する数の sub[] が最初の [] に含まれることを意味します。
2 番目の値は、対応する数の [] が最初の [] の sub[] に含まれることを意味します

[4,3,2] のテンソル
まず、テンソルには括弧が必要です。これには 4 つの括弧
torch = torch.tensor([ [ ], [ ], [ ], [ ] ]) が含まれており
、次に 4 つの括弧内にそれぞれ追加します。 3 つの括弧
torch = torch.tensor([[],[],[]],[[],[],[]],[[],[],[]],[[],[] ,[] ])
次に、3 つの括弧内に 2 つの数値があります
torch = torch.tensor([ [ [1,2],[1,3],[1,4] ],
[ [2,3],[ 3,4] 、[2,4]]、
[[1,3]、[2,4]、[3,3]]、
[[2,3]、[2,4]、[4,4] ] ] )
print(torch.shape)
の出力は[4,3,2] です。

TORCH.NNの使用

ニューラルネットワークの基本骨格 - nn.module

公式ドキュメントには次のように記載されています:
CLASS torch.nn.Module[SOURCE]
すべてのニューラル ネットワーク モジュールの基本クラス.
モデルはこのクラスもサブクラス化する必要があります.
すべてのニューラル ネットワークは、親クラスと同等の Module から継承する必要があります。

# 简单神经网络的搭建
import torch
from torch import nn  # nn -> neural network

class MyModule(nn.Module):
    # 可以通过 PyCharm 代码 -> 生成 -> 重写方法 直接添加 __init__ 函数
    def __init__(self):  # 调用父类的初始化函数,必须写
        super().__init__()

    def forward(self, input):  # 对输入进行操作的函数 input -> forward -> output
        output = input + 1
        return output


my_module = MyModule()
x = torch.tensor(1.0)
output = my_module(x)
print(output)

ニューラル ネットワークの畳み込み演算 - ニューラル ネットワークの畳み込み

コンボリューション演算は、コンボリューションカーネル(重み行列)と入力行列の対応する位置をプーリングカーネルと一致する度に乗算し、乗算後の値をすべて累積した結果を出力します。

torch.nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)

パラメータの説明:

  • kernel_size パラメータは、指定されたコンボリューション カーネルのサイズを示します。
  • 特定の値を指定する必要はありません。変換プロセスが自動的に判断します。
  • 実際、ニューラル ネットワークの畳み込み操作は、画像のパラメーターを通じてその畳み込みカーネルを継続的に調整するプロセスでもあります。
import torch
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=64)


# 创建神经网络,在其中对输入的数据进行卷积操作
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()  # 调用父类的初始化函数
        # 生成 torch.nn.Conv2d 类的对象,用于在 forward 中调用
        self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)

    def forward(self, x):
        output = self.conv1(x)
        return output


mymodule = MyModule()
writer = SummaryWriter("conv2d_logs")
step = 0

for data in dataloader:
    imgs, targets = data
    # print(imgs.shape)  # torch.Size([64, 3, 32, 32]) 参数一表示 batch_size 也即批量的大小,参数二表示 channel 也即通道的数目,参数三、四表示HW。
    output = mymodule(imgs)  # 调用神经网络进行卷积化
    # print(output.shape)  # torch.Size([64, 6, 30, 30])

    writer.add_images("input", imgs, step)
    output = torch.reshape(output, [-1, 3, 30, 30])
    # writer.add_images 只能识别3通道数的图像,故对output做reshape处理,-1表示 batch_size 的值随其他参数而自动判定
    # print(output.shape)
    writer.add_images("output", output, step)

    step += 1

writer.close()

ニューラル ネットワークの最大プーリング操作

プーリング操作では、プーリング カーネルと一致するたびに最大値のみを出力します。
入力の特徴を維持しつつ、データ量を減らして学習を高速化することが目的であり、
画像に対して操作を行う場合、解像度を下げる(モザイクと同様)ことに相当します。

MaxPool2d(kernel_size=3, ceil_mode=True)

パラメータの説明:

  • デフォルトでは、stride は kernel_size の値と等しくなります。
  • ceil_modeがFalseの場合、プーリングカーネルのサイズ未満の部分は予約されません。
import torchvision.datasets
from torch import nn
from torch.nn import MaxPool2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 准备数据集
dataset = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor(), download=True)
# 加载数据集
downloader = DataLoader(dataset, 64)


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.maxPool = MaxPool2d(kernel_size=3, ceil_mode=True)

    def forward(self, x):
        output = self.maxPool(x)
        return output


maxPool = MyModule()

writer = SummaryWriter('maxPool_logs')
step = 0
for data in downloader:
    imgs, target = data
    writer.add_images("input", imgs, step)
    output = maxPool(imgs)
    writer.add_images("output", output, step)
    step += 1

writer.close()

非線形アクティベーション 非線形アクティベーション

目的は、ネットワークにいくつかの非線形特徴を導入することであり、非線形であればあるほど、さまざまな曲線や特性に適合するモデルをトレーニングし、モデルの汎化能力を向上させることが有益です。

ReLU() 
Sigmoid()

パラメータの説明:

  • パラメーター inplace が true に設定されている場合、計算結果は入力値を直接置き換えます。
import torch
import torchvision.datasets
from torch import nn
from torch.nn import ReLU, Sigmoid
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

input = torch.tensor([[2, -1.5],
                      [-3.5, 3]])

dataset = torchvision.datasets.CIFAR10("./dataset", train=False,
                                       transform=torchvision.transforms.ToTensor, download=True)
dataloader = DataLoader(dataset, batch_size=64)


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        # torch.nn.ReLU 的使用
        self.relu = ReLU() 
        # torch.nn.Sigmoid 的使用
        self.sigmoid = Sigmoid()

    def forward(self, x):
        output = self.sigmoid(x)
        return output


mymodule = MyModule()

writer = SummaryWriter("./non_linear_act")
step = 0
for data in dataloader:
    imgs, target = data
    writer.add_images("input", imgs, step)
    output = mymodule(imgs)
    writer.add_images("output", output, step)
    step += 1

writer.close()

リニアレイヤー LINEAR

テンソルの直線の長さを変更するために使用されます

CLASS torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("dataset", train=False,
                                       transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=64, drop_last=True)


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linear = nn.Linear(196608, 30)

    def forward(self, input):
        output = self.linear(input)
        return output


mymodule = MyModule()

for data in dataloader:
    imgs, target = data
    print(imgs.shape)
    output = torch.flatten(imgs)  # 对 torch 线性处理,使用 torch.reshape 同样可以实现
    print(output.shape)
    output = mymodule(output)
    print(output.shape)

ポートフォリオ - 順次

ニューラル ネットワークでの関数演算の結合

Sequential()

パラメータの説明:
パラメータに操作する必要のある関数を順番に記述するだけです

import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.tensorboard import SummaryWriter


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


mymodule = MyModule()
print(mymodule)
input = torch.ones((64, 3, 32, 32))
output = mymodule(input)
print(output.shape)

writer = SummaryWriter("seq_logs")
writer.add_graph(mymodule, input)  # 通用 tensorboard 绘制神经网络流程图
writer.close()

損失関数 - 損失関数

機能:
1. 実際の出力と目標の間のギャップを計算します。
2. 出力を更新するための一定の基準を提供します (バックプロパゲーション)

torch.nn.L1Loss(size_average=None, reduce=None, reduction='mean')
# 将所有位置数值差的绝对值求和后取平均

torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')
# 将所有位置数值差的平方求和后取平均

torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', label_smoothing=0.0)

パラメータの説明:

  • L1Loss 関数では、入力テンソル データ型が浮動小数点数である必要があり、そのほとんどは実際の操作ではすでに浮動小数点数になっています。
  • duction='sum'を設定すると、平均値はとられず、合計結果が直接返されます。返されるデフォルト値は平均値です。
import torch
from torch.nn import L1Loss
from torch import nn

inputs = torch.tensor([1, 2, 3], dtype=torch.float32) 
targets = torch.tensor([1, 2, 5], dtype=torch.float32)


loss = L1Loss(reduction='sum') 
result = loss(inputs, targets)


loss_mse = nn.MSELoss() 
result_mse = loss_mse(inputs, targets)

print(result)
print(result_mse)


x = torch.tensor([0.1, 0.2, 0.3])
y = torch.tensor([1])
x = torch.reshape(x, (1, 3))
loss_cross = nn.CrossEntropyLoss()
result_cross = loss_cross(x, y)
print(result_cross)

ニューラル ネットワーク + バックプロパゲーションの損失関数の使用

import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)

dataloader = DataLoader(dataset, batch_size=1)


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


loss = nn.CrossEntropyLoss()
mymodule = MyModule()
for data in dataloader:
    imgs, targets = data
    outputs = mymodule(imgs)
    result_loss = loss(outputs, targets)
    result_loss.backward()  # 反向传播,用于对神经网络中 grad 的更新,便于后续的优化处理
    print("over")

注:
transform=torchvision.transforms.ToTensor() が括弧を追加するのを忘れた場合、TypeError は次のように報告されます。
TypeError: init () は 1 つの位置引数を受け取りますが、2 つが指定されました

オプティマイザー - オプティマイザー

TORCH.OPTIMはパラメータの最適化に使用されます

import torchvision
from torch import nn, optim
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("dataset", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)

dataloader = DataLoader(dataset, batch_size=1)


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


loss = nn.CrossEntropyLoss()
mymodule = MyModule()
optimizer = optim.SGD(mymodule.parameters(), lr=0.01)

for epoch in range(20):  # 扫描全部数据集,重复进行 20 次优化
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = mymodule(imgs)
        result_loss = loss(outputs, targets)
        optimizer.zero_grad()  # 将前一次计算的梯度清零
        result_loss.backward()  # 反向传播,对神经网络中的梯度进行更新
        optimizer.step()  # 调用优化器,对每个参数进行调优
        running_loss = running_loss + result_loss  # 对每一个 batch 的 loss 值做累加
    print(running_loss)

モデルの使用法

モデルの呼び出しと変更

  • Pytorch が提供する事前トレーニング済みモデルまたは事前トレーニングされていないモデルを呼び出します。
import torchvision
from torch import nn

vgg16 = torchvision.models.vgg16()

print(vgg16)
  • 呼び出されたモデルを変更する
vgg16.classifier.add_module('add_linear', nn.Linear(1000, 10))  # 在 vgg16 的 classifier 中添加一项操作,需要给一个名称
print(vgg16)

vgg16.classifier[6] = nn.Linear(4096, 10)  # 更改 vgg16 的 classifier 中对应项的函数。
print(vgg16)

モデルを保存する

  • 方法は 2 つあり、公式で推奨されているのはメモリ占有量が少ない 2 番目の方法です。
import torch
import torchvision
from torch import nn

vgg16 = torchvision.models.vgg16()

# 保存方式1,模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

# 保存方式2,模型参数(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
  • 方法 1 を使用して独自のモデルを保存します
class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3))

    def forward(self, x):
        x = self.conv1(x)
        return x


myModel_save = MyModule()
torch.save(myModel_save, "myModel_save.pth")

モデルのロード

import torchvision

# 方式1:加载 使用方式1保存的模型
model = torch.load("mymodule.pth")
print(model)

# 方式2:加载 使用方式2保存的模型
vgg16 = torchvision.models.vgg16()
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
model = torch.load("vgg16_method2.pth")
print(vgg16)

トラップ

  • モデルが方法 1 を使用して保存された場合、ロード時に、最初に必要なニューラル ネットワーク コードをインポートする必要があります。from myModel_save import *使用して直接インポートすることも、そのクラス コードを直接コピーして貼り付けることもできます

おすすめ

転載: blog.csdn.net/qq_61539914/article/details/126437285