PytorchのMNIST手書きデータセットに基づくRNNおよびCNNの実装

PytorchのMNIST手書きデータセットに基づくRNNおよびCNNの実装

LSTMは手書きのデータセットを処理します(分類問題)
LSTMはsinを介してcosを予測します(回帰問題)

pytorch入力パラメータ形式

input_size-入力の特徴の次元、つまり、単語の埋め込みにおけるワンホットの長さ、単語の長さは300、次にinput_sizeは300、入力画像の幅は28、次にinput_sizeは28です。

hidden_​​size-非表示状態の機能ディメンション。自分で設定できます

num_layers-レイヤーの数(シーケンス展開と区別される)、通常は1または2

バイアス– Falseの場合、LSTMはbih、bhh b_ {ih}、b_ {hh}を使用しませんbおよびhb時間時間、デフォルトはTrueです。

batch_first – Trueの場合、入力および出力Tensorの形状は(batch、time_step、input_size)、それ以外の場合(time_step batch、input_size)です。

ドロップアウト-ゼロ以外の場合、最後のレイヤーを除いて、ドロップアウトがRNNの出力に追加されます。

双方向-Trueの場合、双方向RNNになり、デフォルトはFalseです。


time_step:長さは文の長さ、文の単語数、
batch_size:バッチで送信されるrnnの数
次のコードはlstmからのものです

lstm输入是input, (h_0, c_0)
input (time_step, batch, input_size) 如果设置了batch_first,则batch为第一维。

(h_0, c_0) 隐层状态
h0 shape:(num_layers * num_directions, batch, hidden_size) 
c0 shape:(num_layers * num_directions, batch, hidden_size)

lstm输出是output, (h_n, c_n)
output (time_step, batch, hidden_size * num_directions) 包含每一个时刻的输出特征,如果设置了batch_first,则batch为第一维
(h_n, c_n) 隐层状态 
h_n shape: (num_layers * num_directions, batch, hidden_size)
c_n shape: (num_layers * num_directions, batch, hidden_size)

一部のドキュメントでは、time_stepとseq_lenの両方がタイムステップを表します

class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()

        self.rnn = nn.LSTM(         
            input_size=INPUT_SIZE,
            hidden_size=64,        
            num_layers=1,          
            batch_first=True,     #(batch, time_step, input_size)
        )

        self.out = nn.Linear(64, 10)  #Linear(num_layer*hidden_size,分类的个数)

    def forward(self, x):
        # x shape (batch, time_step, input_size)
        # r_out shape (batch, time_step, output_size)
        # h_n shape (n_layers, batch, hidden_size)
        # h_c shape (n_layers, batch, hidden_size)
        r_out, (h_n, h_c) = self.rnn(x, None)   # 初始状态为None
        out = self.out(r_out[:, -1, :])
        return out

その中で、out = self.out(r_out [:、-1、:])xtとht-1を入力するとr_outと(h_n、h_c)が生成されるため、この文はlineraレイヤーの入力として最後の瞬間の出力としてtime_stepを取ります。 )、生成されたhnは引き続きrnnの入力に送信され、次のr_outを生成し続けるため、最後の瞬間の出力を入力として使用する必要があります。

完全なコードがコンパイルされた後、
CNN実装の手書きデータセットコードは次のようにリリースされます

import torch
import torch.nn
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets,transforms
import torch.optim as optim
import torch.nn.functional as F



BATCH_SIZE=512
EPOCHS=20
DEVICE=torch.device("cuda" if torch.cuda.is_available() else "cpu")


transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_data=datasets.MNIST(root='./mnist',download=True,train=True,transform=transform)
test_data=datasets.MNIST(root='./mnist',download=True,train=False,transform=transform)

train_loader=DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)
test_loader=DataLoader(test_data,batch_size=BATCH_SIZE,shuffle=True)







class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)

        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=10, out_channels=20, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.out = nn.Linear(20 * 5 * 5, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        output = self.out(x)
        return output


model = Net()  # 实例化网络net,再送入gpu训练
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()


def train(model, device, train_loader, optimizer, epoch, criterion):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):

        output = model(data)

        # loss=criterion(output,target)

        optimizer.zero_grad()
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        if (batch_idx + 1) % 30 == 0:  # train_loader的长度为train_loader.dataset的长度除以batch_size
            print('Train Epoch:{} [{}/{} ({:.0f}%)]\tLoss:{:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader),
                loss.item()
            ))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    test_corr = 0
    with torch.no_grad():
        for img, label in test_loader:
            output = model(img)
            test_loss += criterion(output, label)
            pred = output.max(1, keepdim=True)[1]
            test_corr += pred.eq(label.view_as(pred)).sum().item()

        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, test_corr, len(test_loader.dataset), 100. * (test_corr / len(test_loader.dataset))
        ))


for epoch in range(1, EPOCHS + 1):
    train(model, DEVICE, train_loader, optimizer, epoch, criterion)
    test(model, DEVICE, test_loader)

トレーニング結果は次のとおり
です。トレインエポック:1 [14848/60000(25%)]損失:0.809119
トレインエポック:1 [30208/60000(50%)]損失:0.332066
トレインエポック:1 [45568/60000(75%)]損失: 0.248601

テストセット:平均損失:3.3879、精度:9515/10000(95%)

列車エポック:2 [14848/60000(25%)]損失:0.200926
列車エポック:2 [30208/60000(50%)]損失:0.167642
列車エポック:2 [45568/60000(75%)]損失:0.129635

テストセット:平均損失:1.9960、精度:9700/1000(97%)

列車エポック:3 [14848/60000(25%)]損失:0.097073
列車エポック:3 [30208/60000(50%)]損失:0.078300
列車エポック:3 [45568/60000(75%)]損失:0.095262

テストセット:平均損失:1.5412、精度:9764/10000(98%)

列車エポック:4 [14848/60000(25%)]損失:0.067570
列車エポック:4 [30208/60000(50%)]損失:0.091387
列車エポック:4 [45568/60000(75%)]損失:0.058170

テストセット:平均損失:1.3722、精度:9795/10000(98%)

列車エポック:5 [14848/60000(25%)]損失:0.081385
列車エポック:5 [30208/60000(50%)]損失:0.069488
列車エポック:5 [45568/60000(75%)]損失:0.108909

テストセット:平均損失:1.1676、精度:9818/10000(98%)

列車エポック:6 [14848/60000(25%)]損失:0.060494
列車エポック:6 [30208/60000(50%)]損失:0.070833
列車エポック:6 [45568/60000(75%)]損失:0.085588

テストセット:平均損失:1.0887、精度:9833/10000(98%)

列車エポック:7 [14848/60000(25%)]損失:0.067081
列車エポック:7 [30208/60000(50%)]損失:0.082414
列車エポック:7 [45568/60000(75%)]損失:0.045014

テストセット:平均損失:1.0601、精度:9837/10000(98%)

列車エポック:8 [14848/60000(25%)]損失:0.062390
列車エポック:8 [30208/60000(50%)]損失:0.048241
列車エポック:8 [45568/60000(75%)]損失:0.042879

テストセット:平均損失:0.9528、精度:9836/10000(98%)

列車エポック:9 [14848/60000(25%)]損失:0.048539
列車エポック:9 [30208/60000(50%)]損失:0.055073
列車エポック:9 [45568/60000(75%)]損失:0.055796

テストセット:平均損失:0.8623、精度:9866/10000(99%)

列車エポック:10 [14848/60000(25%)]損失:0.051431
列車エポック:10 [30208/60000(50%)]損失:0.045435
列車エポック:10 [45568/60000(75%)]損失:0.075674

テストセット:平均損失:0.7783、精度:9874/10000(99%)

列車エポック:11 [14848/60000(25%)]損失:0.028392
列車エポック:11 [30208/60000(50%)]損失:0.049267
列車エポック:11 [45568/60000(75%)]損失:0.042472

テストセット:平均損失:0.8189、精度:9875/10000(99%)

列車エポック:12 [14848/60000(25%)]損失:0.058731
列車エポック:12 [30208/60000(50%)]損失:0.025470
列車エポック:12 [45568/60000(75%)]損失:0.029647

テストセット:平均損失:0.7829、精度:9871/10000(99%)

列車エポック:13 [14848/60000(25%)]損失:0.052567
列車エポック:13 [30208/60000(50%)]損失:0.028609
列車エポック:13 [45568/60000(75%)]損失:0.020649

テストセット:平均損失:0.7527、精度:9872/10000(99%)

列車エポック:14 [14848/60000(25%)]損失:0.039200
列車エポック:14 [30208/60000(50%)]損失:0.019106
列車エポック:14 [45568/60000(75%)]損失:0.067107

テストセット:平均損失:0.7386、精度:9886/10000(99%)

列車エポック:15 [14848/60000(25%)]損失:0.038181
列車エポック:15 [30208/60000(50%)]損失:0.022419
列車エポック:15 [45568/60000(75%)]損失:0.016036

テストセット:平均損失:0.7954、精度:9862/10000(99%)

列車エポック:16 [14848/60000(25%)]損失:0.018675
列車エポック:16 [30208/60000(50%)]損失:0.039494
列車エポック:16 [45568/60000(75%)]損失:0.017992

テストセット:平均損失:0.8029、精度:9859/10000(99%)

列車エポック:17 [14848/60000(25%)]損失:0.019442
列車エポック:17 [30208/60000(50%)]損失:0.014947
列車エポック:17 [45568/60000(75%)]損失:0.024432

テストセット:平均損失:0.6863、精度:9874/10000(99%)

列車エポック:18 [14848/60000(25%)]損失:0.013267
列車エポック:18 [30208/60000(50%)]損失:0.022075
列車エポック:18 [45568/60000(75%)]損失:0.024906

テストセット:平均損失:0.6707、精度:9887/10000(99%)

列車エポック:19 [14848/60000(25%)]損失:0.031900
列車エポック:19 [30208/60000(50%)]損失:0.014791
列車エポック:19 [45568/60000(75%)]損失:0.037303

テストセット:平均損失:0.7329、精度:9878/10000(99%)

列車エポック:20 [14848/60000(25%)]損失:0.030795
列車エポック:20 [30208/60000(50%)]損失:0.016112
列車エポック:20 [45568/60000(75%)]損失:0.020148

テストセット:平均損失:0.6894、精度:9884/10000(99%)

おすすめ

転載: blog.csdn.net/qq_41333847/article/details/109296140
おすすめ