monailabel.tasks.scoring.tta module

class monailabel.tasks.scoring.tta.TTAScoring(model, network=None, deepedit=True, num_samples=5, spatial_size=None, spacing=None)[source]

Bases: monailabel.interfaces.tasks.scoring.ScoringMethod

First version of test time augmentation active learning

post_transforms()[source]
pre_transforms()[source]
class monailabel.tasks.scoring.tta.TestTimeAugmentation(transform, batch_size, num_workers, inferrer_fn, device='gpu', image_key='image', label_key='label', meta_keys=None, meta_key_postfix='meta_dict', return_full_data=False, progress=True)[source]

Bases: object

Class for performing test time augmentations. This will pass the same image through the network multiple times.

The user passes transform(s) to be applied to each realisation, and provided that at least one of those transforms is random, the network’s output will vary. Provided that inverse transformations exist for all supplied spatial transforms, the inverse can be applied to each realisation of the network’s output. Once in the same spatial reference, the results can then be combined and metrics computed.

Test time augmentations are a useful feature for computing network uncertainty, as well as observing the network’s dependency on the applied random transforms.

Reference:

Wang et al., Aleatoric uncertainty estimation with test-time augmentation for medical image segmentation with convolutional neural networks, https://doi.org/10.1016/j.neucom.2019.01.103

Parameters
  • transform (InvertibleTransform) – transform (or composed) to be applied to each realisation. At least one transform must be of type Randomizable. All random transforms must be of type InvertibleTransform.

  • batch_size (int) – number of realisations to infer at once.

  • num_workers (int) – how many subprocesses to use for data.

  • inferrer_fn (Callable) – function to use to perform inference.

  • device (Union[str, device]) – device on which to perform inference.

  • image_key – key used to extract image from input dictionary.

  • label_key – key used to extract label from input dictionary.

  • meta_keys (Optional[str]) – explicitly indicate the key of the expected meta data dictionary. for example, for data with key label, the metadata by default is in label_meta_dict. the meta data is a dictionary object which contains: filename, original_shape, etc. if None, will try to construct meta_keys by key_{meta_key_postfix}.

  • meta_key_postfix – use key_{postfix} to to fetch the meta data according to the key data, default is meta_dict, the meta data is a dictionary object. For example, to handle key image, read/write affine matrices from the metadata image_meta_dict dictionary’s affine field. this arg only works when meta_keys=None.

  • return_full_data (bool) – normally, metrics are returned (mode, mean, std, vvc). Setting this flag to True will return the full data. Dimensions will be same size as when passing a single image through inferrer_fn, with a dimension appended equal in size to num_examples (N), i.e., [N,C,H,W,[D]].

  • progress (bool) – whether to display a progress bar.

Example

transform = RandAffined(keys, ...)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])

tt_aug = TestTimeAugmentation(
    transform, batch_size=5, num_workers=0, inferrer_fn=lambda x: post_trans(model(x)), device=device
)
mode, mean, std, vvc = tt_aug(test_data)