Robust Policy Optimization (RPO)

RPO is a model-free, stochastic on-policy policy gradient algorithm that adds a uniform random perturbation to a base parameterized distribution to help the agent maintain a certain level of stochasticity throughout the training process

Paper: Robust Policy Optimization in Deep Reinforcement Learning



Algorithm

Note

This algorithm is built on top of the PPO algorithm and simply adds the alpha hyperparameter to the policy input dictionary. It is the responsibility of the user to make use of this hyper-parameter to modify the parameterized distribution.

class Policy(GaussianMixin, Model):
    ...

    def compute(self, inputs, role):
        # compute the mean actions using the neural network
        mean_actions = self.net(inputs["states"])

        # perturb the mean actions by adding a randomized uniform sample
        rpo_alpha = inputs["alpha"]
        perturbation = torch.zeros_like(mean_actions).uniform_(-rpo_alpha, rpo_alpha)
        mean_actions += perturbation

        return mean_actions, self.log_std_parameter, {}

Algorithm implementation

Main notation/symbols:
- policy function approximator (\(\pi_\theta\)), value function approximator (\(V_\phi\))
- states (\(s\)), actions (\(a\)), rewards (\(r\)), next states (\(s'\)), dones (\(d\))
- values (\(V\)), advantages (\(A\)), returns (\(R\))
- log probabilities (\(logp\))
- loss (\(L\))

Learning algorithm


compute_gae(...)
def \(\;f_{GAE} (r, d, V, V_{_{last}}') \;\rightarrow\; R, A:\)
\(adv \leftarrow 0\)
\(A \leftarrow \text{zeros}(r)\)
# advantages computation
FOR each reverse iteration \(i\) up to the number of rows in \(r\) DO
IF \(i\) is not the last row of \(r\) THEN
\(V_i' = V_{i+1}\)
ELSE
\(V_i' \leftarrow V_{_{last}}'\)
\(adv \leftarrow r_i - V_i \, +\) discount_factor \(\neg d_i \; (V_i' \, -\) lambda \(adv)\)
\(A_i \leftarrow adv\)
# returns computation
\(R \leftarrow A + V\)
# normalize advantages
\(A \leftarrow \dfrac{A - \bar{A}}{A_\sigma + 10^{-8}}\)

_update(...)
# compute returns and advantages
\(V_{_{last}}' \leftarrow V_\phi(s')\)
\(R, A \leftarrow f_{GAE}(r, d, V, V_{_{last}}')\)
# sample mini-batches from memory
[[\(s, a, logp, V, R, A\)]] \(\leftarrow\) states, actions, log_prob, values, returns, advantages
# learning epochs
FOR each learning epoch up to learning_epochs DO
# mini-batches loop
FOR each mini-batch [\(s, a, logp, V, R, A\)] up to mini_batches DO
\(logp' \leftarrow \pi_\theta(s, a)\)
# compute approximate KL divergence
\(ratio \leftarrow logp' - logp\)
\(KL_{_{divergence}} \leftarrow \frac{1}{N} \sum_{i=1}^N ((e^{ratio} - 1) - ratio)\)
# early stopping with KL divergence
IF \(KL_{_{divergence}} >\) kl_threshold THEN
BREAK LOOP
# compute entropy loss
IF entropy computation is enabled THEN
\({L}_{entropy} \leftarrow \, -\) entropy_loss_scale \(\frac{1}{N} \sum_{i=1}^N \pi_{\theta_{entropy}}\)
ELSE
\({L}_{entropy} \leftarrow 0\)
# compute policy loss
\(ratio \leftarrow e^{logp' - logp}\)
\(L_{_{surrogate}} \leftarrow A \; ratio\)
\(L_{_{clipped\,surrogate}} \leftarrow A \; \text{clip}(ratio, 1 - c, 1 + c) \qquad\) with \(c\) as ratio_clip
\(L^{clip}_{\pi_\theta} \leftarrow - \frac{1}{N} \sum_{i=1}^N \min(L_{_{surrogate}}, L_{_{clipped\,surrogate}})\)
# compute value loss
\(V_{_{predicted}} \leftarrow V_\phi(s)\)
IF clip_predicted_values is enabled THEN
\(V_{_{predicted}} \leftarrow V + \text{clip}(V_{_{predicted}} - V, -c, c) \qquad\) with \(c\) as value_clip
\(L_{V_\phi} \leftarrow\) value_loss_scale \(\frac{1}{N} \sum_{i=1}^N (R - V_{_{predicted}})^2\)
# optimization step
reset \(\text{optimizer}_{\theta, \phi}\)
\(\nabla_{\theta, \, \phi} (L^{clip}_{\pi_\theta} + {L}_{entropy} + L_{V_\phi})\)
\(\text{clip}(\lVert \nabla_{\theta, \, \phi} \rVert)\) with grad_norm_clip
step \(\text{optimizer}_{\theta, \phi}\)
# update learning rate
IF there is a learning_rate_scheduler THEN
step \(\text{scheduler}_{\theta, \phi} (\text{optimizer}_{\theta, \phi})\)

Usage

Note

Support for recurrent neural networks (RNN, LSTM, GRU and any other variant) is implemented in a separate file (rpo_rnn.py) to maintain the readability of the standard implementation (rpo.py)

# import the agent and its default configuration
from skrl.agents.torch.rpo import RPO, RPO_DEFAULT_CONFIG

# instantiate the agent's models
models = {}
models["policy"] = ...
models["value"] = ...  # only required during training

# adjust some configuration if necessary
cfg_agent = RPO_DEFAULT_CONFIG.copy()
cfg_agent["<KEY>"] = ...

# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
agent = RPO(models=models,
            memory=memory,  # only required during training
            cfg=cfg_agent,
            observation_space=env.observation_space,
            action_space=env.action_space,
            device=env.device)

Configuration and hyperparameters

RPO_DEFAULT_CONFIG = {
    "rollouts": 16,                 # number of rollouts before updating
    "learning_epochs": 8,           # number of learning epochs during each update
    "mini_batches": 2,              # number of mini batches during each learning epoch

    "alpha": 0.5,                   # amount of uniform random perturbation on the mean actions: U(-alpha, alpha)
    "discount_factor": 0.99,        # discount factor (gamma)
    "lambda": 0.95,                 # TD(lambda) coefficient (lam) for computing returns and advantages

    "learning_rate": 1e-3,                  # learning rate
    "learning_rate_scheduler": None,        # learning rate scheduler class (see torch.optim.lr_scheduler)
    "learning_rate_scheduler_kwargs": {},   # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3})

    "state_preprocessor": None,             # state preprocessor class (see skrl.resources.preprocessors)
    "state_preprocessor_kwargs": {},        # state preprocessor's kwargs (e.g. {"size": env.observation_space})
    "value_preprocessor": None,             # value preprocessor class (see skrl.resources.preprocessors)
    "value_preprocessor_kwargs": {},        # value preprocessor's kwargs (e.g. {"size": 1})

    "random_timesteps": 0,          # random exploration steps
    "learning_starts": 0,           # learning starts after this many steps

    "grad_norm_clip": 0.5,              # clipping coefficient for the norm of the gradients
    "ratio_clip": 0.2,                  # clipping coefficient for computing the clipped surrogate objective
    "value_clip": 0.2,                  # clipping coefficient for computing the value loss (if clip_predicted_values is True)
    "clip_predicted_values": False,     # clip predicted values during value loss computation

    "entropy_loss_scale": 0.0,      # entropy loss scaling factor
    "value_loss_scale": 1.0,        # value loss scaling factor

    "kl_threshold": 0,              # KL divergence threshold for early stopping

    "rewards_shaper": None,         # rewards shaping function: Callable(reward, timestep, timesteps) -> reward
    "time_limit_bootstrap": False,  # bootstrap at timeout termination (episode truncation)

    "experiment": {
        "directory": "",            # experiment's parent directory
        "experiment_name": "",      # experiment name
        "write_interval": 250,      # TensorBoard writing interval (timesteps)

        "checkpoint_interval": 1000,        # interval for checkpoints (timesteps)
        "store_separately": False,          # whether to store checkpoints separately

        "wandb": False,             # whether to use Weights & Biases
        "wandb_kwargs": {}          # wandb kwargs (see https://docs.wandb.ai/ref/python/init)
    }
}

Spaces

The implementation supports the following Gym spaces / Gymnasium spaces

Gym/Gymnasium spaces

Observation

Action

Discrete

\(\square\)

\(\square\)

MultiDiscrete

\(\square\)

\(\square\)

Box

\(\blacksquare\)

\(\blacksquare\)

Dict

\(\blacksquare\)

\(\square\)


Models

The implementation uses 1 continuous stochastic and 1 deterministic function approximator. These function approximators (models) must be collected in a dictionary and passed to the constructor of the class under the argument models

Notation

Concept

Key

Input shape

Output shape

Type

\(\pi_\theta(s)\)

Policy

"policy"

observation

action

Gaussian /
MultivariateGaussian

\(V_\phi(s)\)

Value

"value"

observation

1

Deterministic


Features

Support for advanced features is described in the next table

Feature

Support and remarks

    pytorch    

    jax    

Shared model

for Policy and Value

\(\blacksquare\)

\(\square\)

RNN support

RNN, LSTM, GRU and any other variant

\(\blacksquare\)

\(\square\)


API (PyTorch)

skrl.agents.torch.rpo.RPO_DEFAULT_CONFIG

alias of {‘alpha’: 0.5, ‘clip_predicted_values’: False, ‘discount_factor’: 0.99, ‘entropy_loss_scale’: 0.0, ‘experiment’: {‘checkpoint_interval’: 1000, ‘directory’: ‘’, ‘experiment_name’: ‘’, ‘store_separately’: False, ‘wandb’: False, ‘wandb_kwargs’: {}, ‘write_interval’: 250}, ‘grad_norm_clip’: 0.5, ‘kl_threshold’: 0, ‘lambda’: 0.95, ‘learning_epochs’: 8, ‘learning_rate’: 0.001, ‘learning_rate_scheduler’: None, ‘learning_rate_scheduler_kwargs’: {}, ‘learning_starts’: 0, ‘mini_batches’: 2, ‘random_timesteps’: 0, ‘ratio_clip’: 0.2, ‘rewards_shaper’: None, ‘rollouts’: 16, ‘state_preprocessor’: None, ‘state_preprocessor_kwargs’: {}, ‘time_limit_bootstrap’: False, ‘value_clip’: 0.2, ‘value_loss_scale’: 1.0, ‘value_preprocessor’: None, ‘value_preprocessor_kwargs’: {}}

class skrl.agents.torch.rpo.RPO(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None)

Bases: Agent

__init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None) None

Robust Policy Optimization (RPO)

https://arxiv.org/abs/2212.07536

Parameters:
  • models (dictionary of skrl.models.torch.Model) – Models used by the agent

  • memory (skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

  • observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

  • action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

  • device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

  • cfg (dict) – Configuration dictionary

Raises:

KeyError – If the models dictionary is missing a required key

_update(timestep: int, timesteps: int) None

Algorithm’s main update step

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

act(states: torch.Tensor, timestep: int, timesteps: int) torch.Tensor

Process the environment’s states to make a decision (actions) using the main policy

Parameters:
  • states (torch.Tensor) – Environment’s states

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

Returns:

Actions

Return type:

torch.Tensor

init(trainer_cfg: Mapping[str, Any] | None = None) None

Initialize the agent

post_interaction(timestep: int, timesteps: int) None

Callback called after the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None

Callback called before the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

record_transition(states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None

Record an environment transition in memory

Parameters:
  • states (torch.Tensor) – Observations/states of the environment used to make the decision

  • actions (torch.Tensor) – Actions taken by the agent

  • rewards (torch.Tensor) – Instant rewards achieved by the current actions

  • next_states (torch.Tensor) – Next observations/states of the environment

  • terminated (torch.Tensor) – Signals to indicate that episodes have terminated

  • truncated (torch.Tensor) – Signals to indicate that episodes have been truncated

  • infos (Any type supported by the environment) – Additional information about the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

class skrl.agents.torch.rpo.RPO_RNN(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None)

Bases: Agent

__init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None) None

Robust Policy Optimization (RPO) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.)

https://arxiv.org/abs/2212.07536

Parameters:
  • models (dictionary of skrl.models.torch.Model) – Models used by the agent

  • memory (skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

  • observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

  • action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

  • device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

  • cfg (dict) – Configuration dictionary

Raises:

KeyError – If the models dictionary is missing a required key

_update(timestep: int, timesteps: int) None

Algorithm’s main update step

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

act(states: torch.Tensor, timestep: int, timesteps: int) torch.Tensor

Process the environment’s states to make a decision (actions) using the main policy

Parameters:
  • states (torch.Tensor) – Environment’s states

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

Returns:

Actions

Return type:

torch.Tensor

init(trainer_cfg: Mapping[str, Any] | None = None) None

Initialize the agent

post_interaction(timestep: int, timesteps: int) None

Callback called after the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None

Callback called before the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

record_transition(states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None

Record an environment transition in memory

Parameters:
  • states (torch.Tensor) – Observations/states of the environment used to make the decision

  • actions (torch.Tensor) – Actions taken by the agent

  • rewards (torch.Tensor) – Instant rewards achieved by the current actions

  • next_states (torch.Tensor) – Next observations/states of the environment

  • terminated (torch.Tensor) – Signals to indicate that episodes have terminated

  • truncated (torch.Tensor) – Signals to indicate that episodes have been truncated

  • infos (Any type supported by the environment) – Additional information about the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps


API (JAX)

skrl.agents.jax.rpo.RPO_DEFAULT_CONFIG

alias of {‘alpha’: 0.5, ‘clip_predicted_values’: False, ‘discount_factor’: 0.99, ‘entropy_loss_scale’: 0.0, ‘experiment’: {‘checkpoint_interval’: 1000, ‘directory’: ‘’, ‘experiment_name’: ‘’, ‘store_separately’: False, ‘wandb’: False, ‘wandb_kwargs’: {}, ‘write_interval’: 250}, ‘grad_norm_clip’: 0.5, ‘kl_threshold’: 0, ‘lambda’: 0.95, ‘learning_epochs’: 8, ‘learning_rate’: 0.001, ‘learning_rate_scheduler’: None, ‘learning_rate_scheduler_kwargs’: {}, ‘learning_starts’: 0, ‘mini_batches’: 2, ‘random_timesteps’: 0, ‘ratio_clip’: 0.2, ‘rewards_shaper’: None, ‘rollouts’: 16, ‘state_preprocessor’: None, ‘state_preprocessor_kwargs’: {}, ‘time_limit_bootstrap’: False, ‘value_clip’: 0.2, ‘value_loss_scale’: 1.0, ‘value_preprocessor’: None, ‘value_preprocessor_kwargs’: {}}

class skrl.agents.jax.rpo.RPO(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | jax.Device | None = None, cfg: dict | None = None)

Bases: Agent

__init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | jax.Device | None = None, cfg: dict | None = None) None

Robust Policy Optimization (RPO)

https://arxiv.org/abs/2212.07536

Parameters:
  • models (dictionary of skrl.models.jax.Model) – Models used by the agent

  • memory (skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

  • observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

  • action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

  • device (str or jax.Device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

  • cfg (dict) – Configuration dictionary

Raises:

KeyError – If the models dictionary is missing a required key

_update(timestep: int, timesteps: int) None

Algorithm’s main update step

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

act(states: ndarray | jax.Array, timestep: int, timesteps: int) ndarray | jax.Array

Process the environment’s states to make a decision (actions) using the main policy

Parameters:
  • states (np.ndarray or jax.Array) – Environment’s states

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

Returns:

Actions

Return type:

np.ndarray or jax.Array

init(trainer_cfg: Mapping[str, Any] | None = None) None

Initialize the agent

post_interaction(timestep: int, timesteps: int) None

Callback called after the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None

Callback called before the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

record_transition(states: ndarray | jax.Array, actions: ndarray | jax.Array, rewards: ndarray | jax.Array, next_states: ndarray | jax.Array, terminated: ndarray | jax.Array, truncated: ndarray | jax.Array, infos: Any, timestep: int, timesteps: int) None

Record an environment transition in memory

Parameters:
  • states (np.ndarray or jax.Array) – Observations/states of the environment used to make the decision

  • actions (np.ndarray or jax.Array) – Actions taken by the agent

  • rewards (np.ndarray or jax.Array) – Instant rewards achieved by the current actions

  • next_states (np.ndarray or jax.Array) – Next observations/states of the environment

  • terminated (np.ndarray or jax.Array) – Signals to indicate that episodes have terminated

  • truncated (np.ndarray or jax.Array) – Signals to indicate that episodes have been truncated

  • infos (Any type supported by the environment) – Additional information about the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps