Shared model¶
Sometimes it is desirable to define models that use shared layers or network to represent multiple function approximators. This practice, known as shared parameters (or parameter sharing), shared layers, shared model, shared networks or joint architecture among others, is typically justified by the following criteria:
Learning the same characteristics, especially when processing large inputs (such as images, e.g.).
Reduce the number of parameters in the whole system.
Make the computation more efficient (single forward-pass).
Implementation¶
By combining the implemented mixins, it is possible to define shared models with skrl. In these cases, the use of the role
argument (a Python string) is relevant. The agents will call the models by setting the role
argument according to their requirements. Visit each agent’s documentation (Key column of the table under Spaces and models section) to know the possible values that this parameter can take.
The code snippet below shows how to define a shared model. The following practices for building shared models can be identified:
The definition of multiple inheritance must always include the Model base class at the end.
The Model base class constructor must be invoked before the mixins constructor.
All mixin constructors must be invoked.
Specify
role
argument is optional if all constructors belong to different mixins.If multiple models of the same mixin type are required, the same constructor must be invoked as many times as needed. To do so, it is mandatory to specify the
role
argument.
The
.act(...)
method needs to be overridden to disambiguate its call.The same instance of the shared model must be passed to all keys involved.
Warning
The implementation described for single forward-pass requires that the value-pass always follows the policy-pass (e.g.: PPO
) which may not be generalized to other algorithms.
If this requirement is not met, other forms of “chaching” the shared layers/network output could be implemented.
import torch
import torch.nn as nn
from skrl.models.torch import Model, GaussianMixin, DeterministicMixin
# define the shared model
class SharedModel(GaussianMixin, DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False,
clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction, role="policy")
DeterministicMixin.__init__(self, clip_actions, role="value")
# shared layers/network
self.net = nn.Sequential(nn.Linear(self.num_observations, 32),
nn.ELU(),
nn.Linear(32, 32),
nn.ELU())
# separated layers ("policy")
self.mean_layer = nn.Linear(32, self.num_actions)
self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
# separated layer ("value")
self.value_layer = nn.Linear(32, 1)
# override the .act(...) method to disambiguate its call
def act(self, inputs, role):
if role == "policy":
return GaussianMixin.act(self, inputs, role)
elif role == "value":
return DeterministicMixin.act(self, inputs, role)
# forward the input to compute model output according to the specified role
def compute(self, inputs, role):
if role == "policy":
# save shared layers/network output to perform a single forward-pass
self._shared_output = self.net(inputs["states"])
return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
# use saved shared layers/network output to perform a single forward-pass, if it was saved
shared_output = self.net(inputs["states"]) if self._shared_output is None else self._shared_output
self._shared_output = None # reset saved shared output to prevent the use of erroneous data in subsequent steps
return self.value_layer(shared_output), {}
# instantiate the shared model and pass the same instance to the other key
models = {}
models["policy"] = SharedModel(env.observation_space, env.action_space, env.device)
models["value"] = models["policy"]
import torch
import torch.nn as nn
from skrl.models.torch import Model, GaussianMixin, DeterministicMixin
# define the shared model
class SharedModel(GaussianMixin, DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False,
clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction, role="policy")
DeterministicMixin.__init__(self, clip_actions, role="value")
# shared layers/network
self.net = nn.Sequential(nn.Linear(self.num_observations, 32),
nn.ELU(),
nn.Linear(32, 32),
nn.ELU())
# separated layers ("policy")
self.mean_layer = nn.Linear(32, self.num_actions)
self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
# separated layer ("value")
self.value_layer = nn.Linear(32, 1)
# override the .act(...) method to disambiguate its call
def act(self, inputs, role):
if role == "policy":
return GaussianMixin.act(self, inputs, role)
elif role == "value":
return DeterministicMixin.act(self, inputs, role)
# forward the input to compute model output according to the specified role
def compute(self, inputs, role):
if role == "policy":
return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
elif role == "value":
return self.value_layer(self.net(inputs["states"])), {}
# instantiate the shared model and pass the same instance to the other key
models = {}
models["policy"] = SharedModel(env.observation_space, env.action_space, env.device)
models["value"] = models["policy"]