BNN二值权重网络测试代码

网上流传甚广的binary_layer.py,实现了二值权重网络,github上基本上都是这个文件。

也不知道最初作者是谁了,反正大家都在用。

主要实现了Dense_BinaryLayer这个类,这里就不上传了。

下面是根据前面的博文,构造了相同的NN网络,

输入28*28

隐藏层500

输出10

import tensorflow as tf
import binary_layer as binary
import numpy as np
import os
import cv2 as cv
import time

def fully_connect_bn(pre_layer, output_dim, act, use_bias, training):
    pre_act = binary.dense_binary(pre_layer, output_dim,
                                    use_bias = use_bias,
                                    activation = None,
                                    kernel_constraint = lambda w: tf.clip_by_value(w, -1.0, 1.0))
    bn = binary.batch_normalization(pre_act, momentum=0.9, epsilon=1e-4, training=training)
    if act == None:
        output = bn
    else:
        output = act(bn)
    return output

def shuffle(X,y):
    shuffle_parts = 1
    chunk_size = int(len(X) / shuffle_parts)
    shuffled_range = np.arange(chunk_size)

    X_buffer = np.copy(X[0:chunk_size])
    y_buffer = np.copy(y[0:chunk_size])

    for k in range(shuffle_parts):

        np.random.shuffle(shuffled_range)

        for i in range(chunk_size):

            X_buffer[i] = X[k * chunk_size + shuffled_range[i]]
            y_buffer[i] = y[k * chunk_size + shuffled_range[i]]

        X[k * chunk_size:(k + 1) * chunk_size] = X_buffer
        y[k * chunk_size:(k + 1) * chunk_size] = y_buffer

    return X,y

def train():
    x = tf.placeholder(tf.float32, shape=[None, 784])
    target = tf.placeholder(tf.float32, shape=[None, 10])
    training = tf.placeholder(tf.bool)

    #训练集
    train_path = 'E:\\[1]Paper\\Datasets\\MINST\\train'
    list = os.listdir(train_path)
    train_data = []
    train_label = []
    
    for filename in list:
        filepath = '%s\\%s' % (train_path, filename)
        img = cv.imread(filepath, 0)
        if img is None:
            continue
        img = img / 255
        rows,cols = img.shape
        img = img.reshape((rows * cols))
        train_data.append(img)#一维数据
        labels = [0] * 10
        labels[int(filename.split('_')[0])] = 1
        train_label.append(labels)#数据标签
 
    print('train data load!\n')
 
    #测试集
    val_path = 'E:\\[1]Paper\\Datasets\\MINST\\query'
    list = os.listdir(val_path)
    val_data = []
    val_label = []
 
    for filename in list:
        filepath = '%s\\%s' % (val_path, filename)
        img = cv.imread(filepath, 0)
        if img is None:
            continue
        img = img / 255
        rows,cols = img.shape
        img = img.reshape((rows * cols))
        val_data.append(img)#一维数据
        labels = [0] * 10
        labels[int(filename.split('_')[0])] = 1
        val_label.append(labels)#数据标签
 
    print('test data load!\n')

    fc1 = fully_connect_bn(x, 500, act=binary.binary_tanh_unit, use_bias=True, training=training)
    fc2 = fully_connect_bn(fc1, 10, act=None, use_bias=True, training=training)

    loss = tf.reduce_mean(tf.square(tf.maximum(0.,1. - target * fc2)))

    train_num = 5000
    batch_size = 64
    epochs = 1000
    lr_start = 0.001
    lr_end = 0.00001
    lr_decay = (lr_end / lr_start) ** (1. / epochs)
    global_step1 = tf.Variable(0, trainable=False)
    global_step2 = tf.Variable(0, trainable=False)
    lr1 = tf.train.exponential_decay(lr_start, global_step=global_step1, decay_steps=int(train_num / batch_size), decay_rate=lr_decay)
    lr2 = tf.train.exponential_decay(lr_start, global_step=global_step2, decay_steps=int(train_num / batch_size), decay_rate=lr_decay)

    sess = tf.Session()
    saver = tf.train.Saver()
    #saver.restore(sess, "model/model.ckpt")
    other_var = [var for var in tf.trainable_variables() if not var.name.endswith('kernel:0')]
    opt = binary.AdamOptimizer(binary.get_all_LR_scale(), lr1)
    opt2 = tf.train.AdamOptimizer(lr2)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):   # when training, the moving_mean and moving_variance in the BN need to be updated
        train_kernel_op = opt.apply_gradients(binary.compute_grads(loss, opt),  global_step=global_step1)
        train_other_op = opt2.minimize(loss, var_list=other_var,  global_step=global_step2)

    accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(fc2, 1), tf.argmax(target, 1)), tf.float32))
    sess.run(tf.global_variables_initializer())

    old_acc = 0.0
    for i in range(epochs):
        X_train, Y_train = shuffle(train_data, train_label)
        batches = int(train_num / batch_size)
        for b in range(batches):
            info = sess.run([loss, train_kernel_op, train_other_op],
                     feed_dict={x: X_train[b * batch_size:(b + 1) * batch_size],
                                target: Y_train[b * batch_size:(b + 1) * batch_size],
                                training: True})

        hist = sess.run([accuracy, opt._lr],feed_dict={x: val_data,target: val_label,training: False})
        print('epochs {0} : acc={1}'.format(i, hist))

        if hist[0] > old_acc:
            old_acc = hist[0]
            save_path = saver.save(sess, "./bnn_model.ckpt")

if __name__ == '__main__':
    train()
  验证集时间消耗 验证集准确率
二值权重 0.03s 0.855
全精度 0.14s 0.871

猜你喜欢

转载自blog.csdn.net/XLcaoyi/article/details/94204767