深度强化学习-DQN算法

论文地址:https://arxiv.org/abs/1312.5602

        先讲下在线,离线,同策略和异策略

        同策略(on-policy)和异策略(off-policy)的根本区别在于生成样本的策略和参数更新时的策略是否相同。

        对于同策略,行为策略和要优化的策略是同一策略,更新了策略后,就用该策略的最新版本对数据进行采样;对于异策略,其使用任意行为策略来对数据进行采样,并利用其更新目标策略。例如, Q 学习在计算下一状态的预期奖励时使用了最大化操作,直接选择最优动作,而当前策略并不一定能选择到最优的动作,因此这里生成样本的策略和学习时的策略不同,所以 Q 学习算法是异策略算法;相对应的 Sarsa 算法则是基于当前的策略直接执行一次动作选择,然后用动作和对应的状态更新当前的策略,因此生成样本的策略和学习时的策略相同,所以 Sarsa 算法为同策略算法。

深度 Q 网络和 Q 学习异同点

        整体来说,两者的目标价值以及价值的更新方式基本相同。但有如下不同点:

1)深度 Q 网络将 Q 学习与深度学习结合,用深度网络来近似动作价值函数,而 Q 学习则是采用表格进行存储。

2)深度 Q 网络采用了经验回放的技巧,从历史数据中随机采样,而 Q 学习直接采用下一个状态的数据进行学习。

深度 Q 网络中的两个技巧——目标网络和经验回放

(1)在深度 Q 网络中某个动作价值函数的更新依赖于其他动作价值函数。如果我们一直更新价值网络的参数,会导致更新目标不断变化,也就是我们在追逐一个不断变化的目标,这样势必会不太稳定。为了解决基于时序差分的网络中,优化目标 Qπ (st, at) = rt + Qπ (st+1, π (st+1)) 左右两侧会同时变化使得训练过程不稳定,从而增大回归难度,目标网络选择将优化目标的右边即 rt + Qπ (st+1, π (st+1)) 固定,通过改变优化目标左边的网络参数进行回归。

(2)对于经验回放,其会构建一个回放缓冲区,用来保存数据,每一个数据的内容包括:状态 st、采取的动作 at、得到的奖励 rt、下一个状态 st+1。我们使用 π 与环境交互多次,把收集到的数据都放到回放缓冲区中。防止占用过多的内存,当回放缓冲区装满后,就会自动删去最早进入缓冲区的数据。在训练时,对于每一轮迭代都有相对应的批量(采样得到),然后用这个批量中的数据去更新 Q 函数。即 Q 函数在采样和训练的时候会用到过去的经验数据,也可以消除样本之间的相关性。

算法流程

        算法伪代码

扫描二维码关注公众号,回复: 14647317 查看本文章

代码实现

DQN

class CNNDQN(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(CNNDQN, self).__init__()
        self._input_shape = input_shape
        self._num_actions = num_actions

        self.features = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        self.fc = nn.Sequential(
            nn.Linear(self.feature_size, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions)
        )

    def forward(self, x):
        x = self.features(x).view(x.size()[0], -1)
        return self.fc(x)

    @property
    def feature_size(self):
        x = self.features(torch.zeros(1, *self._input_shape))
        return x.view(1, -1).size(1)

    def act(self, state, epsilon, device):
        if random() > epsilon:
            state = torch.FloatTensor(np.float32(state)) \
                .unsqueeze(0).to(device)
            q_value = self.forward(state)
            action = q_value.max(1)[1].item()
        else:
            action = randrange(self._num_actions)
        return action

        其中输入的shape为(4,84,84)

初始化网络,由于用的是cpu训练,所以加载模型时映射到cpu上

def load_model(environment, model, target_model):
    model_name = join('pretrained_models', '%s.pth' % environment)
    model.load_state_dict(torch.load(model_name,map_location='cpu'))
    target_model.load_state_dict(model.state_dict())
    return model, target_model


def initialize_models(environment, env, device, transfer):
    model = CNNDQN(env.observation_space.shape,
                   env.action_space.n).to(device)
    target_model = CNNDQN(env.observation_space.shape,
                          env.action_space.n).to(device)
    if transfer:
        model, target_model = load_model(environment, model, target_model)
    return model, target_model

计算loss

def compute_td_loss(model, target_net, replay_buffer, gamma, device,
                    batch_size, beta):
    batch = replay_buffer.sample(batch_size, beta)
    state, action, reward, next_state, done, indices, weights = batch

    state = Variable(FloatTensor(np.float32(state))).to(device)
    next_state = Variable(FloatTensor(np.float32(next_state))).to(device)
    action = Variable(LongTensor(action)).to(device)
    reward = Variable(FloatTensor(reward)).to(device)
    done = Variable(FloatTensor(done)).to(device)
    weights = Variable(FloatTensor(weights)).to(device)

    q_values = model(state)
    next_q_values = target_net(next_state)

    q_value = q_values.gather(1, action.unsqueeze(-1)).squeeze(-1)
    next_q_value = next_q_values.max(1)[0]
    expected_q_value = reward + gamma * next_q_value * (1 - done)

    loss = (q_value - expected_q_value.detach()).pow(2) * weights
    prios = loss + 1e-5
    loss = loss.mean()
    loss.backward()
    replay_buffer.update_priorities(indices, prios.data.cpu().numpy())

算法图解

​​​​​​​

猜你喜欢

转载自blog.csdn.net/athrunsunny/article/details/126976433