[Reinforcement Learning Principles + Project Column] Must-see series: single-agent, multi-agent algorithm principles + project practice, related skills (parameter adjustment, drawing, etc., interesting project realization, academic application project realization
The plan for deep reinforcement learning is:
- Basic single-intelligence algorithm teaching (gym environment-based)
- Mainstream multi-intelligence algorithm teaching (gym environment-based)
- Mainstream algorithms: DDPG, DQN, TD3, SAC, PPO, RainbowDQN, QLearning, A2C and other algorithm projects
- Some interesting projects (Super Mario, playing backgammon, Fight the Landlord, various game applications)
- Actual combat of single-intelligence and multiple-intelligence questions (the paper reproduces partial business such as: UAV optimization scheduling, power resource scheduling and other project applications)
This column is mainly to facilitate entry-level students to quickly grasp reinforcement learning single-agent | multi-agent algorithm principles + project practice. In the follow-up, we will continue to analyze the knowledge principles involved in deep learning to everyone, so that everyone can reserve knowledge while practicing the project, knowing what it is, why it is, and why to know why it is.
Disclaimer: Some projects are online classic projects for everyone to learn quickly, and practical links will be added in the future (competitions, papers, practical applications, etc.)
-
Column subscription (personalized choice):
Implemented deep reinforcement learning backgammon based on Monte Carlo tree and strategy value network (including code source)
-
features
- self game
- detailed notes
- Simple process
-
code structure
- net: strategic value network implementation
- mcts: Monte Carlo tree implementation
- server: front-end interface code
- legacy: obsolete code
- docs: other documents
- utils: tool code
- network.py: ported network structure code
- model_5400.pkl: transplanted network training weights
- train_agent.py: training script
- web_server.py: game service script
- web_server_demo.py: game service script (transplant network)
1.1 Process
1.2 Strategic value network
A structure similar to ResNet is adopted, and the SPP module is added.
(Currently, due to the time-consuming training, I ran more than 2,000 self-play chess records after running for more than three weeks. After experiments, the performance of this strategy network is still not good, and it may be that you have not trained enough)
At the same time, another open source policy network and its training weights (network.py, model_5400.pkl) were transplanted for simulation demonstration effect.
1.3 training
Adjust train_agent.py
the file according to the comments, and run the script
Part of the code shows:
if __name__ == '__main__':
conf = LinXiaoNetConfig()
conf.set_cuda(True)
conf.set_input_shape(8, 8)
conf.set_train_info(5, 16, 1e-2)
conf.set_checkpoint_config(5, 'checkpoints/v2train')
conf.set_num_worker(0)
conf.set_log('log/v2train.log')
# conf.set_pretrained_path('checkpoints/v2m4000/epoch_15')
init_logger(conf.log_file)
logger()(conf)
device = 'cuda' if conf.use_cuda else 'cpu'
# 创建策略网络
model = LinXiaoNet(3)
model.to(device)
loss_func = AlphaLoss()
loss_func.to(device)
optimizer = torch.optim.SGD(model.parameters(), conf.init_lr, 0.9, weight_decay=5e-4)
lr_schedule = torch.optim.lr_scheduler.StepLR(optimizer, 1, 0.95)
# initial config tree
tree = MonteTree(model, device, chess_size=conf.input_shape[0], simulate_count=500)
data_cache = TrainDataCache(num_worker=conf.num_worker)
ep_num = 0
chess_num = 0
# config train interval
train_every_chess = 18
# 加载检查点
if conf.pretrain_path is not None:
model_data, optimizer_data, lr_schedule_data, data_cache, ep_num, chess_num = load_checkpoint(conf.pretrain_path)
model.load_state_dict(model_data)
optimizer.load_state_dict(optimizer_data)
lr_schedule.load_state_dict(lr_schedule_data)
logger()('successfully load pretrained : {}'.format(conf.pretrain_path))
while True:
logger()(f'self chess game no.{chess_num+1} start.')
# 进行一次自我对弈,获取对弈记录
chess_record = tree.self_game()
logger()(f'self chess game no.{chess_num+1} end.')
# 根据对弈记录生成训练数据
train_data = generate_train_data(tree.chess_size, chess_record)
# 将训练数据存入缓存
for i in range(len(train_data)):
data_cache.push(train_data[i])
if chess_num % train_every_chess == 0:
logger()(f'train start.')
loader = data_cache.get_loader(conf.batch_size)
model.train()
for _ in range(conf.epoch_num):
loss_record = []
for bat_state, bat_dist, bat_winner in loader:
bat_state, bat_dist, bat_winner = bat_state.to(device), bat_dist.to(device), bat_winner.to(device)
optimizer.zero_grad()
prob, value = model(bat_state)
loss = loss_func(prob, value, bat_dist, bat_winner)
loss.backward()
optimizer.step()
loss_record.append(loss.item())
logger()(f'train epoch {ep_num} loss: {sum(loss_record) / float(len(loss_record))}')
ep_num += 1
if ep_num % conf.checkpoint_save_every_num == 0:
save_checkpoint(
os.path.join(conf.checkpoint_save_dir, f'epoch_{ep_num}'),
ep_num, chess_num, model.state_dict(), optimizer.state_dict(), lr_schedule.state_dict(), data_cache
)
lr_schedule.step()
logger()(f'train end.')
chess_num += 1
save_chess_record(
os.path.join(conf.checkpoint_save_dir, f'chess_record_{chess_num}.pkl'),
chess_record
)
# break
pass
1.4 Simulation experiment
Adjust web_server.py
the file according to the comments, load the pretrained weights used, and run the script
Open the URL in the browser: http://127.0.0.1:8080/
play the game
Part of the code display
# 用户查询机器落子状态
@app.route('/state/get/<state_id>', methods=['GET'])
def get_state(state_id):
global state_result
state_id = int(state_id)
state = 0
chess_state = None
if state_id in state_result.keys() and state_result[state_id] is not None:
state = 1
chess_state = state_result[state_id]
state_result[state_id] = None
ret = {
'code': 0,
'msg': 'OK',
'data': {
'state': state,
'chess_state': chess_state
}
}
return jsonify(ret)
# 游戏开始,为这场游戏创建蒙特卡洛树
@app.route('/game/start', methods=['POST'])
def game_start():
global trees
global model, device, chess_size, simulate_count
tree_id = random.randint(1000, 100000)
trees[tree_id] = MonteTree(model, device, chess_size=chess_size, simulate_count=simulate_count)
ret = {
'code': 0,
'msg': 'OK',
'data': {
'tree_id': tree_id
}
}
return jsonify(ret)
# 游戏结束,销毁蒙特卡洛树
@app.route('/game/end/<tree_id>', methods=['POST'])
def game_end(tree_id):
global trees
tree_id = int(tree_id)
trees[tree_id] = None
ret = {
'code': 0,
'msg': 'OK',
'data': {}
}
return ret
if __name__ == '__main__':
app.run(
'0.0.0.0',
8080
)
1.5 Simulation experiment (transplant network)
Run the script:python web_server_demo.py
Open the URL in the browser: http://127.0.0.1:8080/
play the game
- reference documents