Deploying a MedNIST Classifier App with MONAI Deploy App SDK (Prebuilt Model)¶
This tutorial demos the process of packaging up a trained model using MONAI Deploy App SDK into an artifact which can be run as a local program performing inference, a workflow job doing the same, and a Docker containerized workflow execution.
In this tutorial, we will use a trained model and implement & package the inference application, executing the application locally.
Clone the github project (the latest version of the main branch only)¶
!git clone --branch main --depth 1 https://github.com/Project-MONAI/monai-deploy-app-sdk.git source \
&& rm -rf source/.git
Cloning into 'source'...
remote: Enumerating objects: 212, done.
remote: Counting objects: 100% (212/212), done.
remote: Compressing objects: 100% (188/188), done.
remote: Total 212 (delta 33), reused 79 (delta 7), pack-reused 0
Receiving objects: 100% (212/212), 546.28 KiB | 3.50 MiB/s, done.
Resolving deltas: 100% (33/33), done.
!ls source/examples/apps/mednist_classifier_monaideploy/
mednist_classifier_monaideploy.py
Install monai-deploy-app-sdk package¶
!pip install --upgrade monai-deploy-app-sdk
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting monai-deploy-app-sdk
Downloading monai_deploy_app_sdk-0.1.0rc2-py3-none-any.whl (113 kB)
|████████████████████████████████| 113 kB 2.6 MB/s eta 0:00:01
?25hCollecting colorama>=0.4.1
Downloading colorama-0.4.4-py2.py3-none-any.whl (16 kB)
Collecting networkx>=2.4
Downloading networkx-2.5.1-py3-none-any.whl (1.6 MB)
|████████████████████████████████| 1.6 MB 12.8 MB/s eta 0:00:01
?25hCollecting typeguard~=2.12.1
Downloading typeguard-2.12.1-py3-none-any.whl (17 kB)
Collecting numpy>=1.17
Downloading numpy-1.19.5-cp36-cp36m-manylinux2010_x86_64.whl (14.8 MB)
|████████████████████████████████| 14.8 MB 9.5 MB/s eta 0:00:011
?25hCollecting decorator<5,>=4.3
Downloading decorator-4.4.2-py2.py3-none-any.whl (9.2 kB)
Installing collected packages: decorator, typeguard, numpy, networkx, colorama, monai-deploy-app-sdk
Attempting uninstall: decorator
Found existing installation: decorator 5.1.0
Uninstalling decorator-5.1.0:
Successfully uninstalled decorator-5.1.0
Successfully installed colorama-0.4.4 decorator-4.4.2 monai-deploy-app-sdk-0.1.0rc2 networkx-2.5.1 numpy-1.19.5 typeguard-2.12.1
Install necessary packages for the app¶
!pip install monai Pillow # for MONAI transforms and Pillow
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting monai
Downloading monai-0.6.0-202107081903-py3-none-any.whl (584 kB)
|████████████████████████████████| 584 kB 2.7 MB/s eta 0:00:01
?25hCollecting Pillow
Downloading Pillow-8.3.2-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)
|████████████████████████████████| 3.0 MB 23.2 MB/s eta 0:00:01
?25hRequirement already satisfied: torch>=1.5 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from monai) (1.9.1)
Requirement already satisfied: numpy>=1.17 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from monai) (1.19.5)
Requirement already satisfied: typing_extensions in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from torch>=1.5->monai) (3.10.0.0)
Requirement already satisfied: dataclasses in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from torch>=1.5->monai) (0.8)
Installing collected packages: Pillow, monai
Successfully installed Pillow-8.3.2 monai-0.6.0
Download/Extract mednist_classifier_data.zip from Google Drive¶
# Download mednist_classifier_data.zip
!pip install gdown
!gdown "https://drive.google.com/uc?id=1yJ4P-xMNEfN6lIOq_u6x1eMAq1_MJu-E"
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting gdown
Downloading gdown-3.13.1.tar.gz (10 kB)
Installing build dependencies ... ?25ldone
?25h Getting requirements to build wheel ... ?25ldone
?25h Preparing wheel metadata ... ?25ldone
?25hCollecting filelock
Downloading filelock-3.0.12-py3-none-any.whl (7.6 kB)
Requirement already satisfied: requests[socks]>=2.12.0 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from gdown) (2.26.0)
Requirement already satisfied: six in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from gdown) (1.16.0)
Collecting tqdm
Downloading tqdm-4.62.3-py2.py3-none-any.whl (76 kB)
|████████████████████████████████| 76 kB 2.8 MB/s eta 0:00:01
?25hRequirement already satisfied: certifi>=2017.4.17 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from requests[socks]>=2.12.0->gdown) (2021.5.30)
Requirement already satisfied: idna<4,>=2.5 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from requests[socks]>=2.12.0->gdown) (3.1)
Requirement already satisfied: charset-normalizer~=2.0.0 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from requests[socks]>=2.12.0->gdown) (2.0.0)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from requests[socks]>=2.12.0->gdown) (1.26.6)
Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from requests[socks]>=2.12.0->gdown) (1.7.1)
Building wheels for collected packages: gdown
Building wheel for gdown (PEP 517) ... ?25ldone
?25h Created wheel for gdown: filename=gdown-3.13.1-py3-none-any.whl size=9907 sha256=34f13d3a73d5f3f25f15dd69606e75b7d211bb9cc638bc47b82043612514d1f4
Stored in directory: /tmp/pip-ephem-wheel-cache-xtmpuwlo/wheels/6b/ba/3b/57c8250cc9279fb303e8bfa589361cbc58a1afb291475c4ddc
Successfully built gdown
Installing collected packages: tqdm, filelock, gdown
Successfully installed filelock-3.0.12 gdown-3.13.1 tqdm-4.62.3
Downloading...
From: https://drive.google.com/uc?id=1yJ4P-xMNEfN6lIOq_u6x1eMAq1_MJu-E
To: /home/gbae/mednist_app/mednist_classifier_data.zip
28.6MB [00:02, 10.3MB/s]
# After downloading mednist_classifier_data.zip from the web browser or using gdown,
!unzip -o "mednist_classifier_data.zip"
Archive: mednist_classifier_data.zip
extracting: classifier.zip
extracting: input/AbdomenCT_007000.jpeg
Package app (creating MAP Docker image)¶
This assumes that nvidia docker is installed in the local machine.
Please see https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker to install nvidia-docker2.
Use -l DEBUG
option to see progress.
!monai-deploy package "source/examples/apps/mednist_classifier_monaideploy/mednist_classifier_monaideploy.py" \
--tag mednist_app:latest \
--model classifier.zip
Building MONAI Application Package... Done
[2021-09-21 03:07:51,614] [INFO] (app_packager) - Successfully built mednist_app:latest
Run the app with docker image and input file locally¶
!monai-deploy run mednist_app:latest "input" "output"
Checking dependencies...
--> Verifying if "docker" is installed...
--> Verifying if "mednist_app:latest" is available...
Checking for MAP "mednist_app:latest" locally
"mednist_app:latest" found.
Reading MONAI App Package manifest...
> export '/var/run/monai/export/' detected
--> Verifying if "nvidia-docker" is installed...
Going to initiate execution of operator LoadPILOperator
Executing operator LoadPILOperator (Process ID: 1, Operator ID: 2000b9d2-156f-4abd-8654-cf60219673ac)
Done performing execution of operator LoadPILOperator
Going to initiate execution of operator MedNISTClassifierOperator
Executing operator MedNISTClassifierOperator (Process ID: 1, Operator ID: 13deb10c-dd13-4af5-8a05-a72c07406c05)
AbdomenCT
Done performing execution of operator MedNISTClassifierOperator
!cat output/output.json
"AbdomenCT"
Implementing and Packaging Application with MONAI Deploy App SDK¶
Based on the Torchscript model(classifier.zip
), we will implement an application that process an input Jpeg image and write the prediction(classification) result as JSON file(output.json
).
In our inference application, we will define two operators:
LoadPILOperator
- Load a JPEG image from the input path and pass the loaded image object to the next operator.MedNISTClassifierOperator
- Pre-transform the given image by using MONAI’sCompose
class, feed to the Torchscript model (classifier.zip
), and write the prediction into JSON file(output.json
)
The workflow of the application would look like this.

Setup imports¶
Let’s import necessary classes/decorators and define MEDNIST_CLASSES
.
import monai.deploy.core as md
from monai.deploy.core import (
Application,
DataPath,
ExecutionContext,
Image,
InputContext,
IOType,
Operator,
OutputContext,
)
from monai.transforms import AddChannel, Compose, EnsureType, ScaleIntensity
MEDNIST_CLASSES = ["AbdomenCT", "BreastMRI", "CXR", "ChestCT", "Hand", "HeadCT"]
Creating Operator classes¶
LoadPILOperator¶
@md.input("image", DataPath, IOType.DISK)
@md.output("image", Image, IOType.IN_MEMORY)
@md.env(pip_packages=["pillow"])
class LoadPILOperator(Operator):
"""Load image from the given input (DataPath) and set numpy array to the output (Image)."""
def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext):
import numpy as np
from PIL import Image as PILImage
input_path = op_input.get().path
if input_path.is_dir():
input_path = next(input_path.glob("*.*")) # take the first file
image = PILImage.open(input_path)
image = image.convert("L") # convert to greyscale image
image_arr = np.asarray(image)
output_image = Image(image_arr) # create Image domain object with a numpy array
op_output.set(output_image)
MedNISTClassifierOperator¶
@md.input("image", Image, IOType.IN_MEMORY)
@md.output("output", DataPath, IOType.DISK)
@md.env(pip_packages=["monai"])
class MedNISTClassifierOperator(Operator):
"""Classifies the given image and returns the class name."""
@property
def transform(self):
return Compose([AddChannel(), ScaleIntensity(), EnsureType()])
def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext):
import json
import torch
img = op_input.get().asnumpy() # (64, 64), uint8
image_tensor = self.transform(img) # (1, 64, 64), torch.float64
image_tensor = image_tensor[None].float() # (1, 1, 64, 64), torch.float32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_tensor = image_tensor.to(device)
model = context.models.get() # get a TorchScriptModel object
with torch.no_grad():
outputs = model(image_tensor)
_, output_classes = outputs.max(dim=1)
result = MEDNIST_CLASSES[output_classes[0]] # get the class name
print(result)
# Get output (folder) path and create the folder if not exists
output_folder = op_output.get().path
output_folder.mkdir(parents=True, exist_ok=True)
# Write result to "output.json"
output_path = output_folder / "output.json"
with open(output_path, "w") as fp:
json.dump(result, fp)
Creating Application class¶
Our application class would look like below.
It defines App
class inheriting Application
class.
LoadPILOperator
is connected to MedNISTClassifierOperator
by using self.add_flow()
in compose()
method of App
.
@md.resource(cpu=1, gpu=1, memory="1Gi")
class App(Application):
"""Application class for the MedNIST classifier."""
def compose(self):
load_pil_op = LoadPILOperator()
classifier_op = MedNISTClassifierOperator()
self.add_flow(load_pil_op, classifier_op)
Executing app locally¶
We can execute the app in the Jupyter notebook.
app = App()
app.run(input="input/AbdomenCT_007000.jpeg", output="output", model="classifier.zip")
Going to initiate execution of operator LoadPILOperator
Executing operator LoadPILOperator (Process ID: 7041, Operator ID: 3aa42bbd-f8dd-4374-98ee-7b614979e75a)
Done performing execution of operator LoadPILOperator
Going to initiate execution of operator MedNISTClassifierOperator
Executing operator MedNISTClassifierOperator (Process ID: 7041, Operator ID: 7ee7dd5e-c042-4245-bb75-15ff064bd838)
AbdomenCT
Done performing execution of operator MedNISTClassifierOperator
!cat output/output.json
"AbdomenCT"
Once the application is verified inside Jupyter notebook, we can write the whole application as a file(mednist_classifier_monaideploy.py
) by concatenating code above, then add the following lines:
if __name__ == "__main__":
App(do_run=True)
The above lines are needed to execute the application code by using python
interpreter.
%%writefile mednist_classifier_monaideploy.py
# Copyright 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import monai.deploy.core as md # 'md' stands for MONAI Deploy (or can use 'core' instead)
from monai.deploy.core import (
Application,
DataPath,
ExecutionContext,
Image,
InputContext,
IOType,
Operator,
OutputContext,
)
from monai.transforms import AddChannel, Compose, EnsureType, ScaleIntensity
MEDNIST_CLASSES = ["AbdomenCT", "BreastMRI", "CXR", "ChestCT", "Hand", "HeadCT"]
@md.input("image", DataPath, IOType.DISK)
@md.output("image", Image, IOType.IN_MEMORY)
@md.env(pip_packages=["pillow"])
class LoadPILOperator(Operator):
"""Load image from the given input (DataPath) and set numpy array to the output (Image)."""
def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext):
import numpy as np
from PIL import Image as PILImage
input_path = op_input.get().path
if input_path.is_dir():
input_path = next(input_path.glob("*.*")) # take the first file
image = PILImage.open(input_path)
image = image.convert("L") # convert to greyscale image
image_arr = np.asarray(image)
output_image = Image(image_arr) # create Image domain object with a numpy array
op_output.set(output_image)
@md.input("image", Image, IOType.IN_MEMORY)
@md.output("output", DataPath, IOType.DISK)
@md.env(pip_packages=["monai"])
class MedNISTClassifierOperator(Operator):
"""Classifies the given image and returns the class name."""
@property
def transform(self):
return Compose([AddChannel(), ScaleIntensity(), EnsureType()])
def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext):
import json
import torch
img = op_input.get().asnumpy() # (64, 64), uint8
image_tensor = self.transform(img) # (1, 64, 64), torch.float64
image_tensor = image_tensor[None].float() # (1, 1, 64, 64), torch.float32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_tensor = image_tensor.to(device)
model = context.models.get() # get a TorchScriptModel object
with torch.no_grad():
outputs = model(image_tensor)
_, output_classes = outputs.max(dim=1)
result = MEDNIST_CLASSES[output_classes[0]] # get the class name
print(result)
# Get output (folder) path and create the folder if not exists
output_folder = op_output.get().path
output_folder.mkdir(parents=True, exist_ok=True)
# Write result to "output.json"
output_path = output_folder / "output.json"
with open(output_path, "w") as fp:
json.dump(result, fp)
@md.resource(cpu=1, gpu=1, memory="1Gi")
class App(Application):
"""Application class for the MedNIST classifier."""
def compose(self):
load_pil_op = LoadPILOperator()
classifier_op = MedNISTClassifierOperator()
self.add_flow(load_pil_op, classifier_op)
if __name__ == "__main__":
App(do_run=True)
Writing mednist_classifier_monaideploy.py
In this time, let’s execute the app in the command line.
!python "mednist_classifier_monaideploy.py" -i "input/AbdomenCT_007000.jpeg" -o output -m "classifier.zip"
Going to initiate execution of operator LoadPILOperator
Executing operator LoadPILOperator (Process ID: 8412, Operator ID: 631a82bf-c90e-4217-a17c-831b2c74bc50)
Done performing execution of operator LoadPILOperator
Going to initiate execution of operator MedNISTClassifierOperator
Executing operator MedNISTClassifierOperator (Process ID: 8412, Operator ID: a8fe1121-68bb-463f-bf1c-beff38d4fe86)
AbdomenCT
Done performing execution of operator MedNISTClassifierOperator
Above command is same with the following command line:
!monai-deploy exec "mednist_classifier_monaideploy.py" -i "input/AbdomenCT_007000.jpeg" -o output -m "classifier.zip"
Going to initiate execution of operator LoadPILOperator
Executing operator LoadPILOperator (Process ID: 8453, Operator ID: 7dec2a01-6d18-4104-b250-5b93d663ba4f)
Done performing execution of operator LoadPILOperator
Going to initiate execution of operator MedNISTClassifierOperator
Executing operator MedNISTClassifierOperator (Process ID: 8453, Operator ID: 5e83dd80-5b19-4c78-9382-3d181640b80c)
AbdomenCT
Done performing execution of operator MedNISTClassifierOperator
!cat output/output.json
"AbdomenCT"