Source code for monai.data.meta_tensor
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
from copy import deepcopy
from typing import Any, Callable, Sequence
import torch
from monai.config.type_definitions import NdarrayTensor
from monai.data.meta_obj import MetaObj, get_track_meta
from monai.data.utils import decollate_batch, list_data_collate, remove_extra_metadata
from monai.utils.enums import PostFix
__all__ = ["MetaTensor"]
[docs]class MetaTensor(MetaObj, torch.Tensor):
"""
Class that inherits from both `torch.Tensor` and `MetaObj`, adding support for metadata.
Metadata is stored in the form of a dictionary. Nested, an affine matrix will be
stored. This should be in the form of `torch.Tensor`.
Behavior should be the same as `torch.Tensor` aside from the extended
meta functionality.
Copying of information:
* For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the
first instance of `MetaTensor`.
Example:
.. code-block:: python
import torch
from monai.data import MetaTensor
t = torch.tensor([1,2,3])
affine = torch.eye(4) * 100
meta = {"some": "info"}
m = MetaTensor(t, affine=affine, meta=meta)
m2 = m+m
assert isinstance(m2, MetaTensor)
assert m2.meta["some"] == "info"
assert m2.affine == affine
Notes:
- Requires pytorch 1.9 or newer for full compatibility.
- Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may
not work if `im` is of type `MetaTensor`. This can be resolved with
`torch.jit.trace(net, im.as_tensor())`.
- For pytorch < 1.8, sharing `MetaTensor` instances across processes may not be supported.
- A warning will be raised if in the constructor `affine` is not `None` and
`meta` already contains the key `affine`.
- You can query whether the `MetaTensor` is a batch with the `is_batch` attribute.
- With a batch of data, `batch[0]` will return the 0th image
with the 0th metadata. When the batch dimension is non-singleton, e.g.,
`batch[:, 0]`, `batch[..., -1]` and `batch[1:3]`, then all (or a subset in the
last example) of the metadata will be returned, and `is_batch` will return `True`.
- When creating a batch with this class, use `monai.data.DataLoader` as opposed
to `torch.utils.data.DataLoader`, as this will take care of collating the
metadata properly.
"""
@staticmethod
def __new__(
cls,
x,
affine: torch.Tensor | None = None,
meta: dict | None = None,
applied_operations: list | None = None,
*args,
**kwargs,
) -> MetaTensor:
_kwargs = {"device": kwargs.pop("device", None), "dtype": kwargs.pop("dtype", None)} if kwargs else {}
return torch.as_tensor(x, *args, **_kwargs).as_subclass(cls) # type: ignore
[docs] def __init__(
self,
x,
affine: torch.Tensor | None = None,
meta: dict | None = None,
applied_operations: list | None = None,
*_args,
**_kwargs,
) -> None:
"""
If `meta` is given, use it. Else, if `meta` exists in the input tensor, use it.
Else, use the default value. Similar for the affine, except this could come from
four places.
Priority: `affine`, `meta["affine"]`, `x.affine`, `get_default_affine`.
"""
super().__init__()
# set meta
if meta is not None:
self.meta = meta
elif isinstance(x, MetaObj):
self.meta = x.meta
# set the affine
if affine is not None:
if "affine" in self.meta:
warnings.warn("Setting affine, but the applied meta contains an affine. This will be overwritten.")
self.affine = affine
elif "affine" in self.meta:
# by using the setter function, we ensure it is converted to torch.Tensor if not already
self.affine = self.meta["affine"]
elif isinstance(x, MetaTensor):
self.affine = x.affine
else:
self.affine = self.get_default_affine()
# applied_operations
if applied_operations is not None:
self.applied_operations = applied_operations
elif isinstance(x, MetaTensor):
self.applied_operations = x.applied_operations
else:
self.applied_operations = self.get_default_applied_operations()
# if we are creating a new MetaTensor, then deep copy attributes
if isinstance(x, torch.Tensor) and not isinstance(x, MetaTensor):
self.meta = deepcopy(self.meta)
self.applied_operations = deepcopy(self.applied_operations)
self.affine = self.affine.to(self.device)
def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None:
super()._copy_attr(attribute, input_objs, default_fn, deep_copy)
val = getattr(self, attribute)
if isinstance(val, torch.Tensor):
setattr(self, attribute, val.to(self.device))
[docs] @staticmethod
def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
"""
Update the metadata from the output of `MetaTensor.__torch_function__`.
The output of `torch.Tensor.__torch_function__` could be a single object or a
sequence of them. Hence, in `MetaTensor.__torch_function__` we convert them to a
list of not already, and then we loop across each element, processing metadata
as necessary. For each element, if not of type `MetaTensor`, then nothing to do.
Args:
rets: the output from `torch.Tensor.__torch_function__`, which has been
converted to a list in `MetaTensor.__torch_function__` if it wasn't
already a `Sequence`.
func: the torch function that was applied. Examples might be `torch.squeeze`
or `torch.Tensor.__add__`. We need this since the metadata need to be
treated differently if a batch of data is considered. For example,
slicing (`torch.Tensor.__getitem__`) the ith element of the 0th
dimension of a batch of data should return a ith tensor with the ith
metadata.
args: positional arguments that were passed to `func`.
kwargs: keyword arguments that were passed to `func`.
Returns:
A sequence with the same number of elements as `rets`. For each element, if
the input type was not `MetaTensor`, then no modifications will have been
made. If global parameters have been set to false (e.g.,
`not get_track_meta()`), then any `MetaTensor` will be converted to
`torch.Tensor`. Else, metadata will be propogated as necessary (see
:py:func:`MetaTensor._copy_meta`).
"""
out = []
metas = None
for idx, ret in enumerate(rets):
# if not `MetaTensor`, nothing to do.
if not isinstance(ret, MetaTensor):
pass
# if not tracking, convert to `torch.Tensor`.
elif not get_track_meta():
ret = ret.as_tensor()
# else, handle the `MetaTensor` metadata.
else:
meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values()))
ret._copy_meta(meta_args)
# If we have a batch of data, then we need to be careful if a slice of
# the data is returned. Depending on how the data are indexed, we return
# some or all of the metadata, and the return object may or may not be a
# batch of data (e.g., `batch[:,-1]` versus `batch[0]`).
if ret.is_batch:
# only decollate metadata once
if metas is None:
metas = decollate_batch(ret.meta)
# if indexing e.g., `batch[0]`
if func == torch.Tensor.__getitem__:
idx = args[1]
if isinstance(idx, Sequence):
idx = idx[0]
# if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
# first element will be `slice(None, None, None)` and `Ellipsis`,
# respectively. Don't need to do anything with the metadata.
if idx not in (slice(None, None, None), Ellipsis):
meta = metas[idx]
# if using e.g., `batch[0:2]`, then `is_batch` should still be
# `True`. Also re-collate the remaining elements.
if isinstance(meta, list) and len(meta) > 1:
ret.meta = list_data_collate(meta)
# if using e.g., `batch[0]` or `batch[0, 1]`, then return single
# element from batch, and set `is_batch` to `False`.
else:
ret.meta = meta
ret.is_batch = False
# `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
# But we only want to split the batch if the `unbind` is along the 0th
# dimension.
elif func == torch.Tensor.unbind:
if len(args) > 1:
dim = args[1]
elif "dim" in kwargs:
dim = kwargs["dim"]
else:
dim = 0
if dim == 0:
ret.meta = metas[idx]
ret.is_batch = False
ret.affine = ret.affine.to(ret.device)
out.append(ret)
# if the input was a tuple, then return it as a tuple
return tuple(out) if isinstance(rets, tuple) else out
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any:
"""Wraps all torch functions."""
if kwargs is None:
kwargs = {}
ret = super().__torch_function__(func, types, args, kwargs)
# if `out` has been used as argument, metadata is not copied, nothing to do.
# if "out" in kwargs:
# return ret
# we might have 1 or multiple outputs. Might be MetaTensor, might be something
# else (e.g., `__repr__` returns a string).
# Convert to list (if necessary), process, and at end remove list if one was added.
if isinstance(ret, (str, bytes)) or not isinstance(ret, Sequence):
ret = [ret]
unpack = True
else:
unpack = False
ret = MetaTensor.update_meta(ret, func, args, kwargs)
return ret[0] if unpack else ret
def get_default_affine(self, dtype=torch.float64) -> torch.Tensor:
return torch.eye(4, device=self.device, dtype=dtype)
[docs] def as_tensor(self) -> torch.Tensor:
"""
Return the `MetaTensor` as a `torch.Tensor`.
It is OS dependent as to whether this will be a deep copy or not.
"""
return self.as_subclass(torch.Tensor) # type: ignore
[docs] def as_dict(self, key: str) -> dict:
"""
Get the object as a dictionary for backwards compatibility.
Args:
key: Base key to store main data. The key for the metadata will be
determined using `PostFix.meta`.
Return:
A dictionary consisting of two keys, the main data (stored under `key`) and
the metadata.
"""
return {
key: self.as_tensor(),
PostFix.meta(key): deepcopy(self.meta),
PostFix.transforms(key): deepcopy(self.applied_operations),
}
@property
def affine(self) -> torch.Tensor:
"""Get the affine."""
return self.meta["affine"] # type: ignore
@affine.setter
def affine(self, d: NdarrayTensor) -> None:
"""Set the affine."""
self.meta["affine"] = torch.as_tensor(d, device=self.device)
[docs] def new_empty(self, size, dtype=None, device=None, requires_grad=False):
"""
must be defined for deepcopy to work
See:
- https://pytorch.org/docs/stable/generated/torch.Tensor.new_empty.html#torch-tensor-new-empty
"""
return type(self)(
self.as_tensor().new_empty(size=size, dtype=dtype, device=device, requires_grad=requires_grad)
)
[docs] @staticmethod
def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict):
"""
Convert the image to `torch.Tensor`. If `affine` is in the `meta` dictionary,
convert that to `torch.Tensor`, too. Remove any superfluous metadata.
Args:
im: Input image (`np.ndarray` or `torch.Tensor`)
meta: Metadata dictionary.
Returns:
By default, a `MetaTensor` is returned.
However, if `get_track_meta()` is `False`, a `torch.Tensor` is returned.
"""
img = torch.as_tensor(im)
# if not tracking metadata, return `torch.Tensor`
if not get_track_meta() or meta is None:
return img
# ensure affine is of type `torch.Tensor`
if "affine" in meta:
meta["affine"] = torch.as_tensor(meta["affine"])
# remove any superfluous metadata.
remove_extra_metadata(meta)
# return the `MetaTensor`
return MetaTensor(img, meta=meta)
def __repr__(self, *, tensor_contents=None):
return self.as_tensor().__repr__() + super().__repr__()