Source code for monai.deploy.operators.monai_seg_inference_operator

# Copyright 2021-2023 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from pathlib import Path
from threading import Lock
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np

from monai.deploy.utils.importutil import optional_import
from monai.utils import StrEnum  # Will use the built-in StrEnum when SDK requires Python 3.11.

MONAI_UTILS = "monai.utils"
torch, _ = optional_import("torch", "1.5")
np_str_obj_array_pattern, _ = optional_import("torch.utils.data._utils.collate", name="np_str_obj_array_pattern")
Dataset, _ = optional_import("monai.data", name="Dataset")
DataLoader, _ = optional_import("monai.data", name="DataLoader")
ImageReader_, image_reader_ok_ = optional_import("monai.data", name="ImageReader")
# Dynamic class is not handled so make it Any for now: https://github.com/python/mypy/issues/2477
ImageReader: Any = ImageReader_
if not image_reader_ok_:
    ImageReader = object  # for 'class InMemImageReader(ImageReader):' to work
decollate_batch, _ = optional_import("monai.data", name="decollate_batch")
sliding_window_inference, _ = optional_import("monai.inferers", name="sliding_window_inference")
simple_inference, _ = optional_import("monai.inferers", name="SimpleInferer")
ensure_tuple, _ = optional_import(MONAI_UTILS, name="ensure_tuple")
MetaKeys, _ = optional_import(MONAI_UTILS, name="MetaKeys")
SpaceKeys, _ = optional_import(MONAI_UTILS, name="SpaceKeys")
Compose_, _ = optional_import("monai.transforms", name="Compose")
# Dynamic class is not handled so make it Any for now: https://github.com/python/mypy/issues/2477
Compose: Any = Compose_

from monai.deploy.core import AppContext, ConditionType, Fragment, Image, OperatorSpec

from .inference_operator import InferenceOperator

__all__ = ["MonaiSegInferenceOperator", "InfererType", "InMemImageReader"]


class InfererType(StrEnum):
    """Represents the supported types of the inferer, e.g. Simple and Sliding Window."""

    SIMPLE = "simple"
    SLIDING_WINDOW = "sliding_window"


# @md.env(pip_packages=["monai>=1.0.0", "torch>=1.10.2", "numpy>=1.21"])
[docs]class MonaiSegInferenceOperator(InferenceOperator): """This segmentation operator uses MONAI transforms and Sliding Window Inference. This operator preforms pre-transforms on a input image, inference using a given model, and post-transforms. The segmentation image is saved as a named Image object in memory. If specified in the post transforms, results may also be saved to disk. Named Input: image: Image object of the input image. Named Output: seg_image: Image object of the segmentation image. Not requiring a ready receiver. """ # For testing the app directly, the model should be at the following path. MODEL_LOCAL_PATH = Path(os.environ.get("HOLOSCAN_MODEL_PATH", Path.cwd() / "model/model.ts"))
[docs] def __init__( self, fragment: Fragment, *args, roi_size: Optional[Union[Sequence[int], int]], pre_transforms: Compose, post_transforms: Compose, app_context: AppContext, model_name: Optional[str] = "", overlap: float = 0.25, sw_batch_size: int = 4, inferer: Union[InfererType, str] = InfererType.SLIDING_WINDOW, model_path: Path = MODEL_LOCAL_PATH, **kwargs, ): """Creates a instance of this class. Args: fragment (Fragment): An instance of the Application class which is derived from Fragment. roi_size (Union[Sequence[int], int]): The window size to execute "SLIDING_WINDOW" evaluation. An optional input only to be passed for "SLIDING_WINDOW". If using a "SIMPLE" Inferer, this input is ignored. pre_transforms (Compose): MONAI Compose object used for pre-transforms. post_transforms (Compose): MONAI Compose object used for post-transforms. app_context (AppContext): Object holding the I/O and model paths, and potentially loaded models. model_name (str, optional): Name of the model. Default to "" for single model app. overlap (float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25. Applicable for "SLIDING_WINDOW" only. sw_batch_size(int): The batch size to run window slices. Defaults to 4. Applicable for "SLIDING_WINDOW" only. inferer (InfererType): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW". model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir. """ self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__)) self._executing = False self._lock = Lock() self._input_dataset_key = "image" self._pred_dataset_key = "pred" self._input_image = None # Image will come in when compute is called. self._reader: Any = None self._roi_size = ensure_tuple(roi_size) self._pre_transform = pre_transforms self._post_transforms = post_transforms self._model_name = model_name.strip() if isinstance(model_name, str) else "" self._overlap = overlap self._sw_batch_size = sw_batch_size self._inferer = inferer # Add this so that the local model path can be set from the calling app self.model_path = model_path self.input_name_image = "image" self.output_name_seg = "seg_image" # The execution context passed in on compute does not have the required model info, so need to # get and keep the model via the AppContext obj on construction. self.app_context = app_context self.model = self._get_model(self.app_context, self.model_path, self._model_name) super().__init__(fragment, *args, **kwargs)
def _get_model(self, app_context: AppContext, model_path: Path, model_name: str): """Load the model with the given name from context or model path Args: app_context (AppContext): The application context object holding the model(s) model_path (Path): The path to the model file, as a backup to load model directly model_name (str): The name of the model, when multiples are loaded in the context """ if app_context.models: # `app_context.models.get(model_name)` returns a model instance if exists. # If model_name is not specified and only one model exists, it returns that model. model = app_context.models.get(model_name) else: self._logger.info(f"Loading TorchScript model from: {model_path!r}") model = torch.jit.load( self.model_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), ) return model
[docs] def setup(self, spec: OperatorSpec): spec.input(self.input_name_image) spec.output(self.output_name_seg).condition(ConditionType.NONE) # Downstream receiver optional.
@property def roi_size(self): """The ROI size of tensors used in prediction.""" return self._roi_size @roi_size.setter def roi_size(self, roi_size: Union[Sequence[int], int]): self._roi_size = ensure_tuple(roi_size) @property def input_dataset_key(self): """This is the input image key name used in dictionary based MONAI pre-transforms.""" return self._input_dataset_key @input_dataset_key.setter def input_dataset_key(self, val: str): if not val or len(val) < 1: raise ValueError("Value cannot be None or blank.") self._input_dataset_key = val @property def pred_dataset_key(self): """This is the prediction key name used in dictionary based MONAI post-transforms.""" return self._pred_dataset_key @pred_dataset_key.setter def pred_dataset_key(self, val: str): if not val or len(val) < 1: raise ValueError("Value cannot be None or blank.") self._pred_dataset_key = val @property def overlap(self): """This is the overlap used during sliding window inference""" return self._overlap @overlap.setter def overlap(self, val: float): if val < 0 or val > 1: raise ValueError("Overlap must be between 0 and 1.") self._overlap = val @property def sw_batch_size(self): """The batch size to run window slices""" return self._sw_batch_size @sw_batch_size.setter def sw_batch_size(self, val: int): if not isinstance(val, int) or val < 0: raise ValueError("sw_batch_size must be a positive integer.") self._sw_batch_size = val @property def inferer(self) -> Union[InfererType, str]: """The type of inferer to use""" return self._inferer @inferer.setter def inferer(self, val: InfererType): if not isinstance(val, InfererType): raise ValueError(f"Value must be of the correct type {InfererType}.") self._inferer = val def _convert_dicom_metadata_datatype(self, metadata: Dict): """Converts metadata in pydicom types to the corresponding native types. It is known that some values of the metadata are of the pydicom types, for images converted from DICOM series. Need to use this function to convert the types with best effort and for the few knowns metadata attributes, until the following issue is addressed: https://github.com/Project-MONAI/monai-deploy-app-sdk/issues/185 Args: metadata (Dict): The metadata for an Image object """ if not metadata: return metadata # Try to convert data type for the well knowned attributes. Add more as needed. if metadata.get("SeriesInstanceUID", None): try: metadata["SeriesInstanceUID"] = str(metadata["SeriesInstanceUID"]) except Exception: pass if metadata.get("row_pixel_spacing", None): try: metadata["row_pixel_spacing"] = float(metadata["row_pixel_spacing"]) except Exception: pass if metadata.get("col_pixel_spacing", None): try: metadata["col_pixel_spacing"] = float(metadata["col_pixel_spacing"]) except Exception: pass self._logger.info("Converted Image object metadata:") for k, v in metadata.items(): self._logger.info(f"{k}: {v}, type {type(v)}") return metadata
[docs] def compute(self, op_input, op_output, context): """Infers with the input image and save the predicted image to output Args: op_input (InputContext): An input context for the operator. op_output (OutputContext): An output context for the operator. context (ExecutionContext): An execution context for the operator. """ with self._lock: if self._executing: raise RuntimeError("Operator is already executing.") else: self._executing = True try: input_image = op_input.receive(self.input_name_image) if not input_image: raise ValueError("Input is None.") op_output.emit(self.compute_impl(input_image, context), self.output_name_seg) finally: # Reset state on completing this method execution. with self._lock: self._executing = False
def compute_impl(self, input_image, context): if not input_image: raise ValueError("Input is None.") # Need to try to convert the data type of a few metadata attributes. input_img_metadata = self._convert_dicom_metadata_datatype(input_image.metadata()) # Need to give a name to the image as in-mem Image obj has no name. img_name = str(input_img_metadata.get("SeriesInstanceUID", "Img_in_context")) pre_transforms: Compose = self._pre_transform post_transforms: Compose = self._post_transforms self._reader = InMemImageReader(input_image) pre_transforms = self._pre_transform if self._pre_transform else self.pre_process(self._reader) post_transforms = self._post_transforms if self._post_transforms else self.post_process(pre_transforms) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dataset = Dataset(data=[{self._input_dataset_key: img_name}], transform=pre_transforms) dataloader = DataLoader( dataset, batch_size=1, shuffle=False, num_workers=0 ) # Should the batch_size be dynamic? with torch.no_grad(): for d in dataloader: images = d[self._input_dataset_key].to(device) if self._inferer == InfererType.SLIDING_WINDOW: d[self._pred_dataset_key] = sliding_window_inference( inputs=images, roi_size=self._roi_size, sw_batch_size=self.sw_batch_size, overlap=self.overlap, predictor=self.model, ) elif self._inferer == InfererType.SIMPLE: # Instantiates the SimpleInferer and directly uses its __call__ function d[self._pred_dataset_key] = simple_inference()(inputs=images, network=self.model) else: raise ValueError( f"Unknown inferer: {self._inferer!r}. Available options are " f"{InfererType.SLIDING_WINDOW!r} and {InfererType.SIMPLE!r}." ) d = [post_transforms(i) for i in decollate_batch(d)] out_ndarray = d[0][self._pred_dataset_key].cpu().numpy() # Need to squeeze out the channel dim fist out_ndarray = np.squeeze(out_ndarray, 0) # NOTE: The domain Image object simply contains a Arraylike obj as image as of now. # When the original DICOM series is converted by the Series to Volume operator, # using pydicom pixel_array, the 2D ndarray of each slice has index order HW, and # when all slices are stacked with depth as first axis, DHW. In the pre-transforms, # the image gets transposed to WHD and used as such in the inference pipeline. # So once post-transforms have completed, and the channel is squeezed out, # the resultant ndarray for the prediction image needs to be transposed back, so the # array index order is back to DHW, the same order as the in-memory input Image obj. out_ndarray = out_ndarray.T.astype(np.uint8) self._logger.info(f"Output Seg image numpy array shaped: {out_ndarray.shape}") self._logger.info(f"Output Seg image pixel max value: {np.amax(out_ndarray)}") return Image(out_ndarray, input_img_metadata)
[docs] def pre_process(self, data: Any, *args, **kwargs) -> Union[Any, Image, Tuple[Any, ...], Dict[Any, Any]]: """Transforms input before being used for predicting on a model. This method must be overridden by a derived class. Expected return is monai.transforms.Compose. Args: data(monai.data.ImageReader): Reader used in LoadImage to load `monai.deploy.core.Image` as the input. Returns: monai.transforms.Compose encapsulating pre transforms Raises: NotImplementedError: When the subclass does not override this method. """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
[docs] def post_process(self, data: Any, *args, **kwargs) -> Union[Any, Image, Tuple[Any, ...], Dict[Any, Any]]: """Transforms the prediction results from the model(s). This method must be overridden by a derived class. Expected return is monai.transforms.Compose. Args: data(monai.transforms.Compose): The pre-processing transforms in a Compose object. Returns: monai.transforms.Compose encapsulating post-processing transforms. Raises: NotImplementedError: When the subclass does not override this method. """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
[docs] def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]: """Predicts results using the models(s) with input tensors. This method is currently not used in this class, instead monai.inferers.sliding_window_inference is used. Raises: NotImplementedError: When the subclass does not override this method. """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
class InMemImageReader(ImageReader): """Converts the App SDK Image object from memory. This is derived from MONAI ImageReader. Instead of reading image from file system, this class simply converts a in-memory SDK Image object to the expected formats from ImageReader. The loaded data array will be in C order, for example, a 3D image NumPy array index order will be `WHDC`. The actual data array loaded is to be the same as that from the MONAI ITKReader, which can also load DICOM series. Furthermore, all Readers need to return the array data the same way as the NibabelReader, i.e. a numpy array of index order WHDC with channel being the last dim if present. More details are in the get_data() function. """ def __init__(self, input_image: Image, channel_dim: Optional[int] = None, **kwargs): super().__init__() self.input_image = input_image self.kwargs = kwargs self.channel_dim = channel_dim def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: return True def read(self, data: Union[Sequence[str], str], **kwargs) -> Union[Sequence[Any], Any]: # Really does not have anything to do. Simply return the Image object return self.input_image def get_data(self, input_image): """Extracts data array and meta data from loaded image and return them. This function returns two objects, first is numpy array of image data, second is dict of meta data. It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. A single image is loaded with a single set of metadata as of now. The App SDK Image asnumpy() function is expected to return a numpy array of index order `DHW`. This is because in the DICOM series to volume operator pydicom Dataset pixel_array is used to to get per instance pixel numpy array, with index order of `HW`. When all instances are stacked, along the first axis, the Image numpy array's index order is `DHW`. ITK array_view_from_image and SimpleITK GetArrayViewFromImage also returns a numpy array with the index order of `DHW`. The channel would be the last dim/index if present. In the ITKReader get_data(), this numpy array is then transposed, and the channel axis moved to be last dim post transpose; this is to be consistent with the numpy returned from NibabelReader get_data(). The NibabelReader loads NIfTI image and uses the get_fdata() function of the loaded image to get the numpy array, which has the index order in WHD with the channel being the last dim if present. Args: input_image (Image): an App SDK Image object. """ img_array: List[np.ndarray] = [] compatible_meta: Dict = {} for i in ensure_tuple(input_image): if not isinstance(i, Image): raise TypeError("Only object of Image type is supported.") # The Image asnumpy() returns NumPy array similar to ITK array_view_from_image # The array then needs to be transposed, as does in MONAI ITKReader, to align # with the output from Nibabel reader loading NIfTI files. data = i.asnumpy().T img_array.append(data) header = self._get_meta_dict(i) _copy_compatible_dict(header, compatible_meta) # Stacking image is not really needed, as there is one image only. return _stack_images(img_array, compatible_meta), compatible_meta def _get_meta_dict(self, img: Image) -> Dict: """ Gets the metadata of the image and converts to dict type. Args: img: A SDK Image object. """ img_meta_dict: Dict = img.metadata() meta_dict = {key: img_meta_dict[key] for key in img_meta_dict.keys()} # Will have to derive some key metadata as the SDK Image lacks the necessary interfaces. # So, for now have to get to the Image generator, namely DICOMSeriesToVolumeOperator, and # rely on its published metadata. # Referring to the MONAI ITKReader, the spacing is simply a NumPy array from the ITK image # GetSpacing, in WHD. meta_dict["spacing"] = np.asarray( [ img_meta_dict["row_pixel_spacing"], img_meta_dict["col_pixel_spacing"], img_meta_dict["depth_pixel_spacing"], ] ) # Use define metadata kyes directly meta_dict[MetaKeys.ORIGINAL_AFFINE] = np.asarray(img_meta_dict.get("nifti_affine_transform", None)) meta_dict[MetaKeys.AFFINE] = meta_dict[MetaKeys.ORIGINAL_AFFINE].copy() meta_dict[MetaKeys.SPACE] = SpaceKeys.LPS # not using SpaceKeys.RAS or affine_lps_to_ras # The spatial shape, again, referring to ITKReader, it is the WHD meta_dict[MetaKeys.SPATIAL_SHAPE] = np.asarray(img.asnumpy().T.shape) # Well, no channel as the image data shape is forced to the the same as spatial shape meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = "no_channel" return meta_dict # Reuse MONAI code for the derived ImageReader def _copy_compatible_dict(from_dict: Dict, to_dict: Dict): if not isinstance(to_dict, dict): raise ValueError(f"to_dict must be a Dict, got {type(to_dict)}.") if not to_dict: for key in from_dict: datum = from_dict[key] if isinstance(datum, np.ndarray) and np_str_obj_array_pattern.search(datum.dtype.str) is not None: continue to_dict[key] = datum else: affine_key, shape_key = MetaKeys.AFFINE, MetaKeys.SPATIAL_SHAPE if affine_key in from_dict and not np.allclose(from_dict[affine_key], to_dict[affine_key]): raise RuntimeError( "affine matrix of all images should be the same for channel-wise concatenation. " f"Got {from_dict[affine_key]} and {to_dict[affine_key]}." ) if shape_key in from_dict and not np.allclose(from_dict[shape_key], to_dict[shape_key]): raise RuntimeError( "spatial_shape of all images should be the same for channel-wise concatenation. " f"Got {from_dict[shape_key]} and {to_dict[shape_key]}." ) def _stack_images(image_list: List, meta_dict: Dict): if len(image_list) <= 1: return image_list[0] if meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None) not in ("no_channel", None): channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) return np.concatenate(image_list, axis=channel_dim) # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 return np.stack(image_list, axis=0)