网上流传甚广的binary_layer.py,实现了二值权重网络,github上基本上都是这个文件。
也不知道最初作者是谁了,反正大家都在用。
主要实现了Dense_BinaryLayer这个类,这里就不上传了。
下面是根据前面的博文,构造了相同的NN网络,
输入28*28
隐藏层500
输出10
import tensorflow as tf
import binary_layer as binary
import numpy as np
import os
import cv2 as cv
import time
def fully_connect_bn(pre_layer, output_dim, act, use_bias, training):
pre_act = binary.dense_binary(pre_layer, output_dim,
use_bias = use_bias,
activation = None,
kernel_constraint = lambda w: tf.clip_by_value(w, -1.0, 1.0))
bn = binary.batch_normalization(pre_act, momentum=0.9, epsilon=1e-4, training=training)
if act == None:
output = bn
else:
output = act(bn)
return output
def shuffle(X,y):
shuffle_parts = 1
chunk_size = int(len(X) / shuffle_parts)
shuffled_range = np.arange(chunk_size)
X_buffer = np.copy(X[0:chunk_size])
y_buffer = np.copy(y[0:chunk_size])
for k in range(shuffle_parts):
np.random.shuffle(shuffled_range)
for i in range(chunk_size):
X_buffer[i] = X[k * chunk_size + shuffled_range[i]]
y_buffer[i] = y[k * chunk_size + shuffled_range[i]]
X[k * chunk_size:(k + 1) * chunk_size] = X_buffer
y[k * chunk_size:(k + 1) * chunk_size] = y_buffer
return X,y
def train():
x = tf.placeholder(tf.float32, shape=[None, 784])
target = tf.placeholder(tf.float32, shape=[None, 10])
training = tf.placeholder(tf.bool)
#训练集
train_path = 'E:\\[1]Paper\\Datasets\\MINST\\train'
list = os.listdir(train_path)
train_data = []
train_label = []
for filename in list:
filepath = '%s\\%s' % (train_path, filename)
img = cv.imread(filepath, 0)
if img is None:
continue
img = img / 255
rows,cols = img.shape
img = img.reshape((rows * cols))
train_data.append(img)#一维数据
labels = [0] * 10
labels[int(filename.split('_')[0])] = 1
train_label.append(labels)#数据标签
print('train data load!\n')
#测试集
val_path = 'E:\\[1]Paper\\Datasets\\MINST\\query'
list = os.listdir(val_path)
val_data = []
val_label = []
for filename in list:
filepath = '%s\\%s' % (val_path, filename)
img = cv.imread(filepath, 0)
if img is None:
continue
img = img / 255
rows,cols = img.shape
img = img.reshape((rows * cols))
val_data.append(img)#一维数据
labels = [0] * 10
labels[int(filename.split('_')[0])] = 1
val_label.append(labels)#数据标签
print('test data load!\n')
fc1 = fully_connect_bn(x, 500, act=binary.binary_tanh_unit, use_bias=True, training=training)
fc2 = fully_connect_bn(fc1, 10, act=None, use_bias=True, training=training)
loss = tf.reduce_mean(tf.square(tf.maximum(0.,1. - target * fc2)))
train_num = 5000
batch_size = 64
epochs = 1000
lr_start = 0.001
lr_end = 0.00001
lr_decay = (lr_end / lr_start) ** (1. / epochs)
global_step1 = tf.Variable(0, trainable=False)
global_step2 = tf.Variable(0, trainable=False)
lr1 = tf.train.exponential_decay(lr_start, global_step=global_step1, decay_steps=int(train_num / batch_size), decay_rate=lr_decay)
lr2 = tf.train.exponential_decay(lr_start, global_step=global_step2, decay_steps=int(train_num / batch_size), decay_rate=lr_decay)
sess = tf.Session()
saver = tf.train.Saver()
#saver.restore(sess, "model/model.ckpt")
other_var = [var for var in tf.trainable_variables() if not var.name.endswith('kernel:0')]
opt = binary.AdamOptimizer(binary.get_all_LR_scale(), lr1)
opt2 = tf.train.AdamOptimizer(lr2)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops): # when training, the moving_mean and moving_variance in the BN need to be updated
train_kernel_op = opt.apply_gradients(binary.compute_grads(loss, opt), global_step=global_step1)
train_other_op = opt2.minimize(loss, var_list=other_var, global_step=global_step2)
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(fc2, 1), tf.argmax(target, 1)), tf.float32))
sess.run(tf.global_variables_initializer())
old_acc = 0.0
for i in range(epochs):
X_train, Y_train = shuffle(train_data, train_label)
batches = int(train_num / batch_size)
for b in range(batches):
info = sess.run([loss, train_kernel_op, train_other_op],
feed_dict={x: X_train[b * batch_size:(b + 1) * batch_size],
target: Y_train[b * batch_size:(b + 1) * batch_size],
training: True})
hist = sess.run([accuracy, opt._lr],feed_dict={x: val_data,target: val_label,training: False})
print('epochs {0} : acc={1}'.format(i, hist))
if hist[0] > old_acc:
old_acc = hist[0]
save_path = saver.save(sess, "./bnn_model.ckpt")
if __name__ == '__main__':
train()
验证集时间消耗 | 验证集准确率 | |
二值权重 | 0.03s | 0.855 |
全精度 | 0.14s | 0.871 |