【写给自己】成功搭建了ResNet识别MNIST

但是在RML2018上效果还是很差

网络结构

残差块

class ResidualStack(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, seq, pool_size):
        super(ResidualStack, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=1, padding=0) # (kernel_size-1)//2保证输入输出形状一样
        # Residual Unit 1
        self.conv2 = nn.Conv2d(output_channels, output_channels, kernel_size=kernel_size, stride=1, padding='same')
        self.conv3 = nn.Conv2d(output_channels, output_channels, kernel_size=kernel_size, stride=1, padding='same')
        # Residual Unit 2
        self.conv4 = nn.Conv2d(output_channels, output_channels, kernel_size=kernel_size, stride=1, padding='same')
        self.pool = nn.MaxPool2d(kernel_size=pool_size, stride=pool_size)
        self.seq = seq

    def forward(self, x):
        # Residual Unit 1
        shortcut = x
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = x + shortcut
        x = F.relu(x)
        # Residual Unit 2
        shortcut = x
        x = self.conv3(x)
        x = F.relu(x)
        x = self.conv4(x)
        x = x + shortcut
        x = F.relu(x)
        x = self.pool(x)
        return x

ResNet,数据为2*N维的IQ信号,先用(3,2)的卷积核对它卷积,得到1维的信号,之后就用(3,1)的卷积核进行卷积。

class MyResNet(nn.Module):          # 1,1024,2
    def __init__(self, num_classes):
        super(MyResNet, self).__init__()
        self.num_classes = num_classes
        self.seq0 = ResidualStack(1, 32, kernel_size=(3, 2), seq="ReStk0", pool_size=(2, 2))
        self.seq1 = ResidualStack(32, 32, kernel_size=(3, 1), seq="ReStk1", pool_size=(2, 1))
        self.seq2 = ResidualStack(32, 32, kernel_size=(3, 1), seq="ReStk2", pool_size=(2, 1))
        self.seq3 = ResidualStack(32, 32, kernel_size=(3, 1), seq="ReStk3", pool_size=(2, 1))
        self.seq4 = ResidualStack(32, 32, kernel_size=(3, 1), seq="ReStk4", pool_size=(2, 1))
        self.seq5 = ResidualStack(32, 32, kernel_size=(3, 1), seq="ReStk5", pool_size=(2, 1))
        self.fc1 = nn.Linear(192, 128)           # 64 rml, 192 mnist
        self.fc2 = nn.Linear(128, num_classes)
        self.dropout = nn.AlphaDropout(0.3)

    def forward(self, x):
        x = self.seq0(x)
        x = self.seq1(x)
        x = self.seq2(x)
        x = self.seq3(x)
        x = self.seq4(x)
        x = self.seq5(x)
        x = torch.flatten(x,start_dim=1)
        x = self.fc1(x)
        x = F.selu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

最后一层不需要softmax输出,直接线性输出即可,因为误差函数交叉熵自带softmax函数。(之前用了softmax输出,导致loss一直在2.4数值附近波动)

主程序

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transformer = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
        transforms.Resize([392,2])
    ])
    # \绝对路径 /相对路径
    # tran_data = torchvision.datasets.CIFAR10(root="D:\\桌面\\代码\\cifar10", transform=transformer, train=True, download=False)
    # test_data = torchvision.datasets.CIFAR10(root="D:\\桌面\\代码\\cifar10", transform=transformer, train=False, download=False)
    tran_data = torchvision.datasets.MNIST(root="D:\\桌面\\代码\\mnist", transform=transformer, train=True, download=False)
    test_data = torchvision.datasets.MNIST(root="D:\\桌面\\代码\\mnist", transform=transformer, train=False, download=False)

    x_tran = tran_data.data.to(device)
    y_tran = tran_data.targets.to(device)

    x_test = test_data.data.to(device)
    y_test = test_data.targets.to(device)

    num_classes = len(tran_data.classes)
    model = MyResNet(num_classes).to(device)

    x_tran = resize(x_tran,(60000,1,392,2))
    x_tran = x_tran/255
    x_test = resize(x_test,(10000,1,392,2))
    x_test = x_test/255

    train_dataset = Data.TensorDataset(x_tran, y_tran)
    test_dataset = Data.TensorDataset(x_test, y_test)
    train_dataloader = DataLoader(dataset=train_dataset, batch_size=32)
    test_dataloader = DataLoader(dataset=test_dataset, batch_size=32)

    itrs = 100
    train_loss = []
    train_acc = []
    print("start training")
    for itr in range(itrs):
        epoch_loss, epoch_acc = train(model, train_dataloader, test_dataloader, itr)
        train_loss.append(epoch_loss)
        train_acc.append(epoch_acc)

    torch.save(model.state_dict(), "D:\\桌面\\代码\\mnist\\resnetMNIST.pt")
    np.savetxt("trainloss.txt", np.array(train_loss))
    np.savetxt("trainacc.txt", np.array(train_acc))

识别率在98%左右

注意点

resnet最后一层不需要softmax输出

猜你喜欢

转载自blog.csdn.net/weixin_45121008/article/details/129267338