本稿では、著者:ファン・ユングは、承認後にリリース。
元のリンク:https : //cloud.tencent.com/developer/article/1546403
この記事では、主にpytorchを使用してresnet18モデルをトレーニングし、cifar10を分類してから、cifar10のデータを調整し、トレーニング済みモデルをロードし、調整したデータをFINETUNINGで元のモデルに分類します。pytorch 公式ウェブサイトのチュートリアルを参照してください。
resnet18モデル
pytorchのresnet18モデルリファレンス:https : //github.com/kuangliu/pytorch-cifar
モデルの詳細については、githubのmodels / resnet.pyを参照してください。詳細な説明はここには記載されていません。Readmeの説明の正確度は93.02%に達する可能性がありますが、ローカルテストの200回の反復がこの数値に達せず、ローカル200回の反復の正確度は87.40%でした。 。
必要なパッケージをインポートする
import os
import numpy as np
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import *
from utils import progress_bar
結果を再現可能にするためにランダムシードを設定します
ここで長い間試してみました。CPUを実行するにはtorch.manual_seed(SEED)を設定するだけで、安定して結果を再現できます。ただし、GPUではまだ機能しません。後でランダム性の問題が常にあります。後で、友人の助けを借りて、公式を確認しましたデータは最終的には、感謝の問題を解決しました。その中でも、テンソルフローはGPUで結果を出すことができず、着実に再現できるようです。ご存知の方がいらっしゃいましたら、指導をお願いします〜
SEED = 0
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)
CPUまたはGPUで実行するかどうかを設定します
デバイスを実行するために選択できるGPUがあるかどうかに応じて、ドライバーのインストール、バージョンの互換性に注意してください。ドライバーは長い間私を苦しめてきました。。dockerで実行しているため、ダウンロードしたドライバーのバージョンに一貫性がないため、GPUが検出されていません
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0
start_epoch = 0
データの読み込みと前処理
データは、pyファイルと同じディレクトリの下のデータフォルダーに保存されます。データが存在しない場合、ダウンロードはTrueに設定され、pytorchから自動的にダウンロードされます。ここでは、データの変換方法が異なり、データの多様性が向上します。
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
データセットを調整する
元のcifarデータセットには10のカテゴリが含まれています
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
FINETUNINGを練習する必要があるため、データセットが10カテゴリから2カテゴリ(動物と車両)に変更されました。馬は交通機関ですか?^。^
clz_idx = trainset.class_to_idx
clz_to_idx = {'animal': 0, 'transport': 1}
clz = ['animal', 'transport']
animal_name = ["bird", "cat", "deer", "dog", "frog", "horse"]
animal = [clz_idx[x] for x in animal_name]
trainset.targets = [0 if x in animal else 1 for x in trainset.targets]
trainset.class_to_idx = clz_to_idx
trainset.classes = clz
testset.targets = [0 if x in animal else 1 for x in testset.targets]
testset.class_to_idx = clz_to_idx
testset.classes = clz
事前トレーニング済みモデルを読み込む
モデルはチェックポイントディレクトリに保存されます。モデルのトレーニングは上記のResnet18です。GPUトレーニングの場合は、if内のコードの順序に特に注意してください。
- 並列トレーニングのためにネットをDataParallelに置き換えます。元のResnet18はGPUでのトレーニングにDataParallelを使用したため、ここでもカプセル化する必要があり、モジュールのレイヤーがパッケージ化されます。
- FINETUNING:最後のレイヤーの10種類の出力を2種類の出力に変更します。gpu、net.module.linearの表現に注意してください
net = net.to(device)
モデルを変更した後、モデルをGPUにプッシュする必要があります。この手順を進めることはできず、パラメーターがGPUにないというエラーが発生します。
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/ckpt.pth')
net = ResNet18()
if device == 'cuda':
net = torch.nn.DataParallel(net)
net.load_state_dict(checkpoint['net'])
net.module.linear = nn.Linear(net.module.linear.in_features, 2)
else:
net.load_state_dict(checkpoint['net'])
net.linear = nn.Linear(net.linear.in_features, 2)
net = net.to(device)
調整する必要のないレイヤーの数を指定します
最初の40層のパラメータは固定されており、学ぶ必要はありません
for idx, (name, param) in enumerate(net.named_parameters()):
if idx > 40: # count of layers is 62
param.requires_grad = False
if param.requires_grad == True:
print("\t", idx, name)
損失関数と最適化アルゴリズム
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
トレーニング機能とテスト機能
唯一の精度向上に基づいて進める一方、main.pyで参考Resnet18は、フォローアップ研修を継続する訓練の結果の保存、テスト時には、ケースフォルダは、保存された保存を。
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
# print('%d/%d, [Loss: %.03f | Acc: %.3f%% (%d/%d)]'
# % (batch_idx+1, len(trainloader), train_loss/(batch_idx+1), 100.*correct/total, correct, total))
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
best_acc = 0
def test(epoch):
global best_acc
net.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
# Save checkpoint.
acc = 100. * correct / total
if acc > best_acc:
print('Saving..')
state = {
'net': net.state_dict(),
'acc': acc,
'epoch': epoch,
}
if not os.path.isdir('checkpoint_ft'):
os.mkdir('checkpoint_ft')
torch.save(state, './checkpoint_ft/ckpt.pth')
best_acc = acc
トレーニングを開始
トレーニングはすでにトレーニングされたモデルに基づいているため、ここでの反復回数はあまり多くなくても高精度に到達できます
for epoch in range(start_epoch, start_epoch + 20):
train(epoch)
test(epoch)
結果ショー
Epoch: 0
[================================================================>] Step: 53ms | Tot: 3 391/391 Loss: 0.520 | Acc: 88.662% (44331/50000)
[================================================================>] Step: 21ms | Tot 100/100 | Loss: 0.449 | Acc: 95.090% (9509/10000)
Saving..
Epoch: 1
[================================================================>] Step: 53ms | Tot: 3 391/391 Loss: 0.430 | Acc: 95.342% (47671/50000)
[================================================================>] Step: 20ms | Tot 100/100 | Loss: 0.411 | Acc: 95.590% (9559/10000)
Saving..
Epoch: 2
[================================================================>] Step: 53ms | Tot: 3 391/391 Loss: 0.394 | Acc: 95.816% (47908/50000)
[================================================================>] Step: 20ms | Tot 100/100 | Loss: 0.373 | Acc: 96.110% (9611/10000)
Saving..
Epoch: 3
[================================================================>] Step: 54ms | Tot: 3 391/391 Loss: 0.376 | Acc: 96.002% (48001/50000)
[================================================================>] Step: 20ms | Tot 100/100 | Loss: 0.386 | Acc: 94.560% (9456/10000)
Epoch: 4
[================================================================>] Step: 54ms | Tot: 3 391/391 Loss: 0.368 | Acc: 96.160% (48080/50000)
[================================================================>] Step: 20ms | Tot 100/100 | Loss: 0.365 | Acc: 96.350% (9635/10000)
Saving..
Epoch: 5
[================================================================>] Step: 53ms | Tot: 3 391/391 Loss: 0.362 | Acc: 96.214% (48107/50000)
[================================================================>] Step: 20ms | Tot 100/100 | Loss: 0.381 | Acc: 93.430% (9343/10000)
Epoch: 6
[================================================================>] Step: 54ms | Tot: 3 391/391 Loss: 0.360 | Acc: 96.070% (48035/50000)
[================================================================>] Step: 20ms | Tot 100/100 | Loss: 0.362 | Acc: 95.400% (9540/10000)
Epoch: 7
[================================================================>] Step: 53ms | Tot: 3 391/391 Loss: 0.358 | Acc: 96.062% (48031/50000)
[================================================================>] Step: 21ms | Tot 100/100 | Loss: 0.400 | Acc: 90.730% (9073/10000)
Epoch: 8
[================================================================>] Step: 54ms | Tot: 3 391/391 Loss: 0.356 | Acc: 96.214% (48107/50000)
[================================================================>] Step: 20ms | Tot 100/100 | Loss: 0.362 | Acc: 96.280% (9628/10000)
Epoch: 9
[================================================================>] Step: 53ms | Tot: 3 391/391 Loss: 0.353 | Acc: 96.242% (48121/50000)
[================================================================>] Step: 20ms | Tot 100/100 | Loss: 0.376 | Acc: 94.590% (9459/10000)
Epoch: 10
[================================================================>] Step: 53ms | Tot: 3 391/391 Loss: 0.352 | Acc: 96.348% (48174/50000)
[================================================================>] Step: 21ms | Tot 100/100 | Loss: 0.384 | Acc: 93.080% (9308/10000)
Epoch: 11
[================================================================>] Step: 54ms | Tot: 3 391/391 Loss: 0.351 | Acc: 96.236% (48118/50000)
[================================================================>] Step: 20ms | Tot 100/100 | Loss: 0.356 | Acc: 95.480% (9548/10000)
Epoch: 12
[================================================================>] Step: 53ms | Tot: 3 391/391 Loss: 0.350 | Acc: 96.348% (48174/50000)
[================================================================>] Step: 20ms | Tot 100/100 | Loss: 0.383 | Acc: 93.170% (9317/10000)
Epoch: 13
[================================================================>] Step: 53ms | Tot: 3 391/391 Loss: 0.348 | Acc: 96.358% (48179/50000)
[================================================================>] Step: 20ms | Tot 100/100 | Loss: 0.373 | Acc: 93.330% (9333/10000)
Epoch: 14
[================================================================>] Step: 53ms | Tot: 3 391/391 Loss: 0.347 | Acc: 96.446% (48223/50000)
[================================================================>] Step: 20ms | Tot 100/100 | Loss: 0.391 | Acc: 91.670% (9167/10000)
Epoch: 15
[================================================================>] Step: 54ms | Tot: 3 391/391 Loss: 0.346 | Acc: 96.324% (48162/50000)
[================================================================>] Step: 21ms | Tot 100/100 | Loss: 0.347 | Acc: 95.880% (9588/10000)
Epoch: 16
[================================================================>] Step: 54ms | Tot: 3 391/391 Loss: 0.344 | Acc: 96.488% (48244/50000)
[================================================================>] Step: 20ms | Tot 100/100 | Loss: 0.343 | Acc: 95.980% (9598/10000)
Epoch: 17
[================================================================>] Step: 53ms | Tot: 3 391/391 Loss: 0.344 | Acc: 96.416% (48208/50000)
[================================================================>] Step: 21ms | Tot 100/100 | Loss: 0.344 | Acc: 95.890% (9589/10000)
Epoch: 18
[================================================================>] Step: 54ms | Tot: 3 391/391 Loss: 0.344 | Acc: 96.370% (48185/50000)
[================================================================>] Step: 20ms | Tot 100/100 | Loss: 0.354 | Acc: 95.060% (9506/10000)
Epoch: 19
[================================================================>] Step: 53ms | Tot: 3 391/391 Loss: 0.344 | Acc: 96.338% (48169/50000)
[================================================================>] Step: 20ms | Tot 100/100 | Loss: 0.399 | Acc: 89.760% (8976/10000)
87.4%の精度で既存のResnet18モデルで2つの分類をFINETUNINGすることにより、最初の反復の精度は95.09%に達することができ、収束速度は非常に速く、分類効果も優れています。
最終的な20回の反復テストセットの最大値は96.11%です。
最後
pytorchモデルは比較的単純で、コードは非常に明確に見え、ドキュメントのサポートは包括的です。