图像质量评价(IQA)读库代码详细介绍

        相信很多朋友在做视觉工作入门的时候首先都会接触到读库代码,那么在图像质量评价方向中的读库代码该如何实现呢?接下来我会给大家介绍一段详细的读库代码,代码框架我主要是从2014年kang等人利用CNN进行图像质量评价的源代码进行修改的。代码基于pytorch框架。

        在进行读库代码介绍前,我们需要有一些先验知识储备。

Dataset和Dataloader

        PyTorch提供torch.utils.data.DataLoader  torch.utils.data.Dataset允许你使用预下载的数据集或自己制作的数据。

Dataloader

        这边我直接用一段代码来解释,IQAdatast在下面提到了。这里偷个懒就不说明dataloader里的参数了,代码浅显易懂,根据参数名字也能判断。

train_data = IQAdatasets(
                root=config.folder_path[args.dataset], index=train_index, transform=transforms,transform_gray = transforms_gray , patch_num=args.train_patch_num)
  
train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=0)

Dataset

        我的上一篇博客提到了torchvision.datasets,链接:torchvision中的dataset

        那么torchvision.datasets和这里的torch.utils.data.Dataset有什么区别呢?

1.torchvision.datasets

        从名字中就可以看到,这个datasets不仅是小写,说明它作用的范围小,而且加了个s,这就说明它是复数,为什么是复数呢,是因为如果我们需要不自己设置训练数据集而去使用官方给的数据集的时候,它里边有。

2.torch.utils.data.Dataset

        这个模块更偏重于自己独立的去建立一个新的训练数据集,需要自己去设定参数

代码详解

        自定义数据集类必须实现三个函数:__init__, __len__, __getitem__

import torch
import torch.utils.data as data
import torchvision
from PIL import Image
import os
import os.path

def pil_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

def gray_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('L')

class IQAdatasets(data.Dataset):

    def __init__(self, root, index, transform,transform_gray, patch_num):

        imgpath = []
        ref_names = []
        for line1 in open("data/im_names.txt", "r"):
            line1 = line1.strip()
            path = os.path.join(root, line1)
            # print(path)
            imgpath.append(path)

        labels = []
        for line5 in open("data/mos.txt", "r"):
            line5 = float(line5.strip())
            labels.append(line5)


        sample = []
        for i, item in enumerate(index):
            for aug in range(patch_num):
                # print(item)
                sample.append((imgpath[item ], labels[item ]))
                
        self.samples = sample
        self.transform = transform
        self.transform_gray = transform_gray

    def __getitem__(self, index):
       
        path, target = self.samples[index]
        sample = pil_loader(path)
        sample = self.transform(sample)
        sample_gra = pil_loader(path)
        sample_gray = self.transform_gray(sample_gra)
        return (sample , sample_gray), target

    def __len__(self):
        length = len(self.samples)
        return length

        在上述代码中,我们重写了data.Dataset方法,在这里,我返回的主要是三通道图像和灰度图像的数据集,代码中的 data/im_names.txt是我自己从数据集中的mat文件中提取出来的标签信息,如果大家要用其他数据集的话,这里也可以跟我一样提取出来到txt中或者用h5py来提取LivefullInfo.mat等数据集中的标签。

        这段代码最主要的地方就在于:

def __getitem__(self, index):
       
        path, target = self.samples[index]
        sample = pil_loader(path)
        sample = self.transform(sample)
        sample_gra = pil_loader(path)
        sample_gray = self.transform_gray(sample_gra)
        return (sample , sample_gray), target

        这里我们返回了(sample , sample_gray), target三类数据,分别是三通道图像、灰度图像、图像标签(在IQA任务中,图像标签就是MOS值、STD值、图像路径等)

        以上就是自定义数据集来做IQA任务的读库代码,现在我们来进行调试一下:

if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', dest='dataset', type=str, default='live',
                        help='Support datasets: livec|koniq-10k|bid|live|csiq|tid2013')
    parser.add_argument('--train_patch_num', dest='train_patch_num', type=int, default=25,
                        help='Number of sample patches from training image')
    parser.add_argument('--test_patch_num', dest='test_patch_num', type=int, default=25,
                        help='Number of sample patches from testing image')
    parser.add_argument('--lr', dest='lr', type=float, default=1e-5, help='Learning rate')
    parser.add_argument('--weight_decay', dest='weight_decay', type=float, default=5e-4, help='Weight decay')
    parser.add_argument('--lr_ratio', dest='lr_ratio', type=int, default=10,
                        help='Learning rate ratio for hyper network')
    parser.add_argument('--batch_size', dest='batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--epochs', dest='epochs', type=int, default=5, help='Epochs for training')
    parser.add_argument('--patch_size', dest='patch_size', type=int, default=224,
                        help='Crop size for training & testing image patches')
    parser.add_argument('--train_test_num', dest='train_test_num', type=int, default=10, help='Train-test times')
    args = parser.parse_args()

    sel_num = config.img_num[args.dataset]
    train_index = sel_num[0:int(round(0.8 * len(sel_num)))]
    test_index = sel_num[int(round(0.8 * len(sel_num))):len(sel_num)]
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.RandomCrop(size=args.patch_size),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                         std=(0.229, 0.224, 0.225))
    ])
    transforms_gray = torchvision.transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.RandomCrop(size=args.patch_size),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=(0.5),
                                         std=(0.5))
    ])
    train_data = IQAdatasets(
                root=config.folder_path[args.dataset], index=train_index, transform=transforms,transform_gray = transforms_gray , patch_num=args.train_patch_num)
  
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=0)
    for index , (sample , target) in tqdm(enumerate(train_loader)):
        x_rgb = sample[0].to(device)
        x_gray = sample[1].to(device)

        我这里是对训练数据进行了随机裁剪成224*224的大小并进行归一化。 这段代码最独特的设计在于,我们这里对训练集数据和测试集数据进行了train_patch_num和test_patch_num的设计,多次随机裁剪,可以获取到图像的更多信息。

        之后我会继续分享图像平均切割的读库代码,原理都是类似的。

猜你喜欢

转载自blog.csdn.net/qq_37925923/article/details/127359242