把分割结果的mask画到图上

1. 获取分割结果

2. 把分割结果画图片上

# --*-- coding:utf-8 -*-
import cv2
import numpy as np
import xml.etree.ElementTree as ET
 
def xml_reader(filename):
    tree = ET.parse(filename)
    objects = []
    for obj in tree.findall('image'):
        obj_struct = {}
        obj_attrib = obj.attrib
        obj_struct['name'] = obj_attrib['name']
        
        polygons = []
        for polygon in obj.findall('polygon'):
            polygon_attrib = polygon.attrib
            polygon_struct = {}
            polygon_struct['label'] = polygon_attrib['label']
            polygon_struct['points'] = polygon_attrib['points']
            polygons.append(polygon_struct)
        obj_struct['polygons'] = polygons
        objects.append(obj_struct)
    return objects


def vis_mask(img, box, col, alpha=0.7, show_border=True, border_thick=3):
    """
    Visualizes a single binary mask.
    col = np.array([0.2, 0.7, 0.4])*225
    box = np.array([[x1, y1], [x2, y2], [x3, y3], [x4, y4]], dtype = np.int32)
    """
    img = img.astype(np.float32)
    mask = np.zeros((img.shape[0], img.shape[1]), dtype = np.uint8)
    cv2.fillPoly(mask,[box], 1)
    idx = np.nonzero(mask)
    img[idx[0], idx[1], :] *= 1.0 - alpha
    img[idx[0], idx[1], :] += alpha * col 
    if show_border:
        a = cv2.findContours(mask.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
        ctr = np.array(a[1][0]).reshape(-1,1,2).astype(np.int32)
        cv2.drawContours(img, [ctr], -1, (255,255,0), border_thick)  
    return img.astype(np.uint8)


if __name__ == "__main__":
    objects = xml_reader('0.xml')
    for idx, obj in enumerate(objects):
        img_path = obj['name']
        img= cv2.imread(img_path)
        for polygon in obj['polygons']:
            label = polygon['label']
            if label == "person":
                col = np.array([0.1, 0.1, 0.8])*225
            elif label == "box":
                col = np.array([0.1, 0.8, 0.1])*225
            else:
                col = np.array([0.8, 0.1, 0.1])*225
            points = polygon['points']
            points = [i.split(",") for i in points.split(";")]
            points = np.array(points, dtype= np.float)
            box = np.array(points, dtype= np.int32)
            img = vis_mask(img, box, col)
        cv2.imwrite("%s.jpg" % idx, img)
            

看效果:

end!

猜你喜欢

转载自blog.csdn.net/Guo_Python/article/details/106408045