TF中的多层神经网络

多层神经网络

神经网络中的层数,是以隐含层的数目而言的,一般不会去统计输入层与输出层;本文采用的是简单的全连接层,所谓全连接,就是上一层的每一个节点到要与下一层的每一个节点一一相连。作为案例,将进行最基本的多层网络构建,并实现mnist数据及分类。

全连接层构造函数

# 定义全连接层构造函数
def fcn_layer(inputs,  # 输入的数据
              input_dim,  # 输入的维度
              output_dim,  # 输出的维度,也就是神经元数目
              activation=None  # 激活函数,可以不使用
              ):
    w = tf.Variable(tf.truncated_normal([input_dim, output_dim], stddev=0.1))
    b = tf.Variable(tf.zeros([output_dim]))
    xwb = tf.matmul(inputs, w) + b
    # 基于输入确定是否激活
    if activation is None:
        outputs = xwb
    else:
        outputs = activation(xwb)
    # 返回最终结果
    return outputs

建立如上的全连接层构造函数,在给定输入数据、输入输出维度以及是否使用激活函数的前提下,可以实现端到端的全连接层建立;输出维度也就是一层网络中的神经元数,可以认为每一个神经元在试图提取样本的一种特征,随着网络不断加深,所提取的特征也就越高级,或者说人类越看不懂;权重w一般不建议使用完全随机分布,而是此处的截断式的随机分布;激活函数的选择有很多种,但是隐层一般使用relu而非s型函数或者双曲正切函数,因为会存在梯度消失的问题,而输出层一般采用softmax,将计算的scores映射到概率域。

利用构造函数,构建双层网络

# 利用模型函数进行双层多元网络构建
H1_nn = 256  # 首层神经元数
H2_nn = 64  # 次层神经元数
# 定义输入
x = tf.placeholder(tf.float32, [None, 784], name="X")
y = tf.placeholder(tf.float32, [None, 10], name="Y")
# 定义隐含1、2层,以及前向计算与输出层
h1 = fcn_layer(x, 784, H1_nn, tf.nn.relu)
h2 = fcn_layer(h1, 256, H2_nn, tf.nn.relu)
forward = fcn_layer(h2, H2_nn, 10)
preb = tf.nn.softmax(forward)

如上建立两个隐层一个输出层;因为使用矩阵乘法,注意前后相邻层的维度关系,并且最终输出层的维度是固定的,即分类的数目;此处将前向计算与输出层分开定义,是为了在之后方便定义准确率损失函数。

设置超参

# 设置超参与日志路径
train_epoch = 50
batch_size = 50
total_batch = int(mnist.train.num_examples/batch_size)
learning_rate = 0.01
display_step = 1
save_step = 5
ckpt_dir = "./ckpt_dir/"
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

# 设置损失函数与优化器
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=forward, labels=y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
# 设置准确率
correct_prediction = tf.equal(tf.argmax(preb, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

如上设置各项超参。注意使用的损失函数是TF提供的带有softmax的交叉熵损失函数,所以定义损失函数时,输入的是前向计算。

开始训练

# 启动会话并初始化
sess = tf.Session()
saver = tf.train.Saver()  # 初始化写入对象,要定义在变量初始化之前
sess.run(tf.global_variables_initializer())
start_time = time()  # 记录起始时间

# 开始训练
print("Start Train!")
for epoch in range(train_epoch):
    for batch in range(total_batch):
        xs, ys = mnist.train.next_batch(batch_size)
        sess.run(optimizer, feed_dict={x: xs, y: ys})
    loss, acc = sess.run([loss_function, accuracy], feed_dict={x: mnist.validation.images, y: mnist.validation.labels})  # 每训练一轮用验证集检测损失函数与准确率
    if (epoch+1) % display_step == 0:  # 以一定粒度显示loss与accuracy
        print("epoch: ", epoch+1, ",loss: ", loss, ",accuracy: ", acc)
    if(epoch+1) % save_step == 0:  # 以一定粒度保存模型
        saver.save(sess, os.path.join(ckpt_dir, 'Mnist_h256_model_{:06d}.ckpt'.format(epoch+1)))
        print('Mnist_h256_model_{:06d}.ckpt saved!'.format(epoch+1))
saver.save(sess, os.path.join(ckpt_dir, 'Mnist_h256_model_final.ckpt'))  # 保存最终模型
print("Model Saved!")
duration = time()-start_time  # 统计时长
print("duration: ", duration)

可视化辅助函数

def print_predict_errs(labels,  # 标签集合
                       prediction  # 预测值集合
                       ):
    count = 0
    compare_list = (prediction == np.argmax(labels, 1))
    err_list = [i for i in range(len(compare_list)) if compare_list[i] == False]
    for x in err_list:
        print("index= "+str(x)+"标签值= ", np.argmax(labels[x]),"预测值= ", prediction[x])
        count = count+1
    print("总计: ", str(count))

要求输入的是mnist的标签集,以及预测结果集;所谓预测结果集要求已经从独热编码变为直接显示预测的值的形式,具体见模型检测的代码。

载入并测试模型以及可视化

# 载入最新的模型
saver = tf.train.Saver(tf.global_variables())
moudke_file = tf.train.latest_checkpoint('ckpt_dir')
saver.restore(sess, moudke_file)
print("Load model!")

# 测试模型,利用测试集数据
acc_test = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
print("Test Accuracy: ", acc_test)


# 利用可视化函数,将对于测试集的预测值与预测集的标签值进行比对,并输出不匹配列表
# 按顺序返回预测值,需按行取最大值,将独热编码转换为一列,所取得结果组成一个(1000,1)的元组
predicion_result = sess.run(tf.argmax(preb, 1), feed_dict={x: mnist.test.images})
print_predict_errs(mnist.test.labels,predicion_result)
发布了5 篇原创文章 · 获赞 0 · 访问量 62

猜你喜欢

转载自blog.csdn.net/weixin_41707744/article/details/104790117