"""Energy plot code."""
from collections.abc import Mapping, Sequence
from importlib import import_module
from typing import Any, Literal
import numpy as np
import xarray as xr
from arviz_base import convert_to_dataset, rcParams
from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.dist_plot import plot_dist
from arviz_plots.plots.utils import filter_aes, get_visual_kwargs, set_grid_layout
from arviz_plots.visuals import labelled_title, labelled_y, scatter_xy, vline
[docs]
def plot_energy(
dt,
*,
sample_dims=None,
kind=None,
show_bfmi=True,
threshold=0.3,
plot_collection=None,
backend=None,
labeller=None,
aes_by_visuals: Mapping[
Literal[
"dist",
"title",
"bfmi_points",
],
Sequence[str],
] = None,
visuals: Mapping[
Literal[
"dist",
"title",
"legend",
"remove_axis",
"bfmi_points",
"ref_line",
"title",
"ylabel",
],
Mapping[str, Any] | bool,
] = None,
stats: Mapping[Literal["dist"], Mapping[str, Any] | xr.Dataset] = None,
**pc_kwargs,
):
r"""Plot energy distributions and bfmi from gradient-based algorithms.
Generate a figure with the marginal energy distribution and the energy transition
distribution. Optionally, include a BFMI panel to inspect chain-wise Bayesian Fraction
of Missing Information values. Values below the threshold indicate poor exploration
of the energy distribution.
For details on BFMI and energy diagnostics see [1]_ for a more practical overview check
the EABM chapter on MCMC diagnostic `of gradient-based algorithms <https://arviz-devs.github.io/EABM/Chapters/MCMC_diagnostics.html#diagnosis-of-gradient-based-algorithms>`_.
Parameters
----------
dt : DataTree
``sample_stats`` group with an ``energy`` variable is mandatory.
sample_dims : sequence of str, optional
Dimensions to consider as sample dimensions when computing BFMI.
Defaults to ``rcParams["data.sample_dims"]``
kind : {"kde", "hist", "dot", "ecdf"}, optional
How to represent the marginal density.
Defaults to ``rcParams["plot.density_kind"]``
show_bfmi : bool, default True
Whether to include the BFMI scatter plot. If ``False``, only the energy plot will be shown.
threshold : float, default 0.3
Reference threshold for BFMI values, values below this indicate poor exploration of the
energy distribution.
plot_collection : PlotCollection, optional
backend : {"matplotlib", "bokeh", "plotly"}, optional
labeller : labeller, optional
aes_by_visuals : mapping of {str : sequence of str}, optional
Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
when plotted. Valid keys are the same as for `visuals`.
visuals : mapping of {str : mapping or bool}, optional
Valid keys are:
* dist -> depending on the value of `kind` passed to:
* "kde" -> passed to :func:`~arviz_plots.visuals.line_xy`
* "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line`
* "hist" -> passed to :func:`~arviz_plots.visuals.step_hist`
* "dot" -> passed to :func:`~arviz_plots.visuals.scatter_xy`
* title -> passed to :func:`~arviz_plots.visuals.labelled_title`
* legend -> passed to :class:`arviz_plots.PlotCollection.add_legend`
* remove_axis -> not passed anywhere, can only be ``False`` to skip calling this function
* title -> passed to :func:`~arviz_plots.visuals.labelled_title`
* bfmi_points -> passed to :func:`~arviz_plots.visuals.scatter_xy` for BFMI scatter plot
* ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y` for BFMI column y-axis label
* face -> :term:`visual` that fills the area under the energy distributions.
Defaults to True. Depending on the value of `kind` it is passed to:
* "kde" or "ecdf" -> passed to :func:`~arviz_plots.visuals.fill_between_y`
* "hist" -> passed to :func:`~arviz_plots.visuals.hist`
* dot -> ignored
stats : mapping, optional
Valid keys are:
* dist -> passed to kde, ecdf, ...
**pc_kwargs
Passed to :class:`arviz_plots.PlotCollection.wrap`
Returns
-------
PlotCollection
Examples
--------
Plot an energy plot using ecdf for the energy distributions.
.. plot::
:context: close-figs
>>> from arviz_plots import plot_energy, style
>>> style.use("arviz-variat")
>>> from arviz_base import load_arviz_data
>>> data = load_arviz_data('non_centered_eight')
>>> plot_energy(data, kind="ecdf")
.. minigallery:: plot_energy
References
----------
.. [1] Betancourt. Diagnosing Suboptimal Cotangent Disintegrations in
Hamiltonian Monte Carlo. (2016) https://arxiv.org/abs/1604.00695
""" # pylint: disable=line-too-long
if kind is None:
kind = rcParams["plot.density_kind"]
if visuals is None:
visuals = {}
else:
visuals = visuals.copy()
if sample_dims is None:
sample_dims = rcParams["data.sample_dims"]
if kind not in ("kde", "hist", "ecdf", "dot"):
raise ValueError("kind must be either 'kde', 'hist', 'ecdf' or 'dot'")
energy_ds, bfmi_ds = _get_energy_ds(dt, sample_dims=sample_dims)
if backend is None:
backend = rcParams["plot.backend"]
plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
if plot_collection is None:
pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy()
if show_bfmi:
new_ds = energy_ds.expand_dims(column=2).assign_coords(column=["bfmi", "energy"])
pc_kwargs["figure_kwargs"].setdefault("width_ratios", [1, 3])
num_cols = 2
else:
new_ds = energy_ds.expand_dims(column=1).assign_coords(column=["energy"])
num_cols = 1
pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
pc_kwargs.setdefault("cols", ["column"])
pc_kwargs["aes"].setdefault("color", ["energy"])
pc_kwargs = set_grid_layout(pc_kwargs, plot_bknd, new_ds, num_cols=num_cols, num_rows=1)
plot_collection = PlotCollection.grid(
new_ds,
backend=backend,
**pc_kwargs,
)
visuals.setdefault("credible_interval", False)
visuals.setdefault("point_estimate", False)
visuals.setdefault("point_estimate_text", False)
visuals.setdefault("face", True)
if aes_by_visuals is None:
aes_by_visuals = {}
else:
aes_by_visuals = aes_by_visuals.copy()
# Energy distributions plot
plot_collection.coords = {"column": "energy"}
plot_collection = plot_dist(
energy_ds,
var_names=None,
filter_vars=None,
group=None,
coords=None,
sample_dims=sample_dims,
kind=kind,
point_estimate=None,
ci_kind=None,
ci_prob=None,
plot_collection=plot_collection,
backend=backend,
labeller=labeller,
aes_by_visuals=aes_by_visuals,
visuals=visuals,
stats=stats,
)
plot_collection.coords = None
# legend for energy distributions
legend_kwargs = get_visual_kwargs(visuals, "legend")
if legend_kwargs is not False:
legend_kwargs.setdefault("dim", ["energy"])
legend_kwargs.setdefault("title", "")
plot_collection.add_legend(**legend_kwargs)
if show_bfmi:
# Scatter plot of BFMI values
bfmi_ms_kwargs = get_visual_kwargs(visuals, "bfmi_points")
if bfmi_ms_kwargs is not False:
_, _, bfmi_ignore = filter_aes(plot_collection, aes_by_visuals, "bfmi_points", [])
bfmi_ms_kwargs.setdefault("color", "B1")
plot_collection.coords = {"column": "bfmi"}
plot_collection.map(
scatter_xy,
"bfmi_points",
data=bfmi_ds,
ignore_aes=bfmi_ignore,
**bfmi_ms_kwargs,
)
plot_collection.coords = None
# Reference line for BFMI threshold
ref_line_kwargs = get_visual_kwargs(visuals, "ref_line")
if ref_line_kwargs is not False:
_, ref_aes, ref_ignore = filter_aes(
plot_collection, aes_by_visuals, "ref_line", sample_dims
)
if "color" not in ref_aes:
ref_line_kwargs.setdefault("color", "B2")
if "linestyle" not in ref_aes:
ref_line_kwargs.setdefault("linestyle", "C1")
# Wrap threshold into an xr.Dataset for PlotCollection.map
ref_ds = xr.Dataset({"ref_line": xr.DataArray(threshold)})
plot_collection.coords = {"column": "bfmi"}
plot_collection.map(
vline, "ref_line", data=ref_ds, ignore_aes=ref_ignore, **ref_line_kwargs
)
plot_collection.coords = None
# Add title for BFMI plot
title_kwargs = get_visual_kwargs(visuals, "title")
if title_kwargs is not False:
_, title_aes, title_ignore = filter_aes(
plot_collection, aes_by_visuals, "title", sample_dims
)
if "color" not in title_aes:
title_kwargs.setdefault("color", "B1")
plot_collection.coords = {"column": "bfmi"}
plot_collection.map(
labelled_title,
"title",
text="BFMI",
ignore_aes=title_ignore,
subset_info=True,
labeller=labeller,
**title_kwargs,
)
plot_collection.coords = None
# Add ylabel for BFMI plot
ylabel_kwargs = get_visual_kwargs(visuals, "ylabel")
if ylabel_kwargs is not False:
ylabel_kwargs.setdefault("text", "Chain")
_, _, ylabel_ignore = filter_aes(plot_collection, {}, "ylabel", [])
plot_collection.coords = {"column": "bfmi"}
plot_collection.map(
labelled_y,
"ylabel",
ignore_aes=ylabel_ignore,
**ylabel_kwargs,
)
plot_collection.coords = None
return plot_collection
def _get_energy_ds(dt, sample_dims):
"""Extract energy and BFMI data from DataTree.
Returns
-------
energy_ds : Dataset
Dataset with Energy variable containing marginal and transition energy
bfmi_ds : Dataset
Dataset with bfmi variable containing BFMI values and chain indices
"""
energy = dt["sample_stats"].energy.values
bfmi_vals = dt.sample_stats["energy"].azstats.bfmi(sample_dims=sample_dims)
n_chains = len(bfmi_vals)
chain_indices = np.arange(n_chains)
bfmi_ds = xr.Dataset(
{
"bfmi": xr.DataArray(
np.column_stack([bfmi_vals.values, chain_indices]),
dims=["chain", "plot_axis"],
coords={"chain": bfmi_vals.chain, "plot_axis": ["x", "y"]},
)
}
)
energy_ds = convert_to_dataset(
{"Energy": np.dstack([energy - energy.mean(), np.diff(energy, append=np.nan)])},
coords={"Energy_dim_0": ["marginal", "transition"]},
).rename({"Energy_dim_0": "energy"})
return energy_ds, bfmi_ds