Agents

Agents are autonomous entities that interact with the environment to learn and improve their behavior. Agents’ goal is to learn an optimal policy, which is a correspondence between states and actions that maximizes the cumulative reward received from the environment over time.



Agents

    pytorch    

    jax    

Advantage Actor Critic (A2C)

\(\blacksquare\)

\(\blacksquare\)

Adversarial Motion Priors (AMP)

\(\blacksquare\)

\(\square\)

Cross-Entropy Method (CEM)

\(\blacksquare\)

\(\blacksquare\)

Deep Deterministic Policy Gradient (DDPG)

\(\blacksquare\)

\(\blacksquare\)

Double Deep Q-Network (DDQN)

\(\blacksquare\)

\(\blacksquare\)

Deep Q-Network (DQN)

\(\blacksquare\)

\(\blacksquare\)

Proximal Policy Optimization (PPO)

\(\blacksquare\)

\(\blacksquare\)

Q-learning (Q-learning)

\(\blacksquare\)

\(\square\)

Robust Policy Optimization (RPO)

\(\blacksquare\)

\(\blacksquare\)

Soft Actor-Critic (SAC)

\(\blacksquare\)

\(\blacksquare\)

State Action Reward State Action (SARSA)

\(\blacksquare\)

\(\square\)

Twin-Delayed DDPG (TD3)

\(\blacksquare\)

\(\blacksquare\)

Trust Region Policy Optimization (TRPO)

\(\blacksquare\)

\(\square\)

Base class

Note

This is the base class for all agents in this module and provides only basic functionality that is not tied to any implementation of the optimization algorithms. It is not intended to be used directly.


Basic inheritance usage

from typing import Union, Tuple, Dict, Any, Optional

import gym, gymnasium
import copy

import torch

from skrl.memories.torch import Memory
from skrl.models.torch import Model

from skrl.agents.torch import Agent


CUSTOM_DEFAULT_CONFIG = {
    # ...

    "experiment": {
        "directory": "",            # experiment's parent directory
        "experiment_name": "",      # experiment name
        "write_interval": 250,      # TensorBoard writing interval (timesteps)

        "checkpoint_interval": 1000,        # interval for checkpoints (timesteps)
        "store_separately": False,          # whether to store checkpoints separately

        "wandb": False,             # whether to use Weights & Biases
        "wandb_kwargs": {}          # wandb kwargs (see https://docs.wandb.ai/ref/python/init)
    }
}


class CUSTOM(Agent):
    def __init__(self,
                 models: Dict[str, Model],
                 memory: Optional[Union[Memory, Tuple[Memory]]] = None,
                 observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
                 action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
                 device: Optional[Union[str, torch.device]] = None,
                 cfg: Optional[dict] = None) -> None:
        """Custom agent

        :param models: Models used by the agent
        :type models: dictionary of skrl.models.torch.Model
        :param memory: Memory to storage the transitions.
                       If it is a tuple, the first element will be used for training and
                       for the rest only the environment transitions will be added
        :type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None
        :param observation_space: Observation/state space or shape (default: None)
        :type observation_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional
        :param action_space: Action space or shape (default: None)
        :type action_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional
        :param device: Device on which a torch tensor is or will be allocated (default: ``None``).
                       If None, the device will be either ``"cuda:0"`` if available or ``"cpu"``
        :type device: str or torch.device, optional
        :param cfg: Configuration dictionary
        :type cfg: dict
        """
        _cfg = copy.deepcopy(CUSTOM_DEFAULT_CONFIG)
        _cfg.update(cfg if cfg is not None else {})
        super().__init__(models=models,
                         memory=memory,
                         observation_space=observation_space,
                         action_space=action_space,
                         device=device,
                         cfg=_cfg)
        # =======================================================================
        # - get and process models from `self.models`
        # - populate `self.checkpoint_modules` dictionary for storing checkpoints
        # - parse configurations from `self.cfg`
        # - setup optimizers and learning rate scheduler
        # - set up preprocessors
        # =======================================================================

    def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
        """Initialize the agent
        """
        super().init(trainer_cfg=trainer_cfg)
        self.set_mode("eval")
        # =================================================================
        # - create tensors in memory if required
        # - # create temporary variables needed for storage and computation
        # =================================================================

    def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tensor:
        """Process the environment's states to make a decision (actions) using the main policy

        :param states: Environment's states
        :type states: torch.Tensor
        :param timestep: Current timestep
        :type timestep: int
        :param timesteps: Number of timesteps
        :type timesteps: int

        :return: Actions
        :rtype: torch.Tensor
        """
        # ======================================
        # - sample random actions if required or
        #   sample and return agent's actions
        # ======================================

    def record_transition(self,
                          states: torch.Tensor,
                          actions: torch.Tensor,
                          rewards: torch.Tensor,
                          next_states: torch.Tensor,
                          terminated: torch.Tensor,
                          truncated: torch.Tensor,
                          infos: Any,
                          timestep: int,
                          timesteps: int) -> None:
        """Record an environment transition in memory

        :param states: Observations/states of the environment used to make the decision
        :type states: torch.Tensor
        :param actions: Actions taken by the agent
        :type actions: torch.Tensor
        :param rewards: Instant rewards achieved by the current actions
        :type rewards: torch.Tensor
        :param next_states: Next observations/states of the environment
        :type next_states: torch.Tensor
        :param terminated: Signals to indicate that episodes have terminated
        :type terminated: torch.Tensor
        :param truncated: Signals to indicate that episodes have been truncated
        :type truncated: torch.Tensor
        :param infos: Additional information about the environment
        :type infos: Any type supported by the environment
        :param timestep: Current timestep
        :type timestep: int
        :param timesteps: Number of timesteps
        :type timesteps: int
        """
        super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps)
        # ========================================
        # - record agent's specific data in memory
        # ========================================

    def pre_interaction(self, timestep: int, timesteps: int) -> None:
        """Callback called before the interaction with the environment

        :param timestep: Current timestep
        :type timestep: int
        :param timesteps: Number of timesteps
        :type timesteps: int
        """
        # =====================================
        # - call `self.update(...)` if required
        # =====================================

    def post_interaction(self, timestep: int, timesteps: int) -> None:
        """Callback called after the interaction with the environment

        :param timestep: Current timestep
        :type timestep: int
        :param timesteps: Number of timesteps
        :type timesteps: int
        """
        # =====================================
        # - call `self.update(...)` if required
        # =====================================
        # call parent's method for checkpointing and TensorBoard writing
        super().post_interaction(timestep, timesteps)

    def _update(self, timestep: int, timesteps: int) -> None:
        """Algorithm's main update step

        :param timestep: Current timestep
        :type timestep: int
        :param timesteps: Number of timesteps
        :type timesteps: int
        """
        # ===================================================
        # - implement algorithm's update step
        # - record tracking data using `self.track_data(...)`
        # ===================================================

API (PyTorch)

class skrl.agents.torch.base.Agent(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None)

Bases: object

__init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None) None

Base class that represent a RL agent

Parameters:
  • models (dictionary of skrl.models.torch.Model) – Models used by the agent

  • memory (skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

  • observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

  • action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

  • device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

  • cfg (dict) – Configuration dictionary

__str__() str

Generate a representation of the agent as string

Returns:

Representation of the agent as string

Return type:

str

_empty_preprocessor(_input: Any, *args, **kwargs) Any

Empty preprocess method

This method is defined because PyTorch multiprocessing can’t pickle lambdas

Parameters:

_input (Any) – Input to preprocess

Returns:

Preprocessed input

Return type:

Any

_get_internal_value(_module: Any) Any

Get internal module/variable state/value

Parameters:

_module (Any) – Module or variable

Returns:

Module/variable state/value

Return type:

Any

_update(timestep: int, timesteps: int) None

Algorithm’s main update step

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

Raises:

NotImplementedError – The method is not implemented by the inheriting classes

act(states: torch.Tensor, timestep: int, timesteps: int) torch.Tensor

Process the environment’s states to make a decision (actions) using the main policy

Parameters:
  • states (torch.Tensor) – Environment’s states

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

Raises:

NotImplementedError – The method is not implemented by the inheriting classes

Returns:

Actions

Return type:

torch.Tensor

init(trainer_cfg: Mapping[str, Any] | None = None) None

Initialize the agent

This method should be called before the agent is used. It will initialize the TensoBoard writer (and optionally Weights & Biases) and create the checkpoints directory

Parameters:

trainer_cfg (dict, optional) – Trainer configuration

load(path: str) None

Load the model from the specified path

The final storage device is determined by the constructor of the model

Parameters:

path (str) – Path to load the model from

migrate(path: str, name_map: Mapping[str, Mapping[str, str]] = {}, auto_mapping: bool = True, verbose: bool = False) bool

Migrate the specified extrernal checkpoint to the current agent

The final storage device is determined by the constructor of the agent. Only files generated by the rl_games library are supported at the moment

For ambiguous models (where 2 or more parameters, for source or current model, have equal shape) it is necessary to define the name_map, at least for those parameters, to perform the migration successfully

Parameters:
  • path (str) – Path to the external checkpoint to migrate from

  • name_map (Mapping[str, Mapping[str, str]], optional) – Name map to use for the migration (default: {}). Keys are the current parameter names and values are the external parameter names

  • auto_mapping (bool, optional) – Automatically map the external state dict to the current state dict (default: True)

  • verbose (bool, optional) – Show model names and migration (default: False)

Raises:

ValueError – If the correct file type cannot be identified from the path parameter

Returns:

True if the migration was successful, False otherwise. Migration is successful if all parameters of the current model are found in the external model

Return type:

bool

Example:

# migrate a rl_games checkpoint with ambiguous state_dict
>>> agent.migrate(path="./runs/Cartpole/nn/Cartpole.pth", verbose=False)
[skrl:WARNING] Ambiguous match for net.0.bias <- [a2c_network.actor_mlp.0.bias, a2c_network.actor_mlp.2.bias]
[skrl:WARNING] Ambiguous match for net.2.bias <- [a2c_network.actor_mlp.0.bias, a2c_network.actor_mlp.2.bias]
[skrl:WARNING] Ambiguous match for net.4.weight <- [a2c_network.value.weight, a2c_network.mu.weight]
[skrl:WARNING] Ambiguous match for net.4.bias <- [a2c_network.value.bias, a2c_network.mu.bias]
[skrl:WARNING] Multiple use of a2c_network.actor_mlp.0.bias -> [net.0.bias, net.2.bias]
[skrl:WARNING] Multiple use of a2c_network.actor_mlp.2.bias -> [net.0.bias, net.2.bias]
[skrl:WARNING] Ambiguous match for net.0.bias <- [a2c_network.actor_mlp.0.bias, a2c_network.actor_mlp.2.bias]
[skrl:WARNING] Ambiguous match for net.2.bias <- [a2c_network.actor_mlp.0.bias, a2c_network.actor_mlp.2.bias]
[skrl:WARNING] Ambiguous match for net.4.weight <- [a2c_network.value.weight, a2c_network.mu.weight]
[skrl:WARNING] Ambiguous match for net.4.bias <- [a2c_network.value.bias, a2c_network.mu.bias]
[skrl:WARNING] Multiple use of a2c_network.actor_mlp.0.bias -> [net.0.bias, net.2.bias]
[skrl:WARNING] Multiple use of a2c_network.actor_mlp.2.bias -> [net.0.bias, net.2.bias]
False
>>> name_map = {"policy": {"net.0.bias": "a2c_network.actor_mlp.0.bias",
...                        "net.2.bias": "a2c_network.actor_mlp.2.bias",
...                        "net.4.weight": "a2c_network.mu.weight",
...                        "net.4.bias": "a2c_network.mu.bias"},
...             "value": {"net.0.bias": "a2c_network.actor_mlp.0.bias",
...                       "net.2.bias": "a2c_network.actor_mlp.2.bias",
...                       "net.4.weight": "a2c_network.value.weight",
...                       "net.4.bias": "a2c_network.value.bias"}}
>>> model.migrate(path="./runs/Cartpole/nn/Cartpole.pth", name_map=name_map, verbose=True)
[skrl:INFO] Modules
[skrl:INFO]   |-- current
[skrl:INFO]   |    |-- policy (Policy)
[skrl:INFO]   |    |    |-- log_std_parameter : [1]
[skrl:INFO]   |    |    |-- net.0.weight : [32, 4]
[skrl:INFO]   |    |    |-- net.0.bias : [32]
[skrl:INFO]   |    |    |-- net.2.weight : [32, 32]
[skrl:INFO]   |    |    |-- net.2.bias : [32]
[skrl:INFO]   |    |    |-- net.4.weight : [1, 32]
[skrl:INFO]   |    |    |-- net.4.bias : [1]
[skrl:INFO]   |    |-- value (Value)
[skrl:INFO]   |    |    |-- net.0.weight : [32, 4]
[skrl:INFO]   |    |    |-- net.0.bias : [32]
[skrl:INFO]   |    |    |-- net.2.weight : [32, 32]
[skrl:INFO]   |    |    |-- net.2.bias : [32]
[skrl:INFO]   |    |    |-- net.4.weight : [1, 32]
[skrl:INFO]   |    |    |-- net.4.bias : [1]
[skrl:INFO]   |    |-- optimizer (Adam)
[skrl:INFO]   |    |    |-- state (dict)
[skrl:INFO]   |    |    |-- param_groups (list)
[skrl:INFO]   |    |-- state_preprocessor (RunningStandardScaler)
[skrl:INFO]   |    |    |-- running_mean : [4]
[skrl:INFO]   |    |    |-- running_variance : [4]
[skrl:INFO]   |    |    |-- current_count : []
[skrl:INFO]   |    |-- value_preprocessor (RunningStandardScaler)
[skrl:INFO]   |    |    |-- running_mean : [1]
[skrl:INFO]   |    |    |-- running_variance : [1]
[skrl:INFO]   |    |    |-- current_count : []
[skrl:INFO]   |-- source
[skrl:INFO]   |    |-- model (OrderedDict)
[skrl:INFO]   |    |    |-- value_mean_std.running_mean : [1]
[skrl:INFO]   |    |    |-- value_mean_std.running_var : [1]
[skrl:INFO]   |    |    |-- value_mean_std.count : []
[skrl:INFO]   |    |    |-- running_mean_std.running_mean : [4]
[skrl:INFO]   |    |    |-- running_mean_std.running_var : [4]
[skrl:INFO]   |    |    |-- running_mean_std.count : []
[skrl:INFO]   |    |    |-- a2c_network.sigma : [1]
[skrl:INFO]   |    |    |-- a2c_network.actor_mlp.0.weight : [32, 4]
[skrl:INFO]   |    |    |-- a2c_network.actor_mlp.0.bias : [32]
[skrl:INFO]   |    |    |-- a2c_network.actor_mlp.2.weight : [32, 32]
[skrl:INFO]   |    |    |-- a2c_network.actor_mlp.2.bias : [32]
[skrl:INFO]   |    |    |-- a2c_network.value.weight : [1, 32]
[skrl:INFO]   |    |    |-- a2c_network.value.bias : [1]
[skrl:INFO]   |    |    |-- a2c_network.mu.weight : [1, 32]
[skrl:INFO]   |    |    |-- a2c_network.mu.bias : [1]
[skrl:INFO]   |    |-- epoch (int)
[skrl:INFO]   |    |-- optimizer (dict)
[skrl:INFO]   |    |-- frame (int)
[skrl:INFO]   |    |-- last_mean_rewards (float32)
[skrl:INFO]   |    |-- env_state (NoneType)
[skrl:INFO] Migration
[skrl:INFO] Model: policy (Policy)
[skrl:INFO] Models
[skrl:INFO]   |-- current: 7 items
[skrl:INFO]   |    |-- log_std_parameter : [1]
[skrl:INFO]   |    |-- net.0.weight : [32, 4]
[skrl:INFO]   |    |-- net.0.bias : [32]
[skrl:INFO]   |    |-- net.2.weight : [32, 32]
[skrl:INFO]   |    |-- net.2.bias : [32]
[skrl:INFO]   |    |-- net.4.weight : [1, 32]
[skrl:INFO]   |    |-- net.4.bias : [1]
[skrl:INFO]   |-- source: 9 items
[skrl:INFO]   |    |-- a2c_network.sigma : [1]
[skrl:INFO]   |    |-- a2c_network.actor_mlp.0.weight : [32, 4]
[skrl:INFO]   |    |-- a2c_network.actor_mlp.0.bias : [32]
[skrl:INFO]   |    |-- a2c_network.actor_mlp.2.weight : [32, 32]
[skrl:INFO]   |    |-- a2c_network.actor_mlp.2.bias : [32]
[skrl:INFO]   |    |-- a2c_network.value.weight : [1, 32]
[skrl:INFO]   |    |-- a2c_network.value.bias : [1]
[skrl:INFO]   |    |-- a2c_network.mu.weight : [1, 32]
[skrl:INFO]   |    |-- a2c_network.mu.bias : [1]
[skrl:INFO] Migration
[skrl:INFO]   |-- auto: log_std_parameter <- a2c_network.sigma
[skrl:INFO]   |-- auto: net.0.weight <- a2c_network.actor_mlp.0.weight
[skrl:INFO]   |-- map:  net.0.bias <- a2c_network.actor_mlp.0.bias
[skrl:INFO]   |-- auto: net.2.weight <- a2c_network.actor_mlp.2.weight
[skrl:INFO]   |-- map:  net.2.bias <- a2c_network.actor_mlp.2.bias
[skrl:INFO]   |-- map:  net.4.weight <- a2c_network.mu.weight
[skrl:INFO]   |-- map:  net.4.bias <- a2c_network.mu.bias
[skrl:INFO] Model: value (Value)
[skrl:INFO] Models
[skrl:INFO]   |-- current: 6 items
[skrl:INFO]   |    |-- net.0.weight : [32, 4]
[skrl:INFO]   |    |-- net.0.bias : [32]
[skrl:INFO]   |    |-- net.2.weight : [32, 32]
[skrl:INFO]   |    |-- net.2.bias : [32]
[skrl:INFO]   |    |-- net.4.weight : [1, 32]
[skrl:INFO]   |    |-- net.4.bias : [1]
[skrl:INFO]   |-- source: 9 items
[skrl:INFO]   |    |-- a2c_network.sigma : [1]
[skrl:INFO]   |    |-- a2c_network.actor_mlp.0.weight : [32, 4]
[skrl:INFO]   |    |-- a2c_network.actor_mlp.0.bias : [32]
[skrl:INFO]   |    |-- a2c_network.actor_mlp.2.weight : [32, 32]
[skrl:INFO]   |    |-- a2c_network.actor_mlp.2.bias : [32]
[skrl:INFO]   |    |-- a2c_network.value.weight : [1, 32]
[skrl:INFO]   |    |-- a2c_network.value.bias : [1]
[skrl:INFO]   |    |-- a2c_network.mu.weight : [1, 32]
[skrl:INFO]   |    |-- a2c_network.mu.bias : [1]
[skrl:INFO] Migration
[skrl:INFO]   |-- auto: net.0.weight <- a2c_network.actor_mlp.0.weight
[skrl:INFO]   |-- map:  net.0.bias <- a2c_network.actor_mlp.0.bias
[skrl:INFO]   |-- auto: net.2.weight <- a2c_network.actor_mlp.2.weight
[skrl:INFO]   |-- map:  net.2.bias <- a2c_network.actor_mlp.2.bias
[skrl:INFO]   |-- map:  net.4.weight <- a2c_network.value.weight
[skrl:INFO]   |-- map:  net.4.bias <- a2c_network.value.bias
True
post_interaction(timestep: int, timesteps: int) None

Callback called after the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None

Callback called before the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

record_transition(states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None

Record an environment transition in memory (to be implemented by the inheriting classes)

Inheriting classes must call this method to record episode information (rewards, timesteps, etc.). In addition to recording environment transition (such as states, rewards, etc.), agent information can be recorded.

Parameters:
  • states (torch.Tensor) – Observations/states of the environment used to make the decision

  • actions (torch.Tensor) – Actions taken by the agent

  • rewards (torch.Tensor) – Instant rewards achieved by the current actions

  • next_states (torch.Tensor) – Next observations/states of the environment

  • terminated (torch.Tensor) – Signals to indicate that episodes have terminated

  • truncated (torch.Tensor) – Signals to indicate that episodes have been truncated

  • infos (Any type supported by the environment) – Additional information about the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

save(path: str) None

Save the agent to the specified path

Parameters:

path (str) – Path to save the model to

set_mode(mode: str) None

Set the model mode (training or evaluation)

Parameters:

mode (str) – Mode: ‘train’ for training or ‘eval’ for evaluation

set_running_mode(mode: str) None

Set the current running mode (training or evaluation)

This method sets the value of the training property (boolean). This property can be used to know if the agent is running in training or evaluation mode.

Parameters:

mode (str) – Mode: ‘train’ for training or ‘eval’ for evaluation

track_data(tag: str, value: float) None

Track data to TensorBoard

Currently only scalar data are supported

Parameters:
  • tag (str) – Data identifier (e.g. ‘Loss / policy loss’)

  • value (float) – Value to track

write_checkpoint(timestep: int, timesteps: int) None

Write checkpoint (modules) to disk

The checkpoints are saved in the directory ‘checkpoints’ in the experiment directory. The name of the checkpoint is the current timestep if timestep is not None, otherwise it is the current time.

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

write_tracking_data(timestep: int, timesteps: int) None

Write tracking data to TensorBoard

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps


API (JAX)

class skrl.agents.jax.base.Agent(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | jax.Device | None = None, cfg: dict | None = None)

Bases: object

__init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | jax.Device | None = None, cfg: dict | None = None) None

Base class that represent a RL agent

Parameters:
  • models (dictionary of skrl.models.jax.Model) – Models used by the agent

  • memory (skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

  • observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

  • action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

  • device (str or jax.Device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

  • cfg (dict) – Configuration dictionary

__str__() str

Generate a representation of the agent as string

Returns:

Representation of the agent as string

Return type:

str

_empty_preprocessor(_input: Any, *args, **kwargs) Any

Empty preprocess method

This method is defined because PyTorch multiprocessing can’t pickle lambdas

Parameters:

_input (Any) – Input to preprocess

Returns:

Preprocessed input

Return type:

Any

_get_internal_value(_module: Any) Any

Get internal module/variable state/value

Parameters:

_module (Any) – Module or variable

Returns:

Module/variable state/value

Return type:

Any

_update(timestep: int, timesteps: int) None

Algorithm’s main update step

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

Raises:

NotImplementedError – The method is not implemented by the inheriting classes

act(states: ndarray | jax.Array, timestep: int, timesteps: int) ndarray | jax.Array

Process the environment’s states to make a decision (actions) using the main policy

Parameters:
  • states (np.ndarray or jax.Array) – Environment’s states

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

Raises:

NotImplementedError – The method is not implemented by the inheriting classes

Returns:

Actions

Return type:

np.ndarray or jax.Array

init(trainer_cfg: Mapping[str, Any] | None = None) None

Initialize the agent

This method should be called before the agent is used. It will initialize the TensoBoard writer (and optionally Weights & Biases) and create the checkpoints directory

Parameters:

trainer_cfg (dict, optional) – Trainer configuration

load(path: str) None

Load the model from the specified path

Parameters:

path (str) – Path to load the model from

migrate(path: str, name_map: Mapping[str, Mapping[str, str]] = {}, auto_mapping: bool = True, verbose: bool = False) bool

Migrate the specified extrernal checkpoint to the current agent

Raises:

NotImplementedError – Not yet implemented

post_interaction(timestep: int, timesteps: int) None

Callback called after the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None

Callback called before the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

record_transition(states: ndarray | jax.Array, actions: ndarray | jax.Array, rewards: ndarray | jax.Array, next_states: ndarray | jax.Array, terminated: ndarray | jax.Array, truncated: ndarray | jax.Array, infos: Any, timestep: int, timesteps: int) None

Record an environment transition in memory (to be implemented by the inheriting classes)

Inheriting classes must call this method to record episode information (rewards, timesteps, etc.). In addition to recording environment transition (such as states, rewards, etc.), agent information can be recorded.

Parameters:
  • states (np.ndarray or jax.Array) – Observations/states of the environment used to make the decision

  • actions (np.ndarray or jax.Array) – Actions taken by the agent

  • rewards (np.ndarray or jax.Array) – Instant rewards achieved by the current actions

  • next_states (np.ndarray or jax.Array) – Next observations/states of the environment

  • terminated (np.ndarray or jax.Array) – Signals to indicate that episodes have terminated

  • truncated (np.ndarray or jax.Array) – Signals to indicate that episodes have been truncated

  • infos (Any type supported by the environment) – Additional information about the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

save(path: str) None

Save the agent to the specified path

Parameters:

path (str) – Path to save the model to

set_mode(mode: str) None

Set the model mode (training or evaluation)

Parameters:

mode (str) – Mode: ‘train’ for training or ‘eval’ for evaluation

set_running_mode(mode: str) None

Set the current running mode (training or evaluation)

This method sets the value of the training property (boolean). This property can be used to know if the agent is running in training or evaluation mode.

Parameters:

mode (str) – Mode: ‘train’ for training or ‘eval’ for evaluation

track_data(tag: str, value: float) None

Track data to TensorBoard

Currently only scalar data are supported

Parameters:
  • tag (str) – Data identifier (e.g. ‘Loss / policy loss’)

  • value (float) – Value to track

write_checkpoint(timestep: int, timesteps: int) None

Write checkpoint (modules) to disk

The checkpoints are saved in the directory ‘checkpoints’ in the experiment directory. The name of the checkpoint is the current timestep if timestep is not None, otherwise it is the current time.

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

write_tracking_data(timestep: int, timesteps: int) None

Write tracking data to TensorBoard

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps