Deterministic model¶
Deterministic models run continuous-domain deterministic policies.
skrl provides a Python mixin (DeterministicMixin) to assist in the creation of these types of models,
allowing users to have full control over the function approximator definitions and architectures.
Note that the use of this mixin must comply with the following rules:
The definition of multiple inheritance must always include the Model base class at the end.
The Model base class constructor must be invoked before the mixins constructor.
Note
For models in JAX/Flax it is imperative to define all parameters (except observation_space,
state_space, action_space and device) with default values to avoid errors during initialization
(TypeError: __init__() missing N required positional argument).
In addition, it is necessary to initialize the model’s state_dict (via the init_state_dict method) after
its instantiation to avoid errors during its use (AttributeError: object has no attribute "state_dict".
If "state_dict" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply').
class DeterministicModel(DeterministicMixin, Model):
def __init__(self, observation_space, state_space, action_space, device, clip_actions=False):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
DeterministicMixin.__init__(self, clip_actions=clip_actions)
class DeterministicModel(DeterministicMixin, Model):
def __init__(self, observation_space, state_space, action_space, device, clip_actions=False, **kwargs):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
**kwargs,
)
DeterministicMixin.__init__(self, clip_actions=clip_actions)
Concept¶
Usage¶
Multi-Layer Perceptron (MLP)
Convolutional Neural Network (CNN)
Recurrent Neural Network (RNN)
Gated Recurrent Unit RNN (GRU)
Long Short-Term Memory RNN (LSTM)
import torch
import torch.nn as nn
from skrl.models.torch import Model, DeterministicMixin
# define the model
class MLP(DeterministicMixin, Model):
def __init__(self, observation_space, state_space, action_space, device, clip_actions=False):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
DeterministicMixin.__init__(self, clip_actions=clip_actions)
self.net = nn.Sequential(
nn.Linear(self.num_observations + self.num_actions, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 1),
)
def compute(self, inputs, role):
return self.net(torch.cat([inputs["observations"], inputs["taken_actions"]], dim=1)), {}
# instantiate the model (given a wrapped environment: `env`)
critic = MLP(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
clip_actions=False,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from skrl.models.torch import Model, DeterministicMixin
# define the model
class MLP(DeterministicMixin, Model):
def __init__(self, observation_space, state_space, action_space, device, clip_actions=False):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
DeterministicMixin.__init__(self, clip_actions=clip_actions)
self.fc1 = nn.Linear(self.num_observations + self.num_actions, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 1)
def compute(self, inputs, role):
x = self.fc1(torch.cat([inputs["observations"], inputs["taken_actions"]], dim=1))
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
return self.fc3(x), {}
# instantiate the model (given a wrapped environment: `env`)
critic = MLP(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
clip_actions=False,
)
import jax.numpy as jnp
import flax.linen as nn
from skrl.models.jax import Model, DeterministicMixin
# define the model
class MLP(DeterministicMixin, Model):
def __init__(self, observation_space, state_space, action_space, device, clip_actions=False, **kwargs):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
**kwargs,
)
DeterministicMixin.__init__(self, clip_actions=clip_actions)
def setup(self):
self.fc1 = nn.Dense(64)
self.fc2 = nn.Dense(32)
self.fc3 = nn.Dense(1)
def __call__(self, inputs, role):
x = jnp.concatenate([inputs["observations"], inputs["taken_actions"]], axis=-1)
x = self.fc1(x)
x = nn.relu(x)
x = self.fc2(x)
x = nn.relu(x)
x = self.fc3(x)
return x, {}
# instantiate the model (given a wrapped environment: `env`)
critic = MLP(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
clip_actions=False,
)
# initialize model's state dict
critic.init_state_dict(role="critic")
import jax.numpy as jnp
import flax.linen as nn
from skrl.models.jax import Model, DeterministicMixin
# define the model
class MLP(DeterministicMixin, Model):
def __init__(self, observation_space, state_space, action_space, device, clip_actions=False, **kwargs):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
**kwargs,
)
DeterministicMixin.__init__(self, clip_actions=clip_actions)
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = jnp.concatenate([inputs["observations"], inputs["taken_actions"]], axis=-1)
x = nn.relu(nn.Dense(64)(x))
x = nn.relu(nn.Dense(32)(x))
x = nn.Dense(1)(x)
return x, {}
# instantiate the model (given a wrapped environment: `env`)
critic = MLP(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
clip_actions=False,
)
# initialize model's state dict
critic.init_state_dict(role="critic")
import torch
import torch.nn as nn
from skrl.models.torch import Model, DeterministicMixin
# define the model
class CNN(DeterministicMixin, Model):
def __init__(self, observation_space, state_space, action_space, device, clip_actions=False):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
DeterministicMixin.__init__(self, clip_actions=clip_actions)
self.features_extractor = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, 16),
nn.Tanh(),
)
self.net = nn.Sequential(
nn.Linear(16 + self.num_actions, 64), nn.Tanh(), nn.Linear(64, 32), nn.Tanh(), nn.Linear(32, 1)
)
def compute(self, inputs, role):
# permute (samples, width * height * channels) -> (samples, channels, width, height)
x = self.features_extractor(inputs["observations"].view(-1, *self.observation_space.shape).permute(0, 3, 1, 2))
return self.net(torch.cat([x, inputs["taken_actions"]], dim=1)), {}
# instantiate the model (given a wrapped environment: `env`)
critic = CNN(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
clip_actions=False,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from skrl.models.torch import Model, DeterministicMixin
# define the model
class CNN(DeterministicMixin, Model):
def __init__(self, observation_space, state_space, action_space, device, clip_actions=False):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
DeterministicMixin.__init__(self, clip_actions=clip_actions)
self.conv1 = nn.Conv2d(3, 32, kernel_size=8, stride=4)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 16)
self.fc3 = nn.Linear(16 + self.num_actions, 64)
self.fc4 = nn.Linear(64, 32)
self.fc5 = nn.Linear(32, 1)
def compute(self, inputs, role):
# permute (samples, width * height * channels) -> (samples, channels, width, height)
x = inputs["observations"].view(-1, *self.observation_space.shape).permute(0, 3, 1, 2)
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = self.conv3(x)
x = F.relu(x)
x = torch.flatten(x, start_dim=1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = torch.tanh(x)
x = self.fc3(torch.cat([x, inputs["taken_actions"]], dim=1))
x = torch.tanh(x)
x = self.fc4(x)
x = torch.tanh(x)
x = self.fc5(x)
return x, {}
# instantiate the model (given a wrapped environment: `env`)
critic = CNN(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
clip_actions=False,
)
import jax.numpy as jnp
import flax.linen as nn
from skrl.models.jax import Model, DeterministicMixin
# define the model
class CNN(DeterministicMixin, Model):
def __init__(self, observation_space, state_space, action_space, device, clip_actions=False, **kwargs):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
**kwargs,
)
DeterministicMixin.__init__(self, clip_actions=clip_actions)
def setup(self):
self.conv1 = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")
self.conv2 = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")
self.conv3 = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")
self.fc1 = nn.Dense(512)
self.fc2 = nn.Dense(16)
self.fc3 = nn.Dense(64)
self.fc4 = nn.Dense(32)
self.fc5 = nn.Dense(1)
def __call__(self, inputs, role):
x = inputs["observations"].reshape((-1, *self.observation_space.shape))
x = self.conv1(x)
x = nn.relu(x)
x = self.conv2(x)
x = nn.relu(x)
x = self.conv3(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = self.fc1(x)
x = nn.relu(x)
x = self.fc2(x)
x = nn.tanh(x)
x = jnp.concatenate([x, inputs["taken_actions"]], axis=-1)
x = self.fc3(x)
x = nn.tanh(x)
x = self.fc4(x)
x = nn.tanh(x)
x = self.fc5(x)
return x, {}
# instantiate the model (given a wrapped environment: `env`)
critic = CNN(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
clip_actions=False,
)
# initialize model's state dict
critic.init_state_dict(role="critic")
import jax.numpy as jnp
import flax.linen as nn
from skrl.models.jax import Model, DeterministicMixin
# define the model
class CNN(DeterministicMixin, Model):
def __init__(self, observation_space, state_space, action_space, device, clip_actions=False, **kwargs):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
**kwargs,
)
DeterministicMixin.__init__(self, clip_actions=clip_actions)
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = inputs["observations"].reshape((-1, *self.observation_space.shape))
x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x)
x = nn.relu(x)
x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x)
x = nn.relu(x)
x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = nn.Dense(512)(x)
x = nn.relu(x)
x = nn.Dense(16)(x)
x = nn.tanh(x)
x = jnp.concatenate([x, inputs["taken_actions"]], axis=-1)
x = nn.Dense(64)(x)
x = nn.tanh(x)
x = nn.Dense(32)(x)
x = nn.tanh(x)
x = nn.Dense(1)(x)
return x, {}
# instantiate the model (given a wrapped environment: `env`)
critic = CNN(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
clip_actions=False,
)
# initialize model's state dict
critic.init_state_dict(role="critic")
where:
The following points are relevant in the definition of recurrent models:
- The
.get_specification()method must be overwritten to return a dictionary with key"rnn". Under such key, a sub-dictionary must contain the following items: The sequence length (under sub-key
"sequence_length").A list of the dimensions for each initial hidden/cell state (under sub-key
"sizes").
- The
- The
.compute()method’sinputsparameter may include the following items: "observations": observations of the environment."states": state of the environment."taken_actions": actions taken by the policy for the given observations and/or states, if applicable."terminated": episode termination status for sampled environment transitions.This key is only defined during the training process.
"rnn": list of initial hidden states ordered according to the model specification.
- The
The
.compute()method should include, under the"rnn"key of the returned dictionary, a list of each final hidden/cell state (when applicable).
import torch
import torch.nn as nn
from skrl.models.torch import Model, DeterministicMixin
# define the model
class RNN(DeterministicMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
clip_actions=False,
num_envs=1,
num_layers=1,
hidden_size=64,
sequence_length=10,
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
DeterministicMixin.__init__(self, clip_actions=clip_actions)
self.num_envs = num_envs
self.num_layers = num_layers
self.hidden_size = hidden_size # Hout
self.sequence_length = sequence_length
self.rnn = nn.RNN(
input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True, # (batch, sequence, features)
)
self.net = nn.Sequential(
nn.Linear(self.hidden_size + self.num_actions, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 1),
)
def get_specification(self):
# batch size (N) is the number of envs during rollout
return {
"rnn": {
"sequence_length": self.sequence_length,
"sizes": [
(self.num_layers, self.num_envs, self.hidden_size) # hidden states (D ∗ num_layers, N, Hout)
],
}
}
def compute(self, inputs, role):
observations = inputs["observations"]
terminated = inputs.get("terminated", None)
hidden_states = inputs["rnn"][0]
# critic models are only used during training
rnn_input = observations.view(
-1, self.sequence_length, observations.shape[-1]
) # (N, L, Hin): N=batch_size, L=sequence_length
hidden_states = hidden_states.view(
self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]
) # (D * num_layers, N, L, Hout)
# get the hidden states corresponding to the initial sequence
sequence_index = 1 if role == "target_critic" else 0 # target networks act on the next state of the environment
hidden_states = hidden_states[:, :, sequence_index, :].contiguous() # (D * num_layers, N, Hout)
# reset the RNN state in the middle of a sequence
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = (
[0] + (terminated[:, :-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
)
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i + 1]
rnn_output, hidden_states = self.rnn(rnn_input[:, i0:i1, :], hidden_states)
hidden_states[:, (terminated[:, i1 - 1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_output = torch.cat(rnn_outputs, dim=1)
# no need to reset the RNN state in the sequence
else:
rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
# flatten the RNN output
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
return self.net(torch.cat([rnn_output, inputs["taken_actions"]], dim=1)), {"rnn": [hidden_states]}
# instantiate the model (given a wrapped environment: `env`)
critic = RNN(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
clip_actions=False,
num_envs=env.num_envs,
num_layers=1,
hidden_size=64,
sequence_length=10,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from skrl.models.torch import Model, DeterministicMixin
# define the model
class RNN(DeterministicMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
clip_actions=False,
num_envs=1,
num_layers=1,
hidden_size=64,
sequence_length=10,
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
DeterministicMixin.__init__(self, clip_actions=clip_actions)
self.num_envs = num_envs
self.num_layers = num_layers
self.hidden_size = hidden_size # Hout
self.sequence_length = sequence_length
self.rnn = nn.RNN(
input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True, # (batch, sequence, features)
)
self.fc1 = nn.Linear(self.hidden_size + self.num_actions, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 1)
def get_specification(self):
# batch size (N) is the number of envs during rollout
return {
"rnn": {
"sequence_length": self.sequence_length,
"sizes": [
(self.num_layers, self.num_envs, self.hidden_size) # hidden states (D ∗ num_layers, N, Hout)
],
}
}
def compute(self, inputs, role):
observations = inputs["observations"]
terminated = inputs.get("terminated", None)
hidden_states = inputs["rnn"][0]
# critic models are only used during training
rnn_input = observations.view(
-1, self.sequence_length, observations.shape[-1]
) # (N, L, Hin): N=batch_size, L=sequence_length
hidden_states = hidden_states.view(
self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]
) # (D * num_layers, N, L, Hout)
# get the hidden states corresponding to the initial sequence
sequence_index = 1 if role == "target_critic" else 0 # target networks act on the next state of the environment
hidden_states = hidden_states[:, :, sequence_index, :].contiguous() # (D * num_layers, N, Hout)
# reset the RNN state in the middle of a sequence
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = (
[0] + (terminated[:, :-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
)
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i + 1]
rnn_output, hidden_states = self.rnn(rnn_input[:, i0:i1, :], hidden_states)
hidden_states[:, (terminated[:, i1 - 1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_output = torch.cat(rnn_outputs, dim=1)
# no need to reset the RNN state in the sequence
else:
rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
# flatten the RNN output
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
x = self.fc1(torch.cat([rnn_output, inputs["taken_actions"]], dim=1))
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
return self.fc3(x), {"rnn": [hidden_states]}
# instantiate the model (given a wrapped environment: `env`)
critic = RNN(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
clip_actions=False,
num_envs=env.num_envs,
num_layers=1,
hidden_size=64,
sequence_length=10,
)
where:
The following points are relevant in the definition of recurrent models:
- The
.get_specification()method must be overwritten to return a dictionary with key"rnn". Under such key, a sub-dictionary must contain the following items: The sequence length (under sub-key
"sequence_length").A list of the dimensions for each initial hidden/cell state (under sub-key
"sizes").
- The
- The
.compute()method’sinputsparameter may include the following items: "observations": observations of the environment."states": state of the environment."taken_actions": actions taken by the policy for the given observations and/or states, if applicable."terminated": episode termination status for sampled environment transitions.This key is only defined during the training process.
"rnn": list of initial hidden states ordered according to the model specification.
- The
The
.compute()method should include, under the"rnn"key of the returned dictionary, a list of each final hidden/cell state (when applicable).
import torch
import torch.nn as nn
from skrl.models.torch import Model, DeterministicMixin
# define the model
class GRU(DeterministicMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
clip_actions=False,
num_envs=1,
num_layers=1,
hidden_size=64,
sequence_length=10,
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
DeterministicMixin.__init__(self, clip_actions=clip_actions)
self.num_envs = num_envs
self.num_layers = num_layers
self.hidden_size = hidden_size # Hout
self.sequence_length = sequence_length
self.gru = nn.GRU(
input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True, # (batch, sequence, features)
)
self.net = nn.Sequential(
nn.Linear(self.hidden_size + self.num_actions, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 1),
)
def get_specification(self):
# batch size (N) is the number of envs during rollout
return {
"rnn": {
"sequence_length": self.sequence_length,
"sizes": [
(self.num_layers, self.num_envs, self.hidden_size) # hidden states (D ∗ num_layers, N, Hout)
],
}
}
def compute(self, inputs, role):
observations = inputs["observations"]
terminated = inputs.get("terminated", None)
hidden_states = inputs["rnn"][0]
# critic models are only used during training
rnn_input = observations.view(
-1, self.sequence_length, observations.shape[-1]
) # (N, L, Hin): N=batch_size, L=sequence_length
hidden_states = hidden_states.view(
self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]
) # (D * num_layers, N, L, Hout)
# get the hidden states corresponding to the initial sequence
sequence_index = 1 if role == "target_critic" else 0 # target networks act on the next state of the environment
hidden_states = hidden_states[:, :, sequence_index, :].contiguous() # (D * num_layers, N, Hout)
# reset the RNN state in the middle of a sequence
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = (
[0] + (terminated[:, :-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
)
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i + 1]
rnn_output, hidden_states = self.gru(rnn_input[:, i0:i1, :], hidden_states)
hidden_states[:, (terminated[:, i1 - 1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_output = torch.cat(rnn_outputs, dim=1)
# no need to reset the RNN state in the sequence
else:
rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
# flatten the RNN output
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
return self.net(torch.cat([rnn_output, inputs["taken_actions"]], dim=1)), {"rnn": [hidden_states]}
# instantiate the model (given a wrapped environment: `env`)
critic = GRU(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
clip_actions=False,
num_envs=env.num_envs,
num_layers=1,
hidden_size=64,
sequence_length=10,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from skrl.models.torch import Model, DeterministicMixin
# define the model
class GRU(DeterministicMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
clip_actions=False,
num_envs=1,
num_layers=1,
hidden_size=64,
sequence_length=10,
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
DeterministicMixin.__init__(self, clip_actions=clip_actions)
self.num_envs = num_envs
self.num_layers = num_layers
self.hidden_size = hidden_size # Hout
self.sequence_length = sequence_length
self.gru = nn.GRU(
input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True, # (batch, sequence, features)
)
self.fc1 = nn.Linear(self.hidden_size + self.num_actions, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 1)
def get_specification(self):
# batch size (N) is the number of envs during rollout
return {
"rnn": {
"sequence_length": self.sequence_length,
"sizes": [
(self.num_layers, self.num_envs, self.hidden_size) # hidden states (D ∗ num_layers, N, Hout)
],
}
}
def compute(self, inputs, role):
observations = inputs["observations"]
terminated = inputs.get("terminated", None)
hidden_states = inputs["rnn"][0]
# critic models are only used during training
rnn_input = observations.view(
-1, self.sequence_length, observations.shape[-1]
) # (N, L, Hin): N=batch_size, L=sequence_length
hidden_states = hidden_states.view(
self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]
) # (D * num_layers, N, L, Hout)
# get the hidden states corresponding to the initial sequence
sequence_index = 1 if role == "target_critic" else 0 # target networks act on the next state of the environment
hidden_states = hidden_states[:, :, sequence_index, :].contiguous() # (D * num_layers, N, Hout)
# reset the RNN state in the middle of a sequence
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = (
[0] + (terminated[:, :-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
)
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i + 1]
rnn_output, hidden_states = self.gru(rnn_input[:, i0:i1, :], hidden_states)
hidden_states[:, (terminated[:, i1 - 1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_output = torch.cat(rnn_outputs, dim=1)
# no need to reset the RNN state in the sequence
else:
rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
# flatten the RNN output
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
x = self.fc1(torch.cat([rnn_output, inputs["taken_actions"]], dim=1))
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
return self.fc3(x), {"rnn": [hidden_states]}
# instantiate the model (given a wrapped environment: `env`)
critic = GRU(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
clip_actions=False,
num_envs=env.num_envs,
num_layers=1,
hidden_size=64,
sequence_length=10,
)
where:
The following points are relevant in the definition of recurrent models:
- The
.get_specification()method must be overwritten to return a dictionary with key"rnn". Under such key, a sub-dictionary must contain the following items: The sequence length (under sub-key
"sequence_length").A list of the dimensions for each initial hidden/cell state (under sub-key
"sizes").
- The
- The
.compute()method’sinputsparameter may include the following items: "observations": observations of the environment."states": state of the environment."taken_actions": actions taken by the policy for the given observations and/or states, if applicable."terminated": episode termination status for sampled environment transitions.This key is only defined during the training process.
"rnn": list of initial hidden states ordered according to the model specification.
- The
The
.compute()method should include, under the"rnn"key of the returned dictionary, a list of each final hidden/cell state (when applicable).
import torch
import torch.nn as nn
from skrl.models.torch import Model, DeterministicMixin
# define the model
class LSTM(DeterministicMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
clip_actions=False,
num_envs=1,
num_layers=1,
hidden_size=64,
sequence_length=10,
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
DeterministicMixin.__init__(self, clip_actions=clip_actions)
self.num_envs = num_envs
self.num_layers = num_layers
self.hidden_size = hidden_size # Hcell (Hout is Hcell because proj_size = 0)
self.sequence_length = sequence_length
self.lstm = nn.LSTM(
input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True, # (batch, sequence, features)
)
self.net = nn.Sequential(
nn.Linear(self.hidden_size + self.num_actions, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 1),
)
def get_specification(self):
# batch size (N) is the number of envs during rollout
return {
"rnn": {
"sequence_length": self.sequence_length,
"sizes": [
(self.num_layers, self.num_envs, self.hidden_size), # hidden states (D ∗ num_layers, N, Hout)
(self.num_layers, self.num_envs, self.hidden_size), # cell states (D ∗ num_layers, N, Hcell)
],
}
}
def compute(self, inputs, role):
observations = inputs["observations"]
terminated = inputs.get("terminated", None)
hidden_states, cell_states = inputs["rnn"][0], inputs["rnn"][1]
# critic models are only used during training
rnn_input = observations.view(
-1, self.sequence_length, observations.shape[-1]
) # (N, L, Hin): N=batch_size, L=sequence_length
hidden_states = hidden_states.view(
self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]
) # (D * num_layers, N, L, Hout)
cell_states = cell_states.view(
self.num_layers, -1, self.sequence_length, cell_states.shape[-1]
) # (D * num_layers, N, L, Hcell)
# get the hidden/cell states corresponding to the initial sequence
sequence_index = 1 if role == "target_critic" else 0 # target networks act on the next state of the environment
hidden_states = hidden_states[:, :, sequence_index, :].contiguous() # (D * num_layers, N, Hout)
cell_states = cell_states[:, :, sequence_index, :].contiguous() # (D * num_layers, N, Hcell)
# reset the RNN state in the middle of a sequence
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = (
[0] + (terminated[:, :-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
)
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i + 1]
rnn_output, (hidden_states, cell_states) = self.lstm(
rnn_input[:, i0:i1, :], (hidden_states, cell_states)
)
hidden_states[:, (terminated[:, i1 - 1]), :] = 0
cell_states[:, (terminated[:, i1 - 1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_states = (hidden_states, cell_states)
rnn_output = torch.cat(rnn_outputs, dim=1)
# no need to reset the RNN state in the sequence
else:
rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
# flatten the RNN output
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
return self.net(torch.cat([rnn_output, inputs["taken_actions"]], dim=1)), {
"rnn": [rnn_states[0], rnn_states[1]]
}
# instantiate the model (given a wrapped environment: `env`)
critic = LSTM(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
clip_actions=False,
num_envs=env.num_envs,
num_layers=1,
hidden_size=64,
sequence_length=10,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from skrl.models.torch import Model, DeterministicMixin
# define the model
class LSTM(DeterministicMixin, Model):
def __init__(
self,
observation_space,
state_space,
action_space,
device,
clip_actions=False,
num_envs=1,
num_layers=1,
hidden_size=64,
sequence_length=10,
):
Model.__init__(
self,
observation_space=observation_space,
state_space=state_space,
action_space=action_space,
device=device,
)
DeterministicMixin.__init__(self, clip_actions=clip_actions)
self.num_envs = num_envs
self.num_layers = num_layers
self.hidden_size = hidden_size # Hcell (Hout is Hcell because proj_size = 0)
self.sequence_length = sequence_length
self.lstm = nn.LSTM(
input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True, # (batch, sequence, features)
)
self.fc1 = nn.Linear(self.hidden_size + self.num_actions, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 1)
def get_specification(self):
# batch size (N) is the number of envs during rollout
return {
"rnn": {
"sequence_length": self.sequence_length,
"sizes": [
(self.num_layers, self.num_envs, self.hidden_size), # hidden states (D ∗ num_layers, N, Hout)
(self.num_layers, self.num_envs, self.hidden_size), # cell states (D ∗ num_layers, N, Hcell)
],
}
}
def compute(self, inputs, role):
observations = inputs["observations"]
terminated = inputs.get("terminated", None)
hidden_states, cell_states = inputs["rnn"][0], inputs["rnn"][1]
# critic models are only used during training
rnn_input = observations.view(
-1, self.sequence_length, observations.shape[-1]
) # (N, L, Hin): N=batch_size, L=sequence_length
hidden_states = hidden_states.view(
self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]
) # (D * num_layers, N, L, Hout)
cell_states = cell_states.view(
self.num_layers, -1, self.sequence_length, cell_states.shape[-1]
) # (D * num_layers, N, L, Hcell)
# get the hidden/cell states corresponding to the initial sequence
sequence_index = 1 if role == "target_critic" else 0 # target networks act on the next state of the environment
hidden_states = hidden_states[:, :, sequence_index, :].contiguous() # (D * num_layers, N, Hout)
cell_states = cell_states[:, :, sequence_index, :].contiguous() # (D * num_layers, N, Hcell)
# reset the RNN state in the middle of a sequence
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = (
[0] + (terminated[:, :-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
)
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i + 1]
rnn_output, (hidden_states, cell_states) = self.lstm(
rnn_input[:, i0:i1, :], (hidden_states, cell_states)
)
hidden_states[:, (terminated[:, i1 - 1]), :] = 0
cell_states[:, (terminated[:, i1 - 1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_states = (hidden_states, cell_states)
rnn_output = torch.cat(rnn_outputs, dim=1)
# no need to reset the RNN state in the sequence
else:
rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
# flatten the RNN output
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
x = self.fc1(torch.cat([rnn_output, inputs["taken_actions"]], dim=1))
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
return self.fc3(x), {"rnn": [rnn_states[0], rnn_states[1]]}
# instantiate the model (given a wrapped environment: `env`)
critic = LSTM(
observation_space=env.observation_space,
state_space=env.state_space,
action_space=env.action_space,
device=env.device,
clip_actions=False,
num_envs=env.num_envs,
num_layers=1,
hidden_size=64,
sequence_length=10,
)
API¶
PyTorch¶
Deterministic mixin model (deterministic model). |
- class skrl.models.torch.deterministic.DeterministicMixin(*, clip_actions: bool = False, role: str = '')[source]¶
Bases:
objectDeterministic mixin model (deterministic model).
- Parameters:
clip_actions – Flag to indicate whether the actions should be clipped to the action space.
role – Role played by the model.
Methods:
act(inputs, *[, role])Act deterministically in response to the observations/states of the environment.
- act(inputs: dict[str, Any], *, role: str = '') tuple[torch.Tensor, dict[str, Any]][source]¶
Act deterministically in response to the observations/states of the environment.
- Parameters:
inputs –
Model inputs. The most common keys are:
"observations": observation of the environment used to make the decision."states": state of the environment used to make the decision."taken_actions": actions taken by the policy for the given observations/states.
role – Role played by the model.
- Returns:
Model output. The first component is the expected action/value returned by the model. The second component is a dictionary containing extra output values according to the model.
JAX¶
Deterministic mixin model (deterministic model). |
- class skrl.models.jax.deterministic.DeterministicMixin(*, clip_actions: bool = False, role: str = '')[source]¶
Bases:
objectDeterministic mixin model (deterministic model).
- Parameters:
clip_actions – Flag to indicate whether the actions should be clipped to the action space.
role – Role played by the model.
Methods:
act(inputs, *[, role, params])Act deterministically in response to the observations/states of the environment.
- act(inputs: dict[str, Any], *, role: str = '', params: jax.Array | None = None) tuple[jax.Array, dict[str, Any]][source]¶
Act deterministically in response to the observations/states of the environment.
- Parameters:
inputs –
Model inputs. The most common keys are:
"observations": observation of the environment used to make the decision."states": state of the environment used to make the decision."taken_actions": actions taken by the policy for the given observations/states.
role – Role played by the model.
params – Parameters used to compute the output. If not provided, internal parameters will be used.
- Returns:
Model output. The first component is the expected action/value returned by the model. The second component is a dictionary containing extra output values according to the model.
Warp¶
Deterministic mixin model (deterministic model). |