Twin-Delayed DDPG (TD3)

TD3 is a model-free, deterministic off-policy actor-critic algorithm (based on DDPG) that relies on double Q-learning, target policy smoothing and delayed policy updates to address the problems introduced by overestimation bias in actor-critic algorithms

Paper: Addressing Function Approximation Error in Actor-Critic Methods



Algorithm


Algorithm implementation

Main notation/symbols:
- policy function approximator (\(\mu_\theta\)), critic function approximator (\(Q_\phi\))
- states (\(s\)), actions (\(a\)), rewards (\(r\)), next states (\(s'\)), dones (\(d\))
- loss (\(L\))

Decision making


act(...)
\(a \leftarrow \mu_\theta(s)\)
\(noise \leftarrow\) sample noise
\(scale \leftarrow (1 - \text{timestep} \;/\) timesteps \() \; (\) initial_scale \(-\) final_scale \() \;+\) final_scale
\(a \leftarrow \text{clip}(a + noise * scale, {a}_{Low}, {a}_{High})\)

Learning algorithm


_update(...)
# gradient steps
FOR each gradient step up to gradient_steps DO
# sample a batch from memory
[\(s, a, r, s', d\)] \(\leftarrow\) states, actions, rewards, next_states, dones of size batch_size
# target policy smoothing
\(a' \leftarrow \mu_{\theta_{target}}(s')\)
\(noise \leftarrow \text{clip}(\) smooth_regularization_noise \(, -c, c) \qquad\) with \(c\) as smooth_regularization_clip
\(a' \leftarrow a' + noise\)
\(a' \leftarrow \text{clip}(a', {a'}_{Low}, {a'}_{High})\)
# compute target values
\(Q_{1_{target}} \leftarrow Q_{{\phi 1}_{target}}(s', a')\)
\(Q_{2_{target}} \leftarrow Q_{{\phi 2}_{target}}(s', a')\)
\(Q_{_{target}} \leftarrow \text{min}(Q_{1_{target}}, Q_{2_{target}})\)
\(y \leftarrow r \;+\) discount_factor \(\neg d \; Q_{_{target}}\)
# compute critic loss
\(Q_1 \leftarrow Q_{\phi 1}(s, a)\)
\(Q_2 \leftarrow Q_{\phi 2}(s, a)\)
\(L_{Q_\phi} \leftarrow \frac{1}{N} \sum_{i=1}^N (Q_1 - y)^2 + \frac{1}{N} \sum_{i=1}^N (Q_2 - y)^2\)
# optimization step (critic)
reset \(\text{optimizer}_\phi\)
\(\nabla_{\phi} L_{Q_\phi}\)
\(\text{clip}(\lVert \nabla_{\phi} \rVert)\) with grad_norm_clip
step \(\text{optimizer}_\phi\)
# delayed update
IF it’s time for the policy_delay update THEN
# compute policy (actor) loss
\(a \leftarrow \mu_\theta(s)\)
\(Q_1 \leftarrow Q_{\phi 1}(s, a)\)
\(L_{\mu_\theta} \leftarrow - \frac{1}{N} \sum_{i=1}^N Q_1\)
# optimization step (policy)
reset \(\text{optimizer}_\theta\)
\(\nabla_{\theta} L_{\mu_\theta}\)
\(\text{clip}(\lVert \nabla_{\theta} \rVert)\) with grad_norm_clip
step \(\text{optimizer}_\theta\)
# update target networks
\(\theta_{target} \leftarrow\) polyak \(\theta + (1 \;-\) polyak \() \theta_{target}\)
\({\phi 1}_{target} \leftarrow\) polyak \({\phi 1} + (1 \;-\) polyak \() {\phi 1}_{target}\)
\({\phi 2}_{target} \leftarrow\) polyak \({\phi 2} + (1 \;-\) polyak \() {\phi 2}_{target}\)
# update learning rate
IF there is a learning_rate_scheduler THEN
step \(\text{scheduler}_\theta (\text{optimizer}_\theta)\)
step \(\text{scheduler}_\phi (\text{optimizer}_\phi)\)

Usage

Note

Support for recurrent neural networks (RNN, LSTM, GRU and any other variant) is implemented in a separate file (td3_rnn.py) to maintain the readability of the standard implementation (td3.py)

# import the agent and its default configuration
from skrl.agents.torch.td3 import TD3, TD3_DEFAULT_CONFIG

# instantiate the agent's models
models = {}
models["policy"] = ...
models["target_policy"] = ...  # only required during training
models["critic_1"] = ...  # only required during training
models["critic_2"] = ...  # only required during training
models["target_critic_1"] = ...  # only required during training
models["target_critic_2"] = ...  # only required during training

# adjust some configuration if necessary
cfg_agent = TD3_DEFAULT_CONFIG.copy()
cfg_agent["<KEY>"] = ...

# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
agent = TD3(models=models,
            memory=memory,  # only required during training
            cfg=cfg_agent,
            observation_space=env.observation_space,
            action_space=env.action_space,
            device=env.device)

Configuration and hyperparameters

TD3_DEFAULT_CONFIG = {
    "gradient_steps": 1,            # gradient steps
    "batch_size": 64,               # training batch size

    "discount_factor": 0.99,        # discount factor (gamma)
    "polyak": 0.005,                # soft update hyperparameter (tau)

    "actor_learning_rate": 1e-3,    # actor learning rate
    "critic_learning_rate": 1e-3,   # critic learning rate
    "learning_rate_scheduler": None,        # learning rate scheduler class (see torch.optim.lr_scheduler)
    "learning_rate_scheduler_kwargs": {},   # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3})

    "state_preprocessor": None,             # state preprocessor class (see skrl.resources.preprocessors)
    "state_preprocessor_kwargs": {},        # state preprocessor's kwargs (e.g. {"size": env.observation_space})

    "random_timesteps": 0,          # random exploration steps
    "learning_starts": 0,           # learning starts after this many steps

    "grad_norm_clip": 0,            # clipping coefficient for the norm of the gradients

    "exploration": {
        "noise": None,              # exploration noise
        "initial_scale": 1.0,       # initial scale for the noise
        "final_scale": 1e-3,        # final scale for the noise
        "timesteps": None,          # timesteps for the noise decay
    },

    "policy_delay": 2,                      # policy delay update with respect to critic update
    "smooth_regularization_noise": None,    # smooth noise for regularization
    "smooth_regularization_clip": 0.5,      # clip for smooth regularization

    "rewards_shaper": None,         # rewards shaping function: Callable(reward, timestep, timesteps) -> reward

    "experiment": {
        "directory": "",            # experiment's parent directory
        "experiment_name": "",      # experiment name
        "write_interval": "auto",   # TensorBoard writing interval (timesteps)

        "checkpoint_interval": "auto",      # interval for checkpoints (timesteps)
        "store_separately": False,          # whether to store checkpoints separately

        "wandb": False,             # whether to use Weights & Biases
        "wandb_kwargs": {}          # wandb kwargs (see https://docs.wandb.ai/ref/python/init)
    }
}

Spaces

The implementation supports the following Gym spaces / Gymnasium spaces

Gym/Gymnasium spaces

Observation

Action

Discrete

\(\square\)

\(\square\)

MultiDiscrete

\(\square\)

\(\square\)

Box

\(\blacksquare\)

\(\blacksquare\)

Dict

\(\blacksquare\)

\(\square\)


Models

The implementation uses 6 deterministic function approximators. These function approximators (models) must be collected in a dictionary and passed to the constructor of the class under the argument models

Notation

Concept

Key

Input shape

Output shape

Type

\(\mu_\theta(s)\)

Policy (actor)

"policy"

observation

action

Deterministic

\(\mu_{\theta_{target}}(s)\)

Target policy

"target_policy"

observation

action

Deterministic

\(Q_{\phi 1}(s, a)\)

Q1-network (critic 1)

"critic_1"

observation + action

1

Deterministic

\(Q_{\phi 2}(s, a)\)

Q2-network (critic 2)

"critic_2"

observation + action

1

Deterministic

\(Q_{{\phi 1}_{target}}(s, a)\)

Target Q1-network

"target_critic_1"

observation + action

1

Deterministic

\(Q_{{\phi 2}_{target}}(s, a)\)

Target Q2-network

"target_critic_2"

observation + action

1

Deterministic


Features

Support for advanced features is described in the next table

Feature

Support and remarks

    pytorch    

    jax    

Shared model

-

\(\square\)

\(\square\)

RNN support

RNN, LSTM, GRU and any other variant

\(\blacksquare\)

\(\square\)

Distributed

Single Program Multi Data (SPMD) multi-GPU

\(\blacksquare\)

\(\blacksquare\)


API (PyTorch)

skrl.agents.torch.td3.TD3_DEFAULT_CONFIG

alias of {‘actor_learning_rate’: 0.001, ‘batch_size’: 64, ‘critic_learning_rate’: 0.001, ‘discount_factor’: 0.99, ‘experiment’: {‘checkpoint_interval’: ‘auto’, ‘directory’: ‘’, ‘experiment_name’: ‘’, ‘store_separately’: False, ‘wandb’: False, ‘wandb_kwargs’: {}, ‘write_interval’: ‘auto’}, ‘exploration’: {‘final_scale’: 0.001, ‘initial_scale’: 1.0, ‘noise’: None, ‘timesteps’: None}, ‘grad_norm_clip’: 0, ‘gradient_steps’: 1, ‘learning_rate_scheduler’: None, ‘learning_rate_scheduler_kwargs’: {}, ‘learning_starts’: 0, ‘policy_delay’: 2, ‘polyak’: 0.005, ‘random_timesteps’: 0, ‘rewards_shaper’: None, ‘smooth_regularization_clip’: 0.5, ‘smooth_regularization_noise’: None, ‘state_preprocessor’: None, ‘state_preprocessor_kwargs’: {}}

class skrl.agents.torch.td3.TD3(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None)

Bases: Agent

__init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None) None

Twin Delayed DDPG (TD3)

https://arxiv.org/abs/1802.09477

Parameters:
  • models (dictionary of skrl.models.torch.Model) – Models used by the agent

  • memory (skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

  • observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

  • action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

  • device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

  • cfg (dict) – Configuration dictionary

Raises:

KeyError – If the models dictionary is missing a required key

_update(timestep: int, timesteps: int) None

Algorithm’s main update step

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

act(states: torch.Tensor, timestep: int, timesteps: int) torch.Tensor

Process the environment’s states to make a decision (actions) using the main policy

Parameters:
  • states (torch.Tensor) – Environment’s states

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

Returns:

Actions

Return type:

torch.Tensor

init(trainer_cfg: Mapping[str, Any] | None = None) None

Initialize the agent

post_interaction(timestep: int, timesteps: int) None

Callback called after the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None

Callback called before the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

record_transition(states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None

Record an environment transition in memory

Parameters:
  • states (torch.Tensor) – Observations/states of the environment used to make the decision

  • actions (torch.Tensor) – Actions taken by the agent

  • rewards (torch.Tensor) – Instant rewards achieved by the current actions

  • next_states (torch.Tensor) – Next observations/states of the environment

  • terminated (torch.Tensor) – Signals to indicate that episodes have terminated

  • truncated (torch.Tensor) – Signals to indicate that episodes have been truncated

  • infos (Any type supported by the environment) – Additional information about the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

class skrl.agents.torch.td3.TD3_RNN(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None)

Bases: Agent

__init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None) None

Twin Delayed DDPG (TD3) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.)

https://arxiv.org/abs/1802.09477

Parameters:
  • models (dictionary of skrl.models.torch.Model) – Models used by the agent

  • memory (skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

  • observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

  • action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

  • device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

  • cfg (dict) – Configuration dictionary

Raises:

KeyError – If the models dictionary is missing a required key

_update(timestep: int, timesteps: int) None

Algorithm’s main update step

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

act(states: torch.Tensor, timestep: int, timesteps: int) torch.Tensor

Process the environment’s states to make a decision (actions) using the main policy

Parameters:
  • states (torch.Tensor) – Environment’s states

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

Returns:

Actions

Return type:

torch.Tensor

init(trainer_cfg: Mapping[str, Any] | None = None) None

Initialize the agent

post_interaction(timestep: int, timesteps: int) None

Callback called after the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None

Callback called before the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

record_transition(states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None

Record an environment transition in memory

Parameters:
  • states (torch.Tensor) – Observations/states of the environment used to make the decision

  • actions (torch.Tensor) – Actions taken by the agent

  • rewards (torch.Tensor) – Instant rewards achieved by the current actions

  • next_states (torch.Tensor) – Next observations/states of the environment

  • terminated (torch.Tensor) – Signals to indicate that episodes have terminated

  • truncated (torch.Tensor) – Signals to indicate that episodes have been truncated

  • infos (Any type supported by the environment) – Additional information about the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps


API (JAX)

skrl.agents.jax.td3.TD3_DEFAULT_CONFIG

alias of {‘actor_learning_rate’: 0.001, ‘batch_size’: 64, ‘critic_learning_rate’: 0.001, ‘discount_factor’: 0.99, ‘experiment’: {‘checkpoint_interval’: ‘auto’, ‘directory’: ‘’, ‘experiment_name’: ‘’, ‘store_separately’: False, ‘wandb’: False, ‘wandb_kwargs’: {}, ‘write_interval’: ‘auto’}, ‘exploration’: {‘final_scale’: 0.001, ‘initial_scale’: 1.0, ‘noise’: None, ‘timesteps’: None}, ‘grad_norm_clip’: 0, ‘gradient_steps’: 1, ‘learning_rate_scheduler’: None, ‘learning_rate_scheduler_kwargs’: {}, ‘learning_starts’: 0, ‘policy_delay’: 2, ‘polyak’: 0.005, ‘random_timesteps’: 0, ‘rewards_shaper’: None, ‘smooth_regularization_clip’: 0.5, ‘smooth_regularization_noise’: None, ‘state_preprocessor’: None, ‘state_preprocessor_kwargs’: {}}

class skrl.agents.jax.td3.TD3(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | jax.Device | None = None, cfg: dict | None = None)

Bases: Agent

__init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | jax.Device | None = None, cfg: dict | None = None) None

Twin Delayed DDPG (TD3)

https://arxiv.org/abs/1802.09477

Parameters:
  • models (dictionary of skrl.models.jax.Model) – Models used by the agent

  • memory (skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

  • observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

  • action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

  • device (str or jax.Device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

  • cfg (dict) – Configuration dictionary

Raises:

KeyError – If the models dictionary is missing a required key

_update(timestep: int, timesteps: int) None

Algorithm’s main update step

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

act(states: ndarray | jax.Array, timestep: int, timesteps: int) ndarray | jax.Array

Process the environment’s states to make a decision (actions) using the main policy

Parameters:
  • states (np.ndarray or jax.Array) – Environment’s states

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

Returns:

Actions

Return type:

np.ndarray or jax.Array

init(trainer_cfg: Mapping[str, Any] | None = None) None

Initialize the agent

post_interaction(timestep: int, timesteps: int) None

Callback called after the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None

Callback called before the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

record_transition(states: ndarray | jax.Array, actions: ndarray | jax.Array, rewards: ndarray | jax.Array, next_states: ndarray | jax.Array, terminated: ndarray | jax.Array, truncated: ndarray | jax.Array, infos: Any, timestep: int, timesteps: int) None

Record an environment transition in memory

Parameters:
  • states (np.ndarray or jax.Array) – Observations/states of the environment used to make the decision

  • actions (np.ndarray or jax.Array) – Actions taken by the agent

  • rewards (np.ndarray or jax.Array) – Instant rewards achieved by the current actions

  • next_states (np.ndarray or jax.Array) – Next observations/states of the environment

  • terminated (np.ndarray or jax.Array) – Signals to indicate that episodes have terminated

  • truncated (np.ndarray or jax.Array) – Signals to indicate that episodes have been truncated

  • infos (Any type supported by the environment) – Additional information about the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps