# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for accessing Nvidia MMARs
See Also:
- https://docs.nvidia.com/clara/clara-train-sdk/pt/mmar.html
"""
from __future__ import annotations
import json
import os
import warnings
from collections.abc import Mapping
from pathlib import Path
from typing import Any
import torch
import monai.networks.nets as monai_nets
from monai.apps.utils import download_and_extract, logger
from monai.config.type_definitions import PathLike
from monai.networks.utils import copy_model_state
from monai.utils.module import optional_import
from .model_desc import MODEL_DESC
from .model_desc import RemoteMMARKeys as Keys
__all__ = ["get_model_spec", "download_mmar", "load_from_mmar"]
def get_model_spec(idx: int | str) -> dict | Any:
"""get model specification by `idx`. `idx` could be index of the constant tuple of dict or the actual model ID."""
if isinstance(idx, int):
return MODEL_DESC[idx]
if isinstance(idx, str):
key = idx.strip().lower()
for cand in MODEL_DESC:
if str(cand.get(Keys.ID)).strip().lower() == key:
return cand
return idx
def _get_all_ngc_models(pattern, page_index=0, page_size=50):
url = "https://api.ngc.nvidia.com/v2/search/catalog/resources/MODEL"
query_dict = {
"query": "",
"orderBy": [{"field": "score", "value": "DESC"}],
"queryFields": ["all", "description", "displayName", "name", "resourceId"],
"fields": [
"isPublic",
"attributes",
"guestAccess",
"name",
"orgName",
"teamName",
"displayName",
"dateModified",
"labels",
"description",
],
"page": 0,
}
filter = [dict(field="name", value=f"*{pattern}*")]
query_dict["page"] = page_index
query_dict["pageSize"] = page_size
query_dict["filters"] = filter
query_str = json.dumps(query_dict)
full_url = f"{url}?q={query_str}"
requests_get, has_requests = optional_import("requests", name="get")
if has_requests:
resp = requests_get(full_url)
resp.raise_for_status()
else:
raise ValueError("NGC API requires requests package. Please install it.")
model_list = json.loads(resp.text)
model_dict = {}
for result in model_list["results"]:
for model in result["resources"]:
current_res_id = model["resourceId"]
model_dict[current_res_id] = {"name": model["name"]}
for attribute in model["attributes"]:
if attribute["key"] == "latestVersionIdStr":
model_dict[current_res_id]["latest"] = attribute["value"]
return model_dict
def _get_ngc_url(model_name: str, version: str, model_prefix: str = "") -> str:
return f"https://api.ngc.nvidia.com/v2/models/{model_prefix}{model_name}/versions/{version}/zip"
def _get_ngc_doc_url(model_name: str, model_prefix: str = "") -> str:
return f"https://ngc.nvidia.com/catalog/models/{model_prefix}{model_name}"
[docs]
def download_mmar(
item: str | Mapping, mmar_dir: PathLike | None = None, progress: bool = True, api: bool = True, version: int = -1
) -> Path:
"""
Download and extract Medical Model Archive (MMAR) from Nvidia Clara Train.
See Also:
- https://docs.nvidia.com/clara/
- Nvidia NGC Registry CLI
- https://docs.nvidia.com/clara/clara-train-sdk/pt/mmar.html
Args:
item: the corresponding model item from `MODEL_DESC`.
Or when api is True, the substring to query NGC's model name field.
mmar_dir: target directory to store the MMAR, default is `mmars` subfolder under `torch.hub get_dir()`.
progress: whether to display a progress bar.
api: whether to query NGC and download via api
version: which version of MMAR to download. -1 means the latest from ngc.
Examples::
>>> from monai.apps import download_mmar
>>> download_mmar("clara_pt_prostate_mri_segmentation_1", mmar_dir=".")
>>> download_mmar("prostate_mri_segmentation", mmar_dir=".", api=True)
Returns:
The local directory of the downloaded model.
If api is True, a list of local directories of downloaded models.
"""
if not mmar_dir:
get_dir, has_home = optional_import("torch.hub", name="get_dir")
if has_home:
mmar_dir = Path(get_dir()) / "mmars"
else:
raise ValueError("mmar_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?")
_mmar_dir = Path(mmar_dir)
model_dir: Path
if api:
model_dict = _get_all_ngc_models(item.get(Keys.NAME, f"{item}") if isinstance(item, Mapping) else f"{item}")
if len(model_dict) == 0:
raise ValueError(f"api query returns no item for pattern {item}. Please change or shorten it.")
model_dir_list: list[Path] = []
for k, v in model_dict.items():
ver = v["latest"] if version == -1 else str(version)
download_url = _get_ngc_url(k, ver)
model_dir = _mmar_dir / v["name"]
download_and_extract(
url=download_url,
filepath=_mmar_dir / f'{v["name"]}_{ver}.zip',
output_dir=model_dir,
hash_val=None,
hash_type="md5",
file_type="zip",
has_base=False,
progress=progress,
)
model_dir_list.append(model_dir)
if not model_dir_list:
raise ValueError(f"api query download no item for pattern {item}. Please change or shorten it.")
return model_dir_list[0]
if not isinstance(item, Mapping):
item = get_model_spec(item)
ver = item.get(Keys.VERSION, 1)
if version > 0:
ver = str(version)
model_fullname = f"{item[Keys.NAME]}_{ver}"
model_dir = _mmar_dir / model_fullname
model_url = item.get(Keys.URL) or _get_ngc_url(item[Keys.NAME], version=ver, model_prefix="nvidia/med/")
download_and_extract(
url=model_url,
filepath=_mmar_dir / f"{model_fullname}.{item[Keys.FILE_TYPE]}",
output_dir=model_dir,
hash_val=item[Keys.HASH_VAL],
hash_type=item[Keys.HASH_TYPE],
file_type=item[Keys.FILE_TYPE],
has_base=False,
progress=progress,
)
return model_dir
[docs]
def load_from_mmar(
item: Mapping | str | int,
mmar_dir: PathLike | None = None,
progress: bool = True,
version: int = -1,
map_location: Any | None = None,
pretrained: bool = True,
weights_only: bool = False,
model_key: str = "model",
api: bool = True,
model_file: PathLike | None = None,
) -> Any:
"""
Download and extract Medical Model Archive (MMAR) model weights from Nvidia Clara Train.
Args:
item: the corresponding model item from `MODEL_DESC`.
mmar_dir: : target directory to store the MMAR, default is mmars subfolder under `torch.hub get_dir()`.
progress: whether to display a progress bar when downloading the content.
version: version number of the MMAR. Set it to `-1` to use `item[Keys.VERSION]`.
map_location: pytorch API parameter for `torch.load` or `torch.jit.load`.
pretrained: whether to load the pretrained weights after initializing a network module.
weights_only: whether to load only the weights instead of initializing the network module and assign weights.
model_key: a key to search in the model file or config file for the model dictionary.
Currently this function assumes that the model dictionary has
`{"[name|path]": "test.module", "args": {'kw': 'test'}}`.
api: whether to query NGC API to get model infomation.
model_file: the relative path to the model file within an MMAR.
Examples::
>>> from monai.apps import load_from_mmar
>>> unet_model = load_from_mmar("clara_pt_prostate_mri_segmentation_1", mmar_dir=".", map_location="cpu")
>>> print(unet_model)
See Also:
https://docs.nvidia.com/clara/
"""
if api:
item = {Keys.NAME: get_model_spec(item)[Keys.NAME] if isinstance(item, int) else f"{item}"}
if not isinstance(item, Mapping):
item = get_model_spec(item)
model_dir = download_mmar(item=item, mmar_dir=mmar_dir, progress=progress, version=version, api=api)
if model_file is None:
model_file = os.path.join("models", "model.pt")
_model_file = model_dir / item.get(Keys.MODEL_FILE, model_file)
logger.info(f'\n*** "{item.get(Keys.NAME)}" available at {model_dir}.')
# loading with `torch.jit.load`
if _model_file.name.endswith(".ts"):
if not pretrained:
warnings.warn("Loading a ScriptModule, 'pretrained' option ignored.")
if weights_only:
warnings.warn("Loading a ScriptModule, 'weights_only' option ignored.")
return torch.jit.load(_model_file, map_location=map_location)
# loading with `torch.load`
model_dict = torch.load(_model_file, map_location=map_location)
if weights_only:
return model_dict.get(model_key, model_dict) # model_dict[model_key] or model_dict directly
# 1. search `model_dict['train_config]` for model config spec.
model_config = _get_val(dict(model_dict).get("train_conf", {}), key=model_key, default={})
if not model_config or not isinstance(model_config, Mapping):
# 2. search json CONFIG_FILE for model config spec.
json_path = model_dir / item.get(Keys.CONFIG_FILE, os.path.join("config", "config_train.json"))
with open(json_path) as f:
conf_dict = json.load(f)
conf_dict = dict(conf_dict)
model_config = _get_val(conf_dict, key=model_key, default={})
if not model_config:
# 3. search `model_dict` for model config spec.
model_config = _get_val(dict(model_dict), key=model_key, default={})
if not (model_config and isinstance(model_config, Mapping)):
raise ValueError(
f"Could not load model config dictionary from config: {item.get(Keys.CONFIG_FILE)}, "
f"or from model file: {item.get(Keys.MODEL_FILE)}."
)
# parse `model_config` for model class and model parameters
if model_config.get("name"): # model config section is a "name"
model_name = model_config["name"]
model_cls = monai_nets.__dict__[model_name]
elif model_config.get("path"): # model config section is a "path"
# https://docs.nvidia.com/clara/clara-train-sdk/pt/byom.html
model_module, model_name = model_config.get("path", ".").rsplit(".", 1)
model_cls, has_cls = optional_import(module=model_module, name=model_name)
if not has_cls:
raise ValueError(
f"Could not load MMAR model config {model_config.get('path', '')}, "
f"Please make sure MMAR's sub-folders in '{model_dir}' is on the PYTHONPATH."
"See also: https://docs.nvidia.com/clara/clara-train-sdk/pt/byom.html"
)
else:
raise ValueError(f"Could not load model config {model_config}.")
logger.info(f"*** Model: {model_cls}")
model_kwargs = model_config.get("args", None)
if model_kwargs:
model_inst = model_cls(**model_kwargs)
logger.info(f"*** Model params: {model_kwargs}")
else:
model_inst = model_cls()
if pretrained:
_, changed, unchanged = copy_model_state(model_inst, model_dict.get(model_key, model_dict), inplace=True)
if not (changed and not unchanged): # not all model_inst variables are changed
logger.warning(f"*** Loading model state -- unchanged: {len(unchanged)}, changed: {len(changed)}.")
logger.info("\n---")
doc_url = item.get(Keys.DOC) or _get_ngc_doc_url(item[Keys.NAME], model_prefix="nvidia:med:")
logger.info(f"For more information, please visit {doc_url}\n")
return model_inst
def _get_val(input_dict: Mapping, key: str = "model", default: Any | None = None) -> Any | None:
"""
Search for the item with `key` in `config_dict`.
Returns: the first occurrence of `key` in a breadth first search.
"""
if key in input_dict:
return input_dict[key]
for sub_dict in input_dict:
val = input_dict[sub_dict]
if isinstance(val, Mapping):
found_val = _get_val(val, key=key, default=None)
if found_val is not None:
return found_val
return default