
Models (or agent models) refer to a representation of the agent’s policy, value function, etc. that the agent uses to make decisions. Agents can have one or more models, and their parameters are adjusted by the optimization algorithms.




Tabular model (discrete domain)



Categorical model (discrete domain)



Multi-Categorical model (discrete domain)



Gaussian model (continuous domain)



Multivariate Gaussian model (continuous domain)



Deterministic model (continuous domain)



Shared model



Base class


This is the base class for all models in this module and provides only basic functionality that is not tied to any specific implementation. It is not intended to be used directly.

Mixin and inheritance

from typing import Union, Mapping, Tuple, Any

import torch

class CustomMixin:
    def __init__(self, role: str = "") -> None:
        """Custom mixin

        :param role: Role play by the model (default: ``""``)
        :type role: str, optional
        # =====================================
        # - define custom attributes and others
        # =====================================

    def act(self,
            inputs: Mapping[str, Union[torch.Tensor, Any]],
            role: str = "") -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]:
        """Act according to the specified behavior

        :param inputs: Model inputs. The most common keys are:

                       - ``"states"``: state of the environment used to make the decision
                       - ``"taken_actions"``: actions taken by the policy for the given states
        :type inputs: dict where the values are typically torch.Tensor
        :param role: Role play by the model (default: ``""``)
        :type role: str, optional

        :return: Model output. The first component is the action to be taken by the agent.
                 The second component is the log of the probability density function for stochastic models
                 or None for deterministic models. The third component is a dictionary containing extra output values
        :rtype: tuple of torch.Tensor, torch.Tensor or None, and dictionary
        # ==============================
        # - act in response to the state
        # ==============================

API (PyTorch)

class skrl.models.torch.base.Model(*args: Any, **kwargs: Any)

Bases: Module

__init__(observation_space: int | Sequence[int] | gym.Space | gymnasium.Space, action_space: int | Sequence[int] | gym.Space | gymnasium.Space, device: str | torch.device | None = None) None

Base class representing a function approximator

The following properties are defined:

  • device (torch.device): Device to be used for the computations

  • observation_space (int, sequence of int, gym.Space, gymnasium.Space): Observation/state space

  • action_space (int, sequence of int, gym.Space, gymnasium.Space): Action space

  • num_observations (int): Number of elements in the observation/state space

  • num_actions (int): Number of elements in the action space

  • observation_space (int, sequence of int, gym.Space, gymnasium.Space) – Observation/state space or shape. The num_observations property will contain the size of that space

  • action_space (int, sequence of int, gym.Space, gymnasium.Space) – Action space or shape. The num_actions property will contain the size of that space

  • device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

Custom models should override the act method:

import torch
from skrl.models.torch import Model

class CustomModel(Model):
    def __init__(self, observation_space, action_space, device="cuda:0"):
        Model.__init__(self, observation_space, action_space, device)

        self.layer_1 = nn.Linear(self.num_observations, 64)
        self.layer_2 = nn.Linear(64, self.num_actions)

    def act(self, inputs, role=""):
        x = F.relu(self.layer_1(inputs["states"]))
        x = F.relu(self.layer_2(x))
        return x, None, {}
property device

Device to be used for the computations

property observation_space

Observation/state space. It is a replica of the class constructor parameter of the same name

property action_space

Action space. It is a replica of the class constructor parameter of the same name

property num_observations

Number of elements in the observation/state space

property num_actions

Number of elements in the action space

_get_space_size(space: int | Sequence[int] | gym.Space | gymnasium.Space, number_of_elements: bool = True) int

Get the size (number of elements) of a space

  • space (int, sequence of int, gym.Space, or gymnasium.Space) – Space or shape from which to obtain the number of elements

  • number_of_elements (bool, optional) – Whether the number of elements occupied by the space is returned (default: True). If False, the shape of the space is returned. It only affects Discrete and MultiDiscrete spaces


ValueError – If the space is not supported


Size of the space (number of elements)

Return type:



# from int
>>> model._get_space_size(2)

# from sequence of int
>>> model._get_space_size([2, 3])

# Box space
>>> space = gym.spaces.Box(low=-1, high=1, shape=(2, 3))
>>> model._get_space_size(space)

# Discrete space
>>> space = gym.spaces.Discrete(4)
>>> model._get_space_size(space)
>>> model._get_space_size(space, number_of_elements=False)

# MultiDiscrete space
>>> space = gym.spaces.MultiDiscrete([5, 3, 2])
>>> model._get_space_size(space)
>>> model._get_space_size(space, number_of_elements=False)

# Dict space
>>> space = gym.spaces.Dict({'a': gym.spaces.Box(low=-1, high=1, shape=(2, 3)),
...                          'b': gym.spaces.Discrete(4)})
>>> model._get_space_size(space)
>>> model._get_space_size(space, number_of_elements=False)
act(inputs: Mapping[str, torch.Tensor | Any], role: str = '') Tuple[torch.Tensor, torch.Tensor | None, Mapping[str, torch.Tensor | Any]]

Act according to the specified behavior (to be implemented by the inheriting classes)

Agents will call this method to obtain the decision to be taken given the state of the environment. This method is currently implemented by the helper models (GaussianModel, etc.). The classes that inherit from the latter must only implement the .compute() method

  • inputs (dict where the values are typically torch.Tensor) –

    Model inputs. The most common keys are:

    • "states": state of the environment used to make the decision

    • "taken_actions": actions taken by the policy for the given states

  • role (str, optional) – Role play by the model (default: "")


NotImplementedError – Child class must implement this method


Model output. The first component is the action to be taken by the agent. The second component is the log of the probability density function for stochastic models or None for deterministic models. The third component is a dictionary containing extra output values

Return type:

tuple of torch.Tensor, torch.Tensor or None, and dict

compute(inputs: Mapping[str, torch.Tensor | Any], role: str = '') Tuple[torch.Tensor | Mapping[str, torch.Tensor | Any]]

Define the computation performed (to be implemented by the inheriting classes) by the models

  • inputs (dict where the values are typically torch.Tensor) –

    Model inputs. The most common keys are:

    • "states": state of the environment used to make the decision

    • "taken_actions": actions taken by the policy for the given states

  • role (str, optional) – Role play by the model (default: "")


NotImplementedError – Child class must implement this method


Computation performed by the models

Return type:

tuple of torch.Tensor and dict

forward(inputs: Mapping[str, torch.Tensor | Any], role: str = '') Tuple[torch.Tensor, torch.Tensor | None, Mapping[str, torch.Tensor | Any]]

Forward pass of the model

This method calls the .act() method and returns its outputs

  • inputs (dict where the values are typically torch.Tensor) –

    Model inputs. The most common keys are:

    • "states": state of the environment used to make the decision

    • "taken_actions": actions taken by the policy for the given states

  • role (str, optional) – Role play by the model (default: "")


Model output. The first component is the action to be taken by the agent. The second component is the log of the probability density function for stochastic models or None for deterministic models. The third component is a dictionary containing extra output values

Return type:

tuple of torch.Tensor, torch.Tensor or None, and dict

freeze_parameters(freeze: bool = True) None

Freeze or unfreeze internal parameters

  • Freeze: disable gradient computation (parameters.requires_grad = False)

  • Unfreeze: enable gradient computation (parameters.requires_grad = True)


freeze (bool, optional) – Freeze the internal parameters if True, otherwise unfreeze them (default: True)


# freeze model parameters
>>> model.freeze_parameters(True)

# unfreeze model parameters
>>> model.freeze_parameters(False)
get_specification() Mapping[str, Any]

Returns the specification of the model

The following keys are used by the agents for initialization:

  • "rnn": Recurrent Neural Network (RNN) specification for RNN, LSTM and GRU layers/cells

    • "sizes": List of RNN shapes (number of layers, number of environments, number of features in the RNN state). There must be as many tuples as there are states in the recurrent layer/cell. E.g., LSTM has 2 states (hidden and cell).


Dictionary containing advanced specification of the model

Return type:



# model with a LSTM layer.
# - number of layers: 1
# - number of environments: 4
# - number of features in the RNN state: 64
>>> model.get_specification()
{'rnn': {'sizes': [(1, 4, 64), (1, 4, 64)]}}
init_biases(method_name: str = 'constant_', *args, **kwargs) None

Initialize the model biases according to the specified method name

Method names are from the torch.nn.init module. Allowed method names are uniform_, normal_, constant_, etc.

The following layers will be initialized: - torch.nn.Linear

  • method_name (str, optional) –

    torch.nn.init method name (default: "constant_")

  • args (tuple, optional) – Positional arguments of the method to be called

  • kwargs (dict, optional) – Key-value arguments of the method to be called


# initialize all biases with a constant value (0)
>>> model.init_biases(method_name="constant_", val=0)

# initialize all biases with normal distribution with mean 0 and standard deviation 0.25
>>> model.init_biases(method_name="normal_", mean=0.0, std=0.25)
init_parameters(method_name: str = 'normal_', *args, **kwargs) None

Initialize the model parameters according to the specified method name

Method names are from the torch.nn.init module. Allowed method names are uniform_, normal_, constant_, etc.

  • method_name (str, optional) –

    torch.nn.init method name (default: "normal_")

  • args (tuple, optional) – Positional arguments of the method to be called

  • kwargs (dict, optional) – Key-value arguments of the method to be called


# initialize all parameters with an orthogonal distribution with a gain of 0.5
>>> model.init_parameters("orthogonal_", gain=0.5)

# initialize all parameters as a sparse matrix with a sparsity of 0.1
>>> model.init_parameters("sparse_", sparsity=0.1)
init_weights(method_name: str = 'orthogonal_', *args, **kwargs) None

Initialize the model weights according to the specified method name

Method names are from the torch.nn.init module. Allowed method names are uniform_, normal_, constant_, etc.

The following layers will be initialized: - torch.nn.Linear

  • method_name (str, optional) –

    torch.nn.init method name (default: "orthogonal_")

  • args (tuple, optional) – Positional arguments of the method to be called

  • kwargs (dict, optional) – Key-value arguments of the method to be called


# initialize all weights with uniform distribution in range [-0.1, 0.1]
>>> model.init_weights(method_name="uniform_", a=-0.1, b=0.1)

# initialize all weights with normal distribution with mean 0 and standard deviation 0.25
>>> model.init_weights(method_name="normal_", mean=0.0, std=0.25)
load(path: str) None

Load the model from the specified path

The final storage device is determined by the constructor of the model


path (str) – Path to load the model from


# load the model onto the CPU
>>> model = Model(observation_space, action_space, device="cpu")
>>> model.load("")

# load the model onto the GPU 1
>>> model = Model(observation_space, action_space, device="cuda:1")
>>> model.load("")
migrate(state_dict: Mapping[str, torch.Tensor] | None = None, path: str | None = None, name_map: Mapping[str, str] = {}, auto_mapping: bool = True, verbose: bool = False) bool

Migrate the specified extrernal model’s state dict to the current model

The final storage device is determined by the constructor of the model

Only one of state_dict or path can be specified. The path parameter allows automatic loading the state_dict only from files generated by the rl_games and stable-baselines3 libraries at the moment

For ambiguous models (where 2 or more parameters, for source or current model, have equal shape) it is necessary to define the name_map, at least for those parameters, to perform the migration successfully

  • state_dict (Mapping[str, torch.Tensor], optional) – External model’s state dict to migrate from (default: None)

  • path (str, optional) – Path to the external checkpoint to migrate from (default: None)

  • name_map (Mapping[str, str], optional) – Name map to use for the migration (default: {}). Keys are the current parameter names and values are the external parameter names

  • auto_mapping (bool, optional) – Automatically map the external state dict to the current state dict (default: True)

  • verbose (bool, optional) – Show model names and migration (default: False)

  • ValueError – If neither or both of state_dict and path parameters have been set

  • ValueError – If the correct file type cannot be identified from the path parameter


True if the migration was successful, False otherwise. Migration is successful if all parameters of the current model are found in the external model

Return type:



# migrate a rl_games checkpoint with unambiguous state_dict
>>> model.migrate(path="./runs/Ant/nn/Ant.pth")

# migrate a rl_games checkpoint with ambiguous state_dict
>>> model.migrate(path="./runs/Cartpole/nn/Cartpole.pth", verbose=False)
[skrl:WARNING] Ambiguous match for log_std_parameter <- [value_mean_std.running_mean, value_mean_std.running_var, a2c_network.sigma]
[skrl:WARNING] Ambiguous match for net.0.bias <- [a2c_network.actor_mlp.0.bias, a2c_network.actor_mlp.2.bias]
[skrl:WARNING] Ambiguous match for net.2.bias <- [a2c_network.actor_mlp.0.bias, a2c_network.actor_mlp.2.bias]
[skrl:WARNING] Ambiguous match for net.4.weight <- [a2c_network.value.weight,]
[skrl:WARNING] Ambiguous match for net.4.bias <- [a2c_network.value.bias,]
[skrl:WARNING] Multiple use of a2c_network.actor_mlp.0.bias -> [net.0.bias, net.2.bias]
[skrl:WARNING] Multiple use of a2c_network.actor_mlp.2.bias -> [net.0.bias, net.2.bias]
>>> name_map = {"log_std_parameter": "a2c_network.sigma",
...             "net.0.bias": "a2c_network.actor_mlp.0.bias",
...             "net.2.bias": "a2c_network.actor_mlp.2.bias",
...             "net.4.weight": "",
...             "net.4.bias": ""}
>>> model.migrate(path="./runs/Cartpole/nn/Cartpole.pth", name_map=name_map, verbose=True)
[skrl:INFO] Models
[skrl:INFO]   |-- current: 7 items
[skrl:INFO]   |    |-- log_std_parameter : torch.Size([1])
[skrl:INFO]   |    |-- net.0.weight : torch.Size([32, 4])
[skrl:INFO]   |    |-- net.0.bias : torch.Size([32])
[skrl:INFO]   |    |-- net.2.weight : torch.Size([32, 32])
[skrl:INFO]   |    |-- net.2.bias : torch.Size([32])
[skrl:INFO]   |    |-- net.4.weight : torch.Size([1, 32])
[skrl:INFO]   |    |-- net.4.bias : torch.Size([1])
[skrl:INFO]   |-- source: 15 items
[skrl:INFO]   |    |-- value_mean_std.running_mean : torch.Size([1])
[skrl:INFO]   |    |-- value_mean_std.running_var : torch.Size([1])
[skrl:INFO]   |    |-- value_mean_std.count : torch.Size([])
[skrl:INFO]   |    |-- running_mean_std.running_mean : torch.Size([4])
[skrl:INFO]   |    |-- running_mean_std.running_var : torch.Size([4])
[skrl:INFO]   |    |-- running_mean_std.count : torch.Size([])
[skrl:INFO]   |    |-- a2c_network.sigma : torch.Size([1])
[skrl:INFO]   |    |-- a2c_network.actor_mlp.0.weight : torch.Size([32, 4])
[skrl:INFO]   |    |-- a2c_network.actor_mlp.0.bias : torch.Size([32])
[skrl:INFO]   |    |-- a2c_network.actor_mlp.2.weight : torch.Size([32, 32])
[skrl:INFO]   |    |-- a2c_network.actor_mlp.2.bias : torch.Size([32])
[skrl:INFO]   |    |-- a2c_network.value.weight : torch.Size([1, 32])
[skrl:INFO]   |    |-- a2c_network.value.bias : torch.Size([1])
[skrl:INFO]   |    |-- : torch.Size([1, 32])
[skrl:INFO]   |    |-- : torch.Size([1])
[skrl:INFO] Migration
[skrl:INFO]   |-- map:  log_std_parameter <- a2c_network.sigma
[skrl:INFO]   |-- auto: net.0.weight <- a2c_network.actor_mlp.0.weight
[skrl:INFO]   |-- map:  net.0.bias <- a2c_network.actor_mlp.0.bias
[skrl:INFO]   |-- auto: net.2.weight <- a2c_network.actor_mlp.2.weight
[skrl:INFO]   |-- map:  net.2.bias <- a2c_network.actor_mlp.2.bias
[skrl:INFO]   |-- map:  net.4.weight <-
[skrl:INFO]   |-- map:  net.4.bias <-

# migrate a stable-baselines3 checkpoint with unambiguous state_dict
>>> model.migrate(path="./")

# migrate from any exported model by loading its state_dict (unambiguous state_dict)
>>> state_dict = torch.load("./")
>>> model.migrate(state_dict=state_dict)
random_act(inputs: Mapping[str, torch.Tensor | Any], role: str = '') Tuple[torch.Tensor, None, Mapping[str, torch.Tensor | Any]]

Act randomly according to the action space

  • inputs (dict where the values are typically torch.Tensor) –

    Model inputs. The most common keys are:

    • "states": state of the environment used to make the decision

    • "taken_actions": actions taken by the policy for the given states

  • role (str, optional) – Role play by the model (default: "")


NotImplementedError – Unsupported action space


Model output. The first component is the action to be taken by the agent

Return type:

tuple of torch.Tensor, None, and dict

save(path: str, state_dict: dict | None = None) None

Save the model to the specified path

  • path (str) – Path to save the model to

  • state_dict (dict, optional) – State dictionary to save (default: None). If None, the model’s state_dict will be saved


# save the current model to the specified path

# save an older version of the model to the specified path
>>> old_state_dict = copy.deepcopy(model.state_dict())
>>> # ...
>>>"/tmp/", old_state_dict)
set_mode(mode: str) None

Set the model mode (training or evaluation)


mode (str) – Mode: "train" for training or "eval" for evaluation. See torch.nn.Module.train


ValueError – If the mode is not "train" or "eval"

tensor_to_space(tensor: torch.Tensor, space: gym.Space | gymnasium.Space, start: int = 0) torch.Tensor | dict

Map a flat tensor to a Gym/Gymnasium space

The mapping is done in the following way:

  • Tensors belonging to Discrete spaces are returned without modification

  • Tensors belonging to Box spaces are reshaped to the corresponding space shape keeping the first dimension (number of samples) as they are

  • Tensors belonging to Dict spaces are mapped into a dictionary with the same keys as the original space

  • tensor (torch.Tensor) – Tensor to map from

  • space (gym.Space or gymnasium.Space) – Space to map the tensor to

  • start (int, optional) – Index of the first element of the tensor to map (default: 0)


ValueError – If the space is not supported


Mapped tensor or dictionary

Return type:

torch.Tensor or dict


>>> space = gym.spaces.Dict({'a': gym.spaces.Box(low=-1, high=1, shape=(2, 3)),
...                          'b': gym.spaces.Discrete(4)})
>>> tensor = torch.tensor([[-0.3, -0.2, -0.1, 0.1, 0.2, 0.3, 2]])
>>> model.tensor_to_space(tensor, space)
{'a': tensor([[[-0.3000, -0.2000, -0.1000],
               [ 0.1000,  0.2000,  0.3000]]]),
 'b': tensor([[2.]])}
update_parameters(model: torch.nn.Module, polyak: float = 1) None

Update internal parameters by hard or soft (polyak averaging) update

  • Hard update: \(\theta = \theta_{net}\)

  • Soft (polyak averaging) update: \(\theta = (1 - \rho) \theta + \rho \theta_{net}\)

  • model (torch.nn.Module (skrl.models.torch.Model)) – Model used to update the internal parameters

  • polyak (float, optional) – Polyak hyperparameter between 0 and 1 (default: 1). A hard update is performed when its value is 1


# hard update (from source model)
>>> model.update_parameters(source_model)

# soft update (from source model)
>>> model.update_parameters(source_model, polyak=0.005)


class skrl.models.jax.base.Model(*args: Any, **kwargs: Any)

Bases: Module

__init__(observation_space: int | Sequence[int] | gym.Space | gymnasium.Space, action_space: int | Sequence[int] | gym.Space | gymnasium.Space, device: str | jax.Device | None = None, parent: Any | None = None, name: str | None = None) None

Base class representing a function approximator

The following properties are defined:

  • device (jax.Device): Device to be used for the computations

  • observation_space (int, sequence of int, gym.Space, gymnasium.Space): Observation/state space

  • action_space (int, sequence of int, gym.Space, gymnasium.Space): Action space

  • num_observations (int): Number of elements in the observation/state space

  • num_actions (int): Number of elements in the action space

  • observation_space (int, sequence of int, gym.Space, gymnasium.Space) – Observation/state space or shape. The num_observations property will contain the size of that space

  • action_space (int, sequence of int, gym.Space, gymnasium.Space) – Action space or shape. The num_actions property will contain the size of that space

  • device (str or jax.Device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

  • parent (str, optional) – The parent Module of this Module (default: None). It is a Flax reserved attribute

  • name (str, optional) – The name of this Module (default: None). It is a Flax reserved attribute

Custom models should override the act method:

import flax.linen as nn
from skrl.models.jax import Model

class CustomModel(Model):
    def __init__(self, observation_space, action_space, device=None, **kwargs):
        Model.__init__(self, observation_space, action_space, device, **kwargs)


    def __call__(self, inputs, role):
        x = nn.relu(nn.Dense(64)(inputs["states"]))
        x = nn.relu(nn.Dense(self.num_actions)(x))
        return x, None, {}
property device

Device to be used for the computations

property observation_space

Observation/state space. It is a replica of the class constructor parameter of the same name

property action_space

Action space. It is a replica of the class constructor parameter of the same name

property num_observations

Number of elements in the observation/state space

property num_actions

Number of elements in the action space

_get_space_size(space: int | Sequence[int] | gym.Space | gymnasium.Space, number_of_elements: bool = True) int

Get the size (number of elements) of a space

  • space (int, sequence of int, gym.Space, or gymnasium.Space) – Space or shape from which to obtain the number of elements

  • number_of_elements (bool, optional) – Whether the number of elements occupied by the space is returned (default: True). If False, the shape of the space is returned. It only affects Discrete and MultiDiscrete spaces


ValueError – If the space is not supported


Size of the space (number of elements)

Return type:



# from int
>>> model._get_space_size(2)

# from sequence of int
>>> model._get_space_size([2, 3])

# Box space
>>> space = gym.spaces.Box(low=-1, high=1, shape=(2, 3))
>>> model._get_space_size(space)

# Discrete space
>>> space = gym.spaces.Discrete(4)
>>> model._get_space_size(space)
>>> model._get_space_size(space, number_of_elements=False)

# MultiDiscrete space
>>> space = gym.spaces.MultiDiscrete([5, 3, 2])
>>> model._get_space_size(space)
>>> model._get_space_size(space, number_of_elements=False)

# Dict space
>>> space = gym.spaces.Dict({'a': gym.spaces.Box(low=-1, high=1, shape=(2, 3)),
...                          'b': gym.spaces.Discrete(4)})
>>> model._get_space_size(space)
>>> model._get_space_size(space, number_of_elements=False)
act(inputs: Mapping[str, ndarray | jax.Array | Any], role: str = '', params: jax.Array | None = None) Tuple[jax.Array, jax.Array | None, Mapping[str, jax.Array | Any]]

Act according to the specified behavior (to be implemented by the inheriting classes)

Agents will call this method to obtain the decision to be taken given the state of the environment. The classes that inherit from the latter must only implement the .__call__() method

  • inputs (dict where the values are typically np.ndarray or jax.Array) –

    Model inputs. The most common keys are:

    • "states": state of the environment used to make the decision

    • "taken_actions": actions taken by the policy for the given states

  • role (str, optional) – Role play by the model (default: "")

  • params (jnp.array) – Parameters used to compute the output (default: None). If None, internal parameters will be used


NotImplementedError – Child class must implement this method


Model output. The first component is the action to be taken by the agent. The second component is the log of the probability density function for stochastic models or None for deterministic models. The third component is a dictionary containing extra output values

Return type:

tuple of jax.Array, jax.Array or None, and dict

action_space: int | Sequence[int] | gym.Space | gymnasium.Space
device: str | jax.Device | None = None
freeze_parameters(freeze: bool = True) None

Freeze or unfreeze internal parameters


This method does nothing, just maintains compatibility with other ML frameworks


freeze (bool, optional) – Freeze the internal parameters if True, otherwise unfreeze them (default: True)


# freeze model parameters
>>> model.freeze_parameters(True)

# unfreeze model parameters
>>> model.freeze_parameters(False)
get_specification() Mapping[str, Any]

Returns the specification of the model

The following keys are used by the agents for initialization:

  • "rnn": Recurrent Neural Network (RNN) specification for RNN, LSTM and GRU layers/cells

    • "sizes": List of RNN shapes (number of layers, number of environments, number of features in the RNN state). There must be as many tuples as there are states in the recurrent layer/cell. E.g., LSTM has 2 states (hidden and cell).


Dictionary containing advanced specification of the model

Return type:



# model with a LSTM layer.
# - number of layers: 1
# - number of environments: 4
# - number of features in the RNN state: 64
>>> model.get_specification()
{'rnn': {'sizes': [(1, 4, 64), (1, 4, 64)]}}
init_biases(method_name: str = 'constant_', *args, **kwargs) None

Initialize the model biases according to the specified method name

Method names are from the flax.linen.initializers module. Allowed method names are uniform, normal, constant, etc.

  • method_name (str, optional) –

    flax.linen.initializers method name (default: "normal")

  • args (tuple, optional) – Positional arguments of the method to be called

  • kwargs (dict, optional) – Key-value arguments of the method to be called


# initialize all biases with a constant value (0)
>>> model.init_biases(method_name="constant_", val=0)

# initialize all biases with normal distribution with mean 0 and standard deviation 0.25
>>> model.init_biases(method_name="normal_", mean=0.0, std=0.25)
init_parameters(method_name: str = 'normal', *args, **kwargs) None

Initialize the model parameters according to the specified method name

Method names are from the flax.linen.initializers module. Allowed method names are uniform, normal, constant, etc.

  • method_name (str, optional) –

    flax.linen.initializers method name (default: "normal")

  • args (tuple, optional) – Positional arguments of the method to be called

  • kwargs (dict, optional) – Key-value arguments of the method to be called


# initialize all parameters with an orthogonal distribution with a scale of 0.5
>>> model.init_parameters("orthogonal", scale=0.5)

# initialize all parameters as a normal distribution with a standard deviation of 0.1
>>> model.init_parameters("normal", stddev=0.1)
init_state_dict(role: str, inputs: Mapping[str, ndarray | jax.Array] = {}, key: jax.Array | None = None) None

Initialize state dictionary

  • role (str) – Role play by the model

  • inputs (dict of np.ndarray or jax.Array, optional) –

    Model inputs. The most common keys are:

    • "states": state of the environment used to make the decision

    • "taken_actions": actions taken by the policy for the given states

    If not specified, the keys will be populated with observation and action space samples

  • key (jax.Array, optional) – Pseudo-random number generator (PRNG) key (default: None). If not provided, the skrl’s PRNG key (config.jax.key) will be used

init_weights(method_name: str = 'normal', *args, **kwargs) None

Initialize the model weights according to the specified method name

Method names are from the flax.linen.initializers module. Allowed method names are uniform, normal, constant, etc.

  • method_name (str, optional) –

    flax.linen.initializers method name (default: "normal")

  • args (tuple, optional) – Positional arguments of the method to be called

  • kwargs (dict, optional) – Key-value arguments of the method to be called


# initialize all weights with uniform distribution in range [-0.1, 0.1]
>>> model.init_weights(method_name="uniform_", a=-0.1, b=0.1)

# initialize all weights with normal distribution with mean 0 and standard deviation 0.25
>>> model.init_weights(method_name="normal_", mean=0.0, std=0.25)
load(path: str) None

Load the model from the specified path


path (str) – Path to load the model from


# load the model
>>> model = Model(observation_space, action_space)
>>> model.load("model.flax")
migrate(state_dict: Mapping[str, Any] | None = None, path: str | None = None, name_map: Mapping[str, str] = {}, auto_mapping: bool = True, verbose: bool = False) bool

Migrate the specified extrernal model’s state dict to the current model


This method is not implemented yet, just maintains compatibility with other ML frameworks


NotImplementedError – Not implemented

observation_space: int | Sequence[int] | gym.Space | gymnasium.Space
random_act(inputs: Mapping[str, ndarray | jax.Array | Any], role: str = '', params: jax.Array | None = None) Tuple[ndarray | jax.Array, ndarray | jax.Array | None, Mapping[str, ndarray | jax.Array | Any]]

Act randomly according to the action space

  • inputs (dict where the values are typically np.ndarray or jax.Array) –

    Model inputs. The most common keys are:

    • "states": state of the environment used to make the decision

    • "taken_actions": actions taken by the policy for the given states

  • role (str, optional) – Role play by the model (default: "")

  • params (jnp.array) – Parameters used to compute the output (default: None). If None, internal parameters will be used


NotImplementedError – Unsupported action space


Model output. The first component is the action to be taken by the agent

Return type:

tuple of np.ndarray or jax.Array, None, and dict

save(path: str, state_dict: dict | None = None) None

Save the model to the specified path

  • path (str) – Path to save the model to

  • state_dict (dict, optional) – State dictionary to save (default: None). If None, the model’s state_dict will be saved


# save the current model to the specified path

# TODO: save an older version of the model to the specified path
set_mode(mode: str) None

Set the model mode (training or evaluation)


mode (str) – Mode: "train" for training or "eval" for evaluation


ValueError – If the mode is not "train" or "eval"

tensor_to_space(tensor: ndarray | jax.Array, space: gym.Space | gymnasium.Space, start: int = 0) ndarray | jax.Array | dict

Map a flat tensor to a Gym/Gymnasium space

The mapping is done in the following way:

  • Tensors belonging to Discrete spaces are returned without modification

  • Tensors belonging to Box spaces are reshaped to the corresponding space shape keeping the first dimension (number of samples) as they are

  • Tensors belonging to Dict spaces are mapped into a dictionary with the same keys as the original space

  • tensor (np.ndarray or jax.Array) – Tensor to map from

  • space (gym.Space or gymnasium.Space) – Space to map the tensor to

  • start (int, optional) – Index of the first element of the tensor to map (default: 0)


ValueError – If the space is not supported


Mapped tensor or dictionary

Return type:

np.ndarray or jax.Array, or dict


>>> space = gym.spaces.Dict({'a': gym.spaces.Box(low=-1, high=1, shape=(2, 3)),
...                          'b': gym.spaces.Discrete(4)})
>>> tensor = jnp.array([[-0.3, -0.2, -0.1, 0.1, 0.2, 0.3, 2]])
>>> model.tensor_to_space(tensor, space)
{'a': Array([[[-0.3, -0.2, -0.1],
              [ 0.1,  0.2,  0.3]]], dtype=float32),
 'b': Array([[2.]], dtype=float32)}
update_parameters(model: flax.linen.Module, polyak: float = 1) None

Update internal parameters by hard or soft (polyak averaging) update

  • Hard update: \(\theta = \theta_{net}\)

  • Soft (polyak averaging) update: \(\theta = (1 - \rho) \theta + \rho \theta_{net}\)

  • model (flax.linen.Module (skrl.models.jax.Model)) – Model used to update the internal parameters

  • polyak (float, optional) – Polyak hyperparameter between 0 and 1 (default: 1). A hard update is performed when its value is 1


# hard update (from source model)
>>> model.update_parameters(source_model)

# soft update (from source model)
>>> model.update_parameters(source_model, polyak=0.005)