Source code for monailabel.transform.post

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict, Hashable, Mapping, Optional, Sequence, Union

import cv2
import nibabel as nib
import numpy as np
import skimage.measure as measure
import torch
from monai.config import KeysCollection, NdarrayOrTensor
from monai.data import MetaTensor
from monai.transforms import (
    MapTransform,
    Orientation,
    Resize,
    Transform,
    generate_spatial_bounding_box,
    get_extreme_points,
)
from monai.utils import InterpolateMode, convert_to_numpy, ensure_tuple_rep
from shapely.geometry import Point, Polygon
from torchvision.utils import make_grid, save_image

from monailabel.utils.others.label_colors import get_color

logger = logging.getLogger(__name__)


# TODO:: Move to MONAI ??


[docs]class LargestCCd(MapTransform): def __init__(self, keys: KeysCollection, has_channel: bool = True): super().__init__(keys) self.has_channel = has_channel
[docs] @staticmethod def get_largest_cc(label): largest_cc = np.zeros(shape=label.shape, dtype=label.dtype) for i, item in enumerate(label): item = measure.label(item, connectivity=1) if item.max() != 0: largest_cc[i, ...] = item == (np.argmax(np.bincount(item.flat)[1:]) + 1) return largest_cc
def __call__(self, data): d = dict(data) for key in self.keys: result = self.get_largest_cc(d[key] if self.has_channel else d[key][np.newaxis]) d[key] = result if self.has_channel else result[0] return d
[docs]class ExtremePointsd(MapTransform): def __init__(self, keys: KeysCollection, result: str = "result", points: str = "points"): super().__init__(keys) self.result = result self.points = points def __call__(self, data): d = dict(data) for key in self.keys: try: points = get_extreme_points(d[key]) if d.get(self.result) is None: d[self.result] = dict() d[self.result][self.points] = np.array(points).astype(int).tolist() except ValueError: pass return d
[docs]class BoundingBoxd(MapTransform): def __init__(self, keys: KeysCollection, result: str = "result", bbox: str = "bbox"): super().__init__(keys) self.result = result self.bbox = bbox def __call__(self, data): d = dict(data) for key in self.keys: bbox = generate_spatial_bounding_box(d[key]) if d.get(self.result) is None: d[self.result] = dict() d[self.result][self.bbox] = np.array(bbox).astype(int).tolist() return d
[docs]class Restored(MapTransform): def __init__( self, keys: KeysCollection, ref_image: str, has_channel: bool = True, invert_orient: bool = False, mode: str = InterpolateMode.NEAREST, config_labels=None, align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, meta_key_postfix: str = "meta_dict", ): super().__init__(keys) self.ref_image = ref_image self.has_channel = has_channel self.invert_orient = invert_orient self.config_labels = config_labels self.mode = ensure_tuple_rep(mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.meta_key_postfix = meta_key_postfix def __call__(self, data): d = dict(data) meta_dict = ( d[self.ref_image].meta if d.get(self.ref_image) is not None and isinstance(d[self.ref_image], MetaTensor) else d.get(f"{self.ref_image}_{self.meta_key_postfix}", {}) ) for idx, key in enumerate(self.keys): result = d[key] current_size = result.shape[1:] if self.has_channel else result.shape spatial_shape = meta_dict.get("spatial_shape", current_size) spatial_size = spatial_shape[-len(current_size) :] # Undo Spacing if np.any(np.not_equal(current_size, spatial_size)): resizer = Resize(spatial_size=spatial_size, mode=self.mode[idx]) result = resizer(result, mode=self.mode[idx], align_corners=self.align_corners[idx]) if self.invert_orient: # Undo Orientation orig_affine = meta_dict.get("original_affine", None) if orig_affine is not None: orig_axcodes = nib.orientations.aff2axcodes(orig_affine) inverse_transform = Orientation(axcodes=orig_axcodes) # Apply inverse with inverse_transform.trace_transform(False): result = inverse_transform(result) else: logging.info("Failed invert orientation - original_affine is not on the image header") # Converting label indexes to the ones originally defined in the config file if self.config_labels is not None: new_pred = result * 0.0 for j, (label_name, idx) in enumerate(self.config_labels.items(), 1): # Consider only labels different than background if label_name != "background": new_pred[result == j] = idx result = new_pred d[key] = result if len(result.shape) <= 3 else result[0] if result.shape[0] == 1 else result meta = d.get(f"{key}_{self.meta_key_postfix}") if meta is None: meta = dict() d[f"{key}_{self.meta_key_postfix}"] = meta meta["affine"] = meta_dict.get("original_affine") return d
[docs]class FindContoursd(MapTransform): def __init__( self, keys: KeysCollection, min_positive=10, min_poly_area=80, max_poly_area=0, result="result", result_output_key="annotation", key_label_colors="label_colors", key_foreground_points=None, labels=None, colormap=None, ): super().__init__(keys) self.min_positive = min_positive self.min_poly_area = min_poly_area self.max_poly_area = max_poly_area self.result = result self.result_output_key = result_output_key self.key_label_colors = key_label_colors self.key_foreground_points = key_foreground_points self.colormap = colormap labels = labels if labels else dict() labels = [labels] if isinstance(labels, str) else labels if not isinstance(labels, dict): labels = {v: k + 1 for k, v in enumerate(labels)} labels = {v: k for k, v in labels.items()} self.labels = labels def __call__(self, data): d = dict(data) location = d.get("location", [0, 0]) size = d.get("size", [0, 0]) min_poly_area = d.get("min_poly_area", self.min_poly_area) max_poly_area = d.get("max_poly_area", self.max_poly_area) color_map = d.get(self.key_label_colors) if self.colormap is None else self.colormap foreground_points = d.get(self.key_foreground_points, []) if self.key_foreground_points else [] foreground_points = [Point(pt[0], pt[1]) for pt in foreground_points] # polygons in (x, y) format elements = [] label_names = set() for key in self.keys: p = d[key] if np.count_nonzero(p) < self.min_positive: continue labels = [label for label in np.unique(p).tolist() if label > 0] logger.debug(f"Total Unique Masks (excluding background): {labels}") for label_idx in labels: p = convert_to_numpy(d[key]) if isinstance(d[key], torch.Tensor) else d[key] p = np.where(p == label_idx, 1, 0).astype(np.uint8) p = np.moveaxis(p, 0, 1) # for cv2 label_name = self.labels.get(label_idx, label_idx) label_names.add(label_name) polygons = [] contours, _ = cv2.findContours(p, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) for contour in contours: if len(contour) < 3: continue contour = np.squeeze(contour) area = cv2.contourArea(contour) if area < min_poly_area: # Ignore poly with lesser area continue if 0 < max_poly_area < area: # Ignore very large poly (e.g. in case of nuclei) continue contour[:, 0] += location[0] # X contour[:, 1] += location[1] # Y coords = contour.astype(int).tolist() if foreground_points: for pt in foreground_points: if Polygon(coords).contains(pt): polygons.append(coords) break else: polygons.append(coords) if len(polygons): logger.debug(f"+++++ {label_idx} => Total Polygons Found: {len(polygons)}") elements.append({"label": label_name, "contours": polygons}) if elements: if d.get(self.result) is None: d[self.result] = dict() d[self.result][self.result_output_key] = { "location": location, "size": size, "elements": elements, "labels": {n: get_color(n, color_map) for n in label_names}, } logger.debug(f"+++++ ALL => Total Annotation Elements Found: {len(elements)}") return d
[docs]class DumpImagePrediction2Dd(Transform): def __init__(self, image_path, pred_path, pred_only=True): self.image_path = image_path self.pred_path = pred_path self.pred_only = pred_only def __call__(self, data): d = dict(data) for bidx in range(d["image"].shape[0]): image = np.moveaxis(d["image"][bidx], 1, 2) pred = np.moveaxis(d["pred"][bidx], 0, 1) img_tensor = make_grid(torch.from_numpy(image[:3] * 128 + 128), normalize=True) save_image(img_tensor, self.image_path) if self.pred_only: pred_tensor = make_grid(torch.from_numpy(pred), normalize=True) save_image(pred_tensor[0], self.pred_path) return d image_pred = [pred[None], image[3][None], image[4][None]] if image.shape[0] == 5 else [pred[None]] image_pred_np = np.array(image_pred) image_pred_t = torch.from_numpy(image_pred_np) tensor = make_grid( tensor=image_pred_t, nrow=len(image_pred), normalize=True, pad_value=10, ) save_image(tensor, self.pred_path) return d
[docs]class MergeAllPreds(MapTransform):
[docs] def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False): """ Merge all predictions to one channel Args: keys: The ``keys`` parameter will be used to get and set the actual data item to transform """ super().__init__(keys, allow_missing_keys)
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]): d: Dict = dict(data) merge_image = None for idx, key in enumerate(self.key_iterator(d)): if idx == 0: merge_image = d[key] else: merge_image = merge_image + d[key] # For labels that overlap keep the last label number only merge_image[merge_image > d[key].max()] = d[key].max() return merge_image
[docs]class RenameKeyd(Transform): def __init__(self, source_key, target_key): self.source_key = source_key self.target_key = target_key def __call__(self, data): d = dict(data) d[self.target_key] = d.pop(self.source_key) return d