Proximal Policy Optimization (PPO)¶
PPO is a model-free, stochastic on-policy policy gradient algorithm that alternates between sampling data through interaction with the environment, and optimizing a surrogate objective function while avoiding that the new policy does not move too far away from the old one.
Paper: Proximal Policy Optimization Algorithms.
Algorithm¶
Algorithm implementation¶
Learning algorithm¶
compute_gae(...)_update(...)Usage¶
Note
Support for recurrent neural networks (RNN, LSTM, GRU and any other variant) is implemented in a separate file
(ppo_rnn.py) to maintain the readability of the standard implementation (ppo.py).
# import the agent and its default configuration
from skrl.agents.torch.ppo import PPO, PPO_CFG
# instantiate the agent's models
models = {}
models["policy"] = ...
models["value"] = ... # only required during training
# adjust some configuration if necessary
cfg_agent = PPO_CFG()
cfg_agent.KEY = ...
# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
agent = PPO(
models=models,
memory=memory, # only required during training
cfg=cfg_agent,
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
)
# import the agent and its default configuration
from skrl.agents.jax.ppo import PPO, PPO_CFG
# instantiate the agent's models
models = {}
models["policy"] = ...
models["value"] = ... # only required during training
# adjust some configuration if necessary
cfg_agent = PPO_CFG()
cfg_agent.KEY = ...
# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
agent = PPO(
models=models,
memory=memory, # only required during training
cfg=cfg_agent,
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
)
Note
When using recursive models it is necessary to override their .get_specification() method.
Visit each model’s documentation for more details.
# import the agent and its default configuration
from skrl.agents.torch.ppo import PPO_RNN as PPO, PPO_CFG
# instantiate the agent's models
models = {}
models["policy"] = ...
models["value"] = ... # only required during training
# adjust some configuration if necessary
cfg_agent = PPO_CFG()
cfg_agent.KEY = ...
# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
agent = PPO(
models=models,
memory=memory, # only required during training
cfg=cfg_agent,
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
)
Configuration and hyperparameters¶
Spaces¶
The implementation supports the following Gymnasium spaces:
Gymnasium spaces |
Observation |
Action |
|---|---|---|
Discrete |
\(\square\) |
\(\blacksquare\) |
MultiDiscrete |
\(\square\) |
\(\blacksquare\) |
Box |
\(\blacksquare\) |
\(\blacksquare\) |
Dict |
\(\blacksquare\) |
\(\square\) |
Models¶
The implementation uses 1 stochastic (discrete or continuous) and 1 deterministic function approximator.
These function approximators (models) must be collected in a dictionary and passed to the constructor of the class
under the argument models.
Notation |
Concept |
Key |
Input shape |
Output shape |
Type |
|---|---|---|---|---|---|
\(\pi_\theta(s)\) |
Policy |
|
observation |
action |
Categorical /
|
\(V_\phi(s)\) |
Value |
|
observation |
1 |
Features¶
Support for advanced features is described in the following table:
Feature |
Support and remarks |
|
|
|
|---|---|---|---|---|
Shared model |
for Policy and Value |
\(\blacksquare\) |
\(\square\) |
\(\blacksquare\) |
RNN support |
RNN, LSTM, GRU and any other variant |
\(\blacksquare\) |
\(\square\) |
\(\square\) |
Mixed precision |
Automatic mixed precision |
\(\blacksquare\) |
\(\square\) |
\(\square\) |
Distributed |
Single Program Multi Data (SPMD) multi-GPU |
\(\blacksquare\) |
\(\blacksquare\) |
\(\square\) |
API¶
PyTorch¶
Configuration for the PPO agent. |
|
Proximal Policy Optimization (PPO). |
|
Proximal Policy Optimization (PPO) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.). |
- class skrl.agents.torch.ppo.PPO_CFG(*, experiment: ExperimentCfg = <factory>, rollouts: int = 16, learning_epochs: int = 8, mini_batches: int = 2, discount_factor: float = 0.99, gae_lambda: float = 0.95, learning_rate: float | tuple[float, float] = 0.001, learning_rate_scheduler: type | tuple[type | None, type | None] | None = None, learning_rate_scheduler_kwargs: dict | tuple[dict, dict] = <factory>, observation_preprocessor: type | None = None, observation_preprocessor_kwargs: dict = <factory>, state_preprocessor: type | None = None, state_preprocessor_kwargs: dict = <factory>, value_preprocessor: type | None = None, value_preprocessor_kwargs: dict = <factory>, random_timesteps: int = 0, learning_starts: int = 0, grad_norm_clip: float = 0.5, ratio_clip: float = 0.2, value_clip: float = 0.2, entropy_loss_scale: float = 0.0, value_loss_scale: float = 2.5, kl_threshold: float = 0.0, time_limit_bootstrap: bool = False, rewards_shaper: Callable | None = None, mixed_precision: bool = False)[source]¶
Bases:
AgentCfgConfiguration for the PPO agent.
Methods:
Attributes:
Parameter that balances the importance of future rewards (close to 1.0) versus immediate rewards (close to 0.0).
Entropy loss scaling factor.
Experiment settings.
TD(lambda) coefficient for computing Generalized Advantage Estimation (GAE).
Clipping coefficient for the gradients by their global norm.
KL-divergence threshold for early stopping.
Number of learning epochs to perform during updates.
Learning rate for the policy and value networks.
Learning rate scheduler class for the policy and value networks.
Keyword arguments for the learning rate scheduler's constructor.
Number of steps to perform before calling the algorithm update function.
Number of mini batches to sample when updating.
Whether to enable automatic mixed precision for higher performance.
Preprocessor class to process the environment's observations.
Keyword arguments for the observation preprocessor's constructor.
Number of random exploration (sampling random actions) steps to perform before sampling actions from the policy.
Clipping coefficient for computing the clipped surrogate objective.
Rewards shaping function.
Number of collection steps to perform between updates.
Preprocessor class to process the environment's states.
Keyword arguments for the state preprocessor's constructor.
Whether to bootstrap at timeout termination (episode truncation).
Clipping coefficient for the predicted value during value loss computation.
Value loss scaling factor.
Preprocessor class to process the value network's output.
Keyword arguments for the value preprocessor's constructor.
- discount_factor: float = 0.99¶
Parameter that balances the importance of future rewards (close to 1.0) versus immediate rewards (close to 0.0).
Range:
[0.0, 1.0].
- experiment: ExperimentCfg¶
Experiment settings.
- gae_lambda: float = 0.95¶
TD(lambda) coefficient for computing Generalized Advantage Estimation (GAE).
- grad_norm_clip: float = 0.5¶
Clipping coefficient for the gradients by their global norm.
If less than or equal to 0, the gradients will not be clipped.
- learning_rate: float | tuple[float, float] = 0.001¶
Learning rate for the policy and value networks.
If a float is provided, the same learning rate will be used for the networks.
If a tuple is provided, its elements will be used for each network in order.
- learning_rate_scheduler: type | tuple[type | None, type | None] | None = None¶
Learning rate scheduler class for the policy and value networks.
See Learning rate schedulers for more details.
If a class is provided, the same learning rate scheduler will be used for the networks.
If a tuple is provided, its elements will be used for each network in order.
- learning_rate_scheduler_kwargs: dict | tuple[dict, dict]¶
Keyword arguments for the learning rate scheduler’s constructor.
See Learning rate schedulers for more details.
Warning
The
optimizerargument is automatically passed to the learning rate scheduler’s constructor. Therefore, it must not be provided in the keyword arguments.If a dictionary is provided, the same keyword arguments will be used for the networks.
If a tuple is provided, its elements will be used for each network in order.
- observation_preprocessor: type | None = None¶
Preprocessor class to process the environment’s observations.
See Preprocessors for more details.
- observation_preprocessor_kwargs: dict¶
Keyword arguments for the observation preprocessor’s constructor.
See Preprocessors for more details.
- random_timesteps: int = 0¶
Number of random exploration (sampling random actions) steps to perform before sampling actions from the policy.
- state_preprocessor: type | None = None¶
Preprocessor class to process the environment’s states.
See Preprocessors for more details.
- state_preprocessor_kwargs: dict¶
Keyword arguments for the state preprocessor’s constructor.
See Preprocessors for more details.
- time_limit_bootstrap: bool = False¶
Whether to bootstrap at timeout termination (episode truncation).
- value_clip: float = 0.2¶
Clipping coefficient for the predicted value during value loss computation.
If less than or equal to 0, the predicted value will not be clipped.
- value_preprocessor: type | None = None¶
Preprocessor class to process the value network’s output.
See Preprocessors for more details.
- value_preprocessor_kwargs: dict¶
Keyword arguments for the value preprocessor’s constructor.
See Preprocessors for more details.
- class skrl.agents.torch.ppo.PPO(*, models: dict[str, Model], memory: Memory | None = None, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: PPO_CFG | dict = {})[source]¶
Bases:
AgentProximal Policy Optimization (PPO).
https://arxiv.org/abs/1707.06347
- Parameters:
models – Agent’s models.
memory – Memory to storage agent’s data and environment transitions.
observation_space – Observation space.
state_space – State space.
action_space – Action space.
device – Data allocation and computation device. If not specified, the default device will be used.
cfg – Agent’s configuration.
- Raises:
KeyError – If a configuration key is missing.
Methods:
act(observations, states, *, timestep, timesteps)Process the environment's observations/states to make a decision (actions) using the main policy.
enable_models_training_mode([enabled])Set the training mode of all the agent's models: enabled (training) or disabled (evaluation).
enable_training_mode([enabled, apply_to_models])Set the training mode of the agent: enabled (training) or disabled (evaluation).
init(*[, trainer_cfg])Initialize the agent.
load(path)Load the agent from the specified path.
post_interaction(*, timestep, timesteps)Method called after the interaction with the environment.
pre_interaction(*, timestep, timesteps)Method called before the interaction with the environment.
record_transition(*, observations, states, ...)Record an environment transition in memory.
save(path)Save the agent to the specified path.
track_data(tag, value)Track data to TensorBoard.
update(*, timestep, timesteps)Algorithm's main update step.
write_checkpoint(*, timestep, timesteps)Write checkpoint (modules) to persistent storage.
write_tracking_data(*, timestep, timesteps)Write tracking data to TensorBoard.
- act(observations: torch.Tensor, states: torch.Tensor | None, *, timestep: int, timesteps: int) tuple[torch.Tensor, dict[str, Any]][source]¶
Process the environment’s observations/states to make a decision (actions) using the main policy.
- Parameters:
observations – Environment observations.
states – Environment states.
timestep – Current timestep.
timesteps – Number of timesteps.
- Returns:
Agent output. The first component is the expected action/value returned by the agent. The second component is a dictionary containing extra output values according to the model.
- enable_models_training_mode(enabled: bool = True) None[source]¶
Set the training mode of all the agent’s models: enabled (training) or disabled (evaluation).
- Parameters:
enabled – True to enable the training mode, False to enable the evaluation mode.
- enable_training_mode(enabled: bool = True, *, apply_to_models: bool = False) None[source]¶
Set the training mode of the agent: enabled (training) or disabled (evaluation).
The training mode can be queried by the
trainingproperty.- Parameters:
enabled – True to enable the training mode, False to enable the evaluation mode.
apply_to_models – Whether to apply the training mode to all the agent’s models.
- init(*, trainer_cfg: dict[str, Any] | None = None) None[source]¶
Initialize the agent.
- Parameters:
trainer_cfg – Trainer configuration.
- load(path: str) None[source]¶
Load the agent from the specified path.
Note
The final storage device is determined by the constructor of the agent.
- Parameters:
path – Path to load the agent from.
- post_interaction(*, timestep: int, timesteps: int) None[source]¶
Method called after the interaction with the environment.
- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.
- pre_interaction(*, timestep: int, timesteps: int) None[source]¶
Method called before the interaction with the environment.
- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.
- record_transition(*, observations: torch.Tensor, states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_observations: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None[source]¶
Record an environment transition in memory.
- Parameters:
observations – Environment observations.
states – Environment states.
actions – Actions taken by the agent.
rewards – Instant rewards achieved by the current actions.
next_observations – Next environment observations.
next_states – Next environment states.
terminated – Signals that indicate episodes have terminated.
truncated – Signals that indicate episodes have been truncated.
infos – Additional information about the environment.
timestep – Current timestep.
timesteps – Number of timesteps.
- save(path: str) None[source]¶
Save the agent to the specified path.
- Parameters:
path – Path to save the agent to.
- track_data(tag: str, value: float) None[source]¶
Track data to TensorBoard.
Note
Currently only scalar data is supported.
- Parameters:
tag – Data identifier (e.g. ‘Loss/Policy loss’).
value – Value to track.
- update(*, timestep: int, timesteps: int) None[source]¶
Algorithm’s main update step.
- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.
- write_checkpoint(*, timestep: int, timesteps: int) None[source]¶
Write checkpoint (modules) to persistent storage.
Note
The checkpoints are stored in the subdirectory
checkpointswithin the experiment directory. The checkpoint name is thetimestepargument value (if it is notNone), or the current system date-time otherwise.- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.
- class skrl.agents.torch.ppo.PPO_RNN(*, models: dict[str, Model], memory: Memory | None = None, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: PPO_CFG | dict = {})[source]¶
Bases:
AgentProximal Policy Optimization (PPO) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.).
https://arxiv.org/abs/1707.06347
- Parameters:
models – Agent’s models.
memory – Memory to storage agent’s data and environment transitions.
observation_space – Observation space.
state_space – State space.
action_space – Action space.
device – Data allocation and computation device. If not specified, the default device will be used.
cfg – Agent’s configuration.
- Raises:
KeyError – If a configuration key is missing.
Methods:
act(observations, states, *, timestep, timesteps)Process the environment's observations/states to make a decision (actions) using the main policy.
enable_models_training_mode([enabled])Set the training mode of all the agent's models: enabled (training) or disabled (evaluation).
enable_training_mode([enabled, apply_to_models])Set the training mode of the agent: enabled (training) or disabled (evaluation).
init(*[, trainer_cfg])Initialize the agent.
load(path)Load the agent from the specified path.
post_interaction(*, timestep, timesteps)Method called after the interaction with the environment.
pre_interaction(*, timestep, timesteps)Method called before the interaction with the environment.
record_transition(*, observations, states, ...)Record an environment transition in memory.
save(path)Save the agent to the specified path.
track_data(tag, value)Track data to TensorBoard.
update(*, timestep, timesteps)Algorithm's main update step.
write_checkpoint(*, timestep, timesteps)Write checkpoint (modules) to persistent storage.
write_tracking_data(*, timestep, timesteps)Write tracking data to TensorBoard.
- act(observations: torch.Tensor, states: torch.Tensor | None, *, timestep: int, timesteps: int) tuple[torch.Tensor, dict[str, Any]][source]¶
Process the environment’s observations/states to make a decision (actions) using the main policy.
- Parameters:
observations – Environment observations.
states – Environment states.
timestep – Current timestep.
timesteps – Number of timesteps.
- Returns:
Agent output. The first component is the expected action/value returned by the agent. The second component is a dictionary containing extra output values according to the model.
- enable_models_training_mode(enabled: bool = True) None[source]¶
Set the training mode of all the agent’s models: enabled (training) or disabled (evaluation).
- Parameters:
enabled – True to enable the training mode, False to enable the evaluation mode.
- enable_training_mode(enabled: bool = True, *, apply_to_models: bool = False) None[source]¶
Set the training mode of the agent: enabled (training) or disabled (evaluation).
The training mode can be queried by the
trainingproperty.- Parameters:
enabled – True to enable the training mode, False to enable the evaluation mode.
apply_to_models – Whether to apply the training mode to all the agent’s models.
- init(*, trainer_cfg: dict[str, Any] | None = None) None[source]¶
Initialize the agent.
- Parameters:
trainer_cfg – Trainer configuration.
- load(path: str) None[source]¶
Load the agent from the specified path.
Note
The final storage device is determined by the constructor of the agent.
- Parameters:
path – Path to load the agent from.
- post_interaction(*, timestep: int, timesteps: int) None[source]¶
Method called after the interaction with the environment.
- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.
- pre_interaction(*, timestep: int, timesteps: int) None[source]¶
Method called before the interaction with the environment.
- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.
- record_transition(*, observations: torch.Tensor, states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_observations: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None[source]¶
Record an environment transition in memory.
- Parameters:
observations – Environment observations.
states – Environment states.
actions – Actions taken by the agent.
rewards – Instant rewards achieved by the current actions.
next_observations – Next environment observations.
next_states – Next environment states.
terminated – Signals that indicate episodes have terminated.
truncated – Signals that indicate episodes have been truncated.
infos – Additional information about the environment.
timestep – Current timestep.
timesteps – Number of timesteps.
- save(path: str) None[source]¶
Save the agent to the specified path.
- Parameters:
path – Path to save the agent to.
- track_data(tag: str, value: float) None[source]¶
Track data to TensorBoard.
Note
Currently only scalar data is supported.
- Parameters:
tag – Data identifier (e.g. ‘Loss/Policy loss’).
value – Value to track.
- update(*, timestep: int, timesteps: int) None[source]¶
Algorithm’s main update step.
- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.
- write_checkpoint(*, timestep: int, timesteps: int) None[source]¶
Write checkpoint (modules) to persistent storage.
Note
The checkpoints are stored in the subdirectory
checkpointswithin the experiment directory. The checkpoint name is thetimestepargument value (if it is notNone), or the current system date-time otherwise.- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.
JAX¶
- class skrl.agents.jax.ppo.PPO_CFG(*, experiment: ExperimentCfg = <factory>, rollouts: int = 16, learning_epochs: int = 8, mini_batches: int = 2, discount_factor: float = 0.99, gae_lambda: float = 0.95, learning_rate: float | tuple[float, float] = 0.001, learning_rate_scheduler: type | tuple[type | None, type | None] | None = None, learning_rate_scheduler_kwargs: dict | tuple[dict, dict] = <factory>, observation_preprocessor: type | None = None, observation_preprocessor_kwargs: dict = <factory>, state_preprocessor: type | None = None, state_preprocessor_kwargs: dict = <factory>, value_preprocessor: type | None = None, value_preprocessor_kwargs: dict = <factory>, random_timesteps: int = 0, learning_starts: int = 0, grad_norm_clip: float = 0.5, ratio_clip: float = 0.2, value_clip: float = 0.2, entropy_loss_scale: float = 0.0, value_loss_scale: float = 2.5, kl_threshold: float = 0.0, time_limit_bootstrap: bool = False, rewards_shaper: Callable | None = None)[source]¶
Bases:
AgentCfgConfiguration for the PPO agent.
Methods:
Attributes:
Parameter that balances the importance of future rewards (close to 1.0) versus immediate rewards (close to 0.0).
Entropy loss scaling factor.
Experiment settings.
TD(lambda) coefficient for computing Generalized Advantage Estimation (GAE).
Clipping coefficient for the gradients by their global norm.
KL-divergence threshold for early stopping.
Number of learning epochs to perform during updates.
Learning rate for the policy and value networks.
Learning rate scheduler class for the policy and value networks.
Keyword arguments for the learning rate scheduler's constructor.
Number of steps to perform before calling the algorithm update function.
Number of mini batches to sample when updating.
Preprocessor class to process the environment's observations.
Keyword arguments for the observation preprocessor's constructor.
Number of random exploration (sampling random actions) steps to perform before sampling actions from the policy.
Clipping coefficient for computing the clipped surrogate objective.
Rewards shaping function.
Number of collection steps to perform between updates.
Preprocessor class to process the environment's states.
Keyword arguments for the state preprocessor's constructor.
Whether to bootstrap at timeout termination (episode truncation).
Clipping coefficient for the predicted value during value loss computation.
Value loss scaling factor.
Preprocessor class to process the value network's output.
Keyword arguments for the value preprocessor's constructor.
- discount_factor: float = 0.99¶
Parameter that balances the importance of future rewards (close to 1.0) versus immediate rewards (close to 0.0).
Range:
[0.0, 1.0].
- experiment: ExperimentCfg¶
Experiment settings.
- gae_lambda: float = 0.95¶
TD(lambda) coefficient for computing Generalized Advantage Estimation (GAE).
- grad_norm_clip: float = 0.5¶
Clipping coefficient for the gradients by their global norm.
If less than or equal to 0, the gradients will not be clipped.
- learning_rate: float | tuple[float, float] = 0.001¶
Learning rate for the policy and value networks.
If a float is provided, the same learning rate will be used for the networks.
If a tuple is provided, its elements will be used for each network in order.
- learning_rate_scheduler: type | tuple[type | None, type | None] | None = None¶
Learning rate scheduler class for the policy and value networks.
See Learning rate schedulers for more details.
If a class is provided, the same learning rate scheduler will be used for the networks.
If a tuple is provided, its elements will be used for each network in order.
- learning_rate_scheduler_kwargs: dict | tuple[dict, dict]¶
Keyword arguments for the learning rate scheduler’s constructor.
See Learning rate schedulers for more details.
Warning
The
optimizerargument is automatically passed to the learning rate scheduler’s constructor. Therefore, it must not be provided in the keyword arguments.If a dictionary is provided, the same keyword arguments will be used for the networks.
If a tuple is provided, its elements will be used for each network in order.
- observation_preprocessor: type | None = None¶
Preprocessor class to process the environment’s observations.
See Preprocessors for more details.
- observation_preprocessor_kwargs: dict¶
Keyword arguments for the observation preprocessor’s constructor.
See Preprocessors for more details.
- random_timesteps: int = 0¶
Number of random exploration (sampling random actions) steps to perform before sampling actions from the policy.
- state_preprocessor: type | None = None¶
Preprocessor class to process the environment’s states.
See Preprocessors for more details.
- state_preprocessor_kwargs: dict¶
Keyword arguments for the state preprocessor’s constructor.
See Preprocessors for more details.
- time_limit_bootstrap: bool = False¶
Whether to bootstrap at timeout termination (episode truncation).
- value_clip: float = 0.2¶
Clipping coefficient for the predicted value during value loss computation.
If less than or equal to 0, the predicted value will not be clipped.
- value_preprocessor: type | None = None¶
Preprocessor class to process the value network’s output.
See Preprocessors for more details.
- value_preprocessor_kwargs: dict¶
Keyword arguments for the value preprocessor’s constructor.
See Preprocessors for more details.
- class skrl.agents.jax.ppo.PPO(*, models: dict[str, Model], memory: Memory | None = None, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | jax.Device | None = None, cfg: PPO_CFG | dict = {})[source]¶
Bases:
AgentProximal Policy Optimization (PPO).
https://arxiv.org/abs/1707.06347
- Parameters:
models – Agent’s models.
memory – Memory to storage agent’s data and environment transitions.
observation_space – Observation space.
state_space – State space.
action_space – Action space.
device – Data allocation and computation device. If not specified, the default device will be used.
cfg – Agent’s configuration.
- Raises:
KeyError – If a configuration key is missing.
Methods:
act(observations, states, *, timestep, timesteps)Process the environment's observations/states to make a decision (actions) using the main policy.
enable_models_training_mode([enabled])Set the training mode of all the agent's models: enabled (training) or disabled (evaluation).
enable_training_mode([enabled, apply_to_models])Set the training mode of the agent: enabled (training) or disabled (evaluation).
init(*[, trainer_cfg])Initialize the agent.
load(path)Load the agent from the specified path.
post_interaction(*, timestep, timesteps)Method called after the interaction with the environment.
pre_interaction(*, timestep, timesteps)Method called before the interaction with the environment.
record_transition(*, observations, states, ...)Record an environment transition in memory.
save(path)Save the agent to the specified path.
track_data(tag, value)Track data to TensorBoard.
update(*, timestep, timesteps)Algorithm's main update step.
write_checkpoint(*, timestep, timesteps)Write checkpoint (modules) to persistent storage.
write_tracking_data(*, timestep, timesteps)Write tracking data to TensorBoard.
- act(observations: jax.Array, states: jax.Array | None, *, timestep: int, timesteps: int) tuple[jax.Array, dict[str, Any]][source]¶
Process the environment’s observations/states to make a decision (actions) using the main policy.
- Parameters:
observations – Environment observations.
states – Environment states.
timestep – Current timestep.
timesteps – Number of timesteps.
- Returns:
Agent output. The first component is the expected action/value returned by the agent. The second component is a dictionary containing extra output values according to the model.
- enable_models_training_mode(enabled: bool = True) None[source]¶
Set the training mode of all the agent’s models: enabled (training) or disabled (evaluation).
- Parameters:
enabled – True to enable the training mode, False to enable the evaluation mode.
- enable_training_mode(enabled: bool = True, *, apply_to_models: bool = False) None[source]¶
Set the training mode of the agent: enabled (training) or disabled (evaluation).
The training mode can be queried by the
trainingproperty.- Parameters:
enabled – True to enable the training mode, False to enable the evaluation mode.
apply_to_models – Whether to apply the training mode to all the agent’s models.
- init(*, trainer_cfg: dict[str, Any] | None = None) None[source]¶
Initialize the agent.
- Parameters:
trainer_cfg – Trainer configuration.
- load(path: str) None[source]¶
Load the agent from the specified path.
- Parameters:
path – Path to load the agent from.
- post_interaction(*, timestep: int, timesteps: int) None[source]¶
Method called after the interaction with the environment.
- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.
- pre_interaction(*, timestep: int, timesteps: int) None[source]¶
Method called before the interaction with the environment.
- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.
- record_transition(*, observations: jax.Array, states: jax.Array, actions: jax.Array, rewards: jax.Array, next_observations: jax.Array, next_states: jax.Array, terminated: jax.Array, truncated: jax.Array, infos: Any, timestep: int, timesteps: int) None[source]¶
Record an environment transition in memory.
- Parameters:
observations – Environment observations.
states – Environment states.
actions – Actions taken by the agent.
rewards – Instant rewards achieved by the current actions.
next_observations – Next environment observations.
next_states – Next environment states.
terminated – Signals that indicate episodes have terminated.
truncated – Signals that indicate episodes have been truncated.
infos – Additional information about the environment.
timestep – Current timestep.
timesteps – Number of timesteps.
- save(path: str) None[source]¶
Save the agent to the specified path.
- Parameters:
path – Path to save the agent to.
- track_data(tag: str, value: float) None[source]¶
Track data to TensorBoard.
Note
Currently only scalar data is supported.
- Parameters:
tag – Data identifier (e.g. ‘Loss/Policy loss’).
value – Value to track.
- update(*, timestep: int, timesteps: int) None[source]¶
Algorithm’s main update step.
- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.
- write_checkpoint(*, timestep: int, timesteps: int) None[source]¶
Write checkpoint (modules) to persistent storage.
Note
The checkpoints are stored in the subdirectory
checkpointswithin the experiment directory. The checkpoint name is thetimestepargument value (if it is notNone), or the current system date-time otherwise.- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.