Soft Actor-Critic (SAC)

SAC is a model-free, stochastic off-policy actor-critic algorithm that uses double Q-learning (like TD3) and entropy regularization to maximize a trade-off between exploration and exploitation.

Paper: Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.



Algorithm


Algorithm implementation

Main notation/symbols:
- policy function approximator (\(\pi_\theta\)), critic function approximator (\(Q_\phi\))
- states (\(s\)), actions (\(a\)), rewards (\(r\)), next states (\(s'\)), terminated (\(d_{_{end}}\)), truncated (\(d_{_{timeout}}\))
- log probabilities (\(logp\)), entropy coefficient (\(\alpha\))
- loss (\(L\))

Learning algorithm


_update(...)
# gradient steps
FOR each gradient step up to gradient_steps DO
# sample a batch from memory
[\(s, a, r, s', d_{_{end}}, d_{_{timeout}}\)] with size batch_size
# compute target values
\(a',\; logp' \leftarrow \pi_\theta(s')\)
\(Q_{1_{target}} \leftarrow Q_{{\phi 1}_{target}}(s', a')\)
\(Q_{2_{target}} \leftarrow Q_{{\phi 2}_{target}}(s', a')\)
\(Q_{_{target}} \leftarrow \text{min}(Q_{1_{target}}, Q_{2_{target}}) - \alpha \; logp'\)
\(y \leftarrow r \;+\) discount_factor \(\neg (d_{_{end}} \lor d_{_{timeout}}) \; Q_{_{target}}\)
# compute critic loss
\(Q_1 \leftarrow Q_{\phi 1}(s, a)\)
\(Q_2 \leftarrow Q_{\phi 2}(s, a)\)
\(L_{Q_\phi} \leftarrow 0.5 \; (\frac{1}{N} \sum_{i=1}^N (Q_1 - y)^2 + \frac{1}{N} \sum_{i=1}^N (Q_2 - y)^2)\)
# optimization step (critic)
reset \(\text{optimizer}_\phi\)
\(\nabla_{\phi} L_{Q_\phi}\)
\(\text{clip}(\lVert \nabla_{\phi} \rVert)\) with grad_norm_clip
step \(\text{optimizer}_\phi\)
# compute policy (actor) loss
\(a,\; logp \leftarrow \pi_\theta(s)\)
\(Q_1 \leftarrow Q_{\phi 1}(s, a)\)
\(Q_2 \leftarrow Q_{\phi 2}(s, a)\)
\(L_{\pi_\theta} \leftarrow \frac{1}{N} \sum_{i=1}^N (\alpha \; logp - \text{min}(Q_1, Q_2))\)
# optimization step (policy)
reset \(\text{optimizer}_\theta\)
\(\nabla_{\theta} L_{\pi_\theta}\)
\(\text{clip}(\lVert \nabla_{\theta} \rVert)\) with grad_norm_clip
step \(\text{optimizer}_\theta\)
# entropy learning
IF learn_entropy is enabled THEN
# compute entropy loss
\({L}_{entropy} \leftarrow - \frac{1}{N} \sum_{i=1}^N (log(\alpha) \; (logp + \alpha_{Target}))\)
# optimization step (entropy)
reset \(\text{optimizer}_\alpha\)
\(\nabla_{\alpha} {L}_{entropy}\)
step \(\text{optimizer}_\alpha\)
# compute entropy coefficient
\(\alpha \leftarrow e^{log(\alpha)}\)
# update target networks
\({\phi 1}_{target} \leftarrow\) polyak \({\phi 1} + (1 \;-\) polyak \() {\phi 1}_{target}\)
\({\phi 2}_{target} \leftarrow\) polyak \({\phi 2} + (1 \;-\) polyak \() {\phi 2}_{target}\)
# update learning rate
IF there is a learning_rate_scheduler THEN
step \(\text{scheduler}_\theta (\text{optimizer}_\theta)\)
step \(\text{scheduler}_\phi (\text{optimizer}_\phi)\)

Usage

Note

Support for recurrent neural networks (RNN, LSTM, GRU and any other variant) is implemented in a separate file (sac_rnn.py) to maintain the readability of the standard implementation (sac.py).

# import the agent and its default configuration
from skrl.agents.torch.sac import SAC, SAC_CFG

# instantiate the agent's models
models = {}
models["policy"] = ...
models["critic_1"] = ...  # only required during training
models["critic_2"] = ...  # only required during training
models["target_critic_1"] = ...  # only required during training
models["target_critic_2"] = ...  # only required during training

# adjust some configuration if necessary
cfg_agent = SAC_CFG()
cfg_agent.KEY = ...

# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
agent = SAC(
    models=models,
    memory=memory,  # only required during training
    cfg=cfg_agent,
    observation_space=env.observation_space,
    state_space=env.state_space,
    action_space=env.action_space,
    device=env.device,
)

Configuration and hyperparameters

Dataclass

    pytorch    

    jax    

    warp    

SAC_CFG

SAC_CFG

SAC_CFG

SAC_CFG


Spaces

The implementation supports the following Gymnasium spaces:

Gymnasium spaces

Observation

Action

Discrete

\(\square\)

\(\square\)

MultiDiscrete

\(\square\)

\(\square\)

Box

\(\blacksquare\)

\(\blacksquare\)

Dict

\(\blacksquare\)

\(\square\)


Models

The implementation uses 1 stochastic and 4 deterministic function approximators. These function approximators (models) must be collected in a dictionary and passed to the constructor of the class under the argument models.

Notation

Concept

Key

Input shape

Output shape

Type

\(\pi_\theta(s)\)

Policy (actor)

"policy"

observation

action

Gaussian /
MultivariateGaussian

\(Q_{\phi 1}(s, a)\)

Q1-network (critic 1)

"critic_1"

observation + action

1

Deterministic

\(Q_{\phi 2}(s, a)\)

Q2-network (critic 2)

"critic_2"

observation + action

1

Deterministic

\(Q_{{\phi 1}_{target}}(s, a)\)

Target Q1-network

"target_critic_1"

observation + action

1

Deterministic

\(Q_{{\phi 2}_{target}}(s, a)\)

Target Q2-network

"target_critic_2"

observation + action

1

Deterministic


Features

Support for advanced features is described in the following table:

Feature

Support and remarks

    pytorch    

    jax    

    warp    

Shared model

-

\(\square\)

\(\square\)

\(\square\)

RNN support

RNN, LSTM, GRU and any other variant

\(\blacksquare\)

\(\square\)

\(\square\)

Mixed precision

Automatic mixed precision

\(\blacksquare\)

\(\square\)

\(\square\)

Distributed

Single Program Multi Data (SPMD) multi-GPU

\(\blacksquare\)

\(\blacksquare\)

\(\square\)


API


PyTorch

SAC_CFG

Configuration for the SAC agent.

SAC

Soft Actor-Critic (SAC).

SAC_RNN

Soft Actor-Critic (SAC) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.).

class skrl.agents.torch.sac.SAC_CFG(*, experiment: ExperimentCfg = <factory>, gradient_steps: int = 1, batch_size: int = 64, discount_factor: float = 0.99, polyak: float = 0.005, learning_rate: float | tuple[float, float, float] = 0.001, learning_rate_scheduler: type | tuple[type | None, type | None, type | None] | None = None, learning_rate_scheduler_kwargs: dict | tuple[dict, dict, dict] = <factory>, observation_preprocessor: type | None = None, observation_preprocessor_kwargs: dict = <factory>, state_preprocessor: type | None = None, state_preprocessor_kwargs: dict = <factory>, random_timesteps: int = 0, learning_starts: int = 0, grad_norm_clip: float = 0, learn_entropy: bool = True, initial_entropy_value: float = 0.2, target_entropy: float | None = None, rewards_shaper: Callable | None = None, mixed_precision: bool = False)[source]

Bases: AgentCfg

Configuration for the SAC agent.

Methods:

expand()

Expand the configuration.

validate()

Validate the configuration.

Attributes:

batch_size

Batch size for sampling transitions from memory during training.

discount_factor

Parameter that balances the importance of future rewards (close to 1.0) versus immediate rewards (close to 0.0).

experiment

Experiment settings.

grad_norm_clip

Clipping coefficient for the gradients by their global norm.

gradient_steps

Number of gradient steps to perform for each update.

initial_entropy_value

Initial value for the entropy coefficient.

learn_entropy

Whether to learn the entropy coefficient.

learning_rate

Learning rate for the actor and critic networks, and entropy coefficient.

learning_rate_scheduler

Learning rate scheduler class for the actor and critic networks, and entropy coefficient.

learning_rate_scheduler_kwargs

Keyword arguments for the learning rate scheduler's constructor.

learning_starts

Number of steps to perform before calling the algorithm update function.

mixed_precision

Whether to enable automatic mixed precision for higher performance.

observation_preprocessor

Preprocessor class to process the environment's observations.

observation_preprocessor_kwargs

Keyword arguments for the observation preprocessor's constructor.

polyak

Parameter to control the update of the target networks by polyak averaging.

random_timesteps

Number of random exploration (sampling random actions) steps to perform before sampling actions from the policy.

rewards_shaper

Rewards shaping function.

state_preprocessor

Preprocessor class to process the environment's states.

state_preprocessor_kwargs

Keyword arguments for the state preprocessor's constructor.

target_entropy

Target value for computing the entropy loss.

expand() None[source]

Expand the configuration.

validate() bool[source]

Validate the configuration.

batch_size: int = 64

Batch size for sampling transitions from memory during training.

discount_factor: float = 0.99

Parameter that balances the importance of future rewards (close to 1.0) versus immediate rewards (close to 0.0).

Range: [0.0, 1.0].

experiment: ExperimentCfg

Experiment settings.

grad_norm_clip: float = 0

Clipping coefficient for the gradients by their global norm.

If less than or equal to 0, the gradients will not be clipped.

gradient_steps: int = 1

Number of gradient steps to perform for each update.

initial_entropy_value: float = 0.2

Initial value for the entropy coefficient.

learn_entropy: bool = True

Whether to learn the entropy coefficient.

learning_rate: float | tuple[float, float, float] = 0.001

Learning rate for the actor and critic networks, and entropy coefficient.

  • If a float is provided, the same learning rate will be used for the networks/coefficient.

  • If a tuple is provided, its elements will be used for each network/coefficient in order.

learning_rate_scheduler: type | tuple[type | None, type | None, type | None] | None = None

Learning rate scheduler class for the actor and critic networks, and entropy coefficient.

See Learning rate schedulers for more details.

  • If a class is provided, the same learning rate scheduler will be used for the networks/coefficient.

  • If a tuple is provided, its elements will be used for each network/coefficient in order.

learning_rate_scheduler_kwargs: dict | tuple[dict, dict, dict]

Keyword arguments for the learning rate scheduler’s constructor.

See Learning rate schedulers for more details.

Warning

The optimizer argument is automatically passed to the learning rate scheduler’s constructor. Therefore, it must not be provided in the keyword arguments.

  • If a dictionary is provided, the same keyword arguments will be used for the networks/coefficient.

  • If a tuple is provided, its elements will be used for each network/coefficient in order.

learning_starts: int = 0

Number of steps to perform before calling the algorithm update function.

mixed_precision: bool = False

Whether to enable automatic mixed precision for higher performance.

observation_preprocessor: type | None = None

Preprocessor class to process the environment’s observations.

See Preprocessors for more details.

observation_preprocessor_kwargs: dict

Keyword arguments for the observation preprocessor’s constructor.

See Preprocessors for more details.

polyak: float = 0.005

Parameter to control the update of the target networks by polyak averaging.

Range: [0.0, 1.0]. See update_parameters() for more details.

random_timesteps: int = 0

Number of random exploration (sampling random actions) steps to perform before sampling actions from the policy.

rewards_shaper: Callable | None = None

Rewards shaping function.

state_preprocessor: type | None = None

Preprocessor class to process the environment’s states.

See Preprocessors for more details.

state_preprocessor_kwargs: dict

Keyword arguments for the state preprocessor’s constructor.

See Preprocessors for more details.

target_entropy: float | None = None

Target value for computing the entropy loss.

class skrl.agents.torch.sac.SAC(*, models: dict[str, Model], memory: Memory | None = None, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: SAC_CFG | dict = {})[source]

Bases: Agent

Soft Actor-Critic (SAC).

https://arxiv.org/abs/1801.01290

Parameters:
  • models – Agent’s models.

  • memory – Memory to storage agent’s data and environment transitions.

  • observation_space – Observation space.

  • state_space – State space.

  • action_space – Action space.

  • device – Data allocation and computation device. If not specified, the default device will be used.

  • cfg – Agent’s configuration.

Raises:

KeyError – If a configuration key is missing.

Methods:

act(observations, states, *, timestep, timesteps)

Process the environment's observations/states to make a decision (actions) using the main policy.

enable_models_training_mode([enabled])

Set the training mode of all the agent's models: enabled (training) or disabled (evaluation).

enable_training_mode([enabled, apply_to_models])

Set the training mode of the agent: enabled (training) or disabled (evaluation).

init(*[, trainer_cfg])

Initialize the agent.

load(path)

Load the agent from the specified path.

post_interaction(*, timestep, timesteps)

Method called after the interaction with the environment.

pre_interaction(*, timestep, timesteps)

Method called before the interaction with the environment.

record_transition(*, observations, states, ...)

Record an environment transition in memory.

save(path)

Save the agent to the specified path.

track_data(tag, value)

Track data to TensorBoard.

update(*, timestep, timesteps)

Algorithm's main update step.

write_checkpoint(*, timestep, timesteps)

Write checkpoint (modules) to persistent storage.

write_tracking_data(*, timestep, timesteps)

Write tracking data to TensorBoard.

act(observations: torch.Tensor, states: torch.Tensor | None, *, timestep: int, timesteps: int) tuple[torch.Tensor, dict[str, Any]][source]

Process the environment’s observations/states to make a decision (actions) using the main policy.

Parameters:
  • observations – Environment observations.

  • states – Environment states.

  • timestep – Current timestep.

  • timesteps – Number of timesteps.

Returns:

Agent output. The first component is the expected action/value returned by the agent. The second component is a dictionary containing extra output values according to the model.

enable_models_training_mode(enabled: bool = True) None[source]

Set the training mode of all the agent’s models: enabled (training) or disabled (evaluation).

Parameters:

enabled – True to enable the training mode, False to enable the evaluation mode.

enable_training_mode(enabled: bool = True, *, apply_to_models: bool = False) None[source]

Set the training mode of the agent: enabled (training) or disabled (evaluation).

The training mode can be queried by the training property.

Parameters:
  • enabled – True to enable the training mode, False to enable the evaluation mode.

  • apply_to_models – Whether to apply the training mode to all the agent’s models.

init(*, trainer_cfg: dict[str, Any] | None = None) None[source]

Initialize the agent.

Parameters:

trainer_cfg – Trainer configuration.

load(path: str) None[source]

Load the agent from the specified path.

Note

The final storage device is determined by the constructor of the agent.

Parameters:

path – Path to load the agent from.

post_interaction(*, timestep: int, timesteps: int) None[source]

Method called after the interaction with the environment.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.

pre_interaction(*, timestep: int, timesteps: int) None[source]

Method called before the interaction with the environment.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.

record_transition(*, observations: torch.Tensor, states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_observations: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None[source]

Record an environment transition in memory.

Parameters:
  • observations – Environment observations.

  • states – Environment states.

  • actions – Actions taken by the agent.

  • rewards – Instant rewards achieved by the current actions.

  • next_observations – Next environment observations.

  • next_states – Next environment states.

  • terminated – Signals that indicate episodes have terminated.

  • truncated – Signals that indicate episodes have been truncated.

  • infos – Additional information about the environment.

  • timestep – Current timestep.

  • timesteps – Number of timesteps.

save(path: str) None[source]

Save the agent to the specified path.

Parameters:

path – Path to save the agent to.

track_data(tag: str, value: float) None[source]

Track data to TensorBoard.

Note

Currently only scalar data is supported.

Parameters:
  • tag – Data identifier (e.g. ‘Loss/Policy loss’).

  • value – Value to track.

update(*, timestep: int, timesteps: int) None[source]

Algorithm’s main update step.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.

write_checkpoint(*, timestep: int, timesteps: int) None[source]

Write checkpoint (modules) to persistent storage.

Note

The checkpoints are stored in the subdirectory checkpoints within the experiment directory. The checkpoint name is the timestep argument value (if it is not None), or the current system date-time otherwise.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.

write_tracking_data(*, timestep: int, timesteps: int) None[source]

Write tracking data to TensorBoard.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.

class skrl.agents.torch.sac.SAC_RNN(*, models: dict[str, Model], memory: Memory | None = None, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | torch.device | None = None, cfg: SAC_CFG | dict = {})[source]

Bases: Agent

Soft Actor-Critic (SAC) with support for Recurrent Neural Networks (RNN, GRU, LSTM, etc.).

https://arxiv.org/abs/1801.01290

Parameters:
  • models – Agent’s models.

  • memory – Memory to storage agent’s data and environment transitions.

  • observation_space – Observation space.

  • state_space – State space.

  • action_space – Action space.

  • device – Data allocation and computation device. If not specified, the default device will be used.

  • cfg – Agent’s configuration.

Raises:

KeyError – If a configuration key is missing.

Methods:

act(observations, states, *, timestep, timesteps)

Process the environment's observations/states to make a decision (actions) using the main policy.

enable_models_training_mode([enabled])

Set the training mode of all the agent's models: enabled (training) or disabled (evaluation).

enable_training_mode([enabled, apply_to_models])

Set the training mode of the agent: enabled (training) or disabled (evaluation).

init(*[, trainer_cfg])

Initialize the agent.

load(path)

Load the agent from the specified path.

post_interaction(*, timestep, timesteps)

Method called after the interaction with the environment.

pre_interaction(*, timestep, timesteps)

Method called before the interaction with the environment.

record_transition(*, observations, states, ...)

Record an environment transition in memory.

save(path)

Save the agent to the specified path.

track_data(tag, value)

Track data to TensorBoard.

update(*, timestep, timesteps)

Algorithm's main update step.

write_checkpoint(*, timestep, timesteps)

Write checkpoint (modules) to persistent storage.

write_tracking_data(*, timestep, timesteps)

Write tracking data to TensorBoard.

act(observations: torch.Tensor, states: torch.Tensor | None, *, timestep: int, timesteps: int) tuple[torch.Tensor, dict[str, Any]][source]

Process the environment’s observations/states to make a decision (actions) using the main policy.

Parameters:
  • observations – Environment observations.

  • states – Environment states.

  • timestep – Current timestep.

  • timesteps – Number of timesteps.

Returns:

Agent output. The first component is the expected action/value returned by the agent. The second component is a dictionary containing extra output values according to the model.

enable_models_training_mode(enabled: bool = True) None[source]

Set the training mode of all the agent’s models: enabled (training) or disabled (evaluation).

Parameters:

enabled – True to enable the training mode, False to enable the evaluation mode.

enable_training_mode(enabled: bool = True, *, apply_to_models: bool = False) None[source]

Set the training mode of the agent: enabled (training) or disabled (evaluation).

The training mode can be queried by the training property.

Parameters:
  • enabled – True to enable the training mode, False to enable the evaluation mode.

  • apply_to_models – Whether to apply the training mode to all the agent’s models.

init(*, trainer_cfg: dict[str, Any] | None = None) None[source]

Initialize the agent.

Parameters:

trainer_cfg – Trainer configuration.

load(path: str) None[source]

Load the agent from the specified path.

Note

The final storage device is determined by the constructor of the agent.

Parameters:

path – Path to load the agent from.

post_interaction(*, timestep: int, timesteps: int) None[source]

Method called after the interaction with the environment.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.

pre_interaction(*, timestep: int, timesteps: int) None[source]

Method called before the interaction with the environment.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.

record_transition(*, observations: torch.Tensor, states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_observations: torch.Tensor, next_states: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: Any, timestep: int, timesteps: int) None[source]

Record an environment transition in memory.

Parameters:
  • observations – Environment observations.

  • states – Environment states.

  • actions – Actions taken by the agent.

  • rewards – Instant rewards achieved by the current actions.

  • next_observations – Next environment observations.

  • next_states – Next environment states.

  • terminated – Signals that indicate episodes have terminated.

  • truncated – Signals that indicate episodes have been truncated.

  • infos – Additional information about the environment.

  • timestep – Current timestep.

  • timesteps – Number of timesteps.

save(path: str) None[source]

Save the agent to the specified path.

Parameters:

path – Path to save the agent to.

track_data(tag: str, value: float) None[source]

Track data to TensorBoard.

Note

Currently only scalar data is supported.

Parameters:
  • tag – Data identifier (e.g. ‘Loss/Policy loss’).

  • value – Value to track.

update(*, timestep: int, timesteps: int) None[source]

Algorithm’s main update step.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.

write_checkpoint(*, timestep: int, timesteps: int) None[source]

Write checkpoint (modules) to persistent storage.

Note

The checkpoints are stored in the subdirectory checkpoints within the experiment directory. The checkpoint name is the timestep argument value (if it is not None), or the current system date-time otherwise.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.

write_tracking_data(*, timestep: int, timesteps: int) None[source]

Write tracking data to TensorBoard.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.


JAX

SAC_CFG

Configuration for the SAC agent.

SAC

Soft Actor-Critic (SAC).

class skrl.agents.jax.sac.SAC_CFG(*, experiment: ExperimentCfg = <factory>, gradient_steps: int = 1, batch_size: int = 64, discount_factor: float = 0.99, polyak: float = 0.005, learning_rate: float | tuple[float, float, float] = 0.001, learning_rate_scheduler: type | tuple[type | None, type | None, type | None] | None = None, learning_rate_scheduler_kwargs: dict | tuple[dict, dict, dict] = <factory>, observation_preprocessor: type | None = None, observation_preprocessor_kwargs: dict = <factory>, state_preprocessor: type | None = None, state_preprocessor_kwargs: dict = <factory>, random_timesteps: int = 0, learning_starts: int = 0, grad_norm_clip: float = 0, learn_entropy: bool = True, initial_entropy_value: float = 0.2, target_entropy: float | None = None, rewards_shaper: Callable | None = None)[source]

Bases: AgentCfg

Configuration for the SAC agent.

Methods:

expand()

Expand the configuration.

validate()

Validate the configuration.

Attributes:

batch_size

Batch size for sampling transitions from memory during training.

discount_factor

Parameter that balances the importance of future rewards (close to 1.0) versus immediate rewards (close to 0.0).

experiment

Experiment settings.

grad_norm_clip

Clipping coefficient for the gradients by their global norm.

gradient_steps

Number of gradient steps to perform for each update.

initial_entropy_value

Initial value for the entropy coefficient.

learn_entropy

Whether to learn the entropy coefficient.

learning_rate

Learning rate for the actor and critic networks, and entropy coefficient.

learning_rate_scheduler

Learning rate scheduler class for the actor and critic networks, and entropy coefficient.

learning_rate_scheduler_kwargs

Keyword arguments for the learning rate scheduler's constructor.

learning_starts

Number of steps to perform before calling the algorithm update function.

observation_preprocessor

Preprocessor class to process the environment's observations.

observation_preprocessor_kwargs

Keyword arguments for the observation preprocessor's constructor.

polyak

Parameter to control the update of the target networks by polyak averaging.

random_timesteps

Number of random exploration (sampling random actions) steps to perform before sampling actions from the policy.

rewards_shaper

Rewards shaping function.

state_preprocessor

Preprocessor class to process the environment's states.

state_preprocessor_kwargs

Keyword arguments for the state preprocessor's constructor.

target_entropy

Target value for computing the entropy loss.

expand() None[source]

Expand the configuration.

validate() bool[source]

Validate the configuration.

batch_size: int = 64

Batch size for sampling transitions from memory during training.

discount_factor: float = 0.99

Parameter that balances the importance of future rewards (close to 1.0) versus immediate rewards (close to 0.0).

Range: [0.0, 1.0].

experiment: ExperimentCfg

Experiment settings.

grad_norm_clip: float = 0

Clipping coefficient for the gradients by their global norm.

If less than or equal to 0, the gradients will not be clipped.

gradient_steps: int = 1

Number of gradient steps to perform for each update.

initial_entropy_value: float = 0.2

Initial value for the entropy coefficient.

learn_entropy: bool = True

Whether to learn the entropy coefficient.

learning_rate: float | tuple[float, float, float] = 0.001

Learning rate for the actor and critic networks, and entropy coefficient.

  • If a float is provided, the same learning rate will be used for the networks/coefficient.

  • If a tuple is provided, its elements will be used for each network/coefficient in order.

learning_rate_scheduler: type | tuple[type | None, type | None, type | None] | None = None

Learning rate scheduler class for the actor and critic networks, and entropy coefficient.

See Learning rate schedulers for more details.

  • If a class is provided, the same learning rate scheduler will be used for the networks/coefficient.

  • If a tuple is provided, its elements will be used for each network/coefficient in order.

learning_rate_scheduler_kwargs: dict | tuple[dict, dict, dict]

Keyword arguments for the learning rate scheduler’s constructor.

See Learning rate schedulers for more details.

Warning

The optimizer argument is automatically passed to the learning rate scheduler’s constructor. Therefore, it must not be provided in the keyword arguments.

  • If a dictionary is provided, the same keyword arguments will be used for the networks/coefficient.

  • If a tuple is provided, its elements will be used for each network/coefficient in order.

learning_starts: int = 0

Number of steps to perform before calling the algorithm update function.

observation_preprocessor: type | None = None

Preprocessor class to process the environment’s observations.

See Preprocessors for more details.

observation_preprocessor_kwargs: dict

Keyword arguments for the observation preprocessor’s constructor.

See Preprocessors for more details.

polyak: float = 0.005

Parameter to control the update of the target networks by polyak averaging.

Range: [0.0, 1.0]. See update_parameters() for more details.

random_timesteps: int = 0

Number of random exploration (sampling random actions) steps to perform before sampling actions from the policy.

rewards_shaper: Callable | None = None

Rewards shaping function.

state_preprocessor: type | None = None

Preprocessor class to process the environment’s states.

See Preprocessors for more details.

state_preprocessor_kwargs: dict

Keyword arguments for the state preprocessor’s constructor.

See Preprocessors for more details.

target_entropy: float | None = None

Target value for computing the entropy loss.

class skrl.agents.jax.sac.SAC(*, models: dict[str, Model], memory: Memory | None = None, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | jax.Device | None = None, cfg: SAC_CFG | dict = {})[source]

Bases: Agent

Soft Actor-Critic (SAC).

https://arxiv.org/abs/1801.01290

Parameters:
  • models – Agent’s models.

  • memory – Memory to storage agent’s data and environment transitions.

  • observation_space – Observation space.

  • state_space – State space.

  • action_space – Action space.

  • device – Data allocation and computation device. If not specified, the default device will be used.

  • cfg – Agent’s configuration.

Raises:

KeyError – If a configuration key is missing.

Methods:

act(observations, states, *, timestep, timesteps)

Process the environment's observations/states to make a decision (actions) using the main policy.

enable_models_training_mode([enabled])

Set the training mode of all the agent's models: enabled (training) or disabled (evaluation).

enable_training_mode([enabled, apply_to_models])

Set the training mode of the agent: enabled (training) or disabled (evaluation).

init(*[, trainer_cfg])

Initialize the agent.

load(path)

Load the agent from the specified path.

post_interaction(*, timestep, timesteps)

Method called after the interaction with the environment.

pre_interaction(*, timestep, timesteps)

Method called before the interaction with the environment.

record_transition(*, observations, states, ...)

Record an environment transition in memory.

save(path)

Save the agent to the specified path.

track_data(tag, value)

Track data to TensorBoard.

update(*, timestep, timesteps)

Algorithm's main update step.

write_checkpoint(*, timestep, timesteps)

Write checkpoint (modules) to persistent storage.

write_tracking_data(*, timestep, timesteps)

Write tracking data to TensorBoard.

act(observations: jax.Array, states: jax.Array | None, *, timestep: int, timesteps: int) tuple[jax.Array, dict[str, Any]][source]

Process the environment’s observations/states to make a decision (actions) using the main policy.

Parameters:
  • observations – Environment observations.

  • states – Environment states.

  • timestep – Current timestep.

  • timesteps – Number of timesteps.

Returns:

Agent output. The first component is the expected action/value returned by the agent. The second component is a dictionary containing extra output values according to the model.

enable_models_training_mode(enabled: bool = True) None[source]

Set the training mode of all the agent’s models: enabled (training) or disabled (evaluation).

Parameters:

enabled – True to enable the training mode, False to enable the evaluation mode.

enable_training_mode(enabled: bool = True, *, apply_to_models: bool = False) None[source]

Set the training mode of the agent: enabled (training) or disabled (evaluation).

The training mode can be queried by the training property.

Parameters:
  • enabled – True to enable the training mode, False to enable the evaluation mode.

  • apply_to_models – Whether to apply the training mode to all the agent’s models.

init(*, trainer_cfg: dict[str, Any] | None = None) None[source]

Initialize the agent.

Parameters:

trainer_cfg – Trainer configuration.

load(path: str) None[source]

Load the agent from the specified path.

Parameters:

path – Path to load the agent from.

post_interaction(*, timestep: int, timesteps: int) None[source]

Method called after the interaction with the environment.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.

pre_interaction(*, timestep: int, timesteps: int) None[source]

Method called before the interaction with the environment.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.

record_transition(*, observations: jax.Array, states: jax.Array, actions: jax.Array, rewards: jax.Array, next_observations: jax.Array, next_states: jax.Array, terminated: jax.Array, truncated: jax.Array, infos: Any, timestep: int, timesteps: int) None[source]

Record an environment transition in memory.

Parameters:
  • observations – Environment observations.

  • states – Environment states.

  • actions – Actions taken by the agent.

  • rewards – Instant rewards achieved by the current actions.

  • next_observations – Next environment observations.

  • next_states – Next environment states.

  • terminated – Signals that indicate episodes have terminated.

  • truncated – Signals that indicate episodes have been truncated.

  • infos – Additional information about the environment.

  • timestep – Current timestep.

  • timesteps – Number of timesteps.

save(path: str) None[source]

Save the agent to the specified path.

Parameters:

path – Path to save the agent to.

track_data(tag: str, value: float) None[source]

Track data to TensorBoard.

Note

Currently only scalar data is supported.

Parameters:
  • tag – Data identifier (e.g. ‘Loss/Policy loss’).

  • value – Value to track.

update(*, timestep: int, timesteps: int) None[source]

Algorithm’s main update step.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.

write_checkpoint(*, timestep: int, timesteps: int) None[source]

Write checkpoint (modules) to persistent storage.

Note

The checkpoints are stored in the subdirectory checkpoints within the experiment directory. The checkpoint name is the timestep argument value (if it is not None), or the current system date-time otherwise.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.

write_tracking_data(*, timestep: int, timesteps: int) None[source]

Write tracking data to TensorBoard.

Parameters:
  • timestep – Current timestep.

  • timesteps – Number of timesteps.


Warp

SAC_CFG

Configuration for the SAC agent.

SAC

Soft Actor-Critic (SAC).