Trust Region Policy Optimization (TRPO)

TRPO is a model-free, stochastic on-policy policy gradient algorithm that deploys an iterative procedure to optimize the policy, with guaranteed monotonic improvement

Paper: Trust Region Policy Optimization



Algorithm

For each iteration do
\(\bullet \;\) Collect, in a rollout memory, a set of states \(s\), actions \(a\), rewards \(r\), dones \(d\), log probabilities \(logp\) and values \(V\) on policy using \(\pi_\theta\) and \(V_\phi\)
\(\bullet \;\) Estimate returns \(R\) and advantages \(A\) using Generalized Advantage Estimation (GAE(\(\lambda\))) from the collected data [\(r, d, V\)]
\(\bullet \;\) Compute the surrogate objective (policy loss) gradient \(g\) and the Hessian \(H\) of \(KL\) divergence with respect to the policy parameters \(\theta\)
\(\bullet \;\) Compute the search direction \(\; x \approx H^{-1}g \;\) using the conjugate gradient method
\(\bullet \;\) Compute the maximal (full) step length \(\; \beta = \sqrt{\dfrac{2 \delta}{x^T H x}} x \;\) where \(\delta\) is the desired (maximum) \(KL\) divergence and \(\; \sqrt{\frac{2 \delta}{x^T H x}} \;\) is the step size
\(\bullet \;\) Perform a backtracking line search with exponential decay to find the final policy update \(\; \theta_{new} = \theta + \alpha \; \beta \;\) ensuring improvement of the surrogate objective and satisfaction of the \(KL\) divergence constraint
\(\bullet \;\) Update the value function \(V_\phi\) using the computed returns \(R\)

Algorithm implementation


Learning algorithm


compute_gae(...)
def \(\;f_{GAE} (r, d, V, V_{_{last}}') \;\rightarrow\; R, A:\)
\(adv \leftarrow 0\)
\(A \leftarrow \text{zeros}(r)\)
# advantages computation
FOR each reverse iteration \(i\) up to the number of rows in \(r\) DO
IF \(i\) is not the last row of \(r\) THEN
\(V_i' = V_{i+1}\)
ELSE
\(V_i' \leftarrow V_{_{last}}'\)
\(adv \leftarrow r_i - V_i \, +\) discount_factor \(\neg d_i \; (V_i' \, -\) lambda \(adv)\)
\(A_i \leftarrow adv\)
# returns computation
\(R \leftarrow A + V\)
# normalize advantages
\(A \leftarrow \dfrac{A - \bar{A}}{A_\sigma + 10^{-8}}\)

surrogate_loss(...)
def \(\;f_{Loss} (\pi_\theta, s, a, logp, A) \;\rightarrow\; L_{\pi_\theta}:\)
\(logp' \leftarrow \pi_\theta(s, a)\)
\(L_{\pi_\theta} \leftarrow \frac{1}{N} \sum_{i=1}^N A \; e^{(logp' - logp)}\)

conjugate_gradient(...) (See conjugate gradient method)
def \(\;f_{CG} (\pi_\theta, s, b) \;\rightarrow\; x:\)
\(x \leftarrow \text{zeros}(b)\)
\(r \leftarrow b\)
\(p \leftarrow b\)
\(rr_{old} \leftarrow r \cdot r\)
FOR each iteration up to conjugate_gradient_steps DO
\(\alpha \leftarrow \dfrac{rr_{old}}{p \cdot f_{Ax}(\pi_\theta, s, b)}\)
\(x \leftarrow x + \alpha \; p\)
\(r \leftarrow r - \alpha \; f_{Ax}(\pi_\theta, s)\)
\(rr_{new} \leftarrow r \cdot r\)
IF \(rr_{new} <\) residual tolerance THEN
BREAK LOOP
\(p \leftarrow r + \dfrac{rr_{new}}{rr_{old}} \; p\)
\(rr_{old} \leftarrow rr_{new}\)

fisher_vector_product(...) (See fisher vector product in TRPO)
def \(\;f_{Ax} (\pi_\theta, s, v) \;\rightarrow\; hv:\)
\(kl \leftarrow f_{KL}(\pi_\theta, \pi_\theta, s)\)
\(g_{kl} \leftarrow \nabla_\theta kl\)
\(g_{kl_{flat}} \leftarrow \text{flatten}(g_{kl})\)
\(g_{hv} \leftarrow \nabla_\theta (g_{kl_{flat}} \; v)\)
\(g_{hv_{flat}} \leftarrow \text{flatten}(g_{hv})\)
\(hv \leftarrow g_{hv_{flat}} +\) damping \(v\)

def \(\;f_{KL} (\pi_{\theta 1}, \pi_{\theta 2}, s) \;\rightarrow\; kl:\)
\(\mu_1, \log\sigma_1 \leftarrow \pi_{\theta 1}(s)\)
\(\mu_2, \log\sigma_2 \leftarrow \pi_{\theta 2}(s)\)
\(kl \leftarrow \log\sigma_1 - \log\sigma_2 + \frac{1}{2} \dfrac{(e^{\log\sigma_1})^2 + (\mu_1 - \mu_2)^2}{(e^{\log\sigma_2})^2} - \frac{1}{2}\)
\(kl \leftarrow \frac{1}{N} \sum_{i=1}^N \, (\sum_{dim} kl)\)

_update(...)
# compute returns and advantages
\(V_{_{last}}' \leftarrow V_\phi(s')\)
\(R, A \leftarrow f_{GAE}(r, d, V, V_{_{last}}')\)
# sample all from memory
[[\(s, a, logp, A\)]] \(\leftarrow\) states, actions, log_prob, advantages
# compute policy loss gradient
\(L_{\pi_\theta} \leftarrow f_{Loss}(\pi_\theta, s, a, logp, A)\)
\(g \leftarrow \nabla_{\theta} L_{\pi_\theta}\)
\(g_{_{flat}} \leftarrow \text{flatten}(g)\)
# compute the search direction using the conjugate gradient algorithm
\(search_{direction} \leftarrow f_{CG}(\pi_\theta, s, g_{_{flat}})\)
# compute step size and full step
\(xHx \leftarrow search_{direction} \; f_{Ax}(\pi_\theta, s, search_{direction})\)
\(step_{size} \leftarrow \sqrt{\dfrac{2 \, \delta}{xHx}} \qquad\) with \(\; \delta\) as max_kl_divergence
\(\beta \leftarrow step_{size} \; search_{direction}\)
# backtracking line search
\(flag_{restore} \leftarrow \text{True}\)
\(\pi_{\theta_{backup}} \leftarrow \pi_\theta\)
\(\theta \leftarrow \text{get_parameters}(\pi_\theta)\)
\(I_{expected} \leftarrow g_{_{flat}} \; \beta\)
FOR \(\alpha \leftarrow (0.5\) step_fraction \()^i \;\) with \(i = 0, 1, 2, ...\) up to max_backtrack_steps DO
\(\theta_{new} \leftarrow \theta + \alpha \; \beta\)
\(\pi_\theta \leftarrow \text{set_parameters}(\theta_{new})\)
\(I_{expected} \leftarrow \alpha \; I_{expected}\)
\(kl \leftarrow f_{KL}(\pi_{\theta_{backup}}, \pi_\theta, s)\)
\(L \leftarrow f_{Loss}(\pi_\theta, s, a, logp, A)\)
IF \(kl < \delta\) AND \(\dfrac{L - L_{\pi_\theta}}{I_{expected}} >\) accept_ratio THEN
\(flag_{restore} \leftarrow \text{False}\)
BREAK LOOP
IF \(flag_{restore}\) THEN
\(\pi_\theta \leftarrow \pi_{\theta_{backup}}\)
# sample mini-batches from memory
[[\(s, R\)]] \(\leftarrow\) states, returns
# learning epochs
FOR each learning epoch up to learning_epochs DO
# mini-batches loop
FOR each mini-batch [\(s, R\)] up to mini_batches DO
# compute value loss
\(V' \leftarrow V_\phi(s)\)
\(L_{V_\phi} \leftarrow\) value_loss_scale \(\frac{1}{N} \sum_{i=1}^N (R - V')^2\)
# optimization step (value)
reset \(\text{optimizer}_\phi\)
\(\nabla_{\phi} L_{V_\phi}\)
\(\text{clip}(\lVert \nabla_{\phi} \rVert)\) with grad_norm_clip
step \(\text{optimizer}_\phi\)
# update learning rate
IF there is a learning_rate_scheduler THEN
step \(\text{scheduler}_\phi(\text{optimizer}_\phi)\)

Usage

Note

Support for recurrent neural networks (RNN, LSTM, GRU and any other variant) is implemented in a separate file (trpo_rnn.py) to maintain the readability of the standard implementation (trpo.py)

# import the agent and its default configuration
from skrl.agents.torch.trpo import TRPO, TRPO_DEFAULT_CONFIG

# instantiate the agent's models
models = {}
models["policy"] = ...
models["value"] = ...  # only required during training

# adjust some configuration if necessary
cfg_agent = TRPO_DEFAULT_CONFIG.copy()
cfg_agent["<KEY>"] = ...

# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
agent = TRPO(models=models,
             memory=memory,  # only required during training
             cfg=cfg_agent,
             observation_space=env.observation_space,
             action_space=env.action_space,
             device=env.device)

Configuration and hyperparameters

TRPO_DEFAULT_CONFIG = {
    "rollouts": 16,                 # number of rollouts before updating
    "learning_epochs": 8,           # number of learning epochs during each update
    "mini_batches": 2,              # number of mini batches during each learning epoch

    "discount_factor": 0.99,        # discount factor (gamma)
    "lambda": 0.95,                 # TD(lambda) coefficient (lam) for computing returns and advantages

    "value_learning_rate": 1e-3,            # value learning rate
    "learning_rate_scheduler": None,        # learning rate scheduler class (see torch.optim.lr_scheduler)
    "learning_rate_scheduler_kwargs": {},   # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3})

    "state_preprocessor": None,             # state preprocessor class (see skrl.resources.preprocessors)
    "state_preprocessor_kwargs": {},        # state preprocessor's kwargs (e.g. {"size": env.observation_space})
    "value_preprocessor": None,             # value preprocessor class (see skrl.resources.preprocessors)
    "value_preprocessor_kwargs": {},        # value preprocessor's kwargs (e.g. {"size": 1})

    "random_timesteps": 0,          # random exploration steps
    "learning_starts": 0,           # learning starts after this many steps

    "grad_norm_clip": 0.5,          # clipping coefficient for the norm of the gradients
    "value_loss_scale": 1.0,        # value loss scaling factor

    "damping": 0.1,                     # damping coefficient for computing the Hessian-vector product
    "max_kl_divergence": 0.01,          # maximum KL divergence between old and new policy
    "conjugate_gradient_steps": 10,     # maximum number of iterations for the conjugate gradient algorithm
    "max_backtrack_steps": 10,          # maximum number of backtracking steps during line search
    "accept_ratio": 0.5,                # accept ratio for the line search loss improvement
    "step_fraction": 1.0,               # fraction of the step size for the line search

    "rewards_shaper": None,         # rewards shaping function: Callable(reward, timestep, timesteps) -> reward
    "time_limit_bootstrap": False,  # bootstrap at timeout termination (episode truncation)

    "experiment": {
        "directory": "",            # experiment's parent directory
        "experiment_name": "",      # experiment name
        "write_interval": 250,      # TensorBoard writing interval (timesteps)

        "checkpoint_interval": 1000,        # interval for checkpoints (timesteps)
        "store_separately": False,          # whether to store checkpoints separately

        "wandb": False,             # whether to use Weights & Biases
        "wandb_kwargs": {}          # wandb kwargs (see https://docs.wandb.ai/ref/python/init)
    }
}

Spaces

The implementation supports the following Gym spaces / Gymnasium spaces

Gym/Gymnasium spaces

Observation

Action

Discrete

\(\square\)

\(\square\)

MultiDiscrete

\(\square\)

\(\square\)

Box

\(\blacksquare\)

\(\blacksquare\)

Dict

\(\blacksquare\)

\(\square\)


Models

The implementation uses 1 stochastic and 1 deterministic function approximator. These function approximators (models) must be collected in a dictionary and passed to the constructor of the class under the argument models

Notation

Concept

Key

Input shape

Output shape

Type

\(\pi_\theta(s)\)

Policy

"policy"

observation

action

Gaussian /
MultivariateGaussian

\(V_\phi(s)\)

Value

"value"

observation

1

Deterministic


Features

Support for advanced features is described in the next table

Feature

Support and remarks

    pytorch    

    jax    

Shared model

-

\(\square\)

\(\square\)

RNN support

RNN, LSTM, GRU and any other variant

\(\blacksquare\)

\(\square\)


API (PyTorch)

skrl.agents.torch.trpo.TRPO_DEFAULT_CONFIG

alias of {‘accept_ratio’: 0.5, ‘conjugate_gradient_steps’: 10, ‘damping’: 0.1, ‘discount_factor’: 0.99, ‘experiment’: {‘checkpoint_interval’: 1000, ‘directory’: ‘’, ‘experiment_name’: ‘’, ‘store_separately’: False, ‘wandb’: False, ‘wandb_kwargs’: {}, ‘write_interval’: 250}, ‘grad_norm_clip’: 0.5, ‘lambda’: 0.95, ‘learning_epochs’: 8, ‘learning_rate_scheduler’: None, ‘learning_rate_scheduler_kwargs’: {}, ‘learning_starts’: 0, ‘max_backtrack_steps’: 10, ‘max_kl_divergence’: 0.01, ‘mini_batches’: 2, ‘random_timesteps’: 0, ‘rewards_shaper’: None, ‘rollouts’: 16, ‘state_preprocessor’: None, ‘state_preprocessor_kwargs’: {}, ‘step_fraction’: 1.0, ‘time_limit_bootstrap’: False, ‘value_learning_rate’: 0.001, ‘value_loss_scale’: 1.0, ‘value_preprocessor’: None, ‘value_preprocessor_kwargs’: {}}

class skrl.agents.torch.trpo.TRPO(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None)

Bases: Agent

__init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None) None

Trust Region Policy Optimization (TRPO)

https://arxiv.org/abs/1502.05477

Parameters:
  • models (dictionary of skrl.models.torch.Model) – Models used by the agent

  • memory (skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

  • observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

  • action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

  • device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

  • cfg (dict) – Configuration dictionary

Raises:

KeyError – If the models dictionary is missing a required key

_update(timestep: int, timesteps: int) None

Algorithm’s main update step

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

act(states: torch.Tensor, timestep: int, timesteps: int) torch.Tensor

Process the environment’s states to make a decision (actions) using the main policy

Parameters:
  • states (torch.Tensor) – Environment’s states

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

Returns:

Actions

Return type:

torch.Tensor

init(trainer_cfg: Mapping[str, Any] | None = None) None

Initialize the agent

post_interaction(timestep: int, timesteps: int) None

Callback called after the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None

Callback called before the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

record_transition(states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None

Record an environment transition in memory

Parameters:
  • states (torch.Tensor) – Observations/states of the environment used to make the decision

  • actions (torch.Tensor) – Actions taken by the agent

  • rewards (torch.Tensor) – Instant rewards achieved by the current actions

  • next_states (torch.Tensor) – Next observations/states of the environment

  • terminated (torch.Tensor) – Signals to indicate that episodes have terminated

  • truncated (torch.Tensor) – Signals to indicate that episodes have been truncated

  • infos (Any type supported by the environment) – Additional information about the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

class skrl.agents.torch.trpo.TRPO_RNN(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None)

Bases: Agent

__init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None) None

Trust Region Policy Optimization (TRPO) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.)

https://arxiv.org/abs/1502.05477

Parameters:
  • models (dictionary of skrl.models.torch.Model) – Models used by the agent

  • memory (skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

  • observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

  • action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

  • device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

  • cfg (dict) – Configuration dictionary

Raises:

KeyError – If the models dictionary is missing a required key

_update(timestep: int, timesteps: int) None

Algorithm’s main update step

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

act(states: torch.Tensor, timestep: int, timesteps: int) torch.Tensor

Process the environment’s states to make a decision (actions) using the main policy

Parameters:
  • states (torch.Tensor) – Environment’s states

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

Returns:

Actions

Return type:

torch.Tensor

init(trainer_cfg: Mapping[str, Any] | None = None) None

Initialize the agent

post_interaction(timestep: int, timesteps: int) None

Callback called after the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None

Callback called before the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

record_transition(states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None

Record an environment transition in memory

Parameters:
  • states (torch.Tensor) – Observations/states of the environment used to make the decision

  • actions (torch.Tensor) – Actions taken by the agent

  • rewards (torch.Tensor) – Instant rewards achieved by the current actions

  • next_states (torch.Tensor) – Next observations/states of the environment

  • terminated (torch.Tensor) – Signals to indicate that episodes have terminated

  • truncated (torch.Tensor) – Signals to indicate that episodes have been truncated

  • infos (Any type supported by the environment) – Additional information about the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps