Deploying a MedNIST Classifier App with MONAI Deploy App SDK¶
This tutorial demos the process of packaging up a trained model using MONAI Deploy App SDK into an artifact which can be run as a local program performing inference, a workflow job doing the same, and a Docker containerized workflow execution.
In this tutorial, we will train a MedNIST classifier like the MONAI tutorial here and then implement & package the inference application, executing the application locally.
Train a MedNIST classifier model with MONAI Core¶
Setup environment¶
# Install necessary packages for MONAI Core
!python -c "import monai" || pip install -q "monai[pillow, tqdm]"
!python -c "import ignite" || pip install -q "monai[ignite]"
!python -c "import gdown" || pip install -q "monai[gdown]"
# Install MONAI Deploy App SDK package
!python -c "import monai.deploy" || pip install -q "monai-deploy-app-sdk"
Setup imports¶
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import tempfile
import glob
import PIL.Image
import torch
import numpy as np
from ignite.engine import Events
from monai.apps import download_and_extract
from monai.config import print_config
from monai.networks.nets import DenseNet121
from monai.engines import SupervisedTrainer
from monai.transforms import (
AddChannel,
Compose,
LoadImage,
RandFlip,
RandRotate,
RandZoom,
ScaleIntensity,
EnsureType,
)
from monai.utils import set_determinism
set_determinism(seed=0)
print_config()
MONAI version: 0.6.0
Numpy version: 1.19.5
Pytorch version: 1.9.0
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: 0ad9e73639e30f4f1af5a1f4a45da9cb09930179
Optional dependencies:
Pytorch Ignite version: 0.4.5
Nibabel version: 3.2.1
scikit-image version: 0.17.2
Pillow version: 8.3.1
Tensorboard version: 2.6.0
gdown version: 3.13.0
TorchVision version: 0.10.0
ITK version: 5.2.0
tqdm version: 4.62.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.8.0
pandas version: 1.1.5
einops version: 0.3.2
For details about installing the optional dependencies, please visit:
https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
Download dataset¶
The MedNIST dataset was gathered from several sets from TCIA, the RSNA Bone Age Challenge(https://www.rsna.org/education/ai-resources-and-training/ai-image-challenge/rsna-pediatric-bone-age-challenge-2017), and the NIH Chest X-ray dataset.
The dataset is kindly made available by Dr. Bradley J. Erickson M.D., Ph.D. (Department of Radiology, Mayo Clinic) under the Creative Commons CC BY-SA 4.0 license.
If you use the MedNIST dataset, please acknowledge the source.
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
resource = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE"
md5 = "0bc7306e7427e00ad1c5526a6677552d"
compressed_file = os.path.join(root_dir, "MedNIST.tar.gz")
data_dir = os.path.join(root_dir, "MedNIST")
if not os.path.exists(data_dir):
download_and_extract(resource, compressed_file, root_dir, md5)
/tmp/tmpgh08b1ks
Downloading...
From: https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE
To: /tmp/tmpthbz6o8r/MedNIST.tar.gz
61.8MB [00:05, 10.7MB/s]
Downloaded: /tmp/tmpgh08b1ks/MedNIST.tar.gz
Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.
Writing into directory: /tmp/tmpgh08b1ks.
subdirs = sorted(glob.glob(f"{data_dir}/*/"))
class_names = [os.path.basename(sd[:-1]) for sd in subdirs]
image_files = [glob.glob(f"{sb}/*") for sb in subdirs]
image_files_list = sum(image_files, [])
image_class = sum(([i] * len(f) for i, f in enumerate(image_files)), [])
image_width, image_height = PIL.Image.open(image_files_list[0]).size
print(f"Label names: {class_names}")
print(f"Label counts: {list(map(len, image_files))}")
print(f"Total image count: {len(image_class)}")
print(f"Image dimensions: {image_width} x {image_height}")
Label names: ['AbdomenCT', 'BreastMRI', 'CXR', 'ChestCT', 'Hand', 'HeadCT']
Label counts: [10000, 8954, 10000, 10000, 10000, 10000]
Total image count: 58954
Image dimensions: 64 x 64
Setup and train¶
Here we’ll create a transform sequence and train the network, omitting validation and testing since we know this does indeed work and it’s not needed here:
train_transforms = Compose(
[
LoadImage(image_only=True),
AddChannel(),
ScaleIntensity(),
RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
RandFlip(spatial_axis=0, prob=0.5),
RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
EnsureType(),
]
)
class MedNISTDataset(torch.utils.data.Dataset):
def __init__(self, image_files, labels, transforms):
self.image_files = image_files
self.labels = labels
self.transforms = transforms
def __len__(self):
return len(self.image_files)
def __getitem__(self, index):
return self.transforms(self.image_files[index]), self.labels[index]
# just one dataset and loader, we won't bother with validation or testing
train_ds = MedNISTDataset(image_files_list, image_class, train_transforms)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10)
device = torch.device("cuda:0")
net = DenseNet121(spatial_dims=2, in_channels=1, out_channels=len(class_names)).to(device)
loss_function = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(net.parameters(), 1e-5)
max_epochs = 5
def _prepare_batch(batch, device, non_blocking):
return tuple(b.to(device) for b in batch)
trainer = SupervisedTrainer(device, max_epochs, train_loader, net, opt, loss_function, prepare_batch=_prepare_batch)
@trainer.on(Events.EPOCH_COMPLETED)
def _print_loss(engine):
print(f"Epoch {engine.state.epoch}/{engine.state.max_epochs} Loss: {engine.state.output[0]['loss']}")
trainer.run()
Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448272031/work/c10/core/TensorImpl.h:1156.)
Epoch 1/5 Loss: 0.1811893731355667
Epoch 2/5 Loss: 0.08026652783155441
Epoch 3/5 Loss: 0.05008228123188019
Epoch 4/5 Loss: 0.01724996417760849
Epoch 5/5 Loss: 0.029151903465390205
The network will be saved out here as a Torchscript object named classifier.zip
torch.jit.script(net).save("classifier.zip")
Implementing and Packaging Application with MONAI Deploy App SDK¶
Based on the Torchscript model(classifier.zip
), we will implement an application that process an input Jpeg image and write the prediction(classification) result as JSON file(output.json
).
Creating Operators and connecting them in Application class¶
We used the following train transforms as pre-transforms during the training.
1train_transforms = Compose(
2 [
3 LoadImage(image_only=True),
4 AddChannel(),
5 ScaleIntensity(),
6 RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
7 RandFlip(spatial_axis=0, prob=0.5),
8 RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
9 EnsureType(),
10 ]
11)
RandRotate
, RandFlip
, and RandZoom
transforms are used only for training and those are not necessary during the inference.
In our inference application, we will define two operators:
LoadPILOperator
- Load a JPEG image from the input path and pass the loaded image object to the next operator.MedNISTClassifierOperator
- Pre-transform the given image by using MONAI’sCompose
class, feed to the Torchscript model (classifier.zip
), and write the prediction into JSON file(output.json
)
The workflow of the application would look like this.
Setup imports¶
Let’s import necessary classes/decorators and define MEDNIST_CLASSES
.
import monai.deploy.core as md # 'md' stands for MONAI Deploy (or can use 'core' instead)
from monai.deploy.core import (
Application,
DataPath,
ExecutionContext,
Image,
InputContext,
IOType,
Operator,
OutputContext,
)
from monai.transforms import AddChannel, Compose, EnsureType, ScaleIntensity
MEDNIST_CLASSES = ["AbdomenCT", "BreastMRI", "CXR", "ChestCT", "Hand", "HeadCT"]
Creating Operator classes¶
LoadPILOperator¶
@md.input("image", DataPath, IOType.DISK)
@md.output("image", Image, IOType.IN_MEMORY)
@md.env(pip_packages=["pillow"])
class LoadPILOperator(Operator):
"""Load image from the given input (DataPath) and set numpy array to the output (Image)."""
def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext):
import numpy as np
from PIL import Image as PILImage
input_path = op_input.get().path
if input_path.is_dir():
input_path = next(input_path.glob("*.*")) # take the first file
image = PILImage.open(input_path)
image = image.convert("L") # convert to greyscale image
image_arr = np.asarray(image)
output_image = Image(image_arr) # create Image domain object with a numpy array
op_output.set(output_image)
MedNISTClassifierOperator¶
@md.input("image", Image, IOType.IN_MEMORY)
@md.output("output", DataPath, IOType.DISK)
@md.env(pip_packages=["monai"])
class MedNISTClassifierOperator(Operator):
"""Classifies the given image and returns the class name."""
@property
def transform(self):
return Compose([AddChannel(), ScaleIntensity(), EnsureType()])
def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext):
import json
import torch
img = op_input.get().asnumpy() # (64, 64), uint8
image_tensor = self.transform(img) # (1, 64, 64), torch.float64
image_tensor = image_tensor[None].float() # (1, 1, 64, 64), torch.float32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_tensor = image_tensor.to(device)
model = context.models.get() # get a TorchScriptModel object
with torch.no_grad():
outputs = model(image_tensor)
_, output_classes = outputs.max(dim=1)
result = MEDNIST_CLASSES[output_classes[0]] # get the class name
print(result)
# Get output (folder) path and create the folder if not exists
output_folder = op_output.get().path
output_folder.mkdir(parents=True, exist_ok=True)
# Write result to "output.json"
output_path = output_folder / "output.json"
with open(output_path, "w") as fp:
json.dump(result, fp)
Creating Application class¶
Our application class would look like below.
It defines App
class inheriting Application
class.
LoadPILOperator
is connected to MedNISTClassifierOperator
by using self.add_flow()
in compose()
method of App
.
@md.resource(cpu=1, gpu=1, memory="1Gi")
class App(Application):
"""Application class for the MedNIST classifier."""
def compose(self):
load_pil_op = LoadPILOperator()
classifier_op = MedNISTClassifierOperator()
self.add_flow(load_pil_op, classifier_op)
Executing app locally¶
Let’s find a test input file path to use.
test_input_path = image_files[0][0]
print(f"Test input file path: {test_input_path}")
Test input file path: /tmp/tmpgh08b1ks/MedNIST/AbdomenCT/007000.jpeg
We can execute the app in the Jupyter notebook.
app = App()
app.run(input=test_input_path, output="output", model="classifier.zip")
Going to initiate execution of operator LoadPILOperator
Executing operator LoadPILOperator (Process ID: 14835, Operator ID: dd5dee72-9764-458a-9719-dc89f3cd14ea)
Done performing execution of operator LoadPILOperator
Going to initiate execution of operator MedNISTClassifierOperator
Executing operator MedNISTClassifierOperator (Process ID: 14835, Operator ID: 9b032f84-6a73-4f59-9c56-d04efed5bdb5)
AbdomenCT
Done performing execution of operator MedNISTClassifierOperator
!cat output/output.json
"AbdomenCT"
Once the application is verified inside Jupyter notebook, we can write the whole application as a file(mednist_classifier_monaideploy.py
) by concatenating code above, then add the following lines:
if __name__ == "__main__":
App(do_run=True)
The above lines are needed to execute the application code by using python
interpreter.
%%writefile mednist_classifier_monaideploy.py
# Copyright 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import monai.deploy.core as md # 'md' stands for MONAI Deploy (or can use 'core' instead)
from monai.deploy.core import (
Application,
DataPath,
ExecutionContext,
Image,
InputContext,
IOType,
Operator,
OutputContext,
)
from monai.transforms import AddChannel, Compose, EnsureType, ScaleIntensity
MEDNIST_CLASSES = ["AbdomenCT", "BreastMRI", "CXR", "ChestCT", "Hand", "HeadCT"]
@md.input("image", DataPath, IOType.DISK)
@md.output("image", Image, IOType.IN_MEMORY)
@md.env(pip_packages=["pillow"])
class LoadPILOperator(Operator):
"""Load image from the given input (DataPath) and set numpy array to the output (Image)."""
def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext):
import numpy as np
from PIL import Image as PILImage
input_path = op_input.get().path
if input_path.is_dir():
input_path = next(input_path.glob("*.*")) # take the first file
image = PILImage.open(input_path)
image = image.convert("L") # convert to greyscale image
image_arr = np.asarray(image)
output_image = Image(image_arr) # create Image domain object with a numpy array
op_output.set(output_image)
@md.input("image", Image, IOType.IN_MEMORY)
@md.output("output", DataPath, IOType.DISK)
@md.env(pip_packages=["monai"])
class MedNISTClassifierOperator(Operator):
"""Classifies the given image and returns the class name."""
@property
def transform(self):
return Compose([AddChannel(), ScaleIntensity(), EnsureType()])
def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext):
import json
import torch
img = op_input.get().asnumpy() # (64, 64), uint8
image_tensor = self.transform(img) # (1, 64, 64), torch.float64
image_tensor = image_tensor[None].float() # (1, 1, 64, 64), torch.float32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_tensor = image_tensor.to(device)
model = context.models.get() # get a TorchScriptModel object
with torch.no_grad():
outputs = model(image_tensor)
_, output_classes = outputs.max(dim=1)
result = MEDNIST_CLASSES[output_classes[0]] # get the class name
print(result)
# Get output (folder) path and create the folder if not exists
output_folder = op_output.get().path
output_folder.mkdir(parents=True, exist_ok=True)
# Write result to "output.json"
output_path = output_folder / "output.json"
with open(output_path, "w") as fp:
json.dump(result, fp)
@md.resource(cpu=1, gpu=1, memory="1Gi")
class App(Application):
"""Application class for the MedNIST classifier."""
def compose(self):
load_pil_op = LoadPILOperator()
classifier_op = MedNISTClassifierOperator()
self.add_flow(load_pil_op, classifier_op)
if __name__ == "__main__":
App(do_run=True)
Writing mednist_classifier_monaideploy.py
In this time, let’s execute the app in the command line.
!python mednist_classifier_monaideploy.py -i {test_input_path} -o output -m classifier.zip
Going to initiate execution of operator LoadPILOperator
Executing operator LoadPILOperator (Process ID: 18193, Operator ID: de9a33aa-0abb-4e64-88af-90b27617ff63)
Done performing execution of operator LoadPILOperator
Going to initiate execution of operator MedNISTClassifierOperator
Executing operator MedNISTClassifierOperator (Process ID: 18193, Operator ID: 73bfa497-459c-4ef3-998a-8d162be57687)
Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448272031/work/c10/core/TensorImpl.h:1156.)
AbdomenCT
Done performing execution of operator MedNISTClassifierOperator
Above command is same with the following command line:
!monai-deploy exec mednist_classifier_monaideploy.py -i {test_input_path} -o output -m classifier.zip
Going to initiate execution of operator LoadPILOperator
Executing operator LoadPILOperator (Process ID: 18328, Operator ID: 70e92517-e6ad-4d0a-aaff-2141c672d587)
Done performing execution of operator LoadPILOperator
Going to initiate execution of operator MedNISTClassifierOperator
Executing operator MedNISTClassifierOperator (Process ID: 18328, Operator ID: a9a7fc21-b180-4981-b775-ea8736e805a2)
Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448272031/work/c10/core/TensorImpl.h:1156.)
AbdomenCT
Done performing execution of operator MedNISTClassifierOperator
!cat output/output.json
"AbdomenCT"
Packaging app¶
Let’s package the app with MONAI Application Packager.
!monai-deploy package mednist_classifier_monaideploy.py --tag mednist_app:latest --model classifier.zip # -l DEBUG
Building MONAI Application Package... Done
[2021-09-20 17:01:24,898] [INFO] (app_packager) - Successfully built mednist_app:latest
Note
Building a MONAI Application Package (Docker image) can take time. Use -l DEBUG
option if you want to see the progress.
We can see that the Docker image is created.
!docker image ls | grep mednist_app
mednist_app latest 8c78cc6e0966 3 seconds ago 15.3GB
Executing packaged app locally¶
The packaged app can be run locally through MONAI Application Runner.
# Copy a test input file to 'input' folder
!mkdir -p input && rm -rf input/*
!cp {test_input_path} input/
# Launch the app
!monai-deploy run mednist_app:latest input output
Checking dependencies...
--> Verifying if "docker" is installed...
--> Verifying if "mednist_app:latest" is available...
Checking for MAP "mednist_app:latest" locally
"mednist_app:latest" found.
Reading MONAI App Package manifest...
> export '/var/run/monai/export/' detected
--> Verifying if "nvidia-docker" is installed...
Going to initiate execution of operator LoadPILOperator
Executing operator LoadPILOperator (Process ID: 1, Operator ID: 7bb4824c-ebc7-4801-a0c3-1c5525b132cf)
Done performing execution of operator LoadPILOperator
Going to initiate execution of operator MedNISTClassifierOperator
Executing operator MedNISTClassifierOperator (Process ID: 1, Operator ID: d27f4a05-e557-49c3-8adf-08f83a860d14)
AbdomenCT
Done performing execution of operator MedNISTClassifierOperator
!cat output/output.json
"AbdomenCT"
Note: Please execute the following script once the exercise is done.
# Remove data files which is in the temporary folder
if directory is None:
shutil.rmtree(root_dir)