Base class

Note

This is the base class for all the other classes in this module. It provides the basic functionality for the other classes. It is not intended to be used directly.

Basic inheritance usage

 1from typing import Union, List, Optional
 2
 3import copy
 4
 5from skrl.envs.torch import Wrapper   # from ...envs.torch import Wrapper
 6from skrl.agents.torch import Agent   # from ...agents.torch import Agent
 7
 8from skrl.trainers.torch import Trainer       # from . import Trainer
 9
10
11CUSTOM_DEFAULT_CONFIG = {
12    "timesteps": 100000,            # number of timesteps to train for
13    "headless": False,              # whether to use headless mode (no rendering)
14    "disable_progressbar": False,   # whether to disable the progressbar. If None, disable on non-TTY
15}
16
17
18class CustomTrainer(Trainer):
19    def __init__(self,
20                 env: Wrapper,
21                 agents: Union[Agent, List[Agent], List[List[Agent]]],
22                 agents_scope: Optional[List[int]] = None,
23                 cfg: Optional[dict] = None) -> None:
24        """
25        :param env: Environment to train on
26        :type env: skrl.env.torch.Wrapper
27        :param agents: Agents to train
28        :type agents: Union[Agent, List[Agent]]
29        :param agents_scope: Number of environments for each agent to train on (default: [])
30        :type agents_scope: tuple or list of integers
31        :param cfg: Configuration dictionary
32        :type cfg: dict, optional
33        """
34        _cfg = copy.deepcopy(CUSTOM_DEFAULT_CONFIG)
35        _cfg.update(cfg if cfg is not None else {})
36        agents_scope = agents_scope if agents_scope is not None else []
37        super().__init__(env=env, agents=agents, agents_scope=agents_scope, cfg=_cfg)
38
39        # ================================
40        # - init agents
41        # ================================
42
43    def train(self) -> None:
44        """Train the agents
45        """
46        # ================================
47        # - run training loop
48        #   + call agents.pre_interaction(...)
49        #   + compute actions using agents.act(...)
50        #   + step environment using env.step(...)
51        #   + render scene using env.render(...)
52        #   + record environment transition in memory using agents.record_transition(...)
53        #   + call agents.post_interaction(...)
54        #   + reset environment using env.reset(...)
55        # ================================
56
57    def eval(self) -> None:
58        """Evaluate the agents
59        """
60        # ================================
61        # - run evaluation loop
62        #   + compute actions using agents.act(...)
63        #   + step environment using env.step(...)
64        #   + render scene using env.render(...)
65        #   + call agents.post_interaction(...) parent method to write data to TensorBoard
66        #   + reset environment using env.reset(...)
67        # ================================

API