YOLOv4中的Mosaic数据增强方式实现代码
-
Mosaic是一种通过混合4张训练图像的数据增强方式,示例如下图,这样做也同时混合了四张图片的语义,目的在于使被检测的目标超出它们普遍的语义,使得模型具有更好的鲁棒性。同时,这样做使得训练时的批量归一化(Batch Normalization,BN)操作一次统计了4张图像,能够很好的降低训练时最大的mini-batch的大小。
-
-
详细具体步骤代码参考行文最后链接;
- 裁剪图片为了拼接:
# Generate sub-image data
count = 0
sub_names = []
sub_infor = []
for img_name in person_img_names:
count += 1
img_path = "/trainval/VOCdevkit/VOC2012/JPEGImages/" + img_name + ".jpg"
xml_path = "/VOC2012/Annotations/" + img_name + ".xml"
image = cv2.imread(img_path)
img_info = []
with open(xml_path, "r") as new_f:
root = ET.parse(xml_path).getroot()
for obj in root.findall('object'):
obj_name = obj.find('name').text
bndbox = obj.find('bndbox')
left = bndbox.find('xmin').text
top = bndbox.find('ymin').text
right = bndbox.find('xmax').text
bottom = bndbox.find('ymax').text
img_info.append([obj_name, left, top, right, bottom])
print("Img", count, "- Num of Objs: ", len(img_info))
# Innitialize
new_w = 360//2
new_h = 360//2
# Crop Image
cropped_img_0 = image[0:new_h, 0:new_w]
cropped_img_1 = image[0:new_h, 360//2:360]
cropped_img_2 = image[360//2:360, 0:new_w]
cropped_img_3 = image[360//2:360, 360//2:360]
# TOP-LEFT
new_img_info_0 = []
for obj in img_info:
x1 = int(float(obj[1]))
y1 = int(float(obj[2]))
if x1 < new_w and y1 < new_h:
x2 = int(float(obj[3]))
y2 = int(float(obj[4]))
if x2 > new_w: x2 = new_w
if y2 > new_h: y2 = new_h
new_img_info_0.append([obj[0], x1, y1, x2, y2])
# TOP-RIGHT
new_img_info_1 = []
for obj in img_info:
x1 = int(float(obj[1])) - new_w
x2 = int(float(obj[3])) - new_w
y1 = int(float(obj[2]))
if x2 > 0 and y1 < new_h:
if x1 < 0: x1 = 0
y2 = int(float(obj[4]))
if y2 > new_h: y2 = new_h
new_img_info_1.append([obj[0], x1, y1, x2, y2])
# BOTTOM-LEFT
new_img_info_2 = []
for obj in img_info:
y1 = int(float(obj[2])) - new_h
y2 = int(float(obj[4])) - new_h
x1 = int(float(obj[1]))
if y2 > 0 and x1 < new_w:
if y1 < 0: y1 = 0
x2 = int(float(obj[3]))
if x2 > new_w: x2 = new_w
new_img_info_2.append([obj[0], x1, y1, x2, y2])
# BOTTOM-RIGHT
new_img_info_3 = []
for obj in img_info:
x1 = int(float(obj[1])) - new_w
y1 = int(float(obj[2])) - new_h
x2 = int(float(obj[3])) - new_w
y2 = int(float(obj[4])) - new_h
if x2 > 0 and y2 > 0:
if x1 < 0: x1 = 0
if y1 < 0: y1 = 0
new_img_info_3.append([obj[0], x1, y1, x2, y2])
for infor_i in new_img_info_0:
if infor_i[0] == "person":
# Write Image
cv2.imwrite("/mosaic/trainval/VOCdevkit/VOC2012/sub-images/" + img_name + "_0.jpg", cropped_img_0)
# Write Text
text_file = open("/mosaic/trainval/VOCdevkit/VOC2012/sub-annotations/" + img_name + "_0.txt", "w+")
for infor in new_img_info_0:
print(infor[0], infor[1], infor[2], infor[3], infor[4], file=text_file)
text_file.close()
break
for infor_i in new_img_info_1:
if infor_i[0] == "person":
# Write Image
cv2.imwrite("/mosaic/trainval/VOCdevkit/VOC2012/sub-images/" + img_name + "_1.jpg", cropped_img_1)
# Write Text
text_file = open("/mosaic/trainval/VOCdevkit/VOC2012/sub-annotations/" + img_name + "_1.txt", "w+")
for infor in new_img_info_1:
print(infor[0], infor[1], infor[2], infor[3], infor[4], file=text_file)
text_file.close()
break
for infor_i in new_img_info_2:
if infor_i[0] == "person":
# Write Image
cv2.imwrite("/mosaic/trainval/VOCdevkit/VOC2012/sub-images/" + img_name + "_2.jpg", cropped_img_2)
# Write Text
text_file = open("/mosaic/trainval/VOCdevkit/VOC2012/sub-annotations/" + img_name + "_2.txt", "w+")
for infor in new_img_info_2:
print(infor[0], infor[1], infor[2], infor[3], infor[4], file=text_file)
text_file.close()
break
for infor_i in new_img_info_3:
if infor_i[0] == "person":
# Write Image
cv2.imwrite("/mosaic/trainval/VOCdevkit/VOC2012/sub-images/" + img_name + "_3.jpg", cropped_img_3)
# Write Text
text_file = open("/mosaic/trainval/VOCdevkit/VOC2012/sub-annotations/" + img_name + "_3.txt", "w+")
for infor in new_img_info_3:
print(infor[0], infor[1], infor[2], infor[3], infor[4], file=text_file)
text_file.close()
break
# if count == 10: break
print(count)
- 拼接图像:
for a in range(1):
i = randint(0, 1000)
top_left = cv2.imread(sub_img_src[4*i])
top_right = cv2.imread(sub_img_src[4*i + 1])
bot_left = cv2.imread(sub_img_src[4*i + 2])
bot_right = cv2.imread(sub_img_src[4*i + 3])
plt.figure(figsize = (15, 10))
plt.subplot(221), plt.imshow(top_left)
plt.subplot(222), plt.imshow(top_right)
plt.subplot(223), plt.imshow(bot_left)
plt.subplot(224), plt.imshow(bot_right)
plt.show()
top = np.hstack((top_left, top_right))
bot = np.hstack((bot_left, bot_right))
full = np.vstack((top, bot))
plt.figure(figsize = (15, 10))
plt.imshow(cv2.cvtColor(full, cv2.COLOR_BGR2RGB))
plt.show()
for a in range(1):
# i = randint(0, 1000)
top_left = cv2.imread(sub_img_src[4*i])
anno = file_lines_to_list("/mosaic/trainval/VOCdevkit/VOC2012/sub-annotations/" + sub_img_names[4*i] + ".txt")
dymmy = top_left.copy()
for obj in anno:
rec_tf = cv2.rectangle(dymmy, (int(obj[1]), int(obj[2])), (int(obj[3]), int(obj[4])), (0, 255, 0), thickness = 1)
top_right = cv2.imread(sub_img_src[4*i + 1])
anno = file_lines_to_list("/mosaic/trainval/VOCdevkit/VOC2012/sub-annotations/" + sub_img_names[4*i + 1] + ".txt")
dymmy = top_right.copy()
for obj in anno:
rec_tr = cv2.rectangle(dymmy, (int(obj[1]), int(obj[2])), (int(obj[3]), int(obj[4])), (0, 255, 0), thickness = 1)
bot_left = cv2.imread(sub_img_src[4*i + 2])
anno = file_lines_to_list("/mosaic/trainval/VOCdevkit/VOC2012/sub-annotations/" + sub_img_names[4*i + 2] + ".txt")
dymmy = bot_left.copy()
for obj in anno:
rec_bl = cv2.rectangle(dymmy, (int(obj[1]), int(obj[2])), (int(obj[3]), int(obj[4])), (0, 255, 0), thickness = 1)
bot_right = cv2.imread(sub_img_src[4*i + 3])
anno = file_lines_to_list("/mosaic/trainval/VOCdevkit/VOC2012/sub-annotations/" + sub_img_names[4*i + 3] + ".txt")
dymmy = bot_right.copy()
for obj in anno:
rec_br = cv2.rectangle(dymmy, (int(obj[1]), int(obj[2])), (int(obj[3]), int(obj[4])), (0, 255, 0), thickness = 1)
plt.figure(figsize = (15, 10))
plt.subplot(221), plt.imshow(rec_tf)
plt.subplot(222), plt.imshow(rec_tr)
plt.subplot(223), plt.imshow(rec_bl)
plt.subplot(224), plt.imshow(rec_br)
plt.show()
reg_top = np.hstack((rec_tf, rec_tr))
reg_bot = np.hstack((rec_bl, rec_br))
reg_full = np.vstack((reg_top, reg_bot))
plt.figure(figsize = (15, 10))
plt.imshow(cv2.cvtColor(reg_full, cv2.COLOR_BGR2RGB))
plt.show()
- 为这些增强的图片,重写xml文件:
from xml.etree import ElementTree
from xml.dom import minidom
from xml.etree import ElementTree
from xml.etree.ElementTree import Element, SubElement
import os.path as osp
def prettify(elem):
"""Return a pretty-printed XML string for the Element.
"""
rough_string = ElementTree.tostring(elem, 'utf-8')
reparsed = minidom.parseString(rough_string)
return reparsed.toprettyxml(indent="\t")
w_offset = 180
h_offset = 180
for i in range(len(sub_img_src)//4):
bbox = []
top_left = cv2.imread(sub_img_src[4*i])
anno = file_lines_to_list("/mosaic/trainval/VOCdevkit/VOC2012/sub-annotations/" + sub_img_names[4*i] + ".txt")
for obj in anno:
bbox.append(obj)
top_right = cv2.imread(sub_img_src[4*i + 1])
anno = file_lines_to_list("/mosaic/trainval/VOCdevkit/VOC2012/sub-annotations/" + sub_img_names[4*i + 1] + ".txt")
for obj in anno:
x1 = int(obj[1]) + w_offset
x2 = int(obj[3]) + w_offset
bbox.append([obj[0], x1, obj[2], x2, obj[4]])
bot_left = cv2.imread(sub_img_src[4*i + 2])
anno = file_lines_to_list("/mosaic/trainval/VOCdevkit/VOC2012/sub-annotations/" + sub_img_names[4*i + 2] + ".txt")
for obj in anno:
y1 = int(obj[2]) + h_offset
y2 = int(obj[4]) + h_offset
bbox.append([obj[0], obj[1], y1, obj[3], y2])
bot_right = cv2.imread(sub_img_src[4*i + 3])
anno = file_lines_to_list("/mosaic/trainval/VOCdevkit/VOC2012/sub-annotations/" + sub_img_names[4*i + 3] + ".txt")
for obj in anno:
x1 = int(obj[1]) + w_offset
x2 = int(obj[3]) + w_offset
y1 = int(obj[2]) + h_offset
y2 = int(obj[4]) + h_offset
bbox.append([obj[0], x1, y1, x2, y2])
top = np.hstack((top_left, top_right))
bot = np.hstack((bot_left, bot_right))
full = np.vstack((top, bot))
bbox = np.array(bbox)
annotation = Element('annotation')
for obj in bbox:
SubElement(annotation, 'filename').text = sub_img_names[4*i] + "_stack.jpg"
SubElement(annotation, 'folder').text = "VOC2012"
size_ = SubElement(annotation, 'size')
SubElement(size_,'width').text = str(360)
SubElement(size_,'height').text = str(360)
object_ = Element('object')
SubElement(object_, 'name').text = obj[0]
bndbox = Element('bndbox')
SubElement(bndbox, 'xmin').text = obj[1]
SubElement(bndbox, 'ymin').text = obj[2]
SubElement(bndbox, 'xmax').text = obj[3]
SubElement(bndbox, 'ymax').text = obj[4]
object_.append(bndbox)
annotation.append(object_)
# Write XML
with open(osp.join("/mosaic/trainval/VOCdevkit/VOC2012/full-annotations/" + sub_img_names[4*i] + "_stack.xml"), 'w') as f:
f.write(prettify(annotation))
print(i)
- 本文参考github链接为:Implementation of Mosaic Data Augmentation in YOLOv4