Source code for monai.utils.misc

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from collections.abc import Iterable, Sequence
from typing import Any, Tuple, Callable

import numpy as np
import torch
import random

_seed = None


[docs]def zip_with(op, *vals, mapfunc=map): """ Map `op`, using `mapfunc`, to each tuple derived from zipping the iterables in `vals`. """ return mapfunc(op, zip(*vals))
[docs]def star_zip_with(op, *vals): """ Use starmap as the mapping function in zipWith. """ return zip_with(op, *vals, mapfunc=itertools.starmap)
[docs]def first(iterable, default=None): """ Returns the first item in the given iterable or `default` if empty, meaningful mostly with 'for' expressions. """ for i in iterable: return i return default
[docs]def issequenceiterable(obj) -> bool: """ Determine if the object is an iterable sequence and is not a string. """ return isinstance(obj, Iterable) and not isinstance(obj, str)
[docs]def ensure_tuple(vals: Any) -> Tuple: """ Returns a tuple of `vals`. """ if not issequenceiterable(vals): vals = (vals,) return tuple(vals)
[docs]def ensure_tuple_size(tup, dim: int, pad_val=0) -> Tuple: """ Returns a copy of `tup` with `dim` values by either shortened or padded with `pad_val` as necessary. """ tup = ensure_tuple(tup) + (pad_val,) * dim return tuple(tup[:dim])
[docs]def ensure_tuple_rep(tup: Any, dim: int): """ Returns a copy of `tup` with `dim` values by either shortened or duplicated input. Examples:: >>> ensure_tuple_rep(1, 3) (1, 1, 1) >>> ensure_tuple_rep(None, 3) (None, None, None) >>> ensure_tuple_rep('test', 3) ('test', 'test', 'test') >>> ensure_tuple_rep([1, 2, 3], 3) (1, 2, 3) >>> ensure_tuple_rep(range(3), 3) (0, 1, 2) >>> ensure_tuple_rep([1, 2], 3) ValueError: sequence must have length 3, got length 2. Raises: ValueError: sequence must have length {dim}, got length {len(tup)}. """ if not issequenceiterable(tup): return (tup,) * dim elif len(tup) == dim: return tuple(tup) raise ValueError(f"sequence must have length {dim}, got length {len(tup)}.")
[docs]def fall_back_tuple(user_provided: Any, default: Sequence, func: Callable = lambda x: x and x > 0) -> Tuple: """ Refine `user_provided` according to the `default`, and returns as a validated tuple. The validation is done for each element in `user_provided` using `func`. If `func(user_provided[idx])` returns False, the corresponding `default[idx]` will be used as the fallback. Typically used when `user_provided` is a tuple of window size provided by the user, `default` is defined by data, this function returns an updated `user_provided` with its non-positive components replaced by the corresponding components from `default`. Args: user_provided: item to be validated. default: a sequence used to provided the fallbacks. func: a Callable to validate every components of `user_provided`. Examples:: >>> fall_back_tuple((1, 2), (32, 32)) (1, 2) >>> fall_back_tuple(None, (32, 32)) (32, 32) >>> fall_back_tuple((-1, 10), (32, 32)) (32, 10) >>> fall_back_tuple((-1, None), (32, 32)) (32, 32) >>> fall_back_tuple((1, None), (32, 32)) (1, 32) >>> fall_back_tuple(0, (32, 32)) (32, 32) >>> fall_back_tuple(range(3), (32, 64, 48)) (32, 1, 2) >>> fall_back_tuple([0], (32, 32)) ValueError: sequence must have length 2, got length 1. """ ndim = len(default) user = ensure_tuple_rep(user_provided, ndim) return tuple( # use the default values if user provided is not valid user_c if func(user_c) else default_c for default_c, user_c in zip(default, user) )
def is_scalar_tensor(val) -> bool: if torch.is_tensor(val) and val.ndim == 0: return True return False def is_scalar(val) -> bool: if torch.is_tensor(val) and val.ndim == 0: return True return bool(np.isscalar(val))
[docs]def progress_bar(index: int, count: int, desc: str = None, bar_len: int = 30, newline: bool = False) -> None: """print a progress bar to track some time consuming task. Args: index: current satus in progress. count: total steps of the progress. desc: description of the progress bar, if not None, show before the progress bar. bar_len: the total length of the bar on screen, default is 30 char. newline: whether to print in a new line for every index. """ end = "\r" if newline is False else "\r\n" filled_len = int(bar_len * index // count) bar = f"{desc} " if desc is not None else "" bar += "[" + "=" * filled_len + " " * (bar_len - filled_len) + "]" print(f"{index}/{count} {bar}", end=end) if index == count: print("")
def get_seed(): return _seed
[docs]def set_determinism(seed=np.iinfo(np.int32).max, additional_settings=None) -> None: """ Set random seed for modules to enable or disable deterministic training. Args: seed (None, int): the random seed to use, default is np.iinfo(np.int32).max. It is recommended to set a large seed, i.e. a number that has a good balance of 0 and 1 bits. Avoid having many 0 bits in the seed. if set to None, will disable deterministic training. additional_settings (Callable, list or tuple of Callables): additional settings that need to set random seed. """ if seed is None: # cast to 32 bit seed for CUDA seed_ = torch.default_generator.seed() % (np.iinfo(np.int32).max + 1) if not torch.cuda._is_in_bad_fork(): torch.cuda.manual_seed_all(seed_) else: torch.manual_seed(seed) global _seed _seed = seed random.seed(seed) np.random.seed(seed) if additional_settings is not None: additional_settings = ensure_tuple(additional_settings) for func in additional_settings: func(seed) if seed is not None: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False else: torch.backends.cudnn.deterministic = False