Tensorflow2.0 YOLO篇之图像信息预处理

Tensorflow2.0 YOLO篇之提取图像信息预处理



在上一篇博文中,我们已经将图片的路径和boxex信息分别存储在imgs和boxes的变量当中,现在我们就把他们转化为YOLO网络可以用的形式,本文中还是使用过vscode支持的jupyter的分布执行。有一说一,这个功能是真的好用.这里的代码是接着上一篇的,等到整个功能完全完成了,我会在最后一篇中留下完整代码的下载链接

转化为张量类型

使用tf.data.Dataset.from_tensor_slices将数据加载为tensor张量的形式

def preprocess(img,img_boxes):
    # img :string path
    # img_boxes: [40,5]
    x = tf.io.read_file(img)
    x = tf.image.decode_png(x,channels=3)
    x =tf.image.convert_image_dtype(x,tf.float32)
    return x ,img_boxes

def get_dataset(img_dir,ann_dir,batchsz):
    # return tf dataset
    # [b] ,boxes [b,40,5]
    imgs,boxes = parse_annotation('data/train/image','data/train/annotation',obj_names)
    db = tf.data.Dataset.from_tensor_slices((imgs,boxes))
    db = db.shuffle(1000).map(preprocess).batch(batchsz).repeat()
    print('db Images',len(imgs))
    return db

可视化图片

我们将刚才记载好的信息可视化进行查看

# visual the db
from matplotlib import pyplot as plt
from matplotlib import patches

def db_visualize(db):
    # imgs: [b,512,512,3]
    # imgs_boxes: [b,40,5]
    imgs,imgs_boxes = next(iter(db))
    img,img_boxes = imgs[0],imgs_boxes[0]
    f,ax1 = plt.subplots(1,figsize=(10,10))
    # display the image ,[512,512,3]
    ax1.imshow(img)
    for x1,y1,x2,y2,l in img_boxes: # [40,5]
        x1,y1,x2,y2 = float(x1),float(y1),float(x2),float(y2)
        w =x2-x1
        h = y2-y1
        if l == 1: # green for sugarweet
            color = (0,1,0)
        elif l ==2: # red for weed
            color = (1,0,0)
        else: # ignore invalid boxes
            break
        rect = patches.Rectangle((x1,y1),w,h,linewidth=2,edgecolor=color,
        facecolor='none')
        ax1.add_patch(rect)
        
# %%
db_visualize(train_db)

效果如下图
在这里插入图片描述

图片增强

通过图片裁剪翻转等操作增加图片的数量,需要注意的在操作图片的时候图片的位置信息也会跟着改变

# %%
# data augementation
import imgaug as ia
from    imgaug import augmenters as iaa
def augmentation_generator(yolo_dataset):
    '''
    Augmented batch generator from a yolo dataset

    Parameters
    ----------
    - YOLO dataset

    Returns
    -------
    - augmented batch : tensor (shape : batch_size, IMAGE_W, IMAGE_H, 3)
        batch : tupple(images, annotations)
        batch[0] : images : tensor (shape : batch_size, IMAGE_W, IMAGE_H, 3)
        batch[1] : annotations : tensor (shape : batch_size, max annot, 5)
    '''
    for batch in yolo_dataset:
        # conversion tensor->numpy
        img = batch[0].numpy()
        boxes = batch[1].numpy()
        # conversion bbox numpy->ia object
        ia_boxes = []
        for i in range(img.shape[0]):
            ia_bbs = [ia.BoundingBox(x1=bb[0],
                                       y1=bb[1],
                                       x2=bb[2],
                                       y2=bb[3]) for bb in boxes[i]
                      if (bb[0] + bb[1] +bb[2] + bb[3] > 0)]
            ia_boxes.append(ia.BoundingBoxesOnImage(ia_bbs, shape=(512, 512)))
        # data augmentation
        seq = iaa.Sequential([
            iaa.Fliplr(0.5),
            iaa.Flipud(0.5),
            iaa.Multiply((0.4, 1.6)), # change brightness
            #iaa.ContrastNormalization((0.5, 1.5)),
            #iaa.Affine(translate_px={"x": (-100,100), "y": (-100,100)}, scale=(0.7, 1.30))
            ])
        #seq = iaa.Sequential([])
        seq_det = seq.to_deterministic()
        img_aug = seq_det.augment_images(img)
        img_aug = np.clip(img_aug, 0, 1)
        boxes_aug = seq_det.augment_bounding_boxes(ia_boxes)
        # conversion ia object -> bbox numpy
        for i in range(img.shape[0]):
            boxes_aug[i] = boxes_aug[i].remove_out_of_image().clip_out_of_image()
            for j, bb in enumerate(boxes_aug[i].bounding_boxes):
                boxes[i,j,0] = bb.x1
                boxes[i,j,1] = bb.y1
                boxes[i,j,2] = bb.x2
                boxes[i,j,3] = bb.y2
        # conversion numpy->tensor
        batch = (tf.convert_to_tensor(img_aug), tf.convert_to_tensor(boxes))
        #batch = (img_aug, boxes)

#%%
aug_train_db = augmentation_generator(train_db)
db_visualize(aug_train_db)

Compose labels

刚才我们加载的label信息的格式试试[b,40,5],YOLO网络需要的形式是[b,16,16,5,6] 接下来我们将label是信息转换成对应的格式

首先我们想从5锚点中分别计算iou,选取最佳的anchors

# %%
IMGSZ = 512
GEIDSZ = 16
ANCHORS = [0.57273, 0.677385, 1.87446, 2.06253, 3.33843, 5.47434, 7.88282, 3.52778, 9.77052, 9.16828]


# %%
def process_true_boxes(gt_boxes,anchors):
    # gt_boxes: [40,5]
    # 512//16 = 32
    scale = IMGSZ // GEIDSZ
    # [5,2]
    anchors = np.array(anchors).reshape((5,2))

    # mask for object
    detector_mask = np.zeros([GEIDSZ,GEIDSZ,5,1])
    # x-y-w-h-l
    matching_gt_box = np.zeros([GEIDSZ,GEIDSZ,5,5])
    # [N,5] x1-y1-x2-y2-l => x-y-w-h-l
    gt_boxes_grid = np.zeros(gt_boxes.shape)
    # DB: tensor => numpy
    gt_boxes = gt_boxes.numpy()
    for i,box in enumerate(gt_boxes):
        # box: [5],x1-y1-x2-y2
        # 512  => 16
        x = ((box[0]+box[2])/2)/scale
        y = ((box[1]+box[3])/2)/scale
        w = (box[2]-box[0])/scale
        h = (box[3]-box[1])/scale
        # [40,5] x-y-w-h-i
        gt_boxes_grid[i] = np.array([x,y,w,h,box[4]])

        if w*h > 0: # valid box
            # fine the best anchor
            best_anchor = 0
            for j in range(5):
                # calcute IOU
                interct = np.minimum(w,anchors[j,0])*np.minimum(h,anchors[j,1])
                union = w*h + (anchors[j,1]*anchors[j,0]) - interct
                iou = interct/union
                if iou > best_anchor:
                    best_anchor = j
                    best_iou = iou
            # found the best anchors
            if best_iou > 0:
                x_coord = np.floor(x).astype(np.int32)
                y_coord = np.floor(y).astype(np.int32)
                # [b,h,w,5,1]
                detector_mask[y_coord,x_coord,best_anchor] = 1
                # [b,h,w,5,x-y-w-h-l]
                matching_gt_box[y_coord,x_coord,best_anchor] = \
                    np.array([x,y,w,h,box[4]])
    # [40,5] => [16,16,5,5]
    # [16,16,5,5]
    # [16,16,5,1]
    # [40,5]
    return matching_gt_box,detector_mask,gt_boxes_grid

接下来调整格式:

def ground_truth_generator(db):
    for imgs,imgs_boxes in db:
        # imgs: [b,512,512,3]
        # imgs_boxes: [b,40,5]
        batch_matching_gt_boxes = []
        batch_detector_mask = []
        batch_gt_boxes_grid = []
        b = imgs.shape[0]
        for i in range(b): # for each image
            matching_gt_box,detector_mask,gt_boxes_grid = \
                process_true_boxes(imgs_boxes[i],ANCHORS)
            batch_matching_gt_boxes.append(matching_gt_box)
            batch_detector_mask.append(detector_mask)
            batch_gt_boxes_grid.append(gt_boxes_grid)

        # [b,16,16,5,1]
        detector_mask = tf.cast(np.array(batch_detector_mask),dtype=tf.float32)
        # [b,16,16,5,5] x-y-w-h-l
        matching_gt_box = tf.cast(np.array(batch_matching_gt_boxes),dtype=tf.float32)
        # [b,40,5] x-y-w-h-l
        gt_boxes_grid = tf.cast(np.array(batch_gt_boxes_grid),dtype=tf.float32)

        matching_classes = tf.cast(matching_gt_box[...,4],dtype=tf.int32)
        matching_classes_oh = tf.one_hot(matching_classes,depth=3)
        # x-y-w-h-conf-l1-l2
        # [b,16,16,5,2]
        matching_classes_oh = tf.cast(matching_classes_oh[...,1:],dtype=tf.float32)
        # [b,512,512,3]
        # [b,16,16,5,1]
        # [b,16,16,5,5]
        # [b,16,16,5,2]
        # [b,40,5]
        yield imgs,detector_mask,matching_gt_box,matching_classes_oh,gt_boxes_grid

可视化效果

接下来我们可视化一下看一下效果

# visualize object mask
# train_db -> aug_train_db -> train_gen
train_gen = ground_truth_generator(aug_train_db)
img,detector_mask,matching_gt_box,matching_classes_oh,gt_boxes_grid =\
    next(train_gen)
img,detector_mask,matching_gt_box,matching_classes_oh,gt_boxes_grid=\
    img[0],detector_mask[0],matching_gt_box[0],matching_classes_oh[0],gt_boxes_grid[0]
fig,(ax1,ax2) = plt.subplots(2,figsize=(5,10))
ax1.imshow(img)
# [b,16,16,5,1] => [16,16,1]
mask = tf.reduce_sum(detector_mask,axis=2)
ax2.matshow(mask[...,0]) # [16,16]

在这里插入图片描述


参考书籍: TensorFlow 深度学习 — 龙龙老师

扫描二维码关注公众号,回复: 11196612 查看本文章
原创文章 113 获赞 80 访问量 3万+

猜你喜欢

转载自blog.csdn.net/python_LC_nohtyp/article/details/104842099
今日推荐