VGG16网络结构复现(Pytorch版)

VGG有6种子模型,分别是A、A-LRN、B、C、D、E,我们常看到的基本是D、E这两种模型,即VGG16,VGG19
在这里插入图片描述
在这里插入图片描述
为了方便阅读,并没有加上激活函数层

from torch import nn
import torch
from torchsummary import summary


class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()

        self.sum_Module = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(2, 2),

            nn.Flatten(),
            nn.Linear(7 * 7*512, 4096),
            #nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            #nn.Dropout(0.5),
            nn.Linear(4096, 1000)
        )


    def forward(self, x):
        x = self.sum_Module(x),
        return x


if __name__ == '__main__':
    YOLO = VGG16()

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    inputs = YOLO.to(device)
    summary(inputs, (3, 224, 224),batch_size=1, device="cuda")  # 分别是输入数据的三个维度





请添加图片描述

猜你喜欢

转载自blog.csdn.net/weixin_43694096/article/details/125084531