Creating a Segmentation App with MONAI Deploy App SDK

This tutorial shows how to create an organ segmentation application for a PyTorch model that has been trained with MONAI. Please note that this one does not require the model be a MONAI Bundle.

Deploying AI models requires the integration with clinical imaging network, even if just in a for-research-use setting. This means that the AI deploy application will need to support standards-based imaging protocols, and specifically for Radiological imaging, DICOM protocol.

Typically, DICOM network communication, either in DICOM TCP/IP network protocol or DICOMWeb, would be handled by DICOM devices or services, e.g. MONAI Deploy Informatics Gateway, so the deploy application itself would only need to use DICOM Part 10 files as input and save the AI result in DICOM Part10 file(s). For segmentation use cases, the DICOM instance file for AI results could be a DICOM Segmentation object or a DICOM RT Structure Set, and for classification, DICOM Structure Report and/or DICOM Encapsulated PDF.

During model training, input and label images are typically in non-DICOM volumetric image format, e.g., NIfTI and PNG, converted from a specific DICOM study series. Furthermore, the voxel spacings most likely have been re-sampled to be uniform for all images. When integrated with imaging networks and receiving DICOM instances from modalities and Picture Archiving and Communications System, PACS, an AI deploy application has to deal with a whole DICOM study with multiple series, whose images’ spacing may not be the same as expected by the trained model. To address these cases consistently and efficiently, MONAI Deploy Application SDK provides classes, called operators, to parse DICOM studies, select specific series with application-defined rules, and convert the selected DICOM series into domain-specific image format along with meta-data representing the pertinent DICOM attributes. The image is then further processed in the pre-processing stage to normalize spacing, orientation, intensity, etc., before pixel data as Tensors are used for inference.

In the following sections, we will demonstrate how to create a MONAI Deploy application package using the MONAI Deploy App SDK.

Note

For local testing, if there is a lack of DICOM Part 10 files, one can use open source programs, e.g. 3D Slicer, to convert NIfTI to DICOM files.

Creating Operators and connecting them in Application class

We will implement an application that consists of five Operators:

  • DICOMDataLoaderOperator:

    • Input(dicom_files): a folder path (Path)

    • Output(dicom_study_list): a list of DICOM studies in memory (List[DICOMStudy])

  • DICOMSeriesSelectorOperator:

    • Input(dicom_study_list): a list of DICOM studies in memory (List[DICOMStudy])

    • Input(selection_rules): a selection rule (Dict)

    • Output(study_selected_series_list): a DICOM series object in memory (StudySelectedSeries)

  • DICOMSeriesToVolumeOperator:

    • Input(study_selected_series_list): a DICOM series object in memory (StudySelectedSeries)

    • Output(image): an image object in memory (Image)

  • SpleenSegOperator:

    • Input(image): an image object in memory (Image)

    • Output(seg_image): an image object in memory (Image)

  • DICOMSegmentationWriterOperator:

    • Input(seg_image): a segmentation image object in memory (Image)

    • Input(study_selected_series_list): a DICOM series object in memory (StudySelectedSeries)

    • Output(dicom_seg_instance): a file path (Path)

Note

The DICOMSegmentationWriterOperator needs both the segmentation image as well as the original DICOM series meta-data in order to use the patient demographics and the DICOM Study level attributes.

The workflow of the application would look like this.

%%{init: {"theme": "base", "themeVariables": { "fontSize": "16px"}} }%% classDiagram direction TB DICOMDataLoaderOperator --|> DICOMSeriesSelectorOperator : dicom_study_list...dicom_study_list DICOMSeriesSelectorOperator --|> DICOMSeriesToVolumeOperator : study_selected_series_list...study_selected_series_list DICOMSeriesToVolumeOperator --|> SpleenSegOperator : image...image DICOMSeriesSelectorOperator --|> DICOMSegmentationWriterOperator : study_selected_series_list...study_selected_series_list SpleenSegOperator --|> DICOMSegmentationWriterOperator : seg_image...seg_image class DICOMDataLoaderOperator { <in>dicom_files : DISK dicom_study_list(out) IN_MEMORY } class DICOMSeriesSelectorOperator { <in>dicom_study_list : IN_MEMORY <in>selection_rules : IN_MEMORY study_selected_series_list(out) IN_MEMORY } class DICOMSeriesToVolumeOperator { <in>study_selected_series_list : IN_MEMORY image(out) IN_MEMORY } class SpleenSegOperator { <in>image : IN_MEMORY seg_image(out) IN_MEMORY } class DICOMSegmentationWriterOperator { <in>seg_image : IN_MEMORY <in>study_selected_series_list : IN_MEMORY dicom_seg_instance(out) DISK }

Setup environment

# Install MONAI and other necessary image processing packages for the application
!python -c "import monai" || pip install --upgrade -q "monai"
!python -c "import torch" || pip install -q "torch>=1.10.2"
!python -c "import numpy" || pip install -q "numpy>=1.21"
!python -c "import nibabel" || pip install -q "nibabel>=3.2.1"
!python -c "import pydicom" || pip install -q "pydicom>=1.4.2"
!python -c "import highdicom" || pip install -q "highdicom>=0.18.2"
!python -c "import SimpleITK" || pip install -q "SimpleITK>=2.0.0"

# Install MONAI Deploy App SDK package
!python -c "import monai.deploy" || pip install --upgrade "monai-deploy-app-sdk"

Note: you may need to restart the Jupyter kernel to use the updated packages.

Download/Extract ai_spleen_bundle_data from Google Drive

# Download ai_spleen_bundle_data test data zip file
!pip install gdown 
!gdown "https://drive.google.com/uc?id=1Uds8mEvdGNYUuvFpTtCQ8gNU97bAPCaQ"

# After downloading ai_spleen_bundle_data zip file from the web browser or using gdown,
!unzip -o "ai_spleen_seg_bundle_data.zip"

# Need to copy the model.ts file to its own clean subfolder for packaging, to work around an issue in the Packager
models_folder = "models"
!rm -rf {models_folder} && mkdir -p {models_folder}/model && cp model.ts {models_folder}/model && ls {models_folder}/model
Requirement already satisfied: gdown in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.10/site-packages (5.1.0)
Requirement already satisfied: beautifulsoup4 in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.10/site-packages (from gdown) (4.12.3)
Requirement already satisfied: filelock in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.10/site-packages (from gdown) (3.13.3)
Requirement already satisfied: requests[socks] in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.10/site-packages (from gdown) (2.28.2)
Requirement already satisfied: tqdm in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.10/site-packages (from gdown) (4.66.2)
Requirement already satisfied: soupsieve>1.2 in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.10/site-packages (from beautifulsoup4->gdown) (2.5)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.10/site-packages (from requests[socks]->gdown) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.10/site-packages (from requests[socks]->gdown) (3.6)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.10/site-packages (from requests[socks]->gdown) (1.26.18)
Requirement already satisfied: certifi>=2017.4.17 in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.10/site-packages (from requests[socks]->gdown) (2024.2.2)
Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.10/site-packages (from requests[socks]->gdown) (1.7.1)

[notice] A new release of pip is available: 23.3.2 -> 24.0
[notice] To update, run: pip install --upgrade pip
Downloading...
From (original): https://drive.google.com/uc?id=1Uds8mEvdGNYUuvFpTtCQ8gNU97bAPCaQ
From (redirected): https://drive.google.com/uc?id=1Uds8mEvdGNYUuvFpTtCQ8gNU97bAPCaQ&confirm=t&uuid=03efbee4-6b67-4413-8b8e-522d9c7cc472
To: /home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/ai_spleen_seg_bundle_data.zip
100%|███████████████████████████████████████| 79.4M/79.4M [00:00<00:00, 101MB/s]
Archive:  ai_spleen_seg_bundle_data.zip
  inflating: dcm/1-001.dcm           
  inflating: dcm/1-002.dcm           
  inflating: dcm/1-003.dcm           
  inflating: dcm/1-004.dcm           
  inflating: dcm/1-005.dcm           
  inflating: dcm/1-006.dcm           
  inflating: dcm/1-007.dcm           
  inflating: dcm/1-008.dcm           
  inflating: dcm/1-009.dcm           
  inflating: dcm/1-010.dcm           
  inflating: dcm/1-011.dcm           
  inflating: dcm/1-012.dcm           
  inflating: dcm/1-013.dcm           
  inflating: dcm/1-014.dcm           
  inflating: dcm/1-015.dcm           
  inflating: dcm/1-016.dcm           
  inflating: dcm/1-017.dcm           
  inflating: dcm/1-018.dcm           
  inflating: dcm/1-019.dcm           
  inflating: dcm/1-020.dcm           
  inflating: dcm/1-021.dcm           
  inflating: dcm/1-022.dcm           
  inflating: dcm/1-023.dcm           
  inflating: dcm/1-024.dcm           
  inflating: dcm/1-025.dcm           
  inflating: dcm/1-026.dcm           
  inflating: dcm/1-027.dcm           
  inflating: dcm/1-028.dcm           
  inflating: dcm/1-029.dcm           
  inflating: dcm/1-030.dcm           
  inflating: dcm/1-031.dcm           
  inflating: dcm/1-032.dcm           
  inflating: dcm/1-033.dcm           
  inflating: dcm/1-034.dcm           
  inflating: dcm/1-035.dcm           
  inflating: dcm/1-036.dcm           
  inflating: dcm/1-037.dcm           
  inflating: dcm/1-038.dcm           
  inflating: dcm/1-039.dcm           
  inflating: dcm/1-040.dcm           
  inflating: dcm/1-041.dcm           
  inflating: dcm/1-042.dcm           
  inflating: dcm/1-043.dcm           
  inflating: dcm/1-044.dcm           
  inflating: dcm/1-045.dcm           
  inflating: dcm/1-046.dcm           
  inflating: dcm/1-047.dcm           
  inflating: dcm/1-048.dcm           
  inflating: dcm/1-049.dcm           
  inflating: dcm/1-050.dcm           
  inflating: dcm/1-051.dcm           
  inflating: dcm/1-052.dcm           
  inflating: dcm/1-053.dcm           
  inflating: dcm/1-054.dcm           
  inflating: dcm/1-055.dcm           
  inflating: dcm/1-056.dcm           
  inflating: dcm/1-057.dcm           
  inflating: dcm/1-058.dcm           
  inflating: dcm/1-059.dcm           
  inflating: dcm/1-060.dcm           
  inflating: dcm/1-061.dcm           
  inflating: dcm/1-062.dcm           
  inflating: dcm/1-063.dcm           
  inflating: dcm/1-064.dcm           
  inflating: dcm/1-065.dcm           
  inflating: dcm/1-066.dcm           
  inflating: dcm/1-067.dcm           
  inflating: dcm/1-068.dcm           
  inflating: dcm/1-069.dcm           
  inflating: dcm/1-070.dcm           
  inflating: dcm/1-071.dcm           
  inflating: dcm/1-072.dcm           
  inflating: dcm/1-073.dcm           
  inflating: dcm/1-074.dcm           
  inflating: dcm/1-075.dcm           
  inflating: dcm/1-076.dcm           
  inflating: dcm/1-077.dcm           
  inflating: dcm/1-078.dcm           
  inflating: dcm/1-079.dcm           
  inflating: dcm/1-080.dcm           
  inflating: dcm/1-081.dcm           
  inflating: dcm/1-082.dcm           
  inflating: dcm/1-083.dcm           
  inflating: dcm/1-084.dcm           
  inflating: dcm/1-085.dcm           
  inflating: dcm/1-086.dcm           
  inflating: dcm/1-087.dcm           
  inflating: dcm/1-088.dcm           
  inflating: dcm/1-089.dcm           
  inflating: dcm/1-090.dcm           
  inflating: dcm/1-091.dcm           
  inflating: dcm/1-092.dcm           
  inflating: dcm/1-093.dcm           
  inflating: dcm/1-094.dcm           
  inflating: dcm/1-095.dcm           
  inflating: dcm/1-096.dcm           
  inflating: dcm/1-097.dcm           
  inflating: dcm/1-098.dcm           
  inflating: dcm/1-099.dcm           
  inflating: dcm/1-100.dcm           
  inflating: dcm/1-101.dcm           
  inflating: dcm/1-102.dcm           
  inflating: dcm/1-103.dcm           
  inflating: dcm/1-104.dcm           
  inflating: dcm/1-105.dcm           
  inflating: dcm/1-106.dcm           
  inflating: dcm/1-107.dcm           
  inflating: dcm/1-108.dcm           
  inflating: dcm/1-109.dcm           
  inflating: dcm/1-110.dcm           
  inflating: dcm/1-111.dcm           
  inflating: dcm/1-112.dcm           
  inflating: dcm/1-113.dcm           
  inflating: dcm/1-114.dcm           
  inflating: dcm/1-115.dcm           
  inflating: dcm/1-116.dcm           
  inflating: dcm/1-117.dcm           
  inflating: dcm/1-118.dcm           
  inflating: dcm/1-119.dcm           
  inflating: dcm/1-120.dcm           
  inflating: dcm/1-121.dcm           
  inflating: dcm/1-122.dcm           
  inflating: dcm/1-123.dcm           
  inflating: dcm/1-124.dcm           
  inflating: dcm/1-125.dcm           
  inflating: dcm/1-126.dcm           
  inflating: dcm/1-127.dcm           
  inflating: dcm/1-128.dcm           
  inflating: dcm/1-129.dcm           
  inflating: dcm/1-130.dcm           
  inflating: dcm/1-131.dcm           
  inflating: dcm/1-132.dcm           
  inflating: dcm/1-133.dcm           
  inflating: dcm/1-134.dcm           
  inflating: dcm/1-135.dcm           
  inflating: dcm/1-136.dcm           
  inflating: dcm/1-137.dcm           
  inflating: dcm/1-138.dcm           
  inflating: dcm/1-139.dcm           
  inflating: dcm/1-140.dcm           
  inflating: dcm/1-141.dcm           
  inflating: dcm/1-142.dcm           
  inflating: dcm/1-143.dcm           
  inflating: dcm/1-144.dcm           
  inflating: dcm/1-145.dcm           
  inflating: dcm/1-146.dcm           
  inflating: dcm/1-147.dcm           
  inflating: dcm/1-148.dcm           
  inflating: dcm/1-149.dcm           
  inflating: dcm/1-150.dcm           
  inflating: dcm/1-151.dcm           
  inflating: dcm/1-152.dcm           
  inflating: dcm/1-153.dcm           
  inflating: dcm/1-154.dcm           
  inflating: dcm/1-155.dcm           
  inflating: dcm/1-156.dcm           
  inflating: dcm/1-157.dcm           
  inflating: dcm/1-158.dcm           
  inflating: dcm/1-159.dcm           
  inflating: dcm/1-160.dcm           
  inflating: dcm/1-161.dcm           
  inflating: dcm/1-162.dcm           
  inflating: dcm/1-163.dcm           
  inflating: dcm/1-164.dcm           
  inflating: dcm/1-165.dcm           
  inflating: dcm/1-166.dcm           
  inflating: dcm/1-167.dcm           
  inflating: dcm/1-168.dcm           
  inflating: dcm/1-169.dcm           
  inflating: dcm/1-170.dcm           
  inflating: dcm/1-171.dcm           
  inflating: dcm/1-172.dcm           
  inflating: dcm/1-173.dcm           
  inflating: dcm/1-174.dcm           
  inflating: dcm/1-175.dcm           
  inflating: dcm/1-176.dcm           
  inflating: dcm/1-177.dcm           
  inflating: dcm/1-178.dcm           
  inflating: dcm/1-179.dcm           
  inflating: dcm/1-180.dcm           
  inflating: dcm/1-181.dcm           
  inflating: dcm/1-182.dcm           
  inflating: dcm/1-183.dcm           
  inflating: dcm/1-184.dcm           
  inflating: dcm/1-185.dcm           
  inflating: dcm/1-186.dcm           
  inflating: dcm/1-187.dcm           
  inflating: dcm/1-188.dcm           
  inflating: dcm/1-189.dcm           
  inflating: dcm/1-190.dcm           
  inflating: dcm/1-191.dcm           
  inflating: dcm/1-192.dcm           
  inflating: dcm/1-193.dcm           
  inflating: dcm/1-194.dcm           
  inflating: dcm/1-195.dcm           
  inflating: dcm/1-196.dcm           
  inflating: dcm/1-197.dcm           
  inflating: dcm/1-198.dcm           
  inflating: dcm/1-199.dcm           
  inflating: dcm/1-200.dcm           
  inflating: dcm/1-201.dcm           
  inflating: dcm/1-202.dcm           
  inflating: dcm/1-203.dcm           
  inflating: dcm/1-204.dcm           
  inflating: model.ts                
model.ts
%env HOLOSCAN_INPUT_PATH dcm
%env HOLOSCAN_MODEL_PATH {models_folder}
%env HOLOSCAN_OUTPUT_PATH output
env: HOLOSCAN_INPUT_PATH=dcm
env: HOLOSCAN_MODEL_PATH=models
env: HOLOSCAN_OUTPUT_PATH=output

Setup imports

Let’s import necessary classes/decorators to define Application and Operator.

import logging
from numpy import uint8  # Needed if SaveImaged is enabled
from pathlib import Path

# Required for setting SegmentDescription attributes. Direct import as this is not part of App SDK package.
from pydicom.sr.codedict import codes

from monai.deploy.conditions import CountCondition
from monai.deploy.core import AppContext, Application, ConditionType, Fragment, Operator, OperatorSpec
from monai.deploy.core.domain import Image
from monai.deploy.core.io_type import IOType
from monai.deploy.operators.dicom_data_loader_operator import DICOMDataLoaderOperator
from monai.deploy.operators.dicom_seg_writer_operator import DICOMSegmentationWriterOperator, SegmentDescription
from monai.deploy.operators.dicom_series_selector_operator import DICOMSeriesSelectorOperator
from monai.deploy.operators.dicom_series_to_volume_operator import DICOMSeriesToVolumeOperator
from monai.deploy.operators.monai_seg_inference_operator import InfererType, InMemImageReader, MonaiSegInferenceOperator

from monai.transforms import (
    Activationsd,
    AsDiscreted,
    Compose,
    EnsureChannelFirstd,
    EnsureTyped,
    Invertd,
    LoadImaged,
    Orientationd,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
)

Creating Model Specific Inference Operator classes

Each Operator class inherits the base Operator class. The input/output properties are specified by implementing the setup() method, and the business logic implemented in the compute() method.

The App SDK provides a MonaiSegInferenceOperator class to perform segmentation prediction with a Torch Script model. For consistency, this class uses MONAI dictionary-based transforms, as Compose object, for pre and post transforms. The model-specific inference operator will then only need to create the pre and post transform Compose based on what has been used in the model during training and validation. Note that for deploy application, ignite is not needed nor supported.

SpleenSegOperator

The SpleenSegOperator gets as input an in-memory Image object that has been converted from a DICOM CT series by the preceding DICOMSeriesToVolumeOperator, and as output in-memory segmentation Image object.

The pre_process function creates the pre-transforms Compose object. For LoadImage, a specialized InMemImageReader, derived from MONAI ImageReader, is used to convert the in-memory pixel data and return the numpy array as well as the meta-data. Also, the DICOM input pixel spacings are often not the same as expected by the model, so the Spacingd transform must be used to re-sample the image with the expected spacing.

The post_process function creates the post-transform Compose object. The SaveImageD transform class is used to save the segmentation mask as NIfTI image file, which is optional as the in-memory mask image will be passed down to the DICOM Segmentation writer for creating a DICOM Segmentation instance. The Invertd must also be used to revert the segmentation image’s orientation and spacing to be the same as the input.

When the MonaiSegInferenceOperator object is created, the ROI size is specified, as well as the transform Compose objects. Furthermore, the dataset image key names are set accordingly.

Loading of the model and performing the prediction are encapsulated in the MonaiSegInferenceOperator and other SDK classes. Once the inference is completed, the segmentation Image object is created and set to the output by the SpleenSegOperator.

class SpleenSegOperator(Operator):
    """Performs Spleen segmentation with a 3D image converted from a DICOM CT series.
    """

    DEFAULT_OUTPUT_FOLDER = Path.cwd() / "output/saved_images_folder"

    def __init__(
        self,
        fragment: Fragment,
        *args,
        app_context: AppContext,
        model_path: Path,
        output_folder: Path = DEFAULT_OUTPUT_FOLDER,
        **kwargs,
    ):

        self.logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
        self._input_dataset_key = "image"
        self._pred_dataset_key = "pred"

        self.model_path = model_path
        self.output_folder = output_folder
        self.output_folder.mkdir(parents=True, exist_ok=True)
        self.app_context = app_context
        self.input_name_image = "image"
        self.output_name_seg = "seg_image"
        self.output_name_saved_images_folder = "saved_images_folder"

        # The base class has an attribute called fragment to hold the reference to the fragment object
        super().__init__(fragment, *args, **kwargs)

    def setup(self, spec: OperatorSpec):
        spec.input(self.input_name_image)
        spec.output(self.output_name_seg)
        spec.output(self.output_name_saved_images_folder).condition(
            ConditionType.NONE
        )  # Output not requiring a receiver

    def compute(self, op_input, op_output, context):
        input_image = op_input.receive(self.input_name_image)
        if not input_image:
            raise ValueError("Input image is not found.")

        # This operator gets an in-memory Image object, so a specialized ImageReader is needed.
        _reader = InMemImageReader(input_image)

        pre_transforms = self.pre_process(_reader, str(self.output_folder))
        post_transforms = self.post_process(pre_transforms, str(self.output_folder))

        # Delegates inference and saving output to the built-in operator.
        infer_operator = MonaiSegInferenceOperator(
            self.fragment,
            roi_size=(
                96,
                96,
                96,
            ),
            pre_transforms=pre_transforms,
            post_transforms=post_transforms,
            overlap=0.6,
            app_context=self.app_context,
            model_name="",
            inferer=InfererType.SLIDING_WINDOW,
            sw_batch_size=4,
            model_path=self.model_path,
            name="monai_seg_inference_op",
        )

        # Setting the keys used in the dictionary based transforms may change.
        infer_operator.input_dataset_key = self._input_dataset_key
        infer_operator.pred_dataset_key = self._pred_dataset_key

        # Now emit data to the output ports of this operator
        op_output.emit(infer_operator.compute_impl(input_image, context), self.output_name_seg)
        op_output.emit(self.output_folder, self.output_name_saved_images_folder)

    def pre_process(self, img_reader, out_dir: str = "./input_images") -> Compose:
        """Composes transforms for preprocessing input before predicting on a model."""

        Path(out_dir).mkdir(parents=True, exist_ok=True)
        my_key = self._input_dataset_key

        return Compose(
            [
                LoadImaged(keys=my_key, reader=img_reader),
                EnsureChannelFirstd(keys=my_key),
                # The SaveImaged transform can be commented out to save 5 seconds.
                # Uncompress NIfTI file, nii, is used favoring speed over size, but can be changed to nii.gz
                SaveImaged(
                    keys=my_key,
                    output_dir=out_dir,
                    output_postfix="",
                    resample=False,
                    output_ext=".nii",
                ),
                Orientationd(keys=my_key, axcodes="RAS"),
                Spacingd(keys=my_key, pixdim=[1.5, 1.5, 2.9], mode=["bilinear"]),
                ScaleIntensityRanged(keys=my_key, a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
                EnsureTyped(keys=my_key),
            ]
        )

    def post_process(self, pre_transforms: Compose, out_dir: str = "./prediction_output") -> Compose:
        """Composes transforms for postprocessing the prediction results."""

        Path(out_dir).mkdir(parents=True, exist_ok=True)
        pred_key = self._pred_dataset_key

        return Compose(
            [
                Activationsd(keys=pred_key, softmax=True),
                Invertd(
                    keys=pred_key,
                    transform=pre_transforms,
                    orig_keys=self._input_dataset_key,
                    nearest_interp=False,
                    to_tensor=True,
                ),
                AsDiscreted(keys=pred_key, argmax=True),
                # The SaveImaged transform can be commented out to save 5 seconds.
                # Uncompress NIfTI file, nii, is used favoring speed over size, but can be changed to nii.gz
                SaveImaged(
                    keys=pred_key,
                    output_dir=out_dir,
                    output_postfix="seg",
                    output_dtype=uint8,
                    resample=False,
                    output_ext=".nii",
                ),
            ]
        )

Creating Application class

Our application class would look like below.

It defines App class, inheriting the base Application class.

The base class method, compose, is overridden. Objects required for DICOM parsing, series selection, pixel data conversion to volume image, and segmentation instance creation are created, so is the model-specific SpleenSegOperator. The execution pipeline, as a Directed Acyclic Graph (DAG), is created by connecting these objects through the add_flow method.

class AISpleenSegApp(Application):
    def __init__(self, *args, **kwargs):
        """Creates an application instance."""

        super().__init__(*args, **kwargs)
        self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))

    def run(self, *args, **kwargs):
        # This method calls the base class to run. Can be omitted if simply calling through.
        self._logger.info(f"Begin {self.run.__name__}")
        super().run(*args, **kwargs)
        self._logger.info(f"End {self.run.__name__}")

    def compose(self):
        """Creates the app specific operators and chain them up in the processing DAG."""

        self._logger.debug(f"Begin {self.compose.__name__}")
        app_context = Application.init_app_context({})  # Do not pass argv in Jupyter Notebook
        app_input_path = Path(app_context.input_path)
        app_output_path = Path(app_context.output_path)
        model_path = Path(app_context.model_path)

        self._logger.info(f"App input and output path: {app_input_path}, {app_output_path}")

        # instantiates the SDK built-in operator(s).
        study_loader_op = DICOMDataLoaderOperator(
            self, CountCondition(self, 1), input_folder=app_input_path, name="dcm_loader_op"
        )
        series_selector_op = DICOMSeriesSelectorOperator(self, rules=Sample_Rules_Text, name="series_selector_op")
        series_to_vol_op = DICOMSeriesToVolumeOperator(self, name="series_to_vol_op")

        # Model specific inference operator, supporting MONAI transforms.
        spleen_seg_op = SpleenSegOperator(self, app_context=app_context, model_path=model_path, name="seg_op")

        # Create DICOM Seg writer providing the required segment description for each segment with
        # the actual algorithm and the pertinent organ/tissue.
        # The segment_label, algorithm_name, and algorithm_version are limited to 64 chars.
        # https://dicom.nema.org/medical/dicom/current/output/chtml/part05/sect_6.2.html
        # User can Look up SNOMED CT codes at, e.g.
        # https://bioportal.bioontology.org/ontologies/SNOMEDCT

        _algorithm_name = "3D segmentation of the Spleen from a CT series"
        _algorithm_family = codes.DCM.ArtificialIntelligence
        _algorithm_version = "0.1.0"

        segment_descriptions = [
            SegmentDescription(
                segment_label="Spleen",
                segmented_property_category=codes.SCT.Organ,
                segmented_property_type=codes.SCT.Spleen,
                algorithm_name=_algorithm_name,
                algorithm_family=_algorithm_family,
                algorithm_version=_algorithm_version,
            ),
        ]

        custom_tags = {"SeriesDescription": "AI generated Seg, not for clinical use."}

        dicom_seg_writer = DICOMSegmentationWriterOperator(
            self,
            segment_descriptions=segment_descriptions,
            custom_tags=custom_tags,
            output_folder=app_output_path,
            name="dcm_seg_writer_op",
        )

        # Create the processing pipeline, by specifying the source and destination operators, and
        # ensuring the output from the former matches the input of the latter, in both name and type.
        self.add_flow(study_loader_op, series_selector_op, {("dicom_study_list", "dicom_study_list")})
        self.add_flow(
            series_selector_op, series_to_vol_op, {("study_selected_series_list", "study_selected_series_list")}
        )
        self.add_flow(series_to_vol_op, spleen_seg_op, {("image", "image")})

        # Note below the dicom_seg_writer requires two inputs, each coming from a source operator.
        self.add_flow(
            series_selector_op, dicom_seg_writer, {("study_selected_series_list", "study_selected_series_list")}
        )
        self.add_flow(spleen_seg_op, dicom_seg_writer, {("seg_image", "seg_image")})

        self._logger.debug(f"End {self.compose.__name__}")


# This is a sample series selection rule in JSON, simply selecting CT series.
# If the study has more than 1 CT series, then all of them will be selected.
# Please see more detail in DICOMSeriesSelectorOperator.
# For list of string values, e.g. "ImageType": ["PRIMARY", "ORIGINAL"], it is a match if all elements
# are all in the multi-value attribute of the DICOM series.

Sample_Rules_Text = """
{
    "selections": [
        {
            "name": "CT Series",
            "conditions": {
                "StudyDescription": "(.*?)",
                "Modality": "(?i)CT",
                "SeriesDescription": "(.*?)",
                "ImageType": ["PRIMARY", "ORIGINAL"]
            }
        }
    ]
}
"""

Executing app locally

We can execute the app in Jupyter notebook. Note that the DICOM files of the CT Abdomen series must be present in the dcm folder and the TorchScript, model.ts, in the folder pointed to by the environment variables.

!rm -rf $HOLOSCAN_OUTPUT_PATH
app = AISpleenSegApp()
app.run()
[2024-04-10 16:41:09,106] [INFO] (root) - Parsed args: Namespace(log_level=None, input=None, output=None, model=None, workdir=None, argv=[])
[2024-04-10 16:41:09,114] [INFO] (root) - AppContext object: AppContext(input_path=dcm, output_path=output, model_path=models, workdir=)
[2024-04-10 16:41:09,116] [INFO] (__main__.AISpleenSegApp) - App input and output path: dcm, output
[info] [gxf_executor.cpp:211] Creating context
[info] [gxf_executor.cpp:1674] Loading extensions from configs...
[info] [gxf_executor.cpp:1864] Activating Graph...
[info] [gxf_executor.cpp:1894] Running Graph...
[info] [gxf_executor.cpp:1896] Waiting for completion...
[info] [gxf_executor.cpp:1897] Graph execution waiting. Fragment: 
[info] [greedy_scheduler.cpp:190] Scheduling 6 entities
[2024-04-10 16:41:09,164] [INFO] (monai.deploy.operators.dicom_data_loader_operator.DICOMDataLoaderOperator) - No or invalid input path from the optional input port: None
[2024-04-10 16:41:09,742] [INFO] (root) - Finding series for Selection named: CT Series
[2024-04-10 16:41:09,743] [INFO] (root) - Searching study, : 1.3.6.1.4.1.14519.5.2.1.7085.2626.822645453932810382886582736291
  # of series: 1
[2024-04-10 16:41:09,744] [INFO] (root) - Working on series, instance UID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.119403521930927333027265674239
[2024-04-10 16:41:09,745] [INFO] (root) - On attribute: 'StudyDescription' to match value: '(.*?)'
[2024-04-10 16:41:09,746] [INFO] (root) -     Series attribute StudyDescription value: CT ABDOMEN W IV CONTRAST
[2024-04-10 16:41:09,747] [INFO] (root) - Series attribute string value did not match. Try regEx.
[2024-04-10 16:41:09,748] [INFO] (root) - On attribute: 'Modality' to match value: '(?i)CT'
[2024-04-10 16:41:09,749] [INFO] (root) -     Series attribute Modality value: CT
[2024-04-10 16:41:09,750] [INFO] (root) - Series attribute string value did not match. Try regEx.
[2024-04-10 16:41:09,751] [INFO] (root) - On attribute: 'SeriesDescription' to match value: '(.*?)'
[2024-04-10 16:41:09,753] [INFO] (root) -     Series attribute SeriesDescription value: ABD/PANC 3.0 B31f
[2024-04-10 16:41:09,754] [INFO] (root) - Series attribute string value did not match. Try regEx.
[2024-04-10 16:41:09,755] [INFO] (root) - On attribute: 'ImageType' to match value: ['PRIMARY', 'ORIGINAL']
[2024-04-10 16:41:09,756] [INFO] (root) -     Series attribute ImageType value: None
[2024-04-10 16:41:09,757] [INFO] (root) - Selected Series, UID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.119403521930927333027265674239
[2024-04-10 16:41:10,007] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - Converted Image object metadata:
[2024-04-10 16:41:10,009] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - SeriesInstanceUID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.119403521930927333027265674239, type <class 'str'>
[2024-04-10 16:41:10,009] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - SeriesDate: 20090831, type <class 'str'>
[2024-04-10 16:41:10,010] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - SeriesTime: 101721.452, type <class 'str'>
[2024-04-10 16:41:10,011] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - Modality: CT, type <class 'str'>
[2024-04-10 16:41:10,012] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - SeriesDescription: ABD/PANC 3.0 B31f, type <class 'str'>
[2024-04-10 16:41:10,012] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - PatientPosition: HFS, type <class 'str'>
[2024-04-10 16:41:10,013] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - SeriesNumber: 8, type <class 'int'>
[2024-04-10 16:41:10,014] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - row_pixel_spacing: 0.7890625, type <class 'float'>
[2024-04-10 16:41:10,015] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - col_pixel_spacing: 0.7890625, type <class 'float'>
[2024-04-10 16:41:10,015] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - depth_pixel_spacing: 1.5, type <class 'float'>
[2024-04-10 16:41:10,016] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - row_direction_cosine: [1.0, 0.0, 0.0], type <class 'list'>
[2024-04-10 16:41:10,017] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - col_direction_cosine: [0.0, 1.0, 0.0], type <class 'list'>
[2024-04-10 16:41:10,018] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - depth_direction_cosine: [0.0, 0.0, 1.0], type <class 'list'>
[2024-04-10 16:41:10,019] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - dicom_affine_transform: [[   0.7890625    0.           0.        -197.60547  ]
 [   0.           0.7890625    0.        -398.60547  ]
 [   0.           0.           1.5       -383.       ]
 [   0.           0.           0.           1.       ]], type <class 'numpy.ndarray'>
[2024-04-10 16:41:10,020] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - nifti_affine_transform: [[  -0.7890625   -0.          -0.         197.60547  ]
 [  -0.          -0.7890625   -0.         398.60547  ]
 [   0.           0.           1.5       -383.       ]
 [   0.           0.           0.           1.       ]], type <class 'numpy.ndarray'>
[2024-04-10 16:41:10,021] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - StudyInstanceUID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.822645453932810382886582736291, type <class 'str'>
[2024-04-10 16:41:10,022] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - StudyID: , type <class 'str'>
[2024-04-10 16:41:10,023] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - StudyDate: 20090831, type <class 'str'>
[2024-04-10 16:41:10,025] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - StudyTime: 095948.599, type <class 'str'>
[2024-04-10 16:41:10,026] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - StudyDescription: CT ABDOMEN W IV CONTRAST, type <class 'str'>
[2024-04-10 16:41:10,027] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - AccessionNumber: 5471978513296937, type <class 'str'>
[2024-04-10 16:41:10,029] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - selection_name: CT Series, type <class 'str'>
2024-04-10 16:41:10,797 INFO image_writer.py:197 - writing: /home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/output/saved_images_folder/1.3.6.1.4.1.14519.5.2.1.7085.2626/1.3.6.1.4.1.14519.5.2.1.7085.2626.nii
2024-04-10 16:41:17,104 INFO image_writer.py:197 - writing: /home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/output/saved_images_folder/1.3.6.1.4.1.14519.5.2.1.7085.2626/1.3.6.1.4.1.14519.5.2.1.7085.2626_seg.nii
[2024-04-10 16:41:19,060] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - Output Seg image numpy array shaped: (204, 512, 512)
[2024-04-10 16:41:19,067] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - Output Seg image pixel max value: 1
/home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.10/site-packages/highdicom/valuerep.py:54: UserWarning: The string "C3N-00198" is unlikely to represent the intended person name since it contains only a single component. Construct a person name according to the format in described in https://dicom.nema.org/dicom/2013/output/chtml/part05/sect_6.2.html#sect_6.2.1.2, or, in pydicom 2.2.0 or later, use the pydicom.valuerep.PersonName.from_named_components() method to construct the person name correctly. If a single-component name is really intended, add a trailing caret character to disambiguate the name.
  warnings.warn(
[2024-04-10 16:41:20,496] [INFO] (highdicom.base) - copy Image-related attributes from dataset "1.3.6.1.4.1.14519.5.2.1.7085.2626.936983343951485811186213470191"
[2024-04-10 16:41:20,497] [INFO] (highdicom.base) - copy attributes of module "Specimen"
[2024-04-10 16:41:20,498] [INFO] (highdicom.base) - copy Patient-related attributes from dataset "1.3.6.1.4.1.14519.5.2.1.7085.2626.936983343951485811186213470191"
[2024-04-10 16:41:20,499] [INFO] (highdicom.base) - copy attributes of module "Patient"
[2024-04-10 16:41:20,500] [INFO] (highdicom.base) - copy attributes of module "Clinical Trial Subject"
[2024-04-10 16:41:20,501] [INFO] (highdicom.base) - copy Study-related attributes from dataset "1.3.6.1.4.1.14519.5.2.1.7085.2626.936983343951485811186213470191"
[2024-04-10 16:41:20,502] [INFO] (highdicom.base) - copy attributes of module "General Study"
[2024-04-10 16:41:20,503] [INFO] (highdicom.base) - copy attributes of module "Patient Study"
[2024-04-10 16:41:20,503] [INFO] (highdicom.base) - copy attributes of module "Clinical Trial Study"
[info] [greedy_scheduler.cpp:369] Scheduler stopped: Some entities are waiting for execution, but there are no periodic or async entities to get out of the deadlock.
[info] [greedy_scheduler.cpp:398] Scheduler finished.
[info] [gxf_executor.cpp:1906] Graph execution deactivating. Fragment: 
[info] [gxf_executor.cpp:1907] Deactivating Graph...
[info] [gxf_executor.cpp:1910] Graph execution finished. Fragment: 
[2024-04-10 16:41:20,610] [INFO] (__main__.AISpleenSegApp) - End run

Once the application is verified inside Jupyter notebook, we can write the above Python code into Python files in an application folder.

The application folder structure would look like below:

my_app
├── __main__.py
├── app.py
└── spleen_seg_operator.py

Note

We can create a single application Python file (such as spleen_app.py) that includes the content of the files, instead of creating multiple files. You will see such an example in MedNist Classifier Tutorial.

# Create an application folder
!mkdir -p my_app && rm -rf my_app/*

spleen_seg_operator.py

%%writefile my_app/spleen_seg_operator.py
import logging

from numpy import uint8
from pathlib import Path

from monai.deploy.core import AppContext, ConditionType, Fragment, Operator, OperatorSpec
from monai.deploy.operators.monai_seg_inference_operator import InfererType, InMemImageReader, MonaiSegInferenceOperator
from monai.transforms import (
    Activationsd,
    AsDiscreted,
    Compose,
    EnsureChannelFirstd,
    EnsureTyped,
    Invertd,
    LoadImaged,
    Orientationd,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
)

class SpleenSegOperator(Operator):
    """Performs Spleen segmentation with a 3D image converted from a DICOM CT series.
    """

    DEFAULT_OUTPUT_FOLDER = Path.cwd() / "output/saved_images_folder"

    def __init__(
        self,
        fragment: Fragment,
        *args,
        app_context: AppContext,
        model_path: Path,
        output_folder: Path = DEFAULT_OUTPUT_FOLDER,
        **kwargs,
    ):

        self.logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
        self._input_dataset_key = "image"
        self._pred_dataset_key = "pred"

        self.model_path = model_path
        self.output_folder = output_folder
        self.output_folder.mkdir(parents=True, exist_ok=True)
        self.app_context = app_context
        self.input_name_image = "image"
        self.output_name_seg = "seg_image"
        self.output_name_saved_images_folder = "saved_images_folder"

        # The base class has an attribute called fragment to hold the reference to the fragment object
        super().__init__(fragment, *args, **kwargs)

    def setup(self, spec: OperatorSpec):
        spec.input(self.input_name_image)
        spec.output(self.output_name_seg)
        spec.output(self.output_name_saved_images_folder).condition(
            ConditionType.NONE
        )  # Output not requiring a receiver

    def compute(self, op_input, op_output, context):
        input_image = op_input.receive(self.input_name_image)
        if not input_image:
            raise ValueError("Input image is not found.")

        # This operator gets an in-memory Image object, so a specialized ImageReader is needed.
        _reader = InMemImageReader(input_image)

        pre_transforms = self.pre_process(_reader, str(self.output_folder))
        post_transforms = self.post_process(pre_transforms, str(self.output_folder))

        # Delegates inference and saving output to the built-in operator.
        infer_operator = MonaiSegInferenceOperator(
            self.fragment,
            roi_size=(
                96,
                96,
                96,
            ),
            pre_transforms=pre_transforms,
            post_transforms=post_transforms,
            overlap=0.6,
            app_context=self.app_context,
            model_name="",
            inferer=InfererType.SLIDING_WINDOW,
            sw_batch_size=4,
            model_path=self.model_path,
            name="monai_seg_inference_op",
        )

        # Setting the keys used in the dictionary based transforms may change.
        infer_operator.input_dataset_key = self._input_dataset_key
        infer_operator.pred_dataset_key = self._pred_dataset_key

        # Now emit data to the output ports of this operator
        op_output.emit(infer_operator.compute_impl(input_image, context), self.output_name_seg)
        op_output.emit(self.output_folder, self.output_name_saved_images_folder)

    def pre_process(self, img_reader, out_dir: str = "./input_images") -> Compose:
        """Composes transforms for preprocessing input before predicting on a model."""

        Path(out_dir).mkdir(parents=True, exist_ok=True)
        my_key = self._input_dataset_key

        return Compose(
            [
                LoadImaged(keys=my_key, reader=img_reader),
                EnsureChannelFirstd(keys=my_key),
                # The SaveImaged transform can be commented out to save 5 seconds.
                # Uncompress NIfTI file, nii, is used favoring speed over size, but can be changed to nii.gz
                SaveImaged(
                    keys=my_key,
                    output_dir=out_dir,
                    output_postfix="",
                    resample=False,
                    output_ext=".nii",
                ),
                Orientationd(keys=my_key, axcodes="RAS"),
                Spacingd(keys=my_key, pixdim=[1.5, 1.5, 2.9], mode=["bilinear"]),
                ScaleIntensityRanged(keys=my_key, a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
                EnsureTyped(keys=my_key),
            ]
        )

    def post_process(self, pre_transforms: Compose, out_dir: str = "./prediction_output") -> Compose:
        """Composes transforms for postprocessing the prediction results."""

        Path(out_dir).mkdir(parents=True, exist_ok=True)
        pred_key = self._pred_dataset_key

        return Compose(
            [
                Activationsd(keys=pred_key, softmax=True),
                Invertd(
                    keys=pred_key,
                    transform=pre_transforms,
                    orig_keys=self._input_dataset_key,
                    nearest_interp=False,
                    to_tensor=True,
                ),
                AsDiscreted(keys=pred_key, argmax=True),
                # The SaveImaged transform can be commented out to save 5 seconds.
                # Uncompress NIfTI file, nii, is used favoring speed over size, but can be changed to nii.gz
                SaveImaged(
                    keys=pred_key,
                    output_dir=out_dir,
                    output_postfix="seg",
                    output_dtype=uint8,
                    resample=False,
                    output_ext=".nii",
                ),
            ]
        )
Writing my_app/spleen_seg_operator.py

app.py

%%writefile my_app/app.py
import logging
from pathlib import Path

from spleen_seg_operator import SpleenSegOperator

from pydicom.sr.codedict import codes  # Required for setting SegmentDescription attributes.

from monai.deploy.conditions import CountCondition
from monai.deploy.core import AppContext, Application
from monai.deploy.operators.dicom_data_loader_operator import DICOMDataLoaderOperator
from monai.deploy.operators.dicom_seg_writer_operator import DICOMSegmentationWriterOperator, SegmentDescription
from monai.deploy.operators.dicom_series_selector_operator import DICOMSeriesSelectorOperator
from monai.deploy.operators.dicom_series_to_volume_operator import DICOMSeriesToVolumeOperator
from monai.deploy.operators.stl_conversion_operator import STLConversionOperator

class AISpleenSegApp(Application):
    def __init__(self, *args, **kwargs):
        """Creates an application instance."""

        super().__init__(*args, **kwargs)
        self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))

    def run(self, *args, **kwargs):
        # This method calls the base class to run. Can be omitted if simply calling through.
        self._logger.info(f"Begin {self.run.__name__}")
        super().run(*args, **kwargs)
        self._logger.info(f"End {self.run.__name__}")

    def compose(self):
        """Creates the app specific operators and chain them up in the processing DAG."""

        # Use Commandline options over environment variables to init context.
        app_context = Application.init_app_context(self.argv)
        self._logger.debug(f"Begin {self.compose.__name__}")
        app_input_path = Path(app_context.input_path)
        app_output_path = Path(app_context.output_path)
        model_path = Path(app_context.model_path)

        self._logger.info(f"App input and output path: {app_input_path}, {app_output_path}")

        # instantiates the SDK built-in operator(s).
        study_loader_op = DICOMDataLoaderOperator(
            self, CountCondition(self, 1), input_folder=app_input_path, name="dcm_loader_op"
        )
        series_selector_op = DICOMSeriesSelectorOperator(self, rules=Sample_Rules_Text, name="series_selector_op")
        series_to_vol_op = DICOMSeriesToVolumeOperator(self, name="series_to_vol_op")

        # Model specific inference operator, supporting MONAI transforms.
        spleen_seg_op = SpleenSegOperator(self, app_context=app_context, model_path=model_path, name="seg_op")

        # Create DICOM Seg writer providing the required segment description for each segment with
        # the actual algorithm and the pertinent organ/tissue.
        # The segment_label, algorithm_name, and algorithm_version are limited to 64 chars.
        # https://dicom.nema.org/medical/dicom/current/output/chtml/part05/sect_6.2.html
        # User can Look up SNOMED CT codes at, e.g.
        # https://bioportal.bioontology.org/ontologies/SNOMEDCT

        _algorithm_name = "3D segmentation of the Spleen from a CT series"
        _algorithm_family = codes.DCM.ArtificialIntelligence
        _algorithm_version = "0.1.0"

        segment_descriptions = [
            SegmentDescription(
                segment_label="Spleen",
                segmented_property_category=codes.SCT.Organ,
                segmented_property_type=codes.SCT.Spleen,
                algorithm_name=_algorithm_name,
                algorithm_family=_algorithm_family,
                algorithm_version=_algorithm_version,
            ),
        ]

        custom_tags = {"SeriesDescription": "AI generated Seg, not for clinical use."}

        dicom_seg_writer = DICOMSegmentationWriterOperator(
            self,
            segment_descriptions=segment_descriptions,
            custom_tags=custom_tags,
            output_folder=app_output_path,
            name="dcm_seg_writer_op",
        )

        # Create the processing pipeline, by specifying the source and destination operators, and
        # ensuring the output from the former matches the input of the latter, in both name and type.
        self.add_flow(study_loader_op, series_selector_op, {("dicom_study_list", "dicom_study_list")})
        self.add_flow(
            series_selector_op, series_to_vol_op, {("study_selected_series_list", "study_selected_series_list")}
        )
        self.add_flow(series_to_vol_op, spleen_seg_op, {("image", "image")})

        # Note below the dicom_seg_writer requires two inputs, each coming from a source operator.
        self.add_flow(
            series_selector_op, dicom_seg_writer, {("study_selected_series_list", "study_selected_series_list")}
        )
        self.add_flow(spleen_seg_op, dicom_seg_writer, {("seg_image", "seg_image")})

        self._logger.debug(f"End {self.compose.__name__}")


# This is a sample series selection rule in JSON, simply selecting CT series.
# If the study has more than 1 CT series, then all of them will be selected.
# Please see more detail in DICOMSeriesSelectorOperator.
# For list of string values, e.g. "ImageType": ["PRIMARY", "ORIGINAL"], it is a match if all elements
# are all in the multi-value attribute of the DICOM series.

Sample_Rules_Text = """
{
    "selections": [
        {
            "name": "CT Series",
            "conditions": {
                "StudyDescription": "(.*?)",
                "Modality": "(?i)CT",
                "SeriesDescription": "(.*?)",
                "ImageType": ["PRIMARY", "ORIGINAL"]
            }
        }
    ]
}
"""

if __name__ == "__main__":
    # Creates the app and test it standalone.
    AISpleenSegApp().run()
Writing my_app/app.py
if __name__ == "__main__":
    AISpleenSegApp().run()

The above lines are needed to execute the application code by using python interpreter.

__main__.py

__main__.py is needed for MONAI Application Packager to detect the main application code (app.py) when the application is executed with the application folder path (e.g., python simple_imaging_app).

%%writefile my_app/__main__.py
from app import AISpleenSegApp

if __name__ == "__main__":
    AISpleenSegApp().run()
Writing my_app/__main__.py
!ls my_app
app.py	__main__.py  spleen_seg_operator.py

In this time, let’s execute the app in the command line.

Note

Since the environment variables have been set and contain the correct paths, it is not necessary to provide the command line options on running the application.

!rm -rf $HOLOSCAN_OUTPUT_PATH
!python my_app
[2024-04-10 16:41:25,305] [INFO] (root) - Parsed args: Namespace(log_level=None, input=None, output=None, model=None, workdir=None, argv=['my_app'])
[2024-04-10 16:41:25,472] [INFO] (root) - AppContext object: AppContext(input_path=dcm, output_path=output, model_path=models, workdir=)
[2024-04-10 16:41:25,472] [INFO] (app.AISpleenSegApp) - App input and output path: dcm, output
[info] [gxf_executor.cpp:211] Creating context
[info] [gxf_executor.cpp:1674] Loading extensions from configs...
[info] [gxf_executor.cpp:1864] Activating Graph...
[info] [gxf_executor.cpp:1894] Running Graph...
[info] [gxf_executor.cpp:1896] Waiting for completion...
[info] [gxf_executor.cpp:1897] Graph execution waiting. Fragment: 
[info] [greedy_scheduler.cpp:190] Scheduling 6 entities
[2024-04-10 16:41:25,505] [INFO] (monai.deploy.operators.dicom_data_loader_operator.DICOMDataLoaderOperator) - No or invalid input path from the optional input port: None
[2024-04-10 16:41:25,850] [INFO] (root) - Finding series for Selection named: CT Series
[2024-04-10 16:41:25,850] [INFO] (root) - Searching study, : 1.3.6.1.4.1.14519.5.2.1.7085.2626.822645453932810382886582736291
  # of series: 1
[2024-04-10 16:41:25,850] [INFO] (root) - Working on series, instance UID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.119403521930927333027265674239
[2024-04-10 16:41:25,850] [INFO] (root) - On attribute: 'StudyDescription' to match value: '(.*?)'
[2024-04-10 16:41:25,850] [INFO] (root) -     Series attribute StudyDescription value: CT ABDOMEN W IV CONTRAST
[2024-04-10 16:41:25,850] [INFO] (root) - Series attribute string value did not match. Try regEx.
[2024-04-10 16:41:25,851] [INFO] (root) - On attribute: 'Modality' to match value: '(?i)CT'
[2024-04-10 16:41:25,851] [INFO] (root) -     Series attribute Modality value: CT
[2024-04-10 16:41:25,851] [INFO] (root) - Series attribute string value did not match. Try regEx.
[2024-04-10 16:41:25,851] [INFO] (root) - On attribute: 'SeriesDescription' to match value: '(.*?)'
[2024-04-10 16:41:25,851] [INFO] (root) -     Series attribute SeriesDescription value: ABD/PANC 3.0 B31f
[2024-04-10 16:41:25,851] [INFO] (root) - Series attribute string value did not match. Try regEx.
[2024-04-10 16:41:25,851] [INFO] (root) - On attribute: 'ImageType' to match value: ['PRIMARY', 'ORIGINAL']
[2024-04-10 16:41:25,851] [INFO] (root) -     Series attribute ImageType value: None
[2024-04-10 16:41:25,851] [INFO] (root) - Selected Series, UID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.119403521930927333027265674239
[2024-04-10 16:41:26,402] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - Converted Image object metadata:
[2024-04-10 16:41:26,402] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - SeriesInstanceUID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.119403521930927333027265674239, type <class 'str'>
[2024-04-10 16:41:26,402] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - SeriesDate: 20090831, type <class 'str'>
[2024-04-10 16:41:26,402] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - SeriesTime: 101721.452, type <class 'str'>
[2024-04-10 16:41:26,402] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - Modality: CT, type <class 'str'>
[2024-04-10 16:41:26,402] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - SeriesDescription: ABD/PANC 3.0 B31f, type <class 'str'>
[2024-04-10 16:41:26,402] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - PatientPosition: HFS, type <class 'str'>
[2024-04-10 16:41:26,402] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - SeriesNumber: 8, type <class 'int'>
[2024-04-10 16:41:26,402] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - row_pixel_spacing: 0.7890625, type <class 'float'>
[2024-04-10 16:41:26,402] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - col_pixel_spacing: 0.7890625, type <class 'float'>
[2024-04-10 16:41:26,402] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - depth_pixel_spacing: 1.5, type <class 'float'>
[2024-04-10 16:41:26,402] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - row_direction_cosine: [1.0, 0.0, 0.0], type <class 'list'>
[2024-04-10 16:41:26,402] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - col_direction_cosine: [0.0, 1.0, 0.0], type <class 'list'>
[2024-04-10 16:41:26,402] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - depth_direction_cosine: [0.0, 0.0, 1.0], type <class 'list'>
[2024-04-10 16:41:26,402] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - dicom_affine_transform: [[   0.7890625    0.           0.        -197.60547  ]
 [   0.           0.7890625    0.        -398.60547  ]
 [   0.           0.           1.5       -383.       ]
 [   0.           0.           0.           1.       ]], type <class 'numpy.ndarray'>
[2024-04-10 16:41:26,403] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - nifti_affine_transform: [[  -0.7890625   -0.          -0.         197.60547  ]
 [  -0.          -0.7890625   -0.         398.60547  ]
 [   0.           0.           1.5       -383.       ]
 [   0.           0.           0.           1.       ]], type <class 'numpy.ndarray'>
[2024-04-10 16:41:26,403] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - StudyInstanceUID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.822645453932810382886582736291, type <class 'str'>
[2024-04-10 16:41:26,403] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - StudyID: , type <class 'str'>
[2024-04-10 16:41:26,403] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - StudyDate: 20090831, type <class 'str'>
[2024-04-10 16:41:26,403] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - StudyTime: 095948.599, type <class 'str'>
[2024-04-10 16:41:26,403] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - StudyDescription: CT ABDOMEN W IV CONTRAST, type <class 'str'>
[2024-04-10 16:41:26,403] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - AccessionNumber: 5471978513296937, type <class 'str'>
[2024-04-10 16:41:26,403] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - selection_name: CT Series, type <class 'str'>
2024-04-10 16:41:27,452 INFO image_writer.py:197 - writing: /home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/output/saved_images_folder/1.3.6.1.4.1.14519.5.2.1.7085.2626/1.3.6.1.4.1.14519.5.2.1.7085.2626.nii
2024-04-10 16:41:34,092 INFO image_writer.py:197 - writing: /home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/output/saved_images_folder/1.3.6.1.4.1.14519.5.2.1.7085.2626/1.3.6.1.4.1.14519.5.2.1.7085.2626_seg.nii
[2024-04-10 16:41:35,714] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - Output Seg image numpy array shaped: (204, 512, 512)
[2024-04-10 16:41:35,720] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - Output Seg image pixel max value: 1
/home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.10/site-packages/highdicom/valuerep.py:54: UserWarning: The string "C3N-00198" is unlikely to represent the intended person name since it contains only a single component. Construct a person name according to the format in described in https://dicom.nema.org/dicom/2013/output/chtml/part05/sect_6.2.html#sect_6.2.1.2, or, in pydicom 2.2.0 or later, use the pydicom.valuerep.PersonName.from_named_components() method to construct the person name correctly. If a single-component name is really intended, add a trailing caret character to disambiguate the name.
  warnings.warn(
[2024-04-10 16:41:37,387] [INFO] (highdicom.base) - copy Image-related attributes from dataset "1.3.6.1.4.1.14519.5.2.1.7085.2626.936983343951485811186213470191"
[2024-04-10 16:41:37,387] [INFO] (highdicom.base) - copy attributes of module "Specimen"
[2024-04-10 16:41:37,387] [INFO] (highdicom.base) - copy Patient-related attributes from dataset "1.3.6.1.4.1.14519.5.2.1.7085.2626.936983343951485811186213470191"
[2024-04-10 16:41:37,387] [INFO] (highdicom.base) - copy attributes of module "Patient"
[2024-04-10 16:41:37,387] [INFO] (highdicom.base) - copy attributes of module "Clinical Trial Subject"
[2024-04-10 16:41:37,388] [INFO] (highdicom.base) - copy Study-related attributes from dataset "1.3.6.1.4.1.14519.5.2.1.7085.2626.936983343951485811186213470191"
[2024-04-10 16:41:37,388] [INFO] (highdicom.base) - copy attributes of module "General Study"
[2024-04-10 16:41:37,388] [INFO] (highdicom.base) - copy attributes of module "Patient Study"
[2024-04-10 16:41:37,388] [INFO] (highdicom.base) - copy attributes of module "Clinical Trial Study"
[info] [greedy_scheduler.cpp:369] Scheduler stopped: Some entities are waiting for execution, but there are no periodic or async entities to get out of the deadlock.
[info] [greedy_scheduler.cpp:398] Scheduler finished.
[info] [gxf_executor.cpp:1906] Graph execution deactivating. Fragment: 
[info] [gxf_executor.cpp:1907] Deactivating Graph...
[info] [gxf_executor.cpp:1910] Graph execution finished. Fragment: 
[2024-04-10 16:41:37,477] [INFO] (app.AISpleenSegApp) - End run
[info] [gxf_executor.cpp:230] Destroying context
!ls -R $HOLOSCAN_OUTPUT_PATH
output:
1.2.826.0.1.3680043.10.511.3.12733408477402210746640758069824301.dcm
saved_images_folder

output/saved_images_folder:
1.3.6.1.4.1.14519.5.2.1.7085.2626

output/saved_images_folder/1.3.6.1.4.1.14519.5.2.1.7085.2626:
1.3.6.1.4.1.14519.5.2.1.7085.2626.nii
1.3.6.1.4.1.14519.5.2.1.7085.2626_seg.nii

Packaging app

Let’s package the app with MONAI Application Packager.

In this version of the App SDK, we need to write out the configuration yaml file as well as the package requirements file, in the application folder.

%%writefile my_app/app.yaml
%YAML 1.2
---
application:
  title: MONAI Deploy App Package - MONAI Bundle AI App
  version: 1.0
  inputFormats: ["file"]
  outputFormats: ["file"]

resources:
  cpu: 1
  gpu: 1
  memory: 1Gi
  gpuMemory: 6Gi
Writing my_app/app.yaml
%%writefile my_app/requirements.txt
highdicom>=0.18.2
monai>=1.0
nibabel>=3.2.1
numpy>=1.21.6
pydicom>=2.3.0
setuptools>=59.5.0 # for pkg_resources
SimpleITK>=2.0.0
torch>=1.12.0
Writing my_app/requirements.txt

Now we can use the CLI package command to build the MONAI Application Package (MAP) container image based on a supported base image.

Note

Building a MONAI Application Package (Docker image) can take time. Use -l DEBUG option to see the progress.

tag_prefix = "my_app"

!monai-deploy package my_app -m {models_folder} -c my_app/app.yaml -t {tag_prefix}:1.0 --platform x64-workstation -l DEBUG
[2024-04-10 16:41:40,386] [INFO] (packager.parameters) - Application: /home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/my_app
[2024-04-10 16:41:40,386] [INFO] (packager.parameters) - Detected application type: Python Module
[2024-04-10 16:41:40,386] [INFO] (packager) - Scanning for models in /home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/models...
[2024-04-10 16:41:40,386] [DEBUG] (packager) - Model model=/home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/models/model added.
[2024-04-10 16:41:40,386] [INFO] (packager) - Reading application configuration from /home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/my_app/app.yaml...
[2024-04-10 16:41:40,387] [INFO] (packager) - Generating app.json...
[2024-04-10 16:41:40,387] [INFO] (packager) - Generating pkg.json...
[2024-04-10 16:41:40,394] [DEBUG] (common) - 
=============== Begin app.json ===============
{
    "apiVersion": "1.0.0",
    "command": "[\"python3\", \"/opt/holoscan/app\"]",
    "environment": {
        "HOLOSCAN_APPLICATION": "/opt/holoscan/app",
        "HOLOSCAN_INPUT_PATH": "input/",
        "HOLOSCAN_OUTPUT_PATH": "output/",
        "HOLOSCAN_WORKDIR": "/var/holoscan",
        "HOLOSCAN_MODEL_PATH": "/opt/holoscan/models",
        "HOLOSCAN_CONFIG_PATH": "/var/holoscan/app.yaml",
        "HOLOSCAN_APP_MANIFEST_PATH": "/etc/holoscan/app.json",
        "HOLOSCAN_PKG_MANIFEST_PATH": "/etc/holoscan/pkg.json",
        "HOLOSCAN_DOCS_PATH": "/opt/holoscan/docs",
        "HOLOSCAN_LOGS_PATH": "/var/holoscan/logs"
    },
    "input": {
        "path": "input/",
        "formats": null
    },
    "liveness": null,
    "output": {
        "path": "output/",
        "formats": null
    },
    "readiness": null,
    "sdk": "monai-deploy",
    "sdkVersion": "0.5.1",
    "timeout": 0,
    "version": 1.0,
    "workingDirectory": "/var/holoscan"
}
================ End app.json ================
                 
[2024-04-10 16:41:40,394] [DEBUG] (common) - 
=============== Begin pkg.json ===============
{
    "apiVersion": "1.0.0",
    "applicationRoot": "/opt/holoscan/app",
    "modelRoot": "/opt/holoscan/models",
    "models": {
        "model": "/opt/holoscan/models/model"
    },
    "resources": {
        "cpu": 1,
        "gpu": 1,
        "memory": "1Gi",
        "gpuMemory": "6Gi"
    },
    "version": 1.0,
    "platformConfig": "dgpu"
}
================ End pkg.json ================
                 
[2024-04-10 16:41:40,461] [DEBUG] (packager.builder) - 
========== Begin Dockerfile ==========


FROM nvcr.io/nvidia/clara-holoscan/holoscan:v1.0.3-dgpu

ENV DEBIAN_FRONTEND=noninteractive
ENV TERM=xterm-256color

ARG UNAME
ARG UID
ARG GID

RUN mkdir -p /etc/holoscan/ \
        && mkdir -p /opt/holoscan/ \
        && mkdir -p /var/holoscan \
        && mkdir -p /opt/holoscan/app \
        && mkdir -p /var/holoscan/input \
        && mkdir -p /var/holoscan/output

LABEL base="nvcr.io/nvidia/clara-holoscan/holoscan:v1.0.3-dgpu"
LABEL tag="my_app:1.0"
LABEL org.opencontainers.image.title="MONAI Deploy App Package - MONAI Bundle AI App"
LABEL org.opencontainers.image.version="1.0"
LABEL org.nvidia.holoscan="1.0.3"
LABEL org.monai.deploy.app-sdk="0.5.1"


ENV HOLOSCAN_ENABLE_HEALTH_CHECK=true
ENV HOLOSCAN_INPUT_PATH=/var/holoscan/input
ENV HOLOSCAN_OUTPUT_PATH=/var/holoscan/output
ENV HOLOSCAN_WORKDIR=/var/holoscan
ENV HOLOSCAN_APPLICATION=/opt/holoscan/app
ENV HOLOSCAN_TIMEOUT=0
ENV HOLOSCAN_MODEL_PATH=/opt/holoscan/models
ENV HOLOSCAN_DOCS_PATH=/opt/holoscan/docs
ENV HOLOSCAN_CONFIG_PATH=/var/holoscan/app.yaml
ENV HOLOSCAN_APP_MANIFEST_PATH=/etc/holoscan/app.json
ENV HOLOSCAN_PKG_MANIFEST_PATH=/etc/holoscan/pkg.json
ENV HOLOSCAN_LOGS_PATH=/var/holoscan/logs
ENV PATH=/root/.local/bin:/opt/nvidia/holoscan:$PATH
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/libtorch/1.13.1/lib/:/opt/nvidia/holoscan/lib

RUN apt-get update \
    && apt-get install -y curl jq \
    && rm -rf /var/lib/apt/lists/*

ENV PYTHONPATH="/opt/holoscan/app:$PYTHONPATH"



RUN groupadd -f -g $GID $UNAME
RUN useradd -rm -d /home/$UNAME -s /bin/bash -g $GID -G sudo -u $UID $UNAME
RUN chown -R holoscan /var/holoscan 
RUN chown -R holoscan /var/holoscan/input 
RUN chown -R holoscan /var/holoscan/output 

# Set the working directory
WORKDIR /var/holoscan

# Copy HAP/MAP tool script
COPY ./tools /var/holoscan/tools
RUN chmod +x /var/holoscan/tools


# Copy gRPC health probe

USER $UNAME

ENV PATH=/root/.local/bin:/home/holoscan/.local/bin:/opt/nvidia/holoscan:$PATH

COPY ./pip/requirements.txt /tmp/requirements.txt

RUN pip install --upgrade pip
RUN pip install --no-cache-dir --user -r /tmp/requirements.txt

# Install Holoscan from PyPI only when sdk_type is Holoscan. 
# For MONAI Deploy, the APP SDK will install it unless user specifies the Holoscan SDK file.

# Copy user-specified MONAI Deploy SDK file
COPY ./monai_deploy_app_sdk-0.5.1+25.g31e4165.dirty-py3-none-any.whl /tmp/monai_deploy_app_sdk-0.5.1+25.g31e4165.dirty-py3-none-any.whl
RUN pip install /tmp/monai_deploy_app_sdk-0.5.1+25.g31e4165.dirty-py3-none-any.whl




COPY ./models  /opt/holoscan/models

COPY ./map/app.json /etc/holoscan/app.json
COPY ./app.config /var/holoscan/app.yaml
COPY ./map/pkg.json /etc/holoscan/pkg.json

COPY ./app /opt/holoscan/app

ENTRYPOINT ["/var/holoscan/tools"]
=========== End Dockerfile ===========

[2024-04-10 16:41:40,461] [INFO] (packager.builder) - 
===============================================================================
Building image for:                 x64-workstation
    Architecture:                   linux/amd64
    Base Image:                     nvcr.io/nvidia/clara-holoscan/holoscan:v1.0.3-dgpu
    Build Image:                    N/A
    Cache:                          Enabled
    Configuration:                  dgpu
    Holoscan SDK Package:           pypi.org
    MONAI Deploy App SDK Package:   /home/mqin/src/monai-deploy-app-sdk/dist/monai_deploy_app_sdk-0.5.1+25.g31e4165.dirty-py3-none-any.whl
    gRPC Health Probe:              N/A
    SDK Version:                    1.0.3
    SDK:                            monai-deploy
    Tag:                            my_app-x64-workstation-dgpu-linux-amd64:1.0
    
[2024-04-10 16:41:41,020] [INFO] (common) - Using existing Docker BuildKit builder `holoscan_app_builder`
[2024-04-10 16:41:41,021] [DEBUG] (packager.builder) - Building Holoscan Application Package: tag=my_app-x64-workstation-dgpu-linux-amd64:1.0
#0 building with "holoscan_app_builder" instance using docker-container driver

#1 [internal] load build definition from Dockerfile
#1 transferring dockerfile: 2.80kB done
#1 DONE 0.1s

#2 [internal] load metadata for nvcr.io/nvidia/clara-holoscan/holoscan:v1.0.3-dgpu
#2 DONE 0.1s

#3 [internal] load .dockerignore
#3 transferring context: 1.79kB done
#3 DONE 0.1s

#4 [internal] load build context
#4 DONE 0.0s

#5 importing cache manifest from local:6394528277147153176
#5 inferred cache manifest type: application/vnd.oci.image.index.v1+json done
#5 DONE 0.0s

#6 [ 1/21] FROM nvcr.io/nvidia/clara-holoscan/holoscan:v1.0.3-dgpu@sha256:50343c616bf910e2a7651abb59db7833933e82cce64c3c4885f938d7e4af6155
#6 resolve nvcr.io/nvidia/clara-holoscan/holoscan:v1.0.3-dgpu@sha256:50343c616bf910e2a7651abb59db7833933e82cce64c3c4885f938d7e4af6155 0.0s done
#6 DONE 0.1s

#7 importing cache manifest from nvcr.io/nvidia/clara-holoscan/holoscan:v1.0.3-dgpu
#7 inferred cache manifest type: application/vnd.docker.distribution.manifest.list.v2+json done
#7 DONE 0.9s

#4 [internal] load build context
#4 transferring context: 19.56MB 0.1s done
#4 DONE 0.2s

#8 [ 3/21] RUN apt-get update     && apt-get install -y curl jq     && rm -rf /var/lib/apt/lists/*
#8 CACHED

#9 [14/21] RUN pip install --no-cache-dir --user -r /tmp/requirements.txt
#9 CACHED

#10 [ 4/21] RUN groupadd -f -g 1000 holoscan
#10 CACHED

#11 [11/21] RUN chmod +x /var/holoscan/tools
#11 CACHED

#12 [ 2/21] RUN mkdir -p /etc/holoscan/         && mkdir -p /opt/holoscan/         && mkdir -p /var/holoscan         && mkdir -p /opt/holoscan/app         && mkdir -p /var/holoscan/input         && mkdir -p /var/holoscan/output
#12 CACHED

#13 [ 9/21] WORKDIR /var/holoscan
#13 CACHED

#14 [ 6/21] RUN chown -R holoscan /var/holoscan
#14 CACHED

#15 [16/21] RUN pip install /tmp/monai_deploy_app_sdk-0.5.1+25.g31e4165.dirty-py3-none-any.whl
#15 CACHED

#16 [ 8/21] RUN chown -R holoscan /var/holoscan/output
#16 CACHED

#17 [18/21] COPY ./map/app.json /etc/holoscan/app.json
#17 CACHED

#18 [ 5/21] RUN useradd -rm -d /home/holoscan -s /bin/bash -g 1000 -G sudo -u 1000 holoscan
#18 CACHED

#19 [ 7/21] RUN chown -R holoscan /var/holoscan/input
#19 CACHED

#20 [12/21] COPY ./pip/requirements.txt /tmp/requirements.txt
#20 CACHED

#21 [17/21] COPY ./models  /opt/holoscan/models
#21 CACHED

#22 [10/21] COPY ./tools /var/holoscan/tools
#22 CACHED

#23 [15/21] COPY ./monai_deploy_app_sdk-0.5.1+25.g31e4165.dirty-py3-none-any.whl /tmp/monai_deploy_app_sdk-0.5.1+25.g31e4165.dirty-py3-none-any.whl
#23 CACHED

#24 [19/21] COPY ./app.config /var/holoscan/app.yaml
#24 CACHED

#25 [13/21] RUN pip install --upgrade pip
#25 CACHED

#26 [20/21] COPY ./map/pkg.json /etc/holoscan/pkg.json
#26 CACHED

#27 [21/21] COPY ./app /opt/holoscan/app
#27 DONE 0.3s

#28 exporting to docker image format
#28 exporting layers
#28 exporting layers 0.2s done
#28 exporting manifest sha256:814514f05787bbad414758721a0e6e6c3bc4c8e15be135868bd5d22125ba1323 0.0s done
#28 exporting config sha256:5effa5125f3b256ed2e1063cc763040b2eac0f3a49c281b3fecbe3ebbddedce9 0.0s done
#28 sending tarball
#28 ...

#29 importing to docker
#29 loading layer bc0556e272e1 3.91kB / 3.91kB
#29 loading layer bc0556e272e1 3.91kB / 3.91kB 0.7s done
#29 DONE 0.7s

#28 exporting to docker image format
#28 sending tarball 74.9s done
#28 DONE 75.2s

#30 exporting cache to client directory
#30 preparing build cache for export
#30 writing layer sha256:00bb4c1319ba1a33ac3edcb3aa1240d8abcb8d0383c6267ed8028d3b6228a8a4
#30 writing layer sha256:00bb4c1319ba1a33ac3edcb3aa1240d8abcb8d0383c6267ed8028d3b6228a8a4 done
#30 writing layer sha256:014cff740c9ec6e9a30d0b859219a700ae880eb385d62095d348f5ea136d6015 done
#30 writing layer sha256:0a1756432df4a4350712d8ae5c003f1526bd2180800b3ae6301cfc9ccf370254 done
#30 writing layer sha256:0a77dcbd0e648ddc4f8e5230ade8fdb781d99e24fa4f13ca96a360c7f7e6751f done
#30 writing layer sha256:0bf3a16e4f3f9ec99796b99e331a5c62472bc9377925e1fdc05f64709ed09895 done
#30 writing layer sha256:0ec682bf99715a9f88631226f3749e2271b8b9f254528ef61f65ed829984821c done
#30 writing layer sha256:1133dfcee0e851b490d17b3567f50c4b25ba5750da02ba4b3f3630655d0b1a7b done
#30 writing layer sha256:1294b2835667d633f938174d9fecb18a60bbbebb6fb49788a1f939893a25d1af done
#30 writing layer sha256:16a03c6e0373b62f9713416da0229bb7ce2585183141081d3ea8427ad2e84408 done
#30 writing layer sha256:183aa7032b52e859f5de3dac98da7c8398ed5f8a984d74865561f126c0eecef2 done
#30 writing layer sha256:20d331454f5fb557f2692dfbdbe092c718fd2cb55d5db9d661b62228dacca5c2 done
#30 writing layer sha256:2232aeb26b5b7ea57227e9a5b84da4fb229624d7bc976a5f7ce86d9c8653d277 done
#30 writing layer sha256:238f69a43816e481f0295995fcf5fe74d59facf0f9f99734c8d0a2fb140630e0 done
#30 writing layer sha256:2ad84487f9d4d31cd1e0a92697a5447dd241935253d036b272ef16d31620c1e7 done
#30 writing layer sha256:2bb73464628bd4a136c4937f42d522c847bea86b2215ae734949e24c1caf450e done
#30 writing layer sha256:3e3e04011ebdba380ab129f0ee390626cb2a600623815ca756340c18bedb9517 done
#30 writing layer sha256:3f0770bfaa7c2f6e0a801dbbdeb644aedfdfeccb547611d3bf9faef04222aeba 0.0s done
#30 writing layer sha256:42619ce4a0c9e54cfd0ee41a8e5f27d58b3f51becabd1ac6de725fbe6c42b14a done
#30 writing layer sha256:43a21fb6c76bd2b3715cc09d9f8c3865dc61c51dd9e2327b429f5bec8fff85d1 done
#30 writing layer sha256:4482079b5d33963eb55191bf404b70095535d4a8e2b64dab7373500515f896b4 done
#30 writing layer sha256:49bdc9abf8a437ccff67cc11490ba52c976577992909856a86be872a34d3b950 done
#30 writing layer sha256:4b691ba9f48b41eaa0c754feba8366f1c030464fcbc55eeffa6c86675990933a done
#30 writing layer sha256:4d04a8db404f16c2704fa10739cb6745a0187713a21a6ef0deb34b48629b54c1 done
#30 writing layer sha256:4f4fb700ef54461cfa02571ae0db9a0dc1e0cdb5577484a6d75e68dc38e8acc1 done
#30 writing layer sha256:5275a41be8f6691a490c0a15589e0910c73bf971169ad33a850ef570d37f63dd done
#30 writing layer sha256:52fbfeaf78318d843054ce2bfb5bfc9f71278939a815f6035ab5b14573ad017b done
#30 writing layer sha256:5792b18b6f162bae61ff5840cdb9e8567e6847a56ac886f940b47e7271c529a7 done
#30 writing layer sha256:57f244836ad318f9bbb3b29856ae1a5b31038bfbb9b43d2466d51c199eb55041 done
#30 writing layer sha256:5b5b131e0f20db4cb8e568b623a95f8fc16ed1c6b322a9366df70b59a881f24f done
#30 writing layer sha256:5ccb787d371fd3697122101438ddd0f55b537832e9756d2c51ab1d8158710ac5 done
#30 writing layer sha256:5ea668ffc2fc267d241dbf17ca283bc879643a189be4f7e3d9034a82fc64a1ea done
#30 writing layer sha256:62452179df7c18e292f141d4aec29e6aba9ff8270c893731169fc6f41dc07631 done
#30 writing layer sha256:6630c387f5f2115bca2e646fd0c2f64e1f3d5431c2e050abe607633883eda230 done
#30 writing layer sha256:69af4b756272a77f683a8d118fd5ca55c03ad5f1bacc673b463f54d16b833da5 done
#30 writing layer sha256:6ae1f1fb92c0cb2b6e219f687b08c8e511501a7af696c943ca20d119eba7cd02 done
#30 writing layer sha256:6deb3d550b15a5e099c0b3d0cbc242e351722ca16c058d3a6c28ba1a02824d0f done
#30 writing layer sha256:6e80a527af94a864094c4f9116c2d29d3d7548ec8388579d9cf3f8a39a4b8178 done
#30 writing layer sha256:7386814d57100e2c7389fbf4e16f140f5c549d31434c62c3884a85a3ee5cd2a7 done
#30 writing layer sha256:7852b73ea931e3a8d3287ee7ef3cf4bad068e44f046583bfc2b81336fb299284 done
#30 writing layer sha256:7e73869c74822e4539e104a3d2aff853f4622cd0bb873576db1db53c9e91f621 done
#30 writing layer sha256:7eae142b38745fe88962874372374deb672998600264a17e638c010b79e6b535 done
#30 writing layer sha256:7f2e5ab2c599fa36698918d3e73c991d8616fff9037077cd230529e7cd1c5e0e done
#30 writing layer sha256:82a3436133b2b17bb407c7fe488932aa0ca55411f23ab55c34a6134b287c6a27 done
#30 writing layer sha256:90eae6faa5cc5ba62f12c25915cdfb1a7a51abfba0d05cb5818c3f908f4e345f
#30 preparing build cache for export 0.7s done
#30 writing layer sha256:90eae6faa5cc5ba62f12c25915cdfb1a7a51abfba0d05cb5818c3f908f4e345f done
#30 writing layer sha256:9ac855545fa90ed2bf3b388fdff9ef06ac9427b0c0fca07c9e59161983d8827e done
#30 writing layer sha256:9d19ee268e0d7bcf6716e6658ee1b0384a71d6f2f9aa1ae2085610cf7c7b316f done
#30 writing layer sha256:a10c8d7d2714eabf661d1f43a1ccb87a51748cbb9094d5bc0b713e2481b5d329 done
#30 writing layer sha256:a1748eee9d376f97bd19225ba61dfada9986f063f4fc429e435f157abb629fc6 done
#30 writing layer sha256:a68f4e0ec09ec3b78cb4cf8e4511d658e34e7b6f676d7806ad9703194ff17604 done
#30 writing layer sha256:a8e4decc8f7289623b8fd7b9ba1ca555b5a755ebdbf81328d68209f148d9e602 done
#30 writing layer sha256:afde1c269453ce68a0f2b54c1ba8c5ecddeb18a19e5618a4acdef1f0fe3921af done
#30 writing layer sha256:b48a5fafcaba74eb5d7e7665601509e2889285b50a04b5b639a23f8adc818157 done
#30 writing layer sha256:ba9f7c75e4dd7942b944679995365aab766d3677da2e69e1d74472f471a484dd done
#30 writing layer sha256:bc42865e1c27a9b1bee751f3c99ad2c12a906d32aca396ace7a07231c9cafbd1 done
#30 writing layer sha256:bdfc73b2a0fa11b4086677e117a2f9feb6b4ffeccb23a3d58a30543339607e31 done
#30 writing layer sha256:c175bb235295e50de2961fa1e1a2235c57e6eba723a914287dfc26d3be0eac11 done
#30 writing layer sha256:c98533d2908f36a5e9b52faae83809b3b6865b50e90e2817308acfc64cd3655f done
#30 writing layer sha256:cb6c95b33bc30dd285c5b3cf99a05281b8f12decae1c932ab64bd58f56354021 done
#30 writing layer sha256:d6b5d6e098aacb316146a428c6b5aef9692011c6dce0932e3bbfbf27a514b7ed done
#30 writing layer sha256:d7da5c5e9a40c476c4b3188a845e3276dedfd752e015ea5113df5af64d4d43f7 done
#30 writing layer sha256:e4297ff4df6f7a8f25cb109e5b24483c314c2e72b8e824f9669173919fc159c9 done
#30 writing layer sha256:e4aedc686433c0ec5e676e6cc54a164345f7016aa0eb714f00c07e11664a1168 done
#30 writing layer sha256:e8640a108802cd7519cc53dceb74f7a5c94b562662f1c3c040c2aa6571acf0f3 done
#30 writing layer sha256:e8acb678f16bc0c369d5cf9c184f2d3a1c773986816526e5e3e9c0354f7e757f done
#30 writing layer sha256:e9225f7ab6606813ec9acba98a064826ebfd6713a9645a58cd068538af1ecddb done
#30 writing layer sha256:f33546e75bf1a7d9dc9e21b9a2c54c9d09b24790ad7a4192a8509002ceb14688 done
#30 writing layer sha256:f608e2fbff86e98627b7e462057e7d2416522096d73fe4664b82fe6ce8a4047d done
#30 writing layer sha256:f7702077ced42a1ee35e7f5e45f72634328ff3bcfe3f57735ba80baa5ec45daf done
#30 writing layer sha256:fa66a49172c6e821a1bace57c007c01da10cbc61507c44f8cdfeed8c4e5febab done
#30 writing config sha256:374c8d5f4f72f0b0a709492f153a59ebd070e903971483f2cefb3e9e45bda48a 0.0s done
#30 writing cache manifest sha256:00a618573e1678dbe93ffce1675eee120201710fa121207b43875632f6799a58 0.0s done
#30 DONE 0.7s
[2024-04-10 16:43:00,358] [INFO] (packager) - Build Summary:

Platform: x64-workstation/dgpu
    Status:     Succeeded
    Docker Tag: my_app-x64-workstation-dgpu-linux-amd64:1.0
    Tarball:    None

We can see that the MAP Docker image is created.

We can choose to display and inspect the MAP manifests by running the container with the show command, as well as extracting the manifests and other contents in the MAP by using the extract command, but not demonstrated in this example.

!docker image ls | grep {tag_prefix}
my_app-x64-workstation-dgpu-linux-amd64                                                   1.0                        5effa5125f3b   About a minute ago   17.5GB

Executing packaged app locally

The packaged app can be run locally through MONAI Application Runner.

# Clear the output folder and run the MAP. The input is expected to be a folder.
!echo $HOLOSCAN_OUTPUT_PATH
!echo $HOLOSCAN_INPUT_PATH
!rm -rf $HOLOSCAN_OUTPUT_PATH
!monai-deploy run -i $HOLOSCAN_INPUT_PATH -o $HOLOSCAN_OUTPUT_PATH my_app-x64-workstation-dgpu-linux-amd64:1.0
output
dcm
[2024-04-10 16:43:03,135] [INFO] (runner) - Checking dependencies...
[2024-04-10 16:43:03,135] [INFO] (runner) - --> Verifying if "docker" is installed...

[2024-04-10 16:43:03,135] [INFO] (runner) - --> Verifying if "docker-buildx" is installed...

[2024-04-10 16:43:03,135] [INFO] (runner) - --> Verifying if "my_app-x64-workstation-dgpu-linux-amd64:1.0" is available...

[2024-04-10 16:43:03,211] [INFO] (runner) - Reading HAP/MAP manifest...
Preparing to copy...?25lCopying from container - 0B?25hSuccessfully copied 2.56kB to /tmp/tmpbbeybrjp/app.json
Preparing to copy...?25lCopying from container - 0B?25hSuccessfully copied 2.05kB to /tmp/tmpbbeybrjp/pkg.json
[2024-04-10 16:43:03,567] [INFO] (runner) - --> Verifying if "nvidia-ctk" is installed...

[2024-04-10 16:43:03,568] [INFO] (runner) - --> Verifying "nvidia-ctk" version...

[2024-04-10 16:43:03,864] [INFO] (common) - Launching container (5135fc45ca94) using image 'my_app-x64-workstation-dgpu-linux-amd64:1.0'...
    container name:      optimistic_jang
    host name:           mingq-dt
    network:             host
    user:                1000:1000
    ulimits:             memlock=-1:-1, stack=67108864:67108864
    cap_add:             CAP_SYS_PTRACE
    ipc mode:            host
    shared memory size:  67108864
    devices:             
    group_add:           44
2024-04-10 23:43:05 [INFO] Launching application python3 /opt/holoscan/app ...

[2024-04-10 23:43:09,781] [INFO] (root) - Parsed args: Namespace(log_level=None, input=None, output=None, model=None, workdir=None, argv=['/opt/holoscan/app'])

[2024-04-10 23:43:09,784] [INFO] (root) - AppContext object: AppContext(input_path=/var/holoscan/input, output_path=/var/holoscan/output, model_path=/opt/holoscan/models, workdir=/var/holoscan)

[2024-04-10 23:43:09,784] [INFO] (app.AISpleenSegApp) - App input and output path: /var/holoscan/input, /var/holoscan/output

[info] [app_driver.cpp:1161] Launching the driver/health checking service

[info] [gxf_executor.cpp:211] Creating context

[info] [server.cpp:87] Health checking server listening on 0.0.0.0:8777

[info] [gxf_executor.cpp:1674] Loading extensions from configs...

[info] [gxf_executor.cpp:1864] Activating Graph...

[info] [gxf_executor.cpp:1894] Running Graph...

[info] [gxf_executor.cpp:1896] Waiting for completion...

[info] [gxf_executor.cpp:1897] Graph execution waiting. Fragment: 

[info] [greedy_scheduler.cpp:190] Scheduling 6 entities

[2024-04-10 23:43:09,895] [INFO] (monai.deploy.operators.dicom_data_loader_operator.DICOMDataLoaderOperator) - No or invalid input path from the optional input port: None

[2024-04-10 23:43:10,851] [INFO] (root) - Finding series for Selection named: CT Series

[2024-04-10 23:43:10,851] [INFO] (root) - Searching study, : 1.3.6.1.4.1.14519.5.2.1.7085.2626.822645453932810382886582736291

  # of series: 1

[2024-04-10 23:43:10,851] [INFO] (root) - Working on series, instance UID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.119403521930927333027265674239

[2024-04-10 23:43:10,851] [INFO] (root) - On attribute: 'StudyDescription' to match value: '(.*?)'

[2024-04-10 23:43:10,852] [INFO] (root) -     Series attribute StudyDescription value: CT ABDOMEN W IV CONTRAST

[2024-04-10 23:43:10,852] [INFO] (root) - Series attribute string value did not match. Try regEx.

[2024-04-10 23:43:10,852] [INFO] (root) - On attribute: 'Modality' to match value: '(?i)CT'

[2024-04-10 23:43:10,852] [INFO] (root) -     Series attribute Modality value: CT

[2024-04-10 23:43:10,852] [INFO] (root) - Series attribute string value did not match. Try regEx.

[2024-04-10 23:43:10,852] [INFO] (root) - On attribute: 'SeriesDescription' to match value: '(.*?)'

[2024-04-10 23:43:10,852] [INFO] (root) -     Series attribute SeriesDescription value: ABD/PANC 3.0 B31f

[2024-04-10 23:43:10,852] [INFO] (root) - Series attribute string value did not match. Try regEx.

[2024-04-10 23:43:10,852] [INFO] (root) - On attribute: 'ImageType' to match value: ['PRIMARY', 'ORIGINAL']

[2024-04-10 23:43:10,852] [INFO] (root) -     Series attribute ImageType value: None

[2024-04-10 23:43:10,852] [INFO] (root) - Selected Series, UID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.119403521930927333027265674239

[2024-04-10 23:43:11,263] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - Converted Image object metadata:

[2024-04-10 23:43:11,263] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - SeriesInstanceUID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.119403521930927333027265674239, type <class 'str'>

[2024-04-10 23:43:11,263] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - SeriesDate: 20090831, type <class 'str'>

[2024-04-10 23:43:11,263] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - SeriesTime: 101721.452, type <class 'str'>

[2024-04-10 23:43:11,263] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - Modality: CT, type <class 'str'>

[2024-04-10 23:43:11,263] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - SeriesDescription: ABD/PANC 3.0 B31f, type <class 'str'>

[2024-04-10 23:43:11,263] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - PatientPosition: HFS, type <class 'str'>

[2024-04-10 23:43:11,263] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - SeriesNumber: 8, type <class 'int'>

[2024-04-10 23:43:11,263] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - row_pixel_spacing: 0.7890625, type <class 'float'>

[2024-04-10 23:43:11,264] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - col_pixel_spacing: 0.7890625, type <class 'float'>

[2024-04-10 23:43:11,264] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - depth_pixel_spacing: 1.5, type <class 'float'>

[2024-04-10 23:43:11,264] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - row_direction_cosine: [1.0, 0.0, 0.0], type <class 'list'>

[2024-04-10 23:43:11,264] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - col_direction_cosine: [0.0, 1.0, 0.0], type <class 'list'>

[2024-04-10 23:43:11,264] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - depth_direction_cosine: [0.0, 0.0, 1.0], type <class 'list'>

[2024-04-10 23:43:11,264] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - dicom_affine_transform: [[   0.7890625    0.           0.        -197.60547  ]

 [   0.           0.7890625    0.        -398.60547  ]

 [   0.           0.           1.5       -383.       ]

 [   0.           0.           0.           1.       ]], type <class 'numpy.ndarray'>

[2024-04-10 23:43:11,264] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - nifti_affine_transform: [[  -0.7890625   -0.          -0.         197.60547  ]

 [  -0.          -0.7890625   -0.         398.60547  ]

 [   0.           0.           1.5       -383.       ]

 [   0.           0.           0.           1.       ]], type <class 'numpy.ndarray'>

[2024-04-10 23:43:11,264] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - StudyInstanceUID: 1.3.6.1.4.1.14519.5.2.1.7085.2626.822645453932810382886582736291, type <class 'str'>

[2024-04-10 23:43:11,264] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - StudyID: , type <class 'str'>

[2024-04-10 23:43:11,264] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - StudyDate: 20090831, type <class 'str'>

[2024-04-10 23:43:11,264] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - StudyTime: 095948.599, type <class 'str'>

[2024-04-10 23:43:11,264] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - StudyDescription: CT ABDOMEN W IV CONTRAST, type <class 'str'>

[2024-04-10 23:43:11,264] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - AccessionNumber: 5471978513296937, type <class 'str'>

[2024-04-10 23:43:11,264] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - selection_name: CT Series, type <class 'str'>

2024-04-10 23:43:12,277 INFO image_writer.py:197 - writing: /var/holoscan/output/saved_images_folder/1.3.6.1.4.1.14519.5.2.1.7085.2626/1.3.6.1.4.1.14519.5.2.1.7085.2626.nii

2024-04-10 23:43:16,177 INFO image_writer.py:197 - writing: /var/holoscan/output/saved_images_folder/1.3.6.1.4.1.14519.5.2.1.7085.2626/1.3.6.1.4.1.14519.5.2.1.7085.2626_seg.nii

[2024-04-10 23:43:17,870] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - Output Seg image numpy array shaped: (204, 512, 512)

[2024-04-10 23:43:17,876] [INFO] (monai.deploy.operators.monai_seg_inference_operator.MonaiSegInferenceOperator) - Output Seg image pixel max value: 1

/home/holoscan/.local/lib/python3.10/site-packages/highdicom/valuerep.py:54: UserWarning: The string "C3N-00198" is unlikely to represent the intended person name since it contains only a single component. Construct a person name according to the format in described in https://dicom.nema.org/dicom/2013/output/chtml/part05/sect_6.2.html#sect_6.2.1.2, or, in pydicom 2.2.0 or later, use the pydicom.valuerep.PersonName.from_named_components() method to construct the person name correctly. If a single-component name is really intended, add a trailing caret character to disambiguate the name.

  warnings.warn(

[2024-04-10 23:43:19,386] [INFO] (highdicom.base) - copy Image-related attributes from dataset "1.3.6.1.4.1.14519.5.2.1.7085.2626.936983343951485811186213470191"

[2024-04-10 23:43:19,386] [INFO] (highdicom.base) - copy attributes of module "Specimen"

[2024-04-10 23:43:19,386] [INFO] (highdicom.base) - copy Patient-related attributes from dataset "1.3.6.1.4.1.14519.5.2.1.7085.2626.936983343951485811186213470191"

[2024-04-10 23:43:19,386] [INFO] (highdicom.base) - copy attributes of module "Patient"

[2024-04-10 23:43:19,387] [INFO] (highdicom.base) - copy attributes of module "Clinical Trial Subject"

[2024-04-10 23:43:19,387] [INFO] (highdicom.base) - copy Study-related attributes from dataset "1.3.6.1.4.1.14519.5.2.1.7085.2626.936983343951485811186213470191"

[2024-04-10 23:43:19,387] [INFO] (highdicom.base) - copy attributes of module "General Study"

[2024-04-10 23:43:19,387] [INFO] (highdicom.base) - copy attributes of module "Patient Study"

[2024-04-10 23:43:19,388] [INFO] (highdicom.base) - copy attributes of module "Clinical Trial Study"

[info] [greedy_scheduler.cpp:369] Scheduler stopped: Some entities are waiting for execution, but there are no periodic or async entities to get out of the deadlock.

[info] [greedy_scheduler.cpp:398] Scheduler finished.

[info] [gxf_executor.cpp:1906] Graph execution deactivating. Fragment: 

[info] [gxf_executor.cpp:1907] Deactivating Graph...

[info] [gxf_executor.cpp:1910] Graph execution finished. Fragment: 

[2024-04-10 23:43:19,487] [INFO] (app.AISpleenSegApp) - End run

[info] [gxf_executor.cpp:230] Destroying context

[2024-04-10 16:43:21,271] [INFO] (common) - Container 'optimistic_jang'(5135fc45ca94) exited.
!ls -R $HOLOSCAN_OUTPUT_PATH
output:
1.2.826.0.1.3680043.10.511.3.10550615266418892085330010762562517.dcm
saved_images_folder

output/saved_images_folder:
1.3.6.1.4.1.14519.5.2.1.7085.2626

output/saved_images_folder/1.3.6.1.4.1.14519.5.2.1.7085.2626:
1.3.6.1.4.1.14519.5.2.1.7085.2626.nii
1.3.6.1.4.1.14519.5.2.1.7085.2626_seg.nii