Noises#
Definition of the noises used by the agents during the exploration stage. All noises inherit from a base class that defines a uniform interface.
Noises |
|
|
---|---|---|
Gaussian noise |
\(\blacksquare\) |
\(\blacksquare\) |
Ornstein-Uhlenbeck noise |
\(\blacksquare\) |
\(\blacksquare\) |
Base class#
Note
This is the base class for all the other classes in this module. It provides the basic functionality for the other classes. It is not intended to be used directly.
Basic inheritance usage#
from typing import Union, Tuple
import torch
from skrl.resources.noises.torch import Noise
class CustomNoise(Noise):
def __init__(self, device: Union[str, torch.device] = "cuda:0") -> None:
"""
:param device: Device on which a torch tensor is or will be allocated (default: "cuda:0")
:type device: str or torch.device, optional
"""
super().__init__(device)
def sample(self, size: Union[Tuple[int], torch.Size]) -> torch.Tensor:
"""Sample noise
:param size: Shape of the sampled tensor
:type size: tuple or list of integers, or torch.Size
:return: Sampled noise
:rtype: torch.Tensor
"""
# ================================
# - sample noise
# ================================
from typing import Optional, Union, Tuple
import numpy as np
import jaxlib
import jax.numpy as jnp
from skrl.resources.noises.torch import Noise
class CustomNoise(Noise):
def __init__(self, device: Optional[Union[str, jaxlib.xla_extension.Device]] = None) -> None:
"""Custom noise
:param device: Device on which a jax array is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda:0"`` if available or ``"cpu"``
:type device: str or jaxlib.xla_extension.Device, optional
"""
super().__init__(device)
def sample(self, size: Tuple[int]) -> Union[np.ndarray, jnp.ndarray]:
"""Sample noise
:param size: Shape of the sampled tensor
:type size: tuple or list of integers
:return: Sampled noise
:rtype: np.ndarray or jnp.ndarray
"""
# ================================
# - sample noise
# ================================
API (PyTorch)#
- class skrl.resources.noises.torch.base.Noise(device: str | torch.device | None = None)#
Bases:
object
- __init__(device: str | torch.device | None = None) None #
Base class representing a noise
- Parameters:
device (str or torch.device, optional) – Device on which a tensor/array is or will be allocated (default:
None
). If None, the device will be either"cuda"
if available or"cpu"
Custom noises should override the
sample
method:import torch from skrl.resources.noises.torch import Noise class CustomNoise(Noise): def __init__(self, device=None): super().__init__(device) def sample(self, size): return torch.rand(size, device=self.device)
- sample(size: Tuple[int] | torch.Size) torch.Tensor #
Noise sampling method to be implemented by the inheriting classes
- Parameters:
size (tuple or list of int, or torch.Size) – Shape of the sampled tensor
- Raises:
NotImplementedError – The method is not implemented by the inheriting classes
- Returns:
Sampled noise
- Return type:
- sample_like(tensor: torch.Tensor) torch.Tensor #
Sample a noise with the same size (shape) as the input tensor
This method will call the sampling method as follows
.sample(tensor.shape)
- Parameters:
tensor (torch.Tensor) – Input tensor used to determine output tensor size (shape)
- Returns:
Sampled noise
- Return type:
Example:
>>> x = torch.rand(3, 2, device="cuda:0") >>> noise.sample_like(x) tensor([[-0.0423, -0.1325], [-0.0639, -0.0957], [-0.1367, 0.1031]], device='cuda:0')
API (JAX)#
- class skrl.resources.noises.jax.base.Noise(device: str | jax.Device | None = None)#
Bases:
object
- __init__(device: str | jax.Device | None = None) None #
Base class representing a noise
- Parameters:
device (str or jax.Device, optional) – Device on which a tensor/array is or will be allocated (default:
None
). If None, the device will be either"cuda"
if available or"cpu"
Custom noises should override the
sample
method:import jax from skrl.resources.noises.jax import Noise class CustomNoise(Noise): def __init__(self, device=None): super().__init__(device) def sample(self, size): return jax.random.uniform(jax.random.PRNGKey(0), size)
- sample(size: Tuple[int]) ndarray | jax.Array #
Noise sampling method to be implemented by the inheriting classes
- Parameters:
- Raises:
NotImplementedError – The method is not implemented by the inheriting classes
- Returns:
Sampled noise
- Return type:
np.ndarray or jax.Array
- sample_like(tensor: ndarray | jax.Array) ndarray | jax.Array #
Sample a noise with the same size (shape) as the input tensor
This method will call the sampling method as follows
.sample(tensor.shape)
- Parameters:
tensor (np.ndarray or jax.Array) – Input tensor used to determine output tensor size (shape)
- Returns:
Sampled noise
- Return type:
np.ndarray or jax.Array
Example:
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (3, 2)) >>> noise.sample_like(x) Array([[0.57450044, 0.09968603], [0.7419659 , 0.8941783 ], [0.59656656, 0.45325184]], dtype=float32)