Models#
Models (or agent models) refer to a representation of the agent’s policy, value function, etc. that the agent uses to make decisions. Agents can have one or more models, and their parameters are adjusted by the optimization algorithms.
Models |
|
|
---|---|---|
Tabular model (discrete domain) |
\(\blacksquare\) |
\(\square\) |
Categorical model (discrete domain) |
\(\blacksquare\) |
\(\blacksquare\) |
Gaussian model (continuous domain) |
\(\blacksquare\) |
\(\blacksquare\) |
Multivariate Gaussian model (continuous domain) |
\(\blacksquare\) |
\(\square\) |
Deterministic model (continuous domain) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\blacksquare\) |
\(\square\) |
Base class#
Note
This is the base class for all models in this module and provides only basic functionality that is not tied to any specific implementation. It is not intended to be used directly.
Mixin and inheritance#
from typing import Union, Mapping, Tuple, Any
import torch
class CustomMixin:
def __init__(self, role: str = "") -> None:
"""Custom mixin
:param role: Role play by the model (default: ``""``)
:type role: str, optional
"""
# =====================================
# - define custom attributes and others
# =====================================
def act(self,
inputs: Mapping[str, Union[torch.Tensor, Any]],
role: str = "") -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]:
"""Act according to the specified behavior
:param inputs: Model inputs. The most common keys are:
- ``"states"``: state of the environment used to make the decision
- ``"taken_actions"``: actions taken by the policy for the given states
:type inputs: dict where the values are typically torch.Tensor
:param role: Role play by the model (default: ``""``)
:type role: str, optional
:return: Model output. The first component is the action to be taken by the agent.
The second component is the log of the probability density function for stochastic models
or None for deterministic models. The third component is a dictionary containing extra output values
:rtype: tuple of torch.Tensor, torch.Tensor or None, and dictionary
"""
# ==============================
# - act in response to the state
# ==============================
from typing import Optional, Union, Mapping, Tuple, Any
import flax
import jax.numpy as jnp
class CustomMixin:
def __init__(self, role: str = "") -> None:
"""Custom mixin
:param role: Role play by the model (default: ``""``)
:type role: str, optional
"""
# =====================================
# - define custom attributes and others
# =====================================
flax.linen.Module.__post_init__(self)
def act(self,
inputs: Mapping[str, Union[jnp.ndarray, Any]],
role: str = "",
params: Optional[jnp.ndarray] = None) -> Tuple[jnp.ndarray, Union[jnp.ndarray, None], Mapping[str, Union[jnp.ndarray, Any]]]:
"""Act according to the specified behavior
:param inputs: Model inputs. The most common keys are:
- ``"states"``: state of the environment used to make the decision
- ``"taken_actions"``: actions taken by the policy for the given states
:type inputs: dict where the values are typically jnp.ndarray
:param role: Role play by the model (default: ``""``)
:type role: str, optional
:param params: Parameters used to compute the output (default: ``None``).
If ``None``, internal parameters will be used
:type params: jnp.array
:return: Model output. The first component is the action to be taken by the agent.
The second component is the log of the probability density function.
The third component is a dictionary containing the mean actions ``"mean_actions"``
and extra output values
:rtype: tuple of jnp.ndarray, jnp.ndarray or None, and dictionary
"""
# ==============================
# - act in response to the state
# ==============================
from typing import Optional, Union, Mapping, Sequence, Tuple, Any
import gym, gymnasium
import torch
from skrl.models.torch import Model
class CustomModel(Model):
def __init__(self,
observation_space: Union[int, Sequence[int], gym.Space, gymnasium.Space],
action_space: Union[int, Sequence[int], gym.Space, gymnasium.Space],
device: Optional[Union[str, torch.device]] = None) -> None:
"""Custom model
:param observation_space: Observation/state space or shape.
The ``num_observations`` property will contain the size of that space
:type observation_space: int, sequence of int, gym.Space, gymnasium.Space
:param action_space: Action space or shape.
The ``num_actions`` property will contain the size of that space
:type action_space: int, sequence of int, gym.Space, gymnasium.Space
:param device: Device on which a torch tensor is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda:0"`` if available or ``"cpu"``
:type device: str or torch.device, optional
"""
super().__init__(observation_space, action_space, device)
# =====================================
# - define custom attributes and others
# =====================================
flax.linen.Module.__post_init__(self)
def act(self,
inputs: Mapping[str, Union[torch.Tensor, Any]],
role: str = "") -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]:
"""Act according to the specified behavior
:param inputs: Model inputs. The most common keys are:
- ``"states"``: state of the environment used to make the decision
- ``"taken_actions"``: actions taken by the policy for the given states
:type inputs: dict where the values are typically torch.Tensor
:param role: Role play by the model (default: ``""``)
:type role: str, optional
:return: Model output. The first component is the action to be taken by the agent.
The second component is the log of the probability density function for stochastic models
or None for deterministic models. The third component is a dictionary containing extra output values
:rtype: tuple of torch.Tensor, torch.Tensor or None, and dictionary
"""
# ==============================
# - act in response to the state
# ==============================
from typing import Optional, Union, Mapping, Tuple, Any
import gym, gymnasium
import flax
import jaxlib
import jax.numpy as jnp
from skrl.models.jax import Model
class CustomModel(Model):
def __init__(self,
observation_space: Union[int, Sequence[int], gym.Space, gymnasium.Space],
action_space: Union[int, Sequence[int], gym.Space, gymnasium.Space],
device: Optional[Union[str, jaxlib.xla_extension.Device]] = None,
parent: Optional[Any] = None,
name: Optional[str] = None) -> None:
"""Custom model
:param observation_space: Observation/state space or shape.
The ``num_observations`` property will contain the size of that space
:type observation_space: int, sequence of int, gym.Space, gymnasium.Space
:param action_space: Action space or shape.
The ``num_actions`` property will contain the size of that space
:type action_space: int, sequence of int, gym.Space, gymnasium.Space
:param device: Device on which a jax array is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda:0"`` if available or ``"cpu"``
:type device: str or jaxlib.xla_extension.Device, optional
:param parent: The parent Module of this Module (default: ``None``).
It is a Flax reserved attribute
:type parent: str, optional
:param name: The name of this Module (default: ``None``).
It is a Flax reserved attribute
:type name: str, optional
"""
Model.__init__(self, observation_space, action_space, device, parent, name)
# =====================================
# - define custom attributes and others
# =====================================
flax.linen.Module.__post_init__(self)
def act(self,
inputs: Mapping[str, Union[jnp.ndarray, Any]],
role: str = "",
params: Optional[jnp.ndarray] = None) -> Tuple[jnp.ndarray, Union[jnp.ndarray, None], Mapping[str, Union[jnp.ndarray, Any]]]:
"""Act according to the specified behavior
:param inputs: Model inputs. The most common keys are:
- ``"states"``: state of the environment used to make the decision
- ``"taken_actions"``: actions taken by the policy for the given states
:type inputs: dict where the values are typically jnp.ndarray
:param role: Role play by the model (default: ``""``)
:type role: str, optional
:param params: Parameters used to compute the output (default: ``None``).
If ``None``, internal parameters will be used
:type params: jnp.array
:return: Model output. The first component is the action to be taken by the agent.
The second component is the log of the probability density function.
The third component is a dictionary containing the mean actions ``"mean_actions"``
and extra output values
:rtype: tuple of jnp.ndarray, jnp.ndarray or None, and dictionary
"""
# ==============================
# - act in response to the state
# ==============================
API (PyTorch)#
- class skrl.models.torch.base.Model(*args: Any, **kwargs: Any)#
Bases:
Module
- __init__(observation_space: int | Sequence[int] | gym.Space | gymnasium.Space, action_space: int | Sequence[int] | gym.Space | gymnasium.Space, device: str | torch.device | None = None) None #
Base class representing a function approximator
The following properties are defined:
device
(torch.device): Device to be used for the computationsobservation_space
(int, sequence of int, gym.Space, gymnasium.Space): Observation/state spaceaction_space
(int, sequence of int, gym.Space, gymnasium.Space): Action spacenum_observations
(int): Number of elements in the observation/state spacenum_actions
(int): Number of elements in the action space
- Parameters:
observation_space (int, sequence of int, gym.Space, gymnasium.Space) – Observation/state space or shape. The
num_observations
property will contain the size of that spaceaction_space (int, sequence of int, gym.Space, gymnasium.Space) – Action space or shape. The
num_actions
property will contain the size of that spacedevice (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default:
None
). If None, the device will be either"cuda"
if available or"cpu"
Custom models should override the
act
method:import torch from skrl.models.torch import Model class CustomModel(Model): def __init__(self, observation_space, action_space, device="cuda:0"): Model.__init__(self, observation_space, action_space, device) self.layer_1 = nn.Linear(self.num_observations, 64) self.layer_2 = nn.Linear(64, self.num_actions) def act(self, inputs, role=""): x = F.relu(self.layer_1(inputs["states"])) x = F.relu(self.layer_2(x)) return x, None, {}
- property device#
Device to be used for the computations
- property observation_space#
Observation/state space. It is a replica of the class constructor parameter of the same name
- property action_space#
Action space. It is a replica of the class constructor parameter of the same name
- property num_observations#
Number of elements in the observation/state space
- property num_actions#
Number of elements in the action space
- _get_space_size(space: int | Sequence[int] | gym.Space | gymnasium.Space, number_of_elements: bool = True) int #
Get the size (number of elements) of a space
- Parameters:
space (int, sequence of int, gym.Space, or gymnasium.Space) – Space or shape from which to obtain the number of elements
number_of_elements (bool, optional) – Whether the number of elements occupied by the space is returned (default:
True
). IfFalse
, the shape of the space is returned. It only affects Discrete and MultiDiscrete spaces
- Raises:
ValueError – If the space is not supported
- Returns:
Size of the space (number of elements)
- Return type:
Example:
# from int >>> model._get_space_size(2) 2 # from sequence of int >>> model._get_space_size([2, 3]) 6 # Box space >>> space = gym.spaces.Box(low=-1, high=1, shape=(2, 3)) >>> model._get_space_size(space) 6 # Discrete space >>> space = gym.spaces.Discrete(4) >>> model._get_space_size(space) 4 >>> model._get_space_size(space, number_of_elements=False) 1 # MultiDiscrete space >>> space = gym.spaces.MultiDiscrete([5, 3, 2]) >>> model._get_space_size(space) 10 >>> model._get_space_size(space, number_of_elements=False) 3 # Dict space >>> space = gym.spaces.Dict({'a': gym.spaces.Box(low=-1, high=1, shape=(2, 3)), ... 'b': gym.spaces.Discrete(4)}) >>> model._get_space_size(space) 10 >>> model._get_space_size(space, number_of_elements=False) 7
- act(inputs: Mapping[str, torch.Tensor | Any], role: str = '') Tuple[torch.Tensor, torch.Tensor | None, Mapping[str, torch.Tensor | Any]] #
Act according to the specified behavior (to be implemented by the inheriting classes)
Agents will call this method to obtain the decision to be taken given the state of the environment. This method is currently implemented by the helper models (GaussianModel, etc.). The classes that inherit from the latter must only implement the
.compute()
method- Parameters:
inputs (dict where the values are typically torch.Tensor) –
Model inputs. The most common keys are:
"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states
role (str, optional) – Role play by the model (default:
""
)
- Raises:
NotImplementedError – Child class must implement this method
- Returns:
Model output. The first component is the action to be taken by the agent. The second component is the log of the probability density function for stochastic models or None for deterministic models. The third component is a dictionary containing extra output values
- Return type:
tuple of torch.Tensor, torch.Tensor or None, and dict
- compute(inputs: Mapping[str, torch.Tensor | Any], role: str = '') Tuple[torch.Tensor | Mapping[str, torch.Tensor | Any]] #
Define the computation performed (to be implemented by the inheriting classes) by the models
- Parameters:
inputs (dict where the values are typically torch.Tensor) –
Model inputs. The most common keys are:
"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states
role (str, optional) – Role play by the model (default:
""
)
- Raises:
NotImplementedError – Child class must implement this method
- Returns:
Computation performed by the models
- Return type:
tuple of torch.Tensor and dict
- forward(inputs: Mapping[str, torch.Tensor | Any], role: str = '') Tuple[torch.Tensor, torch.Tensor | None, Mapping[str, torch.Tensor | Any]] #
Forward pass of the model
This method calls the
.act()
method and returns its outputs- Parameters:
inputs (dict where the values are typically torch.Tensor) –
Model inputs. The most common keys are:
"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states
role (str, optional) – Role play by the model (default:
""
)
- Returns:
Model output. The first component is the action to be taken by the agent. The second component is the log of the probability density function for stochastic models or None for deterministic models. The third component is a dictionary containing extra output values
- Return type:
tuple of torch.Tensor, torch.Tensor or None, and dict
- freeze_parameters(freeze: bool = True) None #
Freeze or unfreeze internal parameters
Freeze: disable gradient computation (
parameters.requires_grad = False
)Unfreeze: enable gradient computation (
parameters.requires_grad = True
)
- Parameters:
freeze (bool, optional) – Freeze the internal parameters if True, otherwise unfreeze them (default:
True
)
Example:
# freeze model parameters >>> model.freeze_parameters(True) # unfreeze model parameters >>> model.freeze_parameters(False)
- get_specification() Mapping[str, Any] #
Returns the specification of the model
The following keys are used by the agents for initialization:
"rnn"
: Recurrent Neural Network (RNN) specification for RNN, LSTM and GRU layers/cells"sizes"
: List of RNN shapes (number of layers, number of environments, number of features in the RNN state). There must be as many tuples as there are states in the recurrent layer/cell. E.g., LSTM has 2 states (hidden and cell).
- Returns:
Dictionary containing advanced specification of the model
- Return type:
Example:
# model with a LSTM layer. # - number of layers: 1 # - number of environments: 4 # - number of features in the RNN state: 64 >>> model.get_specification() {'rnn': {'sizes': [(1, 4, 64), (1, 4, 64)]}}
- init_biases(method_name: str = 'constant_', *args, **kwargs) None #
Initialize the model biases according to the specified method name
Method names are from the torch.nn.init module. Allowed method names are uniform_, normal_, constant_, etc.
The following layers will be initialized: - torch.nn.Linear
- Parameters:
method_name (str, optional) –
torch.nn.init method name (default:
"constant_"
)args (tuple, optional) – Positional arguments of the method to be called
kwargs (dict, optional) – Key-value arguments of the method to be called
Example:
# initialize all biases with a constant value (0) >>> model.init_biases(method_name="constant_", val=0) # initialize all biases with normal distribution with mean 0 and standard deviation 0.25 >>> model.init_biases(method_name="normal_", mean=0.0, std=0.25)
- init_parameters(method_name: str = 'normal_', *args, **kwargs) None #
Initialize the model parameters according to the specified method name
Method names are from the torch.nn.init module. Allowed method names are uniform_, normal_, constant_, etc.
- Parameters:
method_name (str, optional) –
torch.nn.init method name (default:
"normal_"
)args (tuple, optional) – Positional arguments of the method to be called
kwargs (dict, optional) – Key-value arguments of the method to be called
Example:
# initialize all parameters with an orthogonal distribution with a gain of 0.5 >>> model.init_parameters("orthogonal_", gain=0.5) # initialize all parameters as a sparse matrix with a sparsity of 0.1 >>> model.init_parameters("sparse_", sparsity=0.1)
- init_weights(method_name: str = 'orthogonal_', *args, **kwargs) None #
Initialize the model weights according to the specified method name
Method names are from the torch.nn.init module. Allowed method names are uniform_, normal_, constant_, etc.
The following layers will be initialized: - torch.nn.Linear
- Parameters:
method_name (str, optional) –
torch.nn.init method name (default:
"orthogonal_"
)args (tuple, optional) – Positional arguments of the method to be called
kwargs (dict, optional) – Key-value arguments of the method to be called
Example:
# initialize all weights with uniform distribution in range [-0.1, 0.1] >>> model.init_weights(method_name="uniform_", a=-0.1, b=0.1) # initialize all weights with normal distribution with mean 0 and standard deviation 0.25 >>> model.init_weights(method_name="normal_", mean=0.0, std=0.25)
- load(path: str) None #
Load the model from the specified path
The final storage device is determined by the constructor of the model
- Parameters:
path (str) – Path to load the model from
Example:
# load the model onto the CPU >>> model = Model(observation_space, action_space, device="cpu") >>> model.load("model.pt") # load the model onto the GPU 1 >>> model = Model(observation_space, action_space, device="cuda:1") >>> model.load("model.pt")
- migrate(state_dict: Mapping[str, torch.Tensor] | None = None, path: str | None = None, name_map: Mapping[str, str] = {}, auto_mapping: bool = True, verbose: bool = False) bool #
Migrate the specified extrernal model’s state dict to the current model
The final storage device is determined by the constructor of the model
Only one of
state_dict
orpath
can be specified. Thepath
parameter allows automatic loading thestate_dict
only from files generated by the rl_games and stable-baselines3 libraries at the momentFor ambiguous models (where 2 or more parameters, for source or current model, have equal shape) it is necessary to define the
name_map
, at least for those parameters, to perform the migration successfully- Parameters:
state_dict (Mapping[str, torch.Tensor], optional) – External model’s state dict to migrate from (default:
None
)path (str, optional) – Path to the external checkpoint to migrate from (default:
None
)name_map (Mapping[str, str], optional) – Name map to use for the migration (default:
{}
). Keys are the current parameter names and values are the external parameter namesauto_mapping (bool, optional) – Automatically map the external state dict to the current state dict (default:
True
)verbose (bool, optional) – Show model names and migration (default:
False
)
- Raises:
ValueError – If neither or both of
state_dict
andpath
parameters have been setValueError – If the correct file type cannot be identified from the
path
parameter
- Returns:
True if the migration was successful, False otherwise. Migration is successful if all parameters of the current model are found in the external model
- Return type:
Example:
# migrate a rl_games checkpoint with unambiguous state_dict >>> model.migrate(path="./runs/Ant/nn/Ant.pth") True # migrate a rl_games checkpoint with ambiguous state_dict >>> model.migrate(path="./runs/Cartpole/nn/Cartpole.pth", verbose=False) [skrl:WARNING] Ambiguous match for log_std_parameter <- [value_mean_std.running_mean, value_mean_std.running_var, a2c_network.sigma] [skrl:WARNING] Ambiguous match for net.0.bias <- [a2c_network.actor_mlp.0.bias, a2c_network.actor_mlp.2.bias] [skrl:WARNING] Ambiguous match for net.2.bias <- [a2c_network.actor_mlp.0.bias, a2c_network.actor_mlp.2.bias] [skrl:WARNING] Ambiguous match for net.4.weight <- [a2c_network.value.weight, a2c_network.mu.weight] [skrl:WARNING] Ambiguous match for net.4.bias <- [a2c_network.value.bias, a2c_network.mu.bias] [skrl:WARNING] Multiple use of a2c_network.actor_mlp.0.bias -> [net.0.bias, net.2.bias] [skrl:WARNING] Multiple use of a2c_network.actor_mlp.2.bias -> [net.0.bias, net.2.bias] False >>> name_map = {"log_std_parameter": "a2c_network.sigma", ... "net.0.bias": "a2c_network.actor_mlp.0.bias", ... "net.2.bias": "a2c_network.actor_mlp.2.bias", ... "net.4.weight": "a2c_network.mu.weight", ... "net.4.bias": "a2c_network.mu.bias"} >>> model.migrate(path="./runs/Cartpole/nn/Cartpole.pth", name_map=name_map, verbose=True) [skrl:INFO] Models [skrl:INFO] |-- current: 7 items [skrl:INFO] | |-- log_std_parameter : torch.Size([1]) [skrl:INFO] | |-- net.0.weight : torch.Size([32, 4]) [skrl:INFO] | |-- net.0.bias : torch.Size([32]) [skrl:INFO] | |-- net.2.weight : torch.Size([32, 32]) [skrl:INFO] | |-- net.2.bias : torch.Size([32]) [skrl:INFO] | |-- net.4.weight : torch.Size([1, 32]) [skrl:INFO] | |-- net.4.bias : torch.Size([1]) [skrl:INFO] |-- source: 15 items [skrl:INFO] | |-- value_mean_std.running_mean : torch.Size([1]) [skrl:INFO] | |-- value_mean_std.running_var : torch.Size([1]) [skrl:INFO] | |-- value_mean_std.count : torch.Size([]) [skrl:INFO] | |-- running_mean_std.running_mean : torch.Size([4]) [skrl:INFO] | |-- running_mean_std.running_var : torch.Size([4]) [skrl:INFO] | |-- running_mean_std.count : torch.Size([]) [skrl:INFO] | |-- a2c_network.sigma : torch.Size([1]) [skrl:INFO] | |-- a2c_network.actor_mlp.0.weight : torch.Size([32, 4]) [skrl:INFO] | |-- a2c_network.actor_mlp.0.bias : torch.Size([32]) [skrl:INFO] | |-- a2c_network.actor_mlp.2.weight : torch.Size([32, 32]) [skrl:INFO] | |-- a2c_network.actor_mlp.2.bias : torch.Size([32]) [skrl:INFO] | |-- a2c_network.value.weight : torch.Size([1, 32]) [skrl:INFO] | |-- a2c_network.value.bias : torch.Size([1]) [skrl:INFO] | |-- a2c_network.mu.weight : torch.Size([1, 32]) [skrl:INFO] | |-- a2c_network.mu.bias : torch.Size([1]) [skrl:INFO] Migration [skrl:INFO] |-- map: log_std_parameter <- a2c_network.sigma [skrl:INFO] |-- auto: net.0.weight <- a2c_network.actor_mlp.0.weight [skrl:INFO] |-- map: net.0.bias <- a2c_network.actor_mlp.0.bias [skrl:INFO] |-- auto: net.2.weight <- a2c_network.actor_mlp.2.weight [skrl:INFO] |-- map: net.2.bias <- a2c_network.actor_mlp.2.bias [skrl:INFO] |-- map: net.4.weight <- a2c_network.mu.weight [skrl:INFO] |-- map: net.4.bias <- a2c_network.mu.bias False # migrate a stable-baselines3 checkpoint with unambiguous state_dict >>> model.migrate(path="./ddpg_pendulum.zip") True # migrate from any exported model by loading its state_dict (unambiguous state_dict) >>> state_dict = torch.load("./external_model.pt") >>> model.migrate(state_dict=state_dict) True
- random_act(inputs: Mapping[str, torch.Tensor | Any], role: str = '') Tuple[torch.Tensor, None, Mapping[str, torch.Tensor | Any]] #
Act randomly according to the action space
- Parameters:
inputs (dict where the values are typically torch.Tensor) –
Model inputs. The most common keys are:
"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states
role (str, optional) – Role play by the model (default:
""
)
- Raises:
NotImplementedError – Unsupported action space
- Returns:
Model output. The first component is the action to be taken by the agent
- Return type:
tuple of torch.Tensor, None, and dict
- save(path: str, state_dict: dict | None = None) None #
Save the model to the specified path
- Parameters:
Example:
# save the current model to the specified path >>> model.save("/tmp/model.pt") # save an older version of the model to the specified path >>> old_state_dict = copy.deepcopy(model.state_dict()) >>> # ... >>> model.save("/tmp/model.pt", old_state_dict)
- set_mode(mode: str) None #
Set the model mode (training or evaluation)
- Parameters:
mode (str) – Mode:
"train"
for training or"eval"
for evaluation. See torch.nn.Module.train- Raises:
ValueError – If the mode is not
"train"
or"eval"
- tensor_to_space(tensor: torch.Tensor, space: gym.Space | gymnasium.Space, start: int = 0) torch.Tensor | dict #
Map a flat tensor to a Gym/Gymnasium space
The mapping is done in the following way:
Tensors belonging to Discrete spaces are returned without modification
Tensors belonging to Box spaces are reshaped to the corresponding space shape keeping the first dimension (number of samples) as they are
Tensors belonging to Dict spaces are mapped into a dictionary with the same keys as the original space
- Parameters:
tensor (torch.Tensor) – Tensor to map from
space (gym.Space or gymnasium.Space) – Space to map the tensor to
start (int, optional) – Index of the first element of the tensor to map (default:
0
)
- Raises:
ValueError – If the space is not supported
- Returns:
Mapped tensor or dictionary
- Return type:
torch.Tensor or dict
Example:
>>> space = gym.spaces.Dict({'a': gym.spaces.Box(low=-1, high=1, shape=(2, 3)), ... 'b': gym.spaces.Discrete(4)}) >>> tensor = torch.tensor([[-0.3, -0.2, -0.1, 0.1, 0.2, 0.3, 2]]) >>> >>> model.tensor_to_space(tensor, space) {'a': tensor([[[-0.3000, -0.2000, -0.1000], [ 0.1000, 0.2000, 0.3000]]]), 'b': tensor([[2.]])}
- update_parameters(model: torch.nn.Module, polyak: float = 1) None #
Update internal parameters by hard or soft (polyak averaging) update
Hard update: \(\theta = \theta_{net}\)
Soft (polyak averaging) update: \(\theta = (1 - \rho) \theta + \rho \theta_{net}\)
- Parameters:
model (torch.nn.Module (skrl.models.torch.Model)) – Model used to update the internal parameters
polyak (float, optional) – Polyak hyperparameter between 0 and 1 (default:
1
). A hard update is performed when its value is 1
Example:
# hard update (from source model) >>> model.update_parameters(source_model) # soft update (from source model) >>> model.update_parameters(source_model, polyak=0.005)
API (JAX)#
- class skrl.models.jax.base.Model(*args: Any, **kwargs: Any)#
Bases:
Module
- __init__(observation_space: int | Sequence[int] | gym.Space | gymnasium.Space, action_space: int | Sequence[int] | gym.Space | gymnasium.Space, device: str | jax.Device | None = None, parent: Any | None = None, name: str | None = None) None #
Base class representing a function approximator
The following properties are defined:
device
(jax.Device): Device to be used for the computationsobservation_space
(int, sequence of int, gym.Space, gymnasium.Space): Observation/state spaceaction_space
(int, sequence of int, gym.Space, gymnasium.Space): Action spacenum_observations
(int): Number of elements in the observation/state spacenum_actions
(int): Number of elements in the action space
- Parameters:
observation_space (int, sequence of int, gym.Space, gymnasium.Space) – Observation/state space or shape. The
num_observations
property will contain the size of that spaceaction_space (int, sequence of int, gym.Space, gymnasium.Space) – Action space or shape. The
num_actions
property will contain the size of that spacedevice (str or jax.Device, optional) – Device on which a tensor/array is or will be allocated (default:
None
). If None, the device will be either"cuda"
if available or"cpu"
parent (str, optional) – The parent Module of this Module (default:
None
). It is a Flax reserved attributename (str, optional) – The name of this Module (default:
None
). It is a Flax reserved attribute
Custom models should override the
act
method:import flax.linen as nn from skrl.models.jax import Model class CustomModel(Model): def __init__(self, observation_space, action_space, device=None, **kwargs): Model.__init__(self, observation_space, action_space, device, **kwargs) # https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.IncorrectPostInitOverrideError flax.linen.Module.__post_init__(self) @nn.compact def __call__(self, inputs, role): x = nn.relu(nn.Dense(64)(inputs["states"])) x = nn.relu(nn.Dense(self.num_actions)(x)) return x, None, {}
- property device#
Device to be used for the computations
- property observation_space#
Observation/state space. It is a replica of the class constructor parameter of the same name
- property action_space#
Action space. It is a replica of the class constructor parameter of the same name
- property num_observations#
Number of elements in the observation/state space
- property num_actions#
Number of elements in the action space
- _get_space_size(space: int | Sequence[int] | gym.Space | gymnasium.Space, number_of_elements: bool = True) int #
Get the size (number of elements) of a space
- Parameters:
space (int, sequence of int, gym.Space, or gymnasium.Space) – Space or shape from which to obtain the number of elements
number_of_elements (bool, optional) – Whether the number of elements occupied by the space is returned (default:
True
). IfFalse
, the shape of the space is returned. It only affects Discrete and MultiDiscrete spaces
- Raises:
ValueError – If the space is not supported
- Returns:
Size of the space (number of elements)
- Return type:
Example:
# from int >>> model._get_space_size(2) 2 # from sequence of int >>> model._get_space_size([2, 3]) 6 # Box space >>> space = gym.spaces.Box(low=-1, high=1, shape=(2, 3)) >>> model._get_space_size(space) 6 # Discrete space >>> space = gym.spaces.Discrete(4) >>> model._get_space_size(space) 4 >>> model._get_space_size(space, number_of_elements=False) 1 # MultiDiscrete space >>> space = gym.spaces.MultiDiscrete([5, 3, 2]) >>> model._get_space_size(space) 10 >>> model._get_space_size(space, number_of_elements=False) 3 # Dict space >>> space = gym.spaces.Dict({'a': gym.spaces.Box(low=-1, high=1, shape=(2, 3)), ... 'b': gym.spaces.Discrete(4)}) >>> model._get_space_size(space) 10 >>> model._get_space_size(space, number_of_elements=False) 7
- act(inputs: Mapping[str, ndarray | jax.Array | Any], role: str = '', params: jax.Array | None = None) Tuple[jax.Array, jax.Array | None, Mapping[str, jax.Array | Any]] #
Act according to the specified behavior (to be implemented by the inheriting classes)
Agents will call this method to obtain the decision to be taken given the state of the environment. The classes that inherit from the latter must only implement the
.__call__()
method- Parameters:
inputs (dict where the values are typically np.ndarray or jax.Array) –
Model inputs. The most common keys are:
"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states
role (str, optional) – Role play by the model (default:
""
)params (jnp.array) – Parameters used to compute the output (default:
None
). IfNone
, internal parameters will be used
- Raises:
NotImplementedError – Child class must implement this method
- Returns:
Model output. The first component is the action to be taken by the agent. The second component is the log of the probability density function for stochastic models or None for deterministic models. The third component is a dictionary containing extra output values
- Return type:
- device: str | jax.Device | None = None#
- freeze_parameters(freeze: bool = True) None #
Freeze or unfreeze internal parameters
Note
This method does nothing, just maintains compatibility with other ML frameworks
- Parameters:
freeze (bool, optional) – Freeze the internal parameters if True, otherwise unfreeze them (default:
True
)
Example:
# freeze model parameters >>> model.freeze_parameters(True) # unfreeze model parameters >>> model.freeze_parameters(False)
- get_specification() Mapping[str, Any] #
Returns the specification of the model
The following keys are used by the agents for initialization:
"rnn"
: Recurrent Neural Network (RNN) specification for RNN, LSTM and GRU layers/cells"sizes"
: List of RNN shapes (number of layers, number of environments, number of features in the RNN state). There must be as many tuples as there are states in the recurrent layer/cell. E.g., LSTM has 2 states (hidden and cell).
- Returns:
Dictionary containing advanced specification of the model
- Return type:
Example:
# model with a LSTM layer. # - number of layers: 1 # - number of environments: 4 # - number of features in the RNN state: 64 >>> model.get_specification() {'rnn': {'sizes': [(1, 4, 64), (1, 4, 64)]}}
- init_biases(method_name: str = 'constant_', *args, **kwargs) None #
Initialize the model biases according to the specified method name
Method names are from the flax.linen.initializers module. Allowed method names are uniform, normal, constant, etc.
- Parameters:
method_name (str, optional) –
flax.linen.initializers method name (default:
"normal"
)args (tuple, optional) – Positional arguments of the method to be called
kwargs (dict, optional) – Key-value arguments of the method to be called
Example:
# initialize all biases with a constant value (0) >>> model.init_biases(method_name="constant_", val=0) # initialize all biases with normal distribution with mean 0 and standard deviation 0.25 >>> model.init_biases(method_name="normal_", mean=0.0, std=0.25)
- init_parameters(method_name: str = 'normal', *args, **kwargs) None #
Initialize the model parameters according to the specified method name
Method names are from the flax.linen.initializers module. Allowed method names are uniform, normal, constant, etc.
- Parameters:
method_name (str, optional) –
flax.linen.initializers method name (default:
"normal"
)args (tuple, optional) – Positional arguments of the method to be called
kwargs (dict, optional) – Key-value arguments of the method to be called
Example:
# initialize all parameters with an orthogonal distribution with a scale of 0.5 >>> model.init_parameters("orthogonal", scale=0.5) # initialize all parameters as a normal distribution with a standard deviation of 0.1 >>> model.init_parameters("normal", stddev=0.1)
- init_state_dict(role: str, inputs: Mapping[str, ndarray | jax.Array] = {}, key: jax.Array | None = None) None #
Initialize state dictionary
- Parameters:
role (str) – Role play by the model
inputs (dict of np.ndarray or jax.Array, optional) –
Model inputs. The most common keys are:
"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states
If not specified, the keys will be populated with observation and action space samples
key (jax.Array, optional) – Pseudo-random number generator (PRNG) key (default:
None
). If not provided, the skrl’s PRNG key (config.jax.key
) will be used
- init_weights(method_name: str = 'normal', *args, **kwargs) None #
Initialize the model weights according to the specified method name
Method names are from the flax.linen.initializers module. Allowed method names are uniform, normal, constant, etc.
- Parameters:
method_name (str, optional) –
flax.linen.initializers method name (default:
"normal"
)args (tuple, optional) – Positional arguments of the method to be called
kwargs (dict, optional) – Key-value arguments of the method to be called
Example:
# initialize all weights with uniform distribution in range [-0.1, 0.1] >>> model.init_weights(method_name="uniform_", a=-0.1, b=0.1) # initialize all weights with normal distribution with mean 0 and standard deviation 0.25 >>> model.init_weights(method_name="normal_", mean=0.0, std=0.25)
- load(path: str) None #
Load the model from the specified path
- Parameters:
path (str) – Path to load the model from
Example:
# load the model >>> model = Model(observation_space, action_space) >>> model.load("model.flax")
- migrate(state_dict: Mapping[str, Any] | None = None, path: str | None = None, name_map: Mapping[str, str] = {}, auto_mapping: bool = True, verbose: bool = False) bool #
Migrate the specified extrernal model’s state dict to the current model
Warning
This method is not implemented yet, just maintains compatibility with other ML frameworks
- Raises:
NotImplementedError – Not implemented
- random_act(inputs: Mapping[str, ndarray | jax.Array | Any], role: str = '', params: jax.Array | None = None) Tuple[ndarray | jax.Array, ndarray | jax.Array | None, Mapping[str, ndarray | jax.Array | Any]] #
Act randomly according to the action space
- Parameters:
inputs (dict where the values are typically np.ndarray or jax.Array) –
Model inputs. The most common keys are:
"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states
role (str, optional) – Role play by the model (default:
""
)params (jnp.array) – Parameters used to compute the output (default:
None
). IfNone
, internal parameters will be used
- Raises:
NotImplementedError – Unsupported action space
- Returns:
Model output. The first component is the action to be taken by the agent
- Return type:
- save(path: str, state_dict: dict | None = None) None #
Save the model to the specified path
- Parameters:
Example:
# save the current model to the specified path >>> model.save("/tmp/model.flax") # TODO: save an older version of the model to the specified path
- set_mode(mode: str) None #
Set the model mode (training or evaluation)
- Parameters:
mode (str) – Mode:
"train"
for training or"eval"
for evaluation- Raises:
ValueError – If the mode is not
"train"
or"eval"
- tensor_to_space(tensor: ndarray | jax.Array, space: gym.Space | gymnasium.Space, start: int = 0) ndarray | jax.Array | dict #
Map a flat tensor to a Gym/Gymnasium space
The mapping is done in the following way:
Tensors belonging to Discrete spaces are returned without modification
Tensors belonging to Box spaces are reshaped to the corresponding space shape keeping the first dimension (number of samples) as they are
Tensors belonging to Dict spaces are mapped into a dictionary with the same keys as the original space
- Parameters:
- Raises:
ValueError – If the space is not supported
- Returns:
Mapped tensor or dictionary
- Return type:
Example:
>>> space = gym.spaces.Dict({'a': gym.spaces.Box(low=-1, high=1, shape=(2, 3)), ... 'b': gym.spaces.Discrete(4)}) >>> tensor = jnp.array([[-0.3, -0.2, -0.1, 0.1, 0.2, 0.3, 2]]) >>> >>> model.tensor_to_space(tensor, space) {'a': Array([[[-0.3, -0.2, -0.1], [ 0.1, 0.2, 0.3]]], dtype=float32), 'b': Array([[2.]], dtype=float32)}
- update_parameters(model: flax.linen.Module, polyak: float = 1) None #
Update internal parameters by hard or soft (polyak averaging) update
Hard update: \(\theta = \theta_{net}\)
Soft (polyak averaging) update: \(\theta = (1 - \rho) \theta + \rho \theta_{net}\)
- Parameters:
model (flax.linen.Module (skrl.models.jax.Model)) – Model used to update the internal parameters
polyak (float, optional) – Polyak hyperparameter between 0 and 1 (default:
1
). A hard update is performed when its value is 1
Example:
# hard update (from source model) >>> model.update_parameters(source_model) # soft update (from source model) >>> model.update_parameters(source_model, polyak=0.005)