参考:
【テンソルをサクッと理解】torch.randとexamplesでテンソルを解説_Neo超ハードブログ - CSDNブログ
Pytorch 公式ドキュメント学習メモ - 3. Build Model_pytorch build_model_Coding_Qi のブログ - CSDN ブログ
クイックスタート — PyTorch チュートリアル 2.0.1+cu117 ドキュメント (2 メッセージ) pytorch の基本 - モデル パラメーターの最適化 (6)_torch.optim.sgd(model.parameters(), lr=learning_ra_小さなプレーリードッグのブログ - CSDN ブログ
1. テンソル テンソル
- データ構造 (配列行列に似ています) - テンソルを使用してモデルの入力、出力、パラメーターをエンコードします
-最後の列-1
- 連結テンソル、.cat
Dim は寸法を表し、角括弧なしの場合はdim =0; 角括弧 1 個の場合はdim =1;
-matplotlib は Python 用の 2D プロット ライブラリです
2. データセットとデータローダー
図: フィーチャー; ラベル: ラベル;
カスタムデータセット(&C): 3 つの機能
_init_: データセット オブジェクトのインスタンス化、画像、アノテーション ファイル、および 2 つの変換 (transform、target_transform) の初期化時に 1 回実行します。
_len_: データセット内のサンプル数を返します。
_getitem_: 指定されたインデックスから、イメージ ラベルをテンソルに変換し、テンソル イメージと対応するラベルを返します。
3. 変身
変換とターゲット変換
正規化されたテンソルとしての特徴と、 を使用したワンホット エンコードされたテンソルとしてのラベル。ToTensor和
Lambda进行转换
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
&ToTensor は PIL 画像または NumPy を a. そのような状況に変換し、画像のピクセル強度値は [0., 1.] の範囲内になります。
&Lambda 変換: 整数をワンホット エンコードされたテンソルに変換する Lambda 関数です。
通常のステータス コードは 000,001,010,011,100,101、
ワンホット エンコーディングは 000001,000010,000100,001000,010000,100000 です。
target_transform = Lambda(lambda y: torch.zeros(
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
首先创建一个大小为 10(我们数据集中的标签数量)的零张量,并调用 scatter_ 在标签给出的索引上分配一个
scatter_(input, dim,index, src):index のインデックスに従って、src のデータを dim の方向に入力に埋め込みます。要素の配置または要素の変更として理解できます。
- dim: どの次元に沿ってインデックスを付けるか
- Index: スキャッターに使用される要素インデックス
- src: スキャッターに使用されるソース要素。スカラーまたはテンソルにすることができます。
4. モデルを構築する
- nn.Module をサブクラス化してニューラル ネットワークを定義し、__init__ を使用してニューラル ネットワーク層を初期化し、各 nn.Module サブクラスが forward メソッドで入力データの操作を実装します。
-Flatten 層は、入力を「平坦化」する、つまり多次元入力を 1 次元に変換するために使用され、畳み込み層から全結合層への移行でよく使用されます。(3, 32, 64) は要素数が 3*32*64=6144 個の 3 次元データであり、この 3 次元データを直線に引くと、直線の長さは 6144 になります。
-nn.Sequential
クラスは torch.nn
コンテナ内のシーケンスコンテナの一種であり、コンテナ内にニューラルネットワークの特定の機能に関連する様々なクラスを入れ子にすることでニューラルネットワークモデルの構築が完了します、このクラスの括弧内がニューラルネットワークです私たちが構築したモデルの具体的な構造
-nn.Linear: モデルの線形層を定義し、前述の線形変換を完了するために使用されます。パラメーターは (入力フィーチャの数、出力フィーチャの数、バイアスを使用するかどうか (デフォルトは true))、および重みです。対応する次元のパラメータとバイアスが自動的に生成されます。
-nn.ReLU
このクラスは非線形活性化分類に属しており、定義時にデフォルトでパラメータを渡す必要はありません。
-logits は、通常、次のステップでソフトマックスにスローされるベクトルです 。ソフトマックス正規化指数関数
- torch.rand の機能は、平たく言えば均一に分散されたデータを生成することであり、torch.rand() の括弧内にいくつかの数値を入力すると、複数の次元のテンソルが生成されます。
x = torch.rand(3,4): 2 次元テンソル、3 行 4 列
3 次元テンソルも比較的理解しやすく、2 次元テンソルは平面とみなすことができ、3 次元テンソルは多数の 2 次元テンソル平面が平行に配置されたものとみなすことができます。
たとえば、一般的な RGB 画像は、並べて配置された 3 つの 2 次元のグレースケール画像として理解できます。
5. Autograd 自動導出
勾配: パラメータに関する損失関数の導関数
逆伝播アルゴリズム: パラメーター (モデルの重み) は、指定されたパラメーターに対する損失関数の勾配に従って調整されます。
torch.autograd は、 あらゆる計算グラフの勾配の自動計算をサポートします。
6. 最適化最適化パラメータ
- モデルのトレーニングを繰り返すたびに、モデルは出力を推測し、推測と実際のラベル間の誤差を計算し、パラメーターに関する誤差の導関数を収集し、勾配降下法を使用してこれらのパラメーターを最適化します。
-ハイパーパラメータを調整してモデルの最適化プロセスを制御します。異なるハイパーパラメータ値はモデルのトレーニングと収束速度に影響します
- 損失: 特定のデータ サンプルの入力を使用して予測を行い、それを実際のデータ ラベル値と比較します。
-SGD オプティマイザー
- トレーニング ループ、最適化のための 3 つのステップ
·optimizer.zero_grad()を呼び出して、モデルパラメータの勾配をリセットします。勾配はデフォルトで合計されます。二重カウントを防ぐために、反復ごとに明示的にゼロにします。
· loss.backward() を呼び出して、予測損失を逆伝播します。PyTorch は、損失に関連する各パラメーターの勾配を保存します。
·勾配を取得したら、optimizer.step() を呼び出して、バックプロパゲーション中に収集された勾配を通じてパラメーターを調整します。
トレーニングループ最適化コード:
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad() #用于清空优化器中的梯度
loss.backward() #计算损失函数对参数的梯度,自动求导
optimizer.step() #根据梯度更新网络参数的值
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
勾配降下法アルゴリズムを使用してモデル パラメーターを更新します。このアルゴリズムでは、モデル パラメーターに対する損失関数の勾配を計算する必要があります。この計算プロセスはバックプロパゲーション アルゴリズムです。
loss.backward() の機能は、損失関数を導出し、損失関数に対する各モデルパラメータの勾配を取得することです。この勾配は、現在の状態での損失関数に対するモデル パラメーターの寄与の大きさと方向、つまりパラメーター更新の方向と大きさを表すことができます。
更新されたパラメータは、次のフォワードパス計算とバックプロパゲーション計算に使用されます。
7. モデルの保存とロード モデルの保存とロード
PyTorch モデルは、学習したパラメーターを state_dictと呼ばれる 内部状態辞書に保存します。これらのパラメータは torch.save メソッドを通じて保存できます。