# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import functools
import warnings
from copy import deepcopy
from typing import Any, Sequence

import numpy as np
import torch

import monai
from monai.config.type_definitions import NdarrayTensor
from import MetaObj, get_track_meta
from import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata
from monai.utils import look_up_option
from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor

__all__ = ["MetaTensor"]

def _get_named_tuple_like_type(func):
    if (
        hasattr(torch, "return_types")
        and hasattr(func, "__name__")
        and hasattr(torch.return_types, func.__name__)
        and isinstance(getattr(torch.return_types, func.__name__), type)
        return getattr(torch.return_types, func.__name__)
    return None

def _not_requiring_metadata(ret):
    return isinstance(ret, (int, str, bytes, torch.Size, torch.dtype, torch.device, np.ndarray)) or not (
        isinstance(ret, MetaTensor) or (isinstance(ret, Sequence) and any(isinstance(x, MetaTensor) for x in ret))

[docs] class MetaTensor(MetaObj, torch.Tensor): """ Class that inherits from both `torch.Tensor` and `MetaObj`, adding support for metadata. Metadata is stored in the form of a dictionary. Nested, an affine matrix will be stored. This should be in the form of `torch.Tensor`. Behavior should be the same as `torch.Tensor` aside from the extended meta functionality. Copying of information: * For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the first instance of `MetaTensor` if `a.is_batch` is False (For batched data, the metadata will be shallow copied for efficiency purposes). Example: .. code-block:: python import torch from import MetaTensor t = torch.tensor([1,2,3]) affine = torch.as_tensor([[2,0,0,0], [0,2,0,0], [0,0,2,0], [0,0,0,1]], dtype=torch.float64) meta = {"some": "info"} m = MetaTensor(t, affine=affine, meta=meta) m2 = m + m assert isinstance(m2, MetaTensor) assert m2.meta["some"] == "info" assert torch.all(m2.affine == affine) Notes: - Requires pytorch 1.9 or newer for full compatibility. - Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may not work if `im` is of type `MetaTensor`. This can be resolved with `torch.jit.trace(net, im.as_tensor())`. - For pytorch < 1.8, sharing `MetaTensor` instances across processes may not be supported. - For pytorch < 1.9, next(iter(meta_tensor)) returns a torch.Tensor. see: - A warning will be raised if in the constructor `affine` is not `None` and `meta` already contains the key `affine`. - You can query whether the `MetaTensor` is a batch with the `is_batch` attribute. - With a batch of data, `batch[0]` will return the 0th image with the 0th metadata. When the batch dimension is non-singleton, e.g., `batch[:, 0]`, `batch[..., -1]` and `batch[1:3]`, then all (or a subset in the last example) of the metadata will be returned, and `is_batch` will return `True`. - When creating a batch with this class, use `` as opposed to ``, as this will take care of collating the metadata properly. """ @staticmethod def __new__( cls, x, affine: torch.Tensor | None = None, meta: dict | None = None, applied_operations: list | None = None, *args, **kwargs, ) -> MetaTensor: _kwargs = {"device": kwargs.pop("device", None), "dtype": kwargs.pop("dtype", None)} if kwargs else {} return torch.as_tensor(x, *args, **_kwargs).as_subclass(cls)
[docs] def __init__( self, x, affine: torch.Tensor | None = None, meta: dict | None = None, applied_operations: list | None = None, *_args, **_kwargs, ) -> None: """ Args: x: initial array for the MetaTensor. Can be a list, tuple, NumPy ndarray, scalar, and other types. affine: optional 4x4 array. meta: dictionary of metadata. applied_operations: list of previously applied operations on the MetaTensor, the list is typically maintained by `monai.transforms.TraceableTransform`. See also: :py:class:`monai.transforms.TraceableTransform` _args: additional args (currently not in use in this constructor). _kwargs: additional kwargs (currently not in use in this constructor). Note: If a `meta` dictionary is given, use it. Else, if `meta` exists in the input tensor `x`, use it. Else, use the default value. Similar for the affine, except this could come from four places, priority: `affine`, `meta["affine"]`, `x.affine`, `get_default_affine`. """ super().__init__() # set meta if meta is not None: self.meta = meta elif isinstance(x, MetaObj): self.__dict__ = deepcopy(x.__dict__) # set the affine if affine is not None: if MetaKeys.AFFINE in self.meta: warnings.warn("Setting affine, but the applied meta contains an affine. This will be overwritten.") self.affine = affine elif MetaKeys.AFFINE in self.meta: # by using the setter function, we ensure it is converted to torch.Tensor if not already self.affine = self.meta[MetaKeys.AFFINE] else: self.affine = self.get_default_affine() # applied_operations if applied_operations is not None: self.applied_operations = applied_operations else: self.applied_operations = MetaObj.get_default_applied_operations() # if we are creating a new MetaTensor, then deep copy attributes if isinstance(x, torch.Tensor) and not isinstance(x, MetaTensor): self.copy_meta_from(self) if MetaKeys.SPACE not in self.meta: self.meta[MetaKeys.SPACE] = SpaceKeys.RAS # defaulting to the right-anterior-superior space
[docs] @staticmethod def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: """ Update the metadata from the output of `MetaTensor.__torch_function__`. The output of `torch.Tensor.__torch_function__` could be a single object or a sequence of them. Hence, in `MetaTensor.__torch_function__` we convert them to a list of not already, and then we loop across each element, processing metadata as necessary. For each element, if not of type `MetaTensor`, then nothing to do. Args: rets: the output from `torch.Tensor.__torch_function__`, which has been converted to a list in `MetaTensor.__torch_function__` if it wasn't already a `Sequence`. func: the torch function that was applied. Examples might be `torch.squeeze` or `torch.Tensor.__add__`. We need this since the metadata need to be treated differently if a batch of data is considered. For example, slicing (`torch.Tensor.__getitem__`) the ith element of the 0th dimension of a batch of data should return a ith tensor with the ith metadata. args: positional arguments that were passed to `func`. kwargs: keyword arguments that were passed to `func`. Returns: A sequence with the same number of elements as `rets`. For each element, if the input type was not `MetaTensor`, then no modifications will have been made. If global parameters have been set to false (e.g., `not get_track_meta()`), then any `MetaTensor` will be converted to `torch.Tensor`. Else, metadata will be propagated as necessary (see :py:func:`MetaTensor._copy_meta`). """ out = [] metas = None # optional output metadicts for each of the return value in `rets` is_batch = any(x.is_batch for x in MetaObj.flatten_meta_objs(args, kwargs.values()) if hasattr(x, "is_batch")) for idx, ret in enumerate(rets): # if not `MetaTensor`, nothing to do. if not isinstance(ret, MetaTensor): pass # if not tracking, convert to `torch.Tensor`. elif not get_track_meta(): ret = ret.as_tensor() # else, handle the `MetaTensor` metadata. else: meta_args = MetaObj.flatten_meta_objs(args, kwargs.values()) ret.is_batch = is_batch ret.copy_meta_from(meta_args, copy_attr=not is_batch) # the following is not implemented but the network arch may run into this case: # if func == and any(m.is_batch if hasattr(m, "is_batch") else False for m in meta_args): # raise NotImplementedError(" is not implemented for batch of MetaTensors.") if is_batch: ret = MetaTensor._handle_batched(ret, idx, metas, func, args, kwargs) out.append(ret) # if the input was a tuple, then return it as a tuple return tuple(out) if isinstance(rets, tuple) else out
@classmethod def _handle_batched(cls, ret, idx, metas, func, args, kwargs): """utility function to handle batched MetaTensors.""" # If we have a batch of data, then we need to be careful if a slice of # the data is returned. Depending on how the data are indexed, we return # some or all of the metadata, and the return object may or may not be a # batch of data (e.g., `batch[:,-1]` versus `batch[0]`). # if indexing e.g., `batch[0]` if func == torch.Tensor.__getitem__: if idx > 0 or len(args) < 2 or len(args[0]) < 1: return ret batch_idx = args[1][0] if isinstance(args[1], Sequence) else args[1] # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the # first element will be `slice(None, None, None)` and `Ellipsis`, # respectively. Don't need to do anything with the metadata. if batch_idx in (slice(None, None, None), Ellipsis, None) or isinstance(batch_idx, torch.Tensor): return ret dec_batch = decollate_batch(args[0], detach=False) ret_meta = dec_batch[batch_idx] if isinstance(ret_meta, list) and ret_meta: # e.g. batch[0:2], re-collate try: ret_meta = list_data_collate(ret_meta) except (TypeError, ValueError, RuntimeError, IndexError) as e: raise ValueError( "Inconsistent batched metadata dicts when slicing a batch of MetaTensors, " "please consider converting it into a torch Tensor using `x.as_tensor()` or " "a numpy array using `x.array`." ) from e elif isinstance(ret_meta, MetaObj): # e.g. `batch[0]` or `batch[0, 1]`, batch_idx is int ret_meta.is_batch = False if hasattr(ret_meta, "__dict__"): ret.__dict__ = ret_meta.__dict__.copy() # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. # But we only want to split the batch if the `unbind` is along the 0th dimension. elif func == torch.Tensor.unbind: if len(args) > 1: dim = args[1] elif "dim" in kwargs: dim = kwargs["dim"] else: dim = 0 if dim == 0: if metas is None: metas = decollate_batch(args[0], detach=False) if hasattr(metas[idx], "__dict__"): ret.__dict__ = metas[idx].__dict__.copy() ret.is_batch = False return ret @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any: """Wraps all torch functions.""" if kwargs is None: kwargs = {} ret = super().__torch_function__(func, types, args, kwargs) # if `out` has been used as argument, metadata is not copied, nothing to do. # if "out" in kwargs: # return ret if _not_requiring_metadata(ret): return ret if _get_named_tuple_like_type(func) is not None and isinstance(ret, _get_named_tuple_like_type(func)): # for torch.max(torch.tensor(1.0), dim=0), the return type is named-tuple like out_items = MetaTensor.update_meta(ret, func, args, kwargs) for idx in range(ret.n_fields): ret[idx].meta = out_items[idx].meta ret[idx].applied_operations = out_items[idx].applied_operations return ret # we might have 1 or multiple outputs. Might be MetaTensor, might be something # else (e.g., `__repr__` returns a string). # Convert to list (if necessary), process, and at end remove list if one was added. if not isinstance(ret, Sequence): ret = [ret] unpack = True else: unpack = False ret = MetaTensor.update_meta(ret, func, args, kwargs) return ret[0] if unpack else ret @staticmethod def _convert(x): if isinstance(x, (MetaTensor, torch.Tensor, tuple, list)): return convert_data_type(x, output_type=np.ndarray, wrap_sequence=False)[0] return x def __array_function__(self, func, types, args, kwargs): """for numpy Interoperability, so that we can compute ``np.sum(MetaTensor([1.0]))``.""" try: if not func.__module__.startswith("numpy"): return NotImplemented except AttributeError: return NotImplemented _args = list(map(MetaTensor._convert, args)) _kwargs = {k: MetaTensor._convert(v) for k, v in kwargs.items()} return func(*_args, **_kwargs) def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): """ For numpy interoperability, so that we can compute ``MetaTensor([1.0]) >= np.asarray([1.0])``. This is for pytorch > 1.8. """ try: if not type(ufunc).__module__.startswith("numpy"): return NotImplemented except AttributeError: return NotImplemented if method != "__call__": return NotImplemented _inputs = map(MetaTensor._convert, inputs) _kwargs = {k: MetaTensor._convert(v) for k, v in kwargs.items()} if "out" in _kwargs: return NotImplemented # not supported try: return getattr(ufunc, method)(*_inputs, **_kwargs) except AttributeError: return NotImplemented @staticmethod def get_default_affine(dtype=torch.float64) -> torch.Tensor: return torch.eye(4, device=torch.device("cpu"), dtype=dtype)
[docs] def as_tensor(self) -> torch.Tensor: """ Return the `MetaTensor` as a `torch.Tensor`. It is OS dependent as to whether this will be a deep copy or not. """ return self.as_subclass(torch.Tensor)
[docs] def get_array(self, output_type=np.ndarray, dtype=None, device=None, *_args, **_kwargs): """ Returns a new array in `output_type`, the array shares the same underlying storage when the output is a numpy array. Changes to self tensor will be reflected in the ndarray and vice versa. Args: output_type: output type, see also: :py:func:`monai.utils.convert_data_type`. dtype: dtype of output data. Converted to correct library type (e.g., `np.float32` is converted to `torch.float32` if output type is `torch.Tensor`). If left blank, it remains unchanged. device: if the output is a `torch.Tensor`, select device (if `None`, unchanged). _args: currently unused parameters. _kwargs: currently unused parameters. """ return convert_data_type(self, output_type=output_type, dtype=dtype, device=device, wrap_sequence=True)[0]
[docs] def set_array(self, src, non_blocking: bool = False, *_args, **_kwargs): """ Copies the elements from src into self tensor and returns self. The src tensor must be broadcastable with the self tensor. It may be of a different data type or reside on a different device. See also: `` Args: src: the source tensor to copy from. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. _args: currently unused parameters. _kwargs: currently unused parameters. """ converted: torch.Tensor = convert_to_tensor(src, track_meta=False, wrap_sequence=True) try: return self.copy_(converted, non_blocking=non_blocking) except RuntimeError: # skip the shape checking = converted return self
@property def array(self): """ Returns a numpy array of ``self``. The array and ``self`` shares the same underlying storage if self is on cpu. Changes to ``self`` (it's a subclass of torch.Tensor) will be reflected in the ndarray and vice versa. If ``self`` is not on cpu, the call will move the array to cpu and then the storage is not shared. :getter: see also: :py:func:`MetaTensor.get_array()` :setter: see also: :py:func:`MetaTensor.set_array()` """ return self.get_array() @array.setter def array(self, src) -> None: """A default setter using ``self.set_array()``""" self.set_array(src)
[docs] def as_dict(self, key: str, output_type=torch.Tensor, dtype=None) -> dict: """ Get the object as a dictionary for backwards compatibility. This method does not make a deep copy of the objects. Args: key: Base key to store main data. The key for the metadata will be determined using `PostFix`. output_type: `torch.Tensor` or `np.ndarray` for the main data. dtype: dtype of output data. Converted to correct library type (e.g., `np.float32` is converted to `torch.float32` if output type is `torch.Tensor`). If left blank, it remains unchanged. Return: A dictionary consisting of three keys, the main data (stored under `key`) and the metadata. """ if output_type not in (torch.Tensor, np.ndarray): raise ValueError(f"output_type must be torch.Tensor or np.ndarray, got {output_type}.") return { key: self.get_array(output_type=output_type, dtype=dtype), PostFix.meta(key): self.meta, PostFix.transforms(key): self.applied_operations, }
[docs] def astype(self, dtype, device=None, *_args, **_kwargs): """ Cast to ``dtype``, sharing data whenever possible. Args: dtype: dtypes such as np.float32, torch.float, "np.float32", float. device: the device if `dtype` is a torch data type. _args: additional args (currently unused). _kwargs: additional kwargs (currently unused). Returns: data array instance """ if isinstance(dtype, str): mod_str, *dtype = dtype.split(".", 1) dtype = mod_str if not dtype else dtype[0] else: mod_str = getattr(dtype, "__module__", "torch") mod_str = look_up_option(mod_str, {"torch", "numpy", "np"}, default="numpy") out_type: type[torch.Tensor] | type[np.ndarray] | None if mod_str == "torch": out_type = torch.Tensor elif mod_str in ("numpy", "np"): out_type = np.ndarray else: out_type = None return self.get_array(output_type=out_type, dtype=dtype, device=device)
@property def affine(self) -> torch.Tensor: """Get the affine. Defaults to ``torch.eye(4, dtype=torch.float64)``""" return self.meta.get(MetaKeys.AFFINE, self.get_default_affine()) # type: ignore @affine.setter def affine(self, d: NdarrayTensor) -> None: """Set the affine.""" self.meta[MetaKeys.AFFINE] = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64) @property def pixdim(self): """Get the spacing""" if self.is_batch: return [affine_to_spacing(a) for a in self.affine] return affine_to_spacing(self.affine)
[docs] def peek_pending_shape(self): """ Get the currently expected spatial shape as if all the pending operations are executed. For tensors that have more than 3 spatial dimensions, only the shapes of the top 3 dimensions will be returned. """ res = None if self.pending_operations: res = self.pending_operations[-1].get(LazyAttr.SHAPE, None) # default to spatial shape (assuming channel-first input) return tuple(convert_to_numpy(self.shape, wrap_sequence=True).tolist()[1:]) if res is None else res
def peek_pending_affine(self): res = self.affine r = len(res) - 1 if r not in (2, 3): warnings.warn(f"Only 2d and 3d affine are supported, got {r}d input.") for p in self.pending_operations: next_matrix = convert_to_tensor(p.get(LazyAttr.AFFINE), dtype=torch.float64) if next_matrix is None: continue res = convert_to_dst_type(res, next_matrix)[0] next_matrix =, next_matrix) res = monai.transforms.lazy.utils.combine_transforms(res, next_matrix) return res def peek_pending_rank(self): a = self.pending_operations[-1].get(LazyAttr.AFFINE, None) if self.pending_operations else self.affine return 1 if a is None else int(max(1, len(a) - 1))
[docs] def new_empty(self, size, dtype=None, device=None, requires_grad=False): """ must be defined for deepcopy to work See: - """ return type(self)( self.as_tensor().new_empty(size=size, dtype=dtype, device=device, requires_grad=requires_grad) )
[docs] def clone(self, **kwargs): """ Returns a copy of the MetaTensor instance. Args: kwargs: additional keyword arguments to `torch.clone`. See also: """ new_inst = MetaTensor(self.as_tensor().clone(**kwargs)) new_inst.__dict__ = deepcopy(self.__dict__) return new_inst
[docs] @staticmethod def ensure_torch_and_prune_meta( im: NdarrayTensor, meta: dict | None, simple_keys: bool = False, pattern: str | None = None, sep: str = "." ): """ Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary, convert that to `torch.Tensor`, too. Remove any superfluous metadata. Args: im: Input image (`np.ndarray` or `torch.Tensor`) meta: Metadata dictionary. When it's None, the metadata is not tracked, this method returns a torch.Tensor. simple_keys: whether to keep only a simple subset of metadata keys. pattern: combined with `sep`, a regular expression used to match and prune keys in the metadata (nested dictionary), default to None, no key deletion. sep: combined with `pattern`, used to match and delete keys in the metadata (nested dictionary). default is ".", see also :py:class:`monai.transforms.DeleteItemsd`. e.g. ``pattern=".*_code$", sep=" "`` removes any meta keys that ends with ``"_code"``. Returns: By default, a `MetaTensor` is returned. However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned. """ img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray # if not tracking metadata, return `torch.Tensor` if not isinstance(img, MetaTensor): return img if meta is None: meta = {} # remove any superfluous metadata. if simple_keys: # ensure affine is of type `torch.Tensor` if MetaKeys.AFFINE in meta: meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE]) # bc-breaking remove_extra_metadata(meta) # bc-breaking if pattern is not None: meta = monai.transforms.DeleteItemsd(keys=pattern, sep=sep, use_re=True)(meta) # return the `MetaTensor` if meta is None: meta = {} img.meta = meta if MetaKeys.AFFINE in meta: img.affine = meta[MetaKeys.AFFINE] # this uses the affine property setter else: img.affine = MetaTensor.get_default_affine() return img
def __repr__(self): """ Prints a representation of the tensor. Prepends "meta" to ``torch.Tensor.__repr__``. Use ``print_verbose`` for associated metadata. """ return f"meta{self.as_tensor().__repr__()}" def __str__(self): """ Prints a representation of the tensor. Prepends "meta" to ``torch.Tensor.__str__``. Use ``print_verbose`` for associated metadata. """ return f"meta{str(self.as_tensor())}" def __format__(self, format_spec): """ returns the output of pytorch tensor's ``__format__`` method. """ return self.as_tensor().__format__(format_spec)
[docs] def print_verbose(self) -> None: """Verbose print with meta data.""" print(self) if self.meta is not None: print(self.meta.__repr__())