Agents¶
Agents are autonomous entities that interact with the environment to learn and improve their behavior. Agents’ goal is to learn an optimal policy, which is a correspondence between states and actions that maximizes the cumulative reward received from the environment over time.
Implemented agents¶
The following table lists the implemented single-agents and their support for different frameworks.
Agents |
|
|
|
|---|---|---|---|
Advantage Actor Critic (A2C) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\square\) |
\(\blacksquare\) |
\(\square\) |
\(\square\) |
|
Cross-Entropy Method (CEM) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\square\) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\blacksquare\) |
|
Double Deep Q-Network (DDQN) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\square\) |
Deep Q-Network (DQN) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\square\) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\blacksquare\) |
|
Q-learning (Q-learning) |
\(\blacksquare\) |
\(\square\) |
\(\square\) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\square\) |
|
Soft Actor-Critic (SAC) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\blacksquare\) |
State Action Reward State Action (SARSA) |
\(\blacksquare\) |
\(\square\) |
\(\square\) |
Twin-Delayed DDPG (TD3) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\square\) |
\(\blacksquare\) |
\(\square\) |
\(\square\) |
Base class / configuration¶
Base class and configuration for single-agent implementations.
API¶
PyTorch¶
Base class for the agent's configuration. |
|
Configuration for the experiment (saving checkpoints and logging data). |
|
Base class that represent a RL agent/algorithm. |
- class skrl.agents.torch.AgentCfg(*, experiment: ~skrl.agents.torch.base.ExperimentCfg = <factory>)[source]¶
Bases:
ABCBase class for the agent’s configuration.
Methods:
Attributes:
Experiment settings.
- experiment: ExperimentCfg¶
Experiment settings.
- class skrl.agents.torch.ExperimentCfg(*, directory: str = '', experiment_name: str = '', write_interval: int | ~typing.Literal['auto'] = 'auto', checkpoint_interval: int | ~typing.Literal['auto'] = 'auto', store_separately: bool = False, wandb: bool = False, wandb_kwargs: dict = <factory>)[source]¶
Bases:
objectConfiguration for the experiment (saving checkpoints and logging data).
Attributes:
Interval (in timesteps) for writing checkpoints.
Directory path where the data generated by the different runs (experiments) are stored.
Name of the experiment (training/evaluation run).
Whether to store checkpoints separately.
Whether to enable the use of Weights & Biases for logging and visualization.
Keyword arguments for the Weights & Biases' setup.
Interval (in timesteps) for writing data to TensorBoard.
- checkpoint_interval: int | Literal['auto'] = 'auto'¶
Interval (in timesteps) for writing checkpoints.
A value less than or equal to 0 disables the writing of checkpoints.
If set to
"auto", the interval will be defined to collect 10 samples throughout training/evaluation (timesteps / 10).
- directory: str = ''¶
Directory path where the data generated by the different runs (experiments) are stored.
- experiment_name: str = ''¶
Name of the experiment (training/evaluation run).
If not specified, the format
YY-MM-DD_HH-MM-SS-SSSSSS_{agent_name}will be used.
- store_separately: bool = False¶
Whether to store checkpoints separately.
If set to True, all of an agent’s modules (models, optimizers, preprocessors, etc.) will be saved in separate files. By default (False), the modules are grouped in a dictionary and stored in the same file.
- wandb_kwargs: dict¶
Keyword arguments for the Weights & Biases’ setup.
Visit the Weights & Biases documentation for more details.
- write_interval: int | Literal['auto'] = 'auto'¶
Interval (in timesteps) for writing data to TensorBoard.
A value less than or equal to 0 disables the writing of data to TensorBoard.
If set to
"auto", the interval will be defined to collect 100 samples throughout training/evaluation (timesteps / 100).
- class skrl.agents.torch.Agent(*, cfg: AgentCfg, models: dict[str, Model], memory: Memory | None = None, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | torch.device | None = None)[source]¶
Bases:
ABCBase class that represent a RL agent/algorithm.
- Parameters:
cfg – Agent’s configuration.
models – Agent’s models.
memory – Memory to storage agent’s data and environment transitions.
observation_space – Observation space.
state_space – State space.
action_space – Action space.
device – Data allocation and computation device. If not specified, the default device will be used.
Methods:
act(observations, states, *, timestep, timesteps)Process the environment's observations/states to make a decision (actions) using the main policy.
enable_models_training_mode([enabled])Set the training mode of all the agent's models: enabled (training) or disabled (evaluation).
enable_training_mode([enabled, apply_to_models])Set the training mode of the agent: enabled (training) or disabled (evaluation).
init(*[, trainer_cfg])Initialize the agent.
load(path)Load the agent from the specified path.
post_interaction(*, timestep, timesteps)Method called after the interaction with the environment.
pre_interaction(*, timestep, timesteps)Method called before the interaction with the environment.
record_transition(*, observations, states, ...)Record an environment transition in memory.
save(path)Save the agent to the specified path.
track_data(tag, value)Track data to TensorBoard.
update(*, timestep, timesteps)Algorithm's main update step.
write_checkpoint(*, timestep, timesteps)Write checkpoint (modules) to persistent storage.
write_tracking_data(*, timestep, timesteps)Write tracking data to TensorBoard.
- abstractmethod act(observations: torch.Tensor, states: torch.Tensor | None, *, timestep: int, timesteps: int) tuple[torch.Tensor, dict[str, Any]][source]¶
Process the environment’s observations/states to make a decision (actions) using the main policy.
- Parameters:
observations – Environment observations.
states – Environment states.
timestep – Current timestep.
timesteps – Number of timesteps.
- Returns:
Agent output. The first component is the expected action/value returned by the agent. The second component is a dictionary containing extra output values according to the model.
- enable_models_training_mode(enabled: bool = True) None[source]¶
Set the training mode of all the agent’s models: enabled (training) or disabled (evaluation).
- Parameters:
enabled – True to enable the training mode, False to enable the evaluation mode.
- enable_training_mode(enabled: bool = True, *, apply_to_models: bool = False) None[source]¶
Set the training mode of the agent: enabled (training) or disabled (evaluation).
The training mode can be queried by the
trainingproperty.- Parameters:
enabled – True to enable the training mode, False to enable the evaluation mode.
apply_to_models – Whether to apply the training mode to all the agent’s models.
- init(*, trainer_cfg: dict[str, Any] | None = None) None[source]¶
Initialize the agent.
Warning
This method must be called before the agent is used. It will initialize the TensorBoard writer (and optionally Weights & Biases) and create the checkpoints directory.
- Parameters:
trainer_cfg – Trainer configuration.
- load(path: str) None[source]¶
Load the agent from the specified path.
Note
The final storage device is determined by the constructor of the agent.
- Parameters:
path – Path to load the agent from.
- abstractmethod post_interaction(*, timestep: int, timesteps: int) None[source]¶
Method called after the interaction with the environment.
- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.
- abstractmethod pre_interaction(*, timestep: int, timesteps: int) None[source]¶
Method called before the interaction with the environment.
- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.
- record_transition(*, observations: torch.Tensor, states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_observations: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None[source]¶
Record an environment transition in memory.
Note
This method keeps track of the episode rewards (instantaneous and cumulative) and timesteps when
experiment.write_intervalconfiguration is resolved to a positive value. Inheriting classes must call this method to record such information.- Parameters:
observations – Environment observations.
states – Environment states.
actions – Actions taken by the agent.
rewards – Instant rewards achieved by the current actions.
next_observations – Next environment observations.
next_states – Next environment states.
terminated – Signals that indicate episodes have terminated.
truncated – Signals that indicate episodes have been truncated.
infos – Additional information about the environment.
timestep – Current timestep.
timesteps – Number of timesteps.
- save(path: str) None[source]¶
Save the agent to the specified path.
- Parameters:
path – Path to save the agent to.
- track_data(tag: str, value: float) None[source]¶
Track data to TensorBoard.
Note
Currently only scalar data is supported.
- Parameters:
tag – Data identifier (e.g. ‘Loss/Policy loss’).
value – Value to track.
- abstractmethod update(*, timestep: int, timesteps: int) None[source]¶
Algorithm’s main update step.
Warning
This method should not be called directly, but rather by the agent itself when the algorithm is needed for learning.
- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.
- write_checkpoint(*, timestep: int, timesteps: int) None[source]¶
Write checkpoint (modules) to persistent storage.
Note
The checkpoints are stored in the subdirectory
checkpointswithin the experiment directory. The checkpoint name is thetimestepargument value (if it is notNone), or the current system date-time otherwise.- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.
JAX¶
Base class for the agent's configuration. |
|
Configuration for the experiment (saving checkpoints and logging data). |
|
Base class that represent a RL agent/algorithm. |
- class skrl.agents.jax.AgentCfg(*, experiment: ~skrl.agents.jax.base.ExperimentCfg = <factory>)[source]¶
Bases:
ABCBase class for the agent’s configuration.
Methods:
Attributes:
Experiment settings.
- experiment: ExperimentCfg¶
Experiment settings.
- class skrl.agents.jax.ExperimentCfg(*, directory: str = '', experiment_name: str = '', write_interval: int | ~typing.Literal['auto'] = 'auto', checkpoint_interval: int | ~typing.Literal['auto'] = 'auto', store_separately: bool = False, wandb: bool = False, wandb_kwargs: dict = <factory>)[source]¶
Bases:
objectConfiguration for the experiment (saving checkpoints and logging data).
Attributes:
Interval (in timesteps) for writing checkpoints.
Directory path where the data generated by the different runs (experiments) are stored.
Name of the experiment (training/evaluation run).
Whether to store checkpoints separately.
Whether to enable the use of Weights & Biases for logging and visualization.
Keyword arguments for the Weights & Biases' setup.
Interval (in timesteps) for writing data to TensorBoard.
- checkpoint_interval: int | Literal['auto'] = 'auto'¶
Interval (in timesteps) for writing checkpoints.
A value less than or equal to 0 disables the writing of checkpoints.
If set to
"auto", the interval will be defined to collect 10 samples throughout training/evaluation (timesteps / 10).
- directory: str = ''¶
Directory path where the data generated by the different runs (experiments) are stored.
- experiment_name: str = ''¶
Name of the experiment (training/evaluation run).
If not specified, the format
YY-MM-DD_HH-MM-SS-SSSSSS_{agent_name}will be used.
- store_separately: bool = False¶
Whether to store checkpoints separately.
If set to True, all of an agent’s modules (models, optimizers, preprocessors, etc.) will be saved in separate files. By default (False), the modules are grouped in a dictionary and stored in the same file.
- wandb_kwargs: dict¶
Keyword arguments for the Weights & Biases’ setup.
Visit the Weights & Biases documentation for more details.
- write_interval: int | Literal['auto'] = 'auto'¶
Interval (in timesteps) for writing data to TensorBoard.
A value less than or equal to 0 disables the writing of data to TensorBoard.
If set to
"auto", the interval will be defined to collect 100 samples throughout training/evaluation (timesteps / 100).
- class skrl.agents.jax.Agent(*, cfg: AgentCfg, models: dict[str, Model], memory: Memory | None = None, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | jax.Device | None = None)[source]¶
Bases:
ABCBase class that represent a RL agent/algorithm.
- Parameters:
cfg – Agent’s configuration.
models – Agent’s models.
memory – Memory to storage agent’s data and environment transitions.
observation_space – Observation space.
state_space – State space.
action_space – Action space.
device – Data allocation and computation device. If not specified, the default device will be used.
Methods:
act(observations, states, *, timestep, timesteps)Process the environment's observations/states to make a decision (actions) using the main policy.
enable_models_training_mode([enabled])Set the training mode of all the agent's models: enabled (training) or disabled (evaluation).
enable_training_mode([enabled, apply_to_models])Set the training mode of the agent: enabled (training) or disabled (evaluation).
init(*[, trainer_cfg])Initialize the agent.
load(path)Load the agent from the specified path.
post_interaction(*, timestep, timesteps)Method called after the interaction with the environment.
pre_interaction(*, timestep, timesteps)Method called before the interaction with the environment.
record_transition(*, observations, states, ...)Record an environment transition in memory.
save(path)Save the agent to the specified path.
track_data(tag, value)Track data to TensorBoard.
update(*, timestep, timesteps)Algorithm's main update step.
write_checkpoint(*, timestep, timesteps)Write checkpoint (modules) to persistent storage.
write_tracking_data(*, timestep, timesteps)Write tracking data to TensorBoard.
- abstractmethod act(observations: jax.Array, states: jax.Array | None, *, timestep: int, timesteps: int) tuple[jax.Array, dict[str, Any]][source]¶
Process the environment’s observations/states to make a decision (actions) using the main policy.
- Parameters:
observations – Environment observations.
states – Environment states.
timestep – Current timestep.
timesteps – Number of timesteps.
- Returns:
Agent output. The first component is the expected action/value returned by the agent. The second component is a dictionary containing extra output values according to the model.
- enable_models_training_mode(enabled: bool = True) None[source]¶
Set the training mode of all the agent’s models: enabled (training) or disabled (evaluation).
- Parameters:
enabled – True to enable the training mode, False to enable the evaluation mode.
- enable_training_mode(enabled: bool = True, *, apply_to_models: bool = False) None[source]¶
Set the training mode of the agent: enabled (training) or disabled (evaluation).
The training mode can be queried by the
trainingproperty.- Parameters:
enabled – True to enable the training mode, False to enable the evaluation mode.
apply_to_models – Whether to apply the training mode to all the agent’s models.
- init(*, trainer_cfg: dict[str, Any] | None = None) None[source]¶
Initialize the agent.
Warning
This method must be called before the agent is used. It will initialize the TensorBoard writer (and optionally Weights & Biases) and create the checkpoints directory.
- Parameters:
trainer_cfg – Trainer configuration.
- load(path: str) None[source]¶
Load the agent from the specified path.
- Parameters:
path – Path to load the agent from.
- abstractmethod post_interaction(*, timestep: int, timesteps: int) None[source]¶
Method called after the interaction with the environment.
- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.
- abstractmethod pre_interaction(*, timestep: int, timesteps: int) None[source]¶
Method called before the interaction with the environment.
- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.
- record_transition(*, observations: jax.Array, states: jax.Array, actions: jax.Array, rewards: jax.Array, next_observations: jax.Array, next_states: jax.Array, terminated: jax.Array, truncated: jax.Array, infos: Any, timestep: int, timesteps: int) None[source]¶
Record an environment transition in memory.
Note
This method keeps track of the episode rewards (instantaneous and cumulative) and timesteps when
experiment.write_intervalconfiguration is resolved to a positive value. Inheriting classes must call this method to record such information.- Parameters:
observations – Environment observations.
states – Environment states.
actions – Actions taken by the agent.
rewards – Instant rewards achieved by the current actions.
next_observations – Next environment observations.
next_states – Next environment states.
terminated – Signals that indicate episodes have terminated.
truncated – Signals that indicate episodes have been truncated.
infos – Additional information about the environment.
timestep – Current timestep.
timesteps – Number of timesteps.
- save(path: str) None[source]¶
Save the agent to the specified path.
- Parameters:
path – Path to save the agent to.
- track_data(tag: str, value: float) None[source]¶
Track data to TensorBoard.
Note
Currently only scalar data is supported.
- Parameters:
tag – Data identifier (e.g. ‘Loss/Policy loss’).
value – Value to track.
- abstractmethod update(*, timestep: int, timesteps: int) None[source]¶
Algorithm’s main update step.
Warning
This method should not be called directly, but rather by the agent itself when the algorithm is needed for learning.
- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.
- write_checkpoint(*, timestep: int, timesteps: int) None[source]¶
Write checkpoint (modules) to persistent storage.
Note
The checkpoints are stored in the subdirectory
checkpointswithin the experiment directory. The checkpoint name is thetimestepargument value (if it is notNone), or the current system date-time otherwise.- Parameters:
timestep – Current timestep.
timesteps – Number of timesteps.
Warp¶
Base class for the agent's configuration. |
|
Configuration for the experiment (saving checkpoints and logging data). |
|
Base class that represent a RL agent/algorithm. |