基于Matterport版本的Mask-RCNN训练自己的数据集

转载自:https://blog.csdn.net/l297969586/article/details/79140840


本文是转载的,但可以做个补充:
将model.train()的头结构训练阶段的epochs=1改为较大的数,如100,将全网络训练的epochs=1改为200,可以在tensorboard上得到较为平滑的loss曲线。如果不改,得到的loss曲线是一条直线(只保存了两次迭代的loss数据)。因为epochs=100后,总共会迭代100*100=10000次,所以可以将STEPS_PER_EPOCH = 100改小一点,以缩短训练时间;


一、工具

cuda与cudnn安装请参考我之前博客:
http://blog.csdn.net/l297969586/article/details/53320706
http://blog.csdn.net/l297969586/article/details/67632608
tensorflow安装:
http://blog.csdn.net/l297969586/article/details/72820310
ipython-notebook:
http://blog.csdn.net/l297969586/article/details/77851039
Mask-RCNN :

https://github.com/matterport/Mask_RCNN

labelme(标注mask数据集用的):

https://github.com/wkentaro/labelme

二、修改训练代码

主要修改train_shapes.ipynb,我个人感觉ipython-notebook不好用,所以我将它转成.py格式,就是把代码粘出来。let’s go!
1、注释%matplotlib inline
2、在ShapesConfig类中,GPU_COUNT = 2,IMAGES_PER_GPU = 1两个参数自己根据自己电脑配置修改参数,由于该工程用的resnet101为主干的网络,训练需要大量的显存支持,我的图片尺寸为1280*800的,IMAGES_PER_GPU 设置为2,在两个GeForce GTX TITAN X上训练显存都会溢出,所以IMAGES_PER_GPU = 1,大佬可忽略。
NUM_CLASSES = 1 + 4为你数据集的类别数,第一类为bg,我的是4类,所以为1+4
IMAGE_MIN_DIM = 800,IMAGE_MAX_DIM = 1280修改为自己图片尺寸
RPN_ANCHOR_SCALES = (8 * 6, 16 * 6, 32 * 6, 64 * 6, 128 * 6),根据自己情况设置anchor大小
3、在全局定义一个iter_num=0
△4、重新写一个训练类
名字自己起,我的叫

class DrugDataset(utils.Dataset):

添加函数

#得到该图中有多少个实例(物体)
def get_obj_index(self, image):
        n = np.max(image)
        return n
#解析labelme中得到的yaml文件,从而得到mask每一层对应的实例标签
def from_yaml_get_class(self,image_id):
        info=self.image_info[image_id]
        with open(info['yaml_path']) as f:
            temp=yaml.load(f.read())
            labels=temp['label_names']
            del labels[0]
        return labels
#重新写draw_mask
def draw_mask(self, num_obj, mask, image):
        info = self.image_info[image_id]
        for index in range(num_obj):
            for i in range(info['width']):
                for j in range(info['height']):
                    at_pixel = image.getpixel((i, j))
                    if at_pixel == index + 1:
                        mask[j, i, index] =1
        return mask
#重新写load_shapes,里面包含自己的自己的类别(我的是box、column、package、fruit四类)
#并在self.image_info信息中添加了path、mask_path 、yaml_path
def load_shapes(self, count, height, width, img_floder, mask_floder, imglist,dataset_root_path):
        """Generate the requested number of synthetic images.
        count: number of images to generate.
        height, width: the size of the generated images.
        """
        # Add classes
        self.add_class("shapes", 1, "box")
        self.add_class("shapes", 2, "column")
        self.add_class("shapes", 3, "package")
        self.add_class("shapes", 4, "fruit")
        for i in range(count):
            filestr = imglist[i].split(".")[0]
            filestr = filestr.split("_")[1]
            mask_path = mask_floder + "/" + filestr + ".png"
            yaml_path=dataset_root_path+"total/rgb_"+filestr+"_json/info.yaml"
            self.add_image("shapes", image_id=i, path=img_floder + "/" + imglist[i],
                           width=width, height=height, mask_path=mask_path,yaml_path=yaml_path)
#重写load_mask
    def load_mask(self, image_id):
        """Generate instance masks for shapes of the given image ID.
        """
        global iter_num
        info = self.image_info[image_id]
        count = 1  # number of object
        img = Image.open(info['mask_path'])
        num_obj = self.get_obj_index(img)
        mask = np.zeros([info['height'], info['width'], num_obj], dtype=np.uint8)
        mask = self.draw_mask(num_obj, mask, img)
        occlusion = np.logical_not(mask[:, :, -1]).astype(np.uint8)
        for i in range(count - 2, -1, -1):
            mask[:, :, i] = mask[:, :, i] * occlusion
            occlusion = np.logical_and(occlusion, np.logical_not(mask[:, :, i]))
        labels=[]
        labels=self.from_yaml_get_class(image_id)
        labels_form=[]
        for i in range(len(labels)):
            if labels[i].find("box")!=-1:
                #print "box"
                labels_form.append("box")
            elif labels[i].find("column")!=-1:
                #print "column"
                labels_form.append("column")
            elif labels[i].find("package")!=-1:
                #print "package"
                labels_form.append("package")
            elif labels[i].find("fruit")!=-1:
                #print "fruit"
                labels_form.append("fruit")
        class_ids = np.array([self.class_names.index(s) for s in labels_form])
        return mask, class_ids.astype(np.int32)

4、代码主体修改

#基础设置
dataset_root_path="/home/yangjunfeng/workspace_lj/fg_dateset/"
img_floder = dataset_root_path+"rgb"
mask_floder = dataset_root_path+"mask"
#yaml_floder = dataset_root_path
imglist = listdir(img_floder)
count = len(imglist)
width = 1280
height = 800
#train与val数据集准备
dataset_train = DrugDataset()
dataset_train.load_shapes(count, 800, 1280, img_floder, mask_floder, imglist,dataset_root_path)
dataset_train.prepare()

dataset_val = DrugDataset()
dataset_val.load_shapes(count, 800, 1280, img_floder, mask_floder, imglist,dataset_root_path)
dataset_val.prepare()

注释掉
model.train(dataset_train,dataset_val,learning_rate=config.LEARNING_RATE/10,epochs=50,layers="all")之后的代码就好了

三、使用labelme生成mask掩码数据集

github地址:https://github.com/wkentaro/labelme
安装方式:

sudo apt-get install python-qt4 pyqt4-dev-tools
sudo pip install labelme

使用,只需在终端输入:

labelme

我的数据集命名如下
这里写图片描述
Note:在画掩码过程中如有多个box、fruit…命名规则为box1、box2..fruit1、fruit2..。因为labelme这个标定工具还是不太智能,最后生成的标签为一个label.png文件,这个文件只有一通道,在你标注时同一标签mask会被给予一个标签位,而mask要求不同的实例要放在不同的层中。最终训练索要得到的输入为一个w*h*n的ndarray,其中n为该图片中实例的个数。总而言之,画mask时就按照上述命名规则就好了,具体的过程已经在上述代码中实现。如图:这里写图片描述
此时labelme生成的为.json文件,需要将json文件转换为我们需要的标签文件,我这里写了一个简单的脚本,不用一个个去转化了,只需将s1改为你对应的路径及图片前缀名,循环数改为自己数据集数即可

#!/bin/bash
s1="/media/lj/GSP1RMCPRXV/fg_dateset/json/rgb_"
s2=".json"
for((i=1;i<901;i++))
do 
s3=${i}
labelme_json_to_dataset ${s1}</span><span class="hljs-variable">${s3}${s2}
done

在你图片目录下会生成多个rgb_x_json文件夹,每个文件夹中有img.png(原图),info.yaml,label.png,label_viz.png四个文件,其中需要用的只有info.yaml以及label.png
转化出来的可视化标签如图:
这里写图片描述

四、转化label.png为可用格式

labelme生成的掩码标签 label.png为16位存储,opencv默认读取8位,需要将16位转8位
参考:http://blog.csdn.net/l297969586/article/details/79154150

五、训练

直接运行修改后的py文件即可,训练中图片展示:
这里写图片描述

六、结果展示

测试demo也需要改,回头再写。。
我只训练了四个类(box,column,package,friut)
测试图片未参与训练,测试结果如下:
这里写图片描述
这里写图片描述

猜你喜欢

转载自blog.csdn.net/Xiongchao99/article/details/79778046