Source code for monai.data.meta_obj
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import itertools
import pprint
from copy import deepcopy
from typing import Any, Iterable
import numpy as np
import torch
from monai.utils.enums import TraceKeys
from monai.utils.misc import first
_TRACK_META = True
__all__ = ["get_track_meta", "set_track_meta", "MetaObj"]
[docs]def set_track_meta(val: bool) -> None:
"""
Boolean to set whether metadata is tracked. If `True`, metadata will be associated
its data by using subclasses of `MetaObj`. If `False`, then data will be returned
with empty metadata.
If `set_track_meta` is `False`, then standard data objects will be returned (e.g.,
`torch.Tensor` and `np.ndarray`) as opposed to MONAI's enhanced objects.
By default, this is `True`, and most users will want to leave it this way. However,
if you are experiencing any problems regarding metadata, and aren't interested in
preserving metadata, then you can disable it.
"""
global _TRACK_META
_TRACK_META = val
[docs]def get_track_meta() -> bool:
"""
Return the boolean as to whether metadata is tracked. If `True`, metadata will be
associated its data by using subclasses of `MetaObj`. If `False`, then data will be
returned with empty metadata.
If `set_track_meta` is `False`, then standard data objects will be returned (e.g.,
`torch.Tensor` and `np.ndarray`) as opposed to MONAI's enhanced objects.
By default, this is `True`, and most users will want to leave it this way. However,
if you are experiencing any problems regarding metadata, and aren't interested in
preserving metadata, then you can disable it.
"""
return _TRACK_META
[docs]class MetaObj:
"""
Abstract base class that stores data as well as any extra metadata.
This allows for subclassing `torch.Tensor` and `np.ndarray` through multiple inheritance.
Metadata is stored in the form of a dictionary.
Behavior should be the same as extended class (e.g., `torch.Tensor` or `np.ndarray`)
aside from the extended meta functionality.
Copying of information:
* For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the
first instance of `MetaObj` if `a.is_batch` is False
(For batched data, the metadata will be shallow copied for efficiency purposes).
"""
def __init__(self):
self._meta: dict = MetaObj.get_default_meta()
self._applied_operations: list = MetaObj.get_default_applied_operations()
self._is_batch: bool = False
[docs] @staticmethod
def flatten_meta_objs(*args: Iterable):
"""
Recursively flatten input and yield all instances of `MetaObj`.
This means that for both `torch.add(a, b)`, `torch.stack([a, b])` (and
their numpy equivalents), we return `[a, b]` if both `a` and `b` are of type
`MetaObj`.
Args:
args: Iterables of inputs to be flattened.
Returns:
list of nested `MetaObj` from input.
"""
for a in itertools.chain(*args):
if isinstance(a, (list, tuple)):
yield from MetaObj.flatten_meta_objs(a)
elif isinstance(a, MetaObj):
yield a
[docs] @staticmethod
def copy_items(data):
"""returns a copy of the data. list and dict are shallow copied for efficiency purposes."""
if isinstance(data, (list, dict, np.ndarray)):
return data.copy()
if isinstance(data, torch.Tensor):
return data.detach().clone()
return deepcopy(data)
[docs] def copy_meta_from(self, input_objs, copy_attr=True) -> None:
"""
Copy metadata from a `MetaObj` or an iterable of `MetaObj` instances.
Args:
input_objs: list of `MetaObj` to copy data from.
copy_attr: whether to copy each attribute with `MetaObj.copy_item`.
note that if the attribute is a nested list or dict, only a shallow copy will be done.
"""
first_meta = input_objs if isinstance(input_objs, MetaObj) else first(input_objs, default=self)
first_meta = first_meta.__dict__
if not copy_attr:
self.__dict__ = first_meta.copy() # shallow copy for performance
else:
self.__dict__.update({a: MetaObj.copy_items(first_meta[a]) for a in first_meta})
[docs] @staticmethod
def get_default_meta() -> dict:
"""Get the default meta.
Returns:
default metadata.
"""
return {}
[docs] @staticmethod
def get_default_applied_operations() -> list:
"""Get the default applied operations.
Returns:
default applied operations.
"""
return []
def __repr__(self) -> str:
"""String representation of class."""
out: str = "\nMetadata\n"
if self.meta is not None:
out += "".join(f"\t{k}: {v}\n" for k, v in self.meta.items())
else:
out += "None"
out += "\nApplied operations\n"
if self.applied_operations is not None:
out += pprint.pformat(self.applied_operations, indent=2, compact=True, width=120)
else:
out += "None"
out += f"\nIs batch?: {self.is_batch}"
return out
@property
def meta(self) -> dict:
"""Get the meta. Defaults to ``{}``."""
return self._meta if hasattr(self, "_meta") else MetaObj.get_default_meta()
@meta.setter
def meta(self, d) -> None:
"""Set the meta."""
if d == TraceKeys.NONE:
self._meta = MetaObj.get_default_meta()
self._meta = d
@property
def applied_operations(self) -> list[dict]:
"""Get the applied operations. Defaults to ``[]``."""
if hasattr(self, "_applied_operations"):
return self._applied_operations
return MetaObj.get_default_applied_operations()
@applied_operations.setter
def applied_operations(self, t) -> None:
"""Set the applied operations."""
if t == TraceKeys.NONE:
# received no operations when decollating a batch
self._applied_operations = MetaObj.get_default_applied_operations()
return
self._applied_operations = t
def push_applied_operation(self, t: Any) -> None:
self._applied_operations.append(t)
def pop_applied_operation(self) -> Any:
return self._applied_operations.pop()
@property
def is_batch(self) -> bool:
"""Return whether object is part of batch or not."""
return self._is_batch if hasattr(self, "_is_batch") else False
@is_batch.setter
def is_batch(self, val: bool) -> None:
"""Set whether object is part of batch or not."""
self._is_batch = val