Trainers¶
Trainers are responsible for orchestrating and managing the training/evaluation of agents and their interactions with the environment.
Trainers |
|
|
---|---|---|
\(\blacksquare\) |
\(\blacksquare\) |
|
\(\blacksquare\) |
\(\square\) |
|
\(\blacksquare\) |
\(\blacksquare\) |
|
\(\blacksquare\) |
\(\blacksquare\) |
Base class¶
Note
This is the base class for all the other classes in this module. It provides the basic functionality for the other classes. It is not intended to be used directly.
Basic inheritance usage¶
from typing import Union, List, Optional
import copy
from skrl.envs.wrappers.torch import Wrapper
from skrl.agents.torch import Agent
from skrl.trainers.torch import Trainer
CUSTOM_DEFAULT_CONFIG = {
"timesteps": 100000, # number of timesteps to train for
"headless": False, # whether to use headless mode (no rendering)
"disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY
"close_environment_at_exit": True, # whether to close the environment on normal program termination
}
class CustomTrainer(Trainer):
def __init__(self,
env: Wrapper,
agents: Union[Agent, List[Agent], List[List[Agent]]],
agents_scope: Optional[List[int]] = None,
cfg: Optional[dict] = None) -> None:
"""
:param env: Environment to train on
:type env: skrl.envs.wrappers.torch.Wrapper
:param agents: Agents to train
:type agents: Union[Agent, List[Agent]]
:param agents_scope: Number of environments for each agent to train on (default: [])
:type agents_scope: tuple or list of integers
:param cfg: Configuration dictionary
:type cfg: dict, optional
"""
_cfg = copy.deepcopy(CUSTOM_DEFAULT_CONFIG)
_cfg.update(cfg if cfg is not None else {})
agents_scope = agents_scope if agents_scope is not None else []
super().__init__(env=env, agents=agents, agents_scope=agents_scope, cfg=_cfg)
# ================================
# - init agents
# ================================
def train(self) -> None:
"""Train the agents
"""
# ================================
# - run training loop
# + call agents.pre_interaction(...)
# + compute actions using agents.act(...)
# + step environment using env.step(...)
# + render scene using env.render(...)
# + record environment transition in memory using agents.record_transition(...)
# + call agents.post_interaction(...)
# + reset environment using env.reset(...)
# ================================
def eval(self) -> None:
"""Evaluate the agents
"""
# ================================
# - run evaluation loop
# + compute actions using agents.act(...)
# + step environment using env.step(...)
# + render scene using env.render(...)
# + call agents.post_interaction(...) parent method to write data to TensorBoard
# + reset environment using env.reset(...)
# ================================
from typing import Union, List, Optional
import copy
from skrl.envs.wrappers.jax import Wrapper
from skrl.agents.jax import Agent
from skrl.trainers.jax import Trainer
CUSTOM_DEFAULT_CONFIG = {
"timesteps": 100000, # number of timesteps to train for
"headless": False, # whether to use headless mode (no rendering)
"disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY
"close_environment_at_exit": True, # whether to close the environment on normal program termination
}
class CustomTrainer(Trainer):
def __init__(self,
env: Wrapper,
agents: Union[Agent, List[Agent], List[List[Agent]]],
agents_scope: Optional[List[int]] = None,
cfg: Optional[dict] = None) -> None:
"""
:param env: Environment to train on
:type env: skrl.envs.wrappers.jax.Wrapper
:param agents: Agents to train
:type agents: Union[Agent, List[Agent]]
:param agents_scope: Number of environments for each agent to train on (default: [])
:type agents_scope: tuple or list of integers
:param cfg: Configuration dictionary
:type cfg: dict, optional
"""
_cfg = copy.deepcopy(CUSTOM_DEFAULT_CONFIG)
_cfg.update(cfg if cfg is not None else {})
agents_scope = agents_scope if agents_scope is not None else []
super().__init__(env=env, agents=agents, agents_scope=agents_scope, cfg=_cfg)
# ================================
# - init agents
# ================================
def train(self) -> None:
"""Train the agents
"""
# ================================
# - run training loop
# + call agents.pre_interaction(...)
# + compute actions using agents.act(...)
# + step environment using env.step(...)
# + render scene using env.render(...)
# + record environment transition in memory using agents.record_transition(...)
# + call agents.post_interaction(...)
# + reset environment using env.reset(...)
# ================================
def eval(self) -> None:
"""Evaluate the agents
"""
# ================================
# - run evaluation loop
# + compute actions using agents.act(...)
# + step environment using env.step(...)
# + render scene using env.render(...)
# + call agents.post_interaction(...) parent method to write data to TensorBoard
# + reset environment using env.reset(...)
# ================================
API (PyTorch)¶
- class skrl.trainers.torch.base.Trainer(env: Wrapper, agents: Agent | List[Agent], agents_scope: List[int] | None = None, cfg: dict | None = None)¶
Bases:
object
- __init__(env: Wrapper, agents: Agent | List[Agent], agents_scope: List[int] | None = None, cfg: dict | None = None) None ¶
Base class for trainers
- Parameters:
- __str__() str ¶
Generate a string representation of the trainer
- Returns:
Representation of the trainer as string
- Return type:
- _setup_agents() None ¶
Setup agents for training
- Raises:
ValueError – Invalid setup
- eval() None ¶
Evaluate the agents
- Raises:
NotImplementedError – Not implemented
- multi_agent_eval() None ¶
Evaluate multi-agents
This method executes the following steps in loop:
Compute actions (sequentially)
Interact with the environments
Render scene
Reset environments
- multi_agent_train() None ¶
Train multi-agents
This method executes the following steps in loop:
Pre-interaction
Compute actions
Interact with the environments
Render scene
Record transitions
Post-interaction
Reset environments
- single_agent_eval() None ¶
Evaluate agent
This method executes the following steps in loop:
Compute actions (sequentially)
Interact with the environments
Render scene
Reset environments
- single_agent_train() None ¶
Train agent
This method executes the following steps in loop:
Pre-interaction
Compute actions
Interact with the environments
Render scene
Record transitions
Post-interaction
Reset environments
- train() None ¶
Train the agents
- Raises:
NotImplementedError – Not implemented
API (JAX)¶
- class skrl.trainers.jax.base.Trainer(env: Wrapper, agents: Agent | List[Agent], agents_scope: List[int] | None = None, cfg: dict | None = None)¶
Bases:
object
- __init__(env: Wrapper, agents: Agent | List[Agent], agents_scope: List[int] | None = None, cfg: dict | None = None) None ¶
Base class for trainers
- Parameters:
- __str__() str ¶
Generate a string representation of the trainer
- Returns:
Representation of the trainer as string
- Return type:
- _setup_agents() None ¶
Setup agents for training
- Raises:
ValueError – Invalid setup
- eval() None ¶
Evaluate the agents
- Raises:
NotImplementedError – Not implemented
- multi_agent_eval() None ¶
Evaluate multi-agents
This method executes the following steps in loop:
Compute actions (sequentially)
Interact with the environments
Render scene
Reset environments
- multi_agent_train() None ¶
Train multi-agents
This method executes the following steps in loop:
Pre-interaction
Compute actions
Interact with the environments
Render scene
Record transitions
Post-interaction
Reset environments
- single_agent_eval() None ¶
Evaluate agent
This method executes the following steps in loop:
Compute actions (sequentially)
Interact with the environments
Render scene
Reset environments
- single_agent_train() None ¶
Train agent
This method executes the following steps in loop:
Pre-interaction
Compute actions
Interact with the environments
Render scene
Record transitions
Post-interaction
Reset environments
- train() None ¶
Train the agents
- Raises:
NotImplementedError – Not implemented