Federated Learning#
Client Base Classes#
- class monai.fl.client.BaseClient[source]#
Provide an abstract base class to allow the client to return summary statistics of the data.
To define a new stats script, subclass this class and implement the following abstract methods:
- self.get_data_stats()
initialize(), abort(), and finalize() – inherited from ClientAlgoStats; can be optionally be implemented to help with lifecycle management of the class object.
- abort(extra=None)[source]#
Call to abort the ClientAlgo training or evaluation.
- Parameters
extra (
Optional
[dict
]) – Dict with additional information that can be provided by the FL system.
- class monai.fl.client.ClientAlgo[source]#
Provide an abstract base class for defining algo to run on any platform. To define a new algo script, subclass this class and implement the following abstract methods:
self.train()
self.get_weights()
self.evaluate()
self.get_data_stats() (optional, inherited from ClientAlgoStats)
initialize(), abort(), and finalize() - inherited from ClientAlgoStats - can be optionally be implemented to help with lifecycle management of the class object.
- evaluate(data, extra=None)[source]#
Get evaluation metrics on test data.
- Parameters
data (
ExchangeObject
) – ExchangeObject with network weights to use for evaluation.extra (
Optional
[dict
]) – Dict with additional information that can be provided by the FL system.
- Returns
ExchangeObject with evaluation metrics.
- Return type
metrics
- get_weights(extra=None)[source]#
Get current local weights or weight differences.
- Parameters
extra (
Optional
[dict
]) – Dict with additional information that can be provided by the FL system.- Returns
current local weights or weight differences.
- Return type
ExchangeObject
ExchangeObject example:
ExchangeObject( weights = self.trainer.network.state_dict(), optim = None, # could be self.optimizer.state_dict() weight_type = WeightType.WEIGHTS )
- train(data, extra=None)[source]#
Train network and produce new network from train data.
- Parameters
data (
ExchangeObject
) – ExchangeObject containing current network weights to base training on.extra (
Optional
[dict
]) – Dict with additional information that can be provided by the FL system.
- Return type
None
- Returns
None
- class monai.fl.client.ClientAlgoStats[source]#
- get_data_stats(extra=None)[source]#
Get summary statistics about the local data.
- Parameters
extra (
Optional
[dict
]) – Dict with additional information that can be provided by the FL system. For example, requested statistics.- Returns
summary statistics.
- Return type
ExchangeObject
Extra dict example:
requested_stats = { FlStatistics.STATISTICS: metrics, FlStatistics.NUM_OF_BINS: num_of_bins, FlStatistics.BIN_RANGES: bin_ranges }
Returned ExchangeObject example:
ExchangeObject( statistics = {...} )
MONAI Bundle Reference Implementations#
- class monai.fl.client.MonaiAlgo(bundle_root, local_epochs=1, send_weight_diff=True, config_train_filename='configs/train.json', config_evaluate_filename='default', config_filters_filename=None, disable_ckpt_loading=True, best_model_filepath='models/model.pt', final_model_filepath='models/model_final.pt', save_dict_key='model', seed=None, benchmark=True, multi_gpu=False, backend='nccl', init_method='env://', train_data_key=BundleKeys.TRAIN_DATA, eval_data_key=BundleKeys.VALID_DATA, data_stats_transform_list=None)[source]#
Implementation of
ClientAlgo
to allow federated learning with MONAI bundle configurations.- Parameters
bundle_root (
str
) – path of bundle.local_epochs (
int
) – number of local epochs to execute during each round of local training; defaults to 1.send_weight_diff (
bool
) – whether to send weight differences rather than full weights; defaults to True.config_train_filename (
Union
[str
,list
,None
]) – bundle training config path relative to bundle_root. Can be a list of files; defaults to “configs/train.json”.config_evaluate_filename (
Union
[str
,list
,None
]) – bundle evaluation config path relative to bundle_root. Can be a list of files. If “default”, config_evaluate_filename = [“configs/train.json”, “configs/evaluate.json”] will be used;config_filters_filename (
Union
[str
,list
,None
]) – filter configuration file. Can be a list of files; defaults to None.disable_ckpt_loading (
bool
) – do not use any CheckpointLoader if defined in train/evaluate configs; defaults to True.best_model_filepath (
Optional
[str
]) – location of best model checkpoint; defaults “models/model.pt” relative to bundle_root.final_model_filepath (
Optional
[str
]) – location of final model checkpoint; defaults “models/model_final.pt” relative to bundle_root.save_dict_key (
Optional
[str
]) – If a model checkpoint contains several state dicts, the one defined by save_dict_key will be returned by get_weights; defaults to “model”. If all state dicts should be returned, set save_dict_key to None.seed (
Optional
[int
]) – set random seed for modules to enable or disable deterministic training; defaults to None, i.e., non-deterministic training.benchmark (
bool
) – set benchmark to False for full deterministic behavior in cuDNN components. Note, full determinism in federated learning depends also on deterministic behavior of other FL components, e.g., the aggregator, which is not controlled by this class.multi_gpu (
bool
) – whether to run MonaiAlgo in a multi-GPU setting; defaults to False.backend (
str
) – backend to use for torch.distributed; defaults to “nccl”.init_method (
str
) – init_method for torch.distributed; defaults to “env://”.
- abort(extra=None)[source]#
Abort the training or evaluation. :param extra: Dict with additional information that can be provided by the FL system.
- evaluate(data, extra=None)[source]#
Evaluate on client’s local data.
- Parameters
data (
ExchangeObject
) – ExchangeObject containing the current global model weights.extra – Dict with additional information that can be provided by the FL system.
- Returns
ExchangeObject containing evaluation metrics.
- Return type
return_metrics
- finalize(extra=None)[source]#
Finalize the training or evaluation. :param extra: Dict with additional information that can be provided by the FL system.
- get_weights(extra=None)[source]#
Returns the current weights of the model.
- Parameters
extra – Dict with additional information that can be provided by the FL system.
- Returns
- ExchangeObject containing current weights (default)
or load requested model type from disk (ModelType.BEST_MODEL or ModelType.FINAL_MODEL).
- Return type
return_weights
- class monai.fl.client.MonaiAlgoStats(bundle_root, config_train_filename='configs/train.json', config_filters_filename=None, train_data_key=BundleKeys.TRAIN_DATA, eval_data_key=BundleKeys.VALID_DATA, data_stats_transform_list=None, histogram_only=False)[source]#
Implementation of
ClientAlgo
to allow federated learning with MONAI bundle configurations.- Parameters
bundle_root (
str
) – path of bundle.config_train_filename (
Union
[str
,list
,None
]) – bundle training config path relative to bundle_root. Can be a list of files; defaults to “configs/train.json”.config_filters_filename (
Union
[str
,list
,None
]) – filter configuration file. Can be a list of files; defaults to None.histogram_only (
bool
) – whether to only compute histograms. Defaults to False.