Agents¶
Agents are autonomous entities that interact with the environment to learn and improve their behavior. Agents’ goal is to learn an optimal policy, which is a correspondence between states and actions that maximizes the cumulative reward received from the environment over time.
Agents |
|
|
---|---|---|
Advantage Actor Critic (A2C) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\square\) |
|
Cross-Entropy Method (CEM) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\blacksquare\) |
|
Double Deep Q-Network (DDQN) |
\(\blacksquare\) |
\(\blacksquare\) |
Deep Q-Network (DQN) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\blacksquare\) |
|
Q-learning (Q-learning) |
\(\blacksquare\) |
\(\square\) |
\(\blacksquare\) |
\(\blacksquare\) |
|
Soft Actor-Critic (SAC) |
\(\blacksquare\) |
\(\blacksquare\) |
State Action Reward State Action (SARSA) |
\(\blacksquare\) |
\(\square\) |
Twin-Delayed DDPG (TD3) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\square\) |
Base class¶
Note
This is the base class for all agents in this module and provides only basic functionality that is not tied to any implementation of the optimization algorithms. It is not intended to be used directly.
Basic inheritance usage¶
from typing import Union, Tuple, Dict, Any, Optional
import gym, gymnasium
import copy
import torch
from skrl.memories.torch import Memory
from skrl.models.torch import Model
from skrl.agents.torch import Agent
CUSTOM_DEFAULT_CONFIG = {
# ...
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
"write_interval": 250, # TensorBoard writing interval (timesteps)
"checkpoint_interval": 1000, # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
"wandb_kwargs": {} # wandb kwargs (see https://docs.wandb.ai/ref/python/init)
}
}
class CUSTOM(Agent):
def __init__(self,
models: Dict[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
device: Optional[Union[str, torch.device]] = None,
cfg: Optional[dict] = None) -> None:
"""Custom agent
:param models: Models used by the agent
:type models: dictionary of skrl.models.torch.Model
:param memory: Memory to storage the transitions.
If it is a tuple, the first element will be used for training and
for the rest only the environment transitions will be added
:type memory: skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None
:param observation_space: Observation/state space or shape (default: None)
:type observation_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional
:param action_space: Action space or shape (default: None)
:type action_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional
:param device: Device on which a torch tensor is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda:0"`` if available or ``"cpu"``
:type device: str or torch.device, optional
:param cfg: Configuration dictionary
:type cfg: dict
"""
_cfg = copy.deepcopy(CUSTOM_DEFAULT_CONFIG)
_cfg.update(cfg if cfg is not None else {})
super().__init__(models=models,
memory=memory,
observation_space=observation_space,
action_space=action_space,
device=device,
cfg=_cfg)
# =======================================================================
# - get and process models from `self.models`
# - populate `self.checkpoint_modules` dictionary for storing checkpoints
# - parse configurations from `self.cfg`
# - setup optimizers and learning rate scheduler
# - set up preprocessors
# =======================================================================
def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
self.set_mode("eval")
# =================================================================
# - create tensors in memory if required
# - # create temporary variables needed for storage and computation
# =================================================================
def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tensor:
"""Process the environment's states to make a decision (actions) using the main policy
:param states: Environment's states
:type states: torch.Tensor
:param timestep: Current timestep
:type timestep: int
:param timesteps: Number of timesteps
:type timesteps: int
:return: Actions
:rtype: torch.Tensor
"""
# ======================================
# - sample random actions if required or
# sample and return agent's actions
# ======================================
def record_transition(self,
states: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
next_states: torch.Tensor,
terminated: torch.Tensor,
truncated: torch.Tensor,
infos: Any,
timestep: int,
timesteps: int) -> None:
"""Record an environment transition in memory
:param states: Observations/states of the environment used to make the decision
:type states: torch.Tensor
:param actions: Actions taken by the agent
:type actions: torch.Tensor
:param rewards: Instant rewards achieved by the current actions
:type rewards: torch.Tensor
:param next_states: Next observations/states of the environment
:type next_states: torch.Tensor
:param terminated: Signals to indicate that episodes have terminated
:type terminated: torch.Tensor
:param truncated: Signals to indicate that episodes have been truncated
:type truncated: torch.Tensor
:param infos: Additional information about the environment
:type infos: Any type supported by the environment
:param timestep: Current timestep
:type timestep: int
:param timesteps: Number of timesteps
:type timesteps: int
"""
super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps)
# ========================================
# - record agent's specific data in memory
# ========================================
def pre_interaction(self, timestep: int, timesteps: int) -> None:
"""Callback called before the interaction with the environment
:param timestep: Current timestep
:type timestep: int
:param timesteps: Number of timesteps
:type timesteps: int
"""
# =====================================
# - call `self.update(...)` if required
# =====================================
def post_interaction(self, timestep: int, timesteps: int) -> None:
"""Callback called after the interaction with the environment
:param timestep: Current timestep
:type timestep: int
:param timesteps: Number of timesteps
:type timesteps: int
"""
# =====================================
# - call `self.update(...)` if required
# =====================================
# call parent's method for checkpointing and TensorBoard writing
super().post_interaction(timestep, timesteps)
def _update(self, timestep: int, timesteps: int) -> None:
"""Algorithm's main update step
:param timestep: Current timestep
:type timestep: int
:param timesteps: Number of timesteps
:type timesteps: int
"""
# ===================================================
# - implement algorithm's update step
# - record tracking data using `self.track_data(...)`
# ===================================================
from typing import Union, Tuple, Dict, Any, Optional
import gym, gymnasium
import copy
import jaxlib
import jax.numpy as jnp
from skrl.memories.jax import Memory
from skrl.models.jax import Model
from skrl.resources.optimizers.jax import Adam
from skrl.agents.jax import Agent
CUSTOM_DEFAULT_CONFIG = {
# ...
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
"write_interval": 250, # TensorBoard writing interval (timesteps)
"checkpoint_interval": 1000, # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
"wandb_kwargs": {} # wandb kwargs (see https://docs.wandb.ai/ref/python/init)
}
}
class CUSTOM(Agent):
def __init__(self,
models: Dict[str, Model],
memory: Optional[Union[Memory, Tuple[Memory]]] = None,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
device: Optional[Union[str, jaxlib.xla_extension.Device]] = None,
cfg: Optional[dict] = None) -> None:
"""Custom agent
:param models: Models used by the agent
:type models: dictionary of skrl.models.jax.Model
:param memory: Memory to storage the transitions.
If it is a tuple, the first element will be used for training and
for the rest only the environment transitions will be added
:type memory: skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None
:param observation_space: Observation/state space or shape (default: None)
:type observation_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional
:param action_space: Action space or shape (default: None)
:type action_space: int, tuple or list of integers, gym.Space, gymnasium.Space or None, optional
:param device: Device on which a jax array is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda:0"`` if available or ``"cpu"``
:type device: str or jaxlib.xla_extension.Device, optional
:param cfg: Configuration dictionary
:type cfg: dict
"""
_cfg = CUSTOM_DEFAULT_CONFIG
_cfg.update(cfg if cfg is not None else {})
super().__init__(models=models,
memory=memory,
observation_space=observation_space,
action_space=action_space,
device=device,
cfg=_cfg)
# =======================================================================
# - get and process models from `self.models`
# - populate `self.checkpoint_modules` dictionary for storing checkpoints
# - parse configurations from `self.cfg`
# - setup optimizers and learning rate scheduler
# - set up preprocessors
# =======================================================================
def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
"""Initialize the agent
"""
super().init(trainer_cfg=trainer_cfg)
self.set_mode("eval")
# =================================================================
# - create tensors in memory if required
# - # create temporary variables needed for storage and computation
# - set up models for just-in-time compilation with XLA
# =================================================================
def act(self, states: jnp.ndarray, timestep: int, timesteps: int) -> jnp.ndarray:
"""Process the environment's states to make a decision (actions) using the main policy
:param states: Environment's states
:type states: jnp.ndarray
:param timestep: Current timestep
:type timestep: int
:param timesteps: Number of timesteps
:type timesteps: int
:return: Actions
:rtype: jnp.ndarray
"""
# ======================================
# - sample random actions if required or
# sample and return agent's actions
# ======================================
def record_transition(self,
states: jnp.ndarray,
actions: jnp.ndarray,
rewards: jnp.ndarray,
next_states: jnp.ndarray,
terminated: jnp.ndarray,
truncated: jnp.ndarray,
infos: Any,
timestep: int,
timesteps: int) -> None:
"""Record an environment transition in memory
:param states: Observations/states of the environment used to make the decision
:type states: jnp.ndarray
:param actions: Actions taken by the agent
:type actions: jnp.ndarray
:param rewards: Instant rewards achieved by the current actions
:type rewards: jnp.ndarray
:param next_states: Next observations/states of the environment
:type next_states: jnp.ndarray
:param terminated: Signals to indicate that episodes have terminated
:type terminated: jnp.ndarray
:param truncated: Signals to indicate that episodes have been truncated
:type truncated: jnp.ndarray
:param infos: Additional information about the environment
:type infos: Any type supported by the environment
:param timestep: Current timestep
:type timestep: int
:param timesteps: Number of timesteps
:type timesteps: int
"""
super().record_transition(states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps)
# ========================================
# - record agent's specific data in memory
# ========================================
def pre_interaction(self, timestep: int, timesteps: int) -> None:
"""Callback called before the interaction with the environment
:param timestep: Current timestep
:type timestep: int
:param timesteps: Number of timesteps
:type timesteps: int
"""
# =====================================
# - call `self.update(...)` if required
# =====================================
def post_interaction(self, timestep: int, timesteps: int) -> None:
"""Callback called after the interaction with the environment
:param timestep: Current timestep
:type timestep: int
:param timesteps: Number of timesteps
:type timesteps: int
"""
# =====================================
# - call `self.update(...)` if required
# =====================================
# call parent's method for checkpointing and TensorBoard writing
super().post_interaction(timestep, timesteps)
def _update(self, timestep: int, timesteps: int) -> None:
"""Algorithm's main update step
:param timestep: Current timestep
:type timestep: int
:param timesteps: Number of timesteps
:type timesteps: int
"""
# ===================================================
# - implement algorithm's update step
# - record tracking data using `self.track_data(...)`
# ===================================================
API (PyTorch)¶
- class skrl.agents.torch.base.Agent(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None)¶
Bases:
object
- __init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None) None ¶
Base class that represent a RL agent
- Parameters:
models (dictionary of skrl.models.torch.Model) – Models used by the agent
memory (skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added
observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default:
None
)action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default:
None
)device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default:
None
). If None, the device will be either"cuda"
if available or"cpu"
cfg (dict) – Configuration dictionary
- __str__() str ¶
Generate a representation of the agent as string
- Returns:
Representation of the agent as string
- Return type:
- _empty_preprocessor(_input: Any, *args, **kwargs) Any ¶
Empty preprocess method
This method is defined because PyTorch multiprocessing can’t pickle lambdas
- Parameters:
_input (Any) – Input to preprocess
- Returns:
Preprocessed input
- Return type:
Any
- _get_internal_value(_module: Any) Any ¶
Get internal module/variable state/value
- Parameters:
_module (Any) – Module or variable
- Returns:
Module/variable state/value
- Return type:
Any
- _update(timestep: int, timesteps: int) None ¶
Algorithm’s main update step
- Parameters:
- Raises:
NotImplementedError – The method is not implemented by the inheriting classes
- act(states: torch.Tensor, timestep: int, timesteps: int) torch.Tensor ¶
Process the environment’s states to make a decision (actions) using the main policy
- Parameters:
states (torch.Tensor) – Environment’s states
timestep (int) – Current timestep
timesteps (int) – Number of timesteps
- Raises:
NotImplementedError – The method is not implemented by the inheriting classes
- Returns:
Actions
- Return type:
- init(trainer_cfg: Mapping[str, Any] | None = None) None ¶
Initialize the agent
This method should be called before the agent is used. It will initialize the TensorBoard writer (and optionally Weights & Biases) and create the checkpoints directory
- Parameters:
trainer_cfg (dict, optional) – Trainer configuration
- load(path: str) None ¶
Load the model from the specified path
The final storage device is determined by the constructor of the model
- Parameters:
path (str) – Path to load the model from
- migrate(path: str, name_map: Mapping[str, Mapping[str, str]] = {}, auto_mapping: bool = True, verbose: bool = False) bool ¶
Migrate the specified external checkpoint to the current agent
The final storage device is determined by the constructor of the agent. Only files generated by the rl_games library are supported at the moment
For ambiguous models (where 2 or more parameters, for source or current model, have equal shape) it is necessary to define the
name_map
, at least for those parameters, to perform the migration successfully- Parameters:
path (str) – Path to the external checkpoint to migrate from
name_map (Mapping[str, Mapping[str, str]], optional) – Name map to use for the migration (default:
{}
). Keys are the current parameter names and values are the external parameter namesauto_mapping (bool, optional) – Automatically map the external state dict to the current state dict (default:
True
)verbose (bool, optional) – Show model names and migration (default:
False
)
- Raises:
ValueError – If the correct file type cannot be identified from the
path
parameter- Returns:
True if the migration was successful, False otherwise. Migration is successful if all parameters of the current model are found in the external model
- Return type:
Example:
# migrate a rl_games checkpoint with ambiguous state_dict >>> agent.migrate(path="./runs/Cartpole/nn/Cartpole.pth", verbose=False) [skrl:WARNING] Ambiguous match for net.0.bias <- [a2c_network.actor_mlp.0.bias, a2c_network.actor_mlp.2.bias] [skrl:WARNING] Ambiguous match for net.2.bias <- [a2c_network.actor_mlp.0.bias, a2c_network.actor_mlp.2.bias] [skrl:WARNING] Ambiguous match for net.4.weight <- [a2c_network.value.weight, a2c_network.mu.weight] [skrl:WARNING] Ambiguous match for net.4.bias <- [a2c_network.value.bias, a2c_network.mu.bias] [skrl:WARNING] Multiple use of a2c_network.actor_mlp.0.bias -> [net.0.bias, net.2.bias] [skrl:WARNING] Multiple use of a2c_network.actor_mlp.2.bias -> [net.0.bias, net.2.bias] [skrl:WARNING] Ambiguous match for net.0.bias <- [a2c_network.actor_mlp.0.bias, a2c_network.actor_mlp.2.bias] [skrl:WARNING] Ambiguous match for net.2.bias <- [a2c_network.actor_mlp.0.bias, a2c_network.actor_mlp.2.bias] [skrl:WARNING] Ambiguous match for net.4.weight <- [a2c_network.value.weight, a2c_network.mu.weight] [skrl:WARNING] Ambiguous match for net.4.bias <- [a2c_network.value.bias, a2c_network.mu.bias] [skrl:WARNING] Multiple use of a2c_network.actor_mlp.0.bias -> [net.0.bias, net.2.bias] [skrl:WARNING] Multiple use of a2c_network.actor_mlp.2.bias -> [net.0.bias, net.2.bias] False >>> name_map = {"policy": {"net.0.bias": "a2c_network.actor_mlp.0.bias", ... "net.2.bias": "a2c_network.actor_mlp.2.bias", ... "net.4.weight": "a2c_network.mu.weight", ... "net.4.bias": "a2c_network.mu.bias"}, ... "value": {"net.0.bias": "a2c_network.actor_mlp.0.bias", ... "net.2.bias": "a2c_network.actor_mlp.2.bias", ... "net.4.weight": "a2c_network.value.weight", ... "net.4.bias": "a2c_network.value.bias"}} >>> model.migrate(path="./runs/Cartpole/nn/Cartpole.pth", name_map=name_map, verbose=True) [skrl:INFO] Modules [skrl:INFO] |-- current [skrl:INFO] | |-- policy (Policy) [skrl:INFO] | | |-- log_std_parameter : [1] [skrl:INFO] | | |-- net.0.weight : [32, 4] [skrl:INFO] | | |-- net.0.bias : [32] [skrl:INFO] | | |-- net.2.weight : [32, 32] [skrl:INFO] | | |-- net.2.bias : [32] [skrl:INFO] | | |-- net.4.weight : [1, 32] [skrl:INFO] | | |-- net.4.bias : [1] [skrl:INFO] | |-- value (Value) [skrl:INFO] | | |-- net.0.weight : [32, 4] [skrl:INFO] | | |-- net.0.bias : [32] [skrl:INFO] | | |-- net.2.weight : [32, 32] [skrl:INFO] | | |-- net.2.bias : [32] [skrl:INFO] | | |-- net.4.weight : [1, 32] [skrl:INFO] | | |-- net.4.bias : [1] [skrl:INFO] | |-- optimizer (Adam) [skrl:INFO] | | |-- state (dict) [skrl:INFO] | | |-- param_groups (list) [skrl:INFO] | |-- state_preprocessor (RunningStandardScaler) [skrl:INFO] | | |-- running_mean : [4] [skrl:INFO] | | |-- running_variance : [4] [skrl:INFO] | | |-- current_count : [] [skrl:INFO] | |-- value_preprocessor (RunningStandardScaler) [skrl:INFO] | | |-- running_mean : [1] [skrl:INFO] | | |-- running_variance : [1] [skrl:INFO] | | |-- current_count : [] [skrl:INFO] |-- source [skrl:INFO] | |-- model (OrderedDict) [skrl:INFO] | | |-- value_mean_std.running_mean : [1] [skrl:INFO] | | |-- value_mean_std.running_var : [1] [skrl:INFO] | | |-- value_mean_std.count : [] [skrl:INFO] | | |-- running_mean_std.running_mean : [4] [skrl:INFO] | | |-- running_mean_std.running_var : [4] [skrl:INFO] | | |-- running_mean_std.count : [] [skrl:INFO] | | |-- a2c_network.sigma : [1] [skrl:INFO] | | |-- a2c_network.actor_mlp.0.weight : [32, 4] [skrl:INFO] | | |-- a2c_network.actor_mlp.0.bias : [32] [skrl:INFO] | | |-- a2c_network.actor_mlp.2.weight : [32, 32] [skrl:INFO] | | |-- a2c_network.actor_mlp.2.bias : [32] [skrl:INFO] | | |-- a2c_network.value.weight : [1, 32] [skrl:INFO] | | |-- a2c_network.value.bias : [1] [skrl:INFO] | | |-- a2c_network.mu.weight : [1, 32] [skrl:INFO] | | |-- a2c_network.mu.bias : [1] [skrl:INFO] | |-- epoch (int) [skrl:INFO] | |-- optimizer (dict) [skrl:INFO] | |-- frame (int) [skrl:INFO] | |-- last_mean_rewards (float32) [skrl:INFO] | |-- env_state (NoneType) [skrl:INFO] Migration [skrl:INFO] Model: policy (Policy) [skrl:INFO] Models [skrl:INFO] |-- current: 7 items [skrl:INFO] | |-- log_std_parameter : [1] [skrl:INFO] | |-- net.0.weight : [32, 4] [skrl:INFO] | |-- net.0.bias : [32] [skrl:INFO] | |-- net.2.weight : [32, 32] [skrl:INFO] | |-- net.2.bias : [32] [skrl:INFO] | |-- net.4.weight : [1, 32] [skrl:INFO] | |-- net.4.bias : [1] [skrl:INFO] |-- source: 9 items [skrl:INFO] | |-- a2c_network.sigma : [1] [skrl:INFO] | |-- a2c_network.actor_mlp.0.weight : [32, 4] [skrl:INFO] | |-- a2c_network.actor_mlp.0.bias : [32] [skrl:INFO] | |-- a2c_network.actor_mlp.2.weight : [32, 32] [skrl:INFO] | |-- a2c_network.actor_mlp.2.bias : [32] [skrl:INFO] | |-- a2c_network.value.weight : [1, 32] [skrl:INFO] | |-- a2c_network.value.bias : [1] [skrl:INFO] | |-- a2c_network.mu.weight : [1, 32] [skrl:INFO] | |-- a2c_network.mu.bias : [1] [skrl:INFO] Migration [skrl:INFO] |-- auto: log_std_parameter <- a2c_network.sigma [skrl:INFO] |-- auto: net.0.weight <- a2c_network.actor_mlp.0.weight [skrl:INFO] |-- map: net.0.bias <- a2c_network.actor_mlp.0.bias [skrl:INFO] |-- auto: net.2.weight <- a2c_network.actor_mlp.2.weight [skrl:INFO] |-- map: net.2.bias <- a2c_network.actor_mlp.2.bias [skrl:INFO] |-- map: net.4.weight <- a2c_network.mu.weight [skrl:INFO] |-- map: net.4.bias <- a2c_network.mu.bias [skrl:INFO] Model: value (Value) [skrl:INFO] Models [skrl:INFO] |-- current: 6 items [skrl:INFO] | |-- net.0.weight : [32, 4] [skrl:INFO] | |-- net.0.bias : [32] [skrl:INFO] | |-- net.2.weight : [32, 32] [skrl:INFO] | |-- net.2.bias : [32] [skrl:INFO] | |-- net.4.weight : [1, 32] [skrl:INFO] | |-- net.4.bias : [1] [skrl:INFO] |-- source: 9 items [skrl:INFO] | |-- a2c_network.sigma : [1] [skrl:INFO] | |-- a2c_network.actor_mlp.0.weight : [32, 4] [skrl:INFO] | |-- a2c_network.actor_mlp.0.bias : [32] [skrl:INFO] | |-- a2c_network.actor_mlp.2.weight : [32, 32] [skrl:INFO] | |-- a2c_network.actor_mlp.2.bias : [32] [skrl:INFO] | |-- a2c_network.value.weight : [1, 32] [skrl:INFO] | |-- a2c_network.value.bias : [1] [skrl:INFO] | |-- a2c_network.mu.weight : [1, 32] [skrl:INFO] | |-- a2c_network.mu.bias : [1] [skrl:INFO] Migration [skrl:INFO] |-- auto: net.0.weight <- a2c_network.actor_mlp.0.weight [skrl:INFO] |-- map: net.0.bias <- a2c_network.actor_mlp.0.bias [skrl:INFO] |-- auto: net.2.weight <- a2c_network.actor_mlp.2.weight [skrl:INFO] |-- map: net.2.bias <- a2c_network.actor_mlp.2.bias [skrl:INFO] |-- map: net.4.weight <- a2c_network.value.weight [skrl:INFO] |-- map: net.4.bias <- a2c_network.value.bias True
- post_interaction(timestep: int, timesteps: int) None ¶
Callback called after the interaction with the environment
- pre_interaction(timestep: int, timesteps: int) None ¶
Callback called before the interaction with the environment
- record_transition(states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None ¶
Record an environment transition in memory (to be implemented by the inheriting classes)
Inheriting classes must call this method to record episode information (rewards, timesteps, etc.). In addition to recording environment transition (such as states, rewards, etc.), agent information can be recorded.
- Parameters:
states (torch.Tensor) – Observations/states of the environment used to make the decision
actions (torch.Tensor) – Actions taken by the agent
rewards (torch.Tensor) – Instant rewards achieved by the current actions
next_states (torch.Tensor) – Next observations/states of the environment
terminated (torch.Tensor) – Signals to indicate that episodes have terminated
truncated (torch.Tensor) – Signals to indicate that episodes have been truncated
infos (Any type supported by the environment) – Additional information about the environment
timestep (int) – Current timestep
timesteps (int) – Number of timesteps
- save(path: str) None ¶
Save the agent to the specified path
- Parameters:
path (str) – Path to save the model to
- set_mode(mode: str) None ¶
Set the model mode (training or evaluation)
- Parameters:
mode (str) – Mode: ‘train’ for training or ‘eval’ for evaluation
- set_running_mode(mode: str) None ¶
Set the current running mode (training or evaluation)
This method sets the value of the
training
property (boolean). This property can be used to know if the agent is running in training or evaluation mode.- Parameters:
mode (str) – Mode: ‘train’ for training or ‘eval’ for evaluation
- track_data(tag: str, value: float) None ¶
Track data to TensorBoard
Currently only scalar data are supported
API (JAX)¶
- class skrl.agents.jax.base.Agent(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | jax.Device | None = None, cfg: dict | None = None)¶
Bases:
object
- __init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | jax.Device | None = None, cfg: dict | None = None) None ¶
Base class that represent a RL agent
- Parameters:
models (dictionary of skrl.models.jax.Model) – Models used by the agent
memory (skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added
observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default:
None
)action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default:
None
)device (str or jax.Device, optional) – Device on which a tensor/array is or will be allocated (default:
None
). If None, the device will be either"cuda"
if available or"cpu"
cfg (dict) – Configuration dictionary
- __str__() str ¶
Generate a representation of the agent as string
- Returns:
Representation of the agent as string
- Return type:
- _empty_preprocessor(_input: Any, *args, **kwargs) Any ¶
Empty preprocess method
This method is defined because PyTorch multiprocessing can’t pickle lambdas
- Parameters:
_input (Any) – Input to preprocess
- Returns:
Preprocessed input
- Return type:
Any
- _get_internal_value(_module: Any) Any ¶
Get internal module/variable state/value
- Parameters:
_module (Any) – Module or variable
- Returns:
Module/variable state/value
- Return type:
Any
- _update(timestep: int, timesteps: int) None ¶
Algorithm’s main update step
- Parameters:
- Raises:
NotImplementedError – The method is not implemented by the inheriting classes
- act(states: ndarray | jax.Array, timestep: int, timesteps: int) ndarray | jax.Array ¶
Process the environment’s states to make a decision (actions) using the main policy
- Parameters:
- Raises:
NotImplementedError – The method is not implemented by the inheriting classes
- Returns:
Actions
- Return type:
np.ndarray or jax.Array
- init(trainer_cfg: Mapping[str, Any] | None = None) None ¶
Initialize the agent
This method should be called before the agent is used. It will initialize the TensorBoard writer (and optionally Weights & Biases) and create the checkpoints directory
- Parameters:
trainer_cfg (dict, optional) – Trainer configuration
- load(path: str) None ¶
Load the model from the specified path
- Parameters:
path (str) – Path to load the model from
- migrate(path: str, name_map: Mapping[str, Mapping[str, str]] = {}, auto_mapping: bool = True, verbose: bool = False) bool ¶
Migrate the specified external checkpoint to the current agent
- Raises:
NotImplementedError – Not yet implemented
- post_interaction(timestep: int, timesteps: int) None ¶
Callback called after the interaction with the environment
- pre_interaction(timestep: int, timesteps: int) None ¶
Callback called before the interaction with the environment
- record_transition(states: ndarray | jax.Array, actions: ndarray | jax.Array, rewards: ndarray | jax.Array, next_states: ndarray | jax.Array, terminated: ndarray | jax.Array, truncated: ndarray | jax.Array, infos: Any, timestep: int, timesteps: int) None ¶
Record an environment transition in memory (to be implemented by the inheriting classes)
Inheriting classes must call this method to record episode information (rewards, timesteps, etc.). In addition to recording environment transition (such as states, rewards, etc.), agent information can be recorded.
- Parameters:
states (np.ndarray or jax.Array) – Observations/states of the environment used to make the decision
actions (np.ndarray or jax.Array) – Actions taken by the agent
rewards (np.ndarray or jax.Array) – Instant rewards achieved by the current actions
next_states (np.ndarray or jax.Array) – Next observations/states of the environment
terminated (np.ndarray or jax.Array) – Signals to indicate that episodes have terminated
truncated (np.ndarray or jax.Array) – Signals to indicate that episodes have been truncated
infos (Any type supported by the environment) – Additional information about the environment
timestep (int) – Current timestep
timesteps (int) – Number of timesteps
- save(path: str) None ¶
Save the agent to the specified path
- Parameters:
path (str) – Path to save the model to
- set_mode(mode: str) None ¶
Set the model mode (training or evaluation)
- Parameters:
mode (str) – Mode: ‘train’ for training or ‘eval’ for evaluation
- set_running_mode(mode: str) None ¶
Set the current running mode (training or evaluation)
This method sets the value of the
training
property (boolean). This property can be used to know if the agent is running in training or evaluation mode.- Parameters:
mode (str) – Mode: ‘train’ for training or ‘eval’ for evaluation
- track_data(tag: str, value: float) None ¶
Track data to TensorBoard
Currently only scalar data are supported