Deeplearning4j 实战 (9):强化学习 -- Cartpole任务的训练和效果测试

在之前的博客中,我用Deeplearning4j构建深度神经网络来解决监督、无监督的机器学习问题。但除了这两类问题外,强化学习也是机器学习中一个重要的分支,并且Deeplearning4j的子项目--Rl4j提供了对部分强化学习算法的支持。这里,就以强化学习中的经典任务--Cartpole问题作为学习Rl4j的入门例子,讲解从环境搭建、模型训练再到最后的效果评估的结果。

Cartpole描述的问题可以认为是:在一辆小车上竖立一根杆子,然后给小车一个推或者拉的力,使得杆子尽量保持平衡不滑倒。更详细的描述可参见openai官网上关于Cartpole问题的解释:https://gym.openai.com/envs/CartPole-v0

接着给出强化学习的一些概念:environment,action,reward

environment:描述强化学习问题中的外部环境,比如:Cartpole问题中杆子的角度,小车的位置、速度等。

action:在不同外部环境条件下采取的动作,比如:Cartpole问题中对于小车施加推或者拉的力。action可以是离散的集合,也可以是连续的。

reward:对于agent/network作出的action后获取的回报/评价。比如:Cartpole问题中如果施加的力可以继续让杆子保持平衡,那reward就+1。

在描述reward这个概念时,提到了agent这个概念,在实际应用中,agent可以用神经网络来实现。

对于强化学习训练后的agent来说,学习到的是如何在变化中的environment和reward选择action的能力。通常有两种学习策略可以选择:Policy-Based和Value-Based。 Policy-Based直接学习action,通过Policy Gradient来更新模型参数,而相对的,Value-Based是最优化action所带来的reward(action-value function,Q-function)来间接选取action。一般认为如果action是离散的,那么Value-Based会优于Policy-Based,而连续的action则相反。在这里主要讨论Value-Based的学习策略,或者更具体的说Q-learning的问题。对于Policy-Based还有Model-Based不做讨论。

Q-learning的概念早在20多年前就已经提出,再与近年来流行的深度神经网络结合产生了DQN的概念。Q-learning的目标是最大化Q值从而学习到选取action的策略。Q-leaning学习的策略公式:

Q(st,at)Q(st,at)+α[rt+1+λmaxaQ(st+1,a)Q(st,at)]

对于这里主要讨论的Catpole问题,我们也采用Q-learning来实现。

可以看到,与监督学习相比,强化学习多了action,environment等概念。虽然可以将reward类比成监督学习中的label(或者反过来,label也可以认为是强化学习中最终的reward),但通过action与environment不断的交互甚至改变environment这一特点,是监督学习中所没有的。在构建应用的时候,监督学习的学习的目标:label,灌入的数据都是一个定值。比如,图像的分类的问题,在用CNN训练的时候,图片本身不发生变化,label也不会发生变化,唯一变化的是神经网络中的权重值。但强化学习在训练的时候,除了神经网络中的权重会发生变化(如果用NN建模的话),environment、reward等都会发生动态的变化。这样构建合适正确的训练数据会比较麻烦,容易出错,所以对于CartPole问题,我们可以采用openAI提供的强化学习开发环境gym来训练/测试agent。

gym的官方地址:https://gym.openai.com/

gym提供了棋类、视频游戏等强化学习问题的学习/测试/算法效果比较的环境。这里要处理的Cartpole问题,gym也提供了环境的支持。但是,除了python,gym对其他语言的支持不是很友好,所以为了可以获取gym中的数据,RL4j提供了对gym-http-api(https://github.com/openai/gym-http-api)调用的包装类。gym-http-api是为了方便除python外的其他语言也可以使用gym环境数据的一个REST接口。简单来说,对于像RL4j这样以Java实现的强化学习算法库可以通过gym-http-api获取gym提供的数据。

gym的REST接口的安装可以参见之前给出的github地址,里面有详细的描述。下面先给出gym-http-api的安装和启动过程的截图:



下面就结合上面说的内容,给出RL4j的Catpole实现逻辑

1. 定义Q-learning的参数以及神经网络结构,两者共同决定DQN的属性

2. 定义读取gym数据的包装类对象

3. 训练DQN并保存模型

4. 加载保存的模型并测试

这里先贴下需要的Maven依赖以及代码版本

  <properties>
	  <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> 
	  <nd4j.version>0.8.0</nd4j.version>  
	  <dl4j.version>0.8.0</dl4j.version>  
	  <datavec.version>0.8.0</datavec.version>  
	  <rl4j.version>0.8.0</rl4j.version>
	  <scala.binary.version>2.10</scala.binary.version>  
  </properties>
  <dependencies>
	<dependency>  
		<groupId>org.nd4j</groupId>  
		<artifactId>nd4j-native</artifactId>   
		<version>${nd4j.version}</version>  
	</dependency>  
        <dependency>  
		<groupId>org.deeplearning4j</groupId>  
		<artifactId>deeplearning4j-core</artifactId>  
		<version>${dl4j.version}</version>  
	</dependency>  
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>rl4j-core</artifactId>
            <version>${rl4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>rl4j-gym</artifactId>
            <version>${rl4j.version}</version>
        </dependency>

  </dependencies>


第一部分的代码如下:

    public static QLearning.QLConfiguration CARTPOLE_QL =
            new QLearning.QLConfiguration(
                    123,    //Random seed
                    200,    //Max step By epoch
                    150000, //Max step
                    150000, //Max size of experience replay
                    32,     //size of batches
                    500,    //target update (hard)
                    10,     //num step noop warmup
                    0.01,   //reward scaling
                    0.99,   //gamma
                    1.0,    //td-error clipping
                    0.1f,   //min epsilon
                    1000,   //num step for eps greedy anneal
                    true    //double DQN
            );

    public static DQNFactoryStdDense.Configuration CARTPOLE_NET = DQNFactoryStdDense.Configuration.builder()            												.l2(0.001)            																        .learningRate(0.0005)
       								.numHiddenNodes(16)
           							.numLayer(3)
            							.build();

第一部分中定义Q-learning的参数,包括每一轮的训练的可采取的action的步数,最大步数以及存储过往action的最大步数等。除此以外,DQNFactoryStdDense用来定义基于MLP的DQN网络结构,包括网络深度等常见参数。这里的代码定义的是一个三层(只有一层隐藏层)的全连接神经网络。

接下来,定义两个方法分别用于训练和测试。catpole方法用于训练DQN,而loadCartpole则用于测试。

训练:

    public static void cartPole() {

        //record the training data in rl4j-data in a new folder (save)
        DataManager manager = new DataManager(true);

        //define the mdp from gym (name, render)
        GymEnv<Box, Integer, DiscreteSpace> mdp = null;
        try {
            mdp = new GymEnv<Box, Integer, DiscreteSpace>("CartPole-v0", false, false);
        } catch (RuntimeException e){
            System.out.print("To run this example, download and start the gym-http-api repo found at https://github.com/openai/gym-http-api.");
        }
        //define the training
        QLearningDiscreteDense<Box> dql = new QLearningDiscreteDense<Box>(mdp, CARTPOLE_NET, CARTPOLE_QL, manager);

        //train
        dql.train();

        //get the final policy
        DQNPolicy<Box> pol = dql.getPolicy();

        //serialize and save (serialization showcase, but not required)
        pol.save("/tmp/pol1");

        //close the mdp (close http)
        mdp.close();

    }

测试:

    public static void loadCartpole(){

        //showcase serialization by using the trained agent on a new similar mdp (but render it this time)

        //define the mdp from gym (name, render)
        GymEnv<Box, Integer, DiscreteSpace> mdp2 = new GymEnv<Box, Integer, DiscreteSpace>("CartPole-v0", true, false);

        //load the previous agent
        DQNPolicy<Box> pol2 = DQNPolicy.load("/tmp/pol1");

        //evaluate the agent
        double rewards = 0;
        for (int i = 0; i < 1000; i++) {
            mdp2.reset();
            double reward = pol2.play(mdp2);
            rewards += reward;
            Logger.getAnonymousLogger().info("Reward: " + reward);
        }

        Logger.getAnonymousLogger().info("average: " + rewards/1000);
        
        mdp2.close();

    }

在训练模型的方法中,包含了第二、三步的内容。首先需要定义gym数据读取对象,即代码中的GymEnv<Box, Integer, DiscreteSpace> mdp。它会通过gym-http-api的接口读取训练数据。接着,将第一步中定义的Q-learning的相关参数,神经网络结构作为参数传入DQN训练的包装类中。其中DataManager的作用是用来管理训练数据。

测试部分的代码实现了之前说的第四步,即加载策略模型并进行测试的过程。在测试的过程中,将每次action的reward打上log,并最后求取平均的reward。

训练的过程截图如下:


最后我们其实最关心的还是这个模型的效果。纯粹通过平均reward的数值大小可能并不是非常的直观,因此这里给出一张gif的效果图:


总结一下Cartpole问题的整个解决过程。首先我们明确,这是一个强化学习的问题,而不是传统的监督学习,因为涉及到与环境的交互等因素。然后,利用openAI提供的强化学习开发环境gym来构建训练平台,而RL4j则可以定义并训练DQN。最后的效果就是上面这张gif图片。需要注意的是,这张gif效果图并非是RL4j直接生成的,而是通过xvfb命令截取虚拟monitor的在每个action后的效果拼接起来的图。具体可先查阅xvfb的相关内容。

猜你喜欢

转载自blog.csdn.net/wangongxi/article/details/73921083
今日推荐