TF_Record format data production and reading (based on cat and dog wars, cifar10 data)

1. Make TF_Record dataset

  Trying Image.open() to open a picture will take up a lot of memory. Here I tried tf.gfile.Gfile, so I suggest you try tf.gfile.Gfile(path, 'rb') to open the picture;

def create_tf_example ( img_list , label , sess ) :
 #image = Image.open(img_list) # Using PIL skimage cv to read images takes up a lot of memory
     #image = image.resize((300,300))
     #image = image.tobytes()
 with tf.gfile.FastGFile( img_list , 'rb' ) as fid :
 img = fid.read()    
            

    ## Data preprocessing but the speed is slow or it is better to convert the size in advance
     #img = tf.image.decode_png(img, channels=3) # Here, it can also be decoded to 1 channel
     #img = tf.image.resize_image_with_crop_or_pad( img,40,40) # The preprocessing speed is very slow but the effect is better to add black border or center crop
     #img = tf.image.resize_images(img,[40,40]) # The speed is faster but still very slow
     #image = sess.run(img)
     #image_bytes = image.tobytes() # Convert tensor to bytes Note that the decoding method of this format is different
 example = tf.train.Example( features = tf.train.Features(
         feature = {
             ' label' : _int64_feature( label ) ,
 'img_raw' : _bytes_feature(img) ,

                            #'width':_int64_feature(width),
            #'height':_int64_feature(height)
        }))

    return example # returns a writable example

2. Read the Cat_vs_Dogs data and generate a record:

filepath = '/Users/***/Git_Mac/Cat_Vs_Dog/train/' #cat_dog under the root directory
 out_dir = 'Record/cat_dogs_2.record'

def Creat_Cats_Vs_Dogs(file_dir,out_dir):
    cats = []
    label_cats = []
    dogs = []
    label_dogs = []
     for file in os.listdir( file_dir ) : #Read the path of all pictures
         name = file.split( sep = '.' )
         if name[ 0 ] == 'cat' :
             cats.append( file_dir + file)
            label_cats.append(0)
        else:
            dogs.append(file_dir + file)
            label_dogs.append(1)
    print('There are %d cats\nThere are %d dogs' %(len(cats), len(dogs)))

    image_list = np.hstack((cats , dogs)) #Combined data
    label_list = np.hstack((label_cats, label_dogs))

    temp = np.array([image_list, label_list])
    temp = temp.transpose()   # 转置
    np.random.shuffle(temp) # Shuffle the data

    image_list = list(temp[:, 0])
    label_list = list(temp[:, 1])
    label_list = [int(i) for i in label_list]

    sex = tf.Session ()
    write = tf.python_io.TFRecordWriter(out_dir)
    count = 0
 for img , lbe in zip (image_list , label_list) :
 example = create_tf_example(img , lbe , sess) #Generate an example for each image and write it    
        
        write.write(example.SerializeToString())

        count += 1
if(count % 1000 == 0):
print(count)                    

3. Read Cifar10 data and generate record:

path = 'Git_Mac/cifar10'    # data root directory
def Create_Cifar10_Record(path,out_dir):
    write = tf.python_io.TFRecordWriter(out_dir)

    label_list = [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ] # labels
    sess = tf.Session()

    for index,directory in zip(label_list,['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']):
        img_list = glob.glob(os.path.join(path,'{}/*.png'.format(directory)))  #读取所有的图片路径
        count = 0

        for img in img_list:
            example = create_tf_example(img,index,sess=sess)
            write.write(example.SerializeToString())
            count+=1
            if(count %100 ==0):   # 查看标签  和  进度
                print(count)

    sess.close()
    write.close()

4、从TF_Record格式中读取数据

   从record格式中读取数据并解码  

def read_and_decode(tfrecords_file, batch_size, shuffle,n_class,one_hot = False):
    filename_queue = tf.train.string_input_producer([tfrecords_file])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)  

    img_features = tf.parse_single_example(
        serialized_example,
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'img_raw': tf.FixedLenFeature([], tf.string),
            #'width': tf.FixedLenFeature([], tf.int64),
            #'height': tf.FixedLenFeature([], tf.int64)
        })

    #image = tf.decode_raw(img_features['img_raw'], tf.uint8)
    #width = tf.cast(img_features['width'], tf.int32)
    #height = tf.cast(img_features['height'], tf.int32)

    img = img_features['img_raw']
    img = tf.image.decode_png(img, channels=3)  # 解码图片  png格式   jpg 使用 decode_jpeg()
    image = tf.reshape(img, [32, 32, 3])   # 32*32*3   这个需要根据你自己的格式进行修改
    label = tf.cast(img_features['label'], tf.int32) 
    #image = tf.reshape(image, [300,300,3])
    image = tf.image.per_image_standardization(image)  # 标准化处理

    if shuffle:         # 是否打乱数据顺序  如果capacity设置过小 会导致数据混合不完全 打乱数据读取会占用很多内存
        image_batch, label_batch = tf.train.shuffle_batch(
            [image, label],
            batch_size = batch_size,
            num_threads= 64,
            capacity = 20000,
            min_after_dequeue = 1000)
    else:
        image_batch, label_batch= tf.train.batch(
            [image,label],
            batch_size = batch_size,
            num_threads = 64,
            capacity= 2000)

    image_batch = tf.cast(image_batch, tf.float32)   # 转换为tf.float32 格式

    if(one_hot == True):    # 生成one_hot格式标签  one_hot格式标签  对应不同的loss 设置方式
        label_batch = tf.one_hot(label_batch, depth= n_class)
        label_batch = tf.cast(label_batch, dtype=tf.int32)
        label_batch = tf.reshape(label_batch, [batch_size, n_class])

    return image_batch, label_batch

 线程读取数据

def Read_Record(filepath):
    with tf.Session() as sess: #开始一个会话
        image,label = read_and_decode(filepath,batch_size=batch_size,shuffle=True,n_class=2,one_hot=False)

        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        coord=tf.train.Coordinator()    # 很重要的
        threads= tf.train.start_queue_runners(coord=coord)
        try:
            for step in range(MAX_STEP):
                if coord.should_stop():
                    break
                img,lbe = sess.run([image,label])
                # 添加你自己的模型 teain 
                #plot_images(img,lbe,batch_size=batch_size)

        except tf.errors.OutOfRangeError as e:
            print(e)
        finally:
            coord.request_stop()

        coord.join(threads)

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325561840&siteId=291194637