def main():
"""
Main Function
"""
# Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer
assert_and_infer_cfg(args)
writer = prep_experiment(args, parser)
train_loader, val_loader, train_obj = datasets.setup_loaders(args)
1:在train.py文件中,看如何加载数据的。在setup_loaders函数中加载args。
args包括一些基础的配置:
#交叉验证
parser.add_argument('--cv', type=int, default=None,
help='cross-validation split id to use. Default # of splits set to 3 in config')
parser.add_argument('--class_uniform_pct', type=float, default=0.5,
help='What fraction of images is uniformly sampled')
parser.add_argument('--class_uniform_tile', type=int, default=1024,
help='tile size for class uniform sampling')
parser.add_argument('--hardnm', default=0, type=int,
help='0 means no aug, 1 means hard negative mining iter 1,' +
'2 means hard negative mining iter 2')
parser.add_argument('--maxSkip', type=int, default=0,
help='Skip x number of frames of video augmented dataset')
parser.add_argument('--scf', action='store_true', default=False,
help='scale correction factor')
2:在setup_loaders函数中:
"""
Dataset setup and loaders
"""
from datasets import cityscapes
from datasets import mapillary
from datasets import kitti
from datasets import camvid
from datasets import uavid
import torchvision.transforms as standard_transforms
import transforms.joint_transforms as joint_transforms
import transforms.transforms as extended_transforms
from torch.utils.data import DataLoader
def setup_loaders(args):
"""
Setup Data Loaders[Currently supports Cityscapes, Mapillary and ADE20kin]
input: argument passed by the user
return: training data loader, validation data loader loader, train_set
"""
if args.dataset == 'cityscapes':
args.dataset_cls = cityscapes
args.train_batch_size = args.bs_mult * args.ngpu
if args.bs_mult_val > 0:
args.val_batch_size = args.bs_mult_val * args.ngpu
else:
args.val_batch_size = args.bs_mult * args.ngpu
elif args.dataset == 'mapillary':
args.dataset_cls = mapillary
args.train_batch_size = args.bs_mult * args.ngpu
args.val_batch_size = 4
elif args.dataset == 'uavid':
args.dataset_cls = uavid
args.train_batch_size = args.bs_mult * args.ngpu
args.val_batch_size = 4
elif args.dataset == 'ade20k':
args.dataset_cls = ade20k
args.train_batch_size = args.bs_mult * args.ngpu
args.val_batch_size = 4
elif args.dataset == 'kitti':
args.dataset_cls = kitti
args.train_batch_size = args.bs_mult * args.ngpu
if args.bs_mult_val > 0:
args.val_batch_size = args.bs_mult_val * args.ngpu
else:
args.val_batch_size = args.bs_mult * args.ngpu
elif args.dataset == 'camvid':
args.dataset_cls = camvid
args.train_batch_size = args.bs_mult * args.ngpu
if args.bs_mult_val > 0:
args.val_batch_size = args.bs_mult_val * args.ngpu
else:
args.val_batch_size = args.bs_mult * args.ngpu
elif args.dataset == 'null_loader':
args.dataset_cls = null_loader
args.train_batch_size = args.bs_mult * args.ngpu
if args.bs_mult_val > 0:
args.val_batch_size = args.bs_mult_val * args.ngpu
else:
args.val_batch_size = args.bs_mult * args.ngpu
else:
raise Exception('Dataset {} is not supported'.format(args.dataset))
# Readjust batch size to mini-batch size for apex
if args.apex:
args.train_batch_size = args.bs_mult
args.val_batch_size = args.bs_mult_val
args.num_workers = 4 * args.ngpu
if args.test_mode:
args.num_workers = 1
mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# Geometric image transformations
train_joint_transform_list = [
joint_transforms.RandomSizeAndCrop(args.crop_size,
False,
pre_size=args.pre_size,
scale_min=args.scale_min,
scale_max=args.scale_max,
ignore_index=args.dataset_cls.ignore_label),
joint_transforms.Resize(args.crop_size),
joint_transforms.RandomHorizontallyFlip()]
train_joint_transform = joint_transforms.Compose(train_joint_transform_list)
# Image appearance transformations
train_input_transform = []
if args.color_aug:
train_input_transform += [extended_transforms.ColorJitter(
brightness=args.color_aug,
contrast=args.color_aug,
saturation=args.color_aug,
hue=args.color_aug)]
if args.bblur:
train_input_transform += [extended_transforms.RandomBilateralBlur()]
elif args.gblur:
train_input_transform += [extended_transforms.RandomGaussianBlur()]
else:
pass
train_input_transform += [standard_transforms.ToTensor(),
standard_transforms.Normalize(*mean_std)]
train_input_transform = standard_transforms.Compose(train_input_transform)
val_input_transform = standard_transforms.Compose([
standard_transforms.ToTensor(),
standard_transforms.Normalize(*mean_std)
])
target_transform = extended_transforms.MaskToTensor()
if args.jointwtborder:
target_train_transform = extended_transforms.RelaxedBoundaryLossToTensor(args.dataset_cls.ignore_label,
args.dataset_cls.num_classes)
else:
target_train_transform = extended_transforms.MaskToTensor()
if args.dataset == 'cityscapes':
city_mode = 'train' ## Can be trainval
city_quality = 'fine'
if args.class_uniform_pct:
if args.coarse_boost_classes:
coarse_boost_classes = \
[int(c) for c in args.coarse_boost_classes.split(',')]
else:
coarse_boost_classes = None
train_set = args.dataset_cls.CityScapesUniform(
city_quality, city_mode, args.maxSkip,
joint_transform_list=train_joint_transform_list,
transform=train_input_transform,
target_transform=target_train_transform,
dump_images=args.dump_augmentation_images,
cv_split=args.cv,
class_uniform_pct=args.class_uniform_pct,
class_uniform_tile=args.class_uniform_tile,
test=args.test_mode,
coarse_boost_classes=coarse_boost_classes)
else:
train_set = args.dataset_cls.CityScapes(
city_quality, city_mode, 0,
joint_transform=train_joint_transform,
transform=train_input_transform,
target_transform=target_train_transform,
dump_images=args.dump_augmentation_images,
cv_split=args.cv)
val_set = args.dataset_cls.CityScapes('fine', 'val', 0,
transform=val_input_transform,
target_transform=target_transform,
cv_split=args.cv)
elif args.dataset == 'mapillary':
eval_size = 1536
val_joint_transform_list = [
joint_transforms.ResizeHeight(eval_size),
joint_transforms.CenterCropPad(eval_size, ignore_index=args.dataset_cls.ignore_label)]
train_set = args.dataset_cls.Mapillary(
'semantic', 'train',
joint_transform_list=train_joint_transform_list,
transform=train_input_transform,
target_transform=target_train_transform,
dump_images=args.dump_augmentation_images,
class_uniform_pct=args.class_uniform_pct,
class_uniform_tile=args.class_uniform_tile,
test=args.test_mode)
val_set = args.dataset_cls.Mapillary(
'semantic', 'val',
joint_transform_list=val_joint_transform_list,
transform=val_input_transform,
target_transform=target_transform,
test=False)
elif args.dataset == 'uavid':
eval_size = 1536
val_joint_transform_list = [
joint_transforms.ResizeHeight(eval_size),
joint_transforms.CenterCropPad(eval_size, ignore_index=args.dataset_cls.ignore_label)]
train_set = args.dataset_cls.UAVid(
'semantic', 'train',
joint_transform_list=train_joint_transform_list,
transform=train_input_transform,
target_transform=target_train_transform,
dump_images=args.dump_augmentation_images,
class_uniform_pct=args.class_uniform_pct,
class_uniform_tile=args.class_uniform_tile,
test=args.test_mode)
# TODO HACK 'val' set to 'train' due to .
val_set = args.dataset_cls.UAVid(
'semantic', 'train',
joint_transform_list=val_joint_transform_list,
transform=val_input_transform,
target_transform=target_transform,
test=False)
elif args.dataset == 'ade20k':
eval_size = 384
val_joint_transform_list = [
joint_transforms.ResizeHeight(eval_size),
joint_transforms.CenterCropPad(eval_size)]
train_set = args.dataset_cls.ade20k(
'semantic', 'train',
joint_transform_list=train_joint_transform_list,
transform=train_input_transform,
target_transform=target_train_transform,
dump_images=args.dump_augmentation_images,
class_uniform_pct=args.class_uniform_pct,
class_uniform_tile=args.class_uniform_tile,
test=args.test_mode)
val_set = args.dataset_cls.ade20k(
'semantic', 'val',
joint_transform_list=val_joint_transform_list,
transform=val_input_transform,
target_transform=target_transform,
test=False)
elif args.dataset == 'kitti':
# eval_size_h = 384
# eval_size_w = 1280
# val_joint_transform_list = [
# joint_transforms.ResizeHW(eval_size_h, eval_size_w)]
train_set = args.dataset_cls.KITTI(
'semantic', 'train', args.maxSkip,
joint_transform_list=train_joint_transform_list,
transform=train_input_transform,
target_transform=target_train_transform,
dump_images=args.dump_augmentation_images,
class_uniform_pct=args.class_uniform_pct,
class_uniform_tile=args.class_uniform_tile,
test=args.test_mode,
cv_split=args.cv,
scf=args.scf,
hardnm=args.hardnm)
val_set = args.dataset_cls.KITTI(
'semantic', 'trainval', 0,
joint_transform_list=None,
transform=val_input_transform,
target_transform=target_transform,
test=False,
cv_split=args.cv,
scf=None)
elif args.dataset == 'camvid':
# eval_size_h = 384
# eval_size_w = 1280
# val_joint_transform_list = [
# joint_transforms.ResizeHW(eval_size_h, eval_size_w)]
train_set = args.dataset_cls.CAMVID(
'semantic', 'trainval', args.maxSkip,
joint_transform_list=train_joint_transform_list,
transform=train_input_transform,
target_transform=target_train_transform,
dump_images=args.dump_augmentation_images,
class_uniform_pct=args.class_uniform_pct,
class_uniform_tile=args.class_uniform_tile,
test=args.test_mode,
cv_split=args.cv,
scf=args.scf,
hardnm=args.hardnm)
val_set = args.dataset_cls.CAMVID(
'semantic', 'test', 0,
joint_transform_list=None,
transform=val_input_transform,
target_transform=target_transform,
test=False,
cv_split=args.cv,
scf=None)
elif args.dataset == 'null_loader':
train_set = args.dataset_cls.null_loader(args.crop_size)
val_set = args.dataset_cls.null_loader(args.crop_size)
else:
raise Exception('Dataset {} is not supported'.format(args.dataset))
if args.apex:
from datasets.sampler import DistributedSampler
train_sampler = DistributedSampler(train_set, pad=True, permutation=True, consecutive_sample=False)
val_sampler = DistributedSampler(val_set, pad=False, permutation=False, consecutive_sample=False)
else:
train_sampler = None
val_sampler = None
train_loader = DataLoader(train_set, batch_size=args.train_batch_size,
num_workers=args.num_workers, shuffle=(train_sampler is None), drop_last=True, sampler = train_sampler)
val_loader = DataLoader(val_set, batch_size=args.val_batch_size,
num_workers=args.num_workers // 2 , shuffle=False, drop_last=False, sampler = val_sampler)
return train_loader, val_loader, train_set
首先指定需要训练的数据集,假设我们只看KITTI:
elif args.dataset == 'kitti':
args.dataset_cls = kitti
args.train_batch_size = args.bs_mult * args.ngpu
if args.bs_mult_val > 0:
args.val_batch_size = args.bs_mult_val * args.ngpu
else:
args.val_batch_size = args.bs_mult * args.ngpu
首先看后面的if-else用于确定训练和验证的batchsize。args.dataset_cls = kitti指定训练数据集为kitti。那么可以通过args.dataset_cls.来调用kitti里面的方法。
接着指定是否采用混合精度,以及加载数据的线程数。
# Readjust batch size to mini-batch size for apex
if args.apex:
args.train_batch_size = args.bs_mult
args.val_batch_size = args.bs_mult_val
args.num_workers = 4 * args.ngpu
if args.test_mode:
args.num_workers = 1
指定对RGB处理归一化处理时的参数:
mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
对数据集进行预处理:我们调用joint_transforms类中的方法进行处理,包括RandomSizeAndCrop,Resize,RandomHorizontallyFlip,三个操作,最后调用composed串联组合在一起。
# Geometric image transformations
train_joint_transform_list = [
joint_transforms.RandomSizeAndCrop(args.crop_size,
False,
pre_size=args.pre_size,
scale_min=args.scale_min,
scale_max=args.scale_max,
ignore_index=args.dataset_cls.ignore_label),
joint_transforms.Resize(args.crop_size),
joint_transforms.RandomHorizontallyFlip()]
train_joint_transform = joint_transforms.Compose(train_joint_transform_list)
接着train_input_transform 对image进行处理,其中:parser.add_argument(‘–color_aug’, type=float,default=0.25, help=‘level of color augmentation’)首先调用extended_transforms函数的ColorJitter方法。分别对应于:
(import transforms.joint_transforms as joint_transforms
import transforms.transforms as extended_transforms)两个文件。接着是bblu(双边滤波)r和gblur(高斯滤波)操作。全部添加到train_input_transform 列表中。
# Image appearance transformations
train_input_transform = []
if args.color_aug:
train_input_transform += [extended_transforms.ColorJitter(
brightness=args.color_aug,
contrast=args.color_aug,
saturation=args.color_aug,
hue=args.color_aug)]
if args.bblur:
train_input_transform += [extended_transforms.RandomBilateralBlur()]
elif args.gblur:
train_input_transform += [extended_transforms.RandomGaussianBlur()]
else:
pass
最后添加转换为tensor和归一化操作。验证集的操作只有换为tensor和归一化操作,剩下的是标签的操作。
train_input_transform += [standard_transforms.ToTensor(),
standard_transforms.Normalize(*mean_std)]
train_input_transform = standard_transforms.Compose(train_input_transform)
val_input_transform = standard_transforms.Compose([
standard_transforms.ToTensor(),
standard_transforms.Normalize(*mean_std)
])
target_transform = extended_transforms.MaskToTensor()
if args.jointwtborder:
target_train_transform = extended_transforms.RelaxedBoundaryLossToTensor(args.dataset_cls.ignore_label,
args.dataset_cls.num_classes)
else:
target_train_transform = extended_transforms.MaskToTensor()
上述操作的总结:
1:
train_joint_transform_list = [
joint_transforms.RandomSizeAndCrop(args.crop_size,#720
False,
pre_size=args.pre_size,#None
scale_min=args.scale_min,#0.5
scale_max=args.scale_max,#2
ignore_index=args.dataset_cls.ignore_label),#255
在call函数中,首先判断image大小和标签大小是否一样,pre_size=None。
接着scale_amt = 1 * random.uniform(self.scale_min, self.scale_max),即首先从(0.5,2)中随机归一化一个尺度scale与1相乘,将image与scale相乘得到一个新的image。将原始的图片resize到新的image图片大小。新的image送入到Randomcrop函数中。
class RandomSizeAndCrop(object):
def __init__(self, size, crop_nopad,
scale_min=0.5, scale_max=2.0, ignore_index=0, pre_size=None):
self.size = size#720
self.crop = RandomCrop(self.size, ignore_index=ignore_index, nopad=crop_nopad)
self.scale_min = scale_min
self.scale_max = scale_max
self.pre_size = pre_size
def __call__(self, img, mask, centroid=None):
assert img.size == mask.size
# first, resize such that shorter edge is pre_size
if self.pre_size is None:
scale_amt = 1.
elif img.size[1] < img.size[0]:
scale_amt = self.pre_size / img.size[1]
else:
scale_amt = self.pre_size / img.size[0]
scale_amt *= random.uniform(self.scale_min, self.scale_max)
w, h = [int(i * scale_amt) for i in img.size]
if centroid is not None:
centroid = [int(c * scale_amt) for c in centroid]
img, mask = img.resize((w, h), Image.BICUBIC), mask.resize((w, h), Image.NEAREST)
return self.crop(img, mask, centroid)
输入的参数为新的rgb的尺寸,忽略标签255,self.size=720。th,tw=(720,720)。如果新生成的image的尺寸等于720,则直接返回image和mask。如果crop的大小大于新的image的大小,进行填充。
class RandomCrop(object):
def __init__(self, size, ignore_index=0, nopad=True):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.ignore_index = ignore_index
self.nopad = nopad
self.pad_color = (0, 0, 0)
def __call__(self, img, mask, centroid=None):
assert img.size == mask.size
w, h = img.size
# ASSUME H, W
th, tw = self.size
if w == tw and h == th:
return img, mask
if self.nopad:
if th > h or tw > w:
# Instead of padding, adjust crop size to the shorter edge of image.
shorter_side = min(w, h)
th, tw = shorter_side, shorter_side
else:
# Check if we need to pad img to fit for crop_size.
if th > h:
pad_h = (th - h) // 2 + 1
else:
pad_h = 0
if tw > w:
pad_w = (tw - w) // 2 + 1
else:
pad_w = 0
border = (pad_w, pad_h, pad_w, pad_h)
if pad_h or pad_w:
img = ImageOps.expand(img, border=border, fill=self.pad_color)
mask = ImageOps.expand(mask, border=border, fill=self.ignore_index)
w, h = img.size
if centroid is not None:
# Need to insure that centroid is covered by crop and that crop
# sits fully within the image
c_x, c_y = centroid
max_x = w - tw
max_y = h - th
x1 = random.randint(c_x - tw, c_x)
x1 = min(max_x, max(0, x1))
y1 = random.randint(c_y - th, c_y)
y1 = min(max_y, max(0, y1))
else:
if w == tw:
x1 = 0
else:
x1 = random.randint(0, w - tw)
if h == th:
y1 = 0
else:
y1 = random.randint(0, h - th)
return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th))
接着是:resize(),将图像和标签resize到720.
joint_transforms.Resize(args.crop_size),
class Resize(object):
"""
Resize image to exact size of crop
"""
def __init__(self, size):
self.size = (size, size)
def __call__(self, img, mask):
assert img.size == mask.size
w, h = img.size
if (w == h and w == self.size):
return img, mask
return (img.resize(self.size, Image.BICUBIC),
mask.resize(self.size, Image.NEAREST))
接着是对图片进行翻转:
joint_transforms.RandomHorizontallyFlip()]
class RandomHorizontallyFlip(object):
def __call__(self, img, mask):
if random.random() < 0.5:
return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(
Image.FLIP_LEFT_RIGHT)
return img, mask
接着是train_input_transform处理:
首先是:extended_transforms.ColorJitter
class ColorJitter(object):
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
@staticmethod
def get_params(brightness, contrast, saturation, hue):
transforms = []
if brightness > 0:
brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
transforms.append(
torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor)))
if contrast > 0:
contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
transforms.append(
torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor)))
if saturation > 0:
saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
transforms.append(
torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor)))
if hue > 0:
hue_factor = np.random.uniform(-hue, hue)
transforms.append(
torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor)))
np.random.shuffle(transforms)
transform = torch_tr.Compose(transforms)
return transform
def __call__(self, img):
"""
Args:
img (PIL Image): Input image.
Returns:
PIL Image: Color jittered image.
"""
transform = self.get_params(self.brightness, self.contrast,
self.saturation, self.hue)
return transform(img)
接着是:extended_transforms.RandomBilateralBlur()双线性滤波。
class RandomBilateralBlur(object):
"""
Apply Bilateral Filtering
"""
def __call__(self, img):
sigma = random.uniform(0.05,0.75)
blurred_img = denoise_bilateral(np.array(img), sigma_spatial=sigma, multichannel=True)
blurred_img *= 255
return Image.fromarray(blurred_img.astype(np.uint8))
然后是:将标签转换为tensor,采用torch.from_numpy。
class MaskToTensor(object):
def __call__(self, img):
return torch.from_numpy(np.array(img, dtype=np.int32)).long()
以上就是所有的预处理操作。
最后就是生成训练集和验证集并进行处理,只看KITTI:
train_set = args.dataset_cls.KITTI(
'semantic', 'train', args.maxSkip,
joint_transform_list=train_joint_transform_list,
transform=train_input_transform,
target_transform=target_train_transform,
dump_images=args.dump_augmentation_images,
class_uniform_pct=args.class_uniform_pct,
class_uniform_tile=args.class_uniform_tile,
test=args.test_mode,
cv_split=args.cv,
scf=args.scf,
hardnm=args.hardnm)
val_set = args.dataset_cls.KITTI(
'semantic', 'trainval', 0,
joint_transform_list=None,
transform=val_input_transform,
target_transform=target_transform,
test=False,
cv_split=args.cv,
scf=None)
在KITTI数据集中:我们看KITT这个类别。
import os
import sys
import numpy as np
from PIL import Image
from torch.utils import data
import logging
import datasets.uniform as uniform
import datasets.cityscapes_labels as cityscapes_labels
import json
from config import cfg
trainid_to_name = cityscapes_labels.trainId2name#(0:road)
id_to_trainid = cityscapes_labels.label2trainid#(0:255,1:255)
num_classes = 19
ignore_label = 255
root = cfg.DATASET.KITTI_DIR
aug_root = cfg.DATASET.KITTI_AUG_DIR
#调色板
palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153,
153, 153, 153, 250, 170, 30,
220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60,
255, 0, 0, 0, 0, 142, 0, 0, 70,
0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
zero_pad = 256 * 3 - len(palette)
for i in range(zero_pad):
palette.append(0)
def colorize_mask(mask):
# mask: numpy array of the mask
new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
new_mask.putpalette(palette)
return new_mask
def get_train_val(cv_split, all_items):
# 90/10 train/val split, three random splits for cross validation
val_0 = [1,5,11,29,35,49,57,68,72,82,93,115,119,130,145,154,156,167,169,189,198]
val_1 = [0,12,24,31,42,50,63,71,84,96,101,112,121,133,141,155,164,171,187,191,197]
val_2 = [3,6,13,21,41,54,61,73,88,91,110,121,126,131,142,149,150,163,173,183,199]
train_set = []
val_set = []
if cv_split == 0:
for i in range(200):
if i in val_0:
val_set.append(all_items[i])
else:
train_set.append(all_items[i])
elif cv_split == 1:
for i in range(200):
if i in val_1:
val_set.append(all_items[i])
else:
train_set.append(all_items[i])
elif cv_split == 2:
for i in range(200):
if i in val_2:
val_set.append(all_items[i])
else:
train_set.append(all_items[i])
else:
logging.info('Unknown cv_split {}'.format(cv_split))
sys.exit()
return train_set, val_set
def make_dataset(quality, mode, maxSkip=0, cv_split=0, hardnm=0):
items = []
all_items = []
aug_items = []
assert quality == 'semantic'
assert mode in ['train', 'val', 'trainval']
# note that train and val are randomly determined, no official split
img_dir_name = "training"
img_path = os.path.join(root, img_dir_name, 'image_2')#/kitti/training/image_2
mask_path = os.path.join(root, img_dir_name, 'semantic')#/kitti/training/semantic
c_items = os.listdir(img_path)#/kitti/training/image_2下所有的图片
c_items.sort()
for it in c_items:
item = (os.path.join(img_path, it), os.path.join(mask_path, it))#image下图片和标签下的图片
all_items.append(item)
logging.info('KITTI has a total of {} images'.format(len(all_items)))
# split into train/val
train_set, val_set = get_train_val(cv_split, all_items)
if mode == 'train':
items = train_set
elif mode == 'val':
items = val_set
elif mode == 'trainval':
items = train_set + val_set
else:
logging.info('Unknown mode {}'.format(mode))
sys.exit()
logging.info('KITTI-{}: {} images'.format(mode, len(items)))
return items, aug_items
def make_test_dataset(quality, mode, maxSkip=0, cv_split=0):
items = []
assert quality == 'semantic'
assert mode == 'test'
img_dir_name = "testing"
img_path = os.path.join(root, img_dir_name, 'image_2')
c_items = os.listdir(img_path)
c_items.sort()
for it in c_items:
item = (os.path.join(img_path, it), None)
items.append(item)#对可迭代对象进行排序
logging.info('KITTI has a total of {} test images'.format(len(items)))
return items, []
class KITTI(data.Dataset):
def __init__(self, quality, mode, maxSkip=0, joint_transform_list=None,
transform=None, target_transform=None, dump_images=False,
class_uniform_pct=0, class_uniform_tile=0, test=False,
cv_split=None, scf=None, hardnm=0):
self.quality = quality#’semantic‘
self.mode = mode#'train'
self.maxSkip = maxSkip#0
self.joint_transform_list = joint_transform_list#transformer
self.transform = transform #train需要做的变换
self.target_transform = target_transform#标签转换为tensor
self.dump_images = dump_images
self.class_uniform_pct = class_uniform_pct#0.5
self.class_uniform_tile = class_uniform_tile#1024
self.scf = scf
self.hardnm = hardnm
if cv_split:#交叉验证
self.cv_split = cv_split
assert cv_split < cfg.DATASET.CV_SPLITS, \
'expected cv_split {} to be < CV_SPLITS {}'.format(
cv_split, cfg.DATASET.CV_SPLITS)
else:
self.cv_split = 0
if self.mode == 'test':
self.imgs, _ = make_test_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split)
else:
self.imgs, _ = make_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split, hardnm=self.hardnm)
assert len(self.imgs), 'Found 0 images, please check the data set'
# Centroids for GT data
if self.class_uniform_pct > 0:
if self.scf:
json_fn = 'kitti_tile{}_cv{}_scf.json'.format(self.class_uniform_tile, self.cv_split)
else:
json_fn = 'kitti_tile{}_cv{}_{}_hardnm{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode, self.hardnm)
if os.path.isfile(json_fn):
with open(json_fn, 'r') as json_data:
centroids = json.load(json_data)
self.centroids = {
int(idx): centroids[idx] for idx in centroids}
else:
if self.scf:
self.centroids = kitti_uniform.class_centroids_all(
self.imgs,
num_classes,
id2trainid=id_to_trainid,
tile_size=class_uniform_tile)
else:
self.centroids = uniform.class_centroids_all(
self.imgs,
num_classes,
id2trainid=id_to_trainid,
tile_size=class_uniform_tile)
with open(json_fn, 'w') as outfile:
json.dump(self.centroids, outfile, indent=4)
self.build_epoch()
def build_epoch(self, cut=False):
if self.class_uniform_pct > 0:
self.imgs_uniform = uniform.build_epoch(self.imgs,
self.centroids,
num_classes,
cfg.CLASS_UNIFORM_PCT)
else:
self.imgs_uniform = self.imgs
def __getitem__(self, index):
elem = self.imgs_uniform[index]
centroid = None
if len(elem) == 4:
img_path, mask_path, centroid, class_id = elem
else:
img_path, mask_path = elem
if self.mode == 'test':
img, mask = Image.open(img_path).convert('RGB'), None
else:
img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
img_name = os.path.splitext(os.path.basename(img_path))[0]
# kitti scale correction factor
if self.mode == 'train' or self.mode == 'trainval':
if self.scf:
width, height = img.size
img = img.resize((width*2, height*2), Image.BICUBIC)
mask = mask.resize((width*2, height*2), Image.NEAREST)
elif self.mode == 'val':
width, height = 1242, 376
img = img.resize((width, height), Image.BICUBIC)
mask = mask.resize((width, height), Image.NEAREST)
elif self.mode == 'test':
img_keepsize = img.copy()
width, height = 1280, 384
img = img.resize((width, height), Image.BICUBIC)
else:
logging.info('Unknown mode {}'.format(mode))
sys.exit()
if self.mode != 'test':
mask = np.array(mask)
mask_copy = mask.copy()
for k, v in id_to_trainid.items():
mask_copy[mask == k] = v
mask = Image.fromarray(mask_copy.astype(np.uint8))
# Image Transformations
if self.joint_transform_list is not None:
for idx, xform in enumerate(self.joint_transform_list):
if idx == 0 and centroid is not None:
# HACK
# We assume that the first transform is capable of taking
# in a centroid
img, mask = xform(img, mask, centroid)
else:
img, mask = xform(img, mask)
# Debug
if self.dump_images and centroid is not None:
outdir = './dump_imgs_{}'.format(self.mode)
os.makedirs(outdir, exist_ok=True)
dump_img_name = trainid_to_name[class_id] + '_' + img_name
out_img_fn = os.path.join(outdir, dump_img_name + '.png')
out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png')
mask_img = colorize_mask(np.array(mask))
img.save(out_img_fn)
mask_img.save(out_msk_fn)
if self.transform is not None:
img = self.transform(img)
if self.mode == 'test':
img_keepsize = self.transform(img_keepsize)
mask = img_keepsize
if self.target_transform is not None:
if self.mode != 'test':
mask = self.target_transform(mask)
return img, mask, img_name
def __len__(self):
return len(self.imgs_uniform)
1:首先确定quality,即时semantic或者image或者depth。接着确认mode是train或者eval。Maxskip=0紧跟着两个transformer操作,包括交叉验证的cv。
2:如果mode=test模式,输入quality, mode, self.maxSkip, cv_split=self.cv_split参数到make_test_dataset中,即制作test数据集。在make_test_dataset函数中,img_dir_name = “testing”,图片路径位于KITTI数据集根目录下的testing文件下的image_2文件下的图片。
3:列出图片路径下所有图片,并将其排序。遍历文件夹下的每一张图片,作为一个item添加到一个空列表中。返回最终的列表。
def make_test_dataset(quality, mode, maxSkip=0, cv_split=0):
items = []
assert quality == 'semantic'
assert mode == 'test'
img_dir_name = "testing"
img_path = os.path.join(root, img_dir_name, 'image_2')
c_items = os.listdir(img_path)
c_items.sort()
for it in c_items:
item = (os.path.join(img_path, it), None)
items.append(item)#对可迭代对象进行排序
logging.info('KITTI has a total of {} test images'.format(len(items)))
return items, []
4:即测试及图片位于self.imgs中。如果mode=eval,则将quality, mode, self.maxSkip, cv_split=self.cv_split, hardnm=self.hardnm输入到make_dataset中。
def make_dataset(quality, mode, maxSkip=0, cv_split=0, hardnm=0):
items = []
all_items = []
aug_items = []
assert quality == 'semantic'
assert mode in ['train', 'val', 'trainval']
# note that train and val are randomly determined, no official split
img_dir_name = "training"
img_path = os.path.join(root, img_dir_name, 'image_2')#/kitti/training/image_2
mask_path = os.path.join(root, img_dir_name, 'semantic')#/kitti/training/semantic
c_items = os.listdir(img_path)#/kitti/training/image_2下所有的图片
c_items.sort()
for it in c_items:
item = (os.path.join(img_path, it), os.path.join(mask_path, it))#image下图片和标签下的图片
all_items.append(item)
logging.info('KITTI has a total of {} images'.format(len(all_items)))
# split into train/val
train_set, val_set = get_train_val(cv_split, all_items)
if mode == 'train':
items = train_set
elif mode == 'val':
items = val_set
elif mode == 'trainval':
items = train_set + val_set
else:
logging.info('Unknown mode {}'.format(mode))
sys.exit()
logging.info('KITTI-{}: {} images'.format(mode, len(items)))
return items, aug_items
5:img_path和mask位于根目录下的training文件下的image_2和Semantic文件夹下。
列出/kitti/training/image_2下所有的图片,并进行排序。
6:遍历两个文件夹下所有图片,键给每一张rgb和mask一一对应最为一个列表。根据cv将traing文件下的图片进行划分。
7:十折交叉验证,200个图片,分成10份,每份20个图片,去20个作为验证,180个作为训练。
假如:cv_split=0,则i=0,判断0是否在val_0列表中,不在,执行else,将all_items[0]对应的图片(img,mask)添加到train_set,同理,一直执行,知道val_set里面添加20个图片,train_set里面有180张图片。
8:如果mode=train,items=train_set,如果mode=val,则items包含的是20张图片。最后输出items。即self.imgs=items。
def get_train_val(cv_split, all_items):
# 90/10 train/val split, three random splits for cross validation
val_0 = [1,5,11,29,35,49,57,68,72,82,93,115,119,130,145,154,156,167,169,189,198]
val_1 = [0,12,24,31,42,50,63,71,84,96,101,112,121,133,141,155,164,171,187,191,197]
val_2 = [3,6,13,21,41,54,61,73,88,91,110,121,126,131,142,149,150,163,173,183,199]
train_set = []
val_set = []
if cv_split == 0:
for i in range(200):
if i in val_0:
val_set.append(all_items[i])
else:
train_set.append(all_items[i])
elif cv_split == 1:
for i in range(200):
if i in val_1:
val_set.append(all_items[i])
else:
train_set.append(all_items[i])
elif cv_split == 2:
for i in range(200):
if i in val_2:
val_set.append(all_items[i])
else:
train_set.append(all_items[i])
else:
logging.info('Unknown cv_split {}'.format(cv_split))
sys.exit()
return train_set, val_set
9:生成一个json文件:将centroids以json文件格式写入outfile(json_fn)。
10:生成的imgs_uniform代替imgs。
11:使用getitem遍历数据集。通过index取元素。
如果mode=test,打开img路径下的rgb图片,没有mask,所以mask为None。如果mode=其他(train,val),打开(rgb,mask)。接着获取图片的名字。如果mode=val,将图片和mask 分别resize为1242,376.
12:如果mode不为test,将mask转换为np模式,并拷贝到mask_copy,遍历
id_to_trainid ,k为原始的33类的类别,v为经过转换后的19类类别。
id_to_trainid = cityscapes_labels.label2trainid
13:mask_copy中mask==k的像素点重新赋值为v,并转换为tensor格式。接着对image和mask进行处理。
生成的训练集和验证集通过dataloader加载。最后输出train_loader, val_loader, train_set。
train_loader = DataLoader(train_set, batch_size=args.train_batch_size,num_workers=args.num_workers, shuffle=(train_sampler is None), drop_last=True, sampler = train_sampler)
val_loader = DataLoader(val_set, batch_size=args.val_batch_size,
num_workers=args.num_workers // 2 , shuffle=False, drop_last=False, sampler = val_sampler)
return train_loader, val_loader, train_set
import os
import sys
import numpy as np
from PIL import Image
from torch.utils import data
import logging
import datasets.uniform as uniform
import datasets.cityscapes_labels as cityscapes_labels
import json
from config import cfg
trainid_to_name = cityscapes_labels.trainId2name#(0:road)
id_to_trainid = cityscapes_labels.label2trainid#(0:255,1:255)
num_classes = 19
ignore_label = 255
root = cfg.DATASET.KITTI_DIR
aug_root = cfg.DATASET.KITTI_AUG_DIR
#调色板
palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153,
153, 153, 153, 250, 170, 30,
220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60,
255, 0, 0, 0, 0, 142, 0, 0, 70,
0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
zero_pad = 256 * 3 - len(palette)
for i in range(zero_pad):
palette.append(0)
def colorize_mask(mask):
# mask: numpy array of the mask
new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
new_mask.putpalette(palette)
return new_mask
def get_train_val(cv_split, all_items):
# 90/10 train/val split, three random splits for cross validation
val_0 = [1,5,11,29,35,49,57,68,72,82,93,115,119,130,145,154,156,167,169,189,198]
val_1 = [0,12,24,31,42,50,63,71,84,96,101,112,121,133,141,155,164,171,187,191,197]
val_2 = [3,6,13,21,41,54,61,73,88,91,110,121,126,131,142,149,150,163,173,183,199]
train_set = []
val_set = []
if cv_split == 0:
for i in range(200):
if i in val_0:
val_set.append(all_items[i])
else:
train_set.append(all_items[i])
elif cv_split == 1:
for i in range(200):
if i in val_1:
val_set.append(all_items[i])
else:
train_set.append(all_items[i])
elif cv_split == 2:
for i in range(200):
if i in val_2:
val_set.append(all_items[i])
else:
train_set.append(all_items[i])
else:
logging.info('Unknown cv_split {}'.format(cv_split))
sys.exit()
return train_set, val_set
def make_dataset(quality, mode, maxSkip=0, cv_split=0, hardnm=0):
items = []
all_items = []
aug_items = []
assert quality == 'semantic'
assert mode in ['train', 'val', 'trainval']
# note that train and val are randomly determined, no official split
img_dir_name = "training"
img_path = os.path.join(root, img_dir_name, 'image_2')#/kitti/training/image_2
mask_path = os.path.join(root, img_dir_name, 'semantic')#/kitti/training/semantic
c_items = os.listdir(img_path)#/kitti/training/image_2下所有的图片
c_items.sort()
for it in c_items:
item = (os.path.join(img_path, it), os.path.join(mask_path, it))#image下图片和标签下的图片
all_items.append(item)
logging.info('KITTI has a total of {} images'.format(len(all_items)))
# split into train/val
train_set, val_set = get_train_val(cv_split, all_items)
if mode == 'train':
items = train_set
elif mode == 'val':
items = val_set
elif mode == 'trainval':
items = train_set + val_set
else:
logging.info('Unknown mode {}'.format(mode))
sys.exit()
logging.info('KITTI-{}: {} images'.format(mode, len(items)))
return items, aug_items
def make_test_dataset(quality, mode, maxSkip=0, cv_split=0):
items = []
assert quality == 'semantic'
assert mode == 'test'
img_dir_name = "testing"
img_path = os.path.join(root, img_dir_name, 'image_2')
c_items = os.listdir(img_path)
c_items.sort()
for it in c_items:
item = (os.path.join(img_path, it), None)
items.append(item)#对可迭代对象进行排序
logging.info('KITTI has a total of {} test images'.format(len(items)))
return items, []
class KITTI(data.Dataset):
def __init__(self, quality, mode, maxSkip=0, joint_transform_list=None,
transform=None, target_transform=None, dump_images=False,
class_uniform_pct=0, class_uniform_tile=0, test=False,
cv_split=None, scf=None, hardnm=0):
self.quality = quality#’semantic‘
self.mode = mode#'train'
self.maxSkip = maxSkip#0
self.joint_transform_list = joint_transform_list#transformer
self.transform = transform #train需要做的变换
self.target_transform = target_transform#标签转换为tensor
self.dump_images = dump_images
self.class_uniform_pct = class_uniform_pct#0.5
self.class_uniform_tile = class_uniform_tile#1024
self.scf = scf
self.hardnm = hardnm
if cv_split:#交叉验证
self.cv_split = cv_split
assert cv_split < cfg.DATASET.CV_SPLITS, \
'expected cv_split {} to be < CV_SPLITS {}'.format(
cv_split, cfg.DATASET.CV_SPLITS)
else:
self.cv_split = 0
if self.mode == 'test':
self.imgs, _ = make_test_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split)
else:
self.imgs, _ = make_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split, hardnm=self.hardnm)
assert len(self.imgs), 'Found 0 images, please check the data set'
# Centroids for GT data
if self.class_uniform_pct > 0:
if self.scf:
json_fn = 'kitti_tile{}_cv{}_scf.json'.format(self.class_uniform_tile, self.cv_split)
else:
json_fn = 'kitti_tile{}_cv{}_{}_hardnm{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode, self.hardnm)
if os.path.isfile(json_fn):
with open(json_fn, 'r') as json_data:
centroids = json.load(json_data)
self.centroids = {
int(idx): centroids[idx] for idx in centroids}
else:
if self.scf:
self.centroids = kitti_uniform.class_centroids_all(
self.imgs,
num_classes,
id2trainid=id_to_trainid,
tile_size=class_uniform_tile)
else:
self.centroids = uniform.class_centroids_all(
self.imgs,
num_classes,
id2trainid=id_to_trainid,
tile_size=class_uniform_tile)
with open(json_fn, 'w') as outfile:
json.dump(self.centroids, outfile, indent=4)
self.build_epoch()
def build_epoch(self, cut=False):
if self.class_uniform_pct > 0:
self.imgs_uniform = uniform.build_epoch(self.imgs,
self.centroids,
num_classes,
cfg.CLASS_UNIFORM_PCT)
else:
self.imgs_uniform = self.imgs
def __getitem__(self, index):
elem = self.imgs_uniform[index]
centroid = None
if len(elem) == 4:
img_path, mask_path, centroid, class_id = elem
else:
img_path, mask_path = elem
if self.mode == 'test':
img, mask = Image.open(img_path).convert('RGB'), None
else:
img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
img_name = os.path.splitext(os.path.basename(img_path))[0]
# kitti scale correction factor
if self.mode == 'train' or self.mode == 'trainval':
if self.scf:
width, height = img.size
img = img.resize((width*2, height*2), Image.BICUBIC)
mask = mask.resize((width*2, height*2), Image.NEAREST)
elif self.mode == 'val':
width, height = 1242, 376
img = img.resize((width, height), Image.BICUBIC)
mask = mask.resize((width, height), Image.NEAREST)
elif self.mode == 'test':
img_keepsize = img.copy()
width, height = 1280, 384
img = img.resize((width, height), Image.BICUBIC)
else:
logging.info('Unknown mode {}'.format(mode))
sys.exit()
if self.mode != 'test':
mask = np.array(mask)
mask_copy = mask.copy()
for k, v in id_to_trainid.items():
mask_copy[mask == k] = v
mask = Image.fromarray(mask_copy.astype(np.uint8))
# Image Transformations
if self.joint_transform_list is not None:
for idx, xform in enumerate(self.joint_transform_list):
if idx == 0 and centroid is not None:
# HACK
# We assume that the first transform is capable of taking
# in a centroid
img, mask = xform(img, mask, centroid)
else:
img, mask = xform(img, mask)
# Debug
if self.dump_images and centroid is not None:
outdir = './dump_imgs_{}'.format(self.mode)
os.makedirs(outdir, exist_ok=True)
dump_img_name = trainid_to_name[class_id] + '_' + img_name
out_img_fn = os.path.join(outdir, dump_img_name + '.png')
out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png')
mask_img = colorize_mask(np.array(mask))
img.save(out_img_fn)
mask_img.save(out_msk_fn)
if self.transform is not None:
img = self.transform(img)
if self.mode == 'test':
img_keepsize = self.transform(img_keepsize)
mask = img_keepsize
if self.target_transform is not None:
if self.mode != 'test':
mask = self.target_transform(mask)
return img, mask, img_name
def __len__(self):
return len(self.imgs_uniform)