ML frameworks configuration

Configurations for behavior modification of Machine Learning (ML) frameworks.



API


PyTorch

skrl.config.torch.parse_device(device: str | 'torch.device' | None, validate: bool = True) torch.device

Parse the input device and return a device instance.

Parameters:
  • device – Device specification. If the specified device is None or it cannot be resolved, the default available device will be returned instead.

  • validate – Whether to check that the specified device is valid. Since PyTorch does not check if the specified device index is valid, a tensor is created for the verification.

Returns:

PyTorch device.

skrl.config.torch.device: torch.device = "cuda:${LOCAL_RANK}" | "cpu"

Default device.

The default device, unless specified, is cuda:0 (or cuda:LOCAL_RANK in a distributed environment) if CUDA is available, cpu otherwise.

skrl.config.torch.key: int = 0

Pseudo-random number generator (PRNG) key.

skrl.config.torch.local_rank: int = 0

The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node).

This property reads from the LOCAL_RANK environment variable (0 if it doesn’t exist). See torch.distributed for more details.

Read-only attribute.

skrl.config.torch.rank: int = 0

The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes).

This property reads from the RANK environment variable (0 if it doesn’t exist). See torch.distributed for more details.

Read-only attribute.

skrl.config.torch.world_size: int = 1

The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes).

This property reads from the WORLD_SIZE environment variable (1 if it doesn’t exist). See torch.distributed for more details.

Read-only attribute.

skrl.config.torch.is_distributed: bool = False

Whether if running in a distributed environment.

This property is True when the PyTorch’s distributed environment variable WORLD_SIZE > 1.

Read-only attribute.


JAX

skrl.config.jax.parse_device(device: str | 'jax.Device' | None) jax.Device

Parse the input device and return a Device instance.

Hint

This function supports the PyTorch-like "type:ordinal" string specification (e.g.: "cuda:0").

Warning

This method returns (forces to use) the device local to process in a distributed environment.

Parameters:

device – Device specification. If the specified device is None or it cannot be resolved, the default available device will be returned instead.

Returns:

JAX Device.

skrl.config.jax.device: jax.Device = "cuda:${JAX_LOCAL_RANK}" | "cpu"

Default device.

The default device, unless specified, is cuda:0 if CUDA is available, cpu otherwise. However, in a distributed environment, it is the device local to process with index JAX_RANK.

skrl.config.jax.key: jax.Array = [0, 0]

Pseudo-random number generator (PRNG) key.

Key is formatted as 32-bit unsigned integer and the default device is used.

skrl.config.jax.local_rank: int = 0

The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node).

This property reads from the JAX_LOCAL_RANK environment variable (0 if it doesn’t exist).

Read-only attribute.

skrl.config.jax.rank: int = 0

The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes).

This property reads from the JAX_RANK environment variable (0 if it doesn’t exist).

Read-only attribute.

skrl.config.jax.world_size: int = 1

The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes).

This property reads from the JAX_WORLD_SIZE environment variable (1 if it doesn’t exist).

Read-only attribute.

skrl.config.jax.coordinator_address: str = "127.0.0.1:1234"

IP address and port where process 0 will start a JAX service.

This property reads from the JAX_COORDINATOR_ADDR:JAX_COORDINATOR_PORT environment variables (127.0.0.1:1234 if they don’t exist).

Read-only attribute.

skrl.config.jax.is_distributed: bool = False

Whether if running in a distributed environment.

This property is True when the JAX’s distributed environment variable WORLD_SIZE > 1.

Read-only attribute.


Warp

skrl.config.warp.parse_device(device: str | 'warp.Device' | None) warp.Device

Parse the input device and return a Device instance.

Parameters:

device – Device specification. If the specified device is None or it cannot be resolved, the default available device will be returned instead.

Returns:

Warp Device.

skrl.config.warp.device: warp.context.Device = "cuda:0" | "cpu"

Default device.

The default device, unless specified, is cuda:0 if CUDA is available, cpu otherwise.

skrl.config.warp.key: int = 0

Pseudo-random number generator (PRNG) key.