from keras.models import Model
from keras.layers import Input, GlobalAveragePooling2D
from keras.layers import Conv2D, MaxPool2D, BatchNormalization
# Mlpconv layerdefmlp_layer(x, filters, size, stide, padding):
x = Conv2D(filters=filters, kernel_size=(size, size), strides=(stide, stide), padding=padding, activation='relu')(x)
x = BatchNormalization()(x)
x = Conv2D(filters=filters, kernel_size=(1,1), strides=(1,1), padding='same', activation='relu')(x)
x = BatchNormalization()(x)
x = Conv2D(filters=filters, kernel_size=(1,1), strides=(1,1), padding='same', activation='relu')(x)
x = BatchNormalization()(x)return x
defbody(img_input, classes):# Con1 227X227X3
x = mlp_layer(img_input,96,11,4,'valid')
x = MaxPool2D(pool_size=(3,3), strides=(2,2), padding='valid')(x)# Con2 27X27X96
x = mlp_layer(x,256,5,1,'same')
x = MaxPool2D(pool_size=(3,3), strides=(2,2), padding='valid')(x)# Con3 13X13X256
x = mlp_layer(x,384,3,1,'same')
x = MaxPool2D(pool_size=(3,3), strides=(2,2), padding='valid')(x)#Con4 6X6X384
x = mlp_layer(x, classes,3,1,'same')# GLP 6X6Xclasses
x = GlobalAveragePooling2D()(x)return x
PyTorch实现Network In Network
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
defmlp_layer(channels, kernel_size, padding, strides=1, max_pooling=True):
net = nn.Sequential()
net.add(
nn.Conv2d(channels=channels, kernel_size=kernel_size, strides=strides, padding=padding, activation='relu'),
nn.Conv2d(channels=channels,kernel_size=1,padding=0,stride=1,activation='relu'),
nn.Conv2D(channels=channels, kernel_size=1, padding=0, strides=1, activation='relu'))if max_pooling:
net.add(nn.MaxPool2D(pool_size=3, strides=2))return net
if __name__ =='__main__':
net = nn.Sequential()with net.name_scope():
net.add(
mlp_layer(96,11,0, strides=4),
mlp_layer(256,5,2),
mlp_layer(384,3,1),
nn.Dropout(0.5),
mlp_layer(10,3,1, max_pooling=False),# 目标类为10类
nn.GlobalAvgPool2D(),# 输入为 batch_size x 10 x 5 x 5, 通过AvgPool2D转成 batch_size x 10 x 1 x 1。# 使用全局池化可以避免估算pool_size大小
nn.Flatten()# 转成 batch_size x 10)