Source code for monai.apps.pathology.transforms.spatial.array

# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Sequence, Tuple, Union

import numpy as np
import torch
from numpy.lib.stride_tricks import as_strided

from monai.config.type_definitions import NdarrayOrTensor
from monai.transforms.transform import Randomizable, Transform
from monai.utils import convert_data_type, convert_to_dst_type
from monai.utils.enums import TransformBackends

__all__ = ["SplitOnGrid", "TileOnGrid"]


[docs]class SplitOnGrid(Transform): """ Split the image into patches based on the provided grid shape. This transform works only with torch.Tensor inputs. Args: grid_shape: a tuple or an integer define the shape of the grid upon which to extract patches. If it's an integer, the value will be repeated for each dimension. Default is 2x2 patch_size: a tuple or an integer that defines the output patch sizes. If it's an integer, the value will be repeated for each dimension. The default is (0, 0), where the patch size will be inferred from the grid shape. Note: the shape of the input image is inferred based on the first image used. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( self, grid_size: Union[int, Tuple[int, int]] = (2, 2), patch_size: Optional[Union[int, Tuple[int, int]]] = None ): # Grid size if isinstance(grid_size, int): self.grid_size = (grid_size, grid_size) else: self.grid_size = grid_size # Patch size self.patch_size = None if isinstance(patch_size, int): self.patch_size = (patch_size, patch_size) else: self.patch_size = patch_size def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: if self.grid_size == (1, 1) and self.patch_size is None: if isinstance(image, torch.Tensor): return torch.stack([image]) elif isinstance(image, np.ndarray): return np.stack([image]) else: raise ValueError(f"Input type [{type(image)}] is not supported.") patch_size, steps = self.get_params(image.shape[1:]) patches: NdarrayOrTensor if isinstance(image, torch.Tensor): patches = ( image.unfold(1, patch_size[0], steps[0]) .unfold(2, patch_size[1], steps[1]) .flatten(1, 2) .transpose(0, 1) .contiguous() ) elif isinstance(image, np.ndarray): x_step, y_step = steps c_stride, x_stride, y_stride = image.strides n_channels = image.shape[0] patches = as_strided( image, shape=(*self.grid_size, n_channels, patch_size[0], patch_size[1]), strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), writeable=False, ) # flatten the first two dimensions patches = patches.reshape(np.prod(patches.shape[:2]), *patches.shape[2:]) # make it a contiguous array patches = np.ascontiguousarray(patches) else: raise ValueError(f"Input type [{type(image)}] is not supported.") return patches def get_params(self, image_size): if self.patch_size is None: patch_size = tuple(image_size[i] // self.grid_size[i] for i in range(2)) else: patch_size = self.patch_size steps = tuple( (image_size[i] - patch_size[i]) // (self.grid_size[i] - 1) if self.grid_size[i] > 1 else image_size[i] for i in range(2) ) return patch_size, steps
class TileOnGrid(Randomizable, Transform): """ Tile the 2D image into patches on a grid and maintain a subset of it. This transform works only with np.ndarray inputs for 2D images. Args: tile_count: number of tiles to extract, if None extracts all non-background tiles Defaults to ``None``. tile_size: size of the square tile Defaults to ``256``. step: step size Defaults to ``None`` (same as tile_size) random_offset: Randomize position of the grid, instead of starting from the top-left corner Defaults to ``False``. pad_full: pad image to the size evenly divisible by tile_size Defaults to ``False``. background_val: the background constant (e.g. 255 for white background) Defaults to ``255``. filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is more than tile_size, then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for random) subset Defaults to ``min`` (which assumes background is high value) """ backend = [TransformBackends.NUMPY] def __init__( self, tile_count: Optional[int] = None, tile_size: int = 256, step: Optional[int] = None, random_offset: bool = False, pad_full: bool = False, background_val: int = 255, filter_mode: str = "min", ): self.tile_count = tile_count self.tile_size = tile_size self.random_offset = random_offset self.pad_full = pad_full self.background_val = background_val self.filter_mode = filter_mode if step is None: # non-overlapping grid self.step = self.tile_size else: self.step = step self.offset = (0, 0) self.random_idxs = np.array((0,)) if self.filter_mode not in ["min", "max", "random"]: raise ValueError("Unsupported filter_mode, must be [min, max or random]: " + str(self.filter_mode)) def randomize(self, img_size: Sequence[int]) -> None: c, h, w = img_size self.offset = (0, 0) if self.random_offset: pad_h = h % self.tile_size pad_w = w % self.tile_size self.offset = (self.R.randint(pad_h) if pad_h > 0 else 0, self.R.randint(pad_w) if pad_w > 0 else 0) h = h - self.offset[0] w = w - self.offset[1] if self.pad_full: pad_h = (self.tile_size - h % self.tile_size) % self.tile_size pad_w = (self.tile_size - w % self.tile_size) % self.tile_size h = h + pad_h w = w + pad_w h_n = (h - self.tile_size + self.step) // self.step w_n = (w - self.tile_size + self.step) // self.step tile_total = h_n * w_n if self.tile_count is not None and tile_total > self.tile_count: self.random_idxs = self.R.choice(range(tile_total), self.tile_count, replace=False) else: self.random_idxs = np.array((0,)) def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: img_np: np.ndarray img_np, *_ = convert_data_type(image, np.ndarray) # type: ignore # add random offset self.randomize(img_size=img_np.shape) if self.random_offset and (self.offset[0] > 0 or self.offset[1] > 0): img_np = img_np[:, self.offset[0] :, self.offset[1] :] # pad to full size, divisible by tile_size if self.pad_full: c, h, w = img_np.shape pad_h = (self.tile_size - h % self.tile_size) % self.tile_size pad_w = (self.tile_size - w % self.tile_size) % self.tile_size img_np = np.pad( img_np, [[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]], constant_values=self.background_val, ) # extact tiles x_step, y_step = self.step, self.step h_tile, w_tile = self.tile_size, self.tile_size c_image, h_image, w_image = img_np.shape c_stride, x_stride, y_stride = img_np.strides llw = as_strided( img_np, shape=((h_image - h_tile) // x_step + 1, (w_image - w_tile) // y_step + 1, c_image, h_tile, w_tile), strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), writeable=False, ) img_np = llw.reshape(-1, c_image, h_tile, w_tile) # if keeping all patches if self.tile_count is None: # retain only patches with significant foreground content to speed up inference # FYI, this returns a variable number of tiles, so the batch_size must be 1 (per gpu), e.g during inference thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size if self.filter_mode == "min": # default, keep non-background tiles (small values) idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) < thresh) img_np = img_np[idxs.reshape(-1)] elif self.filter_mode == "max": idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) >= thresh) img_np = img_np[idxs.reshape(-1)] else: if len(img_np) > self.tile_count: if self.filter_mode == "min": # default, keep non-background tiles (smallest values) idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[: self.tile_count] img_np = img_np[idxs] elif self.filter_mode == "max": idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[-self.tile_count :] img_np = img_np[idxs] else: # random subset (more appropriate for WSIs without distinct background) if self.random_idxs is not None: img_np = img_np[self.random_idxs] elif len(img_np) < self.tile_count: img_np = np.pad( img_np, [[0, self.tile_count - len(img_np)], [0, 0], [0, 0], [0, 0]], constant_values=self.background_val, ) image, *_ = convert_to_dst_type(src=img_np, dst=image, dtype=image.dtype) return image