近端策略优化


近端策略优化(Proximal Policy Optimization,PPO)是一种流行的强化学习算法,它在实现简单性、样本效率和性能之间取得了良好的平衡。PPO 是一种在线策略(on-policy)算法,意味着它通过当前策略与环境的交互来学习。PPO 是对信任域策略优化(Trust Region Policy Optimization, TRPO)的改进,广泛应用于研究和实际场景中。

以下是 PPO 的核心思想及其关键组成部分:


PPO 的核心概念

  1. 策略优化
  2. PPO 优化的是一个随机策略 (\pi_\theta(a|s)),它根据当前状态 (s) 输出动作 (a) 的概率分布。
  3. 目标是通过调整策略参数 (\theta) 来最大化累积奖励。

  4. 替代目标函数

  5. PPO 使用一个替代目标函数来确保更新的稳定性。该函数是对策略梯度目标的裁剪版本,防止更新过大而导致训练不稳定。
  6. 替代目标函数为: [ L^{CLIP}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t \right) \right] ] 其中:

    • (r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}) 是新旧策略的概率比。
    • (\hat{A}_t) 是时间步 (t) 的优势估计值。
    • (\epsilon) 是一个超参数(通常设为 0.1 或 0.2),用于控制裁剪范围。
  7. 优势估计

  8. PPO 使用广义优势估计(Generalized Advantage Estimation, GAE)来计算优势函数 (\hat{A}_t),衡量某个动作相对于平均动作的优势。
  9. GAE 在偏差和方差之间取得了平衡。

  10. 裁剪机制

  11. 裁剪机制确保策略在单次更新中不会变化过大,从而提高训练的稳定性。

  12. 多轮优化

  13. PPO 对同一批数据进行多轮优化,提高了样本效率。

  14. 价值函数

  15. PPO 同时学习一个价值函数 (V_\phi(s)),用于估计从某个状态开始的期望累积奖励。价值函数用于计算优势并减少策略更新的方差。

PPO 算法步骤

  1. 收集轨迹
  2. 使用当前策略 (\pi_\theta) 与环境交互,收集轨迹(状态、动作和奖励的序列)。

  3. 计算优势

  4. 使用 GAE 计算每个时间步的优势值 (\hat{A}_t)。

  5. 优化替代目标函数

  6. 对裁剪后的替代目标函数 (L^{CLIP}(\theta)) 进行多轮梯度上升优化。

  7. 更新价值函数

  8. 使用均方误差(MSE)更新价值函数 (V_\phi(s)),使其逼近实际回报。

  9. 重复

  10. 重复上述过程,直到策略收敛或达到停止条件。

关键超参数

  • 裁剪范围 ((\epsilon)):通常设为 0.1 或 0.2。
  • 学习率:控制优化器的步长。
  • 优化轮数:每批数据的优化次数。
  • GAE 参数 ((\lambda)):控制优势估计的偏差-方差权衡。
  • 折扣因子 ((\gamma)):用于折扣未来奖励。

PPO 的优点

  • 稳定性:裁剪机制防止策略更新过大,训练更稳定。
  • 样本效率:对同一批数据进行多轮优化,提高了样本效率。
  • 易于实现:相比 TRPO 等算法,PPO 的实现更简单。

伪代码示例

import torch
import torch.optim as optim

def ppo_update(policy, optimizer, states, actions, rewards, advantages, epsilon=0.2):
    # 计算旧策略的对数概率
    old_log_probs = policy.get_log_prob(states, actions).detach()

    for _ in range(num_epochs):
        # 计算新策略的对数概率和熵
        new_log_probs = policy.get_log_prob(states, actions)
        entropy = policy.get_entropy(states)

        # 计算概率比
        ratio = torch.exp(new_log_probs - old_log_probs)

        # 计算替代损失
        surrogate1 = ratio * advantages
        surrogate2 = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * advantages
        policy_loss = -torch.min(surrogate1, surrogate2).mean()

        # 计算价值函数损失
        value_loss = ((rewards - policy.get_value(states)) ** 2).mean()

        # 总损失
        loss = policy_loss + 0.5 * value_loss - 0.01 * entropy

        # 更新策略
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

常用库

  • Stable-Baselines3:一个流行的强化学习库,支持 PPO。
  • Ray RLlib:一个可扩展的强化学习库,支持 PPO。
  • PyTorch/TensorFlow:可以使用这些框架自定义实现 PPO。

PPO 是一种强大且通用的算法,适用于从简单控制任务到复杂游戏和机器人等多种环境。如果你有具体问题或需要更详细的实现细节,可以进一步讨论!