Trainers

Trainers are responsible for orchestrating and managing the training/evaluation of agents and their interactions with the environment.



Implemented trainers

The following table lists the implemented trainers and their support for different frameworks.

Trainers

    pytorch    

    jax    

    warp    

Sequential trainer

\(\blacksquare\)

\(\blacksquare\)

\(\blacksquare\)

Parallel trainer

\(\blacksquare\)

\(\square\)

\(\square\)

Step trainer

\(\blacksquare\)

\(\blacksquare\)

\(\square\)



Base class

Base class and configuration for trainer implementations.

API


PyTorch

TrainerCfg

Base class for the trainer's configuration.

Trainer

Base trainer class for implementing custom trainers.

class skrl.trainers.torch.TrainerCfg(*, timesteps: int = 100000, headless: bool = False, render_interval: int = 1, disable_progressbar: bool | None = False, close_environment_at_exit: bool = True, environment_info: str = 'episode', stochastic_evaluation: bool = False)[source]

Bases: ABC

Base class for the trainer’s configuration.

Methods:

expand()

Expand the configuration.

validate()

Validate the configuration.

Attributes:

close_environment_at_exit

Whether to close the environment on normal program termination.

disable_progressbar

Whether to disable the progressbar.

environment_info

Key used to get and log environment info.

headless

Whether to run in headless mode (do not call env.render()).

render_interval

Interval (in timesteps) for rendering the environments.

stochastic_evaluation

Whether to use actions rather than (deterministic) mean actions during evaluation.

timesteps

Number of timesteps to train/evaluate for.

expand() None[source]

Expand the configuration.

validate() bool[source]

Validate the configuration.

close_environment_at_exit: bool = True

Whether to close the environment on normal program termination.

disable_progressbar: bool | None = False

Whether to disable the progressbar. If None, disable on non-TTY.

environment_info: str = 'episode'

Key used to get and log environment info.

headless: bool = False

Whether to run in headless mode (do not call env.render()).

render_interval: int = 1

Interval (in timesteps) for rendering the environments. Only effective if headless is False.

stochastic_evaluation: bool = False

Whether to use actions rather than (deterministic) mean actions during evaluation.

timesteps: int = 100000

Number of timesteps to train/evaluate for.

class skrl.trainers.torch.Trainer(*, cfg: TrainerCfg, env: Wrapper | MultiAgentEnvWrapper, agents: Agent | MultiAgent | list[Agent] | list[MultiAgent], scopes: list[int] | None = None)[source]

Bases: ABC

Base trainer class for implementing custom trainers.

Parameters:
  • cfg – Configuration dictionary.

  • env – Environment to train/evaluate on.

  • agents – Agent(s) to train/evaluate.

  • scopes – Number of environments for each simultaneous agent to train/evaluate on.

Methods:

eval()

Evaluate a single/multi-agent.

train()

Train a single/multi-agent.

eval() None[source]

Evaluate a single/multi-agent.

This method executes the following steps in loop:

  • Pre-interaction

  • Compute actions

  • Interact with the environments

  • Render environments

  • Record transitions

  • Reset environments

Raises:

AssertionError – If the method is called in a simultaneous agents setup.

train() None[source]

Train a single/multi-agent.

This method executes the following steps in loop:

  • Pre-interaction

  • Compute actions

  • Interact with the environments

  • Render environments

  • Record transitions

  • Post-interaction

  • Reset environments

Raises:

AssertionError – If the method is called in a simultaneous agents setup.


JAX

TrainerCfg

Base class for the trainer's configuration.

Trainer

Base trainer class for implementing custom trainers.

class skrl.trainers.jax.TrainerCfg(*, timesteps: int = 100000, headless: bool = False, render_interval: int = 1, disable_progressbar: bool | None = False, close_environment_at_exit: bool = True, environment_info: str = 'episode', stochastic_evaluation: bool = False)[source]

Bases: ABC

Base class for the trainer’s configuration.

Methods:

expand()

Expand the configuration.

validate()

Validate the configuration.

Attributes:

close_environment_at_exit

Whether to close the environment on normal program termination.

disable_progressbar

Whether to disable the progressbar.

environment_info

Key used to get and log environment info.

headless

Whether to run in headless mode (do not call env.render()).

render_interval

Interval (in timesteps) for rendering the environments.

stochastic_evaluation

Whether to use actions rather than (deterministic) mean actions during evaluation.

timesteps

Number of timesteps to train/evaluate for.

expand() None[source]

Expand the configuration.

validate() bool[source]

Validate the configuration.

close_environment_at_exit: bool = True

Whether to close the environment on normal program termination.

disable_progressbar: bool | None = False

Whether to disable the progressbar. If None, disable on non-TTY.

environment_info: str = 'episode'

Key used to get and log environment info.

headless: bool = False

Whether to run in headless mode (do not call env.render()).

render_interval: int = 1

Interval (in timesteps) for rendering the environments. Only effective if headless is False.

stochastic_evaluation: bool = False

Whether to use actions rather than (deterministic) mean actions during evaluation.

timesteps: int = 100000

Number of timesteps to train/evaluate for.

class skrl.trainers.jax.Trainer(*, cfg: TrainerCfg, env: Wrapper | MultiAgentEnvWrapper, agents: Agent | MultiAgent | list[Agent] | list[MultiAgent], scopes: list[int] | None = None)[source]

Bases: ABC

Base trainer class for implementing custom trainers.

Parameters:
  • cfg – Configuration dictionary.

  • env – Environment to train/evaluate on.

  • agents – Agent(s) to train/evaluate.

  • scopes – Number of environments for each simultaneous agent to train/evaluate on.

Methods:

eval()

Evaluate a single/multi-agent.

train()

Train a single/multi-agent.

eval() None[source]

Evaluate a single/multi-agent.

This method executes the following steps in loop:

  • Pre-interaction

  • Compute actions

  • Interact with the environments

  • Render environments

  • Record transitions

  • Reset environments

Raises:

AssertionError – If the method is called in a simultaneous agents setup.

train() None[source]

Train a single/multi-agent.

This method executes the following steps in loop:

  • Pre-interaction

  • Compute actions

  • Interact with the environments

  • Render environments

  • Record transitions

  • Post-interaction

  • Reset environments

Raises:

AssertionError – If the method is called in a simultaneous agents setup.


Warp

TrainerCfg

Base class for the trainer's configuration.

Trainer

Base trainer class for implementing custom trainers.