Creating a Segmentation App Including Visualization with MONAI Deploy App SDK

This tutorial shows how to create an organ segmentation application for a PyTorch model that has been trained with MONAI, and visualize the segmentation and input images with Clara Viz integration.

Deploying AI models requires the integration with clinical imaging network, even if in a for-research-use setting. This means that the AI deploy application will need to support standards-based imaging protocols, and specifically for Radiological imaging, DICOM protocol.

Typically, DICOM network communication, either in DICOM TCP/IP network protocol or DICOMWeb, would be handled by DICOM devices or services, e.g. MONAI Deploy Informatics Gateway, so the deploy application itself would only need to use DICOM Part 10 files as input and save the AI result in DICOM Part10 file(s). For segmentation use cases, the DICOM instance file could be a DICOM Segmentation object or a DICOM RT Structure Set, and for classification, DICOM Structure Report and/or DICOM Encapsulated PDF.

During model training, input and label images are typically in non-DICOM volumetric image format, e.g., NIfTI and PNG, converted from a specific DICOM study series. Furthermore, the voxel spacings most likely have been re-sampled to be uniform for all images. When integrated with imaging networks and receiving DICOM instances from modalities and Picture Archiving and Communications System, PACS, an AI deploy application may have to deal with a whole DICOM study with multiple series, whose images’ spacing may not be the same as expected by the trained model. To address these cases consistently and efficiently, MONAI Deploy Application SDK provides classes, called operators, to parse DICOM studies, select specific series with application-defined rules, and convert the selected DICOM series into domain-specific image format along with meta-data representing the pertinent DICOM attributes.

In the following sections, we will demonstrate how to create a MONAI Deploy application package using the MONAI Deploy App SDK.

Note

For local testing, if there is a lack of DICOM Part 10 files, one can use open source programs, e.g. 3D Slicer, to convert NIfTI to DICOM files.

Creating Operators and connecting them in Application class

We will implement an application that consists of five Operators:

  • DICOMDataLoaderOperator:

    • Input(dicom_files): a folder path (DataPath)

    • Output(dicom_study_list): a list of DICOM studies in memory (List[DICOMStudy])

  • DICOMSeriesSelectorOperator:

    • Input(dicom_study_list): a list of DICOM studies in memory (List[DICOMStudy])

    • Input(selection_rules): a selection rule (Dict)

    • Output(study_selected_series_list): a DICOM series object in memory (StudySelectedSeries)

  • DICOMSeriesToVolumeOperator:

    • Input(study_selected_series_list): a DICOM series object in memory (StudySelectedSeries)

    • Output(image): an image object in memory (Image)

  • SpleenSegOperator:

    • Input(image): an image object in memory (Image)

    • Output(seg_image): an image object in memory (Image)

  • DICOMSegmentationWriterOperator:

    • Input(seg_image): a segmentation image object in memory (Image)

    • Input(study_selected_series_list): a DICOM series object in memory (StudySelectedSeries)

    • Output(dicom_seg_instance): a file path (DataPath)

  • ClaraVizOperator:

    • Input(image): a volume image object in memory (Image)

    • Input(seg_image): a segmentation image object in memory (Image)

Note

The DICOMSegmentationWriterOperator needs both the segmentation image as well as the original DICOM series meta-data in order to use the patient demographics and the DICOM Study level attributes.

The workflow of the application would look like this.

%%{init: {"theme": "base", "themeVariables": { "fontSize": "16px"}} }%% classDiagram direction TB DICOMDataLoaderOperator --|> DICOMSeriesSelectorOperator : dicom_study_list...dicom_study_list DICOMSeriesSelectorOperator --|> DICOMSeriesToVolumeOperator : study_selected_series_list...study_selected_series_list DICOMSeriesToVolumeOperator --|> SpleenSegOperator : image...image DICOMSeriesSelectorOperator --|> DICOMSegmentationWriterOperator : study_selected_series_list...study_selected_series_list SpleenSegOperator --|> DICOMSegmentationWriterOperator : seg_image...seg_image DICOMSeriesToVolumeOperator --|> ClaraVizOperator : image...image SpleenSegOperator --|> ClaraVizOperator : seg_image...seg_image class DICOMDataLoaderOperator { <in>dicom_files : DISK dicom_study_list(out) IN_MEMORY } class DICOMSeriesSelectorOperator { <in>dicom_study_list : IN_MEMORY <in>selection_rules : IN_MEMORY study_selected_series_list(out) IN_MEMORY } class DICOMSeriesToVolumeOperator { <in>study_selected_series_list : IN_MEMORY image(out) IN_MEMORY } class SpleenSegOperator { <in>image : IN_MEMORY seg_image(out) IN_MEMORY } class DICOMSegmentationWriterOperator { <in>seg_image : IN_MEMORY <in>study_selected_series_list : IN_MEMORY dicom_seg_instance(out) DISK } class ClaraVizOperator { <in>image : IN_MEMORY <in>seg_image : IN_MEMORY }

Setup environment

# Install MONAI and other necessary image processing packages for the application
!python -c "import monai" || pip install --upgrade -q "monai"
!python -c "import torch" || pip install -q "torch>=1.10.2"
!python -c "import numpy" || pip install -q "numpy>=1.21"
!python -c "import nibabel" || pip install -q "nibabel>=3.2.1"
!python -c "import pydicom" || pip install -q "pydicom>=1.4.2"
!python -c "import highdicom" || pip install -q "highdicom>=0.18.2"
!python -c "import SimpleITK" || pip install -q "SimpleITK>=2.0.0"
!python -c "import typeguard" || pip install -q "typeguard~=2.12.1"

# Install MONAI Deploy App SDK package
!python -c "import monai.deploy" || pip install --upgrade -q "monai-deploy-app-sdk"

# Install Clara Viz package
!python -c "import clara.viz" || pip install --upgrade -q "clara-viz"

Note: you may need to restart the Jupyter kernel to use the updated packages.

Download/Extract ai_spleen_seg_data from Google Drive

# Download ai_spleen_bundle_data test data zip file
!pip install gdown 
!gdown "https://drive.google.com/uc?id=1Uds8mEvdGNYUuvFpTtCQ8gNU97bAPCaQ"

# After downloading ai_spleen_bundle_data zip file from the web browser or using gdown,
!unzip -o "ai_spleen_seg_bundle_data.zip"
Requirement already satisfied: gdown in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages (4.5.1)
Requirement already satisfied: six in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages (from gdown) (1.16.0)
Requirement already satisfied: filelock in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages (from gdown) (3.8.0)
Requirement already satisfied: tqdm in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages (from gdown) (4.64.0)
Requirement already satisfied: requests[socks] in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages (from gdown) (2.28.1)
Requirement already satisfied: beautifulsoup4 in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages (from gdown) (4.11.1)
Requirement already satisfied: soupsieve>1.2 in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages (from beautifulsoup4->gdown) (2.3.2.post1)
Requirement already satisfied: charset-normalizer<3,>=2 in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages (from requests[socks]->gdown) (2.1.1)
Requirement already satisfied: idna<4,>=2.5 in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages (from requests[socks]->gdown) (3.3)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages (from requests[socks]->gdown) (1.26.12)
Requirement already satisfied: certifi>=2017.4.17 in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages (from requests[socks]->gdown) (2022.6.15)
Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages (from requests[socks]->gdown) (1.7.1)
Downloading...
From: https://drive.google.com/uc?id=1Uds8mEvdGNYUuvFpTtCQ8gNU97bAPCaQ
To: /home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/ai_spleen_seg_bundle_data.zip
100%|███████████████████████████████████████| 79.4M/79.4M [00:00<00:00, 102MB/s]
Archive:  ai_spleen_seg_bundle_data.zip
  inflating: dcm/1-001.dcm           
  inflating: dcm/1-002.dcm           
  inflating: dcm/1-003.dcm           
  inflating: dcm/1-004.dcm           
  inflating: dcm/1-005.dcm           
  inflating: dcm/1-006.dcm           
  inflating: dcm/1-007.dcm           
  inflating: dcm/1-008.dcm           
  inflating: dcm/1-009.dcm           
  inflating: dcm/1-010.dcm           
  inflating: dcm/1-011.dcm           
  inflating: dcm/1-012.dcm           
  inflating: dcm/1-013.dcm           
  inflating: dcm/1-014.dcm           
  inflating: dcm/1-015.dcm           
  inflating: dcm/1-016.dcm           
  inflating: dcm/1-017.dcm           
  inflating: dcm/1-018.dcm           
  inflating: dcm/1-019.dcm           
  inflating: dcm/1-020.dcm           
  inflating: dcm/1-021.dcm           
  inflating: dcm/1-022.dcm           
  inflating: dcm/1-023.dcm           
  inflating: dcm/1-024.dcm           
  inflating: dcm/1-025.dcm           
  inflating: dcm/1-026.dcm           
  inflating: dcm/1-027.dcm           
  inflating: dcm/1-028.dcm           
  inflating: dcm/1-029.dcm           
  inflating: dcm/1-030.dcm           
  inflating: dcm/1-031.dcm           
  inflating: dcm/1-032.dcm           
  inflating: dcm/1-033.dcm           
  inflating: dcm/1-034.dcm           
  inflating: dcm/1-035.dcm           
  inflating: dcm/1-036.dcm           
  inflating: dcm/1-037.dcm           
  inflating: dcm/1-038.dcm           
  inflating: dcm/1-039.dcm           
  inflating: dcm/1-040.dcm           
  inflating: dcm/1-041.dcm           
  inflating: dcm/1-042.dcm           
  inflating: dcm/1-043.dcm           
  inflating: dcm/1-044.dcm           
  inflating: dcm/1-045.dcm           
  inflating: dcm/1-046.dcm           
  inflating: dcm/1-047.dcm           
  inflating: dcm/1-048.dcm           
  inflating: dcm/1-049.dcm           
  inflating: dcm/1-050.dcm           
  inflating: dcm/1-051.dcm           
  inflating: dcm/1-052.dcm           
  inflating: dcm/1-053.dcm           
  inflating: dcm/1-054.dcm           
  inflating: dcm/1-055.dcm           
  inflating: dcm/1-056.dcm           
  inflating: dcm/1-057.dcm           
  inflating: dcm/1-058.dcm           
  inflating: dcm/1-059.dcm           
  inflating: dcm/1-060.dcm           
  inflating: dcm/1-061.dcm           
  inflating: dcm/1-062.dcm           
  inflating: dcm/1-063.dcm           
  inflating: dcm/1-064.dcm           
  inflating: dcm/1-065.dcm           
  inflating: dcm/1-066.dcm           
  inflating: dcm/1-067.dcm           
  inflating: dcm/1-068.dcm           
  inflating: dcm/1-069.dcm           
  inflating: dcm/1-070.dcm           
  inflating: dcm/1-071.dcm           
  inflating: dcm/1-072.dcm           
  inflating: dcm/1-073.dcm           
  inflating: dcm/1-074.dcm           
  inflating: dcm/1-075.dcm           
  inflating: dcm/1-076.dcm           
  inflating: dcm/1-077.dcm           
  inflating: dcm/1-078.dcm           
  inflating: dcm/1-079.dcm           
  inflating: dcm/1-080.dcm           
  inflating: dcm/1-081.dcm           
  inflating: dcm/1-082.dcm           
  inflating: dcm/1-083.dcm           
  inflating: dcm/1-084.dcm           
  inflating: dcm/1-085.dcm           
  inflating: dcm/1-086.dcm           
  inflating: dcm/1-087.dcm           
  inflating: dcm/1-088.dcm           
  inflating: dcm/1-089.dcm           
  inflating: dcm/1-090.dcm           
  inflating: dcm/1-091.dcm           
  inflating: dcm/1-092.dcm           
  inflating: dcm/1-093.dcm           
  inflating: dcm/1-094.dcm           
  inflating: dcm/1-095.dcm           
  inflating: dcm/1-096.dcm           
  inflating: dcm/1-097.dcm           
  inflating: dcm/1-098.dcm           
  inflating: dcm/1-099.dcm           
  inflating: dcm/1-100.dcm           
  inflating: dcm/1-101.dcm           
  inflating: dcm/1-102.dcm           
  inflating: dcm/1-103.dcm           
  inflating: dcm/1-104.dcm           
  inflating: dcm/1-105.dcm           
  inflating: dcm/1-106.dcm           
  inflating: dcm/1-107.dcm           
  inflating: dcm/1-108.dcm           
  inflating: dcm/1-109.dcm           
  inflating: dcm/1-110.dcm           
  inflating: dcm/1-111.dcm           
  inflating: dcm/1-112.dcm           
  inflating: dcm/1-113.dcm           
  inflating: dcm/1-114.dcm           
  inflating: dcm/1-115.dcm           
  inflating: dcm/1-116.dcm           
  inflating: dcm/1-117.dcm           
  inflating: dcm/1-118.dcm           
  inflating: dcm/1-119.dcm           
  inflating: dcm/1-120.dcm           
  inflating: dcm/1-121.dcm           
  inflating: dcm/1-122.dcm           
  inflating: dcm/1-123.dcm           
  inflating: dcm/1-124.dcm           
  inflating: dcm/1-125.dcm           
  inflating: dcm/1-126.dcm           
  inflating: dcm/1-127.dcm           
  inflating: dcm/1-128.dcm           
  inflating: dcm/1-129.dcm           
  inflating: dcm/1-130.dcm           
  inflating: dcm/1-131.dcm           
  inflating: dcm/1-132.dcm           
  inflating: dcm/1-133.dcm           
  inflating: dcm/1-134.dcm           
  inflating: dcm/1-135.dcm           
  inflating: dcm/1-136.dcm           
  inflating: dcm/1-137.dcm           
  inflating: dcm/1-138.dcm           
  inflating: dcm/1-139.dcm           
  inflating: dcm/1-140.dcm           
  inflating: dcm/1-141.dcm           
  inflating: dcm/1-142.dcm           
  inflating: dcm/1-143.dcm           
  inflating: dcm/1-144.dcm           
  inflating: dcm/1-145.dcm           
  inflating: dcm/1-146.dcm           
  inflating: dcm/1-147.dcm           
  inflating: dcm/1-148.dcm           
  inflating: dcm/1-149.dcm           
  inflating: dcm/1-150.dcm           
  inflating: dcm/1-151.dcm           
  inflating: dcm/1-152.dcm           
  inflating: dcm/1-153.dcm           
  inflating: dcm/1-154.dcm           
  inflating: dcm/1-155.dcm           
  inflating: dcm/1-156.dcm           
  inflating: dcm/1-157.dcm           
  inflating: dcm/1-158.dcm           
  inflating: dcm/1-159.dcm           
  inflating: dcm/1-160.dcm           
  inflating: dcm/1-161.dcm           
  inflating: dcm/1-162.dcm           
  inflating: dcm/1-163.dcm           
  inflating: dcm/1-164.dcm           
  inflating: dcm/1-165.dcm           
  inflating: dcm/1-166.dcm           
  inflating: dcm/1-167.dcm           
  inflating: dcm/1-168.dcm           
  inflating: dcm/1-169.dcm           
  inflating: dcm/1-170.dcm           
  inflating: dcm/1-171.dcm           
  inflating: dcm/1-172.dcm           
  inflating: dcm/1-173.dcm           
  inflating: dcm/1-174.dcm           
  inflating: dcm/1-175.dcm           
  inflating: dcm/1-176.dcm           
  inflating: dcm/1-177.dcm           
  inflating: dcm/1-178.dcm           
  inflating: dcm/1-179.dcm           
  inflating: dcm/1-180.dcm           
  inflating: dcm/1-181.dcm           
  inflating: dcm/1-182.dcm           
  inflating: dcm/1-183.dcm           
  inflating: dcm/1-184.dcm           
  inflating: dcm/1-185.dcm           
  inflating: dcm/1-186.dcm           
  inflating: dcm/1-187.dcm           
  inflating: dcm/1-188.dcm           
  inflating: dcm/1-189.dcm           
  inflating: dcm/1-190.dcm           
  inflating: dcm/1-191.dcm           
  inflating: dcm/1-192.dcm           
  inflating: dcm/1-193.dcm           
  inflating: dcm/1-194.dcm           
  inflating: dcm/1-195.dcm           
  inflating: dcm/1-196.dcm           
  inflating: dcm/1-197.dcm           
  inflating: dcm/1-198.dcm           
  inflating: dcm/1-199.dcm           
  inflating: dcm/1-200.dcm           
  inflating: dcm/1-201.dcm           
  inflating: dcm/1-202.dcm           
  inflating: dcm/1-203.dcm           
  inflating: dcm/1-204.dcm           
  inflating: model.ts                

Setup imports

Let’s import necessary classes/decorators to define Application and Operator.

import logging
from os import path

from numpy import uint8

import monai.deploy.core as md
from monai.deploy.core import ExecutionContext, Image, InputContext, IOType, Operator, OutputContext
from monai.deploy.operators.monai_seg_inference_operator import InMemImageReader, MonaiSegInferenceOperator
from monai.transforms import (
    Activationsd,
    AsDiscreted,
    Compose,
    EnsureChannelFirstd,
    EnsureTyped,
    Invertd,
    LoadImaged,
    Orientationd,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
)

# Required for setting SegmentDescription attributes. Direct import as this is not part of App SDK package.
from pydicom.sr.codedict import codes

from monai.deploy.core import Application, resource
from monai.deploy.operators.dicom_data_loader_operator import DICOMDataLoaderOperator
from monai.deploy.operators.dicom_seg_writer_operator import DICOMSegmentationWriterOperator, SegmentDescription
from monai.deploy.operators.dicom_series_selector_operator import DICOMSeriesSelectorOperator
from monai.deploy.operators.dicom_series_to_volume_operator import DICOMSeriesToVolumeOperator
from monai.deploy.operators.clara_viz_operator import ClaraVizOperator

Creating Model Specific Inference Operator classes

Each Operator class inherits Operator class and input/output properties are specified by using @input/@output decorators.

Business logic would be implemented in the compute() method.

The App SDK provides a MonaiSegInferenceOperator class to perform segmentation prediction with a Torch Script model. For consistency, this class uses MONAI dictionary-based transforms, as Compose object, for pre and post transforms. The model-specific inference operator will then only need to create the pre and post transform Compose based on what has been used in the model training and validation. Note that for deploy application, ignite is not needed nor supported.

SpleenSegOperator

The SpleenSegOperator gets as input an in-memory Image object that has been converted from a DICOM CT series by the preceding DICOMSeriesToVolumeOperator, and as output in-memory segmentation Image object.

The pre_process function creates the pre-transforms Compose object. For LoadImage, a specialized InMemImageReader, derived from MONAI ImageReader, is used to convert the in-memory pixel data and return the numpy array as well as the meta-data. Also, the DICOM input pixel spacings are often not the same as expected by the model, so the Spacingd transform must be used to re-sample the image with the expected spacing.

The post_process function creates the post-transform Compose object. The SaveImageD transform class is used to save the segmentation mask as NIfTI image file, which is optional as the in-memory mask image will be passed down to the DICOM Segmentation writer for creating a DICOM Segmentation instance. The Invertd must also be used to revert the segmentation image’s orientation and spacing to be the same as the input.

When the MonaiSegInferenceOperator object is created, the ROI size is specified, as well as the transform Compose objects. Furthermore, the dataset image key names are set accordingly.

Loading of the model and performing the prediction are encapsulated in the MonaiSegInferenceOperator and other SDK classes. Once the inference is completed, the segmentation Image object is created and set to the output (op_output.set(value, label)), by the MonaiSegInferenceOperator.

@md.input("image", Image, IOType.IN_MEMORY)
@md.output("seg_image", Image, IOType.IN_MEMORY)
@md.env(pip_packages=["monai>=0.8.1", "torch>=1.5", "numpy>=1.21", "nibabel"])
class SpleenSegOperator(Operator):
    """Performs Spleen segmentation with a 3D image converted from a DICOM CT series.
    """

    def __init__(self):

        self.logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
        super().__init__()
        self._input_dataset_key = "image"
        self._pred_dataset_key = "pred"

    def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext):

        input_image = op_input.get("image")
        if not input_image:
            raise ValueError("Input image is not found.")

        output_path = context.output.get().path

        # This operator gets an in-memory Image object, so a specialized ImageReader is needed.
        _reader = InMemImageReader(input_image)
        pre_transforms = self.pre_process(_reader)
        post_transforms = self.post_process(pre_transforms, path.join(output_path, "prediction_output"))

        # Delegates inference and saving output to the built-in operator.
        infer_operator = MonaiSegInferenceOperator(
            (
                96,
                96,
                96,
            ),
            pre_transforms,
            post_transforms,
        )

        # Setting the keys used in the dictironary based transforms may change.
        infer_operator.input_dataset_key = self._input_dataset_key
        infer_operator.pred_dataset_key = self._pred_dataset_key

        # Now let the built-in operator handles the work with the I/O spec and execution context.
        infer_operator.compute(op_input, op_output, context)

    def pre_process(self, img_reader) -> Compose:
        """Composes transforms for preprocessing input before predicting on a model."""

        my_key = self._input_dataset_key
        return Compose(
            [
                LoadImaged(keys=my_key, reader=img_reader),
                EnsureChannelFirstd(keys=my_key),
                Orientationd(keys=my_key, axcodes="RAS"),
                Spacingd(keys=my_key, pixdim=[1.5, 1.5, 2.9], mode=["bilinear"]),
                ScaleIntensityRanged(keys=my_key, a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
                EnsureTyped(keys=my_key),
            ]
        )

    def post_process(self, pre_transforms: Compose, out_dir: str = "./prediction_output") -> Compose:
        """Composes transforms for postprocessing the prediction results."""

        pred_key = self._pred_dataset_key
        return Compose(
            [
                Activationsd(keys=pred_key, softmax=True),
                Invertd(
                    keys=pred_key,
                    transform=pre_transforms,
                    orig_keys=self._input_dataset_key,
                    nearest_interp=False,
                    to_tensor=True,
                ),
                AsDiscreted(keys=pred_key, argmax=True),
                SaveImaged(
                    keys=pred_key,
                    output_dir=out_dir,
                    output_postfix="seg",
                    output_dtype=uint8,
                ),
            ]
        )

Creating Application class

Our application class would look like below.

It defines App class, inheriting Application class.

The requirements (resource and package dependency) for the App can be specified by using @resource and @env decorators.

The base class method, compose, is overridden. Objects required for DICOM parsing, series selection (selecting the first series for the current release), pixel data conversion to volume image, and segmentation instance creation are created, so is the model-specific SpleenSegOperator. The execution pipeline, as a Directed Acyclic Graph, is created by connecting these objects through self.add_flow().

@resource(cpu=1, gpu=1, memory="7Gi")
class AISpleenSegApp(Application):
    def __init__(self, *args, **kwargs):
        """Creates an application instance."""

        self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
        super().__init__(*args, **kwargs)

    def run(self, *args, **kwargs):
        # This method calls the base class to run. Can be omitted if simply calling through.
        self._logger.debug(f"Begin {self.run.__name__}")
        super().run(*args, **kwargs)
        self._logger.debug(f"End {self.run.__name__}")

    def compose(self):
        """Creates the app specific operators and chain them up in the processing DAG."""

        self._logger.debug(f"Begin {self.compose.__name__}")
        # Creates the custom operator(s) as well as SDK built-in operator(s).
        study_loader_op = DICOMDataLoaderOperator()
        series_selector_op = DICOMSeriesSelectorOperator(rules=Sample_Rules_Text)
        series_to_vol_op = DICOMSeriesToVolumeOperator()
        # Model specific inference operator, supporting MONAI transforms.

        # Creates the model specific segmentation operator
        spleen_seg_op = SpleenSegOperator()

        # Create DICOM Seg writer providing the required segment description for each segment with
        # the actual algorithm and the pertinent organ/tissue.
        # The segment_label, algorithm_name, and algorithm_version are limited to 64 chars.
        # https://dicom.nema.org/medical/dicom/current/output/chtml/part05/sect_6.2.html
        # User can Look up SNOMED CT codes at, e.g.
        # https://bioportal.bioontology.org/ontologies/SNOMEDCT

        _algorithm_name = "3D segmentation of the Spleen from a CT series"
        _algorithm_family = codes.DCM.ArtificialIntelligence
        _algorithm_version = "0.1.0"

        segment_descriptions = [
            SegmentDescription(
                segment_label="Lung",
                segmented_property_category=codes.SCT.Organ,
                segmented_property_type=codes.SCT.Lung,
                algorithm_name=_algorithm_name,
                algorithm_family=_algorithm_family,
                algorithm_version=_algorithm_version,
            ),
        ]

        dicom_seg_writer = DICOMSegmentationWriterOperator(segment_descriptions)

        # Create the processing pipeline, by specifying the source and destination operators, and
        # ensuring the output from the former matches the input of the latter, in both name and type.
        self.add_flow(study_loader_op, series_selector_op, {"dicom_study_list": "dicom_study_list"})
        self.add_flow(
            series_selector_op, series_to_vol_op, {"study_selected_series_list": "study_selected_series_list"}
        )
        self.add_flow(series_to_vol_op, spleen_seg_op, {"image": "image"})

        # Note below the dicom_seg_writer requires two inputs, each coming from a source operator.
        self.add_flow(
            series_selector_op, dicom_seg_writer, {"study_selected_series_list": "study_selected_series_list"}
        )
        self.add_flow(spleen_seg_op, dicom_seg_writer, {"seg_image": "seg_image"})

        viz_op = ClaraVizOperator()
        self.add_flow(series_to_vol_op, viz_op, {"image": "image"})
        self.add_flow(spleen_seg_op, viz_op, {"seg_image": "seg_image"})

        self._logger.debug(f"End {self.compose.__name__}")

# This is a sample series selection rule in JSON, simply selecting CT series.
# If the study has more than 1 CT series, then all of them will be selected.
# Please see more detail in DICOMSeriesSelectorOperator.
# For list of string values, e.g. "ImageType": ["PRIMARY", "ORIGINAL"], it is a match if all elements
# are all in the multi-value attribute of the DICOM series.
Sample_Rules_Text = """
{
    "selections": [
        {
            "name": "CT Series",
            "conditions": {
                "StudyDescription": "(.*?)",
                "Modality": "(?i)CT",
                "SeriesDescription": "(.*?)",
                "ImageType": ["PRIMARY", "ORIGINAL"]
            }
        }
    ]
}
"""

Executing app locally

We can execute the app in the Jupyter notebook. Note that the DICOM files of the CT Abdomen series must be present in the dcm and the Torch Script model at model.ts. Please use the actual path in your environment.

app = AISpleenSegApp()

app.run(input="dcm", output="output", model="model.ts")
Going to initiate execution of operator DICOMDataLoaderOperator
Executing operator DICOMDataLoaderOperator (Process ID: 1084421, Operator ID: ad334e7d-979b-484b-b73d-70bb012cfe05)
[2022-10-18 21:33:47,082] [INFO] (root) - Finding series for Selection named: CT Series
[2022-10-18 21:33:47,084] [INFO] (root) - Searching study, : 1.3.6.1.4.1.14519.5.2.1.7085.2626.822645453932810382886582736291
  # of series: 1
[2022-10-18 21:33:47,085] [INFO] (root) - Working on series, instance UID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.119403521930927333027265674239
[2022-10-18 21:33:47,087] [INFO] (root) - On attribute: 'StudyDescription' to match value: '(.*?)'
[2022-10-18 21:33:47,087] [INFO] (root) -     Series attribute StudyDescription value: CT ABDOMEN W IV CONTRAST
[2022-10-18 21:33:47,088] [INFO] (root) - Series attribute string value did not match. Try regEx.
[2022-10-18 21:33:47,088] [INFO] (root) - On attribute: 'Modality' to match value: '(?i)CT'
[2022-10-18 21:33:47,089] [INFO] (root) -     Series attribute Modality value: CT
[2022-10-18 21:33:47,089] [INFO] (root) - Series attribute string value did not match. Try regEx.
[2022-10-18 21:33:47,090] [INFO] (root) - On attribute: 'SeriesDescription' to match value: '(.*?)'
[2022-10-18 21:33:47,091] [INFO] (root) -     Series attribute SeriesDescription value: ABD/PANC 3.0 B31f
[2022-10-18 21:33:47,091] [INFO] (root) - Series attribute string value did not match. Try regEx.
[2022-10-18 21:33:47,091] [INFO] (root) - On attribute: 'ImageType' to match value: '['PRIMARY', 'ORIGINAL']'
[2022-10-18 21:33:47,092] [INFO] (root) -     Series attribute ImageType value: None
[2022-10-18 21:33:47,092] [INFO] (root) - Selected Series, UID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.119403521930927333027265674239
Done performing execution of operator DICOMDataLoaderOperator

Going to initiate execution of operator DICOMSeriesSelectorOperator
Executing operator DICOMSeriesSelectorOperator (Process ID: 1084421, Operator ID: 52bc0253-84e7-4d83-bf67-aea10d6df3ae)
Done performing execution of operator DICOMSeriesSelectorOperator

Going to initiate execution of operator DICOMSeriesToVolumeOperator
Executing operator DICOMSeriesToVolumeOperator (Process ID: 1084421, Operator ID: 47a7415b-28f3-4a81-abf8-43fa2bd55389)
Done performing execution of operator DICOMSeriesToVolumeOperator

Going to initiate execution of operator SpleenSegOperator
Executing operator SpleenSegOperator (Process ID: 1084421, Operator ID: 2963c156-eb97-43ec-b48d-ddbc67984cc1)
Converted Image object metadata:
SeriesInstanceUID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.119403521930927333027265674239, type <class 'str'>
SeriesDate: 20090831, type <class 'str'>
SeriesTime: 101721.452, type <class 'str'>
Modality: CT, type <class 'str'>
SeriesDescription: ABD/PANC 3.0 B31f, type <class 'str'>
PatientPosition: HFS, type <class 'str'>
SeriesNumber: 8, type <class 'int'>
row_pixel_spacing: 0.7890625, type <class 'float'>
col_pixel_spacing: 0.7890625, type <class 'float'>
depth_pixel_spacing: 1.5, type <class 'float'>
row_direction_cosine: [1.0, 0.0, 0.0], type <class 'list'>
col_direction_cosine: [0.0, 1.0, 0.0], type <class 'list'>
depth_direction_cosine: [0.0, 0.0, 1.0], type <class 'list'>
dicom_affine_transform: [[   0.7890625    0.           0.        -197.60547  ]
 [   0.           0.7890625    0.        -398.60547  ]
 [   0.           0.           1.5       -383.       ]
 [   0.           0.           0.           1.       ]], type <class 'numpy.ndarray'>
nifti_affine_transform: [[  -0.7890625   -0.          -0.         197.60547  ]
 [  -0.          -0.7890625   -0.         398.60547  ]
 [   0.           0.           1.5       -383.       ]
 [   0.           0.           0.           1.       ]], type <class 'numpy.ndarray'>
StudyInstanceUID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.822645453932810382886582736291, type <class 'str'>
StudyID: , type <class 'str'>
StudyDate: 20090831, type <class 'str'>
StudyTime: 095948.599, type <class 'str'>
StudyDescription: CT ABDOMEN W IV CONTRAST, type <class 'str'>
AccessionNumber: 5471978513296937, type <class 'str'>
selection_name: CT Series, type <class 'str'>
2022-10-18 21:34:00,641 INFO image_writer.py:194 - writing: /home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/output/prediction_output/1.3.6.1.4.1.14519.5.2.1.7085.2626/1.3.6.1.4.1.14519.5.2.1.7085.2626_seg.nii.gz
Output Seg image numpy array shaped: (204, 512, 512)
Output Seg image pixel max value: 1
Done performing execution of operator SpleenSegOperator

Going to initiate execution of operator DICOMSegmentationWriterOperator
Executing operator DICOMSegmentationWriterOperator (Process ID: 1084421, Operator ID: 8340781d-9a63-46f1-ae80-1405bb6a644d)
/home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages/highdicom/valuerep.py:54: UserWarning: The string "C3N-00198" is unlikely to represent the intended person name since it contains only a single component. Construct a person name according to the format in described in http://dicom.nema.org/dicom/2013/output/chtml/part05/sect_6.2.html#sect_6.2.1.2, or, in pydicom 2.2.0 or later, use the pydicom.valuerep.PersonName.from_named_components() method to construct the person name correctly. If a single-component name is really intended, add a trailing caret character to disambiguate the name.
  warnings.warn(
[2022-10-18 21:34:04,374] [INFO] (highdicom.seg.sop) - add plane #0 for segment #1
[2022-10-18 21:34:04,377] [INFO] (highdicom.seg.sop) - add plane #1 for segment #1
[2022-10-18 21:34:04,378] [INFO] (highdicom.seg.sop) - add plane #2 for segment #1
[2022-10-18 21:34:04,381] [INFO] (highdicom.seg.sop) - add plane #3 for segment #1
[2022-10-18 21:34:04,382] [INFO] (highdicom.seg.sop) - add plane #4 for segment #1
[2022-10-18 21:34:04,384] [INFO] (highdicom.seg.sop) - add plane #5 for segment #1
[2022-10-18 21:34:04,386] [INFO] (highdicom.seg.sop) - add plane #6 for segment #1
[2022-10-18 21:34:04,388] [INFO] (highdicom.seg.sop) - add plane #7 for segment #1
[2022-10-18 21:34:04,389] [INFO] (highdicom.seg.sop) - add plane #8 for segment #1
[2022-10-18 21:34:04,391] [INFO] (highdicom.seg.sop) - add plane #9 for segment #1
[2022-10-18 21:34:04,393] [INFO] (highdicom.seg.sop) - add plane #10 for segment #1
[2022-10-18 21:34:04,394] [INFO] (highdicom.seg.sop) - add plane #11 for segment #1
[2022-10-18 21:34:04,396] [INFO] (highdicom.seg.sop) - add plane #12 for segment #1
[2022-10-18 21:34:04,398] [INFO] (highdicom.seg.sop) - add plane #13 for segment #1
[2022-10-18 21:34:04,400] [INFO] (highdicom.seg.sop) - add plane #14 for segment #1
[2022-10-18 21:34:04,402] [INFO] (highdicom.seg.sop) - add plane #15 for segment #1
[2022-10-18 21:34:04,404] [INFO] (highdicom.seg.sop) - add plane #16 for segment #1
[2022-10-18 21:34:04,405] [INFO] (highdicom.seg.sop) - add plane #17 for segment #1
[2022-10-18 21:34:04,407] [INFO] (highdicom.seg.sop) - add plane #18 for segment #1
[2022-10-18 21:34:04,409] [INFO] (highdicom.seg.sop) - add plane #19 for segment #1
[2022-10-18 21:34:04,411] [INFO] (highdicom.seg.sop) - add plane #20 for segment #1
[2022-10-18 21:34:04,413] [INFO] (highdicom.seg.sop) - add plane #21 for segment #1
[2022-10-18 21:34:04,415] [INFO] (highdicom.seg.sop) - add plane #22 for segment #1
[2022-10-18 21:34:04,418] [INFO] (highdicom.seg.sop) - add plane #23 for segment #1
[2022-10-18 21:34:04,421] [INFO] (highdicom.seg.sop) - add plane #24 for segment #1
[2022-10-18 21:34:04,424] [INFO] (highdicom.seg.sop) - add plane #25 for segment #1
[2022-10-18 21:34:04,426] [INFO] (highdicom.seg.sop) - add plane #26 for segment #1
[2022-10-18 21:34:04,428] [INFO] (highdicom.seg.sop) - add plane #27 for segment #1
[2022-10-18 21:34:04,430] [INFO] (highdicom.seg.sop) - add plane #28 for segment #1
[2022-10-18 21:34:04,432] [INFO] (highdicom.seg.sop) - add plane #29 for segment #1
[2022-10-18 21:34:04,435] [INFO] (highdicom.seg.sop) - add plane #30 for segment #1
[2022-10-18 21:34:04,438] [INFO] (highdicom.seg.sop) - add plane #31 for segment #1
[2022-10-18 21:34:04,441] [INFO] (highdicom.seg.sop) - add plane #32 for segment #1
[2022-10-18 21:34:04,444] [INFO] (highdicom.seg.sop) - add plane #33 for segment #1
[2022-10-18 21:34:04,446] [INFO] (highdicom.seg.sop) - add plane #34 for segment #1
[2022-10-18 21:34:04,449] [INFO] (highdicom.seg.sop) - add plane #35 for segment #1
[2022-10-18 21:34:04,452] [INFO] (highdicom.seg.sop) - add plane #36 for segment #1
[2022-10-18 21:34:04,454] [INFO] (highdicom.seg.sop) - add plane #37 for segment #1
[2022-10-18 21:34:04,457] [INFO] (highdicom.seg.sop) - add plane #38 for segment #1
[2022-10-18 21:34:04,459] [INFO] (highdicom.seg.sop) - add plane #39 for segment #1
[2022-10-18 21:34:04,460] [INFO] (highdicom.seg.sop) - add plane #40 for segment #1
[2022-10-18 21:34:04,462] [INFO] (highdicom.seg.sop) - add plane #41 for segment #1
[2022-10-18 21:34:04,464] [INFO] (highdicom.seg.sop) - add plane #42 for segment #1
[2022-10-18 21:34:04,465] [INFO] (highdicom.seg.sop) - add plane #43 for segment #1
[2022-10-18 21:34:04,467] [INFO] (highdicom.seg.sop) - add plane #44 for segment #1
[2022-10-18 21:34:04,469] [INFO] (highdicom.seg.sop) - add plane #45 for segment #1
[2022-10-18 21:34:04,471] [INFO] (highdicom.seg.sop) - add plane #46 for segment #1
[2022-10-18 21:34:04,473] [INFO] (highdicom.seg.sop) - add plane #47 for segment #1
[2022-10-18 21:34:04,475] [INFO] (highdicom.seg.sop) - add plane #48 for segment #1
[2022-10-18 21:34:04,477] [INFO] (highdicom.seg.sop) - add plane #49 for segment #1
[2022-10-18 21:34:04,479] [INFO] (highdicom.seg.sop) - add plane #50 for segment #1
[2022-10-18 21:34:04,480] [INFO] (highdicom.seg.sop) - add plane #51 for segment #1
[2022-10-18 21:34:04,482] [INFO] (highdicom.seg.sop) - add plane #52 for segment #1
[2022-10-18 21:34:04,484] [INFO] (highdicom.seg.sop) - add plane #53 for segment #1
[2022-10-18 21:34:04,485] [INFO] (highdicom.seg.sop) - add plane #54 for segment #1
[2022-10-18 21:34:04,487] [INFO] (highdicom.seg.sop) - add plane #55 for segment #1
[2022-10-18 21:34:04,489] [INFO] (highdicom.seg.sop) - add plane #56 for segment #1
[2022-10-18 21:34:04,491] [INFO] (highdicom.seg.sop) - add plane #57 for segment #1
[2022-10-18 21:34:04,494] [INFO] (highdicom.seg.sop) - add plane #58 for segment #1
[2022-10-18 21:34:04,496] [INFO] (highdicom.seg.sop) - add plane #59 for segment #1
[2022-10-18 21:34:04,498] [INFO] (highdicom.seg.sop) - add plane #60 for segment #1
[2022-10-18 21:34:04,500] [INFO] (highdicom.seg.sop) - add plane #61 for segment #1
[2022-10-18 21:34:04,502] [INFO] (highdicom.seg.sop) - add plane #62 for segment #1
[2022-10-18 21:34:04,504] [INFO] (highdicom.seg.sop) - add plane #63 for segment #1
[2022-10-18 21:34:04,505] [INFO] (highdicom.seg.sop) - add plane #64 for segment #1
[2022-10-18 21:34:04,507] [INFO] (highdicom.seg.sop) - add plane #65 for segment #1
[2022-10-18 21:34:04,509] [INFO] (highdicom.seg.sop) - add plane #66 for segment #1
[2022-10-18 21:34:04,510] [INFO] (highdicom.seg.sop) - add plane #67 for segment #1
[2022-10-18 21:34:04,513] [INFO] (highdicom.seg.sop) - add plane #68 for segment #1
[2022-10-18 21:34:04,516] [INFO] (highdicom.seg.sop) - add plane #69 for segment #1
[2022-10-18 21:34:04,519] [INFO] (highdicom.seg.sop) - add plane #70 for segment #1
[2022-10-18 21:34:04,528] [INFO] (highdicom.seg.sop) - add plane #71 for segment #1
[2022-10-18 21:34:04,532] [INFO] (highdicom.seg.sop) - add plane #72 for segment #1
[2022-10-18 21:34:04,535] [INFO] (highdicom.seg.sop) - add plane #73 for segment #1
[2022-10-18 21:34:04,538] [INFO] (highdicom.seg.sop) - add plane #74 for segment #1
[2022-10-18 21:34:04,542] [INFO] (highdicom.seg.sop) - add plane #75 for segment #1
[2022-10-18 21:34:04,546] [INFO] (highdicom.seg.sop) - add plane #76 for segment #1
[2022-10-18 21:34:04,549] [INFO] (highdicom.seg.sop) - add plane #77 for segment #1
[2022-10-18 21:34:04,552] [INFO] (highdicom.seg.sop) - add plane #78 for segment #1
[2022-10-18 21:34:04,555] [INFO] (highdicom.seg.sop) - add plane #79 for segment #1
[2022-10-18 21:34:04,558] [INFO] (highdicom.seg.sop) - add plane #80 for segment #1
[2022-10-18 21:34:04,561] [INFO] (highdicom.seg.sop) - add plane #81 for segment #1
[2022-10-18 21:34:04,563] [INFO] (highdicom.seg.sop) - add plane #82 for segment #1
[2022-10-18 21:34:04,566] [INFO] (highdicom.seg.sop) - add plane #83 for segment #1
[2022-10-18 21:34:04,568] [INFO] (highdicom.seg.sop) - add plane #84 for segment #1
[2022-10-18 21:34:04,570] [INFO] (highdicom.seg.sop) - add plane #85 for segment #1
[2022-10-18 21:34:04,573] [INFO] (highdicom.seg.sop) - add plane #86 for segment #1
[2022-10-18 21:34:04,575] [INFO] (highdicom.seg.sop) - add plane #87 for segment #1
[2022-10-18 21:34:04,625] [INFO] (highdicom.base) - copy Image-related attributes from dataset "1.3.6.1.4.1.14519.5.2.1.7085.2626.936983343951485811186213470191"
[2022-10-18 21:34:04,626] [INFO] (highdicom.base) - copy attributes of module "Specimen"
[2022-10-18 21:34:04,626] [INFO] (highdicom.base) - copy Patient-related attributes from dataset "1.3.6.1.4.1.14519.5.2.1.7085.2626.936983343951485811186213470191"
[2022-10-18 21:34:04,626] [INFO] (highdicom.base) - copy attributes of module "Patient"
[2022-10-18 21:34:04,627] [INFO] (highdicom.base) - copy attributes of module "Clinical Trial Subject"
[2022-10-18 21:34:04,628] [INFO] (highdicom.base) - copy Study-related attributes from dataset "1.3.6.1.4.1.14519.5.2.1.7085.2626.936983343951485811186213470191"
[2022-10-18 21:34:04,628] [INFO] (highdicom.base) - copy attributes of module "General Study"
[2022-10-18 21:34:04,629] [INFO] (highdicom.base) - copy attributes of module "Patient Study"
[2022-10-18 21:34:04,629] [INFO] (highdicom.base) - copy attributes of module "Clinical Trial Study"
Done performing execution of operator DICOMSegmentationWriterOperator

Going to initiate execution of operator ClaraVizOperator
Executing operator ClaraVizOperator (Process ID: 1084421, Operator ID: 22570e7b-4760-478a-b876-78cc803f90c3)
Done performing execution of operator ClaraVizOperator

Once the application is verified inside Jupyter notebook, we can write the above Python code into Python files in an application folder.

The application folder structure would look like below:

my_app
├── __main__.py
├── app.py
└── spleen_seg_operator.py

Note

We can create a single application Python file (such as spleen_app.py) that includes the content of the files, instead of creating multiple files. You will see such an example in MedNist Classifier Tutorial.

# Create an application folder
!mkdir -p my_app

spleen_seg_operator.py

%%writefile my_app/spleen_seg_operator.py
import logging
from os import path

from numpy import uint8

import monai.deploy.core as md
from monai.deploy.core import ExecutionContext, Image, InputContext, IOType, Operator, OutputContext
from monai.deploy.operators.monai_seg_inference_operator import InMemImageReader, MonaiSegInferenceOperator
from monai.transforms import (
    Activationsd,
    AsDiscreted,
    Compose,
    EnsureChannelFirstd,
    EnsureTyped,
    Invertd,
    LoadImaged,
    Orientationd,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
)


@md.input("image", Image, IOType.IN_MEMORY)
@md.output("seg_image", Image, IOType.IN_MEMORY)
@md.env(pip_packages=["monai>=0.8.1", "torch>=1.10.2", "numpy>=1.21", "nibabel"])
class SpleenSegOperator(Operator):
    """Performs Spleen segmentation with a 3D image converted from a DICOM CT series.
    """

    def __init__(self):

        self.logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
        super().__init__()
        self._input_dataset_key = "image"
        self._pred_dataset_key = "pred"

    def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext):

        input_image = op_input.get("image")
        if not input_image:
            raise ValueError("Input image is not found.")

        output_path = context.output.get().path

        # This operator gets an in-memory Image object, so a specialized ImageReader is needed.
        _reader = InMemImageReader(input_image)
        pre_transforms = self.pre_process(_reader)
        post_transforms = self.post_process(pre_transforms, path.join(output_path, "prediction_output"))

        # Delegates inference and saving output to the built-in operator.
        infer_operator = MonaiSegInferenceOperator(
            (
                96,
                96,
                96,
            ),
            pre_transforms,
            post_transforms,
        )

        # Setting the keys used in the dictironary based transforms may change.
        infer_operator.input_dataset_key = self._input_dataset_key
        infer_operator.pred_dataset_key = self._pred_dataset_key

        # Now let the built-in operator handles the work with the I/O spec and execution context.
        infer_operator.compute(op_input, op_output, context)

    def pre_process(self, img_reader) -> Compose:
        """Composes transforms for preprocessing input before predicting on a model."""

        my_key = self._input_dataset_key
        return Compose(
            [
                LoadImaged(keys=my_key, reader=img_reader),
                EnsureChannelFirstd(keys=my_key),
                Orientationd(keys=my_key, axcodes="RAS"),
                Spacingd(keys=my_key, pixdim=[1.5, 1.5, 2.9], mode=["bilinear"]),
                ScaleIntensityRanged(keys=my_key, a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
                EnsureTyped(keys=my_key),
            ]
        )

    def post_process(self, pre_transforms: Compose, out_dir: str = "./prediction_output") -> Compose:
        """Composes transforms for postprocessing the prediction results."""

        pred_key = self._pred_dataset_key
        return Compose(
            [
                Activationsd(keys=pred_key, softmax=True),
                Invertd(
                    keys=pred_key,
                    transform=pre_transforms,
                    orig_keys=self._input_dataset_key,
                    nearest_interp=False,
                    to_tensor=True,
                ),
                AsDiscreted(keys=pred_key, argmax=True),
                SaveImaged(
                    keys=pred_key,
                    output_dir=out_dir,
                    output_postfix="seg",
                    output_dtype=uint8,
                ),
            ]
        )
Overwriting my_app/spleen_seg_operator.py

app.py

%%writefile my_app/app.py
import logging

from spleen_seg_operator import SpleenSegOperator

# Required for setting SegmentDescription attributes. Direct import as this is not part of App SDK package.
from pydicom.sr.codedict import codes

from monai.deploy.core import Application, resource
from monai.deploy.operators.dicom_data_loader_operator import DICOMDataLoaderOperator
from monai.deploy.operators.dicom_seg_writer_operator import DICOMSegmentationWriterOperator, SegmentDescription
from monai.deploy.operators.dicom_series_selector_operator import DICOMSeriesSelectorOperator
from monai.deploy.operators.dicom_series_to_volume_operator import DICOMSeriesToVolumeOperator
from monai.deploy.operators.clara_viz_operator import ClaraVizOperator

# This is a sample series selection rule in JSON, simply selecting CT series.
# If the study has more than 1 CT series, then all of them will be selected.
# Please see more detail in DICOMSeriesSelectorOperator.
Sample_Rules_Text = """
{
    "selections": [
        {
            "name": "CT Series",
            "conditions": {
                "StudyDescription": "(.*?)",
                "Modality": "(?i)CT",
                "SeriesDescription": "(.*?)",
                "ImageType": ["PRIMARY", "ORIGINAL"],
            }
        }
    ]
}
"""

@resource(cpu=1, gpu=1, memory="7Gi")
class AISpleenSegApp(Application):
    def __init__(self, *args, **kwargs):
        """Creates an application instance."""

        self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
        super().__init__(*args, **kwargs)

    def run(self, *args, **kwargs):
        # This method calls the base class to run. Can be omitted if simply calling through.
        self._logger.debug(f"Begin {self.run.__name__}")
        super().run(*args, **kwargs)
        self._logger.debug(f"End {self.run.__name__}")

    def compose(self):
        """Creates the app specific operators and chain them up in the processing DAG."""

        self._logger.debug(f"Begin {self.compose.__name__}")
        # Creates the custom operator(s) as well as SDK built-in operator(s).
        study_loader_op = DICOMDataLoaderOperator()
        series_selector_op = DICOMSeriesSelectorOperator(rules=Sample_Rules_Text)
        series_to_vol_op = DICOMSeriesToVolumeOperator()
        # Model specific inference operator, supporting MONAI transforms.

        # Creates the model specific segmentation operator
        spleen_seg_op = SpleenSegOperator()

        # Create DICOM Seg writer providing the required segment description for each segment with
        # the actual algorithm and the pertinent organ/tissue.
        # The segment_label, algorithm_name, and algorithm_version are limited to 64 chars.
        # https://dicom.nema.org/medical/dicom/current/output/chtml/part05/sect_6.2.html
        # User can Look up SNOMED CT codes at, e.g.
        # https://bioportal.bioontology.org/ontologies/SNOMEDCT

        _algorithm_name = "3D segmentation of the Spleen from a CT series"
        _algorithm_family = codes.DCM.ArtificialIntelligence
        _algorithm_version = "0.1.0"

        segment_descriptions = [
            SegmentDescription(
                segment_label="Lung",
                segmented_property_category=codes.SCT.Organ,
                segmented_property_type=codes.SCT.Lung,
                algorithm_name=_algorithm_name,
                algorithm_family=_algorithm_family,
                algorithm_version=_algorithm_version,
            ),
        ]

        dicom_seg_writer = DICOMSegmentationWriterOperator(segment_descriptions)

        # Create the processing pipeline, by specifying the source and destination operators, and
        # ensuring the output from the former matches the input of the latter, in both name and type.
        self.add_flow(study_loader_op, series_selector_op, {"dicom_study_list": "dicom_study_list"})
        self.add_flow(
            series_selector_op, series_to_vol_op, {"study_selected_series_list": "study_selected_series_list"}
        )
        self.add_flow(series_to_vol_op, spleen_seg_op, {"image": "image"})

        # Note below the dicom_seg_writer requires two inputs, each coming from a source operator.
        self.add_flow(
            series_selector_op, dicom_seg_writer, {"study_selected_series_list": "study_selected_series_list"}
        )
        self.add_flow(spleen_seg_op, dicom_seg_writer, {"seg_image": "seg_image"})

        viz_op = ClaraVizOperator()
        self.add_flow(series_to_vol_op, viz_op, {"image": "image"})
        self.add_flow(spleen_seg_op, viz_op, {"seg_image": "seg_image"})

        self._logger.debug(f"End {self.compose.__name__}")

# This is a sample series selection rule in JSON, simply selecting CT series.
# If the study has more than 1 CT series, then all of them will be selected.
# Please see more detail in DICOMSeriesSelectorOperator.
# For list of string values, e.g. "ImageType": ["PRIMARY", "ORIGINAL"], it is a match if all elements
# are all in the multi-value attribute of the DICOM series.

Sample_Rules_Text = """
{
    "selections": [
        {
            "name": "CT Series",
            "conditions": {
                "Modality": "(?i)CT",
                "ImageType": ["PRIMARY", "ORIGINAL"],
                "PhotometricInterpretation": "MONOCHROME2"
            }
        }
    ]
}
"""


if __name__ == "__main__":
    # Creates the app and test it standalone. When running is this mode, please note the following:
    #     -i <DICOM folder>, for input DICOM CT series folder
    #     -o <output folder>, for the output folder, default $PWD/output
    #     -m <model file>, for model file path
    # e.g.
    #     python3 app.py -i input -m model.ts
    #
    AISpleenSegApp(do_run=True)
Overwriting my_app/app.py
if __name__ == "__main__":
    AISpleenSegApp(do_run=True)

The above lines are needed to execute the application code by using python interpreter.

__main__.py

__main__.py is needed for MONAI Application Packager to detect the main application code (app.py) when the application is executed with the application folder path (e.g., python simple_imaging_app).

%%writefile my_app/__main__.py
from app import AISpleenSegApp

if __name__ == "__main__":
    AISpleenSegApp(do_run=True)
Overwriting my_app/__main__.py
!ls my_app
app.py	__main__.py  __pycache__  spleen_seg_operator.py

This time, let’s execute the app in the command line.

!python my_app -i dcm -o output -m model.ts
Going to initiate execution of operator DICOMDataLoaderOperator
Executing operator DICOMDataLoaderOperator (Process ID: 1084688, Operator ID: 110251db-4c50-42ff-a56d-32bd876bb739)
Done performing execution of operator DICOMDataLoaderOperator

Going to initiate execution of operator DICOMSeriesSelectorOperator
Executing operator DICOMSeriesSelectorOperator (Process ID: 1084688, Operator ID: 741a6c5d-8439-414f-b4f0-c499fa9f85a9)
[2022-10-18 21:34:15,156] [INFO] (root) - Finding series for Selection named: CT Series
[2022-10-18 21:34:15,156] [INFO] (root) - Searching study, : 1.3.6.1.4.1.14519.5.2.1.7085.2626.822645453932810382886582736291
  # of series: 1
[2022-10-18 21:34:15,156] [INFO] (root) - Working on series, instance UID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.119403521930927333027265674239
[2022-10-18 21:34:15,156] [INFO] (root) - On attribute: 'Modality' to match value: '(?i)CT'
[2022-10-18 21:34:15,156] [INFO] (root) -     Series attribute Modality value: CT
[2022-10-18 21:34:15,156] [INFO] (root) - Series attribute string value did not match. Try regEx.
[2022-10-18 21:34:15,157] [INFO] (root) - On attribute: 'ImageType' to match value: '['PRIMARY', 'ORIGINAL']'
[2022-10-18 21:34:15,157] [INFO] (root) -     Series attribute ImageType value: None
[2022-10-18 21:34:15,157] [INFO] (root) - On attribute: 'PhotometricInterpretation' to match value: 'MONOCHROME2'
[2022-10-18 21:34:15,157] [INFO] (root) -     Series attribute PhotometricInterpretation value: None
[2022-10-18 21:34:15,157] [INFO] (root) - Selected Series, UID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.119403521930927333027265674239
Done performing execution of operator DICOMSeriesSelectorOperator

Going to initiate execution of operator DICOMSeriesToVolumeOperator
Executing operator DICOMSeriesToVolumeOperator (Process ID: 1084688, Operator ID: ec4bd664-8c8c-4d7e-a466-cbb55f6abb05)
Done performing execution of operator DICOMSeriesToVolumeOperator

Going to initiate execution of operator SpleenSegOperator
Executing operator SpleenSegOperator (Process ID: 1084688, Operator ID: 3931c216-5bbb-4c64-8bee-7144e7de0f9e)
Converted Image object metadata:
SeriesInstanceUID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.119403521930927333027265674239, type <class 'str'>
SeriesDate: 20090831, type <class 'str'>
SeriesTime: 101721.452, type <class 'str'>
Modality: CT, type <class 'str'>
SeriesDescription: ABD/PANC 3.0 B31f, type <class 'str'>
PatientPosition: HFS, type <class 'str'>
SeriesNumber: 8, type <class 'int'>
row_pixel_spacing: 0.7890625, type <class 'float'>
col_pixel_spacing: 0.7890625, type <class 'float'>
depth_pixel_spacing: 1.5, type <class 'float'>
row_direction_cosine: [1.0, 0.0, 0.0], type <class 'list'>
col_direction_cosine: [0.0, 1.0, 0.0], type <class 'list'>
depth_direction_cosine: [0.0, 0.0, 1.0], type <class 'list'>
dicom_affine_transform: [[   0.7890625    0.           0.        -197.60547  ]
 [   0.           0.7890625    0.        -398.60547  ]
 [   0.           0.           1.5       -383.       ]
 [   0.           0.           0.           1.       ]], type <class 'numpy.ndarray'>
nifti_affine_transform: [[  -0.7890625   -0.          -0.         197.60547  ]
 [  -0.          -0.7890625   -0.         398.60547  ]
 [   0.           0.           1.5       -383.       ]
 [   0.           0.           0.           1.       ]], type <class 'numpy.ndarray'>
StudyInstanceUID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.822645453932810382886582736291, type <class 'str'>
StudyID: , type <class 'str'>
StudyDate: 20090831, type <class 'str'>
StudyTime: 095948.599, type <class 'str'>
StudyDescription: CT ABDOMEN W IV CONTRAST, type <class 'str'>
AccessionNumber: 5471978513296937, type <class 'str'>
selection_name: CT Series, type <class 'str'>
2022-10-18 21:34:29,110 INFO image_writer.py:194 - writing: /home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/output/prediction_output/1.3.6.1.4.1.14519.5.2.1.7085.2626/1.3.6.1.4.1.14519.5.2.1.7085.2626_seg.nii.gz
Output Seg image numpy array shaped: (204, 512, 512)
Output Seg image pixel max value: 1
Done performing execution of operator SpleenSegOperator

Going to initiate execution of operator DICOMSegmentationWriterOperator
Executing operator DICOMSegmentationWriterOperator (Process ID: 1084688, Operator ID: bb507096-3d21-419d-b746-cda41ab7a9f5)
/home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages/highdicom/valuerep.py:54: UserWarning: The string "C3N-00198" is unlikely to represent the intended person name since it contains only a single component. Construct a person name according to the format in described in http://dicom.nema.org/dicom/2013/output/chtml/part05/sect_6.2.html#sect_6.2.1.2, or, in pydicom 2.2.0 or later, use the pydicom.valuerep.PersonName.from_named_components() method to construct the person name correctly. If a single-component name is really intended, add a trailing caret character to disambiguate the name.
  warnings.warn(
[2022-10-18 21:34:32,757] [INFO] (highdicom.seg.sop) - add plane #0 for segment #1
[2022-10-18 21:34:32,758] [INFO] (highdicom.seg.sop) - add plane #1 for segment #1
[2022-10-18 21:34:32,760] [INFO] (highdicom.seg.sop) - add plane #2 for segment #1
[2022-10-18 21:34:32,760] [INFO] (highdicom.seg.sop) - add plane #3 for segment #1
[2022-10-18 21:34:32,761] [INFO] (highdicom.seg.sop) - add plane #4 for segment #1
[2022-10-18 21:34:32,762] [INFO] (highdicom.seg.sop) - add plane #5 for segment #1
[2022-10-18 21:34:32,763] [INFO] (highdicom.seg.sop) - add plane #6 for segment #1
[2022-10-18 21:34:32,764] [INFO] (highdicom.seg.sop) - add plane #7 for segment #1
[2022-10-18 21:34:32,765] [INFO] (highdicom.seg.sop) - add plane #8 for segment #1
[2022-10-18 21:34:32,766] [INFO] (highdicom.seg.sop) - add plane #9 for segment #1
[2022-10-18 21:34:32,767] [INFO] (highdicom.seg.sop) - add plane #10 for segment #1
[2022-10-18 21:34:32,768] [INFO] (highdicom.seg.sop) - add plane #11 for segment #1
[2022-10-18 21:34:32,769] [INFO] (highdicom.seg.sop) - add plane #12 for segment #1
[2022-10-18 21:34:32,770] [INFO] (highdicom.seg.sop) - add plane #13 for segment #1
[2022-10-18 21:34:32,772] [INFO] (highdicom.seg.sop) - add plane #14 for segment #1
[2022-10-18 21:34:32,773] [INFO] (highdicom.seg.sop) - add plane #15 for segment #1
[2022-10-18 21:34:32,774] [INFO] (highdicom.seg.sop) - add plane #16 for segment #1
[2022-10-18 21:34:32,774] [INFO] (highdicom.seg.sop) - add plane #17 for segment #1
[2022-10-18 21:34:32,775] [INFO] (highdicom.seg.sop) - add plane #18 for segment #1
[2022-10-18 21:34:32,776] [INFO] (highdicom.seg.sop) - add plane #19 for segment #1
[2022-10-18 21:34:32,777] [INFO] (highdicom.seg.sop) - add plane #20 for segment #1
[2022-10-18 21:34:32,778] [INFO] (highdicom.seg.sop) - add plane #21 for segment #1
[2022-10-18 21:34:32,779] [INFO] (highdicom.seg.sop) - add plane #22 for segment #1
[2022-10-18 21:34:32,780] [INFO] (highdicom.seg.sop) - add plane #23 for segment #1
[2022-10-18 21:34:32,781] [INFO] (highdicom.seg.sop) - add plane #24 for segment #1
[2022-10-18 21:34:32,782] [INFO] (highdicom.seg.sop) - add plane #25 for segment #1
[2022-10-18 21:34:32,783] [INFO] (highdicom.seg.sop) - add plane #26 for segment #1
[2022-10-18 21:34:32,784] [INFO] (highdicom.seg.sop) - add plane #27 for segment #1
[2022-10-18 21:34:32,785] [INFO] (highdicom.seg.sop) - add plane #28 for segment #1
[2022-10-18 21:34:32,786] [INFO] (highdicom.seg.sop) - add plane #29 for segment #1
[2022-10-18 21:34:32,787] [INFO] (highdicom.seg.sop) - add plane #30 for segment #1
[2022-10-18 21:34:32,788] [INFO] (highdicom.seg.sop) - add plane #31 for segment #1
[2022-10-18 21:34:32,789] [INFO] (highdicom.seg.sop) - add plane #32 for segment #1
[2022-10-18 21:34:32,790] [INFO] (highdicom.seg.sop) - add plane #33 for segment #1
[2022-10-18 21:34:32,791] [INFO] (highdicom.seg.sop) - add plane #34 for segment #1
[2022-10-18 21:34:32,792] [INFO] (highdicom.seg.sop) - add plane #35 for segment #1
[2022-10-18 21:34:32,793] [INFO] (highdicom.seg.sop) - add plane #36 for segment #1
[2022-10-18 21:34:32,794] [INFO] (highdicom.seg.sop) - add plane #37 for segment #1
[2022-10-18 21:34:32,795] [INFO] (highdicom.seg.sop) - add plane #38 for segment #1
[2022-10-18 21:34:32,796] [INFO] (highdicom.seg.sop) - add plane #39 for segment #1
[2022-10-18 21:34:32,797] [INFO] (highdicom.seg.sop) - add plane #40 for segment #1
[2022-10-18 21:34:32,798] [INFO] (highdicom.seg.sop) - add plane #41 for segment #1
[2022-10-18 21:34:32,799] [INFO] (highdicom.seg.sop) - add plane #42 for segment #1
[2022-10-18 21:34:32,800] [INFO] (highdicom.seg.sop) - add plane #43 for segment #1
[2022-10-18 21:34:32,801] [INFO] (highdicom.seg.sop) - add plane #44 for segment #1
[2022-10-18 21:34:32,802] [INFO] (highdicom.seg.sop) - add plane #45 for segment #1
[2022-10-18 21:34:32,803] [INFO] (highdicom.seg.sop) - add plane #46 for segment #1
[2022-10-18 21:34:32,804] [INFO] (highdicom.seg.sop) - add plane #47 for segment #1
[2022-10-18 21:34:32,805] [INFO] (highdicom.seg.sop) - add plane #48 for segment #1
[2022-10-18 21:34:32,806] [INFO] (highdicom.seg.sop) - add plane #49 for segment #1
[2022-10-18 21:34:32,807] [INFO] (highdicom.seg.sop) - add plane #50 for segment #1
[2022-10-18 21:34:32,808] [INFO] (highdicom.seg.sop) - add plane #51 for segment #1
[2022-10-18 21:34:32,810] [INFO] (highdicom.seg.sop) - add plane #52 for segment #1
[2022-10-18 21:34:32,811] [INFO] (highdicom.seg.sop) - add plane #53 for segment #1
[2022-10-18 21:34:32,812] [INFO] (highdicom.seg.sop) - add plane #54 for segment #1
[2022-10-18 21:34:32,813] [INFO] (highdicom.seg.sop) - add plane #55 for segment #1
[2022-10-18 21:34:32,814] [INFO] (highdicom.seg.sop) - add plane #56 for segment #1
[2022-10-18 21:34:32,815] [INFO] (highdicom.seg.sop) - add plane #57 for segment #1
[2022-10-18 21:34:32,816] [INFO] (highdicom.seg.sop) - add plane #58 for segment #1
[2022-10-18 21:34:32,817] [INFO] (highdicom.seg.sop) - add plane #59 for segment #1
[2022-10-18 21:34:32,818] [INFO] (highdicom.seg.sop) - add plane #60 for segment #1
[2022-10-18 21:34:32,819] [INFO] (highdicom.seg.sop) - add plane #61 for segment #1
[2022-10-18 21:34:32,820] [INFO] (highdicom.seg.sop) - add plane #62 for segment #1
[2022-10-18 21:34:32,821] [INFO] (highdicom.seg.sop) - add plane #63 for segment #1
[2022-10-18 21:34:32,822] [INFO] (highdicom.seg.sop) - add plane #64 for segment #1
[2022-10-18 21:34:32,823] [INFO] (highdicom.seg.sop) - add plane #65 for segment #1
[2022-10-18 21:34:32,824] [INFO] (highdicom.seg.sop) - add plane #66 for segment #1
[2022-10-18 21:34:32,825] [INFO] (highdicom.seg.sop) - add plane #67 for segment #1
[2022-10-18 21:34:32,826] [INFO] (highdicom.seg.sop) - add plane #68 for segment #1
[2022-10-18 21:34:32,828] [INFO] (highdicom.seg.sop) - add plane #69 for segment #1
[2022-10-18 21:34:32,829] [INFO] (highdicom.seg.sop) - add plane #70 for segment #1
[2022-10-18 21:34:32,830] [INFO] (highdicom.seg.sop) - add plane #71 for segment #1
[2022-10-18 21:34:32,831] [INFO] (highdicom.seg.sop) - add plane #72 for segment #1
[2022-10-18 21:34:32,832] [INFO] (highdicom.seg.sop) - add plane #73 for segment #1
[2022-10-18 21:34:32,833] [INFO] (highdicom.seg.sop) - add plane #74 for segment #1
[2022-10-18 21:34:32,834] [INFO] (highdicom.seg.sop) - add plane #75 for segment #1
[2022-10-18 21:34:32,835] [INFO] (highdicom.seg.sop) - add plane #76 for segment #1
[2022-10-18 21:34:32,836] [INFO] (highdicom.seg.sop) - add plane #77 for segment #1
[2022-10-18 21:34:32,837] [INFO] (highdicom.seg.sop) - add plane #78 for segment #1
[2022-10-18 21:34:32,839] [INFO] (highdicom.seg.sop) - add plane #79 for segment #1
[2022-10-18 21:34:32,840] [INFO] (highdicom.seg.sop) - add plane #80 for segment #1
[2022-10-18 21:34:32,841] [INFO] (highdicom.seg.sop) - add plane #81 for segment #1
[2022-10-18 21:34:32,842] [INFO] (highdicom.seg.sop) - add plane #82 for segment #1
[2022-10-18 21:34:32,843] [INFO] (highdicom.seg.sop) - add plane #83 for segment #1
[2022-10-18 21:34:32,844] [INFO] (highdicom.seg.sop) - add plane #84 for segment #1
[2022-10-18 21:34:32,845] [INFO] (highdicom.seg.sop) - add plane #85 for segment #1
[2022-10-18 21:34:32,846] [INFO] (highdicom.seg.sop) - add plane #86 for segment #1
[2022-10-18 21:34:32,847] [INFO] (highdicom.seg.sop) - add plane #87 for segment #1
[2022-10-18 21:34:32,931] [INFO] (highdicom.base) - copy Image-related attributes from dataset "1.3.6.1.4.1.14519.5.2.1.7085.2626.936983343951485811186213470191"
[2022-10-18 21:34:32,932] [INFO] (highdicom.base) - copy attributes of module "Specimen"
[2022-10-18 21:34:32,932] [INFO] (highdicom.base) - copy Patient-related attributes from dataset "1.3.6.1.4.1.14519.5.2.1.7085.2626.936983343951485811186213470191"
[2022-10-18 21:34:32,932] [INFO] (highdicom.base) - copy attributes of module "Patient"
[2022-10-18 21:34:32,932] [INFO] (highdicom.base) - copy attributes of module "Clinical Trial Subject"
[2022-10-18 21:34:32,932] [INFO] (highdicom.base) - copy Study-related attributes from dataset "1.3.6.1.4.1.14519.5.2.1.7085.2626.936983343951485811186213470191"
[2022-10-18 21:34:32,932] [INFO] (highdicom.base) - copy attributes of module "General Study"
[2022-10-18 21:34:32,932] [INFO] (highdicom.base) - copy attributes of module "Patient Study"
[2022-10-18 21:34:32,932] [INFO] (highdicom.base) - copy attributes of module "Clinical Trial Study"
Done performing execution of operator DICOMSegmentationWriterOperator

Going to initiate execution of operator ClaraVizOperator
Executing operator ClaraVizOperator (Process ID: 1084688, Operator ID: d0e9f72b-6ff8-454b-9550-21bd4263a16a)
Box(children=(Widget(), VBox(children=(interactive(children=(Dropdown(description='View mode', index=2, options=(('Cinematic', 'CINEMATIC'), ('Slice', 'SLICE'), ('Slice Segmentation', 'SLICE_SEGMENTATION')), value='SLICE_SEGMENTATION'), Output()), _dom_classes=('widget-interact',)), interactive(children=(Dropdown(description='Camera', options=('Top', 'Right', 'Front'), value='Top'), Output()), _dom_classes=('widget-interact',))))))
Done performing execution of operator ClaraVizOperator

Above command is same with the following command line:

import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'
!monai-deploy exec my_app -i dcm -o output -m model.ts
Going to initiate execution of operator DICOMDataLoaderOperator
Executing operator DICOMDataLoaderOperator (Process ID: 1084773, Operator ID: 21b70bc7-bd07-4803-936b-aadc983343c8)
Done performing execution of operator DICOMDataLoaderOperator

Going to initiate execution of operator DICOMSeriesSelectorOperator
Executing operator DICOMSeriesSelectorOperator (Process ID: 1084773, Operator ID: 06cf792c-e49a-4a84-b04b-c38ec1e2830a)
[2022-10-18 21:34:41,042] [INFO] (root) - Finding series for Selection named: CT Series
[2022-10-18 21:34:41,042] [INFO] (root) - Searching study, : 1.3.6.1.4.1.14519.5.2.1.7085.2626.822645453932810382886582736291
  # of series: 1
[2022-10-18 21:34:41,042] [INFO] (root) - Working on series, instance UID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.119403521930927333027265674239
[2022-10-18 21:34:41,042] [INFO] (root) - On attribute: 'Modality' to match value: '(?i)CT'
[2022-10-18 21:34:41,042] [INFO] (root) -     Series attribute Modality value: CT
[2022-10-18 21:34:41,043] [INFO] (root) - Series attribute string value did not match. Try regEx.
[2022-10-18 21:34:41,043] [INFO] (root) - On attribute: 'ImageType' to match value: '['PRIMARY', 'ORIGINAL']'
[2022-10-18 21:34:41,043] [INFO] (root) -     Series attribute ImageType value: None
[2022-10-18 21:34:41,043] [INFO] (root) - On attribute: 'PhotometricInterpretation' to match value: 'MONOCHROME2'
[2022-10-18 21:34:41,043] [INFO] (root) -     Series attribute PhotometricInterpretation value: None
[2022-10-18 21:34:41,043] [INFO] (root) - Selected Series, UID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.119403521930927333027265674239
Done performing execution of operator DICOMSeriesSelectorOperator

Going to initiate execution of operator DICOMSeriesToVolumeOperator
Executing operator DICOMSeriesToVolumeOperator (Process ID: 1084773, Operator ID: 8017c508-ab6a-4969-a7fc-13a963d60bf6)
Done performing execution of operator DICOMSeriesToVolumeOperator

Going to initiate execution of operator SpleenSegOperator
Executing operator SpleenSegOperator (Process ID: 1084773, Operator ID: 6acc2418-dbe4-4d34-bde0-aea2461b3acf)
Converted Image object metadata:
SeriesInstanceUID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.119403521930927333027265674239, type <class 'str'>
SeriesDate: 20090831, type <class 'str'>
SeriesTime: 101721.452, type <class 'str'>
Modality: CT, type <class 'str'>
SeriesDescription: ABD/PANC 3.0 B31f, type <class 'str'>
PatientPosition: HFS, type <class 'str'>
SeriesNumber: 8, type <class 'int'>
row_pixel_spacing: 0.7890625, type <class 'float'>
col_pixel_spacing: 0.7890625, type <class 'float'>
depth_pixel_spacing: 1.5, type <class 'float'>
row_direction_cosine: [1.0, 0.0, 0.0], type <class 'list'>
col_direction_cosine: [0.0, 1.0, 0.0], type <class 'list'>
depth_direction_cosine: [0.0, 0.0, 1.0], type <class 'list'>
dicom_affine_transform: [[   0.7890625    0.           0.        -197.60547  ]
 [   0.           0.7890625    0.        -398.60547  ]
 [   0.           0.           1.5       -383.       ]
 [   0.           0.           0.           1.       ]], type <class 'numpy.ndarray'>
nifti_affine_transform: [[  -0.7890625   -0.          -0.         197.60547  ]
 [  -0.          -0.7890625   -0.         398.60547  ]
 [   0.           0.           1.5       -383.       ]
 [   0.           0.           0.           1.       ]], type <class 'numpy.ndarray'>
StudyInstanceUID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.822645453932810382886582736291, type <class 'str'>
StudyID: , type <class 'str'>
StudyDate: 20090831, type <class 'str'>
StudyTime: 095948.599, type <class 'str'>
StudyDescription: CT ABDOMEN W IV CONTRAST, type <class 'str'>
AccessionNumber: 5471978513296937, type <class 'str'>
selection_name: CT Series, type <class 'str'>
2022-10-18 21:34:54,946 INFO image_writer.py:194 - writing: /home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/output/prediction_output/1.3.6.1.4.1.14519.5.2.1.7085.2626/1.3.6.1.4.1.14519.5.2.1.7085.2626_seg.nii.gz
Output Seg image numpy array shaped: (204, 512, 512)
Output Seg image pixel max value: 1
Done performing execution of operator SpleenSegOperator

Going to initiate execution of operator DICOMSegmentationWriterOperator
Executing operator DICOMSegmentationWriterOperator (Process ID: 1084773, Operator ID: 78e3d0dc-7d3d-4746-989c-8c1f9f749316)
/home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages/highdicom/valuerep.py:54: UserWarning: The string "C3N-00198" is unlikely to represent the intended person name since it contains only a single component. Construct a person name according to the format in described in http://dicom.nema.org/dicom/2013/output/chtml/part05/sect_6.2.html#sect_6.2.1.2, or, in pydicom 2.2.0 or later, use the pydicom.valuerep.PersonName.from_named_components() method to construct the person name correctly. If a single-component name is really intended, add a trailing caret character to disambiguate the name.
  warnings.warn(
[2022-10-18 21:34:58,408] [INFO] (highdicom.seg.sop) - add plane #0 for segment #1
[2022-10-18 21:34:58,409] [INFO] (highdicom.seg.sop) - add plane #1 for segment #1
[2022-10-18 21:34:58,410] [INFO] (highdicom.seg.sop) - add plane #2 for segment #1
[2022-10-18 21:34:58,411] [INFO] (highdicom.seg.sop) - add plane #3 for segment #1
[2022-10-18 21:34:58,412] [INFO] (highdicom.seg.sop) - add plane #4 for segment #1
[2022-10-18 21:34:58,413] [INFO] (highdicom.seg.sop) - add plane #5 for segment #1
[2022-10-18 21:34:58,414] [INFO] (highdicom.seg.sop) - add plane #6 for segment #1
[2022-10-18 21:34:58,415] [INFO] (highdicom.seg.sop) - add plane #7 for segment #1
[2022-10-18 21:34:58,416] [INFO] (highdicom.seg.sop) - add plane #8 for segment #1
[2022-10-18 21:34:58,417] [INFO] (highdicom.seg.sop) - add plane #9 for segment #1
[2022-10-18 21:34:58,418] [INFO] (highdicom.seg.sop) - add plane #10 for segment #1
[2022-10-18 21:34:58,419] [INFO] (highdicom.seg.sop) - add plane #11 for segment #1
[2022-10-18 21:34:58,420] [INFO] (highdicom.seg.sop) - add plane #12 for segment #1
[2022-10-18 21:34:58,421] [INFO] (highdicom.seg.sop) - add plane #13 for segment #1
[2022-10-18 21:34:58,422] [INFO] (highdicom.seg.sop) - add plane #14 for segment #1
[2022-10-18 21:34:58,423] [INFO] (highdicom.seg.sop) - add plane #15 for segment #1
[2022-10-18 21:34:58,424] [INFO] (highdicom.seg.sop) - add plane #16 for segment #1
[2022-10-18 21:34:58,425] [INFO] (highdicom.seg.sop) - add plane #17 for segment #1
[2022-10-18 21:34:58,426] [INFO] (highdicom.seg.sop) - add plane #18 for segment #1
[2022-10-18 21:34:58,427] [INFO] (highdicom.seg.sop) - add plane #19 for segment #1
[2022-10-18 21:34:58,428] [INFO] (highdicom.seg.sop) - add plane #20 for segment #1
[2022-10-18 21:34:58,429] [INFO] (highdicom.seg.sop) - add plane #21 for segment #1
[2022-10-18 21:34:58,430] [INFO] (highdicom.seg.sop) - add plane #22 for segment #1
[2022-10-18 21:34:58,431] [INFO] (highdicom.seg.sop) - add plane #23 for segment #1
[2022-10-18 21:34:58,432] [INFO] (highdicom.seg.sop) - add plane #24 for segment #1
[2022-10-18 21:34:58,434] [INFO] (highdicom.seg.sop) - add plane #25 for segment #1
[2022-10-18 21:34:58,435] [INFO] (highdicom.seg.sop) - add plane #26 for segment #1
[2022-10-18 21:34:58,436] [INFO] (highdicom.seg.sop) - add plane #27 for segment #1
[2022-10-18 21:34:58,437] [INFO] (highdicom.seg.sop) - add plane #28 for segment #1
[2022-10-18 21:34:58,438] [INFO] (highdicom.seg.sop) - add plane #29 for segment #1
[2022-10-18 21:34:58,439] [INFO] (highdicom.seg.sop) - add plane #30 for segment #1
[2022-10-18 21:34:58,440] [INFO] (highdicom.seg.sop) - add plane #31 for segment #1
[2022-10-18 21:34:58,441] [INFO] (highdicom.seg.sop) - add plane #32 for segment #1
[2022-10-18 21:34:58,442] [INFO] (highdicom.seg.sop) - add plane #33 for segment #1
[2022-10-18 21:34:58,443] [INFO] (highdicom.seg.sop) - add plane #34 for segment #1
[2022-10-18 21:34:58,444] [INFO] (highdicom.seg.sop) - add plane #35 for segment #1
[2022-10-18 21:34:58,445] [INFO] (highdicom.seg.sop) - add plane #36 for segment #1
[2022-10-18 21:34:58,446] [INFO] (highdicom.seg.sop) - add plane #37 for segment #1
[2022-10-18 21:34:58,447] [INFO] (highdicom.seg.sop) - add plane #38 for segment #1
[2022-10-18 21:34:58,448] [INFO] (highdicom.seg.sop) - add plane #39 for segment #1
[2022-10-18 21:34:58,449] [INFO] (highdicom.seg.sop) - add plane #40 for segment #1
[2022-10-18 21:34:58,450] [INFO] (highdicom.seg.sop) - add plane #41 for segment #1
[2022-10-18 21:34:58,451] [INFO] (highdicom.seg.sop) - add plane #42 for segment #1
[2022-10-18 21:34:58,452] [INFO] (highdicom.seg.sop) - add plane #43 for segment #1
[2022-10-18 21:34:58,453] [INFO] (highdicom.seg.sop) - add plane #44 for segment #1
[2022-10-18 21:34:58,454] [INFO] (highdicom.seg.sop) - add plane #45 for segment #1
[2022-10-18 21:34:58,455] [INFO] (highdicom.seg.sop) - add plane #46 for segment #1
[2022-10-18 21:34:58,456] [INFO] (highdicom.seg.sop) - add plane #47 for segment #1
[2022-10-18 21:34:58,457] [INFO] (highdicom.seg.sop) - add plane #48 for segment #1
[2022-10-18 21:34:58,458] [INFO] (highdicom.seg.sop) - add plane #49 for segment #1
[2022-10-18 21:34:58,459] [INFO] (highdicom.seg.sop) - add plane #50 for segment #1
[2022-10-18 21:34:58,460] [INFO] (highdicom.seg.sop) - add plane #51 for segment #1
[2022-10-18 21:34:58,461] [INFO] (highdicom.seg.sop) - add plane #52 for segment #1
[2022-10-18 21:34:58,462] [INFO] (highdicom.seg.sop) - add plane #53 for segment #1
[2022-10-18 21:34:58,463] [INFO] (highdicom.seg.sop) - add plane #54 for segment #1
[2022-10-18 21:34:58,465] [INFO] (highdicom.seg.sop) - add plane #55 for segment #1
[2022-10-18 21:34:58,466] [INFO] (highdicom.seg.sop) - add plane #56 for segment #1
[2022-10-18 21:34:58,467] [INFO] (highdicom.seg.sop) - add plane #57 for segment #1
[2022-10-18 21:34:58,468] [INFO] (highdicom.seg.sop) - add plane #58 for segment #1
[2022-10-18 21:34:58,469] [INFO] (highdicom.seg.sop) - add plane #59 for segment #1
[2022-10-18 21:34:58,470] [INFO] (highdicom.seg.sop) - add plane #60 for segment #1
[2022-10-18 21:34:58,471] [INFO] (highdicom.seg.sop) - add plane #61 for segment #1
[2022-10-18 21:34:58,472] [INFO] (highdicom.seg.sop) - add plane #62 for segment #1
[2022-10-18 21:34:58,473] [INFO] (highdicom.seg.sop) - add plane #63 for segment #1
[2022-10-18 21:34:58,474] [INFO] (highdicom.seg.sop) - add plane #64 for segment #1
[2022-10-18 21:34:58,475] [INFO] (highdicom.seg.sop) - add plane #65 for segment #1
[2022-10-18 21:34:58,476] [INFO] (highdicom.seg.sop) - add plane #66 for segment #1
[2022-10-18 21:34:58,477] [INFO] (highdicom.seg.sop) - add plane #67 for segment #1
[2022-10-18 21:34:58,478] [INFO] (highdicom.seg.sop) - add plane #68 for segment #1
[2022-10-18 21:34:58,479] [INFO] (highdicom.seg.sop) - add plane #69 for segment #1
[2022-10-18 21:34:58,480] [INFO] (highdicom.seg.sop) - add plane #70 for segment #1
[2022-10-18 21:34:58,481] [INFO] (highdicom.seg.sop) - add plane #71 for segment #1
[2022-10-18 21:34:58,482] [INFO] (highdicom.seg.sop) - add plane #72 for segment #1
[2022-10-18 21:34:58,484] [INFO] (highdicom.seg.sop) - add plane #73 for segment #1
[2022-10-18 21:34:58,485] [INFO] (highdicom.seg.sop) - add plane #74 for segment #1
[2022-10-18 21:34:58,486] [INFO] (highdicom.seg.sop) - add plane #75 for segment #1
[2022-10-18 21:34:58,487] [INFO] (highdicom.seg.sop) - add plane #76 for segment #1
[2022-10-18 21:34:58,488] [INFO] (highdicom.seg.sop) - add plane #77 for segment #1
[2022-10-18 21:34:58,489] [INFO] (highdicom.seg.sop) - add plane #78 for segment #1
[2022-10-18 21:34:58,490] [INFO] (highdicom.seg.sop) - add plane #79 for segment #1
[2022-10-18 21:34:58,491] [INFO] (highdicom.seg.sop) - add plane #80 for segment #1
[2022-10-18 21:34:58,492] [INFO] (highdicom.seg.sop) - add plane #81 for segment #1
[2022-10-18 21:34:58,493] [INFO] (highdicom.seg.sop) - add plane #82 for segment #1
[2022-10-18 21:34:58,494] [INFO] (highdicom.seg.sop) - add plane #83 for segment #1
[2022-10-18 21:34:58,495] [INFO] (highdicom.seg.sop) - add plane #84 for segment #1
[2022-10-18 21:34:58,497] [INFO] (highdicom.seg.sop) - add plane #85 for segment #1
[2022-10-18 21:34:58,498] [INFO] (highdicom.seg.sop) - add plane #86 for segment #1
[2022-10-18 21:34:58,499] [INFO] (highdicom.seg.sop) - add plane #87 for segment #1
[2022-10-18 21:34:58,546] [INFO] (highdicom.base) - copy Image-related attributes from dataset "1.3.6.1.4.1.14519.5.2.1.7085.2626.936983343951485811186213470191"
[2022-10-18 21:34:58,546] [INFO] (highdicom.base) - copy attributes of module "Specimen"
[2022-10-18 21:34:58,546] [INFO] (highdicom.base) - copy Patient-related attributes from dataset "1.3.6.1.4.1.14519.5.2.1.7085.2626.936983343951485811186213470191"
[2022-10-18 21:34:58,546] [INFO] (highdicom.base) - copy attributes of module "Patient"
[2022-10-18 21:34:58,547] [INFO] (highdicom.base) - copy attributes of module "Clinical Trial Subject"
[2022-10-18 21:34:58,547] [INFO] (highdicom.base) - copy Study-related attributes from dataset "1.3.6.1.4.1.14519.5.2.1.7085.2626.936983343951485811186213470191"
[2022-10-18 21:34:58,547] [INFO] (highdicom.base) - copy attributes of module "General Study"
[2022-10-18 21:34:58,547] [INFO] (highdicom.base) - copy attributes of module "Patient Study"
[2022-10-18 21:34:58,547] [INFO] (highdicom.base) - copy attributes of module "Clinical Trial Study"
Done performing execution of operator DICOMSegmentationWriterOperator

Going to initiate execution of operator ClaraVizOperator
Executing operator ClaraVizOperator (Process ID: 1084773, Operator ID: 0b5669c2-1eba-4dfb-aca7-f5059aa855d9)
Box(children=(Widget(), VBox(children=(interactive(children=(Dropdown(description='View mode', index=2, options=(('Cinematic', 'CINEMATIC'), ('Slice', 'SLICE'), ('Slice Segmentation', 'SLICE_SEGMENTATION')), value='SLICE_SEGMENTATION'), Output()), _dom_classes=('widget-interact',)), interactive(children=(Dropdown(description='Camera', options=('Top', 'Right', 'Front'), value='Top'), Output()), _dom_classes=('widget-interact',))))))
Done performing execution of operator ClaraVizOperator

!ls output
1.2.826.0.1.3680043.10.511.3.11636838214613793635775978376672891.dcm
1.2.826.0.1.3680043.10.511.3.17384468917290596349831996191635582.dcm
1.2.826.0.1.3680043.10.511.3.55545104192889656878608836519404425.dcm
prediction_output

Packaging app

Clara-Viz operators added in an application are used for interactive visualization, so the application shall not be packaged with MONAI Application Packager.