Source code for monailabel.datastore.xnat

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import logging
import os
import pathlib
import shutil
import time
from typing import Any, Dict, List, Optional
from urllib.parse import quote_plus
from xml.etree import ElementTree

import requests
from requests.auth import HTTPBasicAuth

from monailabel.datastore.utils.convert import nifti_to_dicom_seg
from monailabel.interfaces.datastore import Datastore
from monailabel.utils.others.generic import md5_digest

logger = logging.getLogger(__name__)
xnat_ns = {"xnat": "http://nrg.wustl.edu/xnat"}


[docs]class XNATDatastore(Datastore): def __init__(self, api_url, username=None, password=None, project=None, asset_path="", cache_path=""): self.api_url = api_url self.xnat_session = requests.sessions.session() self.auth = HTTPBasicAuth(username, password) if username else None self.xnat_csrf = "" self._login_xnat() self.projects = project.split(",") if project else [] self.projects = {p.strip() for p in self.projects} self.asset_path = asset_path uri_hash = md5_digest(api_url) cache_path = cache_path.strip() if cache_path else "" self.cache_path = ( os.path.join(cache_path, uri_hash) if cache_path else os.path.join(pathlib.Path.home(), ".cache", "monailabel", "xnat", uri_hash) ) logger.info(f"XNAT:: API URL: {api_url}") logger.info(f"XNAT:: UserName: {username}") logger.info(f"XNAT:: Password: {'*' * len(password) if password else ''}") logger.info(f"XNAT:: Project: {project}") logger.info(f"XNAT:: AssetPath: {asset_path}")
[docs] def name(self) -> str: return "XNAT Datastore"
[docs] def set_name(self, name: str): pass
[docs] def description(self) -> str: return "XNAT Datastore"
[docs] def set_description(self, description: str): pass
[docs] def datalist(self) -> List[Dict[str, Any]]: return [ { "api_url": self.api_url, "image": image_id, "label": image_id, } for image_id in self.get_labeled_images() ]
[docs] def get_labels_by_image_id(self, image_id: str) -> Dict[str, str]: raise NotImplementedError
[docs] def get_label_by_image_id(self, image_id: str, tag: str) -> str: raise NotImplementedError
[docs] def get_image(self, image_id: str, params=None) -> Any: p = self._download_image(image_id, check_zip=True) uri = os.path.join(os.path.dirname(p), "files.zip") return io.BytesIO(pathlib.Path(uri).read_bytes()) if uri else None
[docs] def get_image_uri(self, image_id: str) -> str: return self._download_image(image_id, check_zip=False)
[docs] def get_label(self, label_id: str, label_tag: str, params=None) -> Any: raise NotImplementedError
[docs] def get_label_uri(self, label_id: str, label_tag: str) -> str: raise NotImplementedError
[docs] def get_image_info(self, image_id: str) -> Dict[str, Any]: info = {} project, subject, experiment, scan = self._id_to_fields(image_id) url = "{}/data/projects/{}/subjects/{}/experiments/{}/scans/{}?format=xml".format( self.api_url, quote_plus(project), quote_plus(subject), quote_plus(experiment), quote_plus(scan), ) response = self._request_get(url) if response.ok: info.update({"project": project, "subject": subject, "experiment": experiment, "scan": scan}) return info
[docs] def get_label_info(self, label_id: str, label_tag: str) -> Dict[str, Any]: return {}
[docs] def get_labeled_images(self, label_tag: Optional[str] = None, labels: Optional[List[str]] = None) -> List[str]: return []
[docs] def get_unlabeled_images(self, label_tag: Optional[str] = None, labels: Optional[List[str]] = None) -> List[str]: return self.list_images()
[docs] def list_images(self) -> List[str]: image_ids: List[str] = [] response = self._request_get(f"{self.api_url}/data/projects?format=json") for p in response.json().get("ResultSet", {}).get("Result", []): project = p.get("ID") if self.projects and project not in self.projects: continue response = self._request_get(f"{self.api_url}/data/projects/{quote_plus(project)}/experiments?format=json") for e in response.json().get("ResultSet", {}).get("Result", []): experiment = e.get("ID") response = self._request_get(f"{self.api_url}/data/experiments/{quote_plus(experiment)}?format=xml") tree = ElementTree.fromstring(response.content) s = tree.find(".//xnat:subject_ID", namespaces=xnat_ns) if s is None: continue subject = s.text for n in tree.findall(".//xnat:scan", namespaces=xnat_ns): scan = n.get("ID") image_ids.append(f"{project}/{subject}/{experiment}/{scan}") return image_ids
[docs] def refresh(self) -> None: pass
[docs] def add_image(self, image_id: str, image_filename: str, image_info: Dict[str, Any]) -> str: raise NotImplementedError
[docs] def remove_image(self, image_id: str) -> None: raise NotImplementedError
def __convert_nifti_to_dcmseg(self, series_dir, nii_seg_path, model_name, label_names) -> str: label_info = [] for i, lb in enumerate(label_names): label_info.append( {"model_name": model_name, "name": str(i + 1) + "_" + lb, "description": "lb" + str(i + 1) + "_" + lb} ) dcmSegFile = nifti_to_dicom_seg(series_dir=series_dir, label=nii_seg_path, label_info=label_info) logging.info(f" converted nifti to dicom seg --- at {dcmSegFile}") return dcmSegFile
[docs] def save_label(self, image_id: str, label_filename: str, label_tag: str, label_info: Dict[str, Any]) -> str: aiaa_model_name = label_info.get("model", "NoModel") label_names = label_info.get("params", {}).get("label_names", {}) # save the nii.gz segmentation into Xnat project, subject, experiment, scan = self._id_to_fields(image_id) nameAtXnat = f"pat_{subject}_exp_{experiment}_S_{scan}_AI_{aiaa_model_name}.nii.gz" self._request_put_file( experiment, scan, name_at_xnat=nameAtXnat, file2send=label_filename, ai_model_name=aiaa_model_name ) # convert nii to dcm seg and upload to Xnat if label_filename.endswith(".nii") or label_filename.endswith(".nii.gz"): series_dir = self._download_image(image_id) tmp_dcm_segpath = self.__convert_nifti_to_dcmseg(series_dir, label_filename, aiaa_model_name, label_names) self.__upload_assessment(aiaa_model_name, image_id, tmp_dcm_segpath, "SEG") return image_id
[docs] def remove_label(self, label_id: str, label_tag: str) -> None: raise NotImplementedError
[docs] def update_image_info(self, image_id: str, info: Dict[str, Any]) -> None: pass
[docs] def update_label_info(self, label_id: str, label_tag: str, info: Dict[str, Any]) -> None: pass
[docs] def get_dataset_archive(self, limit_cases: Optional[int]) -> str: raise NotImplementedError
[docs] def status(self) -> Dict[str, Any]: return { "total": len(self.list_images()), "completed": len(self.get_labeled_images()), }
[docs] def json(self): return self.datalist()
def _find_in_asset_store(self, project, subject, experiment, scan) -> str: url = "{}/data/projects/{}/subjects/{}/experiments/{}/scans/{}?format=xml".format( self.api_url, quote_plus(project), quote_plus(subject), quote_plus(experiment), quote_plus(scan), ) response = self._request_get(url) if response.ok: tree = ElementTree.fromstring(response.content) ele = tree.find('.//xnat:file[@label="DICOM"]', namespaces=xnat_ns) path = ele.get("URI") if ele is not None else "" if path: dicom_dir = os.path.dirname(os.path.join(self.asset_path, path.replace("/data/xnat/archive/", ""))) if os.path.exists(dicom_dir) and len(os.listdir(dicom_dir)) > 0: return dicom_dir return "" def _download_zip(self, dest_dir, dest_zip, dicom_dir, project, subject, experiment, scan): url = "{}/data/projects/{}/subjects/{}/experiments/{}/scans/{}/files?format=zip".format( self.api_url, quote_plus(project), quote_plus(subject), quote_plus(experiment), quote_plus(scan), ) response = self._request_get(url) if not response.ok: logger.info(f"Image Fetch Failed: {response.status_code} {response.reason}") return "" os.makedirs(dest_dir, exist_ok=True) with open(dest_zip, "wb") as fp: fp.write(response.content) extract_dir = os.path.join(dest_dir, "temp") shutil.unpack_archive(dest_zip, extract_dir) os.makedirs(dicom_dir, exist_ok=True) for root, _, files in os.walk(extract_dir): for f in files: if f.endswith(".dcm"): shutil.move(os.path.join(root, f), dicom_dir) shutil.rmtree(extract_dir) return dicom_dir def _download_image(self, image_id, check_zip=False) -> str: project, subject, experiment, scan = self._id_to_fields(image_id) if self.projects and project not in self.projects: logger.info(f"Access to Project: {project} is restricted; Allowed: {self.projects}") return "" # Check in Asset Store if self.asset_path and not check_zip: dicom_dir = self._find_in_asset_store(project, subject, experiment, scan) if dicom_dir: logger.info(f"Exists in asset store: {self.asset_path}") return dicom_dir # Check in Local Cache dest_dir = os.path.join(self.cache_path, project, subject, experiment, scan) dest_zip = os.path.join(dest_dir, "files.zip") dicom_dir = os.path.join(dest_dir, "DICOM") if os.path.exists(dest_zip) and len(os.listdir(dicom_dir)) > 0: logger.info(f"Exists in cache: {dest_zip}") return dicom_dir # Download DICOM Zip logger.info(f"Downloading: {project} => {subject} => {experiment} => {scan} => {dest_zip}") start = time.time() self._download_zip(dest_dir, dest_zip, dicom_dir, project, subject, experiment, scan) logger.info(f"Download Time (ms) for {image_id}: {round(time.time() - start, 4)}") return dicom_dir def _id_to_fields(self, image_id): fields = image_id.split("/") project = fields[0] subject = fields[1] experiment = fields[2] scan = fields[3] return project, subject, experiment, scan def _login_xnat(self): # Get CSRF token url = "{}/data/JSESSION?CSRF=true".format( self.api_url, ) csrf_response = self._request_get(url) if not csrf_response.ok: logger.error("XNAT:: Could not get XNAT CSRF token") raise Exception("Could not get XNAT CSRF token") content = csrf_response.content self.xnat_csrf = content.decode("utf-8").strip().split("=")[1] # Log in to XNAT url = f"{self.api_url}/data/JSESSION?XNAT_CSRF={self.xnat_csrf}" login_response = self._request_post(url) if not login_response.ok: logger.error("XNAT:: Could not log in to XNAT") raise Exception("Could not log in to XNAT") logger.info("XNAT:: Logged in XNAT") def _request_get(self, url): return self.xnat_session.get(url, allow_redirects=True) def _request_post(self, url): return self.xnat_session.post(url, auth=self.auth, allow_redirects=True) def _request_put(self, url, data, type): response = self.xnat_session.put( url, data=data, params={"overwrite": "true", "type": type}, headers={"Content-Type": "application/octet-stream"}, allow_redirects=True, ) if response.status_code != 200: # failed call logger.error(f" xnat put call error status_code= {response.status_code} text ={response.text}") else: logger.info(f" xnat dcm-seg / measurement json put completed {response.text}") return response def _request_put_file(self, experiment, scan, file2send, name_at_xnat, ai_model_name): """ uploads file to xnat using REST API """ folder = "AI" url = "{}/REST/experiments/{}/scans/{}/resources/{}/files/{}".format( self.api_url, quote_plus(experiment), quote_plus(scan), quote_plus(folder), quote_plus(name_at_xnat), ) data = open(file2send, "rb") params = {"overwrite": "true", "description": name_at_xnat, "content": ai_model_name, "format": "nii"} response = self.xnat_session.put( url, params=params, data=data, headers={"Content-Type": "application/octet-stream"}, allow_redirects=True ) if response.status_code != 200: # failed call logger.error(f" put call error status_code= {response.status_code} text ={response.text}") else: logger.info(f" put completed {response.text}") return response def __upload_assessment(self, aiaa_model_name, image_id, file_path, type): """ uploads Assessments to xnat using xapi :param aiaa_model_name: model name used to generate this file :param image_id: image id to use :param file_path: file to upload :param type: "SEG" for dicom seg or "MEAS" for measurments """ if not os.path.exists(file_path): logging.error(f" file {file_path} does not exist! ") return data = open(file_path, "rb") project, subject, experiment, scan = self._id_to_fields(image_id) url = "{}/xapi/roi/projects/{}/sessions/{}/collections/Pat{}_S{}_{}".format( self.api_url, quote_plus(project), quote_plus(experiment), quote_plus(subject), quote_plus(scan), quote_plus(aiaa_model_name), ) self._request_put(url, data, type=type)
""" def main(): from monai.transforms import LoadImage logging.basicConfig( level=logging.INFO, format="[%(asctime)s] [%(process)s] [%(threadName)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) # Create alias token for user through instead of using direct username and password # http://127.0.0.1/data/services/tokens/issue ds = XNATDatastore( api_url="http://127.0.0.1", username="admin", # "a8a73c8d-a0bd-44d1-87af-244476072af4", password="admin", # "wGdawXhqo9Fhsh5p1pd6nGRloF99mxXYvvBGjCtTl1A9zYkk4mlaQJuvJQUcXL62", asset_path="/localhome/sachi/Projects/xnat-docker-compose/xnat-data/archive", project="Test", ) image_ids = ds.list_images() logger.info("\n" + "\n".join(image_ids)) image_id = "Test/XNAT01_S00003/XNAT01_E00004/1_2_826_0_1_3680043_8_274_1_1_8323329_10631_1656479315_17615" image_uri = ds.get_image_uri(image_id) logger.info(f"+++ Image URI: {image_uri}") if image_uri: loader = LoadImage(image_only=True) image_np = loader(image_uri) logger.info(f"Image Shape: {image_np.shape}") if __name__ == "__main__": main() """