Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions docs/how_to_guide/24_plot_summary_training_logs.ipynb
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good to this! a couple of comments:

  • let's try to keep it as concise as possible but make the how-to-guide code cells executable as well, e.g., not just plain text.
  • Guide 22 already covers TensorBoard setup, training, and viewing logs. This guide should focus narrowly on plot_summary usage and link to guide 22 for the setup, rather than re-explaining how to pass a Path or inference object.

I suggest the following outline / cells:

  1. One-liner intro: what plot_summary does, link to guide 22 for training/tracking setup
  2. One runnable setup cell: quick NPE training (reuse the pattern from guide 22: prior, simulator, train, get log_dir)
  3. Basic usage: single plot_summary call, show output
  4. Overlay with customization: one cell combining overlay, colors, labels, title, grid. This is the main feature to showcase
  5. Drop the standalone "Using a log directory" section. just mention in the intro that you can pass either an inference object or a Path

Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "5c901931",
"metadata": {},
"source": [
"# How to visualize training logs with `plot_summary`\n",
"\n",
"The `plot_summary` function plots data logged by the tensorboard summary writer during inference training. It is useful for inspecting training and validation loss curves."
]
},
{
"cell_type": "markdown",
"id": "e6a2bb95",
"metadata": {},
"source": [
"## Basic usage"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4ee1a507",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"from sbi.analysis import plot_summary\n",
"\n",
"fig, axes = plot_summary(inference, tags=[\"validation_loss\"])"
]
},
{
"cell_type": "markdown",
"id": "0711dc51",
"metadata": {},
"source": [
"## Overlaying training and validation loss\n",
"\n",
"The most common use case is comparing training and validation loss to check for overfitting. Set `overlay=True` to plot all tags on a single axes:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "77f1e3b8",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"fig, axes = plot_summary(\n",
" inference,\n",
" tags=[\"training_loss\", \"validation_loss\"],\n",
" overlay=True,\n",
" colors=[\"blue\", \"orange\"],\n",
" labels=[\"Train\", \"Validation\"],\n",
")"
]
},
{
"cell_type": "markdown",
"id": "22a8d01e",
"metadata": {},
"source": [
"## Customization\n",
"\n",
"You can add titles, grid lines, and adjust fonts:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5190ebe3",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"fig, axes = plot_summary(\n",
" inference,\n",
" tags=[\"training_loss\", \"validation_loss\"],\n",
" overlay=True,\n",
" colors=[\"blue\", \"orange\"],\n",
" labels=[\"Train\", \"Validation\"],\n",
" title=\"Loss Curves\",\n",
" grid=True,\n",
" ylabel=[\"Loss\"],\n",
" fontsize=14,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "566cb166",
"metadata": {},
"source": [
"## Using a log directory directly\n",
"\n",
"You can also pass a `Path` to a tensorboard log directory instead of an inference object:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a3116ef8",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"fig, axes = plot_summary(\n",
" Path(\"path/to/log/dir\"),\n",
" tags=[\"training_loss\", \"validation_loss\"],\n",
" overlay=True,\n",
")"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions docs/how_to_guide/visualization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ Visualization
:maxdepth: 1

05_conditional_distributions.ipynb
24_plot_summary_training_logs.ipynb
69 changes: 62 additions & 7 deletions sbi/analysis/tensorboard_output.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a couple of comments on the if-else structures below. but more generally, I think we can avoid it entirely. I suggest to construct a list of tags which can be a list with a single tag as well (for the overlay=False case). then you need just a single loop over this list for rendering the figure.

I suggest: Before creating the figure, convert the flat tags list into subplot_tags, a list of lists where each inner list is the group of tags for one subplot. In overlay mode, this is [tags] (one subplot, all tags). In non-overlay mode it's [[tag] for tag in tags] (one subplot per tag). Similarly, precompute subplot_ylabels: in overlay mode, join the ylabels if they differ (e.g. "training_loss / validation_loss"), otherwise use them directly. Does this make sense?

You would still need the helper function _build_plot_kwargs(tag_idx) that constructs the keyword arguments dict for a single ax.plot() call. It starts from a copy of plot_kwargs (or empty dict), then overrides color if colors was provided, overrides label if labels was provided, and falls back to the tag name as label if neither labels nor a plot_kwargs['label'] key exists. This avoids the duplicate-keyword problem (see comment below).

Overall, with this approach you can have a single loop and we avoid the current code duplication in the two branches.

Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def plot_summary(
xlabel: str = "epochs_trained",
ylabel: Optional[List[str]] = None,
plot_kwargs: Optional[Dict[str, Any]] = None,
overlay: bool = False,
colors: Optional[List[str]] = None,
labels: Optional[List[str]] = None,
legend: bool = True,
grid: bool = False,
title: Optional[Union[str, List[str]]] = None,
) -> Tuple[Figure, Axes]:
"""Plots data logged by the TensorBoard tracker of an inference object.

Expand All @@ -55,6 +61,13 @@ def plot_summary(
xlabel: x-axis label describing 'steps' attribute of tensorboards ScalarEvent.
ylabel: list of alternative ylabels for items in tags. Optional.
plot_kwargs: will be passed to ax.plot.
overlay: if True, plots all tags on a single axes intead of separate subplots.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: instead.

colors: list of colors, one per tag. If None, uses matplotlib's default colors.
labels: list of legend labels, one per tag. If None, uses tag names.
legend: optionally shows a legend when overlay is True or when labels provided
grid: whether to show grid lines.
title: title for the figure or individual subplots. A string sets the title
for all subplots. A list of strings sets titles per subplot.

Returns a tuple of Figure and Axes objects.
"""
Expand Down Expand Up @@ -109,25 +122,67 @@ def plot_summary(

plot_options.update(figsize=figsize, fontsize=fontsize)
if fig is None or axes is None:
num_subplots = len(tags)
if overlay:
num_subplots = 1
fig, axes = plt.subplots( # pyright: ignore[reportAssignmentType]
1,
len(tags),
num_subplots,
figsize=plot_options["figsize"],
**plot_options["subplots"],
)
axes = np.atleast_1d(axes) # type: ignore
assert fig is not None and axes is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better practice to do an if-else and raise a suitable error here, e.g., RuntimeError or ValueError with an informative error message.


_labels = labels or tags
ylabel = ylabel or tags

for i, ax in enumerate(axes): # type: ignore
ax.plot(
scalars[tags[i]]["step"], scalars[tags[i]]["value"], **plot_kwargs or {}
)

ax.set_ylabel(ylabel[i], fontsize=fontsize)
if overlay:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will have a kwargs problem here: both branches of these if-else case always pass explicit color=color and label=_labels[i] keyword arguments to ax.plot() as well as **(plot_kwargs or {}).
Thus, if a user passes both plot_kwargs={"color": "red"} or plot_kwargs={"label": "foo"}, it raises TypeError: got multiple values for keyword argument 'color'.

The fix would be merging all kwargs into a single dict before passing to ax.plot(), where we implement a hierarchy: explicit colors/labels params over plot_kwargs, over tag-name defaults. I suggest implement a short helper function, e.g., _build_plot_kwargs that does this and then returns the final kwargs to be passed to ax.plot()

ax = axes[0]
for i, tag in enumerate(tags):
color = colors[i] if colors else None
ax.plot(
scalars[tag]["step"],
scalars[tag]["value"],
color=color,
label=_labels[i],
**(plot_kwargs or {}),
)
ax.set_xlabel(xlabel, fontsize=fontsize)
# If overlay, we join all y labels
ax.set_ylabel(
ylabel[0] if len(set(ylabel)) == 1 else " / ".join(ylabel),
fontsize=fontsize,
)
ax.xaxis.set_tick_params(labelsize=fontsize)
ax.yaxis.set_tick_params(labelsize=fontsize)
if legend:
ax.legend(fontsize=fontsize)
if grid:
ax.grid(True)
if title:
t = title if isinstance(title, str) else title[0]
ax.set_title(t, fontsize=fontsize)
Comment on lines +159 to +165
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move these if-else cases outside of the overlay if-else case? this would reduce code duplication, plus at the moment the legend would not be rendered for the non-overlay case (or is this on purpose?).

else:
for i, ax in enumerate(axes): # type: ignore
color = colors[i] if colors else None
ax.plot(
scalars[tags[i]]["step"],
scalars[tags[i]]["value"],
color=color,
label=_labels[i],
**plot_kwargs or {},
)

ax.set_ylabel(ylabel[i], fontsize=fontsize)
ax.set_xlabel(xlabel, fontsize=fontsize)
ax.xaxis.set_tick_params(labelsize=fontsize)
ax.yaxis.set_tick_params(labelsize=fontsize)
if grid:
ax.grid(True)
if title:
t = title if isinstance(title, str) else title[i]
ax.set_title(t, fontsize=fontsize)

plt.subplots_adjust(wspace=0.3)

Expand Down
135 changes: 135 additions & 0 deletions tests/test_plot_summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Tests for the enhanced plot_summary function."""

from unittest.mock import patch

import matplotlib.pyplot as plt
import pytest

from sbi.analysis.tensorboard_output import plot_summary


@pytest.fixture
def mock_scalars():
"""Mock scalar data mimicking tensorboard event data."""
return {
"training_loss": {
"step": list(range(100)),
"value": [1.0 / (i + 1) for i in range(100)],
},
"validation_loss": {
"step": list(range(100)),
"value": [1.2 / (i + 1) for i in range(100)],
},
}


@pytest.fixture
def mock_inference(mock_scalars):
"""Patch event data loading so we don't need real tensorboard logs."""
with patch(
"sbi.analysis.tensorboard_output._get_event_data_from_log_dir"
) as mock_get:
mock_get.return_value = {"scalars": mock_scalars}
from pathlib import Path

yield Path("/fake/log/dir")


class TestPlotSummaryBackwardCompat:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's please use plain test functions instead of these test classes, unless we really need the classes, e.g., when we share init or fixtures within classes.

"""Existing behavior should not change."""

def test_single_tag(self, mock_inference):
fig, axes = plot_summary(
mock_inference,
tags=["training_loss"],
disable_tensorboard_prompt=True,
)
assert axes.shape == (1,)
plt.close(fig)

def test_multiple_tags_separate_subplots(self, mock_inference):
fig, axes = plot_summary(
mock_inference,
tags=["training_loss", "validation_loss"],
disable_tensorboard_prompt=True,
)
assert axes.shape == (2,)
plt.close(fig)


class TestPlotSummaryOverlay:
"""New overlay functionality."""

def test_overlay_creates_single_axes(self, mock_inference):
fig, axes = plot_summary(
mock_inference,
tags=["training_loss", "validation_loss"],
overlay=True,
disable_tensorboard_prompt=True,
)
assert axes.shape == (1,)
# Should have 2 lines on the single axes
assert len(axes[0].get_lines()) == 2
plt.close(fig)

def test_overlay_with_colors(self, mock_inference):
fig, axes = plot_summary(
mock_inference,
tags=["training_loss", "validation_loss"],
overlay=True,
colors=["blue", "orange"],
disable_tensorboard_prompt=True,
)
lines = axes[0].get_lines()
assert lines[0].get_color() == "blue"
assert lines[1].get_color() == "orange"
plt.close(fig)

def test_overlay_with_labels(self, mock_inference):
fig, axes = plot_summary(
mock_inference,
tags=["training_loss", "validation_loss"],
overlay=True,
labels=["Train", "Val"],
disable_tensorboard_prompt=True,
)
legend = axes[0].get_legend()
assert legend is not None
texts = [t.get_text() for t in legend.get_texts()]
assert texts == ["Train", "Val"]
plt.close(fig)

def test_overlay_legend_disabled(self, mock_inference):
fig, axes = plot_summary(
mock_inference,
tags=["training_loss", "validation_loss"],
overlay=True,
legend=False,
disable_tensorboard_prompt=True,
)
assert axes[0].get_legend() is None
plt.close(fig)


class TestPlotSummaryGrid:
def test_grid_enabled(self, mock_inference):
fig, axes = plot_summary(
mock_inference,
tags=["training_loss"],
grid=True,
disable_tensorboard_prompt=True,
)
assert axes[0].xaxis.get_gridlines()[0].get_visible()
plt.close(fig)


class TestPlotSummaryTitle:
def test_single_title(self, mock_inference):
fig, axes = plot_summary(
mock_inference,
tags=["training_loss"],
title="My Plot",
disable_tensorboard_prompt=True,
)
assert axes[0].get_title() == "My Plot"
plt.close(fig)
Loading