# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
from enum import Enum
from typing import Callable
from monailabel.interfaces.datastore import Datastore
logger = logging.getLogger(__name__)
[docs]class BatchInferImageType(str, Enum):
IMAGES_ALL = "all"
IMAGES_LABELED = "labeled"
IMAGES_UNLABELED = "unlabeled"
[docs]class BatchInferTask:
"""
Basic Batch Infer Task
"""
[docs] def get_images(self, request, datastore: Datastore):
"""
Override this method to get all eligible images for your task to run batch infer
"""
images = request.get("images", BatchInferImageType.IMAGES_ALL)
if isinstance(images, str):
if images == BatchInferImageType.IMAGES_LABELED:
return datastore.get_labeled_images()
if images == BatchInferImageType.IMAGES_UNLABELED:
return datastore.get_unlabeled_images()
return datastore.list_images()
return images
def __call__(self, request, datastore: Datastore, infer: Callable):
image_ids = self.get_images(request, datastore)
logger.info(f"Total number of images for batch inference: {len(image_ids)}")
result = {}
for image_id in image_ids:
req = copy.deepcopy(request)
req["image"] = image_id
logger.info(f"Running inference for image id {image_id}")
result[image_id] = infer(req, datastore)
return result