Model instantiators¶
Utilities for quickly creating model instances.
Implemented model instantiators¶
The following table lists the implemented model instantiators and their support for different frameworks.
Models |
|
|
|
|---|---|---|---|
Tabular model (discrete domain) |
\(\blacksquare\) |
\(\square\) |
\(\square\) |
Categorical model (discrete domain) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\square\) |
Multi-Categorical model (discrete domain) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\square\) |
Gaussian model (continuous domain) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\blacksquare\) |
Multivariate Gaussian model (continuous domain) |
\(\blacksquare\) |
\(\square\) |
\(\square\) |
Deterministic model (continuous domain) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\square\) |
\(\blacksquare\) |
Network definitions¶
The network is composed of one or more containers. For each container, the input, hidden layers and activation functions can be specified.
Implementation details:
Container names must be valid Python identifiers and unique within the model.
The network compute/forward is done by calling the containers in the order in which they are defined.
Containers use
torch.nn.Sequentialin PyTorch, andflax.linen.Sequentialin JAX.If a single activation function is specified (mapping or sequence), it will be applied after each layer (except
flattenlayers) in the container.
network:
- name: <NAME> # container name
input: <INPUT> # container input (certain operations are supported)
layers: # list of supported layers
- <LAYER 1>
- ...
- <LAYER N>
activations: # list of supported activation functions
- <ACTIVATION 1>
- ...
- <ACTIVATION N>
network=[
{
"name": <NAME>, # container name
"input": <INPUT>, # container input (certain operations are supported)
"layers": [ # list of supported layers
<LAYER 1>,
...,
<LAYER N>,
],
"activations": [ # list of supported activation functions
<ACTIVATION 1>,
...,
<ACTIVATION N>,
],
},
]
Inputs¶
Inputs can be specified using tokens or previously defined container outputs (by container names). Certain operations could be specified on them, including indexing and slicing.
Hint
Operations can be mixed to create complex input statements.
Available tokens:
OBSERVATIONS: Unflattened tensorized inputobservations(tensorized observation space) forwarded to the model.STATES: Unflattened tensorized inputstates(tensorized state space) forwarded to the model.ACTIONS: Unflattened tensorized inputtaken_actions(tensorized action space) forwarded to the model.OBSERVATION_SPACE: Token indicating theobservation_spaceof the model.STATE_SPACE: Token indicating thestate_spaceof the model.ACTION_SPACE: Token indicating theaction_spaceof the model.
Supported operations:
Operations |
Example |
|---|---|
Tensor/array indexing and slicing.
|
|
Dictionary indexing by key.
|
|
Arithmetic ( |
|
Concatenation |
|
Permute dimensions |
|
One-hot encoding |
|
Outputs¶
Outputs can be specified using tokens or previously defined container outputs (by container names). Certain operations could be specified on them.
Note
If a token is used, a linear layer will be created with the last container in the list (as the number of input features) and the value represented by the token (as the number of output features).
Hint
Operations can be mixed to create complex output statements.
Available tokens:
ACTIONS: Token indicating that the output shape is the number of elements in the action space.ONE: Token indicating that the output shape is 1.
Supported operations:
Operations |
Example |
|---|---|
Activation function |
|
Arithmetic ( |
|
Concatenation |
|
Activation functions¶
The following table lists the supported activation functions:
Activations |
|
|
|
|---|---|---|---|
|
|
||
|
|
||
|
|||
|
|||
|
|
||
|
|||
|
|||
|
|||
|
Layers¶
The following table lists the supported layers and transformations:
linear¶
Apply a linear transformation (torch.nn.Linear in PyTorch, flax.linen.Dense in JAX).
Note
The tokens NUM_OBSERVATIONS (number of elements in the observation space), NUM_STATES
(number of elements in the state space), NUM_ACTIONS (number of elements in the action space),
and ONE (1) can be used as the layer’s number of input/output features.
Note
If the PyTorch’s in_features parameter is not specified it will be inferred by using
the torch.nn.LazyLinear module.
|
|
|
Type |
Required |
Description |
|
|---|---|---|---|---|---|---|
|
- |
|
|
\(\square\) |
Number of input features |
|
0 |
|
|
|
|
\(\blacksquare\) |
Number of output features |
1 |
|
|
|
|
\(\square\) |
Whether to add a bias |
layers:
- 32
layers:
- linear: 32
layers:
- linear: [32]
Hint
The parameter names can be interchanged/mixed between PyTorch and JAX
layers:
- linear: {out_features: 32}
"layers": [
32,
]
"layers": [
{"linear": 32},
]
"layers": [
{"linear": [32]},
]
Hint
The parameter names can be interchanged/mixed between PyTorch and JAX
"layers": [
{"linear": {"out_features": 32}},
]
conv2d¶
Apply a 2D convolution (torch.nn.Conv2d in PyTorch, flax.linen.Conv in JAX).
Warning
PyTorch
torch.nn.Conv2dexpects the input to be in the form NCHW (N: batch, C: channels, H: height, W: width). A permutation operation may be necessary to modify the dimensions of a batch of images which are typically NHWC.JAX
flax.linen.Convexpects the input to be in the form NHWC (the typical dimensions of a batch of images).
Note
If the PyTorch’s in_channels parameter is not specified it will be inferred by using
the torch.nn.LazyConv2d module.
|
|
Type |
Required |
Description |
|
|---|---|---|---|---|---|
|
- |
|
\(\square\) |
Number of input channels |
|
0 |
|
|
|
\(\blacksquare\) |
Number of output channels (filters) |
1 |
|
|
|
\(\blacksquare\) |
Convolutional kernel size |
2 |
|
|
|
\(\square\) |
Inter-window strides |
3 |
|
|
|
\(\square\) |
Padding added to all dimensions |
4 |
|
|
|
\(\square\) |
Whether to add a bias |
layers:
- conv2d: [32, 8, [4, 4]]
Hint
The parameter names can be interchanged/mixed between PyTorch and JAX
layers:
- conv2d: {out_channels: 32, kernel_size: 8, stride: [4, 4]}
"layers": [
{"conv2d": [32, 8, [4, 4]]},
]
Hint
The parameter names can be interchanged/mixed between PyTorch and JAX
"layers": [
{"conv2d": {"out_channels": 32, "kernel_size": 8, "stride": [4, 4]}},
]
flatten¶
Flatten a contiguous range of dimensions (torch.nn.Flatten in PyTorch,
jax.numpy.reshape operation in JAX).
|
|
|
Type |
Required |
Description |
|
|---|---|---|---|---|---|---|
0 |
|
- |
- |
|
\(\square\) |
First dimension to flatten |
1 |
|
- |
- |
|
\(\square\) |
Last dimension to flatten |
layers:
- flatten
layers:
- flatten: [1, -1]
Hint
The parameter names can be interchanged/mixed between PyTorch and JAX
layers:
- flatten: {start_dim: 1, end_dim: -1}
"layers": [
"flatten",
]
"layers": [
{"flatten": [1, -1]},
]
Hint
The parameter names can be interchanged/mixed between PyTorch and JAX
"layers": [
{"flatten": {"start_dim": 1, "end_dim": -1}},
]
API¶
PyTorch¶
Instantiate a |
|
Instantiate a |
|
Instantiate a |
|
Instantiate a |
|
Instantiate a |
|
Instantiate a shared model |
|
Instantiate a |
- skrl.utils.model_instantiators.torch.categorical_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | torch.device | None = None, unnormalized_log_prob: bool = True, network: list[dict[str, Any]] = [], output: str | list[str] = '', return_source: bool = False) Model | str[source]¶
Instantiate a
CategoricalMixin-based model.- Parameters:
observation_space – Observation space. The
num_observationsproperty will contain the size of the space.state_space – State space. The
num_statesproperty will contain the size of the space.action_space – Action space. The
num_actionsproperty will contain the size of the space.device – Data allocation and computation device. If not specified, the default device will be used.
unnormalized_log_prob – Flag to indicate how to the model’s output will be interpreted. If True, the model’s output is interpreted as unnormalized log probabilities (it can be any real number), otherwise as normalized probabilities (the output must be non-negative, finite and have a non-zero sum).
network – Network definition.
output – Output expression.
return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.
- Returns:
Categorical model instance or definition source (if
return_sourceis True).
- skrl.utils.model_instantiators.torch.multicategorical_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | torch.device | None = None, unnormalized_log_prob: bool = True, reduction: Literal['mean', 'sum', 'prod', 'none'] = 'sum', network: list[dict[str, Any]] = [], output: str | list[str] = '', return_source: bool = False) Model | str[source]¶
Instantiate a
MultiCategoricalMixin-based model.- Parameters:
observation_space – Observation space. The
num_observationsproperty will contain the size of the space.state_space – State space. The
num_statesproperty will contain the size of the space.action_space – Action space. The
num_actionsproperty will contain the size of the space.device – Data allocation and computation device. If not specified, the default device will be used.
unnormalized_log_prob – Flag to indicate how to the model’s output will be interpreted. If True, the model’s output is interpreted as unnormalized log probabilities (it can be any real number), otherwise as normalized probabilities (the output must be non-negative, finite and have a non-zero sum).
reduction – Reduction method for returning the log probability density function. If
"none", the log probability density function is returned as a tensor of shape(num_samples, num_actions)instead of(num_samples, 1).network – Network definition.
output – Output expression.
return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.
- Returns:
MultiCategorical model instance or definition source (if
return_sourceis True).
- skrl.utils.model_instantiators.torch.deterministic_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | torch.device | None = None, clip_actions: bool = False, network: list[dict[str, Any]] = [], output: str | list[str] = '', return_source: bool = False) Model | str[source]¶
Instantiate a
DeterministicMixin-based model.- Parameters:
observation_space – Observation space. The
num_observationsproperty will contain the size of the space.state_space – State space. The
num_statesproperty will contain the size of the space.action_space – Action space. The
num_actionsproperty will contain the size of the space.device – Data allocation and computation device. If not specified, the default device will be used.
clip_actions – Flag to indicate whether the actions should be clipped to the action space.
network – Network definition.
output – Output expression.
return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.
- Returns:
Deterministic model instance or definition source (if
return_sourceis True).
- skrl.utils.model_instantiators.torch.gaussian_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | torch.device | None = None, clip_actions: bool = False, clip_mean_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20, max_log_std: float = 2, reduction: Literal['mean', 'sum', 'prod', 'none'] = 'sum', initial_log_std: float = 0, fixed_log_std: bool = False, network: list[dict[str, Any]] = [], output: str | list[str] = '', return_source: bool = False) Model | str[source]¶
Instantiate a
GaussianMixin-based model.- Parameters:
observation_space – Observation space. The
num_observationsproperty will contain the size of the space.state_space – State space. The
num_statesproperty will contain the size of the space.action_space – Action space. The
num_actionsproperty will contain the size of the space.device – Data allocation and computation device. If not specified, the default device will be used.
clip_actions – Flag to indicate whether the actions should be clipped to the action space.
clip_mean_actions – Flag to indicate whether the mean actions should be clipped to the action space. If
True, the mean actions will be clipped before sampling the actions.clip_log_std – Flag to indicate whether the log standard deviations should be clipped.
min_log_std – Minimum value of the log standard deviation if
clip_log_stdis True.max_log_std – Maximum value of the log standard deviation if
clip_log_stdis True.reduction – Reduction method for returning the log probability density function. If
"none", the log probability density function is returned as a tensor of shape(num_samples, num_actions)instead of(num_samples, 1).initial_log_std – Initial value for the log standard deviation.
fixed_log_std – Whether the log standard deviation parameter should be fixed. Fixed parameters have the gradient computation deactivated.
network – Network definition.
output – Output expression.
return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.
- Returns:
Gaussian model instance or definition source (if
return_sourceis True).
- skrl.utils.model_instantiators.torch.multivariate_gaussian_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | torch.device | None = None, clip_actions: bool = False, clip_mean_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20, max_log_std: float = 2, initial_log_std: float = 0, fixed_log_std: bool = False, network: list[dict[str, Any]] = [], output: str | list[str] = '', return_source: bool = False) Model | str[source]¶
Instantiate a
MultivariateGaussianMixin-based model.- Parameters:
observation_space – Observation space. The
num_observationsproperty will contain the size of the space.state_space – State space. The
num_statesproperty will contain the size of the space.action_space – Action space. The
num_actionsproperty will contain the size of the space.device – Data allocation and computation device. If not specified, the default device will be used.
clip_actions – Flag to indicate whether the actions should be clipped to the action space.
clip_mean_actions – Flag to indicate whether the mean actions should be clipped to the action space. If
True, the mean actions will be clipped before sampling the actions.clip_log_std – Flag to indicate whether the log standard deviations should be clipped.
min_log_std – Minimum value of the log standard deviation if
clip_log_stdis True.max_log_std – Maximum value of the log standard deviation if
clip_log_stdis True.initial_log_std – Initial value for the log standard deviation.
fixed_log_std – Whether the log standard deviation parameter should be fixed. Fixed parameters have the gradient computation deactivated.
network – Network definition.
output – Output expression.
return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.
- Returns:
Multivariate Gaussian model instance or definition source (if
return_sourceis True).
Instantiate a shared model
- Parameters:
observation_space – Observation space. The
num_observationsproperty will contain the size of the space.state_space – State space. The
num_statesproperty will contain the size of the space.action_space – Action space. The
num_actionsproperty will contain the size of the space.device – Data allocation and computation device. If not specified, the default device will be used.
structure – Shared model structure.
roles – Organized list of model roles.
parameters – Organized list of model instantiator parameters.
single_forward_pass – Whether to perform a single forward-pass for the shared layers/network.
return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.
- Returns:
Shared model instance or definition source (if
return_sourceis True).
- skrl.utils.model_instantiators.torch.tabular_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | torch.device | None = None, variant: Literal['epsilon-greedy'] = 'epsilon-greedy', variant_kwargs: dict[str, Any] = {}, return_source: bool = False) Model | str[source]¶
Instantiate a
TabularMixin-based model.Supported variants:
epsilon-greedy: Simple method of balancing exploration and exploitation by randomly selecting one or the other.Argument
Type
Default
Description
epsilonfloat0.1Cut-off probability for choosing to explore
- Parameters:
observation_space – Observation space. The
num_observationsproperty will contain the size of the space.state_space – State space. The
num_statesproperty will contain the size of the space.action_space – Action space. The
num_actionsproperty will contain the size of the space.device – Data allocation and computation device. If not specified, the default device will be used.
variant – Variant of the model.
variant_kwargs – Variant-specific keyword arguments.
return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.
- Returns:
Tabular model instance or definition source (if
return_sourceis True).
JAX¶
Instantiate a |
|
Instantiate a |
|
Instantiate a |
|
Instantiate a |
- skrl.utils.model_instantiators.jax.categorical_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | jax.Device | None = None, unnormalized_log_prob: bool = True, network: list[dict[str, Any]] = [], output: str | list[str] = '', return_source: bool = False) Model | str[source]¶
Instantiate a
CategoricalMixin-based model.- Parameters:
observation_space – Observation space. The
num_observationsproperty will contain the size of the space.state_space – State space. The
num_statesproperty will contain the size of the space.action_space – Action space. The
num_actionsproperty will contain the size of the space.device – Data allocation and computation device. If not specified, the default device will be used.
unnormalized_log_prob – Flag to indicate how to the model’s output will be interpreted. If True, the model’s output is interpreted as unnormalized log probabilities (it can be any real number), otherwise as normalized probabilities (the output must be non-negative, finite and have a non-zero sum).
network – Network definition.
output – Output expression.
return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.
- Returns:
Categorical model instance or definition source (if
return_sourceis True).
- skrl.utils.model_instantiators.jax.multicategorical_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | jax.Device | None = None, unnormalized_log_prob: bool = True, reduction: Literal['mean', 'sum', 'prod', 'none'] = 'sum', network: list[dict[str, Any]] = [], output: str | list[str] = '', return_source: bool = False) Model | str[source]¶
Instantiate a
MultiCategoricalMixin-based model.- Parameters:
observation_space – Observation space. The
num_observationsproperty will contain the size of the space.state_space – State space. The
num_statesproperty will contain the size of the space.action_space – Action space. The
num_actionsproperty will contain the size of the space.device – Data allocation and computation device. If not specified, the default device will be used.
unnormalized_log_prob – Flag to indicate how to the model’s output will be interpreted. If True, the model’s output is interpreted as unnormalized log probabilities (it can be any real number), otherwise as normalized probabilities (the output must be non-negative, finite and have a non-zero sum).
reduction – Reduction method for returning the log probability density function. If
"none", the log probability density function is returned as a tensor of shape(num_samples, num_actions)instead of(num_samples, 1).network – Network definition.
output – Output expression.
return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.
- Returns:
MultiCategorical model instance or definition source (if
return_sourceis True).
- skrl.utils.model_instantiators.jax.deterministic_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | jax.Device | None = None, clip_actions: bool = False, network: list[dict[str, Any]] = [], output: str | list[str] = '', return_source: bool = False) Model | str[source]¶
Instantiate a
DeterministicMixin-based model.- Parameters:
observation_space – Observation space. The
num_observationsproperty will contain the size of the space.state_space – State space. The
num_statesproperty will contain the size of the space.action_space – Action space. The
num_actionsproperty will contain the size of the space.device – Data allocation and computation device. If not specified, the default device will be used.
clip_actions – Flag to indicate whether the actions should be clipped to the action space.
network – Network definition.
output – Output expression.
return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.
- Returns:
Deterministic model instance or definition source (if
return_sourceis True).
- skrl.utils.model_instantiators.jax.gaussian_model(*, observation_space: gymnasium.Space | None = None, state_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, device: str | jax.Device | None = None, clip_actions: bool = False, clip_mean_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20, max_log_std: float = 2, reduction: Literal['mean', 'sum', 'prod', 'none'] = 'sum', initial_log_std: float = 0, fixed_log_std: bool = False, network: list[dict[str, Any]] = [], output: str | list[str] = '', return_source: bool = False) Model | str[source]¶
Instantiate a
GaussianMixin-based model.- Parameters:
observation_space – Observation space. The
num_observationsproperty will contain the size of the space.state_space – State space. The
num_statesproperty will contain the size of the space.action_space – Action space. The
num_actionsproperty will contain the size of the space.device – Data allocation and computation device. If not specified, the default device will be used.
clip_actions – Flag to indicate whether the actions should be clipped to the action space.
clip_mean_actions – Flag to indicate whether the mean actions should be clipped to the action space. If
True, the mean actions will be clipped before sampling the actions.clip_log_std – Flag to indicate whether the log standard deviations should be clipped.
min_log_std – Minimum value of the log standard deviation if
clip_log_stdis True.max_log_std – Maximum value of the log standard deviation if
clip_log_stdis True.reduction – Reduction method for returning the log probability density function. If
"none", the log probability density function is returned as a tensor of shape(num_samples, num_actions)instead of(num_samples, 1).initial_log_std – Initial value for the log standard deviation.
fixed_log_std – Whether the log standard deviation parameter should be fixed. Fixed parameters have the gradient computation deactivated.
network – Network definition.
output – Output expression.
return_source – Whether to return the source string containing the model class used to instantiate the model rather than the model instance.
- Returns:
Gaussian model instance or definition source (if
return_sourceis True).
Warp¶
Instantiate a |
|
Instantiate a |