PyTorchは、ResNet18転送学習に基づくポケモンデータセット分類を実装します

1.実装プロセス

1.データセットの説明

データセットは、次のように5つのカテゴリに分類されます。

  • ピカチュウ:234
  • ミュウツー:239
  • ジェニータートル:223
  • リトルファイアドラゴン:238
  • カエルの種:234

セルフフェッチリンク:https
://pan.baidu.com/s/1bsppVXDRsweVKAxSoLy4sw抽出コード:9fqo
画像ファイルの拡張子にはjpg、jepg、png、gifの4種類があり、画像のサイズが同じではないため、画像のサイズ変更などの操作を行うために必要です。本稿では、画像サイズを224×224サイズに変更します。

2.データの前処理

このホワイトペーパーでは、データセットフレームワークを使用してデータセットを前処理し、画像データセットを{images、labels}などのマッピング関係に変換します。

    def __init__(self, root, resize, mode):
        super(Pokemon, self).__init__()

        self.root = root
        self.resize = resize

        self.name2label = {
    
    }    # "sq...": 0
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue
            self.name2label[name] = len(self.name2label.keys())
        # print(self.name2label)

        # image,label
        self.images, self.labels = self.load_csv('images.csv')

        # 数据集裁剪:训练集、验证集、测试集
        if mode == 'train': # 60%
            self.images = self.images[0:int(0.6*len(self.images))]
            self.labels = self.labels[0:int(0.6*len(self.labels))]
        elif mode == 'val': # 20% = 60% -> 80%
            self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        else:               # 20% = 80% -> 100%
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]

その中で、rootはデータセットが保存されているファイルルートディレクトリを表し、 resizeはデータセット出力の均一なサイズを表し、 modeはデータセットを読み取るときのモード(train、val、test)を表します。name2labelは画像カテゴリの名前とラベル。画像カテゴリのラベルを取得すると便利です。load_csvメソッドは{images、labels}のマッピング関係を作成することです。ここで、imagesは画像が配置されているファイルパスを表し、コードは次のとおりです。次のように:

    def load_csv(self, filename):
        if not os.path.exists(os.path.join(self.root, filename)):
            # 文件不存在,则需要创建该文件
            images = []
            for name in self.name2label.keys():
                # pokemon\\mewtwo\\00001.png
                images += glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
                images += glob.glob(os.path.join(self.root, name, '*.gif'))
            # 1168, 'pokemon\\bulbasaur\\00000000.png'
            print(len(images),images)
            # 保存成image,label的csv文件
            random.shuffle(images)
            with open(os.path.join(self.root, filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                for img in images:  # 'pokemon\\bulbasaur\\00000000.png'
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    # 'pokemon\\bulbasaur\\00000000.png', 0
                    writer.writerow([img, label])
                # print('writen into csv file:',filename)
        # 加载已保存的csv文件
        images, labels = [],[]
        with open(os.path.join(self.root,filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img, label = row
                label = int(label)
                images.append(img)
                labels.append(label)
        assert len(images) == len(labels)
        return images, labels

データセットのサイズとインデックス要素の位置を取得するコードは次のとおりです。

    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        # idx:[0, len(self.images)]
        # self.images, self.labels
        # img:'G:/datasets/pokemon\\charmander\\00000182.png'
        # label: 0,1,2,3,4
        img, label = self.images[idx], self.labels[idx]
        transform = transforms.Compose([
            lambda x: Image.open(x).convert('RGB'),  # string path => image data
            transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
            transforms.RandomRotation(15),      # 随机旋转
            transforms.CenterCrop(self.resize), # 中心裁剪
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.485,0.456,0.406],
            #                      std=[0.229,0.224,0.225])
            transforms.Normalize(mean=[0.6096, 0.7286, 0.5103],
                                 std=[1.5543, 1.4887, 1.5958])
        ])

        img = transform(img)
        label = torch.tensor(label)
        return img, label

その中で、変換での平均とstdの計算を参照してください。正規化するか、経験値mean = [0.485、0.456、0.406]とstd = [0.229、0.224、0.225]を直接使用します。
Visdom視覚化ツールによって表示されるbatch_size=32の画像を次の図に示します。
ここに画像の説明を挿入

2.設計モデル

このペーパーでは、移行学習の概念を採用し、resnet18分類子を直接使用し、ネットワーク構造の最初の17層を保持し、それに応じて最後の層を変更します。コードは次のとおりです。

trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],     # [b,512,1,1]
                      Flatten(),   # [b,512,1,1] => [b,512]
                      nn.Linear(512, 5)
                      ).to(device)

その中で、Flatten()はデータフラット化メソッドであり、コードは次のとおりです。

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)

3.損失関数とオプティマイザを構築します

損失関数はクロスエントロピーを使用し、オプティマイザーはAdamを使用し、学習率は0.001に設定されています。コードは次のとおりです。

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

4.トレーニング、検証、およびテスト

	best_acc, best_epoch = 0, 0
    global_step = 0
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
    for epoch in range(epochs):
        for step, (x,y) in enumerate(train_loader):
            # x: [b,3,224,224]  y: [b]
            x, y = x.to(device), y.to(device)
            output = model(x)
            loss = criterion(output, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1
        # 验证集
        if epoch % 1 == 0:
            val_acc = evaluate(model, val_loader)
            if val_acc > best_acc:
                best_acc = val_acc
                best_epoch = epoch
                torch.save(model.state_dict(), 'best.mdl')
                viz.line([val_acc], [global_step], win='val_acc', update='append')

    print('best acc:', best_acc, 'best epoch:', best_epoch+1)
    # 加载最好的模型
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')
    test_acc = evaluate(model, test_loader)
    print('test acc:', test_acc)
def evaluate(model, loader):
    correct = 0
    total = len(loader.dataset)
    for (x, y) in loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            output = model(x)
            pred = output.argmax(dim=1)
            correct += torch.eq(pred, y).sum().item()
    return correct/total

5.テスト結果

トレーニングセットの損失値の変化曲線とテストセットの精度の変化曲線を次の図に示します。
ここに画像の説明を挿入コンソール出力は次のとおりです。

best acc: 0.9358974358974359 best epoch: 3
loaded from ckpt!

test acc: 0.9401709401709402

これは、epoch = 3の場合、検証セットの精度が最高に達し、この時点のモデルが最良のモデルと見なされ、テストセットのテストに使用され、94.02の精度に達することを示しています。 %。

2.参考文献

[1] https://www.bilibili.com/video/BV1f34y1k7fi?p=106
[2] https://blog.csdn.net/Weary_PJ/article/details/122765199

おすすめ

転載: blog.csdn.net/weixin_43821559/article/details/123561478