Federated Learning#

ClientAlgo#

class monai.fl.client.ClientAlgo[source]#

objective: provide an abstract base class for defining algo to run on any platform. To define a new algo script, subclass this class and implement the following abstract methods:

  • self.train()

  • self.get_weights()

  • self.evaluate()

initialize(), abort(), and finalize() can be optionally be implemented to help with lifecycle management of the class object.

abort(extra=None)[source]#

Call to abort the ClientAlgo training or evaluation

Parameters

extra (Optional[dict]) – optional extra information

evaluate(data, extra=None)[source]#

Get evaluation metrics on test data.

Parameters
  • data (ExchangeObject) – ExchangeObject with network weights to use for evaluation

  • extra (Optional[dict]) – optional extra information

Returns

ExchangeObject with evaluation metrics.

Return type

metrics

finalize(extra=None)[source]#

Call to finalize the ClientAlgo class

Parameters

extra (Optional[dict]) – optional extra information

get_weights(extra=None)[source]#

Get current local weights or weight differences

Parameters

extra (Optional[dict]) – optional extra information

Returns

current local weights or weight differences.

Return type

ExchangeObject

ExchangeObject example:

ExchangeObject(
    weights = self.trainer.network.state_dict(),
    optim = None,  # could be self.optimizer.state_dict()
    weight_type = WeightType.WEIGHTS
)
initialize(extra=None)[source]#

Call to initialize the ClientAlgo class

Parameters

extra (Optional[dict]) – optional extra information, e.g. dict of ExtraItems.CLIENT_NAME and/or ExtraItems.APP_ROOT

train(data, extra=None)[source]#

Train network and produce new network from train data.

Parameters
  • data (ExchangeObject) – ExchangeObject containing current network weights to base training on.

  • extra (Optional[dict]) – optional extra information

Return type

None

Returns

None

class monai.fl.client.MonaiAlgo(bundle_root, local_epochs=1, send_weight_diff=True, config_train_filename='configs/train.json', config_evaluate_filename='default', config_filters_filename=None, disable_ckpt_loading=True, best_model_filepath='models/model.pt', final_model_filepath='models/model_final.pt', save_dict_key='model', seed=None, benchmark=True)[source]#

Implementation of ClientAlgo to allow federated learning with MONAI bundle configurations.

Parameters
  • bundle_root (str) – path of bundle.

  • local_epochs (int) – number of local epochs to execute during each round of local training; defaults to 1.

  • send_weight_diff (bool) – whether to send weight differences rather than full weights; defaults to True.

  • config_train_filename (Union[str, list, None]) – bundle training config path relative to bundle_root. Can be a list of files; defaults to “configs/train.json”.

  • config_evaluate_filename (Union[str, list, None]) – bundle evaluation config path relative to bundle_root. Can be a list of files. If “default”, config_evaluate_filename = [“configs/train.json”, “configs/evaluate.json”] will be used;

  • config_filters_filename (Union[str, list, None]) – filter configuration file. Can be a list of files; defaults to None.

  • disable_ckpt_loading (bool) – do not use any CheckpointLoader if defined in train/evaluate configs; defaults to True.

  • best_model_filepath (Optional[str]) – location of best model checkpoint; defaults “models/model.pt” relative to bundle_root.

  • final_model_filepath (Optional[str]) – location of final model checkpoint; defaults “models/model_final.pt” relative to bundle_root.

  • save_dict_key (Optional[str]) – If a model checkpoint contains several state dicts, the one defined by save_dict_key will be returned by get_weights; defaults to “model”. If all state dicts should be returned, set save_dict_key to None.

  • seed (Optional[int]) – set random seed for modules to enable or disable deterministic training; defaults to None, i.e., non-deterministic training.

  • benchmark (bool) – set benchmark to False for full deterministic behavior in cuDNN components. Note, full determinism in federated learning depends also on deterministic behavior of other FL components, e.g., the aggregator, which is not controlled by this class.

abort(extra=None)[source]#

Abort the training or evaluation. :param extra: Dict with additional information that can be provided by FL system.

evaluate(data, extra=None)[source]#

Evaluate on client’s local data.

Parameters
  • data (ExchangeObject) – ExchangeObject containing the current global model weights.

  • extra – Dict with additional information that can be provided by FL system.

Returns

ExchangeObject containing evaluation metrics.

Return type

return_metrics

finalize(extra=None)[source]#

Finalize the training or evaluation. :param extra: Dict with additional information that can be provided by FL system.

get_weights(extra=None)[source]#

Returns the current weights of the model.

Parameters

extra – Dict with additional information that can be provided by FL system.

Returns

ExchangeObject containing current weights (default)

or load requested model type from disk (ModelType.BEST_MODEL or ModelType.FINAL_MODEL).

Return type

return_weights

initialize(extra=None)[source]#

Initialize routine to parse configuration files and extract main components such as trainer, evaluator, and filters.

Parameters

extra – Dict with additional information that should be provided by FL system, i.e., ExtraItems.CLIENT_NAME and ExtraItems.APP_ROOT.

train(data, extra=None)[source]#

Train on client’s local data.

Parameters
  • data (ExchangeObject) – ExchangeObject containing the current global model weights.

  • extra – Dict with additional information that can be provided by FL system.