Source code for monai.metrics.cumulative_average
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from monai.transforms import isnan
from monai.utils import convert_data_type
from .metric import Cumulative
[docs]class CumulativeAverage(Cumulative):
"""
Cumulatively record data value and aggregate for the average value.
It supports single class or multi-class data, for example,
value can be 0.44 (a loss value) or [0.3, 0.4] (metrics of two classes).
It also supports distributed data parallel, sync data when aggregating.
For example, recording loss values and compute the overall average value in every 5 iterations:
.. code-block:: python
average = CumulativeAverage()
for i, d in enumerate(dataloader):
loss = ...
average.append(loss)
if i % 5 == 0:
print(f"cumulative average of loss: {average.aggregate()}")
average.reset()
"""
def __init__(self) -> None:
super().__init__()
self.sum = None
self.not_nans = None
[docs] def reset(self):
"""
Reset all the running status, including buffers, sum, not nans count, etc.
"""
super().reset()
self.sum = None
self.not_nans = None
[docs] def aggregate(self):
"""
Sync data from all the ranks and compute the average value with previous sum value.
"""
data = self.get_buffer()
# compute SUM across the batch dimension
nans = isnan(data)
not_nans = convert_data_type((~nans), dtype=torch.float32)[0].sum(0)
data[nans] = 0
f = data.sum(0)
# clear the buffer for next update
super().reset()
self.sum = f if self.sum is None else (self.sum + f)
self.not_nans = not_nans if self.not_nans is None else (self.not_nans + not_nans)
return self.sum / self.not_nans