La estructura inicial de PyTorch

1. Proceso de implementación

El módulo de inicio se propuso y adoptó por primera vez en GoogLeNet. Su estructura básica se muestra en la Figura 1. Toda la estructura de inicio se compone de varios de estos módulos de inicio conectados en serie. Hay dos contribuciones principales de la estructura inicial: una es usar una convolución 1x1 para aumentar y disminuir la dimensión; la otra es realizar convolución y reagregación en múltiples tamaños al mismo tiempo. Este documento utiliza la estructura de inicio de la Figura 1 para lograr la clasificación múltiple del conjunto de datos MNIST.
inserte la descripción de la imagen aquí

Figura 1 Estructura básica de inicio

Encapsule la estructura de inicio en una clase para reducir la redundancia de código. el código se muestra a continuación:

class InceptionA(torch.nn.Module):
    def __init__(self, in_channels):
        super(InceptionA,self).__init__()
        self.branch1x1 = torch.nn.Conv2d(in_channels,16,kernel_size=1)
        
        self.branch5x5_1 = torch.nn.Conv2d(in_channels,16,kernel_size=1)
        self.branch5x5_2 = torch.nn.Conv2d(16,24,kernel_size=5,padding=2)
    
        self.branch3x3_1 = torch.nn.Conv2d(in_channels,16,kernel_size=1)
        self.branch3x3_2 = torch.nn.Conv2d(16,24,kernel_size=3,padding=1)
        self.branch3x3_3 = torch.nn.Conv2d(24,24,kernel_size=3,padding=1)
        
        self.branch_pool = torch.nn.Conv2d(in_channels,24,kernel_size=1)
        
    def forward(self,x):
        branch1x1 = self.branch1x1(x)
        
        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)
        
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)
        branch3x3 = self.branch3x3_3(branch3x3)
        
        branch_pool = F.avg_pool2d(x,kernel_size=3,stride=1,padding=1)
        branch_pool = self.branch_pool(branch_pool)
        
        outputs = [branch1x1,branch5x5,branch3x3,branch_pool]
        return torch.cat(outputs,dim=1)

El código de pieza de la red se cambia a:

# 2.设计模型
class Net(torch.nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = torch.nn.Conv2d(1,10,kernel_size=5)
        self.conv2 = torch.nn.Conv2d(88,20,kernel_size=5)
        
        self.incep1 = InceptionA(in_channels=10)
        self.incep2 = InceptionA(in_channels=20)
        
        self.mp = torch.nn.MaxPool2d(2)
        self.fc = torch.nn.Linear(1408,10)
        
    def forward(self,x):
        # Flatten data from (n,1,28,28) to (n,784)
        in_size = x.size(0)
        x = F.relu(self.mp(self.conv1(x)))
        x = self.incep1(x)
        x = F.relu(self.mp(self.conv2(x)))
        x = self.incep2(x)
        x = x.view(in_size,-1)  # flatten
        return self.fc(x)
model = Net()

El resto del código permanece sin cambios.
El resultado de ejecución es:

[1,300] loss: 0.788
[1,600] loss: 0.225
[1,900] loss: 0.155
Accuracy on test set: 97.02 % [9702/10000]
[2,300] loss: 0.115
[2,600] loss: 0.102
[2,900] loss: 0.087
Accuracy on test set: 97.97 % [9797/10000]
[3,300] loss: 0.078
[3,600] loss: 0.073
[3,900] loss: 0.069
Accuracy on test set: 98.35 % [9835/10000]
[4,300] loss: 0.061
[4,600] loss: 0.061
[4,900] loss: 0.060
Accuracy on test set: 98.56 % [9856/10000]
[5,300] loss: 0.053
[5,600] loss: 0.051
[5,900] loss: 0.047
Accuracy on test set: 98.61 % [9861/10000]
[6,300] loss: 0.041
[6,600] loss: 0.046
[6,900] loss: 0.048
Accuracy on test set: 98.85 % [9885/10000]
[7,300] loss: 0.041
[7,600] loss: 0.039
[7,900] loss: 0.041
Accuracy on test set: 98.56 % [9856/10000]
[8,300] loss: 0.034
[8,600] loss: 0.038
[8,900] loss: 0.039
Accuracy on test set: 98.78 % [9878/10000]
[9,300] loss: 0.036
[9,600] loss: 0.031
[9,900] loss: 0.035
Accuracy on test set: 98.87 % [9887/10000]
[10,300] loss: 0.030
[10,600] loss: 0.033
[10,900] loss: 0.032
Accuracy on test set: 98.94 % [9894/10000]

inserte la descripción de la imagen aquí
Complemento:
la altura (anchura) después de la convolución se puede calcular mediante la siguiente fórmula: H ′ = H − F + 2 ps + 1 (1) H'=\frac{H-F+2p}{s}+1\tag { 1}H=sHF+2 p+1( 1 ) Entre ellos,FFF es el tamaño del kernel de convolución (kernel_size),ppp es el número de relleno de convolución (relleno),sss es el paso de convolución.

2. Referencias

[1] https://www.bilibili.com/video/BV1Y7411d7Ys?p=11
[2] https://baike.baidu.com/item/GoogLeNet/22689587?fr=aladdin

Supongo que te gusta

Origin blog.csdn.net/weixin_43821559/article/details/123379833
Recomendado
Clasificación