Data#

Generic Interfaces#

Dataset#

class monai.data.Dataset(data, transform=None)[source]#

A generic dataset with a length property and an optional callable data transform when fetching a data sample. If passing slicing indices, will return a PyTorch Subset, for example: data: Subset = dataset[1:4], for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset

For example, typical input data can be a list of dictionaries:

[{                            {                            {
     'img': 'image1.nii.gz',      'img': 'image2.nii.gz',      'img': 'image3.nii.gz',
     'seg': 'label1.nii.gz',      'seg': 'label2.nii.gz',      'seg': 'label3.nii.gz',
     'extra': 123                 'extra': 456                 'extra': 789
 },                           },                           }]
__getitem__(index)[source]#

Returns a Subset if index is a slice or Sequence, a data item otherwise.

__init__(data, transform=None)[source]#
Parameters
  • data (Sequence) – input data to load and transform to generate dataset for model.

  • transform (Optional[Callable]) – a callable data transform on input data.

IterableDataset#

class monai.data.IterableDataset(data, transform=None)[source]#

A generic dataset for iterable data source and an optional callable data transform when fetching a data sample. Inherit from PyTorch IterableDataset: https://pytorch.org/docs/stable/data.html?highlight=iterabledataset#torch.utils.data.IterableDataset. For example, typical input data can be web data stream which can support multi-process access.

To accelerate the loading process, it can support multi-processing based on PyTorch DataLoader workers, every process executes transforms on part of every loaded data. Note that the order of output data may not match data source in multi-processing mode. And each worker process will have a different copy of the dataset object, need to guarantee process-safe from data source or DataLoader.

__init__(data, transform=None)[source]#
Parameters
  • data (Iterable) – input data source to load and transform to generate dataset for model.

  • transform (Optional[Callable]) – a callable data transform on input data.

DatasetFunc#

class monai.data.DatasetFunc(data, func, **kwargs)[source]#

Execute function on the input dataset and leverage the output to act as a new Dataset. It can be used to load / fetch the basic dataset items, like the list of image, label paths. Or chain together to execute more complicated logic, like partition_dataset, resample_datalist, etc. The data arg of Dataset will be applied to the first arg of callable func. Usage example:

data_list = DatasetFunc(
    data="path to file",
    func=monai.data.load_decathlon_datalist,
    data_list_key="validation",
    base_dir="path to base dir",
)
# partition dataset for every rank
data_partition = DatasetFunc(
    data=data_list,
    func=lambda **kwargs: monai.data.partition_dataset(**kwargs)[torch.distributed.get_rank()],
    num_partitions=torch.distributed.get_world_size(),
)
dataset = Dataset(data=data_partition, transform=transforms)
Parameters
  • data (Any) – input data for the func to process, will apply to func as the first arg.

  • func (Callable) – callable function to generate dataset items.

  • kwargs – other arguments for the func except for the first arg.

reset(data=None, func=None, **kwargs)[source]#

Reset the dataset items with specified func.

Parameters
  • data (Optional[Any]) – if not None, execute func on it, default to self.src.

  • func (Optional[Callable]) – if not None, execute the func with specified kwargs, default to self.func.

  • kwargs – other arguments for the func except for the first arg.

ShuffleBuffer#

class monai.data.ShuffleBuffer(data, transform=None, buffer_size=512, seed=0)[source]#

Extend the IterableDataset with a buffer and randomly pop items.

Parameters
  • data – input data source to load and transform to generate dataset for model.

  • transform – a callable data transform on input data.

  • buffer_size (int) – size of the buffer to store items and randomly pop, default to 512.

  • seed (int) – random seed to initialize the random state of all workers, set seed += 1 in every iter() call, refer to the PyTorch idea: https://github.com/pytorch/pytorch/blob/v1.10.0/torch/utils/data/distributed.py#L98.

Note

Both monai.data.DataLoader and torch.utils.data.DataLoader do not seed this class (as a subclass of IterableDataset) at run time. persistent_workers=True flag (and pytorch>1.8) is therefore required for multiple epochs of loading when num_workers>0. For example:

import monai

def run():
    dss = monai.data.ShuffleBuffer([1, 2, 3, 4], buffer_size=30, seed=42)

    dataloader = monai.data.DataLoader(
        dss, batch_size=1, num_workers=2, persistent_workers=True)
    for epoch in range(3):
        for item in dataloader:
            print(f"epoch: {epoch} item: {item}.")

if __name__ == '__main__':
    run()
generate_item()[source]#

Fill a buffer list up to self.size, then generate randomly popped items.

randomize(size)[source]#

Within this method, self.R should be used, instead of np.random, to introduce random factors.

all self.R calls happen here so that we have a better chance to identify errors of sync the random state.

This method can generate the random factors based on properties of the input data.

Raises

NotImplementedError – When the subclass does not override this method.

Return type

None

randomized_pop(buffer)[source]#

Return the item at a randomized location self._idx in buffer.

CSVIterableDataset#

class monai.data.CSVIterableDataset(src, chunksize=1000, buffer_size=None, col_names=None, col_types=None, col_groups=None, transform=None, shuffle=False, seed=0, kwargs_read_csv=None, **kwargs)[source]#

Iterable dataset to load CSV files and generate dictionary data. It is particularly useful when data come from a stream, inherits from PyTorch IterableDataset: https://pytorch.org/docs/stable/data.html?highlight=iterabledataset#torch.utils.data.IterableDataset.

It also can be helpful when loading extremely big CSV files that can’t read into memory directly, just treat the big CSV file as stream input, call reset() of CSVIterableDataset for every epoch. Note that as a stream input, it can’t get the length of dataset.

To effectively shuffle the data in the big dataset, users can set a big buffer to continuously store the loaded data, then randomly pick data from the buffer for following tasks.

To accelerate the loading process, it can support multi-processing based on PyTorch DataLoader workers, every process executes transforms on part of every loaded data. Note: the order of output data may not match data source in multi-processing mode.

It can load data from multiple CSV files and join the tables with additional kwargs arg. Support to only load specific columns. And it can also group several loaded columns to generate a new column, for example, set col_groups={“meta”: [“meta_0”, “meta_1”, “meta_2”]}, output can be:

[
    {"image": "./image0.nii", "meta_0": 11, "meta_1": 12, "meta_2": 13, "meta": [11, 12, 13]},
    {"image": "./image1.nii", "meta_0": 21, "meta_1": 22, "meta_2": 23, "meta": [21, 22, 23]},
]
Parameters
  • src (Union[str, Sequence[str], Iterable, Sequence[Iterable]]) – if provided the filename of CSV file, it can be a str, URL, path object or file-like object to load. also support to provide iter for stream input directly, will skip loading from filename. if provided a list of filenames or iters, it will join the tables.

  • chunksize (int) – rows of a chunk when loading iterable data from CSV files, default to 1000. more details: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html.

  • buffer_size (Optional[int]) – size of the buffer to store the loaded chunks, if None, set to 2 x chunksize.

  • col_names (Optional[Sequence[str]]) – names of the expected columns to load. if None, load all the columns.

  • col_types (Optional[Dict[str, Optional[Dict[str, Any]]]]) –

    type and default value to convert the loaded columns, if None, use original data. it should be a dictionary, every item maps to an expected column, the key is the column name and the value is None or a dictionary to define the default value and data type. the supported keys in dictionary are: [“type”, “default”]. for example:

    col_types = {
        "subject_id": {"type": str},
        "label": {"type": int, "default": 0},
        "ehr_0": {"type": float, "default": 0.0},
        "ehr_1": {"type": float, "default": 0.0},
        "image": {"type": str, "default": None},
    }
    

  • col_groups (Optional[Dict[str, Sequence[str]]]) – args to group the loaded columns to generate a new column, it should be a dictionary, every item maps to a group, the key will be the new column name, the value is the names of columns to combine. for example: col_groups={“ehr”: [f”ehr_{i}” for i in range(10)], “meta”: [“meta_1”, “meta_2”]}

  • transform (Optional[Callable]) – transform to apply on the loaded items of a dictionary data.

  • shuffle (bool) – whether to shuffle all the data in the buffer every time a new chunk loaded.

  • seed (int) – random seed to initialize the random state for all the workers if shuffle is True, set seed += 1 in every iter() call, refer to the PyTorch idea: https://github.com/pytorch/pytorch/blob/v1.10.0/torch/utils/data/distributed.py#L98.

  • kwargs_read_csv (Optional[Dict]) – dictionary args to pass to pandas read_csv function. Default to {"chunksize": chunksize}.

  • kwargs – additional arguments for pandas.merge() API to join tables.

Deprecated since version 0.8.0: filename is deprecated, use src instead.

close()[source]#

Close the pandas TextFileReader iterable objects. If the input src is file path, TextFileReader was created internally, need to close it. If the input src is iterable object, depends on users requirements whether to close it in this function. For more details, please check: https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html?#iteration.

reset(src=None)[source]#

Reset the pandas TextFileReader iterable object to read data. For more details, please check: https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html?#iteration.

Parameters

src (Union[str, Sequence[str], Iterable, Sequence[Iterable], None]) – if not None and provided the filename of CSV file, it can be a str, URL, path object or file-like object to load. also support to provide iter for stream input directly, will skip loading from filename. if provided a list of filenames or iters, it will join the tables. default to self.src.

PersistentDataset#

class monai.data.PersistentDataset(data, transform, cache_dir, hash_func=<function pickle_hashing>, pickle_module='pickle', pickle_protocol=2, hash_transform=None, reset_ops_id=True)[source]#

Persistent storage of pre-computed values to efficiently manage larger than memory dictionary format data, it can operate transforms for specific fields. Results from the non-random transform components are computed when first used, and stored in the cache_dir for rapid retrieval on subsequent uses. If passing slicing indices, will return a PyTorch Subset, for example: data: Subset = dataset[1:4], for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset

The transforms which are supposed to be cached must implement the monai.transforms.Transform interface and should not be Randomizable. This dataset will cache the outcomes before the first Randomizable Transform within a Compose instance.

For example, typical input data can be a list of dictionaries:

[{                            {                            {
    'image': 'image1.nii.gz',    'image': 'image2.nii.gz',    'image': 'image3.nii.gz',
    'label': 'label1.nii.gz',    'label': 'label2.nii.gz',    'label': 'label3.nii.gz',
    'extra': 123                 'extra': 456                 'extra': 789
},                           },                           }]

For a composite transform like

[ LoadImaged(keys=['image', 'label']),
Orientationd(keys=['image', 'label'], axcodes='RAS'),
ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', spatial_size=(96, 96, 96),
                        pos=1, neg=1, num_samples=4, image_key='image', image_threshold=0),
ToTensord(keys=['image', 'label'])]

Upon first use a filename based dataset will be processed by the transform for the [LoadImaged, Orientationd, ScaleIntensityRanged] and the resulting tensor written to the cache_dir before applying the remaining random dependant transforms [RandCropByPosNegLabeld, ToTensord] elements for use in the analysis.

Subsequent uses of a dataset directly read pre-processed results from cache_dir followed by applying the random dependant parts of transform processing.

During training call set_data() to update input data and recompute cache content.

Note

The input data must be a list of file paths and will hash them as cache keys.

The filenames of the cached files also try to contain the hash of the transforms. In this fashion, PersistentDataset should be robust to changes in transforms. This, however, is not guaranteed, so caution should be used when modifying transforms to avoid unexpected errors. If in doubt, it is advisable to clear the cache directory.

__init__(data, transform, cache_dir, hash_func=<function pickle_hashing>, pickle_module='pickle', pickle_protocol=2, hash_transform=None, reset_ops_id=True)[source]#
Parameters
  • data (Sequence) – input data file paths to load and transform to generate dataset for model. PersistentDataset expects input data to be a list of serializable and hashes them as cache keys using hash_func.

  • transform (Union[Sequence[Callable], Callable]) – transforms to execute operations on input data.

  • cache_dir (Union[Path, str, None]) – If specified, this is the location for persistent storage of pre-computed transformed data tensors. The cache_dir is computed once, and persists on disk until explicitly removed. Different runs, programs, experiments may share a common cache dir provided that the transforms pre-processing is consistent. If cache_dir doesn’t exist, will automatically create it. If cache_dir is None, there is effectively no caching.

  • hash_func (Callable[…, bytes]) – a callable to compute hash from data items to be cached. defaults to monai.data.utils.pickle_hashing.

  • pickle_module (str) – string representing the module used for pickling metadata and objects, default to “pickle”. due to the pickle limitation in multi-processing of Dataloader, we can’t use pickle as arg directly, so here we use a string name instead. if want to use other pickle module at runtime, just register like: >>> from monai.data import utils >>> utils.SUPPORTED_PICKLE_MOD[“test”] = other_pickle this arg is used by torch.save, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, and monai.data.utils.SUPPORTED_PICKLE_MOD.

  • pickle_protocol (int) – can be specified to override the default protocol, default to 2. this arg is used by torch.save, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.

  • hash_transform (Optional[Callable[…, bytes]]) – a callable to compute hash from the transform information when caching. This may reduce errors due to transforms changing during experiments. Default to None (no hash). Other options are pickle_hashing and json_hashing functions from monai.data.utils.

  • reset_ops_id (bool) – whether to set TraceKeys.ID to Tracekys.NONE, defaults to True. When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors. This is useful for skipping the transform instance checks when inverting applied operations using the cached content and with re-created transform instances.

set_data(data)[source]#

Set the input data and delete all the out-dated cache content.

set_transform_hash(hash_xform_func)[source]#

Get hashable transforms, and then hash them. Hashable transforms are deterministic transforms that inherit from Transform. We stop at the first non-deterministic transform, or first that does not inherit from MONAI’s Transform class.

CacheNTransDataset#

class monai.data.CacheNTransDataset(data, transform, cache_n_trans, cache_dir, hash_func=<function pickle_hashing>, pickle_module='pickle', pickle_protocol=2, hash_transform=None, reset_ops_id=True)[source]#

Extension of PersistentDataset, tt can also cache the result of first N transforms, no matter it’s random or not.

__init__(data, transform, cache_n_trans, cache_dir, hash_func=<function pickle_hashing>, pickle_module='pickle', pickle_protocol=2, hash_transform=None, reset_ops_id=True)[source]#
Parameters
  • data (Sequence) – input data file paths to load and transform to generate dataset for model. PersistentDataset expects input data to be a list of serializable and hashes them as cache keys using hash_func.

  • transform (Union[Sequence[Callable], Callable]) – transforms to execute operations on input data.

  • cache_n_trans (int) – cache the result of first N transforms.

  • cache_dir (Union[Path, str, None]) – If specified, this is the location for persistent storage of pre-computed transformed data tensors. The cache_dir is computed once, and persists on disk until explicitly removed. Different runs, programs, experiments may share a common cache dir provided that the transforms pre-processing is consistent. If cache_dir doesn’t exist, will automatically create it. If cache_dir is None, there is effectively no caching.

  • hash_func (Callable[…, bytes]) – a callable to compute hash from data items to be cached. defaults to monai.data.utils.pickle_hashing.

  • pickle_module (str) – string representing the module used for pickling metadata and objects, default to “pickle”. due to the pickle limitation in multi-processing of Dataloader, we can’t use pickle as arg directly, so here we use a string name instead. if want to use other pickle module at runtime, just register like: >>> from monai.data import utils >>> utils.SUPPORTED_PICKLE_MOD[“test”] = other_pickle this arg is used by torch.save, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, and monai.data.utils.SUPPORTED_PICKLE_MOD.

  • pickle_protocol (int) – can be specified to override the default protocol, default to 2. this arg is used by torch.save, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.

  • hash_transform (Optional[Callable[…, bytes]]) – a callable to compute hash from the transform information when caching. This may reduce errors due to transforms changing during experiments. Default to None (no hash). Other options are pickle_hashing and json_hashing functions from monai.data.utils.

  • reset_ops_id (bool) – whether to set TraceKeys.ID to Tracekys.NONE, defaults to True. When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors. This is useful for skipping the transform instance checks when inverting applied operations using the cached content and with re-created transform instances.

LMDBDataset#

class monai.data.LMDBDataset(data, transform, cache_dir='cache', hash_func=<function pickle_hashing>, db_name='monai_cache', progress=True, pickle_protocol=4, hash_transform=None, reset_ops_id=True, lmdb_kwargs=None)[source]#

Extension of PersistentDataset using LMDB as the backend.

Examples

>>> items = [{"data": i} for i in range(5)]
# [{'data': 0}, {'data': 1}, {'data': 2}, {'data': 3}, {'data': 4}]
>>> lmdb_ds = monai.data.LMDBDataset(items, transform=monai.transforms.SimulateDelayd("data", delay_time=1))
>>> print(list(lmdb_ds))  # using the cached results
__init__(data, transform, cache_dir='cache', hash_func=<function pickle_hashing>, db_name='monai_cache', progress=True, pickle_protocol=4, hash_transform=None, reset_ops_id=True, lmdb_kwargs=None)[source]#
Parameters
  • data (Sequence) – input data file paths to load and transform to generate dataset for model. LMDBDataset expects input data to be a list of serializable and hashes them as cache keys using hash_func.

  • transform (Union[Sequence[Callable], Callable]) – transforms to execute operations on input data.

  • cache_dir (Union[Path, str]) – if specified, this is the location for persistent storage of pre-computed transformed data tensors. The cache_dir is computed once, and persists on disk until explicitly removed. Different runs, programs, experiments may share a common cache dir provided that the transforms pre-processing is consistent. If the cache_dir doesn’t exist, will automatically create it. Defaults to “./cache”.

  • hash_func (Callable[…, bytes]) – a callable to compute hash from data items to be cached. defaults to monai.data.utils.pickle_hashing.

  • db_name (str) – lmdb database file name. Defaults to “monai_cache”.

  • progress (bool) – whether to display a progress bar.

  • pickle_protocol – pickle protocol version. Defaults to pickle.HIGHEST_PROTOCOL. https://docs.python.org/3/library/pickle.html#pickle-protocols

  • hash_transform (Optional[Callable[…, bytes]]) – a callable to compute hash from the transform information when caching. This may reduce errors due to transforms changing during experiments. Default to None (no hash). Other options are pickle_hashing and json_hashing functions from monai.data.utils.

  • reset_ops_id (bool) – whether to set TraceKeys.ID to Tracekys.NONE, defaults to True. When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors. This is useful for skipping the transform instance checks when inverting applied operations using the cached content and with re-created transform instances.

  • lmdb_kwargs (Optional[dict]) – additional keyword arguments to the lmdb environment. for more details please visit: https://lmdb.readthedocs.io/en/release/#environment-class

info()[source]#

Returns: dataset info dictionary.

set_data(data)[source]#

Set the input data and delete all the out-dated cache content.

CacheDataset#

class monai.data.CacheDataset(data, transform=None, cache_num=9223372036854775807, cache_rate=1.0, num_workers=1, progress=True, copy_cache=True, as_contiguous=True, hash_as_key=False, hash_func=<function pickle_hashing>)[source]#

Dataset with cache mechanism that can load data and cache deterministic transforms’ result during training.

By caching the results of non-random preprocessing transforms, it accelerates the training data pipeline. If the requested data is not in the cache, all transforms will run normally (see also monai.data.dataset.Dataset).

Users can set the cache rate or number of items to cache. It is recommended to experiment with different cache_num or cache_rate to identify the best training speed.

The transforms which are supposed to be cached must implement the monai.transforms.Transform interface and should not be Randomizable. This dataset will cache the outcomes before the first Randomizable Transform within a Compose instance. So to improve the caching efficiency, please always put as many as possible non-random transforms before the randomized ones when composing the chain of transforms. If passing slicing indices, will return a PyTorch Subset, for example: data: Subset = dataset[1:4], for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset

For example, if the transform is a Compose of:

transforms = Compose([
    LoadImaged(),
    EnsureChannelFirstd(),
    Spacingd(),
    Orientationd(),
    ScaleIntensityRanged(),
    RandCropByPosNegLabeld(),
    ToTensord()
])

when transforms is used in a multi-epoch training pipeline, before the first training epoch, this dataset will cache the results up to ScaleIntensityRanged, as all non-random transforms LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged can be cached. During training, the dataset will load the cached results and run RandCropByPosNegLabeld and ToTensord, as RandCropByPosNegLabeld is a randomized transform and the outcome not cached.

During training call set_data() to update input data and recompute cache content, note that it requires persistent_workers=False in the PyTorch DataLoader.

Note

CacheDataset executes non-random transforms and prepares cache content in the main process before the first epoch, then all the subprocesses of DataLoader will read the same cache content in the main process during training. it may take a long time to prepare cache content according to the size of expected cache data. So to debug or verify the program before real training, users can set cache_rate=0.0 or cache_num=0 to temporarily skip caching.

__init__(data, transform=None, cache_num=9223372036854775807, cache_rate=1.0, num_workers=1, progress=True, copy_cache=True, as_contiguous=True, hash_as_key=False, hash_func=<function pickle_hashing>)[source]#
Parameters
  • data (Sequence) – input data to load and transform to generate dataset for model.

  • transform (Union[Sequence[Callable], Callable, None]) – transforms to execute operations on input data.

  • cache_num (int) – number of items to be cached. Default is sys.maxsize. will take the minimum of (cache_num, data_length x cache_rate, data_length).

  • cache_rate (float) – percentage of cached data in total, default is 1.0 (cache all). will take the minimum of (cache_num, data_length x cache_rate, data_length).

  • num_workers (Optional[int]) – the number of worker threads to use. If num_workers is None then the number returned by os.cpu_count() is used. If a value less than 1 is speficied, 1 will be used instead.

  • progress (bool) – whether to display a progress bar.

  • copy_cache (bool) – whether to deepcopy the cache content before applying the random transforms, default to True. if the random transforms don’t modify the cached content (for example, randomly crop from the cached image and deepcopy the crop region) or if every cache item is only used once in a multi-processing environment, may set copy=False for better performance.

  • as_contiguous (bool) – whether to convert the cached NumPy array or PyTorch tensor to be contiguous. it may help improve the performance of following logic.

  • hash_as_key (bool) – whether to compute hash value of input data as the key to save cache, if key exists, avoid saving duplicated content. it can help save memory when the dataset has duplicated items or augmented dataset.

  • hash_func (Callable[…, bytes]) – if hash_as_key, a callable to compute hash from data items to be cached. defaults to monai.data.utils.pickle_hashing.

set_data(data)[source]#

Set the input data and run deterministic transforms to generate cache content.

Note: should call this func after an entire epoch and must set persistent_workers=False in PyTorch DataLoader, because it needs to create new worker processes based on new generated cache content.

SmartCacheDataset#

class monai.data.SmartCacheDataset(data, transform=None, replace_rate=0.1, cache_num=9223372036854775807, cache_rate=1.0, num_init_workers=1, num_replace_workers=1, progress=True, shuffle=True, seed=0, copy_cache=True, as_contiguous=True)[source]#

Re-implementation of the SmartCache mechanism in NVIDIA Clara-train SDK. At any time, the cache pool only keeps a subset of the whole dataset. In each epoch, only the items in the cache are used for training. This ensures that data needed for training is readily available, keeping GPU resources busy. Note that cached items may still have to go through a non-deterministic transform sequence before being fed to GPU. At the same time, another thread is preparing replacement items by applying the transform sequence to items not in cache. Once one epoch is completed, Smart Cache replaces the same number of items with replacement items. Smart Cache uses a simple running window algorithm to determine the cache content and replacement items. Let N be the configured number of objects in cache; and R be the number of replacement objects (R = ceil(N * r), where r is the configured replace rate). For more details, please refer to: https://docs.nvidia.com/clara/tlt-mi/clara-train-sdk-v3.0/nvmidl/additional_features/smart_cache.html#smart-cache If passing slicing indices, will return a PyTorch Subset, for example: data: Subset = dataset[1:4], for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset

For example, if we have 5 images: [image1, image2, image3, image4, image5], and cache_num=4, replace_rate=0.25. so the actual training images cached and replaced for every epoch are as below:

epoch 1: [image1, image2, image3, image4]
epoch 2: [image2, image3, image4, image5]
epoch 3: [image3, image4, image5, image1]
epoch 3: [image4, image5, image1, image2]
epoch N: [image[N % 5] ...]

The usage of SmartCacheDataset contains 4 steps:

  1. Initialize SmartCacheDataset object and cache for the first epoch.

  2. Call start() to run replacement thread in background.

  3. Call update_cache() before every epoch to replace training items.

  4. Call shutdown() when training ends.

During training call set_data() to update input data and recompute cache content, note to call shutdown() to stop first, then update data and call start() to restart.

Note

This replacement will not work for below cases: 1. Set the multiprocessing_context of DataLoader to spawn. 2. Launch distributed data parallel with torch.multiprocessing.spawn. 3. Run on windows(the default multiprocessing method is spawn) with num_workers greater than 0. 4. Set the persistent_workers of DataLoader to True with num_workers greater than 0.

If using MONAI workflows, please add SmartCacheHandler to the handler list of trainer, otherwise, please make sure to call start(), update_cache(), shutdown() during training.

Parameters
  • data (Sequence) – input data to load and transform to generate dataset for model.

  • transform (Union[Sequence[Callable], Callable, None]) – transforms to execute operations on input data.

  • replace_rate (float) – percentage of the cached items to be replaced in every epoch (default to 0.1).

  • cache_num (int) – number of items to be cached. Default is sys.maxsize. will take the minimum of (cache_num, data_length x cache_rate, data_length).

  • cache_rate (float) – percentage of cached data in total, default is 1.0 (cache all). will take the minimum of (cache_num, data_length x cache_rate, data_length).

  • num_init_workers (Optional[int]) – the number of worker threads to initialize the cache for first epoch. If num_init_workers is None then the number returned by os.cpu_count() is used. If a value less than 1 is speficied, 1 will be used instead.

  • num_replace_workers (Optional[int]) – the number of worker threads to prepare the replacement cache for every epoch. If num_replace_workers is None then the number returned by os.cpu_count() is used. If a value less than 1 is speficied, 1 will be used instead.

  • progress (bool) – whether to display a progress bar when caching for the first epoch.

  • shuffle (bool) – whether to shuffle the whole data list before preparing the cache content for first epoch. it will not modify the original input data sequence in-place.

  • seed (int) – random seed if shuffle is True, default to 0.

  • copy_cache (bool) – whether to deepcopy the cache content before applying the random transforms, default to True. if the random transforms don’t modify the cache content or every cache item is only used once in a multi-processing environment, may set copy=False for better performance.

  • as_contiguous (bool) – whether to convert the cached NumPy array or PyTorch tensor to be contiguous. it may help improve the performance of following logic.

is_started()[source]#

Check whether the replacement thread is already started.

manage_replacement()[source]#

Background thread for replacement.

randomize(data)[source]#

Within this method, self.R should be used, instead of np.random, to introduce random factors.

all self.R calls happen here so that we have a better chance to identify errors of sync the random state.

This method can generate the random factors based on properties of the input data.

Raises

NotImplementedError – When the subclass does not override this method.

Return type

None

set_data(data)[source]#

Set the input data and run deterministic transforms to generate cache content.

Note: should call shutdown() before calling this func.

shutdown()[source]#

Shut down the background thread for replacement.

start()[source]#

Start the background thread to replace training items for every epoch.

update_cache()[source]#

Update cache items for current epoch, need to call this function before every epoch. If the cache has been shutdown before, need to restart the _replace_mgr thread.

ZipDataset#

class monai.data.ZipDataset(datasets, transform=None)[source]#

Zip several PyTorch datasets and output data(with the same index) together in a tuple. If the output of single dataset is already a tuple, flatten it and extend to the result. For example: if datasetA returns (img, imgmeta), datasetB returns (seg, segmeta), finally return (img, imgmeta, seg, segmeta). And if the datasets don’t have same length, use the minimum length of them as the length of ZipDataset. If passing slicing indices, will return a PyTorch Subset, for example: data: Subset = dataset[1:4], for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset

Examples:

>>> zip_data = ZipDataset([[1, 2, 3], [4, 5]])
>>> print(len(zip_data))
2
>>> for item in zip_data:
>>>    print(item)
[1, 4]
[2, 5]
__init__(datasets, transform=None)[source]#
Parameters
  • datasets (Sequence) – list of datasets to zip together.

  • transform (Optional[Callable]) – a callable data transform operates on the zipped item from datasets.

ArrayDataset#

class monai.data.ArrayDataset(img, img_transform=None, seg=None, seg_transform=None, labels=None, label_transform=None)[source]#

Dataset for segmentation and classification tasks based on array format input data and transforms. It ensures the same random seeds in the randomized transforms defined for image, segmentation and label. The transform can be monai.transforms.Compose or any other callable object. For example: If train based on Nifti format images without metadata, all transforms can be composed:

img_transform = Compose(
    [
        LoadImage(image_only=True),
        EnsureChannelFirst(),
        RandAdjustContrast()
    ]
)
ArrayDataset(img_file_list, img_transform=img_transform)

If training based on images and the metadata, the array transforms can not be composed because several transforms receives multiple parameters or return multiple values. Then Users need to define their own callable method to parse metadata from LoadImage or set affine matrix to Spacing transform:

class TestCompose(Compose):
    def __call__(self, input_):
        img, metadata = self.transforms[0](input_)
        img = self.transforms[1](img)
        img, _, _ = self.transforms[2](img, metadata["affine"])
        return self.transforms[3](img), metadata
img_transform = TestCompose(
    [
        LoadImage(image_only=False),
        EnsureChannelFirst(),
        Spacing(pixdim=(1.5, 1.5, 3.0)),
        RandAdjustContrast()
    ]
)
ArrayDataset(img_file_list, img_transform=img_transform)

Examples:

>>> ds = ArrayDataset([1, 2, 3, 4], lambda x: x + 0.1)
>>> print(ds[0])
1.1

>>> ds = ArrayDataset(img=[1, 2, 3, 4], seg=[5, 6, 7, 8])
>>> print(ds[0])
[1, 5]
__init__(img, img_transform=None, seg=None, seg_transform=None, labels=None, label_transform=None)[source]#

Initializes the dataset with the filename lists. The transform img_transform is applied to the images and seg_transform to the segmentations.

Parameters
  • img (Sequence) – sequence of images.

  • img_transform (Optional[Callable]) – transform to apply to each element in img.

  • seg (Optional[Sequence]) – sequence of segmentations.

  • seg_transform (Optional[Callable]) – transform to apply to each element in seg.

  • labels (Optional[Sequence]) – sequence of labels.

  • label_transform (Optional[Callable]) – transform to apply to each element in labels.

randomize(data=None)[source]#

Within this method, self.R should be used, instead of np.random, to introduce random factors.

all self.R calls happen here so that we have a better chance to identify errors of sync the random state.

This method can generate the random factors based on properties of the input data.

Raises

NotImplementedError – When the subclass does not override this method.

Return type

None

ImageDataset#

class monai.data.ImageDataset(image_files, seg_files=None, labels=None, transform=None, seg_transform=None, label_transform=None, image_only=True, transform_with_metadata=False, dtype=<class 'numpy.float32'>, reader=None, *args, **kwargs)[source]#

Loads image/segmentation pairs of files from the given filename lists. Transformations can be specified for the image and segmentation arrays separately. The difference between this dataset and ArrayDataset is that this dataset can apply transform chain to images and segs and return both the images and metadata, and no need to specify transform to load images from files. For more information, please see the image_dataset demo in the MONAI tutorial repo, https://github.com/Project-MONAI/tutorials/blob/master/modules/image_dataset.ipynb

__init__(image_files, seg_files=None, labels=None, transform=None, seg_transform=None, label_transform=None, image_only=True, transform_with_metadata=False, dtype=<class 'numpy.float32'>, reader=None, *args, **kwargs)[source]#

Initializes the dataset with the image and segmentation filename lists. The transform transform is applied to the images and seg_transform to the segmentations.

Parameters
  • image_files (Sequence[str]) – list of image filenames.

  • seg_files (Optional[Sequence[str]]) – if in segmentation task, list of segmentation filenames.

  • labels (Optional[Sequence[float]]) – if in classification task, list of classification labels.

  • transform (Optional[Callable]) – transform to apply to image arrays.

  • seg_transform (Optional[Callable]) – transform to apply to segmentation arrays.

  • label_transform (Optional[Callable]) – transform to apply to the label data.

  • image_only (bool) – if True return only the image volume, otherwise, return image volume and the metadata.

  • transform_with_metadata (bool) – if True, the metadata will be passed to the transforms whenever possible.

  • dtype (Union[dtype, type, str, None]) – if not None convert the loaded image to this data type.

  • reader (Union[ImageReader, str, None]) – register reader to load image file and metadata, if None, will use the default readers. If a string of reader name provided, will construct a reader object with the *args and **kwargs parameters, supported reader name: “NibabelReader”, “PILReader”, “ITKReader”, “NumpyReader”

  • args – additional parameters for reader if providing a reader name.

  • kwargs – additional parameters for reader if providing a reader name.

Raises

ValueError – When seg_files length differs from image_files

randomize(data=None)[source]#

Within this method, self.R should be used, instead of np.random, to introduce random factors.

all self.R calls happen here so that we have a better chance to identify errors of sync the random state.

This method can generate the random factors based on properties of the input data.

Raises

NotImplementedError – When the subclass does not override this method.

Return type

None

NPZDictItemDataset#

class monai.data.NPZDictItemDataset(npzfile, keys, transform=None, other_keys=())[source]#

Represents a dataset from a loaded NPZ file. The members of the file to load are named in the keys of keys and stored under the keyed name. All loaded arrays must have the same 0-dimension (batch) size. Items are always dicts mapping names to an item extracted from the loaded arrays. If passing slicing indices, will return a PyTorch Subset, for example: data: Subset = dataset[1:4], for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset

Parameters
  • npzfile (Union[str, IO]) – Path to .npz file or stream containing .npz file data

  • keys (Dict[str, str]) – Maps keys to load from file to name to store in dataset

  • transform (Optional[Callable[…, Dict[str, Any]]]) – Transform to apply to batch dict

  • other_keys (Optional[Sequence[str]]) – secondary data to load from file and store in dict other_keys, not returned by __getitem__

CSVDataset#

class monai.data.CSVDataset(src=None, row_indices=None, col_names=None, col_types=None, col_groups=None, transform=None, kwargs_read_csv=None, **kwargs)[source]#

Dataset to load data from CSV files and generate a list of dictionaries, every dictionary maps to a row of the CSV file, and the keys of dictionary map to the column names of the CSV file.

It can load multiple CSV files and join the tables with additional kwargs arg. Support to only load specific rows and columns. And it can also group several loaded columns to generate a new column, for example, set col_groups={“meta”: [“meta_0”, “meta_1”, “meta_2”]}, output can be:

[
    {"image": "./image0.nii", "meta_0": 11, "meta_1": 12, "meta_2": 13, "meta": [11, 12, 13]},
    {"image": "./image1.nii", "meta_0": 21, "meta_1": 22, "meta_2": 23, "meta": [21, 22, 23]},
]
Parameters
  • src (Union[str, Sequence[str], None]) – if provided the filename of CSV file, it can be a str, URL, path object or file-like object to load. also support to provide pandas DataFrame directly, will skip loading from filename. if provided a list of filenames or pandas DataFrame, it will join the tables.

  • row_indices (Optional[Sequence[Union[str, int]]]) – indices of the expected rows to load. it should be a list, every item can be a int number or a range [start, end) for the indices. for example: row_indices=[[0, 100], 200, 201, 202, 300]. if None, load all the rows in the file.

  • col_names (Optional[Sequence[str]]) – names of the expected columns to load. if None, load all the columns.

  • col_types (Optional[Dict[str, Optional[Dict[str, Any]]]]) –

    type and default value to convert the loaded columns, if None, use original data. it should be a dictionary, every item maps to an expected column, the key is the column name and the value is None or a dictionary to define the default value and data type. the supported keys in dictionary are: [“type”, “default”]. for example:

    col_types = {
        "subject_id": {"type": str},
        "label": {"type": int, "default": 0},
        "ehr_0": {"type": float, "default": 0.0},
        "ehr_1": {"type": float, "default": 0.0},
        "image": {"type": str, "default": None},
    }
    

  • col_groups (Optional[Dict[str, Sequence[str]]]) – args to group the loaded columns to generate a new column, it should be a dictionary, every item maps to a group, the key will be the new column name, the value is the names of columns to combine. for example: col_groups={“ehr”: [f”ehr_{i}” for i in range(10)], “meta”: [“meta_1”, “meta_2”]}

  • transform (Optional[Callable]) – transform to apply on the loaded items of a dictionary data.

  • kwargs_read_csv (Optional[Dict]) – dictionary args to pass to pandas read_csv function.

  • kwargs – additional arguments for pandas.merge() API to join tables.

Deprecated since version 0.8.0: filename is deprecated, use src instead.

Patch-based dataset#

GridPatchDataset#

class monai.data.GridPatchDataset(data, patch_iter, transform=None, with_coordinates=True)[source]#

Yields patches from data read from an image dataset. Typically used with PatchIter or PatchIterd so that the patches are chosen in a contiguous grid sampling scheme.

import numpy as np

from monai.data import GridPatchDataset, DataLoader, PatchIter, RandShiftIntensity

# image-level dataset
images = [np.arange(16, dtype=float).reshape(1, 4, 4),
          np.arange(16, dtype=float).reshape(1, 4, 4)]
# image-level patch generator, "grid sampling"
patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0))
# patch-level intensity shifts
patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0)

# construct the dataset
ds = GridPatchDataset(data=images,
                      patch_iter=patch_iter,
                      transform=patch_intensity)
# use the grid patch dataset
for item in DataLoader(ds, batch_size=2, num_workers=2):
    print("patch size:", item[0].shape)
    print("coordinates:", item[1])

# >>> patch size: torch.Size([2, 1, 2, 2])
#     coordinates: tensor([[[0, 1], [0, 2], [0, 2]],
#                          [[0, 1], [2, 4], [0, 2]]])
Parameters
  • data (Union[Iterable, Sequence]) – the data source to read image data from.

  • patch_iter (Callable) – converts an input image (item from dataset) into a iterable of image patches. patch_iter(dataset[idx]) must yield a tuple: (patches, coordinates). see also: monai.data.PatchIter or monai.data.PatchIterd.

  • transform (Optional[Callable]) – a callable data transform operates on the patches.

  • with_coordinates (bool) – whether to yield the coordinates of each patch, default to True.

Deprecated since version 0.8.0: dataset is deprecated, use data instead.

PatchDataset#

class monai.data.PatchDataset(data, patch_func, samples_per_image=1, transform=None)[source]#

returns a patch from an image dataset. The patches are generated by a user-specified callable patch_func, and are optionally post-processed by transform. For example, to generate random patch samples from an image dataset:

import numpy as np

from monai.data import PatchDataset, DataLoader
from monai.transforms import RandSpatialCropSamples, RandShiftIntensity

# image dataset
images = [np.arange(16, dtype=float).reshape(1, 4, 4),
          np.arange(16, dtype=float).reshape(1, 4, 4)]
# image patch sampler
n_samples = 5
sampler = RandSpatialCropSamples(roi_size=(3, 3), num_samples=n_samples,
                                 random_center=True, random_size=False)
# patch-level intensity shifts
patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0)
# construct the patch dataset
ds = PatchDataset(dataset=images,
                  patch_func=sampler,
                  samples_per_image=n_samples,
                  transform=patch_intensity)

# use the patch dataset, length: len(images) x samplers_per_image
print(len(ds))

>>> 10

for item in DataLoader(ds, batch_size=2, shuffle=True, num_workers=2):
    print(item.shape)

>>> torch.Size([2, 1, 3, 3])

Deprecated since version 0.8.0: dataset is deprecated, use data instead.

__init__(data, patch_func, samples_per_image=1, transform=None)[source]#
Parameters
  • data (Sequence) – an image dataset to extract patches from.

  • patch_func (Callable) – converts an input image (item from dataset) into a sequence of image patches. patch_func(dataset[idx]) must return a sequence of patches (length samples_per_image).

  • samples_per_image (int) – patch_func should return a sequence of samples_per_image elements.

  • transform (Optional[Callable]) – transform applied to each patch.

PatchIter#

class monai.data.PatchIter(patch_size, start_pos=(), mode=NumpyPadMode.WRAP, **pad_opts)[source]#

Return a patch generator with predefined properties such as patch_size. Typically used with monai.data.GridPatchDataset.

__call__(array)[source]#
Parameters

array (ndarray) – the image to generate patches from.

__init__(patch_size, start_pos=(), mode=NumpyPadMode.WRAP, **pad_opts)[source]#
Parameters
  • patch_size (Sequence[int]) – size of patches to generate slices for, 0/None selects whole dimension

  • start_pos (Sequence[int]) – starting position in the array, default is 0 for each dimension

  • mode (str) – {"constant", "edge", "linear_ramp", "maximum", "mean", "median", "minimum", "reflect", "symmetric", "wrap", "empty"} One of the listed string values or a user supplied function. Defaults to "wrap". See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html

  • pad_opts (Dict) – other arguments for the np.pad function. note that np.pad treats channel dimension as the first dimension.

Note

The patch_size is the size of the patch to sample from the input arrays. It is assumed the arrays first dimension is the channel dimension which will be yielded in its entirety so this should not be specified in patch_size. For example, for an input 3D array with 1 channel of size (1, 20, 20, 20) a regular grid sampling of eight patches (1, 10, 10, 10) would be specified by a patch_size of (10, 10, 10).

PatchIterd#

class monai.data.PatchIterd(keys, patch_size, start_pos=(), mode=NumpyPadMode.WRAP, **pad_opts)[source]#

Dictionary-based wrapper of monai.data.PatchIter. Return a patch generator for dictionary data and the coordinate, Typically used with monai.data.GridPatchDataset. Suppose all the expected fields specified by keys have same shape.

Parameters
  • keys (Union[Collection[Hashable], Hashable]) – keys of the corresponding items to iterate patches.

  • patch_size (Sequence[int]) – size of patches to generate slices for, 0/None selects whole dimension

  • start_pos (Sequence[int]) – starting position in the array, default is 0 for each dimension

  • mode (str) – {"constant", "edge", "linear_ramp", "maximum", "mean", "median", "minimum", "reflect", "symmetric", "wrap", "empty"} One of the listed string values or a user supplied function. Defaults to "wrap". See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html

  • pad_opts – other arguments for the np.pad function. note that np.pad treats channel dimension as the first dimension.

__call__(data)[source]#

Call self as a function.

Image reader#

ImageReader#

class monai.data.ImageReader[source]#

An abstract class defines APIs to load image files.

Typical usage of an implementation of this class is:

image_reader = MyImageReader()
img_obj = image_reader.read(path_to_image)
img_data, meta_data = image_reader.get_data(img_obj)
  • The read call converts image filenames into image objects,

  • The get_data call fetches the image data, as well as metadata.

  • A reader should implement verify_suffix with the logic of checking the input filename by the filename extensions.

abstract get_data(img)[source]#

Extract data array and metadata from loaded image and return them. This function must return two objects, the first is a numpy array of image data, the second is a dictionary of metadata.

Parameters

img – an image object loaded from an image file or a list of image objects.

Return type

Tuple[ndarray, Dict]

abstract read(data, **kwargs)[source]#

Read image data from specified file or files. Note that it returns a data object or a sequence of data objects.

Parameters
  • data (Union[Sequence[Union[str, PathLike]], str, PathLike]) – file name or a list of file names to read.

  • kwargs – additional args for actual read API of 3rd party libs.

Return type

Union[Sequence[Any], Any]

abstract verify_suffix(filename)[source]#

Verify whether the specified filename is supported by the current reader. This method should return True if the reader is able to read the format suggested by the filename.

Parameters

filename (Union[Sequence[Union[str, PathLike]], str, PathLike]) – file name or a list of file names to read. if a list of files, verify all the suffixes.

Return type

bool

ITKReader#

class monai.data.ITKReader(channel_dim=None, series_name='', reverse_indexing=False, series_meta=False, affine_lps_to_ras=True, **kwargs)[source]#

Load medical images based on ITK library. All the supported image formats can be found at: https://github.com/InsightSoftwareConsortium/ITK/tree/master/Modules/IO The loaded data array will be in C order, for example, a 3D image NumPy array index order will be CDWH.

Parameters
  • channel_dim (Optional[int]) –

    the channel dimension of the input image, default is None. This is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. If None, original_channel_dim will be either no_channel or -1.

    • Nifti file is usually “channel last”, so there is no need to specify this argument.

    • PNG file usually has GetNumberOfComponentsPerPixel()==3, so there is no need to specify this argument.

  • series_name (str) – the name of the DICOM series if there are multiple ones. used when loading DICOM series.

  • reverse_indexing (bool) – whether to use a reversed spatial indexing convention for the returned data array. If False, the spatial indexing follows the numpy convention; otherwise, the spatial indexing convention is reversed to be compatible with ITK. Default is False. This option does not affect the metadata.

  • series_meta (bool) – whether to load the metadata of the DICOM series (using the metadata from the first slice). This flag is checked only when loading DICOM series. Default is False.

  • affine_lps_to_ras (bool) – whether to convert the affine matrix from “LPS” to “RAS”. Defaults to True. Set to True to be consistent with NibabelReader, otherwise the affine matrix remains in the ITK convention.

  • kwargs – additional args for itk.imread API. more details about available args: https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itk/support/extras.py

get_data(img)[source]#

Extract data array and metadata from loaded image and return them. This function returns two objects, first is numpy array of image data, second is dict of metadata. It constructs affine, original_affine, and spatial_shape and stores them in meta dict. When loading a list of files, they are stacked together at a new dimension as the first dimension, and the metadata of the first image is used to represent the output metadata.

Parameters

img – an ITK image object loaded from an image file or a list of ITK image objects.

read(data, **kwargs)[source]#

Read image data from specified file or files, it can read a list of images and stack them together as multi-channel data in get_data(). If passing directory path instead of file path, will treat it as DICOM images series and read. Note that the returned object is ITK image object or list of ITK image objects.

Parameters
verify_suffix(filename)[source]#

Verify whether the specified file or files format is supported by ITK reader.

Parameters

filename (Union[Sequence[Union[str, PathLike]], str, PathLike]) – file name or a list of file names to read. if a list of files, verify all the suffixes.

Return type

bool

NibabelReader#

class monai.data.NibabelReader(channel_dim=None, as_closest_canonical=False, squeeze_non_spatial_dims=False, dtype=<class 'numpy.float32'>, **kwargs)[source]#

Load NIfTI format images based on Nibabel library.

Parameters
  • as_closest_canonical (bool) – if True, load the image as closest to canonical axis format.

  • squeeze_non_spatial_dims (bool) – if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)

  • channel_dim (Optional[int]) – the channel dimension of the input image, default is None. this is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. if None, original_channel_dim will be either no_channel or -1. most Nifti files are usually “channel last”, no need to specify this argument for them.

  • dtype (Union[dtype, type, str, None]) – dtype of the output data array when loading with Nibabel library.

  • kwargs – additional args for nibabel.load API. more details about available args: https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py

get_data(img)[source]#

Extract data array and metadata from loaded image and return them. This function returns two objects, first is numpy array of image data, second is dict of metadata. It constructs affine, original_affine, and spatial_shape and stores them in meta dict. When loading a list of files, they are stacked together at a new dimension as the first dimension, and the metadata of the first image is used to present the output metadata.

Parameters

img – a Nibabel image object loaded from an image file or a list of Nibabel image objects.

read(data, **kwargs)[source]#

Read image data from specified file or files, it can read a list of images and stack them together as multi-channel data in get_data(). Note that the returned object is Nibabel image object or list of Nibabel image objects.

Parameters
verify_suffix(filename)[source]#

Verify whether the specified file or files format is supported by Nibabel reader.

Parameters

filename (Union[Sequence[Union[str, PathLike]], str, PathLike]) – file name or a list of file names to read. if a list of files, verify all the suffixes.

Return type

bool

NumpyReader#

class monai.data.NumpyReader(npz_keys=None, channel_dim=None, **kwargs)[source]#

Load NPY or NPZ format data based on Numpy library, they can be arrays or pickled objects. A typical usage is to load the mask data for classification task. It can load part of the npz file with specified npz_keys.

Parameters
  • npz_keys (Union[Collection[Hashable], Hashable, None]) – if loading npz file, only load the specified keys, if None, load all the items. stack the loaded items together to construct a new first dimension.

  • channel_dim (Optional[int]) – if not None, explicitly specify the channel dim, otherwise, treat the array as no channel.

  • kwargs – additional args for numpy.load API except allow_pickle. more details about available args: https://numpy.org/doc/stable/reference/generated/numpy.load.html

get_data(img)[source]#

Extract data array and metadata from loaded image and return them. This function returns two objects, first is numpy array of image data, second is dict of metadata. It constructs affine, original_affine, and spatial_shape and stores them in meta dict. When loading a list of files, they are stacked together at a new dimension as the first dimension, and the metadata of the first image is used to represent the output metadata.

Parameters

img – a Numpy array loaded from a file or a list of Numpy arrays.

read(data, **kwargs)[source]#

Read image data from specified file or files, it can read a list of data files and stack them together as multi-channel data in get_data(). Note that the returned object is Numpy array or list of Numpy arrays.

Parameters
  • data (Union[Sequence[Union[str, PathLike]], str, PathLike]) – file name or a list of file names to read.

  • kwargs – additional args for numpy.load API except allow_pickle, will override self.kwargs for existing keys. More details about available args: https://numpy.org/doc/stable/reference/generated/numpy.load.html

verify_suffix(filename)[source]#

Verify whether the specified file or files format is supported by Numpy reader.

Parameters

filename (Union[Sequence[Union[str, PathLike]], str, PathLike]) – file name or a list of file names to read. if a list of files, verify all the suffixes.

Return type

bool

PILReader#

class monai.data.PILReader(converter=None, **kwargs)[source]#

Load common 2D image format (supports PNG, JPG, BMP) file or files from provided path.

Parameters
get_data(img)[source]#

Extract data array and metadata from loaded image and return them. This function returns two objects, first is numpy array of image data, second is dict of metadata. It computes spatial_shape and stores it in meta dict. When loading a list of files, they are stacked together at a new dimension as the first dimension, and the metadata of the first image is used to represent the output metadata. Note that it will swap axis 0 and 1 after loading the array because the HW definition in PIL is different from other common medical packages.

Parameters

img – a PIL Image object loaded from a file or a list of PIL Image objects.

read(data, **kwargs)[source]#

Read image data from specified file or files, it can read a list of images and stack them together as multi-channel data in get_data(). Note that the returned object is PIL image or list of PIL image.

Parameters
verify_suffix(filename)[source]#

Verify whether the specified file or files format is supported by PIL reader.

Parameters

filename (Union[Sequence[Union[str, PathLike]], str, PathLike]) – file name or a list of file names to read. if a list of files, verify all the suffixes.

Return type

bool

FastMRIReader#

class monai.apps.reconstruction.fastmri_reader.FastMRIReader(*args, **kwargs)[source]#

Load fastMRI files with ‘.h5’ suffix. fastMRI files, when loaded with “h5py”, are HDF5 dictionary-like datasets. The keys are:

  • kspace: contains the fully-sampled kspace

  • reconstruction_rss: contains the root sum of squares of ifft of kspace. This

    is the ground-truth image.

It also has several attributes with the following keys:

  • acquisition (str): acquisition mode of the data (e.g., AXT2 denotes T2 brain MRI scans)

  • max (float): dynamic range of the data

  • norm (float): norm of the kspace

  • patient_id (str): the patient’s id whose measurements were recorded

get_data(dat)[source]#

Extract data array and metadata from the loaded data and return them. This function returns two objects, first is numpy array of image data, second is dict of metadata.

Parameters

dat (Dict) – a dictionary loaded from an h5 file

Return type

Tuple[ndarray, dict]

read(data)[source]#

Read data from specified h5 file. Note that the returned object is a dictionary.

Parameters

data (Union[Sequence[Union[str, PathLike]], str, PathLike]) – file name to read.

Return type

Dict

verify_suffix(filename)[source]#

Verify whether the specified file format is supported by h5py reader.

Parameters

filename (Union[Sequence[Union[str, PathLike]], str, PathLike]) – file name

Return type

bool

Image writer#

resolve_writer#

monai.data.resolve_writer(ext_name, error_if_not_found=True)[source]#

Resolves to a tuple of available ImageWriter in SUPPORTED_WRITERS according to the filename extension key ext_name.

Parameters
  • ext_name – the filename extension of the image. As an indexing key it will be converted to a lower case string.

  • error_if_not_found – whether to raise an error if no suitable image writer is found. if True , raise an OptionalImportError, otherwise return an empty tuple. Default is True.

Return type

Sequence

register_writer#

monai.data.register_writer(ext_name, *im_writers)[source]#

Register ImageWriter, so that writing a file with filename extension ext_name could be resolved to a tuple of potentially appropriate ImageWriter. The customised writers could be registered by:

from monai.data import register_writer
# `MyWriter` must implement `ImageWriter` interface
register_writer("nii", MyWriter)
Parameters
  • ext_name – the filename extension of the image. As an indexing key, it will be converted to a lower case string.

  • im_writers – one or multiple ImageWriter classes with high priority ones first.

ImageWriter#

class monai.data.ImageWriter(**kwargs)[source]#

The class is a collection of utilities to write images to disk.

Main aspects to be considered are:

  • dimensionality of the data array, arrangements of spatial dimensions and channel/time dimensions
    • convert_to_channel_last()

  • metadata of the current affine and output affine, the data array should be converted accordingly
    • get_meta_info()

    • resample_if_needed()

  • data type handling of the output image (as part of resample_if_needed())

Subclasses of this class should implement the backend-specific functions:

  • set_data_array() to set the data array (input must be numpy array or torch tensor)
    • this method sets the backend object’s data part

  • set_metadata() to set the metadata and output affine
    • this method sets the metadata including affine handling and image resampling

  • backend-specific data object create_backend_obj()

  • backend-specific writing function write()

The primary usage of subclasses of ImageWriter is:

writer = MyWriter()  # subclass of ImageWriter
writer.set_data_array(data_array)
writer.set_metadata(meta_dict)
writer.write(filename)

This creates an image writer object based on data_array and meta_dict and write to filename.

It supports up to three spatial dimensions (with the resampling step supports for both 2D and 3D). When saving multiple time steps or multiple channels data_array, time and/or modality axes should be the at the channel_dim. For example, the shape of a 2D eight-class and channel_dim=0, the segmentation probabilities to be saved could be (8, 64, 64); in this case data_array will be converted to (64, 64, 1, 8) (the third dimension is reserved as a spatial dimension).

The metadata could optionally have the following keys:

  • 'original_affine': for data original affine, it will be the

    affine of the output object, defaulting to an identity matrix.

  • 'affine': it should specify the current data affine, defaulting to an identity matrix.

  • 'spatial_shape': for data output spatial shape.

When metadata is specified, the saver will may resample data from the space defined by “affine” to the space defined by “original_affine”, for more details, please refer to the resample_if_needed method.

__init__(**kwargs)[source]#

The constructor supports adding new instance members. The current member in the base class is self.data_obj, the subclasses can add more members, so that necessary meta information can be stored in the object and shared among the class methods.

classmethod convert_to_channel_last(data, channel_dim=0, squeeze_end_dims=True, spatial_ndim=3, contiguous=False)[source]#

Rearrange the data array axes to make the channel_dim-th dim the last dimension and ensure there are spatial_ndim number of spatial dimensions.

When squeeze_end_dims is True, a postprocessing step will be applied to remove any trailing singleton dimensions.

Parameters
  • data (Union[ndarray, Tensor]) – input data to be converted to “channel-last” format.

  • channel_dim (Union[None, int, Sequence[int]]) – specifies the channel axes of the data array to move to the last. None indicates no channel dimension, a new axis will be appended as the channel dimension. a sequence of integers indicates multiple non-spatial dimensions.

  • squeeze_end_dims (bool) – if True, any trailing singleton dimensions will be removed (after the channel has been moved to the end). So if input is (H,W,D,C) and C==1, then it will be saved as (H,W,D). If D is also 1, it will be saved as (H,W). If False, image will always be saved as (H,W,D,C).

  • spatial_ndim (Optional[int]) – modifying the spatial dims if needed, so that output to have at least this number of spatial dims. If None, the output will have the same number of spatial dimensions as the input.

  • contiguous (bool) – if True, the output will be contiguous.

classmethod create_backend_obj(data_array, **kwargs)[source]#

Subclass should implement this method to return a backend-specific data representation object. This method is used by cls.write and the input data_array is assumed ‘channel-last’.

Return type

ndarray

classmethod get_meta_info(metadata=None)[source]#

Extracts relevant meta information from the metadata object (using .get). Optional keys are "spatial_shape", MetaKeys.AFFINE, "original_affine".

classmethod resample_if_needed(data_array, affine=None, target_affine=None, output_spatial_shape=None, mode=GridSampleMode.BILINEAR, padding_mode=GridSamplePadMode.BORDER, align_corners=False, dtype=<class 'numpy.float64'>)[source]#

Convert the data_array into the coordinate system specified by target_affine, from the current coordinate definition of affine.

If the transform between affine and target_affine could be achieved by simply transposing and flipping data_array, no resampling will happen. Otherwise, this function resamples data_array using the transformation computed from affine and target_affine.

This function assumes the NIfTI dimension notations. Spatially it supports up to three dimensions, that is, H, HW, HWD for 1D, 2D, 3D respectively. When saving multiple time steps or multiple channels, time and/or modality axes should be appended after the first three dimensions. For example, shape of 2D eight-class segmentation probabilities to be saved could be (64, 64, 1, 8). Also, data in shape (64, 64, 8) or (64, 64, 8, 1) will be considered as a single-channel 3D image. The convert_to_channel_last method can be used to convert the data to the format described here.

Note that the shape of the resampled data_array may subject to some rounding errors. For example, resampling a 20x20 pixel image from pixel size (1.5, 1.5)-mm to (3.0, 3.0)-mm space will return a 10x10-pixel image. However, resampling a 20x20-pixel image from pixel size (2.0, 2.0)-mm to (3.0, 3.0)-mm space will output a 14x14-pixel image, where the image shape is rounded from 13.333x13.333 pixels. In this case output_spatial_shape could be specified so that this function writes image data to a designated shape.

Parameters
  • data_array (Union[ndarray, Tensor]) – input data array to be converted.

  • affine (Union[ndarray, Tensor, None]) – the current affine of data_array. Defaults to identity

  • target_affine (Union[ndarray, Tensor, None]) – the designated affine of data_array. The actual output affine might be different from this value due to precision changes.

  • output_spatial_shape (Union[Sequence[int], int, None]) – spatial shape of the output image. This option is used when resampling is needed.

  • mode (str) – available options are {"bilinear", "nearest", "bicubic"}. This option is used when resampling is needed. Interpolation mode to calculate output values. Defaults to "bilinear". See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample

  • padding_mode (str) – available options are {"zeros", "border", "reflection"}. This option is used when resampling is needed. Padding mode for outside grid values. Defaults to "border". See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample

  • align_corners (bool) – boolean option of grid_sample to handle the corner convention. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample

  • dtype (Union[dtype, type, str, None]) – data type for resampling computation. Defaults to np.float64 for best precision. If None, use the data type of input data. The output data type of this method is always np.float32.

write(filename, verbose=True, **kwargs)[source]#

subclass should implement this method to call the backend-specific writing APIs.

ITKWriter#

class monai.data.ITKWriter(output_dtype=<class 'numpy.float32'>, affine_lps_to_ras=True, **kwargs)[source]#

Write data and metadata into files on disk using ITK-python.

import numpy as np
from monai.data import ITKWriter

np_data = np.arange(48).reshape(3, 4, 4)

# write as 3d spatial image no channel
writer = ITKWriter(output_dtype=np.float32)
writer.set_data_array(np_data, channel_dim=None)
# optionally set metadata affine
writer.set_metadata({"affine": np.eye(4), "original_affine": -1 * np.eye(4)})
writer.write("test1.nii.gz")

# write as 2d image, channel-first
writer = ITKWriter(output_dtype=np.uint8)
writer.set_data_array(np_data, channel_dim=0)
writer.set_metadata({"spatial_shape": (5, 5)})
writer.write("test1.png")
__init__(output_dtype=<class 'numpy.float32'>, affine_lps_to_ras=True, **kwargs)[source]#
Parameters
  • output_dtype (Union[dtype, type, str, None]) – output data type.

  • affine_lps_to_ras (bool) – whether to convert the affine matrix from “LPS” to “RAS”. Defaults to True. Set to True to be consistent with NibabelWriter, otherwise the affine matrix is assumed already in the ITK convention.

  • kwargs – keyword arguments passed to ImageWriter.

The constructor will create self.output_dtype internally. affine and channel_dim are initialized as instance members (default None, 0):

  • user-specified affine should be set in set_metadata,

  • user-specified channel_dim should be set in set_data_array.

classmethod create_backend_obj(data_array, channel_dim=0, affine=None, dtype=<class 'numpy.float32'>, affine_lps_to_ras=True, **kwargs)[source]#

Create an ITK object from data_array. This method assumes a ‘channel-last’ data_array.

Parameters
  • data_array (Union[ndarray, Tensor]) – input data array.

  • channel_dim (Optional[int]) – channel dimension of the data array. This is used to create a Vector Image if it is not None.

  • affine (Union[ndarray, Tensor, None]) – affine matrix of the data array. This is used to compute spacing, direction and origin.

  • dtype (Union[dtype, type, str, None]) – output data type.

  • affine_lps_to_ras (bool) – whether to convert the affine matrix from “LPS” to “RAS”. Defaults to True. Set to True to be consistent with NibabelWriter, otherwise the affine matrix is assumed already in the ITK convention.

  • kwargs – keyword arguments. Current itk.GetImageFromArray will read ttype from this dictionary.

set_data_array(data_array, channel_dim=0, squeeze_end_dims=True, **kwargs)[source]#

Convert data_array into ‘channel-last’ numpy ndarray.

Parameters
  • data_array (Union[ndarray, Tensor]) – input data array with the channel dimension specified by channel_dim.

  • channel_dim (Optional[int]) – channel dimension of the data array. Defaults to 0. None indicates data without any channel dimension.

  • squeeze_end_dims (bool) – if True, any trailing singleton dimensions will be removed.

  • kwargs – keyword arguments passed to self.convert_to_channel_last, currently support spatial_ndim and contiguous, defauting to 3 and False respectively.

set_metadata(meta_dict=None, resample=True, **options)[source]#

Resample self.dataobj if needed. This method assumes self.data_obj is a ‘channel-last’ ndarray.

Parameters
  • meta_dict (Optional[Mapping]) – a metadata dictionary for affine, original affine and spatial shape information. Optional keys are "spatial_shape", "affine", "original_affine".

  • resample (bool) – if True, the data will be resampled to the original affine (specified in meta_dict).

  • options – keyword arguments passed to self.resample_if_needed, currently support mode, padding_mode, align_corners, and dtype, defaulting to bilinear, border, False, and np.float64 respectively.

write(filename, verbose=False, **kwargs)[source]#

Create an ITK object from self.create_backend_obj(self.obj, ...) and call itk.imwrite.

Parameters
  • filename (Union[str, PathLike]) – filename or PathLike object.

  • verbose (bool) – if True, log the progress.

  • kwargs – keyword arguments passed to itk.imwrite, currently support compression and imageio.

NibabelWriter#

class monai.data.NibabelWriter(output_dtype=<class 'numpy.float32'>, **kwargs)[source]#

Write data and metadata into files on disk using Nibabel.

import numpy as np
from monai.data import NibabelWriter

np_data = np.arange(48).reshape(3, 4, 4)
writer = NibabelWriter()
writer.set_data_array(np_data, channel_dim=None)
writer.set_metadata({"affine": np.eye(4), "original_affine": np.eye(4)})
writer.write("test1.nii.gz", verbose=True)
__init__(output_dtype=<class 'numpy.float32'>, **kwargs)[source]#
Parameters
  • output_dtype (Union[dtype, type, str, None]) – output data type.

  • kwargs – keyword arguments passed to ImageWriter.

The constructor will create self.output_dtype internally. affine is initialized as instance members (default None), user-specified affine should be set in set_metadata.

classmethod create_backend_obj(data_array, affine=None, dtype=None, **kwargs)[source]#

Create an Nifti1Image object from data_array. This method assumes a ‘channel-last’ data_array.

Parameters
  • data_array (Union[ndarray, Tensor]) – input data array.

  • affine (Union[ndarray, Tensor, None]) – affine matrix of the data array.

  • dtype (Union[dtype, type, str, None]) – output data type.

  • kwargs – keyword arguments. Current nib.nifti1.Nifti1Image will read header, extra, file_map from this dictionary.

set_data_array(data_array, channel_dim=0, squeeze_end_dims=True, **kwargs)[source]#

Convert data_array into ‘channel-last’ numpy ndarray.

Parameters
  • data_array (Union[ndarray, Tensor]) – input data array with the channel dimension specified by channel_dim.

  • channel_dim (Optional[int]) – channel dimension of the data array. Defaults to 0. None indicates data without any channel dimension.

  • squeeze_end_dims (bool) – if True, any trailing singleton dimensions will be removed.

  • kwargs – keyword arguments passed to self.convert_to_channel_last, currently support spatial_ndim, defauting to 3.

set_metadata(meta_dict, resample=True, **options)[source]#

Resample self.dataobj if needed. This method assumes self.data_obj is a ‘channel-last’ ndarray.

Parameters
  • meta_dict (Optional[Mapping]) – a metadata dictionary for affine, original affine and spatial shape information. Optional keys are "spatial_shape", "affine", "original_affine".

  • resample (bool) – if True, the data will be resampled to the original affine (specified in meta_dict).

  • options – keyword arguments passed to self.resample_if_needed, currently support mode, padding_mode, align_corners, and dtype, defaulting to bilinear, border, False, and np.float64 respectively.

write(filename, verbose=False, **obj_kwargs)[source]#

Create a Nibabel object from self.create_backend_obj(self.obj, ...) and call nib.save.

Parameters
  • filename (Union[str, PathLike]) – filename or PathLike object.

  • verbose (bool) – if True, log the progress.

  • obj_kwargs – keyword arguments passed to self.create_backend_obj,

PILWriter#

class monai.data.PILWriter(output_dtype=<class 'numpy.float32'>, channel_dim=0, scale=255, **kwargs)[source]#

Write image data into files on disk using pillow.

It’s based on the Image module in PIL library: https://pillow.readthedocs.io/en/stable/reference/Image.html

import numpy as np
from monai.data import PILWriter

np_data = np.arange(48).reshape(3, 4, 4)
writer = PILWriter(np.uint8)
writer.set_data_array(np_data, channel_dim=0)
writer.write("test1.png", verbose=True)
__init__(output_dtype=<class 'numpy.float32'>, channel_dim=0, scale=255, **kwargs)[source]#
Parameters
  • output_dtype (Union[dtype, type, str, None]) – output data type.

  • channel_dim (Optional[int]) – channel dimension of the data array. Defaults to 0. None indicates data without any channel dimension.

  • scale (Optional[int]) – {255, 65535} postprocess data by clipping to [0, 1] and scaling [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling.

  • kwargs – keyword arguments passed to ImageWriter.

classmethod create_backend_obj(data_array, dtype=None, scale=255, reverse_indexing=True, **kwargs)[source]#

Create a PIL object from data_array.

Parameters
  • data_array (Union[ndarray, Tensor]) – input data array.

  • dtype (Union[dtype, type, str, None]) – output data type.

  • scale (Optional[int]) – {255, 65535} postprocess data by clipping to [0, 1] and scaling [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling.

  • reverse_indexing (bool) – if True, the data array’s first two dimensions will be swapped.

  • kwargs – keyword arguments. Currently PILImage.fromarray will read image_mode from this dictionary, defaults to None.

classmethod get_meta_info(metadata=None)[source]#

Extracts relevant meta information from the metadata object (using .get). Optional keys are "spatial_shape", MetaKeys.AFFINE, "original_affine".

classmethod resample_and_clip(data_array, output_spatial_shape=None, mode=InterpolateMode.BICUBIC)[source]#

Resample data_array to output_spatial_shape if needed. :type data_array: Union[ndarray, Tensor] :param data_array: input data array. This method assumes the ‘channel-last’ format. :type output_spatial_shape: Optional[Sequence[int]] :param output_spatial_shape: output spatial shape. :type mode: str :param mode: interpolation mode, defautl is InterpolateMode.BICUBIC.

set_data_array(data_array, channel_dim=0, squeeze_end_dims=True, contiguous=False, **kwargs)[source]#

Convert data_array into ‘channel-last’ numpy ndarray.

Parameters
  • data_array (Union[ndarray, Tensor]) – input data array with the channel dimension specified by channel_dim.

  • channel_dim (Optional[int]) – channel dimension of the data array. Defaults to 0. None indicates data without any channel dimension.

  • squeeze_end_dims (bool) – if True, any trailing singleton dimensions will be removed.

  • contiguous (bool) – if True, the data array will be converted to a contiguous array. Default is False.

  • kwargs – keyword arguments passed to self.convert_to_channel_last, currently support spatial_ndim, defauting to 2.

set_metadata(meta_dict=None, resample=True, **options)[source]#

Resample self.dataobj if needed. This method assumes self.data_obj is a ‘channel-last’ ndarray.

Parameters
  • meta_dict (Optional[Mapping]) – a metadata dictionary for affine, original affine and spatial shape information. Optional key is "spatial_shape".

  • resample (bool) – if True, the data will be resampled to the spatial shape specified in meta_dict.

  • options – keyword arguments passed to self.resample_if_needed, currently support mode, defaulting to bicubic.

write(filename, verbose=False, **kwargs)[source]#

Create a PIL image object from self.create_backend_obj(self.obj, ...) and call save.

Parameters
  • filename (Union[str, PathLike]) – filename or PathLike object.

  • verbose (bool) – if True, log the progress.

  • kwargs – optional keyword arguments passed to self.create_backend_obj currently support reverse_indexing, image_mode, defaulting to True, None respectively.

Nifti format handling#

Writing Nifti#

class monai.data.NiftiSaver(output_dir='./', output_postfix='seg', output_ext='.nii.gz', resample=True, mode=GridSampleMode.BILINEAR, padding_mode=GridSamplePadMode.BORDER, align_corners=False, dtype=<class 'numpy.float64'>, output_dtype=<class 'numpy.float32'>, squeeze_end_dims=True, data_root_dir='', separate_folder=True, print_log=True)[source]#

Save the data as NIfTI file, it can support single data content or a batch of data. Typically, the data can be segmentation predictions, call save for single data or call save_batch to save a batch of data together. The name of saved file will be {input_image_name}_{output_postfix}{output_ext}, where the input image name is extracted from the provided metadata dictionary. If no metadata provided, use index from 0 as the filename prefix.

Note: image should include channel dimension: [B],C,H,W,[D].

Deprecated since version 0.8: Use monai.transforms.SaveImage instead.

__init__(output_dir='./', output_postfix='seg', output_ext='.nii.gz', resample=True, mode=GridSampleMode.BILINEAR, padding_mode=GridSamplePadMode.BORDER, align_corners=False, dtype=<class 'numpy.float64'>, output_dtype=<class 'numpy.float32'>, squeeze_end_dims=True, data_root_dir='', separate_folder=True, print_log=True)[source]#
Parameters
  • output_dir (Union[str, PathLike]) – output image directory.

  • output_postfix (str) – a string appended to all output file names.

  • output_ext (str) – output file extension name.

  • resample (bool) – whether to convert the data array to it’s original coordinate system based on original_affine in the meta_data.

  • mode (str) – {"bilinear", "nearest"} This option is used when resample = True. Interpolation mode to calculate output values. Defaults to "bilinear". See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html

  • padding_mode (str) – {"zeros", "border", "reflection"} This option is used when resample = True. Padding mode for outside grid values. Defaults to "border". See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html

  • align_corners (bool) – Geometrically, we consider the pixels of the input as squares rather than points. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html

  • dtype (Union[dtype, type, str, None]) – data type for resampling computation. Defaults to np.float64 for best precision. If None, use the data type of input data.

  • output_dtype (Union[dtype, type, str, None]) – data type for saving data. Defaults to np.float32.

  • squeeze_end_dims (bool) – if True, any trailing singleton dimensions will be removed (after the channel has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false, image will always be saved as (H,W,D,C).

  • data_root_dir (Union[str, PathLike]) – if not empty, it specifies the beginning parts of the input file’s absolute path. it’s used to compute input_file_rel_path, the relative path to the file from data_root_dir to preserve folder structure when saving in case there are files in different folders with the same file names. for example: input_file_name: /foo/bar/test1/image.nii, postfix: seg output_ext: nii.gz output_dir: /output, data_root_dir: /foo/bar, output will be: /output/test1/image/image_seg.nii.gz

  • separate_folder (bool) – whether to save every file in a separate folder, for example: if input filename is image.nii, postfix is seg and folder_path is output, if True, save as: output/image/image_seg.nii, if False, save as output/image_seg.nii. default to True.

  • print_log (bool) – whether to print log about the saved NIfTI file path, etc. default to True.

save(data, meta_data=None)[source]#

Save data into a NIfTI file. The meta_data could optionally have the following keys:

  • 'filename_or_obj' – for output file name creation, corresponding to filename or object.

  • 'original_affine' – for data orientation handling, defaulting to an identity matrix.

  • 'affine' – for data output affine, defaulting to an identity matrix.

  • 'spatial_shape' – for data output shape.

  • 'patch_index' – if the data is a patch of big image, append the patch index to filename.

When meta_data is specified and resample=True, the saver will try to resample batch data from the space defined by “affine” to the space defined by “original_affine”.

If meta_data is None, use the default index (starting from 0) as the filename.

Parameters
  • data (Union[Tensor, ndarray]) – target data content that to be saved as a NIfTI format file. Assuming the data shape starts with a channel dimension and followed by spatial dimensions.

  • meta_data (Optional[Dict]) – the metadata information corresponding to the data.

See Also

monai.data.nifti_writer.write_nifti()

Return type

None

save_batch(batch_data, meta_data=None)[source]#

Save a batch of data into NIfTI format files.

Spatially it supports up to three dimensions, that is, H, HW, HWD for 1D, 2D, 3D respectively (with resampling supports for 2D and 3D only).

When saving multiple time steps or multiple channels batch_data, time and/or modality axes should be appended after the batch dimensions. For example, the shape of a batch of 2D eight-class segmentation probabilities to be saved could be (batch, 8, 64, 64); in this case each item in the batch will be saved as (64, 64, 1, 8) NIfTI file (the third dimension is reserved as a spatial dimension).

Parameters
  • batch_data (Union[Tensor, ndarray]) – target batch data content that save into NIfTI format.

  • meta_data (Optional[Dict]) – every key-value in the meta_data is corresponding to a batch of data.

Return type

None

monai.data.write_nifti(data, file_name, affine=None, target_affine=None, resample=True, output_spatial_shape=None, mode=GridSampleMode.BILINEAR, padding_mode=GridSamplePadMode.BORDER, align_corners=False, dtype=<class 'numpy.float64'>, output_dtype=<class 'numpy.float32'>)[source]#

Write numpy data into NIfTI files to disk. This function converts data into the coordinate system defined by target_affine when target_affine is specified.

If the coordinate transform between affine and target_affine could be achieved by simply transposing and flipping data, no resampling will happen. otherwise this function will resample data using the coordinate transform computed from affine and target_affine. Note that the shape of the resampled data may subject to some rounding errors. For example, resampling a 20x20 pixel image from pixel size (1.5, 1.5)-mm to (3.0, 3.0)-mm space will return a 10x10-pixel image. However, resampling a 20x20-pixel image from pixel size (2.0, 2.0)-mm to (3.0, 3.0)-mma space will output a 14x14-pixel image, where the image shape is rounded from 13.333x13.333 pixels. In this case output_spatial_shape could be specified so that this function writes image data to a designated shape.

The saved affine matrix follows: - If affine equals to target_affine, save the data with target_affine. - If resample=False, transform affine to new_affine based on the orientation of target_affine and save the data with new_affine. - If resample=True, save the data with target_affine, if explicitly specify the output_spatial_shape, the shape of saved data is not computed by target_affine. - If target_affine is None, set target_affine=affine and save. - If affine and target_affine are None, the data will be saved with an identity matrix as the image affine.

This function assumes the NIfTI dimension notations. Spatially it supports up to three dimensions, that is, H, HW, HWD for 1D, 2D, 3D respectively. When saving multiple time steps or multiple channels data, time and/or modality axes should be appended after the first three dimensions. For example, shape of 2D eight-class segmentation probabilities to be saved could be (64, 64, 1, 8). Also, data in shape (64, 64, 8), (64, 64, 8, 1) will be considered as a single-channel 3D image.

Parameters
  • data (Union[ndarray, Tensor]) – input data to write to file.

  • file_name (str) – expected file name that saved on disk.

  • affine (Union[ndarray, Tensor, None]) – the current affine of data. Defaults to np.eye(4)

  • target_affine (Optional[ndarray]) – before saving the (data, affine) as a Nifti1Image, transform the data into the coordinates defined by target_affine.

  • resample (bool) – whether to run resampling when the target affine could not be achieved by swapping/flipping data axes.

  • output_spatial_shape (Union[Sequence[int], ndarray, None]) – spatial shape of the output image. This option is used when resample = True.

  • mode (str) – {"bilinear", "nearest"} This option is used when resample = True. Interpolation mode to calculate output values. Defaults to "bilinear". See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html

  • padding_mode (str) – {"zeros", "border", "reflection"} This option is used when resample = True. Padding mode for outside grid values. Defaults to "border". See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html

  • align_corners (bool) – Geometrically, we consider the pixels of the input as squares rather than points. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html

  • dtype (Union[dtype, type, str, None]) – data type for resampling computation. Defaults to np.float64 for best precision. If None, use the data type of input data.

  • output_dtype (Union[dtype, type, str, None]) – data type for saving data. Defaults to np.float32.

Deprecated since version 0.8: Use monai.data.NibabelWriter() instead.

Return type

None

PNG format handling#

Writing PNG#

class monai.data.PNGSaver(output_dir='./', output_postfix='seg', output_ext='.png', resample=True, mode=InterpolateMode.NEAREST, scale=None, data_root_dir='', separate_folder=True, print_log=True)[source]#

Save the data as png file, it can support single data content or a batch of data. Typically, the data can be segmentation predictions, call save for single data or call save_batch to save a batch of data together. The name of saved file will be {input_image_name}_{output_postfix}{output_ext}, where the input image name is extracted from the provided metadata dictionary. If no metadata provided, use index from 0 as the filename prefix.

Deprecated since version 0.8: Use monai.transforms.SaveImage instead.

__init__(output_dir='./', output_postfix='seg', output_ext='.png', resample=True, mode=InterpolateMode.NEAREST, scale=None, data_root_dir='', separate_folder=True, print_log=True)[source]#
Parameters
  • output_dir (Union[str, PathLike]) – output image directory.

  • output_postfix (str) – a string appended to all output file names.

  • output_ext (str) – output file extension name.

  • resample (bool) – whether to resample and resize if providing spatial_shape in the metadata.

  • mode (str) – {"nearest", "linear", "bilinear", "bicubic", "trilinear", "area"} The interpolation mode. Defaults to "nearest". See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html

  • scale (Optional[int]) – {255, 65535} postprocess data by clipping to [0, 1] and scaling [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling.

  • data_root_dir (Union[str, PathLike]) – if not empty, it specifies the beginning parts of the input file’s absolute path. it’s used to compute input_file_rel_path, the relative path to the file from data_root_dir to preserve folder structure when saving in case there are files in different folders with the same file names. for example: input_file_name: /foo/bar/test1/image.png, postfix: seg output_ext: png output_dir: /output, data_root_dir: /foo/bar, output will be: /output/test1/image/image_seg.png

  • separate_folder (bool) – whether to save every file in a separate folder, for example: if input filename is image.png, postfix is seg and folder_path is output, if True, save as: output/image/image_seg.png, if False, save as output/image_seg.nii. default to True.

  • print_log (bool) – whether to print log about the saved PNG file path, etc. default to True.

save(data, meta_data=None)[source]#

Save data into a png file. The meta_data could optionally have the following keys:

  • 'filename_or_obj' – for output file name creation, corresponding to filename or object.

  • 'spatial_shape' – for data output shape.

  • 'patch_index' – if the data is a patch of big image, append the patch index to filename.

If meta_data is None, use the default index (starting from 0) as the filename.

Parameters
  • data (Union[Tensor, ndarray]) – target data content that to be saved as a png format file. Assuming the data shape are spatial dimensions. Shape of the spatial dimensions (C,H,W). C should be 1, 3 or 4

  • meta_data (Optional[Dict]) – the metadata information corresponding to the data.

Raises

ValueError – When data channels is not one of [1, 3, 4].

See Also

monai.data.png_writer.write_png()

Return type

None

save_batch(batch_data, meta_data=None)[source]#

Save a batch of data into png format files.

Parameters
  • batch_data (Union[Tensor, ndarray]) – target batch data content that save into png format.

  • meta_data (Optional[Dict]) – every key-value in the meta_data is corresponding to a batch of data.

Return type

None

monai.data.write_png(data, file_name, output_spatial_shape=None, mode=InterpolateMode.BICUBIC, scale=None)[source]#

Write numpy data into png files to disk. Spatially it supports HW for 2D.(H,W) or (H,W,3) or (H,W,4). If scale is None, expect the input data in np.uint8 or np.uint16 type. It’s based on the Image module in PIL library: https://pillow.readthedocs.io/en/stable/reference/Image.html

Parameters
  • data (ndarray) – input data to write to file.

  • file_name (str) – expected file name that saved on disk.

  • output_spatial_shape (Optional[Sequence[int]]) – spatial shape of the output image.

  • mode (str) – {"nearest", "nearest-exact", "linear", "bilinear", "bicubic", "trilinear", "area"} The interpolation mode. Defaults to "bicubic". See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html

  • scale (Optional[int]) – {255, 65535} postprocess data by clipping to [0, 1] and scaling to [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling.

Raises

ValueError – When scale is not one of [255, 65535].

Deprecated since version 0.8: Use monai.data.PILWriter() instead.

Return type

None

Synthetic#

monai.data.synthetic.create_test_image_2d(width, height, num_objs=12, rad_max=30, rad_min=5, noise_max=0.0, num_seg_classes=5, channel_dim=None, random_state=None)[source]#

Return a noisy 2D image with num_objs circles and a 2D mask image. The maximum and minimum radii of the circles are given as rad_max and rad_min. The mask will have num_seg_classes number of classes for segmentations labeled sequentially from 1, plus a background class represented as 0. If noise_max is greater than 0 then noise will be added to the image taken from the uniform distribution on range [0,noise_max). If channel_dim is None, will create an image without channel dimension, otherwise create an image with channel dimension as first dim or last dim.

Parameters
  • width (int) – width of the image. The value should be larger than 2 * rad_max.

  • height (int) – height of the image. The value should be larger than 2 * rad_max.

  • num_objs (int) – number of circles to generate. Defaults to 12.

  • rad_max (int) – maximum circle radius. Defaults to 30.

  • rad_min (int) – minimum circle radius. Defaults to 5.

  • noise_max (float) – if greater than 0 then noise will be added to the image taken from the uniform distribution on range [0,noise_max). Defaults to 0.

  • num_seg_classes (int) – number of classes for segmentations. Defaults to 5.

  • channel_dim (Optional[int]) – if None, create an image without channel dimension, otherwise create an image with channel dimension as first dim or last dim. Defaults to None.

  • random_state (Optional[RandomState]) – the random generator to use. Defaults to np.random.

Return type

Tuple[ndarray, ndarray]

Returns

Randomised Numpy array with shape (width, height)

monai.data.synthetic.create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, rad_min=5, noise_max=0.0, num_seg_classes=5, channel_dim=None, random_state=None)[source]#

Return a noisy 3D image and segmentation.

Parameters
  • height (int) – height of the image. The value should be larger than 2 * rad_max.

  • width (int) – width of the image. The value should be larger than 2 * rad_max.

  • depth (int) – depth of the image. The value should be larger than 2 * rad_max.

  • num_objs (int) – number of circles to generate. Defaults to 12.

  • rad_max (int) – maximum circle radius. Defaults to 30.

  • rad_min (int) – minimum circle radius. Defaults to 5.

  • noise_max (float) – if greater than 0 then noise will be added to the image taken from the uniform distribution on range [0,noise_max). Defaults to 0.

  • num_seg_classes (int) – number of classes for segmentations. Defaults to 5.

  • channel_dim (Optional[int]) – if None, create an image without channel dimension, otherwise create an image with channel dimension as first dim or last dim. Defaults to None.

  • random_state (Optional[RandomState]) – the random generator to use. Defaults to np.random.

Return type

Tuple[ndarray, ndarray]

Returns

Randomised Numpy array with shape (width, height, depth)

Ouput folder layout#

class monai.data.folder_layout.FolderLayout(output_dir, postfix='', extension='', parent=False, makedirs=False, data_root_dir='')[source]#

A utility class to create organized filenames within output_dir. The filename method could be used to create a filename following the folder structure.

Example:

from monai.data import FolderLayout

layout = FolderLayout(
    output_dir="/test_run_1/",
    postfix="seg",
    extension="nii",
    makedirs=False)
layout.filename(subject="Sub-A", idx="00", modality="T1")
# return value: "/test_run_1/Sub-A_seg_00_modality-T1.nii"

The output filename is a string starting with a subject ID, and includes additional information about a customized index and image modality. This utility class doesn’t alter the underlying image data, but provides a convenient way to create filenames.

__init__(output_dir, postfix='', extension='', parent=False, makedirs=False, data_root_dir='')[source]#
Parameters
  • output_dir (Union[str, PathLike]) – output directory.

  • postfix (str) – a postfix string for output file name appended to subject.

  • extension (str) – output file extension to be appended to the end of an output filename.

  • parent (bool) – whether to add a level of parent folder to contain each image to the output filename.

  • makedirs (bool) – whether to create the output parent directories if they do not exist.

  • data_root_dir (Union[str, PathLike]) – an optional PathLike object to preserve the folder structure of the input subject. Please see monai.data.utils.create_file_basename() for more details.

filename(subject='subject', idx=None, **kwargs)[source]#

Create a filename based on the input subject and idx.

The output filename is formed as:

output_dir/[subject/]subject[_postfix][_idx][_key-value][ext]

Parameters
  • subject (Union[str, PathLike]) – subject name, used as the primary id of the output filename. When a PathLike object is provided, the base filename will be used as the subject name, the extension name of subject will be ignored, in favor of extension from this class’s constructor.

  • idx – additional index name of the image.

  • kwargs – additional keyword arguments to be used to form the output filename. The key-value pairs will be appended to the output filename as f"_{k}-{v}".

Utilities#

monai.data.utils.affine_to_spacing(affine, r=3, dtype=<class 'float'>, suppress_zeros=True)[source]#

Computing the current spacing from the affine matrix.

Parameters
  • affine (~NdarrayTensor) – a d x d affine matrix.

  • r (int) – indexing based on the spatial rank, spacing is computed from affine[:r, :r].

  • dtype – data type of the output.

  • suppress_zeros (bool) – whether to suppress the zeros with ones.

Return type

~NdarrayTensor

Returns

an r dimensional vector of spacing.

monai.data.utils.compute_importance_map(patch_size, mode=BlendMode.CONSTANT, sigma_scale=0.125, device='cpu')[source]#

Get importance map for different weight modes.

Parameters
  • patch_size (Tuple[int, …]) – Size of the required importance map. This should be either H, W [,D].

  • mode (Union[BlendMode, str]) –

    {"constant", "gaussian"} How to blend output of overlapping windows. Defaults to "constant".

    • "constant”: gives equal weight to all predictions.

    • "gaussian”: gives less weight to predictions on edges of windows.

  • sigma_scale (Union[Sequence[float], float]) – Sigma_scale to calculate sigma for each dimension (sigma = sigma_scale * dim_size). Used for gaussian mode only.

  • device (Union[device, int, str]) – Device to put importance map on.

Raises

ValueError – When mode is not one of [“constant”, “gaussian”].

Return type

Tensor

Returns

Tensor of size patch_size.

monai.data.utils.compute_shape_offset(spatial_shape, in_affine, out_affine, scale_extent=False)[source]#

Given input and output affine, compute appropriate shapes in the output space based on the input array’s shape. This function also returns the offset to put the shape in a good position with respect to the world coordinate system.

Parameters
  • spatial_shape (Union[ndarray, Sequence[int]]) – input array’s shape

  • in_affine (matrix) – 2D affine matrix

  • out_affine (matrix) – 2D affine matrix

  • scale_extent (bool) –

    whether the scale is computed based on the spacing or the full extent of voxels, for example, for a factor of 0.5 scaling:

    option 1, “o” represents a voxel, scaling the distance between voxels:

    o--o--o
    o-----o
    

    option 2, each voxel has a physical extent, scaling the full voxel extent:

    | voxel 1 | voxel 2 | voxel 3 | voxel 4 |
    |      voxel 1      |      voxel 2      |
    

    Option 1 may reduce the number of locations that requiring interpolation. Option 2 is more resolution agnostic, that is, resampling coordinates depend on the scaling factor, not on the number of voxels. Default is False, using option 1 to compute the shape and offset.

Return type

Tuple[ndarray, ndarray]

monai.data.utils.convert_tables_to_dicts(dfs, row_indices=None, col_names=None, col_types=None, col_groups=None, **kwargs)[source]#

Utility to join pandas tables, select rows, columns and generate groups. Will return a list of dictionaries, every dictionary maps to a row of data in tables.

Parameters
  • dfs – data table in pandas Dataframe format. if providing a list of tables, will join them.

  • row_indices (Optional[Sequence[Union[str, int]]]) – indices of the expected rows to load. it should be a list, every item can be a int number or a range [start, end) for the indices. for example: row_indices=[[0, 100], 200, 201, 202, 300]. if None, load all the rows in the file.

  • col_names (Optional[Sequence[str]]) – names of the expected columns to load. if None, load all the columns.

  • col_types (Optional[Dict[str, Optional[Dict[str, Any]]]]) –

    type and default value to convert the loaded columns, if None, use original data. it should be a dictionary, every item maps to an expected column, the key is the column name and the value is None or a dictionary to define the default value and data type. the supported keys in dictionary are: [“type”, “default”], and note that the value of default should not be None. for example:

    col_types = {
        "subject_id": {"type": str},
        "label": {"type": int, "default": 0},
        "ehr_0": {"type": float, "default": 0.0},
        "ehr_1": {"type": float, "default": 0.0},
    }
    

  • col_groups (Optional[Dict[str, Sequence[str]]]) – args to group the loaded columns to generate a new column, it should be a dictionary, every item maps to a group, the key will be the new column name, the value is the names of columns to combine. for example: col_groups={“ehr”: [f”ehr_{i}” for i in range(10)], “meta”: [“meta_1”, “meta_2”]}

  • kwargs – additional arguments for pandas.merge() API to join tables.

Return type

List[Dict[str, Any]]

monai.data.utils.correct_nifti_header_if_necessary(img_nii)[source]#

Check nifti object header’s format, update the header if needed. In the updated image pixdim matches the affine.

Parameters

img_nii – nifti image object

monai.data.utils.create_file_basename(postfix, input_file_name, folder_path, data_root_dir='', separate_folder=True, patch_index=None, makedirs=True)[source]#

Utility function to create the path to the output file based on the input filename (file name extension is not added by this function). When data_root_dir is not specified, the output file name is:

folder_path/input_file_name (no ext.) /input_file_name (no ext.)[_postfix][_patch_index]

otherwise the relative path with respect to data_root_dir will be inserted, for example:

from monai.data import create_file_basename
create_file_basename(
    postfix="seg",
    input_file_name="/foo/bar/test1/image.png",
    folder_path="/output",
    data_root_dir="/foo/bar",
    separate_folder=True,
    makedirs=False)
# output: /output/test1/image/image_seg
Parameters
  • postfix (str) – output name’s postfix

  • input_file_name (Union[str, PathLike]) – path to the input image file.

  • folder_path (Union[str, PathLike]) – path for the output file

  • data_root_dir (Union[str, PathLike]) – if not empty, it specifies the beginning parts of the input file’s absolute path. This is used to compute input_file_rel_path, the relative path to the file from data_root_dir to preserve folder structure when saving in case there are files in different folders with the same file names.

  • separate_folder (bool) – whether to save every file in a separate folder, for example: if input filename is image.nii, postfix is seg and folder_path is output, if True, save as: output/image/image_seg.nii, if False, save as output/image_seg.nii. default to True.

  • patch_index – if not None, append the patch index to filename.

  • makedirs (bool) – whether to create the folder if it does not exist.

Return type

str

monai.data.utils.decollate_batch(batch, detach=True, pad=True, fill_value=None)[source]#

De-collate a batch of data (for example, as produced by a DataLoader).

Returns a list of structures with the original tensor’s 0-th dimension sliced into elements using torch.unbind.

Images originally stored as (B,C,H,W,[D]) will be returned as (C,H,W,[D]). Other information, such as metadata, may have been stored in a list (or a list inside nested dictionaries). In this case we return the element of the list corresponding to the batch idx.

Return types aren’t guaranteed to be the same as the original, since numpy arrays will have been converted to torch.Tensor, sequences may be converted to lists of tensors, mappings may be converted into dictionaries.

For example:

batch_data = {
    "image": torch.rand((2,1,10,10)),
    DictPostFix.meta("image"): {"scl_slope": torch.Tensor([0.0, 0.0])}
}
out = decollate_batch(batch_data)
print(len(out))
>>> 2

print(out[0])
>>> {'image': tensor([[[4.3549e-01...43e-01]]]), DictPostFix.meta("image"): {'scl_slope': 0.0}}

batch_data = [torch.rand((2,1,10,10)), torch.rand((2,3,5,5))]
out = decollate_batch(batch_data)
print(out[0])
>>> [tensor([[[4.3549e-01...43e-01]]], tensor([[[5.3435e-01...45e-01]]])]

batch_data = torch.rand((2,1,10,10))
out = decollate_batch(batch_data)
print(out[0])
>>> tensor([[[4.3549e-01...43e-01]]])

batch_data = {
    "image": [1, 2, 3], "meta": [4, 5],  # undetermined batch size
}
out = decollate_batch(batch_data, pad=True, fill_value=0)
print(out)
>>> [{'image': 1, 'meta': 4}, {'image': 2, 'meta': 5}, {'image': 3, 'meta': 0}]
out = decollate_batch(batch_data, pad=False)
print(out)
>>> [{'image': 1, 'meta': 4}, {'image': 2, 'meta': 5}]
Parameters
  • batch – data to be de-collated.

  • detach (bool) – whether to detach the tensors. Scalars tensors will be detached into number types instead of torch tensors.

  • pad – when the items in a batch indicate different batch size, whether to pad all the sequences to the longest. If False, the batch size will be the length of the shortest sequence.

  • fill_value – when pad is True, the fillvalue to use when padding, defaults to None.

monai.data.utils.dense_patch_slices(image_size, patch_size, scan_interval)[source]#

Enumerate all slices defining ND patches of size patch_size from an image_size input image.

Parameters
  • image_size (Sequence[int]) – dimensions of image to iterate over

  • patch_size (Sequence[int]) – size of patches to generate slices

  • scan_interval (Sequence[int]) – dense patch sampling interval

Return type

List[Tuple[slice, …]]

Returns

a list of slice objects defining each patch

monai.data.utils.get_extra_metadata_keys()[source]#

Get a list of unnecessary keys for metadata that can be removed.

Return type

List[str]

Returns

List of keys to be removed.

monai.data.utils.get_random_patch(dims, patch_size, rand_state=None)[source]#

Returns a tuple of slices to define a random patch in an array of shape dims with size patch_size or the as close to it as possible within the given dimension. It is expected that patch_size is a valid patch for a source of shape dims as returned by get_valid_patch_size.

Parameters
  • dims (Sequence[int]) – shape of source array

  • patch_size (Sequence[int]) – shape of patch size to generate

  • rand_state (Optional[RandomState]) – a random state object to generate random numbers from

Returns

a tuple of slice objects defining the patch

Return type

(tuple of slice)

monai.data.utils.get_valid_patch_size(image_size, patch_size)[source]#

Given an image of dimensions image_size, return a patch size tuple taking the dimension from patch_size if this is not 0/None. Otherwise, or if patch_size is shorter than image_size, the dimension from image_size is taken. This ensures the returned patch size is within the bounds of image_size. If patch_size is a single number this is interpreted as a patch of the same dimensionality of image_size with that size in each dimension.

Return type

Tuple[int, …]

monai.data.utils.is_supported_format(filename, suffixes)[source]#

Verify whether the specified file or files format match supported suffixes. If supported suffixes is None, skip the verification and return True.

Parameters
  • filename (Union[Sequence[Union[str, PathLike]], str, PathLike]) – file name or a list of file names to read. if a list of files, verify all the suffixes.

  • suffixes (Sequence[str]) – all the supported image suffixes of current reader, must be a list of lower case suffixes.

Return type

bool

monai.data.utils.iter_patch(arr, patch_size=0, start_pos=(), overlap=0.0, copy_back=True, mode=NumpyPadMode.WRAP, **pad_opts)[source]#

Yield successive patches from arr of size patch_size. The iteration can start from position start_pos in arr but drawing from a padded array extended by the patch_size in each dimension (so these coordinates can be negative to start in the padded region). If copy_back is True the values from each patch are written back to arr.

Parameters
  • arr (ndarray) – array to iterate over

  • patch_size (Union[Sequence[int], int]) – size of patches to generate slices for, 0 or None selects whole dimension

  • start_pos (Sequence[int]) – starting position in the array, default is 0 for each dimension

  • overlap (Union[Sequence[float], float]) – the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.

  • copy_back (bool) – if True data from the yielded patches is copied back to arr once the generator completes

  • mode (Optional[str]) – One of the listed string values in monai.utils.NumpyPadMode or monai.utils.PytorchPadMode, or a user supplied function. If None, no wrapping is performed. Defaults to "wrap".

  • pad_opts (Dict) – padding options, see numpy.pad

Yields

Patches of array data from arr which are views into a padded array which can be modified, if copy_back is True these changes will be reflected in arr once the iteration completes.

Note

coordinate format is:

[1st_dim_start, 1st_dim_end,

2nd_dim_start, 2nd_dim_end, …, Nth_dim_start, Nth_dim_end]]

monai.data.utils.iter_patch_position(image_size, patch_size, start_pos=(), overlap=0.0, padded=False)[source]#

Yield successive tuples of upper left corner of patches of size patch_size from an array of dimensions image_size. The iteration starts from position start_pos in the array, or starting at the origin if this isn’t provided. Each patch is chosen in a contiguous grid using a rwo-major ordering.

Parameters
  • image_size (Sequence[int]) – dimensions of array to iterate over

  • patch_size (Union[Sequence[int], int]) – size of patches to generate slices for, 0 or None selects whole dimension

  • start_pos (Sequence[int]) – starting position in the array, default is 0 for each dimension

  • overlap (Union[Sequence[float], float]) – the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.

  • padded (bool) – if the image is padded so the patches can go beyond the borders. Defaults to False.

Yields

Tuples of positions defining the upper left corner of each patch

monai.data.utils.iter_patch_slices(image_size, patch_size, start_pos=(), overlap=0.0, padded=True)[source]#

Yield successive tuples of slices defining patches of size patch_size from an array of dimensions image_size. The iteration starts from position start_pos in the array, or starting at the origin if this isn’t provided. Each patch is chosen in a contiguous grid using a rwo-major ordering.

Parameters
  • image_size (Sequence[int]) – dimensions of array to iterate over

  • patch_size (Union[Sequence[int], int]) – size of patches to generate slices for, 0 or None selects whole dimension

  • start_pos (Sequence[int]) – starting position in the array, default is 0 for each dimension

  • overlap (Union[Sequence[float], float]) – the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.

  • padded (bool) – if the image is padded so the patches can go beyond the borders. Defaults to False.

Yields

Tuples of slice objects defining each patch

Return type

Generator[Tuple[slice, …], None, None]

monai.data.utils.json_hashing(item)[source]#
Parameters

item – data item to be hashed

Returns: the corresponding hash key

Return type

bytes

monai.data.utils.list_data_collate(batch)[source]#

Enhancement for PyTorch DataLoader default collate. If dataset already returns a list of batch data that generated in transforms, need to merge all data to 1 list. Then it’s same as the default collate behavior.

Note

Need to use this collate if apply some transforms that can generate batch data.

monai.data.utils.no_collation(x)[source]#

No any collation operation.

monai.data.utils.orientation_ras_lps(affine)[source]#

Convert the affine between the RAS and LPS orientation by flipping the first two spatial dimensions.

Parameters

affine (~NdarrayTensor) – a 2D affine matrix.

Return type

~NdarrayTensor

monai.data.utils.pad_list_data_collate(batch, method=Method.SYMMETRIC, mode=NumpyPadMode.CONSTANT, **kwargs)[source]#

Function version of monai.transforms.croppad.batch.PadListDataCollate.

Same as MONAI’s list_data_collate, except any tensors are centrally padded to match the shape of the biggest tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of different sizes.

This can be used on both list and dictionary data. Note that in the case of the dictionary data, this decollate function may add the transform information of PadListDataCollate to the list of invertible transforms if input batch have different spatial shape, so need to call static method: monai.transforms.croppad.batch.PadListDataCollate.inverse before inverting other transforms.

Parameters
  • batch (Sequence) – batch of data to pad-collate

  • method (str) – padding method (see monai.transforms.SpatialPad)

  • mode (str) – padding mode (see monai.transforms.SpatialPad)

  • kwargs – other arguments for the np.pad or torch.pad function. note that np.pad treats channel dimension as the first dimension.

monai.data.utils.partition_dataset(data, ratios=None, num_partitions=None, shuffle=False, seed=0, drop_last=False, even_divisible=False)[source]#

Split the dataset into N partitions. It can support shuffle based on specified random seed. Will return a set of datasets, every dataset contains 1 partition of original dataset. And it can split the dataset based on specified ratios or evenly split into num_partitions. Refer to: https://pytorch.org/docs/stable/distributed.html#module-torch.distributed.launch.

Note

It also can be used to partition dataset for ranks in distributed training. For example, partition dataset before training and use CacheDataset, every rank trains with its own data. It can avoid duplicated caching content in each rank, but will not do global shuffle before every epoch:

data_partition = partition_dataset(
    data=train_files,
    num_partitions=dist.get_world_size(),
    shuffle=True,
    even_divisible=True,
)[dist.get_rank()]

train_ds = SmartCacheDataset(
    data=data_partition,
    transform=train_transforms,
    replace_rate=0.2,
    cache_num=15,
)
Parameters
  • data (Sequence) – input dataset to split, expect a list of data.

  • ratios (Optional[Sequence[float]]) – a list of ratio number to split the dataset, like [8, 1, 1].

  • num_partitions (Optional[int]) – expected number of the partitions to evenly split, only works when ratios not specified.

  • shuffle (bool) – whether to shuffle the original dataset before splitting.

  • seed (int) – random seed to shuffle the dataset, only works when shuffle is True.

  • drop_last (bool) – only works when even_divisible is False and no ratios specified. if True, will drop the tail of the data to make it evenly divisible across partitions. if False, will add extra indices to make the data evenly divisible across partitions.

  • even_divisible (bool) – if True, guarantee every partition has same length.

Examples:

>>> data = [1, 2, 3, 4, 5]
>>> partition_dataset(data, ratios=[0.6, 0.2, 0.2], shuffle=False)
[[1, 2, 3], [4], [5]]
>>> partition_dataset(data, num_partitions=2, shuffle=False)
[[1, 3, 5], [2, 4]]
>>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=True, drop_last=True)
[[1, 3], [2, 4]]
>>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=True, drop_last=False)
[[1, 3, 5], [2, 4, 1]]
>>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=False, drop_last=False)
[[1, 3, 5], [2, 4]]
monai.data.utils.partition_dataset_classes(data, classes, ratios=None, num_partitions=None, shuffle=False, seed=0, drop_last=False, even_divisible=False)[source]#

Split the dataset into N partitions based on the given class labels. It can make sure the same ratio of classes in every partition. Others are same as monai.data.partition_dataset.

Parameters
  • data (Sequence) – input dataset to split, expect a list of data.

  • classes (Sequence[int]) – a list of labels to help split the data, the length must match the length of data.

  • ratios (Optional[Sequence[float]]) – a list of ratio number to split the dataset, like [8, 1, 1].

  • num_partitions (Optional[int]) – expected number of the partitions to evenly split, only works when no ratios.

  • shuffle (bool) – whether to shuffle the original dataset before splitting.

  • seed (int) – random seed to shuffle the dataset, only works when shuffle is True.

  • drop_last (bool) – only works when even_divisible is False and no ratios specified. if True, will drop the tail of the data to make it evenly divisible across partitions. if False, will add extra indices to make the data evenly divisible across partitions.

  • even_divisible (bool) – if True, guarantee every partition has same length.

Examples:

>>> data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
>>> classes = [2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3]
>>> partition_dataset_classes(data, classes, shuffle=False, ratios=[2, 1])
[[2, 8, 4, 1, 3, 6, 5, 11, 12], [10, 13, 7, 9, 14]]
monai.data.utils.pickle_hashing(item, protocol=4)[source]#
Parameters
  • item – data item to be hashed

  • protocol – protocol version used for pickling, defaults to pickle.HIGHEST_PROTOCOL.

Returns: the corresponding hash key

Return type

bytes

monai.data.utils.rectify_header_sform_qform(img_nii)[source]#

Look at the sform and qform of the nifti object and correct it if any incompatibilities with pixel dimensions

Adapted from https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/io/misc_io.py

Parameters

img_nii – nifti image object

monai.data.utils.remove_extra_metadata(meta)[source]#

Remove extra metadata from the dictionary. Operates in-place so nothing is returned.

Parameters

meta (dict) – dictionary containing metadata to be modified.

Return type

None

Returns

None

monai.data.utils.remove_keys(data, keys)[source]#

Remove keys from a dictionary. Operates in-place so nothing is returned.

Parameters
  • data (dict) – dictionary to be modified.

  • keys (List[str]) – keys to be deleted from dictionary.

Return type

None

Returns

None

monai.data.utils.reorient_spatial_axes(data_shape, init_affine, target_affine)[source]#

Given the input init_affine, compute the orientation transform between it and target_affine by rearranging/flipping the axes.

Returns the orientation transform and the updated affine (tensor or ndarray depends on the input affine data type). Note that this function requires external module nibabel.orientations.

Return type

Tuple[ndarray, Union[ndarray, Tensor]]

monai.data.utils.resample_datalist(data, factor, random_pick=False, seed=0)[source]#

Utility function to resample the loaded datalist for training, for example: If factor < 1.0, randomly pick part of the datalist and set to Dataset, useful to quickly test the program. If factor > 1.0, repeat the datalist to enhance the Dataset.

Parameters
  • data (Sequence) – original datalist to scale.

  • factor (float) – scale factor for the datalist, for example, factor=4.5, repeat the datalist 4 times and plus 50% of the original datalist.

  • random_pick (bool) – whether to randomly pick data if scale factor has decimal part.

  • seed (int) – random seed to randomly pick data.

monai.data.utils.select_cross_validation_folds(partitions, folds)[source]#

Select cross validation data based on data partitions and specified fold index. if a list of fold indices is provided, concatenate the partitions of these folds.

Parameters
  • partitions (Sequence[Iterable]) – a sequence of datasets, each item is a iterable

  • folds (Union[Sequence[int], int]) – the indices of the partitions to be combined.

Return type

List

Returns

A list of combined datasets.

Example:

>>> partitions = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
>>> select_cross_validation_folds(partitions, 2)
[5, 6]
>>> select_cross_validation_folds(partitions, [1, 2])
[3, 4, 5, 6]
>>> select_cross_validation_folds(partitions, [-1, 2])
[9, 10, 5, 6]
monai.data.utils.set_rnd(obj, seed)[source]#

Set seed or random state for all randomizable properties of obj.

Parameters
  • obj – object to set seed or random state for.

  • seed (int) – set the random state with an integer seed.

Return type

int

monai.data.utils.sorted_dict(item, key=None, reverse=False)[source]#

Return a new sorted dictionary from the item.

monai.data.utils.to_affine_nd(r, affine, dtype=<class 'numpy.float64'>)[source]#

Using elements from affine, to create a new affine matrix by assigning the rotation/zoom/scaling matrix and the translation vector.

When r is an integer, output is an (r+1)x(r+1) matrix, where the top left kxk elements are copied from affine, the last column of the output affine is copied from affine’s last column. k is determined by min(r, len(affine) - 1).

When r is an affine matrix, the output has the same shape as r, and the top left kxk elements are copied from affine, the last column of the output affine is copied from affine’s last column. k is determined by min(len(r) - 1, len(affine) - 1).

Parameters
  • r (int or matrix) – number of spatial dimensions or an output affine to be filled.

  • affine (matrix) – 2D affine matrix

  • dtype – data type of the output array.

Raises
  • ValueError – When affine dimensions is not 2.

  • ValueError – When r is nonpositive.

Return type

~NdarrayTensor

Returns

an (r+1) x (r+1) matrix (tensor or ndarray depends on the input affine data type)

monai.data.utils.worker_init_fn(worker_id)[source]#

Callback function for PyTorch DataLoader worker_init_fn. It can set different random seed for the transforms in different workers.

Return type

None

monai.data.utils.zoom_affine(affine, scale, diagonal=True)[source]#

To make column norm of affine the same as scale. If diagonal is False, returns an affine that combines orthogonal rotation and the new scale. This is done by first decomposing affine, then setting the zoom factors to scale, and composing a new affine; the shearing factors are removed. If diagonal is True, returns a diagonal matrix, the scaling factors are set to the diagonal elements. This function always return an affine with zero translations.

Parameters
  • affine (nxn matrix) – a square matrix.

  • scale (Union[ndarray, Sequence[float]]) – new scaling factor along each dimension. if the components of the scale are non-positive values, will use the corresponding components of the original pixdim, which is computed from the affine.

  • diagonal (bool) – whether to return a diagonal scaling matrix. Defaults to True.

Raises
  • ValueError – When affine is not a square matrix.

  • ValueError – When scale contains a nonpositive scalar.

Returns

the updated n x n affine.

Partition Dataset#

monai.data.partition_dataset(data, ratios=None, num_partitions=None, shuffle=False, seed=0, drop_last=False, even_divisible=False)[source]#

Split the dataset into N partitions. It can support shuffle based on specified random seed. Will return a set of datasets, every dataset contains 1 partition of original dataset. And it can split the dataset based on specified ratios or evenly split into num_partitions. Refer to: https://pytorch.org/docs/stable/distributed.html#module-torch.distributed.launch.

Note

It also can be used to partition dataset for ranks in distributed training. For example, partition dataset before training and use CacheDataset, every rank trains with its own data. It can avoid duplicated caching content in each rank, but will not do global shuffle before every epoch:

data_partition = partition_dataset(
    data=train_files,
    num_partitions=dist.get_world_size(),
    shuffle=True,
    even_divisible=True,
)[dist.get_rank()]

train_ds = SmartCacheDataset(
    data=data_partition,
    transform=train_transforms,
    replace_rate=0.2,
    cache_num=15,
)
Parameters
  • data (Sequence) – input dataset to split, expect a list of data.

  • ratios (Optional[Sequence[float]]) – a list of ratio number to split the dataset, like [8, 1, 1].

  • num_partitions (Optional[int]) – expected number of the partitions to evenly split, only works when ratios not specified.

  • shuffle (bool) – whether to shuffle the original dataset before splitting.

  • seed (int) – random seed to shuffle the dataset, only works when shuffle is True.

  • drop_last (bool) – only works when even_divisible is False and no ratios specified. if True, will drop the tail of the data to make it evenly divisible across partitions. if False, will add extra indices to make the data evenly divisible across partitions.

  • even_divisible (bool) – if True, guarantee every partition has same length.

Examples:

>>> data = [1, 2, 3, 4, 5]
>>> partition_dataset(data, ratios=[0.6, 0.2, 0.2], shuffle=False)
[[1, 2, 3], [4], [5]]
>>> partition_dataset(data, num_partitions=2, shuffle=False)
[[1, 3, 5], [2, 4]]
>>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=True, drop_last=True)
[[1, 3], [2, 4]]
>>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=True, drop_last=False)
[[1, 3, 5], [2, 4, 1]]
>>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=False, drop_last=False)
[[1, 3, 5], [2, 4]]

Partition Dataset based on classes#

monai.data.partition_dataset_classes(data, classes, ratios=None, num_partitions=None, shuffle=False, seed=0, drop_last=False, even_divisible=False)[source]#

Split the dataset into N partitions based on the given class labels. It can make sure the same ratio of classes in every partition. Others are same as monai.data.partition_dataset.

Parameters
  • data (Sequence) – input dataset to split, expect a list of data.

  • classes (Sequence[int]) – a list of labels to help split the data, the length must match the length of data.

  • ratios (Optional[Sequence[float]]) – a list of ratio number to split the dataset, like [8, 1, 1].

  • num_partitions (Optional[int]) – expected number of the partitions to evenly split, only works when no ratios.

  • shuffle (bool) – whether to shuffle the original dataset before splitting.

  • seed (int) – random seed to shuffle the dataset, only works when shuffle is True.

  • drop_last (bool) – only works when even_divisible is False and no ratios specified. if True, will drop the tail of the data to make it evenly divisible across partitions. if False, will add extra indices to make the data evenly divisible across partitions.

  • even_divisible (bool) – if True, guarantee every partition has same length.

Examples:

>>> data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
>>> classes = [2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3]
>>> partition_dataset_classes(data, classes, shuffle=False, ratios=[2, 1])
[[2, 8, 4, 1, 3, 6, 5, 11, 12], [10, 13, 7, 9, 14]]

DistributedSampler#

class monai.data.DistributedSampler(dataset, even_divisible=True, num_replicas=None, rank=None, shuffle=True, **kwargs)[source]#

Enhance PyTorch DistributedSampler to support non-evenly divisible sampling.

Parameters
  • dataset (Dataset) – Dataset used for sampling.

  • even_divisible (bool) – if False, different ranks can have different data length. for example, input data: [1, 2, 3, 4, 5], rank 0: [1, 3, 5], rank 1: [2, 4].

  • num_replicas (Optional[int]) – number of processes participating in distributed training. by default, world_size is retrieved from the current distributed group.

  • rank (Optional[int]) – rank of the current process within num_replicas. by default, rank is retrieved from the current distributed group.

  • shuffle (bool) – if True, sampler will shuffle the indices, default to True.

  • kwargs – additional arguments for DistributedSampler super class, can be seed and drop_last.

More information about DistributedSampler, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler.

DistributedWeightedRandomSampler#

class monai.data.DistributedWeightedRandomSampler(dataset, weights, num_samples_per_rank=None, generator=None, even_divisible=True, num_replicas=None, rank=None, shuffle=True, **kwargs)[source]#

Extend the DistributedSampler to support weighted sampling. Refer to torch.utils.data.WeightedRandomSampler, for more details please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.WeightedRandomSampler.

Parameters
  • dataset (Dataset) – Dataset used for sampling.

  • weights (Sequence[float]) – a sequence of weights, not necessary summing up to one, length should exactly match the full dataset.

  • num_samples_per_rank (Optional[int]) – number of samples to draw for every rank, sample from the distributed subset of dataset. if None, default to the length of dataset split by DistributedSampler.

  • generator (Optional[Generator]) – PyTorch Generator used in sampling.

  • even_divisible (bool) – if False, different ranks can have different data length. for example, input data: [1, 2, 3, 4, 5], rank 0: [1, 3, 5], rank 1: [2, 4].’

  • num_replicas (Optional[int]) – number of processes participating in distributed training. by default, world_size is retrieved from the current distributed group.

  • rank (Optional[int]) – rank of the current process within num_replicas. by default, rank is retrieved from the current distributed group.

  • shuffle (bool) – if True, sampler will shuffle the indices, default to True.

  • kwargs – additional arguments for DistributedSampler super class, can be seed and drop_last.

DatasetSummary#

class monai.data.DatasetSummary(dataset, image_key='image', label_key='label', meta_key=None, meta_key_postfix='meta_dict', num_workers=0, **kwargs)[source]#

This class provides a way to calculate a reasonable output voxel spacing according to the input dataset. The achieved values can used to resample the input in 3d segmentation tasks (like using as the pixdim parameter in monai.transforms.Spacingd). In addition, it also supports to compute the mean, std, min and max intensities of the input, and these statistics are helpful for image normalization (as parameters of monai.transforms.ScaleIntensityRanged and monai.transforms.NormalizeIntensityd).

The algorithm for calculation refers to: Automated Design of Deep Learning Methods for Biomedical Image Segmentation.

Decathlon Datalist#

monai.data.load_decathlon_datalist(data_list_file_path, is_segmentation=True, data_list_key='training', base_dir=None)[source]#

Load image/label paths of decathlon challenge from JSON file

Json file is similar to what you get from http://medicaldecathlon.com/ Those dataset.json files

Parameters
  • data_list_file_path (Union[str, PathLike]) – the path to the json file of datalist.

  • is_segmentation (bool) – whether the datalist is for segmentation task, default is True.

  • data_list_key (str) – the key to get a list of dictionary to be used, default is “training”.

  • base_dir (Union[str, PathLike, None]) – the base directory of the dataset, if None, use the datalist directory.

Raises
  • ValueError – When data_list_file_path does not point to a file.

  • ValueError – When data_list_key is not specified in the data list file.

Returns a list of data items, each of which is a dict keyed by element names, for example:

[
    {'image': '/workspace/data/chest_19.nii.gz',  'label': 0},
    {'image': '/workspace/data/chest_31.nii.gz',  'label': 1}
]
Return type

List[Dict]

monai.data.load_decathlon_properties(data_property_file_path, property_keys)[source]#

Load the properties from the JSON file contains data property with specified property_keys.

Parameters
  • data_property_file_path (Union[str, PathLike]) – the path to the JSON file of data properties.

  • property_keys (Union[Sequence[str], str]) – expected keys to load from the JSON file, for example, we have these keys in the decathlon challenge: name, description, reference, licence, tensorImageSize, modality, labels, numTraining, numTest, etc.

Return type

Dict

monai.data.check_missing_files(datalist, keys, root_dir=None, allow_missing_keys=False)[source]#

Checks whether some files in the Decathlon datalist are missing. It would be helpful to check missing files before a heavy training run.

Parameters
  • datalist (List[Dict]) – a list of data items, every item is a dictionary. usually generated by load_decathlon_datalist API.

  • keys (Union[Collection[Hashable], Hashable]) – expected keys to check in the datalist.

  • root_dir (Union[str, PathLike, None]) – if not None, provides the root dir for the relative file paths in datalist.

  • allow_missing_keys (bool) – whether allow missing keys in the datalist items. if False, raise exception if missing. default to False.

Returns

A list of missing filenames.

monai.data.create_cross_validation_datalist(datalist, nfolds, train_folds, val_folds, train_key='training', val_key='validation', filename=None, shuffle=True, seed=0, check_missing=False, keys=None, root_dir=None, allow_missing_keys=False, raise_error=True)[source]#

Utility to create new Decathlon style datalist based on cross validation partition.

Parameters
  • datalist (List[Dict]) – loaded list of dictionaries for all the items to partition.

  • nfolds (int) – number of the kfold split.

  • train_folds (Union[Sequence[int], int]) – indices of folds for training part.

  • val_folds (Union[Sequence[int], int]) – indices of folds for validation part.

  • train_key (str) – the key of train part in the new datalist, defaults to “training”.

  • val_key (str) – the key of validation part in the new datalist, defaults to “validation”.

  • filename (Union[Path, str, None]) – if not None and ends with “.json”, save the new datalist into JSON file.

  • shuffle (bool) – whether to shuffle the datalist before partition, defaults to True.

  • seed (int) – if shuffle is True, set the random seed, defaults to 0.

  • check_missing (bool) – whether to check all the files specified by keys are existing.

  • keys (Union[Collection[Hashable], Hashable, None]) – if not None and check_missing_files is True, the expected keys to check in the datalist.

  • root_dir (Optional[str]) – if not None, provides the root dir for the relative file paths in datalist.

  • allow_missing_keys (bool) – if check_missing_files is True, whether allow missing keys in the datalist items. if False, raise exception if missing. default to False.

  • raise_error (bool) – when found missing files, if True, raise exception and stop, if False, print warning.

DataLoader#

class monai.data.DataLoader(dataset, num_workers=0, **kwargs)[source]#

Provides an iterable over the given dataset. It inherits the PyTorch DataLoader and adds enhanced collate_fn and worker_fn by default.

Although this class could be configured to be the same as torch.utils.data.DataLoader, its default configuration is recommended, mainly for the following extra features:

  • It handles MONAI randomizable objects with appropriate random state managements for deterministic behaviour.

  • It is aware of the patch-based transform (such as monai.transforms.RandSpatialCropSamplesDict) samples for preprocessing with enhanced data collating behaviour. See: monai.transforms.Compose.

For more details about torch.utils.data.DataLoader, please see: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader.

For example, to construct a randomized dataset and iterate with the data loader:

import torch

from monai.data import DataLoader
from monai.transforms import Randomizable


class RandomDataset(torch.utils.data.Dataset, Randomizable):
    def __getitem__(self, index):
        return self.R.randint(0, 1000, (1,))

    def __len__(self):
        return 16


dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=2, num_workers=4)
for epoch in range(2):
    for i, batch in enumerate(dataloader):
        print(epoch, i, batch.data.numpy().flatten().tolist())
Parameters
  • dataset (Dataset) – dataset from which to load the data.

  • num_workers (int) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)

  • collate_fn – default to monai.data.utils.list_data_collate().

  • worker_init_fn – default to monai.data.utils.worker_init_fn().

  • kwargs – other parameters for PyTorch DataLoader.

ThreadBuffer#

class monai.data.ThreadBuffer(src, buffer_size=1, timeout=0.01)[source]#

Iterates over values from self.src in a separate thread but yielding them in the current thread. This allows values to be queued up asynchronously. The internal thread will continue running so long as the source has values or until the stop() method is called.

One issue raised by using a thread in this way is that during the lifetime of the thread the source object is being iterated over, so if the thread hasn’t finished another attempt to iterate over it will raise an exception or yield unexpected results. To ensure the thread releases the iteration and proper cleanup is done the stop() method must be called which will join with the thread.

Parameters
  • src – Source data iterable

  • buffer_size (int) – Number of items to buffer from the source

  • timeout (float) – Time to wait for an item from the buffer, or to wait while the buffer is full when adding items

ThreadDataLoader#

class monai.data.ThreadDataLoader(dataset, buffer_size=1, buffer_timeout=0.01, repeats=1, use_thread_workers=False, **kwargs)[source]#

Subclass of DataLoader using a ThreadBuffer object to implement __iter__ method asynchronously. This will iterate over data from the loader as expected however the data is generated on a separate thread. Use this class where a DataLoader instance is required and not just an iterable object.

The default behaviour with repeats set to 1 is to yield each batch as it is generated, however with a higher value the generated batch is yielded that many times while underlying dataset asynchronously generates the next. Typically not all relevant information is learned from a batch in a single iteration so training multiple times on the same batch will still produce good training with minimal short-term overfitting while allowing a slow batch generation process more time to produce a result. This duplication is done by simply yielding the same object many times and not by regenerating the data.

Another typical usage is to accelerate light-weight preprocessing (usually cached all the deterministic transforms and no IO operations), because it leverages the separate thread to execute preprocessing to avoid unnecessary IPC between multiple workers of DataLoader. And as CUDA may not work well with the multi-processing of DataLoader, ThreadDataLoader can be useful for GPU transforms. For more details: https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_model_training_guide.md.

The use_thread_workers will cause workers to be created as threads rather than processes although everything else in terms of how the class works is unchanged. This allows multiple workers to be used in Windows for example, or in any other situation where thread semantics is desired. Please note that some MONAI components like several datasets and random transforms are not thread-safe and can’t work as expected with thread workers, need to check all the preprocessing components carefully before enabling use_thread_workers.

See:
Parameters
  • dataset (Dataset) – input dataset.

  • buffer_size (int) – number of items to buffer from the data source.

  • buffer_timeout (float) – time to wait for an item from the buffer, or to wait while the buffer is full when adding items.

  • repeats (int) – number of times to yield the same batch.

  • use_thread_workers (bool) – if True and num_workers > 0 the workers are created as threads instead of processes

  • kwargs – other arguments for DataLoader except for dataset.

TestTimeAugmentation#

class monai.data.TestTimeAugmentation(transform, batch_size, num_workers=0, inferrer_fn=<function _identity>, device='cpu', image_key=CommonKeys.IMAGE, orig_key=CommonKeys.LABEL, nearest_interp=True, orig_meta_keys=None, meta_key_postfix='meta_dict', to_tensor=True, output_device='cpu', post_func=<function _identity>, return_full_data=False, progress=True)[source]#

Class for performing test time augmentations. This will pass the same image through the network multiple times.

The user passes transform(s) to be applied to each realisation, and provided that at least one of those transforms is random, the network’s output will vary. Provided that inverse transformations exist for all supplied spatial transforms, the inverse can be applied to each realisation of the network’s output. Once in the same spatial reference, the results can then be combined and metrics computed.

Test time augmentations are a useful feature for computing network uncertainty, as well as observing the network’s dependency on the applied random transforms.

Reference:

Wang et al., Aleatoric uncertainty estimation with test-time augmentation for medical image segmentation with convolutional neural networks, https://doi.org/10.1016/j.neucom.2019.01.103

Parameters
  • transform (InvertibleTransform) – transform (or composed) to be applied to each realisation. At least one transform must be of type Randomizable. All random transforms must be of type InvertibleTransform.

  • batch_size (int) – number of realisations to infer at once.

  • num_workers (int) – how many subprocesses to use for data.

  • inferrer_fn (Callable) – function to use to perform inference.

  • device (Union[str, device]) – device on which to perform inference.

  • image_key – key used to extract image from input dictionary.

  • orig_key – the key of the original input data in the dict. will get the applied transform information for this input data, then invert them for the expected data with image_key.

  • orig_meta_keys (Optional[str]) – the key of the metadata of original input data, will get the affine, data_shape, etc. the metadata is a dictionary object which contains: filename, original_shape, etc. if None, will try to construct meta_keys by {orig_key}_{meta_key_postfix}.

  • meta_key_postfix – use key_{postfix} to fetch the metadata according to the key data, default is meta_dict, the metadata is a dictionary object. For example, to handle key image, read/write affine matrices from the metadata image_meta_dict dictionary’s affine field. this arg only works when meta_keys=None.

  • to_tensor (bool) – whether to convert the inverted data into PyTorch Tensor first, default to True.

  • output_device (Union[str, device]) – if converted the inverted data to Tensor, move the inverted results to target device before post_func, default to “cpu”.

  • post_func (Callable) – post processing for the inverted data, should be a callable function.

  • return_full_data (bool) – normally, metrics are returned (mode, mean, std, vvc). Setting this flag to True will return the full data. Dimensions will be same size as when passing a single image through inferrer_fn, with a dimension appended equal in size to num_examples (N), i.e., [N,C,H,W,[D]].

  • progress (bool) – whether to display a progress bar.

Example

model = UNet(...).to(device)
transform = Compose([RandAffined(keys, ...), ...])
transform.set_random_state(seed=123)  # ensure deterministic evaluation

tt_aug = TestTimeAugmentation(
    transform, batch_size=5, num_workers=0, inferrer_fn=model, device=device
)
mode, mean, std, vvc = tt_aug(test_data)

N-Dim Fourier Transform#

monai.data.fft_utils.fftn_centered(im, spatial_dims, is_complex=True)[source]#

Pytorch-based fft for spatial_dims-dim signals. “centered” means this function automatically takes care of the required ifft and fft shifts. This function calls monai.metworks.blocks.fft_utils_t.fftn_centered_t. This is equivalent to do ifft in numpy based on numpy.fft.fftn, numpy.fft.fftshift, and numpy.fft.ifftshift

Parameters
  • im (Union[ndarray, Tensor]) – image that can be 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels.

  • spatial_dims (int) – number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume)

  • is_complex (bool) – if True, then the last dimension of the input im is expected to be 2 (representing real and imaginary channels)

Return type

Union[ndarray, Tensor]

Returns

“out” which is the output kspace (fourier of im)

Example

import torch
im = torch.ones(1,3,3,2) # the last dim belongs to real/imaginary parts
# output1 and output2 will be identical
output1 = torch.fft.fftn(torch.view_as_complex(torch.fft.ifftshift(im,dim=(-3,-2))), dim=(-2,-1), norm="ortho")
output1 = torch.fft.fftshift( torch.view_as_real(output1), dim=(-3,-2) )

output2 = fftn_centered(im, spatial_dims=2, is_complex=True)
monai.data.fft_utils.ifftn_centered(ksp, spatial_dims, is_complex=True)[source]#

Pytorch-based ifft for spatial_dims-dim signals. “centered” means this function automatically takes care of the required ifft and fft shifts. This function calls monai.metworks.blocks.fft_utils_t.ifftn_centered_t. This is equivalent to do fft in numpy based on numpy.fft.ifftn, numpy.fft.fftshift, and numpy.fft.ifftshift

Parameters
  • ksp (Union[ndarray, Tensor]) – k-space data that can be 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels.

  • spatial_dims (int) – number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume)

  • is_complex (bool) – if True, then the last dimension of the input ksp is expected to be 2 (representing real and imaginary channels)

Return type

Union[ndarray, Tensor]

Returns

“out” which is the output image (inverse fourier of ksp)

Example

import torch
ksp = torch.ones(1,3,3,2) # the last dim belongs to real/imaginary parts
# output1 and output2 will be identical
output1 = torch.fft.ifftn(torch.view_as_complex(torch.fft.ifftshift(ksp,dim=(-3,-2))), dim=(-2,-1), norm="ortho")
output1 = torch.fft.fftshift( torch.view_as_real(output1), dim=(-3,-2) )

output2 = ifftn_centered(ksp, spatial_dims=2, is_complex=True)

Meta Object#

class monai.data.meta_obj.MetaObj[source]#

Abstract base class that stores data as well as any extra metadata.

This allows for subclassing torch.Tensor and np.ndarray through multiple inheritance.

Metadata is stored in the form of a dictionary.

Behavior should be the same as extended class (e.g., torch.Tensor or np.ndarray) aside from the extended meta functionality.

Copying of information:

  • For c = a + b, then auxiliary data (e.g., metadata) will be copied from the first instance of MetaObj if a.is_batch is False (For batched data, the metadata will be shallow copied for efficiency purposes).

property applied_operations: list[dict]#

Get the applied operations. Defaults to [].

static copy_items(data)[source]#

returns a copy of the data. list and dict are shallow copied for efficiency purposes.

copy_meta_from(input_objs, copy_attr=True)[source]#

Copy metadata from a MetaObj or an iterable of MetaObj instances.

Parameters
  • input_objs – list of MetaObj to copy data from.

  • copy_attr – whether to copy each attribute with MetaObj.copy_item. note that if the attribute is a nested list or dict, only a shallow copy will be done.

Return type

None

static flatten_meta_objs(*args)[source]#

Recursively flatten input and yield all instances of MetaObj. This means that for both torch.add(a, b), torch.stack([a, b]) (and their numpy equivalents), we return [a, b] if both a and b are of type MetaObj.

Parameters

args (Iterable) – Iterables of inputs to be flattened.

Returns

list of nested MetaObj from input.

static get_default_applied_operations()[source]#

Get the default applied operations.

Return type

list

Returns

default applied operations.

static get_default_meta()[source]#

Get the default meta.

Return type

dict

Returns

default metadata.

property is_batch: bool#

Return whether object is part of batch or not.

Return type

bool

property meta: dict#

Get the meta. Defaults to {}.

Return type

dict

monai.data.meta_obj.get_track_meta()[source]#

Return the boolean as to whether metadata is tracked. If True, metadata will be associated its data by using subclasses of MetaObj. If False, then data will be returned with empty metadata.

If set_track_meta is False, then standard data objects will be returned (e.g., torch.Tensor and np.ndarray) as opposed to MONAI’s enhanced objects.

By default, this is True, and most users will want to leave it this way. However, if you are experiencing any problems regarding metadata, and aren’t interested in preserving metadata, then you can disable it.

Return type

bool

monai.data.meta_obj.set_track_meta(val)[source]#

Boolean to set whether metadata is tracked. If True, metadata will be associated its data by using subclasses of MetaObj. If False, then data will be returned with empty metadata.

If set_track_meta is False, then standard data objects will be returned (e.g., torch.Tensor and np.ndarray) as opposed to MONAI’s enhanced objects.

By default, this is True, and most users will want to leave it this way. However, if you are experiencing any problems regarding metadata, and aren’t interested in preserving metadata, then you can disable it.

Return type

None

MetaTensor#

class monai.data.MetaTensor(x, affine=None, meta=None, applied_operations=None, *_args, **_kwargs)[source]#

Bases: MetaObj, Tensor

Class that inherits from both torch.Tensor and MetaObj, adding support for metadata.

Metadata is stored in the form of a dictionary. Nested, an affine matrix will be stored. This should be in the form of torch.Tensor.

Behavior should be the same as torch.Tensor aside from the extended meta functionality.

Copying of information:

  • For c = a + b, then auxiliary data (e.g., metadata) will be copied from the first instance of MetaTensor if a.is_batch is False (For batched data, the metadata will be shallow copied for efficiency purposes).

Example

import torch
from monai.data import MetaTensor

t = torch.tensor([1,2,3])
affine = torch.as_tensor([[2,0,0,0],
                          [0,2,0,0],
                          [0,0,2,0],
                          [0,0,0,1]], dtype=torch.float64)
meta = {"some": "info"}
m = MetaTensor(t, affine=affine, meta=meta)
m2 = m + m
assert isinstance(m2, MetaTensor)
assert m2.meta["some"] == "info"
assert torch.all(m2.affine == affine)

Notes

  • Requires pytorch 1.9 or newer for full compatibility.

  • Older versions of pytorch (<=1.8), torch.jit.trace(net, im) may not work if im is of type MetaTensor. This can be resolved with torch.jit.trace(net, im.as_tensor()).

  • For pytorch < 1.8, sharing MetaTensor instances across processes may not be supported.

  • For pytorch < 1.9, next(iter(meta_tensor)) returns a torch.Tensor. see: https://github.com/pytorch/pytorch/issues/54457

  • A warning will be raised if in the constructor affine is not None and meta already contains the key affine.

  • You can query whether the MetaTensor is a batch with the is_batch attribute.

  • With a batch of data, batch[0] will return the 0th image with the 0th metadata. When the batch dimension is non-singleton, e.g., batch[:, 0], batch[…, -1] and batch[1:3], then all (or a subset in the last example) of the metadata will be returned, and is_batch will return True.

  • When creating a batch with this class, use monai.data.DataLoader as opposed to torch.utils.data.DataLoader, as this will take care of collating the metadata properly.

__init__(x, affine=None, meta=None, applied_operations=None, *_args, **_kwargs)[source]#
Parameters
  • x – initial array for the MetaTensor. Can be a list, tuple, NumPy ndarray, scalar, and other types.

  • affine – optional 4x4 array.

  • meta – dictionary of metadata.

  • applied_operations – list of previously applied operations on the MetaTensor, the list is typically maintained by monai.transforms.TraceableTransform. See also: monai.transforms.TraceableTransform

  • _args – additional args (currently not in use in this constructor).

  • _kwargs – additional kwargs (currently not in use in this constructor).

Note

If a meta dictionary is given, use it. Else, if meta exists in the input tensor x, use it. Else, use the default value. Similar for the affine, except this could come from four places, priority: affine, meta[“affine”], x.affine, get_default_affine.

property affine: Tensor#

Get the affine. Defaults to torch.eye(4, dtype=torch.float64)

Return type

Tensor

property array#

Returns a numpy array of self. The array and self shares the same underlying storage if self is on cpu. Changes to self (it’s a subclass of torch.Tensor) will be reflected in the ndarray and vice versa. If self is not on cpu, the call will move the array to cpu and then the storage is not shared.

Getter

see also: MetaTensor.get_array()

Setter

see also: MetaTensor.set_array()

as_dict(key, output_type=<class 'torch.Tensor'>, dtype=None)[source]#

Get the object as a dictionary for backwards compatibility. This method does not make a deep copy of the objects.

Parameters
  • key (str) – Base key to store main data. The key for the metadata will be determined using PostFix.

  • output_typetorch.Tensor or np.ndarray for the main data.

  • dtype – dtype of output data. Converted to correct library type (e.g., np.float32 is converted to torch.float32 if output type is torch.Tensor). If left blank, it remains unchanged.

Return type

dict

Returns

A dictionary consisting of three keys, the main data (stored under key) and the metadata.

as_tensor()[source]#

Return the MetaTensor as a torch.Tensor. It is OS dependent as to whether this will be a deep copy or not.

Return type

Tensor

astype(dtype, device=None, *_args, **_kwargs)[source]#

Cast to dtype, sharing data whenever possible.

Parameters
  • dtype – dtypes such as np.float32, torch.float, “np.float32”, float.

  • device – the device if dtype is a torch data type.

  • _args – additional args (currently unused).

  • _kwargs – additional kwargs (currently unused).

Returns

data array instance

clone()[source]#

returns a copy of the MetaTensor instance.

static ensure_torch_and_prune_meta(im, meta, simple_keys=False, pattern=None, sep='.')[source]#

Convert the image to torch.Tensor. If affine is in the meta dictionary, convert that to torch.Tensor, too. Remove any superfluous metadata.

Parameters
  • im – Input image (np.ndarray or torch.Tensor)

  • meta – Metadata dictionary.

  • simple_keys – whether to keep only a simple subset of metadata keys.

  • pattern – combined with sep, a regular expression used to match and prune keys in the metadata (nested dictionary), default to None, no key deletion.

  • sep – combined with pattern, used to match and delete keys in the metadata (nested dictionary). default is “.”, see also monai.transforms.DeleteItemsd. e.g. pattern=".*_code$", sep=" " removes any meta keys that ends with "_code".

Returns

By default, a MetaTensor is returned. However, if get_track_meta() is False, a torch.Tensor is returned.

get_array(output_type=<class 'numpy.ndarray'>, dtype=None, device=None, *_args, **_kwargs)[source]#

Returns a new array in output_type, the array shares the same underlying storage when the output is a numpy array. Changes to self tensor will be reflected in the ndarray and vice versa.

Parameters
  • output_type – output type, see also: monai.utils.convert_data_type().

  • dtype – dtype of output data. Converted to correct library type (e.g., np.float32 is converted to torch.float32 if output type is torch.Tensor). If left blank, it remains unchanged.

  • device – if the output is a torch.Tensor, select device (if None, unchanged).

  • _args – currently unused parameters.

  • _kwargs – currently unused parameters.

new_empty(size, dtype=None, device=None, requires_grad=False)[source]#

must be defined for deepcopy to work

See:
property pixdim#

Get the spacing

set_array(src, non_blocking=False, *_args, **_kwargs)[source]#

Copies the elements from src into self tensor and returns self. The src tensor must be broadcastable with the self tensor. It may be of a different data type or reside on a different device.

See also: https://pytorch.org/docs/stable/generated/torch.Tensor.copy_.html

Parameters
  • src – the source tensor to copy from.

  • non_blocking – if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect.

  • _args – currently unused parameters.

  • _kwargs – currently unused parameters.

static update_meta(rets, func, args, kwargs)[source]#

Update the metadata from the output of MetaTensor.__torch_function__.

The output of torch.Tensor.__torch_function__ could be a single object or a sequence of them. Hence, in MetaTensor.__torch_function__ we convert them to a list of not already, and then we loop across each element, processing metadata as necessary. For each element, if not of type MetaTensor, then nothing to do.

Parameters
  • rets (Sequence) – the output from torch.Tensor.__torch_function__, which has been converted to a list in MetaTensor.__torch_function__ if it wasn’t already a Sequence.

  • func – the torch function that was applied. Examples might be torch.squeeze or torch.Tensor.__add__. We need this since the metadata need to be treated differently if a batch of data is considered. For example, slicing (torch.Tensor.__getitem__) the ith element of the 0th dimension of a batch of data should return a ith tensor with the ith metadata.

  • args – positional arguments that were passed to func.

  • kwargs – keyword arguments that were passed to func.

Return type

Sequence

Returns

A sequence with the same number of elements as rets. For each element, if the input type was not MetaTensor, then no modifications will have been made. If global parameters have been set to false (e.g., not get_track_meta()), then any MetaTensor will be converted to torch.Tensor. Else, metadata will be propagated as necessary (see MetaTensor._copy_meta()).

Whole slide image reader#

BaseWSIReader#

class monai.data.BaseWSIReader(level=0, channel_dim=0, **kwargs)[source]#

An abstract class that defines APIs to load patches from whole slide image files.

Typical usage of a concrete implementation of this class is:

image_reader = MyWSIReader()
wsi = image_reader.read(, **kwargs)
img_data, meta_data = image_reader.get_data(wsi)
  • The read call converts an image filename into whole slide image object,

  • The get_data call fetches the image data, as well as metadata.

The following methods needs to be implemented for any concrete implementation of this class:

  • read reads a whole slide image object from a given file

  • get_size returns the size of the whole slide image of a given wsi object at a given level.

  • get_level_count returns the number of levels in the whole slide image

  • _get_patch extracts and returns a patch image form the whole slide image

  • _get_metadata extracts and returns metadata for a whole slide image and a specific patch.

get_data(wsi, location=(0, 0), size=None, level=None, dtype=<class 'numpy.uint8'>, mode='RGB')[source]#

Verifies inputs, extracts patches from WSI image and generates metadata, and return them.

Parameters
  • wsi – a whole slide image object loaded from a file or a list of such objects

  • location (Tuple[int, int]) – (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).

  • size (Optional[Tuple[int, int]]) – (height, width) tuple giving the patch size at the given level (level). If None, it is set to the full image size at the given level.

  • level (Optional[int]) – the level number. Defaults to 0

  • dtype (Union[dtype, type, str, None]) – the data type of output image

  • mode (str) – the output image mode, ‘RGB’ or ‘RGBA’

Return type

Tuple[ndarray, Dict]

Returns

a tuples, where the first element is an image patch [CxHxW] or stack of patches,

and second element is a dictionary of metadata

abstract get_downsample_ratio(wsi, level=None)[source]#

Returns the down-sampling ratio of the whole slide image at a given level.

Parameters
  • wsi – a whole slide image object loaded from a file

  • level (Optional[int]) – the level number where the size is calculated. If not provided the default level (from self.level) will be used.

Return type

float

abstract get_file_path(wsi)[source]#

Return the file path for the WSI object

Return type

str

abstract get_level_count(wsi)[source]#

Returns the number of levels in the whole slide image.

Parameters

wsi – a whole slide image object loaded from a file

Return type

int

abstract get_mpp(wsi, level=None)[source]#

Returns the micro-per-pixel resolution of the whole slide image at a given level.

Parameters
  • wsi – a whole slide image object loaded from a file

  • level (Optional[int]) – the level number where the size is calculated

Return type

Tuple[float, float]

abstract get_size(wsi, level=None)[source]#

Returns the size (height, width) of the whole slide image at a given level.

Parameters
  • wsi – a whole slide image object loaded from a file

  • level (Optional[int]) – the level number where the size is calculated

Return type

Tuple[int, int]

verify_suffix(filename)[source]#

Verify whether the specified file or files format is supported by WSI reader.

The list of supported suffixes are read from self.supported_suffixes.

Parameters

filename (Union[Sequence[Union[str, PathLike]], str, PathLike]) – filename or a list of filenames to read.

Return type

bool

WSIReader#

class monai.data.WSIReader(backend='cucim', level=0, channel_dim=0, **kwargs)[source]#

Read whole slide images and extract patches using different backend libraries

Parameters
  • backend – the name of backend whole slide image reader library, the default is cuCIM.

  • level (int) – the level at which patches are extracted.

  • channel_dim (int) – the desired dimension for color channel. Default to 0 (channel first).

  • num_workers – number of workers for multi-thread image loading (cucim backend only).

  • kwargs – additional arguments to be passed to the backend library

get_downsample_ratio(wsi, level=None)[source]#

Returns the down-sampling ratio of the whole slide image at a given level.

Parameters
  • wsi – a whole slide image object loaded from a file

  • level (Optional[int]) – the level number where the size is calculated. If not provided the default level (from self.level) will be used.

Return type

float

get_file_path(wsi)[source]#

Return the file path for the WSI object

Return type

str

get_level_count(wsi)[source]#

Returns the number of levels in the whole slide image.

Parameters

wsi – a whole slide image object loaded from a file

Return type

int

get_mpp(wsi, level=None)[source]#

Returns the micro-per-pixel resolution of the whole slide image at a given level.

Parameters
  • wsi – a whole slide image object loaded from a file

  • level (Optional[int]) – the level number where the size is calculated. If not provided the default level (from self.level) will be used.

Return type

Tuple[float, float]

get_size(wsi, level=None)[source]#

Returns the size (height, width) of the whole slide image at a given level.

Parameters
  • wsi – a whole slide image object loaded from a file

  • level (Optional[int]) – the level number where the size is calculated. If not provided the default level (from self.level) will be used.

Return type

Tuple[int, int]

read(data, **kwargs)[source]#

Read whole slide image objects from given file or list of files.

Parameters
  • data (Union[Sequence[Union[str, PathLike]], str, PathLike, ndarray]) – file name or a list of file names to read.

  • kwargs – additional args for the reader module (overrides self.kwargs for existing keys).

Returns

whole slide image object or list of such objects

CuCIMWSIReader#

class monai.data.CuCIMWSIReader(level=0, channel_dim=0, num_workers=0, **kwargs)[source]#

Read whole slide images and extract patches using cuCIM library.

Parameters
  • level (int) – the whole slide image level at which the image is extracted. (default=0) This is overridden if the level argument is provided in get_data.

  • channel_dim (int) – the desired dimension for color channel. Default to 0 (channel first).

  • num_workers (int) – number of workers for multi-thread image loading

  • kwargs – additional args for cucim.CuImage module: https://github.com/rapidsai/cucim/blob/main/cpp/include/cucim/cuimage.h

get_downsample_ratio(wsi, level=None)[source]#

Returns the down-sampling ratio of the whole slide image at a given level.

Parameters
  • wsi – a whole slide image object loaded from a file

  • level (Optional[int]) – the level number where the size is calculated. If not provided the default level (from self.level) will be used.

Return type

float

static get_file_path(wsi)[source]#

Return the file path for the WSI object

Return type

str

static get_level_count(wsi)[source]#

Returns the number of levels in the whole slide image.

Parameters

wsi – a whole slide image object loaded from a file

Return type

int

get_mpp(wsi, level=None)[source]#

Returns the micro-per-pixel resolution of the whole slide image at a given level.

Parameters
  • wsi – a whole slide image object loaded from a file

  • level (Optional[int]) – the level number where the size is calculated. If not provided the default level (from self.level) will be used.

Return type

Tuple[float, float]

get_size(wsi, level=None)[source]#

Returns the size (height, width) of the whole slide image at a given level.

Parameters
  • wsi – a whole slide image object loaded from a file

  • level (Optional[int]) – the level number where the size is calculated. If not provided the default level (from self.level) will be used.

Return type

Tuple[int, int]

read(data, **kwargs)[source]#

Read whole slide image objects from given file or list of files.

Parameters
Returns

whole slide image object or list of such objects

OpenSlideWSIReader#

class monai.data.OpenSlideWSIReader(level=0, channel_dim=0, **kwargs)[source]#

Read whole slide images and extract patches using OpenSlide library.

Parameters
  • level (int) – the whole slide image level at which the image is extracted. (default=0) This is overridden if the level argument is provided in get_data.

  • channel_dim (int) – the desired dimension for color channel. Default to 0 (channel first).

  • kwargs – additional args for openslide.OpenSlide module.

get_downsample_ratio(wsi, level=None)[source]#

Returns the down-sampling ratio of the whole slide image at a given level.

Parameters
  • wsi – a whole slide image object loaded from a file

  • level (Optional[int]) – the level number where the size is calculated. If not provided the default level (from self.level) will be used.

Return type

float

static get_file_path(wsi)[source]#

Return the file path for the WSI object

Return type

str

static get_level_count(wsi)[source]#

Returns the number of levels in the whole slide image.

Parameters

wsi – a whole slide image object loaded from a file

Return type

int

get_mpp(wsi, level=None)[source]#

Returns the micro-per-pixel resolution of the whole slide image at a given level.

Parameters
  • wsi – a whole slide image object loaded from a file

  • level (Optional[int]) – the level number where the size is calculated. If not provided the default level (from self.level) will be used.

Return type

Tuple[float, float]

get_size(wsi, level=None)[source]#

Returns the size (height, width) of the whole slide image at a given level.

Parameters
  • wsi – a whole slide image object loaded from a file

  • level (Optional[int]) – the level number where the size is calculated. If not provided the default level (from self.level) will be used.

Return type

Tuple[int, int]

read(data, **kwargs)[source]#

Read whole slide image objects from given file or list of files.

Parameters
  • data (Union[Sequence[Union[str, PathLike]], str, PathLike, ndarray]) – file name or a list of file names to read.

  • kwargs – additional args that overrides self.kwargs for existing keys.

Returns

whole slide image object or list of such objects

TiffFileWSIReader#

class monai.data.TiffFileWSIReader(level=0, channel_dim=0, **kwargs)[source]#

Read whole slide images and extract patches using TiffFile library.

Parameters
  • level (int) – the whole slide image level at which the image is extracted. (default=0) This is overridden if the level argument is provided in get_data.

  • channel_dim (int) – the desired dimension for color channel. Default to 0 (channel first).

  • kwargs – additional args for tifffile.TiffFile module.

get_downsample_ratio(wsi, level=None)[source]#

Returns the down-sampling ratio of the whole slide image at a given level.

Parameters
  • wsi – a whole slide image object loaded from a file

  • level (Optional[int]) – the level number where the size is calculated. If not provided the default level (from self.level) will be used.

Return type

float

static get_file_path(wsi)[source]#

Return the file path for the WSI object

Return type

str

static get_level_count(wsi)[source]#

Returns the number of levels in the whole slide image.

Parameters

wsi – a whole slide image object loaded from a file

Return type

int

get_mpp(wsi, level=None)[source]#

Returns the micro-per-pixel resolution of the whole slide image at a given level.

Parameters
  • wsi – a whole slide image object loaded from a file

  • level (Optional[int]) – the level number where the size is calculated. If not provided the default level (from self.level) will be used.

Return type

Tuple[float, float]

get_size(wsi, level=None)[source]#

Returns the size (height, width) of the whole slide image at a given level.

Parameters
  • wsi – a whole slide image object loaded from a file

  • level (Optional[int]) – the level number where the size is calculated. If not provided the default level (from self.level) will be used.

Return type

Tuple[int, int]

read(data, **kwargs)[source]#

Read whole slide image objects from given file or list of files.

Parameters
  • data (Union[Sequence[Union[str, PathLike]], str, PathLike, ndarray]) – file name or a list of file names to read.

  • kwargs – additional args that overrides self.kwargs for existing keys.

Returns

whole slide image object or list of such objects

Whole slide image datasets#

PatchWSIDataset#

class monai.data.PatchWSIDataset(data, patch_size=None, patch_level=None, transform=None, include_label=True, center_location=True, additional_meta_keys=None, reader='cuCIM', **kwargs)[source]#

This dataset extracts patches from whole slide images (without loading the whole image) It also reads labels for each patch and provides each patch with its associated class labels.

Parameters
  • data (Sequence) – the list of input samples including image, location, and label (see the note below for more details).

  • size – the size of patch to be extracted from the whole slide image.

  • level – the level at which the patches to be extracted (default to 0).

  • transform (Optional[Callable]) – transforms to be executed on input data.

  • include_label (bool) – whether to load and include labels in the output

  • center_location (bool) – whether the input location information is the position of the center of the patch

  • additional_meta_keys (Optional[Sequence[str]]) – the list of keys for items to be copied to the output metadata from the input data

  • reader

    the module to be used for loading whole slide imaging. If reader is

    • a string, it defines the backend of monai.data.WSIReader. Defaults to cuCIM.

    • a class (inherited from BaseWSIReader), it is initialized and set as wsi_reader.

    • an instance of a class inherited from BaseWSIReader, it is set as the wsi_reader.

  • kwargs – additional arguments to pass to WSIReader or provided whole slide reader class

Returns

a dictionary of loaded image (in MetaTensor format) along with the labels (if requested). {“image”: MetaTensor, “label”: torch.Tensor}

Return type

dict

Note

The input data has the following form as an example:

[
    {"image": "path/to/image1.tiff", "patch_location": [200, 500], "label": 0},
    {"image": "path/to/image2.tiff", "patch_location": [100, 700], "patch_size": [20, 20], "patch_level": 2, "label": 1}
]

MaskedPatchWSIDataset#

class monai.data.MaskedPatchWSIDataset(data, patch_size=None, patch_level=None, mask_level=7, transform=None, include_label=False, center_location=False, additional_meta_keys=(mask_location, name), reader='cuCIM', **kwargs)[source]#

This dataset extracts patches from whole slide images at the locations where foreground mask at a given level is non-zero.

Parameters
  • data (Sequence) – the list of input samples including image, location, and label (see the note below for more details).

  • size – the size of patch to be extracted from the whole slide image.

  • level – the level at which the patches to be extracted (default to 0).

  • mask_level (int) – the resolution level at which the mask is created.

  • transform (Optional[Callable]) – transforms to be executed on input data.

  • include_label (bool) – whether to load and include labels in the output

  • center_location (bool) – whether the input location information is the position of the center of the patch

  • additional_meta_keys (Sequence[str]) – the list of keys for items to be copied to the output metadata from the input data

  • reader

    the module to be used for loading whole slide imaging. Defaults to cuCIM. If reader is

    • a string, it defines the backend of monai.data.WSIReader.

    • a class (inherited from BaseWSIReader), it is initialized and set as wsi_reader,

    • an instance of a class inherited from BaseWSIReader, it is set as the wsi_reader.

  • kwargs – additional arguments to pass to WSIReader or provided whole slide reader class

Note

The input data has the following form as an example:

[
    {"image": "path/to/image1.tiff"},
    {"image": "path/to/image2.tiff", "size": [20, 20], "level": 2}
]

SlidingPatchWSIDataset#

class monai.data.SlidingPatchWSIDataset(data, patch_size=None, patch_level=None, mask_level=0, overlap=0.0, offset=(0, 0), offset_limits=None, transform=None, include_label=False, center_location=False, additional_meta_keys=(mask_location, mask_size, num_patches), reader='cuCIM', map_level=0, seed=0, **kwargs)[source]#

This dataset extracts patches from whole slide images (without loading the whole image) It also reads labels for each patch and provides each patch with its associated class labels.

Parameters
  • data (Sequence) – the list of input samples including image, location, and label (see the note below for more details).

  • size – the size of patch to be extracted from the whole slide image.

  • level – the level at which the patches to be extracted (default to 0).

  • offset (Union[Tuple[int, int], int, str]) – the offset of image to extract patches (the starting position of the upper left patch).

  • offset_limits (Union[Tuple[Tuple[int, int], Tuple[int, int]], Tuple[int, int], None]) – if offset is set to “random”, a tuple of integers defining the lower and upper limit of the random offset for all dimensions, or a tuple of tuples that defines the limits for each dimension.

  • overlap (Union[Tuple[float, float], float]) – the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.

  • transform (Optional[Callable]) – transforms to be executed on input data.

  • reader

    the module to be used for loading whole slide imaging. Defaults to cuCIM. If reader is

    • a string, it defines the backend of monai.data.WSIReader.

    • a class (inherited from BaseWSIReader), it is initialized and set as wsi_reader,

    • an instance of a class inherited from BaseWSIReader, it is set as the wsi_reader.

  • map_level (int) – the resolution level at which the output map is created.

  • seed (int) – random seed to randomly generate offsets. Defaults to 0.

  • kwargs – additional arguments to pass to WSIReader or provided whole slide reader class

Note

The input data has the following form as an example:

[
    {"image": "path/to/image1.tiff"},
    {"image": "path/to/image2.tiff", "patch_size": [20, 20], "patch_level": 2}
]

Bounding box#

This utility module mainly supports rectangular bounding boxes with a few different parameterizations and methods for converting between them. It provides reliable access to the spatial coordinates of the box vertices in the “canonical ordering”: [xmin, ymin, xmax, ymax] for 2D and [xmin, ymin, zmin, xmax, ymax, zmax] for 3D. We currently define this ordering as monai.data.box_utils.StandardMode and the rest of the detection pipelines mainly assumes boxes in StandardMode.

class monai.data.box_utils.BoxMode[source]#

An abstract class of a BoxMode.

A BoxMode is callable that converts box mode of boxes, which are Nx4 (2D) or Nx6 (3D) torch tensor or ndarray. BoxMode has several subclasses that represents different box modes, including

  • CornerCornerModeTypeA: represents [xmin, ymin, xmax, ymax] for 2D and [xmin, ymin, zmin, xmax, ymax, zmax] for 3D

  • CornerCornerModeTypeB: represents [xmin, xmax, ymin, ymax] for 2D and [xmin, xmax, ymin, ymax, zmin, zmax] for 3D

  • CornerCornerModeTypeC: represents [xmin, ymin, xmax, ymax] for 2D and [xmin, ymin, xmax, ymax, zmin, zmax] for 3D

  • CornerSizeMode: represents [xmin, ymin, xsize, ysize] for 2D and [xmin, ymin, zmin, xsize, ysize, zsize] for 3D

  • CenterSizeMode: represents [xcenter, ycenter, xsize, ysize] for 2D and [xcenter, ycenter, zcenter, xsize, ysize, zsize] for 3D

We currently define StandardMode = CornerCornerModeTypeA, and monai detection pipelines mainly assume boxes are in StandardMode.

The implementation should be aware of:

abstract boxes_to_corners(boxes)[source]#

Convert the bounding boxes of the current mode to corners.

Parameters

boxes (Tensor) – bounding boxes, Nx4 or Nx6 torch tensor

Returns

corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor. It represents (xmin, ymin, xmax, ymax) or (xmin, ymin, zmin, xmax, ymax, zmax)

Return type

Tuple

Example

boxes = torch.ones(10,6)
boxmode = BoxMode()
boxmode.boxes_to_corners(boxes) # will return a 6-element tuple, each element is a 10x1 tensor
abstract corners_to_boxes(corners)[source]#

Convert the given box corners to the bounding boxes of the current mode.

Parameters

corners (Sequence) – corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor. It represents (xmin, ymin, xmax, ymax) or (xmin, ymin, zmin, xmax, ymax, zmax)

Returns

bounding boxes, Nx4 or Nx6 torch tensor

Return type

Tensor

Example

corners = (torch.ones(10,1), torch.ones(10,1), torch.ones(10,1), torch.ones(10,1))
boxmode = BoxMode()
boxmode.corners_to_boxes(corners) # will return a 10x4 tensor
classmethod get_name(spatial_dims)[source]#

Get the mode name for the given spatial dimension using class variable name.

Parameters

spatial_dims (int) – number of spatial dimensions of the bounding boxes.

Returns

mode string name

Return type

str

class monai.data.box_utils.CenterSizeMode[source]#

A subclass of BoxMode.

Also represented as “ccwh” or “cccwhd”, with format of [xmin, ymin, xsize, ysize] or [xmin, ymin, zmin, xsize, ysize, zsize].

Example

CenterSizeMode.get_name(spatial_dims=2) # will return "ccwh"
CenterSizeMode.get_name(spatial_dims=3) # will return "cccwhd"
boxes_to_corners(boxes)[source]#

Convert the bounding boxes of the current mode to corners.

Parameters

boxes (Tensor) – bounding boxes, Nx4 or Nx6 torch tensor

Returns

corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor. It represents (xmin, ymin, xmax, ymax) or (xmin, ymin, zmin, xmax, ymax, zmax)

Return type

Tuple

Example

boxes = torch.ones(10,6)
boxmode = BoxMode()
boxmode.boxes_to_corners(boxes) # will return a 6-element tuple, each element is a 10x1 tensor
corners_to_boxes(corners)[source]#

Convert the given box corners to the bounding boxes of the current mode.

Parameters

corners (Sequence) – corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor. It represents (xmin, ymin, xmax, ymax) or (xmin, ymin, zmin, xmax, ymax, zmax)

Returns

bounding boxes, Nx4 or Nx6 torch tensor

Return type

Tensor

Example

corners = (torch.ones(10,1), torch.ones(10,1), torch.ones(10,1), torch.ones(10,1))
boxmode = BoxMode()
boxmode.corners_to_boxes(corners) # will return a 10x4 tensor
class monai.data.box_utils.CornerCornerModeTypeA[source]#

A subclass of BoxMode.

Also represented as “xyxy” or “xyzxyz”, with format of [xmin, ymin, xmax, ymax] or [xmin, ymin, zmin, xmax, ymax, zmax].

Example

CornerCornerModeTypeA.get_name(spatial_dims=2) # will return "xyxy"
CornerCornerModeTypeA.get_name(spatial_dims=3) # will return "xyzxyz"
boxes_to_corners(boxes)[source]#

Convert the bounding boxes of the current mode to corners.

Parameters

boxes (Tensor) – bounding boxes, Nx4 or Nx6 torch tensor

Returns

corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor. It represents (xmin, ymin, xmax, ymax) or (xmin, ymin, zmin, xmax, ymax, zmax)

Return type

Tuple

Example

boxes = torch.ones(10,6)
boxmode = BoxMode()
boxmode.boxes_to_corners(boxes) # will return a 6-element tuple, each element is a 10x1 tensor
corners_to_boxes(corners)[source]#

Convert the given box corners to the bounding boxes of the current mode.

Parameters

corners (Sequence) – corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor. It represents (xmin, ymin, xmax, ymax) or (xmin, ymin, zmin, xmax, ymax, zmax)

Returns

bounding boxes, Nx4 or Nx6 torch tensor

Return type

Tensor

Example

corners = (torch.ones(10,1), torch.ones(10,1), torch.ones(10,1), torch.ones(10,1))
boxmode = BoxMode()
boxmode.corners_to_boxes(corners) # will return a 10x4 tensor
class monai.data.box_utils.CornerCornerModeTypeB[source]#

A subclass of BoxMode.

Also represented as “xxyy” or “xxyyzz”, with format of [xmin, xmax, ymin, ymax] or [xmin, xmax, ymin, ymax, zmin, zmax].

Example

CornerCornerModeTypeB.get_name(spatial_dims=2) # will return "xxyy"
CornerCornerModeTypeB.get_name(spatial_dims=3) # will return "xxyyzz"
boxes_to_corners(boxes)[source]#

Convert the bounding boxes of the current mode to corners.

Parameters

boxes (Tensor) – bounding boxes, Nx4 or Nx6 torch tensor

Returns

corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor. It represents (xmin, ymin, xmax, ymax) or (xmin, ymin, zmin, xmax, ymax, zmax)

Return type

Tuple

Example

boxes = torch.ones(10,6)
boxmode = BoxMode()
boxmode.boxes_to_corners(boxes) # will return a 6-element tuple, each element is a 10x1 tensor
corners_to_boxes(corners)[source]#

Convert the given box corners to the bounding boxes of the current mode.

Parameters

corners (Sequence) – corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor. It represents (xmin, ymin, xmax, ymax) or (xmin, ymin, zmin, xmax, ymax, zmax)

Returns

bounding boxes, Nx4 or Nx6 torch tensor

Return type

Tensor

Example

corners = (torch.ones(10,1), torch.ones(10,1), torch.ones(10,1), torch.ones(10,1))
boxmode = BoxMode()
boxmode.corners_to_boxes(corners) # will return a 10x4 tensor
class monai.data.box_utils.CornerCornerModeTypeC[source]#

A subclass of BoxMode.

Also represented as “xyxy” or “xyxyzz”, with format of [xmin, ymin, xmax, ymax] or [xmin, ymin, xmax, ymax, zmin, zmax].

Example

CornerCornerModeTypeC.get_name(spatial_dims=2) # will return "xyxy"
CornerCornerModeTypeC.get_name(spatial_dims=3) # will return "xyxyzz"
boxes_to_corners(boxes)[source]#

Convert the bounding boxes of the current mode to corners.

Parameters

boxes (Tensor) – bounding boxes, Nx4 or Nx6 torch tensor

Returns

corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor. It represents (xmin, ymin, xmax, ymax) or (xmin, ymin, zmin, xmax, ymax, zmax)

Return type

Tuple

Example

boxes = torch.ones(10,6)
boxmode = BoxMode()
boxmode.boxes_to_corners(boxes) # will return a 6-element tuple, each element is a 10x1 tensor
corners_to_boxes(corners)[source]#

Convert the given box corners to the bounding boxes of the current mode.

Parameters

corners (Sequence) – corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor. It represents (xmin, ymin, xmax, ymax) or (xmin, ymin, zmin, xmax, ymax, zmax)

Returns

bounding boxes, Nx4 or Nx6 torch tensor

Return type

Tensor

Example

corners = (torch.ones(10,1), torch.ones(10,1), torch.ones(10,1), torch.ones(10,1))
boxmode = BoxMode()
boxmode.corners_to_boxes(corners) # will return a 10x4 tensor
class monai.data.box_utils.CornerSizeMode[source]#

A subclass of BoxMode.

Also represented as “xywh” or “xyzwhd”, with format of [xmin, ymin, xsize, ysize] or [xmin, ymin, zmin, xsize, ysize, zsize].

Example

CornerSizeMode.get_name(spatial_dims=2) # will return "xywh"
CornerSizeMode.get_name(spatial_dims=3) # will return "xyzwhd"
boxes_to_corners(boxes)[source]#

Convert the bounding boxes of the current mode to corners.

Parameters

boxes (Tensor) – bounding boxes, Nx4 or Nx6 torch tensor

Returns

corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor. It represents (xmin, ymin, xmax, ymax) or (xmin, ymin, zmin, xmax, ymax, zmax)

Return type

Tuple

Example

boxes = torch.ones(10,6)
boxmode = BoxMode()
boxmode.boxes_to_corners(boxes) # will return a 6-element tuple, each element is a 10x1 tensor
corners_to_boxes(corners)[source]#

Convert the given box corners to the bounding boxes of the current mode.

Parameters

corners (Sequence) – corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor. It represents (xmin, ymin, xmax, ymax) or (xmin, ymin, zmin, xmax, ymax, zmax)

Returns

bounding boxes, Nx4 or Nx6 torch tensor

Return type

Tensor

Example

corners = (torch.ones(10,1), torch.ones(10,1), torch.ones(10,1), torch.ones(10,1))
boxmode = BoxMode()
boxmode.corners_to_boxes(corners) # will return a 10x4 tensor
monai.data.box_utils.StandardMode#

alias of CornerCornerModeTypeA

monai.data.box_utils.batched_nms(boxes, scores, labels, nms_thresh, max_proposals=-1, box_overlap_metric=<function box_iou>)[source]#

Performs non-maximum suppression in a batched fashion. Each labels value correspond to a category, and NMS will not be applied between elements of different categories.

Adapted from https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/nms.py

Parameters
  • boxes (Union[ndarray, Tensor]) – bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode

  • scores (Union[ndarray, Tensor]) – prediction scores of the boxes, sized (N,). This function keeps boxes with higher scores.

  • labels (Union[ndarray, Tensor]) – indices of the categories for each one of the boxes. sized(N,), value range is (0, num_classes)

  • nms_thresh (float) – threshold of NMS. Discards all overlapping boxes with box_overlap > nms_thresh.

  • max_proposals (int) – maximum number of boxes it keeps. If max_proposals = -1, there is no limit on the number of boxes that are kept.

  • box_overlap_metric (Callable) – the metric to compute overlap between boxes.

Return type

Union[ndarray, Tensor]

Returns

Indexes of boxes that are kept after NMS.

monai.data.box_utils.box_area(boxes)[source]#

This function computes the area (2D) or volume (3D) of each box. Half precision is not recommended for this function as it may cause overflow, especially for 3D images.

Parameters

boxes (Union[ndarray, Tensor]) – bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode

Return type

Union[ndarray, Tensor]

Returns

area (2D) or volume (3D) of boxes, with size of (N,).

Example

boxes = torch.ones(10,6)
# we do computation with torch.float32 to avoid overflow
compute_dtype = torch.float32
area = box_area(boxes=boxes.to(dtype=compute_dtype))  # torch.float32, size of (10,)
monai.data.box_utils.box_centers(boxes)[source]#

Compute center points of boxes

Parameters

boxes (Union[ndarray, Tensor]) – bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode

Return type

Union[ndarray, Tensor]

Returns

center points with size of (N, spatial_dims)

monai.data.box_utils.box_giou(boxes1, boxes2)[source]#

Compute the generalized intersection over union (GIoU) of two sets of boxes. The two inputs can have different shapes and the func return an NxM matrix, (in contrary to box_pair_giou() , which requires the inputs to have the same shape and returns N values).

Parameters
  • boxes1 (Union[ndarray, Tensor]) – bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode

  • boxes2 (Union[ndarray, Tensor]) – bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be StandardMode

Return type

Union[ndarray, Tensor]

Returns

GIoU, with size of (N,M) and same data type as boxes1

Reference:

https://giou.stanford.edu/GIoU.pdf

monai.data.box_utils.box_iou(boxes1, boxes2)[source]#

Compute the intersection over union (IoU) of two set of boxes.

Parameters
  • boxes1 (Union[ndarray, Tensor]) – bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode

  • boxes2 (Union[ndarray, Tensor]) – bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be StandardMode

Return type

Union[ndarray, Tensor]

Returns

IoU, with size of (N,M) and same data type as boxes1

monai.data.box_utils.box_pair_giou(boxes1, boxes2)[source]#

Compute the generalized intersection over union (GIoU) of a pair of boxes. The two inputs should have the same shape and the func return an (N,) array, (in contrary to box_giou() , which does not require the inputs to have the same shape and returns NxM matrix).

Parameters
  • boxes1 (Union[ndarray, Tensor]) – bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode

  • boxes2 (Union[ndarray, Tensor]) – bounding boxes, same shape with boxes1. The box mode is assumed to be StandardMode

Return type

Union[ndarray, Tensor]

Returns

paired GIoU, with size of (N,) and same data type as boxes1

Reference:

https://giou.stanford.edu/GIoU.pdf

monai.data.box_utils.boxes_center_distance(boxes1, boxes2, euclidean=True)[source]#

Distance of center points between two sets of boxes

Parameters
  • boxes1 (Union[ndarray, Tensor]) – bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode

  • boxes2 (Union[ndarray, Tensor]) – bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be StandardMode

  • euclidean (bool) – computed the euclidean distance otherwise it uses the l1 distance

Return type

Tuple[Union[ndarray, Tensor], Union[ndarray, Tensor], Union[ndarray, Tensor]]

Returns

  • The pairwise distances for every element in boxes1 and boxes2, with size of (N,M) and same data type as boxes1.

  • Center points of boxes1, with size of (N,spatial_dims) and same data type as boxes1.

  • Center points of boxes2, with size of (M,spatial_dims) and same data type as boxes1.

Reference:

https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/ops.py

monai.data.box_utils.centers_in_boxes(centers, boxes, eps=0.01)[source]#

Checks which center points are within boxes

Parameters
  • boxes (Union[ndarray, Tensor]) – bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode.

  • centers (Union[ndarray, Tensor]) – center points, Nx2 or Nx3 torch tensor or ndarray.

  • eps (float) – minimum distance to border of boxes.

Return type

Union[ndarray, Tensor]

Returns

boolean array indicating which center points are within the boxes, sized (N,).

Reference:

https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/ops.py

monai.data.box_utils.clip_boxes_to_image(boxes, spatial_size, remove_empty=True)[source]#

This function clips the boxes to makes sure the bounding boxes are within the image.

Parameters
  • boxes (Union[ndarray, Tensor]) – bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode

  • spatial_size (Union[Sequence[int], ndarray, Tensor]) – The spatial size of the image where the boxes are attached. len(spatial_size) should be in [2, 3].

  • remove_empty (bool) – whether to remove the boxes that are actually empty

Return type

Tuple[Union[ndarray, Tensor], Union[ndarray, Tensor]]

Returns

  • clipped boxes, boxes[keep], does not share memory with original boxes

  • keep, it indicates whether each box in boxes are kept when remove_empty=True.

monai.data.box_utils.convert_box_mode(boxes, src_mode=None, dst_mode=None)[source]#

This function converts the boxes in src_mode to the dst_mode.

Parameters
  • boxes (Union[ndarray, Tensor]) – source bounding boxes, Nx4 or Nx6 torch tensor or ndarray.

  • src_mode (Union[str, BoxMode, Type[BoxMode], None]) – source box mode. If it is not given, this func will assume it is StandardMode(). It follows the same format with mode in get_boxmode().

  • dst_mode (Union[str, BoxMode, Type[BoxMode], None]) – target box mode. If it is not given, this func will assume it is StandardMode(). It follows the same format with mode in get_boxmode().

Return type

Union[ndarray, Tensor]

Returns

bounding boxes with target mode, with same data type as boxes, does not share memory with boxes

Example

boxes = torch.ones(10,4)
# The following three lines are equivalent
# They convert boxes with format [xmin, ymin, xmax, ymax] to [xcenter, ycenter, xsize, ysize].
convert_box_mode(boxes=boxes, src_mode="xyxy", dst_mode="ccwh")
convert_box_mode(boxes=boxes, src_mode="xyxy", dst_mode=monai.data.box_utils.CenterSizeMode)
convert_box_mode(boxes=boxes, src_mode="xyxy", dst_mode=monai.data.box_utils.CenterSizeMode())
monai.data.box_utils.convert_box_to_standard_mode(boxes, mode=None)[source]#

Convert given boxes to standard mode. Standard mode is “xyxy” or “xyzxyz”, representing box format of [xmin, ymin, xmax, ymax] or [xmin, ymin, zmin, xmax, ymax, zmax].

Parameters
  • boxes (Union[ndarray, Tensor]) – source bounding boxes, Nx4 or Nx6 torch tensor or ndarray.

  • mode (Union[str, BoxMode, Type[BoxMode], None]) – source box mode. If it is not given, this func will assume it is StandardMode(). It follows the same format with mode in get_boxmode().

Return type

Union[ndarray, Tensor]

Returns

bounding boxes with standard mode, with same data type as boxes, does not share memory with boxes

Example

boxes = torch.ones(10,6)
# The following two lines are equivalent
# They convert boxes with format [xmin, xmax, ymin, ymax, zmin, zmax] to [xmin, ymin, zmin, xmax, ymax, zmax]
convert_box_to_standard_mode(boxes=boxes, mode="xxyyzz")
convert_box_mode(boxes=boxes, src_mode="xxyyzz", dst_mode="xyzxyz")
monai.data.box_utils.get_boxmode(mode=None, *args, **kwargs)[source]#

This function that return a BoxMode object giving a representation of box mode

Parameters

mode (Union[str, BoxMode, Type[BoxMode], None]) – a representation of box mode. If it is not given, this func will assume it is StandardMode().

Note

StandardMode = CornerCornerModeTypeA, also represented as “xyxy” for 2D and “xyzxyz” for 3D.

mode can be:
  1. str: choose from BoxModeName, for example,
    • “xyxy”: boxes has format [xmin, ymin, xmax, ymax]

    • “xyzxyz”: boxes has format [xmin, ymin, zmin, xmax, ymax, zmax]

    • “xxyy”: boxes has format [xmin, xmax, ymin, ymax]

    • “xxyyzz”: boxes has format [xmin, xmax, ymin, ymax, zmin, zmax]

    • “xyxyzz”: boxes has format [xmin, ymin, xmax, ymax, zmin, zmax]

    • “xywh”: boxes has format [xmin, ymin, xsize, ysize]

    • “xyzwhd”: boxes has format [xmin, ymin, zmin, xsize, ysize, zsize]

    • “ccwh”: boxes has format [xcenter, ycenter, xsize, ysize]

    • “cccwhd”: boxes has format [xcenter, ycenter, zcenter, xsize, ysize, zsize]

  2. BoxMode class: choose from the subclasses of BoxMode, for example,
    • CornerCornerModeTypeA: equivalent to “xyxy” or “xyzxyz”

    • CornerCornerModeTypeB: equivalent to “xxyy” or “xxyyzz”

    • CornerCornerModeTypeC: equivalent to “xyxy” or “xyxyzz”

    • CornerSizeMode: equivalent to “xywh” or “xyzwhd”

    • CenterSizeMode: equivalent to “ccwh” or “cccwhd”

  3. BoxMode object: choose from the subclasses of BoxMode, for example,
    • CornerCornerModeTypeA(): equivalent to “xyxy” or “xyzxyz”

    • CornerCornerModeTypeB(): equivalent to “xxyy” or “xxyyzz”

    • CornerCornerModeTypeC(): equivalent to “xyxy” or “xyxyzz”

    • CornerSizeMode(): equivalent to “xywh” or “xyzwhd”

    • CenterSizeMode(): equivalent to “ccwh” or “cccwhd”

  4. None: will assume mode is StandardMode()

Return type

BoxMode

Returns

BoxMode object

Example

mode = "xyzxyz"
get_boxmode(mode) # will return CornerCornerModeTypeA()
monai.data.box_utils.get_spatial_dims(boxes=None, points=None, corners=None, spatial_size=None)[source]#

Get spatial dimension for the giving setting and check the validity of them. Missing input is allowed. But at least one of the input value should be given. It raises ValueError if the dimensions of multiple inputs do not match with each other.

Parameters
  • boxes (Union[ndarray, Tensor, None]) – bounding boxes, Nx4 or Nx6 torch tensor or ndarray

  • points (Union[ndarray, Tensor, None]) – point coordinates, [x, y] or [x, y, z], Nx2 or Nx3 torch tensor or ndarray

  • corners (Optional[Sequence]) – corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor or ndarray

  • spatial_size (Union[Sequence[int], ndarray, Tensor, None]) – The spatial size of the image where the boxes are attached. len(spatial_size) should be in [2, 3].

Returns

spatial_dims, number of spatial dimensions of the bounding boxes.

Return type

int

Example

boxes = torch.ones(10,6)
get_spatial_dims(boxes, spatial_size=[100,200,200]) # will return 3
get_spatial_dims(boxes, spatial_size=[100,200]) # will raise ValueError
get_spatial_dims(boxes) # will return 3
monai.data.box_utils.is_valid_box_values(boxes)[source]#

This function checks whether the box size is non-negative.

Parameters

boxes (Union[ndarray, Tensor]) – bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode

Return type

bool

Returns

whether boxes is valid

monai.data.box_utils.non_max_suppression(boxes, scores, nms_thresh, max_proposals=-1, box_overlap_metric=<function box_iou>)[source]#

Non-maximum suppression (NMS).

Parameters
  • boxes (Union[ndarray, Tensor]) – bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode

  • scores (Union[ndarray, Tensor]) – prediction scores of the boxes, sized (N,). This function keeps boxes with higher scores.

  • nms_thresh (float) – threshold of NMS. Discards all overlapping boxes with box_overlap > nms_thresh.

  • max_proposals (int) – maximum number of boxes it keeps. If max_proposals = -1, there is no limit on the number of boxes that are kept.

  • box_overlap_metric (Callable) – the metric to compute overlap between boxes.

Return type

Union[ndarray, Tensor]

Returns

Indexes of boxes that are kept after NMS.

Example

boxes = torch.ones(10,6)
scores = torch.ones(10)
keep = non_max_suppression(boxes, scores, num_thresh=0.1)
boxes_after_nms = boxes[keep]
monai.data.box_utils.spatial_crop_boxes(boxes, roi_start, roi_end, remove_empty=True)[source]#

This function generate the new boxes when the corresponding image is cropped to the given ROI. When remove_empty=True, it makes sure the bounding boxes are within the new cropped image.

Parameters
  • boxes (Union[ndarray, Tensor]) – bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode

  • roi_start (Union[Sequence[int], ndarray, Tensor]) – voxel coordinates for start of the crop ROI, negative values allowed.

  • roi_end (Union[Sequence[int], ndarray, Tensor]) – voxel coordinates for end of the crop ROI, negative values allowed.

  • remove_empty (bool) – whether to remove the boxes that are actually empty

Return type

Tuple[Union[ndarray, Tensor], Union[ndarray, Tensor]]

Returns

  • cropped boxes, boxes[keep], does not share memory with original boxes

  • keep, it indicates whether each box in boxes are kept when remove_empty=True.

Video datasets#

VideoDataset#

class monai.data.VideoDataset(video_source, transform=None, max_num_frames=None, color_order=ColorOrder.RGB, multiprocessing=False, channel_dim=0)[source]#

VideoFileDataset#

class monai.data.VideoFileDataset(*args, **kwargs)[source]#

Video dataset from file.

This class requires that OpenCV be installed.

CameraDataset#

class monai.data.CameraDataset(video_source, transform=None, max_num_frames=None, color_order=ColorOrder.RGB, multiprocessing=False, channel_dim=0)[source]#

Video dataset from a capture device (e.g., webcam).

This class requires that OpenCV be installed.

Parameters
  • video_source (Union[str, int]) – index of capture device. get_num_devices can be used to determine possible devices.

  • transform (Optional[Callable]) – transform to be applied to each frame.

  • max_num_frames (Optional[int]) – Max number of frames to iterate across. If None is passed, then the dataset will iterate infinitely.

Raises

RuntimeError – OpenCV not installed.