Ornstein-Uhlenbeck noise¶
Noise generated by a stochastic process that is characterized by its mean-reverting behavior.
Usage¶
The noise usage is defined in each agent’s configuration.
# import the noise class
from skrl.resources.noises.torch import OrnsteinUhlenbeckNoise
cfg = AGENT_CFG()
# ...
cfg.exploration_noise = OrnsteinUhlenbeckNoise
cfg.exploration_noise_kwargs = {"theta": 0.15, "sigma": 0.1, "base_scale": 0.5, "device": device}
# import the noise class
from skrl.resources.noises.jax import OrnsteinUhlenbeckNoise
cfg = AGENT_CFG()
# ...
cfg.exploration_noise = OrnsteinUhlenbeckNoise
cfg.exploration_noise_kwargs = {"theta": 0.15, "sigma": 0.1, "base_scale": 0.5, "device": device}
# import the noise class
from skrl.resources.noises.warp import OrnsteinUhlenbeckNoise
cfg = AGENT_CFG()
# ...
cfg.exploration_noise = OrnsteinUhlenbeckNoise
cfg.exploration_noise_kwargs = {"theta": 0.15, "sigma": 0.1, "base_scale": 0.5, "device": device}
API¶
PyTorch¶
Ornstein-Uhlenbeck noise. |
- class skrl.resources.noises.torch.ornstein_uhlenbeck.OrnsteinUhlenbeckNoise(*, theta: float, sigma: float, base_scale: float, mean: float = 0, std: float = 1, device: str | torch.device | None = None)[source]¶
Bases:
NoiseOrnstein-Uhlenbeck noise.
- Parameters:
theta – Factor to apply to current internal state.
sigma – Factor to apply to the normal distribution.
base_scale – Factor to apply to returned noise.
mean – Mean of the normal distribution.
std – Standard deviation of the normal distribution.
device – Data allocation and computation device. If not specified, the default device will be used.
Example:
>>> noise = OrnsteinUhlenbeckNoise(theta=0.1, sigma=0.2, base_scale=0.5)
Methods:
sample(size)Sample an Ornstein-Uhlenbeck noise.
sample_like(tensor)Sample noise with the same size (shape) as the input tensor.
- sample(size: list[int] | torch.Size) torch.Tensor[source]¶
Sample an Ornstein-Uhlenbeck noise.
- Parameters:
size – Noise shape.
- Returns:
Sampled noise.
Example:
>>> noise.sample((3, 2)) tensor([[-0.0452, 0.0162], [ 0.0649, -0.0708], [-0.0211, 0.0066]], device='cuda:0') >>> x = torch.rand(3, 2, device="cuda:0") >>> noise.sample(x.shape) tensor([[-0.0540, 0.0461], [ 0.1117, -0.1157], [-0.0074, 0.0420]], device='cuda:0')
- sample_like(tensor: torch.Tensor) torch.Tensor[source]¶
Sample noise with the same size (shape) as the input tensor.
This method will call the sampling method as follows
.sample(tensor.shape).- Parameters:
tensor – Input tensor used to determine output tensor size (shape).
- Returns:
Sampled noise.
Example:
>>> x = torch.rand(3, 2, device="cuda:0") >>> noise.sample_like(x) tensor([[-0.0423, -0.1325], [-0.0639, -0.0957], [-0.1367, 0.1031]], device='cuda:0')
JAX¶
Ornstein-Uhlenbeck noise. |
- class skrl.resources.noises.jax.ornstein_uhlenbeck.OrnsteinUhlenbeckNoise(*, theta: float, sigma: float, base_scale: float, mean: float = 0, std: float = 1, device: str | jax.Device | None = None)[source]¶
Bases:
NoiseOrnstein-Uhlenbeck noise.
- Parameters:
theta – Factor to apply to current internal state.
sigma – Factor to apply to the normal distribution.
base_scale – Factor to apply to returned noise.
mean – Mean of the normal distribution.
std – Standard deviation of the normal distribution.
device – Data allocation and computation device. If not specified, the default device will be used.
Example:
>>> noise = OrnsteinUhlenbeckNoise(theta=0.1, sigma=0.2, base_scale=0.5)
Methods:
sample(size)Sample an Ornstein-Uhlenbeck noise.
sample_like(tensor)Sample noise with the same size (shape) as the input tensor.
- sample(size: list[int]) jax.Array[source]¶
Sample an Ornstein-Uhlenbeck noise.
- Parameters:
size – Noise shape.
- Returns:
Sampled noise.
Example:
>>> noise.sample((3, 2)) Array([[ 0.01878439, -0.12833427], [ 0.06494182, 0.12490594], [ 0.024447 , -0.01174496]], dtype=float32) >>> x = jax.random.uniform(jax.random.PRNGKey(0), (3, 2)) >>> noise.sample(x.shape) Array([[ 0.17988093, -1.2289404 ], [ 0.6218886 , 1.1961104 ], [ 0.23410667, -0.11247082]], dtype=float32)
- sample_like(tensor: jax.Array) jax.Array[source]¶
Sample noise with the same size (shape) as the input tensor.
This method will call the sampling method as follows
.sample(tensor.shape).- Parameters:
tensor – Input tensor used to determine output tensor size (shape).
- Returns:
Sampled noise.
Example:
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (3, 2)) >>> noise.sample_like(x) Array([[0.57450044, 0.09968603], [0.7419659 , 0.8941783 ], [0.59656656, 0.45325184]], dtype=float32)
Warp¶
Ornstein-Uhlenbeck noise. |
- class skrl.resources.noises.warp.ornstein_uhlenbeck.OrnsteinUhlenbeckNoise(*, theta: float, sigma: float, base_scale: float, mean: float = 0, std: float = 1, device: str | wp.Device | None = None)[source]¶
Bases:
NoiseOrnstein-Uhlenbeck noise.
- Parameters:
theta – Factor to apply to current internal state.
sigma – Factor to apply to the normal distribution.
base_scale – Factor to apply to returned noise.
mean – Mean of the normal distribution.
std – Standard deviation of the normal distribution.
device – Data allocation and computation device. If not specified, the default device will be used.
Example:
>>> noise = OrnsteinUhlenbeckNoise(theta=0.1, sigma=0.2, base_scale=0.5)
Methods:
sample(size)Sample an Ornstein-Uhlenbeck noise.
sample_like(tensor)Sample noise with the same size (shape) as the input tensor.
- sample(size: list[int]) warp.array[source]¶
Sample an Ornstein-Uhlenbeck noise.
- Parameters:
size – Noise shape.
- Returns:
Sampled noise.
Example:
>>> noise.sample((3, 2)) Array([[ 0.01878439, -0.12833427], [ 0.06494182, 0.12490594], [ 0.024447 , -0.01174496]], dtype=float32) >>> x = jax.random.uniform(jax.random.PRNGKey(0), (3, 2)) >>> noise.sample(x.shape) Array([[ 0.17988093, -1.2289404 ], [ 0.6218886 , 1.1961104 ], [ 0.23410667, -0.11247082]], dtype=float32)
- sample_like(tensor: warp.array) warp.array[source]¶
Sample noise with the same size (shape) as the input tensor.
This method will call the sampling method as follows
.sample(tensor.shape).- Parameters:
tensor – Input tensor used to determine output tensor size (shape).
- Returns:
Sampled noise.
Example:
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (3, 2)) >>> noise.sample_like(x) Array([[0.57450044, 0.09968603], [0.7419659 , 0.8941783 ], [0.59656656, 0.45325184]], dtype=float32)