GFPGAN源码分析—第十三篇

2021SC@SDUSC

源码:

models\gfpgan_model.py

本篇分析models\gfpgan_model.py下的

class GFPGANModel(BaseModel) 类的最后几个方法

目录

class GFPGANModel(BaseModel)

_log_validation_metric_values

get_current_visuals(self)

save(self, epoch, current_iter)


class GFPGANModel(BaseModel)

_log_validation_metric_values

def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
    log_str = f'Validation {dataset_name}\n'
    for metric, value in self.metric_results.items():
        log_str += f'\t # {metric}: {value:.4f}\n'
    logger = get_root_logger()
    logger.info(log_str)
    #保存程序中的数据
    if tb_logger:
        for metric, value in self.metric_results.items():
            tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)

get_current_visuals(self)

def get_current_visuals(self):
    #创建记住插入顺序的字典
    out_dict = OrderedDict()
    out_dict['gt'] = self.gt.detach().cpu()
    #移至cpu 返回值是cpu上的Tensor
    out_dict['sr'] = self.output.detach().cpu()
    return out_dict

save(self, epoch, current_iter)

网络保存

def save(self, epoch, current_iter):
    #保存网络
    self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
    self.save_network(self.net_d, 'net_d', current_iter)
    # 保存组件鉴别器,主要是面部组件
    if self.use_facial_disc:
        self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
        self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
        self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
       #保存训练状态,可以用于恢复
    self.save_training_state(epoch, current_iter)

介绍一下save_network与save_training_state函数的几个参数

save_network Args:
    net (nn.Module | list[nn.Module]): .
    net_label (str): 网络标签(Network label).
    current_iter (int): Current iter number.
save_training_state Args:
    epoch (int): Current epoch.
    current_iter (int): Current iteration.   

Guess you like

Origin blog.csdn.net/Vaifer233/article/details/122172608