Models

Models (or agent models) refer to a representation of the agent’s policy, value function, etc. that the agent uses to make decisions. Agents can have one or more models, and their parameters are adjusted by the optimization algorithms.



Implemented models

The following table lists the implemented models and their support for different frameworks.

Models

    pytorch    

    jax    

    warp    

Tabular model (discrete domain)

\(\blacksquare\)

\(\square\)

\(\square\)

Categorical model (discrete domain)

\(\blacksquare\)

\(\blacksquare\)

\(\square\)

Multi-Categorical model (discrete domain)

\(\blacksquare\)

\(\blacksquare\)

\(\square\)

Gaussian model (continuous domain)

\(\blacksquare\)

\(\blacksquare\)

\(\blacksquare\)

Multivariate Gaussian model (continuous domain)

\(\blacksquare\)

\(\square\)

\(\square\)

Deterministic model (continuous domain)

\(\blacksquare\)

\(\blacksquare\)

\(\blacksquare\)

Shared model

\(\blacksquare\)

\(\square\)

\(\blacksquare\)



Base class

Base class for models.

API


PyTorch

Model

Base model class for implementing custom models.

class skrl.models.torch.Model(*args: Any, **kwargs: Any)[source]

Bases: Module, ABC

Base model class for implementing custom models.

Parameters:
  • observation_space – Observation space. The num_observations property will contain the size of the space.

  • state_space – State space. The num_states property will contain the size of the space.

  • action_space – Action space. The num_actions property will contain the size of the space.

  • device – Data allocation and computation device. If not specified, the default device will be used.

Methods:

act(inputs, *[, role])

Act according to the specified behavior.

broadcast_parameters(*[, rank])

Broadcast model parameters to the whole group (e.g.: across all nodes) in distributed runs.

compute(inputs, *[, role])

Define the computation performed by the model.

enable_training_mode([enabled])

Set the training mode of the model: enabled (training) or disabled (evaluation).

forward(inputs, *[, role])

Forward pass of the model.

freeze_parameters([freeze])

Freeze or unfreeze internal parameters.

get_specification()

Returns the specification of the model.

init_biases([method_name])

Initialize the model biases according to the specified method name.

init_parameters([method_name])

Initialize the model parameters according to the specified method name.

init_state_dict([inputs, role])

Initialize lazy modules' parameters.

init_weights([method_name])

Initialize the model weights according to the specified method name.

load(path)

Load the model from the specified path.

random_act(inputs, *[, role])

Act randomly according to the action space.

reduce_parameters()

Reduce model parameters across all workers/processes in the whole group (e.g.: across all nodes).

save(path, *[, state_dict])

Save the model to the specified path.

update_parameters(model, *[, polyak])

Update internal parameters by hard or soft (polyak averaging) update.

abstractmethod act(inputs: dict[str, Any], *, role: str = '') tuple[torch.Tensor, dict[str, Any]][source]

Act according to the specified behavior.

Agents will call this method to get the expected action/value based on the observations/states.

Warning

This method is currently implemented by the helper models (e.g.: GaussianMixin). The classes that inherit from the latter must only implement the compute() method.

Parameters:
  • inputs

    Model inputs. The most common keys are:

    • "observations": observation of the environment used to make the decision.

    • "states": state of the environment used to make the decision.

    • "taken_actions": actions taken by the policy for the given observations/states.

  • role – Role played by the model.

Returns:

Model output. The first component is the expected action/value returned by the model. The second component is a dictionary containing extra output values according to the model.

Raises:

NotImplementedError – This method must be implemented by subclasses.

broadcast_parameters(*, rank: int = 0) None[source]

Broadcast model parameters to the whole group (e.g.: across all nodes) in distributed runs.

After calling this method, the distributed model will contain the broadcasted parameters from rank.

Parameters:

rank – Worker/process rank from which to broadcast model parameters.

Example:

# broadcast model parameter from worker/process with rank 1
>>> if config.torch.is_distributed:
...     model.broadcast_parameters(rank=1)
abstractmethod compute(inputs: dict[str, Any], *, role: str = '') tuple[torch.Tensor, dict[str, Any]][source]

Define the computation performed by the model.

Warning

This method is abstract and must be implemented by subclasses.

Parameters:
  • inputs

    Model inputs. The most common keys are:

    • "observations": observation of the environment used to make the decision.

    • "states": state of the environment used to make the decision.

    • "taken_actions": actions taken by the policy for the given observations/states.

  • role – Role played by the model.

Returns:

Computation performed by the model.

Raises:

NotImplementedError – This method must be implemented by subclasses.

enable_training_mode(enabled: bool = True) None[source]

Set the training mode of the model: enabled (training) or disabled (evaluation).

Parameters:

enabled – True to enable the training mode, False to enable the evaluation mode. See torch.nn.Module.train() for more details.

forward(inputs: dict[str, Any], *, role: str = '') tuple[torch.Tensor, dict[str, Any]][source]

Forward pass of the model.

Note

This method calls the act() method and returns its outputs. It exists for compatibility with the torch.nn.Module class.

Parameters:
  • inputs

    Model inputs. The most common keys are:

    • "observations": observation of the environment used to make the decision.

    • "states": state of the environment used to make the decision.

    • "taken_actions": actions taken by the policy for the given observations/states.

  • role – Role played by the model.

Returns:

Model output. The first component is the expected action/value returned by the model. The second component is a dictionary containing extra output values according to the model.

freeze_parameters(freeze: bool = True) None[source]

Freeze or unfreeze internal parameters.

  • Freeze: disable gradient computation (parameters.requires_grad = False).

  • Unfreeze: enable gradient computation (parameters.requires_grad = True).

Parameters:

freeze – Whether to freeze or unfreeze the internal parameters.

Example:

# freeze model parameters
>>> model.freeze_parameters(True)

# unfreeze model parameters
>>> model.freeze_parameters(False)
get_specification() dict[str, Any][source]

Returns the specification of the model.

The following keys are used by the agents for initialization:

  • "rnn": Recurrent Neural Network (RNN) specification for RNN, LSTM and GRU layers/cells.

    • "sizes": List of RNN shapes (number of layers, number of environments, number of features in the RNN state). There must be as many tuples as there are states in the recurrent layer/cell. E.g.: LSTM has 2 states (hidden and cell).

Returns:

Dictionary containing advanced specification of the model.

Example:

# model with a LSTM layer
# - number of layers: 1
# - number of environments: 4
# - number of features in the RNN state: 64
>>> model.get_specification()
{'rnn': {'sizes': [(1, 4, 64), (1, 4, 64)]}}
init_biases(method_name: str = 'constant_', *args, **kwargs) None[source]

Initialize the model biases according to the specified method name.

Method names are from the torch.nn.init module. Allowed method names are "uniform_", "normal_", "constant_", etc.

The following layers will be initialized:

  • torch.nn.Linear

Parameters:
  • method_name

    torch.nn.init method name.

  • args – Positional arguments of the method to be called.

  • kwargs – Key-value arguments of the method to be called.

Example:

# initialize all biases with a constant value (0)
>>> model.init_biases(method_name="constant_", val=0)

# initialize all biases with normal distribution with mean 0 and standard deviation 0.25
>>> model.init_biases(method_name="normal_", mean=0.0, std=0.25)
init_parameters(method_name: str = 'normal_', *args, **kwargs) None[source]

Initialize the model parameters according to the specified method name.

Method names are from the torch.nn.init module. Allowed method names are "uniform_", "normal_", "constant_", etc.

Parameters:
  • method_name

    torch.nn.init method name.

  • args – Positional arguments of the method to be called.

  • kwargs – Key-value arguments of the method to be called.

Example:

# initialize all parameters with an orthogonal distribution with a gain of 0.5
>>> model.init_parameters("orthogonal_", gain=0.5)

# initialize all parameters as a sparse matrix with a sparsity of 0.1
>>> model.init_parameters("sparse_", sparsity=0.1)
init_state_dict(inputs: dict[str, Any] = {}, *, role: str = '') None[source]

Initialize lazy modules’ parameters.

Hint

Calling this method only makes sense when using models that contain lazy modules (e.g. model instantiators), and always before performing any operation on model parameters.

Parameters:
  • inputs

    Model inputs. The most common keys are:

    • "observations": observation of the environment used to make the decision.

    • "states": state of the environment used to make the decision.

    • "taken_actions": actions taken by the policy for the given observations/states.

    If not specified, inputs will have random samples from the observation, state and action spaces.

  • role – Role played by the model.

init_weights(method_name: str = 'orthogonal_', *args, **kwargs) None[source]

Initialize the model weights according to the specified method name.

Method names are from the torch.nn.init module. Allowed method names are "uniform_", "normal_", "constant_", etc.

The following layers will be initialized:

  • torch.nn.Linear

Parameters:
  • method_name

    torch.nn.init method name.

  • args – Positional arguments of the method to be called.

  • kwargs – Key-value arguments of the method to be called.

Example:

# initialize all weights with uniform distribution in range [-0.1, 0.1]
>>> model.init_weights(method_name="uniform_", a=-0.1, b=0.1)

# initialize all weights with normal distribution with mean 0 and standard deviation 0.25
>>> model.init_weights(method_name="normal_", mean=0.0, std=0.25)
load(path: str) None[source]

Load the model from the specified path.

Note

The final storage device is determined by the constructor of the model.

Parameters:

path – Path to load the model from.

Example:

# load the model onto the CPU
>>> model = Model(device="cpu")
>>> model.load("model.pt")

# load the model onto the GPU 1
>>> model = Model(device="cuda:1")
>>> model.load("model.pt")
random_act(inputs: dict[str, Any], *, role: str = '') tuple[torch.Tensor, dict[str, Any]][source]

Act randomly according to the action space.

Warning

Sampling from unbounded action spaces may lead to numerical instabilities.

Parameters:
  • inputs

    Model inputs. The most common keys are:

    • "observations": observation of the environment used to make the decision.

    • "states": state of the environment used to make the decision.

    • "taken_actions": actions taken by the policy for the given observations/states.

  • role – Role played by the model.

Returns:

Randomly sampled actions with the same batch size as the given observation ("observations") in the inputs as the first component. The second component is an empty dictionary.

Raises:

ValueError – Unsupported action space.

reduce_parameters() None[source]

Reduce model parameters across all workers/processes in the whole group (e.g.: across all nodes).

After calling this method, the distributed model parameters will be bitwise identical for all workers/processes.

Example:

# reduce model parameter across all workers/processes
>>> if config.torch.is_distributed:
...     model.reduce_parameters()
save(path: str, *, state_dict: dict[str, Any] | None = None) None[source]

Save the model to the specified path.

Parameters:
  • path – Path to save the model to.

  • state_dict – State dictionary to save. If None, the model’s state_dict will be saved.

Example:

# save the current model to the specified path
>>> model.save("/tmp/model.pt")

# save an older version of the model to the specified path
>>> old_state_dict = copy.deepcopy(model.state_dict())
>>> # ...
>>> model.save("/tmp/model.pt", old_state_dict)
update_parameters(model: torch.nn.Module, *, polyak: float = 1.0) None[source]

Update internal parameters by hard or soft (polyak averaging) update.

  • Hard update: \(\theta = \theta_{net}\)

  • Soft (polyak averaging) update: \(\theta = (1 - \rho) \theta + \rho \theta_{net}\)

Parameters:
  • model – Model used to update the internal parameters.

  • polyak – Polyak hyperparameter between 0 and 1. A hard update is performed when its value is 1.

Example:

# hard update (from source model)
>>> model.update_parameters(source_model)

# soft update (from source model)
>>> model.update_parameters(source_model, polyak=0.005)

JAX

Model

Base model class for implementing custom models.

class skrl.models.jax.Model(*args: Any, **kwargs: Any)[source]

Bases: Module, ABC

Base model class for implementing custom models.

Parameters:
  • observation_space – Observation space. The num_observations property will contain the size of the space.

  • state_space – State space. The num_states property will contain the size of the space.

  • action_space – Action space. The num_actions property will contain the size of the space.

  • device – Data allocation and computation device. If not specified, the default device will be used.

  • parent – The parent Module of this Module. It is a Flax reserved attribute.

  • name – The name of this Module. It is a Flax reserved attribute.

Methods:

act(inputs, *[, role, params])

Act according to the specified behavior.

broadcast_parameters(*[, rank])

Broadcast model parameters to the whole group (e.g.: across all nodes) in distributed runs.

enable_training_mode([enabled])

Set the training mode of the model: enabled (training) or disabled (evaluation).

freeze_parameters([freeze])

Freeze or unfreeze internal parameters.

get_specification()

Returns the specification of the model.

init_biases([method_name])

Initialize the model biases according to the specified method name.

init_parameters([method_name])

Initialize the model parameters according to the specified method name.

init_state_dict([inputs, role, key])

Initialize state dictionary.

init_weights([method_name])

Initialize the model weights according to the specified method name.

load(path)

Load the model from the specified path.

random_act(inputs, *[, role, params])

Act randomly according to the action space.

reduce_parameters(tree)

Reduce model parameters across all workers/processes in the whole group (e.g.: across all nodes).

save(path, *[, state_dict])

Save the model to the specified path.

update_parameters(model, *[, polyak])

Update internal parameters by hard or soft (polyak averaging) update.

Attributes:

abstractmethod act(inputs: dict[str, Any], *, role: str = '', params: jax.Array | None = None) tuple[jax.Array, dict[str, Any]][source]

Act according to the specified behavior.

Agents will call this method to get the expected action/value based on the observations/states.

Warning

This method is currently implemented by the helper models (e.g.: GaussianMixin). The classes that inherit from the latter must only implement the .__call__() method.

Parameters:
  • inputs

    Model inputs. The most common keys are:

    • "observations": observation of the environment used to make the decision.

    • "states": state of the environment used to make the decision.

    • "taken_actions": actions taken by the policy for the given observations/states.

  • role – Role played by the model.

  • params – Parameters used to compute the output. If not provided, internal parameters will be used.

Returns:

Model output. The first component is the expected action/value returned by the model. The second component is a dictionary containing extra output values according to the model.

Raises:

NotImplementedError – This method must be implemented by subclasses.

broadcast_parameters(*, rank: int = 0) None[source]

Broadcast model parameters to the whole group (e.g.: across all nodes) in distributed runs.

After calling this method, the distributed model will contain the broadcasted parameters from rank.

Parameters:

rank – Worker/process rank from which to broadcast model parameters.

Example:

# broadcast model parameter from worker/process with rank 1
>>> if config.jax.is_distributed:
...     model.broadcast_parameters(rank=1)
enable_training_mode(enabled: bool = True) None[source]

Set the training mode of the model: enabled (training) or disabled (evaluation).

Parameters:

enabled – True to enable the training mode, False to enable the evaluation mode. The specific behavior can be accessed via the training property.

freeze_parameters(freeze: bool = True) None[source]

Freeze or unfreeze internal parameters.

Note

This method does nothing, just maintains compatibility with other ML frameworks.

Parameters:

freeze – Whether to freeze or unfreeze the internal parameters.

Example:

# freeze model parameters
>>> model.freeze_parameters(True)

# unfreeze model parameters
>>> model.freeze_parameters(False)
get_specification() dict[str, Any][source]

Returns the specification of the model.

The following keys are used by the agents for initialization:

  • "rnn": Recurrent Neural Network (RNN) specification for RNN, LSTM and GRU layers/cells.

    • "sizes": List of RNN shapes (number of layers, number of environments, number of features in the RNN state). There must be as many tuples as there are states in the recurrent layer/cell. E.g.: LSTM has 2 states (hidden and cell).

Returns:

Dictionary containing advanced specification of the model.

Example:

# model with a LSTM layer
# - number of layers: 1
# - number of environments: 4
# - number of features in the RNN state: 64
>>> model.get_specification()
{'rnn': {'sizes': [(1, 4, 64), (1, 4, 64)]}}
init_biases(method_name: str = 'constant_', *args, **kwargs) None[source]

Initialize the model biases according to the specified method name.

Method names are from the flax.linen.initializers module. Allowed method names are "uniform", "normal", "constant", etc.

Parameters:
  • method_name

    flax.linen.initializers method name.

  • args – Positional arguments of the method to be called.

  • kwargs – Key-value arguments of the method to be called.

Example:

# initialize all biases with a constant value (0)
>>> model.init_biases(method_name="constant_", val=0)

# initialize all biases with normal distribution with mean 0 and standard deviation 0.25
>>> model.init_biases(method_name="normal", stddev=0.25)
init_parameters(method_name: str = 'normal', *args, **kwargs) None[source]

Initialize the model parameters according to the specified method name.

Method names are from the flax.linen.initializers module. Allowed method names are "uniform", "normal", "constant", etc.

Parameters:
  • method_name

    flax.linen.initializers method name.

  • args – Positional arguments of the method to be called.

  • kwargs – Key-value arguments of the method to be called.

Example:

# initialize all parameters with an orthogonal distribution with a scale of 0.5
>>> model.init_parameters("orthogonal", scale=0.5)

# initialize all parameters as a normal distribution with a standard deviation of 0.1
>>> model.init_parameters("normal", stddev=0.1)
init_state_dict(inputs: dict[str, Any] = {}, *, role: str = '', key: jax.Array | None = None) None[source]

Initialize state dictionary.

Parameters:
  • inputs

    Model inputs. The most common keys are:

    • "observations": observation of the environment used to make the decision.

    • "states": state of the environment used to make the decision.

    • "taken_actions": actions taken by the policy for the given observations/states.

    If not specified, inputs will have random samples from the observation, state and action spaces.

  • role – Role played by the model.

  • key – Pseudo-random number generator (PRNG) key. If not provided, config.jax.key will be used.

init_weights(method_name: str = 'normal', *args, **kwargs) None[source]

Initialize the model weights according to the specified method name.

Method names are from the flax.linen.initializers module. Allowed method names are "uniform", "normal", "constant", etc.

Parameters:
  • method_name

    flax.linen.initializers method name.

  • args – Positional arguments of the method to be called.

  • kwargs – Key-value arguments of the method to be called.

Example:

# initialize all weights with uniform distribution in range [-0.1, 0.1]
>>> model.init_weights(method_name="uniform_", a=-0.1, b=0.1)

# initialize all weights with normal distribution with mean 0 and standard deviation 0.25
>>> model.init_weights(method_name="normal", stddev=0.25)
load(path: str) None[source]

Load the model from the specified path.

Parameters:

path – Path to load the model from.

Example:

# load the model
>>> model = Model(observation_space, action_space)
>>> model.load("model.flax")
random_act(inputs: dict[str, Any], *, role: str = '', params: jax.Array | None = None) tuple[jax.Array, dict[str, Any]][source]

Act randomly according to the action space.

Warning

Sampling from unbounded action spaces may lead to numerical instabilities.

Parameters:
  • inputs

    Model inputs. The most common keys are:

    • "observations": observation of the environment used to make the decision.

    • "states": state of the environment used to make the decision.

    • "taken_actions": actions taken by the policy for the given observations/states.

  • role – Role played by the model.

  • params – Parameters used to compute the output. If not provided, internal parameters will be used.

Returns:

Randomly sampled actions with the same batch size as the given observation ("observations") in the inputs as the first component. The second component is an empty dictionary.

Raises:

ValueError – Unsupported action space.

reduce_parameters(tree: Any) Any[source]

Reduce model parameters across all workers/processes in the whole group (e.g.: across all nodes).

After calling this method, the distributed model parameters will be bitwise identical for all workers/processes.

Parameters:

tree – Pytree to apply collective reduction.

Returns:

All-reduced pytree.

Example:

# reduce model parameter across all workers/processes
>>> if config.jax.is_distributed:
...     model.reduce_parameters(grad)
save(path: str, *, state_dict: dict[str, Any] | None = None) None[source]

Save the model to the specified path.

Parameters:
  • path – Path to save the model to.

  • state_dict – State dictionary to save. If None, the model’s state_dict will be saved.

Example:

# save the current model to the specified path
>>> model.save("/tmp/model.flax")
update_parameters(model: flax.linen.Module, *, polyak: float = 1.0) None[source]

Update internal parameters by hard or soft (polyak averaging) update.

  • Hard update: \(\theta = \theta_{net}\)

  • Soft (polyak averaging) update: \(\theta = (1 - \rho) \theta + \rho \theta_{net}\)

Parameters:
  • model – Model used to update the internal parameters.

  • polyak – Polyak hyperparameter between 0 and 1. A hard update is performed when its value is 1.

Example:

# hard update (from source model)
>>> model.update_parameters(source_model)

# soft update (from source model)
>>> model.update_parameters(source_model, polyak=0.005)
action_space: gymnasium.Space | None = None
device: str | jax.Device | None = None
observation_space: gymnasium.Space | None = None
state_dict: StateDict
state_space: gymnasium.Space | None = None

Warp

Model

Base model class for implementing custom models.