上节我们介绍了open image v4数据集的结构信息,这节里我们来尝试来真正下载相应的图片,整个数据集很大有561GB,这么大的数据量,对于学习者,传输和存储都是个问题。其实我最常用的方式是下载某些(某个)对象的图片就够了,根据上节我们讲的关系,以对象检测为例,我们可以写一个脚本,单独的获取某些对象图片。这节我们讲述如何快速下载一个乌龟的图像集,我们先在v4的官网上浏览Tortoise,差不多是这样:
一、安装tensorflow object detect Api
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
|
#在根目录下创建一个output目录
mkdir
/output
cd
/output/
#下载旧版本的tensorflow model(object api 包含在model里),最新版本的api存在问题(当前2018.4.20)
wget https:
//github
.com
/tensorflow/models/archive/dcfe009a024854207c9067d785c105f5ebf5a01b
.zip
unzip dcfe009a024854207c9067d785c105f5ebf5a01b.zip
mv
models-dcfe009a024854207c9067d785c105f5ebf5a01b models
rm
dcfe009a024854207c9067d785c105f5ebf5a01b.zip
#安装依赖项
pip
install
Cython
pip
install
pillow
pip
install
lxml
pip
install
jupyter
pip
install
matplotlib
pip
install
opencv-python
pip
install
pycocotools
#安装object detection api 并验证
cd
/output/models/research/
protoc object_detection
/protos/
*.proto --python_out=.
export
PYTHONPATH=$PYTHONPATH:`
pwd
`:`
pwd
`
/slim
python object_detection
/builders/model_builder_test
.py
|
|
下载代码github
二、根据关键字生成tfrecord
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
|
import
pandas as pd
import
numpy as np
import
os
import
tensorflow as tf
import
io
import
logging
import
random
import
sys
import
PIL.Image
import
hashlib
from
urllib
import
request
sys.path.append(
"/output/models/research/"
)
from
object_detection.utils
import
dataset_util
class
open_image_dataset:
test_image_csv
=
"https://storage.googleapis.com/openimages/2018_04/test/test-images-with-rotation.csv"
val_image_csv
=
"https://storage.googleapis.com/openimages/2018_04/validation/validation-images-with-rotation.csv"
val_box_csv
=
"https://storage.googleapis.com/openimages/2018_04/validation/validation-annotations-bbox.csv"
train_image_csv
=
"https://storage.googleapis.com/openimages/2018_04/train/train-images-boxable-with-rotation.csv"
train_box_csv
=
"https://storage.googleapis.com/openimages/2018_04/train/train-annotations-bbox.csv"
def
download_test(
self
):
print
(
"start download test info"
)
folder
=
"test"
if
os.path.exists(folder,)
is
False
:
os.makedirs(folder)
image_csv_path
=
folder
+
"/image.csv"
box_csv_path
=
folder
+
"/box.csv"
classname_csv_path
=
folder
+
"/classname.csv"
if
os.path.exists(image_csv_path)
is
False
:
request.urlretrieve(
self
.test_image_csv,image_csv_path)
if
os.path.exists(box_csv_path)
is
False
:
request.urlretrieve(
self
.test_box_csv,box_csv_path )
if
os.path.exists(classname_csv_path)
is
False
:
request.urlretrieve(
self
.classname_csv,classname_csv_path )
print
(
"download test complete"
)
def
download_val(
self
):
folder
=
"val"
if
os.path.exists(folder,)
is
False
:
os.makedirs(folder)
image_csv_path
=
folder
+
"/image.csv"
box_csv_path
=
folder
+
"/box.csv"
classname_csv_path
=
folder
+
"/classname.csv"
if
os.path.exists(image_csv_path)
is
False
:
request.urlretrieve(
self
.val_image_csv,image_csv_path)
if
os.path.exists(box_csv_path)
is
False
:
request.urlretrieve(
self
.val_box_csv,box_csv_path )
if
os.path.exists(classname_csv_path)
is
False
:
request.urlretrieve(
self
.classname_csv,classname_csv_path)
print
(
"download val complete"
)
def
download_train(
self
):
folder
=
"train"
if
os.path.exists(folder,)
is
False
:
os.makedirs(folder)
image_csv_path
=
folder
+
"/image.csv"
box_csv_path
=
folder
+
"/box.csv"
classname_csv_path
=
folder
+
"/classname.csv"
if
os.path.exists(image_csv_path)
is
False
:
request.urlretrieve(
self
.train_image_csv,image_csv_path)
if
os.path.exists(box_csv_path)
is
False
:
request.urlretrieve(
self
.train_box_csv,box_csv_path )
if
os.path.exists(classname_csv_path)
is
False
:
request.urlretrieve(
self
.classname_csv,classname_csv_path )
print
(
"download train complete"
)
def
create_tfrecord(
self
,folder,keywords):
image_csv_path
=
folder
+
"/image.csv"
box_csv_path
=
folder
+
"/box.csv"
classname_csv_path
=
folder
+
"/classname.csv"
df_image
=
pd.read_csv(image_csv_path)
df_box
=
pd.read_csv(box_csv_path)
df_classname
=
pd.read_csv(classname_csv_path,names
=
[
'labelID'
,
'LabelName'
])
data
=
df_classname[df_classname[
'LabelName'
]
=
=
keywords]
data
=
pd.merge(data, df_box, left_on
=
'labelID'
, right_on
=
'LabelName'
, how
=
'right'
)
data
=
pd.merge(data, df_image, left_on
=
'ImageID'
, right_on
=
'ImageID'
, how
=
'right'
)
data
=
data[data[
'labelID'
].notna() & data[
'ImageID'
].notna()]
folder_path
=
keywords
+
"/"
+
folder
+
"/"
if
os.path.exists(folder_path)
is
False
:
os.makedirs(folder_path)
tfrecord_file
=
folder_path
+
keywords
+
".tfrecord"
writer
=
tf.python_io.TFRecordWriter(tfrecord_file)
for
index,row
in
data.iterrows():
file_name
=
row[
'ImageID'
]
+
".jpg"
file_path
=
folder_path
+
file_name
if
os.path.exists(file_path)
is
False
:
request.urlretrieve(row[
'OriginalURL'
],file_path)
with tf.gfile.GFile(file_path,
'rb'
) as fid:
encoded_jpg
=
fid.read()
encoded_jpg_io
=
io.BytesIO(encoded_jpg)
image
=
PIL.Image.
open
(encoded_jpg_io)
if
image.
format
!
=
'JPEG'
:
print
(
"file format error "
+
file_path)
os.remove(file_path)
continue
image.close()
key
=
hashlib.sha256(encoded_jpg).hexdigest()
xmin
=
[]
ymin
=
[]
xmax
=
[]
ymax
=
[]
classes
=
[]
classes_text
=
[]
width
=
image.width
height
=
image.height
xmin.append(
float
(row[
'XMin'
]))
xmax.append(
float
(row[
'XMax'
]))
ymin.append(
float
(row[
'YMin'
]))
ymax.append(
float
(row[
'YMax'
]))
classes.append(
int
(
1
))
classes_text.append(keywords.encode(
'utf8'
))
example
=
tf.train.Example(features
=
tf.train.Features(feature
=
{
'image/height'
: dataset_util.int64_feature(
int
(height)),
'image/width'
: dataset_util.int64_feature(
int
(width)),
'image/filename'
: dataset_util.bytes_feature(file_name.encode(
'utf8'
)),
'image/source_id'
: dataset_util.bytes_feature(file_name.encode(
'utf8'
)),
'image/key/sha256'
: dataset_util.bytes_feature(key.encode(
'utf8'
)),
'image/encoded'
: dataset_util.bytes_feature(encoded_jpg),
'image/format'
: dataset_util.bytes_feature(
'jpeg'
.encode(
'utf8'
)),
'image/object/bbox/xmin'
: dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax'
: dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin'
: dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax'
: dataset_util.float_list_feature(ymax),
'image/object/class/text'
: dataset_util.bytes_list_feature(classes_text),
'image/object/class/label'
: dataset_util.int64_list_feature(classes),
}))
writer.write(example.SerializeToString())
os.remove(file_path)
print
(
"file "
+
file_path)
writer.close()
print
(
"create "
+
tfrecord_file
+
" success!"
)
def
create_train_tfrecord(
self
,keywords):
self
.download_train()
self
.create_tfrecord(
"train"
,keywords)
def
create_val_tfrecord(
self
,keywords):
self
.download_val()
self
.create_tfrecord(
"val"
,keywords)
def
create_test_tfrecord(
self
,keywords):
self
.download_test()
self
.create_tfrecord(
"test"
,keywords)
def
create_all_tfrecord(
self
,keywords):
self
.create_train_tfrecord(keywords)
self
.create_val_tfrecord(keywords)
dataset
=
open_image_dataset()
dataset.download_test()
dataset.create_tfrecord(
"test"
,
"Tortoise"
)
#下载关键字为"Tortoise"的测试数据集
#dataset.download_val()
#dataset.create_tfrecord("val","Tortoise")#下载关键字为"Tortoise"的验证数据集
#dataset.download_train()
#dataset.create_tfrecord("train","Tortoise")#下载关键字为"Tortoise"的训练数据集
# dataset.create_all_tfrecord("Tortoise") #下载所有关键字为"Tortoise"的数据集
|
三、对生成的tfrecord进行验证
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
|
import
tensorflow as tf
import
numpy as np
import
os
import
skimage.io as io
import
cv2
tfrecords_filename
=
"Tortoise/test/Tortoise.tfrecord"
filename_queue
=
tf.train.string_input_producer([tfrecords_filename])
reader
=
tf.TFRecordReader()
_, serialized_example
=
reader.read(filename_queue)
features
=
tf.parse_single_example(serialized_example,
features
=
{
'image/width'
:tf.FixedLenFeature([], tf.int64),
'image/height'
: tf.FixedLenFeature([], tf.int64),
'image/filename'
: tf.FixedLenFeature([], tf.string),
'image/source_id'
: tf.FixedLenFeature([], tf.string),
'image/key/sha256'
: tf.FixedLenFeature([], tf.string),
'image/encoded'
: tf.FixedLenFeature([], tf.string),
'image/format'
: tf.FixedLenFeature([], tf.string),
'image/object/bbox/xmin'
: tf.FixedLenFeature([], tf.float32),
'image/object/bbox/xmax'
: tf.FixedLenFeature([], tf.float32),
'image/object/bbox/ymin'
:tf.FixedLenFeature([], tf.float32),
'image/object/bbox/ymax'
:tf.FixedLenFeature([], tf.float32),
'image/object/class/text'
:tf.FixedLenFeature([], tf.string),
'image/object/class/label'
: tf.FixedLenFeature([], tf.int64),
})
width
=
tf.cast(features[
'image/width'
], tf.int32)
height
=
tf.cast(features[
'image/height'
], tf.int32)
filename
=
tf.cast(features[
'image/filename'
], tf.string)
format
=
tf.cast(features[
'image/format'
], tf.string)
xmin
=
tf.cast(features[
'image/object/bbox/xmin'
], tf.float32)
xmax
=
tf.cast(features[
'image/object/bbox/xmax'
], tf.float32)
ymin
=
tf.cast(features[
'image/object/bbox/ymin'
], tf.float32)
ymax
=
tf.cast(features[
'image/object/bbox/ymax'
], tf.float32)
text
=
tf.cast(features[
'image/object/class/text'
], tf.string)
label
=
tf.cast(features[
'image/object/class/label'
], tf.int64)
image
=
tf.image.decode_jpeg(features[
'image/encoded'
]);
image
=
tf.reshape(image,tf.stack([height,width,
3
]))
with tf.Session() as sess:
init_op
=
tf.initialize_all_variables()
sess.run(init_op)
coord
=
tf.train.Coordinator()
threads
=
tf.train.start_queue_runners(coord
=
coord)
for
i
in
range
(
20
):
width1,height1,filename1,format1,xmin1,xmax1,ymin1,ymax1,text1,label1,image1
=
sess.run([width,height,filename,
format
,xmin,xmax,ymin,ymax,text,label,image])
print
(width1,height1,filename1,format1,xmin1,xmax1,ymin1,ymax1,text1,label1)
x1,y1
=
int
(xmin1
*
width1),
int
(ymin1
*
height1)
x2,y2
=
int
(xmax1
*
width1),
int
(ymax1
*
height1)
io.imshow(cv2.rectangle(np.array(image1),(x1,y1),(x2,y2),(
0
,
255
,
0
),
3
), cmap
=
'gray'
, interpolation
=
'bicubic'
)
io.show()
coord.request_stop()
coord.join(threads)
|
最终的结果如下: