Source code for monai.deploy.operators.inference_operator

# Copyright 2021-2023 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Tuple, Union

from monai.deploy.core import Fragment, Image, Operator


[docs]class InferenceOperator(Operator): """The base operator for operators that perform AI inference. This operator preforms pre-transforms on a input image, inference with a given model, post-transforms, and final results generation. """
[docs] def __init__(self, fragment: Fragment, *args, **kwargs): """Constructor of the operator.""" super().__init__(fragment, *args, **kwargs)
# @abstractmethod
[docs] def pre_process(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]: """Transforms input before being used for predicting on a model. This method must be overridden by a derived class. Raises: NotImplementedError: When the subclass does not override this method. """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
# @abstractmethod
[docs] def compute(self, op_input, op_output, context): """An abstract method that needs to be implemented by the user. Args: op_input (InputContext): An input context for the operator. op_output (OutputContext): An output context for the operator. context (ExecutionContext): An execution context for the operator. """ pass
# @abstractmethod
[docs] def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]: """Predicts results using the models(s) with input tensors. This method must be overridden by a derived class. Raises: NotImplementedError: When the subclass does not override this method. """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
# @abstractmethod
[docs] def post_process(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]: """Transform the prediction results from the model(s). This method must be overridden by a derived class. Raises: NotImplementedError: When the subclass does not override this method. """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")