# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING
from monai.config import IgniteInfo
from monai.data import decollate_batch
from monai.handlers.utils import write_metrics_reports
from monai.utils import ImageMetaKey as Key
from monai.utils import ensure_tuple, min_version, optional_import, string_list_all_gather
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed")
if TYPE_CHECKING:
from ignite.engine import Engine
else:
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
[docs]
class MetricsSaver:
"""
ignite handler to save metrics values and details into expected files.
Args:
save_dir: directory to save the metrics and metric details.
metrics: expected final metrics to save into files, can be: None, "*" or list of strings.
None - don't save any metrics into files.
"*" - save all the existing metrics in `engine.state.metrics` dict into separate files.
list of strings - specify the expected metrics to save.
default to "*" to save all the metrics into `metrics.csv`.
metric_details: expected metric details to save into files, the data comes from
`engine.state.metric_details`, which should be provided by different `Metrics`,
typically, it's some intermediate values in metric computation.
for example: mean dice of every channel of every image in the validation dataset.
it must contain at least 2 dims: (batch, classes, ...),
if not, will unsqueeze to 2 dims.
this arg can be: None, "*" or list of strings.
None - don't save any metric_details into files.
"*" - save all the existing metric_details in `engine.state.metric_details` dict into separate files.
list of strings - specify the metric_details of expected metrics to save.
if not None, every metric_details array will save a separate `{metric name}_raw.csv` file.
batch_transform: a callable that is used to extract the `meta_data` dictionary of
the input images from `ignite.engine.state.batch` if saving metric details. the purpose is to get the
input filenames from the `meta_data` and store with metric details together.
`engine.state` and `batch_transform` inherit from the ignite concept:
https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
summary_ops: expected computation operations to generate the summary report.
it can be: None, "*" or list of strings, default to None.
None - don't generate summary report for every expected metric_details.
"*" - generate summary report for every metric_details with all the supported operations.
list of strings - generate summary report for every metric_details with specified operations, they
should be within list: ["mean", "median", "max", "min", "<int>percentile", "std", "notnans"].
the number in "<int>percentile" should be [0, 100], like: "15percentile". default: "90percentile".
for more details, please check: https://numpy.org/doc/stable/reference/generated/numpy.nanpercentile.html.
note that: for the overall summary, it computes `nanmean` of all classes for each image first,
then compute summary. example of the generated summary report::
class mean median max 5percentile 95percentile notnans
class0 6.0000 6.0000 7.0000 5.1000 6.9000 2.0000
class1 6.0000 6.0000 6.0000 6.0000 6.0000 1.0000
mean 6.2500 6.2500 7.0000 5.5750 6.9250 2.0000
save_rank: only the handler on specified rank will save to files in multi-gpus validation, default to 0.
delimiter: the delimiter character in the saved file, default to "," as the default output type is `csv`.
to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.
output_type: expected output file type, supported types: ["csv"], default to "csv".
"""
def __init__(
self,
save_dir: str,
metrics: str | Sequence[str] | None = "*",
metric_details: str | Sequence[str] | None = None,
batch_transform: Callable = lambda x: x,
summary_ops: str | Sequence[str] | None = None,
save_rank: int = 0,
delimiter: str = ",",
output_type: str = "csv",
) -> None:
self.save_dir = save_dir
self.metrics = ensure_tuple(metrics) if metrics is not None else None
self.metric_details = ensure_tuple(metric_details) if metric_details is not None else None
self.batch_transform = batch_transform
self.summary_ops = ensure_tuple(summary_ops) if summary_ops is not None else None
self.save_rank = save_rank
self.deli = delimiter
self.output_type = output_type
self._filenames: list[str] = []
[docs]
def attach(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
engine.add_event_handler(Events.EPOCH_STARTED, self._started)
engine.add_event_handler(Events.ITERATION_COMPLETED, self._get_filenames)
engine.add_event_handler(Events.EPOCH_COMPLETED, self)
def _started(self, _engine: Engine) -> None:
"""
Initialize internal buffers.
Args:
_engine: Ignite Engine, unused argument.
"""
self._filenames = []
def _get_filenames(self, engine: Engine) -> None:
if self.metric_details is not None:
meta_data = self.batch_transform(engine.state.batch)
if isinstance(meta_data, dict):
# decollate the `dictionary of list` to `list of dictionaries`
meta_data = decollate_batch(meta_data)
for m in meta_data:
self._filenames.append(f"{m.get(Key.FILENAME_OR_OBJ)}")
def __call__(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
ws = idist.get_world_size()
if self.save_rank >= ws:
raise ValueError("target save rank is greater than the distributed group size.")
# all gather file names across ranks
_images = string_list_all_gather(strings=self._filenames) if ws > 1 else self._filenames
# only save metrics to file in specified rank
if idist.get_rank() == self.save_rank:
_metrics = {}
if self.metrics is not None and len(engine.state.metrics) > 0:
_metrics = {k: v for k, v in engine.state.metrics.items() if k in self.metrics or "*" in self.metrics}
_metric_details = {}
if hasattr(engine.state, "metric_details"):
details = engine.state.metric_details
if self.metric_details is not None and len(details) > 0:
for k, v in details.items():
if k in self.metric_details or "*" in self.metric_details:
_metric_details[k] = v
write_metrics_reports(
save_dir=self.save_dir,
images=None if len(_images) == 0 else _images,
metrics=_metrics,
metric_details=_metric_details,
summary_ops=self.summary_ops,
deli=self.deli,
output_type=self.output_type,
)