Federated Learning#
ClientAlgo#
- class monai.fl.client.ClientAlgo[source]#
objective: provide an abstract base class for defining algo to run on any platform. To define a new algo script, subclass this class and implement the following abstract methods:
self.train()
self.get_weights()
self.evaluate()
initialize(), abort(), and finalize() can be optionally be implemented to help with lifecycle management of the class object.
- abort(extra=None)[source]#
Call to abort the ClientAlgo training or evaluation
- Parameters
extra (
Optional
[dict
]) – optional extra information
- evaluate(data, extra=None)[source]#
Get evaluation metrics on test data.
- Parameters
data (
ExchangeObject
) – ExchangeObject with network weights to use for evaluationextra (
Optional
[dict
]) – optional extra information
- Returns
ExchangeObject with evaluation metrics.
- Return type
metrics
- finalize(extra=None)[source]#
Call to finalize the ClientAlgo class
- Parameters
extra (
Optional
[dict
]) – optional extra information
- get_weights(extra=None)[source]#
Get current local weights or weight differences
- Parameters
extra (
Optional
[dict
]) – optional extra information- Returns
current local weights or weight differences.
- Return type
ExchangeObject
ExchangeObject example:
ExchangeObject( weights = self.trainer.network.state_dict(), optim = None, # could be self.optimizer.state_dict() weight_type = WeightType.WEIGHTS )
- class monai.fl.client.MonaiAlgo(bundle_root, local_epochs=1, send_weight_diff=True, config_train_filename='configs/train.json', config_evaluate_filename='default', config_filters_filename=None, disable_ckpt_loading=True, best_model_filepath='models/model.pt', final_model_filepath='models/model_final.pt', save_dict_key='model', seed=None, benchmark=True)[source]#
Implementation of
ClientAlgo
to allow federated learning with MONAI bundle configurations.- Parameters
bundle_root (
str
) – path of bundle.local_epochs (
int
) – number of local epochs to execute during each round of local training; defaults to 1.send_weight_diff (
bool
) – whether to send weight differences rather than full weights; defaults to True.config_train_filename (
Union
[str
,list
,None
]) – bundle training config path relative to bundle_root. Can be a list of files; defaults to “configs/train.json”.config_evaluate_filename (
Union
[str
,list
,None
]) – bundle evaluation config path relative to bundle_root. Can be a list of files. If “default”, config_evaluate_filename = [“configs/train.json”, “configs/evaluate.json”] will be used;config_filters_filename (
Union
[str
,list
,None
]) – filter configuration file. Can be a list of files; defaults to None.disable_ckpt_loading (
bool
) – do not use any CheckpointLoader if defined in train/evaluate configs; defaults to True.best_model_filepath (
Optional
[str
]) – location of best model checkpoint; defaults “models/model.pt” relative to bundle_root.final_model_filepath (
Optional
[str
]) – location of final model checkpoint; defaults “models/model_final.pt” relative to bundle_root.save_dict_key (
Optional
[str
]) – If a model checkpoint contains several state dicts, the one defined by save_dict_key will be returned by get_weights; defaults to “model”. If all state dicts should be returned, set save_dict_key to None.seed (
Optional
[int
]) – set random seed for modules to enable or disable deterministic training; defaults to None, i.e., non-deterministic training.benchmark (
bool
) – set benchmark to False for full deterministic behavior in cuDNN components. Note, full determinism in federated learning depends also on deterministic behavior of other FL components, e.g., the aggregator, which is not controlled by this class.
- abort(extra=None)[source]#
Abort the training or evaluation. :param extra: Dict with additional information that can be provided by FL system.
- evaluate(data, extra=None)[source]#
Evaluate on client’s local data.
- Parameters
data (
ExchangeObject
) – ExchangeObject containing the current global model weights.extra – Dict with additional information that can be provided by FL system.
- Returns
ExchangeObject containing evaluation metrics.
- Return type
return_metrics
- finalize(extra=None)[source]#
Finalize the training or evaluation. :param extra: Dict with additional information that can be provided by FL system.
- get_weights(extra=None)[source]#
Returns the current weights of the model.
- Parameters
extra – Dict with additional information that can be provided by FL system.
- Returns
- ExchangeObject containing current weights (default)
or load requested model type from disk (ModelType.BEST_MODEL or ModelType.FINAL_MODEL).
- Return type
return_weights