tensorflow训练,带bn,学习率衰减

import tensorflow as tf
import numpy as np
import cv2 
import random

batchsize = 50

#读取数据
#------------------------------------------------------------------------------
import pickle
def load_obj(name ):
    with open(name, 'rb') as f:
        return pickle.load(f)

#dict_read = load_obj('AttackDict_S.dat')
dict_read = load_obj('AttackDict_M.dat')

tTrainN = len(dict_read['TrainLab'])
tValidN = len(dict_read['ValidLab'])

x_train = np.zeros( (tTrainN,112,112,3))
y_train = np.zeros( (tTrainN,2))
x_valid = np.zeros( (tValidN,96,96,3))
y_valid = np.zeros( (tValidN,2))

for i in range (tValidN):
    y_valid[i][dict_read['ValidLab'][i]] = 1
    img = cv2.imread(dict_read['ValidName'][i])
    img = img[8:8+96,8:8+96]
    x_valid[i]=(img/127.5)-1

for i in range (tTrainN):
    y_train[i][dict_read['TrainLab'][i]] = 1
    img = cv2.imread(dict_read['TrainName'][i])
    x_train[i]=img

TrainBatchN = tTrainN//batchsize
ValidBatchN = tValidN//batchsize

ValidP = 0
TrainP = 0

def GetValidBatch():
    global ValidP

    if(ValidP>=tValidN):
        ValidP = 0

    OutY = y_valid[ValidP:ValidP+batchsize]
    OutX = x_valid[ValidP:ValidP+batchsize]

    ValidP = ValidP+batchsize

    return OutX,OutY

def GetTrainBatch():
    global TrainP
    global y_train
    global x_train

    if(TrainP>=tTrainN):
        TrainP = 0
        li = list(range(tTrainN))
        random.shuffle(li)
        y_train = y_train[li]
        x_train = x_train[li]

    OutY = y_train[TrainP:TrainP+batchsize]
    OutXt = x_train[TrainP:TrainP+batchsize]
    OutX = np.zeros((batchsize,96,96,3))

    for i in range (batchsize):
        img = OutXt[i]
        r1 = random.randint(0,100)
        if r1>50 :
            img = cv2.flip(img,1,dst=None)

        r1 = random.randint(96,118)
        img = cv2.resize(img,(r1,r1))
        tz = r1-96

        r1 = random.randint(0,tz)
        r2 = random.randint(0,tz)
        img = img[r1:96+r1, r2:96+r2 ]
        OutX[i] = (img/127.5)-1

    TrainP = TrainP+batchsize

    return OutX,OutY

#------------------------------------------------------------------------------
x = tf.placeholder(tf.float32, [None, 96, 96, 3],name='input')

is_training = tf.placeholder(tf.bool)

def conv2d(input, kernel_size, filters, stride):
    if stride == 2 and kernel_size == 3:
        out = tf.space_to_batch_nd(input, [1, 1], [[1, 1], [1, 1]])
        out = tf.layers.conv2d(out, filters, kernel_size, [stride, stride], 'valid', use_bias=True)
    else:
        out = tf.layers.conv2d(input, filters, kernel_size, [stride, stride], 'same', use_bias=True)
    out = tf.layers.batch_normalization(out, training=is_training)
    out = tf.nn.relu(out)
    return out

def separable_conv2d(input, kernel_size, filters, stride):
    if stride == 2 and kernel_size == 3:
        out = tf.space_to_batch_nd(input, [1, 1], [[1, 1], [1, 1]])
        out = tf.layers.separable_conv2d(out, filters, kernel_size, [stride, stride], 'valid', use_bias=True)
    else:
        out = tf.layers.separable_conv2d(input, filters, kernel_size, [stride, stride], 'same', use_bias=True)
    out = tf.layers.batch_normalization(out, training=is_training)
    out = tf.nn.relu(out)
    return out

def gap(input):
    shape = input.get_shape().as_list()
    size = [shape[1], shape[2]]
    out = tf.layers.average_pooling2d(input, size, size, 'same')
    out = tf.layers.flatten(out)
    return out

def fc(input, units):
    out = tf.layers.dense(input, units, use_bias=True)
    #out = tf.layers.batch_normalization(out, training=is_training)
    return out

out = conv2d(x, 3, 16, 2)
out = separable_conv2d(out, 3, 16, 1)
out = separable_conv2d(out, 3, 32, 2)
out = separable_conv2d(out, 3, 32, 1)
out = separable_conv2d(out, 3, 64, 2)
out = separable_conv2d(out, 3, 64, 1)
out = separable_conv2d(out, 3, 128, 2)
for i in range(0, 5):
    out = separable_conv2d(out, 3, 128, 1)
out = separable_conv2d(out, 3, 256, 1)
print(out)
out = tf.layers.flatten(out)
out = fc(out, 2)
out = tf.identity(out, name='logits')
prediction = tf.nn.softmax(out,name='prediction')

y = tf.placeholder(tf.float32,[None,2])

Lrate = tf.placeholder(tf.float32)
LrateEpoch = [20,80]
LrateChange = [0.1,0.01,0.001]

#交叉熵代价函数
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    cross_entropy = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=out) )

#使用AdamOptimizer进行优化
#train_step = tf.train.AdamOptimizer(Lrate).minimize(cross_entropy)
train_step = tf.train.GradientDescentOptimizer(Lrate).minimize(cross_entropy)

#结果存放在一个布尔列表中
correct_prediction = tf.equal(tf.argmax(prediction,1),tf.argmax(y,1))#argmax返回一维张量中最大值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

print(x)
print(y)
print(prediction)

saver = tf.train.Saver() #声明saver用于保存模型
import os
ckpt_dir =  "./ckpt_dir"

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    #saver.restore(sess,ckpt_dir+"/tbn.ckpt-29")
    TestP = 0
    gloss = 0
    for epoch in range(41):     
        if epoch>= LrateEpoch[1]:
            lr = LrateChange[2]
        else:
            if epoch>= LrateEpoch[0]:
                lr = LrateChange[1]
            else:
                lr = LrateChange[0]  
        print("lr=",lr)
        for i in range(TrainBatchN):            
            batch_x,batch_y=GetTrainBatch()
            sess.run(train_step,feed_dict={x:batch_x,y:batch_y,is_training:True,Lrate:lr})
            t = sess.run(cross_entropy,feed_dict={x:batch_x,y:batch_y,is_training:False})
            gloss = gloss + t
            if(0==i%100):
                print("batch = ",i,"loss = ",t)

        gloss = gloss/TrainBatchN
        print("epoch = ",epoch,"loss = ",gloss)

        acc = 0
        for i in range(ValidBatchN):
            batch_x,batch_y=GetValidBatch()
            t = sess.run(accuracy,feed_dict={x:batch_x,y:batch_y,is_training:False})
            acc = acc + t
        acc = acc/(ValidBatchN)
        print("V ACC = ",acc)

        saver.save(sess,ckpt_dir+"/tbn.ckpt",epoch) #模型保存

猜你喜欢

转载自blog.csdn.net/masbbx123/article/details/86470759