Step trainer¶
Train agents controlling the training/evaluation loop step-by-step.
Concept¶
Usage¶
from skrl.trainers.torch import StepTrainer
# assuming there is an environment called 'env'
# and an agent or a list of agents called 'agents'
# create a sequential trainer
cfg = {"timesteps": 50000, "headless": False}
trainer = StepTrainer(env=env, agents=agents, cfg=cfg)
# train the agent(s)
for timestep in range(cfg["timesteps"]):
trainer.train(timestep=timestep)
# evaluate the agent(s)
for timestep in range(cfg["timesteps"]):
trainer.eval(timestep=timestep)
from skrl.trainers.jax import StepTrainer
# assuming there is an environment called 'env'
# and an agent or a list of agents called 'agents'
# create a sequential trainer
cfg = {"timesteps": 50000, "headless": False}
trainer = StepTrainer(env=env, agents=agents, cfg=cfg)
# train the agent(s)
for timestep in range(cfg["timesteps"]):
trainer.train(timestep=timestep)
# evaluate the agent(s)
for timestep in range(cfg["timesteps"]):
trainer.eval(timestep=timestep)
Configuration¶
STEP_TRAINER_DEFAULT_CONFIG = {
"timesteps": 100000, # number of timesteps to train for
"headless": False, # whether to use headless mode (no rendering)
"disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY
"close_environment_at_exit": True, # whether to close the environment on normal program termination
"environment_info": "episode", # key used to get and log environment info
"stochastic_evaluation": False, # whether to use actions rather than (deterministic) mean actions during evaluation
}
API (PyTorch)¶
- skrl.trainers.torch.step.STEP_TRAINER_DEFAULT_CONFIG¶
alias of {‘close_environment_at_exit’: True, ‘disable_progressbar’: False, ‘environment_info’: ‘episode’, ‘headless’: False, ‘stochastic_evaluation’: False, ‘timesteps’: 100000}
- class skrl.trainers.torch.step.StepTrainer(env: Wrapper, agents: Agent | List[Agent], agents_scope: List[int] | None = None, cfg: dict | None = None)¶
Bases:
Trainer
Step-by-step trainer
Train agents by controlling the training/evaluation loop step by step
- Parameters:
env (skrl.envs.wrappers.torch.Wrapper) – Environment to train on
agents_scope (tuple or list of int, optional) – Number of environments for each agent to train on (default:
None
)cfg (dict, optional) – Configuration dictionary (default:
None
). See STEP_TRAINER_DEFAULT_CONFIG for default values
- eval(timestep: int | None = None, timesteps: int | None = None) Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any] ¶
Evaluate the agents sequentially
This method executes the following steps in loop:
Compute actions (sequentially if num_simultaneous_agents > 1)
Interact with the environments
Render scene
Reset environments
- Parameters:
- Returns:
Observation, reward, terminated, truncated, info
- Return type:
tuple of torch.Tensor and any other info
- multi_agent_eval() None ¶
Evaluate multi-agents
This method executes the following steps in loop:
Compute actions (sequentially)
Interact with the environments
Render scene
Reset environments
- multi_agent_train() None ¶
Train multi-agents
This method executes the following steps in loop:
Pre-interaction
Compute actions
Interact with the environments
Render scene
Record transitions
Post-interaction
Reset environments
- single_agent_eval() None ¶
Evaluate agent
This method executes the following steps in loop:
Compute actions (sequentially)
Interact with the environments
Render scene
Reset environments
- single_agent_train() None ¶
Train agent
This method executes the following steps in loop:
Pre-interaction
Compute actions
Interact with the environments
Render scene
Record transitions
Post-interaction
Reset environments
- train(timestep: int | None = None, timesteps: int | None = None) Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any] ¶
Execute a training iteration
This method executes the following steps once:
Pre-interaction (sequentially if num_simultaneous_agents > 1)
Compute actions (sequentially if num_simultaneous_agents > 1)
Interact with the environments
Render scene
Record transitions (sequentially if num_simultaneous_agents > 1)
Post-interaction (sequentially if num_simultaneous_agents > 1)
Reset environments
- Parameters:
- Returns:
Observation, reward, terminated, truncated, info
- Return type:
tuple of torch.Tensor and any other info
API (JAX)¶
- skrl.trainers.jax.step.STEP_TRAINER_DEFAULT_CONFIG¶
alias of {‘close_environment_at_exit’: True, ‘disable_progressbar’: False, ‘environment_info’: ‘episode’, ‘headless’: False, ‘stochastic_evaluation’: False, ‘timesteps’: 100000}
- class skrl.trainers.jax.step.StepTrainer(env: Wrapper, agents: Agent | List[Agent], agents_scope: List[int] | None = None, cfg: dict | None = None)¶
Bases:
Trainer
Step-by-step trainer
Train agents by controlling the training/evaluation loop step by step
- Parameters:
env (skrl.envs.wrappers.jax.Wrapper) – Environment to train on
agents_scope (tuple or list of int, optional) – Number of environments for each agent to train on (default:
None
)cfg (dict, optional) – Configuration dictionary (default:
None
). See STEP_TRAINER_DEFAULT_CONFIG for default values
- eval(timestep: int | None = None, timesteps: int | None = None) Tuple[ndarray | jax.Array, ndarray | jax.Array, ndarray | jax.Array, ndarray | jax.Array, Any] ¶
Evaluate the agents sequentially
This method executes the following steps in loop:
Compute actions (sequentially if num_simultaneous_agents > 1)
Interact with the environments
Render scene
Reset environments
- Parameters:
- Returns:
Observation, reward, terminated, truncated, info
- Return type:
tuple of np.ndarray or jax.Array and any other info
- multi_agent_eval() None ¶
Evaluate multi-agents
This method executes the following steps in loop:
Compute actions (sequentially)
Interact with the environments
Render scene
Reset environments
- multi_agent_train() None ¶
Train multi-agents
This method executes the following steps in loop:
Pre-interaction
Compute actions
Interact with the environments
Render scene
Record transitions
Post-interaction
Reset environments
- single_agent_eval() None ¶
Evaluate agent
This method executes the following steps in loop:
Compute actions (sequentially)
Interact with the environments
Render scene
Reset environments
- single_agent_train() None ¶
Train agent
This method executes the following steps in loop:
Pre-interaction
Compute actions
Interact with the environments
Render scene
Record transitions
Post-interaction
Reset environments
- train(timestep: int | None = None, timesteps: int | None = None) Tuple[ndarray | jax.Array, ndarray | jax.Array, ndarray | jax.Array, ndarray | jax.Array, Any] ¶
Execute a training iteration
This method executes the following steps once:
Pre-interaction (sequentially if num_simultaneous_agents > 1)
Compute actions (sequentially if num_simultaneous_agents > 1)
Interact with the environments
Render scene
Record transitions (sequentially if num_simultaneous_agents > 1)
Post-interaction (sequentially if num_simultaneous_agents > 1)
Reset environments
- Parameters:
- Returns:
Observation, reward, terminated, truncated, info
- Return type:
tuple of np.ndarray or jax.Array and any other info