# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from importlib import import_module
from pkgutil import walk_packages
from re import match
from typing import Any, Callable, Tuple
OPTIONAL_IMPORT_MSG_FMT = "{}"
[docs]def export(modname):
"""
Make the decorated object a member of the named module. This will also add the object under its aliases if it has
a `__aliases__` member, thus this decorator should be before the `alias` decorator to pick up those names. Alias
names which conflict with package names or existing members will be ignored.
"""
def _inner(obj):
mod = import_module(modname)
if not hasattr(mod, obj.__name__):
setattr(mod, obj.__name__, obj)
# add the aliases for `obj` to the target module
for alias in getattr(obj, "__aliases__", ()):
if not hasattr(mod, alias):
setattr(mod, alias, obj)
return obj
return _inner
[docs]def load_submodules(basemod, load_all: bool = True, exclude_pattern: str = "(.*[tT]est.*)|(_.*)"):
"""
Traverse the source of the module structure starting with module `basemod`, loading all packages plus all files if
`loadAll` is True, excluding anything whose name matches `excludePattern`.
"""
submodules = []
for importer, name, is_pkg in walk_packages(basemod.__path__):
if (is_pkg or load_all) and match(exclude_pattern, name) is None:
mod = import_module(basemod.__name__ + "." + name) # why do I need to do this first?
importer.find_module(name).load_module(name)
submodules.append(mod)
return submodules
@export("monai.utils")
def get_full_type_name(typeobj):
module = typeobj.__module__
if module is None or module == str.__class__.__module__:
return typeobj.__name__ # Avoid reporting __builtin__
else:
return module + "." + typeobj.__name__
[docs]def min_version(the_module, min_version_str: str = "") -> bool:
"""
Convert version strings into tuples of int and compare them.
Returns True if the module's version is greater or equal to the 'min_version'.
When min_version_str is not provided, it always returns True.
"""
if min_version_str:
mod_version = tuple(int(x) for x in the_module.__version__.split(".")[:2])
required = tuple(int(x) for x in min_version_str.split(".")[:2])
return mod_version >= required
return True # always valid version
[docs]def exact_version(the_module, version_str: str = "") -> bool:
"""
Returns True if the module's __version__ matches version_str
"""
return bool(the_module.__version__ == version_str)
[docs]def optional_import(
module: str,
version: str = "",
version_checker: Callable = min_version,
name: str = "",
descriptor: str = OPTIONAL_IMPORT_MSG_FMT,
version_args=None,
allow_namespace_pkg: bool = False,
) -> Tuple[Any, bool]:
"""
Imports an optional module specified by `module` string.
Any importing related exceptions will be stored, and exceptions raise lazily
when attempting to use the failed-to-import module.
Args:
module: name of the module to be imported.
version: version string used by the version_checker.
version_checker: a callable to check the module version, Defaults to monai.utils.min_version.
name: a non-module attribute (such as method/class) to import from the imported module.
descriptor: a format string for the final error message when using a not imported module.
version_args: additional parameters to the version checker.
allow_namespace_pkg: whether importing a namespace package is allowed. Defaults to False.
Returns:
The imported module and a boolean flag indicating whether the import is successful.
Raises:
_exception: Optional import: {msg}.
Examples::
>>> torch, flag = optional_import('torch', '1.1')
>>> print(torch, flag)
<module 'torch' from 'python/lib/python3.6/site-packages/torch/__init__.py'> True
>>> the_module, flag = optional_import('unknown_module')
>>> print(flag)
False
>>> the_module.method # trying to access a module which is not imported
AttributeError: Optional import: import unknown_module (No module named 'unknown_module').
>>> torch, flag = optional_import('torch', '42', exact_version)
>>> torch.nn # trying to access a module for which there isn't a proper version imported
AttributeError: Optional import: import torch (requires version '42' by 'exact_version').
>>> conv, flag = optional_import('torch.nn.functional', '1.0', name='conv1d')
>>> print(conv)
<built-in method conv1d of type object at 0x11a49eac0>
>>> conv, flag = optional_import('torch.nn.functional', '42', name='conv1d')
>>> conv() # trying to use a function from the not successfully imported module (due to unmatched version)
AttributeError: Optional import: from torch.nn.functional import conv1d (requires version '42' by 'min_version').
"""
tb = None
exception_str = ""
if name:
actual_cmd = f"from {module} import {name}"
else:
actual_cmd = f"import {module}"
try:
pkg = __import__(module) # top level module
the_module = importlib.import_module(module)
if not allow_namespace_pkg:
is_namespace = getattr(the_module, "__file__", None) is None and hasattr(the_module, "__path__")
assert not is_namespace
if name: # user specified to load class/function/... from the module
the_module = getattr(the_module, name)
except Exception as import_exception: # any exceptions during import
tb = import_exception.__traceback__
exception_str = f"{import_exception}"
else: # found the module
if version_args and version_checker(pkg, f"{version}", version_args):
return the_module, True
if not version_args and version_checker(pkg, f"{version}"):
return the_module, True
# preparing lazy error message
msg = descriptor.format(actual_cmd)
if version and tb is None: # a pure version issue
msg += f" (requires '{module} {version}' by '{version_checker.__name__}')"
if exception_str:
msg += f" ({exception_str})"
class _LazyRaise:
def __init__(self, *_args, **_kwargs):
_default_msg = (
f"Optional import: {msg}."
+ "\n\nFor details about installing the optional dependencies, please visit:"
+ "\n https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies"
)
if tb is None:
self._exception = AttributeError(_default_msg)
else:
self._exception = AttributeError(_default_msg).with_traceback(tb)
def __getattr__(self, name):
raise self._exception
def __call__(self, *_args, **_kwargs):
raise self._exception
return _LazyRaise(), False