官网说明:
https://www.tensorflow.org/api_docs/python/tf/train/slice_input_producer
下面的程序中给出的是不完整的,但是为了说明一下tf.train.batch足够了。
image_batch,label_batch=tf.train.batch([image,label],batch_size=batch_size,capacity=capacity)
看官网可以知道该函数返回一系列的tensor, 但是究竟是如何做到的呢?这个是让人很困惑的,所以我就是用程序来实验了一下。
下面的输入参数:
(1)image_list:是图片的完整路径
(2)label_list:是已经生成的one_hot 标签
(3)image_w, image_h 是要把图片resize后得到的尺寸
(4)batch_size,capacity 不做过多解释
def get_batch(image_list,label_list,image_w,image_h,batch_size,capacity):
image=tf.cast(image_list,tf.string)
label=tf.cast(label_list,tf.int32)
# create file_queue
input_queue=tf.train.slice_input_producer([image,label],shuffle=True)
label=input_queue[1]
image_contents=tf.read_file(input_queue[0])
image=tf.image.decode_jpeg(image_contents,channels=3)
image=tf.image.resize_image_with_crop_or_pad(image,image_w,image_h)
image=tf.image.per_image_standardization(image)
#create a list of tensor(every tensor represent a batch)
image_batch,label_batch=tf.train.batch([image,label],batch_size=batch_size,capacity=capacity)
#label_batch=tf.reshape(label_batch,[batch_size]) # I con't know why need this line
return image_batch,label_batch
valid_list,valid_label=get_list(valid_path)
valid_label=make_one_hot_label(valid_label)
image_batch,label_batch=get_batch(valid_list,valid_label,224,224,8,64)
for i in range(4):
image_batch,label_batch=get_batch(valid_list,valid_label,224,224,8,64)
print(image_batch)
print(label_batch)
print('------')
程序输出结果:
Tensor("batch_1:0", shape=(8, 224, 224, 3), dtype=float32)
Tensor("batch_1:1", shape=(8, 100), dtype=int32)
------
Tensor("batch_2:0", shape=(8, 224, 224, 3), dtype=float32)
Tensor("batch_2:1", shape=(8, 100), dtype=int32)
------
Tensor("batch_3:0", shape=(8, 224, 224, 3), dtype=float32)
Tensor("batch_3:1", shape=(8, 100), dtype=int32)
------
Tensor("batch_4:0", shape=(8, 224, 224, 3), dtype=float32)
Tensor("batch_4:1", shape=(8, 100), dtype=int32)
------
每次运行tf.train.batch将得到一个tensor,每个tensor是二维的,因为包含image和label