Source code for monai.visualize.utils

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np
import torch

from monai.config.type_definitions import DtypeLike, NdarrayOrTensor
from monai.transforms.croppad.array import SpatialPad
from monai.transforms.utils import rescale_array
from monai.transforms.utils_pytorch_numpy_unification import repeat
from monai.utils.module import optional_import
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type

if TYPE_CHECKING:
    from matplotlib import cm
    from matplotlib import pyplot as plt
else:
    plt, _ = optional_import("matplotlib", name="pyplot")
    cm, _ = optional_import("matplotlib", name="cm")

__all__ = ["matshow3d", "blend_images"]


[docs] def matshow3d( volume: NdarrayOrTensor, fig: Any = None, title: str | None = None, figsize: tuple[int, int] = (10, 10), frames_per_row: int | None = None, frame_dim: int = -3, channel_dim: int | None = None, vmin: float | None = None, vmax: float | None = None, every_n: int = 1, interpolation: str = "none", show: bool = False, fill_value: Any = np.nan, margin: int = 1, dtype: DtypeLike = np.float32, **kwargs: Any, ) -> tuple[Any, np.ndarray]: """ Create a 3D volume figure as a grid of images. Args: volume: 3D volume to display. data shape can be `BCHWD`, `CHWD` or `HWD`. Higher dimensional arrays will be reshaped into (-1, H, W, [C]), `C` depends on `channel_dim` arg. A list of channel-first (C, H[, W, D]) arrays can also be passed in, in which case they will be displayed as a padded and stacked volume. fig: matplotlib figure or Axes to use. If None, a new figure will be created. title: title of the figure. figsize: size of the figure. frames_per_row: number of frames to display in each row. If None, sqrt(firstdim) will be used. frame_dim: for higher dimensional arrays, which dimension from (`-1`, `-2`, `-3`) is moved to the `-3` dimension. dim and reshape to (-1, H, W) shape to construct frames, default to `-3`. channel_dim: if not None, explicitly specify the channel dimension to be transposed to the last dimensionas shape (-1, H, W, C). this can be used to plot RGB color image. if None, the channel dimension will be flattened with `frame_dim` and `batch_dim` as shape (-1, H, W). note that it can only support 3D input image. default is None. vmin: `vmin` for the matplotlib `imshow`. vmax: `vmax` for the matplotlib `imshow`. every_n: factor to subsample the frames so that only every n-th frame is displayed. interpolation: interpolation to use for the matplotlib `matshow`. show: if True, show the figure. fill_value: value to use for the empty part of the grid. margin: margin to use for the grid. dtype: data type of the output stacked frames. kwargs: additional keyword arguments to matplotlib `matshow` and `imshow`. See Also: - https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.imshow.html - https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.matshow.html Example: >>> import numpy as np >>> import matplotlib.pyplot as plt >>> from monai.visualize import matshow3d # create a figure of a 3D volume >>> volume = np.random.rand(10, 10, 10) >>> fig = plt.figure() >>> matshow3d(volume, fig=fig, title="3D Volume") >>> plt.show() # create a figure of a list of channel-first 3D volumes >>> volumes = [np.random.rand(1, 10, 10, 10), np.random.rand(1, 10, 10, 10)] >>> fig = plt.figure() >>> matshow3d(volumes, fig=fig, title="List of Volumes") >>> plt.show() """ vol = convert_data_type(data=volume, output_type=np.ndarray)[0] if channel_dim is not None: if channel_dim not in [0, 1] or vol.shape[channel_dim] not in [1, 3, 4]: raise ValueError("channel_dim must be: None, 0 or 1, and channels of image must be 1, 3 or 4.") if isinstance(vol, (list, tuple)): # a sequence of channel-first volumes if not isinstance(vol[0], np.ndarray): raise ValueError("volume must be a list of arrays.") pad_size = np.max(np.asarray([v.shape for v in vol]), axis=0) pad = SpatialPad(pad_size[1:]) # assuming channel-first for item in vol vol = np.concatenate([pad(v) for v in vol], axis=0) else: # ndarray while len(vol.shape) < 3: vol = np.expand_dims(vol, 0) # type: ignore # so that we display 2d as well if channel_dim is not None: # move the expected dim to construct frames with `B` dim vol = np.moveaxis(vol, frame_dim, -4) # type: ignore vol = vol.reshape((-1, vol.shape[-3], vol.shape[-2], vol.shape[-1])) else: vol = np.moveaxis(vol, frame_dim, -3) # type: ignore vol = vol.reshape((-1, vol.shape[-2], vol.shape[-1])) vmin = np.nanmin(vol) if vmin is None else vmin vmax = np.nanmax(vol) if vmax is None else vmax # subsample every_n-th frame of the 3D volume vol = vol[:: max(every_n, 1)] if not frames_per_row: frames_per_row = int(np.ceil(np.sqrt(len(vol)))) # create the grid of frames cols = max(min(len(vol), frames_per_row), 1) rows = int(np.ceil(len(vol) / cols)) width = [[0, cols * rows - len(vol)]] if channel_dim is not None: width += [[0, 0]] # add pad width for the channel dim width += [[margin, margin]] * 2 vol = np.pad(vol.astype(dtype, copy=False), width, mode="constant", constant_values=fill_value) # type: ignore im = np.block([[vol[i * cols + j] for j in range(cols)] for i in range(rows)]) if channel_dim is not None: # move channel dim to the end im = np.moveaxis(im, 0, -1) # figure related configurations if isinstance(fig, plt.Axes): ax = fig else: if fig is None: fig = plt.figure(tight_layout=True) if not fig.axes: fig.add_subplot(111) ax = fig.axes[0] ax.matshow(im, vmin=vmin, vmax=vmax, interpolation=interpolation, **kwargs) ax.axis("off") if title is not None: ax.set_title(title) if figsize is not None and hasattr(fig, "set_size_inches"): fig.set_size_inches(figsize) if show: plt.show() return fig, im
[docs] def blend_images( image: NdarrayOrTensor, label: NdarrayOrTensor, alpha: float | NdarrayOrTensor = 0.5, cmap: str = "hsv", rescale_arrays: bool = True, transparent_background: bool = True, ) -> NdarrayOrTensor: """ Blend an image and a label. Both should have the shape CHW[D]. The image may have C==1 or 3 channels (greyscale or RGB). The label is expected to have C==1. Args: image: the input image to blend with label data. label: the input label to blend with image data. alpha: this specifies the weighting given to the label, where 0 is completely transparent and 1 is completely opaque. This can be given as either a single value or an array/tensor that is the same size as the input image. cmap: specify colormap in the matplotlib, default to `hsv`, for more details, please refer to: https://matplotlib.org/2.0.2/users/colormaps.html. rescale_arrays: whether to rescale the array to [0, 1] first, default to `True`. transparent_background: if true, any zeros in the label field will not be colored. .. image:: ../../docs/images/blend_images.png """ if label.shape[0] != 1: raise ValueError("Label should have 1 channel.") if image.shape[0] not in (1, 3): raise ValueError("Image should have 1 or 3 channels.") if image.shape[1:] != label.shape[1:]: raise ValueError("image and label should have matching spatial sizes.") if isinstance(alpha, (np.ndarray, torch.Tensor)): if image.shape[1:] != alpha.shape[1:]: # pytype: disable=attribute-error,invalid-directive raise ValueError("if alpha is image, size should match input image and label.") # rescale arrays to [0, 1] if desired if rescale_arrays: image = rescale_array(image) label = rescale_array(label) # convert image to rgb (if necessary) and then rgba if image.shape[0] == 1: image = repeat(image, 3, axis=0) def get_label_rgb(cmap: str, label: NdarrayOrTensor) -> NdarrayOrTensor: _cmap = cm.get_cmap(cmap) label_np, *_ = convert_data_type(label, np.ndarray) label_rgb_np = _cmap(label_np[0]) label_rgb_np = np.moveaxis(label_rgb_np, -1, 0)[:3] label_rgb, *_ = convert_to_dst_type(label_rgb_np, label) return label_rgb label_rgb = get_label_rgb(cmap, label) if isinstance(alpha, (torch.Tensor, np.ndarray)): w_label = alpha elif isinstance(label, torch.Tensor): w_label = torch.full_like(label, alpha) else: w_label = np.full_like(label, alpha) if transparent_background: # where label == 0 (background), set label alpha to 0 w_label[label == 0] = 0 # pytype: disable=unsupported-operands w_image = 1 - w_label return w_image * image + w_label * label_rgb