Multi-Categorical model

Multi-Categorical models run discrete-domain stochastic policies.



skrl provides a Python mixin (MultiCategoricalMixin) to assist in the creation of these types of models, allowing users to have full control over the function approximator definitions and architectures. Note that the use of this mixin must comply with the following rules:

  • The definition of multiple inheritance must always include the Model base class at the end.

  • The Model base class constructor must be invoked before the mixins constructor.

Warning

For models in JAX/Flax it is imperative to define all parameters (except observation_space, action_space and device) with default values to avoid errors (TypeError: __init__() missing N required positional argument) during initialization.

In addition, it is necessary to initialize the model’s state_dict (via the init_state_dict method) after its instantiation to avoid errors (AttributeError: object has no attribute "state_dict". If "state_dict" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply') during its use.

class MultiCategoricalModel(MultiCategoricalMixin, Model):
    def __init__(self, observation_space, action_space, device=None, unnormalized_log_prob=True, reduction="sum"):
        Model.__init__(self, observation_space, action_space, device)
        MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)

Concept

Multi-Categorical modelMulti-Categorical model

Usage

  • Multi-Layer Perceptron (MLP)

  • Convolutional Neural Network (CNN)

  • Recurrent Neural Network (RNN)

  • Gated Recurrent Unit RNN (GRU)

  • Long Short-Term Memory RNN (LSTM)

../../_images/model_categorical_mlp-light.svg../../_images/model_categorical_mlp-dark.svg
import torch
import torch.nn as nn

from skrl.models.torch import Model, MultiCategoricalMixin


# define the model
class MLP(MultiCategoricalMixin, Model):
    def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True, reduction="sum"):
        Model.__init__(self, observation_space, action_space, device)
        MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)

        self.net = nn.Sequential(nn.Linear(self.num_observations, 64),
                                 nn.ReLU(),
                                 nn.Linear(64, 32),
                                 nn.ReLU(),
                                 nn.Linear(32, self.num_actions))

    def compute(self, inputs, role):
        return self.net(inputs["states"]), {}


# instantiate the model (assumes there is a wrapped environment: env)
policy = MLP(observation_space=env.observation_space,
             action_space=env.action_space,
             device=env.device,
             unnormalized_log_prob=True,
             reduction="sum")

API (PyTorch)

class skrl.models.torch.multicategorical.MultiCategoricalMixin(unnormalized_log_prob: bool = True, reduction: str = 'sum', role: str = '')

Bases: object

__init__(unnormalized_log_prob: bool = True, reduction: str = 'sum', role: str = '') None

MultiCategorical mixin model (stochastic model)

Parameters:
  • unnormalized_log_prob (bool, optional) – Flag to indicate how to be interpreted the model’s output (default: True). If True, the model’s output is interpreted as unnormalized log probabilities (it can be any real number), otherwise as normalized probabilities (the output must be non-negative, finite and have a non-zero sum)

  • reduction (str, optional) – Reduction method for returning the log probability density function: (default: "sum"). Supported values are "mean", "sum", "prod" and "none". If “none", the log probability density function is returned as a tensor of shape (num_samples, num_actions) instead of (num_samples, 1)

  • role (str, optional) – Role play by the model (default: "")

Raises:

ValueError – If the reduction method is not valid

Example:

# define the model
>>> import torch
>>> import torch.nn as nn
>>> from skrl.models.torch import Model, MultiCategoricalMixin
>>>
>>> class Policy(MultiCategoricalMixin, Model):
...     def __init__(self, observation_space, action_space, device="cuda:0", unnormalized_log_prob=True, reduction="sum"):
...         Model.__init__(self, observation_space, action_space, device)
...         MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
...
...         self.net = nn.Sequential(nn.Linear(self.num_observations, 32),
...                                  nn.ELU(),
...                                  nn.Linear(32, 32),
...                                  nn.ELU(),
...                                  nn.Linear(32, self.num_actions))
...
...     def compute(self, inputs, role):
...         return self.net(inputs["states"]), {}
...
>>> # given an observation_space: gym.spaces.Box with shape (4,)
>>> # and an action_space: gym.spaces.MultiDiscrete with nvec = [3, 2]
>>> model = Policy(observation_space, action_space)
>>>
>>> print(model)
Policy(
  (net): Sequential(
    (0): Linear(in_features=4, out_features=32, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=32, out_features=5, bias=True)
  )
)
act(inputs: Mapping[str, torch.Tensor | Any], role: str = '') Tuple[torch.Tensor, torch.Tensor | None, Mapping[str, torch.Tensor | Any]]

Act stochastically in response to the state of the environment

Parameters:
  • inputs (dict where the values are typically torch.Tensor) –

    Model inputs. The most common keys are:

    • "states": state of the environment used to make the decision

    • "taken_actions": actions taken by the policy for the given states

  • role (str, optional) – Role play by the model (default: "")

Returns:

Model output. The first component is the action to be taken by the agent. The second component is the log of the probability density function. The third component is a dictionary containing the network output "net_output" and extra output values

Return type:

tuple of torch.Tensor, torch.Tensor or None, and dict

Example:

>>> # given a batch of sample states with shape (4096, 4)
>>> actions, log_prob, outputs = model.act({"states": states})
>>> print(actions.shape, log_prob.shape, outputs["net_output"].shape)
torch.Size([4096, 2]) torch.Size([4096, 1]) torch.Size([4096, 5])
distribution(role: str = '') torch.distributions.Categorical

Get the current distribution of the model

Returns:

First distributions of the model

Return type:

torch.distributions.Categorical

Parameters:

role (str, optional) – Role play by the model (default: "")

Example:

>>> distribution = model.distribution()
>>> print(distribution)
Categorical(probs: torch.Size([10, 3]), logits: torch.Size([10, 3]))
get_entropy(role: str = '') torch.Tensor

Compute and return the entropy of the model

Returns:

Entropy of the model

Return type:

torch.Tensor

Parameters:

role (str, optional) – Role play by the model (default: "")

Example:

>>> entropy = model.get_entropy()
>>> print(entropy.shape)
torch.Size([4096, 1])

API (JAX)

class skrl.models.jax.multicategorical.MultiCategoricalMixin(unnormalized_log_prob: bool = True, reduction: str = 'sum', role: str = '')

Bases: object

__init__(unnormalized_log_prob: bool = True, reduction: str = 'sum', role: str = '') None

MultiCategorical mixin model (stochastic model)

Parameters:
  • unnormalized_log_prob (bool, optional) – Flag to indicate how to be interpreted the model’s output (default: True). If True, the model’s output is interpreted as unnormalized log probabilities (it can be any real number), otherwise as normalized probabilities (the output must be non-negative, finite and have a non-zero sum)

  • reduction (str, optional) – Reduction method for returning the log probability density function: (default: "sum"). Supported values are "mean", "sum", "prod" and "none". If “none", the log probability density function is returned as a tensor of shape (num_samples, num_actions) instead of (num_samples, 1)

  • role (str, optional) – Role play by the model (default: "")

Raises:

ValueError – If the reduction method is not valid

Example:

# define the model
>>> import flax.linen as nn
>>> from skrl.models.jax import Model, MultiCategoricalMixin
>>>
>>> class Policy(MultiCategoricalMixin, Model):
...     def __init__(self, observation_space, action_space, device=None, unnormalized_log_prob=True, reduction="sum", **kwargs):
...         Model.__init__(self, observation_space, action_space, device, **kwargs)
...         MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction)
...
...     @nn.compact  # marks the given module method allowing inlined submodules
...     def __call__(self, inputs, role):
...         x = nn.elu(nn.Dense(32)(inputs["states"]))
...         x = nn.elu(nn.Dense(32)(x))
...         x = nn.Dense(self.num_actions)(x)
...         return x, {}
...
>>> # given an observation_space: gym.spaces.Box with shape (4,)
>>> # and an action_space: gym.spaces.MultiDiscrete with nvec = [3, 2]
>>> model = Policy(observation_space, action_space)
>>>
>>> print(model)
Policy(
    # attributes
    observation_space = Box(-1.0, 1.0, (4,), float32)
    action_space = MultiDiscrete([3 2])
    device = StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)
)
act(inputs: Mapping[str, ndarray | jax.Array | Any], role: str = '', params: jax.Array | None = None) Tuple[jax.Array, jax.Array | None, Mapping[str, jax.Array | Any]]

Act stochastically in response to the state of the environment

Parameters:
  • inputs (dict where the values are typically np.ndarray or jax.Array) –

    Model inputs. The most common keys are:

    • "states": state of the environment used to make the decision

    • "taken_actions": actions taken by the policy for the given states

  • role (str, optional) – Role play by the model (default: "")

  • params (jnp.array) – Parameters used to compute the output (default: None). If None, internal parameters will be used

Returns:

Model output. The first component is the action to be taken by the agent. The second component is the log of the probability density function. The third component is a dictionary containing the network output "net_output" and extra output values

Return type:

tuple of jax.Array, jax.Array or None, and dict

Example:

>>> # given a batch of sample states with shape (4096, 4)
>>> actions, log_prob, outputs = model.act({"states": states})
>>> print(actions.shape, log_prob.shape, outputs["net_output"].shape)
(4096, 2) (4096, 1) (4096, 5)
get_entropy(logits: jax.Array, role: str = '') jax.Array

Compute and return the entropy of the model

Parameters:

role (str, optional) – Role play by the model (default: "")

Returns:

Entropy of the model

Return type:

jax.Array

Example:

# given a standard deviation array: stddev
>>> entropy = model.get_entropy(stddev)
>>> print(entropy.shape)
(4096, 8)