Reinforcement learning: Implemented deep reinforcement learning backgammon based on Monte Carlo tree and strategy value network (including code source)

insert image description here
[Reinforcement Learning Principles + Project Column] Must-see series: single-agent, multi-agent algorithm principles + project practice, related skills (parameter adjustment, drawing, etc., interesting project realization, academic application project realization

insert image description here
Column details : [Reinforcement Learning Principles + Project Column] Must-see series: single-agent, multi-agent algorithm principles + project practice, related skills (parameter adjustment, drawing, etc., interesting project realization, academic application project realization

The plan for deep reinforcement learning is:

  • Basic single-intelligence algorithm teaching (gym environment-based)
  • Mainstream multi-intelligence algorithm teaching (gym environment-based)
    • Mainstream algorithms: DDPG, DQN, TD3, SAC, PPO, RainbowDQN, QLearning, A2C and other algorithm projects
  • Some interesting projects (Super Mario, playing backgammon, Fight the Landlord, various game applications)
  • Actual combat of single-intelligence and multiple-intelligence questions (the paper reproduces partial business such as: UAV optimization scheduling, power resource scheduling and other project applications)

This column is mainly to facilitate entry-level students to quickly grasp reinforcement learning single-agent | multi-agent algorithm principles + project practice. In the follow-up, we will continue to analyze the knowledge principles involved in deep learning to everyone, so that everyone can reserve knowledge while practicing the project, knowing what it is, why it is, and why to know why it is.

Disclaimer: Some projects are online classic projects for everyone to learn quickly, and practical links will be added in the future (competitions, papers, practical applications, etc.)

Implemented deep reinforcement learning backgammon based on Monte Carlo tree and strategy value network (including code source)

  • features

    • self game
    • detailed notes
    • Simple process
  • code structure

    • net: strategic value network implementation
    • mcts: Monte Carlo tree implementation
    • server: front-end interface code
    • legacy: obsolete code
    • docs: other documents
    • utils: tool code
    • network.py: ported network structure code
    • model_5400.pkl: transplanted network training weights
    • train_agent.py: training script
    • web_server.py: game service script
    • web_server_demo.py: game service script (transplant network)

1.1 Process

1.2 Strategic value network

A structure similar to ResNet is adopted, and the SPP module is added.

(Currently, due to the time-consuming training, I ran more than 2,000 self-play chess records after running for more than three weeks. After experiments, the performance of this strategy network is still not good, and it may be that you have not trained enough)

At the same time, another open source policy network and its training weights (network.py, model_5400.pkl) were transplanted for simulation demonstration effect.

1.3 training

Adjust train_agent.pythe file according to the comments, and run the script

Part of the code shows:


if __name__ == '__main__':

    conf = LinXiaoNetConfig()
    conf.set_cuda(True)
    conf.set_input_shape(8, 8)
    conf.set_train_info(5, 16, 1e-2)
    conf.set_checkpoint_config(5, 'checkpoints/v2train')
    conf.set_num_worker(0)
    conf.set_log('log/v2train.log')
    # conf.set_pretrained_path('checkpoints/v2m4000/epoch_15')

    init_logger(conf.log_file)
    logger()(conf)

    device = 'cuda' if conf.use_cuda else 'cpu'

    # 创建策略网络
    model = LinXiaoNet(3)
    model.to(device)

    loss_func = AlphaLoss()
    loss_func.to(device)

    optimizer = torch.optim.SGD(model.parameters(), conf.init_lr, 0.9, weight_decay=5e-4)
    lr_schedule = torch.optim.lr_scheduler.StepLR(optimizer, 1, 0.95)

    # initial config tree
    tree = MonteTree(model, device, chess_size=conf.input_shape[0], simulate_count=500)
    data_cache = TrainDataCache(num_worker=conf.num_worker)

    ep_num = 0
    chess_num = 0
    # config train interval
    train_every_chess = 18

    # 加载检查点
    if conf.pretrain_path is not None:
        model_data, optimizer_data, lr_schedule_data, data_cache, ep_num, chess_num = load_checkpoint(conf.pretrain_path)
        model.load_state_dict(model_data)
        optimizer.load_state_dict(optimizer_data)
        lr_schedule.load_state_dict(lr_schedule_data)
        logger()('successfully load pretrained : {}'.format(conf.pretrain_path))

    while True:
        logger()(f'self chess game no.{chess_num+1} start.')
        # 进行一次自我对弈,获取对弈记录
        chess_record = tree.self_game()
        logger()(f'self chess game no.{chess_num+1} end.')
        # 根据对弈记录生成训练数据
        train_data = generate_train_data(tree.chess_size, chess_record)
        # 将训练数据存入缓存
        for i in range(len(train_data)):
            data_cache.push(train_data[i])
        if chess_num % train_every_chess == 0:
            logger()(f'train start.')
            loader = data_cache.get_loader(conf.batch_size)
            model.train()
            for _ in range(conf.epoch_num):
                loss_record = []
                for bat_state, bat_dist, bat_winner in loader:
                    bat_state, bat_dist, bat_winner = bat_state.to(device), bat_dist.to(device), bat_winner.to(device)
                    optimizer.zero_grad()
                    prob, value = model(bat_state)
                    loss = loss_func(prob, value, bat_dist, bat_winner)
                    loss.backward()
                    optimizer.step()
                    loss_record.append(loss.item())
                logger()(f'train epoch {ep_num} loss: {sum(loss_record) / float(len(loss_record))}')
                ep_num += 1
                if ep_num % conf.checkpoint_save_every_num == 0:
                    save_checkpoint(
                        os.path.join(conf.checkpoint_save_dir, f'epoch_{ep_num}'),
                        ep_num, chess_num, model.state_dict(), optimizer.state_dict(), lr_schedule.state_dict(), data_cache
                    )
            lr_schedule.step()
            logger()(f'train end.')
        chess_num += 1
        save_chess_record(
            os.path.join(conf.checkpoint_save_dir, f'chess_record_{chess_num}.pkl'),
            chess_record
        )
        # break

    pass

1.4 Simulation experiment

Adjust web_server.pythe file according to the comments, load the pretrained weights used, and run the script

Open the URL in the browser: http://127.0.0.1:8080/play the game

Part of the code display

# 用户查询机器落子状态
@app.route('/state/get/<state_id>', methods=['GET'])
def get_state(state_id):
    global state_result
    state_id = int(state_id)
    state = 0
    chess_state = None
    if state_id in state_result.keys() and state_result[state_id] is not None:
        state = 1
        chess_state = state_result[state_id]
        state_result[state_id] = None
    ret = {
        'code': 0,
        'msg': 'OK',
        'data': {
            'state': state,
            'chess_state': chess_state
        }
    }
    return jsonify(ret)


# 游戏开始,为这场游戏创建蒙特卡洛树
@app.route('/game/start', methods=['POST'])
def game_start():
    global trees
    global model, device, chess_size, simulate_count
    tree_id = random.randint(1000, 100000)
    trees[tree_id] = MonteTree(model, device, chess_size=chess_size, simulate_count=simulate_count)
    ret = {
        'code': 0,
        'msg': 'OK',
        'data': {
            'tree_id': tree_id
        }
    }
    return jsonify(ret)


# 游戏结束,销毁蒙特卡洛树
@app.route('/game/end/<tree_id>', methods=['POST'])
def game_end(tree_id):
    global trees
    tree_id = int(tree_id)
    trees[tree_id] = None
    ret = {
        'code': 0,
        'msg': 'OK',
        'data': {}
    }
    return ret


if __name__ == '__main__':
    app.run(
        '0.0.0.0',
        8080
    )

1.5 Simulation experiment (transplant network)

Run the script:python web_server_demo.py

Open the URL in the browser: http://127.0.0.1:8080/play the game

See the source link at the top of the article or at the end of the article

https://download.csdn.net/download/sinat_39620217/88045879

Guess you like

Origin blog.csdn.net/sinat_39620217/article/details/131732626