OmniSafe Tools#
|
This function is used to get the flattened parameters from the model. |
|
This function is used to get the flattened gradients from the model. |
|
This function is used to set the parameters to the model. |
|
This function is used to convert the custom configurations to dict. |
|
Updater of multi-level dictionary. |
|
Get the default kwargs from |
|
Check whether config is valid in default_config. |
|
This function is used to set the random seed for all the packages. |
Algorithms Tools#
Documentation
- omnisafe.utils.tools.get_flat_params_from(model)[source]#
This function is used to get the flattened parameters from the model.
Note
Some algorithms need to get the flattened parameters from the model, such as the
TRPO
andCPO
algorithm. In these algorithms, the parameters are flattened and then used to calculate the loss.Examples
>>> model = torch.nn.Linear(2, 2) >>> model.weight.data = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) >>> get_flat_params_from(model) tensor([1., 2., 3., 4.])
- Parameters:
model (torch.nn.Module) – model to be flattened.
- Returns:
Flattened parameters.
- Raises:
AssertionError – If no gradients were found in model parameters.
- Return type:
Tensor
- omnisafe.utils.tools.get_flat_gradients_from(model)[source]#
This function is used to get the flattened gradients from the model.
Note
Some algorithms need to get the flattened gradients from the model, such as the
TRPO
andCPO
algorithm. In these algorithms, the gradients are flattened and then used to calculate the loss.- Parameters:
model (torch.nn.Module) – The model to be flattened.
- Returns:
Flattened gradients.
- Raises:
AssertionError – If no gradients were found in model parameters.
- Return type:
Tensor
- omnisafe.utils.tools.set_param_values_to_model(model, vals)[source]#
This function is used to set the parameters to the model.
Note
Some algorithms (e.g. TRPO, CPO, etc.) need to set the parameters to the model, instead of using the
optimizer.step()
.Examples
>>> model = torch.nn.Linear(2, 2) >>> model.weight.data = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) >>> vals = torch.tensor([1.0, 2.0, 3.0, 4.0]) >>> set_param_values_to_model(model, vals) >>> model.weight.data tensor([[1., 2.], [3., 4.]])
- Parameters:
model (torch.nn.Module) – The model to be set.
vals (torch.Tensor) – The parameters to be set.
- Raises:
AssertionError – If the instance of the parameters is not
torch.Tensor
, or the lengths of the parameters and the model parameters do not match.- Return type:
None
Config Tools#
Documentation
- omnisafe.utils.tools.custom_cfgs_to_dict(key_list, value)[source]#
This function is used to convert the custom configurations to dict.
Note
This function is used to convert the custom configurations to dict. For example, if the custom configurations are
train_cfgs:use_wandb
andTrue
, then the output dict will be{'train_cfgs': {'use_wandb': True}}
.- Parameters:
key_list (str) – list of keys.
value (Any) – value.
- Returns:
The converted dict.
- Return type:
dict
[str
,Any
]
- omnisafe.utils.tools.update_dict(total_dict, item_dict)[source]#
Updater of multi-level dictionary.
- Parameters:
total_dict (dict[str, Any]) – The total dictionary.
item_dict (dict[str, Any]) – The item dictionary.
- Return type:
None
Examples
>>> total_dict = {'a': {'b': 1, 'c': 2}} >>> item_dict = {'a': {'b': 3, 'd': 4}} >>> update_dict(total_dict, item_dict) >>> total_dict {'a': {'b': 3, 'c': 2, 'd': 4}}
- omnisafe.utils.tools.load_yaml(path)[source]#
Get the default kwargs from
yaml
file.Note
This function search the
yaml
file by the algorithm name and environment name. Make sure your new implemented algorithm or environment has the same name as the yaml file.- Parameters:
path (str) – The path of the
yaml
file.- Returns:
The default kwargs.
- Raises:
AssertionError – If the
yaml
file is not found.- Return type:
dict
[str
,Any
]
- omnisafe.utils.tools.recursive_check_config(config, default_config, exclude_keys=())[source]#
Check whether config is valid in default_config.
- Parameters:
config (dict[str, Any]) – The config to be checked.
default_config (dict[str, Any]) – The default config.
exclude_keys (tuple of str, optional) – The keys to be excluded. Defaults to ().
- Raises:
AssertionError – If the type of the value is not the same as the default value.
KeyError – If the key is not in default_config.
- Return type:
None
Seed Tools#
Documentation
- omnisafe.utils.tools.seed_all(seed)[source]#
This function is used to set the random seed for all the packages.
Hint
To reproduce the results, you need to set the random seed for all the packages. Including
numpy
,random
,torch
,torch.cuda
,torch.backends.cudnn
.Warning
If you want to use the
torch.backends.cudnn.benchmark
ortorch.backends.cudnn.deterministic
and yourcuda
version is over 10.2, you need to set theCUBLAS_WORKSPACE_CONFIG
andPYTHONHASHSEED
environment variables.- Parameters:
seed (int) – The random seed.
- Return type:
None