tfrecords文件在训练时所占内存过大的原因
项目中需要把tiff格式的图像先制作成tensorflow的tfrecords格式,然后在训练时读取tfrecords文件为tf.data.Dataset 进行训练。而在读取tfrecords文件生成Dataset时一般都需要设置shuffle的大小,目的是在训练过程中,每个step都随机从Dataset中随机的获取一个batch的数据进行训练,能较好的防止过拟合的问题。
dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)
首先需要说一下shuffle的原理,shuffle操作后,在训练开始前,会从dataset的第一个元素开始(一个样本或者一个batch)依次取buffer_size个元素放入了缓存中,然后训练时随机从缓存的元素中取一个用于第一个step的训练,并取dataset的下一个元素放入缓存替换到取出来训练的元素的位置。之后每个的训练step重复以上的操作。
更为具体的对dataset的操作性能的讲解可以参考这篇博客:TensorFlow 高性能数据输入管道设计指南
所以可以想象一下,当我们的buffer_size设置的比较小时,这个时候就不能实现完全的随机性抽取训练样本,这样会影响训练的效果。比如我们有一个二分类的任务对猫和狗进行分类,如果buffer_size设置较小,那么前面的epoch的训练模型只能看到猫的图像,看不到狗的图像,影响训练效果。
这时往往需要设置buffer_size=训练元素的数目 来实现均匀的随机化。而当我们的训练集很大的时候,设置buffer_size=训练元素的数目会将整个训练集放入缓存中,会占用很大的内存,甚至会内存溢出。
实现均匀随机化又要节省大量内存的解决方案
我们可以使用下面的方法实现训练样本的均匀随机化:
1.首先将训练样本随机的制作成很多个小的tfrecords文件
2.使用tensorflow的interleave函数读取多个tfrecordds文件前将tfrecordds文件先随机排列
3.设置shuffle的buffer_size=小的tfrecords文件的元素数目
由于buffer_size设置的较小,所以训练占的内存就很小,这样又实现了均匀随机化,减缓了过拟合的问题。
具体的实现代码
1)将tiff格式的图像随机的制作成很多个小的tfrecords文件
import os
import tensorflow as tf
from PIL import Image
import glob
from skimage import io
import numpy as np
import time
import random
def img2TFRecord(img_dir,label_dir,tfrecords_dir,num_examples_per_tfrecords_txt_file,num_examples_per_tfrecords):
"""
将图像和对应的标签制作成TFRecord文件,之后使用tf.data.TFRecordDataset读取TFRecord文件制作成dataset
:param img_dir: 图像的文件夹
:param label_dir: 标签的文件夹
:param tfrecords_dir: 制作成的tfrecords文件夹
:param num_examples_per_tfrecords_txt_file: 每个tfrecords文件包含的样本数的txt文件
:param num_examples_per_tfrecords: 每个tfrecords文件包含的样本数
:return:
"""
t0 = time.time()
img_filenames = glob.glob(img_dir+'/*.tif')
# 对列表进行随机打乱
random.shuffle(img_filenames)
print(img_filenames[:10])
# 把样本数目写入在一个txt中,之后需要用
with open(num_examples_per_tfrecords_txt_file,'w') as f:
f.write(str(num_examples_per_tfrecords))
# 每num_examples_per_tfrecord生成一个tfrecords文件
i = 0
index = 0
# len(img_filenames)
while i < len(img_filenames):
# 要生成的tfrecord文件
tfrecords_file = tfrecords_dir + '/' + 'train_' + str(index) + '.tfrecords'
print("准备生成%s..." % (tfrecords_file))
writer = tf.python_io.TFRecordWriter(tfrecords_file)
for j in range(num_examples_per_tfrecords):
if i < len(img_filenames):
# windows的操作
img_name = img_filenames[i].split("\\")[1]
# linux的操作需要将上一行代码替换为下面的代码
# img_name = img_filenames[i].split("/")[-1]
# 对应的label文件
label_filename = label_dir + '/' + img_name
img = io.imread(img_filenames[i])
label = io.imread(label_filename)
# 将图像和Label转化为二进制格式
img_raw = img.tobytes()
label_raw = label.tobytes()
# 对label和img进行封装
example = tf.train.Example(features=tf.train.Features(feature={
"label_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_raw])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
# 序列化为字符串
writer.write(example.SerializeToString())
i += 1
writer.close()
print("%s 制作完毕!"%(tfrecords_file))
index += 1
t1 = time.time()
print("用时为%s s"%(str(t1-t0)))
2)读取tfrecords文件为Dataset
import tensorflow as tf
import glob
import numpy as np
import os
from skimage import io
########################使用tfrecords文件创建dataset##############################################
img_shape = (256,256,3)
# 解析TFRecorddataset
def _parse_function(example_proto):
features = {"img_raw": tf.FixedLenFeature((), tf.string),
"label_raw": tf.FixedLenFeature((), tf.string)}
parsed_features = tf.parse_single_example(example_proto, features)
img_str = parsed_features["img_raw"]
label_str = parsed_features["label_raw"]
# 对 tf.string进行解码
img = tf.decode_raw(img_str,tf.uint8)
img = tf.reshape(img,img_shape)
label = tf.decode_raw(label_str,tf.uint8)
label = tf.reshape(label,img_shape[0:2])
# 一个label在keras中要求有3-D
label = tf.expand_dims(label, axis=-1)
# 归一化
img = tf.cast(img,tf.float32) * (1.0/255)
label = tf.cast(label,tf.float32)
# 注意当label取 0,255时需要去掉下面的注释;当label取 0,1时不需要
# label = label * (1.0/255)
return img, label
def get_dataset_from_tfrecords(tfrecords_pattern,threads=1,batch_size=1,shuffle=True,shuffle_buffer_size=1,cycle_length=1):
"""
使用tfrecords文件创建dataset
:param ttfrecords_pattern:tfrecords文件的模板,如train_*.tfrecords
:param threads: map操作的多线程数
:param batch_size: 批次大小
:param shuffle: 是否随机打乱
:param shuffle_buffer_size: shuffle操作的buffer_size
:param cycle_length: interleave同时读取的文件数目
:return: DataSet
"""
with tf.device("/cpu:0"):
# 读取tfrecords文件的路径,并随机打乱
files = tf.data.Dataset.list_files(tfrecords_pattern, shuffle=True)
dataset = files.interleave(map_func=tf.data.TFRecordDataset, cycle_length=cycle_length)
dataset = dataset.map(_parse_function, num_parallel_calls=threads)
if shuffle:
dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)
dataset = dataset.batch(batch_size).repeat()
# 提高性能
dataset.prefetch(batch_size)
return dataset