GFPGAN源码分析—第十二篇

2021SC@SDUSC

源码:

models\gfpgan_model.py

本篇分析models\gfpgan_model.py下的

class GFPGANModel(BaseModel) 类的部分方法

class GFPGANModel(BaseModel)

目录

class GFPGANModel(BaseModel)

test(self)

dist_validation()

nondist_validation()


test(self)

测试

def test(self):
    #使用 with torch.no_grad():,强制之后的内容不进行计算图构建。
    with torch.no_grad():
        if hasattr(self, 'net_g_ema'):
            self.net_g_ema.eval()
            self.output, _ = self.net_g_ema(self.lq)
        else:
            logger = get_root_logger()
            logger.warning('Do not have self.net_g_ema, use self.net_g.')
            self.net_g.eval()
            self.output, _ = self.net_g(self.lq)
            self.net_g.train()

dist_validation()

def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
    if self.opt['rank'] == 0:
        #调用nondist_validation函数进行处理
        self.nondist_validation(dataloader, current_iter, tb_logger, save_img)

nondist_validation()

参数:
self, dataloader, current_iter, tb_logger, save_img

分几步看一下代码

1.进度条与with_metrics的初始化

dataset_name = dataloader.dataset.opt['name']
#确认with_metrics is not None
with_metrics = self.opt['val'].get('metrics') is not None
if with_metrics:
    self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
    #进度条
pbar = tqdm(total=len(dataloader), unit='image')

2.遍历dataloader,做fead data以及图像变换保存等

for idx, val_data in enumerate(dataloader):
    #分离文件名与扩展名,返回一个元组。
    img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
    #调用fead_data处理val_data
    self.feed_data()
    self.test()
	#调用get_current_visuals
    visuals = self.get_current_visuals()
    #将torch张量转换为图像numpy数组
    sr_img = tensor2img([visuals['sr']], min_max=(-1, 1))
    gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))

    if 'gt' in visuals:
        gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
        del self.gt
    # tentative for out of GPU memory
    del self.lq
    del self.output
    torch.cuda.empty_cache()
#如果需要保存图片
    if save_img:
        #首先设置路径
        if self.opt['is_train']:
            save_img_path = osp.join(self.opt['path']['visualization'], img_name,
                                     f'{img_name}_{current_iter}.png')
        else:
            if self.opt['val']['suffix']:
                save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
                                         f'{img_name}_{self.opt["val"]["suffix"]}.png')
            else:
                save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
                                         f'{img_name}_{self.opt["name"]}.png')
        imwrite(sr_img, save_img_path)

    if with_metrics:
        # calculate metrics
        for name, opt_ in self.opt['val']['metrics'].items():
            metric_data = dict(img1=sr_img, img2=gt_img)
            self.metric_results[name] += calculate_metric(metric_data, opt_)
     #更新进度条
    pbar.update(1)
    pbar.set_description(f'Test {img_name}')
pbar.close()

3.调用_log_validation_metric_values

#with_metrics一定为True
if with_metrics:
    for metric in self.metric_results.keys():
        self.metric_results[metric] /= (idx + 1)
#调用_log_validation_metric_values
    self._log_validation_metric_values(current_iter, dataset_name, tb_logger)

Guess you like

Origin blog.csdn.net/Vaifer233/article/details/122172474