Twin-Delayed DDPG (TD3)¶
TD3 is a model-free, deterministic off-policy actor-critic algorithm (based on DDPG) that relies on double Q-learning, target policy smoothing and delayed policy updates to address the problems introduced by overestimation bias in actor-critic algorithms
Paper: Addressing Function Approximation Error in Actor-Critic Methods
Algorithm¶
Algorithm implementation¶
Decision making¶
act(...)
Learning algorithm¶
_update(...)
Usage¶
Note
Support for recurrent neural networks (RNN, LSTM, GRU and any other variant) is implemented in a separate file (td3_rnn.py
) to maintain the readability of the standard implementation (td3.py
)
# import the agent and its default configuration
from skrl.agents.torch.td3 import TD3, TD3_DEFAULT_CONFIG
# instantiate the agent's models
models = {}
models["policy"] = ...
models["target_policy"] = ... # only required during training
models["critic_1"] = ... # only required during training
models["critic_2"] = ... # only required during training
models["target_critic_1"] = ... # only required during training
models["target_critic_2"] = ... # only required during training
# adjust some configuration if necessary
cfg_agent = TD3_DEFAULT_CONFIG.copy()
cfg_agent["<KEY>"] = ...
# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
agent = TD3(models=models,
memory=memory, # only required during training
cfg=cfg_agent,
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device)
# import the agent and its default configuration
from skrl.agents.jax.td3 import TD3, TD3_DEFAULT_CONFIG
# instantiate the agent's models
models = {}
models["policy"] = ...
models["target_policy"] = ... # only required during training
models["critic_1"] = ... # only required during training
models["critic_2"] = ... # only required during training
models["target_critic_1"] = ... # only required during training
models["target_critic_2"] = ... # only required during training
# adjust some configuration if necessary
cfg_agent = TD3_DEFAULT_CONFIG.copy()
cfg_agent["<KEY>"] = ...
# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
agent = TD3(models=models,
memory=memory, # only required during training
cfg=cfg_agent,
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device)
Note
When using recursive models it is necessary to override their .get_specification()
method. Visit each model’s documentation for more details
# import the agent and its default configuration
from skrl.agents.torch.td3 import TD3_RNN as TD3, TD3_DEFAULT_CONFIG
# instantiate the agent's models
models = {}
models["policy"] = ...
models["target_policy"] = ... # only required during training
models["critic_1"] = ... # only required during training
models["critic_2"] = ... # only required during training
models["target_critic_1"] = ... # only required during training
models["target_critic_2"] = ... # only required during training
# adjust some configuration if necessary
cfg_agent = TD3_DEFAULT_CONFIG.copy()
cfg_agent["<KEY>"] = ...
# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
agent = TD3(models=models,
memory=memory, # only required during training
cfg=cfg_agent,
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device)
Configuration and hyperparameters¶
TD3_DEFAULT_CONFIG = {
"gradient_steps": 1, # gradient steps
"batch_size": 64, # training batch size
"discount_factor": 0.99, # discount factor (gamma)
"polyak": 0.005, # soft update hyperparameter (tau)
"actor_learning_rate": 1e-3, # actor learning rate
"critic_learning_rate": 1e-3, # critic learning rate
"learning_rate_scheduler": None, # learning rate scheduler class (see torch.optim.lr_scheduler)
"learning_rate_scheduler_kwargs": {}, # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3})
"state_preprocessor": None, # state preprocessor class (see skrl.resources.preprocessors)
"state_preprocessor_kwargs": {}, # state preprocessor's kwargs (e.g. {"size": env.observation_space})
"random_timesteps": 0, # random exploration steps
"learning_starts": 0, # learning starts after this many steps
"grad_norm_clip": 0, # clipping coefficient for the norm of the gradients
"exploration": {
"noise": None, # exploration noise
"initial_scale": 1.0, # initial scale for the noise
"final_scale": 1e-3, # final scale for the noise
"timesteps": None, # timesteps for the noise decay
},
"policy_delay": 2, # policy delay update with respect to critic update
"smooth_regularization_noise": None, # smooth noise for regularization
"smooth_regularization_clip": 0.5, # clip for smooth regularization
"rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
"write_interval": "auto", # TensorBoard writing interval (timesteps)
"checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
"wandb_kwargs": {} # wandb kwargs (see https://docs.wandb.ai/ref/python/init)
}
}
Spaces¶
The implementation supports the following Gym spaces / Gymnasium spaces
Gym/Gymnasium spaces |
Observation |
Action |
---|---|---|
Discrete |
\(\square\) |
\(\square\) |
MultiDiscrete |
\(\square\) |
\(\square\) |
Box |
\(\blacksquare\) |
\(\blacksquare\) |
Dict |
\(\blacksquare\) |
\(\square\) |
Models¶
The implementation uses 6 deterministic function approximators. These function approximators (models) must be collected in a dictionary and passed to the constructor of the class under the argument models
Notation |
Concept |
Key |
Input shape |
Output shape |
Type |
---|---|---|---|---|---|
\(\mu_\theta(s)\) |
Policy (actor) |
|
observation |
action |
|
\(\mu_{\theta_{target}}(s)\) |
Target policy |
|
observation |
action |
|
\(Q_{\phi 1}(s, a)\) |
Q1-network (critic 1) |
|
observation + action |
1 |
|
\(Q_{\phi 2}(s, a)\) |
Q2-network (critic 2) |
|
observation + action |
1 |
|
\(Q_{{\phi 1}_{target}}(s, a)\) |
Target Q1-network |
|
observation + action |
1 |
|
\(Q_{{\phi 2}_{target}}(s, a)\) |
Target Q2-network |
|
observation + action |
1 |
Features¶
Support for advanced features is described in the next table
Feature |
Support and remarks |
|
|
---|---|---|---|
Shared model |
- |
\(\square\) |
\(\square\) |
RNN support |
RNN, LSTM, GRU and any other variant |
\(\blacksquare\) |
\(\square\) |
Distributed |
Single Program Multi Data (SPMD) multi-GPU |
\(\blacksquare\) |
\(\blacksquare\) |
API (PyTorch)¶
- skrl.agents.torch.td3.TD3_DEFAULT_CONFIG¶
alias of {‘actor_learning_rate’: 0.001, ‘batch_size’: 64, ‘critic_learning_rate’: 0.001, ‘discount_factor’: 0.99, ‘experiment’: {‘checkpoint_interval’: ‘auto’, ‘directory’: ‘’, ‘experiment_name’: ‘’, ‘store_separately’: False, ‘wandb’: False, ‘wandb_kwargs’: {}, ‘write_interval’: ‘auto’}, ‘exploration’: {‘final_scale’: 0.001, ‘initial_scale’: 1.0, ‘noise’: None, ‘timesteps’: None}, ‘grad_norm_clip’: 0, ‘gradient_steps’: 1, ‘learning_rate_scheduler’: None, ‘learning_rate_scheduler_kwargs’: {}, ‘learning_starts’: 0, ‘policy_delay’: 2, ‘polyak’: 0.005, ‘random_timesteps’: 0, ‘rewards_shaper’: None, ‘smooth_regularization_clip’: 0.5, ‘smooth_regularization_noise’: None, ‘state_preprocessor’: None, ‘state_preprocessor_kwargs’: {}}
- class skrl.agents.torch.td3.TD3(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None)¶
Bases:
Agent
- __init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None) None ¶
Twin Delayed DDPG (TD3)
https://arxiv.org/abs/1802.09477
- Parameters:
models (dictionary of skrl.models.torch.Model) – Models used by the agent
memory (skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added
observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default:
None
)action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default:
None
)device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default:
None
). If None, the device will be either"cuda"
if available or"cpu"
cfg (dict) – Configuration dictionary
- Raises:
KeyError – If the models dictionary is missing a required key
- act(states: torch.Tensor, timestep: int, timesteps: int) torch.Tensor ¶
Process the environment’s states to make a decision (actions) using the main policy
- Parameters:
states (torch.Tensor) – Environment’s states
timestep (int) – Current timestep
timesteps (int) – Number of timesteps
- Returns:
Actions
- Return type:
- post_interaction(timestep: int, timesteps: int) None ¶
Callback called after the interaction with the environment
- pre_interaction(timestep: int, timesteps: int) None ¶
Callback called before the interaction with the environment
- record_transition(states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None ¶
Record an environment transition in memory
- Parameters:
states (torch.Tensor) – Observations/states of the environment used to make the decision
actions (torch.Tensor) – Actions taken by the agent
rewards (torch.Tensor) – Instant rewards achieved by the current actions
next_states (torch.Tensor) – Next observations/states of the environment
terminated (torch.Tensor) – Signals to indicate that episodes have terminated
truncated (torch.Tensor) – Signals to indicate that episodes have been truncated
infos (Any type supported by the environment) – Additional information about the environment
timestep (int) – Current timestep
timesteps (int) – Number of timesteps
- class skrl.agents.torch.td3.TD3_RNN(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None)¶
Bases:
Agent
- __init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None) None ¶
Twin Delayed DDPG (TD3) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.)
https://arxiv.org/abs/1802.09477
- Parameters:
models (dictionary of skrl.models.torch.Model) – Models used by the agent
memory (skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added
observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default:
None
)action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default:
None
)device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default:
None
). If None, the device will be either"cuda"
if available or"cpu"
cfg (dict) – Configuration dictionary
- Raises:
KeyError – If the models dictionary is missing a required key
- act(states: torch.Tensor, timestep: int, timesteps: int) torch.Tensor ¶
Process the environment’s states to make a decision (actions) using the main policy
- Parameters:
states (torch.Tensor) – Environment’s states
timestep (int) – Current timestep
timesteps (int) – Number of timesteps
- Returns:
Actions
- Return type:
- post_interaction(timestep: int, timesteps: int) None ¶
Callback called after the interaction with the environment
- pre_interaction(timestep: int, timesteps: int) None ¶
Callback called before the interaction with the environment
- record_transition(states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None ¶
Record an environment transition in memory
- Parameters:
states (torch.Tensor) – Observations/states of the environment used to make the decision
actions (torch.Tensor) – Actions taken by the agent
rewards (torch.Tensor) – Instant rewards achieved by the current actions
next_states (torch.Tensor) – Next observations/states of the environment
terminated (torch.Tensor) – Signals to indicate that episodes have terminated
truncated (torch.Tensor) – Signals to indicate that episodes have been truncated
infos (Any type supported by the environment) – Additional information about the environment
timestep (int) – Current timestep
timesteps (int) – Number of timesteps
API (JAX)¶
- skrl.agents.jax.td3.TD3_DEFAULT_CONFIG¶
alias of {‘actor_learning_rate’: 0.001, ‘batch_size’: 64, ‘critic_learning_rate’: 0.001, ‘discount_factor’: 0.99, ‘experiment’: {‘checkpoint_interval’: ‘auto’, ‘directory’: ‘’, ‘experiment_name’: ‘’, ‘store_separately’: False, ‘wandb’: False, ‘wandb_kwargs’: {}, ‘write_interval’: ‘auto’}, ‘exploration’: {‘final_scale’: 0.001, ‘initial_scale’: 1.0, ‘noise’: None, ‘timesteps’: None}, ‘grad_norm_clip’: 0, ‘gradient_steps’: 1, ‘learning_rate_scheduler’: None, ‘learning_rate_scheduler_kwargs’: {}, ‘learning_starts’: 0, ‘policy_delay’: 2, ‘polyak’: 0.005, ‘random_timesteps’: 0, ‘rewards_shaper’: None, ‘smooth_regularization_clip’: 0.5, ‘smooth_regularization_noise’: None, ‘state_preprocessor’: None, ‘state_preprocessor_kwargs’: {}}
- class skrl.agents.jax.td3.TD3(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | jax.Device | None = None, cfg: dict | None = None)¶
Bases:
Agent
- __init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | jax.Device | None = None, cfg: dict | None = None) None ¶
Twin Delayed DDPG (TD3)
https://arxiv.org/abs/1802.09477
- Parameters:
models (dictionary of skrl.models.jax.Model) – Models used by the agent
memory (skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added
observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default:
None
)action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default:
None
)device (str or jax.Device, optional) – Device on which a tensor/array is or will be allocated (default:
None
). If None, the device will be either"cuda"
if available or"cpu"
cfg (dict) – Configuration dictionary
- Raises:
KeyError – If the models dictionary is missing a required key
- act(states: ndarray | jax.Array, timestep: int, timesteps: int) ndarray | jax.Array ¶
Process the environment’s states to make a decision (actions) using the main policy
- post_interaction(timestep: int, timesteps: int) None ¶
Callback called after the interaction with the environment
- pre_interaction(timestep: int, timesteps: int) None ¶
Callback called before the interaction with the environment
- record_transition(states: ndarray | jax.Array, actions: ndarray | jax.Array, rewards: ndarray | jax.Array, next_states: ndarray | jax.Array, terminated: ndarray | jax.Array, truncated: ndarray | jax.Array, infos: Any, timestep: int, timesteps: int) None ¶
Record an environment transition in memory
- Parameters:
states (np.ndarray or jax.Array) – Observations/states of the environment used to make the decision
actions (np.ndarray or jax.Array) – Actions taken by the agent
rewards (np.ndarray or jax.Array) – Instant rewards achieved by the current actions
next_states (np.ndarray or jax.Array) – Next observations/states of the environment
terminated (np.ndarray or jax.Array) – Signals to indicate that episodes have terminated
truncated (np.ndarray or jax.Array) – Signals to indicate that episodes have been truncated
infos (Any type supported by the environment) – Additional information about the environment
timestep (int) – Current timestep
timesteps (int) – Number of timesteps