tensorflow(四)

tensorflow数据处理方法,

1.输入数据集

小数据集,可一次性加载到内存处理。

大数据集,一般由大量数据文件组成,因为数据集的规模太大,无法一次性加载到内存,只能每一步训练时加载数据,可以采用流水线并行读取数据。

流水线并行读取数据过程, (1)创建文件名列表(2)创建文件名队列(3)创建Reader和Decoder(4)创建样例队列

filename_queue = tf.train.string_input_producer(['stat0.csv','stat1.csv'])

reader = tf.TextLinerReader()
_,value = reader.read(filename_queue)

record_defaults = [[0],[0],[0.0],[0.0]]
id,age = tf.decode_csv(value,record_defaults=record_defaults)
features = tf.stack([id,age])
def get_my_example(filename_queue):
    reader = tf.SomeReader()
    _,value = reader.read(filename_queue)
    features = tf.decode_some(value)
    processed_example = some_processing(features)
    return processed_example

def input_pipeline(filenames,batch_size,num_epochs=None):
    filename_queue = tf.train.string_input_producer(filenames,num_epochs,shuffle=True)
    example = get_my_example(filename_queue)
    min_after_deque = 10000
    capacity = min_after_deque + 3*batch_size
    example_batch = tf.train.shuffle_batch([example],batch_size=batch_size,capacity=capacity,min_after_deque=min_after_deque)
    
    return example_batch

x_batch = input_pipeline(['stat.tfrecord'],batch_size=20)
sess = tf.Session()
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())

sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
try:
    for _ in range(1000):
        if not coord.should_stop():
            sess.run(train_op)
            print(example)
except:
    print('catch exception')
finally:
    coord.request_stop()
coord.join(threads)
sess.close()

2.模型参数

模型参数指的是模型的权重值和偏置值,使用tf.Variable创建模型参数

W = tf.Variable(0.0,name='W')
double = tf.multiply(2.0,W)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(4):
        sess.run(tf.assign_add(W,1.0))
        print(sess.run(W))

3.保持和恢复模型参数

tf.train.Saver是辅助训练工具类,它实现了存储模型参数的变量和checkpoint文件间的读写操作。

W = tf.Variable(0.0,name='W')
double = tf.multiply(2.0,W)

saver = tf.train.Saver({'weights':W})

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(4):
        sess.run(tf.assign_add(W,1.0))
        print(sess.run(W))
        saver.save(sess,'/tmp/text/ckpt')

猜你喜欢

转载自www.cnblogs.com/yangyang12138/p/12081892.html
今日推荐