Suite de segmentación de imágenes Análisis completo de PaddleSeg (dos) Interpretación del código de configuración

El artículo anterior presentó la situación general de la suite de segmentación de imágenes PaddleSeg e introdujo el archivo de entrada de entrenamiento train.py. El archivo de configuración se analiza en el archivo train.py para obtener los parámetros de entrenamiento. Este artículo presenta principalmente cómo analizar archivos de configuración a través de la clase Config.

La clase Config se define en el archivo paddleseg / cvlibs / config.py. Guarda todos los hiperparámetros como la configuración del conjunto de datos, la configuración del modelo, la configuración de la red troncal, la configuración de la función de pérdida, etc.

En PaddleSeg, la configuración se guarda mediante archivos YAML. La ventaja de este método es que solo necesita modificar el YAML o crear un nuevo archivo YAML para crear una nueva tarea de entrenamiento.

La sintaxis de YAML es relativamente simple y la estructura del archivo es fácil de leer. Comencemos con el archivo de configuración de la red FCN más básica para la segmentación de imágenes para aprender a generar objetos de configuración a partir de archivos YAML.

Por ejemplo, observe el contenido del archivo dygraph / configs / fcn / fcn_hrnetw18_cityscapes_1024x512_80k.yml:

# _base_ 不是必须的,其作用更像基类。
# _base_指定的文件可以保存通用的配置,避免相同配置重复书写。若存在相同配置,会覆盖_base_指定yml文件的配置。
_base_: '../_base_/cityscapes.yml'

#模型信息
model:
  #模型的类型FCN
  type: FCN
  #使用的主干网络为HRNet 
  backbone:
    type: HRNet_W18
    #主干网络的预训练模型的下载地址。
    pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz
  #模型分类数为19,可根据实际情况修改
  num_classes: 19
  #模型的预训练地址,这里为空
  pretrained: Null
  #这个是创建模型时需要传入的参数,该参数可以根据具体模型情况进行自定义设置,这个结合模型在具体讲解。
  backbone_indices: [-1]

#优化器设置,这里只设置了正则化的衰减系数,原因是因为在base里面已经设置了优化器的名称和学习率。
optimizer:
  weight_decay: 0.0005
#总迭代次数为80000次。
iters: 80000

Echemos un vistazo al contenido del archivo cityscape.yml a continuación:

#如果fcn的配置文件,配置了相同内容会覆盖本配置内容。
batch_size: 4
#迭代次数
iters: 80000
#训练集配置
train_dataset:
  #类型为Cityscapes,这里的type对应的值会在Config类中实例化具体的对象,所以名字要跟类名一致。
  #Citycapes类保存在dygraph/paddleseg/datasets/cityscapes.py文件中
  type: Cityscapes
  #指定数据集的根目录,这里没有指定具体的文件List,是因为list是在Cityscape类中生成的。
  dataset_root: data/cityscapes
  #数据增强操作
  transforms:
  #每一个type 则代表了一个数据增强操作对应的类名。下面的值则为创建对象需要传递的参数。
    - type: ResizeStepScaling
      min_scale_factor: 0.5
      max_scale_factor: 2.0
      scale_step_size: 0.25
    - type: RandomPaddingCrop
      crop_size: [1024, 512]
    - type: RandomHorizontalFlip
    - type: Normalize
  #模式为训练模式
  mode: train
#验证集配置
val_dataset:
  type: Cityscapes
  dataset_root: data/cityscapes
  transforms:
    - type: Normalize
  #模式为验证集模式
  mode: val

#优化器设置。
optimizer:
  #优化器为SGG
  type: sgd
  #动量
  momentum: 0.9
  #正则化
  weight_decay: 4.0e-5

#学习率设置
learning_rate:
  #学习率
  value: 0.01
  #学习率衰减策略
  decay:
    type: poly
    power: 0.9
    end_lr: 0.0
#损失函数设置
loss:
  types:
    #支持多种损失函数
    - type: CrossEntropyLoss
  #损失权重,若包含多个损失函数,可以在此处设置权重,权重数量需要与损失函数数量一致。
  coef: [1]

El contenido del archivo de configuración yml se presenta arriba, y la siguiente es una explicación de cómo la clase Config convierte el archivo yml en un objeto. El código de configuración es relativamente largo, lo siguiente intercepta métodos importantes para la interpretación.

El método de construcción de la clase Config:

  def __init__(self,
               path: str,
               learning_rate: float = None,
               batch_size: int = None,
               iters: int = None):
      #path为yml文件的路径,若果没有指定路径则抛出异常。
      if not path:
          raise ValueError('Please specify the configuration file path.')
      #还需要判断路径是否存在,如果不存在则抛出异常。
      if not os.path.exists(path):
          raise FileNotFoundError('File {} does not exist'.format(path))
      #初始化成员变量,模型对象和损失函数对象。
      self._model = None
      self._losses = None
      #判断配置文件类型是否为YAML。
      if path.endswith('yml') or path.endswith('yaml'):
          #如果文件类型正确,则通过_parse_from_yaml方法将文件内容保存到字典中。
          self.dic = self._parse_from_yaml(path)
      else:
          raise RuntimeError('Config file should in yaml format!')
      #更新配置中的learning_rate、batch_size和iters三个参数,这个三个参数是通过命令行传递过来的,
      #优先级高于yaml配置,会覆盖配置文件中的配置。
      self.update(
          learning_rate=learning_rate, batch_size=batch_size, iters=iters)

Echemos un vistazo al código fuente del método _parse_from_yaml encontrado en el constructor:

    def _parse_from_yaml(self, path: str):
        '''Parse a yaml file and build config'''
        #首先打开配置文件,通过yaml库中的load方法转换为字典。yaml为第三方库,可以同pip安装。具体使用方法参考yaml相关文档。
        with codecs.open(path, 'r', 'utf-8') as file:
            dic = yaml.load(file, Loader=yaml.FullLoader)
		#判断_base_是否在字典中,本次使用的FCN的配置文件是包含的也就是上面讲解的cityscape.yml文件。
        if '_base_' in dic:
            #同样获取cityscape.yml的路径然后通过本方法获取base配置的字典。
            cfg_dir = os.path.dirname(path)
            base_path = dic.pop('_base_')
            base_path = os.path.join(cfg_dir, base_path)
            #递归调用,因为cityscape.yml中并不包含_base_,所以下面的方法就不会执行到现在这部分代码。
            base_dic = self._parse_from_yaml(base_path)
            #更新dic字典中的内容。
            dic = self._update_dic(dic, base_dic)
        return dic

La siguiente es una explicación del método de actualización en el constructor. Este método es relativamente simple para actualizar la tasa de aprendizaje, el tamaño del lote y los iters.

    def update(self,
               learning_rate: float = None,
               batch_size: int = None,
               iters: int = None):
        '''Update config'''
        #如果learning_rate存在,更新字典中的值。
        if learning_rate:
            self.dic['learning_rate']['value'] = learning_rate
        #更新batch_size
        if batch_size:
            self.dic['batch_size'] = batch_size
        #更新iters。
        if iters:
            self.dic['iters'] = iters

Llame al método _update_dic en _parse_from_yaml para actualizar los parámetros del diccionario. Veamos la diferencia con la actualización anterior.

    def _update_dic(self, dic, base_dic):
        """
        Update config from dic based base_dic
        """
        #首先复制一个basc_dic
        base_dic = base_dic.copy()
        #遍历dic中的键值对。
        for key, val in dic.items():
        	#如果dic中的值的类型为字典,同时这个键在base_dic中存在,则需要使用base_dic中值进行更新。
            #递归调用本方法进行更新,直到val类型是基本类型。
            if isinstance(val, dict) and key in base_dic:
                base_dic[key] = self._update_dic(val, base_dic[key])
            #如果是基本类型则直接更新,上面递归到此处会停止,在下面return处直接返回。
            else:
                base_dic[key] = val
        dic = base_dic
        return dic

La clase Config también contiene muchos métodos anotados con @property, correspondientes a train_dataset, val_dataset, modelo, pérdida y otras configuraciones en el archivo de configuración yaml. Como se mencionó anteriormente, estas configuraciones contendrán una clave llamada type, y su valor correspondiente es el nombre de la clase. El método anotado con property creará el objeto con el nombre de la clase y devolverá el objeto al usuario. Aquí se usa el método de carga diferida, y se creará solo cuando se llame. Démosle un ejemplo de los atributos del modelo para explicar El flujo de trabajo de otros atributos es similar.

@property
  def model(self) -> paddle.nn.Layer:
      #从Config的配置字典中获取model的配置内容对应yaml文件中的部分如下:
      #model:
  	  #type: FCN
      #backbone:
      #		type: HRNet_W18
      #		pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz
  	  #num_classes: 19
  	  #pretrained: Null
      #backbone_indices: [-1]
      
      model_cfg = self.dic.get('model').copy()
      #使用train_dataset配置中的类别数量覆盖model中的配置
      model_cfg['num_classes'] = self.train_dataset.num_classes
      #如果model_cfg 不存在则抛出异常
      if not model_cfg:
          raise RuntimeError('No model specified in the configuration file.')
      #在构造函数中_model配置为None,这里只创建一次模型对象。
      if not self._model:
          #创建模型对象。下面会继续解读_load_object方法。
          self._model = self._load_object(model_cfg)
      return self._model

Interpretación del método _load_object:

 def _load_object(self, cfg: dict) -> Any:
 		#拷贝一份配置,因为需要通过type的值创建对象,所以如果cfg中不包含type键则会抛出异常。
        cfg = cfg.copy()
        if 'type' not in cfg:
            raise RuntimeError('No object information in {}.'.format(cfg))
        #通过_load_component方法根据type的值获取类组件,这里的组件都是在定义各个类的时候通过
        #装饰器添加到manager维护的List中的,所以这里可以直接获取。至于如何加入list会在第3节接触到。
        component = self._load_component(cfg.pop('type'))
		#此处获取创建对象需要传递的参数,保存在params中。
        params = {}
        #遍历cfg中的键值对。
        for key, val in cfg.items():
            #这里使用_is_meta_type方法来判断val是字典同时也包含type值,如果包含的的话说明val对应的也是一个对象,
            #需要使用递归的方式获取到,直到参数类型为简单对象。
            if self._is_meta_type(val):
                params[key] = self._load_object(val)
            #如果参数是一个列表,则需要遍历列表中的内容,判断是否需要递归创建对象。
            elif isinstance(val, list):
                params[key] = [
                    self._load_object(item)
                    if self._is_meta_type(item) else item for item in val
                ]
            #遇到基本类型,保存参数。
            else:
                params[key] = val
		#遍历借宿创建对象。
        return component(**params)

La interpretación del código de la clase Config ahora está completa.

Dirección del almacén de PaddleSeg: https://github.com/PaddlePaddle/PaddleSeg

Supongo que te gusta

Origin blog.csdn.net/txyugood/article/details/111031176
Recomendado
Clasificación