monailabel.tasks.train.basic_train module

class monailabel.tasks.train.basic_train.BasicTrainTask(model_dir, description=None, config=None, amp=True, load_path=None, load_dict=None, publish_path=None, stats_path=None, train_save_interval=20, val_interval=1, n_saved=5, final_filename='checkpoint_final.pt', key_metric_filename='model.pt', model_dict_key='model', find_unused_parameters=False, load_strict=False, labels=None, disable_meta_tracking=False, tracking='mlflow', tracking_uri='', tracking_experiment_name=None)[source]

Bases: monailabel.interfaces.tasks.train.TrainTask

This provides Basic Train Task to train a model using SupervisedTrainer and SupervisedEvaluator from MONAI

Parameters
  • model_dir – Base Model Dir to save the model checkpoints, events etc…

  • description – Description for this task

  • config – K,V pairs to be part of user config

  • amp – Enable AMP for training

  • load_path – Initialize model from existing checkpoint (pre-trained)

  • load_dict – Provide dictionary to load from checkpoint. If None, then net will be loaded

  • publish_path – Publish path for best trained model (based on best key metric)

  • stats_path – Path to save the train stats

  • train_save_interval – checkpoint save interval for training

  • val_interval – validation interval (run every x epochs)

  • n_saved – max checkpoints to save

  • final_filename – name of final checkpoint that will be saved

  • key_metric_filename – best key metric model file name

  • model_dict_key – key to save network weights into checkpoint

  • find_unused_parameters – Applicable for DDP/Multi GPU training

  • load_strict – Load pre-trained model in strict mode

  • labels – Labels to be used as part of training context (some transform might need)

  • disable_meta_tracking – Disable tracking for faster training rate (unless you are using MetaTensor/batched transforms)

  • tracking – Tracking Manager for Experiment Management (only ‘mlflow’ is supported)

  • tracking_uri – Tracking URI for Experiment Management

  • tracking_experiment_name – Name for tracking experiment

TRAIN_METRIC_ACCURACY = 'train_acc'
TRAIN_METRIC_MEAN_DICE = 'train_mean_dice'
VAL_METRIC_ACCURACY = 'val_acc'
VAL_METRIC_MEAN_DICE = 'val_mean_dice'
__init__(model_dir, description=None, config=None, amp=True, load_path=None, load_dict=None, publish_path=None, stats_path=None, train_save_interval=20, val_interval=1, n_saved=5, final_filename='checkpoint_final.pt', key_metric_filename='model.pt', model_dict_key='model', find_unused_parameters=False, load_strict=False, labels=None, disable_meta_tracking=False, tracking='mlflow', tracking_uri='', tracking_experiment_name=None)[source]
Parameters
  • model_dir – Base Model Dir to save the model checkpoints, events etc…

  • description – Description for this task

  • config – K,V pairs to be part of user config

  • amp – Enable AMP for training

  • load_path – Initialize model from existing checkpoint (pre-trained)

  • load_dict – Provide dictionary to load from checkpoint. If None, then net will be loaded

  • publish_path – Publish path for best trained model (based on best key metric)

  • stats_path – Path to save the train stats

  • train_save_interval – checkpoint save interval for training

  • val_interval – validation interval (run every x epochs)

  • n_saved – max checkpoints to save

  • final_filename – name of final checkpoint that will be saved

  • key_metric_filename – best key metric model file name

  • model_dict_key – key to save network weights into checkpoint

  • find_unused_parameters – Applicable for DDP/Multi GPU training

  • load_strict – Load pre-trained model in strict mode

  • labels – Labels to be used as part of training context (some transform might need)

  • disable_meta_tracking – Disable tracking for faster training rate (unless you are using MetaTensor/batched transforms)

  • tracking – Tracking Manager for Experiment Management (only ‘mlflow’ is supported)

  • tracking_uri – Tracking URI for Experiment Management

  • tracking_experiment_name – Name for tracking experiment

cleanup(request)[source]
config()[source]
event_names(context)[source]
finalize(context)[source]
get_cache_dir(request)[source]
info()[source]
load_path(output_dir, pretrained=True)[source]
abstract loss_function(context)[source]
lr_scheduler_handler(context)[source]
abstract network(context)[source]
abstract optimizer(context)[source]
partition_datalist(context, shuffle=False)[source]
pre_process(request, datastore)[source]
stats()[source]
train(rank, world_size, request, datalist)[source]
train_additional_metrics(context)[source]
train_data_loader(context, num_workers=0, shuffle=True)[source]
train_handlers(context)[source]
train_inferer(context)[source]
train_iteration_update(context)[source]
train_key_metric(context)[source]
abstract train_post_transforms(context)[source]
abstract train_pre_transforms(context)[source]
val_additional_metrics(context)[source]
val_data_loader(context, num_workers=0)[source]
val_handlers(context)[source]
abstract val_inferer(context)[source]
val_iteration_update(context)[source]
val_key_metric(context)[source]
val_post_transforms(context)[source]
val_pre_transforms(context)[source]
class monailabel.tasks.train.basic_train.Context[source]

Bases: object

monailabel.tasks.train.basic_train.main_worker(rank, world_size, request, datalist, task)[source]