Proximal Policy Optimization (PPO)

PPO is a model-free, stochastic on-policy policy gradient algorithm that alternates between sampling data through interaction with the environment, and optimizing a surrogate objective function while avoiding that the new policy does not move too far away from the old one

Paper: Proximal Policy Optimization Algorithms


For each iteration do:
\(\bullet \;\) Collect, in a rollout memory, a set of states \(s\), actions \(a\), rewards \(r\), dones \(d\), log probabilities \(logp\) and values \(V\) on policy using \(\pi_\theta\) and \(V_\phi\)
\(\bullet \;\) Estimate returns \(R\) and advantages \(A\) using Generalized Advantage Estimation (GAE(\(\lambda\))) from the collected data [\(r, d, V\)]
\(\bullet \;\) Compute the entropy loss \({L}_{entropy}\)
\(\bullet \;\) Compute the clipped surrogate objective (policy loss) with \(ratio\) as the probability ratio between the action under the current policy and the action under the previous policy: \(L^{clip}_{\pi_\theta} = \mathbb{E}[\min(A \; ratio, A \; \text{clip}(ratio, 1-c, 1+c))]\)
\(\bullet \;\) Compute the value loss \(L_{V_\phi}\) as the mean squared error (MSE) between the predicted values \(V_{_{predicted}}\) and the estimated returns \(R\)
\(\bullet \;\) Optimize the total loss \(L = L^{clip}_{\pi_\theta} - c_1 \, L_{V_\phi} + c_2 \, {L}_{entropy}\)

Algorithm implementation

Main notation/symbols:
- policy function approximator (\(\pi_\theta\)), value function approximator (\(V_\phi\))
- states (\(s\)), actions (\(a\)), rewards (\(r\)), next states (\(s'\)), dones (\(d\))
- values (\(V\)), advantages (\(A\)), returns (\(R\))
- log probabilities (\(logp\))
- loss (\(L\))

Learning algorithm

def \(\;f_{GAE} (r, d, V, V_{_{last}}') \;\rightarrow\; R, A:\)
\(adv \leftarrow 0\)
\(A \leftarrow \text{zeros}(r)\)
# advantages computation
FOR each reverse iteration \(i\) up to the number of rows in \(r\) DO
IF \(i\) is not the last row of \(r\) THEN
\(V_i' = V_{i+1}\)
\(V_i' \leftarrow V_{_{last}}'\)
\(adv \leftarrow r_i - V_i \, +\) discount_factor \(\neg d_i \; (V_i' \, -\) lambda \(adv)\)
\(A_i \leftarrow adv\)
# returns computation
\(R \leftarrow A + V\)
# normalize advantages
\(A \leftarrow \dfrac{A - \bar{A}}{A_\sigma + 10^{-8}}\)

# compute returns and advantages
\(V_{_{last}}' \leftarrow V_\phi(s')\)
\(R, A \leftarrow f_{GAE}(r, d, V, V_{_{last}}')\)
# sample mini-batches from memory
[[\(s, a, logp, V, R, A\)]] \(\leftarrow\) states, actions, log_prob, values, returns, advantages
# learning epochs
FOR each learning epoch up to learning_epochs DO
# mini-batches loop
FOR each mini-batch [\(s, a, logp, V, R, A\)] up to mini_batches DO
\(logp' \leftarrow \pi_\theta(s, a)\)
# compute approximate KL divergence
\(ratio \leftarrow logp' - logp\)
\(KL_{_{divergence}} \leftarrow \frac{1}{N} \sum_{i=1}^N ((e^{ratio} - 1) - ratio)\)
# early stopping with KL divergence
IF \(KL_{_{divergence}} >\) kl_threshold THEN
# compute entropy loss
IF entropy computation is enabled THEN
\({L}_{entropy} \leftarrow \, -\) entropy_loss_scale \(\frac{1}{N} \sum_{i=1}^N \pi_{\theta_{entropy}}\)
\({L}_{entropy} \leftarrow 0\)
# compute policy loss
\(ratio \leftarrow e^{logp' - logp}\)
\(L_{_{surrogate}} \leftarrow A \; ratio\)
\(L_{_{clipped\,surrogate}} \leftarrow A \; \text{clip}(ratio, 1 - c, 1 + c) \qquad\) with \(c\) as ratio_clip
\(L^{clip}_{\pi_\theta} \leftarrow - \frac{1}{N} \sum_{i=1}^N \min(L_{_{surrogate}}, L_{_{clipped\,surrogate}})\)
# compute value loss
\(V_{_{predicted}} \leftarrow V_\phi(s)\)
IF clip_predicted_values is enabled THEN
\(V_{_{predicted}} \leftarrow V + \text{clip}(V_{_{predicted}} - V, -c, c) \qquad\) with \(c\) as value_clip
\(L_{V_\phi} \leftarrow\) value_loss_scale \(\frac{1}{N} \sum_{i=1}^N (R - V_{_{predicted}})^2\)
# optimization step
reset \(\text{optimizer}_{\theta, \phi}\)
\(\nabla_{\theta, \, \phi} (L^{clip}_{\pi_\theta} + {L}_{entropy} + L_{V_\phi})\)
\(\text{clip}(\lVert \nabla_{\theta, \, \phi} \rVert)\) with grad_norm_clip
step \(\text{optimizer}_{\theta, \phi}\)
# update learning rate
IF there is a learning_rate_scheduler THEN
step \(\text{scheduler}_{\theta, \phi} (\text{optimizer}_{\theta, \phi})\)



Support for recurrent neural networks (RNN, LSTM, GRU and any other variant) is implemented in a separate file ( to maintain the readability of the standard implementation (

# import the agent and its default configuration
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG

# instantiate the agent's models
models = {}
models["policy"] = ...
models["value"] = ...  # only required during training

# adjust some configuration if necessary
cfg_agent = PPO_DEFAULT_CONFIG.copy()
cfg_agent["<KEY>"] = ...

# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
agent = PPO(models=models,
            memory=memory,  # only required during training

Configuration and hyperparameters

    "rollouts": 16,                 # number of rollouts before updating
    "learning_epochs": 8,           # number of learning epochs during each update
    "mini_batches": 2,              # number of mini batches during each learning epoch

    "discount_factor": 0.99,        # discount factor (gamma)
    "lambda": 0.95,                 # TD(lambda) coefficient (lam) for computing returns and advantages

    "learning_rate": 1e-3,                  # learning rate
    "learning_rate_scheduler": None,        # learning rate scheduler class (see torch.optim.lr_scheduler)
    "learning_rate_scheduler_kwargs": {},   # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3})

    "state_preprocessor": None,             # state preprocessor class (see skrl.resources.preprocessors)
    "state_preprocessor_kwargs": {},        # state preprocessor's kwargs (e.g. {"size": env.observation_space})
    "value_preprocessor": None,             # value preprocessor class (see skrl.resources.preprocessors)
    "value_preprocessor_kwargs": {},        # value preprocessor's kwargs (e.g. {"size": 1})

    "random_timesteps": 0,          # random exploration steps
    "learning_starts": 0,           # learning starts after this many steps

    "grad_norm_clip": 0.5,              # clipping coefficient for the norm of the gradients
    "ratio_clip": 0.2,                  # clipping coefficient for computing the clipped surrogate objective
    "value_clip": 0.2,                  # clipping coefficient for computing the value loss (if clip_predicted_values is True)
    "clip_predicted_values": False,     # clip predicted values during value loss computation

    "entropy_loss_scale": 0.0,      # entropy loss scaling factor
    "value_loss_scale": 1.0,        # value loss scaling factor

    "kl_threshold": 0,              # KL divergence threshold for early stopping

    "rewards_shaper": None,         # rewards shaping function: Callable(reward, timestep, timesteps) -> reward
    "time_limit_bootstrap": False,  # bootstrap at timeout termination (episode truncation)

    "experiment": {
        "directory": "",            # experiment's parent directory
        "experiment_name": "",      # experiment name
        "write_interval": 250,      # TensorBoard writing interval (timesteps)

        "checkpoint_interval": 1000,        # interval for checkpoints (timesteps)
        "store_separately": False,          # whether to store checkpoints separately

        "wandb": False,             # whether to use Weights & Biases
        "wandb_kwargs": {}          # wandb kwargs (see


The implementation supports the following Gym spaces / Gymnasium spaces

The implementation uses 1 stochastic (discrete or continuous) and 1 deterministic function approximator. These function approximators (models) must be collected in a dictionary and passed to the constructor of the class under the argument models




Input shape

Categorical /
Multi-Categorical /
Support for advanced features is described in the next table


Shared model

for Policy and Value



RNN support

RNN, LSTM, GRU and any other variant



API (PyTorch)


class skrl.agents.torch.ppo.PPO(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None)

Bases: Agent

__init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None) None

Proximal Policy Optimization (PPO)

  • models (dictionary of skrl.models.torch.Model) – Models used by the agent

  • memory (skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

  • observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

  • action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

  • device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

  • cfg (dict) – Configuration dictionary


KeyError – If the models dictionary is missing a required key

_update(timestep: int, timesteps: int) None

Algorithm’s main update step

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

act(states: torch.Tensor, timestep: int, timesteps: int) torch.Tensor

Process the environment’s states to make a decision (actions) using the main policy

  • states (torch.Tensor) – Environment’s states

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps



Return type:


init(trainer_cfg: Mapping[str, Any] | None = None) None

Initialize the agent

post_interaction(timestep: int, timesteps: int) None

Callback called after the interaction with the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None

Callback called before the interaction with the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

record_transition(states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None

Record an environment transition in memory

  • states (torch.Tensor) – Observations/states of the environment used to make the decision

  • actions (torch.Tensor) – Actions taken by the agent

  • rewards (torch.Tensor) – Instant rewards achieved by the current actions

  • next_states (torch.Tensor) – Next observations/states of the environment

  • terminated (torch.Tensor) – Signals to indicate that episodes have terminated

  • truncated (torch.Tensor) – Signals to indicate that episodes have been truncated

  • infos (Any type supported by the environment) – Additional information about the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

class skrl.agents.torch.ppo.PPO_RNN(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None)

Bases: Agent

__init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None) None

Proximal Policy Optimization (PPO) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.)

  • models (dictionary of skrl.models.torch.Model) – Models used by the agent

  • memory (skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

  • observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

  • action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

  • device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

  • cfg (dict) – Configuration dictionary


KeyError – If the models dictionary is missing a required key

_update(timestep: int, timesteps: int) None

Algorithm’s main update step

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

act(states: torch.Tensor, timestep: int, timesteps: int) torch.Tensor

Process the environment’s states to make a decision (actions) using the main policy

  • states (torch.Tensor) – Environment’s states

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps



Return type:


init(trainer_cfg: Mapping[str, Any] | None = None) None

Initialize the agent

post_interaction(timestep: int, timesteps: int) None

Callback called after the interaction with the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None

Callback called before the interaction with the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

record_transition(states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None

Record an environment transition in memory

  • states (torch.Tensor) – Observations/states of the environment used to make the decision

  • actions (torch.Tensor) – Actions taken by the agent

  • rewards (torch.Tensor) – Instant rewards achieved by the current actions

  • next_states (torch.Tensor) – Next observations/states of the environment

  • terminated (torch.Tensor) – Signals to indicate that episodes have terminated

  • truncated (torch.Tensor) – Signals to indicate that episodes have been truncated

  • infos (Any type supported by the environment) – Additional information about the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps



class skrl.agents.jax.ppo.PPO(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | jax.Device | None = None, cfg: dict | None = None)

Bases: Agent

__init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | jax.Device | None = None, cfg: dict | None = None) None

Proximal Policy Optimization (PPO)

  • models (dictionary of skrl.models.jax.Model) – Models used by the agent

  • memory (skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

  • observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

  • action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

  • device (str or jax.Device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

  • cfg (dict) – Configuration dictionary


KeyError – If the models dictionary is missing a required key

_update(timestep: int, timesteps: int) None

Algorithm’s main update step

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

act(states: ndarray | jax.Array, timestep: int, timesteps: int) ndarray | jax.Array

Process the environment’s states to make a decision (actions) using the main policy

  • states (np.ndarray or jax.Array) – Environment’s states

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps



Return type:

np.ndarray or jax.Array

init(trainer_cfg: Mapping[str, Any] | None = None) None

Initialize the agent

post_interaction(timestep: int, timesteps: int) None

Callback called after the interaction with the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None

Callback called before the interaction with the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

record_transition(states: ndarray | jax.Array, actions: ndarray | jax.Array, rewards: ndarray | jax.Array, next_states: ndarray | jax.Array, terminated: ndarray | jax.Array, truncated: ndarray | jax.Array, infos: Any, timestep: int, timesteps: int) None

Record an environment transition in memory

  • states (np.ndarray or jax.Array) – Observations/states of the environment used to make the decision

  • actions (np.ndarray or jax.Array) – Actions taken by the agent

  • rewards (np.ndarray or jax.Array) – Instant rewards achieved by the current actions

  • next_states (np.ndarray or jax.Array) – Next observations/states of the environment

  • terminated (np.ndarray or jax.Array) – Signals to indicate that episodes have terminated

  • truncated (np.ndarray or jax.Array) – Signals to indicate that episodes have been truncated

  • infos (Any type supported by the environment) – Additional information about the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps