Deploying a MedNIST Classifier App with MONAI Deploy App SDK

This tutorial demos the process of packaging up a trained model using MONAI Deploy App SDK into an artifact which can be run as a local program performing inference, a workflow job doing the same, and a Docker containerized workflow execution.

In this tutorial, we will train a MedNIST classifier like the MONAI tutorial here and then implement & package the inference application, executing the application locally.

Train a MedNIST classifier model with MONAI Core

Setup environment

# Install necessary packages for MONAI Core
!python -c "import monai" || pip install -q "monai[pillow, tqdm]"
!python -c "import ignite" || pip install -q "monai[ignite]"
!python -c "import gdown" || pip install -q "monai[gdown]"
!python -c "import pydicom" || pip install -q "pydicom>=1.4.2"
!python -c "import highdicom" || pip install -q "highdicom>=0.18.2"  # for the use of DICOM Writer operators

# Install MONAI Deploy App SDK package
!python -c "import monai.deploy" || pip install -q "monai-deploy-app-sdk"

Setup imports

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import glob
import PIL.Image
import torch
import numpy as np

from ignite.engine import Events

from monai.apps import download_and_extract
from monai.config import print_config
from monai.networks.nets import DenseNet121
from monai.engines import SupervisedTrainer
from monai.transforms import (
    AddChannel,
    Compose,
    LoadImage,
    RandFlip,
    RandRotate,
    RandZoom,
    ScaleIntensity,
    EnsureType,
)
from monai.utils import set_determinism

set_determinism(seed=0)

print_config()
MONAI version: 1.2.0
Numpy version: 1.24.4
Pytorch version: 2.0.1+cu117
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: c33f1ba588ee00229a309000e888f9817b4f1934
MONAI __file__: /home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.1.0
scikit-image version: 0.21.0
Pillow version: 10.0.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: 4.7.1
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.65.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.5
pandas version: NOT INSTALLED or UNKNOWN VERSION.
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

Download dataset

The MedNIST dataset was gathered from several sets from TCIA, the RSNA Bone Age Challenge(https://www.rsna.org/education/ai-resources-and-training/ai-image-challenge/rsna-pediatric-bone-age-challenge-2017), and the NIH Chest X-ray dataset.

The dataset is kindly made available by Dr. Bradley J. Erickson M.D., Ph.D. (Department of Radiology, Mayo Clinic) under the Creative Commons CC BY-SA 4.0 license.

If you use the MedNIST dataset, please acknowledge the source.

directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

resource = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE"
md5 = "0bc7306e7427e00ad1c5526a6677552d"

compressed_file = os.path.join(root_dir, "MedNIST.tar.gz")
data_dir = os.path.join(root_dir, "MedNIST")
if not os.path.exists(data_dir):
    download_and_extract(resource, compressed_file, root_dir, md5)
/tmp/tmp_ijj195_
Downloading...
From (uriginal): https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE
From (redirected): https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE&confirm=t&uuid=d974f3a4-5d30-48b6-9b6d-9459b32b4cac
To: /tmp/tmp3aa3c3k6/MedNIST.tar.gz
100%|██████████| 61.8M/61.8M [00:03<00:00, 19.1MB/s]
2023-08-03 20:42:12,748 - INFO - Downloaded: /tmp/tmp_ijj195_/MedNIST.tar.gz
2023-08-03 20:42:12,856 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.
2023-08-03 20:42:12,857 - INFO - Writing into directory: /tmp/tmp_ijj195_.

subdirs = sorted(glob.glob(f"{data_dir}/*/"))

class_names = [os.path.basename(sd[:-1]) for sd in subdirs]
image_files = [glob.glob(f"{sb}/*") for sb in subdirs]

image_files_list = sum(image_files, [])
image_class = sum(([i] * len(f) for i, f in enumerate(image_files)), [])
image_width, image_height = PIL.Image.open(image_files_list[0]).size

print(f"Label names: {class_names}")
print(f"Label counts: {list(map(len, image_files))}")
print(f"Total image count: {len(image_class)}")
print(f"Image dimensions: {image_width} x {image_height}")
Label names: ['AbdomenCT', 'BreastMRI', 'CXR', 'ChestCT', 'Hand', 'HeadCT']
Label counts: [10000, 8954, 10000, 10000, 10000, 10000]
Total image count: 58954
Image dimensions: 64 x 64

Setup and train

Here we’ll create a transform sequence and train the network, omitting validation and testing since we know this does indeed work and it’s not needed here:

train_transforms = Compose(
    [
        LoadImage(image_only=True),
        AddChannel(),
        ScaleIntensity(),
        RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
        RandFlip(spatial_axis=0, prob=0.5),
        RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
        EnsureType(),
    ]
)
/home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages/monai/utils/deprecate_utils.py:111: FutureWarning: <class 'monai.transforms.utility.array.AddChannel'>: Class `AddChannel` has been deprecated since version 0.8. It will be removed in version 1.3. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead with `channel_dim='no_channel'`.
  warn_deprecated(obj, msg, warning_category)
class MedNISTDataset(torch.utils.data.Dataset):
    def __init__(self, image_files, labels, transforms):
        self.image_files = image_files
        self.labels = labels
        self.transforms = transforms

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        return self.transforms(self.image_files[index]), self.labels[index]


# just one dataset and loader, we won't bother with validation or testing 
train_ds = MedNISTDataset(image_files_list, image_class, train_transforms)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = DenseNet121(spatial_dims=2, in_channels=1, out_channels=len(class_names)).to(device)
loss_function = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(net.parameters(), 1e-5)
max_epochs = 5
def _prepare_batch(batch, device, non_blocking):
    return tuple(b.to(device) for b in batch)


trainer = SupervisedTrainer(device, max_epochs, train_loader, net, opt, loss_function, prepare_batch=_prepare_batch)


@trainer.on(Events.EPOCH_COMPLETED)
def _print_loss(engine):
    print(f"Epoch {engine.state.epoch}/{engine.state.max_epochs} Loss: {engine.state.output[0]['loss']}")


trainer.run()
Epoch 1/5 Loss: 0.18928290903568268
Epoch 2/5 Loss: 0.06710730493068695
Epoch 3/5 Loss: 0.029032323509454727
Epoch 4/5 Loss: 0.01877668686211109
Epoch 5/5 Loss: 0.01939055137336254

The network will be saved out here as a Torchscript object named classifier.zip

torch.jit.script(net).save("classifier.zip")

Implementing and Packaging Application with MONAI Deploy App SDK

Based on the Torchscript model(classifier.zip), we will implement an application that process an input Jpeg image and write the prediction(classification) result as JSON file(output.json).

Creating Operators and connecting them in Application class

We used the following train transforms as pre-transforms during the training.

Train transforms used in training
 1train_transforms = Compose(
 2    [
 3        LoadImage(image_only=True),
 4        AddChannel(),
 5        ScaleIntensity(),
 6        RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
 7        RandFlip(spatial_axis=0, prob=0.5),
 8        RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
 9        EnsureType(),
10    ]
11)

RandRotate, RandFlip, and RandZoom transforms are used only for training and those are not necessary during the inference.

In our inference application, we will define two operators:

  1. LoadPILOperator - Load a JPEG image from the input path and pass the loaded image object to the next operator.

    • This Operator does similar job with LoadImage(image_only=True) transform in train_transforms, but handles only one image.

    • Input: a file path (Path)

    • Output: an image object in memory (Image)

  2. MedNISTClassifierOperator - Pre-transform the given image by using MONAI’s Compose class, feed to the Torchscript model (classifier.zip), and write the prediction into JSON file(output.json)

    • Pre-transforms consist of three transforms – AddChannel, ScaleIntensity, and EnsureType.

    • Input: an image object in memory (Image)

    • Output: a folder path that the prediction result(output.json) would be written (DataPath)

The workflow of the application would look like this.

%%{init: {"theme": "base", "themeVariables": { "fontSize": "16px"}} }%% classDiagram direction LR LoadPILOperator --|> MedNISTClassifierOperator : image...image class LoadPILOperator { <in>image : DISK image(out) IN_MEMORY } class MedNISTClassifierOperator { <in>image : IN_MEMORY output(out) DISK }

Set up environment variables

Before proceeding to the application building and packaging, we first need to set the well-known environment variables, because the application parses them for the input, output, and model folders. Defaults are used if these environment variable are absent.

Set the environment variables corresponding to the extracted data path.

input_folder = "input"
output_foler = "output"
models_folder = "models"

# Choose a file as test input
test_input_path = image_files[0][0]
!rm -rf {input_folder} && mkdir -p {input_folder} && cp {test_input_path} {input_folder} && ls {input_folder}
# Need to copy the model file to its own clean subfolder for pacakging, to workaround an issue in the Packager
!rm -rf {models_folder} && mkdir -p {models_folder}/model && cp classifier.zip {models_folder}/model && ls {models_folder}/model

%env HOLOSCAN_INPUT_PATH {input_folder}
%env HOLOSCAN_OUTPUT_PATH {output_foler}
%env HOLOSCAN_MODEL_PATH {models_folder}
001420.jpeg
classifier.zip
env: HOLOSCAN_INPUT_PATH=input
env: HOLOSCAN_OUTPUT_PATH=output
env: HOLOSCAN_MODEL_PATH=models

Setup imports

Let’s import necessary classes/decorators and define MEDNIST_CLASSES.

import logging
import os
from pathlib import Path
from typing import Optional

import torch

from monai.deploy.conditions import CountCondition
from monai.deploy.core import AppContext, Application, ConditionType, Fragment, Image, Operator, OperatorSpec
from monai.deploy.operators.dicom_text_sr_writer_operator import DICOMTextSRWriterOperator, EquipmentInfo, ModelInfo
from monai.transforms import AddChannel, Compose, EnsureType, ScaleIntensity

MEDNIST_CLASSES = ["AbdomenCT", "BreastMRI", "CXR", "ChestCT", "Hand", "HeadCT"]

Creating Operator classes

LoadPILOperator
class LoadPILOperator(Operator):
    """Load image from the given input (DataPath) and set numpy array to the output (Image)."""

    DEFAULT_INPUT_FOLDER = Path.cwd() / "input"
    DEFAULT_OUTPUT_NAME = "image"

    # For now, need to have the input folder as an instance attribute, set on init.
    # If dynamically changing the input folder, per compute, then use a (optional) input port to convey the
    # value of the input folder, which is then emitted by a upstream operator.
    def __init__(
        self,
        fragment: Fragment,
        *args,
        input_folder: Path = DEFAULT_INPUT_FOLDER,
        output_name: str = DEFAULT_OUTPUT_NAME,
        **kwargs,
    ):
        """Creates an loader object with the input folder and the output port name overrides as needed.

        Args:
            fragment (Fragment): An instance of the Application class which is derived from Fragment.
            input_folder (Path): Folder from which to load input file(s).
                                 Defaults to `input` in the current working directory.
            output_name (str): Name of the output port, which is an image object. Defaults to `image`.
        """

        self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
        self.input_path = input_folder
        self.index = 0
        self.output_name_image = (
            output_name.strip() if output_name and len(output_name.strip()) > 0 else LoadPILOperator.DEFAULT_OUTPUT_NAME
        )

        super().__init__(fragment, *args, **kwargs)

    def setup(self, spec: OperatorSpec):
        """Set up the named input and output port(s)"""
        spec.output(self.output_name_image)

    def compute(self, op_input, op_output, context):
        import numpy as np
        from PIL import Image as PILImage

        # Input path is stored in the object attribute, but could change to use a named port if need be.
        input_path = self.input_path
        if input_path.is_dir():
            input_path = next(self.input_path.glob("*.*"))  # take the first file

        image = PILImage.open(input_path)
        image = image.convert("L")  # convert to greyscale image
        image_arr = np.asarray(image)

        output_image = Image(image_arr)  # create Image domain object with a numpy array
        op_output.emit(output_image, self.output_name_image)  # cannot omit the name even if single output.
MedNISTClassifierOperator
class MedNISTClassifierOperator(Operator):
    """Classifies the given image and returns the class name.

    Named inputs:
        image: Image object for which to generate the classification.
        output_folder: Optional, the path to save the results JSON file, overridingthe the one set on __init__

    Named output:
        result_text: The classification results in text.
    """

    DEFAULT_OUTPUT_FOLDER = Path.cwd() / "classification_results"
    # For testing the app directly, the model should be at the following path.
    MODEL_LOCAL_PATH = Path(os.environ.get("HOLOSCAN_MODEL_PATH", Path.cwd() / "model/model.ts"))

    def __init__(
        self,
        frament: Fragment,
        *args,
        app_context: AppContext,
        model_name: Optional[str] = "",
        model_path: Path = MODEL_LOCAL_PATH,
        output_folder: Path = DEFAULT_OUTPUT_FOLDER,
        **kwargs,
    ):
        """Creates an instance with the reference back to the containing application/fragment.

        fragment (Fragment): An instance of the Application class which is derived from Fragment.
        model_name (str, optional): Name of the model. Default to "" for single model app.
        model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir.
        output_folder (Path, optional): output folder for saving the classification results JSON file.
        """

        # the names used for the model inference input and output
        self._input_dataset_key = "image"
        self._pred_dataset_key = "pred"

        # The names used for the operator input and output
        self.input_name_image = "image"
        self.output_name_result = "result_text"

        # The name of the optional input port for passing data to override the output folder path.
        self.input_name_output_folder = "output_folder"

        # The output folder set on the object can be overriden at each compute by data in the optional named input
        self.output_folder = output_folder

        # Need the name when there are multiple models loaded
        self._model_name = model_name.strip() if isinstance(model_name, str) else ""
        # Need the path to load the models when they are not loaded in the execution context
        self.model_path = model_path
        self.app_context = app_context
        self.model = self._get_model(self.app_context, self.model_path, self._model_name)

        # This needs to be at the end of the constructor.
        super().__init__(frament, *args, **kwargs)

    def _get_model(self, app_context: AppContext, model_path: Path, model_name: str):
        """Load the model with the given name from context or model path

        Args:
            app_context (AppContext): The application context object holding the model(s)
            model_path (Path): The path to the model file, as a backup to load model directly
            model_name (str): The name of the model, when multiples are loaded in the context
        """

        if app_context.models:
            # `app_context.models.get(model_name)` returns a model instance if exists.
            # If model_name is not specified and only one model exists, it returns that model.
            model = app_context.models.get(model_name)
        else:
            model = torch.jit.load(
                MedNISTClassifierOperator.MODEL_LOCAL_PATH,
                map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
            )

        return model

    def setup(self, spec: OperatorSpec):
        """Set up the operator named input and named output, both are in-memory objects."""

        spec.input(self.input_name_image)
        spec.input(self.input_name_output_folder).condition(ConditionType.NONE)  # Optional for overriding.
        spec.output(self.output_name_result).condition(ConditionType.NONE)  # Not forcing a downstream receiver.

    @property
    def transform(self):
        return Compose([AddChannel(), ScaleIntensity(), EnsureType()])

    def compute(self, op_input, op_output, context):
        import json

        import torch

        img = op_input.receive(self.input_name_image).asnumpy()  # (64, 64), uint8. Input validation can be added.
        image_tensor = self.transform(img)  # (1, 64, 64), torch.float64
        image_tensor = image_tensor[None].float()  # (1, 1, 64, 64), torch.float32

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        image_tensor = image_tensor.to(device)

        with torch.no_grad():
            outputs = self.model(image_tensor)

        _, output_classes = outputs.max(dim=1)

        result = MEDNIST_CLASSES[output_classes[0]]  # get the class name
        print(result)
        op_output.emit(result, self.output_name_result)

        # Get output folder, with value in optional input port overriding the obj attribute
        output_folder_on_compute = op_input.receive(self.input_name_output_folder) or self.output_folder
        Path.mkdir(output_folder_on_compute, parents=True, exist_ok=True)  # Let exception bubble up if raised.
        output_path = output_folder_on_compute / "output.json"
        with open(output_path, "w") as fp:
            json.dump(result, fp)

Creating Application class

Our application class would look like below.

It defines App class inheriting Application class.

LoadPILOperator is connected to MedNISTClassifierOperator by using self.add_flow() in compose() method of App.

class App(Application):
    """Application class for the MedNIST classifier."""

    def compose(self):
        app_context = Application.init_app_context({})  # Do not pass argv in Jupyter Notebook
        app_input_path = Path(app_context.input_path)
        app_output_path = Path(app_context.output_path)
        model_path = Path(app_context.model_path)
        load_pil_op = LoadPILOperator(self, CountCondition(self, 1), input_folder=app_input_path, name="pil_loader_op")
        classifier_op = MedNISTClassifierOperator(
            self, app_context=app_context, output_folder=app_output_path, model_path=model_path, name="classifier_op"
        )

        my_model_info = ModelInfo("MONAI WG Trainer", "MEDNIST Classifier", "0.1", "xyz")
        my_equipment = EquipmentInfo(manufacturer="MOANI Deploy App SDK", manufacturer_model="DICOM SR Writer")
        my_special_tags = {"SeriesDescription": "Not for clinical use. The result is for research use only."}
        dicom_sr_operator = DICOMTextSRWriterOperator(
            self,
            copy_tags=False,
            model_info=my_model_info,
            equipment_info=my_equipment,
            custom_tags=my_special_tags,
            output_folder=app_output_path,
        )

        self.add_flow(load_pil_op, classifier_op, {("image", "image")})
        self.add_flow(classifier_op, dicom_sr_operator, {("result_text", "text")})

Executing app locally

The test input file file, output path, and model have been prepared, and the paths set in the environment variables, so we can go ahead and execute the application Jupyter notebook with a clean output folder.

!rm -rf $HOLOSCAN_OUTPUT_PATH
app = App().run()
[info] [gxf_executor.cpp:210] Creating context
[info] [gxf_executor.cpp:1595] Loading extensions from configs...
[info] [gxf_executor.cpp:1741] Activating Graph...
[info] [gxf_executor.cpp:1771] Running Graph...
[info] [gxf_executor.cpp:1773] Waiting for completion...
[info] [gxf_executor.cpp:1774] Graph execution waiting. Fragment: 
[info] [greedy_scheduler.cpp:190] Scheduling 3 entities
/home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages/monai/data/meta_tensor.py:116: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
  return torch.as_tensor(x, *args, **_kwargs).as_subclass(cls)  # type: ignore
AbdomenCT
/home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages/pydicom/valuerep.py:443: UserWarning: Invalid value for VR UI: 'xyz'. Please see <https://dicom.nema.org/medical/dicom/current/output/html/part05.html#table_6.2-1> for allowed values for each VR.
  warnings.warn(msg)
[info] [greedy_scheduler.cpp:369] Scheduler stopped: Some entities are waiting for execution, but there are no periodic or async entities to get out of the deadlock.
[info] [greedy_scheduler.cpp:398] Scheduler finished.
[info] [gxf_executor.cpp:1783] Graph execution deactivating. Fragment: 
[info] [gxf_executor.cpp:1784] Deactivating Graph...
[info] [gxf_executor.cpp:1787] Graph execution finished. Fragment: 
[info] [gxf_executor.cpp:229] Destroying context
!cat $HOLOSCAN_OUTPUT_PATH/output.json
"AbdomenCT"

Once the application is verified inside Jupyter notebook, we can write the whole application as a file(mednist_classifier_monaideploy.py) by concatenating code above, then add the following lines:

if __name__ == "__main__":
    App()

The above lines are needed to execute the application code by using python interpreter.

# Create an application folder
!mkdir -p mednist_app
!rm -rf mednist_app/*
%%writefile mednist_app/mednist_classifier_monaideploy.py

# Copyright 2021-2023 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from pathlib import Path
from typing import Optional

import torch

from monai.deploy.conditions import CountCondition
from monai.deploy.core import AppContext, Application, ConditionType, Fragment, Image, Operator, OperatorSpec
from monai.deploy.operators.dicom_text_sr_writer_operator import DICOMTextSRWriterOperator, EquipmentInfo, ModelInfo
from monai.transforms import AddChannel, Compose, EnsureType, ScaleIntensity

MEDNIST_CLASSES = ["AbdomenCT", "BreastMRI", "CXR", "ChestCT", "Hand", "HeadCT"]


# @md.env(pip_packages=["pillow"])
class LoadPILOperator(Operator):
    """Load image from the given input (DataPath) and set numpy array to the output (Image)."""

    DEFAULT_INPUT_FOLDER = Path.cwd() / "input"
    DEFAULT_OUTPUT_NAME = "image"

    # For now, need to have the input folder as an instance attribute, set on init.
    # If dynamically changing the input folder, per compute, then use a (optional) input port to convey the
    # value of the input folder, which is then emitted by a upstream operator.
    def __init__(
        self,
        fragment: Fragment,
        *args,
        input_folder: Path = DEFAULT_INPUT_FOLDER,
        output_name: str = DEFAULT_OUTPUT_NAME,
        **kwargs,
    ):
        """Creates an loader object with the input folder and the output port name overrides as needed.

        Args:
            fragment (Fragment): An instance of the Application class which is derived from Fragment.
            input_folder (Path): Folder from which to load input file(s).
                                 Defaults to `input` in the current working directory.
            output_name (str): Name of the output port, which is an image object. Defaults to `image`.
        """

        self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
        self.input_path = input_folder
        self.index = 0
        self.output_name_image = (
            output_name.strip() if output_name and len(output_name.strip()) > 0 else LoadPILOperator.DEFAULT_OUTPUT_NAME
        )

        super().__init__(fragment, *args, **kwargs)

    def setup(self, spec: OperatorSpec):
        """Set up the named input and output port(s)"""
        spec.output(self.output_name_image)

    def compute(self, op_input, op_output, context):
        import numpy as np
        from PIL import Image as PILImage

        # Input path is stored in the object attribute, but could change to use a named port if need be.
        input_path = self.input_path
        if input_path.is_dir():
            input_path = next(self.input_path.glob("*.*"))  # take the first file

        image = PILImage.open(input_path)
        image = image.convert("L")  # convert to greyscale image
        image_arr = np.asarray(image)

        output_image = Image(image_arr)  # create Image domain object with a numpy array
        op_output.emit(output_image, self.output_name_image)  # cannot omit the name even if single output.


# @md.env(pip_packages=["monai"])
class MedNISTClassifierOperator(Operator):
    """Classifies the given image and returns the class name.

    Named inputs:
        image: Image object for which to generate the classification.
        output_folder: Optional, the path to save the results JSON file, overridingthe the one set on __init__

    Named output:
        result_text: The classification results in text.
    """

    DEFAULT_OUTPUT_FOLDER = Path.cwd() / "classification_results"
    # For testing the app directly, the model should be at the following path.
    MODEL_LOCAL_PATH = Path(os.environ.get("HOLOSCAN_MODEL_PATH", Path.cwd() / "model/model.ts"))

    def __init__(
        self,
        frament: Fragment,
        *args,
        app_context: AppContext,
        model_name: Optional[str] = "",
        model_path: Path = MODEL_LOCAL_PATH,
        output_folder: Path = DEFAULT_OUTPUT_FOLDER,
        **kwargs,
    ):
        """Creates an instance with the reference back to the containing application/fragment.

        fragment (Fragment): An instance of the Application class which is derived from Fragment.
        model_name (str, optional): Name of the model. Default to "" for single model app.
        model_path (Path): Path to the model file. Defaults to model/models.ts of current working dir.
        output_folder (Path, optional): output folder for saving the classification results JSON file.
        """

        # the names used for the model inference input and output
        self._input_dataset_key = "image"
        self._pred_dataset_key = "pred"

        # The names used for the operator input and output
        self.input_name_image = "image"
        self.output_name_result = "result_text"

        # The name of the optional input port for passing data to override the output folder path.
        self.input_name_output_folder = "output_folder"

        # The output folder set on the object can be overriden at each compute by data in the optional named input
        self.output_folder = output_folder

        # Need the name when there are multiple models loaded
        self._model_name = model_name.strip() if isinstance(model_name, str) else ""
        # Need the path to load the models when they are not loaded in the execution context
        self.model_path = model_path
        self.app_context = app_context
        self.model = self._get_model(self.app_context, self.model_path, self._model_name)

        # This needs to be at the end of the constructor.
        super().__init__(frament, *args, **kwargs)

    def _get_model(self, app_context: AppContext, model_path: Path, model_name: str):
        """Load the model with the given name from context or model path

        Args:
            app_context (AppContext): The application context object holding the model(s)
            model_path (Path): The path to the model file, as a backup to load model directly
            model_name (str): The name of the model, when multiples are loaded in the context
        """

        if app_context.models:
            # `app_context.models.get(model_name)` returns a model instance if exists.
            # If model_name is not specified and only one model exists, it returns that model.
            model = app_context.models.get(model_name)
        else:
            model = torch.jit.load(
                MedNISTClassifierOperator.MODEL_LOCAL_PATH,
                map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
            )

        return model

    def setup(self, spec: OperatorSpec):
        """Set up the operator named input and named output, both are in-memory objects."""

        spec.input(self.input_name_image)
        spec.input(self.input_name_output_folder).condition(ConditionType.NONE)  # Optional for overriding.
        spec.output(self.output_name_result).condition(ConditionType.NONE)  # Not forcing a downstream receiver.

    @property
    def transform(self):
        return Compose([AddChannel(), ScaleIntensity(), EnsureType()])

    def compute(self, op_input, op_output, context):
        import json

        import torch

        img = op_input.receive(self.input_name_image).asnumpy()  # (64, 64), uint8. Input validation can be added.
        image_tensor = self.transform(img)  # (1, 64, 64), torch.float64
        image_tensor = image_tensor[None].float()  # (1, 1, 64, 64), torch.float32

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        image_tensor = image_tensor.to(device)

        with torch.no_grad():
            outputs = self.model(image_tensor)

        _, output_classes = outputs.max(dim=1)

        result = MEDNIST_CLASSES[output_classes[0]]  # get the class name
        print(result)
        op_output.emit(result, self.output_name_result)

        # Get output folder, with value in optional input port overriding the obj attribute
        output_folder_on_compute = op_input.receive(self.input_name_output_folder) or self.output_folder
        Path.mkdir(output_folder_on_compute, parents=True, exist_ok=True)  # Let exception bubble up if raised.
        output_path = output_folder_on_compute / "output.json"
        with open(output_path, "w") as fp:
            json.dump(result, fp)


# @md.resource(cpu=1, gpu=1, memory="1Gi")
class App(Application):
    """Application class for the MedNIST classifier."""

    def compose(self):
        app_context = AppContext({})  # Let it figure out all the attributes without overriding
        app_input_path = Path(app_context.input_path)
        app_output_path = Path(app_context.output_path)
        model_path = Path(app_context.model_path)
        load_pil_op = LoadPILOperator(self, CountCondition(self, 1), input_folder=app_input_path, name="pil_loader_op")
        classifier_op = MedNISTClassifierOperator(
            self, app_context=app_context, output_folder=app_output_path, model_path=model_path, name="classifier_op"
        )

        my_model_info = ModelInfo("MONAI WG Trainer", "MEDNIST Classifier", "0.1", "xyz")
        my_equipment = EquipmentInfo(manufacturer="MOANI Deploy App SDK", manufacturer_model="DICOM SR Writer")
        my_special_tags = {"SeriesDescription": "Not for clinical use. The result is for research use only."}
        dicom_sr_operator = DICOMTextSRWriterOperator(
            self,
            copy_tags=False,
            model_info=my_model_info,
            equipment_info=my_equipment,
            custom_tags=my_special_tags,
            output_folder=app_output_path,
        )

        self.add_flow(load_pil_op, classifier_op, {("image", "image")})
        self.add_flow(classifier_op, dicom_sr_operator, {("result_text", "text")})


if __name__ == "__main__":
    App().run()
Writing mednist_app/mednist_classifier_monaideploy.py

This time, let’s execute the app in the command line.

!python "mednist_app/mednist_classifier_monaideploy.py"
[info] [gxf_executor.cpp:210] Creating context
[info] [gxf_executor.cpp:1595] Loading extensions from configs...
[info] [gxf_executor.cpp:1741] Activating Graph...
[info] [gxf_executor.cpp:1771] Running Graph...
[info] [gxf_executor.cpp:1773] Waiting for completion...
[info] [gxf_executor.cpp:1774] Graph execution waiting. Fragment: 
[info] [greedy_scheduler.cpp:190] Scheduling 3 entities
/home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages/monai/utils/deprecate_utils.py:111: FutureWarning: <class 'monai.transforms.utility.array.AddChannel'>: Class `AddChannel` has been deprecated since version 0.8. It will be removed in version 1.3. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead with `channel_dim='no_channel'`.
  warn_deprecated(obj, msg, warning_category)
/home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages/monai/data/meta_tensor.py:116: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
  return torch.as_tensor(x, *args, **_kwargs).as_subclass(cls)  # type: ignore
AbdomenCT
/home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages/pydicom/valuerep.py:443: UserWarning: Invalid value for VR UI: 'xyz'. Please see <https://dicom.nema.org/medical/dicom/current/output/html/part05.html#table_6.2-1> for allowed values for each VR.
  warnings.warn(msg)
[info] [greedy_scheduler.cpp:369] Scheduler stopped: Some entities are waiting for execution, but there are no periodic or async entities to get out of the deadlock.
[info] [greedy_scheduler.cpp:398] Scheduler finished.
[info] [gxf_executor.cpp:1783] Graph execution deactivating. Fragment: 
[info] [gxf_executor.cpp:1784] Deactivating Graph...
[info] [gxf_executor.cpp:1787] Graph execution finished. Fragment: 
[info] [gxf_executor.cpp:229] Destroying context
!cat $HOLOSCAN_OUTPUT_PATH/output.json
"AbdomenCT"

Packaging app

Let’s package the app with MONAI Application Packager.

In this version of the App SDK, we need to write out the configuration yaml file as well as the package requirements file, in the application folder.

%%writefile mednist_app/app.yaml
%YAML 1.2
---
application:
  title: MONAI Deploy App Package - MedNIST Classifier App
  version: 1.0
  inputFormats: ["file"]
  outputFormats: ["file"]

resources:
  cpu: 1
  gpu: 1
  memory: 1Gi
  gpuMemory: 1Gi
Writing mednist_app/app.yaml
%%writefile mednist_app/requirements.txt
monai>=1.2.0
Pillow>=8.4.0
pydicom>=2.3.0
highdicom>=0.18.2
SimpleITK>=2.0.0
setuptools>=59.5.0 # for pkg_resources
Writing mednist_app/requirements.txt
tag_prefix = "mednist_app"

!monai-deploy package "mednist_app/mednist_classifier_monaideploy.py" -m {models_folder} -c "mednist_app/app.yaml" -t {tag_prefix}:1.0 --platform x64-workstation -l DEBUG
/home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages/pydantic/_internal/_config.py:269: UserWarning: Valid config keys have changed in V2:
* 'allow_population_by_field_name' has been renamed to 'populate_by_name'
  warnings.warn(message, UserWarning)
[2023-08-03 20:49:29,599] [INFO] (packager.parameters) - Application: /home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/mednist_app/mednist_classifier_monaideploy.py
[2023-08-03 20:49:29,599] [INFO] (packager.parameters) - Detected application type: Python File
[2023-08-03 20:49:29,599] [INFO] (packager) - Scanning for models in {models_path}...
[2023-08-03 20:49:29,599] [DEBUG] (packager) - Model model=/home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/models/model added.
[2023-08-03 20:49:29,599] [INFO] (packager) - Reading application configuration from /home/mqin/src/monai-deploy-app-sdk/notebooks/tutorials/mednist_app/app.yaml...
[2023-08-03 20:49:29,601] [INFO] (packager) - Generating app.json...
[2023-08-03 20:49:29,601] [INFO] (packager) - Generating pkg.json...
[2023-08-03 20:49:29,601] [DEBUG] (common) - 
=============== Begin app.json ===============
{
    "apiVersion": "1.0.0",
    "command": "[\"python3\", \"/opt/holoscan/app/mednist_classifier_monaideploy.py\"]",
    "environment": {
        "HOLOSCAN_APPLICATION": "/opt/holoscan/app",
        "HOLOSCAN_INPUT_PATH": "input/",
        "HOLOSCAN_OUTPUT_PATH": "output/",
        "HOLOSCAN_WORKDIR": "/var/holoscan",
        "HOLOSCAN_MODEL_PATH": "/opt/holoscan/models",
        "HOLOSCAN_CONFIG_PATH": "/var/holoscan/app.yaml",
        "HOLOSCAN_APP_MANIFEST_PATH": "/etc/holoscan/app.json",
        "HOLOSCAN_PKG_MANIFEST_PATH": "/etc/holoscan/pkg.json",
        "HOLOSCAN_DOCS_PATH": "/opt/holoscan/docs",
        "HOLOSCAN_LOGS_PATH": "/var/holoscan/logs"
    },
    "input": {
        "path": "input/",
        "formats": null
    },
    "liveness": null,
    "output": {
        "path": "output/",
        "formats": null
    },
    "readiness": null,
    "sdk": "monai-deploy",
    "sdkVersion": "0.6.0",
    "timeout": 0,
    "version": 1.0,
    "workingDirectory": "/var/holoscan"
}
================ End app.json ================
                 
[2023-08-03 20:49:29,602] [DEBUG] (common) - 
=============== Begin pkg.json ===============
{
    "apiVersion": "1.0.0",
    "applicationRoot": "/opt/holoscan/app",
    "modelRoot": "/opt/holoscan/models",
    "models": {
        "model": "/opt/holoscan/models"
    },
    "resources": {
        "cpu": 1,
        "gpu": 1,
        "memory": "1Gi",
        "gpuMemory": "1Gi"
    },
    "version": 1.0
}
================ End pkg.json ================
                 
[2023-08-03 20:49:29,635] [DEBUG] (packager.builder) - 
========== Begin Dockerfile ==========


FROM nvcr.io/nvidia/clara-holoscan/holoscan:v0.6.0-dgpu

ENV DEBIAN_FRONTEND=noninteractive
ENV TERM=xterm-256color

ARG UNAME
ARG UID
ARG GID

RUN mkdir -p /etc/holoscan/ \
        && mkdir -p /opt/holoscan/ \
        && mkdir -p /var/holoscan \
        && mkdir -p /opt/holoscan/app \
        && mkdir -p /var/holoscan/input \
        && mkdir -p /var/holoscan/output

LABEL base="nvcr.io/nvidia/clara-holoscan/holoscan:v0.6.0-dgpu"
LABEL tag="mednist_app:1.0"
LABEL org.opencontainers.image.title="MONAI Deploy App Package - MedNIST Classifier App"
LABEL org.opencontainers.image.version="1.0"
LABEL org.nvidia.holoscan="0.6.0"

ENV HOLOSCAN_ENABLE_HEALTH_CHECK=true
ENV HOLOSCAN_INPUT_PATH=/var/holoscan/input
ENV HOLOSCAN_OUTPUT_PATH=/var/holoscan/output
ENV HOLOSCAN_WORKDIR=/var/holoscan
ENV HOLOSCAN_APPLICATION=/opt/holoscan/app
ENV HOLOSCAN_TIMEOUT=0
ENV HOLOSCAN_MODEL_PATH=/opt/holoscan/models
ENV HOLOSCAN_DOCS_PATH=/opt/holoscan/docs
ENV HOLOSCAN_CONFIG_PATH=/var/holoscan/app.yaml
ENV HOLOSCAN_APP_MANIFEST_PATH=/etc/holoscan/app.json
ENV HOLOSCAN_PKG_MANIFEST_PATH=/etc/holoscan/pkg.json
ENV HOLOSCAN_LOGS_PATH=/var/holoscan/logs
ENV PATH=/root/.local/bin:/opt/nvidia/holoscan:$PATH
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/libtorch/1.13.1/lib/:/opt/nvidia/holoscan/lib

RUN apt-get update \
    && apt-get install -y curl jq \
    && rm -rf /var/lib/apt/lists/*

ENV PYTHONPATH="/opt/holoscan/app:$PYTHONPATH"



RUN groupadd -g $GID $UNAME
RUN useradd -rm -d /home/$UNAME -s /bin/bash -g $GID -G sudo -u $UID $UNAME
RUN chown -R holoscan /var/holoscan 
RUN chown -R holoscan /var/holoscan/input 
RUN chown -R holoscan /var/holoscan/output 

# Set the working directory
WORKDIR /var/holoscan

# Copy HAP/MAP tool script
COPY ./tools /var/holoscan/tools
RUN chmod +x /var/holoscan/tools


# Copy gRPC health probe

USER $UNAME

ENV PATH=/root/.local/bin:/home/holoscan/.local/bin:/opt/nvidia/holoscan:$PATH

COPY ./pip/requirements.txt /tmp/requirements.txt

RUN pip install --upgrade pip
RUN pip install --no-cache-dir --user -r /tmp/requirements.txt

# Install Holoscan from PyPI org
RUN pip install holoscan==0.6.0


# Copy user-specified MONAI Deploy SDK file
COPY ./monai_deploy_app_sdk-0.5.1+7.g9fa1185.dirty-py3-none-any.whl /tmp/monai_deploy_app_sdk-0.5.1+7.g9fa1185.dirty-py3-none-any.whl
RUN pip install /tmp/monai_deploy_app_sdk-0.5.1+7.g9fa1185.dirty-py3-none-any.whl




COPY ./models  /opt/holoscan/models

COPY ./map/app.json /etc/holoscan/app.json
COPY ./app.config /var/holoscan/app.yaml
COPY ./map/pkg.json /etc/holoscan/pkg.json

COPY ./app /opt/holoscan/app

ENTRYPOINT ["/var/holoscan/tools"]
=========== End Dockerfile ===========

[2023-08-03 20:49:29,636] [INFO] (packager.builder) - 
===============================================================================
Building image for:                 x64-workstation
    Architecture:                   linux/amd64
    Base Image:                     nvcr.io/nvidia/clara-holoscan/holoscan:v0.6.0-dgpu
    Build Image:                    N/A  
    Cache:                          Enabled
    Configuration:                  dgpu
    Holoiscan SDK Package:          pypi.org
    MONAI Deploy App SDK Package:   /home/mqin/src/monai-deploy-app-sdk/dist/monai_deploy_app_sdk-0.5.1+7.g9fa1185.dirty-py3-none-any.whl
    gRPC Health Probe:              N/A
    SDK Version:                    0.6.0
    SDK:                            monai-deploy
    Tag:                            mednist_app-x64-workstation-dgpu-linux-amd64:1.0
    
[2023-08-03 20:49:30,337] [INFO] (common) - Using existing Docker BuildKit builder `holoscan_app_builder`
[2023-08-03 20:49:30,338] [DEBUG] (packager.builder) - Building Holoscan Application Package: tag=mednist_app-x64-workstation-dgpu-linux-amd64:1.0
#1 [internal] load .dockerignore
#1 transferring context: 1.79kB 0.0s done
#1 DONE 0.1s

#2 [internal] load build definition from Dockerfile
#2 transferring dockerfile: 2.67kB done
#2 DONE 0.1s

#3 [internal] load metadata for nvcr.io/nvidia/clara-holoscan/holoscan:v0.6.0-dgpu
#3 DONE 0.8s

#4 [internal] load build context
#4 DONE 0.0s

#5 importing cache manifest from local:9585092855700183608
#5 DONE 0.0s

#6 importing cache manifest from nvcr.io/nvidia/clara-holoscan/holoscan:v0.6.0-dgpu
#6 DONE 0.9s

#7 [ 1/22] FROM nvcr.io/nvidia/clara-holoscan/holoscan:v0.6.0-dgpu@sha256:9653f80f241fd542f25afbcbcf7a0d02ed7e5941c79763e69def5b1e6d9fb7bc
#7 resolve nvcr.io/nvidia/clara-holoscan/holoscan:v0.6.0-dgpu@sha256:9653f80f241fd542f25afbcbcf7a0d02ed7e5941c79763e69def5b1e6d9fb7bc
#7 resolve nvcr.io/nvidia/clara-holoscan/holoscan:v0.6.0-dgpu@sha256:9653f80f241fd542f25afbcbcf7a0d02ed7e5941c79763e69def5b1e6d9fb7bc 0.1s done
#7 DONE 0.1s

#4 [internal] load build context
#4 transferring context: 28.78MB 0.2s done
#4 DONE 0.3s

#8 [ 6/22] RUN chown -R holoscan /var/holoscan
#8 CACHED

#9 [12/22] COPY ./pip/requirements.txt /tmp/requirements.txt
#9 CACHED

#10 [10/22] COPY ./tools /var/holoscan/tools
#10 CACHED

#11 [15/22] RUN pip install holoscan==0.6.0
#11 CACHED

#12 [16/22] COPY ./monai_deploy_app_sdk-0.5.1+7.g9fa1185.dirty-py3-none-any.whl /tmp/monai_deploy_app_sdk-0.5.1+7.g9fa1185.dirty-py3-none-any.whl
#12 CACHED

#13 [ 7/22] RUN chown -R holoscan /var/holoscan/input
#13 CACHED

#14 [ 4/22] RUN groupadd -g 1000 holoscan
#14 CACHED

#15 [14/22] RUN pip install --no-cache-dir --user -r /tmp/requirements.txt
#15 CACHED

#16 [11/22] RUN chmod +x /var/holoscan/tools
#16 CACHED

#17 [ 5/22] RUN useradd -rm -d /home/holoscan -s /bin/bash -g 1000 -G sudo -u 1000 holoscan
#17 CACHED

#18 [ 9/22] WORKDIR /var/holoscan
#18 CACHED

#19 [ 8/22] RUN chown -R holoscan /var/holoscan/output
#19 CACHED

#20 [ 2/22] RUN mkdir -p /etc/holoscan/         && mkdir -p /opt/holoscan/         && mkdir -p /var/holoscan         && mkdir -p /opt/holoscan/app         && mkdir -p /var/holoscan/input         && mkdir -p /var/holoscan/output
#20 CACHED

#21 [ 3/22] RUN apt-get update     && apt-get install -y curl jq     && rm -rf /var/lib/apt/lists/*
#21 CACHED

#22 [13/22] RUN pip install --upgrade pip
#22 CACHED

#23 [17/22] RUN pip install /tmp/monai_deploy_app_sdk-0.5.1+7.g9fa1185.dirty-py3-none-any.whl
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 6.29MB / 2.40GB 0.2s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 130.02MB / 2.40GB 2.7s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 251.66MB / 2.40GB 5.1s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 374.34MB / 2.40GB 7.7s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 497.03MB / 2.40GB 10.2s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 629.15MB / 2.40GB 12.9s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 758.12MB / 2.40GB 15.5s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 880.80MB / 2.40GB 18.0s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 1.00GB / 2.40GB 20.6s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 1.13GB / 2.40GB 23.1s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 1.26GB / 2.40GB 25.4s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 1.38GB / 2.40GB 27.9s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 1.51GB / 2.40GB 30.3s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 1.63GB / 2.40GB 32.9s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 1.76GB / 2.40GB 35.7s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 1.89GB / 2.40GB 38.4s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 2.02GB / 2.40GB 41.0s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 2.14GB / 2.40GB 43.4s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 2.26GB / 2.40GB 45.8s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 2.38GB / 2.40GB 48.2s
#23 sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 2.40GB / 2.40GB 49.8s done
#23 extracting sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523
#23 extracting sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 57.5s done
#23 sha256:f20d17e4fd485b1a37bb580c6b5e8b8d707b382d387df57004086b8036ddaded 5.24MB / 105.68MB 0.2s
#23 sha256:f20d17e4fd485b1a37bb580c6b5e8b8d707b382d387df57004086b8036ddaded 13.63MB / 105.68MB 0.3s
#23 sha256:f20d17e4fd485b1a37bb580c6b5e8b8d707b382d387df57004086b8036ddaded 22.02MB / 105.68MB 0.5s
#23 sha256:f20d17e4fd485b1a37bb580c6b5e8b8d707b382d387df57004086b8036ddaded 29.36MB / 105.68MB 0.6s
#23 sha256:f20d17e4fd485b1a37bb580c6b5e8b8d707b382d387df57004086b8036ddaded 35.65MB / 105.68MB 0.8s
#23 sha256:f20d17e4fd485b1a37bb580c6b5e8b8d707b382d387df57004086b8036ddaded 44.04MB / 105.68MB 0.9s
#23 sha256:f20d17e4fd485b1a37bb580c6b5e8b8d707b382d387df57004086b8036ddaded 51.38MB / 105.68MB 1.1s
#23 sha256:f20d17e4fd485b1a37bb580c6b5e8b8d707b382d387df57004086b8036ddaded 59.77MB / 105.68MB 1.2s
#23 sha256:f20d17e4fd485b1a37bb580c6b5e8b8d707b382d387df57004086b8036ddaded 67.15MB / 105.68MB 1.4s
#23 sha256:f20d17e4fd485b1a37bb580c6b5e8b8d707b382d387df57004086b8036ddaded 75.50MB / 105.68MB 1.5s
#23 sha256:f20d17e4fd485b1a37bb580c6b5e8b8d707b382d387df57004086b8036ddaded 82.84MB / 105.68MB 1.7s
#23 sha256:f20d17e4fd485b1a37bb580c6b5e8b8d707b382d387df57004086b8036ddaded 89.13MB / 105.68MB 1.8s
#23 sha256:f20d17e4fd485b1a37bb580c6b5e8b8d707b382d387df57004086b8036ddaded 97.52MB / 105.68MB 2.0s
#23 sha256:f20d17e4fd485b1a37bb580c6b5e8b8d707b382d387df57004086b8036ddaded 105.68MB / 105.68MB 2.1s
#23 sha256:f20d17e4fd485b1a37bb580c6b5e8b8d707b382d387df57004086b8036ddaded 105.68MB / 105.68MB 2.4s done
#23 extracting sha256:f20d17e4fd485b1a37bb580c6b5e8b8d707b382d387df57004086b8036ddaded
#23 extracting sha256:f20d17e4fd485b1a37bb580c6b5e8b8d707b382d387df57004086b8036ddaded 2.9s done
#23 sha256:55e32cef42f992f9c914515dc95457ad65a501d20fb8face7a82d51a620e8d0c 149.04kB / 149.04kB 0.0s done
#23 extracting sha256:55e32cef42f992f9c914515dc95457ad65a501d20fb8face7a82d51a620e8d0c 0.0s done
#23 sha256:806c67c703b35fc283718dc9a3a7062a0303aabbea1395b138e66c10cd915f56 6.29MB / 48.57MB 0.2s
#23 sha256:806c67c703b35fc283718dc9a3a7062a0303aabbea1395b138e66c10cd915f56 13.63MB / 48.57MB 0.3s
#23 sha256:806c67c703b35fc283718dc9a3a7062a0303aabbea1395b138e66c10cd915f56 22.02MB / 48.57MB 0.5s
#23 sha256:806c67c703b35fc283718dc9a3a7062a0303aabbea1395b138e66c10cd915f56 29.36MB / 48.57MB 0.6s
#23 sha256:806c67c703b35fc283718dc9a3a7062a0303aabbea1395b138e66c10cd915f56 35.65MB / 48.57MB 0.8s
#23 sha256:806c67c703b35fc283718dc9a3a7062a0303aabbea1395b138e66c10cd915f56 42.99MB / 48.57MB 0.9s
#23 sha256:806c67c703b35fc283718dc9a3a7062a0303aabbea1395b138e66c10cd915f56 48.57MB / 48.57MB 1.1s
#23 sha256:806c67c703b35fc283718dc9a3a7062a0303aabbea1395b138e66c10cd915f56 48.57MB / 48.57MB 1.1s done
#23 extracting sha256:806c67c703b35fc283718dc9a3a7062a0303aabbea1395b138e66c10cd915f56
#23 extracting sha256:806c67c703b35fc283718dc9a3a7062a0303aabbea1395b138e66c10cd915f56 2.2s done
#23 CACHED

#24 [18/22] COPY ./models  /opt/holoscan/models
#24 DONE 3.1s

#25 [19/22] COPY ./map/app.json /etc/holoscan/app.json
#25 DONE 0.1s

#26 [20/22] COPY ./app.config /var/holoscan/app.yaml
#26 DONE 0.1s

#27 [21/22] COPY ./map/pkg.json /etc/holoscan/pkg.json
#27 DONE 0.1s

#28 [22/22] COPY ./app /opt/holoscan/app
#28 DONE 0.1s

#29 exporting to docker image format
#29 exporting layers
#29 exporting layers 1.1s done
#29 exporting manifest sha256:c788c08ceb970d2b6c8a36eaf5d5809a959ed6ac92e387bf964a3b4998d3f2af 0.0s done
#29 exporting config sha256:d22d232013f038d48c43d4caa2246268674e9c6c81083f7ab6c3f37ec7ce31e2 0.0s done
#29 sending tarball
#29 ...

#30 importing to docker
#30 DONE 1.3s

#29 exporting to docker image format
#29 sending tarball 52.6s done
#29 DONE 53.8s

#31 exporting content cache
#31 preparing build cache for export
#31 writing layer sha256:0709800848b4584780b40e7e81200689870e890c38b54e96b65cd0a3b1942f2d done
#31 writing layer sha256:0ce020987cfa5cd1654085af3bb40779634eb3d792c4a4d6059036463ae0040d done
#31 writing layer sha256:0f65089b284381bf795d15b1a186e2a8739ea957106fa526edef0d738e7cda70 done
#31 writing layer sha256:12a47450a9f9cc5d4edab65d0f600dbbe8b23a1663b0b3bb2c481d40e074b580 done
#31 writing layer sha256:1338fe24653eba781a71bd79902b5b905624589983ce80c816a09bda7b89e3bd
#31 writing layer sha256:1338fe24653eba781a71bd79902b5b905624589983ce80c816a09bda7b89e3bd 0.6s done
#31 writing layer sha256:1477e9e55f1216fe4085565e21baa742149b480d35141f298402b1e766fb58d3 0.0s done
#31 writing layer sha256:1de965777e2e37c7fabe00bdbf3d0203ca83ed30a71a5479c3113fe4fc48c4bb done
#31 writing layer sha256:24b5aa2448e920814dd67d7d3c0169b2cdacb13c4048d74ded3b4317843b13ff done
#31 writing layer sha256:2d42104dbf0a7cc962b791f6ab4f45a803f8a36d296f996aca180cfb2f3e30d0 done
#31 writing layer sha256:2fa1ce4fa3fec6f9723380dc0536b7c361d874add0baaddc4bbf2accac82d2ff
#31 writing layer sha256:2fa1ce4fa3fec6f9723380dc0536b7c361d874add0baaddc4bbf2accac82d2ff done
#31 writing layer sha256:38794be1b5dc99645feabf89b22cd34fb5bdffb5164ad920e7df94f353efe9c0 done
#31 writing layer sha256:38f963dc57c1e7b68a738fe39ed9f9345df7188111a047e2163a46648d7f1d88 done
#31 writing layer sha256:3e7e4c9bc2b136814c20c04feb4eea2b2ecf972e20182d88759931130cfb4181 done
#31 writing layer sha256:3fd77037ad585442cd82d64e337f49a38ddba50432b2a1e563a48401d25c79e6 done
#31 writing layer sha256:41814ed91034b30ac9c44dfc604a4bade6138005ccf682372c02e0bead66dbc0 done
#31 writing layer sha256:45893188359aca643d5918c9932da995364dc62013dfa40c075298b1baabece3 done
#31 writing layer sha256:49bc651b19d9e46715c15c41b7c0daa007e8e25f7d9518f04f0f06592799875a done
#31 writing layer sha256:4aeb0049534a685f9b8d851171ca3ee850fc1609d85e651ebdb0508d8d1e9403 0.0s done
#31 writing layer sha256:4c12db5118d8a7d909e4926d69a2192d2b3cd8b110d49c7504a4f701258c1ccc done
#31 writing layer sha256:4cc43a803109d6e9d1fd35495cef9b1257035f5341a2db54f7a1940815b6cc65 done
#31 writing layer sha256:4d32b49e2995210e8937f0898327f196d3fcc52486f0be920e8b2d65f150a7ab done
#31 writing layer sha256:4d6fe980bad9cd7b2c85a478c8033cae3d098a81f7934322fb64658b0c8f9854 done
#31 writing layer sha256:4f4fb700ef54461cfa02571ae0db9a0dc1e0cdb5577484a6d75e68dc38e8acc1 done
#31 writing layer sha256:5150182f1ff123399b300ca469e00f6c4d82e1b9b72652fb8ee7eab370245236 done
#31 writing layer sha256:55e32cef42f992f9c914515dc95457ad65a501d20fb8face7a82d51a620e8d0c done
#31 writing layer sha256:595c38fa102c61c3dda19bdab70dcd26a0e50465b986d022a84fa69023a05d0f done
#31 writing layer sha256:59d451175f6950740e26d38c322da0ef67cb59da63181eb32996f752ba8a2f17 done
#31 writing layer sha256:5ad1f2004580e415b998124ea394e9d4072a35d70968118c779f307204d6bd17 done
#31 writing layer sha256:62598eafddf023e7f22643485f4321cbd51ff7eee743b970db12454fd3c8c675 done
#31 writing layer sha256:63d7e616a46987136f4cc9eba95db6f6327b4854cfe3c7e20fed6db0c966e380 done
#31 writing layer sha256:6939d591a6b09b14a437e5cd2d6082a52b6d76bec4f72d960440f097721da34f done
#31 writing layer sha256:698318e5a60e5e0d48c45bf992f205a9532da567fdfe94bd59be2e192975dd6f done
#31 writing layer sha256:6ddc1d0f91833b36aac1c6f0c8cea005c87d94bab132d46cc06d9b060a81cca3 done
#31 writing layer sha256:74ac1f5a47c0926bff1e997bb99985a09926f43bd0895cb27ceb5fa9e95f8720 done
#31 writing layer sha256:7577973918dd30e764733a352a93f418000bc3181163ca451b2307492c1a6ba9 done
#31 writing layer sha256:806c67c703b35fc283718dc9a3a7062a0303aabbea1395b138e66c10cd915f56 done
#31 writing layer sha256:886c886d8a09d8befb92df75dd461d4f97b77d7cff4144c4223b0d2f6f2c17f2 done
#31 writing layer sha256:8a7451db9b4b817b3b33904abddb7041810a4ffe8ed4a034307d45d9ae9b3f2a done
#31 writing layer sha256:916f4054c6e7f10de4fd7c08ffc75fa23ebecca4eceb8183cb1023b33b1696c9 done
#31 writing layer sha256:9463aa3f56275af97693df69478a2dc1d171f4e763ca6f7b6f370a35e605c154 done
#31 writing layer sha256:955fd173ed884230c2eded4542d10a97384b408537be6bbb7c4ae09ccd6fb2d0 done
#31 writing layer sha256:99ef644a0c84a569b9692a76e0c6a1c3e9dedae5d551087be684b6bc1bea6f22 done
#31 writing layer sha256:9c42a4ee99755f441251e6043b2cbba16e49818a88775e7501ec17e379ce3cfd done
#31 writing layer sha256:9c63be0a86e3dc4168db3814bf464e40996afda0031649d9faa8ff7568c3154f done
#31 writing layer sha256:9e04bda98b05554953459b5edef7b2b14d32f1a00b979a23d04b6eb5c191e66b done
#31 writing layer sha256:a4a0c690bc7da07e592514dccaa26098a387e8457f69095e922b6d73f7852502 done
#31 writing layer sha256:a4aafbc094d78a85bef41036173eb816a53bcd3e2564594a32f542facdf2aba6 done
#31 writing layer sha256:ae36a4d38b76948e39a5957025c984a674d2de18ce162a8caaa536e6f06fccea done
#31 writing layer sha256:b2fa40114a4a0725c81b327df89c0c3ed5c05ca9aa7f1157394d5096cf5460ce done
#31 writing layer sha256:b48a5fafcaba74eb5d7e7665601509e2889285b50a04b5b639a23f8adc818157 done
#31 writing layer sha256:b8ec9058cfc8057a4989af89add416b6d4c425cb3e3a4542281d3b188ef8d97f 0.0s done
#31 writing layer sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523
#31 preparing build cache for export 1.3s done
#31 writing layer sha256:b9874c0e107000cd0157c2f6c12f6b095ec79e92b293ef581984b94c75080523 done
#31 writing layer sha256:c86976a083599e36a6441f36f553627194d05ea82bb82a78682e718fe62fccf6 done
#31 writing layer sha256:cb506fbdedc817e3d074f609e2edbf9655aacd7784610a1bbac52f2d7be25438 done
#31 writing layer sha256:d2a6fe65a1f84edb65b63460a75d1cac1aa48b72789006881b0bcfd54cd01ffd done
#31 writing layer sha256:d674572cae0440d01b016bd1a6cf88924f6067f38858706ad4856d78993a0a6e done
#31 writing layer sha256:d709c3fc82181d7bc3561f087363554add07059d1fc1fa014d3da3f9092a7524 0.0s done
#31 writing layer sha256:d8d16d6af76dc7c6b539422a25fdad5efb8ada5a8188069fcd9d113e3b783304 done
#31 writing layer sha256:ddc2ade4f6fe866696cb638c8a102cb644fa842c2ca578392802b3e0e5e3bcb7 done
#31 writing layer sha256:e2cfd7f6244d6f35befa6bda1caa65f1786cecf3f00ef99d7c9a90715ce6a03c done
#31 writing layer sha256:e94a4481e9334ff402bf90628594f64a426672debbdfb55f1290802e52013907 done
#31 writing layer sha256:eaf45e9f32d1f5a9983945a1a9f8dedbb475bc0f578337610e00b4dedec87c20 done
#31 writing layer sha256:eb411bef39c013c9853651e68f00965dbd826d829c4e478884a2886976e9c989 done
#31 writing layer sha256:edfe4a95eb6bd3142aeda941ab871ffcc8c19cf50c33561c210ba8ead2424759 done
#31 writing layer sha256:ef4466d6f927d29d404df9c5af3ef5733c86fa14e008762c90110b963978b1e7 done
#31 writing layer sha256:f20d17e4fd485b1a37bb580c6b5e8b8d707b382d387df57004086b8036ddaded done
#31 writing layer sha256:f346e3ecdf0bee048fa1e3baf1d3128ff0283b903f03e97524944949bd8882e5 done
#31 writing layer sha256:f3f9a00a1ce9aadda250aacb3e66a932676badc5d8519c41517fdf7ea14c13ed done
#31 writing layer sha256:fd849d9bd8889edd43ae38e9f21a912430c8526b2c18f3057a3b2cd74eb27b31 done
#31 writing config sha256:23fbbd00b006000bbd87f5dfe7e12fe71203e710a83580e1ee63125c214ff4d5 0.0s done
#31 writing manifest sha256:7b8dadf0182c3fbe7dede6a29963defd6eff4efa3a6b8e81e5f2dceaaf023210 0.0s done
#31 DONE 1.3s
[2023-08-03 20:52:32,339] [INFO] (packager) - Build Summary:

Platform: x64-workstation/dgpu
    Status:     Succeeded
    Docker Tag: mednist_app-x64-workstation-dgpu-linux-amd64:1.0
    Tarball:    None

Note

Building a MONAI Application Package (Docker image) can take time. Use -l DEBUG option if you want to see the progress.

We can see that the Docker image is created.

!docker image ls | grep {tag_prefix}
mednist_app-x64-workstation-dgpu-linux-amd64              1.0                        d22d232013f0   59 seconds ago   15.4GB

Executing packaged app locally

We can choose to display and export the MAP manifests, but in this example, we will just run the MAP through MONAI Application Runner.

# Clear the output folder and run the MAP. The input is expected to be a folder.
!rm -rf $HOLOSCAN_OUTPUT_PATH
!monai-deploy run -i$HOLOSCAN_INPUT_PATH -o $HOLOSCAN_OUTPUT_PATH mednist_app-x64-workstation-dgpu-linux-amd64:1.0
/home/mqin/src/monai-deploy-app-sdk/.venv/lib/python3.8/site-packages/pydantic/_internal/_config.py:269: UserWarning: Valid config keys have changed in V2:
* 'allow_population_by_field_name' has been renamed to 'populate_by_name'
  warnings.warn(message, UserWarning)
[2023-08-03 20:52:37,269] [INFO] (runner) - Checking dependencies...
[2023-08-03 20:52:37,269] [INFO] (runner) - --> Verifying if "docker" is installed...

[2023-08-03 20:52:37,270] [INFO] (runner) - --> Verifying if "docker-buildx" is installed...

[2023-08-03 20:52:37,270] [INFO] (runner) - --> Verifying if "mednist_app-x64-workstation-dgpu-linux-amd64:1.0" is available...

[2023-08-03 20:52:37,348] [INFO] (runner) - Reading HAP/MAP manifest...
Preparing to copy...?25lCopying from container - 0B?25hSuccessfully copied 2.56kB to /tmp/tmp4dplyjgr/app.json
Preparing to copy...?25lCopying from container - 0B?25hSuccessfully copied 2.05kB to /tmp/tmp4dplyjgr/pkg.json
[2023-08-03 20:52:37,733] [INFO] (runner) - --> Verifying if "nvidia-ctk" is installed...

[2023-08-03 20:52:37,954] [INFO] (common) - Launching container (96d09cbab602) using image 'mednist_app-x64-workstation-dgpu-linux-amd64:1.0'...
    container name:      determined_maxwell
    host name:           mingq-dt
    network:             host
    user:                1000:1000
    ulimits:             memlock=-1:-1, stack=67108864:67108864
    cap_add:             CAP_SYS_PTRACE
    ipc mode:            host
    shared memory size:  67108864
    devices:             
2023-08-04 03:52:38 [INFO] Launching application python3 /opt/holoscan/app/mednist_classifier_monaideploy.py ...

[info] [app_driver.cpp:1025] Launching the driver/health checking service

[info] [gxf_executor.cpp:210] Creating context

[info] [server.cpp:73] Health checking server listening on 0.0.0.0:8777

[info] [gxf_executor.cpp:1595] Loading extensions from configs...

[info] [gxf_executor.cpp:1741] Activating Graph...

[info] [gxf_executor.cpp:1771] Running Graph...

[info] [gxf_executor.cpp:1773] Waiting for completion...

[info] [gxf_executor.cpp:1774] Graph execution waiting. Fragment: 

[info] [greedy_scheduler.cpp:190] Scheduling 3 entities

[info] [greedy_scheduler.cpp:369] Scheduler stopped: Some entities are waiting for execution, but there are no periodic or async entities to get out of the deadlock.

[info] [greedy_scheduler.cpp:398] Scheduler finished.

[info] [gxf_executor.cpp:1783] Graph execution deactivating. Fragment: 

[info] [gxf_executor.cpp:1784] Deactivating Graph...

[info] [gxf_executor.cpp:1787] Graph execution finished. Fragment: 

[info] [gxf_executor.cpp:229] Destroying context

/home/holoscan/.local/lib/python3.8/site-packages/monai/utils/deprecate_utils.py:111: FutureWarning: <class 'monai.transforms.utility.array.AddChannel'>: Class `AddChannel` has been deprecated since version 0.8. It will be removed in version 1.3. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead with `channel_dim='no_channel'`.

  warn_deprecated(obj, msg, warning_category)

/home/holoscan/.local/lib/python3.8/site-packages/monai/data/meta_tensor.py:116: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)

  return torch.as_tensor(x, *args, **_kwargs).as_subclass(cls)  # type: ignore

/home/holoscan/.local/lib/python3.8/site-packages/pydicom/valuerep.py:443: UserWarning: Invalid value for VR UI: 'xyz'. Please see <https://dicom.nema.org/medical/dicom/current/output/html/part05.html#table_6.2-1> for allowed values for each VR.

  warnings.warn(msg)

AbdomenCT

[2023-08-03 20:52:51,293] [INFO] (common) - Container 'determined_maxwell'(96d09cbab602) exited.
!cat $HOLOSCAN_OUTPUT_PATH/output.json
"AbdomenCT"

Note: Please execute the following script once the exercise is done.

# Remove data files which is in the temporary folder
if directory is None:
    shutil.rmtree(root_dir)