Shared model#

Sometimes it is desirable to define models that use shared layers or network to represent multiple function approximators. This practice, known as shared parameters (or parameter sharing), shared layers, shared model, shared networks or joint architecture among others, is typically justified by the following criteria:

  • Learning the same characteristics, especially when processing large inputs (such as images, e.g.).

  • Reduce the number of parameters in the whole system.

  • Make the computation more efficient.


By combining the implemented mixins, it is possible to define shared models with skrl. In these cases, the use of the role argument (a Python string) is relevant. The agents will call the models by setting the role argument according to their requirements. Visit each agent’s documentation (Key column of the table under Spaces and models section) to know the possible values that this parameter can take.

The code snippet below shows how to define a shared model. The following practices for building shared models can be identified:

  • The definition of multiple inheritance must always include the Model base class at the end.

  • The Model base class constructor must be invoked before the mixins constructor.

  • All mixin constructors must be invoked.

    • Specify role argument is optional if all constructors belong to different mixins.

    • If multiple models of the same mixin type are required, the same constructor must be invoked as many times as needed. To do so, it is mandatory to specify the role argument.

  • The .act(...) method needs to be overridden to disambiguate its call.

  • The same instance of the shared model must be passed to all keys involved.

import torch
import torch.nn as nn

from skrl.models.torch import Model, GaussianMixin, DeterministicMixin

# define the shared model
class SharedModel(GaussianMixin, DeterministicMixin, Model):
    def __init__(self, observation_space, action_space, device, clip_actions=False,
                clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"):
        Model.__init__(self, observation_space, action_space, device)
        GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction, role="policy")
        DeterministicMixin.__init__(self, clip_actions, role="value")

        # shared layers/network = nn.Sequential(nn.Linear(self.num_observations, 32),
                                 nn.Linear(32, 32),

        # separated layers ("policy")
        self.mean_layer = nn.Linear(32, self.num_actions)
        self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))

        # separated layer ("value")
        self.value_layer = nn.Linear(32, 1)

    # override the .act(...) method to disambiguate its call
    def act(self, inputs, role):
        if role == "policy":
            return GaussianMixin.act(self, inputs, role)
        elif role == "value":
            return DeterministicMixin.act(self, inputs, role)

    # forward the input to compute model output according to the specified role
    def compute(self, inputs, role):
        if role == "policy":
            return self.mean_layer(["states"])), self.log_std_parameter, {}
        elif role == "value":
            return self.value_layer(["states"])), {}

# instantiate the shared model and pass the same instance to the other key
models = {}
models["policy"] = SharedModel(env.observation_space, env.action_space, env.device)
models["value"] = models["policy"]