Source code for monailabel.transform.post

# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Sequence, Union

import numpy as np
import skimage.measure as measure
from monai.config import KeysCollection
from monai.transforms import MapTransform, Resize, generate_spatial_bounding_box, get_extreme_points
from monai.transforms.spatial.dictionary import InterpolateModeSequence
from monai.utils import InterpolateMode, ensure_tuple_rep

# TODO:: Move to MONAI ??


[docs]class LargestCCd(MapTransform): def __init__(self, keys: KeysCollection, has_channel: bool = True): super().__init__(keys) self.has_channel = has_channel
[docs] @staticmethod def get_largest_cc(label): largest_cc = np.zeros(shape=label.shape, dtype=label.dtype) for i, item in enumerate(label): item = measure.label(item, connectivity=1) if item.max() != 0: largest_cc[i, ...] = item == (np.argmax(np.bincount(item.flat)[1:]) + 1) return largest_cc
def __call__(self, data): d = dict(data) for key in self.keys: result = self.get_largest_cc(d[key] if self.has_channel else d[key][np.newaxis]) d[key] = result if self.has_channel else result[0] return d
[docs]class ExtremePointsd(MapTransform): def __init__(self, keys: KeysCollection, result: str = "result", points: str = "points"): super().__init__(keys) self.result = result self.points = points def __call__(self, data): d = dict(data) for key in self.keys: points = get_extreme_points(d[key]) if d.get(self.result) is None: d[self.result] = dict() d[self.result][self.points] = np.array(points).astype(int).tolist() return d
[docs]class BoundingBoxd(MapTransform): def __init__(self, keys: KeysCollection, result: str = "result", bbox: str = "bbox"): super().__init__(keys) self.result = result self.bbox = bbox def __call__(self, data): d = dict(data) for key in self.keys: bbox = generate_spatial_bounding_box(d[key]) if d.get(self.result) is None: d[self.result] = dict() d[self.result][self.bbox] = np.array(bbox).astype(int).tolist() return d
[docs]class Restored(MapTransform): def __init__( self, keys: KeysCollection, ref_image: str, has_channel: bool = True, mode: InterpolateModeSequence = InterpolateMode.NEAREST, align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, meta_key_postfix: str = "meta_dict", ): super().__init__(keys) self.ref_image = ref_image self.has_channel = has_channel self.mode = ensure_tuple_rep(mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.meta_key_postfix = meta_key_postfix def __call__(self, data): d = dict(data) meta_dict = d[f"{self.ref_image}_{self.meta_key_postfix}"] for idx, key in enumerate(self.keys): result = d[key] current_size = result.shape[1:] if self.has_channel else result.shape spatial_shape = meta_dict["spatial_shape"] spatial_size = spatial_shape[-len(current_size) :] # Undo Spacing if np.any(np.not_equal(current_size, spatial_size)): resizer = Resize(spatial_size=spatial_size, mode=self.mode[idx]) result = resizer(result, mode=self.mode[idx], align_corners=self.align_corners[idx]) d[key] = result if len(result.shape) <= 3 else result[0] if result.shape[0] == 1 else result meta = d.get(f"{key}_{self.meta_key_postfix}") if meta is None: meta = dict() d[f"{key}_{self.meta_key_postfix}"] = meta meta["affine"] = meta_dict.get("original_affine") return d