Source code for monai.data.nifti_reader

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import nibabel as nib
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data._utils.collate import np_str_obj_array_pattern

from monai.data.utils import correct_nifti_header_if_necessary
from monai.transforms.compose import Randomizable


[docs]def load_nifti(filename_or_obj, as_closest_canonical=False, image_only=True, dtype=None): """ Loads a Nifti file from the given path or file-like object. Args: filename_or_obj (str or file): path to file or file-like object as_closest_canonical (bool): if True, load the image as closest to canonical axis format image_only (bool): if True return only the image volume, other return image volume and header dict dtype (np.dtype, optional): if not None convert the loaded image to this data type Returns: The loaded image volume if `image_only` is True, or a tuple containing the volume and the Nifti header in dict format otherwise Note: header['original_affine'] stores the original affine loaded from `filename_or_obj`. header['affine'] stores the affine after the optional `as_closest_canonical` transform. """ img = nib.load(filename_or_obj) img = correct_nifti_header_if_necessary(img) header = dict(img.header) header['filename_or_obj'] = filename_or_obj header['original_affine'] = img.affine header['affine'] = img.affine header['as_closest_canonical'] = as_closest_canonical if as_closest_canonical: img = nib.as_closest_canonical(img) header['affine'] = img.affine if dtype is not None: dat = img.get_fdata(dtype=dtype) else: dat = np.asanyarray(img.dataobj) if image_only: return dat return dat, header
[docs]class NiftiDataset(Dataset): """ Loads image/segmentation pairs of Nifti files from the given filename lists. Transformations can be specified for the image and segmentation arrays separately. """ def __init__(self, image_files, seg_files=None, labels=None, as_closest_canonical=False, transform=None, seg_transform=None, image_only=True, dtype=None): """ Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied to the images and `seg_transform` to the segmentations. Args: image_files (list of str): list of image filenames seg_files (list of str): if in segmentation task, list of segmentation filenames labels (list or array): if in classification task, list of classification labels as_closest_canonical (bool): if True, load the image as closest to canonical orientation transform (Callable, optional): transform to apply to image arrays seg_transform (Callable, optional): transform to apply to segmentation arrays image_only (bool): if True return only the image volume, other return image volume and header dict dtype (np.dtype, optional): if not None convert the loaded image to this data type """ if seg_files is not None and len(image_files) != len(seg_files): raise ValueError('Must have same number of image and segmentation files') self.image_files = image_files self.seg_files = seg_files self.labels = labels self.as_closest_canonical = as_closest_canonical self.transform = transform self.seg_transform = seg_transform self.image_only = image_only self.dtype = dtype def __len__(self): return len(self.image_files) def __getitem__(self, index): meta_data = None if self.image_only: img = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical, image_only=self.image_only, dtype=self.dtype) else: img, meta_data = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical, image_only=self.image_only, dtype=self.dtype) target = None if self.seg_files is not None: target = load_nifti(self.seg_files[index]) elif self.labels is not None: target = self.labels[index] seed = np.random.randint(2147483647) if self.transform is not None: if isinstance(self.transform, Randomizable): self.transform.set_random_state(seed=seed) img = self.transform(img) if self.seg_transform is not None: if isinstance(self.seg_transform, Randomizable): self.seg_transform.set_random_state(seed=seed) target = self.seg_transform(target) if self.image_only or meta_data is None: return img, target compatible_meta = {} for meta_key in meta_data: meta_datum = meta_data[meta_key] if type(meta_datum).__name__ == 'ndarray' \ and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None: continue compatible_meta[meta_key] = meta_datum return img, target, compatible_meta