import argparse [参数解析器](https://zhuanlan.zhihu.com/p/539331146)import ast
import os #操作系统import torch#机器学习库import yaml#[yaml](https://blog.csdn.net/m0_56654441/article/details/120604119)from src.processor import processor#当前src目录下processor.py文件中processor对象# Use Deterministic mode and set random seed
torch.backends.cudnn.deterministic =True
torch.backends.cudnn.benchmark =False
torch.manual_seed(0)#[保证相同输入,具有相同的输出,复现性强](https://xhfei.blog.csdn.net/article/details/106268969?spm=1001.2101.3001.6661.1)defget_parser():#参数解析器
parser = argparse.ArgumentParser(
description='STAR')#参数解析器对象
parser.add_argument('--dataset', default='eth5')#添加各项参数--代表可选default代表默认值type代表数据类型
parser.add_argument('--save_dir')
parser.add_argument('--model_dir')
parser.add_argument('--config')
parser.add_argument('--using_cuda', default=True,type=ast.literal_eval)
parser.add_argument('--test_set', default='eth',type=str,help='Set this value to [eth, hotel, zara1, zara2, univ] for ETH-univ, ETH-hotel, UCY-zara01, UCY-zara02, UCY-univ')
parser.add_argument('--base_dir', default='.',help='Base directory including these scripts.')
parser.add_argument('--save_base_dir', default='./output/',help='Directory for saving caches and models.')
parser.add_argument('--phase', default='train',help='Set this value to \'train\' or \'test\'')
parser.add_argument('--train_model', default='star',help='Your model name')
parser.add_argument('--load_model', default=None,type=str,help="load pretrained model for test or training")
parser.add_argument('--model', default='star.STAR')
parser.add_argument('--seq_length', default=20,type=int)
parser.add_argument('--obs_length', default=8,type=int)
parser.add_argument('--pred_length', default=12,type=int)
parser.add_argument('--batch_around_ped', default=256,type=int)
parser.add_argument('--batch_size', default=8,type=int)
parser.add_argument('--test_batch_size', default=4,type=int)
parser.add_argument('--show_step', default=100,type=int)
parser.add_argument('--start_test', default=10,type=int)
parser.add_argument('--sample_num', default=20,type=int)
parser.add_argument('--num_epochs', default=300,type=int)
parser.add_argument('--ifshow_detail', default=True,type=ast.literal_eval)
parser.add_argument('--ifsave_results', default=False,type=ast.literal_eval)
parser.add_argument('--randomRotate', default=True,type=ast.literal_eval,help="=True:random rotation of each trajectory fragment")
parser.add_argument('--neighbor_thred', default=10,type=int)
parser.add_argument('--learning_rate', default=0.0015,type=float)
parser.add_argument('--clip', default=1,type=int)return parser#返回defload_arg(p):#加载参数# save argif os.path.exists(p.config):#如果文件存在,则加载withopen(p.config,'r')as f:
default_arg = yaml.load(f)
key =vars(p).keys()for k in default_arg.keys():if k notin key:print('WRONG ARG: {}'.format(k))try:assert(k in key)except:
s =1
parser.set_defaults(**default_arg)return parser.parse_args()else:returnFalse