Source code for monailabel.tasks.train.utils

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from monai.config import KeysCollection
from monai.handlers import MeanDice, RootMeanSquaredError, from_engine
from monai.utils import ensure_tuple


[docs]def region_wise_metrics(regions, metric, prefix, keys=("pred", "label")): all_metrics = dict() all_metrics[metric] = MeanDice(output_transform=from_engine(keys), include_background=False) if regions: labels = regions if isinstance(regions, dict) else {name: idx for idx, name in enumerate(regions, start=1)} for name, idx in labels.items(): all_metrics[f"{prefix}_{name}_mean_dice"] = MeanDice( output_transform=from_engine_idx(keys, idx), include_background=False, ) return all_metrics
[docs]def from_engine_idx(keys: KeysCollection, idx, first: bool = False): keys = ensure_tuple(keys) def _match(t, idx): p, label = t p_n = [x[idx, ...][None] for x in p] l_n = [x[idx, ...][None] for x in label] return p_n, l_n def _wrapper(data): if isinstance(data, dict): return _match(tuple(data[k] for k in keys), idx) if isinstance(data, list) and isinstance(data[0], dict): ret = [data[0][k] if first else [i[k] for i in data] for k in keys] return _match(tuple(ret) if len(ret) > 1 else ret[0], idx) return _wrapper
# For regression
[docs]def region_wise_rmse(regions, metric, prefix, keys=("pred", "label")): all_metrics = dict() all_metrics[metric] = RootMeanSquaredError(output_transform=from_engine(keys)) if regions: labels = regions if isinstance(regions, dict) else {name: idx for idx, name in enumerate(regions, start=1)} for name, idx in labels.items(): all_metrics[f"{prefix}_{name}_rmse"] = RootMeanSquaredError( reduction="mean", output_transform=from_engine_idx(keys, idx), ) return all_metrics