Source code for monailabel.tasks.train.basic_train

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import json
import logging
import os
import platform
import tempfile
import time
from abc import abstractmethod
from datetime import datetime
from typing import Any, List

import ignite
import torch
import torch.distributed
from ignite.engine import Events
from ignite.handlers import EarlyStopping
from monai.data import (
    CacheDataset,
    DataLoader,
    Dataset,
    PersistentDataset,
    SmartCacheDataset,
    ThreadDataLoader,
    get_track_meta,
    partition_dataset,
    set_track_meta,
)
from monai.engines import SupervisedEvaluator, SupervisedTrainer
from monai.handlers import (
    CheckpointLoader,
    CheckpointSaver,
    LrScheduleHandler,
    MeanDice,
    MLFlowHandler,
    StatsHandler,
    TensorBoardStatsHandler,
    ValidationHandler,
    from_engine,
    stopping_fn_from_metric,
)
from monai.inferers import SimpleInferer
from monai.transforms import Compose

from monailabel.config import settings
from monailabel.interfaces.datastore import Datastore
from monailabel.interfaces.tasks.train import TrainTask
from monailabel.tasks.train.handler import PublishStatsAndModel, prepare_stats
from monailabel.utils.others.generic import device_list, name_to_device, path_to_uri, remove_file

logger = logging.getLogger(__name__)


[docs]class Context: def __init__(self): self.start_ts = 0 # timestamp self.run_id = None # unique run_id self.output_dir = None # output dir for storing model self.cache_dir = None # cache dir for saving/caching temp data self.events_dir = None # events dir for storing tensorboard events self.datalist = None # input datalist self.train_datalist = None # train datalist self.train_batch_size = None # train batch size self.val_datalist = None # validation datalist self.val_batch_size = None # validation batch size self.device = None # device on which training will run self.network = None # network self.optimizer = None # optimizer self.dataset_type = "CacheDataset" # dataset type self.dataloader_type = "ThreadDataLoader" # dataloader type self.pretrained = False # using pretrained model self.max_epochs = 1 # max epochs to run training self.multi_gpu = False # multi gpu enabled self.local_rank = 0 # local rank in case of multi gpu self.world_size = 0 # world size in case of multi gpu self.request = None self.trainer = None self.evaluator = None self.tracking = None self.tracking_uri = None self.tracking_experiment_name = None self.tracking_run_name = None
[docs]class BasicTrainTask(TrainTask): """ This provides Basic Train Task to train a model using SupervisedTrainer and SupervisedEvaluator from MONAI """ TRAIN_METRIC_MEAN_DICE = "train_mean_dice" VAL_METRIC_MEAN_DICE = "val_mean_dice" TRAIN_METRIC_ACCURACY = "train_acc" VAL_METRIC_ACCURACY = "val_acc"
[docs] def __init__( self, model_dir, description=None, config=None, amp=True, load_path=None, load_dict=None, publish_path=None, stats_path=None, train_save_interval=20, val_interval=1, n_saved=5, final_filename="checkpoint_final.pt", key_metric_filename="model.pt", model_dict_key="model", find_unused_parameters=False, load_strict=False, labels=None, disable_meta_tracking=False, tracking="mlflow" if settings.MONAI_LABEL_TRACKING_ENABLED else None, tracking_uri=settings.MONAI_LABEL_TRACKING_URI, tracking_experiment_name=None, ): """ :param model_dir: Base Model Dir to save the model checkpoints, events etc... :param description: Description for this task :param config: K,V pairs to be part of user config :param amp: Enable AMP for training :param load_path: Initialize model from existing checkpoint (pre-trained) :param load_dict: Provide dictionary to load from checkpoint. If None, then `net` will be loaded :param publish_path: Publish path for best trained model (based on best key metric) :param stats_path: Path to save the train stats :param train_save_interval: checkpoint save interval for training :param val_interval: validation interval (run every x epochs) :param n_saved: max checkpoints to save :param final_filename: name of final checkpoint that will be saved :param key_metric_filename: best key metric model file name :param model_dict_key: key to save network weights into checkpoint :param find_unused_parameters: Applicable for DDP/Multi GPU training :param load_strict: Load pre-trained model in strict mode :param labels: Labels to be used as part of training context (some transform might need) :param disable_meta_tracking: Disable tracking for faster training rate (unless you are using MetaTensor/batched transforms) :param tracking: Tracking Manager for Experiment Management (only 'mlflow' is supported) :param tracking_uri: Tracking URI for Experiment Management :param tracking_experiment_name: Name for tracking experiment """ super().__init__(description) self._model_dir = model_dir self._amp = amp self._config = { "name": "train_01", "pretrained": True, "device": device_list(), "max_epochs": 50, "early_stop_patience": -1, "val_split": 0.2, "train_batch_size": 1, "val_batch_size": 1, "multi_gpu": True, "gpus": "all", "dataset": ["SmartCacheDataset", "CacheDataset", "PersistentDataset", "Dataset"], "dataloader": ["ThreadDataLoader", "DataLoader"], "tracking": ["mlflow", "None"] if settings.MONAI_LABEL_TRACKING_ENABLED else ["None", "mlflow"], "tracking_uri": tracking_uri if tracking_uri else "", "tracking_experiment_name": "", } if config: self._config.update(config) self._load_path = load_path self._load_dict = load_dict self._publish_path = publish_path self._stats_path = stats_path if stats_path else os.path.join(model_dir, "train_stats.json") self._train_save_interval = train_save_interval self._val_interval = val_interval self._n_saved = n_saved self._final_filename = final_filename self._key_metric_filename = key_metric_filename self._model_dict_key = model_dict_key self._find_unused_parameters = find_unused_parameters self._load_strict = load_strict self._labels = [] if labels is None else [labels] if isinstance(labels, str) else labels self._disable_meta_tracking = disable_meta_tracking self._tracking = tracking self._tracking_uri = tracking_uri self._tracking_experiment_name = tracking_experiment_name
[docs] def info(self): r = super().info() if self._labels: r["labels"] = self._labels return r
[docs] @abstractmethod def network(self, context: Context): pass
[docs] @abstractmethod def optimizer(self, context: Context): pass
[docs] @abstractmethod def loss_function(self, context: Context): pass
[docs] def lr_scheduler_handler(self, context: Context): # lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(context.optimizer, mode="min") # return LrScheduleHandler(lr_scheduler, print_lr=True, step_transform=lambda x: x.state.output[0]["loss"]) lr_scheduler = torch.optim.lr_scheduler.StepLR(context.optimizer, step_size=1000, gamma=0.1) return LrScheduleHandler(lr_scheduler, print_lr=True)
def _dataset(self, context, datalist, is_train, replace_rate=0.25): if context.multi_gpu: world_size = torch.distributed.get_world_size() if len(datalist) // world_size: # every gpu gets full data when datalist is smaller datalist = partition_dataset(data=datalist, num_partitions=world_size, even_divisible=True)[ context.local_rank ] transforms = ( self._validate_transforms(self.train_pre_transforms(context), "Training", "pre") if is_train else self._validate_transforms(self.val_pre_transforms(context), "Validation", "pre") ) dataset = ( CacheDataset(datalist, transforms) if context.dataset_type == "CacheDataset" else ( SmartCacheDataset(datalist, transforms, replace_rate) if context.dataset_type == "SmartCacheDataset" else ( PersistentDataset(datalist, transforms, cache_dir=os.path.join(context.cache_dir, "pds")) if context.dataset_type == "PersistentDataset" else Dataset(datalist, transforms) ) ) ) return dataset, datalist def _dataloader(self, context, dataset, batch_size, num_workers, shuffle=False): if context.dataloader_type == "ThreadDataLoader": return ThreadDataLoader( dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, ) return DataLoader( dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, )
[docs] def train_data_loader(self, context, num_workers=0, shuffle=True): dataset, datalist = self._dataset(context, context.train_datalist, is_train=True) logger.info(f"{context.local_rank} - Records for Training: {len(datalist)}") logger.debug(f"{context.local_rank} - Training: {datalist}") return self._dataloader(context, dataset, context.train_batch_size, num_workers, shuffle)
[docs] def train_inferer(self, context: Context): return SimpleInferer()
[docs] def train_key_metric(self, context: Context): return { self.TRAIN_METRIC_MEAN_DICE: MeanDice( output_transform=from_engine(["pred", "label"]), include_background=False, ) }
[docs] def load_path(self, output_dir, pretrained=True): load_path = os.path.join(output_dir, self._key_metric_filename) if not os.path.exists(load_path) and pretrained: load_path = self._load_path return load_path
[docs] def train_handlers(self, context: Context): handlers: List[Any] = [] # LR Scheduler lr_scheduler = self.lr_scheduler_handler(context) if lr_scheduler: handlers.append(lr_scheduler) if context.local_rank == 0: handlers.extend( [ StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), TensorBoardStatsHandler( log_dir=context.events_dir, tag_name="train_loss", output_transform=from_engine(["loss"], first=True), ), ] ) if context.tracking and context.tracking.lower() == "mlflow": handlers.append( MLFlowHandler( tracking_uri=context.tracking_uri, experiment_name=context.tracking_experiment_name, run_name=context.tracking_run_name, iteration_log=True, output_transform=from_engine(["loss"], first=True), close_on_complete=True, ) ) if context.evaluator: logger.info(f"{context.local_rank} - Adding Validation to run every '{self._val_interval}' interval") handlers.append(ValidationHandler(self._val_interval, validator=context.evaluator, epoch_level=True)) return handlers
[docs] def train_additional_metrics(self, context: Context): return None
[docs] def val_data_loader(self, context: Context, num_workers=0): dataset, datalist = self._dataset(context, context.val_datalist, is_train=False) logger.info(f"{context.local_rank} - Records for Validation: {len(datalist)}") logger.debug(f"{context.local_rank} - Validation: {datalist}") return self._dataloader(context, dataset, context.val_batch_size, num_workers)
[docs] def val_pre_transforms(self, context: Context): return self.train_pre_transforms(context)
[docs] def val_post_transforms(self, context: Context): return self.train_post_transforms(context)
[docs] def val_handlers(self, context: Context): handlers = [ StatsHandler(output_transform=lambda x: None, iteration_log=False), TensorBoardStatsHandler(log_dir=context.events_dir, output_transform=lambda x: None, iteration_log=False), ] if context.tracking and context.tracking.lower() == "mlflow": handlers.append( MLFlowHandler( tracking_uri=context.tracking_uri, experiment_name=context.tracking_experiment_name, run_name=context.tracking_run_name, iteration_log=False, close_on_complete=True, ) ) return handlers if context.local_rank == 0 else None
[docs] def val_key_metric(self, context): return { self.VAL_METRIC_MEAN_DICE: MeanDice( output_transform=from_engine(["pred", "label"]), include_background=False, ) }
[docs] def train_iteration_update(self, context: Context): return None
[docs] def val_iteration_update(self, context: Context): return None
[docs] def event_names(self, context: Context): return None
[docs] def val_additional_metrics(self, context: Context): return None
[docs] @abstractmethod def train_pre_transforms(self, context: Context): pass
[docs] @abstractmethod def train_post_transforms(self, context: Context): pass
[docs] @abstractmethod def val_inferer(self, context: Context): pass
def _load_external_ds(self, ds): if ds and isinstance(ds, str) and os.path.exists(ds): with open(ds) as fp: ds = json.load(fp) return ds
[docs] def partition_datalist(self, context: Context, shuffle=False): # user can external validation/training datalist in the request val_datalist = self._load_external_ds(context.request.get("val_ds")) train_datalist = self._load_external_ds(context.request.get("train_ds", context.datalist)) if not val_datalist: val_split = context.request.get("val_split", 0.0) if val_split > 0.0: train_datalist, val_datalist = partition_dataset( train_datalist, ratios=[(1 - val_split), val_split], shuffle=shuffle ) else: train_datalist = context.datalist val_datalist = [] if context.local_rank == 0: logger.info(f"Total Records for Training: {len(train_datalist)}") logger.info(f"Total Records for Validation: {len(val_datalist)}") return train_datalist, val_datalist
[docs] def stats(self): if self._stats_path and os.path.exists(self._stats_path): with open(self._stats_path) as fc: return json.load(fc) return {}
[docs] def config(self): return self._config
@staticmethod def _validate_transforms(transforms, step="Training", name="pre"): if not transforms or isinstance(transforms, Compose) or callable(transforms): return transforms if isinstance(transforms, list): return Compose(transforms) raise ValueError(f"{step} {name}-transforms are not of `list` or `Compose` type") def __call__(self, request, datastore: Datastore): logger.info(f"Train Request (input): {request}") req = copy.deepcopy(self._config) req.update(copy.deepcopy(request)) req["run_id"] = datetime.now().strftime("%Y%m%d_%H%M%S") device = name_to_device(req.get("device", "cuda")) req["device"] = device multi_gpu = req["multi_gpu"] multi_gpus = req.get("gpus", "all") world_size = torch.cuda.device_count() if not multi_gpus or multi_gpus == "all" else len(multi_gpus.split(",")) logger.info(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}") datalist = self.pre_process(req, datastore) if multi_gpu and world_size < 2: logger.info("Distributed/Multi GPU is limited") multi_gpu = False req["multi_gpu"] = False if multi_gpu: logger.info("Distributed/Multi GPU Training = TRUE") tfile = tempfile.NamedTemporaryFile().name if any(platform.win32_ver()): req["distributed_backend"] = "gloo" req["distributed_url"] = f"file://{tfile}" logger.info(f"Total proces to spawn: {world_size}") torch.multiprocessing.spawn(main_worker, nprocs=world_size, args=(world_size, req, datalist, self)) remove_file(tfile) else: logger.info("Distributed Training = FALSE") res = self.train(0, world_size, req, datalist) self.cleanup(req) return res self.cleanup(req) if os.path.exists(self._stats_path): with open(self._stats_path) as f: return json.load(f) return {}
[docs] def train(self, rank, world_size, request, datalist): start_ts = time.time() context: Context = Context() context.start_ts = start_ts context.request = request context.datalist = datalist context.local_rank = rank context.world_size = world_size context.run_id = request["run_id"] context.multi_gpu = request["multi_gpu"] if context.multi_gpu: os.environ["LOCAL_RANK"] = str(context.local_rank) logger.info(f"{context.local_rank} - Train Request (final): {request}") if context.multi_gpu: distributed_backend = context.request.get("distributed_backend", "nccl") distributed_url = context.request.get("distributed_url", "env://") torch.distributed.init_process_group( backend=distributed_backend, init_method=distributed_url, world_size=context.world_size, rank=context.local_rank, ) ignite.distributed.set_local_rank(rank) ignite.distributed.sync() context.device = self._device(context) context.max_epochs = request["max_epochs"] context.train_batch_size = request["train_batch_size"] context.val_batch_size = request["val_batch_size"] context.pretrained = request["pretrained"] context.dataset_type = request["dataset"] context.dataloader_type = request["dataloader"] name = request["name"] context.output_dir = os.path.join(self._model_dir, name) context.cache_dir = os.path.join(context.output_dir, f"cache_{context.run_id}") context.events_dir = os.path.join(context.output_dir, f"events_{context.run_id}") logger.info(f"Run/Output Path: {context.output_dir}") tracking_uri = request.get("tracking_uri", self._tracking_uri) if not tracking_uri: tracking_uri = path_to_uri(os.path.join(context.output_dir, "mlruns")) experiment_name = request.get("tracking_experiment_name") experiment_name = experiment_name if experiment_name else request.get("model") run_name = request.get("tracking_run_name") run_name = run_name if run_name else f"run_{context.run_id}" context.tracking = request.get("tracking", self._tracking) context.tracking = context.tracking[0] if isinstance(context.tracking, list) else context.tracking context.tracking_uri = tracking_uri context.tracking_experiment_name = experiment_name context.tracking_run_name = run_name logger.info(f"Tracking: {context.tracking} ") logger.info(f"Tracking URI: {context.tracking_uri}; ") logger.info(f"Tracking Experiment Name: {experiment_name}; Run Name: {run_name}") if not os.path.exists(context.output_dir): os.makedirs(context.output_dir, exist_ok=True) context.train_datalist, context.val_datalist = self.partition_datalist(context) context.network, context.optimizer = self._create_network_and_optimizer(context) context.evaluator = self._create_evaluator(context) context.trainer = self._create_trainer(context) # Finalize and Run Training self.finalize(context) # Disable Tracking meta_tracking = get_track_meta() if self._disable_meta_tracking: set_track_meta(False) try: context.trainer.run() finally: set_track_meta(meta_tracking) # In case of same process (restore) if context.multi_gpu: torch.distributed.destroy_process_group() # Try to clear cuda cache if torch.cuda.is_available(): torch.cuda.empty_cache() return prepare_stats(start_ts, context.trainer, context.evaluator)
[docs] def finalize(self, context): if context.local_rank == 0: publisher = PublishStatsAndModel( self._stats_path, self._publish_path, self._key_metric_filename, context.start_ts, context.run_id, context.output_dir, context.trainer, context.evaluator, ) if context.evaluator: context.evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=publisher) else: context.trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=publisher) early_stop_patience = int(context.request.get("early_stop_patience", 0)) if early_stop_patience > 0 and context.evaluator: kw = self.val_key_metric(context) metric_name = list(kw.keys())[0] if kw else None if metric_name: early_stopper = EarlyStopping( patience=early_stop_patience, score_function=stopping_fn_from_metric(metric_name), trainer=context.trainer, ) context.evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) else: logger.warning("No Validation Key Metric has been defined to enable Early Stopper")
[docs] def pre_process(self, request, datastore: Datastore): return datastore.datalist()
[docs] def get_cache_dir(self, request): run_id = request["run_id"] output_dir = os.path.join(self._model_dir, request["name"]) return os.path.join(output_dir, f"cache_{run_id}")
[docs] def cleanup(self, request): logger.info("Running cleanup...") # delete/cleanup cache remove_file(self.get_cache_dir(request))
def _device(self, context: Context): if context.multi_gpu: gpus = context.request.get("gpus", "all") multi_gpus = list(range(context.world_size)) if gpus == "all" else [int(g) for g in gpus.split(",")] gpu = multi_gpus[context.local_rank] logger.info(f"++++ Rank:{context.local_rank} => Using GPU-{gpu}") device = torch.device(f"cuda:{gpu}") torch.cuda.set_device(device) else: device = torch.device(context.request["device"] if torch.cuda.is_available() else "cpu") logger.info(f"{context.local_rank} - Using Device: {device}; IDX: {device.index}") return device def _create_network_and_optimizer(self, context: Context): network = self.network(context).to(context.device) # optimizer needs network parameters context.network = network optimizer = self.optimizer(context) if context.multi_gpu: network = torch.nn.parallel.DistributedDataParallel( network, device_ids=[context.device.index], output_device=context.device.index, find_unused_parameters=self._find_unused_parameters, ) return network, optimizer def _create_evaluator(self, context: Context): evaluator = None if context.val_datalist and len(context.val_datalist) > 0: val_hanlders: List = self.val_handlers(context) if context.local_rank == 0: val_hanlders.append( CheckpointSaver( save_dir=context.output_dir, save_dict={self._model_dict_key: context.network}, save_key_metric=True, key_metric_filename=self._key_metric_filename, n_saved=self._n_saved, ) ) evaluator = SupervisedEvaluator( device=context.device, val_data_loader=self.val_data_loader(context), network=context.network, inferer=self.val_inferer(context), postprocessing=self._validate_transforms(self.val_post_transforms(context), "Validation", "post"), key_val_metric=self.val_key_metric(context), additional_metrics=self.val_additional_metrics(context), val_handlers=val_hanlders, iteration_update=self.val_iteration_update(context), event_names=self.event_names(context), ) return evaluator def _create_trainer(self, context: Context): train_handlers: List = self.train_handlers(context) if context.local_rank == 0: train_handlers.append( CheckpointSaver( save_dir=context.output_dir, save_dict={self._model_dict_key: context.network}, save_interval=self._train_save_interval, save_final=True, final_filename=self._final_filename, save_key_metric=True, key_metric_filename=( f"train_{self._key_metric_filename}" if context.evaluator else self._key_metric_filename ), n_saved=self._n_saved, ) ) self._load_checkpoint(context, train_handlers) return SupervisedTrainer( device=context.device, max_epochs=context.max_epochs, train_data_loader=self.train_data_loader(context), network=context.network, optimizer=context.optimizer, loss_function=self.loss_function(context), inferer=self.train_inferer(context), amp=self._amp, postprocessing=self._validate_transforms(self.train_post_transforms(context), "Training", "post"), key_train_metric=self.train_key_metric(context), additional_metrics=self.train_additional_metrics(context), train_handlers=train_handlers, iteration_update=self.train_iteration_update(context), event_names=self.event_names(context), ) def _load_checkpoint(self, context, train_handlers): load_path = self.load_path(context.output_dir, context.pretrained) if load_path and os.path.exists(load_path): logger.info(f"{context.local_rank} - Load Path {load_path}") load_dict = {self._model_dict_key: context.network} if self._load_dict is None else self._load_dict map_location = {"cuda:0": f"cuda:{context.device.index}"} if context.multi_gpu else None train_handlers.append( CheckpointLoader(load_path, load_dict, map_location=map_location, strict=self._load_strict) )
[docs]def main_worker(rank, world_size, request, datalist, task: BasicTrainTask): logging.basicConfig( level=logging.INFO, format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", force=True, ) logger.info(f"Main Worker: {rank}") task.train(rank, world_size, request, datalist)