1. Proceso de implementación
El módulo de inicio se propuso y adoptó por primera vez en GoogLeNet. Su estructura básica se muestra en la Figura 1. Toda la estructura de inicio se compone de varios de estos módulos de inicio conectados en serie. Hay dos contribuciones principales de la estructura inicial: una es usar una convolución 1x1 para aumentar y disminuir la dimensión; la otra es realizar convolución y reagregación en múltiples tamaños al mismo tiempo. Este documento utiliza la estructura de inicio de la Figura 1 para lograr la clasificación múltiple del conjunto de datos MNIST.
Encapsule la estructura de inicio en una clase para reducir la redundancia de código. el código se muestra a continuación:
class InceptionA(torch.nn.Module):
def __init__(self, in_channels):
super(InceptionA,self).__init__()
self.branch1x1 = torch.nn.Conv2d(in_channels,16,kernel_size=1)
self.branch5x5_1 = torch.nn.Conv2d(in_channels,16,kernel_size=1)
self.branch5x5_2 = torch.nn.Conv2d(16,24,kernel_size=5,padding=2)
self.branch3x3_1 = torch.nn.Conv2d(in_channels,16,kernel_size=1)
self.branch3x3_2 = torch.nn.Conv2d(16,24,kernel_size=3,padding=1)
self.branch3x3_3 = torch.nn.Conv2d(24,24,kernel_size=3,padding=1)
self.branch_pool = torch.nn.Conv2d(in_channels,24,kernel_size=1)
def forward(self,x):
branch1x1 = self.branch1x1(x)
branch5x5 = self.branch5x5_1(x)
branch5x5 = self.branch5x5_2(branch5x5)
branch3x3 = self.branch3x3_1(x)
branch3x3 = self.branch3x3_2(branch3x3)
branch3x3 = self.branch3x3_3(branch3x3)
branch_pool = F.avg_pool2d(x,kernel_size=3,stride=1,padding=1)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1,branch5x5,branch3x3,branch_pool]
return torch.cat(outputs,dim=1)
El código de pieza de la red se cambia a:
# 2.设计模型
class Net(torch.nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = torch.nn.Conv2d(1,10,kernel_size=5)
self.conv2 = torch.nn.Conv2d(88,20,kernel_size=5)
self.incep1 = InceptionA(in_channels=10)
self.incep2 = InceptionA(in_channels=20)
self.mp = torch.nn.MaxPool2d(2)
self.fc = torch.nn.Linear(1408,10)
def forward(self,x):
# Flatten data from (n,1,28,28) to (n,784)
in_size = x.size(0)
x = F.relu(self.mp(self.conv1(x)))
x = self.incep1(x)
x = F.relu(self.mp(self.conv2(x)))
x = self.incep2(x)
x = x.view(in_size,-1) # flatten
return self.fc(x)
model = Net()
El resto del código permanece sin cambios.
El resultado de ejecución es:
[1,300] loss: 0.788
[1,600] loss: 0.225
[1,900] loss: 0.155
Accuracy on test set: 97.02 % [9702/10000]
[2,300] loss: 0.115
[2,600] loss: 0.102
[2,900] loss: 0.087
Accuracy on test set: 97.97 % [9797/10000]
[3,300] loss: 0.078
[3,600] loss: 0.073
[3,900] loss: 0.069
Accuracy on test set: 98.35 % [9835/10000]
[4,300] loss: 0.061
[4,600] loss: 0.061
[4,900] loss: 0.060
Accuracy on test set: 98.56 % [9856/10000]
[5,300] loss: 0.053
[5,600] loss: 0.051
[5,900] loss: 0.047
Accuracy on test set: 98.61 % [9861/10000]
[6,300] loss: 0.041
[6,600] loss: 0.046
[6,900] loss: 0.048
Accuracy on test set: 98.85 % [9885/10000]
[7,300] loss: 0.041
[7,600] loss: 0.039
[7,900] loss: 0.041
Accuracy on test set: 98.56 % [9856/10000]
[8,300] loss: 0.034
[8,600] loss: 0.038
[8,900] loss: 0.039
Accuracy on test set: 98.78 % [9878/10000]
[9,300] loss: 0.036
[9,600] loss: 0.031
[9,900] loss: 0.035
Accuracy on test set: 98.87 % [9887/10000]
[10,300] loss: 0.030
[10,600] loss: 0.033
[10,900] loss: 0.032
Accuracy on test set: 98.94 % [9894/10000]
Complemento:
la altura (anchura) después de la convolución se puede calcular mediante la siguiente fórmula: H ′ = H − F + 2 ps + 1 (1) H'=\frac{H-F+2p}{s}+1\tag { 1}H′=sH−F+2 p+1( 1 ) Entre ellos,FFF es el tamaño del kernel de convolución (kernel_size),ppp es el número de relleno de convolución (relleno),sss es el paso de convolución.
2. Referencias
[1] https://www.bilibili.com/video/BV1Y7411d7Ys?p=11
[2] https://baike.baidu.com/item/GoogLeNet/22689587?fr=aladdin