Tutorial práctico de tensorflow 1.x (diez) - red neuronal recurrente

Objetivo

Este artículo tiene como objetivo presentar los puntos de conocimiento introductorios y ejemplos prácticos de tensorflow. Espero que todos los estudiantes novatos puedan dominar las operaciones básicas relacionadas con tensorflow después de aprender.

Red neuronal recurrente simple

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST', one_hot=True)
batch_size = 64
n_batches = mnist.train.num_examples // batch_size
n_classes = 10 # 类别个数
hidden_size = 128 # 隐层纬度
steps = 28 # 最大序列
embedding_size = 28 # 输入维度

x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])

weights = tf.Variable(tf.random_normal([hidden_size, n_classes], stddev=0.1))
biases = tf.Variable(tf.zeros([n_classes]))

def RNN(x, w, b):
    inputs = tf.reshape(x, shape = [-1, steps, embedding_size]) # 将图像拉成一个时序序列
    cell = tf.contrib.rnn.BasicRNNCell(hidden_size) # 每个 RNN 隐藏输出维度 hidden_size
    _, final_state = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32) # rnn 模型计算获取最有状态向量
    result = tf.nn.softmax(tf.matmul(final_state, w) + b) # 对结果进行 softmax
    return result

predict = RNN(x, weights, biases) # 预测结果
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=predict)) # 损失
opt = tf.train.AdamOptimizer(0.001).minimize(loss) # 定义优化器
correct = tf.equal(tf.argmax(y,1), tf. argmax(predict,1))
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) # 准确率

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    total_batch = 0
    last_batch = 0
    best = 0
    for epoch in range(100):
        for _ in range(n_batches):
            xx, yy = mnist.train.next_batch(batch_size)
            sess.run(opt, {x:xx, y:yy})
        acc, l = sess.run([accuracy, loss], {x:mnist.test.images, y:mnist.test.labels})
        if acc > best:
            best = acc
            last_batch = total_batch
            print('eopch:%d, acc:%f, loss:%f'%(epoch, acc, l))
        if total_batch - last_batch > 5: # 早停条件
            print('early stop')
            break
        total_batch += 1
复制代码

salida de resultados

eopch:0, acc:0.878200, loss:1.589309
eopch:1, acc:0.907200, loss:1.556449
eopch:2, acc:0.917600, loss:1.546643
eopch:4, acc:0.933500, loss:1.528619
eopch:5, acc:0.950100, loss:1.512312
eopch:6, acc:0.951100, loss:1.511004
early stop
复制代码

Punto clave 1

Puede encontrar el conocimiento básico de las redes neuronales recurrentes y sus variantes en línea. Hay muchos materiales de aprendizaje, y también puede consultar los artículos que escribí antes: juejin.cn/post/697234…

punto dos

En comparación con la anterior red simple de múltiples capas ocultas + abandono, la tasa de precisión es solo del 97,8 %, la tasa de precisión del uso de la red neuronal cíclica es solo del 95 % y la tasa de precisión de la red neuronal convolucional no es tan alta como esa. de la red neuronal convolucional, lo que indica que CNN es inherentemente bueno en el procesamiento de imágenes Ventajosamente, las redes neuronales recurrentes son adecuadas para el procesamiento de datos textuales.

Referencia en este artículo

Referencia para este artículo: blog.csdn.net/qq_19672707…

Supongo que te gusta

Origin juejin.im/post/7087049893562286094
Recomendado
Clasificación