OmniSafe Math#

get_transpose(tensor)

Transpose the last two dimensions of a tensor.

get_diagonal(tensor)

Get the diagonal of the last two dimensions of a tensor.

discount_cumsum(vector_x, discount)

Compute the discounted cumulative sum of vectors.

conjugate_gradients(fisher_product, vector_b)

Implementation of Conjugate gradient algorithm.

SafeTanhTransformer([cache_size])

Safe Tanh Transformer.

TanhNormal(loc, scale)

Create a tanh-normal distribution.

Tensor Operations#

Documentation

omnisafe.utils.math.get_transpose(tensor)[source]#

Transpose the last two dimensions of a tensor.

Examples

>>> tensor = torch.rand(2, 3)
>>> get_transpose(tensor).shape
torch.Size([3, 2])
Parameters:

tensor (torch.Tensor) – The tensor to transpose.

Returns:

Transposed tensor.

Return type:

Tensor

omnisafe.utils.math.get_diagonal(tensor)[source]#

Get the diagonal of the last two dimensions of a tensor.

Examples

>>> tensor = torch.rand(3, 3)
>>> get_diagonal(tensor).shape
torch.Size([1, 3])
Parameters:

tensor (torch.Tensor) – The tensor to get the diagonal from.

Returns:

Diagonal part of the tensor.

Return type:

Tensor

omnisafe.utils.math.discount_cumsum(vector_x, discount)[source]#

Compute the discounted cumulative sum of vectors.

Examples

>>> vector_x = torch.arange(1, 5)
>>> vector_x
tensor([1, 2, 3, 4])
>>> discount_cumsum(vector_x, 0.9)
tensor([8.15, 5.23, 2.80, 1.00])
Parameters:
  • vector_x (torch.Tensor) – A sequence of shape (B, T).

  • discount (float) – The discount factor.

Returns:

The discounted cumulative sum of vectors.

Return type:

Tensor

omnisafe.utils.math.conjugate_gradients(fisher_product, vector_b, num_steps=10, residual_tol=1e-10, eps=1e-6)[source]#

Implementation of Conjugate gradient algorithm.

Conjugate gradient algorithm is used to solve the linear system of equations \(A x = b\). The algorithm is described in detail in the paper Conjugate Gradient Method.

Note

Increasing num_steps will lead to a more accurate approximation to \(A^{-1} b\), and possibly slightly-improved performance, but at the cost of slowing things down. Also probably don’t play with this hyperparameter.

Parameters:
  • fisher_product (Callable[[torch.Tensor], torch.Tensor]) – Fisher information matrix vector product.

  • vector_b (torch.Tensor) – The vector \(b\) in the equation \(A x = b\).

  • num_steps (int, optional) – The number of steps to run the algorithm for. Defaults to 10.

  • residual_tol (float, optional) – The tolerance for the residual. Defaults to 1e-10.

  • eps (float, optional) – A small number to avoid dividing by zero. Defaults to 1e-6.

Returns:

The vector x in the equation Ax=b.

Return type:

Tensor

Distribution Operations#

Documentation

class omnisafe.utils.math.SafeTanhTransformer(cache_size=0)[source]#

Safe Tanh Transformer.

This transformer is used to avoid the error caused by the input of tanh function being too large or too small.

_inverse(y)[source]#

Abstract method to compute inverse transformation.

Return type:

Tensor

class omnisafe.utils.math.TanhNormal(loc, scale)[source]#

Create a tanh-normal distribution.

(1)#\[ \begin{align}\begin{aligned}X \sim Normal(loc, scale)\\Y = tanh(X) \sim TanhNormal(loc, scale)\end{aligned}\end{align} \]

Examples

>>> m = TanhNormal(torch.tensor([0.0]), torch.tensor([1.0]))
>>> m.sample()  # tanh-normal distributed with mean=0 and stddev=1
tensor([-0.7616])
Parameters:
  • loc (float or Tensor) – The mean of the underlying normal distribution.

  • scale (float or Tensor) – The standard deviation of the underlying normal distribution.

Initialize an instance of TanhNormal.

entropy()[source]#

The entropy of the tanh normal distribution.

Return type:

Tensor

expand(batch_shape, instance=None)[source]#

Expand the distribution.

Return type:

TanhNormal

property loc: Tensor#

The mean of the normal distribution.

property mean: Tensor#

The mean of the tanh normal distribution.

property scale: Tensor#

The standard deviation of the normal distribution.

property stddev: Tensor#

The standard deviation of the tanh normal distribution.

property variance: Tensor#

The variance of the tanh normal distribution.