Google Open Images Dataset V4 图片数据集详解2-分类快速下载

上节我们介绍了open image v4数据集的结构信息,这节里我们来尝试来真正下载相应的图片,整个数据集很大有561GB,这么大的数据量,对于学习者,传输和存储都是个问题。其实我最常用的方式是下载某些(某个)对象的图片就够了,根据上节我们讲的关系,以对象检测为例,我们可以写一个脚本,单独的获取某些对象图片。这节我们讲述如何快速下载一个乌龟的图像集,我们先在v4的官网上浏览Tortoise,差不多是这样:

一、安装tensorflow object detect Api

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#在根目录下创建一个output目录
mkdir  /output
cd  /output/
 
#下载旧版本的tensorflow model(object api 包含在model里),最新版本的api存在问题(当前2018.4.20)
wget  https: //github .com /tensorflow/models/archive/dcfe009a024854207c9067d785c105f5ebf5a01b .zip
unzip dcfe009a024854207c9067d785c105f5ebf5a01b.zip 
mv  models-dcfe009a024854207c9067d785c105f5ebf5a01b models
rm  dcfe009a024854207c9067d785c105f5ebf5a01b.zip 
 
#安装依赖项
pip  install  Cython
pip  install  pillow
pip  install  lxml
pip  install  jupyter
pip  install  matplotlib
pip  install  opencv-python
pip  install  pycocotools
 
#安装object detection api 并验证
cd  /output/models/research/
protoc object_detection /protos/ *.proto --python_out=.
export  PYTHONPATH=$PYTHONPATH:` pwd `:` pwd ` /slim
python object_detection /builders/model_builder_test .py


下载代码github


二、根据关键字生成tfrecord

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import  pandas as pd
import  numpy as np
import  os
import  tensorflow as tf
import  io
import  logging
import  random
import  sys
import  PIL.Image
import  hashlib
from  urllib  import  request
  
sys.path.append( "/output/models/research/" )
from  object_detection.utils  import  dataset_util
  
  
class  open_image_dataset:
  
      
      
  
      
  
     def  download_test( self ):
         print ( "start download test info" )
         folder = "test"
         if  os.path.exists(folder,)  is  False :
             os.makedirs(folder)
         image_csv_path = folder + "/image.csv"
         box_csv_path = folder + "/box.csv"
         classname_csv_path = folder + "/classname.csv"
         if  os.path.exists(image_csv_path)  is  False :
             request.urlretrieve( self .test_image_csv,image_csv_path)
         if  os.path.exists(box_csv_path)  is  False :
             request.urlretrieve( self .test_box_csv,box_csv_path )
         if  os.path.exists(classname_csv_path)  is  False :
             request.urlretrieve( self .classname_csv,classname_csv_path )
         print ( "download test complete" )
     def  download_val( self ):
         folder = "val"
         if  os.path.exists(folder,)  is  False :
             os.makedirs(folder)
         image_csv_path = folder + "/image.csv"
         box_csv_path = folder + "/box.csv"
         classname_csv_path = folder + "/classname.csv"
         if  os.path.exists(image_csv_path)  is  False :
             request.urlretrieve( self .val_image_csv,image_csv_path)
         if  os.path.exists(box_csv_path)  is  False :
             request.urlretrieve( self .val_box_csv,box_csv_path )
         if  os.path.exists(classname_csv_path)  is  False :
             request.urlretrieve( self .classname_csv,classname_csv_path)
         print ( "download val complete" )
      
     def  download_train( self ):
         folder = "train"
         if  os.path.exists(folder,)  is  False :
             os.makedirs(folder)
         image_csv_path = folder + "/image.csv"
         box_csv_path = folder + "/box.csv"
         classname_csv_path = folder + "/classname.csv"
         if  os.path.exists(image_csv_path)  is  False :
             request.urlretrieve( self .train_image_csv,image_csv_path)
         if  os.path.exists(box_csv_path)  is  False :
             request.urlretrieve( self .train_box_csv,box_csv_path )
         if  os.path.exists(classname_csv_path)  is  False :
             request.urlretrieve( self .classname_csv,classname_csv_path )
         print ( "download train complete" )
              
     def  create_tfrecord( self ,folder,keywords):  
         image_csv_path = folder + "/image.csv"
         box_csv_path = folder + "/box.csv"
         classname_csv_path = folder + "/classname.csv"    
          
         df_image  =  pd.read_csv(image_csv_path)
         df_box  =  pd.read_csv(box_csv_path)
         df_classname  =  pd.read_csv(classname_csv_path,names = [ 'labelID' , 'LabelName' ])
  
         data =  df_classname[df_classname[ 'LabelName' ] = = keywords]
         data = pd.merge(data, df_box, left_on  =  'labelID' , right_on  =  'LabelName' , how = 'right' )
         data = pd.merge(data, df_image, left_on  =  'ImageID' , right_on  =  'ImageID' , how = 'right' )
         data = data[data[ 'labelID' ].notna() & data[ 'ImageID' ].notna()]
          
         folder_path = keywords + "/" + folder + "/"
         if  os.path.exists(folder_path)  is  False :
             os.makedirs(folder_path)
              
         tfrecord_file = folder_path + keywords + ".tfrecord"
         writer  =  tf.python_io.TFRecordWriter(tfrecord_file)
  
         for   index,row  in  data.iterrows():
             file_name = row[ 'ImageID' ] + ".jpg"
             file_path = folder_path + file_name
             if  os.path.exists(file_path)  is  False :
                 request.urlretrieve(row[ 'OriginalURL' ],file_path)        
             with tf.gfile.GFile(file_path,  'rb' ) as fid:
                 encoded_jpg  =  fid.read()
             encoded_jpg_io  =  io.BytesIO(encoded_jpg)
             image  =  PIL.Image. open (encoded_jpg_io)
             if  image. format  ! =  'JPEG' :
                 print ( "file format error " + file_path)
                 os.remove(file_path)
                 continue
             image.close()  
             key  =  hashlib.sha256(encoded_jpg).hexdigest()    
  
             xmin  =  []
             ymin  =  []
             xmax  =  []
             ymax  =  []
             classes  =  []
             classes_text  =  []
             width = image.width
             height = image.height
             xmin.append( float (row[ 'XMin' ]))
             xmax.append( float (row[ 'XMax' ]))
             ymin.append( float (row[ 'YMin' ]))
             ymax.append( float (row[ 'YMax' ]))
             classes.append( int ( 1 ))
             classes_text.append(keywords.encode( 'utf8' ))
              
             example  =  tf.train.Example(features = tf.train.Features(feature = {
                 'image/height' : dataset_util.int64_feature( int (height)),
               'image/width' : dataset_util.int64_feature( int (width)),
               'image/filename' : dataset_util.bytes_feature(file_name.encode( 'utf8' )),
               'image/source_id' : dataset_util.bytes_feature(file_name.encode( 'utf8' )),
               'image/key/sha256' : dataset_util.bytes_feature(key.encode( 'utf8' )),
               'image/encoded' : dataset_util.bytes_feature(encoded_jpg),
               'image/format' : dataset_util.bytes_feature( 'jpeg' .encode( 'utf8' )),
               'image/object/bbox/xmin' : dataset_util.float_list_feature(xmin),
               'image/object/bbox/xmax' : dataset_util.float_list_feature(xmax),
               'image/object/bbox/ymin' : dataset_util.float_list_feature(ymin),
               'image/object/bbox/ymax' : dataset_util.float_list_feature(ymax),
               'image/object/class/text' : dataset_util.bytes_list_feature(classes_text),
               'image/object/class/label' : dataset_util.int64_list_feature(classes),
             }))
             writer.write(example.SerializeToString())
             os.remove(file_path)
             print ( "file " + file_path)
         writer.close() 
         print ( "create " + tfrecord_file + " success!" )
          
     def  create_train_tfrecord( self ,keywords):  
          self .download_train()
          self .create_tfrecord( "train" ,keywords)
     def  create_val_tfrecord( self ,keywords):  
          self .download_val()
          self .create_tfrecord( "val" ,keywords) 
     def  create_test_tfrecord( self ,keywords):  
          self .download_test()
          self .create_tfrecord( "test" ,keywords)
     def  create_all_tfrecord( self ,keywords):
         self .create_train_tfrecord(keywords)
         self .create_val_tfrecord(keywords)
          
dataset = open_image_dataset()
dataset.download_test()
dataset.create_tfrecord( "test" , "Tortoise" ) #下载关键字为"Tortoise"的测试数据集
#dataset.download_val()
#dataset.create_tfrecord("val","Tortoise")#下载关键字为"Tortoise"的验证数据集
#dataset.download_train()
#dataset.create_tfrecord("train","Tortoise")#下载关键字为"Tortoise"的训练数据集
  
# dataset.create_all_tfrecord("Tortoise") #下载所有关键字为"Tortoise"的数据集


三、对生成的tfrecord进行验证

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import  tensorflow as tf
import  numpy as np
import  os
import  skimage.io as io
import  cv2
tfrecords_filename  =  "Tortoise/test/Tortoise.tfrecord"
 
filename_queue  =  tf.train.string_input_producer([tfrecords_filename]) 
reader  =  tf.TFRecordReader()
_, serialized_example  =  reader.read(filename_queue) 
     
features  =  tf.parse_single_example(serialized_example,
                                    features = {
                                         'image/width' :tf.FixedLenFeature([], tf.int64),
                                         'image/height' : tf.FixedLenFeature([], tf.int64),
                                         'image/filename' :  tf.FixedLenFeature([], tf.string),
                                         'image/source_id' : tf.FixedLenFeature([], tf.string),
                                         'image/key/sha256' :  tf.FixedLenFeature([], tf.string),
                                         'image/encoded' : tf.FixedLenFeature([], tf.string),
                                         'image/format' :  tf.FixedLenFeature([], tf.string),
                                         'image/object/bbox/xmin' : tf.FixedLenFeature([], tf.float32),
                                         'image/object/bbox/xmax' : tf.FixedLenFeature([], tf.float32),
                                         'image/object/bbox/ymin' :tf.FixedLenFeature([], tf.float32),
                                         'image/object/bbox/ymax' :tf.FixedLenFeature([], tf.float32),
                                         'image/object/class/text' :tf.FixedLenFeature([], tf.string),
                                         'image/object/class/label' : tf.FixedLenFeature([], tf.int64),
                                    })  
 
width =  tf.cast(features[ 'image/width' ], tf.int32)
height  =  tf.cast(features[ 'image/height' ], tf.int32)
filename  =  tf.cast(features[ 'image/filename' ], tf.string)
format  =  tf.cast(features[ 'image/format' ], tf.string)
xmin  =  tf.cast(features[ 'image/object/bbox/xmin' ], tf.float32)
xmax  =  tf.cast(features[ 'image/object/bbox/xmax' ], tf.float32)
ymin  =  tf.cast(features[ 'image/object/bbox/ymin' ], tf.float32)
ymax  =  tf.cast(features[ 'image/object/bbox/ymax' ], tf.float32)
text  =  tf.cast(features[ 'image/object/class/text' ], tf.string)
label  =  tf.cast(features[ 'image/object/class/label' ], tf.int64)
 
image  = tf.image.decode_jpeg(features[ 'image/encoded' ]);
image  =  tf.reshape(image,tf.stack([height,width, 3 ]))
 
 
 
 
with tf.Session() as sess: 
     init_op  =  tf.initialize_all_variables()
     sess.run(init_op)
     coord = tf.train.Coordinator()
     threads =  tf.train.start_queue_runners(coord = coord)
     for  in  range ( 20 ):
         width1,height1,filename1,format1,xmin1,xmax1,ymin1,ymax1,text1,label1,image1 = sess.run([width,height,filename, format ,xmin,xmax,ymin,ymax,text,label,image])
         print (width1,height1,filename1,format1,xmin1,xmax1,ymin1,ymax1,text1,label1)
         x1,y1 = int (xmin1 * width1), int (ymin1 * height1)
         x2,y2 = int (xmax1 * width1), int (ymax1 * height1)
         io.imshow(cv2.rectangle(np.array(image1),(x1,y1),(x2,y2),( 0 , 255 , 0 ), 3 ), cmap  =  'gray' , interpolation  =  'bicubic' )
         io.show()
         
     coord.request_stop()
     coord.join(threads)

下载代码github

最终的结果如下:











猜你喜欢

转载自blog.csdn.net/wulala789/article/details/80671952