# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import os
import time
from abc import abstractmethod
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch
from monai.data import decollate_batch
from monai.inferers import Inferer, SimpleInferer, SlidingWindowInferer
from monai.utils import deprecated
from monailabel.interfaces.exception import MONAILabelError, MONAILabelException
from monailabel.interfaces.tasks.infer_v2 import InferTask, InferType
from monailabel.interfaces.utils.transform import dump_data, run_transforms
from monailabel.transform.cache import CacheTransformDatad
from monailabel.transform.writer import ClassificationWriter, DetectionWriter, Writer
from monailabel.utils.others.generic import device_list, device_map, name_to_device, strtobool
logger = logging.getLogger(__name__)
[docs]class CallBackTypes(str, Enum):
PRE_TRANSFORMS = "PRE_TRANSFORMS"
INFERER = "INFERER"
INVERT_TRANSFORMS = "INVERT_TRANSFORMS"
POST_TRANSFORMS = "POST_TRANSFORMS"
WRITER = "WRITER"
[docs]class BasicInferTask(InferTask):
"""
Basic Inference Task Helper
"""
[docs] def __init__(
self,
path: Union[None, str, Sequence[str]],
network: Union[None, Any],
type: Union[str, InferType],
labels: Union[str, None, Sequence[str], Dict[Any, Any]],
dimension: int,
description: str,
model_state_dict: str = "model",
input_key: str = "image",
output_label_key: str = "pred",
output_json_key: str = "result",
config: Union[None, Dict[str, Any]] = None,
load_strict: bool = True,
roi_size=None,
preload=False,
train_mode=False,
skip_writer=False,
):
"""
:param path: Model File Path. Supports multiple paths to support versions (Last item will be picked as latest)
:param network: Model Network (e.g. monai.networks.xyz). None in case if you use TorchScript (torch.jit).
:param type: Type of Infer (segmentation, deepgrow etc..)
:param labels: Labels associated to this Infer
:param dimension: Input dimension
:param description: Description
:param model_state_dict: Key for loading the model state from checkpoint
:param input_key: Input key for running inference
:param output_label_key: Output key for storing result/label of inference
:param output_json_key: Output key for storing result/label of inference
:param config: K,V pairs to be part of user config
:param load_strict: Load model in strict mode
:param roi_size: ROI size for scanning window inference
:param preload: Preload model/network on all available GPU devices
:param train_mode: Run in Train mode instead of eval (when network has dropouts)
:param skip_writer: Skip Writer and return data dictionary
"""
super().__init__(type, labels, dimension, description, config)
self.path = [] if not path else [path] if isinstance(path, str) else path
self.network = network
self.model_state_dict = model_state_dict
self.input_key = input_key
self.output_label_key = output_label_key
self.output_json_key = output_json_key
self.load_strict = load_strict
self.roi_size = roi_size
self.train_mode = train_mode
self.skip_writer = skip_writer
self._networks: Dict = {}
self._config.update(
{
"device": device_list(),
# "result_extension": None,
# "result_dtype": None,
# "result_compress": False
# "roi_size": self.roi_size,
# "sw_batch_size": 1,
# "sw_overlap": 0.25,
}
)
if config:
self._config.update(config)
if preload:
for device in device_map().values():
logger.info(f"Preload Network for device: {device}")
self._get_network(device, None)
[docs] def info(self) -> Dict[str, Any]:
return {
"type": self.type,
"labels": self.labels,
"dimension": self.dimension,
"description": self.description,
"config": self.config(),
}
[docs] def config(self) -> Dict[str, Any]:
return self._config
[docs] def is_valid(self) -> bool:
if self.network or self.type == InferType.SCRIBBLES:
return True
paths = self.path
for path in reversed(paths):
if path and os.path.exists(path):
return True
return False
[docs] def get_path(self, validate=True):
if not self.path:
return None
paths = self.path
for path in reversed(paths):
if path:
if not validate or os.path.exists(path):
return path
return None
# if data and data.get("cache_transforms", False):
# in_memory = data.get("cache_transforms_in_memory", True)
# ttl = data.get("cache_transforms_ttl", 300)
#
# t.append(CacheTransformDatad(keys=keys, hash_key=hash_key, in_memory=in_memory, ttl=ttl))
[docs] @abstractmethod
def post_transforms(self, data=None) -> Sequence[Callable]:
"""
Provide List of post-transforms
:param data: current data dictionary/request which can be helpful to define the transforms per-request basis
For Example::
return [
monai.transforms.EnsureChannelFirstd(keys='pred', channel_dim='no_channel'),
monai.transforms.Activationsd(keys='pred', softmax=True),
monai.transforms.AsDiscreted(keys='pred', argmax=True),
monai.transforms.SqueezeDimd(keys='pred', dim=0),
monai.transforms.ToNumpyd(keys='pred'),
monailabel.interface.utils.Restored(keys='pred', ref_image='image'),
monailabel.interface.utils.ExtremePointsd(keys='pred', result='result', points='points'),
monailabel.interface.utils.BoundingBoxd(keys='pred', result='result', bbox='bbox'),
]
"""
pass
[docs] def inferer(self, data=None) -> Inferer:
input_shape = data[self.input_key].shape if data else None
roi_size = data.get("roi_size", self.roi_size) if data else self.roi_size
sw_batch_size = data.get("sw_batch_size", 1) if data else 1
sw_overlap = data.get("sw_overlap", 0.25) if data else 0.25
device = data.get("device")
sliding = False
if input_shape and roi_size:
for i in range(len(roi_size)):
if input_shape[-i] > roi_size[-i]:
sliding = True
if sliding:
return SlidingWindowInferer(
roi_size=roi_size,
overlap=sw_overlap,
sw_batch_size=sw_batch_size,
sw_device=device,
device=device,
)
return SimpleInferer()
[docs] def detector(self, data=None) -> Optional[Callable]:
return None
def __call__(
self, request, callbacks: Union[Dict[CallBackTypes, Any], None] = None
) -> Tuple[Union[str, None], Dict]:
"""
It provides basic implementation to run the following in order
- Run Pre Transforms
- Run Inferer
- Run Invert Transforms
- Run Post Transforms
- Run Writer to save the label mask and result params
You can provide callbacks which can be useful while writing pipelines to consume intermediate outputs
Callback function should consume data and return data (modified/updated) e.g. `def my_cb(data): return data`
Returns: Label (File Path) and Result Params (JSON)
"""
begin = time.time()
req = copy.deepcopy(self._config)
req.update(request)
# device
device = name_to_device(req.get("device", "cuda"))
req["device"] = device
logger.setLevel(req.get("logging", "INFO").upper())
if req.get("image") is not None and isinstance(req.get("image"), str):
logger.info(f"Infer Request (final): {req}")
data = copy.deepcopy(req)
data.update({"image_path": req.get("image")})
else:
dump_data(req, logger.level)
data = req
# callbacks useful in case of pipeliens to consume intermediate output from each of the following stages
# callback function should consume data and returns data (modified/updated)
callbacks = callbacks if callbacks else {}
callback_run_pre_transforms = callbacks.get(CallBackTypes.PRE_TRANSFORMS)
callback_run_inferer = callbacks.get(CallBackTypes.INFERER)
callback_run_invert_transforms = callbacks.get(CallBackTypes.INVERT_TRANSFORMS)
callback_run_post_transforms = callbacks.get(CallBackTypes.POST_TRANSFORMS)
callback_writer = callbacks.get(CallBackTypes.WRITER)
start = time.time()
pre_transforms = self.pre_transforms(data)
data = self.run_pre_transforms(data, pre_transforms)
if callback_run_pre_transforms:
data = callback_run_pre_transforms(data)
latency_pre = time.time() - start
start = time.time()
if self.type == InferType.DETECTION:
data = self.run_detector(data, device=device)
else:
data = self.run_inferer(data, device=device)
if callback_run_inferer:
data = callback_run_inferer(data)
latency_inferer = time.time() - start
start = time.time()
data = self.run_invert_transforms(data, pre_transforms, self.inverse_transforms(data))
if callback_run_invert_transforms:
data = callback_run_invert_transforms(data)
latency_invert = time.time() - start
start = time.time()
data = self.run_post_transforms(data, self.post_transforms(data))
if callback_run_post_transforms:
data = callback_run_post_transforms(data)
latency_post = time.time() - start
if self.skip_writer or strtobool(data.get("skip_writer")):
return None, dict(data)
start = time.time()
result_file_name, result_json = self.writer(data)
if callback_writer:
data = callback_writer(data)
latency_write = time.time() - start
latency_total = time.time() - begin
logger.info(
"++ Latencies => Total: {:.4f}; "
"Pre: {:.4f}; Inferer: {:.4f}; Invert: {:.4f}; Post: {:.4f}; Write: {:.4f}".format(
latency_total,
latency_pre,
latency_inferer,
latency_invert,
latency_post,
latency_write,
)
)
result_json["label_names"] = self.labels
result_json["latencies"] = {
"pre": round(latency_pre, 2),
"infer": round(latency_inferer, 2),
"invert": round(latency_invert, 2),
"post": round(latency_post, 2),
"write": round(latency_write, 2),
"total": round(latency_total, 2),
"transform": data.get("latencies"),
}
# Add Centroids to the result json to consume in OHIF v3
centroids = data.get("centroids", None)
if centroids is not None:
centroids_dict = dict()
for c in centroids:
all_items = list(c.items())
centroids_dict[all_items[0][0]] = [str(i) for i in all_items[0][1]] # making it json compatible
result_json["centroids"] = centroids_dict
else:
result_json["centroids"] = dict()
if result_file_name is not None and isinstance(result_file_name, str):
logger.info(f"Result File: {result_file_name}")
logger.info(f"Result Json Keys: {list(result_json.keys())}")
return result_file_name, result_json
[docs] def run_post_transforms(self, data: Dict[str, Any], transforms):
return run_transforms(data, transforms, log_prefix="POST")
[docs] def clear_cache(self):
self._networks.clear()
def _get_network(self, device, data):
path = self.get_path()
logger.info(f"Infer model path: {path}")
if data and self._config.get("model_filename"):
model_filename = data.get("model_filename")
model_filename = model_filename if isinstance(model_filename, str) else model_filename[0]
user_path = os.path.join(os.path.dirname(self.path[0]), model_filename)
if user_path and os.path.exists(user_path):
path = user_path
logger.info(f"Using <User> provided model_file: {user_path}")
else:
logger.info(f"Ignoring <User> provided model_file (not valid): {user_path}")
if not path and not self.network:
if self.type == InferType.SCRIBBLES:
return None
raise MONAILabelException(
MONAILabelError.INFERENCE_ERROR,
f"Model Path ({self.path}) does not exist/valid",
)
cached = self._networks.get(device)
statbuf = os.stat(path) if path else None
network = None
if cached:
if statbuf and statbuf.st_mtime == cached[1]:
network = cached[0]
elif statbuf:
logger.warning(f"Reload model from cache. Prev ts: {cached[1]}; Current ts: {statbuf.st_mtime}")
if network is None:
if self.network:
network = copy.deepcopy(self.network)
network.to(torch.device(device))
if path:
checkpoint = torch.load(path, map_location=torch.device(device))
model_state_dict = checkpoint.get(self.model_state_dict, checkpoint)
if set(self.network.state_dict().keys()) != set(model_state_dict.keys()):
logger.warning(
f"Checkpoint keys don't match network.state_dict()! Items that exist in only one dict"
f" but not in the other: {set(self.network.state_dict().keys()) ^ set(model_state_dict.keys())}"
)
logger.warning(
"The run will now continue unless load_strict is set to True. "
"If loading fails or the network behaves abnormally, please check the loaded weights"
)
network.load_state_dict(model_state_dict, strict=self.load_strict)
else:
network = torch.jit.load(path, map_location=torch.device(device))
if self.train_mode:
network.train()
else:
network.eval()
self._networks[device] = (network, statbuf.st_mtime if statbuf else 0)
return network
[docs] def run_inferer(self, data: Dict[str, Any], convert_to_batch=True, device="cuda"):
"""
Run Inferer over pre-processed Data. Derive this logic to customize the normal behavior.
In some cases, you want to implement your own for running chained inferers over pre-processed data
:param data: pre-processed data
:param convert_to_batch: convert input to batched input
:param device: device type run load the model and run inferer
:return: updated data with output_key stored that will be used for post-processing
"""
inferer = self.inferer(data)
logger.info(f"Inferer:: {device} => {inferer.__class__.__name__} => {inferer.__dict__}")
network = self._get_network(device, data)
if network:
inputs = data[self.input_key]
inputs = inputs if torch.is_tensor(inputs) else torch.from_numpy(inputs)
inputs = inputs[None] if convert_to_batch else inputs
inputs = inputs.to(torch.device(device))
with torch.no_grad():
outputs = inferer(inputs, network)
if device.startswith("cuda"):
torch.cuda.empty_cache()
if convert_to_batch:
if isinstance(outputs, dict):
outputs_d = decollate_batch(outputs)
outputs = outputs_d[0]
else:
outputs = outputs[0]
data[self.output_label_key] = outputs
else:
# consider them as callable transforms
data = run_transforms(data, inferer, log_prefix="INF", log_name="Inferer")
return data
[docs] def run_detector(self, data: Dict[str, Any], convert_to_batch=True, device="cuda"):
"""
Run Detector over pre-processed Data. Derive this logic to customize the normal behavior.
In some cases, you want to implement your own for running chained inferers over pre-processed data
:param data: pre-processed data
:param convert_to_batch: convert input to batched input
:param device: device type run load the model and run inferer
:return: updated data with output_key stored that will be used for post-processing
"""
"""
Run Detector over pre-processed Data. Derive this logic to customize the normal behavior.
In some cases, you want to implement your own for running chained detector ops over pre-processed data
:param data: pre-processed data
:param device: device type run load the model and run inferer
:return: updated data with output_key stored that will be used for post-processing
"""
detector = self.detector(data)
if detector is None:
raise ValueError("Detector is Not Provided")
if hasattr(detector, "inferer"):
logger.info(
f"Detector Inferer:: {device} => {detector.inferer.__class__.__name__} => {detector.inferer.__dict__}" # type: ignore
)
network = self._get_network(device, data)
if network:
inputs = data[self.input_key]
inputs = inputs if torch.is_tensor(inputs) else torch.from_numpy(inputs)
inputs = inputs[None] if convert_to_batch else inputs
inputs = inputs.to(torch.device(device))
if hasattr(detector, "network"):
detector.network = network # type: ignore
else:
logger.warning("Detector has no 'network' attribute defined; Running without pretrained network")
with torch.no_grad():
if callable(getattr(detector, "eval", None)):
detector.eval() # type: ignore
network.eval()
outputs = detector(inputs, use_inferer=True)
if device.startswith("cuda"):
torch.cuda.empty_cache()
if convert_to_batch:
if isinstance(outputs, dict):
outputs_d = decollate_batch(outputs)
outputs = outputs_d[0]
else:
outputs = outputs[0]
if isinstance(outputs, dict):
data.update(outputs)
else:
data[self.output_label_key] = outputs
return data
[docs] def writer(self, data: Dict[str, Any], extension=None, dtype=None) -> Tuple[Any, Any]:
"""
You can provide your own writer. However, this writer saves the prediction/label mask to file
and fetches result json
:param data: typically it is post processed data
:param extension: output label extension
:param dtype: output label dtype
:return: tuple of output_file and result_json
"""
logger.info("Writing Result...")
if extension is not None:
data["result_extension"] = extension
if dtype is not None:
data["result_dtype"] = dtype
if self.labels is not None:
data["labels"] = self.labels
if self.type == InferType.CLASSIFICATION:
if isinstance(self.labels, dict):
label_names = {v: k for k, v in self.labels.items()}
else:
label_names = {v: k for v, k in enumerate(self.labels)} if isinstance(self.labels, Sequence) else None
cw = ClassificationWriter(label=self.output_label_key, label_names=label_names)
return cw(data)
if self.type == InferType.DETECTION:
dw = DetectionWriter()
return dw(data)
writer = Writer(label=self.output_label_key, json=self.output_json_key)
return writer(data)
[docs] def clear(self):
self._networks.clear()
[docs] def set_loglevel(self, level: str):
logger.setLevel(level.upper())