Source code for monailabel.datastore.dicom

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import pathlib
import shutil
from typing import Any, Dict, Iterator, List, Optional, Tuple

import requests
from cachetools import TTLCache, cached
from dicomweb_client import DICOMwebClient
from pydicom.dataset import Dataset

from monailabel.config import settings
from monailabel.datastore.local import LocalDatastore
from monailabel.datastore.utils.convert import binary_to_image, dicom_to_nifti, nifti_to_dicom_seg
from monailabel.datastore.utils.dicom import dicom_web_download_series, dicom_web_upload_dcm
from monailabel.interfaces.datastore import DefaultLabelTag
from monailabel.utils.others.generic import md5_digest

logger = logging.getLogger(__name__)


[docs]class DICOMwebClientX(DICOMwebClient): def _decode_multipart_message(self, response: requests.Response, stream: bool) -> Iterator[bytes]: content_type = response.headers["content-type"] media_type, *ct_info = (ct.strip() for ct in content_type.split(";")) if media_type.lower() != "multipart/related": response.headers["content-type"] = "multipart/related" return super()._decode_multipart_message(response, stream) # type: ignore
[docs]class DICOMWebDatastore(LocalDatastore): def __init__( self, client: DICOMwebClient, search_filter: Dict[str, Any], cache_path: Optional[str] = None, fetch_by_frame=False, convert_to_nifti=True, ): self._client = client self._search_filter = search_filter self._fetch_by_frame = fetch_by_frame self._convert_to_nifti = convert_to_nifti uri_hash = md5_digest(self._client.base_url) datastore_path = ( os.path.join(cache_path, uri_hash) if cache_path else os.path.join(pathlib.Path.home(), ".cache", "monailabel", "dicom", uri_hash) ) logger.info(f"DICOMWeb Datastore (cache) Path: {datastore_path}; FetchByFrame: {fetch_by_frame}") logger.info(f"DICOMWeb Convert To Nifti: {convert_to_nifti}") super().__init__(datastore_path=datastore_path, auto_reload=True)
[docs] def name(self) -> str: base_url: str = self._client.base_url return base_url
def _to_id(self, file: str) -> Tuple[str, str]: extensions = [".nii", ".nii.gz", ".nrrd"] for extension in extensions: if file.endswith(extension): return file.replace(extension, ""), extension return super()._to_id(file)
[docs] def get_image_uri(self, image_id: str) -> str: logger.info(f"Image ID: {image_id}") image_dir = os.path.realpath(os.path.join(self._datastore.image_path(), image_id)) logger.info(f"Image Dir (cache): {image_dir}") if not os.path.exists(image_dir) or not os.listdir(image_dir): dicom_web_download_series(None, image_id, image_dir, self._client, self._fetch_by_frame) if not self._convert_to_nifti: return image_dir image_nii_gz = os.path.realpath(os.path.join(self._datastore.image_path(), f"{image_id}.nii.gz")) if not os.path.exists(image_nii_gz): image_nii_gz = dicom_to_nifti(image_dir) super().add_image(image_id, image_nii_gz, self._dicom_info(image_id)) return image_nii_gz
[docs] def get_label_uri(self, label_id: str, label_tag: str, image_id: str = "") -> str: if label_tag != DefaultLabelTag.FINAL: return super().get_label_uri(label_id, label_tag) logger.info(f"Label ID: {label_id} => {label_tag}") label_dir = os.path.realpath(os.path.join(self._datastore.label_path(label_tag), label_id)) logger.info(f"Label Dir (cache): {label_dir}") if not os.path.exists(label_dir) or not os.listdir(label_dir): dicom_web_download_series(None, label_id, label_dir, self._client, self._fetch_by_frame) if not self._convert_to_nifti: return label_dir label_nii_gz = os.path.realpath( os.path.join(self._datastore.label_path(DefaultLabelTag.FINAL), f"{image_id}.nii.gz") ) if not os.path.exists(label_nii_gz): label_nii_gz = dicom_to_nifti(label_dir, is_seg=True) if label_nii_gz: super().save_label(image_id, label_nii_gz, label_tag, self._dicom_info(label_id)) return label_nii_gz
def _dicom_info(self, series_id): meta = Dataset.from_json(self._client.search_for_series(search_filters={"SeriesInstanceUID": series_id})[0]) fields = ["StudyDate", "StudyTime", "Modality", "RetrieveURL", "PatientID", "StudyInstanceUID"] info = {"SeriesInstanceUID": series_id} for f in fields: info[f] = str(meta[f].value) if meta.get(f) else "UNK" return info
[docs] @cached(cache=TTLCache(maxsize=16, ttl=settings.MONAI_LABEL_DICOMWEB_CACHE_EXPIRY)) def list_images(self) -> List[str]: datasets = self._client.search_for_series(search_filters=self._search_filter) series = [str(Dataset.from_json(ds)["SeriesInstanceUID"].value) for ds in datasets] logger.debug("Total Series: {}\n{}".format(len(series), "\n".join(series))) return series
[docs] @cached(cache=TTLCache(maxsize=16, ttl=settings.MONAI_LABEL_DICOMWEB_CACHE_EXPIRY)) def get_labeled_images(self, label_tag: Optional[str] = None, labels: Optional[List[str]] = None) -> List[str]: datasets = self._client.search_for_series(search_filters={"Modality": "SEG"}) all_segs = [Dataset.from_json(ds) for ds in datasets] image_series = [] for seg in all_segs: meta = self._client.retrieve_series_metadata( str(seg["StudyInstanceUID"].value), str(seg["SeriesInstanceUID"].value) ) seg_meta = Dataset.from_json(meta[0]) if seg_meta.get("ReferencedSeriesSequence"): referenced_series_instance_uid = str( seg_meta["ReferencedSeriesSequence"].value[0]["SeriesInstanceUID"].value ) if referenced_series_instance_uid in self.list_images(): image_series.append(referenced_series_instance_uid) else: logger.warning( "Label Ignored:: ReferencedSeriesSequence is NOT in filtered image list: {}".format( str(seg["SeriesInstanceUID"].value) ) ) else: logger.warning( "Label Ignored:: ReferencedSeriesSequence is NOT found: {}".format( str(seg["SeriesInstanceUID"].value) ) ) return image_series
[docs] def get_unlabeled_images(self, label_tag: Optional[str] = None, labels: Optional[List[str]] = None) -> List[str]: series = self.list_images() seg_series = self.get_labeled_images() logger.info("Total Series (with seg): {}\n{}".format(len(seg_series), "\n".join(seg_series))) return list(set(series) - set(seg_series))
[docs] def save_label( self, image_id: str, label_filename: str, label_tag: str, label_info: Dict[str, Any], label_id: str = "" ) -> str: logger.info(f"Input - Image Id: {image_id}") logger.info(f"Input - Label File: {label_filename}") logger.info(f"Input - Label Tag: {label_tag}") logger.info(f"Input - Label Info: {label_info}") image_uri = self.get_image_uri(image_id) label_ext = "".join(pathlib.Path(label_filename).suffixes) output_file = "" if label_ext == ".bin": output_file = binary_to_image(image_uri, label_filename) label_filename = output_file logger.info(f"Label File: {label_filename}") # Support DICOM-SEG uploading only final version if label_tag == DefaultLabelTag.FINAL: image_dir = os.path.realpath(os.path.join(self._datastore.image_path(), image_id)) label_file = nifti_to_dicom_seg(image_dir, label_filename, label_info.get("label_info")) label_series_id = dicom_web_upload_dcm(label_file, self._client) image_info = self.get_image_info(image_id) label_info.update( { "SeriesInstanceUID": label_series_id, "Modality": image_info.get("Modality"), "PatientID": image_info.get("PatientID"), "StudyInstanceUID": image_info.get("StudyInstanceUID"), } ) os.unlink(label_file) label_id = super().save_label(image_id, label_filename, label_tag, label_info) logger.info("Save completed!") if output_file: os.unlink(output_file) return label_id
def _download_labeled_data(self): datasets = self._client.search_for_series(search_filters={"Modality": "SEG"}) all_segs = [Dataset.from_json(ds) for ds in datasets] image_labels = [] for seg in all_segs: meta = self._client.retrieve_series_metadata( str(seg["StudyInstanceUID"].value), str(seg["SeriesInstanceUID"].value) ) seg_meta = Dataset.from_json(meta[0]) if seg_meta.get("ReferencedSeriesSequence"): referenced_series_instance_uid = str( seg_meta["ReferencedSeriesSequence"].value[0]["SeriesInstanceUID"].value ) if referenced_series_instance_uid in self.list_images(): image_labels.append( { "image": str(seg_meta["ReferencedSeriesSequence"].value[0]["SeriesInstanceUID"].value), "label": str(seg["SeriesInstanceUID"].value), } ) else: logger.warning( "Label Ignored:: ReferencedSeriesSequence is NOT in filtered image list: {}".format( str(seg["SeriesInstanceUID"].value) ) ) else: logger.warning( "Label Ignored:: ReferencedSeriesSequence is NOT found: {}".format( str(seg["SeriesInstanceUID"].value) ) ) invalid = set(super().get_labeled_images()) - {image_label["image"] for image_label in image_labels} logger.info(f"Invalid Labels: {invalid}") for e in invalid: logger.info(f"Label {e} not exist on remote; Remove from local") label_uri = super().get_label_uri(e, DefaultLabelTag.FINAL) if label_uri and os.path.exists(label_uri): shutil.rmtree(os.path.join(os.path.dirname(label_uri), e), ignore_errors=True) os.unlink(label_uri) for image_label in image_labels: self.get_image_uri(image_id=image_label["image"]) self.get_label_uri( label_id=image_label["label"], label_tag=DefaultLabelTag.FINAL, image_id=image_label["image"] )
[docs] def datalist(self, full_path=True) -> List[Dict[str, Any]]: self._download_labeled_data() return super().datalist(full_path)