Deploying a MedNIST Classifier App with MONAI Deploy App SDK

This tutorial demos the process of packaging up a trained model using MONAI Deploy App SDK into an artifact which can be run as a local program performing inference, a workflow job doing the same, and a Docker containerized workflow execution.

In this tutorial, we will train a MedNIST classifier like the MONAI tutorial here and then implement & package the inference application, executing the application locally.

Train a MedNIST classifier model with MONAI Core

Setup environment

# Install necessary packages for MONAI Core
!python -c "import monai" || pip install -q "monai[pillow, tqdm]"
!python -c "import ignite" || pip install -q "monai[ignite]"
!python -c "import gdown" || pip install -q "monai[gdown]"
!python -c "import pydicom" || pip install -q "pydicom>=1.4.2"
!python -c "import highdicom" || pip install -q "highdicom>=0.18.2"  # for the use of DICOM Writer operators

# Install MONAI Deploy App SDK package
!python -c "import monai.deploy" || pip install -q "monai-deploy-app-sdk"

Setup imports

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import glob
import PIL.Image
import torch
import numpy as np

from ignite.engine import Events

from monai.apps import download_and_extract
from monai.config import print_config
from monai.networks.nets import DenseNet121
from monai.engines import SupervisedTrainer
from monai.transforms import (
    EnsureChannelFirst,
    Compose,
    LoadImage,
    RandFlip,
    RandRotate,
    RandZoom,
    ScaleIntensity,
    EnsureType,
)
from monai.utils import set_determinism

set_determinism(seed=0)

print_config()
MONAI version: 1.3.0
Numpy version: 1.26.4
Pytorch version: 2.0.1+cu117
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 865972f7a791bf7b42efbcd87c8402bd865b329e
MONAI __file__: /home/<username>/src/monai-deploy-app-sdk/.venv/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: 0.22.0
scipy version: 1.13.0
Pillow version: 10.3.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: 4.7.3
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.66.2
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.6
pandas version: NOT INSTALLED or UNKNOWN VERSION.
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

Download dataset

The MedNIST dataset was gathered from several sets from TCIA, the RSNA Bone Age Challenge(https://www.rsna.org/education/ai-resources-and-training/ai-image-challenge/rsna-pediatric-bone-age-challenge-2017), and the NIH Chest X-ray dataset.

The dataset is kindly made available by Dr. Bradley J. Erickson M.D., Ph.D. (Department of Radiology, Mayo Clinic) under the Creative Commons CC BY-SA 4.0 license.

If you use the MedNIST dataset, please acknowledge the source.

directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

resource = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE"
md5 = "0bc7306e7427e00ad1c5526a6677552d"

compressed_file = os.path.join(root_dir, "MedNIST.tar.gz")
data_dir = os.path.join(root_dir, "MedNIST")
if not os.path.exists(data_dir):
    download_and_extract(resource, compressed_file, root_dir, md5)
/tmp/tmpz1q9ch5_
Downloading...
From (original): https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE
From (redirected): https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE&confirm=t&uuid=6052caf3-cb8c-4cb3-b8b8-804d8dc90e06
To: /tmp/tmpq9pcg2c8/MedNIST.tar.gz
100%|██████████| 61.8M/61.8M [00:00<00:00, 81.1MB/s]
2024-04-10 16:28:05,802 - INFO - Downloaded: /tmp/tmpz1q9ch5_/MedNIST.tar.gz

2024-04-10 16:28:05,912 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.
2024-04-10 16:28:05,914 - INFO - Writing into directory: /tmp/tmpz1q9ch5_.
subdirs = sorted(glob.glob(f"{data_dir}/*/"))

class_names = [os.path.basename(sd[:-1]) for sd in subdirs]
image_files = [glob.glob(f"{sb}/*") for sb in subdirs]

image_files_list = sum(image_files, [])
image_class = sum(([i] * len(f) for i, f in enumerate(image_files)), [])
image_width, image_height = PIL.Image.open(image_files_list[0]).size

print(f"Label names: {class_names}")
print(f"Label counts: {list(map(len, image_files))}")
print(f"Total image count: {len(image_class)}")
print(f"Image dimensions: {image_width} x {image_height}")
Label names: ['AbdomenCT', 'BreastMRI', 'CXR', 'ChestCT', 'Hand', 'HeadCT']
Label counts: [10000, 8954, 10000, 10000, 10000, 10000]
Total image count: 58954
Image dimensions: 64 x 64

Setup and train

Here we’ll create a transform sequence and train the network, omitting validation and testing since we know this does indeed work and it’s not needed here:

train_transforms = Compose(
    [
        LoadImage(image_only=True),
        EnsureChannelFirst(channel_dim="no_channel"),
        ScaleIntensity(),
        RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
        RandFlip(spatial_axis=0, prob=0.5),
        RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
        EnsureType(),
    ]
)
class MedNISTDataset(torch.utils.data.Dataset):
    def __init__(self, image_files, labels, transforms):
        self.image_files = image_files
        self.labels = labels
        self.transforms = transforms

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        return self.transforms(self.image_files[index]), self.labels[index]


# just one dataset and loader, we won't bother with validation or testing 
train_ds = MedNISTDataset(image_files_list, image_class, train_transforms)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = DenseNet121(spatial_dims=2, in_channels=1, out_channels=len(class_names)).to(device)
loss_function = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(net.parameters(), 1e-5)
max_epochs = 5
def _prepare_batch(batch, device, non_blocking):
    return tuple(b.to(device) for b in batch)


trainer = SupervisedTrainer(device, max_epochs, train_loader, net, opt, loss_function, prepare_batch=_prepare_batch)


@trainer.on(Events.EPOCH_COMPLETED)
def _print_loss(engine):
    print(f"Epoch {engine.state.epoch}/{engine.state.max_epochs} Loss: {engine.state.output[0]['loss']}")


trainer.run()
Epoch 1/5 Loss: 0.18928290903568268
Epoch 2/5 Loss: 0.06710730493068695
Epoch 3/5 Loss: 0.029032323509454727
Epoch 4/5 Loss: 0.01877668686211109
Epoch 5/5 Loss: 0.01939055137336254

The network will be saved out here as a Torchscript object named classifier.zip

torch.jit.script(net).save("classifier.zip")

Implementing and Packaging Application with MONAI Deploy App SDK

Based on the Torchscript model(classifier.zip), we will implement an application that process an input Jpeg image and write the prediction(classification) result as JSON file(output.json).

Creating Operators and connecting them in Application class

We used the following train transforms as pre-transforms during the training.

Train transforms used in training
 1train_transforms = Compose(
 2    [
 3        LoadImage(image_only=True),
 4        EnsureChannelFirst(channel_dim="no_channel"),
 5        ScaleIntensity(),
 6        RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
 7        RandFlip(spatial_axis=0, prob=0.5),
 8        RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
 9        EnsureType(),
10    ]
11)

RandRotate, RandFlip, and RandZoom transforms are used only for training and those are not necessary during the inference.

In our inference application, we will define two operators:

  1. LoadPILOperator - Load a JPEG image from the input path and pass the loaded image object to the next operator.

    • This Operator does similar job with LoadImage(image_only=True) transform in train_transforms, but handles only one image.

    • Input: a file path (Path)

    • Output: an image object in memory (Image)

  2. MedNISTClassifierOperator - Pre-transform the given image by using MONAI’s Compose class, feed to the Torchscript model (classifier.zip), and write the prediction into JSON file(output.json)

    • Pre-transforms consist of three transforms – EnsureChannelFirst, ScaleIntensity, and EnsureType.

    • Input: an image object in memory (Image)

    • Output: a folder path that the prediction result(output.json) would be written (DataPath)

The workflow of the application would look like this.

%%{init: {"theme": "base", "themeVariables": { "fontSize": "16px"}} }%% classDiagram direction LR LoadPILOperator --|> MedNISTClassifierOperator : image...image class LoadPILOperator { <in>image : DISK image(out) IN_MEMORY } class MedNISTClassifierOperator { <in>image : IN_MEMORY output(out) DISK }

Set up environment variables

Before proceeding to the application building and packaging, we first need to set the well-known environment variables, because the application parses them for the input, output, and model folders. Defaults are used if these environment variable are absent.

Set the environment variables corresponding to the extracted data path.

input_folder = "input"
output_foler = "output"
models_folder = "models"

# Choose a file as test input
test_input_path = image_files[0][0]
!rm -rf {input_folder} && mkdir -p {input_folder} && cp {test_input_path} {input_folder} && ls {input_folder}
# Need to copy the model file to its own clean subfolder for pacakging, to workaround an issue in the Packager
!rm -rf {models_folder} && mkdir -p {models_folder}/model && cp classifier.zip {models_folder}/model && ls {models_folder}/model

%env HOLOSCAN_INPUT_PATH {input_folder}
%env HOLOSCAN_OUTPUT_PATH {output_foler}
%env HOLOSCAN_MODEL_PATH {models_folder}
001420.jpeg
classifier.zip
env: HOLOSCAN_INPUT_PATH=input
env: HOLOSCAN_OUTPUT_PATH=output
env: HOLOSCAN_MODEL_PATH=models

Setup imports

Let’s import necessary classes/decorators and define MEDNIST_CLASSES.

import logging
import os
from pathlib import Path
from typing import Optional

import torch

from monai.deploy.conditions import CountCondition
from monai.deploy.core import AppContext, Application, ConditionType, Fragment, Image, Operator, OperatorSpec
from monai.deploy.operators.dicom_text_sr_writer_operator import DICOMTextSRWriterOperator, EquipmentInfo, ModelInfo
from monai.transforms import EnsureChannelFirst, Compose, EnsureType, ScaleIntensity

MEDNIST_CLASSES = ["AbdomenCT", "BreastMRI", "CXR", "ChestCT", "Hand", "HeadCT"]

Creating Operator classes

LoadPILOperator
class LoadPILOperator(Operator):
    """Load image from the given input (DataPath) and set numpy array to the output (Image)."""

    DEFAULT_INPUT_FOLDER = Path.cwd() / "input"
    DEFAULT_OUTPUT_NAME = "image"

    # For now, need to have the input folder as an instance attribute, set on init.
    # If dynamically changing the input folder, per compute, then use a (optional) input port to convey the
    # value of the input folder, which is then emitted by a upstream operator.
    def __init__(
        self,
        fragment: Fragment,
        *args,
        input_folder: Path = DEFAULT_INPUT_FOLDER,
        output_name: str = DEFAULT_OUTPUT_NAME,
        **kwargs,
    ):
        """Creates an loader object with the input folder and the output port name overrides as needed.

        Args:
            fragment (Fragment): An instance of the Application class which is derived from Fragment.
            input_folder (Path): Folder from which to load input file(s).
                                 Defaults to `input` in the current working directory.
            output_name (str): Name of the output port, which is an image object. Defaults to `image`.
        """

        self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
        self.input_path = input_folder
        self.index = 0
        self.output_name_image = (
            output_name.strip() if output_name and len(output_name.strip()) > 0 else LoadPILOperator.DEFAULT_OUTPUT_NAME
        )

        super().__init__(fragment, *args, **kwargs)

    def setup(self, spec: OperatorSpec):
        """Set up the named input and output port(s)"""
        spec.output(self.output_name_image)

    def compute(self, op_input, op_output, context):
        import numpy as np
        from PIL import Image as PILImage

        # Input path is stored in the object attribute, but could change to use a named port if need be.
        input_path = self.input_path
        if input_path.is_dir():
            input_path = next(self.input_path.glob("*.*"))  # take the first file

        image = PILImage.open(input_path)
        image = image.convert("L")  # convert to greyscale image
        image_arr = np.asarray(image)

        output_image = Image(image_arr)  # create Image domain object with a numpy array
        op_output.emit(output_image, self.output_name_image)  # cannot omit the name even if single output.
MedNISTClassifierOperator
class MedNISTClassifierOperator(Operator):
    """Classifies the given image and returns the class name.

    Named inputs:
        image: Image object for which to generate the classification.
        output_folder: Optional, the path to save the results JSON file, overridingthe the one set on __init__

    Named output:
        result_text: The classification results in text.
    """

    DEFAULT_OUTPUT_FOLDER = Path.cwd() / "classification_results"
    # For testing the app directly, the model should be at the following path.
    MODEL_LOCAL_PATH = Path(os.environ.get("HOLOSCAN_MODEL_PATH", Path.cwd() / "model/model.ts"))

    def __init__(
        self,
        frament: Fragment,
        *args,
        app_context: AppContext,
        model_name: Optional[str] = "",
        model_path: Path = MODEL_LOCAL_PATH,
        output_folder: Path = DEFAULT_OUTPUT_FOLDER,
        **kwargs,
    ):
        """Creates an instance with the reference back to the containing application/fragment.

        fragment (Fragment): An instance of the Application class which is derived from Fragment.
        model_name (str, optional): Name of the model. Default to "" for single model app.
        model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir.
        output_folder (Path, optional): output folder for saving the classification results JSON file.
        """

        # the names used for the model inference input and output
        self._input_dataset_key = "image"
        self._pred_dataset_key = "pred"

        # The names used for the operator input and output
        self.input_name_image = "image"
        self.output_name_result = "result_text"

        # The name of the optional input port for passing data to override the output folder path.
        self.input_name_output_folder = "output_folder"

        # The output folder set on the object can be overriden at each compute by data in the optional named input
        self.output_folder = output_folder

        # Need the name when there are multiple models loaded
        self._model_name = model_name.strip() if isinstance(model_name, str) else ""
        # Need the path to load the models when they are not loaded in the execution context
        self.model_path = model_path
        self.app_context = app_context
        self.model = self._get_model(self.app_context, self.model_path, self._model_name)

        # This needs to be at the end of the constructor.
        super().__init__(frament, *args, **kwargs)

    def _get_model(self, app_context: AppContext, model_path: Path, model_name: str):
        """Load the model with the given name from context or model path

        Args:
            app_context (AppContext): The application context object holding the model(s)
            model_path (Path): The path to the model file, as a backup to load model directly
            model_name (str): The name of the model, when multiples are loaded in the context
        """

        if app_context.models:
            # `app_context.models.get(model_name)` returns a model instance if exists.
            # If model_name is not specified and only one model exists, it returns that model.
            model = app_context.models.get(model_name)
        else:
            model = torch.jit.load(
                MedNISTClassifierOperator.MODEL_LOCAL_PATH,
                map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
            )

        return model

    def setup(self, spec: OperatorSpec):
        """Set up the operator named input and named output, both are in-memory objects."""

        spec.input(self.input_name_image)
        spec.input(self.input_name_output_folder).condition(ConditionType.NONE)  # Optional for overriding.
        spec.output(self.output_name_result).condition(ConditionType.NONE)  # Not forcing a downstream receiver.

    @property
    def transform(self):
        return Compose([EnsureChannelFirst(channel_dim="no_channel"), ScaleIntensity(), EnsureType()])

    def compute(self, op_input, op_output, context):
        import json

        import torch

        img = op_input.receive(self.input_name_image).asnumpy()  # (64, 64), uint8. Input validation can be added.
        image_tensor = self.transform(img)  # (1, 64, 64), torch.float64
        image_tensor = image_tensor[None].float()  # (1, 1, 64, 64), torch.float32

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        image_tensor = image_tensor.to(device)

        with torch.no_grad():
            outputs = self.model(image_tensor)

        _, output_classes = outputs.max(dim=1)

        result = MEDNIST_CLASSES[output_classes[0]]  # get the class name
        print(result)
        op_output.emit(result, self.output_name_result)

        # Get output folder, with value in optional input port overriding the obj attribute
        output_folder_on_compute = op_input.receive(self.input_name_output_folder) or self.output_folder
        Path.mkdir(output_folder_on_compute, parents=True, exist_ok=True)  # Let exception bubble up if raised.
        output_path = output_folder_on_compute / "output.json"
        with open(output_path, "w") as fp:
            json.dump(result, fp)

Creating Application class

Our application class would look like below.

It defines App class inheriting Application class.

LoadPILOperator is connected to MedNISTClassifierOperator by using self.add_flow() in compose() method of App.

class App(Application):
    """Application class for the MedNIST classifier."""

    def compose(self):
        app_context = Application.init_app_context({})  # Do not pass argv in Jupyter Notebook
        app_input_path = Path(app_context.input_path)
        app_output_path = Path(app_context.output_path)
        model_path = Path(app_context.model_path)
        load_pil_op = LoadPILOperator(self, CountCondition(self, 1), input_folder=app_input_path, name="pil_loader_op")
        classifier_op = MedNISTClassifierOperator(
            self, app_context=app_context, output_folder=app_output_path, model_path=model_path, name="classifier_op"
        )

        my_model_info = ModelInfo("MONAI WG Trainer", "MEDNIST Classifier", "0.1", "xyz")
        my_equipment = EquipmentInfo(manufacturer="MOANI Deploy App SDK", manufacturer_model="DICOM SR Writer")
        my_special_tags = {"SeriesDescription": "Not for clinical use. The result is for research use only."}
        dicom_sr_operator = DICOMTextSRWriterOperator(
            self,
            copy_tags=False,
            model_info=my_model_info,
            equipment_info=my_equipment,
            custom_tags=my_special_tags,
            output_folder=app_output_path,
        )

        self.add_flow(load_pil_op, classifier_op, {("image", "image")})
        self.add_flow(classifier_op, dicom_sr_operator, {("result_text", "text")})

Executing app locally

The test input file file, output path, and model have been prepared, and the paths set in the environment variables, so we can go ahead and execute the application Jupyter notebook with a clean output folder.

!rm -rf $HOLOSCAN_OUTPUT_PATH
app = App().run()
[2024-04-10 16:35:53,768] [INFO] (root) - Parsed args: Namespace(log_level=None, input=None, output=None, model=None, workdir=None, argv=[])
[2024-04-10 16:35:53,778] [INFO] (root) - AppContext object: AppContext(input_path=input, output_path=output, model_path=models, workdir=)
[info] [gxf_executor.cpp:211] Creating context
[info] [gxf_executor.cpp:1674] Loading extensions from configs...
[info] [gxf_executor.cpp:1864] Activating Graph...
[info] [gxf_executor.cpp:1894] Running Graph...
[info] [gxf_executor.cpp:1896] Waiting for completion...
[info] [gxf_executor.cpp:1897] Graph execution waiting. Fragment: 
[info] [greedy_scheduler.cpp:190] Scheduling 3 entities
/home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.10/site-packages/monai/data/meta_tensor.py:116: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
  return torch.as_tensor(x, *args, **_kwargs).as_subclass(cls)
/home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.10/site-packages/pydicom/valuerep.py:443: UserWarning: Invalid value for VR UI: 'xyz'. Please see <https://dicom.nema.org/medical/dicom/current/output/html/part05.html#table_6.2-1> for allowed values for each VR.
  warnings.warn(msg)
[2024-04-10 16:35:54,545] [INFO] (root) - Finished writing DICOM instance to file output/1.2.826.0.1.3680043.8.498.91196297255331853052707757292596626343.dcm
[2024-04-10 16:35:54,548] [INFO] (monai.deploy.operators.dicom_text_sr_writer_operator.DICOMTextSRWriterOperator) - DICOM SOP instance saved in output/1.2.826.0.1.3680043.8.498.91196297255331853052707757292596626343.dcm
AbdomenCT
[info] [greedy_scheduler.cpp:369] Scheduler stopped: Some entities are waiting for execution, but there are no periodic or async entities to get out of the deadlock.
[info] [greedy_scheduler.cpp:398] Scheduler finished.
[info] [gxf_executor.cpp:1906] Graph execution deactivating. Fragment: 
[info] [gxf_executor.cpp:1907] Deactivating Graph...
[info] [gxf_executor.cpp:1910] Graph execution finished. Fragment: 
[info] [gxf_executor.cpp:230] Destroying context
!cat $HOLOSCAN_OUTPUT_PATH/output.json
"AbdomenCT"

Once the application is verified inside Jupyter notebook, we can write the whole application as a file(mednist_classifier_monaideploy.py) by concatenating code above, then add the following lines:

if __name__ == "__main__":
    App()

The above lines are needed to execute the application code by using python interpreter.

# Create an application folder
!mkdir -p mednist_app
!rm -rf mednist_app/*
%%writefile mednist_app/mednist_classifier_monaideploy.py

# Copyright 2021-2023 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from pathlib import Path
from typing import Optional

import torch

from monai.deploy.conditions import CountCondition
from monai.deploy.core import AppContext, Application, ConditionType, Fragment, Image, Operator, OperatorSpec
from monai.deploy.operators.dicom_text_sr_writer_operator import DICOMTextSRWriterOperator, EquipmentInfo, ModelInfo
from monai.transforms import EnsureChannelFirst, Compose, EnsureType, ScaleIntensity

MEDNIST_CLASSES = ["AbdomenCT", "BreastMRI", "CXR", "ChestCT", "Hand", "HeadCT"]


# @md.env(pip_packages=["pillow"])
class LoadPILOperator(Operator):
    """Load image from the given input (DataPath) and set numpy array to the output (Image)."""

    DEFAULT_INPUT_FOLDER = Path.cwd() / "input"
    DEFAULT_OUTPUT_NAME = "image"

    # For now, need to have the input folder as an instance attribute, set on init.
    # If dynamically changing the input folder, per compute, then use a (optional) input port to convey the
    # value of the input folder, which is then emitted by a upstream operator.
    def __init__(
        self,
        fragment: Fragment,
        *args,
        input_folder: Path = DEFAULT_INPUT_FOLDER,
        output_name: str = DEFAULT_OUTPUT_NAME,
        **kwargs,
    ):
        """Creates an loader object with the input folder and the output port name overrides as needed.

        Args:
            fragment (Fragment): An instance of the Application class which is derived from Fragment.
            input_folder (Path): Folder from which to load input file(s).
                                 Defaults to `input` in the current working directory.
            output_name (str): Name of the output port, which is an image object. Defaults to `image`.
        """

        self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
        self.input_path = input_folder
        self.index = 0
        self.output_name_image = (
            output_name.strip() if output_name and len(output_name.strip()) > 0 else LoadPILOperator.DEFAULT_OUTPUT_NAME
        )

        super().__init__(fragment, *args, **kwargs)

    def setup(self, spec: OperatorSpec):
        """Set up the named input and output port(s)"""
        spec.output(self.output_name_image)

    def compute(self, op_input, op_output, context):
        import numpy as np
        from PIL import Image as PILImage

        # Input path is stored in the object attribute, but could change to use a named port if need be.
        input_path = self.input_path
        if input_path.is_dir():
            input_path = next(self.input_path.glob("*.*"))  # take the first file

        image = PILImage.open(input_path)
        image = image.convert("L")  # convert to greyscale image
        image_arr = np.asarray(image)

        output_image = Image(image_arr)  # create Image domain object with a numpy array
        op_output.emit(output_image, self.output_name_image)  # cannot omit the name even if single output.


# @md.env(pip_packages=["monai"])
class MedNISTClassifierOperator(Operator):
    """Classifies the given image and returns the class name.

    Named inputs:
        image: Image object for which to generate the classification.
        output_folder: Optional, the path to save the results JSON file, overridingthe the one set on __init__

    Named output:
        result_text: The classification results in text.
    """

    DEFAULT_OUTPUT_FOLDER = Path.cwd() / "classification_results"
    # For testing the app directly, the model should be at the following path.
    MODEL_LOCAL_PATH = Path(os.environ.get("HOLOSCAN_MODEL_PATH", Path.cwd() / "model/model.ts"))

    def __init__(
        self,
        frament: Fragment,
        *args,
        app_context: AppContext,
        model_name: Optional[str] = "",
        model_path: Path = MODEL_LOCAL_PATH,
        output_folder: Path = DEFAULT_OUTPUT_FOLDER,
        **kwargs,
    ):
        """Creates an instance with the reference back to the containing application/fragment.

        fragment (Fragment): An instance of the Application class which is derived from Fragment.
        model_name (str, optional): Name of the model. Default to "" for single model app.
        model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir.
        output_folder (Path, optional): output folder for saving the classification results JSON file.
        """

        # the names used for the model inference input and output
        self._input_dataset_key = "image"
        self._pred_dataset_key = "pred"

        # The names used for the operator input and output
        self.input_name_image = "image"
        self.output_name_result = "result_text"

        # The name of the optional input port for passing data to override the output folder path.
        self.input_name_output_folder = "output_folder"

        # The output folder set on the object can be overriden at each compute by data in the optional named input
        self.output_folder = output_folder

        # Need the name when there are multiple models loaded
        self._model_name = model_name.strip() if isinstance(model_name, str) else ""
        # Need the path to load the models when they are not loaded in the execution context
        self.model_path = model_path
        self.app_context = app_context
        self.model = self._get_model(self.app_context, self.model_path, self._model_name)

        # This needs to be at the end of the constructor.
        super().__init__(frament, *args, **kwargs)

    def _get_model(self, app_context: AppContext, model_path: Path, model_name: str):
        """Load the model with the given name from context or model path

        Args:
            app_context (AppContext): The application context object holding the model(s)
            model_path (Path): The path to the model file, as a backup to load model directly
            model_name (str): The name of the model, when multiples are loaded in the context
        """

        if app_context.models:
            # `app_context.models.get(model_name)` returns a model instance if exists.
            # If model_name is not specified and only one model exists, it returns that model.
            model = app_context.models.get(model_name)
        else:
            model = torch.jit.load(
                MedNISTClassifierOperator.MODEL_LOCAL_PATH,
                map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
            )

        return model

    def setup(self, spec: OperatorSpec):
        """Set up the operator named input and named output, both are in-memory objects."""

        spec.input(self.input_name_image)
        spec.input(self.input_name_output_folder).condition(ConditionType.NONE)  # Optional for overriding.
        spec.output(self.output_name_result).condition(ConditionType.NONE)  # Not forcing a downstream receiver.

    @property
    def transform(self):
        return Compose([EnsureChannelFirst(channel_dim="no_channel"), ScaleIntensity(), EnsureType()])

    def compute(self, op_input, op_output, context):
        import json

        import torch

        img = op_input.receive(self.input_name_image).asnumpy()  # (64, 64), uint8. Input validation can be added.
        image_tensor = self.transform(img)  # (1, 64, 64), torch.float64
        image_tensor = image_tensor[None].float()  # (1, 1, 64, 64), torch.float32

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        image_tensor = image_tensor.to(device)

        with torch.no_grad():
            outputs = self.model(image_tensor)

        _, output_classes = outputs.max(dim=1)

        result = MEDNIST_CLASSES[output_classes[0]]  # get the class name
        print(result)
        op_output.emit(result, self.output_name_result)

        # Get output folder, with value in optional input port overriding the obj attribute
        output_folder_on_compute = op_input.receive(self.input_name_output_folder) or self.output_folder
        Path.mkdir(output_folder_on_compute, parents=True, exist_ok=True)  # Let exception bubble up if raised.
        output_path = output_folder_on_compute / "output.json"
        with open(output_path, "w") as fp:
            json.dump(result, fp)


# @md.resource(cpu=1, gpu=1, memory="1Gi")
class App(Application):
    """Application class for the MedNIST classifier."""

    def compose(self):
        app_context = AppContext({})  # Let it figure out all the attributes without overriding
        app_input_path = Path(app_context.input_path)
        app_output_path = Path(app_context.output_path)
        model_path = Path(app_context.model_path)
        load_pil_op = LoadPILOperator(self, CountCondition(self, 1), input_folder=app_input_path, name="pil_loader_op")
        classifier_op = MedNISTClassifierOperator(
            self, app_context=app_context, output_folder=app_output_path, model_path=model_path, name="classifier_op"
        )

        my_model_info = ModelInfo("MONAI WG Trainer", "MEDNIST Classifier", "0.1", "xyz")
        my_equipment = EquipmentInfo(manufacturer="MOANI Deploy App SDK", manufacturer_model="DICOM SR Writer")
        my_special_tags = {"SeriesDescription": "Not for clinical use. The result is for research use only."}
        dicom_sr_operator = DICOMTextSRWriterOperator(
            self,
            copy_tags=False,
            model_info=my_model_info,
            equipment_info=my_equipment,
            custom_tags=my_special_tags,
            output_folder=app_output_path,
        )

        self.add_flow(load_pil_op, classifier_op, {("image", "image")})
        self.add_flow(classifier_op, dicom_sr_operator, {("result_text", "text")})


if __name__ == "__main__":
    App().run()
Writing mednist_app/mednist_classifier_monaideploy.py

This time, let’s execute the app in the command line.

!python "mednist_app/mednist_classifier_monaideploy.py"
[info] [gxf_executor.cpp:211] Creating context
[info] [gxf_executor.cpp:1674] Loading extensions from configs...
[info] [gxf_executor.cpp:1864] Activating Graph...
[info] [gxf_executor.cpp:1894] Running Graph...
[info] [gxf_executor.cpp:1896] Waiting for completion...
[info] [gxf_executor.cpp:1897] Graph execution waiting. Fragment: 
[info] [greedy_scheduler.cpp:190] Scheduling 3 entities
/home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.10/site-packages/monai/data/meta_tensor.py:116: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
  return torch.as_tensor(x, *args, **_kwargs).as_subclass(cls)
AbdomenCT
/home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.10/site-packages/pydicom/valuerep.py:443: UserWarning: Invalid value for VR UI: 'xyz'. Please see <https://dicom.nema.org/medical/dicom/current/output/html/part05.html#table_6.2-1> for allowed values for each VR.
  warnings.warn(msg)
[info] [greedy_scheduler.cpp:369] Scheduler stopped: Some entities are waiting for execution, but there are no periodic or async entities to get out of the deadlock.
[info] [greedy_scheduler.cpp:398] Scheduler finished.
[info] [gxf_executor.cpp:1906] Graph execution deactivating. Fragment: 
[info] [gxf_executor.cpp:1907] Deactivating Graph...
[info] [gxf_executor.cpp:1910] Graph execution finished. Fragment: 
[info] [gxf_executor.cpp:230] Destroying context
!cat $HOLOSCAN_OUTPUT_PATH/output.json
"AbdomenCT"

Packaging app

Let’s package the app with MONAI Application Packager.

In this version of the App SDK, we need to write out the configuration yaml file as well as the package requirements file, in the application folder.

%%writefile mednist_app/app.yaml
%YAML 1.2
---
application:
  title: MONAI Deploy App Package - MedNIST Classifier App
  version: 1.0
  inputFormats: ["file"]
  outputFormats: ["file"]

resources:
  cpu: 1
  gpu: 1
  memory: 1Gi
  gpuMemory: 1Gi
Writing mednist_app/app.yaml
%%writefile mednist_app/requirements.txt
monai>=1.2.0
Pillow>=8.4.0
pydicom>=2.3.0
highdicom>=0.18.2
SimpleITK>=2.0.0
setuptools>=59.5.0 # for pkg_resources
Writing mednist_app/requirements.txt
tag_prefix = "mednist_app"

!monai-deploy package "mednist_app/mednist_classifier_monaideploy.py" -m {models_folder} -c "mednist_app/app.yaml" -t {tag_prefix}:1.0 --platform x64-workstation -l DEBUG
[2024-04-10 16:36:05,007] [INFO] (packager.parameters) - Application: /home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/mednist_app/mednist_classifier_monaideploy.py
[2024-04-10 16:36:05,007] [INFO] (packager.parameters) - Detected application type: Python File
[2024-04-10 16:36:05,007] [INFO] (packager) - Scanning for models in /home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/models...
[2024-04-10 16:36:05,007] [DEBUG] (packager) - Model model=/home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/models/model added.
[2024-04-10 16:36:05,007] [INFO] (packager) - Reading application configuration from /home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/mednist_app/app.yaml...
[2024-04-10 16:36:05,009] [INFO] (packager) - Generating app.json...
[2024-04-10 16:36:05,009] [INFO] (packager) - Generating pkg.json...
[2024-04-10 16:36:05,015] [DEBUG] (common) - 
=============== Begin app.json ===============
{
    "apiVersion": "1.0.0",
    "command": "[\"python3\", \"/opt/holoscan/app/mednist_classifier_monaideploy.py\"]",
    "environment": {
        "HOLOSCAN_APPLICATION": "/opt/holoscan/app",
        "HOLOSCAN_INPUT_PATH": "input/",
        "HOLOSCAN_OUTPUT_PATH": "output/",
        "HOLOSCAN_WORKDIR": "/var/holoscan",
        "HOLOSCAN_MODEL_PATH": "/opt/holoscan/models",
        "HOLOSCAN_CONFIG_PATH": "/var/holoscan/app.yaml",
        "HOLOSCAN_APP_MANIFEST_PATH": "/etc/holoscan/app.json",
        "HOLOSCAN_PKG_MANIFEST_PATH": "/etc/holoscan/pkg.json",
        "HOLOSCAN_DOCS_PATH": "/opt/holoscan/docs",
        "HOLOSCAN_LOGS_PATH": "/var/holoscan/logs"
    },
    "input": {
        "path": "input/",
        "formats": null
    },
    "liveness": null,
    "output": {
        "path": "output/",
        "formats": null
    },
    "readiness": null,
    "sdk": "monai-deploy",
    "sdkVersion": "0.5.1",
    "timeout": 0,
    "version": 1.0,
    "workingDirectory": "/var/holoscan"
}
================ End app.json ================
                 
[2024-04-10 16:36:05,015] [DEBUG] (common) - 
=============== Begin pkg.json ===============
{
    "apiVersion": "1.0.0",
    "applicationRoot": "/opt/holoscan/app",
    "modelRoot": "/opt/holoscan/models",
    "models": {
        "model": "/opt/holoscan/models/model"
    },
    "resources": {
        "cpu": 1,
        "gpu": 1,
        "memory": "1Gi",
        "gpuMemory": "1Gi"
    },
    "version": 1.0,
    "platformConfig": "dgpu"
}
================ End pkg.json ================
                 
[2024-04-10 16:36:05,050] [DEBUG] (packager.builder) - 
========== Begin Dockerfile ==========


FROM nvcr.io/nvidia/clara-holoscan/holoscan:v1.0.3-dgpu

ENV DEBIAN_FRONTEND=noninteractive
ENV TERM=xterm-256color

ARG UNAME
ARG UID
ARG GID

RUN mkdir -p /etc/holoscan/ \
        && mkdir -p /opt/holoscan/ \
        && mkdir -p /var/holoscan \
        && mkdir -p /opt/holoscan/app \
        && mkdir -p /var/holoscan/input \
        && mkdir -p /var/holoscan/output

LABEL base="nvcr.io/nvidia/clara-holoscan/holoscan:v1.0.3-dgpu"
LABEL tag="mednist_app:1.0"
LABEL org.opencontainers.image.title="MONAI Deploy App Package - MedNIST Classifier App"
LABEL org.opencontainers.image.version="1.0"
LABEL org.nvidia.holoscan="1.0.3"
LABEL org.monai.deploy.app-sdk="0.5.1"


ENV HOLOSCAN_ENABLE_HEALTH_CHECK=true
ENV HOLOSCAN_INPUT_PATH=/var/holoscan/input
ENV HOLOSCAN_OUTPUT_PATH=/var/holoscan/output
ENV HOLOSCAN_WORKDIR=/var/holoscan
ENV HOLOSCAN_APPLICATION=/opt/holoscan/app
ENV HOLOSCAN_TIMEOUT=0
ENV HOLOSCAN_MODEL_PATH=/opt/holoscan/models
ENV HOLOSCAN_DOCS_PATH=/opt/holoscan/docs
ENV HOLOSCAN_CONFIG_PATH=/var/holoscan/app.yaml
ENV HOLOSCAN_APP_MANIFEST_PATH=/etc/holoscan/app.json
ENV HOLOSCAN_PKG_MANIFEST_PATH=/etc/holoscan/pkg.json
ENV HOLOSCAN_LOGS_PATH=/var/holoscan/logs
ENV PATH=/root/.local/bin:/opt/nvidia/holoscan:$PATH
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/libtorch/1.13.1/lib/:/opt/nvidia/holoscan/lib

RUN apt-get update \
    && apt-get install -y curl jq \
    && rm -rf /var/lib/apt/lists/*

ENV PYTHONPATH="/opt/holoscan/app:$PYTHONPATH"



RUN groupadd -f -g $GID $UNAME
RUN useradd -rm -d /home/$UNAME -s /bin/bash -g $GID -G sudo -u $UID $UNAME
RUN chown -R holoscan /var/holoscan 
RUN chown -R holoscan /var/holoscan/input 
RUN chown -R holoscan /var/holoscan/output 

# Set the working directory
WORKDIR /var/holoscan

# Copy HAP/MAP tool script
COPY ./tools /var/holoscan/tools
RUN chmod +x /var/holoscan/tools


# Copy gRPC health probe

USER $UNAME

ENV PATH=/root/.local/bin:/home/holoscan/.local/bin:/opt/nvidia/holoscan:$PATH

COPY ./pip/requirements.txt /tmp/requirements.txt

RUN pip install --upgrade pip
RUN pip install --no-cache-dir --user -r /tmp/requirements.txt

# Install Holoscan from PyPI only when sdk_type is Holoscan. 
# For MONAI Deploy, the APP SDK will install it unless user specifies the Holoscan SDK file.

# Copy user-specified MONAI Deploy SDK file
COPY ./monai_deploy_app_sdk-0.5.1+25.g31e4165.dirty-py3-none-any.whl /tmp/monai_deploy_app_sdk-0.5.1+25.g31e4165.dirty-py3-none-any.whl
RUN pip install /tmp/monai_deploy_app_sdk-0.5.1+25.g31e4165.dirty-py3-none-any.whl




COPY ./models  /opt/holoscan/models

COPY ./map/app.json /etc/holoscan/app.json
COPY ./app.config /var/holoscan/app.yaml
COPY ./map/pkg.json /etc/holoscan/pkg.json

COPY ./app /opt/holoscan/app

ENTRYPOINT ["/var/holoscan/tools"]
=========== End Dockerfile ===========

[2024-04-10 16:36:05,050] [INFO] (packager.builder) - 
===============================================================================
Building image for:                 x64-workstation
    Architecture:                   linux/amd64
    Base Image:                     nvcr.io/nvidia/clara-holoscan/holoscan:v1.0.3-dgpu
    Build Image:                    N/A
    Cache:                          Enabled
    Configuration:                  dgpu
    Holoscan SDK Package:           pypi.org
    MONAI Deploy App SDK Package:   /home/mqin/src/monai-deploy-app-sdk/dist/monai_deploy_app_sdk-0.5.1+25.g31e4165.dirty-py3-none-any.whl
    gRPC Health Probe:              N/A
    SDK Version:                    1.0.3
    SDK:                            monai-deploy
    Tag:                            mednist_app-x64-workstation-dgpu-linux-amd64:1.0
    
[2024-04-10 16:36:05,416] [INFO] (common) - Using existing Docker BuildKit builder `holoscan_app_builder`
[2024-04-10 16:36:05,416] [DEBUG] (packager.builder) - Building Holoscan Application Package: tag=mednist_app-x64-workstation-dgpu-linux-amd64:1.0
#0 building with "holoscan_app_builder" instance using docker-container driver

#1 [internal] load build definition from Dockerfile
#1 transferring dockerfile: 2.81kB done
#1 DONE 0.1s

#2 [internal] load metadata for nvcr.io/nvidia/clara-holoscan/holoscan:v1.0.3-dgpu
#2 DONE 0.5s

#3 [internal] load .dockerignore
#3 transferring context: 1.79kB done
#3 DONE 0.1s

#4 importing cache manifest from nvcr.io/nvidia/clara-holoscan/holoscan:v1.0.3-dgpu
#4 ...

#5 [internal] load build context
#5 DONE 0.0s

#6 importing cache manifest from local:12491137658764693548
#6 inferred cache manifest type: application/vnd.oci.image.index.v1+json done
#6 DONE 0.0s

#7 [ 1/21] FROM nvcr.io/nvidia/clara-holoscan/holoscan:v1.0.3-dgpu@sha256:50343c616bf910e2a7651abb59db7833933e82cce64c3c4885f938d7e4af6155
#7 resolve nvcr.io/nvidia/clara-holoscan/holoscan:v1.0.3-dgpu@sha256:50343c616bf910e2a7651abb59db7833933e82cce64c3c4885f938d7e4af6155 0.0s done
#7 DONE 0.0s

#4 importing cache manifest from nvcr.io/nvidia/clara-holoscan/holoscan:v1.0.3-dgpu
#4 inferred cache manifest type: application/vnd.docker.distribution.manifest.list.v2+json done
#4 DONE 0.5s

#5 [internal] load build context
#5 transferring context: 28.75MB 0.2s done
#5 DONE 0.2s

#8 [13/21] RUN pip install --upgrade pip
#8 CACHED

#9 [10/21] COPY ./tools /var/holoscan/tools
#9 CACHED

#10 [ 9/21] WORKDIR /var/holoscan
#10 CACHED

#11 [ 7/21] RUN chown -R holoscan /var/holoscan/input
#11 CACHED

#12 [ 5/21] RUN useradd -rm -d /home/holoscan -s /bin/bash -g 1000 -G sudo -u 1000 holoscan
#12 CACHED

#13 [ 8/21] RUN chown -R holoscan /var/holoscan/output
#13 CACHED

#14 [ 3/21] RUN apt-get update     && apt-get install -y curl jq     && rm -rf /var/lib/apt/lists/*
#14 CACHED

#15 [12/21] COPY ./pip/requirements.txt /tmp/requirements.txt
#15 CACHED

#16 [11/21] RUN chmod +x /var/holoscan/tools
#16 CACHED

#17 [ 2/21] RUN mkdir -p /etc/holoscan/         && mkdir -p /opt/holoscan/         && mkdir -p /var/holoscan         && mkdir -p /opt/holoscan/app         && mkdir -p /var/holoscan/input         && mkdir -p /var/holoscan/output
#17 CACHED

#18 [ 4/21] RUN groupadd -f -g 1000 holoscan
#18 CACHED

#19 [ 6/21] RUN chown -R holoscan /var/holoscan
#19 CACHED

#20 [14/21] RUN pip install --no-cache-dir --user -r /tmp/requirements.txt
#20 CACHED

#21 [15/21] COPY ./monai_deploy_app_sdk-0.5.1+25.g31e4165.dirty-py3-none-any.whl /tmp/monai_deploy_app_sdk-0.5.1+25.g31e4165.dirty-py3-none-any.whl
#21 DONE 0.3s

#22 [16/21] RUN pip install /tmp/monai_deploy_app_sdk-0.5.1+25.g31e4165.dirty-py3-none-any.whl
#22 0.701 Defaulting to user installation because normal site-packages is not writeable
#22 0.799 Processing /tmp/monai_deploy_app_sdk-0.5.1+25.g31e4165.dirty-py3-none-any.whl
#22 0.810 Requirement already satisfied: numpy>=1.21.6 in /usr/local/lib/python3.10/dist-packages (from monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (1.23.5)
#22 0.996 Collecting holoscan~=1.0 (from monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty)
#22 1.066   Downloading holoscan-1.0.3-cp310-cp310-manylinux_2_35_x86_64.whl.metadata (4.1 kB)
#22 1.137 Collecting colorama>=0.4.1 (from monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty)
#22 1.141   Downloading colorama-0.4.6-py2.py3-none-any.whl.metadata (17 kB)
#22 1.222 Collecting typeguard>=3.0.0 (from monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty)
#22 1.227   Downloading typeguard-4.2.1-py3-none-any.whl.metadata (3.7 kB)
#22 1.322 Collecting pip==23.3.2 (from holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty)
#22 1.327   Downloading pip-23.3.2-py3-none-any.whl.metadata (3.5 kB)
#22 1.343 Requirement already satisfied: cupy-cuda12x==12.2 in /usr/local/lib/python3.10/dist-packages (from holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (12.2.0)
#22 1.343 Requirement already satisfied: cloudpickle==2.2.1 in /usr/local/lib/python3.10/dist-packages (from holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (2.2.1)
#22 1.344 Requirement already satisfied: python-on-whales==0.60.1 in /usr/local/lib/python3.10/dist-packages (from holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (0.60.1)
#22 1.346 Requirement already satisfied: Jinja2==3.1.2 in /usr/local/lib/python3.10/dist-packages (from holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (3.1.2)
#22 1.346 Requirement already satisfied: packaging==23.1 in /usr/local/lib/python3.10/dist-packages (from holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (23.1)
#22 1.347 Requirement already satisfied: pyyaml==6.0 in /usr/local/lib/python3.10/dist-packages (from holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (6.0)
#22 1.348 Requirement already satisfied: requests==2.28.2 in /usr/local/lib/python3.10/dist-packages (from holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (2.28.2)
#22 1.350 Requirement already satisfied: psutil==5.9.6 in /usr/local/lib/python3.10/dist-packages (from holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (5.9.6)
#22 1.461 Collecting wheel-axle-runtime<1.0 (from holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty)
#22 1.467   Downloading wheel_axle_runtime-0.0.5-py3-none-any.whl.metadata (7.7 kB)
#22 1.504 Requirement already satisfied: fastrlock>=0.5 in /usr/local/lib/python3.10/dist-packages (from cupy-cuda12x==12.2->holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (0.8.2)
#22 1.508 Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2==3.1.2->holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (2.1.3)
#22 1.523 Requirement already satisfied: pydantic<2,>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-on-whales==0.60.1->holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (1.10.14)
#22 1.524 Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from python-on-whales==0.60.1->holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (4.66.1)
#22 1.525 Requirement already satisfied: typer>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from python-on-whales==0.60.1->holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (0.9.0)
#22 1.526 Requirement already satisfied: typing-extensions in /home/holoscan/.local/lib/python3.10/site-packages (from python-on-whales==0.60.1->holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (4.10.0)
#22 1.533 Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests==2.28.2->holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (3.3.2)
#22 1.534 Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests==2.28.2->holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (3.6)
#22 1.535 Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests==2.28.2->holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (1.26.18)
#22 1.536 Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests==2.28.2->holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (2023.11.17)
#22 1.551 Requirement already satisfied: filelock in /home/holoscan/.local/lib/python3.10/site-packages (from wheel-axle-runtime<1.0->holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (3.13.3)
#22 1.587 Requirement already satisfied: click<9.0.0,>=7.1.1 in /usr/local/lib/python3.10/dist-packages (from typer>=0.4.1->python-on-whales==0.60.1->holoscan~=1.0->monai-deploy-app-sdk==0.5.1+25.g31e4165.dirty) (8.1.7)
#22 1.630 Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)
#22 1.653 Downloading holoscan-1.0.3-cp310-cp310-manylinux_2_35_x86_64.whl (33.6 MB)
#22 2.391    ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 33.6/33.6 MB 36.3 MB/s eta 0:00:00
#22 2.399 Downloading pip-23.3.2-py3-none-any.whl (2.1 MB)
#22 2.452    ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 45.3 MB/s eta 0:00:00
#22 2.459 Downloading typeguard-4.2.1-py3-none-any.whl (34 kB)
#22 2.484 Downloading wheel_axle_runtime-0.0.5-py3-none-any.whl (12 kB)
#22 2.841 Installing collected packages: wheel-axle-runtime, typeguard, pip, colorama, holoscan, monai-deploy-app-sdk
#22 2.915   Attempting uninstall: pip
#22 2.916     Found existing installation: pip 24.0
#22 2.967     Uninstalling pip-24.0:
#22 3.365       Successfully uninstalled pip-24.0
#22 5.013 Successfully installed colorama-0.4.6 holoscan-1.0.3 monai-deploy-app-sdk-0.5.1+25.g31e4165.dirty pip-23.3.2 typeguard-4.2.1 wheel-axle-runtime-0.0.5
#22 DONE 5.6s

#23 [17/21] COPY ./models  /opt/holoscan/models
#23 DONE 0.3s

#24 [18/21] COPY ./map/app.json /etc/holoscan/app.json
#24 DONE 0.1s

#25 [19/21] COPY ./app.config /var/holoscan/app.yaml
#25 DONE 0.1s

#26 [20/21] COPY ./map/pkg.json /etc/holoscan/pkg.json
#26 DONE 0.1s

#27 [21/21] COPY ./app /opt/holoscan/app
#27 DONE 0.1s

#28 exporting to docker image format
#28 exporting layers
#28 exporting layers 5.9s done
#28 exporting manifest sha256:261bfd883479734974f9f01500b63a394537a84df186d0552794645b0152f0f5 0.0s done
#28 exporting config sha256:47a95542f89e8e3174bba11729dc605a923542cf7c48c180ae2eb42290619826 0.0s done
#28 sending tarball
#28 ...

#29 importing to docker
#29 loading layer 2c6ff491304f 32.77kB / 125.57kB
#29 loading layer 04072fb0fc22 557.06kB / 73.96MB
#29 loading layer 04072fb0fc22 71.86MB / 73.96MB 2.1s
#29 loading layer 1982c4813c35 262.14kB / 26.20MB
#29 loading layer dc0acf48e445 513B / 513B
#29 loading layer a50eb25f7721 320B / 320B
#29 loading layer 838ef774fdf1 298B / 298B
#29 loading layer 3401c98a4ff8 4.00kB / 4.00kB
#29 loading layer 838ef774fdf1 298B / 298B 0.8s done
#29 loading layer 2c6ff491304f 32.77kB / 125.57kB 4.0s done
#29 loading layer 04072fb0fc22 71.86MB / 73.96MB 3.8s done
#29 loading layer 1982c4813c35 262.14kB / 26.20MB 1.3s done
#29 loading layer dc0acf48e445 513B / 513B 0.9s done
#29 loading layer a50eb25f7721 320B / 320B 0.9s done
#29 loading layer 3401c98a4ff8 4.00kB / 4.00kB 0.8s done
#29 DONE 4.0s

#28 exporting to docker image format
#28 sending tarball 68.3s done
#28 DONE 74.3s

#30 exporting cache to client directory
#30 preparing build cache for export
#30 writing layer sha256:00bb4c1319ba1a33ac3edcb3aa1240d8abcb8d0383c6267ed8028d3b6228a8a4
#30 writing layer sha256:00bb4c1319ba1a33ac3edcb3aa1240d8abcb8d0383c6267ed8028d3b6228a8a4 done
#30 writing layer sha256:014cff740c9ec6e9a30d0b859219a700ae880eb385d62095d348f5ea136d6015 done
#30 writing layer sha256:0a1756432df4a4350712d8ae5c003f1526bd2180800b3ae6301cfc9ccf370254 done
#30 writing layer sha256:0a77dcbd0e648ddc4f8e5230ade8fdb781d99e24fa4f13ca96a360c7f7e6751f done
#30 writing layer sha256:0ec682bf99715a9f88631226f3749e2271b8b9f254528ef61f65ed829984821c done
#30 writing layer sha256:1133dfcee0e851b490d17b3567f50c4b25ba5750da02ba4b3f3630655d0b1a7b done
#30 writing layer sha256:1294b2835667d633f938174d9fecb18a60bbbebb6fb49788a1f939893a25d1af done
#30 writing layer sha256:16a03c6e0373b62f9713416da0229bb7ce2585183141081d3ea8427ad2e84408 done
#30 writing layer sha256:20d331454f5fb557f2692dfbdbe092c718fd2cb55d5db9d661b62228dacca5c2 done
#30 writing layer sha256:2232aeb26b5b7ea57227e9a5b84da4fb229624d7bc976a5f7ce86d9c8653d277 done
#30 writing layer sha256:238f69a43816e481f0295995fcf5fe74d59facf0f9f99734c8d0a2fb140630e0 done
#30 writing layer sha256:2ad84487f9d4d31cd1e0a92697a5447dd241935253d036b272ef16d31620c1e7 done
#30 writing layer sha256:2bb73464628bd4a136c4937f42d522c847bea86b2215ae734949e24c1caf450e done
#30 writing layer sha256:2ca59f23482f8bc9a313f15326cc9326efd2553b0480274dc62b6213b864e2ed 0.0s done
#30 writing layer sha256:32ccfe43297de5eb7d872ac37cb2e4b356a9fdd75b37a1d4e9c0a96f26d3a1eb 0.0s done
#30 writing layer sha256:3e3e04011ebdba380ab129f0ee390626cb2a600623815ca756340c18bedb9517 done
#30 writing layer sha256:42619ce4a0c9e54cfd0ee41a8e5f27d58b3f51becabd1ac6de725fbe6c42b14a done
#30 writing layer sha256:43a21fb6c76bd2b3715cc09d9f8c3865dc61c51dd9e2327b429f5bec8fff85d1 done
#30 writing layer sha256:49bdc9abf8a437ccff67cc11490ba52c976577992909856a86be872a34d3b950 done
#30 writing layer sha256:4b691ba9f48b41eaa0c754feba8366f1c030464fcbc55eeffa6c86675990933a done
#30 writing layer sha256:4d04a8db404f16c2704fa10739cb6745a0187713a21a6ef0deb34b48629b54c1 done
#30 writing layer sha256:4f4fb700ef54461cfa02571ae0db9a0dc1e0cdb5577484a6d75e68dc38e8acc1 done
#30 writing layer sha256:5275a41be8f6691a490c0a15589e0910c73bf971169ad33a850ef570d37f63dd done
#30 writing layer sha256:52fbfeaf78318d843054ce2bfb5bfc9f71278939a815f6035ab5b14573ad017b done
#30 writing layer sha256:5792b18b6f162bae61ff5840cdb9e8567e6847a56ac886f940b47e7271c529a7 done
#30 writing layer sha256:57f244836ad318f9bbb3b29856ae1a5b31038bfbb9b43d2466d51c199eb55041 done
#30 writing layer sha256:5b5b131e0f20db4cb8e568b623a95f8fc16ed1c6b322a9366df70b59a881f24f done
#30 writing layer sha256:5ccb787d371fd3697122101438ddd0f55b537832e9756d2c51ab1d8158710ac5 done
#30 writing layer sha256:62452179df7c18e292f141d4aec29e6aba9ff8270c893731169fc6f41dc07631 done
#30 writing layer sha256:6630c387f5f2115bca2e646fd0c2f64e1f3d5431c2e050abe607633883eda230 done
#30 writing layer sha256:69af4b756272a77f683a8d118fd5ca55c03ad5f1bacc673b463f54d16b833da5 done
#30 writing layer sha256:6ae1f1fb92c0cb2b6e219f687b08c8e511501a7af696c943ca20d119eba7cd02 done
#30 writing layer sha256:6deb3d550b15a5e099c0b3d0cbc242e351722ca16c058d3a6c28ba1a02824d0f done
#30 writing layer sha256:7386814d57100e2c7389fbf4e16f140f5c549d31434c62c3884a85a3ee5cd2a7 done
#30 writing layer sha256:7852b73ea931e3a8d3287ee7ef3cf4bad068e44f046583bfc2b81336fb299284 done
#30 writing layer sha256:7e73869c74822e4539e104a3d2aff853f4622cd0bb873576db1db53c9e91f621 done
#30 writing layer sha256:7eae142b38745fe88962874372374deb672998600264a17e638c010b79e6b535 done
#30 writing layer sha256:7f2e5ab2c599fa36698918d3e73c991d8616fff9037077cd230529e7cd1c5e0e done
#30 writing layer sha256:81b2d4e60f6b67ed37f95e3d15237a436e76056fb4babcb9a188fd2b337c897b 0.0s done
#30 writing layer sha256:82a3436133b2b17bb407c7fe488932aa0ca55411f23ab55c34a6134b287c6a27 done
#30 writing layer sha256:90eae6faa5cc5ba62f12c25915cdfb1a7a51abfba0d05cb5818c3f908f4e345f
#30 writing layer sha256:90eae6faa5cc5ba62f12c25915cdfb1a7a51abfba0d05cb5818c3f908f4e345f done
#30 writing layer sha256:93e2013abbc3bc85f24d4739ac397584f6332aec7d8e80f8d95d9c961978fe90 0.0s done
#30 writing layer sha256:9723201c31b4e56a2dff5c3769790d4d6a7c069d75bdd3996395600bd0d067cd done
#30 writing layer sha256:9ac855545fa90ed2bf3b388fdff9ef06ac9427b0c0fca07c9e59161983d8827e done
#30 writing layer sha256:9d19ee268e0d7bcf6716e6658ee1b0384a71d6f2f9aa1ae2085610cf7c7b316f done
#30 writing layer sha256:a10c8d7d2714eabf661d1f43a1ccb87a51748cbb9094d5bc0b713e2481b5d329 done
#30 writing layer sha256:a1748eee9d376f97bd19225ba61dfada9986f063f4fc429e435f157abb629fc6 done
#30 writing layer sha256:a68f4e0ec09ec3b78cb4cf8e4511d658e34e7b6f676d7806ad9703194ff17604 done
#30 writing layer sha256:a8e4decc8f7289623b8fd7b9ba1ca555b5a755ebdbf81328d68209f148d9e602 done
#30 writing layer sha256:a9cc9b4b42ca5455c9da9b048ab2cc36e82bd335f51c23817f4bcf330bbb96f1 done
#30 writing layer sha256:afde1c269453ce68a0f2b54c1ba8c5ecddeb18a19e5618a4acdef1f0fe3921af done
#30 writing layer sha256:b48a5fafcaba74eb5d7e7665601509e2889285b50a04b5b639a23f8adc818157 done
#30 writing layer sha256:ba9f7c75e4dd7942b944679995365aab766d3677da2e69e1d74472f471a484dd done
#30 writing layer sha256:bdfc73b2a0fa11b4086677e117a2f9feb6b4ffeccb23a3d58a30543339607e31 done
#30 writing layer sha256:c175bb235295e50de2961fa1e1a2235c57e6eba723a914287dfc26d3be0eac11 done
#30 writing layer sha256:c98533d2908f36a5e9b52faae83809b3b6865b50e90e2817308acfc64cd3655f done
#30 writing layer sha256:cb6c95b33bc30dd285c5b3cf99a05281b8f12decae1c932ab64bd58f56354021 done
#30 writing layer sha256:cc985f61e92a80cbc59a150c5758becb75f8eddbbbaf17d46374ede3cd01a51f
#30 writing layer sha256:cc985f61e92a80cbc59a150c5758becb75f8eddbbbaf17d46374ede3cd01a51f 0.5s done
#30 writing layer sha256:d7da5c5e9a40c476c4b3188a845e3276dedfd752e015ea5113df5af64d4d43f7
#30 writing layer sha256:d7da5c5e9a40c476c4b3188a845e3276dedfd752e015ea5113df5af64d4d43f7 done
#30 writing layer sha256:df3589199e830d446e82feab6d40fac58781a5bd8b2d206f25b85a317b994f93 0.0s done
#30 writing layer sha256:e434bbf389a48c6e211eca75d5ca50839cb622b1ba3a36c6b35d600e53e16b21 done
#30 writing layer sha256:e4aedc686433c0ec5e676e6cc54a164345f7016aa0eb714f00c07e11664a1168 done
#30 writing layer sha256:e5d1792b50654fc7f0eed206f4c91e95f8e4b107554a7296502020c7029a76b6
#30 writing layer sha256:e5d1792b50654fc7f0eed206f4c91e95f8e4b107554a7296502020c7029a76b6 1.3s done
#30 preparing build cache for export 2.2s done
#30 writing layer sha256:e8acb678f16bc0c369d5cf9c184f2d3a1c773986816526e5e3e9c0354f7e757f done
#30 writing layer sha256:e9225f7ab6606813ec9acba98a064826ebfd6713a9645a58cd068538af1ecddb done
#30 writing layer sha256:f33546e75bf1a7d9dc9e21b9a2c54c9d09b24790ad7a4192a8509002ceb14688 done
#30 writing layer sha256:f608e2fbff86e98627b7e462057e7d2416522096d73fe4664b82fe6ce8a4047d done
#30 writing layer sha256:f7702077ced42a1ee35e7f5e45f72634328ff3bcfe3f57735ba80baa5ec45daf done
#30 writing layer sha256:fa66a49172c6e821a1bace57c007c01da10cbc61507c44f8cdfeed8c4e5febab done
#30 writing config sha256:217441004720a68ddf80261db2a5b316ddba5c5bc611403e7439f6d0f6d2055d 0.0s done
#30 writing cache manifest sha256:eb0660732980435a67eb754f67f1d7b91fa92f577670a600c7a8a50a85b8f872 0.0s done
#30 DONE 2.2s
[2024-04-10 16:37:31,408] [INFO] (packager) - Build Summary:

Platform: x64-workstation/dgpu
    Status:     Succeeded
    Docker Tag: mednist_app-x64-workstation-dgpu-linux-amd64:1.0
    Tarball:    None

Note

Building a MONAI Application Package (Docker image) can take time. Use -l DEBUG option if you want to see the progress.

We can see that the Docker image is created.

!docker image ls | grep {tag_prefix}
mednist_app-x64-workstation-dgpu-linux-amd64                                              1.0                        47a95542f89e   About a minute ago   17.5GB

Executing packaged app locally

We can choose to display and export the MAP manifests, but in this example, we will just run the MAP through MONAI Application Runner.

# Clear the output folder and run the MAP. The input is expected to be a folder.
!rm -rf $HOLOSCAN_OUTPUT_PATH
!monai-deploy run -i$HOLOSCAN_INPUT_PATH -o $HOLOSCAN_OUTPUT_PATH mednist_app-x64-workstation-dgpu-linux-amd64:1.0
[2024-04-10 16:37:33,094] [INFO] (runner) - Checking dependencies...
[2024-04-10 16:37:33,094] [INFO] (runner) - --> Verifying if "docker" is installed...

[2024-04-10 16:37:33,094] [INFO] (runner) - --> Verifying if "docker-buildx" is installed...

[2024-04-10 16:37:33,094] [INFO] (runner) - --> Verifying if "mednist_app-x64-workstation-dgpu-linux-amd64:1.0" is available...

[2024-04-10 16:37:33,168] [INFO] (runner) - Reading HAP/MAP manifest...
Preparing to copy...?25lCopying from container - 0B?25hSuccessfully copied 2.56kB to /tmp/tmp96catisy/app.json
Preparing to copy...?25lCopying from container - 0B?25hSuccessfully copied 2.05kB to /tmp/tmp96catisy/pkg.json
[2024-04-10 16:37:33,777] [INFO] (runner) - --> Verifying if "nvidia-ctk" is installed...

[2024-04-10 16:37:33,778] [INFO] (runner) - --> Verifying "nvidia-ctk" version...

[2024-04-10 16:37:33,934] [INFO] (common) - Launching container (4b3bba81606f) using image 'mednist_app-x64-workstation-dgpu-linux-amd64:1.0'...
    container name:      flamboyant_jepsen
    host name:           mingq-dt
    network:             host
    user:                1000:1000
    ulimits:             memlock=-1:-1, stack=67108864:67108864
    cap_add:             CAP_SYS_PTRACE
    ipc mode:            host
    shared memory size:  67108864
    devices:             
    group_add:           44
2024-04-10 23:37:34 [INFO] Launching application python3 /opt/holoscan/app/mednist_classifier_monaideploy.py ...

[info] [app_driver.cpp:1161] Launching the driver/health checking service

[info] [gxf_executor.cpp:211] Creating context

[info] [server.cpp:87] Health checking server listening on 0.0.0.0:8777

[info] [gxf_executor.cpp:1674] Loading extensions from configs...

[info] [gxf_executor.cpp:1864] Activating Graph...

[info] [gxf_executor.cpp:1894] Running Graph...

[info] [gxf_executor.cpp:1896] Waiting for completion...

[info] [gxf_executor.cpp:1897] Graph execution waiting. Fragment: 

[info] [greedy_scheduler.cpp:190] Scheduling 3 entities

/home/holoscan/.local/lib/python3.10/site-packages/monai/data/meta_tensor.py:116: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)

  return torch.as_tensor(x, *args, **_kwargs).as_subclass(cls)

/home/holoscan/.local/lib/python3.10/site-packages/pydicom/valuerep.py:443: UserWarning: Invalid value for VR UI: 'xyz'. Please see <https://dicom.nema.org/medical/dicom/current/output/html/part05.html#table_6.2-1> for allowed values for each VR.

  warnings.warn(msg)

[info] [greedy_scheduler.cpp:369] Scheduler stopped: Some entities are waiting for execution, but there are no periodic or async entities to get out of the deadlock.

[info] [greedy_scheduler.cpp:398] Scheduler finished.

[info] [gxf_executor.cpp:1906] Graph execution deactivating. Fragment: 

[info] [gxf_executor.cpp:1907] Deactivating Graph...

[info] [gxf_executor.cpp:1910] Graph execution finished. Fragment: 

[info] [gxf_executor.cpp:230] Destroying context

AbdomenCT

[2024-04-10 16:37:42,228] [INFO] (common) - Container 'flamboyant_jepsen'(4b3bba81606f) exited.
!cat $HOLOSCAN_OUTPUT_PATH/output.json
"AbdomenCT"

Note: Please execute the following script once the exercise is done.

# Remove data files which is in the temporary folder
if directory is None:
    shutil.rmtree(root_dir)