Trainers

Trainers are responsible for orchestrating and managing the training/evaluation of agents and their interactions with the environment.



Trainers

    pytorch    

    jax    

Sequential trainer

\(\blacksquare\)

\(\blacksquare\)

Parallel trainer

\(\blacksquare\)

\(\square\)

Step trainer

\(\blacksquare\)

\(\blacksquare\)

Manual training

\(\blacksquare\)

\(\blacksquare\)

Base class

Note

This is the base class for all the other classes in this module. It provides the basic functionality for the other classes. It is not intended to be used directly.


Basic inheritance usage

from typing import Union, List, Optional

import copy

from skrl.envs.wrappers.torch import Wrapper
from skrl.agents.torch import Agent

from skrl.trainers.torch import Trainer


CUSTOM_DEFAULT_CONFIG = {
    "timesteps": 100000,            # number of timesteps to train for
    "headless": False,              # whether to use headless mode (no rendering)
    "disable_progressbar": False,   # whether to disable the progressbar. If None, disable on non-TTY
    "close_environment_at_exit": True,   # whether to close the environment on normal program termination
}


class CustomTrainer(Trainer):
    def __init__(self,
                 env: Wrapper,
                 agents: Union[Agent, List[Agent], List[List[Agent]]],
                 agents_scope: Optional[List[int]] = None,
                 cfg: Optional[dict] = None) -> None:
        """
        :param env: Environment to train on
        :type env: skrl.envs.wrappers.torch.Wrapper
        :param agents: Agents to train
        :type agents: Union[Agent, List[Agent]]
        :param agents_scope: Number of environments for each agent to train on (default: [])
        :type agents_scope: tuple or list of integers
        :param cfg: Configuration dictionary
        :type cfg: dict, optional
        """
        _cfg = copy.deepcopy(CUSTOM_DEFAULT_CONFIG)
        _cfg.update(cfg if cfg is not None else {})
        agents_scope = agents_scope if agents_scope is not None else []
        super().__init__(env=env, agents=agents, agents_scope=agents_scope, cfg=_cfg)

        # ================================
        # - init agents
        # ================================

    def train(self) -> None:
        """Train the agents
        """
        # ================================
        # - run training loop
        #   + call agents.pre_interaction(...)
        #   + compute actions using agents.act(...)
        #   + step environment using env.step(...)
        #   + render scene using env.render(...)
        #   + record environment transition in memory using agents.record_transition(...)
        #   + call agents.post_interaction(...)
        #   + reset environment using env.reset(...)
        # ================================

    def eval(self) -> None:
        """Evaluate the agents
        """
        # ================================
        # - run evaluation loop
        #   + compute actions using agents.act(...)
        #   + step environment using env.step(...)
        #   + render scene using env.render(...)
        #   + call agents.post_interaction(...) parent method to write data to TensorBoard
        #   + reset environment using env.reset(...)
        # ================================

API (PyTorch)

class skrl.trainers.torch.base.Trainer(env: Wrapper, agents: Agent | List[Agent], agents_scope: List[int] | None = None, cfg: dict | None = None)

Bases: object

__init__(env: Wrapper, agents: Agent | List[Agent], agents_scope: List[int] | None = None, cfg: dict | None = None) None

Base class for trainers

Parameters:
  • env (skrl.envs.wrappers.torch.Wrapper) – Environment to train on

  • agents (Union[Agent, List[Agent]]) – Agents to train

  • agents_scope (tuple or list of int, optional) – Number of environments for each agent to train on (default: None)

  • cfg (dict, optional) – Configuration dictionary (default: None)

__str__() str

Generate a string representation of the trainer

Returns:

Representation of the trainer as string

Return type:

str

_setup_agents() None

Setup agents for training

Raises:

ValueError – Invalid setup

eval() None

Evaluate the agents

Raises:

NotImplementedError – Not implemented

multi_agent_eval() None

Evaluate multi-agents

This method executes the following steps in loop:

  • Compute actions (sequentially)

  • Interact with the environments

  • Render scene

  • Reset environments

multi_agent_train() None

Train multi-agents

This method executes the following steps in loop:

  • Pre-interaction

  • Compute actions

  • Interact with the environments

  • Render scene

  • Record transitions

  • Post-interaction

  • Reset environments

single_agent_eval() None

Evaluate agent

This method executes the following steps in loop:

  • Compute actions (sequentially)

  • Interact with the environments

  • Render scene

  • Reset environments

single_agent_train() None

Train agent

This method executes the following steps in loop:

  • Pre-interaction

  • Compute actions

  • Interact with the environments

  • Render scene

  • Record transitions

  • Post-interaction

  • Reset environments

train() None

Train the agents

Raises:

NotImplementedError – Not implemented


API (JAX)

class skrl.trainers.jax.base.Trainer(env: Wrapper, agents: Agent | List[Agent], agents_scope: List[int] | None = None, cfg: dict | None = None)

Bases: object

__init__(env: Wrapper, agents: Agent | List[Agent], agents_scope: List[int] | None = None, cfg: dict | None = None) None

Base class for trainers

Parameters:
  • env (skrl.envs.wrappers.jax.Wrapper) – Environment to train on

  • agents (Union[Agent, List[Agent]]) – Agents to train

  • agents_scope (tuple or list of int, optional) – Number of environments for each agent to train on (default: None)

  • cfg (dict, optional) – Configuration dictionary (default: None)

__str__() str

Generate a string representation of the trainer

Returns:

Representation of the trainer as string

Return type:

str

_setup_agents() None

Setup agents for training

Raises:

ValueError – Invalid setup

eval() None

Evaluate the agents

Raises:

NotImplementedError – Not implemented

multi_agent_eval() None

Evaluate multi-agents

This method executes the following steps in loop:

  • Compute actions (sequentially)

  • Interact with the environments

  • Render scene

  • Reset environments

multi_agent_train() None

Train multi-agents

This method executes the following steps in loop:

  • Pre-interaction

  • Compute actions

  • Interact with the environments

  • Render scene

  • Record transitions

  • Post-interaction

  • Reset environments

single_agent_eval() None

Evaluate agent

This method executes the following steps in loop:

  • Compute actions (sequentially)

  • Interact with the environments

  • Render scene

  • Reset environments

single_agent_train() None

Train agent

This method executes the following steps in loop:

  • Pre-interaction

  • Compute actions

  • Interact with the environments

  • Render scene

  • Record transitions

  • Post-interaction

  • Reset environments

train() None

Train the agents

Raises:

NotImplementedError – Not implemented