Source code for monailabel.transform.writer

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import tempfile
from typing import Any, Dict, Iterable, List, Optional, Tuple

import itk
import nrrd
import numpy as np
import torch
from monai.data import MetaTensor
from monai.data.image_writer import NibabelWriter

from monailabel.utils.others.detection import create_slicer_detection_json
from monailabel.utils.others.generic import file_ext
from monailabel.utils.others.pathology import create_asap_annotations_xml, create_dsa_annotations_json

logger = logging.getLogger(__name__)


[docs]def write_nifti(data, filename, affine=None, original_affine=None, output_dtype=np.float32): writer = NibabelWriter(output_dtype=output_dtype) writer.set_data_array(data, channel_dim=None, spatial_ndim=None) writer.set_metadata({"affine": affine, "original_affine": original_affine}) writer.write(filename)
[docs]def write_itk(image_np, output_file, affine, dtype, compress): if isinstance(image_np, torch.Tensor): image_np = image_np.numpy() if isinstance(affine, torch.Tensor): affine = affine.numpy() if len(image_np.shape) >= 2: image_np = image_np.transpose().copy() if dtype: image_np = image_np.astype(dtype) result_image = itk.image_from_array(image_np) logger.debug(f"ITK Image size: {itk.size(result_image)}") # https://github.com/RSIP-Vision/medio/blob/master/medio/metadata/affine.py#L108-L121 if affine is not None: convert_aff_mat = np.diag([-1, -1, 1, 1]) if affine.shape[0] == 3: # Handle RGB (2D Image) convert_aff_mat = np.diag([-1, -1, 1]) affine = convert_aff_mat @ affine dim = affine.shape[0] - 1 _origin_key = (slice(-1), -1) _m_key = (slice(-1), slice(-1)) origin = affine[_origin_key] spacing = np.linalg.norm(affine[_m_key] @ np.eye(dim), axis=0) direction = affine[_m_key] @ np.diag(1 / spacing) logger.debug(f"Affine: {affine}") logger.debug(f"Origin: {origin}") logger.debug(f"Spacing: {spacing}") logger.debug(f"Direction: {direction}") result_image.SetDirection(itk.matrix_from_array(direction)) result_image.SetSpacing(spacing) result_image.SetOrigin(origin) itk.imwrite(result_image, output_file, compress)
[docs]def write_seg_nrrd( image_np: np.ndarray, output_file: str, dtype: type, affine: np.ndarray, labels: List[str], color_map: Optional[Dict[str, List[float]]] = None, index_order: str = "C", space: str = "left-posterior-superior", ) -> None: """Write multi-channel seg.nrrd file. Args: image_np: Image as numpy ndarray output_file: Output file path that the seg.nrrd file should be saved to dtype: numpy type e.g. float32 affine: Affine matrix labels: Labels of image segment which will be written to the nrrd header color_map: Mapping from segment_name(str) to it's color e.g. {'heart': [255/255, 244/255, 209/255]} index_order: Either 'C' or 'F' (see nrrd.write() documentation) Raises: ValueError: In case affine is not provided ValueError: In case labels are not provided """ if isinstance(image_np, torch.Tensor): image_np = image_np.numpy() if isinstance(affine, torch.Tensor): affine = affine.numpy() image_np = image_np.transpose().copy() if dtype: image_np = image_np.astype(dtype) if not isinstance(labels, Iterable): raise ValueError("Labels have to be defined, e.g. as a list") header: Dict[str, Any] = {} for i, segment_name in enumerate(labels): header.update( { f"Segment{i}_ID": segment_name, f"Segment{i}_Name": segment_name, } ) if color_map is not None: header[f"Segment{i}_Color"] = " ".join(list(map(str, color_map[segment_name]))) if affine is None: raise ValueError("Affine matrix has to be defined") kinds = ["list", "domain", "domain", "domain"] convert_aff_mat = np.diag([-1, -1, 1, 1]) affine = convert_aff_mat @ affine _origin_key = (slice(-1), -1) origin = affine[_origin_key] space_directions = np.array( [ [np.nan, np.nan, np.nan], affine[0, :3], affine[1, :3], affine[2, :3], ] ) header.update( { "kinds": kinds, "space directions": space_directions, "space origin": origin, "space": space, } ) nrrd.write( output_file, image_np, header=header, index_order=index_order, )
[docs]class Writer: def __init__( self, label="pred", json=None, ref_image=None, key_extension="result_extension", key_dtype="result_dtype", key_compress="result_compress", key_write_to_file="result_write_to_file", meta_key_postfix="meta_dict", nibabel=False, ): self.label = label self.json = json self.ref_image = ref_image if ref_image else label # User can specify through params self.key_extension = key_extension self.key_dtype = key_dtype self.key_compress = key_compress self.key_write_to_file = key_write_to_file self.meta_key_postfix = meta_key_postfix self.nibabel = nibabel def __call__(self, data) -> Tuple[Any, Any]: logger.setLevel(data.get("logging", "INFO").upper()) path = data.get("image_path") ext = file_ext(path) if path else None dtype = data.get(self.key_dtype, None) compress = data.get(self.key_compress, False) write_to_file = data.get(self.key_write_to_file, True) ext = data.get(self.key_extension) if data.get(self.key_extension) else ext write_to_file = write_to_file if ext else False logger.info(f"Result ext: {ext}; write_to_file: {write_to_file}; dtype: {dtype}") if isinstance(data[self.label], MetaTensor): image_np = data[self.label].array else: image_np = data[self.label] # Always using Restored as the last transform before writing meta_dict = data.get(f"{self.ref_image}_{self.meta_key_postfix}") affine = meta_dict.get("affine") if meta_dict else None if affine is None and isinstance(data[self.ref_image], MetaTensor): affine = data[self.ref_image].affine logger.debug(f"Image: {image_np.shape}; Data Image: {data[self.label].shape}") output_file = None output_json = data.get(self.json, {}) if write_to_file: output_file = tempfile.NamedTemporaryFile(suffix=ext).name logger.debug(f"Saving Image to: {output_file}") if self.is_multichannel_image(image_np): if ext != ".seg.nrrd": logger.warning( f"Using extension '{ext}' with multi-channel 4D label will probably fail" + "Consider to use extension '.seg.nrrd'" ) labels = data.get("labels") color_map = data.get("color_map") logger.debug("Using write_seg_nrrd...") write_seg_nrrd(image_np, output_file, dtype, affine, labels, color_map) # Issue with slicer:: https://discourse.itk.org/t/saving-non-orthogonal-volume-in-nifti-format/2760/22 elif self.nibabel and ext and ext.lower() in [".nii", ".nii.gz"]: logger.debug("Using MONAI write_nifti...") write_nifti(image_np, output_file, affine=affine, output_dtype=dtype) else: write_itk(image_np, output_file, affine if len(image_np.shape) > 2 else None, dtype, compress) else: output_file = image_np return output_file, output_json
[docs] def is_multichannel_image(self, image_np: np.ndarray) -> bool: """Check if the provided image contains multiple channels Args: image_np : Expected shape (channels, width, height, batch) Returns: bool: If this is a multi-channel image or not """ return len(image_np.shape) == 4 and image_np.shape[0] > 1
[docs]class ClassificationWriter: def __init__(self, label="pred", label_names=None): self.label = label self.label_names = label_names def __call__(self, data) -> Tuple[Any, Any]: logger.info(data[self.label].array) result = [] for idx, score in enumerate(data[self.label]): name = f"label_{idx}" name = self.label_names.get(idx) if self.label_names else name if name: result.append({"idx": idx, "label": name, "score": float(score)}) return None, {"prediction": result}
[docs]class PolygonWriter: def __init__( self, label="pred", json="result", key_write_to_file="result_write_to_file", key_annotations="annotations", key_label_colors="label_colors", key_output_format="output", ): self.label = label self.json = json self.key_write_to_file = key_write_to_file self.key_annotations = key_annotations self.key_label_colors = key_label_colors self.key_output_format = key_output_format self.format = format def __call__(self, data) -> Tuple[Any, Any]: loglevel = data.get("logging", "INFO").upper() logger.setLevel(loglevel) output = data.get(self.key_output_format, "dsa") logger.info(f"+++ Output Type: {output}") output_json = data.get(self.json, {}) write_to_file = data.get(self.key_write_to_file, True) if not write_to_file: return None, output_json res_json = { "name": f"MONAILabel Annotations - {data.get('model')}", "description": data.get("description"), "model": data.get("model"), "location": data.get("location"), "size": data.get("size"), "annotations": [output_json], "latencies": data.get("latencies"), } output_file = None if output == "asap": logger.info("+++ Generating ASAP XML Annotation") output_file, _ = create_asap_annotations_xml(res_json, loglevel=loglevel) elif output == "dsa": logger.info("+++ Generating DSA JSON Annotation") output_file, _ = create_dsa_annotations_json(res_json, loglevel=loglevel) else: logger.info("+++ Return Default JSON Annotation") return output_file, res_json
[docs]class DetectionWriter: def __init__( self, json="result", pred_box_key="box", pred_label_key="label", key_write_to_file="result_write_to_file", key_output_format="slicer", ): self.json = json self.pred_box_key = pred_box_key self.pred_label_key = pred_label_key self.key_write_to_file = key_write_to_file self.key_output_format = key_output_format self.format = format def __call__(self, data) -> Tuple[Any, Any]: loglevel = data.get("logging", "INFO").upper() logger.setLevel(loglevel) output = data.get(self.key_output_format, "slicer") logger.info(f"+++ Output Type: {output}") output_json = data.get(self.json, {}) write_to_file = data.get(self.key_write_to_file, True) if not write_to_file: return None, output_json res_json = { "name": f"MONAILabel Annotations - {data.get('model')}", "description": data.get("description"), "model": data.get("model"), "location": data.get("location"), "size": data.get("size"), "box": data.get(self.pred_box_key).cpu().detach().tolist(), "label": data.get(self.pred_label_key).cpu().detach().tolist(), "image": data.get("image_path", None), "latencies": data.get("latencies"), } output_files = None if output == "slicer": logger.info("+++ Generating Slicer Detection ROI Node JSON file") output_files, _ = create_slicer_detection_json(res_json, loglevel=loglevel) else: logger.info("+++ Return Default JSON Annotation") return output_files, res_json