Soft Actor-Critic (SAC)#

SAC is a model-free, stochastic off-policy actor-critic algorithm that uses double Q-learning (like TD3) and entropy regularization to maximize a trade-off between exploration and exploitation

Paper: Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor



Algorithm#


Algorithm implementation#

Main notation/symbols:
- policy function approximator (\(\pi_\theta\)), critic function approximator (\(Q_\phi\))
- states (\(s\)), actions (\(a\)), rewards (\(r\)), next states (\(s'\)), dones (\(d\))
- log probabilities (\(logp\)), entropy coefficient (\(\alpha\))
- loss (\(L\))

Learning algorithm#


_update(...)
# sample a batch from memory
[\(s, a, r, s', d\)] \(\leftarrow\) states, actions, rewards, next_states, dones of size batch_size
# gradient steps
FOR each gradient step up to gradient_steps DO
# compute target values
\(a',\; logp' \leftarrow \pi_\theta(s')\)
\(Q_{1_{target}} \leftarrow Q_{{\phi 1}_{target}}(s', a')\)
\(Q_{2_{target}} \leftarrow Q_{{\phi 2}_{target}}(s', a')\)
\(Q_{_{target}} \leftarrow \text{min}(Q_{1_{target}}, Q_{2_{target}}) - \alpha \; logp'\)
\(y \leftarrow r \;+\) discount_factor \(\neg d \; Q_{_{target}}\)
# compute critic loss
\(Q_1 \leftarrow Q_{\phi 1}(s, a)\)
\(Q_2 \leftarrow Q_{\phi 2}(s, a)\)
\(L_{Q_\phi} \leftarrow 0.5 \; (\frac{1}{N} \sum_{i=1}^N (Q_1 - y)^2 + \frac{1}{N} \sum_{i=1}^N (Q_2 - y)^2)\)
# optimization step (critic)
reset \(\text{optimizer}_\phi\)
\(\nabla_{\phi} L_{Q_\phi}\)
\(\text{clip}(\lVert \nabla_{\phi} \rVert)\) with grad_norm_clip
step \(\text{optimizer}_\phi\)
# compute policy (actor) loss
\(a,\; logp \leftarrow \pi_\theta(s)\)
\(Q_1 \leftarrow Q_{\phi 1}(s, a)\)
\(Q_2 \leftarrow Q_{\phi 2}(s, a)\)
\(L_{\pi_\theta} \leftarrow \frac{1}{N} \sum_{i=1}^N (\alpha \; logp - \text{min}(Q_1, Q_2))\)
# optimization step (policy)
reset \(\text{optimizer}_\theta\)
\(\nabla_{\theta} L_{\pi_\theta}\)
\(\text{clip}(\lVert \nabla_{\theta} \rVert)\) with grad_norm_clip
step \(\text{optimizer}_\theta\)
# entropy learning
IF learn_entropy is enabled THEN
# compute entropy loss
\({L}_{entropy} \leftarrow - \frac{1}{N} \sum_{i=1}^N (log(\alpha) \; (logp + \alpha_{Target}))\)
# optimization step (entropy)
reset \(\text{optimizer}_\alpha\)
\(\nabla_{\alpha} {L}_{entropy}\)
step \(\text{optimizer}_\alpha\)
# compute entropy coefficient
\(\alpha \leftarrow e^{log(\alpha)}\)
# update target networks
\({\phi 1}_{target} \leftarrow\) polyak \({\phi 1} + (1 \;-\) polyak \() {\phi 1}_{target}\)
\({\phi 2}_{target} \leftarrow\) polyak \({\phi 2} + (1 \;-\) polyak \() {\phi 2}_{target}\)
# update learning rate
IF there is a learning_rate_scheduler THEN
step \(\text{scheduler}_\theta (\text{optimizer}_\theta)\)
step \(\text{scheduler}_\phi (\text{optimizer}_\phi)\)

Usage#

Note

Support for recurrent neural networks (RNN, LSTM, GRU and any other variant) is implemented in a separate file (sac_rnn.py) to maintain the readability of the standard implementation (sac.py)

# import the agent and its default configuration
from skrl.agents.torch.sac import SAC, SAC_DEFAULT_CONFIG

# instantiate the agent's models
models = {}
models["policy"] = ...
models["critic_1"] = ...  # only required during training
models["critic_2"] = ...  # only required during training
models["target_critic_1"] = ...  # only required during training
models["target_critic_2"] = ...  # only required during training

# adjust some configuration if necessary
cfg_agent = SAC_DEFAULT_CONFIG.copy()
cfg_agent["<KEY>"] = ...

# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
agent = SAC(models=models,
            memory=memory,  # only required during training
            cfg=cfg_agent,
            observation_space=env.observation_space,
            action_space=env.action_space,
            device=env.device)

Configuration and hyperparameters#

SAC_DEFAULT_CONFIG = {
    "gradient_steps": 1,            # gradient steps
    "batch_size": 64,               # training batch size

    "discount_factor": 0.99,        # discount factor (gamma)
    "polyak": 0.005,                # soft update hyperparameter (tau)

    "actor_learning_rate": 1e-3,    # actor learning rate
    "critic_learning_rate": 1e-3,   # critic learning rate
    "learning_rate_scheduler": None,        # learning rate scheduler class (see torch.optim.lr_scheduler)
    "learning_rate_scheduler_kwargs": {},   # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3})

    "state_preprocessor": None,             # state preprocessor class (see skrl.resources.preprocessors)
    "state_preprocessor_kwargs": {},        # state preprocessor's kwargs (e.g. {"size": env.observation_space})

    "random_timesteps": 0,          # random exploration steps
    "learning_starts": 0,           # learning starts after this many steps

    "grad_norm_clip": 0,            # clipping coefficient for the norm of the gradients

    "learn_entropy": True,          # learn entropy
    "entropy_learning_rate": 1e-3,  # entropy learning rate
    "initial_entropy_value": 0.2,   # initial entropy value
    "target_entropy": None,         # target entropy

    "rewards_shaper": None,         # rewards shaping function: Callable(reward, timestep, timesteps) -> reward

    "experiment": {
        "base_directory": "",       # base directory for the experiment
        "experiment_name": "",      # experiment name
        "write_interval": 250,      # TensorBoard writing interval (timesteps)

        "checkpoint_interval": 1000,        # interval for checkpoints (timesteps)
        "store_separately": False,          # whether to store checkpoints separately

        "wandb": False,             # whether to use Weights & Biases
        "wandb_kwargs": {}          # wandb kwargs (see https://docs.wandb.ai/ref/python/init)
    }
}

Spaces#

The implementation supports the following Gym spaces / Gymnasium spaces

Gym/Gymnasium spaces

Observation

Action

Discrete

\(\square\)

\(\square\)

MultiDiscrete

\(\square\)

\(\square\)

Box

\(\blacksquare\)

\(\blacksquare\)

Dict

\(\blacksquare\)

\(\square\)


Models#

The implementation uses 1 stochastic and 4 deterministic function approximators. These function approximators (models) must be collected in a dictionary and passed to the constructor of the class under the argument models

Notation

Concept

Key

Input shape

Output shape

Type

\(\pi_\theta(s)\)

Policy (actor)

"policy"

observation

action

Gaussian /
MultivariateGaussian

\(Q_{\phi 1}(s, a)\)

Q1-network (critic 1)

"critic_1"

observation + action

1

Deterministic

\(Q_{\phi 2}(s, a)\)

Q2-network (critic 2)

"critic_2"

observation + action

1

Deterministic

\(Q_{{\phi 1}_{target}}(s, a)\)

Target Q1-network

"target_critic_1"

observation + action

1

Deterministic

\(Q_{{\phi 2}_{target}}(s, a)\)

Target Q2-network

"target_critic_2"

observation + action

1

Deterministic


Features#

Support for advanced features is described in the next table

Feature

Support and remarks

    pytorch    

    jax    

Shared model

-

\(\square\)

\(\square\)

RNN support

RNN, LSTM, GRU and any other variant

\(\blacksquare\)

\(\square\)


API (PyTorch)#

skrl.agents.torch.sac.SAC_DEFAULT_CONFIG#

alias of {‘actor_learning_rate’: 0.001, ‘batch_size’: 64, ‘critic_learning_rate’: 0.001, ‘discount_factor’: 0.99, ‘entropy_learning_rate’: 0.001, ‘experiment’: {‘base_directory’: ‘’, ‘checkpoint_interval’: 1000, ‘experiment_name’: ‘’, ‘store_separately’: False, ‘wandb’: False, ‘wandb_kwargs’: {}, ‘write_interval’: 250}, ‘grad_norm_clip’: 0, ‘gradient_steps’: 1, ‘initial_entropy_value’: 0.2, ‘learn_entropy’: True, ‘learning_rate_scheduler’: None, ‘learning_rate_scheduler_kwargs’: {}, ‘learning_starts’: 0, ‘polyak’: 0.005, ‘random_timesteps’: 0, ‘rewards_shaper’: None, ‘state_preprocessor’: None, ‘state_preprocessor_kwargs’: {}, ‘target_entropy’: None}

class skrl.agents.torch.sac.SAC(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None)#

Bases: Agent

__init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None) None#

Soft Actor-Critic (SAC)

https://arxiv.org/abs/1801.01290

Parameters:
  • models (dictionary of skrl.models.torch.Model) – Models used by the agent

  • memory (skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

  • observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

  • action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

  • device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

  • cfg (dict) – Configuration dictionary

Raises:

KeyError – If the models dictionary is missing a required key

_update(timestep: int, timesteps: int) None#

Algorithm’s main update step

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

act(states: torch.Tensor, timestep: int, timesteps: int) torch.Tensor#

Process the environment’s states to make a decision (actions) using the main policy

Parameters:
  • states (torch.Tensor) – Environment’s states

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

Returns:

Actions

Return type:

torch.Tensor

init(trainer_cfg: Mapping[str, Any] | None = None) None#

Initialize the agent

post_interaction(timestep: int, timesteps: int) None#

Callback called after the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None#

Callback called before the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

record_transition(states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None#

Record an environment transition in memory

Parameters:
  • states (torch.Tensor) – Observations/states of the environment used to make the decision

  • actions (torch.Tensor) – Actions taken by the agent

  • rewards (torch.Tensor) – Instant rewards achieved by the current actions

  • next_states (torch.Tensor) – Next observations/states of the environment

  • terminated (torch.Tensor) – Signals to indicate that episodes have terminated

  • truncated (torch.Tensor) – Signals to indicate that episodes have been truncated

  • infos (Any type supported by the environment) – Additional information about the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

class skrl.agents.torch.sac.SAC_RNN(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None)#

Bases: Agent

__init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: dict | None = None) None#

Soft Actor-Critic (SAC) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.)

https://arxiv.org/abs/1801.01290

Parameters:
  • models (dictionary of skrl.models.torch.Model) – Models used by the agent

  • memory (skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

  • observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

  • action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

  • device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

  • cfg (dict) – Configuration dictionary

Raises:

KeyError – If the models dictionary is missing a required key

_update(timestep: int, timesteps: int) None#

Algorithm’s main update step

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

act(states: torch.Tensor, timestep: int, timesteps: int) torch.Tensor#

Process the environment’s states to make a decision (actions) using the main policy

Parameters:
  • states (torch.Tensor) – Environment’s states

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

Returns:

Actions

Return type:

torch.Tensor

init(trainer_cfg: Mapping[str, Any] | None = None) None#

Initialize the agent

post_interaction(timestep: int, timesteps: int) None#

Callback called after the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None#

Callback called before the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

record_transition(states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None#

Record an environment transition in memory

Parameters:
  • states (torch.Tensor) – Observations/states of the environment used to make the decision

  • actions (torch.Tensor) – Actions taken by the agent

  • rewards (torch.Tensor) – Instant rewards achieved by the current actions

  • next_states (torch.Tensor) – Next observations/states of the environment

  • terminated (torch.Tensor) – Signals to indicate that episodes have terminated

  • truncated (torch.Tensor) – Signals to indicate that episodes have been truncated

  • infos (Any type supported by the environment) – Additional information about the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps


API (JAX)#

skrl.agents.jax.sac.SAC_DEFAULT_CONFIG#

alias of {‘actor_learning_rate’: 0.001, ‘batch_size’: 64, ‘critic_learning_rate’: 0.001, ‘discount_factor’: 0.99, ‘entropy_learning_rate’: 0.001, ‘experiment’: {‘checkpoint_interval’: 1000, ‘directory’: ‘’, ‘experiment_name’: ‘’, ‘store_separately’: False, ‘wandb’: False, ‘wandb_kwargs’: {}, ‘write_interval’: 250}, ‘grad_norm_clip’: 0, ‘gradient_steps’: 1, ‘initial_entropy_value’: 0.2, ‘learn_entropy’: True, ‘learning_rate_scheduler’: None, ‘learning_rate_scheduler_kwargs’: {}, ‘learning_starts’: 0, ‘polyak’: 0.005, ‘random_timesteps’: 0, ‘rewards_shaper’: None, ‘state_preprocessor’: None, ‘state_preprocessor_kwargs’: {}, ‘target_entropy’: None}

class skrl.agents.jax.sac.SAC(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | jax.Device | None = None, cfg: dict | None = None)#

Bases: Agent

__init__(models: Mapping[str, Model], memory: Memory | Tuple[Memory] | None = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: str | jax.Device | None = None, cfg: dict | None = None) None#

Soft Actor-Critic (SAC)

https://arxiv.org/abs/1801.01290

Parameters:
  • models (dictionary of skrl.models.jax.Model) – Models used by the agent

  • memory (skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

  • observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

  • action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

  • device (str or jax.Device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

  • cfg (dict) – Configuration dictionary

Raises:

KeyError – If the models dictionary is missing a required key

_update(timestep: int, timesteps: int) None#

Algorithm’s main update step

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

act(states: ndarray | jax.Array, timestep: int, timesteps: int) ndarray | jax.Array#

Process the environment’s states to make a decision (actions) using the main policy

Parameters:
  • states (np.ndarray or jax.Array) – Environment’s states

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

Returns:

Actions

Return type:

np.ndarray or jax.Array

init(trainer_cfg: Mapping[str, Any] | None = None) None#

Initialize the agent

post_interaction(timestep: int, timesteps: int) None#

Callback called after the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None#

Callback called before the interaction with the environment

Parameters:
  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps

record_transition(states: ndarray | jax.Array, actions: ndarray | jax.Array, rewards: ndarray | jax.Array, next_states: ndarray | jax.Array, terminated: ndarray | jax.Array, truncated: ndarray | jax.Array, infos: Any, timestep: int, timesteps: int) None#

Record an environment transition in memory

Parameters:
  • states (np.ndarray or jax.Array) – Observations/states of the environment used to make the decision

  • actions (np.ndarray or jax.Array) – Actions taken by the agent

  • rewards (np.ndarray or jax.Array) – Instant rewards achieved by the current actions

  • next_states (np.ndarray or jax.Array) – Next observations/states of the environment

  • terminated (np.ndarray or jax.Array) – Signals to indicate that episodes have terminated

  • truncated (np.ndarray or jax.Array) – Signals to indicate that episodes have been truncated

  • infos (Any type supported by the environment) – Additional information about the environment

  • timestep (int) – Current timestep

  • timesteps (int) – Number of timesteps