OmniSafe Distributed#
Setup the distributed training environment. |
|
|
Get the rank of calling process. |
Count active MPI processes. |
|
|
The entrance method of multi-processing. |
|
Average a torch tensor over MPI processes. |
|
Average contents of gradient buffers across MPI processes. |
|
Sync all parameters of module across all MPI processes. |
|
Average contents of all parameters across MPI processes. |
|
Average a tensor over distributed processes. |
|
Sum a tensor over distributed processes. |
|
Determine global maximum of tensor over distributed processes. |
|
Determine global minimum of tensor over distributed processes. |
|
Multi-processing operation. |
|
Get mean/std and optional min/max of scalar x across MPI processes. |
Set up distributed training#
Documentation
- omnisafe.utils.distributed.setup_distributed()[source]#
Setup the distributed training environment.
Avoid slowdowns caused by each separate process’s PyTorch, using more than its fair share of CPU resources.
- Return type:
None
- omnisafe.utils.distributed.get_rank()[source]#
Get the rank of calling process.
- Return type:
int
Examples
>>> # In process 0 >>> get_rank() 0
- Returns:
The rank of calling process.
- omnisafe.utils.distributed.world_size()[source]#
Count active MPI processes.
- Returns:
The number of active MPI processes.
- Return type:
int
- omnisafe.utils.distributed.fork(parallel, device='cpu', manual_args=None)[source]#
The entrance method of multi-processing.
Re-launches the current script with workers linked by MPI. Also, terminates the original process that launched it. Taken almost without modification from the Baselines function of the same name.
- Parameters:
parallel (int) – The number of processes to launch.
device (str, optional) – The device to be used. Defaults to ‘cpu’.
manual_args (list of str or None, optional) – The arguments to be passed to the new processes. Defaults to None.
- Return type:
bool
Tensor Operations#
Documentation
- omnisafe.utils.distributed.avg_tensor(value)[source]#
Average a torch tensor over MPI processes.
Since torch and numpy share same memory space, tensors of dim > 0 can be be manipulated through call by reference, scalars must be assigned.
Examples
>>> # In process 0 >>> x = torch.tensor(1.0) >>> # In process 1 >>> x = torch.tensor(2.0) >>> avg_tensor(x) >>> x tensor(1.5)
- Parameters:
value (torch.Tensor) – The value to be averaged.
- Return type:
None
- omnisafe.utils.distributed.avg_grads(module)[source]#
Average contents of gradient buffers across MPI processes.
Note
This function only works when the training is multi-processing.
Examples
>>> # In process 0 >>> x = torch.tensor(1.0, requires_grad=True) >>> y = x ** 2 >>> y.backward() >>> x.grad tensor(2.) >>> # In process 1 >>> x = torch.tensor(2.0, requires_grad=True) >>> y = x ** 2 >>> y.backward() >>> x.grad tensor(4.) >>> avg_grads(x) >>> x.grad tensor(3.)
- Parameters:
module (torch.nn.Module) – The module in which grad need to be averaged.
- Return type:
None
- omnisafe.utils.distributed.sync_params(module)[source]#
Sync all parameters of module across all MPI processes.
Note
This function only works when the training is multi-processing.
Examples
>>> # In process 0 >>> model = torch.nn.Linear(1, 1) >>> model.weight.data = torch.tensor([[1.]]) >>> model.weight.data tensor([[1.]]) >>> # In process 1 >>> model = torch.nn.Linear(1, 1) >>> model.weight.data = torch.tensor([[2.]]) >>> model.weight.data tensor([[2.]]) >>> sync_params(model) >>> model.weight.data tensor([[1.]])
- Parameters:
module (torch.nn.Module) – The module to be synchronized.
- Return type:
None
- omnisafe.utils.distributed.avg_params(module)[source]#
Average contents of all parameters across MPI processes.
Examples
>>> # In process 0 >>> model = torch.nn.Linear(1, 1) >>> model.weight.data = torch.tensor([[1.]]) >>> model.weight.data tensor([[1.]]) >>> # In process 1 >>> model = torch.nn.Linear(1, 1) >>> model.weight.data = torch.tensor([[2.]]) >>> model.weight.data tensor([[2.]]) >>> avg_params(model) >>> model.weight.data tensor([[1.5]])
- Parameters:
module (torch.nn.Module) – The module in which parameters need to be averaged.
- Return type:
None
Distributed Operations#
Documentation
- omnisafe.utils.distributed.dist_avg(value)[source]#
Average a tensor over distributed processes.
Examples
>>> # In process 0 >>> x = torch.tensor(1.0) >>> # In process 1 >>> x = torch.tensor(2.0) >>> dist_avg(x) tensor(1.5)
- Parameters:
value (np.ndarray, torch.Tensor, int, or float) – value to be averaged.
- Returns:
Averaged tensor.
- Return type:
torch.Tensor
- omnisafe.utils.distributed.dist_sum(value)[source]#
Sum a tensor over distributed processes.
Examples
>>> # In process 0 >>> x = torch.tensor(1.0) >>> # In process 1 >>> x = torch.tensor(2.0) >>> dist_sum(x) tensor(3.)
- Parameters:
value (np.ndarray, torch.Tensor, int, or float) – The value to be summed.
- Returns:
Summed tensor.
- Return type:
torch.Tensor
- omnisafe.utils.distributed.dist_max(value)[source]#
Determine global maximum of tensor over distributed processes.
Examples
>>> # In process 0 >>> x = torch.tensor(1.0) >>> # In process 1 >>> x = torch.tensor(2.0) >>> dist_max(x) tensor(2.)
- Parameters:
value (np.ndarray, torch.Tensor, int, or float) – value to be find max value.
- Returns:
Maximum tensor.
- Return type:
torch.Tensor
- omnisafe.utils.distributed.dist_min(value)[source]#
Determine global minimum of tensor over distributed processes.
Examples
>>> # In process 0 >>> x = torch.tensor(1.0) >>> # In process 1 >>> x = torch.tensor(2.0) >>> dist_min(x) tensor(1.)
- Parameters:
value (np.ndarray, torch.Tensor, int, or float) – value to be find min value.
- Returns:
Minimum tensor.
- Return type:
torch.Tensor
- omnisafe.utils.distributed.dist_op(value, operation)[source]#
Multi-processing operation.
Note
The operation can be
ReduceOp.SUM
,ReduceOp.MAX
,ReduceOp.MIN
. corresponding todist_sum()
,dist_max()
,dist_min()
, respectively.- Parameters:
value (np.ndarray, torch.Tensor, int, or float) – The value to be operated.
operation (ReduceOp) – operation type.
- Returns:
Operated (SUM, MAX, MIN)
- Return type:
torch.Tensor
- omnisafe.utils.distributed.dist_statistics_scalar(value, with_min_and_max=False)[source]#
Get mean/std and optional min/max of scalar x across MPI processes.
Examples
>>> # In process 0 >>> x = torch.tensor(1.0) >>> # In process 1 >>> x = torch.tensor(2.0) >>> dist_statistics_scalar(x) (tensor(1.5), tensor(0.5))
- Parameters:
value (torch.Tensor) – Value to be operated.
with_min_and_max (bool, optional) – whether to return min and max. Defaults to False.
- Returns:
A tuple of the [mean, std] or [mean, std, min, max] of the input tensor.
- Return type:
tuple
[Tensor
,...
]