Deterministic model
Deterministic models run continuous-domain deterministic policies.
skrl provides a Python mixin (DeterministicMixin
) to assist in the creation of these types of models, allowing users to have full control over the function approximator definitions and architectures. Note that the use of this mixin must comply with the following rules:
The definition of multiple inheritance must always include the Model base class at the end.
class DeterministicModel(DeterministicMixin, Model): def __init__(self, observation_space, action_space, device="cuda:0", clip_actions=False): Model.__init__(self, observation_space, action_space, device) DeterministicMixin.__init__(self, clip_actions)
The Model base class constructor must be invoked before the mixins constructor.
class DeterministicModel(DeterministicMixin, Model): def __init__(self, observation_space, action_space, device="cuda:0", clip_actions=False): Model.__init__(self, observation_space, action_space, device) DeterministicMixin.__init__(self, clip_actions)
Concept
Basic usage
Multi-Layer Perceptron (MLP)
Convolutional Neural Network (CNN)
Recurrent Neural Network (RNN)
Gated Recurrent Unit RNN (GRU)
Long Short-Term Memory RNN (LSTM)
1import torch
2import torch.nn as nn
3
4from skrl.models.torch import Model, DeterministicMixin
5
6
7# define the model
8class MLP(DeterministicMixin, Model):
9 def __init__(self, observation_space, action_space, device, clip_actions=False):
10 Model.__init__(self, observation_space, action_space, device)
11 DeterministicMixin.__init__(self, clip_actions)
12
13 self.net = nn.Sequential(nn.Linear(self.num_observations + self.num_actions, 64),
14 nn.ReLU(),
15 nn.Linear(64, 32),
16 nn.ReLU(),
17 nn.Linear(32, 1))
18
19 def compute(self, inputs, role):
20 return self.net(torch.cat([inputs["states"], inputs["taken_actions"]], dim=1)), {}
21
22
23# instantiate the model (assumes there is a wrapped environment: env)
24critic = MLP(observation_space=env.observation_space,
25 action_space=env.action_space,
26 device=env.device,
27 clip_actions=False)
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5from skrl.models.torch import Model, DeterministicMixin
6
7
8# define the model
9class MLP(DeterministicMixin, Model):
10 def __init__(self, observation_space, action_space, device, clip_actions=False):
11 Model.__init__(self, observation_space, action_space, device)
12 DeterministicMixin.__init__(self, clip_actions)
13
14 self.fc1 = nn.Linear(self.num_observations + self.num_actions, 64)
15 self.fc2 = nn.Linear(64, 32)
16 self.fc3 = nn.Linear(32, 1)
17
18 def compute(self, inputs, role):
19 x = self.fc1(torch.cat([inputs["states"], inputs["taken_actions"]], dim=1))
20 x = F.relu(x)
21 x = self.fc2(x)
22 x = F.relu(x)
23 return self.fc3(x), {}
24
25
26# instantiate the model (assumes there is a wrapped environment: env)
27critic = MLP(observation_space=env.observation_space,
28 action_space=env.action_space,
29 device=env.device,
30 clip_actions=False)
1import torch
2import torch.nn as nn
3
4from skrl.models.torch import Model, DeterministicMixin
5
6
7# define the model
8class CNN(DeterministicMixin, Model):
9 def __init__(self, observation_space, action_space, device, clip_actions=False):
10 Model.__init__(self, observation_space, action_space, device)
11 DeterministicMixin.__init__(self, clip_actions)
12
13 self.features_extractor = nn.Sequential(nn.Conv2d(3, 32, kernel_size=8, stride=4),
14 nn.ReLU(),
15 nn.Conv2d(32, 64, kernel_size=4, stride=2),
16 nn.ReLU(),
17 nn.Conv2d(64, 64, kernel_size=3, stride=1),
18 nn.ReLU(),
19 nn.Flatten(),
20 nn.Linear(1024, 512),
21 nn.ReLU(),
22 nn.Linear(512, 16),
23 nn.Tanh())
24
25 self.net = nn.Sequential(nn.Linear(16 + self.num_actions, 64),
26 nn.Tanh(),
27 nn.Linear(64, 32),
28 nn.Tanh(),
29 nn.Linear(32, 1))
30
31 def compute(self, inputs, role):
32 # permute (samples, width * height * channels) -> (samples, channels, width, height)
33 x = self.features_extractor(inputs["states"].view(-1, *self.observation_space.shape).permute(0, 3, 1, 2))
34 return self.net(torch.cat([x, inputs["taken_actions"]], dim=1)), {}
35
36
37# instantiate the model (assumes there is a wrapped environment: env)
38critic = CNN(observation_space=env.observation_space,
39 action_space=env.action_space,
40 device=env.device,
41 clip_actions=False)
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5from skrl.models.torch import Model, DeterministicMixin
6
7
8# define the model
9class CNN(DeterministicMixin, Model):
10 def __init__(self, observation_space, action_space, device, clip_actions=False):
11 Model.__init__(self, observation_space, action_space, device)
12 DeterministicMixin.__init__(self, clip_actions)
13
14 self.conv1 = nn.Conv2d(3, 32, kernel_size=8, stride=4)
15 self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
16 self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
17 self.fc1 = nn.Linear(1024, 512)
18 self.fc2 = nn.Linear(512, 16)
19 self.fc3 = nn.Linear(16 + self.num_actions, 64)
20 self.fc4 = nn.Linear(64, 32)
21 self.fc5 = nn.Linear(32, 1)
22
23 def compute(self, inputs, role):
24 # permute (samples, width * height * channels) -> (samples, channels, width, height)
25 x = inputs["states"].view(-1, *self.observation_space.shape).permute(0, 3, 1, 2)
26 x = self.conv1(x)
27 x = F.relu(x)
28 x = self.conv2(x)
29 x = F.relu(x)
30 x = self.conv3(x)
31 x = F.relu(x)
32 x = torch.flatten(x, start_dim=1)
33 x = self.fc1(x)
34 x = F.relu(x)
35 x = self.fc2(x)
36 x = torch.tanh(x)
37 x = self.fc3(torch.cat([x, inputs["taken_actions"]], dim=1))
38 x = torch.tanh(x)
39 x = self.fc4(x)
40 x = torch.tanh(x)
41 x = self.fc5(x)
42 return x, {}
43
44
45# instantiate the model (assumes there is a wrapped environment: env)
46critic = CNN(observation_space=env.observation_space,
47 action_space=env.action_space,
48 device=env.device,
49 clip_actions=False)
where:
The following points are relevant in the definition of recurrent models:
The
.get_specification()
method must be overwritten to return, under a dictionary key"rnn"
, a sub-dictionary that includes the sequence length (under key"sequence_length"
) as a number and a list of the dimensions (under key"sizes"
) of each initial hidden stateThe
.compute()
method’sinputs
parameter will have, at least, the following items in the dictionary:"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states, if applicable"terminated"
: episode termination status for sampled environment transitions. This key is only defined during the training process"rnn"
: list of initial hidden states ordered according to the model specification
The
.compute()
method must inlcude, under the"rnn"
key of the returned dictionary, a list of each final hidden state
1import torch
2import torch.nn as nn
3
4from skrl.models.torch import Model, DeterministicMixin
5
6
7# define the model
8class RNN(DeterministicMixin, Model):
9 def __init__(self, observation_space, action_space, device, clip_actions=False,
10 num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
11 Model.__init__(self, observation_space, action_space, device)
12 DeterministicMixin.__init__(self, clip_actions)
13
14 self.num_envs = num_envs
15 self.num_layers = num_layers
16 self.hidden_size = hidden_size # Hout
17 self.sequence_length = sequence_length
18
19 self.rnn = nn.RNN(input_size=self.num_observations,
20 hidden_size=self.hidden_size,
21 num_layers=self.num_layers,
22 batch_first=True) # batch_first -> (batch, sequence, features)
23
24 self.net = nn.Sequential(nn.Linear(self.hidden_size + self.num_actions, 64),
25 nn.ReLU(),
26 nn.Linear(64, 32),
27 nn.ReLU(),
28 nn.Linear(32, 1))
29
30 def get_specification(self):
31 # batch size (N) is the number of envs during rollout
32 return {"rnn": {"sequence_length": self.sequence_length,
33 "sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
34
35 def compute(self, inputs, role):
36 states = inputs["states"]
37 terminated = inputs.get("terminated", None)
38 hidden_states = inputs["rnn"][0]
39
40 # critic models are only used during training
41 rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
42
43 hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
44 # get the hidden states corresponding to the initial sequence
45 sequence_index = 1 if role == "target_critic" else 0 # target networks act on the next state of the environment
46 hidden_states = hidden_states[:,:,sequence_index,:].contiguous() # (D * num_layers, N, Hout)
47
48 # reset the RNN state in the middle of a sequence
49 if terminated is not None and torch.any(terminated):
50 rnn_outputs = []
51 terminated = terminated.view(-1, self.sequence_length)
52 indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
53
54 for i in range(len(indexes) - 1):
55 i0, i1 = indexes[i], indexes[i + 1]
56 rnn_output, hidden_states = self.rnn(rnn_input[:,i0:i1,:], hidden_states)
57 hidden_states[:, (terminated[:,i1-1]), :] = 0
58 rnn_outputs.append(rnn_output)
59
60 rnn_output = torch.cat(rnn_outputs, dim=1)
61 # no need to reset the RNN state in the sequence
62 else:
63 rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
64
65 # flatten the RNN output
66 rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
67
68 return self.net(torch.cat([rnn_output, inputs["taken_actions"]], dim=1)), {"rnn": [hidden_states]}
69
70
71# instantiate the model (assumes there is a wrapped environment: env)
72critic = RNN(observation_space=env.observation_space,
73 action_space=env.action_space,
74 device=env.device,
75 clip_actions=False,
76 num_envs=env.num_envs,
77 num_layers=1,
78 hidden_size=64,
79 sequence_length=10)
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5from skrl.models.torch import Model, DeterministicMixin
6
7
8# define the model
9class RNN(DeterministicMixin, Model):
10 def __init__(self, observation_space, action_space, device, clip_actions=False,
11 num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
12 Model.__init__(self, observation_space, action_space, device)
13 DeterministicMixin.__init__(self, clip_actions)
14
15 self.num_envs = num_envs
16 self.num_layers = num_layers
17 self.hidden_size = hidden_size # Hout
18 self.sequence_length = sequence_length
19
20 self.rnn = nn.RNN(input_size=self.num_observations,
21 hidden_size=self.hidden_size,
22 num_layers=self.num_layers,
23 batch_first=True) # batch_first -> (batch, sequence, features)
24
25 self.fc1 = nn.Linear(self.hidden_size + self.num_actions, 64)
26 self.fc2 = nn.Linear(64, 32)
27 self.fc3 = nn.Linear(32, 1)
28
29 def get_specification(self):
30 # batch size (N) is the number of envs during rollout
31 return {"rnn": {"sequence_length": self.sequence_length,
32 "sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
33
34 def compute(self, inputs, role):
35 states = inputs["states"]
36 terminated = inputs.get("terminated", None)
37 hidden_states = inputs["rnn"][0]
38
39 # critic models are only used during training
40 rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
41
42 hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
43 # get the hidden states corresponding to the initial sequence
44 sequence_index = 1 if role == "target_critic" else 0 # target networks act on the next state of the environment
45 hidden_states = hidden_states[:,:,sequence_index,:].contiguous() # (D * num_layers, N, Hout)
46
47 # reset the RNN state in the middle of a sequence
48 if terminated is not None and torch.any(terminated):
49 rnn_outputs = []
50 terminated = terminated.view(-1, self.sequence_length)
51 indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
52
53 for i in range(len(indexes) - 1):
54 i0, i1 = indexes[i], indexes[i + 1]
55 rnn_output, hidden_states = self.rnn(rnn_input[:,i0:i1,:], hidden_states)
56 hidden_states[:, (terminated[:,i1-1]), :] = 0
57 rnn_outputs.append(rnn_output)
58
59 rnn_output = torch.cat(rnn_outputs, dim=1)
60 # no need to reset the RNN state in the sequence
61 else:
62 rnn_output, hidden_states = self.rnn(rnn_input, hidden_states)
63
64 # flatten the RNN output
65 rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
66
67 x = self.fc1(torch.cat([rnn_output, inputs["taken_actions"]], dim=1))
68 x = F.relu(x)
69 x = self.fc2(x)
70 x = F.relu(x)
71
72 return self.fc3(x), {"rnn": [hidden_states]}
73
74
75# instantiate the model (assumes there is a wrapped environment: env)
76critic = RNN(observation_space=env.observation_space,
77 action_space=env.action_space,
78 device=env.device,
79 clip_actions=False,
80 num_envs=env.num_envs,
81 num_layers=1,
82 hidden_size=64,
83 sequence_length=10)
where:
The following points are relevant in the definition of recurrent models:
The
.get_specification()
method must be overwritten to return, under a dictionary key"rnn"
, a sub-dictionary that includes the sequence length (under key"sequence_length"
) as a number and a list of the dimensions (under key"sizes"
) of each initial hidden stateThe
.compute()
method’sinputs
parameter will have, at least, the following items in the dictionary:"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states, if applicable"terminated"
: episode termination status for sampled environment transitions. This key is only defined during the training process"rnn"
: list of initial hidden states ordered according to the model specification
The
.compute()
method must inlcude, under the"rnn"
key of the returned dictionary, a list of each final hidden state
1import torch
2import torch.nn as nn
3
4from skrl.models.torch import Model, DeterministicMixin
5
6
7# define the model
8class GRU(DeterministicMixin, Model):
9 def __init__(self, observation_space, action_space, device, clip_actions=False,
10 num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
11 Model.__init__(self, observation_space, action_space, device)
12 DeterministicMixin.__init__(self, clip_actions)
13
14 self.num_envs = num_envs
15 self.num_layers = num_layers
16 self.hidden_size = hidden_size # Hout
17 self.sequence_length = sequence_length
18
19 self.gru = nn.GRU(input_size=self.num_observations,
20 hidden_size=self.hidden_size,
21 num_layers=self.num_layers,
22 batch_first=True) # batch_first -> (batch, sequence, features)
23
24 self.net = nn.Sequential(nn.Linear(self.hidden_size + self.num_actions, 64),
25 nn.ReLU(),
26 nn.Linear(64, 32),
27 nn.ReLU(),
28 nn.Linear(32, 1))
29
30 def get_specification(self):
31 # batch size (N) is the number of envs during rollout
32 return {"rnn": {"sequence_length": self.sequence_length,
33 "sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
34
35 def compute(self, inputs, role):
36 states = inputs["states"]
37 terminated = inputs.get("terminated", None)
38 hidden_states = inputs["rnn"][0]
39
40 # critic models are only used during training
41 rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
42
43 hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
44 # get the hidden states corresponding to the initial sequence
45 sequence_index = 1 if role == "target_critic" else 0 # target networks act on the next state of the environment
46 hidden_states = hidden_states[:,:,sequence_index,:].contiguous() # (D * num_layers, N, Hout)
47
48 # reset the RNN state in the middle of a sequence
49 if terminated is not None and torch.any(terminated):
50 rnn_outputs = []
51 terminated = terminated.view(-1, self.sequence_length)
52 indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
53
54 for i in range(len(indexes) - 1):
55 i0, i1 = indexes[i], indexes[i + 1]
56 rnn_output, hidden_states = self.gru(rnn_input[:,i0:i1,:], hidden_states)
57 hidden_states[:, (terminated[:,i1-1]), :] = 0
58 rnn_outputs.append(rnn_output)
59
60 rnn_output = torch.cat(rnn_outputs, dim=1)
61 # no need to reset the RNN state in the sequence
62 else:
63 rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
64
65 # flatten the RNN output
66 rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
67
68 return self.net(torch.cat([rnn_output, inputs["taken_actions"]], dim=1)), {"rnn": [hidden_states]}
69
70
71# instantiate the model (assumes there is a wrapped environment: env)
72critic = GRU(observation_space=env.observation_space,
73 action_space=env.action_space,
74 device=env.device,
75 clip_actions=False,
76 num_envs=env.num_envs,
77 num_layers=1,
78 hidden_size=64,
79 sequence_length=10)
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5from skrl.models.torch import Model, DeterministicMixin
6
7
8# define the model
9class GRU(DeterministicMixin, Model):
10 def __init__(self, observation_space, action_space, device, clip_actions=False,
11 num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
12 Model.__init__(self, observation_space, action_space, device)
13 DeterministicMixin.__init__(self, clip_actions)
14
15 self.num_envs = num_envs
16 self.num_layers = num_layers
17 self.hidden_size = hidden_size # Hout
18 self.sequence_length = sequence_length
19
20 self.gru = nn.GRU(input_size=self.num_observations,
21 hidden_size=self.hidden_size,
22 num_layers=self.num_layers,
23 batch_first=True) # batch_first -> (batch, sequence, features)
24
25 self.fc1 = nn.Linear(self.hidden_size + self.num_actions, 64)
26 self.fc2 = nn.Linear(64, 32)
27 self.fc3 = nn.Linear(32, 1)
28
29 def get_specification(self):
30 # batch size (N) is the number of envs during rollout
31 return {"rnn": {"sequence_length": self.sequence_length,
32 "sizes": [(self.num_layers, self.num_envs, self.hidden_size)]}} # hidden states (D ∗ num_layers, N, Hout)
33
34 def compute(self, inputs, role):
35 states = inputs["states"]
36 terminated = inputs.get("terminated", None)
37 hidden_states = inputs["rnn"][0]
38
39 # critic models are only used during training
40 rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
41
42 hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
43 # get the hidden states corresponding to the initial sequence
44 sequence_index = 1 if role == "target_critic" else 0 # target networks act on the next state of the environment
45 hidden_states = hidden_states[:,:,sequence_index,:].contiguous() # (D * num_layers, N, Hout)
46
47 # reset the RNN state in the middle of a sequence
48 if terminated is not None and torch.any(terminated):
49 rnn_outputs = []
50 terminated = terminated.view(-1, self.sequence_length)
51 indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
52
53 for i in range(len(indexes) - 1):
54 i0, i1 = indexes[i], indexes[i + 1]
55 rnn_output, hidden_states = self.gru(rnn_input[:,i0:i1,:], hidden_states)
56 hidden_states[:, (terminated[:,i1-1]), :] = 0
57 rnn_outputs.append(rnn_output)
58
59 rnn_output = torch.cat(rnn_outputs, dim=1)
60 # no need to reset the RNN state in the sequence
61 else:
62 rnn_output, hidden_states = self.gru(rnn_input, hidden_states)
63
64 # flatten the RNN output
65 rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
66
67 x = self.fc1(torch.cat([rnn_output, inputs["taken_actions"]], dim=1))
68 x = F.relu(x)
69 x = self.fc2(x)
70 x = F.relu(x)
71
72 return self.fc3(x), {"rnn": [hidden_states]}
73
74
75# instantiate the model (assumes there is a wrapped environment: env)
76critic = GRU(observation_space=env.observation_space,
77 action_space=env.action_space,
78 device=env.device,
79 clip_actions=False,
80 num_envs=env.num_envs,
81 num_layers=1,
82 hidden_size=64,
83 sequence_length=10)
where:
The following points are relevant in the definition of recurrent models:
The
.get_specification()
method must be overwritten to return, under a dictionary key"rnn"
, a sub-dictionary that includes the sequence length (under key"sequence_length"
) as a number and a list of the dimensions (under key"sizes"
) of each initial hidden/cell statesThe
.compute()
method’sinputs
parameter will have, at least, the following items in the dictionary:"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states, if applicable"terminated"
: episode termination status for sampled environment transitions. This key is only defined during the training process"rnn"
: list of initial hidden/cell states ordered according to the model specification
The
.compute()
method must inlcude, under the"rnn"
key of the returned dictionary, a list of each final hidden/cell states
1import torch
2import torch.nn as nn
3
4from skrl.models.torch import Model, DeterministicMixin
5
6
7# define the model
8class LSTM(DeterministicMixin, Model):
9 def __init__(self, observation_space, action_space, device, clip_actions=False,
10 num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
11 Model.__init__(self, observation_space, action_space, device)
12 DeterministicMixin.__init__(self, clip_actions)
13
14 self.num_envs = num_envs
15 self.num_layers = num_layers
16 self.hidden_size = hidden_size # Hcell (Hout is Hcell because proj_size = 0)
17 self.sequence_length = sequence_length
18
19 self.lstm = nn.LSTM(input_size=self.num_observations,
20 hidden_size=self.hidden_size,
21 num_layers=self.num_layers,
22 batch_first=True) # batch_first -> (batch, sequence, features)
23
24 self.net = nn.Sequential(nn.Linear(self.hidden_size + self.num_actions, 64),
25 nn.ReLU(),
26 nn.Linear(64, 32),
27 nn.ReLU(),
28 nn.Linear(32, 1))
29
30 def get_specification(self):
31 # batch size (N) is the number of envs during rollout
32 return {"rnn": {"sequence_length": self.sequence_length,
33 "sizes": [(self.num_layers, self.num_envs, self.hidden_size), # hidden states (D ∗ num_layers, N, Hout)
34 (self.num_layers, self.num_envs, self.hidden_size)]}} # cell states (D ∗ num_layers, N, Hcell)
35
36 def compute(self, inputs, role):
37 states = inputs["states"]
38 terminated = inputs.get("terminated", None)
39 hidden_states, cell_states = inputs["rnn"][0], inputs["rnn"][1]
40
41 # critic models are only used during training
42 rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
43
44 hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
45 cell_states = cell_states.view(self.num_layers, -1, self.sequence_length, cell_states.shape[-1]) # (D * num_layers, N, L, Hcell)
46 # get the hidden/cell states corresponding to the initial sequence
47 sequence_index = 1 if role == "target_critic" else 0 # target networks act on the next state of the environment
48 hidden_states = hidden_states[:,:,sequence_index,:].contiguous() # (D * num_layers, N, Hout)
49 cell_states = cell_states[:,:,sequence_index,:].contiguous() # (D * num_layers, N, Hcell)
50
51 # reset the RNN state in the middle of a sequence
52 if terminated is not None and torch.any(terminated):
53 rnn_outputs = []
54 terminated = terminated.view(-1, self.sequence_length)
55 indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
56
57 for i in range(len(indexes) - 1):
58 i0, i1 = indexes[i], indexes[i + 1]
59 rnn_output, (hidden_states, cell_states) = self.lstm(rnn_input[:,i0:i1,:], (hidden_states, cell_states))
60 hidden_states[:, (terminated[:,i1-1]), :] = 0
61 cell_states[:, (terminated[:,i1-1]), :] = 0
62 rnn_outputs.append(rnn_output)
63
64 rnn_states = (hidden_states, cell_states)
65 rnn_output = torch.cat(rnn_outputs, dim=1)
66 # no need to reset the RNN state in the sequence
67 else:
68 rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
69
70 # flatten the RNN output
71 rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
72
73 return self.net(torch.cat([rnn_output, inputs["taken_actions"]], dim=1)), {"rnn": [rnn_states[0], rnn_states[1]]}
74
75
76# instantiate the model (assumes there is a wrapped environment: env)
77critic = LSTM(observation_space=env.observation_space,
78 action_space=env.action_space,
79 device=env.device,
80 clip_actions=False,
81 num_envs=env.num_envs,
82 num_layers=1,
83 hidden_size=64,
84 sequence_length=10)
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5from skrl.models.torch import Model, DeterministicMixin
6
7
8# define the model
9class LSTM(DeterministicMixin, Model):
10 def __init__(self, observation_space, action_space, device, clip_actions=False,
11 num_envs=1, num_layers=1, hidden_size=64, sequence_length=10):
12 Model.__init__(self, observation_space, action_space, device)
13 DeterministicMixin.__init__(self, clip_actions)
14
15 self.num_envs = num_envs
16 self.num_layers = num_layers
17 self.hidden_size = hidden_size # Hcell (Hout is Hcell because proj_size = 0)
18 self.sequence_length = sequence_length
19
20 self.lstm = nn.LSTM(input_size=self.num_observations,
21 hidden_size=self.hidden_size,
22 num_layers=self.num_layers,
23 batch_first=True) # batch_first -> (batch, sequence, features)
24
25 self.fc1 = nn.Linear(self.hidden_size + self.num_actions, 64)
26 self.fc2 = nn.Linear(64, 32)
27 self.fc3 = nn.Linear(32, 1)
28
29 def get_specification(self):
30 # batch size (N) is the number of envs during rollout
31 return {"rnn": {"sequence_length": self.sequence_length,
32 "sizes": [(self.num_layers, self.num_envs, self.hidden_size), # hidden states (D ∗ num_layers, N, Hout)
33 (self.num_layers, self.num_envs, self.hidden_size)]}} # cell states (D ∗ num_layers, N, Hcell)
34
35 def compute(self, inputs, role):
36 states = inputs["states"]
37 terminated = inputs.get("terminated", None)
38 hidden_states, cell_states = inputs["rnn"][0], inputs["rnn"][1]
39
40 # critic models are only used during training
41 rnn_input = states.view(-1, self.sequence_length, states.shape[-1]) # (N, L, Hin): N=batch_size, L=sequence_length
42
43 hidden_states = hidden_states.view(self.num_layers, -1, self.sequence_length, hidden_states.shape[-1]) # (D * num_layers, N, L, Hout)
44 cell_states = cell_states.view(self.num_layers, -1, self.sequence_length, cell_states.shape[-1]) # (D * num_layers, N, L, Hcell)
45 # get the hidden/cell states corresponding to the initial sequence
46 sequence_index = 1 if role == "target_critic" else 0 # target networks act on the next state of the environment
47 hidden_states = hidden_states[:,:,sequence_index,:].contiguous() # (D * num_layers, N, Hout)
48 cell_states = cell_states[:,:,sequence_index,:].contiguous() # (D * num_layers, N, Hcell)
49
50 # reset the RNN state in the middle of a sequence
51 if terminated is not None and torch.any(terminated):
52 rnn_outputs = []
53 terminated = terminated.view(-1, self.sequence_length)
54 indexes = [0] + (terminated[:,:-1].any(dim=0).nonzero(as_tuple=True)[0] + 1).tolist() + [self.sequence_length]
55
56 for i in range(len(indexes) - 1):
57 i0, i1 = indexes[i], indexes[i + 1]
58 rnn_output, (hidden_states, cell_states) = self.lstm(rnn_input[:,i0:i1,:], (hidden_states, cell_states))
59 hidden_states[:, (terminated[:,i1-1]), :] = 0
60 cell_states[:, (terminated[:,i1-1]), :] = 0
61 rnn_outputs.append(rnn_output)
62
63 rnn_states = (hidden_states, cell_states)
64 rnn_output = torch.cat(rnn_outputs, dim=1)
65 # no need to reset the RNN state in the sequence
66 else:
67 rnn_output, rnn_states = self.lstm(rnn_input, (hidden_states, cell_states))
68
69 # flatten the RNN output
70 rnn_output = torch.flatten(rnn_output, start_dim=0, end_dim=1) # (N, L, D ∗ Hout) -> (N * L, D ∗ Hout)
71
72 x = self.fc1(torch.cat([rnn_output, inputs["taken_actions"]], dim=1))
73 x = F.relu(x)
74 x = self.fc2(x)
75 x = F.relu(x)
76
77 return self.fc3(x), {"rnn": [rnn_states[0], rnn_states[1]]}
78
79
80# instantiate the model (assumes there is a wrapped environment: env)
81critic = LSTM(observation_space=env.observation_space,
82 action_space=env.action_space,
83 device=env.device,
84 clip_actions=False,
85 num_envs=env.num_envs,
86 num_layers=1,
87 hidden_size=64,
88 sequence_length=10)
API
- class skrl.models.torch.deterministic.DeterministicMixin(clip_actions: bool = False, role: str = '')
Bases:
object
- __init__(clip_actions: bool = False, role: str = '') None
Deterministic mixin model (deterministic model)
- Parameters
Example:
# define the model >>> import torch >>> import torch.nn as nn >>> from skrl.models.torch import Model, DeterministicMixin >>> >>> class Value(DeterministicMixin, Model): ... def __init__(self, observation_space, action_space, device="cuda:0", clip_actions=False): ... Model.__init__(self, observation_space, action_space, device) ... DeterministicMixin.__init__(self, clip_actions) ... ... self.net = nn.Sequential(nn.Linear(self.num_observations, 32), ... nn.ELU(), ... nn.Linear(32, 32), ... nn.ELU(), ... nn.Linear(32, 1)) ... ... def compute(self, inputs, role): ... return self.net(inputs["states"]), {} ... >>> # given an observation_space: gym.spaces.Box with shape (60,) >>> # and an action_space: gym.spaces.Box with shape (8,) >>> model = Value(observation_space, action_space) >>> >>> print(model) Value( (net): Sequential( (0): Linear(in_features=60, out_features=32, bias=True) (1): ELU(alpha=1.0) (2): Linear(in_features=32, out_features=32, bias=True) (3): ELU(alpha=1.0) (4): Linear(in_features=32, out_features=1, bias=True) ) )
- act(inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = '') Tuple[torch.Tensor, Optional[torch.Tensor], Mapping[str, Union[torch.Tensor, Any]]]
Act deterministically in response to the state of the environment
- Parameters
inputs (dict where the values are typically torch.Tensor) –
Model inputs. The most common keys are:
"states"
: state of the environment used to make the decision"taken_actions"
: actions taken by the policy for the given states
role (str, optional) – Role play by the model (default:
""
)
- Returns
Model output. The first component is the action to be taken by the agent. The second component is
None
. The third component is a dictionary containing extra output values- Return type
tuple of torch.Tensor, torch.Tensor or None, and dictionary
Example:
>>> # given a batch of sample states with shape (4096, 60) >>> actions, _, outputs = model.act({"states": states}) >>> print(actions.shape, outputs) torch.Size([4096, 1]) {}