Source code for monai.data.meta_obj

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import itertools
import pprint
from copy import deepcopy
from typing import Any, Iterable

import numpy as np
import torch

from monai.utils import TraceKeys, first, is_immutable

_TRACK_META = True

__all__ = ["get_track_meta", "set_track_meta", "MetaObj"]


[docs] def set_track_meta(val: bool) -> None: """ Boolean to set whether metadata is tracked. If `True`, metadata will be associated its data by using subclasses of `MetaObj`. If `False`, then data will be returned with empty metadata. If `set_track_meta` is `False`, then standard data objects will be returned (e.g., `torch.Tensor` and `np.ndarray`) as opposed to MONAI's enhanced objects. By default, this is `True`, and most users will want to leave it this way. However, if you are experiencing any problems regarding metadata, and aren't interested in preserving metadata, then you can disable it. """ global _TRACK_META _TRACK_META = val
[docs] def get_track_meta() -> bool: """ Return the boolean as to whether metadata is tracked. If `True`, metadata will be associated its data by using subclasses of `MetaObj`. If `False`, then data will be returned with empty metadata. If `set_track_meta` is `False`, then standard data objects will be returned (e.g., `torch.Tensor` and `np.ndarray`) as opposed to MONAI's enhanced objects. By default, this is `True`, and most users will want to leave it this way. However, if you are experiencing any problems regarding metadata, and aren't interested in preserving metadata, then you can disable it. """ return _TRACK_META
[docs] class MetaObj: """ Abstract base class that stores data as well as any extra metadata. This allows for subclassing `torch.Tensor` and `np.ndarray` through multiple inheritance. Metadata is stored in the form of a dictionary. Behavior should be the same as extended class (e.g., `torch.Tensor` or `np.ndarray`) aside from the extended meta functionality. Copying of information: * For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the first instance of `MetaObj` if `a.is_batch` is False (For batched data, the metadata will be shallow copied for efficiency purposes). """ def __init__(self) -> None: self._meta: dict = MetaObj.get_default_meta() self._applied_operations: list = MetaObj.get_default_applied_operations() self._pending_operations: list = MetaObj.get_default_applied_operations() # the same default as applied_ops self._is_batch: bool = False
[docs] @staticmethod def flatten_meta_objs(*args: Iterable): """ Recursively flatten input and yield all instances of `MetaObj`. This means that for both `torch.add(a, b)`, `torch.stack([a, b])` (and their numpy equivalents), we return `[a, b]` if both `a` and `b` are of type `MetaObj`. Args: args: Iterables of inputs to be flattened. Returns: list of nested `MetaObj` from input. """ for a in itertools.chain(*args): if isinstance(a, (list, tuple)): yield from MetaObj.flatten_meta_objs(a) elif isinstance(a, MetaObj): yield a
[docs] @staticmethod def copy_items(data): """returns a copy of the data. list and dict are shallow copied for efficiency purposes.""" if is_immutable(data): return data if isinstance(data, (list, dict, np.ndarray)): return data.copy() if isinstance(data, torch.Tensor): return data.detach().clone() return deepcopy(data)
[docs] def copy_meta_from(self, input_objs, copy_attr=True, keys=None): """ Copy metadata from a `MetaObj` or an iterable of `MetaObj` instances. Args: input_objs: list of `MetaObj` to copy data from. copy_attr: whether to copy each attribute with `MetaObj.copy_item`. note that if the attribute is a nested list or dict, only a shallow copy will be done. keys: the keys of attributes to copy from the ``input_objs``. If None, all keys from the input_objs will be copied. """ first_meta = input_objs if isinstance(input_objs, MetaObj) else first(input_objs, default=self) if not hasattr(first_meta, "__dict__"): return self first_meta = first_meta.__dict__ keys = first_meta.keys() if keys is None else keys if not copy_attr: self.__dict__ = {a: first_meta[a] for a in keys if a in first_meta} # shallow copy for performance else: self.__dict__.update({a: MetaObj.copy_items(first_meta[a]) for a in keys if a in first_meta}) return self
[docs] @staticmethod def get_default_meta() -> dict: """Get the default meta. Returns: default metadata. """ return {}
[docs] @staticmethod def get_default_applied_operations() -> list: """Get the default applied operations. Returns: default applied operations. """ return []
def __repr__(self) -> str: """String representation of class.""" out: str = "\nMetadata\n" if self.meta is not None: out += "".join(f"\t{k}: {v}\n" for k, v in self.meta.items()) else: out += "None" out += "\nApplied operations\n" if self.applied_operations is not None: out += pprint.pformat(self.applied_operations, indent=2, compact=True, width=120) else: out += "None" out += f"\nIs batch?: {self.is_batch}" return out @property def meta(self) -> dict: """Get the meta. Defaults to ``{}``.""" return self._meta if hasattr(self, "_meta") else MetaObj.get_default_meta() @meta.setter def meta(self, d) -> None: """Set the meta.""" if d == TraceKeys.NONE: self._meta = MetaObj.get_default_meta() else: self._meta = d @property def applied_operations(self) -> list[dict]: """Get the applied operations. Defaults to ``[]``.""" if hasattr(self, "_applied_operations"): return self._applied_operations return MetaObj.get_default_applied_operations() @applied_operations.setter def applied_operations(self, t) -> None: """Set the applied operations.""" if t == TraceKeys.NONE: # received no operations when decollating a batch self._applied_operations = MetaObj.get_default_applied_operations() return self._applied_operations = t def push_applied_operation(self, t: Any) -> None: self._applied_operations.append(t) def pop_applied_operation(self) -> Any: return self._applied_operations.pop() @property def pending_operations(self) -> list[dict]: """Get the pending operations. Defaults to ``[]``.""" if hasattr(self, "_pending_operations"): return self._pending_operations return MetaObj.get_default_applied_operations() # the same default as applied_ops @property def has_pending_operations(self) -> bool: """ Determine whether there are pending operations. Returns: True if there are pending operations; False if not """ return self.pending_operations is not None and len(self.pending_operations) > 0 def push_pending_operation(self, t: Any) -> None: self._pending_operations.append(t) def pop_pending_operation(self) -> Any: return self._pending_operations.pop() def clear_pending_operations(self) -> Any: self._pending_operations = MetaObj.get_default_applied_operations() @property def is_batch(self) -> bool: """Return whether object is part of batch or not.""" return self._is_batch if hasattr(self, "_is_batch") else False @is_batch.setter def is_batch(self, val: bool) -> None: """Set whether object is part of batch or not.""" self._is_batch = val