Multi-Categorical model¶
Multi-Categorical models run discrete-domain stochastic policies.
skrl provides a Python mixin (MultiCategoricalMixin) to assist in the creation of these types of models,
allowing users to have full control over the function approximator definitions and architectures.
Note that the use of this mixin must comply with the following rules:
The definition of multiple inheritance must always include the Model base class at the end.
The Model base class constructor must be invoked before the mixins constructor.
Note
For models in JAX/Flax it is imperative to define all parameters (except observation_space,
state_space, action_space and device) with default values to avoid errors during initialization
(TypeError: __init__() missing N required positional argument).
In addition, it is necessary to initialize the model’s state_dict (via the init_state_dict method) after
its instantiation to avoid errors during its use (AttributeError: object has no attribute "state_dict".
If "state_dict" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply').
class MultiCategoricalModel(MultiCategoricalMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
unnormalized_log_prob=True,
reduction="sum",
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
MultiCategoricalMixin.__init__(self, unnormalized_log_prob=unnormalized_log_prob, reduction=reduction)
class MultiCategoricalModel(MultiCategoricalMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
unnormalized_log_prob=True,
reduction="sum",
**kwargs,
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
**kwargs,
)
MultiCategoricalMixin.__init__(self, unnormalized_log_prob=unnormalized_log_prob, reduction=reduction)
Concept¶
Usage¶
Multi-Layer Perceptron (MLP)
Convolutional Neural Network (CNN)
Recurrent Neural Network (RNN)
Gated Recurrent Unit RNN (GRU)
Long Short-Term Memory RNN (LSTM)
import torch
import torch.nn as nn
from skrl.models.torch import Model, MultiCategoricalMixin
# define the model
class MLP(MultiCategoricalMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
unnormalized_log_prob=True,
reduction="sum",
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
MultiCategoricalMixin.__init__(self, unnormalized_log_prob=unnormalized_log_prob, reduction=reduction)
self.net = nn.Sequential(
nn.Linear(self.num_observations, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, self.num_actions),
)
def compute(self, inputs, role):
return self.net(inputs["observations"]), {}
# instantiate the model (given a wrapped environment: `env`)
policy = MLP(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
unnormalized_log_prob=True,
reduction="sum",
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from skrl.models.torch import Model, MultiCategoricalMixin
# define the model
class MLP(MultiCategoricalMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
unnormalized_log_prob=True,
reduction="sum",
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
MultiCategoricalMixin.__init__(self, unnormalized_log_prob=unnormalized_log_prob, reduction=reduction)
self.fc1 = nn.Linear(self.num_observations, 64)
self.fc2 = nn.Linear(64, 32)
self.logits = nn.Linear(32, self.num_actions)
def compute(self, inputs, role):
x = self.fc1(inputs["observations"])
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
return self.logits(x), {}
# instantiate the model (given a wrapped environment: `env`)
policy = MLP(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
unnormalized_log_prob=True,
reduction="sum",
)
import flax.linen as nn
from skrl.models.jax import Model, MultiCategoricalMixin
# define the model
class MLP(MultiCategoricalMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
unnormalized_log_prob=True,
reduction="sum",
**kwargs,
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
**kwargs,
)
MultiCategoricalMixin.__init__(self, unnormalized_log_prob=unnormalized_log_prob, reduction=reduction)
def setup(self):
self.fc1 = nn.Dense(64)
self.fc2 = nn.Dense(32)
self.fc3 = nn.Dense(self.num_actions)
def __call__(self, inputs, role):
x = self.fc1(inputs["observations"])
x = nn.relu(x)
x = self.fc2(x)
x = nn.relu(x)
x = self.fc3(x)
return x, {}
# instantiate the model (given a wrapped environment: `env`)
policy = MLP(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
unnormalized_log_prob=True,
reduction="sum",
)
# initialize model's state dict
policy.init_state_dict(role="policy")
import flax.linen as nn
from skrl.models.jax import Model, MultiCategoricalMixin
# define the model
class MLP(MultiCategoricalMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
unnormalized_log_prob=True,
reduction="sum",
**kwargs,
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
**kwargs,
)
MultiCategoricalMixin.__init__(self, unnormalized_log_prob=unnormalized_log_prob, reduction=reduction)
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.Dense(64)(inputs["observations"])
x = nn.relu(x)
x = nn.Dense(32)(x)
x = nn.relu(x)
x = nn.Dense(self.num_actions)(x)
return x, {}
# instantiate the model (given a wrapped environment: `env`)
policy = MLP(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
unnormalized_log_prob=True,
reduction="sum",
)
# initialize model's state dict
policy.init_state_dict(role="policy")
import torch
import torch.nn as nn
from skrl.models.torch import Model, MultiCategoricalMixin
# define the model
class CNN(MultiCategoricalMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
unnormalized_log_prob=True,
reduction="sum",
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
MultiCategoricalMixin.__init__(self, unnormalized_log_prob=unnormalized_log_prob, reduction=reduction)
self.net = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, 16),
nn.Tanh(),
nn.Linear(16, 64),
nn.Tanh(),
nn.Linear(64, 32),
nn.Tanh(),
nn.Linear(32, self.num_actions),
)
def compute(self, inputs, role):
# permute (samples, width * height * channels) -> (samples, channels, width, height)
return self.net(inputs["observations"].view(-1, *self.observation_space.shape).permute(0, 3, 1, 2)), {}
# instantiate the model (given a wrapped environment: `env`)
policy = CNN(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
unnormalized_log_prob=True,
reduction="sum",
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from skrl.models.torch import Model, MultiCategoricalMixin
# define the model
class CNN(MultiCategoricalMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
unnormalized_log_prob=True,
reduction="sum",
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
MultiCategoricalMixin.__init__(self, unnormalized_log_prob=unnormalized_log_prob, reduction=reduction)
self.conv1 = nn.Conv2d(3, 32, kernel_size=8, stride=4)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 16)
self.fc3 = nn.Linear(16, 64)
self.fc4 = nn.Linear(64, 32)
self.fc5 = nn.Linear(32, self.num_actions)
def compute(self, inputs, role):
# permute (samples, width * height * channels) -> (samples, channels, width, height)
x = inputs["observations"].view(-1, *self.observation_space.shape).permute(0, 3, 1, 2)
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = self.conv3(x)
x = F.relu(x)
x = torch.flatten(x, start_dim=1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = torch.tanh(x)
x = self.fc3(x)
x = torch.tanh(x)
x = self.fc4(x)
x = torch.tanh(x)
x = self.fc5(x)
return x, {}
# instantiate the model (given a wrapped environment: `env`)
policy = CNN(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
unnormalized_log_prob=True,
reduction="sum",
)
import flax.linen as nn
from skrl.models.jax import Model, MultiCategoricalMixin
# define the model
class CNN(MultiCategoricalMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
unnormalized_log_prob=True,
reduction="sum",
**kwargs,
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
**kwargs,
)
MultiCategoricalMixin.__init__(self, unnormalized_log_prob=unnormalized_log_prob, reduction=reduction)
def setup(self):
self.conv1 = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")
self.conv2 = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")
self.conv3 = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")
self.fc1 = nn.Dense(512)
self.fc2 = nn.Dense(16)
self.fc3 = nn.Dense(64)
self.fc4 = nn.Dense(32)
self.fc5 = nn.Dense(self.num_actions)
def __call__(self, inputs, role):
x = inputs["observations"].reshape((-1, *self.observation_space.shape))
x = self.conv1(x)
x = nn.relu(x)
x = self.conv2(x)
x = nn.relu(x)
x = self.conv3(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = self.fc1(x)
x = nn.relu(x)
x = self.fc2(x)
x = nn.tanh(x)
x = self.fc3(x)
x = nn.tanh(x)
x = self.fc4(x)
x = nn.tanh(x)
x = self.fc5(x)
return x, {}
# instantiate the model (given a wrapped environment: `env`)
policy = CNN(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
unnormalized_log_prob=True,
reduction="sum",
)
# initialize model's state dict
policy.init_state_dict(role="policy")
import flax.linen as nn
from skrl.models.jax import Model, MultiCategoricalMixin
# define the model
class CNN(MultiCategoricalMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
unnormalized_log_prob=True,
reduction="sum",
**kwargs,
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
**kwargs,
)
MultiCategoricalMixin.__init__(self, unnormalized_log_prob=unnormalized_log_prob, reduction=reduction)
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = inputs["observations"].reshape((-1, *self.observation_space.shape))
x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x)
x = nn.relu(x)
x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x)
x = nn.relu(x)
x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = nn.Dense(512)(x)
x = nn.relu(x)
x = nn.Dense(16)(x)
x = nn.tanh(x)
x = nn.Dense(64)(x)
x = nn.tanh(x)
x = nn.Dense(32)(x)
x = nn.tanh(x)
x = nn.Dense(self.num_actions)(x)
return x, {}
# instantiate the model (given a wrapped environment: `env`)
policy = CNN(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
unnormalized_log_prob=True,
reduction="sum",
)
# initialize model's state dict
policy.init_state_dict(role="policy")
where:
The following points are relevant in the definition of recurrent models:
- The
.get_specification()method must be overwritten to return a dictionary with key"rnn". Under such key, a sub-dictionary must contain the following items: The sequence length (under sub-key
"sequence_length").A list of the dimensions for each initial hidden/cell state (under sub-key
"sizes").
- The
- The
.compute()method’sinputsparameter may include the following items: "observations": observations of the environment."states": state of the environment."taken_actions": actions taken by the policy for the given observations and/or states, if applicable."terminated": episode termination status for sampled environment transitions.This key is only defined during the training process.
"rnn": list of initial hidden states ordered according to the model specification.
- The
The
.compute()method should include, under the"rnn"key of the returned dictionary, a list of each final hidden/cell state (when applicable).
import torch
import torch.nn as nn
from skrl.models.torch import Model, MultiCategoricalMixin
# define the model
class RNN(MultiCategoricalMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
unnormalized_log_prob=True,
reduction="sum",
num_envs=1,
num_layers=1,
hidden_size=64,
sequence_length=10,
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
MultiCategoricalMixin.__init__(self, unnormalized_log_prob=unnormalized_log_prob, reduction=reduction)
self.num_envs = num_envs
self.num_layers = num_layers
self.hidden_size = hidden_size # Hout
self.sequence_length = sequence_length
self.rnn = nn.RNN(
input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True, # (batch, sequence, features)
)
self.net = nn.Sequential(
nn.Linear(self.hidden_size, 64), nn.ReLU(), nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, self.num_actions)
)
def get_specification(self):
# batch size (N) is the number of envs during rollout
return {
"rnn": {
"sequence_length": self.sequence_length,
"sizes": [
(self.num_layers, self.num_envs, self.hidden_size) # hidden states (D ∗ num_layers, N, Hout)
],
}
}
def compute(self, inputs, role):
observations = inputs["observations"]
terminated = inputs.get("terminated", None)
hidden_states = inputs["rnn"][0]
# training
if self.training:
rnn_input = observations.view(
-1, self.sequence_length, observations.shape[-1]
) # (N, L, Hin): N=batch_size, L=sequence_length
hidden_states = hidden_states.view(
self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]
) # (D * num_layers, N, L, Hout)
# get the hidden states corresponding to the initial sequence
hidden_states = hidden_states[:, :, 0, :].contiguous() # (D * num_layers, N, Hout)
# reset the RNN state in the middle of a sequence
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = (
[0]
+ (terminated[:, :-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist()
+ [self.sequence_length]
)
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i + 1]
rnn_output, hidden_states = self.rnn(rnn_input[:, i0:i1, :], hidden_states)
hidden_states[:, (terminated[:, i1 - 1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_output = torch.cat(rnn_outputs, dim=1)
# no need to reset the RNN state in the sequence
else:
rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
# rollout
else:
rnn_input = observations.view(-1, 1, observations.shape[-1]) # (N, L, Hin): N=num_envs, L=1
rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
# flatten the RNN output
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
return self.net(rnn_output), {"rnn": [hidden_states]}
# instantiate the model (given a wrapped environment: `env`)
policy = RNN(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
unnormalized_log_prob=True,
reduction="sum",
num_envs=env.num_envs,
num_layers=1,
hidden_size=64,
sequence_length=10,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from skrl.models.torch import Model, MultiCategoricalMixin
# define the model
class RNN(MultiCategoricalMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
unnormalized_log_prob=True,
reduction="sum",
num_envs=1,
num_layers=1,
hidden_size=64,
sequence_length=10,
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
MultiCategoricalMixin.__init__(self, unnormalized_log_prob=unnormalized_log_prob, reduction=reduction)
self.num_envs = num_envs
self.num_layers = num_layers
self.hidden_size = hidden_size # Hout
self.sequence_length = sequence_length
self.rnn = nn.RNN(
input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True, # (batch, sequence, features)
)
self.fc1 = nn.Linear(self.hidden_size, 64)
self.fc2 = nn.Linear(64, 32)
self.logits = nn.Linear(32, self.num_actions)
def get_specification(self):
# batch size (N) is the number of envs during rollout
return {
"rnn": {
"sequence_length": self.sequence_length,
"sizes": [
(self.num_layers, self.num_envs, self.hidden_size) # hidden states (D ∗ num_layers, N, Hout)
],
}
}
def compute(self, inputs, role):
observations = inputs["observations"]
terminated = inputs.get("terminated", None)
hidden_states = inputs["rnn"][0]
# training
if self.training:
rnn_input = observations.view(
-1, self.sequence_length, observations.shape[-1]
) # (N, L, Hin): N=batch_size, L=sequence_length
hidden_states = hidden_states.view(
self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]
) # (D * num_layers, N, L, Hout)
# get the hidden states corresponding to the initial sequence
hidden_states = hidden_states[:, :, 0, :].contiguous() # (D * num_layers, N, Hout)
# reset the RNN state in the middle of a sequence
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = (
[0]
+ (terminated[:, :-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist()
+ [self.sequence_length]
)
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i + 1]
rnn_output, hidden_states = self.rnn(rnn_input[:, i0:i1, :], hidden_states)
hidden_states[:, (terminated[:, i1 - 1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_output = torch.cat(rnn_outputs, dim=1)
# no need to reset the RNN state in the sequence
else:
rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
# rollout
else:
rnn_input = observations.view(-1, 1, observations.shape[-1]) # (N, L, Hin): N=num_envs, L=1
rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
# flatten the RNN output
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
x = self.fc1(rnn_output)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
return self.logits(x), {"rnn": [hidden_states]}
# instantiate the model (given a wrapped environment: `env`)
policy = RNN(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
unnormalized_log_prob=True,
reduction="sum",
num_envs=env.num_envs,
num_layers=1,
hidden_size=64,
sequence_length=10,
)
where:
The following points are relevant in the definition of recurrent models:
- The
.get_specification()method must be overwritten to return a dictionary with key"rnn". Under such key, a sub-dictionary must contain the following items: The sequence length (under sub-key
"sequence_length").A list of the dimensions for each initial hidden/cell state (under sub-key
"sizes").
- The
- The
.compute()method’sinputsparameter may include the following items: "observations": observations of the environment."states": state of the environment."taken_actions": actions taken by the policy for the given observations and/or states, if applicable."terminated": episode termination status for sampled environment transitions.This key is only defined during the training process.
"rnn": list of initial hidden states ordered according to the model specification.
- The
The
.compute()method should include, under the"rnn"key of the returned dictionary, a list of each final hidden/cell state (when applicable).
import torch
import torch.nn as nn
from skrl.models.torch import Model, MultiCategoricalMixin
# define the model
class GRU(MultiCategoricalMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
unnormalized_log_prob=True,
reduction="sum",
num_envs=1,
num_layers=1,
hidden_size=64,
sequence_length=10,
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
MultiCategoricalMixin.__init__(self, unnormalized_log_prob=unnormalized_log_prob, reduction=reduction)
self.num_envs = num_envs
self.num_layers = num_layers
self.hidden_size = hidden_size # Hout
self.sequence_length = sequence_length
self.gru = nn.GRU(
input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True, # (batch, sequence, features)
)
self.net = nn.Sequential(
nn.Linear(self.hidden_size, 64), nn.ReLU(), nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, self.num_actions)
)
def get_specification(self):
# batch size (N) is the number of envs during rollout
return {
"rnn": {
"sequence_length": self.sequence_length,
"sizes": [
(self.num_layers, self.num_envs, self.hidden_size) # hidden states (D ∗ num_layers, N, Hout)
],
}
}
def compute(self, inputs, role):
observations = inputs["observations"]
terminated = inputs.get("terminated", None)
hidden_states = inputs["rnn"][0]
# training
if self.training:
rnn_input = observations.view(
-1, self.sequence_length, observations.shape[-1]
) # (N, L, Hin): N=batch_size, L=sequence_length
hidden_states = hidden_states.view(
self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]
) # (D * num_layers, N, L, Hout)
# get the hidden states corresponding to the initial sequence
hidden_states = hidden_states[:, :, 0, :].contiguous() # (D * num_layers, N, Hout)
# reset the RNN state in the middle of a sequence
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = (
[0]
+ (terminated[:, :-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist()
+ [self.sequence_length]
)
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i + 1]
rnn_output, hidden_states = self.gru(rnn_input[:, i0:i1, :], hidden_states)
hidden_states[:, (terminated[:, i1 - 1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_output = torch.cat(rnn_outputs, dim=1)
# no need to reset the RNN state in the sequence
else:
rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
# rollout
else:
rnn_input = observations.view(-1, 1, observations.shape[-1]) # (N, L, Hin): N=num_envs, L=1
rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
# flatten the RNN output
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
return self.net(rnn_output), {"rnn": [hidden_states]}
# instantiate the model (given a wrapped environment: `env`)
policy = GRU(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
unnormalized_log_prob=True,
reduction="sum",
num_envs=env.num_envs,
num_layers=1,
hidden_size=64,
sequence_length=10,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from skrl.models.torch import Model, MultiCategoricalMixin
# define the model
class GRU(MultiCategoricalMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
unnormalized_log_prob=True,
reduction="sum",
num_envs=1,
num_layers=1,
hidden_size=64,
sequence_length=10,
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
MultiCategoricalMixin.__init__(self, unnormalized_log_prob=unnormalized_log_prob, reduction=reduction)
self.num_envs = num_envs
self.num_layers = num_layers
self.hidden_size = hidden_size # Hout
self.sequence_length = sequence_length
self.gru = nn.GRU(
input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True, # (batch, sequence, features)
)
self.fc1 = nn.Linear(self.hidden_size, 64)
self.fc2 = nn.Linear(64, 32)
self.logits = nn.Linear(32, self.num_actions)
def get_specification(self):
# batch size (N) is the number of envs during rollout
return {
"rnn": {
"sequence_length": self.sequence_length,
"sizes": [
(self.num_layers, self.num_envs, self.hidden_size) # hidden states (D ∗ num_layers, N, Hout)
],
}
}
def compute(self, inputs, role):
observations = inputs["observations"]
terminated = inputs.get("terminated", None)
hidden_states = inputs["rnn"][0]
# training
if self.training:
rnn_input = observations.view(
-1, self.sequence_length, observations.shape[-1]
) # (N, L, Hin): N=batch_size, L=sequence_length
hidden_states = hidden_states.view(
self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]
) # (D * num_layers, N, L, Hout)
# get the hidden states corresponding to the initial sequence
hidden_states = hidden_states[:, :, 0, :].contiguous() # (D * num_layers, N, Hout)
# reset the RNN state in the middle of a sequence
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = (
[0]
+ (terminated[:, :-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist()
+ [self.sequence_length]
)
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i + 1]
rnn_output, hidden_states = self.gru(rnn_input[:, i0:i1, :], hidden_states)
hidden_states[:, (terminated[:, i1 - 1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_output = torch.cat(rnn_outputs, dim=1)
# no need to reset the RNN state in the sequence
else:
rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
# rollout
else:
rnn_input = observations.view(-1, 1, observations.shape[-1]) # (N, L, Hin): N=num_envs, L=1
rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
# flatten the RNN output
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
x = self.fc1(rnn_output)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
return self.logits(x), {"rnn": [hidden_states]}
# instantiate the model (given a wrapped environment: `env`)
policy = GRU(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
unnormalized_log_prob=True,
reduction="sum",
num_envs=env.num_envs,
num_layers=1,
hidden_size=64,
sequence_length=10,
)
where:
The following points are relevant in the definition of recurrent models:
- The
.get_specification()method must be overwritten to return a dictionary with key"rnn". Under such key, a sub-dictionary must contain the following items: The sequence length (under sub-key
"sequence_length").A list of the dimensions for each initial hidden/cell state (under sub-key
"sizes").
- The
- The
.compute()method’sinputsparameter may include the following items: "observations": observations of the environment."states": state of the environment."taken_actions": actions taken by the policy for the given observations and/or states, if applicable."terminated": episode termination status for sampled environment transitions.This key is only defined during the training process.
"rnn": list of initial hidden states ordered according to the model specification.
- The
The
.compute()method should include, under the"rnn"key of the returned dictionary, a list of each final hidden/cell state (when applicable).
import torch
import torch.nn as nn
from skrl.models.torch import Model, MultiCategoricalMixin
# define the model
class LSTM(MultiCategoricalMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
unnormalized_log_prob=True,
reduction="sum",
num_envs=1,
num_layers=1,
hidden_size=64,
sequence_length=10,
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
MultiCategoricalMixin.__init__(self, unnormalized_log_prob=unnormalized_log_prob, reduction=reduction)
self.num_envs = num_envs
self.num_layers = num_layers
self.hidden_size = hidden_size # Hcell (Hout is Hcell because proj_size = 0)
self.sequence_length = sequence_length
self.lstm = nn.LSTM(
input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True, # (batch, sequence, features)
)
self.net = nn.Sequential(
nn.Linear(self.hidden_size, 64), nn.ReLU(), nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, self.num_actions)
)
def get_specification(self):
# batch size (N) is the number of envs during rollout
return {
"rnn": {
"sequence_length": self.sequence_length,
"sizes": [
(self.num_layers, self.num_envs, self.hidden_size), # hidden states (D ∗ num_layers, N, Hout)
(self.num_layers, self.num_envs, self.hidden_size), # cell states (D ∗ num_layers, N, Hcell)
],
}
}
def compute(self, inputs, role):
observations = inputs["observations"]
terminated = inputs.get("terminated", None)
hidden_states, cell_states = inputs["rnn"][0], inputs["rnn"][1]
# training
if self.training:
rnn_input = observations.view(
-1, self.sequence_length, observations.shape[-1]
) # (N, L, Hin): N=batch_size, L=sequence_length
hidden_states = hidden_states.view(
self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]
) # (D * num_layers, N, L, Hout)
cell_states = cell_states.view(
self.num_layers, -1, self.sequence_length, cell_states.shape[-1]
) # (D * num_layers, N, L, Hcell)
# get the hidden/cell states corresponding to the initial sequence
hidden_states = hidden_states[:, :, 0, :].contiguous() # (D * num_layers, N, Hout)
cell_states = cell_states[:, :, 0, :].contiguous() # (D * num_layers, N, Hcell)
# reset the RNN state in the middle of a sequence
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = (
[0]
+ (terminated[:, :-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist()
+ [self.sequence_length]
)
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i + 1]
rnn_output, (hidden_states, cell_states) = self.lstm(
rnn_input[:, i0:i1, :], (hidden_states, cell_states)
)
hidden_states[:, (terminated[:, i1 - 1]), :] = 0
cell_states[:, (terminated[:, i1 - 1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_states = (hidden_states, cell_states)
rnn_output = torch.cat(rnn_outputs, dim=1)
# no need to reset the RNN state in the sequence
else:
rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
# rollout
else:
rnn_input = observations.view(-1, 1, observations.shape[-1]) # (N, L, Hin): N=num_envs, L=1
rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
# flatten the RNN output
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
return self.net(rnn_output), {"rnn": [rnn_states[0], rnn_states[1]]}
# instantiate the model (given a wrapped environment: `env`)
policy = LSTM(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
unnormalized_log_prob=True,
reduction="sum",
num_envs=env.num_envs,
num_layers=1,
hidden_size=64,
sequence_length=10,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from skrl.models.torch import Model, MultiCategoricalMixin
# define the model
class LSTM(MultiCategoricalMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
unnormalized_log_prob=True,
reduction="sum",
num_envs=1,
num_layers=1,
hidden_size=64,
sequence_length=10,
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
MultiCategoricalMixin.__init__(self, unnormalized_log_prob=unnormalized_log_prob, reduction=reduction)
self.num_envs = num_envs
self.num_layers = num_layers
self.hidden_size = hidden_size # Hcell (Hout is Hcell because proj_size = 0)
self.sequence_length = sequence_length
self.lstm = nn.LSTM(
input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True, # (batch, sequence, features)
)
self.fc1 = nn.Linear(self.hidden_size, 64)
self.fc2 = nn.Linear(64, 32)
self.logits = nn.Linear(32, self.num_actions)
def get_specification(self):
# batch size (N) is the number of envs during rollout
return {
"rnn": {
"sequence_length": self.sequence_length,
"sizes": [
(self.num_layers, self.num_envs, self.hidden_size), # hidden states (D ∗ num_layers, N, Hout)
(self.num_layers, self.num_envs, self.hidden_size), # cell states (D ∗ num_layers, N, Hcell)
],
}
}
def compute(self, inputs, role):
observations = inputs["observations"]
terminated = inputs.get("terminated", None)
hidden_states, cell_states = inputs["rnn"][0], inputs["rnn"][1]
# training
if self.training:
rnn_input = observations.view(
-1, self.sequence_length, observations.shape[-1]
) # (N, L, Hin): N=batch_size, L=sequence_length
hidden_states = hidden_states.view(
self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]
) # (D * num_layers, N, L, Hout)
cell_states = cell_states.view(
self.num_layers, -1, self.sequence_length, cell_states.shape[-1]
) # (D * num_layers, N, L, Hcell)
# get the hidden/cell states corresponding to the initial sequence
hidden_states = hidden_states[:, :, 0, :].contiguous() # (D * num_layers, N, Hout)
cell_states = cell_states[:, :, 0, :].contiguous() # (D * num_layers, N, Hcell)
# reset the RNN state in the middle of a sequence
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = (
[0]
+ (terminated[:, :-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist()
+ [self.sequence_length]
)
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i + 1]
rnn_output, (hidden_states, cell_states) = self.lstm(
rnn_input[:, i0:i1, :], (hidden_states, cell_states)
)
hidden_states[:, (terminated[:, i1 - 1]), :] = 0
cell_states[:, (terminated[:, i1 - 1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_states = (hidden_states, cell_states)
rnn_output = torch.cat(rnn_outputs, dim=1)
# no need to reset the RNN state in the sequence
else:
rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
# rollout
else:
rnn_input = observations.view(-1, 1, observations.shape[-1]) # (N, L, Hin): N=num_envs, L=1
rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
# flatten the RNN output
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
x = self.fc1(rnn_output)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
return self.logits(x), {"rnn": [rnn_states[0], rnn_states[1]]}
# instantiate the model (given a wrapped environment: `env`)
policy = LSTM(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
unnormalized_log_prob=True,
reduction="sum",
num_envs=env.num_envs,
num_layers=1,
hidden_size=64,
sequence_length=10,
)
API¶
PyTorch¶
MultiCategorical mixin model (stochastic model). |
- class skrl.models.torch.multicategorical.MultiCategoricalMixin(*, unnormalized_log_prob: bool = True, reduction: Literal['mean', 'sum', 'prod', 'none'] = 'sum', role: str = '')[source]¶
Bases:
objectMultiCategorical mixin model (stochastic model).
- Parameters:
unnormalized_log_prob – Flag to indicate how to the model’s output will be interpreted. If True, the model’s output is interpreted as unnormalized log probabilities (it can be any real number), otherwise as normalized probabilities (the output must be non-negative, finite and have a non-zero sum).
reduction – Reduction method for returning the log probability density function. If
"none", the log probability density function is returned as a tensor of shape(num_samples, num_actions)instead of(num_samples, 1).role – Role played by the model.
- Raises:
ValueError – If the reduction method is not valid.
Methods:
act(inputs, *[, role])Act stochastically in response to the observations/states of the environment.
distribution(*[, role])Get the current distribution of the model.
get_entropy(*[, role])Compute and return the entropy of the model.
- act(inputs: dict[str, Any], *, role: str = '') tuple[torch.Tensor, dict[str, Any]][source]¶
Act stochastically in response to the observations/states of the environment.
- Parameters:
inputs –
Model inputs. The most common keys are:
"observations": observation of the environment used to make the decision."states": state of the environment used to make the decision."taken_actions": actions taken by the policy for the given observations/states.
role – Role played by the model.
- Returns:
Model output. The first component is the expected action/value returned by the model. The second component is a dictionary containing the following extra output values:
"log_prob": log of the probability density function."net_output": network output.
- distribution(*, role: str = '') torch.distributions.Categorical[source]¶
Get the current distribution of the model.
- Parameters:
role – Role played by the model.
- Returns:
First distribution of the model.
- get_entropy(*, role: str = '') torch.Tensor[source]¶
Compute and return the entropy of the model.
- Parameters:
role – Role played by the model.
- Returns:
Entropy of the model.
JAX¶
MultiCategorical mixin model (stochastic model). |
- class skrl.models.jax.multicategorical.MultiCategoricalMixin(*, unnormalized_log_prob: bool = True, reduction: Literal['mean', 'sum', 'prod', 'none'] = 'sum', role: str = '')[source]¶
Bases:
objectMultiCategorical mixin model (stochastic model).
- Parameters:
unnormalized_log_prob – Flag to indicate how to the model’s output will be interpreted. If True, the model’s output is interpreted as unnormalized log probabilities (it can be any real number), otherwise as normalized probabilities (the output must be non-negative, finite and have a non-zero sum).
reduction – Reduction method for returning the log probability density function. If
"none", the log probability density function is returned as a tensor of shape(num_samples, num_actions)instead of(num_samples, 1).role – Role played by the model.
- Raises:
ValueError – If the reduction method is not valid.
Methods:
act(inputs, *[, role, params])Act stochastically in response to the observations/states of the environment.
get_entropy(stddev, *[, role])Compute and return the entropy of the model.
- act(inputs: dict[str, Any], *, role: str = '', params: jax.Array | None = None) tuple[jax.Array, dict[str, Any]][source]¶
Act stochastically in response to the observations/states of the environment.
- Parameters:
inputs –
Model inputs. The most common keys are:
"observations": observation of the environment used to make the decision."states": state of the environment used to make the decision."taken_actions": actions taken by the policy for the given observations/states.
role – Role played by the model.
params – Parameters used to compute the output. If not provided, internal parameters will be used.
- Returns:
Model output. The first component is the expected action/value returned by the model. The second component is a dictionary containing the following extra output values:
"log_prob": log of the probability density function."net_output": network output.