【程序阅读】Spatio-Temporal Graph Transformer Networks for Pedestrian Trajectory Prediction/trainval.py

import argparse [参数解析器](https://zhuanlan.zhihu.com/p/539331146)
import ast
import os #操作系统

import torch#机器学习库
import yaml#[yaml](https://blog.csdn.net/m0_56654441/article/details/120604119)

from src.processor import processor#当前src目录下processor.py文件中processor对象

# Use Deterministic mode and set random seed
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(0)
#[保证相同输入,具有相同的输出,复现性强](https://xhfei.blog.csdn.net/article/details/106268969?spm=1001.2101.3001.6661.1)

def get_parser():#参数解析器
    parser = argparse.ArgumentParser(
        description='STAR')#参数解析器对象
    parser.add_argument('--dataset', default='eth5')#添加各项参数--代表可选default代表默认值type代表数据类型
    parser.add_argument('--save_dir')
    parser.add_argument('--model_dir')
    parser.add_argument('--config')
    parser.add_argument('--using_cuda', default=True, type=ast.literal_eval)
    parser.add_argument('--test_set', default='eth', type=str,
                        help='Set this value to [eth, hotel, zara1, zara2, univ] for ETH-univ, ETH-hotel, UCY-zara01, UCY-zara02, UCY-univ')
    parser.add_argument('--base_dir', default='.', help='Base directory including these scripts.')
    parser.add_argument('--save_base_dir', default='./output/', help='Directory for saving caches and models.')
    parser.add_argument('--phase', default='train', help='Set this value to \'train\' or \'test\'')
    parser.add_argument('--train_model', default='star', help='Your model name')
    parser.add_argument('--load_model', default=None, type=str, help="load pretrained model for test or training")
    parser.add_argument('--model', default='star.STAR')
    parser.add_argument('--seq_length', default=20, type=int)
    parser.add_argument('--obs_length', default=8, type=int)
    parser.add_argument('--pred_length', default=12, type=int)
    parser.add_argument('--batch_around_ped', default=256, type=int)
    parser.add_argument('--batch_size', default=8, type=int)
    parser.add_argument('--test_batch_size', default=4, type=int)
    parser.add_argument('--show_step', default=100, type=int)
    parser.add_argument('--start_test', default=10, type=int)
    parser.add_argument('--sample_num', default=20, type=int)
    parser.add_argument('--num_epochs', default=300, type=int)
    parser.add_argument('--ifshow_detail', default=True, type=ast.literal_eval)
    parser.add_argument('--ifsave_results', default=False, type=ast.literal_eval)
    parser.add_argument('--randomRotate', default=True, type=ast.literal_eval,
                        help="=True:random rotation of each trajectory fragment")
    parser.add_argument('--neighbor_thred', default=10, type=int)
    parser.add_argument('--learning_rate', default=0.0015, type=float)
    parser.add_argument('--clip', default=1, type=int)

    return parser#返回


def load_arg(p):#加载参数
    # save arg
    if os.path.exists(p.config):#如果文件存在,则加载
        with open(p.config, 'r') as f:
            default_arg = yaml.load(f)
        key = vars(p).keys()
        for k in default_arg.keys():
            if k not in key:
                print('WRONG ARG: {}'.format(k))
                try:
                    assert (k in key)
                except:
                    s = 1
        parser.set_defaults(**default_arg)
        return parser.parse_args()
    else:
        return False
p = {
    
    Namespace} Namespace(base_dir='.', batch_around_ped=256, batch_size=8, clip=1, config='./output/eth/star//config_train.yaml', dataset='eth5', ifsave_results=False, ifshow_detail=True, learning_rate=0.0015, load_model=None, model='star.STAR', model_dir='./output/eth/star/', neighbor_thred=10, num_epochs=300, obs_length=8, phase='train', pred_length=12, randomRotate=True, sample_num=20, save_base_dir='./output/', save_dir='./output/eth/', seq_length=20, show_step=100, start_test=10, test_batch_size=4, test_set='eth', train_model='star', using_cuda=True)
 base_dir = {
    
    str} '.'
 batch_around_ped = {
    
    int} 256
 batch_size = {
    
    int} 8
 clip = {
    
    int} 1
 config = {
    
    str} './output/eth/star//config_train.yaml'
 dataset = {
    
    str} 'eth5'
 ifsave_results = {
    
    bool} False
 ifshow_detail = {
    
    bool} True
 learning_rate = {
    
    float} 0.0015
 load_model = {
    
    NoneType} None
 model = {
    
    str} 'star.STAR'
 model_dir = {
    
    str} './output/eth/star/'
 neighbor_thred = {
    
    int} 10
 num_epochs = {
    
    int} 300
 obs_length = {
    
    int} 8
 phase = {
    
    str} 'train'
 pred_length = {
    
    int} 12
 randomRotate = {
    
    bool} True
 sample_num = {
    
    int} 20
 save_base_dir = {
    
    str} './output/'
 save_dir = {
    
    str} './output/eth/'
 seq_length = {
    
    int} 20
 show_step = {
    
    int} 100
 start_test = {
    
    int} 10
 test_batch_size = {
    
    int} 4
 test_set = {
    
    str} 'eth'
 train_model = {
    
    str} 'star'
 using_cuda = {
    
    bool} True
config_train.yaml
base_dir: .
batch_around_ped: 256
batch_size: 8
clip: 1
config: ./output/eth/star//config_train.yaml
dataset: eth5
ifsave_results: false
ifshow_detail: true
learning_rate: 0.0015
load_model: null
model: star.STAR
model_dir: ./output/eth/star/
neighbor_thred: 10
num_epochs: 300
obs_length: 8
phase: train
pred_length: 12
randomRotate: true
sample_num: 20
save_base_dir: ./output/
save_dir: ./output/eth/
seq_length: 20
show_step: 100
start_test: 10
test_batch_size: 4
test_set: eth
train_model: star
using_cuda: true

def save_arg(args):#保存参数
    # save arg
    arg_dict = vars(args)#获取字典
    if not os.path.exists(args.model_dir):#如果不存在路径,则重新创建
        os.makedirs(args.model_dir)
    with open(args.config, 'w') as f:#以可写方式打开
        yaml.dump(arg_dict, f)#写入


if __name__ == '__main__':
    parser = get_parser()#输入参数 例如 python traibval.py --test_set eth
    p = parser.parse_args()#从命令行中结构化解析参数

    p.save_dir = p.save_base_dir + str(p.test_set) + '/'
    p.model_dir = p.save_base_dir + str(p.test_set) + '/' + p.train_model + '/'
    p.config = p.model_dir + '/config_' + p.phase + '.yaml'
#修改保存 模型 配置文件的路径
    if not load_arg(p):#如果没有文件,那么先保存
        save_arg(p)

    args = load_arg(p)#加载参数

    torch.cuda.set_device(0)#显卡0

    trainer = processor(args)#设计处理器

    if args.phase == 'test':
        trainer.test()#测试
    else:
        trainer.train()#训练

猜你喜欢

转载自blog.csdn.net/xi_shui/article/details/128141787