tensorflow:保存与读取网络结构,参数等

在深度学习中,迁移学习是一个很普遍的操作,即将一个训练好的网络的一部分迁移到另一个网络,作为另一个网络结果的一部分.但是,我们要怎么保存和迁移呢?今天将以tensorflow的代码为例,给大家一个简单的介绍.
采用的函数是: tf.train.Saver()
1.存储和读取的步骤
(1)存储saver.save(sess, save_dir)

saver = tf.train.Saver()#声明ta.train.Saver()类用于保存
save_path = saver.save(sess,'save/filename.ckpt')#保存路径为相对路径的save文件夹,保存名为filename.ckpt

存储之后总共有几个后缀的文件:
filename.ckpt.meta:保存tensorflow的网络(计算图)结构
filename.ckpt:保存tensorflow中每一个变量的值
ckptpoint:保存一个目录下所有的模型文件列表
(2)读取saver.restore()

save.restore(sess, 'save/filename.ckpt')#从保存路径读取

在读取之前,先定义号和原来模型中相同的变量.读取出的结果直接赋值给变量使用
(3)直接测试已经训练好的模型
可以通过meta graph构建网络、载入训练时得到的参数,并使用默认的session:

saver = tf.train.import_meta_graph(‘save/filename.meta’)
saver.restore(tf.get_default_session(),’ save/filename.ckpt-16000’)

2.代码实现
代码实现我懒得写了,引用一个作者(Traphix)写好的,比较清晰明了: https://www.jianshu.com/p/83fa3aa2d0e9
(1)训练网络的

import tensorflow as tf
import sys

# load MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data', one_hot=True)

# 一些 hyper parameters
activation = tf.nn.relu
batch_size = 100
iteration = 20000
hidden1_units = 30
# 注意!这里是存储路径!
model_path = sys.path[0] + '/simple_mnist.ckpt'

X = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])

W_fc1 = tf.Variable(tf.truncated_normal([784, hidden1_units], stddev=0.2))
b_fc1 = tf.Variable(tf.zeros([hidden1_units]))
W_fc2 = tf.Variable(tf.truncated_normal([hidden1_units, 10], stddev=0.2))
b_fc2 = tf.Variable(tf.zeros([10]))

def inference(img):
    fc1 = activation(tf.nn.bias_add(tf.matmul(img, W_fc1), b_fc1))
    logits = tf.nn.bias_add(tf.matmul(fc1, W_fc2), b_fc2)
    return logits

def loss(logits, labels):
    cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits, labels)
    loss = tf.reduce_mean(cross_entropy)
    return loss

def evaluation(logits, labels):
    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    return accuracy

logits = inference(X)
loss = loss(logits, y_)
train_op = tf.train.AdamOptimizer(1e-4).minimize(loss)
accuracy = evaluation(logits, y_)

# 先实例化一个Saver()类
saver = tf.train.Saver()
init = tf.initialize_all_variables()

with tf.Session() as sess:
    sess.run(init)
    for i in xrange(iteration):
        batch = mnist.train.next_batch(batch_size)
        if i%1000 == 0 and i:
            train_accuracy = sess.run(accuracy, feed_dict={X: batch[0], y_: batch[1]})
            print "step %d, train accuracy %g" %(i, train_accuracy)
        sess.run(train_op, feed_dict={X: batch[0], y_: batch[1]})
    print '[+] Test accuracy is %f' % sess.run(accuracy, feed_dict={X: mnist.test.images, y_: mnist.test.labels})
    # 存储训练好的variables
    save_path = saver.save(sess, model_path)
    print "[+] Model saved in file: %s" % save_path

(2)测试

import tensorflow as tf
import sys

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data', one_hot=True)

activation = tf.nn.relu
hidden1_units = 30
model_path = sys.path[0] + '/simple_mnist.ckpt'

X = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])

W_fc1 = tf.Variable(tf.truncated_normal([784, hidden1_units], stddev=0.2))
b_fc1 = tf.Variable(tf.zeros([hidden1_units]))
W_fc2 = tf.Variable(tf.truncated_normal([hidden1_units, 10], stddev=0.2))
b_fc2 = tf.Variable(tf.zeros([10]))

def inference(img):
    fc1 = activation(tf.nn.bias_add(tf.matmul(img, W_fc1), b_fc1))
    logits = tf.nn.bias_add(tf.matmul(fc1, W_fc2), b_fc2)
    return logits

def evaluation(logits, labels):
    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    return accuracy

logits = inference(X)
accuracy = evaluation(logits, y_)

saver = tf.train.Saver()

with tf.Session() as sess:
    # 读取之前训练好的数据
    load_path = saver.restore(sess, model_path)
    print "[+] Model restored from %s" % load_path
    print '[+] Test accuracy is %f' % sess.run(accuracy, feed_dict={X: mnist.test.images, y_: mnist.test.labels})

猜你喜欢

转载自blog.csdn.net/xuan_zizizi/article/details/79106583
今日推荐