monailabel.tasks.train.basic_train module¶
- class monailabel.tasks.train.basic_train.BasicTrainTask(model_dir, description=None, config=None, amp=True, load_path=None, load_dict=None, publish_path=None, stats_path=None, train_save_interval=20, val_interval=1, n_saved=5, final_filename='checkpoint_final.pt', key_metric_filename='model.pt', model_dict_key='model', find_unused_parameters=False, load_strict=False, labels=None, disable_meta_tracking=False, tracking='mlflow', tracking_uri='', tracking_experiment_name=None)[source]¶
Bases:
monailabel.interfaces.tasks.train.TrainTask
This provides Basic Train Task to train a model using SupervisedTrainer and SupervisedEvaluator from MONAI
- Parameters
model_dir – Base Model Dir to save the model checkpoints, events etc…
description – Description for this task
config – K,V pairs to be part of user config
amp – Enable AMP for training
load_path – Initialize model from existing checkpoint (pre-trained)
load_dict – Provide dictionary to load from checkpoint. If None, then net will be loaded
publish_path – Publish path for best trained model (based on best key metric)
stats_path – Path to save the train stats
train_save_interval – checkpoint save interval for training
val_interval – validation interval (run every x epochs)
n_saved – max checkpoints to save
final_filename – name of final checkpoint that will be saved
key_metric_filename – best key metric model file name
model_dict_key – key to save network weights into checkpoint
find_unused_parameters – Applicable for DDP/Multi GPU training
load_strict – Load pre-trained model in strict mode
labels – Labels to be used as part of training context (some transform might need)
disable_meta_tracking – Disable tracking for faster training rate (unless you are using MetaTensor/batched transforms)
tracking – Tracking Manager for Experiment Management (only ‘mlflow’ is supported)
tracking_uri – Tracking URI for Experiment Management
tracking_experiment_name – Name for tracking experiment
- TRAIN_METRIC_ACCURACY = 'train_acc'¶
- TRAIN_METRIC_MEAN_DICE = 'train_mean_dice'¶
- VAL_METRIC_ACCURACY = 'val_acc'¶
- VAL_METRIC_MEAN_DICE = 'val_mean_dice'¶
- __init__(model_dir, description=None, config=None, amp=True, load_path=None, load_dict=None, publish_path=None, stats_path=None, train_save_interval=20, val_interval=1, n_saved=5, final_filename='checkpoint_final.pt', key_metric_filename='model.pt', model_dict_key='model', find_unused_parameters=False, load_strict=False, labels=None, disable_meta_tracking=False, tracking='mlflow', tracking_uri='', tracking_experiment_name=None)[source]¶
- Parameters
model_dir – Base Model Dir to save the model checkpoints, events etc…
description – Description for this task
config – K,V pairs to be part of user config
amp – Enable AMP for training
load_path – Initialize model from existing checkpoint (pre-trained)
load_dict – Provide dictionary to load from checkpoint. If None, then net will be loaded
publish_path – Publish path for best trained model (based on best key metric)
stats_path – Path to save the train stats
train_save_interval – checkpoint save interval for training
val_interval – validation interval (run every x epochs)
n_saved – max checkpoints to save
final_filename – name of final checkpoint that will be saved
key_metric_filename – best key metric model file name
model_dict_key – key to save network weights into checkpoint
find_unused_parameters – Applicable for DDP/Multi GPU training
load_strict – Load pre-trained model in strict mode
labels – Labels to be used as part of training context (some transform might need)
disable_meta_tracking – Disable tracking for faster training rate (unless you are using MetaTensor/batched transforms)
tracking – Tracking Manager for Experiment Management (only ‘mlflow’ is supported)
tracking_uri – Tracking URI for Experiment Management
tracking_experiment_name – Name for tracking experiment