Binary_layer.py, que se distribuye ampliamente en Internet, implementa una red de peso binario. Este archivo está básicamente en github.
No sé quién fue el autor original, de todos modos, todos lo están usando.
Implementa principalmente la clase Dense_BinaryLayer, por lo que no la subiré aquí.
Lo siguiente se basa en la publicación de blog anterior para construir la misma red NN,
Ingrese 28 * 28
Capa oculta 500
Salida 10
import tensorflow as tf
import binary_layer as binary
import numpy as np
import os
import cv2 as cv
import time
def fully_connect_bn(pre_layer, output_dim, act, use_bias, training):
pre_act = binary.dense_binary(pre_layer, output_dim,
use_bias = use_bias,
activation = None,
kernel_constraint = lambda w: tf.clip_by_value(w, -1.0, 1.0))
bn = binary.batch_normalization(pre_act, momentum=0.9, epsilon=1e-4, training=training)
if act == None:
output = bn
else:
output = act(bn)
return output
def shuffle(X,y):
shuffle_parts = 1
chunk_size = int(len(X) / shuffle_parts)
shuffled_range = np.arange(chunk_size)
X_buffer = np.copy(X[0:chunk_size])
y_buffer = np.copy(y[0:chunk_size])
for k in range(shuffle_parts):
np.random.shuffle(shuffled_range)
for i in range(chunk_size):
X_buffer[i] = X[k * chunk_size + shuffled_range[i]]
y_buffer[i] = y[k * chunk_size + shuffled_range[i]]
X[k * chunk_size:(k + 1) * chunk_size] = X_buffer
y[k * chunk_size:(k + 1) * chunk_size] = y_buffer
return X,y
def train():
x = tf.placeholder(tf.float32, shape=[None, 784])
target = tf.placeholder(tf.float32, shape=[None, 10])
training = tf.placeholder(tf.bool)
#训练集
train_path = 'E:\\[1]Paper\\Datasets\\MINST\\train'
list = os.listdir(train_path)
train_data = []
train_label = []
for filename in list:
filepath = '%s\\%s' % (train_path, filename)
img = cv.imread(filepath, 0)
if img is None:
continue
img = img / 255
rows,cols = img.shape
img = img.reshape((rows * cols))
train_data.append(img)#一维数据
labels = [0] * 10
labels[int(filename.split('_')[0])] = 1
train_label.append(labels)#数据标签
print('train data load!\n')
#测试集
val_path = 'E:\\[1]Paper\\Datasets\\MINST\\query'
list = os.listdir(val_path)
val_data = []
val_label = []
for filename in list:
filepath = '%s\\%s' % (val_path, filename)
img = cv.imread(filepath, 0)
if img is None:
continue
img = img / 255
rows,cols = img.shape
img = img.reshape((rows * cols))
val_data.append(img)#一维数据
labels = [0] * 10
labels[int(filename.split('_')[0])] = 1
val_label.append(labels)#数据标签
print('test data load!\n')
fc1 = fully_connect_bn(x, 500, act=binary.binary_tanh_unit, use_bias=True, training=training)
fc2 = fully_connect_bn(fc1, 10, act=None, use_bias=True, training=training)
loss = tf.reduce_mean(tf.square(tf.maximum(0.,1. - target * fc2)))
train_num = 5000
batch_size = 64
epochs = 1000
lr_start = 0.001
lr_end = 0.00001
lr_decay = (lr_end / lr_start) ** (1. / epochs)
global_step1 = tf.Variable(0, trainable=False)
global_step2 = tf.Variable(0, trainable=False)
lr1 = tf.train.exponential_decay(lr_start, global_step=global_step1, decay_steps=int(train_num / batch_size), decay_rate=lr_decay)
lr2 = tf.train.exponential_decay(lr_start, global_step=global_step2, decay_steps=int(train_num / batch_size), decay_rate=lr_decay)
sess = tf.Session()
saver = tf.train.Saver()
#saver.restore(sess, "model/model.ckpt")
other_var = [var for var in tf.trainable_variables() if not var.name.endswith('kernel:0')]
opt = binary.AdamOptimizer(binary.get_all_LR_scale(), lr1)
opt2 = tf.train.AdamOptimizer(lr2)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops): # when training, the moving_mean and moving_variance in the BN need to be updated
train_kernel_op = opt.apply_gradients(binary.compute_grads(loss, opt), global_step=global_step1)
train_other_op = opt2.minimize(loss, var_list=other_var, global_step=global_step2)
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(fc2, 1), tf.argmax(target, 1)), tf.float32))
sess.run(tf.global_variables_initializer())
old_acc = 0.0
for i in range(epochs):
X_train, Y_train = shuffle(train_data, train_label)
batches = int(train_num / batch_size)
for b in range(batches):
info = sess.run([loss, train_kernel_op, train_other_op],
feed_dict={x: X_train[b * batch_size:(b + 1) * batch_size],
target: Y_train[b * batch_size:(b + 1) * batch_size],
training: True})
hist = sess.run([accuracy, opt._lr],feed_dict={x: val_data,target: val_label,training: False})
print('epochs {0} : acc={1}'.format(i, hist))
if hist[0] > old_acc:
old_acc = hist[0]
save_path = saver.save(sess, "./bnn_model.ckpt")
if __name__ == '__main__':
train()
Consumo de tiempo del conjunto de validación | Precisión del conjunto de validación | |
Peso binario | 0,03 s | 0,855 |
Precisión total | 0,14 s | 0,871 |