Objetivo
Este artículo tiene como objetivo presentar los puntos de conocimiento introductorios y ejemplos prácticos de tensorflow. Espero que todos los estudiantes novatos puedan dominar las operaciones básicas relacionadas con tensorflow después de aprender.
Red neuronal recurrente simple
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST', one_hot=True)
batch_size = 64
n_batches = mnist.train.num_examples // batch_size
n_classes = 10 # 类别个数
hidden_size = 128 # 隐层纬度
steps = 28 # 最大序列
embedding_size = 28 # 输入维度
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
weights = tf.Variable(tf.random_normal([hidden_size, n_classes], stddev=0.1))
biases = tf.Variable(tf.zeros([n_classes]))
def RNN(x, w, b):
inputs = tf.reshape(x, shape = [-1, steps, embedding_size]) # 将图像拉成一个时序序列
cell = tf.contrib.rnn.BasicRNNCell(hidden_size) # 每个 RNN 隐藏输出维度 hidden_size
_, final_state = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32) # rnn 模型计算获取最有状态向量
result = tf.nn.softmax(tf.matmul(final_state, w) + b) # 对结果进行 softmax
return result
predict = RNN(x, weights, biases) # 预测结果
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=predict)) # 损失
opt = tf.train.AdamOptimizer(0.001).minimize(loss) # 定义优化器
correct = tf.equal(tf.argmax(y,1), tf. argmax(predict,1))
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) # 准确率
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
total_batch = 0
last_batch = 0
best = 0
for epoch in range(100):
for _ in range(n_batches):
xx, yy = mnist.train.next_batch(batch_size)
sess.run(opt, {x:xx, y:yy})
acc, l = sess.run([accuracy, loss], {x:mnist.test.images, y:mnist.test.labels})
if acc > best:
best = acc
last_batch = total_batch
print('eopch:%d, acc:%f, loss:%f'%(epoch, acc, l))
if total_batch - last_batch > 5: # 早停条件
print('early stop')
break
total_batch += 1
复制代码
salida de resultados
eopch:0, acc:0.878200, loss:1.589309
eopch:1, acc:0.907200, loss:1.556449
eopch:2, acc:0.917600, loss:1.546643
eopch:4, acc:0.933500, loss:1.528619
eopch:5, acc:0.950100, loss:1.512312
eopch:6, acc:0.951100, loss:1.511004
early stop
复制代码
Punto clave 1
Puede encontrar el conocimiento básico de las redes neuronales recurrentes y sus variantes en línea. Hay muchos materiales de aprendizaje, y también puede consultar los artículos que escribí antes: juejin.cn/post/697234…
punto dos
En comparación con la anterior red simple de múltiples capas ocultas + abandono, la tasa de precisión es solo del 97,8 %, la tasa de precisión del uso de la red neuronal cíclica es solo del 95 % y la tasa de precisión de la red neuronal convolucional no es tan alta como esa. de la red neuronal convolucional, lo que indica que CNN es inherentemente bueno en el procesamiento de imágenes Ventajosamente, las redes neuronales recurrentes son adecuadas para el procesamiento de datos textuales.
Referencia en este artículo
Referencia para este artículo: blog.csdn.net/qq_19672707…