关于Tfrecord

写入Tfrecord

        print("convert data into tfrecord:train\n")
        out_file_train = "/home/huadong.wang/bo.yan/fudan_mtl/data/ace2005/bn_nw.train.tfrecord"
        writer = tf.python_io.TFRecordWriter(out_file_train)

        for i in tqdm(range(len(data_train))):
            record = tf.train.Example(features=tf.train.Features(feature={
                'word_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_x[i].tostring()])),
                'et_ids1': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_et1[i].tostring()])),
                'et_ids2': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_et2[i].tostring()])),
                'position_ids1': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_p1[i].tostring()])),
                'position_ids2': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_p1[i].tostring()])),
                'chunks': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_chunks[i].tostring()])),
                'spath_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_spath[i].tostring()])),
                'seq_len': tf.train.Feature(int64_list=tf.train.Int64List(value=[train_x_len[i]])),
                'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[np.argmax(train_relation[i])])),
                'task': tf.train.Feature(int64_list=tf.train.Int64List(value=[np.int64(0)]))
            }))
            writer.write(record.SerializeToString())
        writer.close()

  

解析tfrecord

def _parse_tfexample(serialized_example):
  '''parse serialized tf.train.SequenceExample to tensors
  context features : label, task
  sequence features: sentence
  '''
  context_features={'label'    : tf.FixedLenFeature([], tf.int64),
                    'task'    : tf.FixedLenFeature([], tf.int64),
                    'seq_len': tf.FixedLenFeature([], tf.int64)}
  sequence_features={'word_ids': tf.FixedLenSequenceFeature([], tf.int64),
                     'et_ids1': tf.FixedLenSequenceFeature([], tf.int64),
                     'et_ids2': tf.FixedLenSequenceFeature([], tf.int64),
                     'position_ids1': tf.FixedLenSequenceFeature([], tf.int64),
                     'position_ids2': tf.FixedLenSequenceFeature([], tf.int64),
                     'chunks': tf.FixedLenSequenceFeature([], tf.int64),
                     'spath_ids': tf.FixedLenSequenceFeature([], tf.int64),
                     }
  context_dict, sequence_dict = tf.parse_single_sequence_example(
                      serialized_example,
                      context_features   = context_features,
                      sequence_features  = sequence_features)

  sentence = (sequence_dict['word_ids'],sequence_dict['et_ids1'],sequence_dict['et_ids2'],sequence_dict['position_ids1'],
              sequence_dict['position_ids2'],sequence_dict['chunks'],sequence_dict['spath_ids'], context_dict['seq_len'])

  label = context_dict['label']
  task = context_dict['task']

  return task, label, sentence



def read_tfrecord(epoch, batch_size):
  for dataset in DATASETS:
    train_record_file = os.path.join(OUT_DIR, dataset+'.train.tfrecord')
    test_record_file = os.path.join(OUT_DIR, dataset+'.test.tfrecord')

    train_data = util.read_tfrecord(train_record_file, 
                                    epoch, 
                                    batch_size, 
                                    _parse_tfexample, 
                                    shuffle=True)

    test_data = util.read_tfrecord(test_record_file, 
                                    epoch,
                                   batch_size,
                                    _parse_tfexample, 
                                    shuffle=False)
    yield train_data, test_data

模型中使用:

  def build_task_graph(self, data):
    task_label, labels, sentence = data
    # sentence = tf.nn.embedding_lookup(self.word_embed, sentence)
##########################
    word_ids, et_ids1,et_ids2,position_ids1,position_ids2,chunks,spath_ids,seq_len = sentence
    # sentence = word_ids
#########################

    self.word_ids = word_ids
    self.position_ids1 = position_ids1
    self.position_ids2 = position_ids2
    self.et_ids1 = et_ids1
    self.et_ids2 = et_ids2
    self.chunks_ids = chunks
    self.spath_ids = spath_ids
    self.seq_len = seq_len

    sentence = self.add_embedding_layers()

  

 

猜你喜欢

转载自www.cnblogs.com/huadongw/p/11483730.html