Source code for monai.handlers.metrics_saver

# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Union

from monai.handlers.utils import string_list_all_gather, write_metrics_reports
from monai.utils import ImageMetaKey as Key
from monai.utils import ensure_tuple, exact_version, issequenceiterable, optional_import

Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed")
if TYPE_CHECKING:
    from ignite.engine import Engine
else:
    Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine")


[docs]class MetricsSaver: """ ignite handler to save metrics values and details into expected files. Args: save_dir: directory to save the metrics and metric details. metrics: expected final metrics to save into files, can be: None, "*" or list of strings. None - don't save any metrics into files. "*" - save all the existing metrics in `engine.state.metrics` dict into separate files. list of strings - specify the expected metrics to save. default to "*" to save all the metrics into `metrics.csv`. metric_details: expected metric details to save into files, the data comes from `engine.state.metric_details`, which should be provided by different `Metrics`, typically, it's some intermediate values in metric computation. for example: mean dice of every channel of every image in the validation dataset. it must contain at least 2 dims: (batch, classes, ...), if not, will unsequeeze to 2 dims. this arg can be: None, "*" or list of strings. None - don't save any metric_details into files. "*" - save all the existing metric_details in `engine.state.metric_details` dict into separate files. list of strings - specify the metric_details of expected metrics to save. if not None, every metric_details array will save a separate `{metric name}_raw.csv` file. batch_transform: callable function to extract the meta_dict from input batch data if saving metric details. used to extract filenames from input dict data. summary_ops: expected computation operations to generate the summary report based on specified metric_details. it can be: None, "*" or list of strings. None - don't generate summary report for every specified metric_details "*" - generate summary report for every metric_details with all the supported operations. list of strings - generate summary report for every metric_details with specified operations, they should be within this list: [`mean`, `median`, `max`, `min`, `90percent`, `std`]. default to None. save_rank: only the handler on specified rank will save to files in multi-gpus validation, default to 0. delimiter: the delimiter character in CSV file, default to "\t". output_type: expected output file type, supported types: ["csv"], default to "csv". """ def __init__( self, save_dir: str, metrics: Optional[Union[str, Sequence[str]]] = "*", metric_details: Optional[Union[str, Sequence[str]]] = None, batch_transform: Callable = lambda x: x, summary_ops: Optional[Union[str, Sequence[str]]] = None, save_rank: int = 0, delimiter: str = "\t", output_type: str = "csv", ) -> None: self.save_dir = save_dir self.metrics = ensure_tuple(metrics) if metrics is not None else None self.metric_details = ensure_tuple(metric_details) if metric_details is not None else None self.batch_transform = batch_transform self.summary_ops = ensure_tuple(summary_ops) if summary_ops is not None else None self.save_rank = save_rank self.deli = delimiter self.output_type = output_type self._filenames: List[str] = []
[docs] def attach(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ engine.add_event_handler(Events.EPOCH_STARTED, self._started) engine.add_event_handler(Events.ITERATION_COMPLETED, self._get_filenames) engine.add_event_handler(Events.EPOCH_COMPLETED, self)
def _started(self, engine: Engine) -> None: self._filenames = [] def _get_filenames(self, engine: Engine) -> None: if self.metric_details is not None: filenames = self.batch_transform(engine.state.batch).get(Key.FILENAME_OR_OBJ) if issequenceiterable(filenames): self._filenames.extend(filenames) def __call__(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ ws = idist.get_world_size() if self.save_rank >= ws: raise ValueError("target save rank is greater than the distributed group size.") # all gather file names across ranks _images = string_list_all_gather(strings=self._filenames) if ws > 1 else self._filenames # only save metrics to file in specified rank if idist.get_rank() == self.save_rank: _metrics = {} if self.metrics is not None and len(engine.state.metrics) > 0: _metrics = {k: v for k, v in engine.state.metrics.items() if k in self.metrics or "*" in self.metrics} _metric_details = {} if self.metric_details is not None and len(engine.state.metric_details) > 0: for k, v in engine.state.metric_details.items(): if k in self.metric_details or "*" in self.metric_details: _metric_details[k] = v write_metrics_reports( save_dir=self.save_dir, images=None if len(_images) == 0 else _images, metrics=_metrics, metric_details=_metric_details, summary_ops=self.summary_ops, deli=self.deli, output_type=self.output_type, )