Image segmentation suite PaddleSeg comprehensive analysis (two) Config code interpretation

The previous article introduced the overall situation of the image segmentation PaddleSeg suite, and introduced the training entry file train.py. The configuration file is parsed in the train.py file to obtain training parameters. This article mainly introduces how to parse configuration files through the Config class.

The Config class is defined in the paddleseg/cvlibs/config.py file. It saves all hyperparameters such as data set configuration, model configuration, backbone network configuration, loss function configuration and so on.

In PaddleSeg, the configuration is saved by using YAML files. The advantage of this method is that you only need to modify the YAML or create a new YAML file to create a new training task.

The syntax of YAML is relatively simple, and the file structure is easy to read. Let’s start with the configuration file of the most basic FCN network for image segmentation to learn how to generate Config objects from YAML files.

For example, take a look at the content of the dygraph/configs/fcn/fcn_hrnetw18_cityscapes_1024x512_80k.yml file:

# _base_ 不是必须的,其作用更像基类。
# _base_指定的文件可以保存通用的配置,避免相同配置重复书写。若存在相同配置,会覆盖_base_指定yml文件的配置。
_base_: '../_base_/cityscapes.yml'

#模型信息
model:
  #模型的类型FCN
  type: FCN
  #使用的主干网络为HRNet 
  backbone:
    type: HRNet_W18
    #主干网络的预训练模型的下载地址。
    pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz
  #模型分类数为19,可根据实际情况修改
  num_classes: 19
  #模型的预训练地址,这里为空
  pretrained: Null
  #这个是创建模型时需要传入的参数,该参数可以根据具体模型情况进行自定义设置,这个结合模型在具体讲解。
  backbone_indices: [-1]

#优化器设置,这里只设置了正则化的衰减系数,原因是因为在base里面已经设置了优化器的名称和学习率。
optimizer:
  weight_decay: 0.0005
#总迭代次数为80000次。
iters: 80000

Let's take a look at the contents of the cityscape.yml file below:

#如果fcn的配置文件,配置了相同内容会覆盖本配置内容。
batch_size: 4
#迭代次数
iters: 80000
#训练集配置
train_dataset:
  #类型为Cityscapes,这里的type对应的值会在Config类中实例化具体的对象,所以名字要跟类名一致。
  #Citycapes类保存在dygraph/paddleseg/datasets/cityscapes.py文件中
  type: Cityscapes
  #指定数据集的根目录,这里没有指定具体的文件List,是因为list是在Cityscape类中生成的。
  dataset_root: data/cityscapes
  #数据增强操作
  transforms:
  #每一个type 则代表了一个数据增强操作对应的类名。下面的值则为创建对象需要传递的参数。
    - type: ResizeStepScaling
      min_scale_factor: 0.5
      max_scale_factor: 2.0
      scale_step_size: 0.25
    - type: RandomPaddingCrop
      crop_size: [1024, 512]
    - type: RandomHorizontalFlip
    - type: Normalize
  #模式为训练模式
  mode: train
#验证集配置
val_dataset:
  type: Cityscapes
  dataset_root: data/cityscapes
  transforms:
    - type: Normalize
  #模式为验证集模式
  mode: val

#优化器设置。
optimizer:
  #优化器为SGG
  type: sgd
  #动量
  momentum: 0.9
  #正则化
  weight_decay: 4.0e-5

#学习率设置
learning_rate:
  #学习率
  value: 0.01
  #学习率衰减策略
  decay:
    type: poly
    power: 0.9
    end_lr: 0.0
#损失函数设置
loss:
  types:
    #支持多种损失函数
    - type: CrossEntropyLoss
  #损失权重,若包含多个损失函数,可以在此处设置权重,权重数量需要与损失函数数量一致。
  coef: [1]

The content of the yml configuration file is introduced above, and the following is an explanation of how the Config class converts the yml file into an object. The Config code is relatively long, the following intercepts important methods for interpretation.

The construction method of the Config class:

  def __init__(self,
               path: str,
               learning_rate: float = None,
               batch_size: int = None,
               iters: int = None):
      #path为yml文件的路径,若果没有指定路径则抛出异常。
      if not path:
          raise ValueError('Please specify the configuration file path.')
      #还需要判断路径是否存在,如果不存在则抛出异常。
      if not os.path.exists(path):
          raise FileNotFoundError('File {} does not exist'.format(path))
      #初始化成员变量,模型对象和损失函数对象。
      self._model = None
      self._losses = None
      #判断配置文件类型是否为YAML。
      if path.endswith('yml') or path.endswith('yaml'):
          #如果文件类型正确,则通过_parse_from_yaml方法将文件内容保存到字典中。
          self.dic = self._parse_from_yaml(path)
      else:
          raise RuntimeError('Config file should in yaml format!')
      #更新配置中的learning_rate、batch_size和iters三个参数,这个三个参数是通过命令行传递过来的,
      #优先级高于yaml配置,会覆盖配置文件中的配置。
      self.update(
          learning_rate=learning_rate, batch_size=batch_size, iters=iters)

Let's take a look at the source code of the _parse_from_yaml method encountered in the constructor:

    def _parse_from_yaml(self, path: str):
        '''Parse a yaml file and build config'''
        #首先打开配置文件,通过yaml库中的load方法转换为字典。yaml为第三方库,可以同pip安装。具体使用方法参考yaml相关文档。
        with codecs.open(path, 'r', 'utf-8') as file:
            dic = yaml.load(file, Loader=yaml.FullLoader)
		#判断_base_是否在字典中,本次使用的FCN的配置文件是包含的也就是上面讲解的cityscape.yml文件。
        if '_base_' in dic:
            #同样获取cityscape.yml的路径然后通过本方法获取base配置的字典。
            cfg_dir = os.path.dirname(path)
            base_path = dic.pop('_base_')
            base_path = os.path.join(cfg_dir, base_path)
            #递归调用,因为cityscape.yml中并不包含_base_,所以下面的方法就不会执行到现在这部分代码。
            base_dic = self._parse_from_yaml(base_path)
            #更新dic字典中的内容。
            dic = self._update_dic(dic, base_dic)
        return dic

The following is an explanation of the update method in the constructor. This method is relatively simple to update the learning rate, batch size, and iters.

    def update(self,
               learning_rate: float = None,
               batch_size: int = None,
               iters: int = None):
        '''Update config'''
        #如果learning_rate存在,更新字典中的值。
        if learning_rate:
            self.dic['learning_rate']['value'] = learning_rate
        #更新batch_size
        if batch_size:
            self.dic['batch_size'] = batch_size
        #更新iters。
        if iters:
            self.dic['iters'] = iters

Call the _update_dic method in _parse_from_yaml to update the dictionary parameters. Let's look at the difference with the update above.

    def _update_dic(self, dic, base_dic):
        """
        Update config from dic based base_dic
        """
        #首先复制一个basc_dic
        base_dic = base_dic.copy()
        #遍历dic中的键值对。
        for key, val in dic.items():
        	#如果dic中的值的类型为字典,同时这个键在base_dic中存在,则需要使用base_dic中值进行更新。
            #递归调用本方法进行更新,直到val类型是基本类型。
            if isinstance(val, dict) and key in base_dic:
                base_dic[key] = self._update_dic(val, base_dic[key])
            #如果是基本类型则直接更新,上面递归到此处会停止,在下面return处直接返回。
            else:
                base_dic[key] = val
        dic = base_dic
        return dic

The Config class also contains many methods annotated with @property, corresponding to the train_dataset, val_dataset, model, loss and other configurations in the yaml configuration file. As mentioned earlier, these configurations will contain a key named type, and its corresponding value is the name of the class. The method annotated with property will create the object by the name of the class and return the object to the user. The lazy loading method is used here, and it will be created only when it is called. Let's give an example of model attributes to explain. The workflow of other attributes is similar.

@property
  def model(self) -> paddle.nn.Layer:
      #从Config的配置字典中获取model的配置内容对应yaml文件中的部分如下:
      #model:
  	  #type: FCN
      #backbone:
      #		type: HRNet_W18
      #		pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz
  	  #num_classes: 19
  	  #pretrained: Null
      #backbone_indices: [-1]
      
      model_cfg = self.dic.get('model').copy()
      #使用train_dataset配置中的类别数量覆盖model中的配置
      model_cfg['num_classes'] = self.train_dataset.num_classes
      #如果model_cfg 不存在则抛出异常
      if not model_cfg:
          raise RuntimeError('No model specified in the configuration file.')
      #在构造函数中_model配置为None,这里只创建一次模型对象。
      if not self._model:
          #创建模型对象。下面会继续解读_load_object方法。
          self._model = self._load_object(model_cfg)
      return self._model

Interpretation of the _load_object method:

 def _load_object(self, cfg: dict) -> Any:
 		#拷贝一份配置,因为需要通过type的值创建对象,所以如果cfg中不包含type键则会抛出异常。
        cfg = cfg.copy()
        if 'type' not in cfg:
            raise RuntimeError('No object information in {}.'.format(cfg))
        #通过_load_component方法根据type的值获取类组件,这里的组件都是在定义各个类的时候通过
        #装饰器添加到manager维护的List中的,所以这里可以直接获取。至于如何加入list会在第3节接触到。
        component = self._load_component(cfg.pop('type'))
		#此处获取创建对象需要传递的参数,保存在params中。
        params = {}
        #遍历cfg中的键值对。
        for key, val in cfg.items():
            #这里使用_is_meta_type方法来判断val是字典同时也包含type值,如果包含的的话说明val对应的也是一个对象,
            #需要使用递归的方式获取到,直到参数类型为简单对象。
            if self._is_meta_type(val):
                params[key] = self._load_object(val)
            #如果参数是一个列表,则需要遍历列表中的内容,判断是否需要递归创建对象。
            elif isinstance(val, list):
                params[key] = [
                    self._load_object(item)
                    if self._is_meta_type(item) else item for item in val
                ]
            #遇到基本类型,保存参数。
            else:
                params[key] = val
		#遍历借宿创建对象。
        return component(**params)

The interpretation of the Config class code is now complete.

PaddleSeg warehouse address: https://github.com/PaddlePaddle/PaddleSeg

Guess you like

Origin blog.csdn.net/txyugood/article/details/111031176