Multi-Agent Proximal Policy Optimization (MAPPO)¶
MAPPO is a model-free, stochastic on-policy policy gradient CTDE (centralized training, decentralized execution) multi-agent algorithm that uses a centralized value function to estimate a single value that is used to guide the policy updates of all agents, improving coordination and cooperation between them
Paper: The Surprising Effectiveness of PPO in Cooperative, Multi-Agent Games
Algorithm¶
Algorithm implementation¶
Learning algorithm¶
compute_gae(...)
_update(...)
Usage¶
# import the agent and its default configuration
from skrl.multi_agents.torch.mappo import MAPPO, MAPPO_DEFAULT_CONFIG
# instantiate the agent's models
models = {}
for agent_name in env.possible_agents:
models[agent_name] = {}
models[agent_name]["policy"] = ...
models[agent_name]["value"] = ... # only required during training
# adjust some configuration if necessary
cfg_agent = MAPPO_DEFAULT_CONFIG.copy()
cfg_agent["<KEY>"] = ...
# instantiate the agent
# (assuming a defined environment <env> and memories <memories>)
agent = MAPPO(possible_agents=env.possible_agents,
models=models,
memory=memories, # only required during training
cfg=cfg_agent,
observation_spaces=env.observation_spaces,
action_spaces=env.action_spaces,
device=env.device,
shared_observation_spaces=env.shared_observation_spaces)
# import the agent and its default configuration
from skrl.multi_agents.jax.mappo import MAPPO, MAPPO_DEFAULT_CONFIG
# instantiate the agent's models
models = {}
for agent_name in env.possible_agents:
models[agent_name] = {}
models[agent_name]["policy"] = ...
models[agent_name]["value"] = ... # only required during training
# adjust some configuration if necessary
cfg_agent = MAPPO_DEFAULT_CONFIG.copy()
cfg_agent["<KEY>"] = ...
# instantiate the agent
# (assuming a defined environment <env> and memories <memories>)
agent = MAPPO(possible_agents=env.possible_agents,
models=models,
memory=memories, # only required during training
cfg=cfg_agent,
observation_spaces=env.observation_spaces,
action_spaces=env.action_spaces,
device=env.device,
shared_observation_spaces=env.shared_observation_spaces)
Configuration and hyperparameters¶
Note
The specification of a single value is automatically extended to all involved agents, unless the configuration of each individual agent is specified using a dictionary. For example:
# specify a configuration value for each agent (agent names depend on environment)
cfg["discount_factor"] = {"agent_0": 0.99, "agent_1": 0.995, "agent_2": 0.985}
MAPPO_DEFAULT_CONFIG = {
"rollouts": 16, # number of rollouts before updating
"learning_epochs": 8, # number of learning epochs during each update
"mini_batches": 2, # number of mini batches during each learning epoch
"discount_factor": 0.99, # discount factor (gamma)
"lambda": 0.95, # TD(lambda) coefficient (lam) for computing returns and advantages
"learning_rate": 1e-3, # learning rate
"learning_rate_scheduler": None, # learning rate scheduler class (see torch.optim.lr_scheduler)
"learning_rate_scheduler_kwargs": {}, # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3})
"state_preprocessor": None, # state preprocessor class (see skrl.resources.preprocessors)
"state_preprocessor_kwargs": {}, # state preprocessor's kwargs (e.g. {"size": env.observation_space})
"shared_state_preprocessor": None, # shared state preprocessor class (see skrl.resources.preprocessors)
"shared_state_preprocessor_kwargs": {}, # shared state preprocessor's kwargs (e.g. {"size": env.shared_observation_space})
"value_preprocessor": None, # value preprocessor class (see skrl.resources.preprocessors)
"value_preprocessor_kwargs": {}, # value preprocessor's kwargs (e.g. {"size": 1})
"random_timesteps": 0, # random exploration steps
"learning_starts": 0, # learning starts after this many steps
"grad_norm_clip": 0.5, # clipping coefficient for the norm of the gradients
"ratio_clip": 0.2, # clipping coefficient for computing the clipped surrogate objective
"value_clip": 0.2, # clipping coefficient for computing the value loss (if clip_predicted_values is True)
"clip_predicted_values": False, # clip predicted values during value loss computation
"entropy_loss_scale": 0.0, # entropy loss scaling factor
"value_loss_scale": 1.0, # value loss scaling factor
"kl_threshold": 0, # KL divergence threshold for early stopping
"rewards_shaper": None, # rewards shaping function: Callable(reward, timestep, timesteps) -> reward
"time_limit_bootstrap": False, # bootstrap at timeout termination (episode truncation)
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
"write_interval": 250, # TensorBoard writing interval (timesteps)
"checkpoint_interval": 1000, # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately
"wandb": False, # whether to use Weights & Biases
"wandb_kwargs": {} # wandb kwargs (see https://docs.wandb.ai/ref/python/init)
}
}
Spaces¶
The implementation supports the following Gym spaces / Gymnasium spaces
Gym/Gymnasium spaces |
Observation |
Action |
---|---|---|
Discrete |
\(\square\) |
\(\blacksquare\) |
MultiDiscrete |
\(\square\) |
\(\blacksquare\) |
Box |
\(\blacksquare\) |
\(\blacksquare\) |
Dict |
\(\blacksquare\) |
\(\square\) |
Models¶
The implementation uses 1 stochastic (discrete or continuous) and 1 deterministic function approximator. These function approximators (models) must be collected in a dictionary and passed to the constructor of the class under the argument models
Notation |
Concept |
Key |
Input shape |
Output shape |
Type |
---|---|---|---|---|---|
\(\pi_\theta(s)\) |
Policy |
|
observation |
action |
Categorical /
|
\(V_\phi(s)\) |
Value |
|
observation |
1 |
Features¶
Support for advanced features is described in the next table
Feature |
Support and remarks |
|
|
---|---|---|---|
Shared model |
for Policy and Value |
\(\blacksquare\) |
\(\square\) |
RNN support |
- |
\(\square\) |
\(\square\) |
Distributed |
Single Program Multi Data (SPMD) multi-GPU |
\(\blacksquare\) |
\(\square\) |
API (PyTorch)¶
- skrl.multi_agents.torch.mappo.MAPPO_DEFAULT_CONFIG¶
alias of {‘clip_predicted_values’: False, ‘discount_factor’: 0.99, ‘entropy_loss_scale’: 0.0, ‘experiment’: {‘checkpoint_interval’: 1000, ‘directory’: ‘’, ‘experiment_name’: ‘’, ‘store_separately’: False, ‘wandb’: False, ‘wandb_kwargs’: {}, ‘write_interval’: 250}, ‘grad_norm_clip’: 0.5, ‘kl_threshold’: 0, ‘lambda’: 0.95, ‘learning_epochs’: 8, ‘learning_rate’: 0.001, ‘learning_rate_scheduler’: None, ‘learning_rate_scheduler_kwargs’: {}, ‘learning_starts’: 0, ‘mini_batches’: 2, ‘random_timesteps’: 0, ‘ratio_clip’: 0.2, ‘rewards_shaper’: None, ‘rollouts’: 16, ‘shared_state_preprocessor’: None, ‘shared_state_preprocessor_kwargs’: {}, ‘state_preprocessor’: None, ‘state_preprocessor_kwargs’: {}, ‘time_limit_bootstrap’: False, ‘value_clip’: 0.2, ‘value_loss_scale’: 1.0, ‘value_preprocessor’: None, ‘value_preprocessor_kwargs’: {}}
- class skrl.multi_agents.torch.mappo.MAPPO(possible_agents: Sequence[str], models: Mapping[str, Model], memories: Mapping[str, Memory] | None = None, observation_spaces: Mapping[str, int] | Mapping[str, gym.Space] | Mapping[str, gymnasium.Space] | None = None, action_spaces: Mapping[str, int] | Mapping[str, gym.Space] | Mapping[str, gymnasium.Space] | None = None, device: str | torch.device | None = None, cfg: dict | None = None, shared_observation_spaces: Mapping[str, int] | Mapping[str, gym.Space] | Mapping[str, gymnasium.Space] | None = None)¶
Bases:
MultiAgent
- __init__(possible_agents: Sequence[str], models: Mapping[str, Model], memories: Mapping[str, Memory] | None = None, observation_spaces: Mapping[str, int] | Mapping[str, gym.Space] | Mapping[str, gymnasium.Space] | None = None, action_spaces: Mapping[str, int] | Mapping[str, gym.Space] | Mapping[str, gymnasium.Space] | None = None, device: str | torch.device | None = None, cfg: dict | None = None, shared_observation_spaces: Mapping[str, int] | Mapping[str, gym.Space] | Mapping[str, gymnasium.Space] | None = None) None ¶
Multi-Agent Proximal Policy Optimization (MAPPO)
https://arxiv.org/abs/2103.01955
- Parameters:
possible_agents (list of str) – Name of all possible agents the environment could generate
models (nested dictionary of skrl.models.torch.Model) – Models used by the agents. External keys are environment agents’ names. Internal keys are the models required by the algorithm
memories (dictionary of skrl.memory.torch.Memory, optional) – Memories to storage the transitions.
observation_spaces (dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional) – Observation/state spaces or shapes (default:
None
)action_spaces (dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional) – Action spaces or shapes (default:
None
)device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default:
None
). If None, the device will be either"cuda"
if available or"cpu"
cfg (dict) – Configuration dictionary
shared_observation_spaces (dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional) – Shared observation/state space or shape (default:
None
)
- act(states: Mapping[str, torch.Tensor], timestep: int, timesteps: int) torch.Tensor ¶
Process the environment’s states to make a decision (actions) using the main policies
- Parameters:
states (dictionary of torch.Tensor) – Environment’s states
timestep (int) – Current timestep
timesteps (int) – Number of timesteps
- Returns:
Actions
- Return type:
- post_interaction(timestep: int, timesteps: int) None ¶
Callback called after the interaction with the environment
- pre_interaction(timestep: int, timesteps: int) None ¶
Callback called before the interaction with the environment
- record_transition(states: Mapping[str, torch.Tensor], actions: Mapping[str, torch.Tensor], rewards: Mapping[str, torch.Tensor], next_states: Mapping[str, torch.Tensor], terminated: Mapping[str, torch.Tensor], truncated: Mapping[str, torch.Tensor], infos: Mapping[str, Any], timestep: int, timesteps: int) None ¶
Record an environment transition in memory
- Parameters:
states (dictionary of torch.Tensor) – Observations/states of the environment used to make the decision
actions (dictionary of torch.Tensor) – Actions taken by the agent
rewards (dictionary of torch.Tensor) – Instant rewards achieved by the current actions
next_states (dictionary of torch.Tensor) – Next observations/states of the environment
terminated (dictionary of torch.Tensor) – Signals to indicate that episodes have terminated
truncated (dictionary of torch.Tensor) – Signals to indicate that episodes have been truncated
infos (dictionary of any supported type) – Additional information about the environment
timestep (int) – Current timestep
timesteps (int) – Number of timesteps
API (JAX)¶
- skrl.multi_agents.jax.mappo.MAPPO_DEFAULT_CONFIG¶
alias of {‘clip_predicted_values’: False, ‘discount_factor’: 0.99, ‘entropy_loss_scale’: 0.0, ‘experiment’: {‘checkpoint_interval’: 1000, ‘directory’: ‘’, ‘experiment_name’: ‘’, ‘store_separately’: False, ‘wandb’: False, ‘wandb_kwargs’: {}, ‘write_interval’: 250}, ‘grad_norm_clip’: 0.5, ‘kl_threshold’: 0, ‘lambda’: 0.95, ‘learning_epochs’: 8, ‘learning_rate’: 0.001, ‘learning_rate_scheduler’: None, ‘learning_rate_scheduler_kwargs’: {}, ‘learning_starts’: 0, ‘mini_batches’: 2, ‘random_timesteps’: 0, ‘ratio_clip’: 0.2, ‘rewards_shaper’: None, ‘rollouts’: 16, ‘shared_state_preprocessor’: None, ‘shared_state_preprocessor_kwargs’: {}, ‘state_preprocessor’: None, ‘state_preprocessor_kwargs’: {}, ‘time_limit_bootstrap’: False, ‘value_clip’: 0.2, ‘value_loss_scale’: 1.0, ‘value_preprocessor’: None, ‘value_preprocessor_kwargs’: {}}
- class skrl.multi_agents.jax.mappo.MAPPO(possible_agents: Sequence[str], models: Mapping[str, Model], memories: Mapping[str, Memory] | None = None, observation_spaces: Mapping[str, int] | Mapping[str, gym.Space] | Mapping[str, gymnasium.Space] | None = None, action_spaces: Mapping[str, int] | Mapping[str, gym.Space] | Mapping[str, gymnasium.Space] | None = None, device: str | jax.Device | None = None, cfg: dict | None = None, shared_observation_spaces: Mapping[str, int] | Mapping[str, gym.Space] | Mapping[str, gymnasium.Space] | None = None)¶
Bases:
MultiAgent
- __init__(possible_agents: Sequence[str], models: Mapping[str, Model], memories: Mapping[str, Memory] | None = None, observation_spaces: Mapping[str, int] | Mapping[str, gym.Space] | Mapping[str, gymnasium.Space] | None = None, action_spaces: Mapping[str, int] | Mapping[str, gym.Space] | Mapping[str, gymnasium.Space] | None = None, device: str | jax.Device | None = None, cfg: dict | None = None, shared_observation_spaces: Mapping[str, int] | Mapping[str, gym.Space] | Mapping[str, gymnasium.Space] | None = None) None ¶
Multi-Agent Proximal Policy Optimization (MAPPO)
https://arxiv.org/abs/2103.01955
- Parameters:
possible_agents (list of str) – Name of all possible agents the environment could generate
models (nested dictionary of skrl.models.jax.Model) – Models used by the agents. External keys are environment agents’ names. Internal keys are the models required by the algorithm
memories (dictionary of skrl.memory.jax.Memory, optional) – Memories to storage the transitions.
observation_spaces (dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional) – Observation/state spaces or shapes (default:
None
)action_spaces (dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional) – Action spaces or shapes (default:
None
)device (str or jax.Device, optional) – Device on which a tensor/array is or will be allocated (default:
None
). If None, the device will be either"cuda"
if available or"cpu"
cfg (dict) – Configuration dictionary
shared_observation_spaces (dictionary of int, sequence of int, gym.Space or gymnasium.Space, optional) – Shared observation/state space or shape (default:
None
)
- act(states: Mapping[str, ndarray | jax.Array], timestep: int, timesteps: int) ndarray | jax.Array ¶
Process the environment’s states to make a decision (actions) using the main policies
- post_interaction(timestep: int, timesteps: int) None ¶
Callback called after the interaction with the environment
- pre_interaction(timestep: int, timesteps: int) None ¶
Callback called before the interaction with the environment
- record_transition(states: Mapping[str, ndarray | jax.Array], actions: Mapping[str, ndarray | jax.Array], rewards: Mapping[str, ndarray | jax.Array], next_states: Mapping[str, ndarray | jax.Array], terminated: Mapping[str, ndarray | jax.Array], truncated: Mapping[str, ndarray | jax.Array], infos: Mapping[str, Any], timestep: int, timesteps: int) None ¶
Record an environment transition in memory
- Parameters:
states (dictionary of np.ndarray or jax.Array) – Observations/states of the environment used to make the decision
actions (dictionary of np.ndarray or jax.Array) – Actions taken by the agent
rewards (dictionary of np.ndarray or jax.Array) – Instant rewards achieved by the current actions
next_states (dictionary of np.ndarray or jax.Array) – Next observations/states of the environment
terminated (dictionary of np.ndarray or jax.Array) – Signals to indicate that episodes have terminated
truncated (dictionary of np.ndarray or jax.Array) – Signals to indicate that episodes have been truncated
infos (dictionary of any type supported by the environment) – Additional information about the environment
timestep (int) – Current timestep
timesteps (int) – Number of timesteps