from __future__ import annotations
import collections.abc
import inspect
import logging
from collections import OrderedDict, namedtuple
from typing import TYPE_CHECKING, Any, Union
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.offsetbox import AnchoredText
from matplotlib.transforms import Bbox
from mpl_toolkits.axes_grid1 import axes_size, make_axes_locatable
from .utils import (
Plottable,
get_histogram_axes_title,
get_plottable_protocol_bins,
hist_object_handler,
isLight,
process_histogram_parts,
align_marker,
to_padded2d,
)
if TYPE_CHECKING:
from numpy.typing import ArrayLike
StairsArtists = namedtuple("StairsArtists", "stairs errorbar legend_artist")
ErrorBarArtists = namedtuple("ErrorBarArtists", "errorbar")
ColormeshArtists = namedtuple("ColormeshArtists", "pcolormesh cbar text")
Hist1DArtists = Union[StairsArtists, ErrorBarArtists]
Hist2DArtists = ColormeshArtists
def soft_update_kwargs(kwargs, mods, rc=True):
not_default = [k for k, v in mpl.rcParamsDefault.items() if v != mpl.rcParams[k]]
respect = [
"hatch.linewidth",
"lines.linewidth",
"patch.linewidth",
"lines.linestyle",
]
aliases = {"ls": "linestyle", "lw": "linewidth"}
kwargs = {aliases[k] if k in aliases else k: v for k, v in kwargs.items()}
for key, val in mods.items():
rc_modded = (key in not_default) or (
key in [k.split(".")[-1] for k in not_default if k in respect]
)
if key not in kwargs and (rc and not rc_modded):
kwargs[key] = val
return kwargs
########################################
# Histogram plotter
[docs]
def histplot(
H, # Histogram object, tuple or array
bins=None, # Bins to be supplied when h is a value array or iterable of array
*,
yerr: ArrayLike | bool | None = None,
w2=None,
w2method=None,
stack: bool = False,
density: bool = False,
binwnorm=None,
histtype: str = "step",
xerr=False,
label=None,
sort=None,
edges=True,
binticks=False,
ax: mpl.axes.Axes | None = None,
flow="hint",
**kwargs,
):
"""
Create a 1D histogram plot from `np.histogram`-like inputs.
Parameters
----------
H : object
Histogram object with containing values and optionally bins. Can be:
- `np.histogram` tuple
- PlottableProtocol histogram object
- `boost_histogram` classic (<0.13) histogram object
- raw histogram values, provided `bins` is specified.
Or list thereof.
bins : iterable, optional
Histogram bins, if not part of ``H``.
yerr : iterable or bool, optional
Histogram uncertainties. Following modes are supported:
- True, sqrt(N) errors or poissonian interval when ``w2`` is specified
- shape(N) array of for one sided errors or list thereof
- shape(Nx2) array of for two sided errors or list thereof
w2 : iterable, optional
Sum of the histogram weights squared for poissonian interval error
calculation
w2method: callable, optional
Function calculating CLs with signature ``low, high = fcn(w, w2)``. Here
``low`` and ``high`` are given in absolute terms, not relative to w.
Default is ``None``. If w2 has integer values (likely to be data) poisson
interval is calculated, otherwise the resulting error is symmetric
``sqrt(w2)``. Specifying ``poisson`` or ``sqrt`` will force that behaviours.
stack : bool, optional
Whether to stack or overlay non-axis dimension (if it exists). N.B. in
contrast to ROOT, stacking is performed in a single call aka
``histplot([h1, h2, ...], stack=True)`` as opposed to multiple calls.
density : bool, optional
If true, convert sum weights to probability density (i.e. integrates to 1
over domain of axis) (Note: this option conflicts with ``binwnorm``)
binwnorm : float, optional
If true, convert sum weights to bin-width-normalized, with unit equal to
supplied value (usually you want to specify 1.)
histtype: {'step', 'fill', 'band', 'errorbar'}, optional, default: "step"
Type of histogram to plot:
- "step": skyline/step/outline of a histogram using `plt.stairs <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.stairs.html#matplotlib-axes-axes-stairs>`_
- "fill": filled histogram using `plt.stairs <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.stairs.html#matplotlib-axes-axes-stairs>`_
- "step": filled band spanning the yerr range of the histogram using `plt.stairs <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.stairs.html#matplotlib-axes-axes-stairs>`_
- "errorbar": single marker histogram using `plt.errorbar <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.errorbar.html#matplotlib-axes-axes-errorbar>`_
xerr: bool or float, optional
Size of xerr if ``histtype == 'errorbar'``. If ``True``, bin-width will be used.
label : str or list, optional
Label for legend entry.
sort: {'label'/'l', 'yield'/'y'}, optional
Append '_r' for reverse.
edges : bool, default: True, optional
Specifies whether to draw first and last edges of the histogram
binticks : bool, default: False, optional
Attempts to draw x-axis ticks coinciding with bin boundaries if feasible.
ax : matplotlib.axes.Axes, optional
Axes object (if None, last one is fetched or one is created)
flow : str, optional { "show", "sum", "hint", "none"}
Whether plot the under/overflow bin. If "show", add additional under/overflow bin.
If "sum", add the under/overflow bin content to first/last bin.
**kwargs :
Keyword arguments passed to underlying matplotlib functions -
{'stairs', 'errorbar'}.
Returns
-------
List[Hist1DArtists]
"""
# ax check
if ax is None:
ax = plt.gca()
else:
if not isinstance(ax, plt.Axes):
raise ValueError("ax must be a matplotlib Axes object")
# arg check
_allowed_histtype = ["fill", "step", "errorbar", "band"]
_err_message = f"Select 'histtype' from: {_allowed_histtype}"
assert histtype in _allowed_histtype, _err_message
assert flow is None or flow in {
"show",
"sum",
"hint",
"none",
}, "flow must be show, sum, hint, or none"
# Convert 1/0 etc to real bools
stack = bool(stack)
density = bool(density)
edges = bool(edges)
binticks = bool(binticks)
# Process input
hists = list(process_histogram_parts(H, bins))
final_bins, xtick_labels = get_plottable_protocol_bins(hists[0].axes[0])
_bin_widths = np.diff(final_bins)
_bin_centers = final_bins[1:] - _bin_widths / float(2)
assert final_bins.ndim == 1, "bins need to be 1 dimensional"
_x_axes_label = ax.get_xlabel()
x_axes_label = (
_x_axes_label
if _x_axes_label != ""
else get_histogram_axes_title(hists[0].axes[0])
)
plottables = []
flow_bins = final_bins
for i, h in enumerate(hists):
value, variance = np.copy(h.values()), h.variances()
if has_variances := variance is not None:
variance = np.copy(variance)
underflow, overflow = 0.0, 0.0
underflowv, overflowv = 0.0, 0.0
# One sided flow bins - hist (uproot hist does not have the over- or underflow traits)
if (
hasattr(h, "axes")
and (traits := getattr(h.axes[0], "traits", None)) is not None
and hasattr(traits, "underflow")
and hasattr(traits, "overflow")
):
if traits.overflow:
overflow = np.copy(h.values(flow=True))[-1]
if has_variances:
overflowv = np.copy(h.variances(flow=True))[-1]
if traits.underflow:
underflow = np.copy(h.values(flow=True))[0]
if has_variances:
underflowv = np.copy(h.variances(flow=True))[0]
# Both flow bins exist - uproot
elif hasattr(h, "values") and "flow" in inspect.getfullargspec(h.values).args:
if len(h.values()) + 2 == len(
h.values(flow=True)
): # easy case, both over/under
underflow, overflow = (
np.copy(h.values(flow=True))[0],
np.copy(h.values(flow=True))[-1],
)
if has_variances:
underflowv, overflowv = (
np.copy(h.variances(flow=True))[0],
np.copy(h.variances(flow=True))[-1],
)
# Set plottables
if flow == "none":
plottables.append(Plottable(value, edges=final_bins, variances=variance))
elif flow == "hint": # modify plottable
plottables.append(Plottable(value, edges=final_bins, variances=variance))
elif flow == "show":
_flow_bin_size = np.max(
[0.05 * (final_bins[-1] - final_bins[0]), np.mean(np.diff(final_bins))]
)
flow_bins = np.copy(final_bins)
if underflow > 0:
flow_bins = np.r_[flow_bins[0] - _flow_bin_size, flow_bins]
value = np.r_[underflow, value]
if has_variances:
variance = np.r_[underflowv, variance]
if overflow > 0:
flow_bins = np.r_[flow_bins, flow_bins[-1] + _flow_bin_size]
value = np.r_[value, overflow]
if has_variances:
variance = np.r_[variance, overflowv]
plottables.append(Plottable(value, edges=flow_bins, variances=variance))
elif flow == "sum":
if underflow > 0:
value[0] += underflow
if has_variances:
variance[0] += underflowv
if overflow > 0:
value[-1] += overflow
if has_variances:
variance[-1] += overflowv
plottables.append(Plottable(value, edges=final_bins, variances=variance))
else:
plottables.append(Plottable(value, edges=final_bins, variances=variance))
if w2 is not None:
for _w2, _plottable in zip(
w2.reshape(len(plottables), len(final_bins) - 1), plottables
):
_plottable.variances = _w2
_plottable.method = w2method
if w2 is not None and yerr is not None:
raise ValueError("Can only supply errors or w2")
_labels: list[str | None]
if label is None:
_labels = [None] * len(plottables)
elif isinstance(label, str):
_labels = [label] * len(plottables)
elif not np.iterable(label):
_labels = [str(label)] * len(plottables)
else:
_labels = [str(lab) for lab in label]
def iterable_not_string(arg):
return isinstance(arg, collections.abc.Iterable) and not isinstance(arg, str)
_chunked_kwargs: list[dict[str, Any]] = []
for _ in range(len(plottables)):
_chunked_kwargs.append({})
for kwarg in kwargs:
# Check if iterable
if iterable_not_string(kwargs[kwarg]):
# Check if tuple of floats or ints (can be used for colors)
if isinstance(kwargs[kwarg], tuple) and all(
isinstance(x, int) or isinstance(x, float) for x in kwargs[kwarg]
):
for i in range(len(_chunked_kwargs)):
_chunked_kwargs[i][kwarg] = kwargs[kwarg]
else:
for i, kw in enumerate(kwargs[kwarg]):
_chunked_kwargs[i][kwarg] = kw
else:
for i in range(len(_chunked_kwargs)):
_chunked_kwargs[i][kwarg] = kwargs[kwarg]
############################
# # yerr calculation
_yerr: np.ndarray | None
if yerr is not None:
# yerr is array
if hasattr(yerr, "__len__"):
_yerr = np.asarray(yerr)
# yerr is a number
elif isinstance(yerr, (int, float)) and not isinstance(yerr, bool):
_yerr = np.ones((len(plottables), len(final_bins) - 1)) * yerr
# yerr is automatic
else:
_yerr = None
else:
_yerr = None
if _yerr is not None:
assert isinstance(_yerr, np.ndarray)
if _yerr.ndim == 3:
# Already correct format
pass
elif _yerr.ndim == 2 and len(plottables) == 1:
# Broadcast ndim 2 to ndim 3
if _yerr.shape[-2] == 2: # [[1,1], [1,1]]
_yerr = _yerr.reshape(len(plottables), 2, _yerr.shape[-1])
elif _yerr.shape[-2] == 1: # [[1,1]]
_yerr = np.tile(_yerr, 2).reshape(len(plottables), 2, _yerr.shape[-1])
else:
raise ValueError("yerr format is not understood")
elif _yerr.ndim == 2:
# Broadcast yerr (nh, N) to (nh, 2, N)
_yerr = np.tile(_yerr, 2).reshape(len(plottables), 2, _yerr.shape[-1])
elif _yerr.ndim == 1:
# Broadcast yerr (1, N) to (nh, 2, N)
_yerr = np.tile(_yerr, 2 * len(plottables)).reshape(
len(plottables), 2, _yerr.shape[-1]
)
else:
raise ValueError("yerr format is not understood")
assert _yerr is not None
for yrs, _plottable in zip(_yerr, plottables):
_plottable.fixed_errors(*yrs)
# Sorting
if sort is not None:
if isinstance(sort, str):
if sort.split("_")[0] in ["l", "label"] and isinstance(_labels, list):
order = np.argsort(label) # [::-1]
elif sort.split("_")[0] in ["y", "yield"]:
_yields = [np.sum(_h.values) for _h in plottables]
order = np.argsort(_yields)
if len(sort.split("_")) == 2 and sort.split("_")[1] == "r":
order = order[::-1]
elif isinstance(sort, list) or isinstance(sort, np.ndarray):
if len(sort) != len(plottables):
raise ValueError(
f"Sort indexing array is of the wrong size - {len(sort)}, {len(plottables)} expected."
)
order = np.asarray(sort)
else:
raise ValueError(f"Sort type: {sort} not understood.")
plottables = [plottables[ix] for ix in order]
_chunked_kwargs = [_chunked_kwargs[ix] for ix in order]
_labels = [_labels[ix] for ix in order]
# ############################
# # Stacking, norming, density
if density is True and binwnorm is not None:
raise ValueError("Can only set density or binwnorm.")
if density is True:
if stack:
_total = np.sum(
np.array([plottable.values for plottable in plottables]), axis=0
)
for plottable in plottables:
plottable.flat_scale(1.0 / np.sum(np.diff(final_bins) * _total))
else:
for plottable in plottables:
plottable.density = True
elif binwnorm is not None:
for plottable, norm in zip(
plottables, np.broadcast_to(binwnorm, (len(plottables),))
):
plottable.flat_scale(norm)
plottable.binwnorm()
# Stack
if stack and len(plottables) > 1:
from .utils import stack as stack_fun
plottables = stack_fun(*plottables)
##########
# Plotting
return_artists: list[StairsArtists | ErrorBarArtists] = []
if histtype == "step":
for i in range(len(plottables)):
do_errors = yerr is not False and (
(yerr is not None or w2 is not None)
or (plottables[i].variances is not None)
)
_kwargs = _chunked_kwargs[i]
_label = _labels[i] if do_errors else None
_step_label = _labels[i] if not do_errors else None
_kwargs = soft_update_kwargs(_kwargs, {"linewidth": 1.5})
_plot_info = plottables[i].to_stairs()
_plot_info["baseline"] = None if not edges else 0
_s = ax.stairs(
**_plot_info,
label=_step_label,
**_kwargs,
)
if do_errors:
_kwargs = soft_update_kwargs(_kwargs, {"color": _s.get_edgecolor()})
_kwargs["linestyle"] = "none"
_plot_info = plottables[i].to_errorbar()
_e = ax.errorbar(
**_plot_info,
**_kwargs,
)
_e_leg = ax.errorbar(
[], [], yerr=1, xerr=1, color=_s.get_edgecolor(), label=_label
)
return_artists.append(
StairsArtists(
_s,
_e if do_errors else None,
_e_leg if do_errors else None,
)
)
_artist = _s
elif histtype == "fill":
for i in range(len(plottables)):
_kwargs = _chunked_kwargs[i]
_f = ax.stairs(
**plottables[i].to_stairs(), label=_labels[i], fill=True, **_kwargs
)
return_artists.append(StairsArtists(_f, None, None))
_artist = _f
elif histtype == "band":
band_defaults = {
"alpha": 0.5,
"edgecolor": "darkgray",
"facecolor": "whitesmoke",
"hatch": "//// /",
}
for i in range(len(plottables)):
_kwargs = _chunked_kwargs[i]
_f = ax.stairs(
**plottables[i].to_stairband(),
label=_labels[i],
fill=True,
**soft_update_kwargs(_kwargs, band_defaults),
)
return_artists.append(StairsArtists(_f, None, None))
_artist = _f
elif histtype == "errorbar":
err_defaults = {
"linestyle": "none",
"marker": ".",
"markersize": 10.0,
"elinewidth": 1,
}
_xerr: np.ndarray | float | int | None
if xerr is True:
_xerr = _bin_widths / 2
elif isinstance(xerr, (int, float)) and not isinstance(xerr, bool):
_xerr = xerr
else:
_xerr = None
for i in range(len(plottables)):
_kwargs = _chunked_kwargs[i]
_plot_info = plottables[i].to_errorbar()
if yerr is False:
_plot_info["yerr"] = None
_plot_info["xerr"] = _xerr
_e = ax.errorbar(
**_plot_info,
label=_labels[i],
**soft_update_kwargs(_kwargs, err_defaults),
)
return_artists.append(ErrorBarArtists(_e))
_artist = _e[0]
# Add sticky edges for autoscale
assert hasattr(
listy := _artist.sticky_edges.y, "append"
), "cannot append to sticky edges"
listy.append(0)
if xtick_labels is None or flow == "show":
if binticks:
_slice = int(round(float(len(final_bins)) / len(ax.get_xticks()))) + 1
ax.set_xticks(final_bins[::_slice])
else:
ax.set_xticks(_bin_centers)
ax.set_xticklabels(xtick_labels)
if x_axes_label:
ax.set_xlabel(x_axes_label)
# Flow extra styling
if (fig := ax.figure) is None:
raise ValueError("No figure found")
if flow == "hint":
_marker_size = (
30
* ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()).width
)
if underflow > 0.0:
ax.scatter(
final_bins[0],
0,
_marker_size,
marker=align_marker("<", halign="right"),
edgecolor="black",
zorder=5,
clip_on=False,
facecolor="white",
transform=ax.get_xaxis_transform(),
)
if overflow > 0.0:
ax.scatter(
final_bins[-1],
0,
_marker_size,
marker=align_marker(">", halign="left"),
edgecolor="black",
zorder=5,
clip_on=False,
facecolor="white",
transform=ax.get_xaxis_transform(),
)
elif flow == "show" and (underflow > 0.0 or overflow > 0.0):
xticks = ax.get_xticks().tolist()
lw = ax.spines["bottom"].get_linewidth()
_edges = plottables[0].edges
_centers = plottables[0].centers
_marker_size = (
20
* ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()).width
)
if underflow > 0.0:
xticks[0] = ""
xticks[1] = f"<{flow_bins[2]}"
ax.set_xticklabels(xticks)
ax.plot(
[_edges[0], _edges[1]],
[0, 0],
color="white",
zorder=5,
ls="--",
lw=lw,
transform=ax.get_xaxis_transform(),
clip_on=False,
)
ax.scatter(
_centers[0],
0,
_marker_size,
marker=align_marker("d", valign="center"),
edgecolor="black",
zorder=5,
clip_on=False,
facecolor="white",
transform=ax.get_xaxis_transform(),
)
if overflow > 0.0:
xticks[-1] = ""
xticks[-2] = f">{flow_bins[-3]}"
ax.set_xticklabels(xticks)
ax.plot(
[_edges[-2], _edges[-1]],
[0, 0],
color="white",
zorder=5,
ls="--",
lw=lw,
transform=ax.get_xaxis_transform(),
clip_on=False,
)
ax.scatter(
_centers[-1],
0,
_marker_size,
marker=align_marker("d", valign="center"),
edgecolor="black",
zorder=5,
clip_on=False,
facecolor="white",
transform=ax.get_xaxis_transform(),
)
return return_artists
[docs]
def hist2dplot(
H,
xbins=None,
ybins=None,
labels=None,
cbar: bool = True,
cbarsize="7%",
cbarpad=0.2,
cbarpos="right",
cbarextend=True,
cmin=None,
cmax=None,
ax: mpl.axes.Axes | None = None,
flow="hint",
**kwargs,
):
"""
Create a 2D histogram plot from `np.histogram`-like inputs.
Parameters
----------
H : object
Histogram object with containing values and optionally bins. Can be:
- `np.histogram` tuple
- `boost_histogram` histogram object
- raw histogram values as list of list or 2d-array
xbins : 1D array-like, optional, default None
Histogram bins along x axis, if not part of ``H``.
ybins : 1D array-like, optional, default None
Histogram bins along y axis, if not part of ``H``.
labels : 2D array (H-like) or bool, default None, optional
Array of per-bin labels to display. If ``True`` will
display numerical values
cbar : bool, optional, default True
Draw a colorbar. In contrast to mpl behaviors the cbar axes is
appended in such a way that it doesn't modify the original axes
width:height ratio.
cbarsize : str or float, optional, default "7%"
Colorbar width.
cbarpad : float, optional, default 0.2
Colorbar distance from main axis.
cbarpos : {'right', 'left', 'bottom', 'top'}, optional, default "right"
Colorbar position w.r.t main axis.
cbarextend : bool, optional, default False
Extends figure size to keep original axes size same as without cbar.
Only safe for 1 axes per fig.
cmin : float, optional
Colorbar minimum.
cmax : float, optional
Colorbar maximum.
ax : matplotlib.axes.Axes, optional
Axes object (if None, last one is fetched or one is created)
flow : str, optional {"show", "sum","hint", None}
Whether plot the under/overflow bin. If "show", add additional under/overflow bin. If "sum", add the under/overflow bin content to first/last bin. "hint" would highlight the bins with under/overflow contents
**kwargs :
Keyword arguments passed to underlying matplotlib function - pcolormesh.
Returns
-------
Hist2DArtist
"""
# ax check
if ax is None:
ax = plt.gca()
else:
if not isinstance(ax, plt.Axes):
raise ValueError("ax must be a matplotlib Axes object")
h = hist_object_handler(H, xbins, ybins)
# TODO: use Histogram everywhere
H = np.copy(h.values())
xbins, xtick_labels = get_plottable_protocol_bins(h.axes[0])
ybins, ytick_labels = get_plottable_protocol_bins(h.axes[1])
# Show under/overflow bins
# "show": Add additional bin with 2 times bin width
if (
hasattr(h, "values")
and "flow" not in inspect.getfullargspec(h.values).args
and flow is not None
):
print(
f"Warning: {type(h)} is not allowed to get flow bins, flow bin option set to None"
)
flow = None
elif flow in ["hint", "show"]:
xwidth, ywidth = (xbins[-1] - xbins[0]) * 0.05, (ybins[-1] - ybins[0]) * 0.05
pxbins = np.r_[xbins[0] - xwidth, xbins, xbins[-1] + xwidth]
pybins = np.r_[ybins[0] - ywidth, ybins, ybins[-1] + ywidth]
padded = to_padded2d(h)
hint_xlo, hint_xhi, hint_ylo, hint_yhi = True, True, True, True
if np.all(padded[0, :] == 0):
padded = padded[1:, :]
pxbins = pxbins[1:]
hint_xlo = False
if np.all(padded[-1, :] == 0):
padded = padded[:-1, :]
pxbins = pxbins[:-1]
hint_xhi = False
if np.all(padded[:, 0] == 0):
padded = padded[:, 1:]
pybins = pybins[1:]
hint_ylo = False
if np.all(padded[:, -1] == 0):
padded = padded[:, :-1]
pybins = pybins[:-1]
hint_yhi = False
if flow == "show":
H = padded
xbins, ybins = pxbins, pybins
elif flow == "sum":
H = np.copy(h.values())
# Sum borders
try:
H[0], H[-1] = (
H[0] + h.values(flow=True)[0, 1:-1], # type: ignore[call-arg]
H[-1] + h.values(flow=True)[-1, 1:-1], # type: ignore[call-arg]
)
H[:, 0], H[:, -1] = (
H[:, 0] + h.values(flow=True)[1:-1, 0], # type: ignore[call-arg]
H[:, -1] + h.values(flow=True)[1:-1, -1], # type: ignore[call-arg]
)
# Sum corners to corners
H[0, 0], H[-1, -1], H[0, -1], H[-1, 0] = (
h.values(flow=True)[0, 0] + H[0, 0], # type: ignore[call-arg]
h.values(flow=True)[-1, -1] + H[-1, -1], # type: ignore[call-arg]
h.values(flow=True)[0, -1] + H[0, -1], # type: ignore[call-arg]
h.values(flow=True)[-1, 0] + H[-1, 0], # type: ignore[call-arg]
)
except TypeError as error:
if "got an unexpected keyword argument 'flow'" in str(error):
raise TypeError(
f"The histograms value method {repr(h)} does not take a 'flow' argument. UHI Plottable doesn't require this to have, but it is required for this function."
f" Implementations like hist/boost-histogram support this argument."
) from error
xbin_centers = xbins[1:] - np.diff(xbins) / float(2)
ybin_centers = ybins[1:] - np.diff(ybins) / float(2)
_x_axes_label = ax.get_xlabel()
x_axes_label = (
_x_axes_label if _x_axes_label != "" else get_histogram_axes_title(h.axes[0])
)
_y_axes_label = ax.get_ylabel()
y_axes_label = (
_y_axes_label if _y_axes_label != "" else get_histogram_axes_title(h.axes[1])
)
H = H.T
if cmin is not None:
H[H < cmin] = None
if cmax is not None:
H[H > cmax] = None
X, Y = np.meshgrid(xbins, ybins)
kwargs.setdefault("shading", "flat")
pc = ax.pcolormesh(X, Y, H, vmin=cmin, vmax=cmax, **kwargs)
if x_axes_label:
ax.set_xlabel(x_axes_label)
if y_axes_label:
ax.set_ylabel(y_axes_label)
ax.set_xlim(xbins[0], xbins[-1])
ax.set_ylim(ybins[0], ybins[-1])
if xtick_labels is None: # Ordered axis
if len(ax.get_xticks()) > len(xbins) * 0.7:
ax.set_xticks(xbins)
else: # Categorical axis
ax.set_xticks(xbin_centers)
ax.set_xticklabels(xtick_labels)
if ytick_labels is None:
if len(ax.get_yticks()) > len(ybins) * 0.7:
ax.set_yticks(ybins)
else: # Categorical axis
ax.set_yticks(ybin_centers)
ax.set_yticklabels(ytick_labels)
if cbar:
cax = append_axes(
ax, size=cbarsize, pad=cbarpad, position=cbarpos, extend=cbarextend
)
cb_obj = plt.colorbar(pc, cax=cax)
else:
cb_obj = None
plt.sca(ax)
if flow == "show":
if hint_xlo:
ax.plot(
[xbins[1]] * 2,
[0, 1],
ls="--",
color="lightgrey",
clip_on=False,
transform=ax.get_xaxis_transform(),
)
if hint_xhi:
ax.plot(
[xbins[-2]] * 2,
[0, 1],
ls="--",
color="lightgrey",
clip_on=False,
transform=ax.get_xaxis_transform(),
)
if hint_ylo:
ax.plot(
[0, 1],
[ybins[1]] * 2,
ls="--",
color="lightgrey",
clip_on=False,
transform=ax.get_yaxis_transform(),
)
if hint_yhi:
ax.plot(
[0, 1],
[ybins[-2]] * 2,
ls="--",
color="lightgrey",
clip_on=False,
transform=ax.get_yaxis_transform(),
)
elif flow == "hint":
if (fig := ax.figure) is None:
raise ValueError("No figure found.")
_marker_size = (
30
* ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()).width
)
if hint_xlo:
ax.scatter(
0,
0,
_marker_size,
marker=align_marker("<", halign="right", valign="bottom"),
edgecolor="black",
zorder=5,
clip_on=False,
facecolor="white",
transform=ax.transAxes,
)
if hint_xhi:
ax.scatter(
1,
0,
_marker_size,
marker=align_marker(">", halign="left"),
edgecolor="black",
zorder=5,
clip_on=False,
facecolor="white",
transform=ax.transAxes,
)
if hint_ylo:
ax.scatter(
0,
0,
_marker_size,
marker=align_marker("v", valign="top", halign="left"),
edgecolor="black",
zorder=5,
clip_on=False,
facecolor="white",
transform=ax.transAxes,
)
if hint_yhi:
ax.scatter(
0,
1,
_marker_size,
marker=align_marker("^", valign="bottom"),
edgecolor="black",
zorder=5,
clip_on=False,
facecolor="white",
transform=ax.transAxes,
)
_labels: np.ndarray | None = None
if isinstance(labels, bool):
_labels = H if labels else None
elif np.iterable(labels):
label_array = np.asarray(labels).T
if H.shape == label_array.shape:
_labels = label_array
else:
raise ValueError(
f"Labels input has incorrect shape (expect: {H.shape}, got: {label_array.shape})"
)
elif labels is not None:
raise ValueError(
"Labels not understood, either specify a bool or a Hist-like array"
)
text_artists = []
if _labels is not None:
if (pccmap := pc.cmap) is None:
raise ValueError("No colormap found.")
for ix, xc in enumerate(xbin_centers):
for iy, yc in enumerate(ybin_centers):
normedh = pc.norm(H[iy, ix])
color = "black" if isLight(pccmap(normedh)[:-1]) else "lightgrey"
text_artists.append(
ax.text(
xc, yc, _labels[iy, ix], ha="center", va="center", color=color
)
)
return ColormeshArtists(pc, cb_obj, text_artists)
#############################################
# Utils
def overlap(ax, bbox, get_vertices=False):
"""
Find overlap of bbox for drawn elements an axes.
"""
from matplotlib.lines import Line2D
from matplotlib.patches import Patch, Rectangle
from matplotlib.text import Text
# From
# https://github.com/matplotlib/matplotlib/blob/08008d5cb4d1f27692e9aead9a76396adc8f0b19/lib/matplotlib/legend.py#L845
lines = []
bboxes = []
for handle in ax.lines:
assert isinstance(handle, Line2D)
path = handle.get_path()
lines.append(path)
for handle in ax.collections:
for path in handle.get_paths():
lines.append(path.interpolated(20))
for handle in ax.patches:
assert isinstance(handle, Patch)
if isinstance(handle, Rectangle):
transform = handle.get_data_transform()
bboxes.append(handle.get_bbox().transformed(transform))
else:
if len(handle.get_path().vertices) == 0:
continue
lines.append(handle.get_path().interpolated(20))
for handle in ax.texts:
assert isinstance(handle, Text)
bboxes.append(handle.get_window_extent())
# TODO Possibly other objects
vertices = np.concatenate([line.vertices for line in lines])
tvertices = [ax.transData.transform(v) for v in vertices]
overlap = bbox.count_contains(tvertices) + bbox.count_overlaps(bboxes)
if get_vertices:
return overlap, vertices
else:
return overlap
def _draw_leg_bbox(ax):
"""
Draw legend() and fetch it's bbox
"""
fig = ax.figure
leg = ax.get_legend()
if leg is None:
leg = [
c for c in ax.get_children() if isinstance(c, plt.matplotlib.legend.Legend)
][0]
fig.canvas.draw()
return leg.get_frame().get_bbox()
def _draw_text_bbox(ax):
"""
Draw legend() and fetch it's bbox
"""
fig = ax.figure
textboxes = [k for k in ax.get_children() if isinstance(k, AnchoredText)]
fig.canvas.draw()
if len(textboxes) > 1:
logging.warning("More than one textbox found")
for box in textboxes:
if box.loc in [1, 2]:
bbox = box.get_tightbbox(fig.canvas.renderer)
else:
bbox = textboxes[0].get_tightbbox(fig.canvas.renderer)
return bbox
[docs]
def yscale_legend(
ax: mpl.axes.Axes | None = None,
otol: float | int | None = None,
soft_fail: bool = False,
) -> mpl.axes.Axes:
"""
Automatically scale y-axis up to fit in legend().
Parameters
----------
ax : matplotlib.axes.Axes, optional
Axes object (if None, last one is fetched or one is created)
otol : float, optional
Tolerance for overlap, default 0. Set ``otol > 0`` for less strict scaling.
soft_fail : bool, optional
Set ``soft_fail=True`` to return even if it could not fit the legend.
Returns
-------
ax : matplotlib.axes.Axes
"""
if ax is None:
ax = plt.gca()
if otol is None:
otol = 0
scale_factor = 10 ** (1.05) if ax.get_yscale() == "log" else 1.05
max_scales = 0
while overlap(ax, _draw_leg_bbox(ax)) > otol:
logging.debug(
f"Legend overlap with other artists is {overlap(ax, _draw_leg_bbox(ax))}."
)
logging.info("Scaling y-axis by 5% to fit legend")
ax.set_ylim(ax.get_ylim()[0], ax.get_ylim()[-1] * scale_factor)
if (fig := ax.figure) is None:
raise RuntimeError("Could not fetch figure, maybe no plot is drawn yet?")
fig.canvas.draw()
if max_scales > 10:
if not soft_fail:
raise RuntimeError(
"Could not fit legend in 10 iterations, return anyway by passing `soft_fail=True`."
)
else:
logging.warning("Could not fit legend in 10 iterations")
break
max_scales += 1
return ax
[docs]
def yscale_anchored_text(
ax: mpl.axes.Axes | None = None,
otol: float | int | None = None,
soft_fail: bool = False,
) -> mpl.axes.Axes:
"""
Automatically scale y-axis up to fit AnchoredText
Parameters
----------
ax : matplotlib.axes.Axes, optional
Axes object (if None, last one is fetched or one is created)
otol : float, optional
Tolerance for overlap, default 0. Set ``otol > 0`` for less strict scaling.
soft_fail : bool, optional
Set ``soft_fail=True`` to return even if it could not fit the legend.
Returns
-------
ax : matplotlib.axes.Axes
"""
if ax is None:
ax = plt.gca()
if otol is None:
otol = 0
scale_factor = 10 ** (1.05) if ax.get_yscale() == "log" else 1.05
max_scales = 0
while overlap(ax, _draw_text_bbox(ax)) > otol:
logging.debug(
f"AnchoredText overlap with other artists is {overlap(ax, _draw_text_bbox(ax))}."
)
logging.info("Scaling y-axis by 5% to fit legend")
ax.set_ylim(ax.get_ylim()[0], ax.get_ylim()[-1] * scale_factor)
if (fig := ax.figure) is None:
raise RuntimeError("Could not fetch figure, maybe no plot is drawn yet?")
fig.canvas.draw()
if max_scales > 10:
if not soft_fail:
raise RuntimeError(
"Could not fit AnchoredText in 10 iterations, return anyway by passing `soft_fail=True`."
)
else:
logging.warning("Could not fit AnchoredText in 10 iterations")
break
max_scales += 1
return ax
[docs]
def ylow(ax: mpl.axes.Axes | None = None, ylow: float | None = None) -> mpl.axes.Axes:
"""
Set lower y limit to 0 or a specific value if not data/errors go lower.
Parameters
----------
ax : matplotlib.axes.Axes, optional
Axes object (if None, last one is fetched or one is created)
ylow : float, optional
Set lower y limit to a specific value.
Returns
-------
ax : matplotlib.axes.Axes
"""
if ax is None:
ax = plt.gca()
if ax.get_yaxis().get_scale() == "log":
return ax
if ylow is None:
# Check full figsize below 0
bbox = Bbox.from_bounds(
0, 0, ax.get_window_extent().width, -ax.get_window_extent().height
)
if overlap(ax, bbox) == 0:
ax.set_ylim(0, None)
else:
ydata = overlap(ax, bbox, get_vertices=True)[1][:, 1]
ax.set_ylim(np.min([np.min(ydata), ax.get_ylim()[0]]), None)
else:
ax.set_ylim(0, ax.get_ylim()[-1])
return ax
[docs]
def mpl_magic(ax=None, info=True):
"""
Consolidate all ex-post style adjustments:
ylow
yscale_legend
"""
if ax is None:
ax = plt.gca()
if info:
print("Running ROOT/CMS style adjustments (hide with info=False):")
ax = ylow(ax)
ax = yscale_legend(ax)
ax = yscale_anchored_text(ax)
return ax
########################################
# Figure/axes helpers
[docs]
def rescale_to_axessize(ax, w, h):
"""
Adjust figure size to axes size in inches
Parameters: w, h: width, height in inches
"""
if not ax:
ax = plt.gca()
left = ax.figure.subplotpars.left
r = ax.figure.subplotpars.right
t = ax.figure.subplotpars.top
b = ax.figure.subplotpars.bottom
figw = float(w) / (r - left)
figh = float(h) / (t - b)
ax.figure.set_size_inches(figw, figh)
[docs]
def box_aspect(ax, aspect=1):
"""
Adjust figure size to axes size in inches
Parameters: aspect: float, optional aspect ratio
"""
position = ax.get_position()
fig_width, fig_height = ax.get_figure().get_size_inches()
fig_aspect = fig_height / fig_width
pb = position.frozen()
pb1 = pb.shrunk_to_aspect(aspect, pb, fig_aspect)
ax.set_position(pb1)
class RemainderFixed(axes_size.Scaled):
def __init__(self, xsizes, ysizes, divider):
self.xsizes = xsizes
self.ysizes = ysizes
self.div = divider
def get_size(self, renderer):
xrel, xabs = sum(self.xsizes, start=axes_size.Fixed(0)).get_size(renderer)
yrel, yabs = sum(self.ysizes, start=axes_size.Fixed(0)).get_size(renderer)
bb = Bbox.from_bounds(*self.div.get_position()).transformed(
self.div._fig.transFigure
)
w = bb.width / self.div._fig.dpi - xabs
h = bb.height / self.div._fig.dpi - yabs
return 0, min([w, h])
[docs]
def make_square_add_cbar(ax, size=0.4, pad=0.1):
"""
Make input axes square and return an appended axes to the right for
a colorbar. Both axes resize together to fit figure automatically.
Works with tight_layout().
"""
divider = make_axes_locatable(ax)
margin_size = axes_size.Fixed(size)
pad_size = axes_size.Fixed(pad)
xsizes = [pad_size, margin_size]
ysizes = xsizes
cax = divider.append_axes("right", size=margin_size, pad=pad_size)
divider.set_horizontal([RemainderFixed(xsizes, ysizes, divider)] + xsizes)
divider.set_vertical([RemainderFixed(xsizes, ysizes, divider)] + ysizes)
return cax
[docs]
def append_axes(ax, size=0.1, pad=0.1, position="right", extend=False):
"""
Append a side ax to the current figure and return it.
Figure is automatically extended along the direction of the added axes to
accommodate it. Unfortunately can not be reliably chained.
"""
fig = ax.figure
bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
width, height = bbox.width, bbox.height
def convert(fraction, position=position):
if isinstance(fraction, str) and fraction.endswith("%"):
if position in ["right", "left"]:
fraction = width * float(fraction.strip("%")) / 100
elif position in ["top", "bottom"]:
fraction = height * float(fraction.strip("%")) / 100
return fraction
size = convert(size)
pad = convert(pad)
divider = make_axes_locatable(ax)
margin_size = axes_size.Fixed(size)
pad_size = axes_size.Fixed(pad)
xsizes = [pad_size, margin_size]
if position in ["top", "bottom"]:
xsizes = xsizes[::-1]
yhax = divider.append_axes(position, size=margin_size, pad=pad_size)
if extend:
def extend_ratio(ax, yhax):
ax.figure.canvas.draw()
orig_size = ax.get_position().size
new_size = sum(itax.get_position().size for itax in [ax, yhax])
return new_size / orig_size
if position in ["right"]:
divider.set_horizontal([axes_size.Fixed(width)] + xsizes)
fig.set_size_inches(
fig.get_size_inches()[0] * extend_ratio(ax, yhax)[0],
fig.get_size_inches()[1],
)
elif position in ["left"]:
divider.set_horizontal(xsizes[::-1] + [axes_size.Fixed(width)])
fig.set_size_inches(
fig.get_size_inches()[0] * extend_ratio(ax, yhax)[0],
fig.get_size_inches()[1],
)
elif position in ["top"]:
divider.set_vertical([axes_size.Fixed(height)] + xsizes[::-1])
fig.set_size_inches(
fig.get_size_inches()[0],
fig.get_size_inches()[1] * extend_ratio(ax, yhax)[1],
)
ax.get_shared_x_axes().join(ax, yhax)
elif position in ["bottom"]:
divider.set_vertical(xsizes + [axes_size.Fixed(height)])
fig.set_size_inches(
fig.get_size_inches()[0],
fig.get_size_inches()[1] * extend_ratio(ax, yhax)[1],
)
ax.get_shared_x_axes().join(ax, yhax)
return yhax
####################
# Legend Helpers
def hist_legend(ax=None, **kwargs):
from matplotlib.lines import Line2D
if ax is None:
ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
new_handles = [
Line2D([], [], c=h.get_edgecolor()) if isinstance(h, mpl.patches.Polygon) else h
for h in handles
]
ax.legend(handles=new_handles[::-1], labels=labels[::-1], **kwargs)
return ax
[docs]
def sort_legend(ax, order=None):
"""
ax : axes with legend labels in it
order : Ordered dict with renames or array with order
"""
handles, labels = ax.get_legend_handles_labels()
by_label = OrderedDict(zip(labels, handles))
if isinstance(order, OrderedDict):
ordered_label_list = list(order.keys())
elif isinstance(order, (list, tuple, np.ndarray)):
ordered_label_list = list(order)
elif order is None:
ordered_label_list = labels
else:
raise TypeError(f"Unexpected values type of order: {type(order)}")
ordered_label_list = [entry for entry in ordered_label_list if entry in labels]
ordered_label_values = [by_label[k] for k in ordered_label_list]
if isinstance(order, OrderedDict):
ordered_label_list = [order[k] for k in ordered_label_list]
return ordered_label_values, ordered_label_list