解决tensorflow的tfrecords文件在训练时所占内存过大的问题

tfrecords文件在训练时所占内存过大的原因

项目中需要把tiff格式的图像先制作成tensorflow的tfrecords格式,然后在训练时读取tfrecords文件为tf.data.Dataset 进行训练。而在读取tfrecords文件生成Dataset时一般都需要设置shuffle的大小,目的是在训练过程中,每个step都随机从Dataset中随机的获取一个batch的数据进行训练,能较好的防止过拟合的问题。

dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)

首先需要说一下shuffle的原理,shuffle操作后,在训练开始前,会从dataset的第一个元素开始(一个样本或者一个batch)依次取buffer_size个元素放入了缓存中,然后训练时随机从缓存的元素中取一个用于第一个step的训练,并取dataset的下一个元素放入缓存替换到取出来训练的元素的位置。之后每个的训练step重复以上的操作。
更为具体的对dataset的操作性能的讲解可以参考这篇博客:TensorFlow 高性能数据输入管道设计指南
所以可以想象一下,当我们的buffer_size设置的比较小时,这个时候就不能实现完全的随机性抽取训练样本,这样会影响训练的效果。比如我们有一个二分类的任务对猫和狗进行分类,如果buffer_size设置较小,那么前面的epoch的训练模型只能看到猫的图像,看不到狗的图像,影响训练效果。
这时往往需要设置buffer_size=训练元素的数目 来实现均匀的随机化。而当我们的训练集很大的时候,设置buffer_size=训练元素的数目会将整个训练集放入缓存中,会占用很大的内存,甚至会内存溢出。

实现均匀随机化又要节省大量内存的解决方案

我们可以使用下面的方法实现训练样本的均匀随机化:
1.首先将训练样本随机的制作成很多个小的tfrecords文件
2.使用tensorflow的interleave函数读取多个tfrecordds文件前将tfrecordds文件先随机排列
3.设置shuffle的buffer_size=小的tfrecords文件的元素数目
由于buffer_size设置的较小,所以训练占的内存就很小,这样又实现了均匀随机化,减缓了过拟合的问题。

具体的实现代码

1)将tiff格式的图像随机的制作成很多个小的tfrecords文件

import os
import tensorflow as tf
from PIL import Image
import glob
from skimage import io
import numpy as np
import time
import random

def img2TFRecord(img_dir,label_dir,tfrecords_dir,num_examples_per_tfrecords_txt_file,num_examples_per_tfrecords):
    """
    将图像和对应的标签制作成TFRecord文件,之后使用tf.data.TFRecordDataset读取TFRecord文件制作成dataset
    :param img_dir: 图像的文件夹
    :param label_dir: 标签的文件夹
    :param tfrecords_dir: 制作成的tfrecords文件夹
    :param num_examples_per_tfrecords_txt_file: 每个tfrecords文件包含的样本数的txt文件
    :param num_examples_per_tfrecords: 每个tfrecords文件包含的样本数
    :return:
    """
    t0 = time.time()
    img_filenames = glob.glob(img_dir+'/*.tif')

    # 对列表进行随机打乱
    random.shuffle(img_filenames)

    print(img_filenames[:10])

    # 把样本数目写入在一个txt中,之后需要用
    with open(num_examples_per_tfrecords_txt_file,'w') as f:
        f.write(str(num_examples_per_tfrecords))

    # 每num_examples_per_tfrecord生成一个tfrecords文件
    i = 0
    index = 0
    # len(img_filenames)
    while i < len(img_filenames):
        # 要生成的tfrecord文件
        tfrecords_file = tfrecords_dir + '/' + 'train_' + str(index)  + '.tfrecords'
        print("准备生成%s..." % (tfrecords_file))
        writer = tf.python_io.TFRecordWriter(tfrecords_file)
        for j in range(num_examples_per_tfrecords):
            if i < len(img_filenames):
            	# windows的操作
                img_name = img_filenames[i].split("\\")[1]
                # linux的操作需要将上一行代码替换为下面的代码
                # img_name = img_filenames[i].split("/")[-1]
                # 对应的label文件
                label_filename = label_dir + '/' + img_name
                img = io.imread(img_filenames[i])
                label = io.imread(label_filename)
                # 将图像和Label转化为二进制格式
                img_raw = img.tobytes()
                label_raw = label.tobytes()
                # 对label和img进行封装
                example = tf.train.Example(features=tf.train.Features(feature={
                    "label_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_raw])),
                    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
                }))
                # 序列化为字符串
                writer.write(example.SerializeToString())
                i += 1
        writer.close()
        print("%s 制作完毕!"%(tfrecords_file))
        index += 1

    t1 = time.time()
    print("用时为%s s"%(str(t1-t0)))

2)读取tfrecords文件为Dataset

import tensorflow as tf
import glob
import numpy as np
import os
from skimage import io

########################使用tfrecords文件创建dataset##############################################
img_shape = (256,256,3)

# 解析TFRecorddataset
def _parse_function(example_proto):
  features = {"img_raw": tf.FixedLenFeature((), tf.string),
              "label_raw": tf.FixedLenFeature((), tf.string)}
  parsed_features = tf.parse_single_example(example_proto, features)
  img_str = parsed_features["img_raw"]
  label_str = parsed_features["label_raw"]
  # 对 tf.string进行解码
  img = tf.decode_raw(img_str,tf.uint8)
  img = tf.reshape(img,img_shape)

  label = tf.decode_raw(label_str,tf.uint8)
  label = tf.reshape(label,img_shape[0:2])
  # 一个label在keras中要求有3-D
  label = tf.expand_dims(label, axis=-1)

  # 归一化
  img = tf.cast(img,tf.float32) * (1.0/255)
  label = tf.cast(label,tf.float32)
  # 注意当label取 0,255时需要去掉下面的注释;当label取 0,1时不需要
  # label = label  * (1.0/255)

  return img, label

def get_dataset_from_tfrecords(tfrecords_pattern,threads=1,batch_size=1,shuffle=True,shuffle_buffer_size=1,cycle_length=1):
    """
    使用tfrecords文件创建dataset
    :param ttfrecords_pattern:tfrecords文件的模板,如train_*.tfrecords
    :param threads: map操作的多线程数
    :param batch_size: 批次大小
    :param shuffle: 是否随机打乱
    :param shuffle_buffer_size: shuffle操作的buffer_size
    :param cycle_length: interleave同时读取的文件数目
    :return: DataSet
    """
    with tf.device("/cpu:0"):
        # 读取tfrecords文件的路径,并随机打乱
        files = tf.data.Dataset.list_files(tfrecords_pattern, shuffle=True)
        dataset = files.interleave(map_func=tf.data.TFRecordDataset, cycle_length=cycle_length)
        dataset = dataset.map(_parse_function, num_parallel_calls=threads)
        if shuffle:
            dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)
        dataset = dataset.batch(batch_size).repeat()
        # 提高性能
        dataset.prefetch(batch_size)
    return dataset
发布了37 篇原创文章 · 获赞 6 · 访问量 5407

猜你喜欢

转载自blog.csdn.net/qq_37891889/article/details/102883038
今日推荐