Step trainer¶
Train agents controlling the training/evaluation loop step-by-step.
Concept¶
Usage¶
from skrl.trainers.torch import StepTrainer
# assuming there is an environment called 'env'
# and an agent or a list of agents called 'agents'
# create a sequential trainer
cfg = {"timesteps": 50000, "headless": False}
trainer = StepTrainer(env=env, agents=agents, cfg=cfg)
# train the agent(s)
for timestep in range(cfg["timesteps"]):
trainer.train(timestep=timestep)
# evaluate the agent(s)
for timestep in range(cfg["timesteps"]):
trainer.eval(timestep=timestep)
from skrl.trainers.jax import StepTrainer
# assuming there is an environment called 'env'
# and an agent or a list of agents called 'agents'
# create a sequential trainer
cfg = {"timesteps": 50000, "headless": False}
trainer = StepTrainer(env=env, agents=agents, cfg=cfg)
# train the agent(s)
for timestep in range(cfg["timesteps"]):
trainer.train(timestep=timestep)
# evaluate the agent(s)
for timestep in range(cfg["timesteps"]):
trainer.eval(timestep=timestep)
Configuration¶
Dataclass |
|
|
|
|---|---|---|---|
|
API¶
PyTorch¶
Configuration for the step trainer. |
|
Step-by-step trainer. |
- class skrl.trainers.torch.step.StepTrainerCfg(*, timesteps: int = 100000, headless: bool = False, render_interval: int = 1, disable_progressbar: bool | None = False, close_environment_at_exit: bool = True, environment_info: str = 'episode', stochastic_evaluation: bool = False)[source]¶
Bases:
TrainerCfgConfiguration for the step trainer.
Methods:
Attributes:
Whether to close the environment on normal program termination.
Whether to disable the progressbar.
Key used to get and log environment info.
Whether to run in headless mode (do not call
env.render()).Interval (in timesteps) for rendering the environments.
Whether to use actions rather than (deterministic) mean actions during evaluation.
Number of timesteps to train/evaluate for.
- close_environment_at_exit: bool = True¶
Whether to close the environment on normal program termination.
- disable_progressbar: bool | None = False¶
Whether to disable the progressbar. If None, disable on non-TTY.
- render_interval: int = 1¶
Interval (in timesteps) for rendering the environments. Only effective if
headlessis False.
- class skrl.trainers.torch.step.StepTrainer(*, env: Wrapper | MultiAgentEnvWrapper, agents: Agent | MultiAgent | list[Agent] | list[MultiAgent], scopes: list[int] | None = None, cfg: StepTrainerCfg | dict = {})[source]¶
Bases:
TrainerStep-by-step trainer.
Train agents by controlling the training/evaluation loop step by step.
- Parameters:
env – Environment to train/evaluate on.
agents – Agent(s) to train/evaluate.
scopes – Number of environments for each simultaneous agent to train/evaluate on.
cfg – Configuration dictionary.
Methods:
eval([timestep, timesteps])Execute an evaluation iteration.
reset()Reset the trainer.
train([timestep, timesteps])Execute a training iteration.
- eval(timestep: int | None = None, timesteps: int | None = None) tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any][source]¶
Execute an evaluation iteration.
This method executes the following steps in loop:
Pre-interaction (sequentially if
num_simultaneous_agents > 1)Compute actions (sequentially if
num_simultaneous_agents > 1)Interact with the environments
Render environments
Record transitions (sequentially if
num_simultaneous_agents > 1)Reset environments
- Parameters:
timestep – Current timestep. If None, the current timestep will be carried by an internal variable.
timesteps – Total number of timesteps. If None, it is obtained from the trainer’s config.
- Returns:
Environment’s observations, rewards, terminated, truncated and info.
- train(timestep: int | None = None, timesteps: int | None = None) tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any][source]¶
Execute a training iteration.
This method executes the following steps once:
Pre-interaction (sequentially if
num_simultaneous_agents > 1)Compute actions (sequentially if
num_simultaneous_agents > 1)Interact with the environments
Render environments
Record transitions (sequentially if
num_simultaneous_agents > 1)Post-interaction (sequentially if n``um_simultaneous_agents > 1``)
Reset environments
- Parameters:
timestep – Current timestep. If None, the current timestep will be carried by an internal variable.
timesteps – Total number of timesteps. If None, it is obtained from the trainer’s config.
- Returns:
Environment’s observations, rewards, terminated, truncated and info.
JAX¶
Configuration for the step trainer. |
|
Step-by-step trainer. |
- class skrl.trainers.jax.step.StepTrainerCfg(*, timesteps: int = 100000, headless: bool = False, render_interval: int = 1, disable_progressbar: bool | None = False, close_environment_at_exit: bool = True, environment_info: str = 'episode', stochastic_evaluation: bool = False)[source]¶
Bases:
TrainerCfgConfiguration for the step trainer.
Methods:
Attributes:
Whether to close the environment on normal program termination.
Whether to disable the progressbar.
Key used to get and log environment info.
Whether to run in headless mode (do not call
env.render()).Interval (in timesteps) for rendering the environments.
Whether to use actions rather than (deterministic) mean actions during evaluation.
Number of timesteps to train/evaluate for.
- close_environment_at_exit: bool = True¶
Whether to close the environment on normal program termination.
- disable_progressbar: bool | None = False¶
Whether to disable the progressbar. If None, disable on non-TTY.
- render_interval: int = 1¶
Interval (in timesteps) for rendering the environments. Only effective if
headlessis False.
- class skrl.trainers.jax.step.StepTrainer(*, env: Wrapper | MultiAgentEnvWrapper, agents: Agent | MultiAgent | list[Agent] | list[MultiAgent], scopes: list[int] | None = None, cfg: StepTrainerCfg | dict = {})[source]¶
Bases:
TrainerStep-by-step trainer.
Train agents by controlling the training/evaluation loop step by step.
- Parameters:
env – Environment to train/evaluate on.
agents – Agent(s) to train/evaluate.
scopes – Number of environments for each simultaneous agent to train/evaluate on.
cfg – Configuration dictionary.
Methods:
eval([timestep, timesteps])Execute an evaluation iteration.
reset()Reset the trainer.
train([timestep, timesteps])Execute a training iteration.
- eval(timestep: int | None = None, timesteps: int | None = None) tuple[jax.Array, jax.Array, jax.Array, jax.Array, Any][source]¶
Execute an evaluation iteration.
This method executes the following steps in loop:
Pre-interaction (sequentially if
num_simultaneous_agents > 1)Compute actions (sequentially if
num_simultaneous_agents > 1)Interact with the environments
Render environments
Record transitions (sequentially if
num_simultaneous_agents > 1)Reset environments
- Parameters:
timestep – Current timestep. If None, the current timestep will be carried by an internal variable.
timesteps – Total number of timesteps. If None, it is obtained from the trainer’s config.
- Returns:
Environment’s observations, rewards, terminated, truncated and info.
- train(timestep: int | None = None, timesteps: int | None = None) tuple[jax.Array, jax.Array, jax.Array, jax.Array, Any][source]¶
Execute a training iteration.
This method executes the following steps once:
Pre-interaction (sequentially if
num_simultaneous_agents > 1)Compute actions (sequentially if
num_simultaneous_agents > 1)Interact with the environments
Render environments
Record transitions (sequentially if
num_simultaneous_agents > 1)Post-interaction (sequentially if n``um_simultaneous_agents > 1``)
Reset environments
- Parameters:
timestep – Current timestep. If None, the current timestep will be carried by an internal variable.
timesteps – Total number of timesteps. If None, it is obtained from the trainer’s config.
- Returns:
Environment’s observations, rewards, terminated, truncated and info.