Gaussian model
Gaussian models run continuous-domain stochastic policies.
skrl provides a Python mixin (GaussianMixin
) to assist in the creation of these types of models, allowing users to have full control over the function approximator definitions and architectures. Note that the use of this mixin must comply with the following rules:
The definition of multiple inheritance must always include the Model base class at the end.
class GaussianModel(GaussianMixin, Model): def __init__(self, observation_space, action_space, device="cuda:0", clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"): Model.__init__(self, observation_space, action_space, device) GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
The Model base class constructor must be invoked before the mixins constructor.
class GaussianModel(GaussianMixin, Model): def __init__(self, observation_space, action_space, device="cuda:0", clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"): Model.__init__(self, observation_space, action_space, device) GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
Concept
Basic usage
Multi-Layer Perceptron (MLP)
Convolutional Neural Network (CNN)
Recurrent Neural Network (RNN)
Gated Recurrent Unit RNN (GRU)
Long Short-Term Memory RNN (LSTM)
1import torch
2import torch.nn as nn
3
4from skrl.models.torch import Model, GaussianMixin
5
6
7# define the model
8class MLP(GaussianMixin, Model):
9 def __init__(self, observation_space, action_space, device,
10 clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"):
11 Model.__init__(self, observation_space, action_space, device)
12 GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
13
14 self.net = nn.Sequential(nn.Linear(self.num_observations, 64),
15 nn.ReLU(),
16 nn.Linear(64, 32),
17 nn.ReLU(),
18 nn.Linear(32, self.num_actions),
19 nn.Tanh())
20
21 self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
22
23 def compute(self, inputs, role):
24 return self.net(inputs["states"]), self.log_std_parameter, {}
25
26
27# instantiate the model (assumes there is a wrapped environment: env)
28policy = MLP(observation_space=env.observation_space,
29 action_space=env.action_space,
30 device=env.device,
31 clip_actions=True,
32 clip_log_std=True,
33 min_log_std=-20,
34 max_log_std=2,
35 reduction="sum")
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5from skrl.models.torch import Model, GaussianMixin
6
7
8# define the model
9class MLP(GaussianMixin, Model):
10 def __init__(self, observation_space, action_space, device,
11 clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"):
12 Model.__init__(self, observation_space, action_space, device)
13 GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
14
15 self.fc1 = nn.Linear(self.num_observations, 64)
16 self.fc2 = nn.Linear(64, 32)
17 self.fc3 = nn.Linear(32, self.num_actions)
18
19 self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
20
21 def compute(self, inputs, role):
22 x = self.fc1(inputs["states"])
23 x = F.relu(x)
24 x = self.fc2(x)
25 x = F.relu(x)
26 x = self.fc3(x)
27 return torch.tanh(x), self.log_std_parameter, {}
28
29
30# instantiate the model (assumes there is a wrapped environment: env)
31policy = MLP(observation_space=env.observation_space,
32 action_space=env.action_space,
33 device=env.device,
34 clip_actions=True,
35 clip_log_std=True,
36 min_log_std=-20,
37 max_log_std=2,
38 reduction="sum")
1import torch
2import torch.nn as nn
3
4from skrl.models.torch import Model, GaussianMixin
5
6
7# define the model
8class CNN(GaussianMixin, Model):
9 def __init__(self, observation_space, action_space, device,
10 clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"):
11 Model.__init__(self, observation_space, action_space, device)
12 GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
13
14 self.net = nn.Sequential(nn.Conv2d(3, 32, kernel_size=8, stride=4),
15 nn.ReLU(),
16 nn.Conv2d(32, 64, kernel_size=4, stride=2),
17 nn.ReLU(),
18 nn.Conv2d(64, 64, kernel_size=3, stride=1),
19 nn.ReLU(),
20 nn.Flatten(),
21 nn.Linear(1024, 512),
22 nn.ReLU(),
23 nn.Linear(512, 16),
24 nn.Tanh(),
25 nn.Linear(16, 64),
26 nn.Tanh(),
27 nn.Linear(64, 32),
28 nn.Tanh(),
29 nn.Linear(32, self.num_actions))
30
31 self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
32
33 def compute(self, inputs, role):
34 # permute (samples, width * height * channels) -> (samples, channels, width, height)
35 return self.net(inputs["states"].view(-1, *self.observation_space.shape).permute(0, 3, 1, 2)), self.log_std_parameter, {}
36
37
38# instantiate the model (assumes there is a wrapped environment: env)
39policy = CNN(observation_space=env.observation_space,
40 action_space=env.action_space,
41 device=env.device,
42 clip_actions=True,
43 clip_log_std=True,
44 min_log_std=-20,
45 max_log_std=2,
46 reduction="sum")
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5from skrl.models.torch import Model, GaussianMixin
6
7
8# define the model
9class CNN(GaussianMixin, Model):
10 def __init__(self, observation_space, action_space, device,
11 clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"):
12 Model.__init__(self, observation_space, action_space, device)
13 GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
14
15 self.conv1 = nn.Conv2d(3, 32, kernel_size=8, stride=4)
16 self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
17 self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
18 self.fc1 = nn.Linear(1024, 512)
19 self.fc2 = nn.Linear(512, 16)
20 self.fc3 = nn.Linear(16, 64)
21 self.fc4 = nn.Linear(64, 32)
22 self.fc5 = nn.Linear(32, self.num_actions)
23
24 self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
25
26 def compute(self, inputs, role):
27 # permute (samples, width * height * channels) -> (samples, channels, width, height)
28 x = inputs["states"].view(-1, *self.observation_space.shape).permute(0, 3, 1, 2)
29 x = self.conv1(x)
30 x = F.relu(x)
31 x = self.conv2(x)
32 x = F.relu(x)
33 x = self.conv3(x)
34 x = F.relu(x)
35 x = torch.flatten(x, start_dim=1)
36 x = self.fc1(x)
37 x = F.relu(x)
38 x = self.fc2(x)
39 x = torch.tanh(x)
40 x = self.fc3(x)
41 x = torch.tanh(x)
42 x = self.fc4(x)
43 x = torch.tanh(x)
44 x = self.fc5(x)
45 return x, self.log_std_parameter, {}
46
47
48# instantiate the model (assumes there is a wrapped environment: env)
49policy = CNN(observation_space=env.observation_space,
50 action_space=env.action_space,
51 device=env.device,
52 clip_actions=True,
53 clip_log_std=True,
54 min_log_std=-20,
55 max_log_std=2,
56 reduction="sum")
where:
The following points are relevant in the definition of recurrent models:
The
.get_specification()
method must be overwritten to return, under a dictionary key"rnn"
, a sub-dictionary that includes the sequence length (under key"sequence_length"
) as a number and a list of the dimensions (under key"sizes"
) of each initial hidden stateThe
.compute()
method’sinputs
parameter will have, at least, the following items in the dictionary:"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states, if applicable"terminated"
: episode termination status for sampled environment transitions. This key is only defined during the training process"rnn"
: list of initial hidden states ordered according to the model specification
The
.compute()
method must inlcude, under the"rnn"
key of the returned dictionary, a list of each final hidden state
1import torch
2import torch.nn as nn
3
4from skrl.models.torch import Model, GaussianMixin
5
6
7# define the model
8class RNN(GaussianMixin, Model):
9 def __init__(self, observation_space, action_space, device, clip_actions=False,
10 clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum",
11 num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
12 Model.__init__(self, observation_space, action_space, device)
13 GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
14
15 self.num_envs = num_envs
16 self.num_layers = num_layers
17 self.hidden_size = hidden_size # Hout
18 self.sequence_length = sequence_length
19
20 self.rnn = nn.RNN(input_size=self.num_observations,
21 hidden_size=self.hidden_size,
22 num_layers=self.num_layers,
23 batch_first=True) # batch_first -> (batch, sequence, features)
24
25 self.net = nn.Sequential(nn.Linear(self.hidden_size, 64),
26 nn.ReLU(),
27 nn.Linear(64, 32),
28 nn.ReLU(),
29 nn.Linear(32, self.num_actions),
30 nn.Tanh())
31
32 self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
33
34 def get_specification(self):
35 # batch size (N) is the number of envs during rollout
36 return {"rnn": {"sequence_length": self.sequence_length,
37 "sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
38
39 def compute(self, inputs, role):
40 states = inputs["states"]
41 terminated = inputs.get("terminated", None)
42 hidden_states = inputs["rnn"][0]
43
44 # training
45 if self.training:
46 rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
47 hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
48 # get the hidden states corresponding to the initial sequence
49 hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
50
51 # reset the RNN state in the middle of a sequence
52 if terminated is not None and torch.any(terminated):
53 rnn_outputs = []
54 terminated = terminated.view(-1, self.sequence_length)
55 indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
56
57 for i in range(len(indexes) - 1):
58 i0, i1 = indexes[i], indexes[i + 1]
59 rnn_output, hidden_states = self.rnn(rnn_input[:,i0:i1,:], hidden_states)
60 hidden_states[:, (terminated[:,i1-1]), :] = 0
61 rnn_outputs.append(rnn_output)
62
63 rnn_output = torch.cat(rnn_outputs, dim=1)
64 # no need to reset the RNN state in the sequence
65 else:
66 rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
67 # rollout
68 else:
69 rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
70 rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
71
72 # flatten the RNN output
73 rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
74
75 return self.net(rnn_output), self.log_std_parameter, {"rnn": [hidden_states]}
76
77
78# instantiate the model (assumes there is a wrapped environment: env)
79policy = RNN(observation_space=env.observation_space,
80 action_space=env.action_space,
81 device=env.device,
82 clip_actions=True,
83 clip_log_std=True,
84 min_log_std=-20,
85 max_log_std=2,
86 reduction="sum",
87 num_envs=env.num_envs,
88 num_layers=1,
89 hidden_size=64,
90 sequence_length=10)
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5from skrl.models.torch import Model, GaussianMixin
6
7
8# define the model
9class RNN(GaussianMixin, Model):
10 def __init__(self, observation_space, action_space, device, clip_actions=False,
11 clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum",
12 num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
13 Model.__init__(self, observation_space, action_space, device)
14 GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
15
16 self.num_envs = num_envs
17 self.num_layers = num_layers
18 self.hidden_size = hidden_size # Hout
19 self.sequence_length = sequence_length
20
21 self.rnn = nn.RNN(input_size=self.num_observations,
22 hidden_size=self.hidden_size,
23 num_layers=self.num_layers,
24 batch_first=True) # batch_first -> (batch, sequence, features)
25
26 self.fc1 = nn.Linear(self.hidden_size, 64)
27 self.fc2 = nn.Linear(64, 32)
28 self.fc3 = nn.Linear(32, self.num_actions)
29
30 self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
31
32 def get_specification(self):
33 # batch size (N) is the number of envs during rollout
34 return {"rnn": {"sequence_length": self.sequence_length,
35 "sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
36
37 def compute(self, inputs, role):
38 states = inputs["states"]
39 terminated = inputs.get("terminated", None)
40 hidden_states = inputs["rnn"][0]
41
42 # training
43 if self.training:
44 rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
45 hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
46 # get the hidden states corresponding to the initial sequence
47 hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
48
49 # reset the RNN state in the middle of a sequence
50 if terminated is not None and torch.any(terminated):
51 rnn_outputs = []
52 terminated = terminated.view(-1, self.sequence_length)
53 indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
54
55 for i in range(len(indexes) - 1):
56 i0, i1 = indexes[i], indexes[i + 1]
57 rnn_output, hidden_states = self.rnn(rnn_input[:,i0:i1,:], hidden_states)
58 hidden_states[:, (terminated[:,i1-1]), :] = 0
59 rnn_outputs.append(rnn_output)
60
61 rnn_output = torch.cat(rnn_outputs, dim=1)
62 # no need to reset the RNN state in the sequence
63 else:
64 rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
65 # rollout
66 else:
67 rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
68 rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
69
70 # flatten the RNN output
71 rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
72
73 x = self.fc1(rnn_output)
74 x = F.relu(x)
75 x = self.fc2(x)
76 x = F.relu(x)
77 x = self.fc3(x)
78
79 return torch.tanh(x), self.log_std_parameter, {"rnn": [hidden_states]}
80
81
82# instantiate the model (assumes there is a wrapped environment: env)
83policy = RNN(observation_space=env.observation_space,
84 action_space=env.action_space,
85 device=env.device,
86 clip_actions=True,
87 clip_log_std=True,
88 min_log_std=-20,
89 max_log_std=2,
90 reduction="sum",
91 num_envs=env.num_envs,
92 num_layers=1,
93 hidden_size=64,
94 sequence_length=10)
where:
The following points are relevant in the definition of recurrent models:
The
.get_specification()
method must be overwritten to return, under a dictionary key"rnn"
, a sub-dictionary that includes the sequence length (under key"sequence_length"
) as a number and a list of the dimensions (under key"sizes"
) of each initial hidden stateThe
.compute()
method’sinputs
parameter will have, at least, the following items in the dictionary:"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states, if applicable"terminated"
: episode termination status for sampled environment transitions. This key is only defined during the training process"rnn"
: list of initial hidden states ordered according to the model specification
The
.compute()
method must inlcude, under the"rnn"
key of the returned dictionary, a list of each final hidden state
1import torch
2import torch.nn as nn
3
4from skrl.models.torch import Model, GaussianMixin
5
6
7# define the model
8class GRU(GaussianMixin, Model):
9 def __init__(self, observation_space, action_space, device, clip_actions=False,
10 clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum",
11 num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
12 Model.__init__(self, observation_space, action_space, device)
13 GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
14
15 self.num_envs = num_envs
16 self.num_layers = num_layers
17 self.hidden_size = hidden_size # Hout
18 self.sequence_length = sequence_length
19
20 self.gru = nn.GRU(input_size=self.num_observations,
21 hidden_size=self.hidden_size,
22 num_layers=self.num_layers,
23 batch_first=True) # batch_first -> (batch, sequence, features)
24
25 self.net = nn.Sequential(nn.Linear(self.hidden_size, 64),
26 nn.ReLU(),
27 nn.Linear(64, 32),
28 nn.ReLU(),
29 nn.Linear(32, self.num_actions),
30 nn.Tanh())
31
32 self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
33
34 def get_specification(self):
35 # batch size (N) is the number of envs during rollout
36 return {"rnn": {"sequence_length": self.sequence_length,
37 "sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
38
39 def compute(self, inputs, role):
40 states = inputs["states"]
41 terminated = inputs.get("terminated", None)
42 hidden_states = inputs["rnn"][0]
43
44 # training
45 if self.training:
46 rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
47 hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
48 # get the hidden states corresponding to the initial sequence
49 hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
50
51 # reset the RNN state in the middle of a sequence
52 if terminated is not None and torch.any(terminated):
53 rnn_outputs = []
54 terminated = terminated.view(-1, self.sequence_length)
55 indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
56
57 for i in range(len(indexes) - 1):
58 i0, i1 = indexes[i], indexes[i + 1]
59 rnn_output, hidden_states = self.gru(rnn_input[:,i0:i1,:], hidden_states)
60 hidden_states[:, (terminated[:,i1-1]), :] = 0
61 rnn_outputs.append(rnn_output)
62
63 rnn_output = torch.cat(rnn_outputs, dim=1)
64 # no need to reset the RNN state in the sequence
65 else:
66 rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
67 # rollout
68 else:
69 rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
70 rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
71
72 # flatten the RNN output
73 rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
74
75 return self.net(rnn_output), self.log_std_parameter, {"rnn": [hidden_states]}
76
77
78# instantiate the model (assumes there is a wrapped environment: env)
79policy = GRU(observation_space=env.observation_space,
80 action_space=env.action_space,
81 device=env.device,
82 clip_actions=True,
83 clip_log_std=True,
84 min_log_std=-20,
85 max_log_std=2,
86 reduction="sum",
87 num_envs=env.num_envs,
88 num_layers=1,
89 hidden_size=64,
90 sequence_length=10)
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5from skrl.models.torch import Model, GaussianMixin
6
7
8# define the model
9class GRU(GaussianMixin, Model):
10 def __init__(self, observation_space, action_space, device, clip_actions=False,
11 clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum",
12 num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
13 Model.__init__(self, observation_space, action_space, device)
14 GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
15
16 self.num_envs = num_envs
17 self.num_layers = num_layers
18 self.hidden_size = hidden_size # Hout
19 self.sequence_length = sequence_length
20
21 self.gru = nn.GRU(input_size=self.num_observations,
22 hidden_size=self.hidden_size,
23 num_layers=self.num_layers,
24 batch_first=True) # batch_first -> (batch, sequence, features)
25
26 self.fc1 = nn.Linear(self.hidden_size, 64)
27 self.fc2 = nn.Linear(64, 32)
28 self.fc3 = nn.Linear(32, self.num_actions)
29
30 self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
31
32 def get_specification(self):
33 # batch size (N) is the number of envs during rollout
34 return {"rnn": {"sequence_length": self.sequence_length,
35 "sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
36
37 def compute(self, inputs, role):
38 states = inputs["states"]
39 terminated = inputs.get("terminated", None)
40 hidden_states = inputs["rnn"][0]
41
42 # training
43 if self.training:
44 rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
45 hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
46 # get the hidden states corresponding to the initial sequence
47 hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
48
49 # reset the RNN state in the middle of a sequence
50 if terminated is not None and torch.any(terminated):
51 rnn_outputs = []
52 terminated = terminated.view(-1, self.sequence_length)
53 indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
54
55 for i in range(len(indexes) - 1):
56 i0, i1 = indexes[i], indexes[i + 1]
57 rnn_output, hidden_states = self.gru(rnn_input[:,i0:i1,:], hidden_states)
58 hidden_states[:, (terminated[:,i1-1]), :] = 0
59 rnn_outputs.append(rnn_output)
60
61 rnn_output = torch.cat(rnn_outputs, dim=1)
62 # no need to reset the RNN state in the sequence
63 else:
64 rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
65 # rollout
66 else:
67 rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
68 rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
69
70 # flatten the RNN output
71 rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
72
73 x = self.fc1(rnn_output)
74 x = F.relu(x)
75 x = self.fc2(x)
76 x = F.relu(x)
77 x = self.fc3(x)
78
79 return torch.tanh(x), self.log_std_parameter, {"rnn": [hidden_states]}
80
81
82# instantiate the model (assumes there is a wrapped environment: env)
83policy = GRU(observation_space=env.observation_space,
84 action_space=env.action_space,
85 device=env.device,
86 clip_actions=True,
87 clip_log_std=True,
88 min_log_std=-20,
89 max_log_std=2,
90 reduction="sum",
91 num_envs=env.num_envs,
92 num_layers=1,
93 hidden_size=64,
94 sequence_length=10)
where:
The following points are relevant in the definition of recurrent models:
The
.get_specification()
method must be overwritten to return, under a dictionary key"rnn"
, a sub-dictionary that includes the sequence length (under key"sequence_length"
) as a number and a list of the dimensions (under key"sizes"
) of each initial hidden/cell statesThe
.compute()
method’sinputs
parameter will have, at least, the following items in the dictionary:"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states, if applicable"terminated"
: episode termination status for sampled environment transitions. This key is only defined during the training process"rnn"
: list of initial hidden/cell states ordered according to the model specification
The
.compute()
method must inlcude, under the"rnn"
key of the returned dictionary, a list of each final hidden/cell states
1import torch
2import torch.nn as nn
3
4from skrl.models.torch import Model, GaussianMixin
5
6
7# define the model
8class LSTM(GaussianMixin, Model):
9 def __init__(self, observation_space, action_space, device, clip_actions=False,
10 clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum",
11 num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
12 Model.__init__(self, observation_space, action_space, device)
13 GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
14
15 self.num_envs = num_envs
16 self.num_layers = num_layers
17 self.hidden_size = hidden_size # Hcell (Hout is Hcell because proj_size = 0)
18 self.sequence_length = sequence_length
19
20 self.lstm = nn.LSTM(input_size=self.num_observations,
21 hidden_size=self.hidden_size,
22 num_layers=self.num_layers,
23 batch_first=True) # batch_first -> (batch, sequence, features)
24
25 self.net = nn.Sequential(nn.Linear(self.hidden_size, 64),
26 nn.ReLU(),
27 nn.Linear(64, 32),
28 nn.ReLU(),
29 nn.Linear(32, self.num_actions),
30 nn.Tanh())
31
32 self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
33
34 def get_specification(self):
35 # batch size (N) is the number of envs during rollout
36 return {"rnn": {"sequence_length": self.sequence_length,
37 "sizes": [(self.num_layers, self.num_envs, self.hidden_size), # hidden states (D ∗ num_layers, N, Hout)
38 (self.num_layers, self.num_envs, self.hidden_size)]}} # cell states (D ∗ num_layers, N, Hcell)
39
40 def compute(self, inputs, role):
41 states = inputs["states"]
42 terminated = inputs.get("terminated", None)
43 hidden_states, cell_states = inputs["rnn"][0], inputs["rnn"][1]
44
45 # training
46 if self.training:
47 rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
48 hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
49 cell_states = cell_states.view(self.num_layers, -1, self.sequence_length, cell_states.shape[-1]) # (D * num_layers, N, L, Hcell)
50 # get the hidden/cell states corresponding to the initial sequence
51 hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
52 cell_states = cell_states[:,:,0,:].contiguous() # (D * num_layers, N, Hcell)
53
54 # reset the RNN state in the middle of a sequence
55 if terminated is not None and torch.any(terminated):
56 rnn_outputs = []
57 terminated = terminated.view(-1, self.sequence_length)
58 indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
59
60 for i in range(len(indexes) - 1):
61 i0, i1 = indexes[i], indexes[i + 1]
62 rnn_output, (hidden_states, cell_states) = self.lstm(rnn_input[:,i0:i1,:], (hidden_states, cell_states))
63 hidden_states[:, (terminated[:,i1-1]), :] = 0
64 cell_states[:, (terminated[:,i1-1]), :] = 0
65 rnn_outputs.append(rnn_output)
66
67 rnn_states = (hidden_states, cell_states)
68 rnn_output = torch.cat(rnn_outputs, dim=1)
69 # no need to reset the RNN state in the sequence
70 else:
71 rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
72 # rollout
73 else:
74 rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
75 rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
76
77 # flatten the RNN output
78 rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
79
80 return self.net(rnn_output), self.log_std_parameter, {"rnn": [rnn_states[0], rnn_states[1]]}
81
82
83# instantiate the model (assumes there is a wrapped environment: env)
84policy = LSTM(observation_space=env.observation_space,
85 action_space=env.action_space,
86 device=env.device,
87 clip_actions=True,
88 clip_log_std=True,
89 min_log_std=-20,
90 max_log_std=2,
91 reduction="sum",
92 num_envs=env.num_envs,
93 num_layers=1,
94 hidden_size=64,
95 sequence_length=10)
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5from skrl.models.torch import Model, GaussianMixin
6
7
8# define the model
9class LSTM(GaussianMixin, Model):
10 def __init__(self, observation_space, action_space, device, clip_actions=False,
11 clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum",
12 num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
13 Model.__init__(self, observation_space, action_space, device)
14 GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
15
16 self.num_envs = num_envs
17 self.num_layers = num_layers
18 self.hidden_size = hidden_size # Hcell (Hout is Hcell because proj_size = 0)
19 self.sequence_length = sequence_length
20
21 self.lstm = nn.LSTM(input_size=self.num_observations,
22 hidden_size=self.hidden_size,
23 num_layers=self.num_layers,
24 batch_first=True) # batch_first -> (batch, sequence, features)
25
26 self.fc1 = nn.Linear(self.hidden_size, 64)
27 self.fc2 = nn.Linear(64, 32)
28 self.fc3 = nn.Linear(32, self.num_actions)
29
30 self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
31
32 def get_specification(self):
33 # batch size (N) is the number of envs during rollout
34 return {"rnn": {"sequence_length": self.sequence_length,
35 "sizes": [(self.num_layers, self.num_envs, self.hidden_size), # hidden states (D ∗ num_layers, N, Hout)
36 (self.num_layers, self.num_envs, self.hidden_size)]}} # cell states (D ∗ num_layers, N, Hcell)
37
38 def compute(self, inputs, role):
39 states = inputs["states"]
40 terminated = inputs.get("terminated", None)
41 hidden_states, cell_states = inputs["rnn"][0], inputs["rnn"][1]
42
43 # training
44 if self.training:
45 rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
46 hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
47 cell_states = cell_states.view(self.num_layers, -1, self.sequence_length, cell_states.shape[-1]) # (D * num_layers, N, L, Hcell)
48 # get the hidden/cell states corresponding to the initial sequence
49 hidden_states = hidden_states[:,:,0,:].contiguous() # (D * num_layers, N, Hout)
50 cell_states = cell_states[:,:,0,:].contiguous() # (D * num_layers, N, Hcell)
51
52 # reset the RNN state in the middle of a sequence
53 if terminated is not None and torch.any(terminated):
54 rnn_outputs = []
55 terminated = terminated.view(-1, self.sequence_length)
56 indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
57
58 for i in range(len(indexes) - 1):
59 i0, i1 = indexes[i], indexes[i + 1]
60 rnn_output, (hidden_states, cell_states) = self.lstm(rnn_input[:,i0:i1,:], (hidden_states, cell_states))
61 hidden_states[:, (terminated[:,i1-1]), :] = 0
62 cell_states[:, (terminated[:,i1-1]), :] = 0
63 rnn_outputs.append(rnn_output)
64
65 rnn_states = (hidden_states, cell_states)
66 rnn_output = torch.cat(rnn_outputs, dim=1)
67 # no need to reset the RNN state in the sequence
68 else:
69 rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
70 # rollout
71 else:
72 rnn_input = states.view(-1, 1, states.shape[-1]) # (N, L, Hin): N=num_envs, L=1
73 rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
74
75 # flatten the RNN output
76 rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
77
78 x = self.fc1(rnn_output)
79 x = F.relu(x)
80 x = self.fc2(x)
81 x = F.relu(x)
82 x = self.fc3(x)
83
84 return torch.tanh(x), self.log_std_parameter, {"rnn": [rnn_states[0], rnn_states[1]]}
85
86
87# instantiate the model (assumes there is a wrapped environment: env)
88policy = LSTM(observation_space=env.observation_space,
89 action_space=env.action_space,
90 device=env.device,
91 clip_actions=True,
92 clip_log_std=True,
93 min_log_std=-20,
94 max_log_std=2,
95 reduction="sum",
96 num_envs=env.num_envs,
97 num_layers=1,
98 hidden_size=64,
99 sequence_length=10)
API
- class skrl.models.torch.gaussian.GaussianMixin(clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = - 20, max_log_std: float = 2, reduction: str = 'sum', role: str = '')
Bases:
object
- __init__(clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = - 20, max_log_std: float = 2, reduction: str = 'sum', role: str = '') None
Gaussian mixin model (stochastic model)
- Parameters
clip_actions (bool, optional) – Flag to indicate whether the actions should be clipped to the action space (default:
False
)clip_log_std (bool, optional) – Flag to indicate whether the log standard deviations should be clipped (default:
True
)min_log_std (float, optional) – Minimum value of the log standard deviation if
clip_log_std
is True (default:-20
)max_log_std (float, optional) – Maximum value of the log standard deviation if
clip_log_std
is True (default:2
)reduction (str, optional) – Reduction method for returning the log probability density function: (default:
"sum"
). Supported values are"mean"
,"sum"
,"prod"
and"none"
. If “none"
, the log probability density function is returned as a tensor of shape(num_samples, num_actions)
instead of(num_samples, 1)
role (str, optional) – Role play by the model (default:
""
)
- Raises
ValueError – If the reduction method is not valid
Example:
# define the model >>> import torch >>> import torch.nn as nn >>> from skrl.models.torch import Model, GaussianMixin >>> >>> class Policy(GaussianMixin, Model): ... def __init__(self, observation_space, action_space, device="cuda:0", ... clip_actions=False, clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"): ... Model.__init__(self, observation_space, action_space, device) ... GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction) ... ... self.net = nn.Sequential(nn.Linear(self.num_observations, 32), ... nn.ELU(), ... nn.Linear(32, 32), ... nn.ELU(), ... nn.Linear(32, self.num_actions)) ... self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions)) ... ... def compute(self, inputs, role): ... return self.net(inputs["states"]), self.log_std_parameter, {} ... >>> # given an observation_space: gym.spaces.Box with shape (60,) >>> # and an action_space: gym.spaces.Box with shape (8,) >>> model = Policy(observation_space, action_space) >>> >>> print(model) Policy( (net): Sequential( (0): Linear(in_features=60, out_features=32, bias=True) (1): ELU(alpha=1.0) (2): Linear(in_features=32, out_features=32, bias=True) (3): ELU(alpha=1.0) (4): Linear(in_features=32, out_features=8, bias=True) ) )
- act(inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = '') Tuple[torch.Tensor, Optional[torch.Tensor], Mapping[str, Union[torch.Tensor, Any]]]
Act stochastically in response to the state of the environment
- Parameters
inputs (dict where the values are typically torch.Tensor) –
Model inputs. The most common keys are:
"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states
role (str, optional) – Role play by the model (default:
""
)
- Returns
Model output. The first component is the action to be taken by the agent. The second component is the log of the probability density function. The third component is a dictionary containing the mean actions
"mean_actions"
and extra output values- Return type
tuple of torch.Tensor, torch.Tensor or None, and dictionary
Example:
>>> # given a batch of sample states with shape (4096, 60) >>> actions, log_prob, outputs = model.act({"states": states}) >>> print(actions.shape, log_prob.shape, outputs["mean_actions"].shape) torch.Size([4096, 8]) torch.Size([4096, 1]) torch.Size([4096, 8])
- distribution(role: str = '') torch.distributions.normal.Normal
Get the current distribution of the model
- Returns
Distribution of the model
- Return type
torch.distributions.Normal
- Parameters
role (str, optional) – Role play by the model (default:
""
)
Example:
>>> distribution = model.distribution() >>> print(distribution) Normal(loc: torch.Size([4096, 8]), scale: torch.Size([4096, 8]))
- get_entropy(role: str = '') torch.Tensor
Compute and return the entropy of the model
- Returns
Entropy of the model
- Return type
- Parameters
role (str, optional) – Role play by the model (default:
""
)
Example:
>>> entropy = model.get_entropy() >>> print(entropy.shape) torch.Size([4096, 8])
- get_log_std(role: str = '') torch.Tensor
Return the log standard deviation of the model
- Returns
Log standard deviation of the model
- Return type
- Parameters
role (str, optional) – Role play by the model (default:
""
)
Example:
>>> log_std = model.get_log_std() >>> print(log_std.shape) torch.Size([4096, 8])