Gaussian model

Gaussian models run continuous-domain stochastic policies.



skrl provides a Python mixin (GaussianMixin) to assist in the creation of these types of models, allowing users to have full control over the function approximator definitions and architectures. Note that the use of this mixin must comply with the following rules:

  • The definition of multiple inheritance must always include the Model base class at the end.

  • The Model base class constructor must be invoked before the mixins constructor.

Note

For models in JAX/Flax it is imperative to define all parameters (except observation_space, state_space, action_space and device) with default values to avoid errors during initialization (TypeError: __init__() missing N required positional argument).

In addition, it is necessary to initialize the model’s state_dict (via the init_state_dict method) after its instantiation to avoid errors during its use (AttributeError: object has no attribute "state_dict". If "state_dict" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply').

class GaussianModel(GaussianMixin, Model):
    def __init__(
        self,
        observation_space,
        state_space,
        action_space,
        device,
        clip_actions=False,
        clip_mean_actions=False,
        clip_log_std=True,
        min_log_std=-20,
        max_log_std=2,
        reduction="sum",
    ):
        Model.__init__(
            self,
            observation_space=observation_space,
            state_space=state_space,
            action_space=action_space,
            device=device,
        )
        GaussianMixin.__init__(
            self,
            clip_actions=clip_actions,
            clip_mean_actions=clip_mean_actions,
            clip_log_std=clip_log_std,
            min_log_std=min_log_std,
            max_log_std=max_log_std,
            reduction=reduction,
        )

Concept

Gaussian model Gaussian model

Usage

  • Multi-Layer Perceptron (MLP)

  • Convolutional Neural Network (CNN)

  • Recurrent Neural Network (RNN)

  • Gated Recurrent Unit RNN (GRU)

  • Long Short-Term Memory RNN (LSTM)

../../_images/model_gaussian_mlp-light.svg ../../_images/model_gaussian_mlp-dark.svg

import torch
import torch.nn as nn

from skrl.models.torch import Model, GaussianMixin


# define the model
class MLP(GaussianMixin, Model):
    def __init__(
        self,
        observation_space,
        state_space,
        action_space,
        device,
        clip_actions=False,
        clip_mean_actions=False,
        clip_log_std=True,
        min_log_std=-20,
        max_log_std=2,
        reduction="sum",
    ):
        Model.__init__(
            self,
            observation_space=observation_space,
            state_space=state_space,
            action_space=action_space,
            device=device,
        )
        GaussianMixin.__init__(
            self,
            clip_actions=clip_actions,
            clip_mean_actions=clip_mean_actions,
            clip_log_std=clip_log_std,
            min_log_std=min_log_std,
            max_log_std=max_log_std,
            reduction=reduction,
        )

        self.net = nn.Sequential(
            nn.Linear(self.num_observations, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, self.num_actions),
            nn.Tanh(),
        )

        self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))

    def compute(self, inputs, role):
        return self.net(inputs["observations"]), {"log_std": self.log_std_parameter}


# instantiate the model (given a wrapped environment: `env`)
policy = MLP(
    observation_space=env.observation_space,
    state_space=env.state_space,
    action_space=env.action_space,
    device=env.device,
    clip_actions=True,
    clip_mean_actions=True,
    clip_log_std=True,
    min_log_std=-20,
    max_log_std=2,
    reduction="sum",
)

API


PyTorch

GaussianMixin

Gaussian mixin model (stochastic model).

class skrl.models.torch.gaussian.GaussianMixin(*, clip_actions: bool = False, clip_mean_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20, max_log_std: float = 2, reduction: Literal['mean', 'sum', 'prod', 'none'] = 'sum', role: str = '')[source]

Bases: object

Gaussian mixin model (stochastic model).

Parameters:
  • clip_actions – Flag to indicate whether the actions should be clipped to the action space.

  • clip_mean_actions – Flag to indicate whether the mean actions should be clipped to the action space. If True, the mean actions will be clipped before sampling the actions.

  • clip_log_std – Flag to indicate whether the log standard deviations should be clipped.

  • min_log_std – Minimum value of the log standard deviation if clip_log_std is True.

  • max_log_std – Maximum value of the log standard deviation if clip_log_std is True.

  • reduction – Reduction method for returning the log probability density function. If "none", the log probability density function is returned as a tensor of shape (num_samples, num_actions) instead of (num_samples, 1).

  • role – Role played by the model.

Raises:

ValueError – If the reduction method is not valid.

Methods:

act(inputs, *[, role])

Act stochastically in response to the observations/states of the environment.

distribution(*[, role])

Get the current distribution of the model.

get_entropy(*[, role])

Compute and return the entropy of the model.

act(inputs: dict[str, Any], *, role: str = '') tuple[torch.Tensor, dict[str, Any]][source]

Act stochastically in response to the observations/states of the environment.

Parameters:
  • inputs

    Model inputs. The most common keys are:

    • "observations": observation of the environment used to make the decision.

    • "states": state of the environment used to make the decision.

    • "taken_actions": actions taken by the policy for the given observations/states.

  • role – Role played by the model.

Returns:

Model output. The first component is the expected action/value returned by the model. The second component is a dictionary containing the following extra output values:

  • "log_std": log of the standard deviation.

  • "log_prob": log of the probability density function.

  • "mean_actions": mean actions (network output after optional clipping).

distribution(*, role: str = '') torch.distributions.Normal[source]

Get the current distribution of the model.

Parameters:

role – Role played by the model.

Returns:

Distribution of the model.

get_entropy(*, role: str = '') torch.Tensor[source]

Compute and return the entropy of the model.

Parameters:

role – Role played by the model.

Returns:

Entropy of the model.


JAX

GaussianMixin

Gaussian mixin model (stochastic model).

class skrl.models.jax.gaussian.GaussianMixin(*, clip_actions: bool = False, clip_mean_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20, max_log_std: float = 2, reduction: Literal['mean', 'sum', 'prod', 'none'] = 'sum', role: str = '')[source]

Bases: object

Gaussian mixin model (stochastic model).

Parameters:
  • clip_actions – Flag to indicate whether the actions should be clipped to the action space.

  • clip_mean_actions – Flag to indicate whether the mean actions should be clipped to the action space. If True, the mean actions will be clipped before sampling the actions.

  • clip_log_std – Flag to indicate whether the log standard deviations should be clipped.

  • min_log_std – Minimum value of the log standard deviation if clip_log_std is True.

  • max_log_std – Maximum value of the log standard deviation if clip_log_std is True.

  • reduction – Reduction method for returning the log probability density function. If "none", the log probability density function is returned as a tensor of shape (num_samples, num_actions) instead of (num_samples, 1).

  • role – Role played by the model.

Raises:

ValueError – If the reduction method is not valid.

Methods:

act(inputs, *[, role, params])

Act stochastically in response to the observations/states of the environment.

get_entropy(stddev, *[, role])

Compute and return the entropy of the model.

act(inputs: dict[str, Any], *, role: str = '', params: jax.Array | None = None) tuple[jax.Array, dict[str, Any]][source]

Act stochastically in response to the observations/states of the environment.

Parameters:
  • inputs

    Model inputs. The most common keys are:

    • "observations": observation of the environment used to make the decision.

    • "states": state of the environment used to make the decision.

    • "taken_actions": actions taken by the policy for the given observations/states.

  • role – Role played by the model.

  • params – Parameters used to compute the output. If not provided, internal parameters will be used.

Returns:

Model output. The first component is the expected action/value returned by the model. The second component is a dictionary containing the following extra output values:

  • "log_std": log of the standard deviation.

  • "log_prob": log of the probability density function.

  • "mean_actions": mean actions (network output after optional clipping).

get_entropy(stddev: jax.Array, *, role: str = '') jax.Array[source]

Compute and return the entropy of the model.

Parameters:
  • stddev – Model standard deviation.

  • role – Role played by the model.

Returns:

Entropy of the model.


Warp

GaussianMixin

Gaussian mixin model (stochastic model).