VGG使用重复元素的网络

由5个卷积层块(2个单卷积层,3个双卷积层),3个全连接层组成——VGG-11

from mxnet import gluon,init,nd,autograd
from mxnet.gluon import nn,loss

def vgg_block(num_convs, num_channels):
    blk = nn.Sequential()
    for _ in range(num_convs):
        blk.add(nn.Conv2D(num_channels,kernel_size=3,padding=1,activation='relu'))
    blk.add(nn.MaxPool2D(pool_size=2,strides=2))
    return blk

conv_arch = ((1,64),(1,128),(2,256),(2,512),(2,512))

def vgg(conv_arch):
    net = nn.Sequential()
    # 卷积层
    for (num_convs,num_channels) in conv_arch:
        net.add(vgg_block(num_channels,num_channels))

    # 全连接层
    net.add(nn.Dense(4096,activation='relu'),nn.Dropout(0.5),
            nn.Dense(4096,activation='relu'),nn.Dropout(0.5),
            nn.Dense(10)
            )
    return net

net = vgg(conv_arch)

net.initialize()
X = nd.random.uniform(shape=(1,1,224,224))
for blk in net:
    X = blk(X)
    print(blk.name,'output shape:\t',X.shape)

输出形状。

猜你喜欢

转载自www.cnblogs.com/TreeDream/p/10082108.html
今日推荐