# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import os
import time
from abc import abstractmethod
from typing import Dict
import torch
from monailabel.interfaces.exception import MONAILabelError, MONAILabelException
from monailabel.interfaces.utils.transform import run_transforms
from monailabel.transform.writer import Writer
logger = logging.getLogger(__name__)
[docs]class InferType:
"""
Type of Inference Model
Attributes:
SEGMENTATION - Segmentation Model
CLASSIFICATION - Classification Model
DEEPGROW - Deepgrow Interactive Model
DEEPEDIT - DeepEdit Interactive Model
SCRIBBLES - Scribbles Model
OTHERS - Other Model Type
"""
SEGMENTATION = "segmentation"
CLASSIFICATION = "classification"
DEEPGROW = "deepgrow"
DEEPEDIT = "deepedit"
SCRIBBLES = "scribbles"
OTHERS = "others"
KNOWN_TYPES = [SEGMENTATION, CLASSIFICATION, DEEPGROW, DEEPEDIT, SCRIBBLES, OTHERS]
[docs]class InferTask:
"""
Basic Inference Task Helper
"""
[docs] def __init__(
self,
path,
network,
type: InferType,
labels,
dimension,
description,
model_state_dict="model",
input_key="image",
output_label_key="pred",
output_json_key="result",
config=None,
):
"""
:param path: Model File Path. Supports multiple paths to support versions (Last item will be picked as latest)
:param network: Model Network (e.g. monai.networks.xyz). None in case if you use TorchScript (torch.jit).
:param type: Type of Infer (segmentation, deepgrow etc..)
:param dimension: Input dimension
:param description: Description
:param model_state_dict: Key for loading the model state from checkpoint
:param input_key: Input key for running inference
:param output_label_key: Output key for storing result/label of inference
:param output_json_key: Output key for storing result/label of inference
:param config: K,V pairs to be part of user config
"""
self.path = path
self.network = network
self.type = type
self.labels = [] if labels is None else [labels] if isinstance(labels, str) else labels
self.dimension = dimension
self.description = description
self.model_state_dict = model_state_dict
self.input_key = input_key
self.output_label_key = output_label_key
self.output_json_key = output_json_key
self._networks: Dict = {}
self._config = {
# "device": "cuda",
# "result_extension": None,
# "result_dtype": None,
# "result_compress": False
}
if config:
self._config.update(config)
[docs] def info(self):
return {
"type": self.type,
"labels": self.labels,
"dimension": self.dimension,
"description": self.description,
"config": self.config(),
}
[docs] def config(self):
return self._config
[docs] def is_valid(self):
if self.network or self.type == InferType.SCRIBBLES:
return True
paths = [self.path] if isinstance(self.path, str) else self.path
for path in reversed(paths):
if os.path.exists(path):
return True
return False
[docs] def get_path(self):
if not self.path:
return None
paths = [self.path] if isinstance(self.path, str) else self.path
for path in reversed(paths):
if os.path.exists(path):
return path
return None
[docs] @abstractmethod
def post_transforms(self):
"""
Provide List of post-transforms
For Example::
return [
monai.transforms.AddChanneld(keys='pred'),
monai.transforms.Activationsd(keys='pred', softmax=True),
monai.transforms.AsDiscreted(keys='pred', argmax=True),
monai.transforms.SqueezeDimd(keys='pred', dim=0),
monai.transforms.ToNumpyd(keys='pred'),
monailabel.interface.utils.Restored(keys='pred', ref_image='image'),
monailabel.interface.utils.ExtremePointsd(keys='pred', result='result', points='points'),
monailabel.interface.utils.BoundingBoxd(keys='pred', result='result', bbox='bbox'),
]
"""
pass
[docs] @abstractmethod
def inferer(self):
"""
Provide Inferer Class
For Example::
return monai.inferers.SlidingWindowInferer(roi_size=[160, 160, 160])
"""
pass
def __call__(self, request):
"""
It provides basic implementation to run the following in order
- Run Pre Transforms
- Run Inferer
- Run Post Transforms
- Run Writer to save the label mask and result params
Returns: Label (File Path) and Result Params (JSON)
"""
begin = time.time()
req = copy.deepcopy(self._config)
req.update(copy.deepcopy(request))
logger.info(f"Infer Request (final): {req}")
data = copy.deepcopy(req)
data.update({"image_path": req.get("image")})
device = req.get("device", "cuda")
start = time.time()
pre_transforms = self.pre_transforms()
data = self.run_pre_transforms(data, pre_transforms)
latency_pre = time.time() - start
start = time.time()
data = self.run_inferer(data, device=device)
latency_inferer = time.time() - start
start = time.time()
data = self.run_invert_transforms(data, pre_transforms, self.inverse_transforms())
latency_invert = time.time() - start
start = time.time()
data = self.run_post_transforms(data, self.post_transforms())
latency_post = time.time() - start
start = time.time()
result_file_name, result_json = self.writer(data)
latency_write = time.time() - start
latency_total = time.time() - begin
logger.info(
"++ Latencies => Total: {:.4f}; "
"Pre: {:.4f}; Inferer: {:.4f}; Invert: {:.4f}; Post: {:.4f}; Write: {:.4f}".format(
latency_total,
latency_pre,
latency_inferer,
latency_invert,
latency_post,
latency_write,
)
)
result_json["label_names"] = self.labels
logger.info("Result File: {}".format(result_file_name))
logger.info("Result Json: {}".format(result_json))
return result_file_name, result_json
[docs] def run_post_transforms(self, data, transforms):
return run_transforms(data, transforms, log_prefix="POST")
def _get_network(self, device):
path = self.get_path()
logger.info("Infer model path: {}".format(path))
if not path and not self.network:
if self.type == InferType.SCRIBBLES:
return None
raise MONAILabelException(
MONAILabelError.INFERENCE_ERROR,
f"Model Path ({self.path}) does not exist/valid",
)
cached = self._networks.get(device)
statbuf = os.stat(path) if path else None
network = None
if cached:
if statbuf and statbuf.st_mtime == cached[1]:
network = cached[0]
elif statbuf:
logger.info(f"Reload model from cache. Prev ts: {cached[1]}; Current ts: {statbuf.st_mtime}")
if network is None:
if self.network:
network = self.network
if path:
checkpoint = torch.load(path)
model_state_dict = checkpoint.get(self.model_state_dict, checkpoint)
network.load_state_dict(model_state_dict)
else:
network = torch.jit.load(path)
network = network.cuda() if device == "cuda" else network
network.eval()
self._networks[device] = (network, statbuf.st_mtime if statbuf else 0)
return network
[docs] def run_inferer(self, data, convert_to_batch=True, device="cuda"):
"""
Run Inferer over pre-processed Data. Derive this logic to customize the normal behavior.
In some cases, you want to implement your own for running chained inferers over pre-processed data
:param data: pre-processed data
:param convert_to_batch: convert input to batched input
:param device: device type run load the model and run inferer
:return: updated data with output_key stored that will be used for post-processing
"""
inferer = self.inferer()
logger.info("Running Inferer:: {}".format(inferer.__class__.__name__))
network = self._get_network(device)
if network:
inputs = data[self.input_key]
inputs = inputs if torch.is_tensor(inputs) else torch.from_numpy(inputs)
inputs = inputs[None] if convert_to_batch else inputs
inputs = inputs.cuda() if device == "cuda" else inputs
with torch.no_grad():
outputs = inferer(inputs, network)
if device == "cuda":
torch.cuda.empty_cache()
outputs = outputs[0] if convert_to_batch else outputs
data[self.output_label_key] = outputs
else:
# consider them as callable transforms
data = run_transforms(data, inferer, log_prefix="INF", log_name="Inferer")
return data
[docs] def writer(self, data, extension=None, dtype=None):
"""
You can provide your own writer. However this writer saves the prediction/label mask to file
and fetches result json
:param data: typically it is post processed data
:param extension: output label extension
:param dtype: output label dtype
:return: tuple of output_file and result_json
"""
logger.info("Writing Result")
if extension is not None:
data["result_extension"] = extension
if dtype is not None:
data["result_dtype"] = dtype
writer = Writer(label=self.output_label_key, json=self.output_json_key)
return writer(data)
[docs] def clear(self):
self._networks.clear()