Deploying a MedNIST Classifier App with MONAI Deploy App SDK (Prebuilt Model)

This tutorial demos the process of packaging up a trained model using MONAI Deploy App SDK into an artifact which can be run as a local program performing inference, a workflow job doing the same, and a Docker containerized workflow execution.

In this tutorial, we will use a trained model and implement & package the inference application, executing the application locally.

Clone the github project (the latest version of the main branch only)

!git clone --branch main --depth 1 https://github.com/Project-MONAI/monai-deploy-app-sdk.git source \
 && rm -rf source/.git
Cloning into 'source'...
remote: Enumerating objects: 212, done.
remote: Counting objects: 100% (212/212), done.
remote: Compressing objects: 100% (188/188), done.
remote: Total 212 (delta 33), reused 79 (delta 7), pack-reused 0
Receiving objects: 100% (212/212), 546.28 KiB | 3.50 MiB/s, done.
Resolving deltas: 100% (33/33), done.
!ls source/examples/apps/mednist_classifier_monaideploy/
mednist_classifier_monaideploy.py

Install monai-deploy-app-sdk package

!pip install --upgrade monai-deploy-app-sdk
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting monai-deploy-app-sdk
  Downloading monai_deploy_app_sdk-0.1.0rc2-py3-none-any.whl (113 kB)
     |████████████████████████████████| 113 kB 2.6 MB/s eta 0:00:01
?25hCollecting colorama>=0.4.1
  Downloading colorama-0.4.4-py2.py3-none-any.whl (16 kB)
Collecting networkx>=2.4
  Downloading networkx-2.5.1-py3-none-any.whl (1.6 MB)
     |████████████████████████████████| 1.6 MB 12.8 MB/s eta 0:00:01
?25hCollecting typeguard~=2.12.1
  Downloading typeguard-2.12.1-py3-none-any.whl (17 kB)
Collecting numpy>=1.17
  Downloading numpy-1.19.5-cp36-cp36m-manylinux2010_x86_64.whl (14.8 MB)
     |████████████████████████████████| 14.8 MB 9.5 MB/s eta 0:00:011
?25hCollecting decorator<5,>=4.3
  Downloading decorator-4.4.2-py2.py3-none-any.whl (9.2 kB)
Installing collected packages: decorator, typeguard, numpy, networkx, colorama, monai-deploy-app-sdk
  Attempting uninstall: decorator
    Found existing installation: decorator 5.1.0
    Uninstalling decorator-5.1.0:
      Successfully uninstalled decorator-5.1.0
Successfully installed colorama-0.4.4 decorator-4.4.2 monai-deploy-app-sdk-0.1.0rc2 networkx-2.5.1 numpy-1.19.5 typeguard-2.12.1

Install necessary packages for the app

!pip install monai Pillow  # for MONAI transforms and Pillow
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting monai
  Downloading monai-0.6.0-202107081903-py3-none-any.whl (584 kB)
     |████████████████████████████████| 584 kB 2.7 MB/s eta 0:00:01
?25hCollecting Pillow
  Downloading Pillow-8.3.2-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)
     |████████████████████████████████| 3.0 MB 23.2 MB/s eta 0:00:01
?25hRequirement already satisfied: torch>=1.5 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from monai) (1.9.1)
Requirement already satisfied: numpy>=1.17 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from monai) (1.19.5)
Requirement already satisfied: typing_extensions in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from torch>=1.5->monai) (3.10.0.0)
Requirement already satisfied: dataclasses in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from torch>=1.5->monai) (0.8)
Installing collected packages: Pillow, monai
Successfully installed Pillow-8.3.2 monai-0.6.0

Download/Extract mednist_classifier_data.zip from Google Drive

# Download mednist_classifier_data.zip
!pip install gdown 
!gdown "https://drive.google.com/uc?id=1yJ4P-xMNEfN6lIOq_u6x1eMAq1_MJu-E"
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting gdown
  Downloading gdown-3.13.1.tar.gz (10 kB)
  Installing build dependencies ... ?25ldone
?25h  Getting requirements to build wheel ... ?25ldone
?25h    Preparing wheel metadata ... ?25ldone
?25hCollecting filelock
  Downloading filelock-3.0.12-py3-none-any.whl (7.6 kB)
Requirement already satisfied: requests[socks]>=2.12.0 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from gdown) (2.26.0)
Requirement already satisfied: six in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from gdown) (1.16.0)
Collecting tqdm
  Downloading tqdm-4.62.3-py2.py3-none-any.whl (76 kB)
     |████████████████████████████████| 76 kB 2.8 MB/s eta 0:00:01
?25hRequirement already satisfied: certifi>=2017.4.17 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from requests[socks]>=2.12.0->gdown) (2021.5.30)
Requirement already satisfied: idna<4,>=2.5 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from requests[socks]>=2.12.0->gdown) (3.1)
Requirement already satisfied: charset-normalizer~=2.0.0 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from requests[socks]>=2.12.0->gdown) (2.0.0)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from requests[socks]>=2.12.0->gdown) (1.26.6)
Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from requests[socks]>=2.12.0->gdown) (1.7.1)
Building wheels for collected packages: gdown
  Building wheel for gdown (PEP 517) ... ?25ldone
?25h  Created wheel for gdown: filename=gdown-3.13.1-py3-none-any.whl size=9907 sha256=34f13d3a73d5f3f25f15dd69606e75b7d211bb9cc638bc47b82043612514d1f4
  Stored in directory: /tmp/pip-ephem-wheel-cache-xtmpuwlo/wheels/6b/ba/3b/57c8250cc9279fb303e8bfa589361cbc58a1afb291475c4ddc
Successfully built gdown
Installing collected packages: tqdm, filelock, gdown
Successfully installed filelock-3.0.12 gdown-3.13.1 tqdm-4.62.3
Downloading...
From: https://drive.google.com/uc?id=1yJ4P-xMNEfN6lIOq_u6x1eMAq1_MJu-E
To: /home/gbae/mednist_app/mednist_classifier_data.zip
28.6MB [00:02, 10.3MB/s]
# After downloading mednist_classifier_data.zip from the web browser or using gdown,
!unzip -o "mednist_classifier_data.zip"
Archive:  mednist_classifier_data.zip
 extracting: classifier.zip          
 extracting: input/AbdomenCT_007000.jpeg  

Package app (creating MAP Docker image)

This assumes that nvidia docker is installed in the local machine.

Please see https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker to install nvidia-docker2.

Use -l DEBUG option to see progress.

!monai-deploy package "source/examples/apps/mednist_classifier_monaideploy/mednist_classifier_monaideploy.py" \
    --tag mednist_app:latest \
    --model classifier.zip
Building MONAI Application Package... Done
[2021-09-21 03:07:51,614] [INFO] (app_packager) - Successfully built mednist_app:latest

Run the app with docker image and input file locally

!monai-deploy run mednist_app:latest "input" "output"
Checking dependencies...
--> Verifying if "docker" is installed...

--> Verifying if "mednist_app:latest" is available...

Checking for MAP "mednist_app:latest" locally
"mednist_app:latest" found.

Reading MONAI App Package manifest...
 > export '/var/run/monai/export/' detected
--> Verifying if "nvidia-docker" is installed...

Going to initiate execution of operator LoadPILOperator
Executing operator LoadPILOperator (Process ID: 1, Operator ID: 2000b9d2-156f-4abd-8654-cf60219673ac)
Done performing execution of operator LoadPILOperator

Going to initiate execution of operator MedNISTClassifierOperator
Executing operator MedNISTClassifierOperator (Process ID: 1, Operator ID: 13deb10c-dd13-4af5-8a05-a72c07406c05)
AbdomenCT
Done performing execution of operator MedNISTClassifierOperator

!cat output/output.json
"AbdomenCT"

Implementing and Packaging Application with MONAI Deploy App SDK

Based on the Torchscript model(classifier.zip), we will implement an application that process an input Jpeg image and write the prediction(classification) result as JSON file(output.json).

In our inference application, we will define two operators:

  1. LoadPILOperator - Load a JPEG image from the input path and pass the loaded image object to the next operator.

    • This Operator does similar job with LoadImage(image_only=True) transform in train_transforms, but handles only one image.

    • Input: a file path (DataPath)

    • Output: an image object in memory (Image)

  2. MedNISTClassifierOperator - Pre-transform the given image by using MONAI’s Compose class, feed to the Torchscript model (classifier.zip), and write the prediction into JSON file(output.json)

    • Pre-transforms consist of three transforms – AddChannel, ScaleIntensity, and EnsureType.

    • Input: an image object in memory (Image)

    • Output: a folder path that the prediction result(output.json) would be written (DataPath)

The workflow of the application would look like this.

Workflow

Setup imports

Let’s import necessary classes/decorators and define MEDNIST_CLASSES.

import monai.deploy.core as md
from monai.deploy.core import (
    Application,
    DataPath,
    ExecutionContext,
    Image,
    InputContext,
    IOType,
    Operator,
    OutputContext,
)
from monai.transforms import AddChannel, Compose, EnsureType, ScaleIntensity

MEDNIST_CLASSES = ["AbdomenCT", "BreastMRI", "CXR", "ChestCT", "Hand", "HeadCT"]

Creating Operator classes

LoadPILOperator

@md.input("image", DataPath, IOType.DISK)
@md.output("image", Image, IOType.IN_MEMORY)
@md.env(pip_packages=["pillow"])
class LoadPILOperator(Operator):
    """Load image from the given input (DataPath) and set numpy array to the output (Image)."""

    def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext):
        import numpy as np
        from PIL import Image as PILImage

        input_path = op_input.get().path
        if input_path.is_dir():
            input_path = next(input_path.glob("*.*"))  # take the first file

        image = PILImage.open(input_path)
        image = image.convert("L")  # convert to greyscale image
        image_arr = np.asarray(image)

        output_image = Image(image_arr)  # create Image domain object with a numpy array
        op_output.set(output_image)

MedNISTClassifierOperator

@md.input("image", Image, IOType.IN_MEMORY)
@md.output("output", DataPath, IOType.DISK)
@md.env(pip_packages=["monai"])
class MedNISTClassifierOperator(Operator):
    """Classifies the given image and returns the class name."""

    @property
    def transform(self):
        return Compose([AddChannel(), ScaleIntensity(), EnsureType()])

    def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext):
        import json

        import torch

        img = op_input.get().asnumpy()  # (64, 64), uint8
        image_tensor = self.transform(img)  # (1, 64, 64), torch.float64
        image_tensor = image_tensor[None].float()  # (1, 1, 64, 64), torch.float32

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        image_tensor = image_tensor.to(device)

        model = context.models.get()  # get a TorchScriptModel object

        with torch.no_grad():
            outputs = model(image_tensor)

        _, output_classes = outputs.max(dim=1)

        result = MEDNIST_CLASSES[output_classes[0]]  # get the class name
        print(result)

        # Get output (folder) path and create the folder if not exists
        output_folder = op_output.get().path
        output_folder.mkdir(parents=True, exist_ok=True)

        # Write result to "output.json"
        output_path = output_folder / "output.json"
        with open(output_path, "w") as fp:
            json.dump(result, fp)

Creating Application class

Our application class would look like below.

It defines App class inheriting Application class.

LoadPILOperator is connected to MedNISTClassifierOperator by using self.add_flow() in compose() method of App.

@md.resource(cpu=1, gpu=1, memory="1Gi")
class App(Application):
    """Application class for the MedNIST classifier."""

    def compose(self):
        load_pil_op = LoadPILOperator()
        classifier_op = MedNISTClassifierOperator()

        self.add_flow(load_pil_op, classifier_op)

Executing app locally

We can execute the app in the Jupyter notebook.

app = App()
app.run(input="input/AbdomenCT_007000.jpeg", output="output", model="classifier.zip")
Going to initiate execution of operator LoadPILOperator
Executing operator LoadPILOperator (Process ID: 7041, Operator ID: 3aa42bbd-f8dd-4374-98ee-7b614979e75a)
Done performing execution of operator LoadPILOperator

Going to initiate execution of operator MedNISTClassifierOperator
Executing operator MedNISTClassifierOperator (Process ID: 7041, Operator ID: 7ee7dd5e-c042-4245-bb75-15ff064bd838)
AbdomenCT
Done performing execution of operator MedNISTClassifierOperator

!cat output/output.json
"AbdomenCT"

Once the application is verified inside Jupyter notebook, we can write the whole application as a file(mednist_classifier_monaideploy.py) by concatenating code above, then add the following lines:

if __name__ == "__main__":
    App(do_run=True)

The above lines are needed to execute the application code by using python interpreter.

%%writefile mednist_classifier_monaideploy.py

# Copyright 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import monai.deploy.core as md  # 'md' stands for MONAI Deploy (or can use 'core' instead)
from monai.deploy.core import (
    Application,
    DataPath,
    ExecutionContext,
    Image,
    InputContext,
    IOType,
    Operator,
    OutputContext,
)
from monai.transforms import AddChannel, Compose, EnsureType, ScaleIntensity

MEDNIST_CLASSES = ["AbdomenCT", "BreastMRI", "CXR", "ChestCT", "Hand", "HeadCT"]


@md.input("image", DataPath, IOType.DISK)
@md.output("image", Image, IOType.IN_MEMORY)
@md.env(pip_packages=["pillow"])
class LoadPILOperator(Operator):
    """Load image from the given input (DataPath) and set numpy array to the output (Image)."""

    def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext):
        import numpy as np
        from PIL import Image as PILImage

        input_path = op_input.get().path
        if input_path.is_dir():
            input_path = next(input_path.glob("*.*"))  # take the first file

        image = PILImage.open(input_path)
        image = image.convert("L")  # convert to greyscale image
        image_arr = np.asarray(image)

        output_image = Image(image_arr)  # create Image domain object with a numpy array
        op_output.set(output_image)


@md.input("image", Image, IOType.IN_MEMORY)
@md.output("output", DataPath, IOType.DISK)
@md.env(pip_packages=["monai"])
class MedNISTClassifierOperator(Operator):
    """Classifies the given image and returns the class name."""

    @property
    def transform(self):
        return Compose([AddChannel(), ScaleIntensity(), EnsureType()])

    def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext):
        import json

        import torch

        img = op_input.get().asnumpy()  # (64, 64), uint8
        image_tensor = self.transform(img)  # (1, 64, 64), torch.float64
        image_tensor = image_tensor[None].float()  # (1, 1, 64, 64), torch.float32

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        image_tensor = image_tensor.to(device)

        model = context.models.get()  # get a TorchScriptModel object

        with torch.no_grad():
            outputs = model(image_tensor)

        _, output_classes = outputs.max(dim=1)

        result = MEDNIST_CLASSES[output_classes[0]]  # get the class name
        print(result)

        # Get output (folder) path and create the folder if not exists
        output_folder = op_output.get().path
        output_folder.mkdir(parents=True, exist_ok=True)

        # Write result to "output.json"
        output_path = output_folder / "output.json"
        with open(output_path, "w") as fp:
            json.dump(result, fp)


@md.resource(cpu=1, gpu=1, memory="1Gi")
class App(Application):
    """Application class for the MedNIST classifier."""

    def compose(self):
        load_pil_op = LoadPILOperator()
        classifier_op = MedNISTClassifierOperator()

        self.add_flow(load_pil_op, classifier_op)


if __name__ == "__main__":
    App(do_run=True)
Writing mednist_classifier_monaideploy.py

In this time, let’s execute the app in the command line.

!python "mednist_classifier_monaideploy.py" -i "input/AbdomenCT_007000.jpeg" -o output -m "classifier.zip"
Going to initiate execution of operator LoadPILOperator
Executing operator LoadPILOperator (Process ID: 8412, Operator ID: 631a82bf-c90e-4217-a17c-831b2c74bc50)
Done performing execution of operator LoadPILOperator

Going to initiate execution of operator MedNISTClassifierOperator
Executing operator MedNISTClassifierOperator (Process ID: 8412, Operator ID: a8fe1121-68bb-463f-bf1c-beff38d4fe86)
AbdomenCT
Done performing execution of operator MedNISTClassifierOperator

Above command is same with the following command line:

!monai-deploy exec "mednist_classifier_monaideploy.py" -i "input/AbdomenCT_007000.jpeg" -o output -m "classifier.zip"
Going to initiate execution of operator LoadPILOperator
Executing operator LoadPILOperator (Process ID: 8453, Operator ID: 7dec2a01-6d18-4104-b250-5b93d663ba4f)
Done performing execution of operator LoadPILOperator

Going to initiate execution of operator MedNISTClassifierOperator
Executing operator MedNISTClassifierOperator (Process ID: 8453, Operator ID: 5e83dd80-5b19-4c78-9382-3d181640b80c)
AbdomenCT
Done performing execution of operator MedNISTClassifierOperator

!cat output/output.json
"AbdomenCT"