Step trainer

Train agents controlling the training/evaluation loop step-by-step.



Concept

Step-by-step trainerStep-by-step trainer

Usage

from skrl.trainers.torch import StepTrainer

# assuming there is an environment called 'env'
# and an agent or a list of agents called 'agents'

# create a sequential trainer
cfg = {"timesteps": 50000, "headless": False}
trainer = StepTrainer(env=env, agents=agents, cfg=cfg)

# train the agent(s)
for timestep in range(cfg["timesteps"]):
    trainer.train(timestep=timestep)

# evaluate the agent(s)
for timestep in range(cfg["timesteps"]):
    trainer.eval(timestep=timestep)

Configuration

STEP_TRAINER_DEFAULT_CONFIG = {
    "timesteps": 100000,            # number of timesteps to train for
    "headless": False,              # whether to use headless mode (no rendering)
    "disable_progressbar": False,   # whether to disable the progressbar. If None, disable on non-TTY
    "close_environment_at_exit": True,   # whether to close the environment on normal program termination
}

API (PyTorch)

skrl.trainers.torch.step.STEP_TRAINER_DEFAULT_CONFIG

alias of {‘close_environment_at_exit’: True, ‘disable_progressbar’: False, ‘headless’: False, ‘timesteps’: 100000}

class skrl.trainers.torch.step.StepTrainer(env: Wrapper, agents: Agent | List[Agent], agents_scope: List[int] | None = None, cfg: dict | None = None)

Bases: Trainer

__init__(env: Wrapper, agents: Agent | List[Agent], agents_scope: List[int] | None = None, cfg: dict | None = None) None

Step-by-step trainer

Train agents by controlling the training/evaluation loop step by step

Parameters:
  • env (skrl.envs.wrappers.torch.Wrapper) – Environment to train on

  • agents (Union[Agent, List[Agent]]) – Agents to train

  • agents_scope (tuple or list of int, optional) – Number of environments for each agent to train on (default: None)

  • cfg (dict, optional) – Configuration dictionary (default: None). See STEP_TRAINER_DEFAULT_CONFIG for default values

eval(timestep: int | None = None, timesteps: int | None = None) Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]

Evaluate the agents sequentially

This method executes the following steps in loop:

  • Compute actions (sequentially if num_simultaneous_agents > 1)

  • Interact with the environments

  • Render scene

  • Reset environments

Parameters:
  • timestep (int, optional) – Current timestep (default: None). If None, the current timestep will be carried by an internal variable

  • timesteps (int, optional) – Total number of timesteps (default: None). If None, the total number of timesteps is obtained from the trainer’s config

Returns:

Observation, reward, terminated, truncated, info

Return type:

tuple of torch.Tensor and any other info

multi_agent_eval() None

Evaluate multi-agents

This method executes the following steps in loop:

  • Compute actions (sequentially)

  • Interact with the environments

  • Render scene

  • Reset environments

multi_agent_train() None

Train multi-agents

This method executes the following steps in loop:

  • Pre-interaction

  • Compute actions

  • Interact with the environments

  • Render scene

  • Record transitions

  • Post-interaction

  • Reset environments

single_agent_eval() None

Evaluate agent

This method executes the following steps in loop:

  • Compute actions (sequentially)

  • Interact with the environments

  • Render scene

  • Reset environments

single_agent_train() None

Train agent

This method executes the following steps in loop:

  • Pre-interaction

  • Compute actions

  • Interact with the environments

  • Render scene

  • Record transitions

  • Post-interaction

  • Reset environments

train(timestep: int | None = None, timesteps: int | None = None) Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]

Execute a training iteration

This method executes the following steps once:

  • Pre-interaction (sequentially if num_simultaneous_agents > 1)

  • Compute actions (sequentially if num_simultaneous_agents > 1)

  • Interact with the environments

  • Render scene

  • Record transitions (sequentially if num_simultaneous_agents > 1)

  • Post-interaction (sequentially if num_simultaneous_agents > 1)

  • Reset environments

Parameters:
  • timestep (int, optional) – Current timestep (default: None). If None, the current timestep will be carried by an internal variable

  • timesteps (int, optional) – Total number of timesteps (default: None). If None, the total number of timesteps is obtained from the trainer’s config

Returns:

Observation, reward, terminated, truncated, info

Return type:

tuple of torch.Tensor and any other info


API (JAX)

skrl.trainers.jax.step.STEP_TRAINER_DEFAULT_CONFIG

alias of {‘close_environment_at_exit’: True, ‘disable_progressbar’: False, ‘headless’: False, ‘timesteps’: 100000}

class skrl.trainers.jax.step.StepTrainer(env: Wrapper, agents: Agent | List[Agent], agents_scope: List[int] | None = None, cfg: dict | None = None)

Bases: Trainer

__init__(env: Wrapper, agents: Agent | List[Agent], agents_scope: List[int] | None = None, cfg: dict | None = None) None

Step-by-step trainer

Train agents by controlling the training/evaluation loop step by step

Parameters:
  • env (skrl.envs.wrappers.jax.Wrapper) – Environment to train on

  • agents (Union[Agent, List[Agent]]) – Agents to train

  • agents_scope (tuple or list of int, optional) – Number of environments for each agent to train on (default: None)

  • cfg (dict, optional) – Configuration dictionary (default: None). See STEP_TRAINER_DEFAULT_CONFIG for default values

eval(timestep: int | None = None, timesteps: int | None = None) Tuple[ndarray | jax.Array, ndarray | jax.Array, ndarray | jax.Array, ndarray | jax.Array, Any]

Evaluate the agents sequentially

This method executes the following steps in loop:

  • Compute actions (sequentially if num_simultaneous_agents > 1)

  • Interact with the environments

  • Render scene

  • Reset environments

Parameters:
  • timestep (int, optional) – Current timestep (default: None). If None, the current timestep will be carried by an internal variable

  • timesteps (int, optional) – Total number of timesteps (default: None). If None, the total number of timesteps is obtained from the trainer’s config

Returns:

Observation, reward, terminated, truncated, info

Return type:

tuple of np.ndarray or jax.Array and any other info

multi_agent_eval() None

Evaluate multi-agents

This method executes the following steps in loop:

  • Compute actions (sequentially)

  • Interact with the environments

  • Render scene

  • Reset environments

multi_agent_train() None

Train multi-agents

This method executes the following steps in loop:

  • Pre-interaction

  • Compute actions

  • Interact with the environments

  • Render scene

  • Record transitions

  • Post-interaction

  • Reset environments

single_agent_eval() None

Evaluate agent

This method executes the following steps in loop:

  • Compute actions (sequentially)

  • Interact with the environments

  • Render scene

  • Reset environments

single_agent_train() None

Train agent

This method executes the following steps in loop:

  • Pre-interaction

  • Compute actions

  • Interact with the environments

  • Render scene

  • Record transitions

  • Post-interaction

  • Reset environments

train(timestep: int | None = None, timesteps: int | None = None) Tuple[ndarray | jax.Array, ndarray | jax.Array, ndarray | jax.Array, ndarray | jax.Array, Any]

Execute a training iteration

This method executes the following steps once:

  • Pre-interaction (sequentially if num_simultaneous_agents > 1)

  • Compute actions (sequentially if num_simultaneous_agents > 1)

  • Interact with the environments

  • Render scene

  • Record transitions (sequentially if num_simultaneous_agents > 1)

  • Post-interaction (sequentially if num_simultaneous_agents > 1)

  • Reset environments

Parameters:
  • timestep (int, optional) – Current timestep (default: None). If None, the current timestep will be carried by an internal variable

  • timesteps (int, optional) – Total number of timesteps (default: None). If None, the total number of timesteps is obtained from the trainer’s config

Returns:

Observation, reward, terminated, truncated, info

Return type:

tuple of np.ndarray or jax.Array and any other info