Robust Policy Optimization (RPO)#
RPO is a model-free, stochastic on-policy policy gradient algorithm that adds a uniform random perturbation to a base parameterized distribution to help the agent maintain a certain level of stochasticity throughout the training process
Paper: Robust Policy Optimization in Deep Reinforcement Learning
Algorithm#
Note
This algorithm is built on top of the PPO algorithm and simply adds the alpha
hyperparameter to the policy input dictionary. It is the responsibility of the user to make use of this hyper-parameter to modify the parameterized distribution.
class Policy(GaussianMixin, Model):
...
def compute(self, inputs, role):
# compute the mean actions using the neural network
mean_actions = self.net(inputs["states"])
# perturb the mean actions by adding a randomized uniform sample
rpo_alpha = inputs["alpha"]
perturbation = torch.zeros_like(mean_actions).uniform_(-rpo_alpha, rpo_alpha)
mean_actions += perturbation
return mean_actions, self.log_std_parameter, {}
class Policy(GaussianMixin, Model):
...
def __call__(self, inputs, role):
# compute the mean actions using the neural network
mean_actions = ...
log_std = ...
# perturb the mean actions by adding a randomized uniform sample
rpo_alpha = inputs["alpha"]
perturbation = jax.random.uniform(inputs["key"], mean_actions.shape, minval=-rpo_alpha, maxval=rpo_alpha)
mean_actions += perturbation
return mean_actions, log_std, {}
class Policy(GaussianMixin, Model):
...
def compute(self, inputs, role):
# compute the mean actions using the neural network
mean_actions = self.net(inputs["states"])
# perturb the mean actions by adding a randomized uniform sample
rpo_alpha = 0.5
perturbation = torch.zeros_like(mean_actions).uniform_(-rpo_alpha, rpo_alpha)
mean_actions += perturbation
return mean_actions, self.log_std_parameter, {}
class Policy(GaussianMixin, Model):
...
def __call__(self, inputs, role):
# compute the mean actions using the neural network
mean_actions = ...
log_std = ...
# perturb the mean actions by adding a randomized uniform sample
rpo_alpha = 0.5
perturbation = jax.random.uniform(inputs["key"], mean_actions.shape, minval=-rpo_alpha, maxval=rpo_alpha)
mean_actions += perturbation
return mean_actions, log_std, {}
Algorithm implementation#
Learning algorithm#
compute_gae(...)
_update(...)
Usage#
Note
Support for recurrent neural networks (RNN, LSTM, GRU and any other variant) is implemented in a separate file (rpo_rnn.py
) to maintain the readability of the standard implementation (rpo.py
)
# import the agent and its default configuration
from skrl.agents.torch.rpo import RPO, RPO_DEFAULT_CONFIG
# instantiate the agent's models
models = {}
models["policy"] = ...
models["value"] = ... # only required during training
# adjust some configuration if necessary
cfg_agent = RPO_DEFAULT_CONFIG.copy()
cfg_agent["<KEY>"] = ...
# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
agent = RPO(models=models,
memory=memory, # only required during training
cfg=cfg_agent,
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device)
# import the agent and its default configuration
from skrl.agents.jax.rpo import RPO, RPO_DEFAULT_CONFIG
# instantiate the agent's models
models = {}
models["policy"] = ...
models["value"] = ... # only required during training
# adjust some configuration if necessary
cfg_agent = RPO_DEFAULT_CONFIG.copy()
cfg_agent["<KEY>"] = ...
# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
agent = RPO(models=models,
memory=memory, # only required during training
cfg=cfg_agent,
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device)
Note
When using recursive models it is necessary to override their .get_specification()
method. Visit each model’s documentation for more details
# import the agent and its default configuration
from skrl.agents.torch.rpo import RPO_RNN as RPO, RPO_DEFAULT_CONFIG
# instantiate the agent's models
models = {}
models["policy"] = ...
models["value"] = ... # only required during training
# adjust some configuration if necessary
cfg_agent = RPO_DEFAULT_CONFIG.copy()
cfg_agent["<KEY>"] = ...
# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
agent = RPO(models=models,
memory=memory, # only required during training
cfg=cfg_agent,
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device)
Configuration and hyperparameters#
RPO_DEFAULT_CONFIG = {
"rollouts": 16, # number of rollouts before updating
"learning_epochs": 8, # number of learning epochs during each update
"mini_batches": 2, # number of mini batches during each learning epoch
"alpha": 0.5, # amount of uniform random perturbation on the mean actions: U(-alpha, alpha)
"discount_factor": 0.99, # discount factor (gamma)
"lambda": 0.95, # TD(lambda) coefficient (lam) for computing returns and advantages
"learning_rate": 1e-3, # learning rate
"learning_rate_scheduler": None, # learning rate scheduler class (see torch.optim.lr_scheduler)
"learning_rate_scheduler_kwargs": {}, # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3})
"state_preprocessor": None, # state preprocessor class (see skrl.resources.preprocessors)
"state_preprocessor_kwargs": {}, # state preprocessor's kwargs (e.g. {"size": env.observation_space})
"value_preprocessor": None, # value preprocessor class (see skrl.resources.preprocessors)
"value_preprocessor_kwargs": {}, # value preprocessor's kwargs (e.g. {"size": 1})
"random_timesteps": 0, # random exploration steps
"learning_starts": 0, # learning starts after this many steps
"grad_norm_clip": 0.5, # clipping coefficient for the norm of the gradients
"ratio_clip": 0.2, # clipping coefficient for computing the clipped surrogate objective
"value_clip": 0.2, # clipping coefficient for computing the value loss (if clip_predicted_values is True)
"clip_predicted_values": False, # clip predicted values during value loss computation
"entropy_loss_scale": 0.0, # entropy loss scaling factor
"value_loss_scale": 1.0, # value loss scaling factor
"kl_threshold": 0, # KL divergence threshold for early stopping
"rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward
"time_limit_bootstrap": False, # bootstrap at timeout termination (episode truncation)
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
"write_interval": 250, # TensorBoard writing interval (timesteps)
"checkpoint_interval": 1000, # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
"wandb_kwargs": {} # wandb kwargs (see https://docs.wandb.ai/ref/python/init)
}
}
Spaces#
The implementation supports the following Gym spaces / Gymnasium spaces
Gym/Gymnasium spaces |
Observation |
Action |
---|---|---|
Discrete |
\(\square\) |
\(\square\) |
Box |
\(\blacksquare\) |
\(\blacksquare\) |
Dict |
\(\blacksquare\) |
\(\square\) |
Models#
The implementation uses 1 continuous stochastic and 1 deterministic function approximator. These function approximators (models) must be collected in a dictionary and passed to the constructor of the class under the argument models
Notation |
Concept |
Key |
Input shape |
Output shape |
Type |
---|---|---|---|---|---|
\(\pi_\theta(s)\) |
Policy |
|
observation |
action |
|
\(V_\phi(s)\) |
Value |
|
observation |
1 |
Features#
Support for advanced features is described in the next table
Feature |
Support and remarks |
|
|
---|---|---|---|
Shared model |
for Policy and Value |
\(\blacksquare\) |
\(\square\) |
RNN support |
RNN, LSTM, GRU and any other variant |
\(\blacksquare\) |
\(\square\) |
API (PyTorch)#
- skrl.agents.torch.rpo.RPO_DEFAULT_CONFIG#
alias of {‘alpha’: 0.5, ‘clip_predicted_values’: False, ‘discount_factor’: 0.99, ‘entropy_loss_scale’: 0.0, ‘experiment’: {‘checkpoint_interval’: 1000, ‘directory’: ‘’, ‘experiment_name’: ‘’, ‘store_separately’: False, ‘wandb’: False, ‘wandb_kwargs’: {}, ‘write_interval’: 250}, ‘grad_norm_clip’: 0.5, ‘kl_threshold’: 0, ‘lambda’: 0.95, ‘learning_epochs’: 8, ‘learning_rate’: 0.001, ‘learning_rate_scheduler’: None, ‘learning_rate_scheduler_kwargs’: {}, ‘learning_starts’: 0, ‘mini_batches’: 2, ‘random_timesteps’: 0, ‘ratio_clip’: 0.2, ‘rewards_shaper’: None, ‘rollouts’: 16, ‘state_preprocessor’: None, ‘state_preprocessor_kwargs’: {}, ‘time_limit_bootstrap’: False, ‘value_clip’: 0.2, ‘value_loss_scale’: 1.0, ‘value_preprocessor’: None, ‘value_preprocessor_kwargs’: {}}
- class skrl.agents.torch.rpo.RPO(models: Dict[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None)#
Bases:
Agent
- __init__(models: Dict[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None) None #
Robust Policy Optimization (RPO)
https://arxiv.org/abs/2212.07536
- Parameters:
models (dictionary of skrl.models.torch.Model) – Models used by the agent
memory (skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added
observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default:
None
)action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default:
None
)device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default:
None
). If None, the device will be either"cuda"
if available or"cpu"
cfg (dict) – Configuration dictionary
- Raises:
KeyError – If the models dictionary is missing a required key
- act(states: torch.Tensor, timestep: int, timesteps: int) torch.Tensor #
Process the environment’s states to make a decision (actions) using the main policy
- Parameters:
states (torch.Tensor) – Environment’s states
timestep (int) – Current timestep
timesteps (int) – Number of timesteps
- Returns:
Actions
- Return type:
- post_interaction(timestep: int, timesteps: int) None #
Callback called after the interaction with the environment
- pre_interaction(timestep: int, timesteps: int) None #
Callback called before the interaction with the environment
- record_transition(states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None #
Record an environment transition in memory
- Parameters:
states (torch.Tensor) – Observations/states of the environment used to make the decision
actions (torch.Tensor) – Actions taken by the agent
rewards (torch.Tensor) – Instant rewards achieved by the current actions
next_states (torch.Tensor) – Next observations/states of the environment
terminated (torch.Tensor) – Signals to indicate that episodes have terminated
truncated (torch.Tensor) – Signals to indicate that episodes have been truncated
infos (Any type supported by the environment) – Additional information about the environment
timestep (int) – Current timestep
timesteps (int) – Number of timesteps
- class skrl.agents.torch.rpo.RPO_RNN(models: Dict[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None)#
Bases:
Agent
- __init__(models: Dict[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None) None #
Robust Policy Optimization (RPO) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.)
https://arxiv.org/abs/2212.07536
- Parameters:
models (dictionary of skrl.models.torch.Model) – Models used by the agent
memory (skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added
observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default:
None
)action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default:
None
)device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default:
None
). If None, the device will be either"cuda"
if available or"cpu"
cfg (dict) – Configuration dictionary
- Raises:
KeyError – If the models dictionary is missing a required key
- act(states: torch.Tensor, timestep: int, timesteps: int) torch.Tensor #
Process the environment’s states to make a decision (actions) using the main policy
- Parameters:
states (torch.Tensor) – Environment’s states
timestep (int) – Current timestep
timesteps (int) – Number of timesteps
- Returns:
Actions
- Return type:
- post_interaction(timestep: int, timesteps: int) None #
Callback called after the interaction with the environment
- pre_interaction(timestep: int, timesteps: int) None #
Callback called before the interaction with the environment
- record_transition(states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None #
Record an environment transition in memory
- Parameters:
states (torch.Tensor) – Observations/states of the environment used to make the decision
actions (torch.Tensor) – Actions taken by the agent
rewards (torch.Tensor) – Instant rewards achieved by the current actions
next_states (torch.Tensor) – Next observations/states of the environment
terminated (torch.Tensor) – Signals to indicate that episodes have terminated
truncated (torch.Tensor) – Signals to indicate that episodes have been truncated
infos (Any type supported by the environment) – Additional information about the environment
timestep (int) – Current timestep
timesteps (int) – Number of timesteps
API (JAX)#
- skrl.agents.jax.rpo.RPO_DEFAULT_CONFIG#
alias of {‘alpha’: 0.5, ‘clip_predicted_values’: False, ‘discount_factor’: 0.99, ‘entropy_loss_scale’: 0.0, ‘experiment’: {‘checkpoint_interval’: 1000, ‘directory’: ‘’, ‘experiment_name’: ‘’, ‘store_separately’: False, ‘wandb’: False, ‘wandb_kwargs’: {}, ‘write_interval’: 250}, ‘grad_norm_clip’: 0.5, ‘kl_threshold’: 0, ‘lambda’: 0.95, ‘learning_epochs’: 8, ‘learning_rate’: 0.001, ‘learning_rate_scheduler’: None, ‘learning_rate_scheduler_kwargs’: {}, ‘learning_starts’: 0, ‘mini_batches’: 2, ‘random_timesteps’: 0, ‘ratio_clip’: 0.2, ‘rewards_shaper’: None, ‘rollouts’: 16, ‘state_preprocessor’: None, ‘state_preprocessor_kwargs’: {}, ‘time_limit_bootstrap’: False, ‘value_clip’: 0.2, ‘value_loss_scale’: 1.0, ‘value_preprocessor’: None, ‘value_preprocessor_kwargs’: {}}
- class skrl.agents.jax.rpo.RPO(models: Dict[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | jax.Device | None = None, cfg: dict | None = None)#
Bases:
Agent
- __init__(models: Dict[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | jax.Device | None = None, cfg: dict | None = None) None #
Robust Policy Optimization (RPO)
https://arxiv.org/abs/2212.07536
- Parameters:
models (dictionary of skrl.models.jax.Model) – Models used by the agent
memory (skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added
observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default:
None
)action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default:
None
)device (str or jax.Device, optional) – Device on which a tensor/array is or will be allocated (default:
None
). If None, the device will be either"cuda"
if available or"cpu"
cfg (dict) – Configuration dictionary
- Raises:
KeyError – If the models dictionary is missing a required key
- act(states: ndarray | jax.Array, timestep: int, timesteps: int) ndarray | jax.Array #
Process the environment’s states to make a decision (actions) using the main policy
- post_interaction(timestep: int, timesteps: int) None #
Callback called after the interaction with the environment
- pre_interaction(timestep: int, timesteps: int) None #
Callback called before the interaction with the environment
- record_transition(states: ndarray | jax.Array, actions: ndarray | jax.Array, rewards: ndarray | jax.Array, next_states: ndarray | jax.Array, terminated: ndarray | jax.Array, truncated: ndarray | jax.Array, infos: Any, timestep: int, timesteps: int) None #
Record an environment transition in memory
- Parameters:
states (np.ndarray or jax.Array) – Observations/states of the environment used to make the decision
actions (np.ndarray or jax.Array) – Actions taken by the agent
rewards (np.ndarray or jax.Array) – Instant rewards achieved by the current actions
next_states (np.ndarray or jax.Array) – Next observations/states of the environment
terminated (np.ndarray or jax.Array) – Signals to indicate that episodes have terminated
truncated (np.ndarray or jax.Array) – Signals to indicate that episodes have been truncated
infos (Any type supported by the environment) – Additional information about the environment
timestep (int) – Current timestep
timesteps (int) – Number of timesteps