pytorch写cnn的麻烦之处是cnn到第一个全连接层的参数不明确。可以手动计算,但太麻烦,不建议这做。下面通过代码自己计算出来参数个数。
问题:下面的的第一个全连接层的输入神经元个数是多少呢?
net = nn.Sequential(
nn.Conv2d(1, 32, 3),nn.ReLU(),nn.MaxPool2d(2,2),
nn.Conv2d(32, 64, 3),nn.ReLU(),nn.MaxPool2d(2,2),
# nn.Linear(?, 4096),nn.ReLU(),
# nn.Linear(4096,10)
)
我们先不写全连接层,只写卷积层、激活层,池化层。(也就是将全连接层前面的都写好)
我们使用net(data)看看output形状,就可以知道全连接的神经元个数了。output的输出是(batch_size, out_channels,height,width)。
import torch,torchvision
import torch.nn as nn
net = nn.Sequential(
nn.Conv2d(1, 32, 3),nn.ReLU(),nn.MaxPool2d(2,2),
nn.Conv2d(32, 64, 3),nn.ReLU(),nn.MaxPool2d(2,2),
# nn.Linear(?, 4096),nn.ReLU(),
# nn.Linear(4096,10)
)
train_data = torchvision.datasets.MNIST('./data', train=True, transform=torchvision.transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)
for data, target in train_loader:
output = net(data)
print(output.size())
break
输出结果:
由输出结果可得,第一次的全连接输入神经元的个数为64*5*5.。即nn.Linear(64*5*5, 4096).
下面选看:
下面是完整的cnn识别mnist的代码,在全连接层前面加上nn.Flatten()可以将参数拉成(batch_size, ?)
import torch,torchvision
import torch.nn as nn
#构建模型
net = nn.Sequential(
nn.Conv2d(1, 32, 3),nn.ReLU(),nn.MaxPool2d(2,2),
nn.Conv2d(32, 64, 3),nn.ReLU(),nn.MaxPool2d(2,2),
#nn.Flatten(会将参数拉成二维,即(batch_size, ?))
nn.Flatten(),
nn.Linear(64*5*5, 4096),nn.ReLU(),
nn.Linear(4096,10)
)
train_data = torchvision.datasets.MNIST('./data', train=True, transform=torchvision.transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')
net = net.to(device)
print('training on: ',device)
for epoch in range(10):
print('epoch:%d'%epoch)
acc = 0.0
sum = 0.0
loss_sum = 0
for batch, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = net(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
acc+=torch.sum(torch.argmax(output,dim=1)==target).cpu().item()
sum+=len(target)
loss_sum+=loss
if batch%100==0:
print('\tbatch: %d, loss: %.4f'%(batch, loss))
print('epoch: %d, acc: %.2f%%, loss: %.4f'%(epoch, 100*acc/sum, loss_sum/len(train_loader)))
结果图: