Source code for pyrfu.plot.annotate_heatmap

#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Built-in imports
import itertools
from typing import List, Union

# 3rd party imports
import numpy as np
from matplotlib import ticker

__author__ = "Louis Richard"
__email__ = "louisr@irfu.se"
__copyright__ = "Copyright 2020-2023"
__license__ = "MIT"
__version__ = "2.4.2"
__status__ = "Prototype"


[docs]def annotate_heatmap( im, data: Union[np.ndarray, List[List[float]]] = None, valfmt: str = "{x:.2f}", textcolors: tuple = ("black", "white"), threshold: float = None, **textkw, ): r"""Annotate a heatmap. Parameters ---------- im : matplotlib.image.AxesImage The AxesImage to be labeled. data : array_like, Optional Data used to annotate. If None, the image's data is used. valfmt : str or matplotlib.ticker.Formatter, Optional The format of the annotations inside the heatmap.. textcolors : dict, Optional A pair of colors. The first is used for values below a threshold, the second for those above. threshold : float, Optional Value in data units according to which the colors from textcolors are applied. If None (the default) uses the middle of the colormap as separation. **textkw All other arguments are forwarded to each call to `text` used to create the text labels. Returns ------- texts : list Cells labels """ if data is None: data = im.get_array() if not isinstance(data, (np.ndarray, list)): raise TypeError("data must be a numpy array") # Normalize the threshold to the images color range. if threshold is not None: threshold = im.norm(threshold) else: threshold = im.norm(data.max()) / 2.0 # Set default alignment to center, but allow it to be # overwritten by textkw. kw = {"horizontalalignment": "center", "verticalalignment": "center"} kw.update(textkw) # Get the formatter in case a string is supplied if isinstance(valfmt, str): valfmt = ticker.StrMethodFormatter(valfmt) # Loop over the data and create a `Text` for each "pixel". # Change the text's color depending on the data. texts = [] for i, j in itertools.product(range(data.shape[0]), range(data.shape[1])): kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) texts.append(text) return texts