Model instantiators

Utilities for quickly creating model instances.



Implemented model instantiators

The following table lists the implemented model instantiators and their support for different frameworks.

Models

    pytorch    

    jax    

    warp    

Tabular model (discrete domain)

\(\blacksquare\)

\(\square\)

\(\square\)

Categorical model (discrete domain)

\(\blacksquare\)

\(\blacksquare\)

\(\square\)

Multi-Categorical model (discrete domain)

\(\blacksquare\)

\(\blacksquare\)

\(\square\)

Gaussian model (continuous domain)

\(\blacksquare\)

\(\blacksquare\)

\(\blacksquare\)

Multivariate Gaussian model (continuous domain)

\(\blacksquare\)

\(\square\)

\(\square\)

Deterministic model (continuous domain)

\(\blacksquare\)

\(\blacksquare\)

\(\blacksquare\)

Shared model

\(\blacksquare\)

\(\square\)

\(\blacksquare\)


Network definitions

The network is composed of one or more containers. For each container, the input, hidden layers and activation functions can be specified.

Implementation details:

  • Container names must be valid Python identifiers and unique within the model.

  • The network compute/forward is done by calling the containers in the order in which they are defined.

  • Containers use torch.nn.Sequential in PyTorch, and flax.linen.Sequential in JAX.

  • If a single activation function is specified (mapping or sequence), it will be applied after each layer (except flatten layers) in the container.

network:
  - name: <NAME>  # container name
    input: <INPUT>  # container input (certain operations are supported)
    layers:  # list of supported layers
      - <LAYER 1>
      - ...
      - <LAYER N>
    activations:  # list of supported activation functions
      - <ACTIVATION 1>
      - ...
      - <ACTIVATION N>

Inputs

Inputs can be specified using tokens or previously defined container outputs (by container names). Certain operations could be specified on them, including indexing and slicing.

Hint

Operations can be mixed to create complex input statements.

Available tokens:

  • OBSERVATIONS: Unflattened tensorized input observations (tensorized observation space) forwarded to the model.

  • STATES: Unflattened tensorized input states (tensorized state space) forwarded to the model.

  • ACTIONS: Unflattened tensorized input taken_actions (tensorized action space) forwarded to the model.

  • OBSERVATION_SPACE: Token indicating the observation_space of the model.

  • STATE_SPACE: Token indicating the state_space of the model.

  • ACTION_SPACE: Token indicating the action_space of the model.

Supported operations:

Operations

Example

Tensor/array indexing and slicing.
E.g.: Box space

OBSERVATIONS[:, 0]
OBSERVATIONS[:, 2:5]

Dictionary indexing by key.
E.g.: Dict space

STATES["joint-pos"]

Arithmetic (+, -, *, /)

features_extractor + ACTIONS

Concatenation

concatenate([features_extractor, ACTIONS])

Permute dimensions

permute(OBSERVATIONS, (0, 3, 1, 2))

One-hot encoding Discrete
and MultiDiscrete spaces

one_hot_encoding(OBSERVATION_SPACE, OBSERVATIONS)


Outputs

Outputs can be specified using tokens or previously defined container outputs (by container names). Certain operations could be specified on them.

Note

If a token is used, a linear layer will be created with the last container in the list (as the number of input features) and the value represented by the token (as the number of output features).

Hint

Operations can be mixed to create complex output statements.

Available tokens:

  • ACTIONS: Token indicating that the output shape is the number of elements in the action space.

  • ONE: Token indicating that the output shape is 1.

Supported operations:

Operations

Example

Activation function

tanh(ACTIONS)

Arithmetic (+, -, *, /)

features_extractor + ONE

Concatenation

concatenate([features_extractor, net])


Activation functions

The following table lists the supported activation functions:

Activations

    pytorch    

    jax    

    warp    

relu

ReLU

relu

ReLU

tanh

Tanh

tanh

Tanh

sigmoid

Sigmoid

sigmoid

leaky_relu

LeakyReLU

leaky_relu

elu

ELU

elu

ELU

softplus

Softplus

softplus

softsign

Softsign

soft_sign

selu

SELU

selu

softmax

Softmax

softmax


Layers

The following table lists the supported layers and transformations:

Layers

    pytorch    

    jax    

    warp    

linear

Linear

Dense

Linear

conv2d

Conv2d

Conv

flatten

Flatten

reshape

Flatten


linear

Apply a linear transformation (torch.nn.Linear in PyTorch, flax.linen.Dense in JAX).

Note

The tokens NUM_OBSERVATIONS (number of elements in the observation space), NUM_STATES (number of elements in the state space), NUM_ACTIONS (number of elements in the action space), and ONE (1) can be used as the layer’s number of input/output features.

Note

If the PyTorch’s in_features parameter is not specified it will be inferred by using the torch.nn.LazyLinear module.

    pytorch    

    jax    

    warp    

Type

Required

Description

in_features

-

in_features

int

\(\square\)

Number of input features

0

out_features

features

out_features

int

\(\blacksquare\)

Number of output features

1

bias

use_bias

bias

bool

\(\square\)

Whether to add a bias

layers:
  - 32

conv2d

Apply a 2D convolution (torch.nn.Conv2d in PyTorch, flax.linen.Conv in JAX).

Warning

  • PyTorch torch.nn.Conv2d expects the input to be in the form NCHW (N: batch, C: channels, H: height, W: width). A permutation operation may be necessary to modify the dimensions of a batch of images which are typically NHWC.

  • JAX flax.linen.Conv expects the input to be in the form NHWC (the typical dimensions of a batch of images).

Note

If the PyTorch’s in_channels parameter is not specified it will be inferred by using the torch.nn.LazyConv2d module.

    pytorch    

    jax    

Type

Required

Description

in_channels

-

int

\(\square\)

Number of input channels

0

out_channels

features

int

\(\blacksquare\)

Number of output channels (filters)

1

kernel_size

kernel_size

int, tuple[int]

\(\blacksquare\)

Convolutional kernel size

2

stride

strides

int, tuple[int]

\(\square\)

Inter-window strides

3

padding

padding

str, int, tuple[int]

\(\square\)

Padding added to all dimensions

4

bias

use_bias

bool

\(\square\)

Whether to add a bias

layers:
  - conv2d: [32, 8, [4, 4]]

flatten

Flatten a contiguous range of dimensions (torch.nn.Flatten in PyTorch, jax.numpy.reshape operation in JAX).

    pytorch    

    jax    

    warp    

Type

Required

Description

0

start_dim

-

-

int

\(\square\)

First dimension to flatten

1

end_dim

-

-

int

\(\square\)

Last dimension to flatten

layers:
  - flatten

API


PyTorch

categorical_model

Instantiate a CategoricalMixin-based model.

multicategorical_model

Instantiate a MultiCategoricalMixin-based model.

deterministic_model

Instantiate a DeterministicMixin-based model.

gaussian_model

Instantiate a GaussianMixin-based model.

multivariate_gaussian_model

Instantiate a MultivariateGaussianMixin-based model.

shared_model

Instantiate a shared model

tabular_model

Instantiate a TabularMixin-based model.

skrl.utils.model_instantiators.torch.categorical_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | torch.device | None = None, unnormalized_log_prob: bool = True, network: list[dict[str, Any]] = [], output: str | list[str] = '', return_source: bool = False) Model | str[source]

Instantiate a CategoricalMixin-based model.

Parameters:
  • observation_space – Observation space. The num_observations property will contain the size of the space.

  • state_space – State space. The num_states property will contain the size of the space.

  • action_space – Action space. The num_actions property will contain the size of the space.

  • device – Data allocation and computation device. If not specified, the default device will be used.

  • unnormalized_log_prob – Flag to indicate how to the model’s output will be interpreted. If True, the model’s output is interpreted as unnormalized log probabilities (it can be any real number), otherwise as normalized probabilities (the output must be non-negative, finite and have a non-zero sum).

  • network – Network definition.

  • output – Output expression.

  • return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.

Returns:

Categorical model instance or definition source (if return_source is True).

skrl.utils.model_instantiators.torch.multicategorical_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | torch.device | None = None, unnormalized_log_prob: bool = True, reduction: Literal['mean', 'sum', 'prod', 'none'] = 'sum', network: list[dict[str, Any]] = [], output: str | list[str] = '', return_source: bool = False) Model | str[source]

Instantiate a MultiCategoricalMixin-based model.

Parameters:
  • observation_space – Observation space. The num_observations property will contain the size of the space.

  • state_space – State space. The num_states property will contain the size of the space.

  • action_space – Action space. The num_actions property will contain the size of the space.

  • device – Data allocation and computation device. If not specified, the default device will be used.

  • unnormalized_log_prob – Flag to indicate how to the model’s output will be interpreted. If True, the model’s output is interpreted as unnormalized log probabilities (it can be any real number), otherwise as normalized probabilities (the output must be non-negative, finite and have a non-zero sum).

  • reduction – Reduction method for returning the log probability density function. If "none", the log probability density function is returned as a tensor of shape (num_samples, num_actions) instead of (num_samples, 1).

  • network – Network definition.

  • output – Output expression.

  • return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.

Returns:

MultiCategorical model instance or definition source (if return_source is True).

skrl.utils.model_instantiators.torch.deterministic_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | torch.device | None = None, clip_actions: bool = False, network: list[dict[str, Any]] = [], output: str | list[str] = '', return_source: bool = False) Model | str[source]

Instantiate a DeterministicMixin-based model.

Parameters:
  • observation_space – Observation space. The num_observations property will contain the size of the space.

  • state_space – State space. The num_states property will contain the size of the space.

  • action_space – Action space. The num_actions property will contain the size of the space.

  • device – Data allocation and computation device. If not specified, the default device will be used.

  • clip_actions – Flag to indicate whether the actions should be clipped to the action space.

  • network – Network definition.

  • output – Output expression.

  • return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.

Returns:

Deterministic model instance or definition source (if return_source is True).

skrl.utils.model_instantiators.torch.gaussian_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | torch.device | None = None, clip_actions: bool = False, clip_mean_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20, max_log_std: float = 2, reduction: Literal['mean', 'sum', 'prod', 'none'] = 'sum', initial_log_std: float = 0, fixed_log_std: bool = False, network: list[dict[str, Any]] = [], output: str | list[str] = '', return_source: bool = False) Model | str[source]

Instantiate a GaussianMixin-based model.

Parameters:
  • observation_space – Observation space. The num_observations property will contain the size of the space.

  • state_space – State space. The num_states property will contain the size of the space.

  • action_space – Action space. The num_actions property will contain the size of the space.

  • device – Data allocation and computation device. If not specified, the default device will be used.

  • clip_actions – Flag to indicate whether the actions should be clipped to the action space.

  • clip_mean_actions – Flag to indicate whether the mean actions should be clipped to the action space. If True, the mean actions will be clipped before sampling the actions.

  • clip_log_std – Flag to indicate whether the log standard deviations should be clipped.

  • min_log_std – Minimum value of the log standard deviation if clip_log_std is True.

  • max_log_std – Maximum value of the log standard deviation if clip_log_std is True.

  • reduction – Reduction method for returning the log probability density function. If "none", the log probability density function is returned as a tensor of shape (num_samples, num_actions) instead of (num_samples, 1).

  • initial_log_std – Initial value for the log standard deviation.

  • fixed_log_std – Whether the log standard deviation parameter should be fixed. Fixed parameters have the gradient computation deactivated.

  • network – Network definition.

  • output – Output expression.

  • return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.

Returns:

Gaussian model instance or definition source (if return_source is True).

skrl.utils.model_instantiators.torch.multivariate_gaussian_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | torch.device | None = None, clip_actions: bool = False, clip_mean_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20, max_log_std: float = 2, initial_log_std: float = 0, fixed_log_std: bool = False, network: list[dict[str, Any]] = [], output: str | list[str] = '', return_source: bool = False) Model | str[source]

Instantiate a MultivariateGaussianMixin-based model.

Parameters:
  • observation_space – Observation space. The num_observations property will contain the size of the space.

  • state_space – State space. The num_states property will contain the size of the space.

  • action_space – Action space. The num_actions property will contain the size of the space.

  • device – Data allocation and computation device. If not specified, the default device will be used.

  • clip_actions – Flag to indicate whether the actions should be clipped to the action space.

  • clip_mean_actions – Flag to indicate whether the mean actions should be clipped to the action space. If True, the mean actions will be clipped before sampling the actions.

  • clip_log_std – Flag to indicate whether the log standard deviations should be clipped.

  • min_log_std – Minimum value of the log standard deviation if clip_log_std is True.

  • max_log_std – Maximum value of the log standard deviation if clip_log_std is True.

  • initial_log_std – Initial value for the log standard deviation.

  • fixed_log_std – Whether the log standard deviation parameter should be fixed. Fixed parameters have the gradient computation deactivated.

  • network – Network definition.

  • output – Output expression.

  • return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.

Returns:

Multivariate Gaussian model instance or definition source (if return_source is True).

skrl.utils.model_instantiators.torch.shared_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | torch.device | None = None, structure: list[str] = ['GaussianMixin', 'DeterministicMixin'], roles: list[str] = [], parameters: list[dict[str, Any]] = [], single_forward_pass: bool = True, return_source: bool = False) Model | str[source]

Instantiate a shared model

Parameters:
  • observation_space – Observation space. The num_observations property will contain the size of the space.

  • state_space – State space. The num_states property will contain the size of the space.

  • action_space – Action space. The num_actions property will contain the size of the space.

  • device – Data allocation and computation device. If not specified, the default device will be used.

  • structure – Shared model structure.

  • roles – Organized list of model roles.

  • parameters – Organized list of model instantiator parameters.

  • single_forward_pass – Whether to perform a single forward-pass for the shared layers/network.

  • return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.

Returns:

Shared model instance or definition source (if return_source is True).

skrl.utils.model_instantiators.torch.tabular_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | torch.device | None = None, variant: Literal['epsilon-greedy'] = 'epsilon-greedy', variant_kwargs: dict[str, Any] = {}, return_source: bool = False) Model | str[source]

Instantiate a TabularMixin-based model.

Supported variants:

  • epsilon-greedy: Simple method of balancing exploration and exploitation by randomly selecting one or the other.

    Argument

    Type

    Default

    Description

    epsilon

    float

    0.1

    Cut-off probability for choosing to explore

Parameters:
  • observation_space – Observation space. The num_observations property will contain the size of the space.

  • state_space – State space. The num_states property will contain the size of the space.

  • action_space – Action space. The num_actions property will contain the size of the space.

  • device – Data allocation and computation device. If not specified, the default device will be used.

  • variant – Variant of the model.

  • variant_kwargs – Variant-specific keyword arguments.

  • return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.

Returns:

Tabular model instance or definition source (if return_source is True).


JAX

categorical_model

Instantiate a CategoricalMixin-based model.

multicategorical_model

Instantiate a MultiCategoricalMixin-based model.

deterministic_model

Instantiate a DeterministicMixin-based model.

gaussian_model

Instantiate a GaussianMixin-based model.

skrl.utils.model_instantiators.jax.categorical_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | jax.Device | None = None, unnormalized_log_prob: bool = True, network: list[dict[str, Any]] = [], output: str | list[str] = '', return_source: bool = False) Model | str[source]

Instantiate a CategoricalMixin-based model.

Parameters:
  • observation_space – Observation space. The num_observations property will contain the size of the space.

  • state_space – State space. The num_states property will contain the size of the space.

  • action_space – Action space. The num_actions property will contain the size of the space.

  • device – Data allocation and computation device. If not specified, the default device will be used.

  • unnormalized_log_prob – Flag to indicate how to the model’s output will be interpreted. If True, the model’s output is interpreted as unnormalized log probabilities (it can be any real number), otherwise as normalized probabilities (the output must be non-negative, finite and have a non-zero sum).

  • network – Network definition.

  • output – Output expression.

  • return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.

Returns:

Categorical model instance or definition source (if return_source is True).

skrl.utils.model_instantiators.jax.multicategorical_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | jax.Device | None = None, unnormalized_log_prob: bool = True, reduction: Literal['mean', 'sum', 'prod', 'none'] = 'sum', network: list[dict[str, Any]] = [], output: str | list[str] = '', return_source: bool = False) Model | str[source]

Instantiate a MultiCategoricalMixin-based model.

Parameters:
  • observation_space – Observation space. The num_observations property will contain the size of the space.

  • state_space – State space. The num_states property will contain the size of the space.

  • action_space – Action space. The num_actions property will contain the size of the space.

  • device – Data allocation and computation device. If not specified, the default device will be used.

  • unnormalized_log_prob – Flag to indicate how to the model’s output will be interpreted. If True, the model’s output is interpreted as unnormalized log probabilities (it can be any real number), otherwise as normalized probabilities (the output must be non-negative, finite and have a non-zero sum).

  • reduction – Reduction method for returning the log probability density function. If "none", the log probability density function is returned as a tensor of shape (num_samples, num_actions) instead of (num_samples, 1).

  • network – Network definition.

  • output – Output expression.

  • return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.

Returns:

MultiCategorical model instance or definition source (if return_source is True).

skrl.utils.model_instantiators.jax.deterministic_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | jax.Device | None = None, clip_actions: bool = False, network: list[dict[str, Any]] = [], output: str | list[str] = '', return_source: bool = False) Model | str[source]

Instantiate a DeterministicMixin-based model.

Parameters:
  • observation_space – Observation space. The num_observations property will contain the size of the space.

  • state_space – State space. The num_states property will contain the size of the space.

  • action_space – Action space. The num_actions property will contain the size of the space.

  • device – Data allocation and computation device. If not specified, the default device will be used.

  • clip_actions – Flag to indicate whether the actions should be clipped to the action space.

  • network – Network definition.

  • output – Output expression.

  • return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.

Returns:

Deterministic model instance or definition source (if return_source is True).

skrl.utils.model_instantiators.jax.gaussian_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | jax.Device | None = None, clip_actions: bool = False, clip_mean_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20, max_log_std: float = 2, reduction: Literal['mean', 'sum', 'prod', 'none'] = 'sum', initial_log_std: float = 0, fixed_log_std: bool = False, network: list[dict[str, Any]] = [], output: str | list[str] = '', return_source: bool = False) Model | str[source]

Instantiate a GaussianMixin-based model.

Parameters:
  • observation_space – Observation space. The num_observations property will contain the size of the space.

  • state_space – State space. The num_states property will contain the size of the space.

  • action_space – Action space. The num_actions property will contain the size of the space.

  • device – Data allocation and computation device. If not specified, the default device will be used.

  • clip_actions – Flag to indicate whether the actions should be clipped to the action space.

  • clip_mean_actions – Flag to indicate whether the mean actions should be clipped to the action space. If True, the mean actions will be clipped before sampling the actions.

  • clip_log_std – Flag to indicate whether the log standard deviations should be clipped.

  • min_log_std – Minimum value of the log standard deviation if clip_log_std is True.

  • max_log_std – Maximum value of the log standard deviation if clip_log_std is True.

  • reduction – Reduction method for returning the log probability density function. If "none", the log probability density function is returned as a tensor of shape (num_samples, num_actions) instead of (num_samples, 1).

  • initial_log_std – Initial value for the log standard deviation.

  • fixed_log_std – Whether the log standard deviation parameter should be fixed. Fixed parameters have the gradient computation deactivated.

  • network – Network definition.

  • output – Output expression.

  • return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.

Returns:

Gaussian model instance or definition source (if return_source is True).


Warp

deterministic_model

Instantiate a DeterministicMixin-based model.

gaussian_model

Instantiate a GaussianMixin-based model.