Source code for monailabel.transform.cache

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
import os
import pathlib
from typing import Hashable, Sequence, Tuple, Union

import torch
from expiring_dict import ExpiringDict
from monai.config import KeysCollection
from monai.data import MetaTensor
from monai.transforms import Transform
from monai.utils import ensure_tuple

from monailabel.utils.others.generic import md5_digest
from monailabel.utils.sessions import Sessions

logger = logging.getLogger(__name__)

_cache_path = None
_data_mem_cache = None
_data_file_cache = None


[docs]def init_cache(): global _cache_path global _data_mem_cache global _data_file_cache if not _cache_path: _cache_path = os.path.join(pathlib.Path.home(), ".cache", "monailabel", "cacheT") _data_mem_cache = ExpiringDict(ttl=600) _data_file_cache = Sessions(store_path=_cache_path, expiry=600) _data_file_cache.remove_expired()
[docs]class CacheTransformDatad(Transform): def __init__( self, keys: KeysCollection, hash_key: Union[str, Sequence[str]] = ("image_path", "model"), in_memory: bool = True, ttl: int = 600, reset_applied_operations_id: bool = True, ): self.keys: Tuple[Hashable, ...] = ensure_tuple(keys) self.hash_key = [hash_key] if isinstance(hash_key, str) else hash_key self.in_memory = in_memory self.ttl = ttl self.reset_applied_operations_id = reset_applied_operations_id # remove previous expired... init_cache() def __call__(self, data): return self.save(data)
[docs] def load(self, data): d = dict(data) hash_key_prefix = md5_digest("".join([d[k] for k in self.hash_key])) # full dictionary if not self.keys: return self._load(f"{hash_key_prefix}") # set of keys for key in self.keys: d[key] = self._load(f"{hash_key_prefix}_{key}") if d[key] is None: logger.info(f"Ignore; Failed to load {key} from Cache; memory:{self.in_memory}") return None # For Invert Transform (reset id) if self.reset_applied_operations_id and isinstance(d[key], MetaTensor): for o in d[key].applied_operations: o["id"] = "none" return d
[docs] def save(self, data): d = dict(data) hash_keys = [d[k] for k in self.hash_key if d.get(k)] hash_key_prefix = md5_digest("".join(hash_keys)) if len(hash_keys) != len(self.hash_key): logger.warning(f"Ignore caching; Missing hash keys; Found: {hash_keys}; Expected: {self.hash_key}") return d # full dictionary if not self.keys: self._save(f"{hash_key_prefix}", d) else: for key in self.keys: self._save(f"{hash_key_prefix}_{key}", d[key]) return d
def _load(self, hash_key): if self.in_memory: return _data_mem_cache.get(hash_key) info = _data_file_cache.get_session(session_id=hash_key) if info and os.path.isfile(info.image): return torch.load(info.image) return None def _save(self, hash_key, obj): if self.in_memory: _data_mem_cache.ttl(key=hash_key, value=copy.deepcopy(obj), ttl=self.ttl) else: os.makedirs(_cache_path, exist_ok=True) cached_file = os.path.join(_cache_path, f"{hash_key}.tmp") torch.save(obj, cached_file) _data_file_cache.add_session(cached_file, expiry=self.ttl, session_id=hash_key)