1. Make TF_Record dataset
Trying Image.open() to open a picture will take up a lot of memory. Here I tried tf.gfile.Gfile, so I suggest you try tf.gfile.Gfile(path, 'rb') to open the picture;
def create_tf_example ( img_list , label , sess ) : #image = Image.open(img_list) # Using PIL skimage cv to read images takes up a lot of memory #image = image.resize((300,300)) #image = image.tobytes() with tf.gfile.FastGFile( img_list , 'rb' ) as fid : img = fid.read() ## Data preprocessing but the speed is slow or it is better to convert the size in advance #img = tf.image.decode_png(img, channels=3) # Here, it can also be decoded to 1 channel #img = tf.image.resize_image_with_crop_or_pad( img,40,40) # The preprocessing speed is very slow but the effect is better to add black border or center crop #img = tf.image.resize_images(img,[40,40]) # The speed is faster but still very slow #image = sess.run(img) #image_bytes = image.tobytes() # Convert tensor to bytes Note that the decoding method of this format is different example = tf.train.Example( features = tf.train.Features( feature = { ' label' : _int64_feature( label ) , 'img_raw' : _bytes_feature(img) , #'width':_int64_feature(width), #'height':_int64_feature(height) })) return example # returns a writable example
2. Read the Cat_vs_Dogs data and generate a record:
filepath = '/Users/***/Git_Mac/Cat_Vs_Dog/train/' #cat_dog under the root directory out_dir = 'Record/cat_dogs_2.record'
def Creat_Cats_Vs_Dogs(file_dir,out_dir): cats = [] label_cats = [] dogs = [] label_dogs = [] for file in os.listdir( file_dir ) : #Read the path of all pictures name = file.split( sep = '.' ) if name[ 0 ] == 'cat' : cats.append( file_dir + file) label_cats.append(0) else: dogs.append(file_dir + file) label_dogs.append(1) print('There are %d cats\nThere are %d dogs' %(len(cats), len(dogs))) image_list = np.hstack((cats , dogs)) #Combined data label_list = np.hstack((label_cats, label_dogs)) temp = np.array([image_list, label_list]) temp = temp.transpose() # 转置 np.random.shuffle(temp) # Shuffle the data image_list = list(temp[:, 0]) label_list = list(temp[:, 1]) label_list = [int(i) for i in label_list] sex = tf.Session () write = tf.python_io.TFRecordWriter(out_dir) count = 0 for img , lbe in zip (image_list , label_list) : example = create_tf_example(img , lbe , sess) #Generate an example for each image and write it write.write(example.SerializeToString()) count += 1 if(count % 1000 == 0): print(count)
3. Read Cifar10 data and generate record:
path = 'Git_Mac/cifar10' # data root directory
def Create_Cifar10_Record(path,out_dir): write = tf.python_io.TFRecordWriter(out_dir) label_list = [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ] # labels sess = tf.Session() for index,directory in zip(label_list,['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']): img_list = glob.glob(os.path.join(path,'{}/*.png'.format(directory))) #读取所有的图片路径 count = 0 for img in img_list: example = create_tf_example(img,index,sess=sess) write.write(example.SerializeToString()) count+=1 if(count %100 ==0): # 查看标签 和 进度 print(count) sess.close() write.close()
4、从TF_Record格式中读取数据
从record格式中读取数据并解码
def read_and_decode(tfrecords_file, batch_size, shuffle,n_class,one_hot = False): filename_queue = tf.train.string_input_producer([tfrecords_file]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) img_features = tf.parse_single_example( serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string), #'width': tf.FixedLenFeature([], tf.int64), #'height': tf.FixedLenFeature([], tf.int64) }) #image = tf.decode_raw(img_features['img_raw'], tf.uint8) #width = tf.cast(img_features['width'], tf.int32) #height = tf.cast(img_features['height'], tf.int32) img = img_features['img_raw'] img = tf.image.decode_png(img, channels=3) # 解码图片 png格式 jpg 使用 decode_jpeg() image = tf.reshape(img, [32, 32, 3]) # 32*32*3 这个需要根据你自己的格式进行修改 label = tf.cast(img_features['label'], tf.int32) #image = tf.reshape(image, [300,300,3]) image = tf.image.per_image_standardization(image) # 标准化处理 if shuffle: # 是否打乱数据顺序 如果capacity设置过小 会导致数据混合不完全 打乱数据读取会占用很多内存 image_batch, label_batch = tf.train.shuffle_batch( [image, label], batch_size = batch_size, num_threads= 64, capacity = 20000, min_after_dequeue = 1000) else: image_batch, label_batch= tf.train.batch( [image,label], batch_size = batch_size, num_threads = 64, capacity= 2000) image_batch = tf.cast(image_batch, tf.float32) # 转换为tf.float32 格式 if(one_hot == True): # 生成one_hot格式标签 one_hot格式标签 对应不同的loss 设置方式 label_batch = tf.one_hot(label_batch, depth= n_class) label_batch = tf.cast(label_batch, dtype=tf.int32) label_batch = tf.reshape(label_batch, [batch_size, n_class]) return image_batch, label_batch
线程读取数据
def Read_Record(filepath): with tf.Session() as sess: #开始一个会话 image,label = read_and_decode(filepath,batch_size=batch_size,shuffle=True,n_class=2,one_hot=False) init_op = tf.global_variables_initializer() sess.run(init_op) coord=tf.train.Coordinator() # 很重要的 threads= tf.train.start_queue_runners(coord=coord) try: for step in range(MAX_STEP): if coord.should_stop(): break img,lbe = sess.run([image,label])
# 添加你自己的模型 teain #plot_images(img,lbe,batch_size=batch_size)
except tf.errors.OutOfRangeError as e: print(e) finally: coord.request_stop() coord.join(threads)