bn的优势:
(1)更大的学习率(传统方法太大的learning rate容易导致梯度explode/vanish,或者get stuck in poor local)
(2)不再需要dropout
(3)less careful about initialization
但是BN不仅仅加BN层,还要修改以下的东西才能更快:
(1)learning rate 赋予更大的初值,且下降得更快。(比如将learning rate从0.0015扩大5倍到0.0075,下降快6倍)
(2)Remove Droupout
(3)Reduce L2 weight decay。(比如每次除5)
(4)Remove LRN
(5、6)其他。。看论文
ResNet 有用到BN,其在CIFAR-10网络中参数为:
20,34,44,56层使用:
learning rate=0.1,在32k和48k iterations时/10。
l2 weght decay=1e-4
110层使用:0.01learning rat用于warm up training,直到training error小于80%。
一、tf.nn.batch_normalization
Tensorflow 提供了Batch Normalization的API。但是,这个API很灵活,灵活的后果就是我们需要自己去定义所有的参数。
(比如,提供给此API的Tensor,居然需要我们自己去计算mean和variance)
tf.nn.batch_normalization( x, #Tensor,对它执行BN操作 mean, #Tensor,一般为x的平均数,float32。 variance, #Tensor,一般为x的方差,float32。 offset, #Tensor,beta值,BN的shift操作。一般初始为0 scale, #Tensor,gamma值,BN的scale操作。一般初始为1 variance_epsilon, #float。小的实数防止除0出现。 name=None ) """ 返回值(Tensor): y= (x-mean)/sqrt(variance^2+variance_epsilon)*scale+offset。 但是mean和variance需要自己提前计算, 而tensorflow又提供了另一个API来计算mean和variance。(当然我们也可以自己瞎搞一个) """这个API完全按照论文的思路设计,且更加灵活(比如mean和variance可以设置为其他值而不是x的均值和方差,beta和gamma也是如此)。
(见下图):
二、Tensor平均数和方差计算tf.nn.moments
由于上述的API需要手动计算mean和variance,所以就用到了这个API。
tf.nn.moments( x, #Tensor,要计算mean和variance的变量 axes, #要处理的维度。BN一般就是所有的维度。即[d for d in range(len(x.get_shape())] shift=None, name=None, keep_dims=False )
三、例子
import tensorflow as tf sess=tf.Session() x=tf.constant([[1,5],[10,100]],dtype=tf.float32) #维度 axes=[d for d in range(len(x.get_shape()))] #beta gamma参数 beta= tf.get_variable("beta",shape=[],initializer=tf.constant_initializer(0.0)) gamma=tf.get_variable("gamma",shape=[],initializer=tf.constant_initializer(1.0)) sess.run(tf.global_variables_initializer()) #计算mean和variance,并执行BN操作 x_mean,x_variance=tf.nn.moments(x,axes) y=tf.nn.batch_normalization(x,x_mean,x_variance,beta,gamma,1e-10,"bn") #查看最终值 y_mean,y_variance=tf.nn.moments(y,axes) x_val,xm_val,xv_val,y_val,ym_val,yv_val=sess.run([x,x_mean,x_variance,y,y_mean,y_variance]) print("*********执行BN前的Variable x:************") print("x=%s\n x mean=%s\n x variance=%s" %(x_val,xm_val,xv_val)) print("*********执行BN后的Variable y:************") print("y=%s \n y mean=%s\n y variance=%s" %(y_val,ym_val,yv_val))
执行结果为:
*********执行BN前的Variable x:************
x=[[ 1. 5.]
[ 10. 100.]]
x mean=29.0
x variance=1690.5
*********执行BN后的Variable y:************
y=[[-0.68100518 -0.58371872]
[-0.46211064 1.72683454]]
y mean=0.0
y variance=1.0
x=[[ 1. 5.]
[ 10. 100.]]
x mean=29.0
x variance=1690.5
*********执行BN后的Variable y:************
y=[[-0.68100518 -0.58371872]
[-0.46211064 1.72683454]]
y mean=0.0
y variance=1.0
可知道x经过BN处理后得到y,y的均值为0,方差变成1了(beta为0,gamma为1时)。
这里我们可以修改下beta和gamma的初始值,则y的平均值会变成beta,方差会变成gamma^2。
四、BN层放置顺序
BN网络中,一个卷积层或全连接层中,对于输入x,有3步中间操作:BN操作、weight操作、ReLu操作。这三种操作的顺序该怎么排列。
原论文的说法是:在Any layer previously received x as input, now received BN(x),但一个卷积层中的子层呢?
对于2*con16 =》 2*conv32=》2*conv64=》fc-10 在MNIST中试了下三种顺序:
(1) x -> bn -> weight -> relu
(2) x -> bn -> relu -> weight
(3) x -> weight ->bn -> relu
最后发现效果都挺好的,可能是这个数据集太简单了,有待以后继续测试。。。
不过在Resnet 1k网络中,第2种方法比第3种效果更好(在有shortcut的情况)。
论文地址:https://arxiv.org/pdf/1603.05027.pdf
五、BN在Mnist效果对比
由于Mnist太简单,正常CNN网络加不加BN层效果不明显。
所以我们需要给网络模型增加训练难度:把ReLu替换成Sigmoid。
(使用Sigmoid会让训练无比的慢,起码慢了百八十倍了~我一开始还以为网络出问题了。。ReLu真的强大!)
其他参数一致:网络为768*100*100*100*10的全连接模型,学习率为1e-4,momentum=0.9,L2_WEGHT_DECAY=1e-4,batch_sz为50,训练10个epoch。
无BN层训练结果:
[step100] accuracy=0.1 loss=116.691 [step200] accuracy=0.16 loss=114.826 [step300] accuracy=0.1 loss=115.051 [step400] accuracy=0.1 loss=117.023 [step500] accuracy=0.08 loss=115.734 [step600] accuracy=0.14 loss=114.13 [step700] accuracy=0.1 loss=115.985 [step800] accuracy=0.1 loss=115.7 [step900] accuracy=0.02 loss=117.614 [step1000] accuracy=0.1 loss=115.558 [*]Test Result=0.0892000000738 at epoch0 [step100] accuracy=0.08 loss=114.817 [step200] accuracy=0.22 loss=113.812 [step300] accuracy=0.1 loss=115.722 [step400] accuracy=0.04 loss=116.21 [step500] accuracy=0.14 loss=115.215 [step600] accuracy=0.08 loss=115.071 [step700] accuracy=0.14 loss=115.076 [step800] accuracy=0.06 loss=116.63 [step900] accuracy=0.12 loss=114.81 [step1000] accuracy=0.08 loss=115.669 [*]Test Result=0.100900000408 at epoch1 [step100] accuracy=0.1 loss=115.425 [step200] accuracy=0.1 loss=115.394 [step300] accuracy=0.08 loss=115.214 [step400] accuracy=0.04 loss=114.856 [step500] accuracy=0.08 loss=117.108 [step600] accuracy=0.14 loss=113.223 [step700] accuracy=0.08 loss=115.142 [step800] accuracy=0.16 loss=114.448 [step900] accuracy=0.1 loss=114.995 [step1000] accuracy=0.18 loss=115.651 [*]Test Result=0.113499999568 at epoch2 [step100] accuracy=0.12 loss=114.254 [step200] accuracy=0.08 loss=116.074 [step300] accuracy=0.2 loss=113.781 [step400] accuracy=0.08 loss=115.302 [step500] accuracy=0.06 loss=115.785 [step600] accuracy=0.08 loss=116.462 [step700] accuracy=0.08 loss=114.897 [step800] accuracy=0.14 loss=116.592 [step900] accuracy=0.1 loss=116.425 [step1000] accuracy=0.06 loss=114.058 [*]Test Result=0.103200000077 at epoch3 [step100] accuracy=0.26 loss=113.873 [step200] accuracy=0.08 loss=115.774 [step300] accuracy=0.14 loss=114.722 [step400] accuracy=0.1 loss=114.43 [step500] accuracy=0.12 loss=114.766 [step600] accuracy=0.08 loss=116.453 [step700] accuracy=0.02 loss=116.828 [step800] accuracy=0.06 loss=115.831 [step900] accuracy=0.14 loss=114.576 [step1000] accuracy=0.04 loss=114.588 [*]Test Result=0.113499999568 at epoch4 [step100] accuracy=0.24 loss=114.013 [step200] accuracy=0.1 loss=115.269 [step300] accuracy=0.08 loss=115.71 [step400] accuracy=0.18 loss=113.4 [step500] accuracy=0.14 loss=115.153 [step600] accuracy=0.08 loss=114.52 [step700] accuracy=0.12 loss=114.871 [step800] accuracy=0.22 loss=115.017 [step900] accuracy=0.12 loss=113.872 [step1000] accuracy=0.12 loss=115.084 [*]Test Result=0.171800000742 at epoch5 [step100] accuracy=0.12 loss=116.787 [step200] accuracy=0.1 loss=116.283 [step300] accuracy=0.04 loss=115.422 [step400] accuracy=0.14 loss=114.826 [step500] accuracy=0.18 loss=114.08 [step600] accuracy=0.14 loss=114.935 [step700] accuracy=0.18 loss=114.367 [step800] accuracy=0.02 loss=115.996 [step900] accuracy=0.08 loss=114.403 [step1000] accuracy=0.24 loss=113.339 [*]Test Result=0.113499999568 at epoch6 [step100] accuracy=0.18 loss=114.502 [step200] accuracy=0.12 loss=114.226 [step300] accuracy=0.14 loss=114.238 [step400] accuracy=0.28 loss=113.135 [step500] accuracy=0.04 loss=115.067 [step600] accuracy=0.16 loss=113.927 [step700] accuracy=0.1 loss=113.124 [step800] accuracy=0.06 loss=114.841 [step900] accuracy=0.16 loss=113.212 [step1000] accuracy=0.26 loss=112.934 [*]Test Result=0.199200000018 at epoch7 [step100] accuracy=0.16 loss=114.148 [step200] accuracy=0.12 loss=113.84 [step300] accuracy=0.14 loss=112.673 [step400] accuracy=0.2 loss=112.878 [step500] accuracy=0.2 loss=114.386 [step600] accuracy=0.12 loss=112.982 [step700] accuracy=0.38 loss=111.301 [step800] accuracy=0.3 loss=112.395 [step900] accuracy=0.52 loss=110.003 [step1000] accuracy=0.12 loss=111.22 [*]Test Result=0.122199999765 at epoch8 [step100] accuracy=0.08 loss=112.523 [step200] accuracy=0.42 loss=108.418 [step300] accuracy=0.4 loss=105.239 [step400] accuracy=0.5 loss=98.153 [step500] accuracy=0.22 loss=103.485 [step600] accuracy=0.2 loss=104.636 [step700] accuracy=0.48 loss=95.7585 [step800] accuracy=0.24 loss=94.8633 [step900] accuracy=0.38 loss=93.5662 [step1000] accuracy=0.36 loss=89.0528 [*]Test Result=0.351300003231 at epoch9
跑了10个epoch,测试集正确率才到达35%。
加了BN层以后训练效果:
[step100] accuracy=0.1 loss=116.102 [step200] accuracy=0.22 loss=112.854 [step300] accuracy=0.14 loss=115.377 [step400] accuracy=0.1 loss=115.649 [step500] accuracy=0.1 loss=115.625 [step600] accuracy=0.24 loss=114.879 [step700] accuracy=0.1 loss=115.61 [step800] accuracy=0.12 loss=114.699 [step900] accuracy=0.14 loss=115.097 [step1000] accuracy=0.1 loss=114.932 [*]Test Result=0.0974000002816 at epoch0 [step100] accuracy=0.1 loss=116.12 [step200] accuracy=0.06 loss=116.164 [step300] accuracy=0.1 loss=115.818 [step400] accuracy=0.12 loss=115.697 [step500] accuracy=0.18 loss=115.264 [step600] accuracy=0.2 loss=114.414 [step700] accuracy=0.04 loss=115.895 [step800] accuracy=0.12 loss=114.564 [step900] accuracy=0.06 loss=115.524 [step1000] accuracy=0.22 loss=114.622 [*]Test Result=0.161500000656 at epoch1 [step100] accuracy=0.2 loss=115.315 [step200] accuracy=0.14 loss=114.43 [step300] accuracy=0.1 loss=115.918 [step400] accuracy=0.16 loss=114.786 [step500] accuracy=0.26 loss=112.941 [step600] accuracy=0.3 loss=113.985 [step700] accuracy=0.3 loss=112.463 [step800] accuracy=0.14 loss=113.471 [step900] accuracy=0.14 loss=112.914 [step1000] accuracy=0.14 loss=112.23 [*]Test Result=0.24730000034 at epoch2 [step100] accuracy=0.2 loss=111.719 [step200] accuracy=0.32 loss=108.348 [step300] accuracy=0.24 loss=106.837 [step400] accuracy=0.36 loss=102.211 [step500] accuracy=0.32 loss=99.1392 [step600] accuracy=0.42 loss=94.0066 [step700] accuracy=0.5 loss=82.9231 [step800] accuracy=0.5 loss=78.0428 [step900] accuracy=0.56 loss=75.0709 [step1000] accuracy=0.56 loss=72.2615 [*]Test Result=0.569599996507 at epoch3 [step100] accuracy=0.54 loss=72.2187 [step200] accuracy=0.62 loss=62.6503 [step300] accuracy=0.7 loss=51.1989 [step400] accuracy=0.7 loss=50.0574 [step500] accuracy=0.62 loss=48.4715 [step600] accuracy=0.58 loss=56.4319 [step700] accuracy=0.76 loss=48.7727 [step800] accuracy=0.76 loss=39.0827 [step900] accuracy=0.66 loss=44.0735 [step1000] accuracy=0.74 loss=40.6393 [*]Test Result=0.731999999881 at epoch4 [step100] accuracy=0.82 loss=39.1621 [step200] accuracy=0.8 loss=33.0594 [step300] accuracy=0.68 loss=41.5027 [step400] accuracy=0.72 loss=49.6565 [step500] accuracy=0.8 loss=32.1081 [step600] accuracy=0.8 loss=42.5631 [step700] accuracy=0.84 loss=31.7484 [step800] accuracy=0.8 loss=34.406 [step900] accuracy=0.7 loss=36.0701 [step1000] accuracy=0.76 loss=39.4207 [*]Test Result=0.798400003314 at epoch5 [step100] accuracy=0.66 loss=38.2423 [step200] accuracy=0.88 loss=23.5632 [step300] accuracy=0.8 loss=37.7658 [step400] accuracy=0.8 loss=41.1382 [step500] accuracy=0.84 loss=31.7916 [step600] accuracy=0.86 loss=24.6395 [step700] accuracy=0.8 loss=29.7371 [step800] accuracy=0.84 loss=33.4366 [step900] accuracy=0.84 loss=25.56 [step1000] accuracy=0.92 loss=23.0958 [*]Test Result=0.841499999762 at epoch6 [step100] accuracy=0.9 loss=17.4944 [step200] accuracy=0.74 loss=35.0277 [step300] accuracy=0.9 loss=30.2663 [step400] accuracy=0.78 loss=34.679 [step500] accuracy=0.82 loss=25.4055 [step600] accuracy=0.86 loss=19.0345 [step700] accuracy=0.98 loss=14.34 [step800] accuracy=0.86 loss=27.425 [step900] accuracy=0.78 loss=35.237 [step1000] accuracy=0.88 loss=23.2125 [*]Test Result=0.86880000174 at epoch7 [step100] accuracy=0.88 loss=23.3765 [step200] accuracy=0.82 loss=33.0606 [step300] accuracy=0.76 loss=44.3354 [step400] accuracy=0.9 loss=17.5737 [step500] accuracy=0.82 loss=27.3082 [step600] accuracy=0.92 loss=18.8941 [step700] accuracy=0.84 loss=27.9557 [step800] accuracy=0.9 loss=16.8646 [step900] accuracy=0.92 loss=12.3513 [step1000] accuracy=0.9 loss=22.4553 [*]Test Result=0.886900002956 at epoch8
跑了9个epoch差不多有了88%的正确率。粗略估计下同样到达35%正确率,前者需要10个epochs,后者差不多需要3.4个epochs。
快了3倍左右~这个数值和论文上BN-Baseline与Incetion的加速差不多。应该可以通过调整LR变得更快。
代码:
###VGG.PY######### import tensorflow as tf """ (1)构造函数__init__参数 input_sz: 输入层placeholder的4-D shape,如mnist是[None,28,28,1] fc_layers: 全连接层每一层大小,接在卷积层后面。如mnist可以为[128,84,10],[10] conv_info: 卷积层、池化层。 如vgg16可以这样写:[(2,64),(2,128),(3,256),(3,512),(3,512)],表示2+2+3+3+3=13个卷积层,4个池化层,以及channels (2)train函数:训练一步 batch_input: 输入的batch batch_output: label learning_rate:学习率 返回:正确率和loss值(float) 格式:{"accuracy":accuracy,"loss":loss} (3)forward:训练后用于测试 (4)save(save_path,steps)保存模型 (5)restore(path):从文件夹中读取最后一个模型 (6)loss函数使用cross-entrop one-hot版本:y*log(y_net) (7)optimizer使用adamoptimier """ class VGG: #VGG分类器 sess=None #Tensor input=None output=None desired_out=None loss=None iscorrect=None accuracy=None optimizer=None param_num=0 #参数个数 #参数 learning_rate=None MOMENTUM = 0.9 WEIGHT_DECAY = 1e-4 #L2 REGULARIZATION ACTIVATE = None CONV_PADDING = "SAME" MAX_POOL_PADDING = "SAME" CONV_WEIGHT_INITAILIZER = tf.truncated_normal_initializer(stddev=0.1) CONV_BIAS_INITAILIZER = tf.constant_initializer(value=0.0) FC_WEIGHT_INITAILIZER = tf.truncated_normal_initializer(stddev=0.1) FC_BIAS_INITAILIZER = tf.constant_initializer(value=0.0) def train(self,batch_input,batch_output,learning_rate): _,accuracy,loss=self.sess.run([self.optimizer,self.accuracy,self.loss], feed_dict={self.input:batch_input,self.desired_out:batch_output,self.learning_rate:learning_rate}) return {"accuracy":accuracy,"loss":loss} def forward(self,batch_input): return self.sess.run(self.output,feed_dict={self.input:batch_input}) def save(self,save_path,steps): saver=tf.train.Saver(max_to_keep=5) saver.save(self.sess,save_path,global_step=steps) def restore(self,restore_path): path=tf.train.latest_checkpoint(restore_path) print("[*]Restore from %s" %(path)) if path==None: return False saver=tf.train.Saver(max_to_keep=5) saver.restore(self.sess,path) return True def bn(self,x,name="bn"): #return x axes = [d for d in range(len(x.get_shape()))] beta = self._get_variable("beta", shape=[],initializer=tf.constant_initializer(0.0)) gamma= self._get_variable("gamma",shape=[],initializer=tf.constant_initializer(1.0)) x_mean,x_variance=tf.nn.moments(x,axes) y=tf.nn.batch_normalization(x,x_mean,x_variance,beta,gamma,1e-10,name) return y def get_optimizer(self): # #Optimizer #sself.optimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss) #self.optimizer =tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss) #1300 steps后达到误差范围。 self.optimizer =tf.train.MomentumOptimizer(self.learning_rate,self.MOMENTUM).minimize(self.loss) #9000 steps后达到误差范围。 #对x执行一次卷积操作+Relu def conv(self,x,name,channels,ksize=3): x_shape=x.get_shape() x_channels=x_shape[3].value weight_shape=[ksize,ksize,x_channels,channels] bias_shape=[channels] weight = self._get_variable("weight",weight_shape,initializer=self.CONV_WEIGHT_INITAILIZER) bias = self._get_variable("bias",bias_shape,initializer=self.CONV_BIAS_INITAILIZER) y=tf.nn.conv2d(x,weight,strides=[1,1,1,1],padding=self.CONV_PADDING,name=name) y=tf.add(y,bias,name=name) return y def max_pool(self,x,name): return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding=self.MAX_POOL_PADDING,name=name) #定义_get_variable方便进行l2_regularization以及其他一些操作 def _get_variable(self,name,shape,initializer): param=1 for i in range(0,len(shape)): param*=shape[i] self.param_num+=param if self.WEIGHT_DECAY>0: regularizer=tf.contrib.layers.l2_regularizer(self.WEIGHT_DECAY) else: regularizer=None return tf.get_variable(name, shape=shape, initializer=initializer, regularizer=regularizer) def fc(self,x,num,name): x_num=x.get_shape()[1].value weight_shape=[x_num,num] bias_shape =[num] weight=self._get_variable("weight",shape=weight_shape,initializer=self.FC_WEIGHT_INITAILIZER) bias =self._get_variable("bias",shape=bias_shape,initializer=self.FC_BIAS_INITAILIZER) y=tf.add(tf.matmul(x,weight),bias,name=name) return y def _loss(self): cross_entropy=-tf.reduce_sum(self.desired_out*tf.log(tf.clip_by_value(self.output,1e-10,1.0))) regularization_losses=tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) self.loss = tf.add_n([cross_entropy]+regularization_losses) #tf.scalar_summary('loss', loss_) return self.loss def __init__(self,input_sz,fc_layers,conv_info=[],activate_fun=tf.nn.relu): # self.ACTIVATE=activate_fun self.param_num=0 #返回参数个数 self.sess=tf.Session() layers=[] #(1)placeholder定义(输入、输出、learning_rate) #input self.input=tf.placeholder(tf.float32,input_sz,name="input") layers.append(self.input) # layers.append(self.bn(layers[-1])) #output output_sz=[None,fc_layers[-1]] self.desired_out=tf.placeholder(tf.float32,output_sz,name="desired_out") self.learning_rate=tf.placeholder(tf.float32,name="learning_rate") #(2)插入卷积层+池化层 with tf.variable_scope("convolution"): conv_block_id=0 for cur_layers in conv_info: #添加卷积层block with tf.variable_scope("conv_block_%d" %(conv_block_id)) as scope: cur_conv_num=cur_layers[0] #cur_conv_num个卷积层叠放 cur_channels=cur_layers[1] #每个卷积层的通道 #cur_conv_num个卷积层叠加 for conv_id in range(0,cur_conv_num): with tf.variable_scope("conv_%d" %(conv_id)): #添加卷积层 x=layers[-1] """ #顺序一:x->bn->weight->relu x2=self.bn(x) x3=self.conv(x2,channels=cur_channels,name="conv") x4=self.ACTIVATE(x3) """ #""" #顺序二: x->bn->relu->weight x2=self.bn(x) x3=self.ACTIVATE(x2) x4=self.conv(x3,channels=cur_channels,name="conv") #""" """ #顺序三:x->weight->bn->relu x2=self.conv(x,channels=cur_channels,name="conv") x3=self.bn(x2) x4=self.ACTIVATE(x3) """ layers.append(x4) #每个卷积块后是pool层 last_layer=layers[-1] pool=self.max_pool(last_layer,"max_pool") layers.append(pool) conv_block_id+=1 #(3)卷积层flatten last_layer=layers[-1] last_shape=last_layer.get_shape() neu_num=1 for dim in range(1,len(last_shape)): neu_num*=last_shape[dim].value flat_layer=tf.reshape(last_layer,[-1,neu_num],name="flatten") layers.append(flat_layer) #(4)全连接层 #!!!!!!!!!最后一层不要加上relu!!!!!! with tf.variable_scope("full_connection"): for fc_id in range(0,len(fc_layers)): with tf.variable_scope("fc_%d" %(fc_id)): num=fc_layers[fc_id] x=layers[-1] x2=self.bn(x) x3=self.ACTIVATE(x,name="relu") y=self.fc(x3,num,"fc") layers.append(y) #(5)softmax和loss函数 self.output=tf.nn.softmax(layers[-1]) #loss函数 self._loss() #(6)辅助信息:正确率 self.iscorrect=tf.equal(tf.argmax(self.desired_out,1),tf.argmax(self.output,1),name="iscorrect") self.accuracy=tf.reduce_mean(tf.cast(self.iscorrect,dtype=tf.float32),name="accuracy") #(7)优化器和 variables初始化 self.get_optimizer() self.sess.run(tf.global_variables_initializer()) writer = tf.summary.FileWriter("./tboard/",self.sess.graph) def __del__(self): self.sess.close()
####VGG_MNIST.PY#### import VGG import tensorflow as tf import sys from tensorflow.examples.tutorials.mnist import input_data vgg=VGG.VGG([None,28,28,1],[100,100,100,10],activate_fun=tf.sigmoid)#,[(3,16),(3,32),(3,64),(3,128)]) #vgg=VGG.VGG([None,28,28,1],[10],[(2,16),(2,32),(2,64)]) print("param_num=%d" %(vgg.param_num)) #writer = tf.summary.FileWriter("./tboard/",vgg.sess.graph) mnist = input_data.read_data_sets("input_data", one_hot=True) def get_mnist_batch(num,get_test=False): batch=None if get_test: batch=[mnist.test.images,mnist.test.labels] else: batch=mnist.train.next_batch(num) input=[] for x in batch[0]: inp=[[0 for _ in range(0,28)] for _ in range(0,28)] for row in range(0,28): for col in range(0,28): inp[row][col]=[x[row*28+col]] """ if inp[row][col][0]>0.6: print(" ",end="") else: if inp[row][col][0]>0.3: print(".",end="") else: print("w",end="") if col==27: print("") sys.exit(0) """ input.append(inp) return input,batch[1] def get_mnist_test_accuracy(): batch=get_mnist_batch(0,True) accuracy=0 for st in range(0,10000,100): ret=vgg.train(batch[0][st:st+100],batch[1][st:st+100],learning_rate=0) accuracy+=ret["accuracy"]/100 return accuracy """ if vgg.restore("./model/"): test_acc=get_mnist_test_accuracy() print("[*]Test Result=%s at epoch%d" %(test_acc,0)) """ learning_rate=1e-4 for epoch in range(0,10): batch_sz=50 for i in range(int(50000/batch_sz)): batch = get_mnist_batch(batch_sz) ret=vgg.train(batch[0],batch[1],learning_rate=learning_rate) if i%100==0: #print(batch[1][0]) #print(ret[0][0]) print("[step%d] accuracy=%s loss=%s" %(i+100,ret["accuracy"],ret["loss"])) #learning_rate/=2 vgg.save("model/mnist_epoch",epoch) test_acc=get_mnist_test_accuracy() print("[*]Test Result=%s at epoch%d" %(test_acc,epoch))