Source code for monailabel.datastore.cvat

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import os
import shutil
import tempfile
import time
import urllib.parse

import numpy as np
import requests
from PIL import Image
from requests.auth import HTTPBasicAuth

from monailabel.datastore.local import LocalDatastore
from monailabel.interfaces.datastore import DefaultLabelTag
from monailabel.utils.others.generic import get_mime_type

logger = logging.getLogger(__name__)


[docs]class CVATDatastore(LocalDatastore): def __init__( self, datastore_path, api_url, username=None, password=None, project="MONAILabel", task_prefix="ActiveLearning_Iteration", image_quality=70, labels=None, normalize_label=True, segment_size=1, **kwargs, ): default_labels = [ {"name": "Tool", "attributes": [], "color": "#66ff66"}, {"name": "InBody", "attributes": [], "color": "#ff0000"}, {"name": "OutBody", "attributes": [], "color": "#0000ff"}, ] labels = labels if labels else default_labels labels = json.loads(labels) if isinstance(labels, str) else labels self.api_url = api_url.rstrip("/").strip() self.auth = HTTPBasicAuth(username, password) if username else None self.project = project self.task_prefix = task_prefix self.image_quality = image_quality self.labels = labels self.label_map = {l["name"]: idx for idx, l in enumerate(labels, start=1)} self.normalize_label = normalize_label self.segment_size = segment_size logger.info(f"CVAT:: API URL: {api_url}") logger.info(f"CVAT:: UserName: {username}") logger.info(f"CVAT:: Password: {'*' * len(password) if password else ''}") logger.info(f"CVAT:: Project: {project}") logger.info(f"CVAT:: Task Prefix: {task_prefix}") logger.info(f"CVAT:: Image Quality: {image_quality}") logger.info(f"CVAT:: Labels: {labels}") logger.info(f"CVAT:: Normalize Label: {normalize_label}") logger.info(f"CVAT:: Segment Size: {normalize_label}") super().__init__(datastore_path=datastore_path, **kwargs) self.done_prefix = "DONE"
[docs] def name(self) -> str: return "CVAT+Local Datastore"
[docs] def description(self) -> str: return "CVAT+Local Datastore"
[docs] def get_cvat_project_id(self, create): projects = requests.get(f"{self.api_url}/api/projects", auth=self.auth).json() logger.debug(projects) project_id = None for project in projects["results"]: if project["name"] == self.project: project_id = project["id"] break if create and project_id is None: body = {"name": self.project, "labels": self.labels} project = requests.post(f"{self.api_url}/api/projects", auth=self.auth, json=body).json() logger.info(project) project_id = project["id"] logger.debug(f"Using Project ID: {project_id}") return project_id
[docs] def get_cvat_task_id(self, project_id, create): filter = {"and": [{"==": [{"var": "project_id"}, project_id]}]} filter = urllib.parse.quote_plus(json.dumps(filter)) tasks = requests.get(f"{self.api_url}/api/tasks?filter={filter}", auth=self.auth).json() task_id = None task_name = "" for task in tasks["results"]: if task["name"].startswith(self.task_prefix): task_id = task["id"] task_name = task["name"] if task["name"] > task_name else task_name # increment to next iteration based on latest done_xxx if create: if not task_name: for task in tasks["results"]: if task["name"].startswith(f"{self.done_prefix}_{self.task_prefix}"): task_name = task["name"] if task["name"] > task_name else task_name version = int(task_name.split("_")[-1]) + 1 if task_name else 1 task_name = f"{self.task_prefix}_{version}" logger.info(f"Creating new CVAT Task: {task_name}; project: {self.project}") body = {"name": task_name, "labels": [], "project_id": project_id, "subset": "Train"} if self.segment_size: body["segment_size"] = self.segment_size task = requests.post(f"{self.api_url}/api/tasks", auth=self.auth, json=body).json() logger.debug(task) task_id = task["id"] logger.debug(f"Using Task ID: {task_id}; Task Name: {task_name}") return task_id, task_name
[docs] def task_status(self): project_id = self.get_cvat_project_id(create=False) if project_id is None: return None task_id, _ = self.get_cvat_task_id(project_id, create=False) if task_id is None: return None r = requests.get(f"{self.api_url}/api/tasks/{task_id}", auth=self.auth).json() return r.get("status")
[docs] def upload_to_cvat(self, samples): project_id = self.get_cvat_project_id(create=True) task_id, _ = self.get_cvat_task_id(project_id, create=True) file_list = [("image_quality", (None, f"{self.image_quality}"))] for i, image in enumerate(samples): logger.info(f"Selected Image to upload to CVAT: {image}") file_list.append((f"client_files[{i}]", (os.path.basename(image), open(image, "rb"), get_mime_type(image)))) r = requests.post(f"{self.api_url}/api/tasks/{task_id}/data", files=file_list, auth=self.auth).json() logger.info(r)
[docs] def trigger_automation(self, function): project_id = self.get_cvat_project_id(create=False) if project_id is not None: task_id, _ = self.get_cvat_task_id(project_id, create=False) if task_id is not None: body = {"cleanup": True, "task": task_id, "function": function} r = requests.post(f"{self.api_url}/api/lambda/requests?org=", json=body, auth=self.auth).json() logger.info(r)
def _load_labelmap_txt(self, file): labelmap = {} if os.path.exists(file): with open(file) as f: for line in f.readlines(): if line and not line.startswith("#"): fields = line.split(":") name = fields[0] rgb = tuple(int(c) for c in fields[1].split(",")) labelmap[name] = rgb return labelmap
[docs] def download_from_cvat(self, max_retry_count=5, retry_wait_time=10): status = self.task_status() if status != "completed": logger.info(f"No Tasks with completed status (current: {status}) to refresh/download the final labels") return None project_id = self.get_cvat_project_id(create=False) task_id, task_name = self.get_cvat_task_id(project_id, create=False) logger.info(f"Preparing to download/update final labels from: {project_id} => {task_id} => {task_name}") download_url = f"{self.api_url}/api/tasks/{task_id}/annotations?action=download&format=Segmentation+mask+1.1" tmp_folder = tempfile.TemporaryDirectory().name os.makedirs(tmp_folder, exist_ok=True) tmp_zip = tempfile.NamedTemporaryFile(suffix=".zip").name retry_count = 0 for retry in range(max_retry_count): try: r = requests.get(download_url, allow_redirects=True, auth=self.auth) time.sleep(retry_wait_time) with open(tmp_zip, "wb") as fp: fp.write(r.content) shutil.unpack_archive(tmp_zip, tmp_folder) segmentations_dir = os.path.join(tmp_folder, "SegmentationClass") final_labels = self._datastore.label_path(DefaultLabelTag.FINAL) for f in os.listdir(segmentations_dir): label = os.path.join(segmentations_dir, f) if os.path.isfile(label) and label.endswith(".png"): os.makedirs(final_labels, exist_ok=True) dest = os.path.join(final_labels, f) if self.normalize_label: img = np.array(Image.open(label)) mask = np.zeros_like(img) labelmap = self._load_labelmap_txt(os.path.join(tmp_folder, "labelmap.txt")) for name, color in labelmap.items(): if name in self.label_map: idx = self.label_map.get(name) mask[np.all(img == color, axis=-1)] = idx Image.fromarray(mask[:, :, 0]).save(dest) # single channel logger.info(f"Copy Final Label: {label} to {dest}; unique: {np.unique(mask)}") else: Image.open(label).save(dest) logger.info(f"Copy Final Label: {label} to {dest}") # Rename after consuming/downloading the labels patch_url = f"{self.api_url}/api/tasks/{task_id}" body = {"name": f"{self.done_prefix}_{task_name}"} requests.patch(patch_url, allow_redirects=True, auth=self.auth, json=body) return task_name except Exception as e: if retry_count: logger.exception(e) logger.error(f"{retry} => Failed to download...") retry_count = retry_count + 1 return None
""" def main(): from pathlib import Path from monailabel.config import settings settings.MONAI_LABEL_DATASTORE_AUTO_RELOAD = False settings.MONAI_LABEL_DATASTORE_FILE_EXT = ["*.png", "*.jpg", "*.jpeg", ".xml"] settings.MONAI_LABEL_DATASTORE = "cvat" settings.MONAI_LABEL_DATASTORE_URL = "http://10.117.19.88:8080" settings.MONAI_LABEL_DATASTORE_USERNAME = "sachi" settings.MONAI_LABEL_DATASTORE_PASSWORD = "sachi" os.putenv("MASTER_ADDR", "127.0.0.1") os.putenv("MASTER_PORT", "1234") logging.basicConfig( level=logging.INFO, format="[%(asctime)s] [%(process)s] [%(threadName)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", force=True, ) home = str(Path.home()) studies = f"{home}/Dataset/picked/all" ds = CVATDatastore( datastore_path=studies, api_url=settings.MONAI_LABEL_DATASTORE_URL, username=settings.MONAI_LABEL_DATASTORE_USERNAME, password=settings.MONAI_LABEL_DATASTORE_PASSWORD, project="MONAILabel", task_prefix="ActiveLearning_Iteration", image_quality=70, labels=None, normalize_label=True, segment_size=0, extensions=settings.MONAI_LABEL_DATASTORE_FILE_EXT, auto_reload=settings.MONAI_LABEL_DATASTORE_AUTO_RELOAD, ) ds.download_from_cvat() # studies = f"{home}/Dataset/Holoscan/flattened/images" if __name__ == "__main__": main() """