tutorial práctico de tensorflow 1.x (11) - guardado y restauración de modelos

Objetivo

Este artículo tiene como objetivo presentar los puntos de conocimiento introductorios y ejemplos prácticos de tensorflow. Espero que todos los estudiantes novatos puedan dominar las operaciones básicas relacionadas con tensorflow después de aprender.

guardar modelo

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST", one_hot=True)
batch_size = 64
n_batches = mnist.train.num_examples // batch_size

x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
 
w = tf.Variable(tf.random_normal([784, 10], stddev=0.1))
b = tf.Variable(tf.zeros([10]))
predict = tf.nn.softmax(tf.matmul(x, w) + b)

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predict, labels=y))
opt = tf.train.AdamOptimizer(0.001).minimize(loss)
correct = tf.equal(tf.argmax(y, 1), tf.argmax(predict, 1))
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    total_batch = 0
    last = 0
    best = 0
    for epoch in range(100):
        for _ in range(n_batches):
            xx,yy = mnist.train.next_batch(batch_size)
            sess.run(opt, {x:xx, y:yy})
        acc, l = sess.run([accuracy, loss], {x:mnist.test.images, y:mnist.test.labels})
        if acc > best:
            best = acc
            last = total_batch
            saver.save(sess, 'saved_model/model') # 每次只保存最好的结果
            print(epoch, acc, l)
        if total_batch - last > 5:
            print('early stop')
            break
        total_batch += 1
        
复制代码

salida de resultados

0 0.9035 1.5953374
1 0.9147 1.5688152
2 0.9212 1.5580758
3 0.9234 1.552525
4 0.9239 1.5495663
5 0.9264 1.5462393
6 0.9271 1.5441632
7 0.9288 1.5419955
8 0.9302 1.5403246
12 0.9308 1.5376735
14 0.9324 1.5360526
19 0.9333 1.534032
25 0.9338 1.5329739
26 0.934 1.5326717
early stop
复制代码

modelo de lectura

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, 'saved_model/model')
    acc, l = sess.run([accuracy, loss], {x:mnist.test.images, y:mnist.test.labels})
    print(acc, l)
    
复制代码

impresión de resultados

0.934 1.5326717
复制代码

Punto clave 1

Debido a que solo estamos guardando el modelo de mejor rendimiento, estamos leyendo el modelo y el resultado de probarlo con los mismos datos es el mismo que la última vez que se entrenó.

Referencia en este artículo

Referencia para este artículo: blog.csdn.net/qq_19672707…

Supongo que te gusta

Origin juejin.im/post/7087050067076448270
Recomendado
Clasificación