网络训练时参数解析的三种方式(yaml、argparser、class属性)

一个模型中包含众多的训练参数,如文件保存目录、数据集目录、学习率、epoch数量、模块中的参数等。 参数解析常用的有yaml文件解析还有argparser解析。

但也不是绝对的,我也有看见别人喜欢直接全局定义方式去写,当然在本章我会讲解一种更加简单的参数解析的方式,这也是我个人喜欢的方式。

yaml文件解析

将所有参数都放在yaml文件中,通过读取yaml文件来配置参数。这种常见于比较复杂的项目,例如有多个模型,对应多组参数。这样就可以每个模型配置一个yaml文件,里面对应的是每个模型的对应的参数。

pip install pyyaml

这里我随意写了一个yaml文件

# config.yaml

# Directory to save files
save_dir: 'path/to/save/directory'

# Dataset directory
dataset_dir: 'path/to/dataset'

# Training parameters
learning_rate: 0.001
epochs: 10

# Model parameters
model_params:
  hidden_size: 128
  num_layers: 3
  dropout: 0.2

 使用pyyaml进行解析

import yaml

with open('config.yaml', 'r') as file:
    config = yaml.safe_load(file)

print("Save Directory:", config['save_dir'])
print("Dataset Directory:", config['dataset_dir'])
print("Learning Rate:", config['learning_rate'])
print("Epochs:", config['epochs'])
print("Model Parameters:")
print("  Hidden Size:", config['model_params']['hidden_size'])
print("  Num Layers:", config['model_params']['num_layers'])
print("  Dropout:", config['model_params']['dropout'])

这里的打印输出结果为:


Save Directory: path/to/save/directory
Dataset Directory: path/to/dataset
Learning Rate: 0.001
Epochs: 10
Model Parameters:
  Hidden Size: 128
  Num Layers: 3
  Dropout: 0.2

 
argparser解析

argparser解析的形式一般放在train.py文件的最前面,适用于参数相对比较少,每次只需要改一 两个参数的情况。

import argparse

def parse_arguments():
    parser = argparse.ArgumentParser(description='Your Description Here')

    parser.add_argument('--output_dir', default='output', help='Directory to save results')
    parser.add_argument('--data_dir', default='data', help='Directory where dataset is stored')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate for training')
    parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs')

    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_arguments()

    print("Loaded configuration:")
    print(vars(args))

Loaded configuration:
{'output_dir': 'output', 'data_dir': 'data', 'learning_rate': 0.001, 'epochs': 10}

argparse其还支持命令行运行,这在很多项目中都会用到这种方式。

类定义解析

class parser_args():
    def __init__(self):
        self.net = "vgg16"
        self.Cuda = True
        self.EPOCHS = 100
        self.batch_size = 4
        self.warm = 1

    def _help(self):
        stc = {
            "log_dir": "存放训练模型.pth的路径",
            "Cuda": "是否使用Cuda,如果没有GPU,可以使用CUP,i.e: Cuda=False",
            "EPOCHS": "训练的轮次,这里默认就跑100轮",
            "batch_size": "批量大小,一般为1,2,4",
            "warm": "控制学习率的'热身'或'预热'过程"
        }

实例化类对象:args = parser_args()

这种方式相对简洁,易于理解,并且在实例化类对象后可以直接访问这些参数。这种方式在参数不太频繁变动的情况下比较的合适。

总结

总体来说,选择使用哪种方式通常取决于项目的需求和个人偏好。YAML 适用于配置文件,argparse 适用于命令行参数,而类属性适用于更结构化的面向对象设计。

猜你喜欢

转载自blog.csdn.net/m0_62919535/article/details/134362040
今日推荐