# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, Orientationd, ScaleIntensityRanged, Spacingd
from monailabel.interfaces.tasks.infer import InferTask, InferType
from monailabel.scribbles.transforms import (
AddBackgroundScribblesFromROId,
ApplyGraphCutOptimisationd,
MakeISegUnaryd,
MakeLikelihoodFromScribblesHistogramd,
)
from monailabel.transform.post import BoundingBoxd, Restored
[docs]class HistogramBasedGraphCut(InferTask):
"""
Defines histogram-based GraphCut task for Generic segmentation from the following paper:
Wang, Guotai, et al. "Interactive medical image segmentation using deep learning with image-specific fine tuning."
IEEE transactions on medical imaging 37.7 (2018): 1562-1573. (preprint: https://arxiv.org/pdf/1710.04043.pdf)
This task takes as input 1) original image volume and 2) scribbles from user
indicating foreground and background regions. A likelihood volume is generated using histogram method.
User-scribbles are incorporated using Equation 7 on page 4 of the paper.
SimpleCRF's GraphCut layer is used to optimise Equation 5 from the paper, where unaries come from Equation 7
and pairwise is the original input volume.
"""
def __init__(
self,
dimension=3,
description="A post processing step with histogram-based GraphCut for Generic segmentation",
intensity_range=(-300, 200, 0.0, 1.0),
pix_dim=(2.5, 2.5, 5.0),
):
super().__init__(
path=None, network=None, labels=None, type=InferType.SCRIBBLES, dimension=dimension, description=description
)
self.intensity_range = intensity_range
self.pix_dim = pix_dim
[docs] def inferer(self):
return Compose(
[
# unary term maker
MakeISegUnaryd(
image="image",
logits="prob",
scribbles="label",
unary="unary",
scribbles_bg_label=2,
scribbles_fg_label=3,
),
# optimiser
ApplyGraphCutOptimisationd(
unary="unary",
pairwise="image",
post_proc_label="pred",
lamda=1.0,
sigma=0.1,
),
]
)
[docs] def post_transforms(self):
return [
Restored(keys="pred", ref_image="image"),
BoundingBoxd(keys="pred", result="result", bbox="bbox"),
]