2023年11月25日~12月1日周报(继续调试OpenFWI代码)

目录

一、前言

二、学习情况

2.1 train.py的理解

2.11 定义数据集

2.12 定义损失函数、优化器

2.13 加载模型

2.14 开始训练

2.2 test.py的理解

2.21 定义测试集

2.22 加载模型

2.23  开始test

三、遇到的部分问题及解决

3.1 module 'torchvision' has no attribute '__version__'

3.2 Python之__call__的理解

四、相关参考

4.1 关键字global的用法

4.2 python 创建文件夹之 mkdir() 和makedirs()

4.3 os.environ模块环境变量解释

4.4 Tensorboard的使用 ---- SummaryWriter类

4.5 torch.backends.cudnn.benchmark = true的作用

4.6 MinMaxNormalize 规一化算法

4.7 符号函数np.sign()的介绍及用法

4.8 numpy库ndarray多维数组的运算:np.abs()

4.9 数据平滑处理之np.log1p的介绍 

4.10 DistributedSamper()的介绍

4.11 RandomSampler()的理解

4.12 argparse模块介绍

4.13 os.path.join()函数介绍

4.14 Compose()函数

五、总结

5.1 存在的疑惑

5.2 下周安排


一、前言

        上周抄写了OpenFWI代码中的InversionNet网络部分,本周继续完成剩下部分的抄写任务,并初步阅读OpenFWI的论文。

二、学习情况

2.1 train.py的理解

2.11 定义数据集

  • 创建文件保存的路径、初始化分布式模式:
utils.mkdir(args.output_path)
utils.init_distributed_mode(args)
  • 确定训练设备:
device = torch.device(args.device)
  • 确定数据与标签的归一化方式:
# Normalize data and label to [-1, 1]   将数据归一化在[-1, 1]
transform_data = Compose([
    # 数据平滑处理
    T.LogTransform(k=args.k),
    # 归一化让不同维度之间的特征在数值上有一定比较性,可以大大提高分类器的准确性。
    T.MinMaxNormalize(T.log_transform(ctx['data_min'], k=args.k), T.log_transform(ctx['data_max'], k=args.k))
])
transform_label = Compose([
    T.MinMaxNormalize(ctx['label_min'], ctx['label_max'])
])
  •  初始化训练集与验证集:
# 判断文件类型是否为txt
if args.train_anno[-3:] == 'txt':
    # 数据加载
    dataset_train = FWIDataset(
        args.train_anno, # 文件的路径
        preload=True, # 是否将整个数据集加载到内存中
        sample_ratio=args.sample_temporal, # 地震数据的下采样率
        file_size=ctx['file_size'], # 每个npy文件中的样本数
        transform_data=transform_data, # 数据转换
        transform_label=transform_label # 标签转换
    )
else:
    dataset_train = torch.load(args.train_anno)

print('Loading validation data')
if args.val_anno[-3:] == 'txt':
    dataset_valid = FWIDataset(
        args.val_anno,
        preload=True,
        sample_ratio=args.sample_temporal,
        file_size=ctx['file_size'],
        transform_data=transform_data,
        transform_label=transform_label
    )
else:
    dataset_valid = torch.load(args.val_anno)
  •  加载数据集:
if args.distributed:
    train_sampler = DistributedSampler(dataset_train, shuffle=True) # 分布式网络训练
    valid_sampler = DistributedSampler(dataset_valid, shuffle=True)
else:
    train_sampler = RandomSampler(dataset_train) # RandomSampler()表示随机对数据样本进行采样,返回的是DataSet中的索引位置(indices)
    valid_sampler = RandomSampler(dataset_valid)

# 读取数据,将train_sampler与valid_sampler加载到网络模型中进行训练
dataloader_train = DataLoader(
    dataset_train, # 传入的数据集
    batch_size=args.batch_size, # 每个batch的样本数
    sampler=train_sampler, # 自定义从数据集中取走样本的策略,若指定这个参数,shuffle必须为false
    # num_workers=args.workers,
    num_workers = 0, # 设置进程数
    pin_memory=True, # 设置为True表示dalaloader在返回它们之前,会将tensors拷贝到CUDA中的固定内存中
    drop_last=True, # 设置为True表示丢弃最后一批样本
    collate_fn=default_collate # 将一个list的sample组成一个mini-batch的函数
)

dataloader_valid = DataLoader(
    dataset_valid, batch_size=args.batch_size,
    sampler=valid_sampler,
    # num_workers=args.workers,
    num_workers=0,
    pin_memory=True, collate_fn=default_collate)

2.12 定义损失函数、优化器

  • 损失函数:
l1loss = nn.L1Loss() # 平均绝对误差(MAE)
l2loss = nn.MSELoss() # 均方误差(MSE)
  • 学习率:
lr = args.lr * args.world_size
warmup_iters = args.lr_warmup_epochs * len(dataloader_train)
lr_milestones = [len(dataloader_train) * m for m in args.lr_milestones]
# 确定学习率调整策略(学习率预热)
# (warmup策略-开始以很小的学习率进行训练,使得网络熟悉数据,随着训练的进行学习率慢慢变大,到达一定程度,以设置的初始学习率进行训练,通过一些inter后,学习率再慢慢减少)
# 学习率变化:上升—平稳—下降
lr_scheduler = WarmupMultiStepLR(
    optimizer, milestones=lr_milestones, gamma=args.lr_gamma,
    warmup_iters=warmup_iters, warmup_factor=1e-5)
  • 优化器:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=args.weight_decay)

2.13 加载模型

model_without_ddp = model
# 判断是否采用分布式训练
if args.distributed:
    # DistributedDataParallel—DDP分布式多卡训练(可实现单机多卡、多机多卡)
    model = DistributedDataParallel(model, device_ids=[args.local_rank])
    model_without_ddp = model.module

# 判断是否加载预训练模型
if args.resume:
    checkpoint = torch.load(args.resume, map_location='cpu')
    model_without_ddp.load_state_dict(network.replace_legacy(checkpoint['model']))
    optimizer.load_state_dict(checkpoint['optimizer'])
    lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
    args.start_epoch = checkpoint['epoch'] + 1
    step = checkpoint['step']
    lr_scheduler.milestones = lr_milestones

2.14 开始训练

  • 训练一个批次:
# 在该函数中会进行损失函数的计算、反向传播、更新模型、将L1、L1以及混合损失记录在Tensorboard的writer中等
train_one_epoch(model, criterion, optimizer, lr_scheduler, dataloader_train, device, epoch, args.print_freq, train_writer)

# 如下所示:
ptimizer.zero_grad()
data, label = data.to(device), label.to(device)
output = model(data)
loss, loss_g1v, loss_g2v = criterion(output, label)
loss.backward()
optimizer.step()
  • 评价模型:
loss = evaluate(model, criterion, dataloader_valid, device, val_writer)
  • 保存训练好的模型:
# 文件中保存的内容:建立字典
# 注意:若模型是由nn.Model类继承的模型,保存pth文件时,state_dict参数需要由model.state_dict指定
checkpoint = {
    'model': model_without_ddp.state_dict(),
    'optimizer': optimizer.state_dict(),
    'lr_scheduler': lr_scheduler.state_dict(),
    'epoch': epoch,
    'step': step,
    'args': args}

# Save checkpoint per epoch
# pth文件通过有序字典来保持模型参数
# 当想要回复某一阶段的训练(或者进行测试)时,可以读取之前保存的网络模型参数
if loss < best_loss:
    utils.save_on_master(
        checkpoint,
        os.path.join(args.output_path, 'checkpoint.pth'))
    print('saving checkpoint at epoch: ', epoch)
    chp = epoch
    best_loss = loss
# Save checkpoint every epoch block
print('current best loss: ', best_loss)
print('current best epoch: ', chp)
if args.output_path and (epoch + 1) % args.epoch_block == 0:
    utils.save_on_master(
        checkpoint,
        os.path.join(args.output_path, 'model_{}.pth'.format(epoch + 1)))

注意:常规dict是无序的,OrderedDict能够比dict更好地处理频繁的重新排序操作。 

2.2 test.py的理解

2.21 定义测试集

  • 创建文件保存的路径、确定训练设备、设置优化运行效率
utils.mkdir(args.output_path) # 创建相对路径的文件夹
device = torch.device(args.device) # 确定训练设备
torch.backends.cudnn.benchmark = True # 优化运行效率
  •  确定数据与标签的预处理方式
transform_valid_data = Compose([
    T.LogTransform(k=args.k),
    T.MinMaxNormalize(log_data_min, log_data_max),
])

transform_valid_label = Compose([
    T.MinMaxNormalize(ctx['label_min'], ctx['label_max'])
])
  • 初始化测试集
if args.val_anno[-3:] == 'txt':
    dataset_valid = FWIDataset(
        args.val_anno,
        sample_ratio=args.sample_temporal,
        file_size=ctx['file_size'],
        transform_data=transform_valid_data,
        transform_label=transform_valid_label
    )
else:
    dataset_valid = torch.load(args.val_anno)
  • 加载数据集
valid_sampler = SequentialSampler(dataset_valid)
dataloader_valid = torch.utils.data.DataLoader(
    dataset_valid, batch_size=args.batch_size,
    sampler=valid_sampler, num_workers=args.workers,
    pin_memory=True, collate_fn=default_collate
)

2.22 加载模型

model = network.model_dict[args.model](upsample_mode=args.up_mode,
                                       sample_spatial=args.sample_spatial,
                                       sample_temporal=args.sample_temporal,
                                       norm=args.norm).to(device)

2.23  开始test

evaluate(model, criterions, dataloader_valid, device, args.k, ctx,
         vis_path, args.vis_batch, args.vis_sample, args.missing, args.std)

三、遇到的部分问题及解决

3.1 module 'torchvision' has no attribute '__version__'

        问题描述:module 'torchvision' has no attribute '__version__';这个问题是由于torchvision库的版本问题引起的,在较旧的版本中,可能没有‘version’这个属性

        参考:AttributeError: module 'torchvision' has no attribute 'version' - CSDN文库 

        解决方式:升级torchvision库到最新版本。命令如下:

pip install --upgrade trochvision -i https://pypi.tuna.tsinghua.edu.cn/simple

3.2 Python之__call__的理解

       在调试代码的时候对类的调用理解不够到位,在学习的时候困扰较多。

transform_data = Compose([
    # 数据平滑处理
    T.LogTransform(k=args.k),
    # 数据归一化
    # 创建MinMaxNormalize对象
    T.MinMaxNormalize(T.log_transform(ctx['data_min'], k=args.k), T.log_transform(ctx['data_max'], k=args.k))
])

# 数据加载
dataset_train = FWIDataset(
        args.train_anno,
        preload=True,
        sample_ratio=args.sample_temporal,
        file_size=ctx['file_size'],
        transform_data=transform_data,
        transform_label=transform_label
)

# 读取数据
dataloader_train = DataLoader(
    dataset_train, # 传入的数据集
    batch_size=args.batch_size, # 每个batch的样本数
    sampler=train_sampler, # 自定义从数据集中取走样本的策略,若指定这个参数,shuffle必须为false
    num_workers=args.workers, # 设置进程数
    pin_memory=True, # 设置为True表示dalaloader在返回它们之前,会将tensors拷贝到CUDA中的固定内存中
    drop_last=True, # 设置为True表示丢弃最后一批样本
    collate_fn=default_collate # 将一个list的sample组成一个mini-batch的函数
)

# 开始训练
train_one_epoch(model, criterion, optimizer, lr_scheduler, dataloader_train, device, epoch, args.print_freq, train_writer)

# 训练一个轮次的模型——完成数据转换
def train_one_epoch(model, criterion, optimizer, lr_scheduler, dataloader, device, epoch, print_freq, writer):

    for data, label in metric_logger.log_every(dataloader, print_freq, header):

def log_every(self, iterable, print_freq, header=None):
    for obj in iterable:

        参考:python特殊函数__call__(self) - 知乎 (zhihu.com)

四、相关参考

4.1 关键字global的用法

        global是Python中全局变量的关键字。在定义函数时,若需要在函数内部对函数外部的变量进行操作,需要在函数内部将函数外部的变量声明为global变量。

        在add()函数中,没有在a前加global,因此add()函数无法将a赋值为3,无法对a的值进行修改,a在函数外部的值没有进行改变。

注意:

  1. 变量分为全局变量(可以由对象或函数创建,也可以在程序任何地方创建,创建成功后可以被本程序内所有对象或函数引用)与局部变量(内部变量-由某个对象或函数创建,只能被内部引用,无法被其它对象或函数引用);
  2. global需要再函数内部声明,可以使用同一个global语句指定多个全局变量;
  3. 全局变量无法使用局部变量,只有对应局部作用域有效。
  4. 参考:python 中关键字 global 的用法_python global_ Marks的博客-CSDN博客
a = 1
b = 2

def add():
    a = 3
    global b
    b = 4
    print("② a + b =", a, "+", b,  "=", a + b)

print("① a + b =", a, "+", b,  "=", a + b)
add()
print("③ a + b =", a, "+", b,  "=", a + b)

4.2 python 创建文件夹之 mkdir() 和makedirs()

        mkdir()命令用于创建一级目录makedirs()命令用于创建多级目录

        参考:python 创建文件夹之 mkdir() 和makedirs()_mkdir python_薰珞婷紫小亭子的博客-CSDN博客

        代码与目录创建情况截图如下:

import os

output_path1 = "output1/"  # 模型保存地址
os.mkdir(output_path1)  # only create one folder
output_path2 = "output2/output2"  # 模型保存地址
os.makedirs(output_path2)  # create more then one folder

4.3 os.environ模块环境变量解释

        os.environ是一个环境变量的字典,环境变量是程序与操作系统之间的通信方式。在Python中通过os.environ可以获取到有关系统的各种信息。

        参考:os.environ模块环境变量详解-CSDN博客

import os
print(os.environ.keys()) # 打印os.environ.keys()主目录下所有的key
print(os.environ.get("HOME")) # 获取环境变量,若有这个键,返回对应的值;反之,返回none
print(os.environ.get("HOME", "default"))	#环境变量HOME不存在,返回	default

# 设置系统环境变量
os.environ['环境变量名称']='环境变量值' #其中key和value均为string类型
os.putenv('环境变量名称', '环境变量值')
os.environ.setdefault('环境变量名称', '环境变量值')

# 修改系统环境变量
os.environ['环境变量名称']='新环境变量值'

# 获取系统环境变量
os.environ['环境变量名称']
os.getenv('环境变量名称')
os.environ.get('环境变量名称', '默认值')	#默认值可给可不给,环境变量不存在返回默认值

# 删除系统环境变量
del os.environ['环境变量名称']
del(os.environ['环境变量名称'])

# 判断系统环境变量是否存在
'环境变量值' in os.environ   # 存在返回 True,不存在返回 False

4.4 Tensorboard的使用 ---- SummaryWriter类

       SummaryWriter:在给定目录中创建事件文件,并向其中添加摘要和事件。 该类异步更新文件内容,这允许训练程序调用方法以直接从训练循环将数据添加到文件中,而不会减慢训练速度。

from torch.utils.tensorboard import SummaryWriter

        参考:Tensorboard的使用 ---- SummaryWriter类(pytorch版)_chuanauc的博客-CSDN博客

        下图表示使用SummaryWriter类构建了两个实例train_writer与val_trainwriter,若都为true,则运行完后会出现对应的两个文件夹,其中包含可以被tensorboard所解释的文件。

         可视化展示:在Pycharm的terminal中,键入指令tensorboard –logdir=XXXX(其中,XXXX是指文件写入的地方)

4.5 torch.backends.cudnn.benchmark = true的作用

        在很多情况下,代码中有这样一行:

torch.backends.cudnn.benchmark = true

        上述代码可以提升训练速度,程序在开始的时候会花费一点额外时间,为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速。但是在计算中具有随机性,每次网络的结果可能存在一点点差异。

        需要注意的是:

  1. 若网络的输入数据维度或类型上变化不大,上述设置可以增加运行效率;
  2. 若网络的输入数据在每次迭代都变化的话,会导致cnDNN每次都去寻找一遍最优配置,反而会降低运行效率;
  3. 因此,如果cuDNN使用非确定性算法,可以通过torch.backends.cudnn.enabled = False来禁用。

        参考:torch.backends.cudnn.benchmark = true的作用-CSDN博客

4.6 MinMaxNormalize 规一化算法

        数据归一化是指将数据按照一定比例缩放(转换),使之落入一个小的特定的区间,转换为类似于(0,1)或者(-1,1)之间的小数,将有量纲的表达式转换为无量纲的表达式。在多指标评价体系中,不同的评价指标性质(量纲、数量级等)不同,当各指标间的水平差距很大时,若直接使用原始数据进行分析,会突出数值较高的指标在综合性分析中的作用,相对削弱值水平低指标的作用。原始数据经过数据标准化处理后,各指标处于同一数量级,适合进行综合对比分析。

         参考:如何理解归一化(normalization)? - 知乎 (zhihu.com)

         本文采用的公式如下:

vid = (\frac{vid-vmin}{vmax-vmin}-0.5)*2

其中vmin表示地震数据/速度模型的最小值,vmax表示地震数据/速度模型的最大值。

4.7 符号函数np.sign()的介绍及用法

        np.sign()是Python的Numpy中的取数字符号(数字前的正负号)的函数:

        参考:【Python】Numpy库之符号函数sign()的介绍及用法_python sign-CSDN博客 

sign(x)=\left\{\begin{matrix} 1,x>0\\ 0,x=0\\ -1,x<0\end{matrix}\right.

import numpy as np # 导入numpy库

data = [-0.8, -1.1, 0, 2.3, 4.5]
print("输入数据为:", data)

# 使用numpy的sign(x)函数求输入数据的符号
signResult = np.sign(data)
print("使用sign函数的输出符号为:", signResult)

4.8 numpy库ndarray多维数组的运算:np.abs()

        np.abs()表示计算数组中各元素的绝对值(多个元素并行处理),返回的类型为ndarray。

        参考:Python numpy.abs和abs函数别再傻傻分不清了_小熊爱喝牛奶的博客-CSDN博客

        需要注意的是:np.abs()与abs()需要进行区分,abs()适用于处理单个元素,返回的类型是int。 

4.9 数据平滑处理之np.log1p的介绍 

        在进行数据预处理的时候,可以首先对偏度较大的数据使用log1p函数进行转换,使其更加服从高斯分布,更加利于后续得到更好的分类结果。log1p将数据压缩到一个区间内,与数据的标准化类似。

        注意:由于使用log1p对数据进行了压缩,最预测出来的平滑数据需要进行一个还原,使用log1p的逆运算expm1。

log1p=log(x+1),即ln(x+1)

expm1=exp(x)-1

        参考:数据平滑处理——log1p()和exmp1()-CSDN博客

4.10 DistributedSamper()的介绍

         DistributedSampler保证测试数据集中加载固定的顺序,DistributedSampler()位于torch.utils.data,该类通常用于分布式单机多卡(或多机多卡)的审计网络训练。在使用时,首先需要初始化DistributedSampler,将该对象作为参数传入rotch.utils.data.DataLoader()的sampler参数中。此时DataLoader具备分布式采样的能力,以单机多卡为例,若当前环境中有N张显卡,则整个数据集会被分割为N份,每张显卡会获得属于自己的那一份数据。一个epoch数据总和/num_gpu=每个GPU得到的数目;一个epoch每个GPU迭代的次数=得到的数据/batch_size。如下:

from torch.utils.data import DistributedSampler,DataLoader
from torchvision import datasets

dataset = datasets.ImageFolder(data_path, transform)
sampler = DistributedSampler(dataset)
loader = DataLoader(
                    dataset = dataset,
                    sampler = sampler,
                   )

         DistributedSampler的构造函数如下所示:

def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
                 rank: Optional[int] = None, shuffle: bool = True,
                 seed: int = 0, drop_last: bool = False) -> None:
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                "Invalid rank {}, rank should be in the interval"
                " [0, {}]".format(rank, num_replicas - 1))
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.drop_last = drop_last
        # If the dataset length is evenly divisible by # of replicas, then there
        # is no need to drop any data, since the dataset will be split equally.
        if self.drop_last and len(self.dataset) % self.num_replicas != 0:  # type: ignore[arg-type]
            # Split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            self.num_samples = math.ceil(
                (len(self.dataset) - self.num_replicas) / self.num_replicas  # type: ignore[arg-type]
            )
        else:
            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)  # type: ignore[arg-type]
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle
        self.seed = seed

其中:

  1. dataset:类型为torch.utils.data.Dataset,是该采样器需要处理的对象;
  2. num_replicas:将数据集划分为几块,默认为None,在后续代码中会进行判断;
  3. rank:表示此sampler要处理的环境的rank号,在单机多卡环境下就是第几张显卡,默认为None,在后续代码中会进行判断;

  4. shuffle:是否要打乱数据的顺序;

  5. seed:随机数种子,用于打乱顺序;

  6. drop_last:是否丢弃最后一组数据;

        注意:DistributedSampler中保证参数shuffle=False,训练集需要保证shuffle=True(默认参数是True)。在DataLoader中需要保证测试数据集和训练数据集都是shuffle=False(参数默认是False),因为有了sampler进行数据采样,如果shuffle=True会与sampler进行采样冲突,出现报错。如果不是DDP,则需要保证训练数据集的dataloader中shuffle参数是True,测试数据集的dataloader中shuffle参数是False。

         参考(待深入理解)Pytorch - DistributedSampler 源码分析(一) - 知乎 (zhihu.com)

4.11 RandomSampler()的理解

        RandomSampler()表示对数据进行随机采样。

torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)

 其中:

  1. data_source:表示被采样的数据集合;
  2. replacement:采样策略,若为True,则代表使用替换采样策略,即可重复对一个样本进行采样;若为False,则表示不用替换采样策略,即一个样本最多只能被采一次;
  3. num_samples:所采样本的数量,默认采全部样本;当replacement规定为True时,可指定采样数量,即修改num_samples的大小;如果replacement设置为False,则该参数不可做修改;
  4. generator:采样过程中的生成器;

        参考:PyTorch学习笔记:data.RandomSampler——数据随机采样_pytorch随机抽样_视觉萌新、的博客-CSDN博客

4.12 argparse模块介绍

        argparse是一个Python模块:命令行选项、参数和子命令解析器。

        使用流程:

        ①创建解析器:ArgumentParser对象将命令行解析成Python数据类型所需的全部信息;

parser = argparse.ArgumentParser(description='Process some integers.')
class argparse.ArgumentParser(prog=None, usage=None, description=None, epilog=None, parents=[], formatter_class=argparse.HelpFormatter, prefix_chars='-', fromfile_prefix_chars=None, argument_default=None, conflict_handler='error', add_help=True, allow_abbrev=True)

 其中:

  • prog:表示程序的名称(默认:sys.argv[0]);
  • usage:表示描述程序用途的字符串(默认值:从添加到解析器的参数生成;
  • description:表示在参数帮助文档之前显示的文本(默认值:无),描述这个程序做什么以及怎么做,显示在命令行用法字符串和各种参数的帮助消息之间;
  • epilog:表示在参数帮助文档之后显示的文本(默认值:无);
  • parents:表示一个 ArgumentParser 对象的列表,它们的参数也应包含在内;
  • formatter_class:表示用于自定义帮助文档输出格式的类;
  • prefix_chars:表示可选参数的前缀字符集合(默认值:’-’);
  • fromfile_prefix_chars:当需要从文件中读取其他参数时,用于标识文件名的前缀字符集合(默认值:None);
  • argument_default:表示参数的全局默认值(默认值: None);
  • conflict_handler:表示解决冲突选项的策略(通常是不必要的);
  • add_help:表示为解析器添加一个 -h/--help 选项(默认值: True);
  • allow_abbrev:表示如果缩写是无歧义的,则允许缩写长选项 (默认值:True)

        ②添加参数:调用add_argument()方法向一个ArgumentParser中添加程序的参数信息;

# 相当于增加integers属性,后续可以打印args.integers中的内容
parser.add_argument('integers', metavar='N', type=int, nargs='+', help='an integer for the accumulator')
ArgumentParser.add_argument(name or flags...[, action][, nargs][, const][, default][, type][, choices][, required][, help][, metavar][, dest])

 其中:

  • name or flags:一个命名或者一个选项字符串的列表,例如 foo 或 -f, --foo(也可以对参数名进行简写;给属性名之前加上“- -”,就能将之变为可选参数);
  • action:当参数在命令行中出现时使用的动作基本类型;
  • nargs:命令行参数应当消耗的数目;
  • const:被一些 action 和 nargs 选择所需求的常数;
  • default:当参数未在命令行中出现时使用的值;
  • type:命令行参数应当被转换成的类型;
  • choices:可用的参数的容器;
  • required:此命令行选项是否可省略 (仅选项可用);
  • help:一个此选项作用的简单描述;
  • metavar:在使用方法消息中使用的参数值示例;
  • dest:被添加到 parse_args() 所返回对象上的属性名;

        ③将参数传给args实例:把parser中设置的所有"add_argument"给返回到args子类实例当中, 那么parser中增加的属性内容都会在args实例中,使用即可。

args = parser.parse_args()

        ④解析参数:ArgumentParser通过parse_args()方法解析参数;

>>> parser.parse_args(['--sum', '7', '-1', '42'])
Namespace(accumulate=<built-in function sum>, integers=[7, -1, 42])

        参考:①Python 讲堂 parse_args()详解_parser.parse_args-CSDN博客;② argparse.ArgumentParser()用法解析-CSDN博客;③python之parser.add_argument()用法——命令行选项、参数和子命令解析器-CSDN博客

4.13 os.path.join()函数介绍

        os.path.join()函数用于路径拼接文件的路径,可以传入多个路径。

  • 若不存在以“ / ”开始的参数,函数会自动加上;
  • 若存在以“ / ”开始的参数,从最后一个出现以“ / ”开头的参数开始拼接,之前的参数都会被丢弃;
  • 若同时存在以“ \ ”和“ / ”开始的参数,以“ / ”为主,从最后一个以“ / ”开头的参数开始拼接,之前的参数都会被丢弃;
  • 若只存在以“ ./ ”开始的参数,会从“ ./ ”开头的参数的上一个参数开始拼接;

        参考:os.path.join()函数用法详解_os.path.join函数-CSDN博客

4.14 Compose()函数

        在Pytorch中,Compose 是 torchvision.transforms 模块中的一个类,它用于将多个图像预处理操作组合在一起,以便在深度学习任务中按顺序应用这些操作。通常,它用于预处理输入数据,例如图像,以供神经网络模型使用。

import torch
from torchvision import transforms

# Compose创建一个图像预处理管道,其中的参数是个列表,列表中的元素是想执行的transforms操作
transform_data = Compose([
    # 数据平滑处理
    T.LogTransform(k=args.k),
    # 数据归一化
    T.MinMaxNormalize(T.log_transform(ctx['data_min'], k=args.k), T.log_transform(ctx['data_max'], k=args.k))
    ])

        Compose()类会将transforms列表中的transform操作进行遍历,实现的部分代码如下:

def __call__(self, img):
    for t in self.transforms:   
        img = t(img)
    return img

        参考:【pytorch】transforms.Compose()使用 - 知乎 (zhihu.com) 

五、总结

5.1 存在的疑惑

  1. 被各种函数的调用实实在在绕晕了
  2. warm-up策略的原理
  3. Python类的理解
  4. SmoothedValue与MetricLogger的理解
  5. 内容损失、风格迁移:神经网络风格迁移Pytorch_image.to(device)-CSDN博客

5.2 下周安排

  1. 尝试查看tensorboard的可视化展示
  2. 学习与理解:[Python] 深入理解 self、cls、__call__、__new__、__init__、__del__、__str__、__class__、__doc__等_python cls-CSDN博客
  3. 将OpenFWI论文读完
  4. 运行完OpenFWI代码

(本周无语之ctrl+z回到前一天晚上的草稿,当时没注意可以返回历史记录,为自己的愚蠢感到难过)

猜你喜欢

转载自blog.csdn.net/m0_53096519/article/details/134625424