# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import time
from collections.abc import Mapping, MutableMapping
from typing import Any, cast
import torch
import torch.distributed as dist
from monai.apps.auto3dseg.data_analyzer import DataAnalyzer
from monai.apps.utils import get_logger
from monai.auto3dseg import SegSummarizer
from monai.bundle import BundleWorkflow, ConfigComponent, ConfigItem, ConfigParser, ConfigWorkflow
from monai.engines import SupervisedEvaluator, SupervisedTrainer, Trainer
from monai.fl.client import ClientAlgo, ClientAlgoStats
from monai.fl.utils.constants import ExtraItems, FiltersType, FlPhase, FlStatistics, ModelType, WeightType
from monai.fl.utils.exchange_object import ExchangeObject
from monai.networks.utils import copy_model_state, get_state_dict
from monai.utils import min_version, require_pkg
from monai.utils.enums import DataStatsKeys
logger = get_logger(__name__)
def convert_global_weights(global_weights: Mapping, local_var_dict: MutableMapping) -> tuple[MutableMapping, int]:
"""Helper function to convert global weights to local weights format"""
# Before loading weights, tensors might need to be reshaped to support HE for secure aggregation.
model_keys = global_weights.keys()
n_converted = 0
for var_name in local_var_dict:
if var_name in model_keys:
weights = global_weights[var_name]
try:
# reshape global weights to compute difference later on
weights = torch.reshape(torch.as_tensor(weights), local_var_dict[var_name].shape)
# update the local dict
local_var_dict[var_name] = weights
n_converted += 1
except Exception as e:
raise ValueError(f"Convert weight from {var_name} failed.") from e
return local_var_dict, n_converted
def compute_weight_diff(global_weights, local_var_dict):
if global_weights is None:
raise ValueError("Cannot compute weight differences if `global_weights` is None!")
if local_var_dict is None:
raise ValueError("Cannot compute weight differences if `local_var_dict` is None!")
# compute delta model, global model has the primary key set
weight_diff = {}
n_diff = 0
for name in global_weights:
if name not in local_var_dict:
continue
# returned weight diff will be on the cpu
weight_diff[name] = local_var_dict[name].cpu() - global_weights[name].cpu()
n_diff += 1
if torch.any(torch.isnan(weight_diff[name])):
raise ValueError(f"Weights for {name} became NaN...")
if n_diff == 0:
raise RuntimeError("No weight differences computed!")
return weight_diff
def disable_ckpt_loaders(parser: ConfigParser) -> None:
if "validate#handlers" in parser:
for h in parser["validate#handlers"]:
if ConfigComponent.is_instantiable(h):
if "CheckpointLoader" in h["_target_"]:
h["_disabled_"] = True
[docs]
class MonaiAlgoStats(ClientAlgoStats):
"""
Implementation of ``ClientAlgoStats`` to allow federated learning with MONAI bundle configurations.
Args:
bundle_root: directory path of the bundle.
config_train_filename: bundle training config path relative to bundle_root. Can be a list of files;
defaults to "configs/train.json". only useful when `workflow` is None.
config_filters_filename: filter configuration file. Can be a list of files; defaults to `None`.
data_stats_transform_list: transforms to apply for the data stats result.
histogram_only: whether to only compute histograms. Defaults to False.
workflow: the bundle workflow to execute, usually it's training, evaluation or inference.
if None, will create an `ConfigWorkflow` internally based on `config_train_filename`.
"""
def __init__(
self,
bundle_root: str,
config_train_filename: str | list | None = "configs/train.json",
config_filters_filename: str | list | None = None,
data_stats_transform_list: list | None = None,
histogram_only: bool = False,
workflow: BundleWorkflow | None = None,
):
self.logger = logger
self.bundle_root = bundle_root
self.config_train_filename = config_train_filename
self.config_filters_filename = config_filters_filename
self.train_data_key = "train"
self.eval_data_key = "eval"
self.data_stats_transform_list = data_stats_transform_list
self.histogram_only = histogram_only
self.workflow = None
if workflow is not None:
if not isinstance(workflow, BundleWorkflow):
raise ValueError("workflow must be a subclass of BundleWorkflow.")
if workflow.get_workflow_type() is None:
raise ValueError("workflow doesn't specify the type.")
self.workflow = workflow
self.client_name: str | None = None
self.app_root: str = ""
self.post_statistics_filters: Any = None
self.phase = FlPhase.IDLE
self.dataset_root: Any = None
[docs]
def initialize(self, extra=None):
"""
Initialize routine to parse configuration files and extract main components such as trainer, evaluator, and filters.
Args:
extra: Dict with additional information that should be provided by FL system,
i.e., `ExtraItems.CLIENT_NAME` and `ExtraItems.APP_ROOT`.
"""
if extra is None:
extra = {}
self.client_name = extra.get(ExtraItems.CLIENT_NAME, "noname")
self.logger.info(f"Initializing {self.client_name} ...")
# FL platform needs to provide filepath to configuration files
self.app_root = extra.get(ExtraItems.APP_ROOT, "")
self.bundle_root = os.path.join(self.app_root, self.bundle_root)
if self.workflow is None:
config_train_files = self._add_config_files(self.config_train_filename)
self.workflow = ConfigWorkflow(
config_file=config_train_files, meta_file=None, logging_file=None, workflow_type="train"
)
self.workflow.initialize()
self.workflow.bundle_root = self.bundle_root
# initialize the workflow as the content changed
self.workflow.initialize()
config_filter_files = self._add_config_files(self.config_filters_filename)
filter_parser = ConfigParser()
if len(config_filter_files) > 0:
filter_parser.read_config(config_filter_files)
# Get filters
self.post_statistics_filters = filter_parser.get_parsed_content(
FiltersType.POST_STATISTICS_FILTERS, default=ConfigItem(None, FiltersType.POST_STATISTICS_FILTERS)
)
self.logger.info(f"Initialized {self.client_name}.")
[docs]
def get_data_stats(self, extra: dict | None = None) -> ExchangeObject:
"""
Returns summary statistics about the local data.
Args:
extra: Dict with additional information that can be provided by the FL system.
Both FlStatistics.HIST_BINS and FlStatistics.HIST_RANGE must be provided.
Returns:
stats: ExchangeObject with summary statistics.
"""
if extra is None:
raise ValueError("`extra` has to be set")
if self.workflow.dataset_dir: # type: ignore
self.phase = FlPhase.GET_DATA_STATS
self.logger.info(f"Computing statistics on {self.workflow.dataset_dir}") # type: ignore
if FlStatistics.HIST_BINS not in extra:
raise ValueError("FlStatistics.NUM_OF_BINS not specified in `extra`")
else:
hist_bins = extra[FlStatistics.HIST_BINS]
if FlStatistics.HIST_RANGE not in extra:
raise ValueError("FlStatistics.HIST_RANGE not specified in `extra`")
else:
hist_range = extra[FlStatistics.HIST_RANGE]
stats_dict = {}
# train data stats
train_summary_stats, train_case_stats = self._get_data_key_stats(
data=self.workflow.train_dataset_data, # type: ignore
data_key=self.train_data_key,
hist_bins=hist_bins,
hist_range=hist_range,
output_path=os.path.join(self.app_root, "train_data_stats.yaml"),
)
if train_case_stats:
# Only return summary statistics to FL server
stats_dict.update({self.train_data_key: train_summary_stats})
# eval data stats
eval_summary_stats = None
eval_case_stats = None
if self.workflow.val_dataset_data is not None: # type: ignore
eval_summary_stats, eval_case_stats = self._get_data_key_stats(
data=self.workflow.val_dataset_data, # type: ignore
data_key=self.eval_data_key,
hist_bins=hist_bins,
hist_range=hist_range,
output_path=os.path.join(self.app_root, "eval_data_stats.yaml"),
)
else:
self.logger.warning("the datalist doesn't contain validation section.")
if eval_summary_stats:
# Only return summary statistics to FL server
stats_dict.update({self.eval_data_key: eval_summary_stats})
# total stats
if train_case_stats and eval_case_stats:
# Compute total summary
total_summary_stats = self._compute_total_stats(
[train_case_stats, eval_case_stats], hist_bins, hist_range
)
stats_dict.update({FlStatistics.TOTAL_DATA: total_summary_stats})
# optional filter of data stats
stats = ExchangeObject(statistics=stats_dict)
if self.post_statistics_filters is not None:
for _filter in self.post_statistics_filters:
stats = _filter(stats, extra)
return stats
else:
raise ValueError("data_root not set!")
def _get_data_key_stats(self, data, data_key, hist_bins, hist_range, output_path=None):
analyzer = DataAnalyzer(
datalist={data_key: data},
dataroot=self.workflow.dataset_dir, # type: ignore
hist_bins=hist_bins,
hist_range=hist_range,
output_path=output_path,
histogram_only=self.histogram_only,
)
self.logger.info(f"{self.client_name} compute data statistics on {data_key}...")
all_stats = analyzer.get_all_case_stats(transform_list=self.data_stats_transform_list, key=data_key)
case_stats = all_stats[DataStatsKeys.BY_CASE]
summary_stats = {
FlStatistics.DATA_STATS: all_stats[DataStatsKeys.SUMMARY],
FlStatistics.DATA_COUNT: len(data),
FlStatistics.FAIL_COUNT: len(data) - len(case_stats),
# TODO: add shapes, voxels sizes, etc.
}
return summary_stats, case_stats
@staticmethod
def _compute_total_stats(case_stats_lists, hist_bins, hist_range):
# Compute total summary
total_case_stats = []
for case_stats_list in case_stats_lists:
total_case_stats += case_stats_list
summarizer = SegSummarizer(
"image", "label", average=True, do_ccp=True, hist_bins=hist_bins, hist_range=hist_range
)
total_summary_stats = summarizer.summarize(total_case_stats)
summary_stats = {
FlStatistics.DATA_STATS: total_summary_stats,
FlStatistics.DATA_COUNT: len(total_case_stats),
FlStatistics.FAIL_COUNT: 0,
}
return summary_stats
def _add_config_files(self, config_files):
files = []
if config_files:
if isinstance(config_files, str):
files.append(os.path.join(self.bundle_root, config_files))
elif isinstance(config_files, list):
for file in config_files:
if isinstance(file, str):
files.append(os.path.join(self.bundle_root, file))
else:
raise ValueError(f"Expected config file to be of type str but got {type(file)}: {file}")
else:
raise ValueError(
f"Expected config files to be of type str or list but got {type(config_files)}: {config_files}"
)
return files
[docs]
@require_pkg(pkg_name="ignite", version="0.4.10", version_checker=min_version)
class MonaiAlgo(ClientAlgo, MonaiAlgoStats):
"""
Implementation of ``ClientAlgo`` to allow federated learning with MONAI bundle configurations.
Args:
bundle_root: directory path of the bundle.
local_epochs: number of local epochs to execute during each round of local training; defaults to 1.
send_weight_diff: whether to send weight differences rather than full weights; defaults to `True`.
config_train_filename: bundle training config path relative to bundle_root. can be a list of files.
defaults to "configs/train.json". only useful when `train_workflow` is None.
train_kwargs: other args of the `ConfigWorkflow` of train, except for `config_file`, `meta_file`,
`logging_file`, `workflow_type`. only useful when `train_workflow` is None.
config_evaluate_filename: bundle evaluation config path relative to bundle_root. can be a list of files.
if "default", ["configs/train.json", "configs/evaluate.json"] will be used.
this arg is only useful when `eval_workflow` is None.
eval_kwargs: other args of the `ConfigWorkflow` of evaluation, except for `config_file`, `meta_file`,
`logging_file`, `workflow_type`. only useful when `eval_workflow` is None.
config_filters_filename: filter configuration file. Can be a list of files; defaults to `None`.
disable_ckpt_loading: do not use any CheckpointLoader if defined in train/evaluate configs; defaults to `True`.
best_model_filepath: location of best model checkpoint; defaults "models/model.pt" relative to `bundle_root`.
final_model_filepath: location of final model checkpoint; defaults "models/model_final.pt" relative to `bundle_root`.
save_dict_key: If a model checkpoint contains several state dicts,
the one defined by `save_dict_key` will be returned by `get_weights`; defaults to "model".
If all state dicts should be returned, set `save_dict_key` to None.
data_stats_transform_list: transforms to apply for the data stats result.
eval_workflow_name: the workflow name corresponding to the "config_evaluate_filename", default to "train"
as the default "config_evaluate_filename" overrides the train workflow config.
this arg is only useful when `eval_workflow` is None.
train_workflow: the bundle workflow to execute training, if None, will create a `ConfigWorkflow` internally
based on `config_train_filename` and `train_kwargs`.
eval_workflow: the bundle workflow to execute evaluation, if None, will create a `ConfigWorkflow` internally
based on `config_evaluate_filename`, `eval_kwargs`, `eval_workflow_name`.
"""
def __init__(
self,
bundle_root: str,
local_epochs: int = 1,
send_weight_diff: bool = True,
config_train_filename: str | list | None = "configs/train.json",
train_kwargs: dict | None = None,
config_evaluate_filename: str | list | None = "default",
eval_kwargs: dict | None = None,
config_filters_filename: str | list | None = None,
disable_ckpt_loading: bool = True,
best_model_filepath: str | None = "models/model.pt",
final_model_filepath: str | None = "models/model_final.pt",
save_dict_key: str | None = "model",
data_stats_transform_list: list | None = None,
eval_workflow_name: str = "train",
train_workflow: BundleWorkflow | None = None,
eval_workflow: BundleWorkflow | None = None,
):
self.logger = logger
self.bundle_root = bundle_root
self.local_epochs = local_epochs
self.send_weight_diff = send_weight_diff
self.config_train_filename = config_train_filename
self.train_kwargs = {} if train_kwargs is None else train_kwargs
if config_evaluate_filename == "default":
# by default, evaluator needs both training and evaluate to be instantiated
config_evaluate_filename = ["configs/train.json", "configs/evaluate.json"]
self.config_evaluate_filename = config_evaluate_filename
self.eval_kwargs = {} if eval_kwargs is None else eval_kwargs
self.config_filters_filename = config_filters_filename
self.disable_ckpt_loading = disable_ckpt_loading
self.model_filepaths = {ModelType.BEST_MODEL: best_model_filepath, ModelType.FINAL_MODEL: final_model_filepath}
self.save_dict_key = save_dict_key
self.data_stats_transform_list = data_stats_transform_list
self.eval_workflow_name = eval_workflow_name
self.train_workflow = None
self.eval_workflow = None
if train_workflow is not None:
if not isinstance(train_workflow, BundleWorkflow) or train_workflow.get_workflow_type() != "train":
raise ValueError(
f"train workflow must be BundleWorkflow and set type in {BundleWorkflow.supported_train_type}."
)
self.train_workflow = train_workflow
if eval_workflow is not None:
# evaluation workflow can be "train" type or "infer" type
if not isinstance(eval_workflow, BundleWorkflow) or eval_workflow.get_workflow_type() is None:
raise ValueError("train workflow must be BundleWorkflow and set type.")
self.eval_workflow = eval_workflow
self.stats_sender = None
self.app_root = ""
self.filter_parser: ConfigParser | None = None
self.trainer: SupervisedTrainer | None = None
self.evaluator: SupervisedEvaluator | None = None
self.pre_filters = None
self.post_weight_filters = None
self.post_evaluate_filters = None
self.iter_of_start_time = 0
self.global_weights: Mapping | None = None
self.phase = FlPhase.IDLE
self.client_name = None
self.dataset_root = None
[docs]
def initialize(self, extra=None):
"""
Initialize routine to parse configuration files and extract main components such as trainer, evaluator, and filters.
Args:
extra: Dict with additional information that should be provided by FL system,
i.e., `ExtraItems.CLIENT_NAME` and `ExtraItems.APP_ROOT`.
"""
self._set_cuda_device()
if extra is None:
extra = {}
self.client_name = extra.get(ExtraItems.CLIENT_NAME, "noname")
timestamp = time.strftime("%Y%m%d_%H%M%S")
self.logger.info(f"Initializing {self.client_name} ...")
# FL platform needs to provide filepath to configuration files
self.app_root = extra.get(ExtraItems.APP_ROOT, "")
self.bundle_root = os.path.join(self.app_root, self.bundle_root)
if self.train_workflow is None and self.config_train_filename is not None:
config_train_files = self._add_config_files(self.config_train_filename)
# if enabled experiment tracking, set the run name to the FL client name and timestamp,
# expect the tracking settings use "run_name" to define the run name
if "run_name" not in self.train_kwargs:
self.train_kwargs["run_name"] = f"{self.client_name}_{timestamp}"
self.train_workflow = ConfigWorkflow(
config_file=config_train_files,
meta_file=None,
logging_file=None,
workflow_type="train",
**self.train_kwargs,
)
if self.train_workflow is not None:
self.train_workflow.initialize()
self.train_workflow.bundle_root = self.bundle_root
self.train_workflow.max_epochs = self.local_epochs
if self.disable_ckpt_loading and isinstance(self.train_workflow, ConfigWorkflow):
disable_ckpt_loaders(parser=self.train_workflow.parser)
# initialize the workflow as the content changed
self.train_workflow.initialize()
self.trainer = self.train_workflow.trainer
if not isinstance(self.trainer, SupervisedTrainer):
raise ValueError(f"trainer must be SupervisedTrainer, but got: {type(self.trainer)}.")
if self.eval_workflow is None and self.config_evaluate_filename is not None:
config_eval_files = self._add_config_files(self.config_evaluate_filename)
# if enabled experiment tracking, set the run name to the FL client name and timestamp,
# expect the tracking settings use "run_name" to define the run name
if "run_name" not in self.eval_kwargs:
self.eval_kwargs["run_name"] = f"{self.client_name}_{timestamp}"
self.eval_workflow = ConfigWorkflow(
config_file=config_eval_files,
meta_file=None,
logging_file=None,
workflow_type=self.eval_workflow_name,
**self.eval_kwargs,
)
if self.eval_workflow is not None:
self.eval_workflow.initialize()
self.eval_workflow.bundle_root = self.bundle_root
if self.disable_ckpt_loading and isinstance(self.eval_workflow, ConfigWorkflow):
disable_ckpt_loaders(parser=self.eval_workflow.parser)
# initialize the workflow as the content changed
self.eval_workflow.initialize()
self.evaluator = self.eval_workflow.evaluator
if not isinstance(self.evaluator, SupervisedEvaluator):
raise ValueError(f"evaluator must be SupervisedEvaluator, but got: {type(self.evaluator)}.")
config_filter_files = self._add_config_files(self.config_filters_filename)
self.filter_parser = ConfigParser()
if len(config_filter_files) > 0:
self.filter_parser.read_config(config_filter_files)
# set stats sender for nvflare
self.stats_sender = extra.get(ExtraItems.STATS_SENDER, self.stats_sender)
if self.stats_sender is not None:
self.stats_sender.attach(self.trainer)
self.stats_sender.attach(self.evaluator)
# Get filters
self.pre_filters = self.filter_parser.get_parsed_content(
FiltersType.PRE_FILTERS, default=ConfigItem(None, FiltersType.PRE_FILTERS)
)
self.post_weight_filters = self.filter_parser.get_parsed_content(
FiltersType.POST_WEIGHT_FILTERS, default=ConfigItem(None, FiltersType.POST_WEIGHT_FILTERS)
)
self.post_evaluate_filters = self.filter_parser.get_parsed_content(
FiltersType.POST_EVALUATE_FILTERS, default=ConfigItem(None, FiltersType.POST_EVALUATE_FILTERS)
)
self.post_statistics_filters = self.filter_parser.get_parsed_content(
FiltersType.POST_STATISTICS_FILTERS, default=ConfigItem(None, FiltersType.POST_STATISTICS_FILTERS)
)
self.logger.info(f"Initialized {self.client_name}.")
[docs]
def train(self, data: ExchangeObject, extra: dict | None = None) -> None:
"""
Train on client's local data.
Args:
data: `ExchangeObject` containing the current global model weights.
extra: Dict with additional information that can be provided by the FL system.
"""
self._set_cuda_device()
if extra is None:
extra = {}
if not isinstance(data, ExchangeObject):
raise ValueError(f"expected data to be ExchangeObject but received {type(data)}")
if self.trainer is None:
raise ValueError("self.trainer should not be None.")
if self.pre_filters is not None:
for _filter in self.pre_filters:
data = _filter(data, extra)
self.phase = FlPhase.TRAIN
self.logger.info(f"Load {self.client_name} weights...")
local_var_dict = get_state_dict(self.trainer.network)
self.global_weights, n_converted = convert_global_weights(
global_weights=cast(dict, data.weights), local_var_dict=local_var_dict
)
self._check_converted(data.weights, local_var_dict, n_converted)
# set engine state max epochs.
self.trainer.state.max_epochs = self.trainer.state.epoch + self.local_epochs
# get current iteration when a round starts
self.iter_of_start_time = self.trainer.state.iteration
_, updated_keys, _ = copy_model_state(src=cast(Mapping, self.global_weights), dst=self.trainer.network)
if len(updated_keys) == 0:
self.logger.warning("No weights loaded!")
self.logger.info(f"Start {self.client_name} training...")
self.trainer.run()
[docs]
def get_weights(self, extra=None):
"""
Returns the current weights of the model.
Args:
extra: Dict with additional information that can be provided by the FL system.
Returns:
return_weights: `ExchangeObject` containing current weights (default)
or load requested model type from disk (`ModelType.BEST_MODEL` or `ModelType.FINAL_MODEL`).
"""
self._set_cuda_device()
if extra is None:
extra = {}
# by default return current weights, return best if requested via model type.
self.phase = FlPhase.GET_WEIGHTS
if ExtraItems.MODEL_TYPE in extra:
model_type = extra.get(ExtraItems.MODEL_TYPE)
if not isinstance(model_type, ModelType):
raise ValueError(
f"Expected requested model type to be of type `ModelType` but received {type(model_type)}"
)
if model_type in self.model_filepaths:
model_path = os.path.join(self.bundle_root, cast(str, self.model_filepaths[model_type]))
if not os.path.isfile(model_path):
raise ValueError(f"No best model checkpoint exists at {model_path}")
weights = torch.load(model_path, map_location="cpu")
# if weights contain several state dicts, use the one defined by `save_dict_key`
if isinstance(weights, dict) and self.save_dict_key in weights:
weights = weights.get(self.save_dict_key)
weigh_type: WeightType | None = WeightType.WEIGHTS
stats: dict = {}
self.logger.info(f"Returning {model_type} checkpoint weights from {model_path}.")
else:
raise ValueError(
f"Requested model type {model_type} not specified in `model_filepaths`: {self.model_filepaths}"
)
else:
if self.trainer:
weights = get_state_dict(self.trainer.network)
# returned weights will be on the cpu
for k in weights.keys():
weights[k] = weights[k].cpu()
weigh_type = WeightType.WEIGHTS
stats = self.trainer.get_stats()
# calculate current iteration and epoch data after training.
stats[FlStatistics.NUM_EXECUTED_ITERATIONS] = self.trainer.state.iteration - self.iter_of_start_time
# compute weight differences
if self.send_weight_diff:
weights = compute_weight_diff(global_weights=self.global_weights, local_var_dict=weights)
weigh_type = WeightType.WEIGHT_DIFF
self.logger.info("Returning current weight differences.")
else:
self.logger.info("Returning current weights.")
else:
weights = None
weigh_type = None
stats = dict()
if not isinstance(stats, dict):
raise ValueError(f"stats is not a dict, {stats}")
return_weights = ExchangeObject(
weights=weights,
optim=None, # could be self.optimizer.state_dict()
weight_type=weigh_type,
statistics=stats,
)
# filter weights if needed (use to apply differential privacy, encryption, compression, etc.)
if self.post_weight_filters is not None:
for _filter in self.post_weight_filters:
return_weights = _filter(return_weights, extra)
return return_weights
[docs]
def evaluate(self, data: ExchangeObject, extra: dict | None = None) -> ExchangeObject:
"""
Evaluate on client's local data.
Args:
data: `ExchangeObject` containing the current global model weights.
extra: Dict with additional information that can be provided by the FL system.
Returns:
return_metrics: `ExchangeObject` containing evaluation metrics.
"""
self._set_cuda_device()
if extra is None:
extra = {}
if not isinstance(data, ExchangeObject):
raise ValueError(f"expected data to be ExchangeObject but received {type(data)}")
if self.evaluator is None:
raise ValueError("self.evaluator should not be None.")
if self.pre_filters is not None:
for _filter in self.pre_filters:
data = _filter(data, extra)
self.phase = FlPhase.EVALUATE
self.logger.info(f"Load {self.client_name} weights...")
local_var_dict = get_state_dict(self.evaluator.network)
global_weights, n_converted = convert_global_weights(
global_weights=cast(dict, data.weights), local_var_dict=local_var_dict
)
self._check_converted(data.weights, local_var_dict, n_converted)
_, updated_keys, _ = copy_model_state(src=global_weights, dst=self.evaluator.network)
if len(updated_keys) == 0:
self.logger.warning("No weights loaded!")
self.logger.info(f"Start {self.client_name} evaluating...")
if isinstance(self.trainer, Trainer):
self.evaluator.run(self.trainer.state.epoch + 1)
else:
self.evaluator.run()
return_metrics = ExchangeObject(metrics=self.evaluator.state.metrics)
if self.post_evaluate_filters is not None:
for _filter in self.post_evaluate_filters:
return_metrics = _filter(return_metrics, extra)
return return_metrics
[docs]
def abort(self, extra=None):
"""
Abort the training or evaluation.
Args:
extra: Dict with additional information that can be provided by the FL system.
"""
self.logger.info(f"Aborting {self.client_name} during {self.phase} phase.")
if isinstance(self.trainer, Trainer):
self.logger.info(f"Aborting {self.client_name} trainer...")
self.trainer.interrupt()
if isinstance(self.evaluator, Trainer):
self.logger.info(f"Aborting {self.client_name} evaluator...")
self.evaluator.interrupt()
[docs]
def finalize(self, extra: dict | None = None) -> None:
"""
Finalize the training or evaluation.
Args:
extra: Dict with additional information that can be provided by the FL system.
"""
self.logger.info(f"Terminating {self.client_name} during {self.phase} phase.")
if isinstance(self.trainer, Trainer):
self.logger.info(f"Terminating {self.client_name} trainer...")
self.trainer.terminate()
if isinstance(self.evaluator, Trainer):
self.logger.info(f"Terminating {self.client_name} evaluator...")
self.evaluator.terminate()
if self.train_workflow is not None:
self.train_workflow.finalize()
if self.eval_workflow is not None:
self.eval_workflow.finalize()
def _check_converted(self, global_weights, local_var_dict, n_converted):
if n_converted == 0:
raise RuntimeError(
f"No global weights converted! Received weight dict keys are {list(global_weights.keys())}"
)
else:
self.logger.info(
f"Converted {n_converted} global variables to match {len(local_var_dict)} local variables."
)
def _set_cuda_device(self):
if dist.is_initialized():
self.rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(self.rank)