PointRend使用记录

​本文已参与「新人创作礼」活动,一起开启掘金创作之路。

 下面是PointRend的源码位置,接下来先跑下看看

GitHub - zsef123/PointRend-PyTorch: A PyTorch implementation of PointRend: Image Segmentation as Rendering https://github.com/zsef123/PointRend-PyTorch

(1)数据准备 

数据就用公共数据集CamVid,该数据集加背景0共12个类,标签值为0-11,下面是一级目标,目录结构及文件名务必保持一致,因为我后面在数据读取的时候添加了读自己数据集的数据导入函数,文件夹名字是固定了的,当然你也可以改代码。

​编辑

 二级目录,train/val/test,目录结构需要一致,另外如果test只有图像也可以不要labels文件夹

​编辑

 (2)添加自己的数据加载模块

在 init.py文件中添加了get_own函数,加完以后在get_loader函数添加自己数据的引导,另外需要强调下,我自己添加的数据加载没有专门加数据扩充策略,你们自己加下,加了效果应该会好点。

​编辑

init.py代码: 

import os
import cv2
from PIL import Image
import numpy as np
import torch
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.datasets.voc import VOCSegmentation
from torchvision.datasets.cityscapes import Cityscapes

from .transforms import Compose, Resize, ToTensor, Normalize, RandomCrop, RandomFlip, ConvertMaskID


def get_voc(C, split="train"):
    if split == "train":
        transforms = Compose([
            ToTensor(),
            RandomCrop((256, 256)),
            Resize((256, 256)),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    else:
        transforms = Compose([
            ToTensor(),
            Resize((256, 256)),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    return VOCSegmentation(C['root'], download=True, image_set=split, transforms=transforms)


def get_cityscapes(C, split="train"):
    if split == "train":
        # Appendix B. Semantic Segmentation Details
        transforms = Compose([
            ToTensor(),
            RandomCrop(768),
            ConvertMaskID(Cityscapes.classes),
            RandomFlip(),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    else:
        transforms = Compose([
            ToTensor(),
            Resize(768),
            ConvertMaskID(Cityscapes.classes),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    return Cityscapes(**C, split=split, transforms=transforms)

class get_own(torch.utils.data.Dataset):

    def __init__(self, C, split="train"):
        images_path = os.path.join(C['root'], split, 'images')
        labels_path = os.path.join(C['root'], split, 'labels')

        images_path_list = []
        labels_path_list = []

        imgs = os.listdir(images_path)
        for name in imgs:
            img_full_path = os.path.join(images_path, name)
            lab_full_path = os.path.join(labels_path, name)
            images_path_list.append(img_full_path)
            labels_path_list.append(lab_full_path)

        self.images_path_list = images_path_list
        self.labels_path_list = labels_path_list

        if split == "train":
            # Appendix B. Semantic Segmentation Details
            Transform = Compose([
                ToTensor(),
                # RandomCrop(256),
                # ConvertMaskID(Cityscapes.classes),
                # RandomFlip()
                # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        else:
            Transform = Compose([
                ToTensor(),
                # Resize(256),
                # ConvertMaskID(Cityscapes.classes),
                # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        
        self.transform = Transform
       
    def __getitem__(self,index):  
        image_path = self.images_path_list[index]
        label_path = self.labels_path_list[index]

        # image = Image.open(image_path).convert('RGB')
        # label = Image.open(label_path)
        image = cv2.imread(image_path)
        label = cv2.imread(label_path, 0)
        image = np.array(image, np.float32) / 255.0
        label = np.array(label, np.float32)
        # image = self.transform('images':image)
        # label = self.transform('masks':label)
        image, label = self.transform(image, label)
        # image = image.type(torch.FloatTensor)
        # label = label.type(torch.FloatTensor)
             
        return image, label
        
    def __len__(self):
        return len(self.images_path_list)


def get_loader(C, split, distributed):
    """
    Args:
        C (Config): C.data
        split (str): args of dataset,
                    The image split to use, ``train``, ``test`` or ``val`` if split="gtFine"
                    otherwise ``train``, ``train_extra`` or ``val`
    """
    print(C.name, C.dataset, split)
    if C.name == "cityscapes":
        dset = get_cityscapes(C.dataset, split)
    elif C.name == "pascalvoc":
        dset = get_voc(C.dataset, split)
    elif C.name == "own":
        dset = get_own(C.dataset, split)

    if split == "train":
        shuffle = True
        drop_last = False
    else:
        shuffle = False
        drop_last = False

    sampler = None
    if distributed:
        sampler = DistributedSampler(dset, shuffle=shuffle)
        shuffle = None

    return DataLoader(dset, **C.loader, sampler=sampler,
                      shuffle=shuffle, drop_last=drop_last,
                      pin_memory=True)

 (3)训练

 这个GitHub项目结构比较好,训练模块在train.py中,不需要改,主要改main.py文件中的部分东西,由于这个项目用了apex来加速训练,而我这里安装不方便,还报错了,我main.py的主要改动就是注释掉apex相关的部分。

main.py代码

import os
import sys
import argparse
import logging
from tokenize import Double
from configs.parser import Parser

import torch

# from apex import amp
# from apex.parallel import DistributedDataParallel as ApexDDP

from model import deeplabv3, PointHead, PointRend
from datas import get_loader
from train import train
from utils.gpus import synchronize, is_main_process


def parse_args():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument("--config", type=str, default="./configs/default.yaml", help="It must be config/*.yaml")  #yaml文件是必要的配置文件,后面会简要说明
    parser.add_argument("--save", type=str, default="build", help="Save path in out directory")
    parser.add_argument("--local_rank", type=int, default=0, help="Using for Apex DDP")
    return parser.parse_args()


def amp_init(args):
    # Apex Initialize
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
        synchronize()

    torch.backends.cudnn.benchmark = True


def set_loggging(save_dir):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                        format=log_format, datefmt='[%y/%m/%d %H:%M:%S]')

    fh = logging.FileHandler(f"{save_dir}/log.txt")
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)


if __name__ == "__main__":
    args = parse_args()
    amp_init(args)

    parser = Parser(args.config)
    C = parser.C
    save_dir = f"{os.getcwd()}/outs/{args.save}"

    if is_main_process():
        if not os.path.exists(save_dir):
            os.makedirs(save_dir, mode=0o775)

        parser.dump(f"{save_dir}/config.yaml")

        set_loggging(save_dir)

    device = torch.device("cuda")
    train_loader = get_loader(C.data, "train", distributed=args.distributed)
    valid_loader = get_loader(C.data, "val", distributed=args.distributed)

    net = PointRend(
        deeplabv3(**C.net.deeplab),
        PointHead(**C.net.pointhead)
    ).to(device)

    params = [{"params": net.backbone.backbone.parameters(),   "lr": float(C.train.lr)},
              {"params": net.head.parameters(),                "lr": float(C.train.lr)},
              {"params": net.backbone.classifier.parameters(), "lr": float(C.train.lr) * 10}]

    # optim = torch.optim.SGD(params, momentum=C.train.momentum, weight_decay=C.train.weight_decay)
    #这里尝试了用adamw优化器训练
    optim = torch.optim.AdamW(params, lr=float(C.train.lr), weight_decay=float(C.train.weight_decay))
    
    #这里注释了需要apex加速的模块
    # net, optim = amp.initialize(net, optim, opt_level=C.apex.opt)
    # if args.distributed:
    #     net = ApexDDP(net, delay_allreduce=True)

    train(C.run, save_dir, train_loader, valid_loader, net, optim, device)



#Apex混合精度加速 介绍:为了帮助提高Pytorch的训练效率,英伟达提供了混合精度训练工具Apex。
# 号称能够在不降低性能的情况下,将模型训练的速度提升2-4倍,训练显存消耗减少为之前的一半。
# 该项目开源于:https://github.com/NVIDIA/apex ,文档地址是:https://nvidia.github.io/apex/index.html该工具提供了三个功能,amp、parallel和normalization。

训练用的default.yaml文件

data:
  name: "own"
  dataset:
    root: "./datasets/CamVid/"
    mode: "fine"
    target_type: "semantic"
  loader:
    batch_size: 5
    num_workers: 0

net:
  deeplab:
    pretrained: False
    resnet: "res101"
    head_in_ch: 2048
    num_classes: 12
  pointhead:
    in_c: 524 # 512 + num_classes
    num_classes: 12
    k: 3
    beta: 0.75

run:
  epochs: 101

train:
  lr: 1e-3       
  momentum: 0.9
  weight_decay: 1e-3

apex:
  opt: "O0"

(4)预测

原始的预测用的是infer.py文件,这个预测要加载标签,而且会给出精度评价,我考虑到会有直接预测而不加标签预测的情况,改了一个预测代码

predict.py代码:

import os
import time
import logging
import cv2
from PIL import Image
import numpy as np
import torch
import argparse
from configs.parser import Parser
from model import deeplabv3, PointHead, PointRend
from utils.metrics import ConfusionMatrix
from utils.gpus import synchronize, is_main_process

@torch.no_grad()
def infer(loader, net, device):
    net.eval()
    num_classes = 2 # Hard coding for Cityscapes
    metric = ConfusionMatrix(num_classes)
    for i, (x, gt) in enumerate(loader):
        x = x.to(device, non_blocking=True)
        gt = gt.squeeze(1).to(device, dtype=torch.long, non_blocking=True)

        pred = net(x)["fine"].argmax(1)

        metric.update(pred, gt)

    mIoU = metric.mIoU()
    logging.info(f"[Infer] mIOU : {mIoU}")
    return mIoU

def parse_args():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument("--config", type=str, default="./configs/default.yaml", help="It must be config/*.yaml")
    parser.add_argument("--save", type=str, default="build", help="Save path in out directory")
    parser.add_argument("--local_rank", type=int, default=0, help="Using for Apex DDP")
    return parser.parse_args()

def amp_init(args):
    # Apex Initialize
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
        synchronize()

    torch.backends.cudnn.benchmark = True

def predict(data_path, model_path, net, save_path):
    net.load_state_dict(torch.load(model_path))
    net.eval()
    img_names = os.listdir(data_path)
    for ele in img_names:
        full_path = os.path.join(data_path, ele)
        # image = Image.open(full_path).convert('RGB')
        image = cv2.imread(full_path)
        image = np.array(image, np.float32) / 255.0
        # image = np.array(image)
        image = image.transpose(2,0,1)
        image = np.expand_dims(image, axis=0)
        # image = torch.from_numpy(image)
        image = torch.FloatTensor(image)
        x = image.to(device, non_blocking=True)
        pred = net(x)["fine"].argmax(1)
        # pred = net(x)["fine"]
        save_full_path = os.path.join(save_path, ele)
        pred = pred.cpu().numpy()
        cv2.imwrite(save_full_path, pred[0])


if __name__ == "__main__":
    path = './datasets/CamVid/test/'
    save_path = './datasets/pred/'
    model_path = './outs/CamVid/epoch_0100_loss_0.54185.pth'
    args = parse_args()
    amp_init(args)

    parser = Parser(args.config)
    C = parser.C

    device = torch.device("cuda")
    net = PointRend(
        deeplabv3(**C.net.deeplab),
        PointHead(**C.net.pointhead)
    ).to(device)

    predict(path, model_path, net, save_path)

效果:

​编辑  ​编辑

                                    图像                                                                        标签

编辑

 预测结果

从结果看,很明显效果不理想,不过不要太过悲观,因为我去掉了加速模块,这个训练有点慢,我训练了100个epoch就停掉了,并且数据也没有做增强,效果肯定还可以提高的。

猜你喜欢

转载自juejin.im/post/7123841452165038094