monailabel.tasks.train.basic_train module

class monailabel.tasks.train.basic_train.BasicTrainTask(model_dir, description=None, config=None, amp=True, load_path=None, load_dict=None, publish_path=None, stats_path=None, train_save_interval=50, val_interval=1, final_filename='checkpoint_final.pt', key_metric_filename='model.pt')[source]

Bases: monailabel.interfaces.tasks.train.TrainTask

This provides Basic Train Task to train a model using SupervisedTrainer and SupervisedEvaluator from MONAI

Parameters
  • model_dir – Base Model Dir to save the model checkpoints, events etc…

  • description – Description for this task

  • config – K,V pairs to be part of user config

  • amp – Enable AMP for training

  • load_path – Initialize model from existing checkpoint (pre-trained)

  • load_dict – Provide dictionary to load from checkpoint. If None, then net will be loaded

  • publish_path – Publish path for best trained model (based on best key metric)

  • stats_path – Path to save the train stats

  • train_save_interval – checkpoint save interval for training

  • val_interval – validation interval (run every x epochs)

  • final_filename – name of final checkpoint that will be saved

  • key_metric_filename – best key metric model file name

__init__(model_dir, description=None, config=None, amp=True, load_path=None, load_dict=None, publish_path=None, stats_path=None, train_save_interval=50, val_interval=1, final_filename='checkpoint_final.pt', key_metric_filename='model.pt')[source]
Parameters
  • model_dir – Base Model Dir to save the model checkpoints, events etc…

  • description – Description for this task

  • config – K,V pairs to be part of user config

  • amp – Enable AMP for training

  • load_path – Initialize model from existing checkpoint (pre-trained)

  • load_dict – Provide dictionary to load from checkpoint. If None, then net will be loaded

  • publish_path – Publish path for best trained model (based on best key metric)

  • stats_path – Path to save the train stats

  • train_save_interval – checkpoint save interval for training

  • val_interval – validation interval (run every x epochs)

  • final_filename – name of final checkpoint that will be saved

  • key_metric_filename – best key metric model file name

config()[source]
event_names()[source]
load_path(output_dir, pretrained=True)[source]
abstract loss_function()[source]
abstract network()[source]
abstract optimizer()[source]
partition_datalist(request, datalist, shuffle=True)[source]
stats()[source]
train_additional_metrics()[source]
train_data_loader(datalist, batch_size=1, num_workers=0, cached=False)[source]
train_handlers(output_dir, events_dir, evaluator)[source]
train_inferer()[source]
train_iteration_update()[source]
train_key_metric()[source]
abstract train_post_transforms()[source]
abstract train_pre_transforms()[source]
val_additional_metrics()[source]
val_data_loader(datalist, batch_size=1, num_workers=0, cached=False)[source]
val_handlers(output_dir, events_dir)[source]
abstract val_inferer()[source]
val_iteration_update()[source]
val_key_metric()[source]
val_post_transforms()[source]
val_pre_transforms()[source]