Step trainer

Train agents controlling the training/evaluation loop step-by-step.



Concept

Step-by-step trainer Step-by-step trainer

Usage

from skrl.trainers.torch import StepTrainer

# assuming there is an environment called 'env'
# and an agent or a list of agents called 'agents'

# create a sequential trainer
cfg = {"timesteps": 50000, "headless": False}
trainer = StepTrainer(env=env, agents=agents, cfg=cfg)

# train the agent(s)
for timestep in range(cfg["timesteps"]):
    trainer.train(timestep=timestep)

# evaluate the agent(s)
for timestep in range(cfg["timesteps"]):
    trainer.eval(timestep=timestep)

Configuration

Dataclass

    pytorch    

    jax    

    warp    

StepTrainerCfg

StepTrainerCfg

StepTrainerCfg


API


PyTorch

StepTrainerCfg

Configuration for the step trainer.

StepTrainer

Step-by-step trainer.

class skrl.trainers.torch.step.StepTrainerCfg(*, timesteps: int = 100000, headless: bool = False, render_interval: int = 1, disable_progressbar: bool | None = False, close_environment_at_exit: bool = True, environment_info: str = 'episode', stochastic_evaluation: bool = False)[source]

Bases: TrainerCfg

Configuration for the step trainer.

Methods:

expand()

Expand the configuration.

validate()

Validate the configuration.

Attributes:

close_environment_at_exit

Whether to close the environment on normal program termination.

disable_progressbar

Whether to disable the progressbar.

environment_info

Key used to get and log environment info.

headless

Whether to run in headless mode (do not call env.render()).

render_interval

Interval (in timesteps) for rendering the environments.

stochastic_evaluation

Whether to use actions rather than (deterministic) mean actions during evaluation.

timesteps

Number of timesteps to train/evaluate for.

expand() None[source]

Expand the configuration.

validate() bool[source]

Validate the configuration.

close_environment_at_exit: bool = True

Whether to close the environment on normal program termination.

disable_progressbar: bool | None = False

Whether to disable the progressbar. If None, disable on non-TTY.

environment_info: str = 'episode'

Key used to get and log environment info.

headless: bool = False

Whether to run in headless mode (do not call env.render()).

render_interval: int = 1

Interval (in timesteps) for rendering the environments. Only effective if headless is False.

stochastic_evaluation: bool = False

Whether to use actions rather than (deterministic) mean actions during evaluation.

timesteps: int = 100000

Number of timesteps to train/evaluate for.

class skrl.trainers.torch.step.StepTrainer(*, env: Wrapper | MultiAgentEnvWrapper, agents: Agent | MultiAgent | list[Agent] | list[MultiAgent], scopes: list[int] | None = None, cfg: StepTrainerCfg | dict = {})[source]

Bases: Trainer

Step-by-step trainer.

Train agents by controlling the training/evaluation loop step by step.

Parameters:
  • env – Environment to train/evaluate on.

  • agents – Agent(s) to train/evaluate.

  • scopes – Number of environments for each simultaneous agent to train/evaluate on.

  • cfg – Configuration dictionary.

Methods:

eval([timestep, timesteps])

Execute an evaluation iteration.

reset()

Reset the trainer.

train([timestep, timesteps])

Execute a training iteration.

eval(timestep: int | None = None, timesteps: int | None = None) tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any][source]

Execute an evaluation iteration.

This method executes the following steps in loop:

  • Pre-interaction (sequentially if num_simultaneous_agents > 1)

  • Compute actions (sequentially if num_simultaneous_agents > 1)

  • Interact with the environments

  • Render environments

  • Record transitions (sequentially if num_simultaneous_agents > 1)

  • Reset environments

Parameters:
  • timestep – Current timestep. If None, the current timestep will be carried by an internal variable.

  • timesteps – Total number of timesteps. If None, it is obtained from the trainer’s config.

Returns:

Environment’s observations, rewards, terminated, truncated and info.

reset()[source]

Reset the trainer.

train(timestep: int | None = None, timesteps: int | None = None) tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any][source]

Execute a training iteration.

This method executes the following steps once:

  • Pre-interaction (sequentially if num_simultaneous_agents > 1)

  • Compute actions (sequentially if num_simultaneous_agents > 1)

  • Interact with the environments

  • Render environments

  • Record transitions (sequentially if num_simultaneous_agents > 1)

  • Post-interaction (sequentially if n``um_simultaneous_agents > 1``)

  • Reset environments

Parameters:
  • timestep – Current timestep. If None, the current timestep will be carried by an internal variable.

  • timesteps – Total number of timesteps. If None, it is obtained from the trainer’s config.

Returns:

Environment’s observations, rewards, terminated, truncated and info.


JAX

StepTrainerCfg

Configuration for the step trainer.

StepTrainer

Step-by-step trainer.

class skrl.trainers.jax.step.StepTrainerCfg(*, timesteps: int = 100000, headless: bool = False, render_interval: int = 1, disable_progressbar: bool | None = False, close_environment_at_exit: bool = True, environment_info: str = 'episode', stochastic_evaluation: bool = False)[source]

Bases: TrainerCfg

Configuration for the step trainer.

Methods:

expand()

Expand the configuration.

validate()

Validate the configuration.

Attributes:

close_environment_at_exit

Whether to close the environment on normal program termination.

disable_progressbar

Whether to disable the progressbar.

environment_info

Key used to get and log environment info.

headless

Whether to run in headless mode (do not call env.render()).

render_interval

Interval (in timesteps) for rendering the environments.

stochastic_evaluation

Whether to use actions rather than (deterministic) mean actions during evaluation.

timesteps

Number of timesteps to train/evaluate for.

expand() None[source]

Expand the configuration.

validate() bool[source]

Validate the configuration.

close_environment_at_exit: bool = True

Whether to close the environment on normal program termination.

disable_progressbar: bool | None = False

Whether to disable the progressbar. If None, disable on non-TTY.

environment_info: str = 'episode'

Key used to get and log environment info.

headless: bool = False

Whether to run in headless mode (do not call env.render()).

render_interval: int = 1

Interval (in timesteps) for rendering the environments. Only effective if headless is False.

stochastic_evaluation: bool = False

Whether to use actions rather than (deterministic) mean actions during evaluation.

timesteps: int = 100000

Number of timesteps to train/evaluate for.

class skrl.trainers.jax.step.StepTrainer(*, env: Wrapper | MultiAgentEnvWrapper, agents: Agent | MultiAgent | list[Agent] | list[MultiAgent], scopes: list[int] | None = None, cfg: StepTrainerCfg | dict = {})[source]

Bases: Trainer

Step-by-step trainer.

Train agents by controlling the training/evaluation loop step by step.

Parameters:
  • env – Environment to train/evaluate on.

  • agents – Agent(s) to train/evaluate.

  • scopes – Number of environments for each simultaneous agent to train/evaluate on.

  • cfg – Configuration dictionary.

Methods:

eval([timestep, timesteps])

Execute an evaluation iteration.

reset()

Reset the trainer.

train([timestep, timesteps])

Execute a training iteration.

eval(timestep: int | None = None, timesteps: int | None = None) tuple[jax.Array, jax.Array, jax.Array, jax.Array, Any][source]

Execute an evaluation iteration.

This method executes the following steps in loop:

  • Pre-interaction (sequentially if num_simultaneous_agents > 1)

  • Compute actions (sequentially if num_simultaneous_agents > 1)

  • Interact with the environments

  • Render environments

  • Record transitions (sequentially if num_simultaneous_agents > 1)

  • Reset environments

Parameters:
  • timestep – Current timestep. If None, the current timestep will be carried by an internal variable.

  • timesteps – Total number of timesteps. If None, it is obtained from the trainer’s config.

Returns:

Environment’s observations, rewards, terminated, truncated and info.

reset()[source]

Reset the trainer.

train(timestep: int | None = None, timesteps: int | None = None) tuple[jax.Array, jax.Array, jax.Array, jax.Array, Any][source]

Execute a training iteration.

This method executes the following steps once:

  • Pre-interaction (sequentially if num_simultaneous_agents > 1)

  • Compute actions (sequentially if num_simultaneous_agents > 1)

  • Interact with the environments

  • Render environments

  • Record transitions (sequentially if num_simultaneous_agents > 1)

  • Post-interaction (sequentially if n``um_simultaneous_agents > 1``)

  • Reset environments

Parameters:
  • timestep – Current timestep. If None, the current timestep will be carried by an internal variable.

  • timesteps – Total number of timesteps. If None, it is obtained from the trainer’s config.

Returns:

Environment’s observations, rewards, terminated, truncated and info.