tensorflow 训练格式TFRecord简单使用;spark dataframe保存TFRecord

参考:tensorflow 训练格式TFRecord简单使用

1、tensorflow 训练格式TFRecord简单使用

保存:


import tensorflow as tf

# 回忆上一小节介绍的,每个Example内部实际有若干种Feature表达,下面
# 的四个工具方法方便我们进行Feature的构造
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode()]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _int64list_feature(value_list):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value_list))

# Example序列化成字节字符串
def serialize_example(user_id, city_id, app_type, viewd_pois, avg_paid, comment):
    # 注意我们需要按照格式来进行数据的组装,这里的dict便按照指定Schema构造了一条Example
    feature = {
      'user_id': _int64_feature(user_id),
      'city_id': _int64_feature(city_id),
      'app_type': _int64_feature(app_type),
      'viewd_pois': _int64list_feature(viewd_pois),
      'avg_paid': _float_feature(avg_paid),
      'comment': _bytes_feature(comment),
    }
    # 调用相关api将Example序列化为字节字符串
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

# 样本的生产,这里展示了将2条样本数据写入到了TFRecord文件中
def write_demo(filepath):
    with tf.python_io.TFRecordWriter(filepath) as writer:
        writer.write(serialize_example(1, 10, 1, [658, 325], 36.3, "yummy food."))
        writer.write(serialize_example(2, 20, 2, [897, 568, 126], 89.6, "nice place to have dinner."))
    print ("write demo data done.")

filepath = "testdata.tfrecord"
write_demo(filepath)

读取:

def read_demo(filepath):
    # 定义schema
    schema = {
        'user_id': tf.FixedLenFeature([], tf.int64),
        'city_id': tf.FixedLenFeature([], tf.int64),
        'app_type': tf.FixedLenFeature([], tf.int64),
        'viewed_pois': tf.VarLenFeature(tf.int64),
        'avg_paid': tf.FixedLenFeature([], tf.float32, default_value=0.0),
        'comment': tf.FixedLenFeature([], tf.string, default_value=''),
    }
    
    # 使用相关api,按照schema解析dataset中的样本
    def _parse_function(example_proto):
        return tf.parse_single_example(example_proto, schema)
    
    # 读取TFRecord文件来创建dataset
    dataset = tf.data.TFRecordDataset(filepath)
    #按照schema解析dataset中的每个样本
    parsed_dataset = dataset.map(_parse_function)
    #创建Iterator并迭代Iterator即可访问dataset中的样本
    next = parsed_dataset.make_one_shot_iterator().get_next()
    
    # 这里直接利用session,打印dataset中的样本
    with tf.Session() as sess:
        while True:
            try:
                print (sess.run(next))
            except:
                print ("out of data")
                break


filepath = "testdata.tfrecord"
read_demo(filepath)

在这里插入图片描述

2、spark dataframe保存TFRecord

from pyspark.sql.types import *
def main():
    #从hive表中读取数据
    df=spark.sql("""
    select * from experiment.table""")
    #tfrecords保存路径
    path = "viewfs:///user/hadoop-hdp/ml/demo/tensorflow/data/tfrecord"
    #将spark DataFrame格式数据转换为tfrecords格式数据
    df.repartition(file_num).write      \
        .mode("overwrite")              \
        .format("tfrecords")            \
        .option("recordType", "Example")\
        .save(path)
if __name__ == "__main__":
    main()

猜你喜欢

转载自blog.csdn.net/weixin_42357472/article/details/120328235
今日推荐