# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Callable, Optional, Union
import numpy as np
from monai.config import DtypeLike, IgniteInfo
from monai.data import decollate_batch
from monai.transforms import SaveImage
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, deprecated, min_version, optional_import
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
if TYPE_CHECKING:
from ignite.engine import Engine
else:
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
[docs]@deprecated(since="0.6.0", removed="0.7.0", msg_suffix="Please consider using `SaveImage[d]` transform instead.")
class SegmentationSaver:
"""
Event handler triggered on completing every iteration to save the segmentation predictions into files.
It can extract the input image meta data(filename, affine, original_shape, etc.) and resample the predictions
based on the meta data.
The name of saved file will be `{input_image_name}_{output_postfix}{output_ext}`,
where the input image name is extracted from the meta data dictionary. If no meta data provided,
use index from 0 as the filename prefix.
The predictions can be PyTorch Tensor with [B, C, H, W, [D]] shape or a list of Tensor without batch dim.
"""
def __init__(
self,
output_dir: str = "./",
output_postfix: str = "seg",
output_ext: str = ".nii.gz",
resample: bool = True,
mode: Union[GridSampleMode, InterpolateMode, str] = "nearest",
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
scale: Optional[int] = None,
dtype: DtypeLike = np.float64,
output_dtype: DtypeLike = np.float32,
squeeze_end_dims: bool = True,
data_root_dir: str = "",
batch_transform: Callable = lambda x: x,
output_transform: Callable = lambda x: x,
name: Optional[str] = None,
) -> None:
"""
Args:
output_dir: output image directory.
output_postfix: a string appended to all output file names, default to `seg`.
output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`.
resample: whether to resample before saving the data array.
if saving PNG format image, based on the `spatial_shape` from metadata.
if saving NIfTI format image, based on the `original_affine` from metadata.
mode: This option is used when ``resample = True``. Defaults to ``"nearest"``.
- NIfTI files {``"bilinear"``, ``"nearest"``}
Interpolation mode to calculate output values.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
- PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
The interpolation mode.
See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate
padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``.
- NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
- PNG files
This option is ignored.
scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling
[0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling.
It's used for PNG format only.
dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
If None, use the data type of input data.
It's used for Nifti format only.
output_dtype: data type for saving data. Defaults to ``np.float32``, it's used for Nifti format only.
squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel
has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and
then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false,
image will always be saved as (H,W,D,C).
it's used for NIfTI format only.
data_root_dir: if not empty, it specifies the beginning parts of the input file's
absolute path. it's used to compute `input_file_rel_path`, the relative path to the file from
`data_root_dir` to preserve folder structure when saving in case there are files in different
folders with the same file names. for example:
input_file_name: /foo/bar/test1/image.nii,
output_postfix: seg
output_ext: nii.gz
output_dir: /output,
data_root_dir: /foo/bar,
output will be: /output/test1/image/image_seg.nii.gz
batch_transform: a callable that is used to extract the `meta_data` dictionary of the input images
from `ignite.engine.state.batch`. the purpose is to extract necessary information from the meta data:
filename, affine, original_shape, etc.
output_transform: a callable that is used to extract the model prediction data from
`ignite.engine.state.output`. the first dimension of its output will be treated as the batch dimension.
each item in the batch will be saved individually.
name: identifier of logging.logger to use, defaulting to `engine.logger`.
"""
self._saver = SaveImage(
output_dir=output_dir,
output_postfix=output_postfix,
output_ext=output_ext,
resample=resample,
mode=mode,
padding_mode=padding_mode,
scale=scale,
dtype=dtype,
output_dtype=output_dtype,
squeeze_end_dims=squeeze_end_dims,
data_root_dir=data_root_dir,
)
self.batch_transform = batch_transform
self.output_transform = output_transform
self.logger = logging.getLogger(name)
self._name = name
[docs] def attach(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
if self._name is None:
self.logger = engine.logger
if not engine.has_event_handler(self, Events.ITERATION_COMPLETED):
engine.add_event_handler(Events.ITERATION_COMPLETED, self)
def __call__(self, engine: Engine) -> None:
"""
This method assumes self.batch_transform will extract metadata from the input batch.
Output file datatype is determined from ``engine.state.output.dtype``.
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
meta_data = self.batch_transform(engine.state.batch)
if isinstance(meta_data, dict):
# decollate the `dictionary of list` to `list of dictionaries`
meta_data = decollate_batch(meta_data)
engine_output = self.output_transform(engine.state.output)
for m, o in zip(meta_data, engine_output):
self._saver(o, m)
self.logger.info("model outputs saved into files.")