chainer-骨干网络backbone-AlexNet代码重构【附源码】


前言

本文基于chainer实现AlexNet网络结构,并基于torch的结构方式构建chainer版的,并计算AlexNet的参数量。


代码实现

class AlexNet(chainer.Chain):
    cfgs={
    
    
        'alexnet':None
    }
    
    def get_out_channels(self,W,F,P,S):
        return int((W-F+2*P)/S +1)
    
    def __init__(self, num_classes=1000, channels=3,image_size=224,initialW=chainer.initializers.HeNormal(), **kwargs):
        super(AlexNet, self).__init__()
        
        self.features = []
        self.features +=[("conv1",L.Convolution2D(channels, 64, ksize=11, stride=4, pad=2,initialW=initialW))]
        out_size = self.get_out_channels(image_size,11,2,4)
        self.features +=[("_canv1_relu",ReLU())]
        self.features +=[("_canv1_maxpooling",MaxPooling2D(ksize=3,stride=2))]
        out_size = self.get_out_channels(out_size,2,0,2)
        self.features +=[("conv2",L.Convolution2D(64, 192, ksize=5, stride=1, pad=2,initialW=initialW))]
        out_size = self.get_out_channels(out_size,5,2,1)
        self.features +=[("_canv2_relu",ReLU())]
        self.features +=[("_canv2_maxpooling",MaxPooling2D(ksize=3,stride=2))]
        out_size = self.get_out_channels(out_size,2,0,2)
        self.features +=[("conv3",L.Convolution2D(192, 384, ksize=3, stride=1, pad=1,initialW=initialW))]
        out_size = self.get_out_channels(out_size,3,1,1)
        self.features +=[("_canv3_relu",ReLU())]
        self.features +=[("conv4",L.Convolution2D(384, 256, ksize=3, stride=1, pad=1,initialW=initialW))]
        out_size = self.get_out_channels(out_size,3,1,1)
        self.features +=[("_canv4_relu",ReLU())]
        self.features +=[("conv5",L.Convolution2D(256, 256, ksize=3, stride=1, pad=1,initialW=initialW))]
        out_size = self.get_out_channels(out_size,3,1,1)
        self.features +=[("_canv5_relu",ReLU())]
        self.features +=[("_canv5_maxpooling",MaxPooling2D(ksize=3,stride=2))]
        out_size = self.get_out_channels(out_size,2,0,2)
        
        self.classifier = []
        self.classifier +=[("_dropout1",Dropout(0.5))]
        self.classifier +=[("fc1",L.Linear(256 * out_size * out_size, 4096,initialW=initialW))]
        self.classifier +=[("_fc1_relu",ReLU())]
        self.classifier +=[("_dropout2",Dropout(0.5))]
        self.classifier +=[("fc2",L.Linear(4096, 4096,initialW=initialW))]
        self.classifier +=[("_fc2_relu",ReLU())]
        self.classifier +=[("output_1",L.Linear(4096, num_classes,initialW=initialW))]
    
        with self.init_scope():
            for n in self.features:
                if not n[0].startswith('_'):
                    setattr(self, n[0], n[1])
            for n in self.classifier:
                if not n[0].startswith('_'):
                    setattr(self, n[0], n[1])
        
    def __call__(self, x):
        for n, f in self.features:
            origin_size = x.shape
            # print(n,x.shape)
            if not n.startswith('_'):
                x = getattr(self, n)(x)
            else:
                x = f.apply((x,))[0]
            print(n,origin_size,x.shape)
        
        for n, f in self.classifier:
            origin_size = x.shape
            if not n.startswith('_'):
                x = getattr(self, n)(x)
            else:
                x = f.apply((x,))[0]
            print(n,origin_size,x.shape)

        if chainer.config.train:
            return x
        return F.softmax(x)

注意此类就是AlexNet的实现过程,注意网络的前向传播过程中,分了训练以及测试。
训练过程中直接返回x,测试过程中会进入softmax得出概率

调用方式

if __name__ == '__main__':
    batch_size = 4
    n_channels = 3
    image_size = 224
    num_classes = 10
    
    model_simple = AlexNet(num_classes=num_classes, channels=n_channels,image_size=image_size)
    print(model_simple.count_params())
    
    x = np.random.rand(batch_size, n_channels, image_size, image_size).astype(np.float32)
    t = np.random.randint(0, num_classes, size=(batch_size,)).astype(np.int32)
    with chainer.using_config('train', True):
        y1 = model_simple(x)
    loss1 = F.softmax_cross_entropy(y1, t)
    print(loss1.data)

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/ctu_sue/article/details/128682822