# Soft Actor-Critic (SAC)#

SAC is a model-free, stochastic off-policy actor-critic algorithm that uses double Q-learning (like TD3) and entropy regularization to maximize a trade-off between exploration and exploitation

## Algorithm#

### Algorithm implementation#

Main notation/symbols:
- policy function approximator ($$\pi_\theta$$), critic function approximator ($$Q_\phi$$)
- states ($$s$$), actions ($$a$$), rewards ($$r$$), next states ($$s'$$), dones ($$d$$)
- log probabilities ($$logp$$), entropy coefficient ($$\alpha$$)
- loss ($$L$$)

#### Learning algorithm#

_update(...)
# sample a batch from memory
[$$s, a, r, s', d$$] $$\leftarrow$$ states, actions, rewards, next_states, dones of size batch_size
# compute target values
$$a',\; logp' \leftarrow \pi_\theta(s')$$
$$Q_{1_{target}} \leftarrow Q_{{\phi 1}_{target}}(s', a')$$
$$Q_{2_{target}} \leftarrow Q_{{\phi 2}_{target}}(s', a')$$
$$Q_{_{target}} \leftarrow \text{min}(Q_{1_{target}}, Q_{2_{target}}) - \alpha \; logp'$$
$$y \leftarrow r \;+$$ discount_factor $$\neg d \; Q_{_{target}}$$
# compute critic loss
$$Q_1 \leftarrow Q_{\phi 1}(s, a)$$
$$Q_2 \leftarrow Q_{\phi 2}(s, a)$$
$$L_{Q_\phi} \leftarrow 0.5 \; (\frac{1}{N} \sum_{i=1}^N (Q_1 - y)^2 + \frac{1}{N} \sum_{i=1}^N (Q_2 - y)^2)$$
# optimization step (critic)
reset $$\text{optimizer}_\phi$$
$$\nabla_{\phi} L_{Q_\phi}$$
$$\text{clip}(\lVert \nabla_{\phi} \rVert)$$ with grad_norm_clip
step $$\text{optimizer}_\phi$$
# compute policy (actor) loss
$$a,\; logp \leftarrow \pi_\theta(s)$$
$$Q_1 \leftarrow Q_{\phi 1}(s, a)$$
$$Q_2 \leftarrow Q_{\phi 2}(s, a)$$
$$L_{\pi_\theta} \leftarrow \frac{1}{N} \sum_{i=1}^N (\alpha \; logp - \text{min}(Q_1, Q_2))$$
# optimization step (policy)
reset $$\text{optimizer}_\theta$$
$$\nabla_{\theta} L_{\pi_\theta}$$
$$\text{clip}(\lVert \nabla_{\theta} \rVert)$$ with grad_norm_clip
step $$\text{optimizer}_\theta$$
# entropy learning
IF learn_entropy is enabled THEN
# compute entropy loss
$${L}_{entropy} \leftarrow - \frac{1}{N} \sum_{i=1}^N (log(\alpha) \; (logp + \alpha_{Target}))$$
# optimization step (entropy)
reset $$\text{optimizer}_\alpha$$
$$\nabla_{\alpha} {L}_{entropy}$$
step $$\text{optimizer}_\alpha$$
# compute entropy coefficient
$$\alpha \leftarrow e^{log(\alpha)}$$
# update target networks
$${\phi 1}_{target} \leftarrow$$ polyak $${\phi 1} + (1 \;-$$ polyak $$) {\phi 1}_{target}$$
$${\phi 2}_{target} \leftarrow$$ polyak $${\phi 2} + (1 \;-$$ polyak $$) {\phi 2}_{target}$$
# update learning rate
IF there is a learning_rate_scheduler THEN
step $$\text{scheduler}_\theta (\text{optimizer}_\theta)$$
step $$\text{scheduler}_\phi (\text{optimizer}_\phi)$$

## Usage#

Note

Support for recurrent neural networks (RNN, LSTM, GRU and any other variant) is implemented in a separate file (sac_rnn.py) to maintain the readability of the standard implementation (sac.py)

# import the agent and its default configuration
from skrl.agents.torch.sac import SAC, SAC_DEFAULT_CONFIG

# instantiate the agent's models
models = {}
models["policy"] = ...
models["critic_1"] = ...  # only required during training
models["critic_2"] = ...  # only required during training
models["target_critic_1"] = ...  # only required during training
models["target_critic_2"] = ...  # only required during training

# adjust some configuration if necessary
cfg_agent = SAC_DEFAULT_CONFIG.copy()
cfg_agent["<KEY>"] = ...

# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
agent = SAC(models=models,
memory=memory,  # only required during training
cfg=cfg_agent,
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device)


### Configuration and hyperparameters#

SAC_DEFAULT_CONFIG = {
"batch_size": 64,               # training batch size

"discount_factor": 0.99,        # discount factor (gamma)
"polyak": 0.005,                # soft update hyperparameter (tau)

"actor_learning_rate": 1e-3,    # actor learning rate
"critic_learning_rate": 1e-3,   # critic learning rate
"learning_rate_scheduler": None,        # learning rate scheduler class (see torch.optim.lr_scheduler)
"learning_rate_scheduler_kwargs": {},   # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3})

"state_preprocessor": None,             # state preprocessor class (see skrl.resources.preprocessors)
"state_preprocessor_kwargs": {},        # state preprocessor's kwargs (e.g. {"size": env.observation_space})

"random_timesteps": 0,          # random exploration steps
"learning_starts": 0,           # learning starts after this many steps

"learn_entropy": True,          # learn entropy
"entropy_learning_rate": 1e-3,  # entropy learning rate
"initial_entropy_value": 0.2,   # initial entropy value
"target_entropy": None,         # target entropy

"rewards_shaper": None,         # rewards shaping function: Callable(reward, timestep, timesteps) -> reward

"experiment": {
"base_directory": "",       # base directory for the experiment
"experiment_name": "",      # experiment name
"write_interval": 250,      # TensorBoard writing interval (timesteps)

"checkpoint_interval": 1000,        # interval for checkpoints (timesteps)
"store_separately": False,          # whether to store checkpoints separately

"wandb": False,             # whether to use Weights & Biases
"wandb_kwargs": {}          # wandb kwargs (see https://docs.wandb.ai/ref/python/init)
}
}


### Spaces#

The implementation supports the following Gym spaces / Gymnasium spaces

Gym/Gymnasium spaces

Observation

Action

Discrete

$$\square$$

$$\square$$

MultiDiscrete

$$\square$$

$$\square$$

Box

$$\blacksquare$$

$$\blacksquare$$

Dict

$$\blacksquare$$

$$\square$$

### Models#

The implementation uses 1 stochastic and 4 deterministic function approximators. These function approximators (models) must be collected in a dictionary and passed to the constructor of the class under the argument models

Notation

Concept

Key

Input shape

Output shape

Type

$$\pi_\theta(s)$$

Policy (actor)

"policy"

observation

action

$$Q_{\phi 1}(s, a)$$

Q1-network (critic 1)

"critic_1"

observation + action

1

Deterministic

$$Q_{\phi 2}(s, a)$$

Q2-network (critic 2)

"critic_2"

observation + action

1

Deterministic

$$Q_{{\phi 1}_{target}}(s, a)$$

Target Q1-network

"target_critic_1"

observation + action

1

Deterministic

$$Q_{{\phi 2}_{target}}(s, a)$$

Target Q2-network

"target_critic_2"

observation + action

1

Deterministic

### Features#

Support for advanced features is described in the next table

Feature

Support and remarks

Shared model

-

$$\square$$

$$\square$$

RNN support

RNN, LSTM, GRU and any other variant

$$\blacksquare$$

$$\square$$

## API (PyTorch)#

skrl.agents.torch.sac.SAC_DEFAULT_CONFIG#

alias of {‘actor_learning_rate’: 0.001, ‘batch_size’: 64, ‘critic_learning_rate’: 0.001, ‘discount_factor’: 0.99, ‘entropy_learning_rate’: 0.001, ‘experiment’: {‘base_directory’: ‘’, ‘checkpoint_interval’: 1000, ‘experiment_name’: ‘’, ‘store_separately’: False, ‘wandb’: False, ‘wandb_kwargs’: {}, ‘write_interval’: 250}, ‘grad_norm_clip’: 0, ‘gradient_steps’: 1, ‘initial_entropy_value’: 0.2, ‘learn_entropy’: True, ‘learning_rate_scheduler’: None, ‘learning_rate_scheduler_kwargs’: {}, ‘learning_starts’: 0, ‘polyak’: 0.005, ‘random_timesteps’: 0, ‘rewards_shaper’: None, ‘state_preprocessor’: None, ‘state_preprocessor_kwargs’: {}, ‘target_entropy’: None}

class skrl.agents.torch.sac.SAC(models: , memory: = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: = None, cfg: = None)#

Bases: Agent

__init__(models: , memory: = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: = None, cfg: = None) None#

Soft Actor-Critic (SAC)

https://arxiv.org/abs/1801.01290

Parameters:
• models (dictionary of skrl.models.torch.Model) – Models used by the agent

• memory (skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

• observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

• action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

• device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

• cfg (dict) – Configuration dictionary

Raises:

KeyError – If the models dictionary is missing a required key

_update(timestep: int, timesteps: int) None#

Algorithm’s main update step

Parameters:
• timestep (int) – Current timestep

• timesteps (int) – Number of timesteps

act(states: torch.Tensor, timestep: int, timesteps: int) #

Process the environment’s states to make a decision (actions) using the main policy

Parameters:
• states (torch.Tensor) – Environment’s states

• timestep (int) – Current timestep

• timesteps (int) – Number of timesteps

Returns:

Actions

Return type:

torch.Tensor

init(trainer_cfg: = None) None#

Initialize the agent

post_interaction(timestep: int, timesteps: int) None#

Callback called after the interaction with the environment

Parameters:
• timestep (int) – Current timestep

• timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None#

Callback called before the interaction with the environment

Parameters:
• timestep (int) – Current timestep

• timesteps (int) – Number of timesteps

record_transition(states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None#

Record an environment transition in memory

Parameters:
• states (torch.Tensor) – Observations/states of the environment used to make the decision

• actions (torch.Tensor) – Actions taken by the agent

• rewards (torch.Tensor) – Instant rewards achieved by the current actions

• next_states (torch.Tensor) – Next observations/states of the environment

• terminated (torch.Tensor) – Signals to indicate that episodes have terminated

• truncated (torch.Tensor) – Signals to indicate that episodes have been truncated

• infos (Any type supported by the environment) – Additional information about the environment

• timestep (int) – Current timestep

• timesteps (int) – Number of timesteps

class skrl.agents.torch.sac.SAC_RNN(models: , memory: = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: = None, cfg: = None)#

Bases: Agent

__init__(models: , memory: = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: = None, cfg: = None) None#

Soft Actor-Critic (SAC) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.)

https://arxiv.org/abs/1801.01290

Parameters:
• models (dictionary of skrl.models.torch.Model) – Models used by the agent

• memory (skrl.memory.torch.Memory, list of skrl.memory.torch.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

• observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

• action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

• device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

• cfg (dict) – Configuration dictionary

Raises:

KeyError – If the models dictionary is missing a required key

_update(timestep: int, timesteps: int) None#

Algorithm’s main update step

Parameters:
• timestep (int) – Current timestep

• timesteps (int) – Number of timesteps

act(states: torch.Tensor, timestep: int, timesteps: int) #

Process the environment’s states to make a decision (actions) using the main policy

Parameters:
• states (torch.Tensor) – Environment’s states

• timestep (int) – Current timestep

• timesteps (int) – Number of timesteps

Returns:

Actions

Return type:

torch.Tensor

init(trainer_cfg: = None) None#

Initialize the agent

post_interaction(timestep: int, timesteps: int) None#

Callback called after the interaction with the environment

Parameters:
• timestep (int) – Current timestep

• timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None#

Callback called before the interaction with the environment

Parameters:
• timestep (int) – Current timestep

• timesteps (int) – Number of timesteps

record_transition(states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None#

Record an environment transition in memory

Parameters:
• states (torch.Tensor) – Observations/states of the environment used to make the decision

• actions (torch.Tensor) – Actions taken by the agent

• rewards (torch.Tensor) – Instant rewards achieved by the current actions

• next_states (torch.Tensor) – Next observations/states of the environment

• terminated (torch.Tensor) – Signals to indicate that episodes have terminated

• truncated (torch.Tensor) – Signals to indicate that episodes have been truncated

• infos (Any type supported by the environment) – Additional information about the environment

• timestep (int) – Current timestep

• timesteps (int) – Number of timesteps

## API (JAX)#

skrl.agents.jax.sac.SAC_DEFAULT_CONFIG#

alias of {‘actor_learning_rate’: 0.001, ‘batch_size’: 64, ‘critic_learning_rate’: 0.001, ‘discount_factor’: 0.99, ‘entropy_learning_rate’: 0.001, ‘experiment’: {‘checkpoint_interval’: 1000, ‘directory’: ‘’, ‘experiment_name’: ‘’, ‘store_separately’: False, ‘wandb’: False, ‘wandb_kwargs’: {}, ‘write_interval’: 250}, ‘grad_norm_clip’: 0, ‘gradient_steps’: 1, ‘initial_entropy_value’: 0.2, ‘learn_entropy’: True, ‘learning_rate_scheduler’: None, ‘learning_rate_scheduler_kwargs’: {}, ‘learning_starts’: 0, ‘polyak’: 0.005, ‘random_timesteps’: 0, ‘rewards_shaper’: None, ‘state_preprocessor’: None, ‘state_preprocessor_kwargs’: {}, ‘target_entropy’: None}

class skrl.agents.jax.sac.SAC(models: , memory: = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: = None, cfg: = None)#

Bases: Agent

__init__(models: , memory: = None, observation_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, action_space: int | Tuple[int] | gym.Space | gymnasium.Space | None = None, device: = None, cfg: = None) None#

Soft Actor-Critic (SAC)

https://arxiv.org/abs/1801.01290

Parameters:
• models (dictionary of skrl.models.jax.Model) – Models used by the agent

• memory (skrl.memory.jax.Memory, list of skrl.memory.jax.Memory or None) – Memory to storage the transitions. If it is a tuple, the first element will be used for training and for the rest only the environment transitions will be added

• observation_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Observation/state space or shape (default: None)

• action_space (int, tuple or list of int, gym.Space, gymnasium.Space or None, optional) – Action space or shape (default: None)

• device (str or jax.Device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

• cfg (dict) – Configuration dictionary

Raises:

KeyError – If the models dictionary is missing a required key

_update(timestep: int, timesteps: int) None#

Algorithm’s main update step

Parameters:
• timestep (int) – Current timestep

• timesteps (int) – Number of timesteps

act(states: , timestep: int, timesteps: int) #

Process the environment’s states to make a decision (actions) using the main policy

Parameters:
• states (np.ndarray or jax.Array) – Environment’s states

• timestep (int) – Current timestep

• timesteps (int) – Number of timesteps

Returns:

Actions

Return type:

np.ndarray or jax.Array

init(trainer_cfg: = None) None#

Initialize the agent

post_interaction(timestep: int, timesteps: int) None#

Callback called after the interaction with the environment

Parameters:
• timestep (int) – Current timestep

• timesteps (int) – Number of timesteps

pre_interaction(timestep: int, timesteps: int) None#

Callback called before the interaction with the environment

Parameters:
• timestep (int) – Current timestep

• timesteps (int) – Number of timesteps

record_transition(states: , actions: , rewards: , next_states: , terminated: , truncated: , infos: Any, timestep: int, timesteps: int) None#

Record an environment transition in memory

Parameters:
• states (np.ndarray or jax.Array) – Observations/states of the environment used to make the decision

• actions (np.ndarray or jax.Array) – Actions taken by the agent

• rewards (np.ndarray or jax.Array) – Instant rewards achieved by the current actions

• next_states (np.ndarray or jax.Array) – Next observations/states of the environment

• terminated (np.ndarray or jax.Array) – Signals to indicate that episodes have terminated

• truncated (np.ndarray or jax.Array) – Signals to indicate that episodes have been truncated

• infos (Any type supported by the environment) – Additional information about the environment

• timestep (int) – Current timestep

• timesteps (int) – Number of timesteps