Source code for monailabel.utils.others.class_utils

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import importlib.util
import inspect
import logging
import os
import sys
from distutils.util import strtobool
from typing import List

from monailabel.interfaces.exception import MONAILabelError, MONAILabelException

logger = logging.getLogger(__name__)

[docs]def unload_module(name): modules = [] for m in sorted(sys.modules): if m == name or m.startswith(f"{name}."): modules.append(m) if modules and strtobool(os.environ.get("MONAI_LABEL_RELOAD_APP_LIB", "true")):"Remove/Reload previous Modules: {modules}") for m in modules: del sys.modules[m]
[docs]def module_from_file(module_name, file_path): app_dir = os.path.dirname(file_path) sys.path.insert(0, app_dir) unload_module("lib") spec = importlib.util.spec_from_file_location(module_name, file_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) sys.path.remove(app_dir) logger.debug(f"module: {module}") return module
[docs]def is_subclass(n, o, class_c): if inspect.isclass(o) and n != class_c: b = [cls.__name__ for cls in o.__bases__] logger.debug(f"Base classes => {b}") if class_c in b:"Subclass for {class_c} Found: {o}") return True return False
[docs]def get_class_of_subclass(module, class_c): logger.debug(f"{module} => {class_c}") for n, o in inspect.getmembers(module): if not inspect.isclass(o): continue logger.debug(f"{n} => {o}") if is_subclass(n, o, class_c): return o return None
[docs]def get_class_of_subclass_from_file(module_name, file_path, class_c): return get_class_of_subclass(module_from_file(module_name, file_path), class_c)
[docs]def to_expression(class_path, class_args): key_val = [] for key in class_args: val = class_args[key] if isinstance(val, str): val = f"'{val}'" elif isinstance(val, tuple) or isinstance(val, list): vals = [] for v in val: if isinstance(v, str): v = f"'{v}'" else: v = str(v) vals.append(v) if isinstance(val, tuple): val = f"({', '.join(vals)})" else: val = f"[{', '.join(vals)}]" else: val = str(val) key_val.append(f"{key}={val}") return f"{class_path}({', '.join(key_val)})"
[docs]def class_args_to_exp(c, mappings=None): class_name = c["name"] class_name = mappings.get(class_name, class_name) if mappings else class_name class_args = c.get("args", {}) return to_expression(class_name, class_args)
[docs]def get_class_info(exp, handle_bool=True): if isinstance(exp, dict): return exp["name"], exp["args"] if exp.find("(") == -1: return exp, {} def foo(**kwargs): return kwargs if handle_bool: exp = exp.replace("=true", "=True").replace("=false", "=False") # safe to assume exp = exp.replace(" true", " True").replace(" false", " False") class_path = exp[: exp.find("(")] class_args = exp[exp.find("(") + 1 : -1] if exp.find("(") >= 0 else None logger.debug(f"Eval Input:: {class_path} => {class_args}") class_args = eval("foo(" + class_args + ")") if class_args else None logger.debug(f"{class_path} => {class_args}") return class_path, class_args
[docs]def init_class(class_path, class_args): if "." not in class_path: raise MONAILabelException( MONAILabelError.CLASS_INIT_ERROR, "Class path need to be in the form [module/file].[class_name]." ) module_name, class_name = class_path.rsplit(".", 1) m = importlib.import_module(module_name) importlib.reload(m) c = getattr(m, class_name) return c(**class_args) if class_args else c()
[docs]def init_class_from_exp(exp): class_path, class_args = get_class_info(exp) return init_class(class_path, class_args)
[docs]def get_class_names(p, subclass=None) -> List[str]: logger = logging.getLogger(__name__) result = [] logger.debug(f"Module File Path: {p.__file__}") if os.path.basename(p.__file__).startswith("__"): current_dir = os.path.dirname(p.__file__) current_module_name = p.__package__ for file in glob.glob(current_dir + "/*.py*"): name = os.path.splitext(os.path.basename(file))[0] if name.startswith("__"): continue module = importlib.import_module("." + name, package=current_module_name) for m in dir(module): c = getattr(module, m) if not c or inspect.isabstract(c): continue if ( inspect.isclass(c) and c.__module__ == module.__name__ and (not subclass or is_subclass(c.__name__, c, subclass)) ): result.append(c.__module__ + "." + c.__name__) else: for m in dir(p): c = getattr(p, m) if not c or inspect.isabstract(c): continue if ( inspect.isclass(c) and c.__module__ == p.__name__ and (not subclass or is_subclass(c.__name__, c, subclass)) ): result.append(c.__module__ + "." + c.__name__) return result