Source code for pyrfu.plot.plot_heatmap
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import List, Tuple, Union
import matplotlib.pyplot as plt
# 3rd party imports
import numpy as np
from matplotlib.axes import Axes
from matplotlib.colorbar import Colorbar
from matplotlib.image import AxesImage
from mpl_toolkits.axes_grid1 import make_axes_locatable
__author__ = "Louis Richard"
__email__ = "louisr@irfu.se"
__copyright__ = "Copyright 2020-2023"
__license__ = "MIT"
__version__ = "2.4.2"
__status__ = "Prototype"
[docs]def plot_heatmap(
ax: Axes,
data: Union[np.ndarray, List[List[float]]],
row_labels: Union[np.ndarray, List[str]],
col_labels: Union[np.ndarray, List[str]],
cbar_kw: dict = None,
cbarlabel: str = "",
**kwargs,
) -> Tuple[AxesImage, Colorbar]:
r"""Create a heatmap from a numpy array and two lists of labels.
Parameters
----------
ax : matplotlib.axes._axes.Axes
Axis to which the heatmap is plotted.
data : array_like
A 2D numpy array of shape (N, M).
row_labels : list or numpy.ndarray
A list or array of length N with the labels for the rows.
col_labels : list or numpy.ndarray
A list or array of length M with the labels for the columns.
cbar_kw : dict, Optional
A dictionary with arguments to `matplotlib.Figure.colorbar`.
cbarlabel : str, Optional
The label for the colorbar.
**kwargs
All other arguments are forwarded to `imshow`.
Returns
-------
im : matplotlib.image.AxesImage
The AxesImage of the data.
cbar : matplotlib.colorbar.Colorbar
Colorbar.
"""
if ax is None:
_, ax = plt.subplots(1)
if not isinstance(data, (list, np.ndarray)):
raise TypeError("row_labels must be a list or numpy.ndarray")
if not isinstance(row_labels, (list, np.ndarray)):
raise TypeError("row_labels must be a list or numpy.ndarray")
if not isinstance(col_labels, (list, np.ndarray)):
raise TypeError("col_labels must be a list or numpy.ndarray")
if data.shape[0] != len(row_labels):
raise ValueError("row_labels must have the same length as data")
if data.shape[1] != len(col_labels):
raise ValueError("col_labels must have the same length as data")
# Plot the heatmap
if cbar_kw is None:
cbar_kw = {}
im = ax.imshow(data, **kwargs)
divider = make_axes_locatable(ax)
colorbar_axes = divider.append_axes("right", size="2%", pad=0.1)
# Create colorbar
cbar = ax.figure.colorbar(im, ax=ax, cax=colorbar_axes, **cbar_kw)
cbar.ax.set_ylabel(cbarlabel)
# We want to show all ticks...
ax.set_xticks(np.arange(data.shape[1]))
ax.set_yticks(np.arange(data.shape[0]))
# ... and label them with the respective list entries.
ax.set_xticklabels(col_labels)
ax.set_yticklabels(row_labels)
# Let the horizontal axes labeling appear on top.
ax.tick_params(top=True, bottom=True, labeltop=True, labelbottom=False)
# Turn spines off and create white grid.
# ax.spines[:].set_visible(False)
ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
ax.grid(which="minor", color="k", linestyle="-", linewidth=0.5)
ax.tick_params(which="minor", bottom=False, left=False)
ax.set_axisbelow(False)
cbar.ax.set_axisbelow(False)
return im, cbar