Source code for monai.apps.auto3dseg.data_analyzer

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from os import path
from typing import Dict, List, Optional, Union

import numpy as np
import torch

from monai.apps.utils import get_logger
from monai.auto3dseg import SegSummarizer
from monai.auto3dseg.utils import datafold_read
from monai.bundle import config_parser
from monai.bundle.config_parser import ConfigParser
from monai.data import DataLoader, Dataset
from monai.data.utils import no_collation
from monai.transforms import (
    Compose,
    EnsureChannelFirstd,
    EnsureTyped,
    Lambdad,
    LoadImaged,
    Orientationd,
    SqueezeDimd,
    ToDeviced,
)
from monai.utils import StrEnum, min_version, optional_import
from monai.utils.enums import DataStatsKeys, ImageStatsKeys


def strenum_representer(dumper, data):
    return dumper.represent_scalar("tag:yaml.org,2002:str", data.value)


if optional_import("yaml")[1]:
    config_parser.yaml.SafeDumper.add_multi_representer(StrEnum, strenum_representer)

tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm")
logger = get_logger(module_name=__name__)

__all__ = ["DataAnalyzer"]


def _argmax_if_multichannel(x):
    return torch.argmax(x, dim=0, keepdim=True) if x.shape[0] > 1 else x


[docs]class DataAnalyzer: """ The DataAnalyzer automatically analyzes given medical image dataset and reports the statistics. The module expects file paths to the image data and utilizes the LoadImaged transform to read the files, which supports nii, nii.gz, png, jpg, bmp, npz, npy, and dcm formats. Currently, only segmentation task is supported, so the user needs to provide paths to the image and label files (if have). Also, label data format is preferred to be (1,H,W,D), with the label index in the first dimension. If it is in onehot format, it will be converted to the preferred format. Args: datalist: a Python dictionary storing group, fold, and other information of the medical image dataset, or a string to the JSON file storing the dictionary. dataroot: user's local directory containing the datasets. output_path: path to save the analysis result. average: whether to average the statistical value across different image modalities. do_ccp: apply the connected component algorithm to process the labels/images device: a string specifying hardware (CUDA/CPU) utilized for the operations. worker: number of workers to use for parallel processing. If device is cuda/GPU, worker has to be 0. image_key: a string that user specify for the image. The DataAnalyzer will look it up in the datalist to locate the image files of the dataset. label_key: a string that user specify for the label. The DataAnalyzer will look it up in the datalist to locate the label files of the dataset. If label_key is None, the DataAnalyzer will skip looking for labels and all label-related operations. Raises: ValueError if device is GPU and worker > 0. Examples: .. code-block:: python from monai.apps.auto3dseg.data_analyzer import DataAnalyzer datalist = { "testing": [{"image": "image_003.nii.gz"}], "training": [ {"fold": 0, "image": "image_001.nii.gz", "label": "label_001.nii.gz"}, {"fold": 0, "image": "image_002.nii.gz", "label": "label_002.nii.gz"}, {"fold": 1, "image": "image_001.nii.gz", "label": "label_001.nii.gz"}, {"fold": 1, "image": "image_004.nii.gz", "label": "label_004.nii.gz"}, ], } dataroot = '/datasets' # the directory where you have the image files (nii.gz) DataAnalyzer(datalist, dataroot) Notes: The module can also be called from the command line interface (CLI). For example: .. code-block:: bash python -m monai.apps.auto3dseg \\ DataAnalyzer \\ get_all_case_stats \\ --datalist="my_datalist.json" \\ --dataroot="my_dataroot_dir" """ def __init__( self, datalist: Union[str, Dict], dataroot: str = "", output_path: str = "./data_stats.yaml", average: bool = True, do_ccp: bool = True, device: Union[str, torch.device] = "cuda", worker: int = 0, image_key: str = "image", label_key: Optional[str] = "label", ): if path.isfile(output_path): warnings.warn(f"File {output_path} already exists and will be overwritten.") logger.debug(f"{output_path} will be overwritten by a new datastat.") self.datalist = datalist self.dataroot = dataroot self.output_path = output_path self.average = average self.do_ccp = do_ccp self.device = torch.device(device) self.worker = worker self.image_key = image_key self.label_key = label_key if (self.device.type == "cuda") and (worker > 0): raise ValueError("CUDA does not support multiple subprocess. If device is GPU, please set worker to 0") @staticmethod def _check_data_uniformity(keys: List[str], result: Dict): """ Check data uniformity since DataAnalyzer provides no support to multi-modal images with different affine matrices/spacings due to monai transforms. Args: keys: a list of string-type keys under image_stats dictionary. Returns: False if one of the selected key values is not constant across the dataset images. """ constant_props = [result[DataStatsKeys.SUMMARY][DataStatsKeys.IMAGE_STATS][key] for key in keys] for prop in constant_props: if "stdev" in prop and np.any(prop["stdev"]): return False return True
[docs] def get_all_case_stats(self): """ Get all case stats. Caller of the DataAnalyser class. The function iterates datalist and call get_case_stats to generate stats. Then get_case_summary is called to combine results. Returns: A data statistics dictionary containing "stats_summary" (summary statistics of the entire datasets). Within stats_summary there are "image_stats" (summarizing info of shape, channel, spacing, and etc using operations_summary), "image_foreground_stats" (info of the intensity for the non-zero labeled voxels), and "label_stats" (info of the labels, pixel percentage, image_intensity, and each individual label in a list) "stats_by_cases" (List type value. Each element of the list is statistics of an image-label info. Within each element, there are: "image" (value is the path to an image), "label" (value is the path to the corresponding label), "image_stats" (summarizing info of shape, channel, spacing, and etc using operations), "image_foreground_stats" (similar to the previous one but one foreground image), and "label_stats" (stats of the individual labels ) Notes: Since the backend of the statistics computation are torch/numpy, nan/inf value may be generated and carried over in the computation. In such cases, the output dictionary will include .nan/.inf in the statistics. """ summarizer = SegSummarizer(self.image_key, self.label_key, average=self.average, do_ccp=self.do_ccp) keys = list(filter(None, [self.image_key, self.label_key])) transform_list = [ LoadImaged(keys=keys), EnsureChannelFirstd(keys=keys), # this creates label to be (1,H,W,D) Orientationd(keys=keys, axcodes="RAS"), EnsureTyped(keys=keys, data_type="tensor"), Lambdad(keys=self.label_key, func=_argmax_if_multichannel) if self.label_key else None, SqueezeDimd(keys=["label"], dim=0) if self.label_key else None, ToDeviced(keys=keys, device=self.device), summarizer, ] transform = Compose(transforms=list(filter(None, transform_list))) files, _ = datafold_read(datalist=self.datalist, basedir=self.dataroot, fold=-1) dataset = Dataset(data=files, transform=transform) dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.worker, collate_fn=no_collation) result = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []} if not has_tqdm: warnings.warn("tqdm is not installed. not displaying the caching progress.") for batch_data in tqdm(dataloader) if has_tqdm else dataloader: d = batch_data[0] stats_by_cases = { DataStatsKeys.BY_CASE_IMAGE_PATH: d[DataStatsKeys.BY_CASE_IMAGE_PATH], DataStatsKeys.BY_CASE_LABEL_PATH: d[DataStatsKeys.BY_CASE_LABEL_PATH], DataStatsKeys.IMAGE_STATS: d[DataStatsKeys.IMAGE_STATS], } if self.label_key is not None: stats_by_cases.update( { DataStatsKeys.FG_IMAGE_STATS: d[DataStatsKeys.FG_IMAGE_STATS], DataStatsKeys.LABEL_STATS: d[DataStatsKeys.LABEL_STATS], } ) result[DataStatsKeys.BY_CASE].append(stats_by_cases) result[DataStatsKeys.SUMMARY] = summarizer.summarize(result[DataStatsKeys.BY_CASE]) if not self._check_data_uniformity([ImageStatsKeys.SPACING], result): logger.warning("Data is not completely uniform. MONAI transforms may provide unexpected result") ConfigParser.export_config_file(result, self.output_path, fmt="yaml", default_flow_style=None) return result