Source code for monai.handlers.classification_saver

# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import warnings
from typing import TYPE_CHECKING, Callable, List, Optional

import torch

from monai.config import IgniteInfo
from monai.data import CSVSaver, decollate_batch
from monai.utils import ImageMetaKey as Key
from monai.utils import evenly_divisible_all_gather, min_version, optional_import, string_list_all_gather

idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed")
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
if TYPE_CHECKING:
    from ignite.engine import Engine
else:
    Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")


[docs]class ClassificationSaver: """ Event handler triggered on completing every iteration to save the classification predictions as CSV file. If running in distributed data parallel, only saves CSV file in the specified rank. """
[docs] def __init__( self, output_dir: str = "./", filename: str = "predictions.csv", overwrite: bool = True, batch_transform: Callable = lambda x: x, output_transform: Callable = lambda x: x, name: Optional[str] = None, save_rank: int = 0, saver: Optional[CSVSaver] = None, ) -> None: """ Args: output_dir: if `saver=None`, output CSV file directory. filename: if `saver=None`, name of the saved CSV file name. overwrite: if `saver=None`, whether to overwriting existing file content, if True, will clear the file before saving. otherwise, will append new content to the file. batch_transform: a callable that is used to extract the `meta_data` dictionary of the input images from `ignite.engine.state.batch`. the purpose is to get the input filenames from the `meta_data` and store with classification results together. `engine.state` and `batch_transform` inherit from the ignite concept: https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. output_transform: a callable that is used to extract the model prediction data from `ignite.engine.state.output`. the first dimension of its output will be treated as the batch dimension. each item in the batch will be saved individually. `engine.state` and `output_transform` inherit from the ignite concept: https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. name: identifier of logging.logger to use, defaulting to `engine.logger`. save_rank: only the handler on specified rank will save to CSV file in multi-gpus validation, default to 0. saver: the saver instance to save classification results, if None, create a CSVSaver internally. the saver must provide `save_batch(batch_data, meta_data)` and `finalize()` APIs. """ self.save_rank = save_rank self.output_dir = output_dir self.filename = filename self.overwrite = overwrite self.batch_transform = batch_transform self.output_transform = output_transform self.saver = saver self.logger = logging.getLogger(name) self._name = name self._outputs: List[torch.Tensor] = [] self._filenames: List[str] = []
[docs] def attach(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ if self._name is None: self.logger = engine.logger if not engine.has_event_handler(self._started, Events.EPOCH_STARTED): engine.add_event_handler(Events.EPOCH_STARTED, self._started) if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self) if not engine.has_event_handler(self._finalize, Events.EPOCH_COMPLETED): engine.add_event_handler(Events.EPOCH_COMPLETED, self._finalize)
def _started(self, _engine: Engine) -> None: """ Initialize internal buffers. Args: _engine: Ignite Engine, unused argument. """ self._outputs = [] self._filenames = [] def __call__(self, engine: Engine) -> None: """ This method assumes self.batch_transform will extract metadata from the input batch. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ meta_data = self.batch_transform(engine.state.batch) if isinstance(meta_data, dict): # decollate the `dictionary of list` to `list of dictionaries` meta_data = decollate_batch(meta_data) engine_output = self.output_transform(engine.state.output) for m, o in zip(meta_data, engine_output): self._filenames.append(f"{m.get(Key.FILENAME_OR_OBJ)}") if isinstance(o, torch.Tensor): o = o.detach() self._outputs.append(o) def _finalize(self, _engine: Engine) -> None: """ All gather classification results from ranks and save to CSV file. Args: _engine: Ignite Engine, unused argument. """ ws = idist.get_world_size() if self.save_rank >= ws: raise ValueError("target save rank is greater than the distributed group size.") outputs = torch.stack(self._outputs, dim=0) filenames = self._filenames if ws > 1: outputs = evenly_divisible_all_gather(outputs, concat=True) filenames = string_list_all_gather(filenames) if len(filenames) == 0: meta_dict = None else: if len(filenames) != len(outputs): warnings.warn(f"filenames length: {len(filenames)} doesn't match outputs length: {len(outputs)}.") meta_dict = {Key.FILENAME_OR_OBJ: filenames} # save to CSV file only in the expected rank if idist.get_rank() == self.save_rank: saver = self.saver or CSVSaver(self.output_dir, self.filename, self.overwrite) saver.save_batch(outputs, meta_dict) saver.finalize()