Noises

Definition of the noises used by the agents during the exploration stage. All noises inherit from a base class that defines a uniform interface.



Noises

    pytorch    

    jax    

Gaussian noise

\(\blacksquare\)

\(\blacksquare\)

Ornstein-Uhlenbeck noise  

\(\blacksquare\)

\(\blacksquare\)

Base class

Note

This is the base class for all the other classes in this module. It provides the basic functionality for the other classes. It is not intended to be used directly.


Basic inheritance usage

from typing import Union, Tuple

import torch

from skrl.resources.noises.torch import Noise


class CustomNoise(Noise):
    def __init__(self, device: Union[str, torch.device] = "cuda:0") -> None:
        """
        :param device: Device on which a torch tensor is or will be allocated (default: "cuda:0")
        :type device: str or torch.device, optional
        """
        super().__init__(device)

    def sample(self, size: Union[Tuple[int], torch.Size]) -> torch.Tensor:
        """Sample noise

        :param size: Shape of the sampled tensor
        :type size: tuple or list of integers, or torch.Size

        :return: Sampled noise
        :rtype: torch.Tensor
        """
        # ================================
        # - sample noise
        # ================================

API (PyTorch)

class skrl.resources.noises.torch.base.Noise(device: str | torch.device | None = None)

Bases: object

__init__(device: str | torch.device | None = None) None

Base class representing a noise

Parameters:

device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

Custom noises should override the sample method:

import torch
from skrl.resources.noises.torch import Noise

class CustomNoise(Noise):
    def __init__(self, device=None):
        super().__init__(device)

    def sample(self, size):
        return torch.rand(size, device=self.device)
sample(size: Tuple[int] | torch.Size) torch.Tensor

Noise sampling method to be implemented by the inheriting classes

Parameters:

size (tuple or list of int, or torch.Size) – Shape of the sampled tensor

Raises:

NotImplementedError – The method is not implemented by the inheriting classes

Returns:

Sampled noise

Return type:

torch.Tensor

sample_like(tensor: torch.Tensor) torch.Tensor

Sample a noise with the same size (shape) as the input tensor

This method will call the sampling method as follows .sample(tensor.shape)

Parameters:

tensor (torch.Tensor) – Input tensor used to determine output tensor size (shape)

Returns:

Sampled noise

Return type:

torch.Tensor

Example:

>>> x = torch.rand(3, 2, device="cuda:0")
>>> noise.sample_like(x)
tensor([[-0.0423, -0.1325],
        [-0.0639, -0.0957],
        [-0.1367,  0.1031]], device='cuda:0')

API (JAX)

class skrl.resources.noises.jax.base.Noise(device: str | jax.Device | None = None)

Bases: object

__init__(device: str | jax.Device | None = None) None

Base class representing a noise

Parameters:

device (str or jax.Device, optional) – Device on which a tensor/array is or will be allocated (default: None). If None, the device will be either "cuda" if available or "cpu"

Custom noises should override the sample method:

import jax
from skrl.resources.noises.jax import Noise

class CustomNoise(Noise):
    def __init__(self, device=None):
        super().__init__(device)

    def sample(self, size):
        return jax.random.uniform(jax.random.PRNGKey(0), size)
sample(size: Tuple[int]) ndarray | jax.Array

Noise sampling method to be implemented by the inheriting classes

Parameters:

size (tuple or list of int) – Shape of the sampled tensor

Raises:

NotImplementedError – The method is not implemented by the inheriting classes

Returns:

Sampled noise

Return type:

np.ndarray or jax.Array

sample_like(tensor: ndarray | jax.Array) ndarray | jax.Array

Sample a noise with the same size (shape) as the input tensor

This method will call the sampling method as follows .sample(tensor.shape)

Parameters:

tensor (np.ndarray or jax.Array) – Input tensor used to determine output tensor size (shape)

Returns:

Sampled noise

Return type:

np.ndarray or jax.Array

Example:

>>> x = jax.random.uniform(jax.random.PRNGKey(0), (3, 2))
>>> noise.sample_like(x)
Array([[0.57450044, 0.09968603],
       [0.7419659 , 0.8941783 ],
       [0.59656656, 0.45325184]], dtype=float32)