Gaussian model¶
Gaussian models run continuous-domain stochastic policies.
skrl provides a Python mixin (GaussianMixin
) to assist in the creation of these types of models, allowing users to have full control over the function approximator definitions and architectures. Note that the use of this mixin must comply with the following rules:
The definition of multiple inheritance must always include the Model base class at the end.
The Model base class constructor must be invoked before the mixins constructor.
Warning
For models in JAX/Flax it is imperative to define all parameters (except observation_space
, action_space
and device
) with default values to avoid errors (TypeError: __init__() missing N required positional argument
) during initialization.
In addition, it is necessary to initialize the model’s state_dict
(via the init_state_dict
method) after its instantiation to avoid errors (AttributeError: object has no attribute "state_dict". If "state_dict" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'
) during its use.
class GaussianModel(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device=None,
clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
class GaussianModel(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device=None,
clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum", **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
Concept¶
Usage¶
Multi-Layer Perceptron (MLP)
Convolutional Neural Network (CNN)
Recurrent Neural Network (RNN)
Gated Recurrent Unit RNN (GRU)
Long Short-Term Memory RNN (LSTM)
import torch
import torch.nn as nn
from skrl.models.torch import Model, GaussianMixin
# define the model
class MLP(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device,
clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
self.net = nn.Sequential(nn.Linear(self.num_observations, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, self.num_actions),
nn.Tanh())
self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
def compute(self, inputs, role):
return self.net(inputs["states"]), self.log_std_parameter, {}
# instantiate the model (assumes there is a wrapped environment: env)
policy = MLP(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=True,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
reduction="sum")
import torch
import torch.nn as nn
import torch.nn.functional as F
from skrl.models.torch import Model, GaussianMixin
# define the model
class MLP(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device,
clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
self.fc1 = nn.Linear(self.num_observations, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, self.num_actions)
self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
def compute(self, inputs, role):
x = self.fc1(inputs["states"])
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
return torch.tanh(x), self.log_std_parameter, {}
# instantiate the model (assumes there is a wrapped environment: env)
policy = MLP(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=True,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
reduction="sum")
import jax.numpy as jnp
import flax.linen as nn
from skrl.models.jax import Model, GaussianMixin
# define the model
class MLP(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device=None,
clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum", **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
def setup(self):
self.fc1 = nn.Dense(64)
self.fc2 = nn.Dense(32)
self.fc3 = nn.Dense(self.num_actions)
self.log_std_parameter = self.param("log_std_parameter", lambda _: jnp.zeros(self.num_actions))
def __call__(self, inputs, role):
x = self.fc1(inputs["states"])
x = nn.relu(x)
x = self.fc2(x)
x = nn.relu(x)
x = self.fc3(x)
return nn.tanh(x), self.log_std_parameter, {}
# instantiate the model (assumes there is a wrapped environment: env)
policy = MLP(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=True,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
reduction="sum")
# initialize model's state dict
policy.init_state_dict("policy")
import jax.numpy as jnp
import flax.linen as nn
from skrl.models.jax import Model, GaussianMixin
# define the model
class MLP(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device=None,
clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum", **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.Dense(64)(inputs["states"])
x = nn.relu(x)
x = nn.Dense(32)(x)
x = nn.relu(x)
x = nn.Dense(self.num_actions)(x)
log_std_parameter = self.param("log_std_parameter", lambda _: jnp.zeros(self.num_actions))
return nn.tanh(x), log_std_parameter, {}
# instantiate the model (assumes there is a wrapped environment: env)
policy = MLP(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=True,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
reduction="sum")
# initialize model's state dict
policy.init_state_dict("policy")
import torch
import torch.nn as nn
from skrl.models.torch import Model, GaussianMixin
# define the model
class CNN(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device,
clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
self.net = nn.Sequential(nn.Conv2d(3, 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, 16),
nn.Tanh(),
nn.Linear(16, 64),
nn.Tanh(),
nn.Linear(64, 32),
nn.Tanh(),
nn.Linear(32, self.num_actions))
self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
def compute(self, inputs, role):
# permute (samples, width * height * channels) -> (samples, channels, width, height)
return self.net(inputs["states"].view(-1, *self.observation_space.shape).permute(0, 3, 1, 2)), self.log_std_parameter, {}
# instantiate the model (assumes there is a wrapped environment: env)
policy = CNN(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=True,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
reduction="sum")
import torch
import torch.nn as nn
import torch.nn.functional as F
from skrl.models.torch import Model, GaussianMixin
# define the model
class CNN(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device,
clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
self.conv1 = nn.Conv2d(3, 32, kernel_size=8, stride=4)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 16)
self.fc3 = nn.Linear(16, 64)
self.fc4 = nn.Linear(64, 32)
self.fc5 = nn.Linear(32, self.num_actions)
self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
def compute(self, inputs, role):
# permute (samples, width * height * channels) -> (samples, channels, width, height)
x = inputs["states"].view(-1, *self.observation_space.shape).permute(0, 3, 1, 2)
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = self.conv3(x)
x = F.relu(x)
x = torch.flatten(x, start_dim=1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = torch.tanh(x)
x = self.fc3(x)
x = torch.tanh(x)
x = self.fc4(x)
x = torch.tanh(x)
x = self.fc5(x)
return x, self.log_std_parameter, {}
# instantiate the model (assumes there is a wrapped environment: env)
policy = CNN(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=True,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
reduction="sum")
import jax.numpy as jnp
import flax.linen as nn
from skrl.models.jax import Model, GaussianMixin
# define the model
class CNN(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device=None,
clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum", **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
def setup(self):
self.conv1 = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")
self.conv2 = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")
self.conv3 = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")
self.fc1 = nn.Dense(512)
self.fc2 = nn.Dense(16)
self.fc3 = nn.Dense(64)
self.fc4 = nn.Dense(32)
self.fc5 = nn.Dense(self.num_actions)
self.log_std_parameter = self.param("log_std_parameter", lambda _: jnp.zeros(self.num_actions))
def __call__(self, inputs, role):
x = inputs["states"].reshape((-1, *self.observation_space.shape))
x = self.conv1(x)
x = nn.relu(x)
x = self.conv2(x)
x = nn.relu(x)
x = self.conv3(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = self.fc1(x)
x = nn.relu(x)
x = self.fc2(x)
x = nn.tanh(x)
x = self.fc3(x)
x = nn.tanh(x)
x = self.fc4(x)
x = nn.tanh(x)
x = self.fc5(x)
return nn.tanh(x), self.log_std_parameter, {}
# instantiate the model (assumes there is a wrapped environment: env)
policy = CNN(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=True,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
reduction="sum")
# initialize model's state dict
policy.init_state_dict("policy")
import jax.numpy as jnp
import flax.linen as nn
from skrl.models.jax import Model, GaussianMixin
# define the model
class CNN(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device=None,
clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum", **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = inputs["states"].reshape((-1, *self.observation_space.shape))
x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x)
x = nn.relu(x)
x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x)
x = nn.relu(x)
x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = nn.Dense(512)(x)
x = nn.relu(x)
x = nn.Dense(16)(x)
x = nn.tanh(x)
x = nn.Dense(64)(x)
x = nn.tanh(x)
x = nn.Dense(32)(x)
x = nn.tanh(x)
x = nn.Dense(self.num_actions)(x)
log_std_parameter = self.param("log_std_parameter", lambda _: jnp.zeros(self.num_actions))
return nn.tanh(x), log_std_parameter, {}
# instantiate the model (assumes there is a wrapped environment: env)
policy = CNN(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=True,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
reduction="sum")
# initialize model's state dict
policy.init_state_dict("policy")
where:
The following points are relevant in the definition of recurrent models:
The
.get_specification()
method must be overwritten to return, under a dictionary key"rnn"
, a sub-dictionary that includes the sequence length (under key"sequence_length"
) as a number and a list of the dimensions (under key"sizes"
) of each initial hidden stateThe
.compute()
method’sinputs
parameter will have, at least, the following items in the dictionary:"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states, if applicable"terminated"
: episode termination status for sampled environment transitions. This key is only defined during the training process"rnn"
: list of initial hidden states ordered according to the model specification
The
.compute()
method must include, under the"rnn"
key of the returned dictionary, a list of each final hidden state
import torch
import torch.nn as nn
from skrl.models.torch import Model, GaussianMixin
# define the model
class RNN(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False,
clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum",
num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
self.num_envs = num_envs
self.num_layers = num_layers
self.hidden_size = hidden_size # Hout
self.sequence_length = sequence_length
self.rnn = nn.RNN(input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True) # batch_first -> (batch, sequence, features)
self.net = nn.Sequential(nn.Linear(self.hidden_size, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, self.num_actions),
nn.Tanh())
self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
def get_specification(self):
# batch size (N) is the number of envs during rollout
return {"rnn": {"sequence_length": self.sequence_length,
"sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
def compute(self, inputs, role):
states = inputs["states"]
terminated = inputs.get("terminated", None)
hidden_states = inputs["rnn"][0]
# training
if self.training:
rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
# get the hidden states corresponding to the initial sequence
hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
# reset the RNN state in the middle of a sequence
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i + 1]
rnn_output, hidden_states = self.rnn(rnn_input[:,i0:i1,:], hidden_states)
hidden_states[:, (terminated[:,i1-1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_output = torch.cat(rnn_outputs, dim=1)
# no need to reset the RNN state in the sequence
else:
rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
# rollout
else:
rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
# flatten the RNN output
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
return self.net(rnn_output), self.log_std_parameter, {"rnn": [hidden_states]}
# instantiate the model (assumes there is a wrapped environment: env)
policy = RNN(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=True,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
reduction="sum",
num_envs=env.num_envs,
num_layers=1,
hidden_size=64,
sequence_length=10)
import torch
import torch.nn as nn
import torch.nn.functional as F
from skrl.models.torch import Model, GaussianMixin
# define the model
class RNN(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False,
clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum",
num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
self.num_envs = num_envs
self.num_layers = num_layers
self.hidden_size = hidden_size # Hout
self.sequence_length = sequence_length
self.rnn = nn.RNN(input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True) # batch_first -> (batch, sequence, features)
self.fc1 = nn.Linear(self.hidden_size, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, self.num_actions)
self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
def get_specification(self):
# batch size (N) is the number of envs during rollout
return {"rnn": {"sequence_length": self.sequence_length,
"sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
def compute(self, inputs, role):
states = inputs["states"]
terminated = inputs.get("terminated", None)
hidden_states = inputs["rnn"][0]
# training
if self.training:
rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
# get the hidden states corresponding to the initial sequence
hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
# reset the RNN state in the middle of a sequence
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i + 1]
rnn_output, hidden_states = self.rnn(rnn_input[:,i0:i1,:], hidden_states)
hidden_states[:, (terminated[:,i1-1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_output = torch.cat(rnn_outputs, dim=1)
# no need to reset the RNN state in the sequence
else:
rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
# rollout
else:
rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
# flatten the RNN output
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
x = self.fc1(rnn_output)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
return torch.tanh(x), self.log_std_parameter, {"rnn": [hidden_states]}
# instantiate the model (assumes there is a wrapped environment: env)
policy = RNN(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=True,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
reduction="sum",
num_envs=env.num_envs,
num_layers=1,
hidden_size=64,
sequence_length=10)
where:
The following points are relevant in the definition of recurrent models:
The
.get_specification()
method must be overwritten to return, under a dictionary key"rnn"
, a sub-dictionary that includes the sequence length (under key"sequence_length"
) as a number and a list of the dimensions (under key"sizes"
) of each initial hidden stateThe
.compute()
method’sinputs
parameter will have, at least, the following items in the dictionary:"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states, if applicable"terminated"
: episode termination status for sampled environment transitions. This key is only defined during the training process"rnn"
: list of initial hidden states ordered according to the model specification
The
.compute()
method must include, under the"rnn"
key of the returned dictionary, a list of each final hidden state
import torch
import torch.nn as nn
from skrl.models.torch import Model, GaussianMixin
# define the model
class GRU(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False,
clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum",
num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
self.num_envs = num_envs
self.num_layers = num_layers
self.hidden_size = hidden_size # Hout
self.sequence_length = sequence_length
self.gru = nn.GRU(input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True) # batch_first -> (batch, sequence, features)
self.net = nn.Sequential(nn.Linear(self.hidden_size, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, self.num_actions),
nn.Tanh())
self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
def get_specification(self):
# batch size (N) is the number of envs during rollout
return {"rnn": {"sequence_length": self.sequence_length,
"sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
def compute(self, inputs, role):
states = inputs["states"]
terminated = inputs.get("terminated", None)
hidden_states = inputs["rnn"][0]
# training
if self.training:
rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
# get the hidden states corresponding to the initial sequence
hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
# reset the RNN state in the middle of a sequence
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i + 1]
rnn_output, hidden_states = self.gru(rnn_input[:,i0:i1,:], hidden_states)
hidden_states[:, (terminated[:,i1-1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_output = torch.cat(rnn_outputs, dim=1)
# no need to reset the RNN state in the sequence
else:
rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
# rollout
else:
rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
# flatten the RNN output
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
return self.net(rnn_output), self.log_std_parameter, {"rnn": [hidden_states]}
# instantiate the model (assumes there is a wrapped environment: env)
policy = GRU(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=True,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
reduction="sum",
num_envs=env.num_envs,
num_layers=1,
hidden_size=64,
sequence_length=10)
import torch
import torch.nn as nn
import torch.nn.functional as F
from skrl.models.torch import Model, GaussianMixin
# define the model
class GRU(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False,
clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum",
num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
self.num_envs = num_envs
self.num_layers = num_layers
self.hidden_size = hidden_size # Hout
self.sequence_length = sequence_length
self.gru = nn.GRU(input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True) # batch_first -> (batch, sequence, features)
self.fc1 = nn.Linear(self.hidden_size, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, self.num_actions)
self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
def get_specification(self):
# batch size (N) is the number of envs during rollout
return {"rnn": {"sequence_length": self.sequence_length,
"sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
def compute(self, inputs, role):
states = inputs["states"]
terminated = inputs.get("terminated", None)
hidden_states = inputs["rnn"][0]
# training
if self.training:
rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
# get the hidden states corresponding to the initial sequence
hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
# reset the RNN state in the middle of a sequence
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i + 1]
rnn_output, hidden_states = self.gru(rnn_input[:,i0:i1,:], hidden_states)
hidden_states[:, (terminated[:,i1-1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_output = torch.cat(rnn_outputs, dim=1)
# no need to reset the RNN state in the sequence
else:
rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
# rollout
else:
rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
# flatten the RNN output
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
x = self.fc1(rnn_output)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
return torch.tanh(x), self.log_std_parameter, {"rnn": [hidden_states]}
# instantiate the model (assumes there is a wrapped environment: env)
policy = GRU(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=True,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
reduction="sum",
num_envs=env.num_envs,
num_layers=1,
hidden_size=64,
sequence_length=10)
where:
The following points are relevant in the definition of recurrent models:
The
.get_specification()
method must be overwritten to return, under a dictionary key"rnn"
, a sub-dictionary that includes the sequence length (under key"sequence_length"
) as a number and a list of the dimensions (under key"sizes"
) of each initial hidden/cell statesThe
.compute()
method’sinputs
parameter will have, at least, the following items in the dictionary:"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states, if applicable"terminated"
: episode termination status for sampled environment transitions. This key is only defined during the training process"rnn"
: list of initial hidden/cell states ordered according to the model specification
The
.compute()
method must include, under the"rnn"
key of the returned dictionary, a list of each final hidden/cell states
import torch
import torch.nn as nn
from skrl.models.torch import Model, GaussianMixin
# define the model
class LSTM(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False,
clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum",
num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
self.num_envs = num_envs
self.num_layers = num_layers
self.hidden_size = hidden_size # Hcell (Hout is Hcell because proj_size = 0)
self.sequence_length = sequence_length
self.lstm = nn.LSTM(input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True) # batch_first -> (batch, sequence, features)
self.net = nn.Sequential(nn.Linear(self.hidden_size, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, self.num_actions),
nn.Tanh())
self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
def get_specification(self):
# batch size (N) is the number of envs during rollout
return {"rnn": {"sequence_length": self.sequence_length,
"sizes": [(self.num_layers, self.num_envs, self.hidden_size), # hidden states (D ∗ num_layers, N, Hout)
(self.num_layers, self.num_envs, self.hidden_size)]}} # cell states (D ∗ num_layers, N, Hcell)
def compute(self, inputs, role):
states = inputs["states"]
terminated = inputs.get("terminated", None)
hidden_states, cell_states = inputs["rnn"][0], inputs["rnn"][1]
# training
if self.training:
rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
cell_states = cell_states.view(self.num_layers, -1, self.sequence_length, cell_states.shape[-1]) # (D * num_layers, N, L, Hcell)
# get the hidden/cell states corresponding to the initial sequence
hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
cell_states = cell_states[:,:,0,:].contiguous() # (D * num_layers, N, Hcell)
# reset the RNN state in the middle of a sequence
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i + 1]
rnn_output, (hidden_states, cell_states) = self.lstm(rnn_input[:,i0:i1,:], (hidden_states, cell_states))
hidden_states[:, (terminated[:,i1-1]), :] = 0
cell_states[:, (terminated[:,i1-1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_states = (hidden_states, cell_states)
rnn_output = torch.cat(rnn_outputs, dim=1)
# no need to reset the RNN state in the sequence
else:
rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
# rollout
else:
rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
# flatten the RNN output
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
return self.net(rnn_output), self.log_std_parameter, {"rnn": [rnn_states[0], rnn_states[1]]}
# instantiate the model (assumes there is a wrapped environment: env)
policy = LSTM(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=True,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
reduction="sum",
num_envs=env.num_envs,
num_layers=1,
hidden_size=64,
sequence_length=10)
import torch
import torch.nn as nn
import torch.nn.functional as F
from skrl.models.torch import Model, GaussianMixin
# define the model
class LSTM(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False,
clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum",
num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
self.num_envs = num_envs
self.num_layers = num_layers
self.hidden_size = hidden_size # Hcell (Hout is Hcell because proj_size = 0)
self.sequence_length = sequence_length
self.lstm = nn.LSTM(input_size=self.num_observations,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True) # batch_first -> (batch, sequence, features)
self.fc1 = nn.Linear(self.hidden_size, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, self.num_actions)
self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
def get_specification(self):
# batch size (N) is the number of envs during rollout
return {"rnn": {"sequence_length": self.sequence_length,
"sizes": [(self.num_layers, self.num_envs, self.hidden_size), # hidden states (D ∗ num_layers, N, Hout)
(self.num_layers, self.num_envs, self.hidden_size)]}} # cell states (D ∗ num_layers, N, Hcell)
def compute(self, inputs, role):
states = inputs["states"]
terminated = inputs.get("terminated", None)
hidden_states, cell_states = inputs["rnn"][0], inputs["rnn"][1]
# training
if self.training:
rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
cell_states = cell_states.view(self.num_layers, -1, self.sequence_length, cell_states.shape[-1]) # (D * num_layers, N, L, Hcell)
# get the hidden/cell states corresponding to the initial sequence
hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
cell_states = cell_states[:,:,0,:].contiguous() # (D * num_layers, N, Hcell)
# reset the RNN state in the middle of a sequence
if terminated is not None and torch.any(terminated):
rnn_outputs = []
terminated = terminated.view(-1, self.sequence_length)
indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
for i in range(len(indexes) - 1):
i0, i1 = indexes[i], indexes[i + 1]
rnn_output, (hidden_states, cell_states) = self.lstm(rnn_input[:,i0:i1,:], (hidden_states, cell_states))
hidden_states[:, (terminated[:,i1-1]), :] = 0
cell_states[:, (terminated[:,i1-1]), :] = 0
rnn_outputs.append(rnn_output)
rnn_states = (hidden_states, cell_states)
rnn_output = torch.cat(rnn_outputs, dim=1)
# no need to reset the RNN state in the sequence
else:
rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
# rollout
else:
rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
# flatten the RNN output
rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
x = self.fc1(rnn_output)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
return torch.tanh(x), self.log_std_parameter, {"rnn": [rnn_states[0], rnn_states[1]]}
# instantiate the model (assumes there is a wrapped environment: env)
policy = LSTM(observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
clip_actions=True,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
reduction="sum",
num_envs=env.num_envs,
num_layers=1,
hidden_size=64,
sequence_length=10)
API (PyTorch)¶
- class skrl.models.torch.gaussian.GaussianMixin(clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20, max_log_std: float = 2, reduction: str = 'sum', role: str = '')¶
Bases:
object
- __init__(clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20, max_log_std: float = 2, reduction: str = 'sum', role: str = '') None ¶
Gaussian mixin model (stochastic model)
- Parameters:
clip_actions (bool, optional) – Flag to indicate whether the actions should be clipped to the action space (default:
False
)clip_log_std (bool, optional) – Flag to indicate whether the log standard deviations should be clipped (default:
True
)min_log_std (float, optional) – Minimum value of the log standard deviation if
clip_log_std
is True (default:-20
)max_log_std (float, optional) – Maximum value of the log standard deviation if
clip_log_std
is True (default:2
)reduction (str, optional) – Reduction method for returning the log probability density function: (default:
"sum"
). Supported values are"mean"
,"sum"
,"prod"
and"none"
. If “none"
, the log probability density function is returned as a tensor of shape(num_samples, num_actions)
instead of(num_samples, 1)
role (str, optional) – Role play by the model (default:
""
)
- Raises:
ValueError – If the reduction method is not valid
Example:
# define the model >>> import torch >>> import torch.nn as nn >>> from skrl.models.torch import Model, GaussianMixin >>> >>> class Policy(GaussianMixin, Model): ... def __init__(self, observation_space, action_space, device="cuda:0", ... clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"): ... Model.__init__(self, observation_space, action_space, device) ... GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction) ... ... self.net = nn.Sequential(nn.Linear(self.num_observations, 32), ... nn.ELU(), ... nn.Linear(32, 32), ... nn.ELU(), ... nn.Linear(32, self.num_actions)) ... self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions)) ... ... def compute(self, inputs, role): ... return self.net(inputs["states"]), self.log_std_parameter, {} ... >>> # given an observation_space: gym.spaces.Box with shape (60,) >>> # and an action_space: gym.spaces.Box with shape (8,) >>> model = Policy(observation_space, action_space) >>> >>> print(model) Policy( (net): Sequential( (0): Linear(in_features=60, out_features=32, bias=True) (1): ELU(alpha=1.0) (2): Linear(in_features=32, out_features=32, bias=True) (3): ELU(alpha=1.0) (4): Linear(in_features=32, out_features=8, bias=True) ) )
- act(inputs: Mapping[str, torch.Tensor | Any], role: str = '') Tuple[torch.Tensor, torch.Tensor | None, Mapping[str, torch.Tensor | Any]] ¶
Act stochastically in response to the state of the environment
- Parameters:
inputs (dict where the values are typically torch.Tensor) –
Model inputs. The most common keys are:
"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states
role (str, optional) – Role play by the model (default:
""
)
- Returns:
Model output. The first component is the action to be taken by the agent. The second component is the log of the probability density function. The third component is a dictionary containing the mean actions
"mean_actions"
and extra output values- Return type:
tuple of torch.Tensor, torch.Tensor or None, and dict
Example:
>>> # given a batch of sample states with shape (4096, 60) >>> actions, log_prob, outputs = model.act({"states": states}) >>> print(actions.shape, log_prob.shape, outputs["mean_actions"].shape) torch.Size([4096, 8]) torch.Size([4096, 1]) torch.Size([4096, 8])
- distribution(role: str = '') torch.distributions.Normal ¶
Get the current distribution of the model
- Returns:
Distribution of the model
- Return type:
torch.distributions.Normal
- Parameters:
role (str, optional) – Role play by the model (default:
""
)
Example:
>>> distribution = model.distribution() >>> print(distribution) Normal(loc: torch.Size([4096, 8]), scale: torch.Size([4096, 8]))
- get_entropy(role: str = '') torch.Tensor ¶
Compute and return the entropy of the model
- Returns:
Entropy of the model
- Return type:
- Parameters:
role (str, optional) – Role play by the model (default:
""
)
Example:
>>> entropy = model.get_entropy() >>> print(entropy.shape) torch.Size([4096, 8])
- get_log_std(role: str = '') torch.Tensor ¶
Return the log standard deviation of the model
- Returns:
Log standard deviation of the model
- Return type:
- Parameters:
role (str, optional) – Role play by the model (default:
""
)
Example:
>>> log_std = model.get_log_std() >>> print(log_std.shape) torch.Size([4096, 8])
API (JAX)¶
- class skrl.models.jax.gaussian.GaussianMixin(clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20, max_log_std: float = 2, reduction: str = 'sum', role: str = '')¶
Bases:
object
- __init__(clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20, max_log_std: float = 2, reduction: str = 'sum', role: str = '') None ¶
Gaussian mixin model (stochastic model)
- Parameters:
clip_actions (bool, optional) – Flag to indicate whether the actions should be clipped to the action space (default:
False
)clip_log_std (bool, optional) – Flag to indicate whether the log standard deviations should be clipped (default:
True
)min_log_std (float, optional) – Minimum value of the log standard deviation if
clip_log_std
is True (default:-20
)max_log_std (float, optional) – Maximum value of the log standard deviation if
clip_log_std
is True (default:2
)reduction (str, optional) – Reduction method for returning the log probability density function: (default:
"sum"
). Supported values are"mean"
,"sum"
,"prod"
and"none"
. If “none"
, the log probability density function is returned as a tensor of shape(num_samples, num_actions)
instead of(num_samples, 1)
role (str, optional) – Role play by the model (default:
""
)
- Raises:
ValueError – If the reduction method is not valid
Example:
# define the model >>> import flax.linen as nn >>> from skrl.models.jax import Model, GaussianMixin >>> >>> class Policy(GaussianMixin, Model): ... def __init__(self, observation_space, action_space, device=None, ... clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum", **kwargs): ... Model.__init__(self, observation_space, action_space, device, **kwargs) ... GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction) ... ... def setup(self): ... self.layer_1 = nn.Dense(32) ... self.layer_2 = nn.Dense(32) ... self.layer_3 = nn.Dense(self.num_actions) ... ... self.log_std_parameter = self.param("log_std_parameter", lambda _: jnp.zeros(self.num_actions)) ... ... def __call__(self, inputs, role): ... x = nn.elu(self.layer_1(inputs["states"])) ... x = nn.elu(self.layer_2(x)) ... return self.layer_3(x), self.log_std_parameter, {} ... >>> # given an observation_space: gym.spaces.Box with shape (60,) >>> # and an action_space: gym.spaces.Box with shape (8,) >>> model = Policy(observation_space, action_space) >>> >>> print(model) Policy( # attributes observation_space = Box(-1.0, 1.0, (60,), float32) action_space = Box(-1.0, 1.0, (8,), float32) device = StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0) )
- act(inputs: Mapping[str, ndarray | jax.Array | Any], role: str = '', params: jax.Array | None = None) Tuple[jax.Array, jax.Array | None, Mapping[str, jax.Array | Any]] ¶
Act stochastically in response to the state of the environment
- Parameters:
inputs (dict where the values are typically np.ndarray or jax.Array) –
Model inputs. The most common keys are:
"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states
role (str, optional) – Role play by the model (default:
""
)params (jnp.array) – Parameters used to compute the output (default:
None
). IfNone
, internal parameters will be used
- Returns:
Model output. The first component is the action to be taken by the agent. The second component is the log of the probability density function. The third component is a dictionary containing the mean actions
"mean_actions"
and extra output values- Return type:
Example:
>>> # given a batch of sample states with shape (4096, 60) >>> actions, log_prob, outputs = model.act({"states": states}) >>> print(actions.shape, log_prob.shape, outputs["mean_actions"].shape) (4096, 8) (4096, 1) (4096, 8)
- get_entropy(stddev: jax.Array, role: str = '') jax.Array ¶
Compute and return the entropy of the model
- Parameters:
role (str, optional) – Role play by the model (default:
""
)- Returns:
Entropy of the model
- Return type:
Example:
# given a standard deviation array: stddev >>> entropy = model.get_entropy(stddev) >>> print(entropy.shape) (4096, 8)