Source code for monailabel.datastore.dsa

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import pathlib
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List, Optional

import girder_client
import numpy as np
from PIL import Image

from monailabel.interfaces.datastore import Datastore, DefaultLabelTag
from monailabel.utils.others.generic import md5_digest

logger = logging.getLogger(__name__)


[docs]class DSADatastore(Datastore): def __init__(self, api_url, api_key=None, folder=None, annotation_groups=None, asset_store_path="", cache_path=""): self.api_url = api_url self.api_key = api_key self.folders = folder.split(",") if folder else [] self.folders = {f.strip() for f in self.folders} self.annotation_groups = [a.lower() if a else a for a in annotation_groups] if annotation_groups else [] self.asset_store_path = asset_store_path uri_hash = md5_digest(api_url) self.cache_path = ( os.path.join(cache_path, uri_hash) if cache_path else os.path.join(pathlib.Path.home(), ".cache", "monailabel", "dsa", uri_hash) ) logger.info(f"DSA:: Api Url: {api_url}") logger.info(f"DSA:: Api Key: {'*' * len(api_key) if api_key else ''}") logger.info(f"DSA:: Folder (Images): {folder}") logger.info(f"DSA:: Annotation Groups: {annotation_groups}") logger.info(f"DSA:: Local Asset Store Path: {asset_store_path}") self.gc = girder_client.GirderClient(apiUrl=api_url) if api_key: self.gc.authenticate(apiKey=api_key)
[docs] def name(self) -> str: return "DSA Datastore"
[docs] def set_name(self, name: str): pass
[docs] def description(self) -> str: return "Digital Slide Archive"
[docs] def set_description(self, description: str): pass
[docs] def datalist(self) -> List[Dict[str, Any]]: return [ { "api_url": self.api_url, "image": image_id, "label": image_id, "groups": self.annotation_groups, } for image_id in self.get_labeled_images() ]
[docs] def get_labels_by_image_id(self, image_id: str) -> Dict[str, str]: return {DefaultLabelTag.FINAL.name: image_id}
[docs] def get_label_by_image_id(self, image_id: str, tag: str) -> str: return image_id
[docs] def get_annotations_by_image_id(self, image_id: str) -> Dict[str, Dict[str, List]]: image_id, name = self._name_to_id(image_id) data = self.gc.get(f"annotation/item/{image_id}", parameters={"limit": 0}) result: Dict[str, Dict[str, List]] = {} # TODO(avirodov): probably can request only annotation for a given image_id, need to check how. # TODO(avirodov): download only "relevant" annotations. Maybe a flag to start_server? for d in data: if d["itemId"] == image_id: annotation_data = self.gc.get(f'annotation/{d["_id"]}') name = d["annotation"]["name"] result[name] = {} result[name]["points"] = [] for element in annotation_data["annotation"]["elements"]: # TODO(avirodov): support other elements for other training types. For now only NuClick points. if element["type"] == "point": # TODO(avirodov): Define a proper annotation model for Monai-Label (DSA's model could be it). result[name]["points"].append( (float(element["center"][0]), float(element["center"][1]), float(element["center"][2])) ) return result
[docs] def get_image(self, image_id: str, params=None) -> Any: try: name = self.get_image_info(image_id)["name"] except girder_client.HttpError: image_id, name = self._name_to_id(image_id) location = params.get("location", [0, 0]) size = params.get("size", [0, 0]) if sum(location) <= 0 and sum(size) <= 0: # whole side image dest = os.path.join(self.cache_path, name) if not os.path.exists(dest): logger.info(f"Downloading: {image_id} => {name} => {dest}") self.gc.downloadItem(itemId=image_id, dest=self.cache_path) return dest parameters = { "left": location[0], "top": location[1], "regionWidth": size[0], "regionHeight": size[1], "units": "base_pixels", "encoding": "PNG", } resp = self.gc.get(f"item/{image_id}/tiles/region", parameters=parameters, jsonResp=False) img = Image.open(BytesIO(resp.content)).convert("RGB") return np.asarray(img, dtype=np.uint8)
def _name_to_id(self, name): folders = self.folders if self.folders else self._get_all_folders() for folder in folders: # First check if the name is directly present data = self.gc.get("item", parameters={"folderId": folder, "name": name, "limit": 0}) for d in data: if d.get("largeImage"): return d["_id"], d["name"] # next check if the name is present in a stem form data = self.gc.get("item", parameters={"folderId": folder, "limit": 0}) for d in data: if d.get("largeImage") and d["name"] == name or Path(d["name"]).stem == name: return d["_id"], d["name"] # Next check if the name is anywhere in the system data = self.gc.get("item", parameters={"text": f'"{name}"' if '"' not in name else name, "limit": 0}) for d in data: if d.get("largeImage") and d["name"] == name or Path(d["name"]).stem == name: return d["_id"], d["name"] # If we fail to find the item, the best we can do is return the name return name, name
[docs] def get_image_uri(self, image_id: str) -> str: try: info = self.get_image_info(image_id) name = info["name"] file_id = info.get("largeImage", {}).get("fileId") except girder_client.HttpError: image_id, name = self._name_to_id(image_id) file_id = None if self.asset_store_path: if file_id is None: data = self.gc.get(f"item/{image_id}/files", parameters={"limit": 0}) file_id = data[0]["_id"] f = self.gc.get(f"resource/{file_id}?type=file") if "path" in f and os.path.exists(os.path.join(self.asset_store_path, f["path"])): return str(os.path.join(self.asset_store_path, f["path"])) else: cached = os.path.join(self.cache_path, name) if os.path.exists(cached): return str(cached) return f"{self.api_url}/item/{image_id}"
[docs] def get_label(self, label_id: str, label_tag: str, params=None) -> Any: return self.gc.get(f"annotation/item/{label_id}")
[docs] def get_label_uri(self, label_id: str, label_tag: str) -> str: return f"{self.api_url}/annotation/item/{label_id}"
[docs] def get_image_info(self, image_id: str) -> Dict[str, Any]: return self.gc.getItem(image_id) # type: ignore
[docs] def get_label_info(self, label_id: str, label_tag: str) -> Dict[str, Any]: return {}
def _get_annotated_images(self): data = self.gc.get("annotation", parameters={"limit": 0}) images = [] for d in data: if not self.annotation_groups: images.append(d["itemId"]) continue # get annotations and find if any matching groups exist matched = [ g for g in d["groups"] if g in self.annotation_groups or (g and g.lower() in self.annotation_groups) ] if matched: images.append(d["itemId"]) return images
[docs] def get_labeled_images(self, label_tag: Optional[str] = None, labels: Optional[List[str]] = None) -> List[str]: images = self.list_images() annotated = self._get_annotated_images() return [image for image in images if image in annotated]
[docs] def get_unlabeled_images(self, label_tag: Optional[str] = None, labels: Optional[List[str]] = None) -> List[str]: images = self.list_images() labeled = self.get_labeled_images() return [image for image in images if image not in labeled]
def _get_all_folders(self): folders = [] for collection in self.gc.listCollection(): for folder in self.gc.listFolder(parentId=collection["_id"], parentFolderType="collection"): folders.append(folder["_id"]) return folders
[docs] def list_images(self) -> List[str]: images = [] folders = self.folders if self.folders else self._get_all_folders() for folder in folders: for item in self.gc.get("item", parameters={"folderId": folder, "limit": 0}): if item.get("largeImage"): images.append(item["_id"]) return images
[docs] def refresh(self) -> None: pass
[docs] def add_image(self, image_id: str, image_filename: str, image_info: Dict[str, Any]) -> str: raise NotImplementedError
[docs] def remove_image(self, image_id: str) -> None: raise NotImplementedError
[docs] def save_label(self, image_id: str, label_filename: str, label_tag: str, label_info: Dict[str, Any]) -> str: raise NotImplementedError
[docs] def remove_label(self, label_id: str, label_tag: str) -> None: raise NotImplementedError
[docs] def update_image_info(self, image_id: str, info: Dict[str, Any]) -> None: raise NotImplementedError
[docs] def update_label_info(self, label_id: str, label_tag: str, info: Dict[str, Any]) -> None: raise NotImplementedError
[docs] def get_dataset_archive(self, limit_cases: Optional[int]) -> str: raise NotImplementedError
[docs] def status(self) -> Dict[str, Any]: # This is a very costly query, disable it for now # return { # "total": len(self.list_images()), # "completed": len(self.get_labeled_images()), # } return { "total": 0, "completed": 0, }
[docs] def json(self): return self.datalist()
""" def main(): import json logging.basicConfig( level=logging.INFO, format="[%(asctime)s] [%(process)s] [%(threadName)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) api_url = "http://0.0.0.0:8080/api/v1" folder = None annotation_groups = None asset_store_path = None api_key = None # "zBsr184BByiRK0BUyMMB01v3O8kTqkXPbqxndpfi" # api_url = "https://demo.kitware.com/histomicstk/api/v1" # folder = "5bbdeba3e629140048d017bb" # annotation_groups = ["mostly_tumor"] # asset_store_path = None # api_key = None ds = DSADatastore( api_url=api_url, api_key=api_key, folder=folder, annotation_groups=annotation_groups, asset_store_path=asset_store_path, ) images = ds.list_images() print(f"Images: {images}") labeled_images = ds.get_labeled_images() print(f"Labeled Images: {labeled_images}") unlabeled_images = ds.get_unlabeled_images() print(f"UnLabeled Images: {unlabeled_images}") image_id = images[0] print(f"Image Info: {json.dumps(ds.get_image_info(image_id), indent=2)}") print(f"Image URI: {ds.get_image_uri(image_id)}") print(f"Image URI (name): {ds.get_image_uri('TCGA-02-0010-01Z-00-DX4.07de2e55-a8fe-40ee-9e98-bcb78050b9f7')}") print(f"Labels: {ds.get_labels_by_image_id(image_id)}") if labeled_images: label_id = labeled_images[0] label_tag = "FINAL" print(f"Label Info: {json.dumps(ds.get_label_info(label_id, label_tag), indent=2)}") print(f"Label URI: {ds.get_label_uri(label_id, label_tag)}") print(f"Dataset for Training: \n{json.dumps(ds.datalist(), indent=2)}") img = ds.get_image( "TCGA-02-0010-01Z-00-DX4.07de2e55-a8fe-40ee-9e98-bcb78050b9f7", params={ "location": (6090, 15863), "size": (1071, 714), }, ) print(f"Fetched Region: {img.shape}") if __name__ == "__main__": main() """