Tensorflow tf.train.batch

官网说明:

https://www.tensorflow.org/api_docs/python/tf/train/slice_input_producer

下面的程序中给出的是不完整的,但是为了说明一下tf.train.batch足够了。

image_batch,label_batch=tf.train.batch([image,label],batch_size=batch_size,capacity=capacity)

看官网可以知道该函数返回一系列的tensor, 但是究竟是如何做到的呢?这个是让人很困惑的,所以我就是用程序来实验了一下。

下面的输入参数:

(1)image_list:是图片的完整路径

(2)label_list:是已经生成的one_hot 标签

(3)image_w, image_h 是要把图片resize后得到的尺寸

(4)batch_size,capacity 不做过多解释

def get_batch(image_list,label_list,image_w,image_h,batch_size,capacity):
    
    image=tf.cast(image_list,tf.string)
    label=tf.cast(label_list,tf.int32)
    # create file_queue
    input_queue=tf.train.slice_input_producer([image,label],shuffle=True)
    
    label=input_queue[1]
    image_contents=tf.read_file(input_queue[0])
    image=tf.image.decode_jpeg(image_contents,channels=3)
    
    image=tf.image.resize_image_with_crop_or_pad(image,image_w,image_h)
    image=tf.image.per_image_standardization(image)
    #create a list of tensor(every tensor represent a batch)
    image_batch,label_batch=tf.train.batch([image,label],batch_size=batch_size,capacity=capacity)
    #label_batch=tf.reshape(label_batch,[batch_size])  # I con't know why need this line 
    return image_batch,label_batch

valid_list,valid_label=get_list(valid_path)
valid_label=make_one_hot_label(valid_label)
image_batch,label_batch=get_batch(valid_list,valid_label,224,224,8,64)

for i in range(4):
    image_batch,label_batch=get_batch(valid_list,valid_label,224,224,8,64)
    print(image_batch) 
    print(label_batch)
    print('------')

程序输出结果:

Tensor("batch_1:0", shape=(8, 224, 224, 3), dtype=float32)
Tensor("batch_1:1", shape=(8, 100), dtype=int32)
------
Tensor("batch_2:0", shape=(8, 224, 224, 3), dtype=float32)
Tensor("batch_2:1", shape=(8, 100), dtype=int32)
------
Tensor("batch_3:0", shape=(8, 224, 224, 3), dtype=float32)
Tensor("batch_3:1", shape=(8, 100), dtype=int32)
------
Tensor("batch_4:0", shape=(8, 224, 224, 3), dtype=float32)
Tensor("batch_4:1", shape=(8, 100), dtype=int32)
------

每次运行tf.train.batch将得到一个tensor,每个tensor是二维的,因为包含image和label

猜你喜欢

转载自blog.csdn.net/yuanliang861/article/details/83007158
今日推荐