import json
import logging
from typing import Callable, Dict, Hashable, List, Optional, Sequence, Union
import numpy as np
import torch
from monai.config import IndexSelection, KeysCollection
from monai.data import MetaTensor
from monai.transforms import MapTransform, Randomizable, Resize, SpatialCrop, generate_spatial_bounding_box, is_positive
from monai.utils import InterpolateMode, PostFix, ensure_tuple_rep
from scipy.ndimage import distance_transform_cdt, gaussian_filter
from skimage import measure
logger = logging.getLogger(__name__)
[docs]class AddClickGuidanced(MapTransform):
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False, guidance="guidance"):
super().__init__(keys, allow_missing_keys)
self.guidance = guidance
def __call__(self, data):
d = dict(data)
guidance = []
for key in self.keys:
g = d.get(key)
g = np.array(g).astype(int).tolist() if g else []
d[self.guidance] = guidance
return d
[docs]class AddInitialSeedPointd(Randomizable, MapTransform):
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False, label="label", connected_regions=1):
super().__init__(keys, allow_missing_keys)
self.label = label
self.connected_regions = connected_regions
def _apply(self, label):
default_guidance = [-1] * len(label.shape)
if self.connected_regions > 1:
blobs_labels = measure.label(label, background=0)
u, count = np.unique(blobs_labels, return_counts=True)
count_sort_ind = np.argsort(-count)
connected_regions = u[count_sort_ind].astype(int).tolist()
connected_regions = [r for r in connected_regions if r]
connected_regions = connected_regions[: self.connected_regions]
blobs_labels = None
connected_regions = [1]
pos_guidance = []
for region in connected_regions:
label = label if blobs_labels is None else (blobs_labels == region).astype(int)
if np.sum(label) == 0:
distance = distance_transform_cdt(label).flatten()
probability = np.exp(distance) - 1.0
idx = np.where(label.flatten() > 0)[0]
seed = self.R.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx]))
dst = distance[seed]
g = np.asarray(np.unravel_index(seed, label.shape)).transpose().tolist()[0]
g[0] = dst[0] # for debug
return np.asarray([pos_guidance, [default_guidance] * len(pos_guidance)]).astype(int, copy=False).tolist()
def __call__(self, data):
d = dict(data)
for key in self.keys:
d[key] = json.dumps(self._apply(d[self.label]))
return d
[docs]class AddGuidanceSignald(MapTransform):
def __init__(
keys: KeysCollection,
allow_missing_keys: bool = False,
guidance: str = "guidance",
sigma: int = 2,
super().__init__(keys, allow_missing_keys)
self.guidance = guidance
self.sigma = sigma
self.number_intensity_ch = number_intensity_ch
[docs] def signal(self, shape, points):
signal = np.zeros(shape, dtype=np.float32)
flag = False
for p in points:
if np.any(np.asarray(p) < 0):
if len(shape) == 3:
signal[int(p[-3]), int(p[-2]), int(p[-1])] = 1.0
signal[int(p[-2]), int(p[-1])] = 1.0
flag = True
if flag:
signal = gaussian_filter(signal, sigma=self.sigma)
signal = (signal - np.min(signal)) / (np.max(signal) - np.min(signal))
return torch.Tensor(signal)[None]
def __call__(self, data):
d = dict(data)
for key in self.keys:
img = d[key]
guidance = d[self.guidance]
guidance = json.loads(guidance) if isinstance(guidance, str) else guidance
if guidance and (guidance[0] or guidance[1]):
img = img[0 : 0 + self.number_intensity_ch, ...]
shape = img.shape[-2:] if len(img.shape) == 3 else img.shape[-3:]
device = img.device if isinstance(img, torch.Tensor) else None
pos = self.signal(shape, guidance[0]).to(device=device)
neg = self.signal(shape, guidance[1]).to(device=device)
result = torch.concat([img if isinstance(img, torch.Tensor) else torch.Tensor(img), pos, neg])
s = torch.zeros_like(img[0])[None]
result = torch.concat([img, s, s])
d[key] = result
return d
[docs]class SpatialCropForegroundd(MapTransform):
def __init__(
keys: KeysCollection,
source_key: str,
spatial_size: Union[Sequence[int], np.ndarray],
select_fn: Callable = is_positive,
channel_indices: Optional[IndexSelection] = None,
margin: int = 0,
allow_smaller: bool = True,
start_coord_key: str = "foreground_start_coord",
end_coord_key: str = "foreground_end_coord",
original_shape_key: str = "foreground_original_shape",
cropped_shape_key: str = "foreground_cropped_shape",
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.source_key = source_key
self.spatial_size = list(spatial_size)
self.select_fn = select_fn
self.channel_indices = channel_indices
self.margin = margin
self.allow_smaller = allow_smaller
self.start_coord_key = start_coord_key
self.end_coord_key = end_coord_key
self.original_shape_key = original_shape_key
self.cropped_shape_key = cropped_shape_key
def __call__(self, data):
d = dict(data)
box_start, box_end = generate_spatial_bounding_box(
d[self.source_key], self.select_fn, self.channel_indices, self.margin, self.allow_smaller
center = list(np.mean([box_start, box_end], axis=0).astype(int, copy=False))
current_size = list(np.subtract(box_end, box_start).astype(int, copy=False))
if np.all(np.less(current_size, self.spatial_size)):
cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size)
box_start = np.array([s.start for s in cropper.slices])
box_end = np.array([s.stop for s in cropper.slices])
cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
for key in self.keys:
image = d[key]
meta = image.meta
meta[self.start_coord_key] = box_start
meta[self.end_coord_key] = box_end
meta[self.original_shape_key] = d[key].shape
result = cropper(image)
meta[self.cropped_shape_key] = result.shape
d[key] = result
return d
[docs]class RestoreLabeld(MapTransform):
def __init__(
keys: KeysCollection,
ref_image: str,
mode: Union[Sequence[Union[InterpolateMode, str]], InterpolateMode, str] = InterpolateMode.NEAREST,
align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None,
meta_key_postfix: str = PostFix.meta(),
start_coord_key: str = "foreground_start_coord",
end_coord_key: str = "foreground_end_coord",
original_shape_key: str = "foreground_original_shape",
cropped_shape_key: str = "foreground_cropped_shape",
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.ref_image = ref_image
self.mode = ensure_tuple_rep(mode, len(self.keys))
self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))
self.meta_key_postfix = meta_key_postfix
self.start_coord_key = start_coord_key
self.end_coord_key = end_coord_key
self.original_shape_key = original_shape_key
self.cropped_shape_key = cropped_shape_key
def __call__(self, data):
d = dict(data)
meta_dict = (
if isinstance(d[self.ref_image], MetaTensor)
else d[f"{self.ref_image}_{self.meta_key_postfix}"]
for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners):
image = d[key]
# Undo Resize
current_shape = image.shape
cropped_shape = meta_dict[self.cropped_shape_key]
if np.any(np.not_equal(current_shape, cropped_shape)):
resizer = Resize(spatial_size=cropped_shape[1:], mode=mode)
image = resizer(image, mode=mode, align_corners=align_corners)
# Undo Crop
original_shape = meta_dict[self.original_shape_key][1:]
result = np.zeros(original_shape, dtype=np.float32)
box_start = meta_dict[self.start_coord_key]
box_end = meta_dict[self.end_coord_key]
spatial_dims = min(len(box_start), len(image.shape[1:]))
slices = [slice(s, e) for s, e in zip(box_start[:spatial_dims], box_end[:spatial_dims])]
slices = tuple(slices)
result[slices] = image.array if isinstance(image, MetaTensor) else image
d[key] = result
return d
[docs]class SpatialCropGuidanced(MapTransform):
def __init__(
keys: KeysCollection,
guidance: str,
start_coord_key: str = "foreground_start_coord",
end_coord_key: str = "foreground_end_coord",
original_shape_key: str = "foreground_original_shape",
cropped_shape_key: str = "foreground_cropped_shape",
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.guidance = guidance
self.spatial_size = list(spatial_size)
self.margin = margin
self.start_coord_key = start_coord_key
self.end_coord_key = end_coord_key
self.original_shape_key = original_shape_key
self.cropped_shape_key = cropped_shape_key
[docs] def bounding_box(self, points, img_shape):
ndim = len(img_shape)
margin = ensure_tuple_rep(self.margin, ndim)
for m in margin:
if m < 0:
raise ValueError("margin value should not be negative number.")
box_start = [0] * ndim
box_end = [0] * ndim
for di in range(ndim):
dt = points[..., di]
min_d = max(min(dt - margin[di]), 0)
max_d = min(img_shape[di], max(dt + margin[di] + 1))
box_start[di], box_end[di] = min_d, max_d
return box_start, box_end
def __call__(self, data):
d: Dict = dict(data)
first_key: Union[Hashable, List] = self.first_key(d)
if not first_key:
return d
guidance = d[self.guidance]
original_spatial_shape = d[first_key].shape[1:]
box_start, box_end = self.bounding_box(np.array(guidance[0] + guidance[1]), original_spatial_shape)
center = list(np.mean([box_start, box_end], axis=0).astype(int, copy=False))
spatial_size = self.spatial_size
box_size = list(np.subtract(box_end, box_start).astype(int, copy=False))
spatial_size = spatial_size[-len(box_size) :]
if np.all(np.less(box_size, spatial_size)):
cropper = SpatialCrop(roi_center=center, roi_size=spatial_size)
cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
# update bounding box in case it was corrected by the SpatialCrop constructor
box_start = np.array([s.start for s in cropper.slices])
box_end = np.array([s.stop for s in cropper.slices])
for key in self.keys:
image = d[key]
meta = image.meta
meta[self.start_coord_key] = box_start
meta[self.end_coord_key] = box_end
meta[self.original_shape_key] = d[key].shape
result = cropper(image)
result.meta[self.cropped_shape_key] = result.shape
d[key] = result
pos_clicks, neg_clicks = guidance[0], guidance[1]
pos = np.subtract(pos_clicks, box_start).tolist() if len(pos_clicks) else []
neg = np.subtract(neg_clicks, box_start).tolist() if len(neg_clicks) else []
d[self.guidance] = [pos, neg]
return d
[docs]class ResizeGuidanced(MapTransform):
def __init__(
keys: KeysCollection,
ref_image: str,
cropped_shape_key: str = "foreground_cropped_shape",
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.ref_image = ref_image
self.cropped_shape_key = cropped_shape_key
def __call__(self, data):
d = dict(data)
current_shape = d[self.ref_image].shape[1:]
meta = d[self.ref_image].meta
if self.cropped_shape_key and meta.get(self.cropped_shape_key):
cropped_shape = meta[self.cropped_shape_key][1:]
cropped_shape = meta.get("spatial_shape", current_shape)
factor = np.divide(current_shape, cropped_shape)
for key in self.keys:
guidance = d[key]
pos_clicks, neg_clicks = guidance[0], guidance[1]
pos = np.multiply(pos_clicks, factor).astype(int, copy=False).tolist() if len(pos_clicks) else []
neg = np.multiply(neg_clicks, factor).astype(int, copy=False).tolist() if len(neg_clicks) else []
d[key] = [pos, neg]
return d