tensorflow保存训练好的模型

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
mnist = input_data.read_data_sets('./MNIST_data',one_hot=True)

class Net:
    def __init__(self):
        self.x = tf.placeholder(shape=[None,784],dtype=tf.float32)
        self.y = tf.placeholder(shape=[None,10],dtype=tf.float32)
        self.w1 = tf.Variable(tf.truncated_normal(shape=[784,512],dtype=tf.float32,stddev=0.1))
        self.b1 = tf.Variable(tf.zeros([512]))
        self.w2 = tf.Variable(tf.truncated_normal(shape=[512,256], dtype=tf.float32, stddev=0.1))
        self.b2 = tf.Variable(tf.zeros([256]))
        self.w3 = tf.Variable(tf.truncated_normal(shape=[256, 10], dtype=tf.float32, stddev=0.1))
        self.b3 = tf.Variable(tf.zeros([10]))
    def forward(self):
        self.y1 = tf.nn.dropout(tf.nn.relu(tf.layers.batch_normalization(tf.matmul(self.x,self.w1)+self.b1)),keep_prob=0.5)
        self.y2 = tf.nn.dropout(tf.nn.relu(tf.layers.batch_normalization(tf.matmul(self.y1, self.w2) + self.b2)),keep_prob=0.5)
        self.y3 = tf.nn.softmax(tf.layers.batch_normalization(tf.matmul(self.y2, self.w3) + self.b3))
    def backward(self):
        self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.y3,labels=self.y))
        self.opt = tf.train.AdamOptimizer(0.0001).minimize(self.loss)
        self.acc = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(self.y3,axis=1),tf.argmax(self.y,axis=1)),dtype=tf.float32))
save_path = 'params/chpk'#chpk必加
if __name__ == '__main__':
    net = Net()
    net.forward()
    net.backward()
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()#创建一个在训练时的存储对象
    with tf.Session() as sess:
        saver.restore(sess,save_path)#已经有保存好了的
        # sess.run(init)#第一次训练
        for epoch in range(1000):
            x,y = mnist.train.next_batch(100)
            loss,_ = sess.run([net.loss,net.opt],feed_dict={net.x:x,net.y:y})#执行merged
            if epoch % 100 == 0:
                xs, ys = mnist.validation.next_batch(100)
                error, _,acc,out = sess.run([net.loss, net.opt,net.acc,net.y3], feed_dict={net.x: x, net.y: y})
                saver.save(sess,save_path=save_path)#每训练100次保存一次
                print('saved')

猜你喜欢

转载自blog.csdn.net/weixin_38241876/article/details/89553716