Utilities#
Configurations#
- class monai.config.deviceconfig.IgniteInfo[source]#
Config information of the PyTorch ignite package.
- monai.config.deviceconfig.get_system_info()[source]#
Get system info as an ordered dictionary.
- Return type:
OrderedDict
- monai.config.deviceconfig.print_config(file=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>)[source]#
Print the package versions to file.
- Parameters:
file – print() text stream file. Defaults to sys.stdout.
- monai.config.deviceconfig.print_debug_info(file=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>)[source]#
Print config (installed dependencies, etc.) and system info for debugging.
- Parameters:
file (
TextIO
) – print() text stream file. Defaults to sys.stdout.- Return type:
None
Module utils#
- exception monai.utils.module.InvalidPyTorchVersionError(required_version, name)[source]#
Raised when called function or method requires a more recent PyTorch version than that installed.
- exception monai.utils.module.OptionalImportError[source]#
Could not import APIs from an optional dependency.
- monai.utils.module.damerau_levenshtein_distance(s1, s2)[source]#
Calculates the Damerau–Levenshtein distance between two strings for spelling correction. https://en.wikipedia.org/wiki/Damerau–Levenshtein_distance
- Return type:
int
- monai.utils.module.exact_version(the_module, version_str='', *_args)[source]#
Returns True if the module’s __version__ matches version_str
- Return type:
bool
- monai.utils.module.export(modname)[source]#
Make the decorated object a member of the named module. This will also add the object under its aliases if it has a __aliases__ member, thus this decorator should be before the alias decorator to pick up those names. Alias names which conflict with package names or existing members will be ignored.
- monai.utils.module.get_full_type_name(typeobj)[source]#
Utility to get the full path name of a class or object type.
- monai.utils.module.get_package_version(dep_name, default='NOT INSTALLED or UNKNOWN VERSION.')[source]#
Try to load package and get version. If not found, return default.
- monai.utils.module.get_torch_version_tuple()#
- Returns:
tuple of ints represents the pytorch major/minor version.
- monai.utils.module.instantiate(__path, __mode, **kwargs)[source]#
Create an object instance or call a callable object from a class or function represented by
_path
. kwargs will be part of the input arguments to the class constructor or function. The target component must be a class or a function, if not, return the component directly.- Parameters:
__path (
str
) – if a string is provided, it’s interpreted as the full path of the target class or function component. If a callable is provided,__path(**kwargs)
will be invoked and returned for__mode="default"
. For__mode="callable"
, the callable will be returned as__path
or, ifkwargs
are provided, asfunctools.partial(__path, **kwargs)
for future invoking.__mode (
str
) –the operating mode for invoking the (callable)
component
represented by__path
:"default"
: returnscomponent(**kwargs)
"callable"
: returnscomponent
or, ifkwargs
are provided,functools.partial(component, **kwargs)
"debug"
: returnspdb.runcall(component, **kwargs)
kwargs (
Any
) – keyword arguments to the callable represented by__path
.
- Return type:
Any
- monai.utils.module.load_submodules(basemod, load_all=True, exclude_pattern='(.*[tT]est.*)|(_.*)')[source]#
Traverse the source of the module structure starting with module basemod, loading all packages plus all files if load_all is True, excluding anything whose name matches exclude_pattern.
- Return type:
tuple
[list
[module
],list
[str
]]
- monai.utils.module.look_up_option(opt_str, supported, default='no_default', print_all_options=True)[source]#
Look up the option in the supported collection and return the matched item. Raise a value error possibly with a guess of the closest match.
- Parameters:
opt_str – The option string or Enum to look up.
supported – The collection of supported options, it can be list, tuple, set, dict, or Enum.
default – If it is given, this method will return default when opt_str is not found, instead of raising a ValueError. Otherwise, it defaults to “no_default”, so that the method may raise a ValueError.
print_all_options – whether to print all available options when opt_str is not found. Defaults to True
Examples:
from enum import Enum from monai.utils import look_up_option class Color(Enum): RED = "red" BLUE = "blue" look_up_option("red", Color) # <Color.RED: 'red'> look_up_option(Color.RED, Color) # <Color.RED: 'red'> look_up_option("read", Color) # ValueError: By 'read', did you mean 'red'? # 'read' is not a valid option. # Available options are {'blue', 'red'}. look_up_option("red", {"red", "blue"}) # "red"
Adapted from NifTK/NiftyNet
- monai.utils.module.min_version(the_module, min_version_str='', *_args)[source]#
Convert version strings into tuples of int and compare them.
Returns True if the module’s version is greater or equal to the ‘min_version’. When min_version_str is not provided, it always returns True.
- Return type:
bool
- monai.utils.module.optional_import(module, version='', version_checker=<function min_version>, name='', descriptor='{}', version_args=None, allow_namespace_pkg=False, as_type='default')[source]#
Imports an optional module specified by module string. Any importing related exceptions will be stored, and exceptions raise lazily when attempting to use the failed-to-import module.
- Parameters:
module (
str
) – name of the module to be imported.version (
str
) – version string used by the version_checker.version_checker (
Callable
[…,bool
]) – a callable to check the module version, Defaults to monai.utils.min_version.name (
str
) – a non-module attribute (such as method/class) to import from the imported module.descriptor (
str
) – a format string for the final error message when using a not imported module.version_args (
Optional
[Any
]) – additional parameters to the version checker.allow_namespace_pkg (
bool
) – whether importing a namespace package is allowed. Defaults to False.as_type (
str
) – there are cases where the optionally imported object is used as a base class, or a decorator, the exceptions should raise accordingly. The current supported values are “default” (call once to raise), “decorator” (call the constructor and the second call to raise), and anything else will return a lazy class that can be used as a base class (call the constructor to raise).
- Return type:
tuple
[Any
,bool
]- Returns:
The imported module and a boolean flag indicating whether the import is successful.
Examples:
>>> torch, flag = optional_import('torch', '1.1') >>> print(torch, flag) <module 'torch' from 'python/lib/python3.6/site-packages/torch/__init__.py'> True >>> the_module, flag = optional_import('unknown_module') >>> print(flag) False >>> the_module.method # trying to access a module which is not imported OptionalImportError: import unknown_module (No module named 'unknown_module'). >>> torch, flag = optional_import('torch', '42', exact_version) >>> torch.nn # trying to access a module for which there isn't a proper version imported OptionalImportError: import torch (requires version '42' by 'exact_version'). >>> conv, flag = optional_import('torch.nn.functional', '1.0', name='conv1d') >>> print(conv) <built-in method conv1d of type object at 0x11a49eac0> >>> conv, flag = optional_import('torch.nn.functional', '42', name='conv1d') >>> conv() # trying to use a function from the not successfully imported module (due to unmatched version) OptionalImportError: from torch.nn.functional import conv1d (requires version '42' by 'min_version').
- monai.utils.module.pytorch_after(major, minor, patch=0, current_ver_string=None)#
Compute whether the current pytorch version is after or equal to the specified version. The current system pytorch version is determined by torch.__version__ or via system environment variable PYTORCH_VER.
- Parameters:
major – major version number to be compared with
minor – minor version number to be compared with
patch – patch version number to be compared with
current_ver_string – if None, torch.__version__ will be used.
- Returns:
True if the current pytorch version is greater than or equal to the specified version.
- monai.utils.module.require_pkg(pkg_name, version='', version_checker=<function min_version>, raise_error=True)[source]#
Decorator function to check the required package installation.
- Parameters:
pkg_name (
str
) – required package name, like: “itk”, “nibabel”, etc.version (
str
) – required version string used by the version_checker.version_checker (
Callable
[…,bool
]) – a callable to check the module version, defaults to monai.utils.min_version.raise_error (
bool
) – if True, raise OptionalImportError error if the required package is not installed or the version doesn’t match requirement, if False, print the error in a warning.
- Return type:
Callable
- monai.utils.module.version_geq(lhs, rhs)[source]#
Returns True if version lhs is later or equal to rhs.
- Parameters:
lhs (
str
) – version name to compare with rhs, return True if later or equal to rhs.rhs (
str
) – version name to compare with lhs, return True if earlier or equal to lhs.
- Return type:
bool
- monai.utils.module.version_leq(lhs, rhs)[source]#
Returns True if version lhs is earlier or equal to rhs.
- Parameters:
lhs (
str
) – version name to compare with rhs, return True if earlier or equal to rhs.rhs (
str
) – version name to compare with lhs, return True if later or equal to lhs.
- Return type:
bool
Aliases#
This module is written for configurable workflow, not currently in use.
- monai.utils.aliases.alias(*names)[source]#
Stores the decorated function or class in the global aliases table under the given names and as the __aliases__ member of the decorated object. This new member will contain all alias names declared for that object.
- monai.utils.aliases.resolve_name(name)[source]#
Search for the declaration (function or class) with the given name. This will first search the list of aliases to see if it was declared with this aliased name, then search treating name as a fully qualified name, then search the loaded modules for one having a declaration with the given name. If no declaration is found, raise ValueError.
- Raises:
ValueError – When the module is not found.
ValueError – When the module does not have the specified member.
ValueError – When multiple modules with the declaration name are found.
ValueError – When no module with the specified member is found.
Misc#
- class monai.utils.misc.ConvertUnits(input_unit, target_unit)[source]#
Convert the values from input unit to the target unit
- Parameters:
input_unit (
str
) – the unit of the input quantitytarget_unit (
str
) – the unit of the target quantity
- monai.utils.misc.check_key_duplicates(ordered_pairs)[source]#
Checks if there is a duplicated key in the sequence of ordered_pairs. If there is - it will log a warning or raise ValueError (if configured by environmental var MONAI_FAIL_ON_DUPLICATE_CONFIG==1)
Otherwise, it returns the dict made from this sequence.
Satisfies a format for an object_pairs_hook in json.load
- Parameters:
ordered_pairs (
Sequence
[tuple
[Any
,Any
]]) – sequence of (key, value)- Return type:
dict
[Any
,Any
]
- monai.utils.misc.check_kwargs_exist_in_class_init(cls, kwargs)[source]#
Check if the all keys in kwargs exist in the __init__ method of the class.
- Parameters:
cls – the class to check.
kwargs – kwargs to examine.
- Returns:
a boolean indicating if all keys exist. a set of extra keys that are not used in the __init__.
- monai.utils.misc.check_parent_dir(path, create_dir=True)[source]#
Utility to check whether the parent directory of the path exists.
- Parameters:
path (
Union
[str
,PathLike
]) – input path to check the parent directory.create_dir (
bool
) – if True, when the parent directory doesn’t exist, create the directory, otherwise, raise exception.
- Return type:
None
- monai.utils.misc.copy_to_device(obj, device, non_blocking=True, verbose=False)[source]#
Copy object or tuple/list/dictionary of objects to
device
.- Parameters:
obj – object or tuple/list/dictionary of objects to move to
device
.device – move
obj
to this device. Can be a string (e.g.,cpu
,cuda
,cuda:0
, etc.) or of typetorch.device
.non_blocking – when True, moves data to device asynchronously if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.
verbose – when True, will print a warning for any elements of incompatible type not copied to
device
.
- Returns:
- Same as input, copied to
device
where possible. Original input will be unchanged.
- Same as input, copied to
- monai.utils.misc.ensure_tuple(vals, wrap_array=False)[source]#
Returns a tuple of vals.
- Parameters:
vals (
Any
) – input data to convert to a tuple.wrap_array (
bool
) – if True, treat the input numerical array (ndarray/tensor) as one item of the tuple. if False, try to convert the array with tuple(vals), default to False.
- Return type:
tuple
- monai.utils.misc.ensure_tuple_rep(tup, dim)[source]#
Returns a copy of tup with dim values by either shortened or duplicated input.
- Raises:
ValueError – When
tup
is a sequence andtup
length is notdim
.
Examples:
>>> ensure_tuple_rep(1, 3) (1, 1, 1) >>> ensure_tuple_rep(None, 3) (None, None, None) >>> ensure_tuple_rep('test', 3) ('test', 'test', 'test') >>> ensure_tuple_rep([1, 2, 3], 3) (1, 2, 3) >>> ensure_tuple_rep(range(3), 3) (0, 1, 2) >>> ensure_tuple_rep([1, 2], 3) ValueError: Sequence must have length 3, got length 2.
- Return type:
tuple
[Any
, …]
- monai.utils.misc.ensure_tuple_size(vals, dim, pad_val=0, pad_from_start=False)[source]#
Returns a copy of tup with dim values by either shortened or padded with pad_val as necessary.
- Return type:
tuple
- monai.utils.misc.fall_back_tuple(user_provided, default, func=<function <lambda>>)[source]#
Refine user_provided according to the default, and returns as a validated tuple.
The validation is done for each element in user_provided using func. If func(user_provided[idx]) returns False, the corresponding default[idx] will be used as the fallback.
Typically used when user_provided is a tuple of window size provided by the user, default is defined by data, this function returns an updated user_provided with its non-positive components replaced by the corresponding components from default.
- Parameters:
user_provided – item to be validated.
default – a sequence used to provided the fallbacks.
func – a Callable to validate every components of user_provided.
Examples:
>>> fall_back_tuple((1, 2), (32, 32)) (1, 2) >>> fall_back_tuple(None, (32, 32)) (32, 32) >>> fall_back_tuple((-1, 10), (32, 32)) (32, 10) >>> fall_back_tuple((-1, None), (32, 32)) (32, 32) >>> fall_back_tuple((1, None), (32, 32)) (1, 32) >>> fall_back_tuple(0, (32, 32)) (32, 32) >>> fall_back_tuple(range(3), (32, 64, 48)) (32, 1, 2) >>> fall_back_tuple([0], (32, 32)) ValueError: Sequence must have length 2, got length 1.
- monai.utils.misc.first(iterable: Iterable[T], default: T) T [source]#
- monai.utils.misc.first(iterable: Iterable[T]) T | None
Returns the first item in the given iterable or default if empty, meaningful mostly with ‘for’ expressions.
- monai.utils.misc.has_option(obj, keywords)[source]#
Return a boolean indicating whether the given callable obj has the keywords in its signature.
- monai.utils.misc.is_immutable(obj)[source]#
Determine if the object is an immutable object.
see also python/cpython
- Return type:
bool
- monai.utils.misc.is_module_ver_at_least(module, version)[source]#
Determine if a module’s version is at least equal to the given value.
- Parameters:
module – imported module’s name, e.g., np or torch.
version – required version, given as a tuple, e.g., (1, 8, 0).
- Returns:
True if module is the given version or newer.
- monai.utils.misc.issequenceiterable(obj)[source]#
Determine if the object is an iterable sequence and is not a string.
- Return type:
bool
- monai.utils.misc.label_union(x)[source]#
Compute the union of class IDs in label and generate a list to include all class IDs :param x: a list of numbers (for example, class_IDs)
- Returns
a list showing the union (the union the class IDs)
- monai.utils.misc.list_to_dict(items)[source]#
To convert a list of “key=value” pairs into a dictionary. For examples: items: [“a=1”, “b=2”, “c=3”], return: {“a”: “1”, “b”: “2”, “c”: “3”}. If no “=” in the pair, use None as the value, for example: [“a”], return: {“a”: None}. Note that it will remove the blanks around keys and values.
- monai.utils.misc.path_to_uri(path)[source]#
Convert a file path to URI. if not absolute path, will convert to absolute path first.
- Parameters:
path (
Union
[str
,PathLike
]) – input file path to convert, can be a string or Path object.- Return type:
str
- monai.utils.misc.pprint_edges(val, n_lines=20)[source]#
Pretty print the head and tail
n_lines
ofval
, and omit the middle part if the part has more than 3 lines.Returns: the formatted string.
- Return type:
str
- monai.utils.misc.progress_bar(index, count, desc=None, bar_len=30, newline=False)[source]#
print a progress bar to track some time consuming task.
- Parameters:
index – current status in progress.
count – total steps of the progress.
desc – description of the progress bar, if not None, show before the progress bar.
bar_len – the total length of the bar on screen, default is 30 char.
newline – whether to print in a new line for every index.
- monai.utils.misc.run_cmd(cmd_list, **kwargs)[source]#
Run a command by using
subprocess.run
with capture_output=True and stderr=subprocess.STDOUT so that the raise exception will have that information. The argument capture_output can be set explicitly if desired, but will be overriden with the debug status from the variable.- Parameters:
cmd_list (
list
[str
]) – a list of strings describing the command to run.kwargs (
Any
) – keyword arguments supported by thesubprocess.run
method.
- Return type:
CompletedProcess
- Returns:
a CompletedProcess instance after the command completes.
- monai.utils.misc.sample_slices(data, dim=1, as_indices=True, *slicevals)[source]#
sample several slices of input numpy array or Tensor on specified dim.
- Parameters:
data (
Union
[ndarray
,Tensor
]) – input data to sample slices, can be numpy array or PyTorch Tensor.dim (
int
) – expected dimension index to sample slices, default to 1.as_indices (
bool
) – if True, slicevals arg will be treated as the expected indices of slice, like: 1, 3, 5 means data[…, [1, 3, 5], …], if False, slicevals arg will be treated as args for slice func, like: 1, None means data[…, [1:], …], 1, 5 means data[…, [1: 5], …].slicevals (
int
) – indices of slices or start and end indices of expected slices, depends on as_indices flag.
- Return type:
Union
[ndarray
,Tensor
]
- monai.utils.misc.save_obj(obj, path, create_dir=True, atomic=True, func=None, **kwargs)[source]#
Save an object to file with specified path. Support to serialize to a temporary file first, then move to final destination, so that files are guaranteed to not be damaged if exception occurs.
- Parameters:
obj – input object data to save.
path – target file path to save the input object.
create_dir – whether to create dictionary of the path if not existing, default to True.
atomic – if True, state is serialized to a temporary file first, then move to final destination. so that files are guaranteed to not be damaged if exception occurs. default to True.
func – the function to save file, if None, default to torch.save.
kwargs – other args for the save func except for the checkpoint and filename. default func is torch.save(), details of other args: https://pytorch.org/docs/stable/generated/torch.save.html.
- monai.utils.misc.set_determinism(seed=4294967295, use_deterministic_algorithms=None, additional_settings=None)[source]#
Set random seed for modules to enable or disable deterministic training.
- Parameters:
seed – the random seed to use, default is np.iinfo(np.int32).max. It is recommended to set a large seed, i.e. a number that has a good balance of 0 and 1 bits. Avoid having many 0 bits in the seed. if set to None, will disable deterministic training.
use_deterministic_algorithms – Set whether PyTorch operations must use “deterministic” algorithms.
additional_settings – additional settings that need to set random seed.
Note
This function will not affect the randomizable objects in
monai.transforms.Randomizable
, which have independent random states. For those objects, theset_random_state()
method should be used to ensure the deterministic behavior (alternatively,monai.data.DataLoader
by default sets the seeds according to the global random state, please see also:monai.data.utils.worker_init_fn
andmonai.data.utils.set_rnd
).
- monai.utils.misc.str2bool(value, default=False, raise_exc=True)[source]#
Convert a string to a boolean. Case insensitive. True: yes, true, t, y, 1. False: no, false, f, n, 0.
- Parameters:
value – string to be converted to a boolean. If value is a bool already, simply return it.
raise_exc – if value not in tuples of expected true or false inputs, should we raise an exception? If not, return default.
- Raises
- ValueError: value not in tuples of expected true or false inputs and
raise_exc is True.
- Useful with argparse, for example:
parser.add_argument(”–convert”, default=False, type=str2bool) python mycode.py –convert=True
- monai.utils.misc.str2list(value, raise_exc=True)[source]#
- Convert a string to a list. Useful with argparse commandline arguments:
parser.add_argument(”–blocks”, default=[1,2,3], type=str2list) python mycode.py –blocks=1,2,2,4
- Parameters:
value – string (comma separated) to be converted to a list
raise_exc – if not possible to convert to a list, raise an exception
- Raises
ValueError: value not a string or list or not possible to convert
- monai.utils.misc.to_tuple_of_dictionaries(dictionary_of_tuples, keys)[source]#
Given a dictionary whose values contain scalars or tuples (with the same length as
keys
), Create a dictionary for each key containing the scalar values mapping to that key.- Parameters:
dictionary_of_tuples (
dict
) – a dictionary whose values are scalars or tuples whose length is the length ofkeys
keys (
Any
) – a tuple of string values representing the keys in question
- Return type:
tuple
[dict
[Any
,Any
], …]- Returns:
a tuple of dictionaries that contain scalar values, one dictionary for each key
- Raises:
ValueError – when values in the dictionary are tuples but not the same length as the length
of keys –
Examples
>>> to_tuple_of_dictionaries({'a': 1 'b': (2, 3), 'c': (4, 4)}, ("x", "y")) ({'a':1, 'b':2, 'c':4}, {'a':1, 'b':3, 'c':4})
NVTX Annotations#
Decorators and context managers for NVIDIA Tools Extension to profile MONAI components
- class monai.utils.nvtx.Range(name=None, methods=None, append_method_name=None, recursive=False)[source]#
A decorator and context manager for NVIDIA Tools Extension (NVTX) Range for profiling. When used as a decorator it encloses a specific method of the object with an NVTX Range. When used as a context manager, it encloses the runtime context (created by with statement) with an NVTX Range.
- Parameters:
name – the name to be associated to the range
methods – (only when used as decorator) the name of a method (or a list of the name of the methods) to be wrapped by NVTX range. If None (default), the method(s) will be inferred based on the object’s type for various MONAI components, such as Networks, Losses, Functions, Transforms, and Datasets. Otherwise, it look up predefined methods: “forward”, “__call__”, “__next__”, “__getitem__”
append_method_name – if append the name of the methods to be decorated to the range’s name If None (default), it appends the method’s name only if we are annotating more than one method.
recursive – if set to True, it will recursively annotate every individual module in a list or in a chain of modules (chained using Compose). Default to False.
Profiling#
- class monai.utils.profiling.PerfContext[source]#
Context manager for tracking how much time is spent within context blocks. This uses time.perf_counter to accumulate the total amount of time in seconds in the attribute total_time over however many context blocks the object is used in.
- class monai.utils.profiling.ProfileHandler(name, profiler, start_event, end_event)[source]#
Handler for Ignite Engine classes which measures the time from a start event ton an end event. This can be used to profile epoch, iteration, and other events as defined in ignite.engine.Events. This class should be used only within the context of a profiler object.
- Parameters:
name – name of event to profile
profiler – instance of WorkflowProfiler used by the handler, should be within the context of this object
start_event – item in ignite.engine.Events stating event at which to start timing
end_event – item in ignite.engine.Events stating event at which to stop timing
- class monai.utils.profiling.WorkflowProfiler(call_selector=<function select_transform_call>)[source]#
Profiler for timing all aspects of a workflow. This includes using stack tracing to capture call times for all selected calls (by default calls to Transform.__call__ methods), times within context blocks, times to generate items from iterables, and times to execute decorated functions.
This profiler must be used only within its context because it uses an internal thread to read results from a multiprocessing queue. This allows the profiler to function across multiple threads and processes, though the multiprocess tracing is at times unreliable and not available in Windows at all.
The profiler uses sys.settrace and threading.settrace to find all calls to profile, this will be set when the context enters and cleared when it exits so proper use of the context is essential to prevent excessive tracing. Note that tracing has a high overhead so times will not accurately reflect real world performance but give an idea of relative share of time spent.
The tracing functionality uses a selector to choose which calls to trace, since tracing all calls induces infinite loops and would be terribly slow even if not. This selector is a callable accepting a call trace frame and returns True if the call should be traced. The default is select_transform_call which will return True for Transform.__call__ calls only.
Example showing use of all profiling functions:
import monai.transform as mt from monai.utils import WorkflowProfiler import torch comp=mt.Compose([mt.ScaleIntensity(),mt.RandAxisFlip(0.5)]) with WorkflowProfiler() as wp: for _ in wp.profile_iter("range",range(5)): with wp.profile_ctx("Loop"): for i in range(10): comp(torch.rand(1,16,16)) @wp.profile_callable() def foo(): pass foo() foo() print(wp.get_times_summary_pd()) # print results
- Parameters:
call_selector – selector to determine which calls to trace, use None to disable tracing
- add_result(result)[source]#
Add a result in a thread-safe manner to the internal results dictionary.
- Return type:
None
- dump_csv(stream=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>)[source]#
Save all results to a csv file.
- get_results()[source]#
Get a fresh results dictionary containing fresh tuples of ProfileResult objects.
- get_times_summary(times_in_s=True)[source]#
Returns a dictionary mapping results entries to tuples containing the number of items, time sum, time average, time std dev, time min, and time max.
- get_times_summary_pd(times_in_s=True)[source]#
Returns the same information as get_times_summary but in a Pandas DataFrame.
- profile_callable(name=None)[source]#
Decorator which can be applied to a function which profiles any calls to it. All calls to decorated callables must be done within the context of the profiler.
- profile_ctx(name, caller=None)#
Creates a context to profile, placing a timing result onto the queue when it exits.
- monai.utils.profiling.select_transform_call(frame)[source]#
Returns True if frame is a call to a Transform object’s _call__ method.
- monai.utils.profiling.torch_profiler_full(func)[source]#
A decorator which will run the torch profiler for the decorated function, printing the results in full. Note: Enforces a gpu sync point which could slow down pipelines.
Deprecated#
- monai.utils.deprecate_utils.deprecated(since=None, removed=None, msg_suffix='', version_val='0+untagged.50.g59a7211.dirty', warning_category=<class 'FutureWarning'>)[source]#
Marks a function or class as deprecated. If since is given this should be a version at or earlier than the current version and states at what version of the definition was marked as deprecated. If removed is given this can be any version and marks when the definition was removed.
When the decorated definition is called, that is when the function is called or the class instantiated, a warning_category is issued if since is given and the current version is at or later than that given. a DeprecatedError exception is instead raised if removed is given and the current version is at or later than that, or if neither since nor removed is provided.
The relevant docstring of the deprecating function should also be updated accordingly, using the Sphinx directives such as .. versionchanged:: version and .. deprecated:: version. https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded
- Parameters:
since – version at which the definition was marked deprecated but not removed.
removed – version at which the definition was/will be removed and no longer usable.
msg_suffix – message appended to warning/exception detailing reasons for deprecation and what to use instead.
version_val – (used for testing) version to compare since and removed against, default is MONAI version.
warning_category – a warning category class, defaults to FutureWarning.
- Returns:
Decorated definition which warns or raises exception when used
- monai.utils.deprecate_utils.deprecated_arg(name, since=None, removed=None, msg_suffix='', version_val='0+untagged.50.g59a7211.dirty', new_name=None, warning_category=<class 'FutureWarning'>)[source]#
Marks a particular named argument of a callable as deprecated. The same conditions for since and removed as described in the deprecated decorator.
When the decorated definition is called, that is when the function is called or the class instantiated with args, a warning_category is issued if since is given and the current version is at or later than that given. a DeprecatedError exception is instead raised if removed is given and the current version is at or later than that, or if neither since nor removed is provided.
The relevant docstring of the deprecating function should also be updated accordingly, using the Sphinx directives such as .. versionchanged:: version and .. deprecated:: version. https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded
- Parameters:
name – name of position or keyword argument to mark as deprecated.
since – version at which the argument was marked deprecated but not removed.
removed – version at which the argument was/will be removed and no longer usable.
msg_suffix – message appended to warning/exception detailing reasons for deprecation and what to use instead.
version_val – (used for testing) version to compare since and removed against, default is MONAI version.
new_name – name of position or keyword argument to replace the deprecated argument. if it is specified and the signature of the decorated function has a kwargs, the value to the deprecated argument name will be removed.
warning_category – a warning category class, defaults to FutureWarning.
- Returns:
Decorated callable which warns or raises exception when deprecated argument used.
- monai.utils.deprecate_utils.deprecated_arg_default(name, old_default, new_default, since=None, replaced=None, msg_suffix='', version_val='0+untagged.50.g59a7211.dirty', warning_category=<class 'FutureWarning'>)[source]#
Marks a particular arguments default of a callable as deprecated. It is changed from old_default to new_default in version changed.
When the decorated definition is called, a warning_category is issued if since is given, the default is not explicitly set by the caller and the current version is at or later than that given. Another warning with the same category is issued if changed is given and the current version is at or later.
The relevant docstring of the deprecating function should also be updated accordingly, using the Sphinx directives such as .. versionchanged:: version and .. deprecated:: version. https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded
- Parameters:
name – name of position or keyword argument where the default is deprecated/changed.
old_default – name of the old default. This is only for the warning message, it will not be validated.
new_default – name of the new default. It is validated that this value is not present as the default before version replaced. This means, that you can also use this if the actual default value is None and set later in the function. You can also set this to any string representation, e.g. “calculate_default_value()” if the default is calculated from another function.
since – version at which the argument default was marked deprecated but not replaced.
replaced – version at which the argument default was/will be replaced.
msg_suffix – message appended to warning/exception detailing reasons for deprecation.
version_val – (used for testing) version to compare since and removed against, default is MONAI version.
warning_category – a warning category class, defaults to FutureWarning.
- Returns:
Decorated callable which warns when deprecated default argument is not explicitly specified.
Type conversion#
- monai.utils.type_conversion.convert_data_type(data, output_type=None, device=None, dtype=None, wrap_sequence=False, safe=False)[source]#
Convert to MetaTensor, torch.Tensor or np.ndarray from MetaTensor, torch.Tensor, np.ndarray, float, int, etc.
- Parameters:
data – data to be converted
output_type – monai.data.MetaTensor, torch.Tensor, or np.ndarray (if None, unchanged)
device – if output is MetaTensor or torch.Tensor, select device (if None, unchanged)
dtype – dtype of output data. Converted to correct library type (e.g., np.float32 is converted to torch.float32 if output type is torch.Tensor). If left blank, it remains unchanged.
wrap_sequence – if False, then lists will recursively call this function. E.g., [1, 2] -> [array(1), array(2)]. If True, then [1, 2] -> array([1, 2]).
safe – if True, then do safe dtype convert when intensity overflow. default to False. E.g., [256, -12] -> [array(0), array(244)]. If True, then [256, -12] -> [array(255), array(0)].
- Returns:
modified data, orig_type, orig_device
Note
When both output_type and dtype are specified with different backend (e.g., torch.Tensor and np.float32), the output_type will be used as the primary type, for example:
>>> convert_data_type(1, torch.Tensor, dtype=np.float32) (1.0, <class 'torch.Tensor'>, None)
- monai.utils.type_conversion.convert_to_cupy(data, dtype=None, wrap_sequence=False, safe=False)[source]#
Utility to convert the input data to a cupy array. If passing a dictionary, list or tuple, recursively check every item and convert it to cupy array.
- Parameters:
data – input data can be PyTorch Tensor, numpy array, cupy array, list, dictionary, int, float, bool, str, etc. Tensor, numpy array, cupy array, float, int, bool are converted to cupy arrays, for dictionary, list or tuple, convert every item to a numpy array if applicable.
dtype – target data type when converting to Cupy array, tt must be an argument of numpy.dtype, for more details: https://docs.cupy.dev/en/stable/reference/generated/cupy.array.html.
wrap_sequence – if False, then lists will recursively call this function. E.g., [1, 2] -> [array(1), array(2)]. If True, then [1, 2] -> array([1, 2]).
safe – if True, then do safe dtype convert when intensity overflow. default to False. E.g., [256, -12] -> [array(0), array(244)]. If True, then [256, -12] -> [array(255), array(0)].
- monai.utils.type_conversion.convert_to_dst_type(src, dst, dtype=None, wrap_sequence=False, device=None, safe=False)[source]#
Convert source data to the same data type and device as the destination data. If dst is an instance of torch.Tensor or its subclass, convert src to torch.Tensor with the same data type as dst, if dst is an instance of numpy.ndarray or its subclass, convert to numpy.ndarray with the same data type as dst, otherwise, convert to the type of dst directly.
- Parameters:
src – source data to convert type.
dst – destination data that convert to the same data type as it.
dtype – an optional argument if the target dtype is different from the original dst’s data type.
wrap_sequence – if False, then lists will recursively call this function. E.g., [1, 2] -> [array(1), array(2)]. If True, then [1, 2] -> array([1, 2]).
device – target device to put the converted Tensor data. If unspecified, dst.device will be used if possible.
safe – if True, then do safe dtype convert when intensity overflow. default to False. E.g., [256, -12] -> [array(0), array(244)]. If True, then [256, -12] -> [array(255), array(0)].
See also
- monai.utils.type_conversion.convert_to_numpy(data, dtype=None, wrap_sequence=False, safe=False)[source]#
Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple, recursively check every item and convert it to numpy array.
- Parameters:
data (
Any
) – input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will convert Tensor, Numpy array, float, int, bool to numpy arrays, strings and objects keep the original. for dictionary, list or tuple, convert every item to a numpy array if applicable.dtype (
Union
[dtype
,type
,str
,None
]) – target data type when converting to numpy array.wrap_sequence (
bool
) – if False, then lists will recursively call this function. E.g., [1, 2] -> [array(1), array(2)]. If True, then [1, 2] -> array([1, 2]).safe (
bool
) – if True, then do safe dtype convert when intensity overflow. default to False. E.g., [256, -12] -> [array(0), array(244)]. If True, then [256, -12] -> [array(255), array(0)].
- Return type:
Any
- monai.utils.type_conversion.convert_to_tensor(data, dtype=None, device=None, wrap_sequence=False, track_meta=False, safe=False)[source]#
Utility to convert the input data to a PyTorch Tensor, if track_meta is True, the output will be a MetaTensor, otherwise, the output will be a regular torch Tensor. If passing a dictionary, list or tuple, recursively check every item and convert it to PyTorch Tensor.
- Parameters:
data – input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will convert Tensor, Numpy array, float, int, bool to Tensor, strings and objects keep the original. for dictionary, list or tuple, convert every item to a Tensor if applicable.
dtype – target data type to when converting to Tensor.
device – target device to put the converted Tensor data.
wrap_sequence – if False, then lists will recursively call this function. E.g., [1, 2] -> [tensor(1), tensor(2)]. If True, then [1, 2] -> tensor([1, 2]).
track_meta – whether to track the meta information, if True, will convert to MetaTensor. default to False.
safe – if True, then do safe dtype convert when intensity overflow. default to False. E.g., [256, -12] -> [tensor(0), tensor(244)]. If True, then [256, -12] -> [tensor(255), tensor(0)].
- monai.utils.type_conversion.dtype_numpy_to_torch(dtype)[source]#
Convert a numpy dtype to its torch equivalent.
- Return type:
dtype
- monai.utils.type_conversion.dtype_torch_to_numpy(dtype)[source]#
Convert a torch dtype to its numpy equivalent.
- Return type:
dtype
- monai.utils.type_conversion.get_dtype(data)[source]#
Get the dtype of an image, or if there is a sequence, recursively call the method on the 0th element.
This therefore assumes that in a Sequence, all types are the same.
- monai.utils.type_conversion.get_equivalent_dtype(dtype, data_type)[source]#
Convert to the dtype that corresponds to data_type.
The input dtype can also be a string. e.g., “float32” becomes torch.float32 or np.float32 as necessary.
Example:
im = torch.tensor(1) dtype = get_equivalent_dtype(np.float32, type(im))
Decorators#
Distributed Data Parallel#
- class monai.utils.dist.RankFilter(rank=None, filter_fn=<function RankFilter.<lambda>>)[source]#
The RankFilter class is a convenient filter that extends the Filter class in the Python logging module. The purpose is to control which log records are processed based on the rank in a distributed environment.
- Parameters:
rank – the rank of the process in the torch.distributed. Default is None and then it will use dist.get_rank().
filter_fn – an optional lambda function used as the filtering criteria. The default function logs only if the rank of the process is 0, but the user can define their own function to implement custom filtering logic.
- monai.utils.dist.evenly_divisible_all_gather(data: Tensor, concat: Literal[True]) Tensor [source]#
- monai.utils.dist.evenly_divisible_all_gather(data: Tensor, concat: Literal[False]) list[Tensor]
- monai.utils.dist.evenly_divisible_all_gather(data: Tensor, concat: bool) torch.Tensor | list[torch.Tensor]
Utility function for distributed data parallel to pad at first dim to make it evenly divisible and all_gather. The input data of every rank should have the same number of dimensions, only the first dim can be different.
Note: If has ignite installed, will execute based on ignite distributed APIs, otherwise, if the native PyTorch distributed group initialized, will execute based on native PyTorch distributed APIs.
- Parameters:
data – source tensor to pad and execute all_gather in distributed data parallel.
concat – whether to concat the gathered list to be a Tensor, if False, return a list of Tensors, similar behavior as torch.distributed.all_gather(). default to True.
Note
The input data on different ranks must have exactly same dtype.
- monai.utils.dist.get_dist_device()[source]#
Get the expected target device in the native PyTorch distributed data parallel. For NCCL backend, return GPU device of current process. For GLOO backend, return CPU. For any other backends, return None as the default, tensor.to(None) will not change the device.
- monai.utils.dist.string_list_all_gather(strings, delimiter='\\t')[source]#
Utility function for distributed data parallel to all gather a list of strings. Refer to the idea of ignite all_gather(string): https://pytorch.org/ignite/v0.4.5/distributed.html#ignite.distributed.utils.all_gather.
Note: If has ignite installed, will execute based on ignite distributed APIs, otherwise, if the native PyTorch distributed group initialized, will execute based on native PyTorch distributed APIs.
- Parameters:
strings (
list
[str
]) – a list of strings to all gather.delimiter (
str
) – use the delimiter to join the string list to be a long string, then all gather across ranks and split to a list. default to “ “.
- Return type:
list
[str
]
Enums#
- class monai.utils.enums.AlgoKeys(value)[source]#
Default keys for templated Auto3DSeg Algo. ID is the identifier of the algorithm. The string has the format of <name>_<idx>_<other>. ALGO is the Auto3DSeg Algo instance. IS_TRAINED is the status that shows if the Algo has been trained. SCORE is the score the Algo has achieved after training.
- class monai.utils.enums.BlendMode(value)[source]#
See also:
monai.data.utils.compute_importance_map
- class monai.utils.enums.BundleProperty(value)[source]#
Bundle property fields: DESC is the description of the property. REQUIRED is flag to indicate whether the property is required or optional.
- class monai.utils.enums.BundlePropertyConfig(value)[source]#
additional bundle property fields for config based bundle workflow: ID is the config item ID of the property. REF_ID is the ID of config item which is supposed to refer to this property. For properties that do not have REF_ID, None should be set. this field is only useful to check the optional property ID.
- class monai.utils.enums.ChannelMatching(value)[source]#
See also:
monai.networks.nets.HighResBlock
- class monai.utils.enums.CommonKeys(value)[source]#
A set of common keys for dictionary based supervised training process. IMAGE is the input image data. LABEL is the training or evaluation label of segmentation or classification task. PRED is the prediction data of model output. LOSS is the loss value of current iteration. INFO is some useful information during training or evaluation, like loss value, etc.
- class monai.utils.enums.CompInitMode(value)[source]#
Mode names for instantiating a class or calling a callable.
See also:
monai.utils.module.instantiate()
- class monai.utils.enums.DataStatsKeys(value)[source]#
Defaults keys for dataset statistical analysis modules
- class monai.utils.enums.EngineStatsKeys(value)[source]#
Default keys for the statistics of trainer and evaluator engines.
- class monai.utils.enums.FastMRIKeys(value)[source]#
The keys to be used for extracting data from the fastMRI dataset
- class monai.utils.enums.ForwardMode(value)[source]#
See also:
monai.transforms.engines.evaluator.Evaluator
- class monai.utils.enums.GanKeys(value)[source]#
A set of common keys for generative adversarial networks.
- class monai.utils.enums.GridPatchSort(value)[source]#
The sorting method for the generated patches in GridPatch
- class monai.utils.enums.GridSampleMode(value)[source]#
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
interpolation mode of torch.nn.functional.grid_sample
Note
(documentation from torch.nn.functional.grid_sample) mode=’bicubic’ supports only 4-D input. When mode=’bilinear’ and the input is 5-D, the interpolation mode used internally will actually be trilinear. However, when the input is 4-D, the interpolation mode will legitimately be bilinear.
- class monai.utils.enums.GridSamplePadMode(value)[source]#
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- class monai.utils.enums.HoVerNetBranch(value)[source]#
Three branches of HoVerNet model, which results in three outputs: HV is horizontal and vertical gradient map of each nucleus (regression), NP is the pixel prediction of all nuclei (segmentation), and NC is the type of each nucleus (classification).
- class monai.utils.enums.HoVerNetMode(value)[source]#
Modes for HoVerNet model: FAST: a faster implementation (than original) ORIGINAL: the original implementation
- class monai.utils.enums.ImageStatsKeys(value)[source]#
Defaults keys for dataset statistical analysis image modules
- class monai.utils.enums.InterpolateMode(value)[source]#
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
- class monai.utils.enums.LabelStatsKeys(value)[source]#
Defaults keys for dataset statistical analysis label modules
- class monai.utils.enums.LazyAttr(value)[source]#
MetaTensor with pending operations requires some key attributes tracked especially when the primary array is not up-to-date due to lazy evaluation. This class specifies the set of key attributes to be tracked for each MetaTensor. See also:
monai.transforms.lazy.utils.resample()
for more details.
- class monai.utils.enums.NdimageMode(value)[source]#
The available options determine how the input array is extended beyond its boundaries when interpolating. See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
- class monai.utils.enums.NumpyPadMode(value)[source]#
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
- class monai.utils.enums.PytorchPadMode(value)[source]#
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
- class monai.utils.enums.SkipMode(value)[source]#
See also:
monai.networks.layers.SkipConnection
- class monai.utils.enums.SpaceKeys(value)[source]#
The coordinate system keys, for example, Nifti1 uses Right-Anterior-Superior or “RAS”, DICOM (0020,0032) uses Left-Posterior-Superior or “LPS”. This type does not distinguish spatial 1/2/3D.
- class monai.utils.enums.SplineMode(value)[source]#
Order of spline interpolation.
See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
- class monai.utils.enums.StrEnum(value)[source]#
Enum subclass that converts its value to a string.
from monai.utils import StrEnum class Example(StrEnum): MODE_A = "A" MODE_B = "B" assert (list(Example) == ["A", "B"]) assert Example.MODE_A == "A" assert str(Example.MODE_A) == "A" assert monai.utils.look_up_option("A", Example) == "A"
- class monai.utils.enums.TraceKeys(value)[source]#
Extra metadata keys used for traceable transforms.
- class monai.utils.enums.TraceStatusKeys(value)[source]#
Enumerable status keys for the TraceKeys.STATUS flag
- class monai.utils.enums.TransformBackends(value)[source]#
Transform backends. Most of monai.transforms components first converts the input data into
torch.Tensor
ormonai.data.MetaTensor
. Internally, some transforms are made by converting the data intonumpy.array
orcupy.array
and use the underlying transform backend API to achieve the actual output array and converting back toTensor
/MetaTensor
. Transforms with more than one backend indicate the that they may convert the input data types to accommodate the underlying API.
- class monai.utils.enums.UpsampleMode(value)[source]#
See also:
monai.networks.blocks.UpSample
- class monai.utils.enums.Weight(value)[source]#
See also:
monai.losses.dice.GeneralizedDiceLoss
Jupyter Utilities#
This set of utility function is meant to make using Jupyter notebooks easier with MONAI. Plotting functions using Matplotlib produce common plots for metrics and images.
- class monai.utils.jupyter_utils.StatusMembers(value)[source]#
Named members of the status dictionary, others may be present for named metric values.
- class monai.utils.jupyter_utils.ThreadContainer(engine, loss_transform=<function _get_loss_from_output>, metric_transform=<function ThreadContainer.<lambda>>, status_format='{}: {:.4}')[source]#
Contains a running Engine object within a separate thread from main thread in a Jupyter notebook. This allows an engine to begin a run in the background and allow the starting notebook cell to complete. A user can thus start a run and then navigate away from the notebook without concern for loosing connection with the running cell. All output is acquired through methods which synchronize with the running engine using an internal lock member, acquiring this lock allows the engine to be inspected while it’s prevented from starting the next iteration.
- Parameters:
engine (
Engine
) – wrapped Engine object, when the container is started its run method is calledloss_transform (
Callable
) – callable to convert an output dict into a single numeric valuemetric_transform (
Callable
) – callable to convert a named metric value into a single numeric valuestatus_format (
str
) – format string for status key-value pairs.
- plot_status(logger, plot_func=<function plot_engine_status>)[source]#
Generate a plot of the current status of the contained engine whose loss and metrics were tracked by logger. The function plot_func must accept arguments title, engine, logger, and fig which are the plot title, self.engine, logger, and self.fig respectively. The return value must be a figure object (stored in self.fig) and a list of Axes objects for the plots in the figure. Only the figure is returned by this method, which holds the internal lock during the plot generation.
- property status_dict: dict[str, str]#
A dictionary containing status information, current loss, and current metric values.
- Return type:
dict
[str
,str
]
- monai.utils.jupyter_utils.plot_engine_status(engine, logger, title='Training Log', yscale='log', avg_keys=('loss', ), window_fraction=20, image_fn=<function tensor_to_images>, fig=None, selected_inst=0)[source]#
Plot the status of the given Engine with its logger. The plot will consist of a graph of loss values and metrics taken from the logger, and images taken from the output and batch members of engine.state. The images are converted to Numpy arrays suitable for input to Axes.imshow using image_fn, if this is None then no image plotting is done.
- Parameters:
engine – Engine to extract images from
logger – MetricLogger to extract loss and metric data from
title – graph title
yscale – for metric plot, scale for y-axis compatible with Axes.set_yscale
avg_keys – for metric plot, tuple of keys in graphmap to provide running average plots for
window_fraction – for metric plot, what fraction of the graph value length to use as the running average window
image_fn – callable converting tensors keyed to a name in the Engine to a tuple of images to plot
fig – Figure object to plot into, reuse from previous plotting for flicker-free refreshing
selected_inst – index of the instance to show in the image plot
- Returns:
Figure object (or fig if given), list of Axes objects for graph and images
- monai.utils.jupyter_utils.plot_metric_graph(ax, title, graphmap, yscale='log', avg_keys=('loss',), window_fraction=20)[source]#
Plot metrics on a single graph with running averages plotted for selected keys. The values in graphmap should be lists of (timepoint, value) pairs as stored in MetricLogger objects.
- Parameters:
ax – Axes object to plot into
title – graph title
graphmap – dictionary of named graph values, which are lists of values or (index, value) pairs
yscale – scale for y-axis compatible with Axes.set_yscale
avg_keys – tuple of keys in graphmap to provide running average plots for
window_fraction – what fraction of the graph value length to use as the running average window
- monai.utils.jupyter_utils.plot_metric_images(fig, title, graphmap, imagemap, yscale='log', avg_keys=('loss',), window_fraction=20)[source]#
Plot metric graph data with images below into figure fig. The intended use is for the graph data to be metrics from a training run and the images to be the batch and output from the last iteration. This uses plot_metric_graph to plot the metric graph.
- Parameters:
fig – Figure object to plot into, reuse from previous plotting for flicker-free refreshing
title – graph title
graphmap – dictionary of named graph values, which are lists of values or (index, value) pairs
imagemap – dictionary of named images to show with metric plot
yscale – for metric plot, scale for y-axis compatible with Axes.set_yscale
avg_keys – for metric plot, tuple of keys in graphmap to provide running average plots for
window_fraction – for metric plot, what fraction of the graph value length to use as the running average window
- Returns:
list of Axes objects for graph followed by images
- monai.utils.jupyter_utils.tensor_to_images(name, tensor)[source]#
Return an tuple of images derived from the given tensor. The name value indices which key from the output or batch value the tensor was stored as, or is “Batch” or “Output” if these were single tensors instead of dictionaries. Returns a tuple of 2D images of shape HW, or 3D images of shape CHW where C is color channels RGB or RGBA. This allows multiple images to be created from a single tensor, ie. to show each channel separately.
State Cacher#
- class monai.utils.state_cacher.StateCacher(in_memory, cache_dir=None, allow_overwrite=True, pickle_module=<module 'pickle' from '/home/docs/.asdf/installs/python/3.9.19/lib/python3.9/pickle.py'>, pickle_protocol=2)[source]#
Class to cache and retrieve the state of an object.
Objects can either be stored in memory or on disk. If stored on disk, they can be stored in a given directory, or alternatively a temporary location will be used.
If necessary/possible, restored objects will be returned to their original device.
Example:
>>> state_cacher = StateCacher(memory_cache, cache_dir=cache_dir) >>> state_cacher.store("model", model.state_dict()) >>> model.load_state_dict(state_cacher.retrieve("model"))
- __init__(in_memory, cache_dir=None, allow_overwrite=True, pickle_module=<module 'pickle' from '/home/docs/.asdf/installs/python/3.9.19/lib/python3.9/pickle.py'>, pickle_protocol=2)[source]#
Constructor.
- Parameters:
in_memory – boolean to determine if the object will be cached in memory or on disk.
cache_dir – directory for data to be cached if in_memory==False. Defaults to using a temporary directory. Any created files will be deleted during the StateCacher’s destructor.
allow_overwrite – allow the cache to be overwritten. If set to False, an error will be thrown if a matching already exists in the list of cached objects.
pickle_module – module used for pickling metadata and objects, default to pickle. this arg is used by torch.save, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
pickle_protocol – can be specified to override the default protocol, default to 2. this arg is used by torch.save, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
- store(key, data_obj, pickle_module=None, pickle_protocol=None)[source]#
Store a given object with the given key name.
- Parameters:
key – key of the data object to store.
data_obj – data object to store.
pickle_module – module used for pickling metadata and objects, default to self.pickle_module. this arg is used by torch.save, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
pickle_protocol – can be specified to override the default protocol, default to self.pickle_protocol. this arg is used by torch.save, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
Component store#
- class monai.utils.component_store.ComponentStore(name, description)[source]#
Represents a storage object for other objects (specifically functions) keyed to a name with a description.
These objects act as global named places for storing components for objects parameterised by component names. Typically this is functions although other objects can be added. Printing a component store will produce a list of members along with their docstring information if present.
Example:
TestStore = ComponentStore("Test Store", "A test store for demo purposes") @TestStore.add_def("my_func_name", "Some description of your function") def _my_func(a, b): '''A description of your function here.''' return a * b print(TestStore) # will print out name, description, and 'my_func_name' with the docstring func = TestStore["my_func_name"] result = func(7, 6)
- add(name, desc, value)[source]#
Store the object value under the name name with description desc.
- Return type:
~T
- add_def(name, desc)[source]#
Returns a decorator which stores the decorated function under name with description desc.
- Return type:
Callable
- property names: tuple[str, ...]#
Produces all factory names.
- Return type:
tuple
[str
, …]