diff --git a/docs/how_to_guide/24_plot_summary_training_logs.ipynb b/docs/how_to_guide/24_plot_summary_training_logs.ipynb new file mode 100644 index 000000000..4b7f771b8 --- /dev/null +++ b/docs/how_to_guide/24_plot_summary_training_logs.ipynb @@ -0,0 +1,139 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5c901931", + "metadata": {}, + "source": [ + "# How to visualize training logs with `plot_summary`\n", + "\n", + "The `plot_summary` function plots data logged by the tensorboard summary writer during inference training. It is useful for inspecting training and validation loss curves." + ] + }, + { + "cell_type": "markdown", + "id": "e6a2bb95", + "metadata": {}, + "source": [ + "## Basic usage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4ee1a507", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "from sbi.analysis import plot_summary\n", + "\n", + "fig, axes = plot_summary(inference, tags=[\"validation_loss\"])" + ] + }, + { + "cell_type": "markdown", + "id": "0711dc51", + "metadata": {}, + "source": [ + "## Overlaying training and validation loss\n", + "\n", + "The most common use case is comparing training and validation loss to check for overfitting. Set `overlay=True` to plot all tags on a single axes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77f1e3b8", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "fig, axes = plot_summary(\n", + " inference,\n", + " tags=[\"training_loss\", \"validation_loss\"],\n", + " overlay=True,\n", + " colors=[\"blue\", \"orange\"],\n", + " labels=[\"Train\", \"Validation\"],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "22a8d01e", + "metadata": {}, + "source": [ + "## Customization\n", + "\n", + "You can add titles, grid lines, and adjust fonts:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5190ebe3", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "fig, axes = plot_summary(\n", + " inference,\n", + " tags=[\"training_loss\", \"validation_loss\"],\n", + " overlay=True,\n", + " colors=[\"blue\", \"orange\"],\n", + " labels=[\"Train\", \"Validation\"],\n", + " title=\"Loss Curves\",\n", + " grid=True,\n", + " ylabel=[\"Loss\"],\n", + " fontsize=14,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "566cb166", + "metadata": {}, + "source": [ + "## Using a log directory directly\n", + "\n", + "You can also pass a `Path` to a tensorboard log directory instead of an inference object:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a3116ef8", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "fig, axes = plot_summary(\n", + " Path(\"path/to/log/dir\"),\n", + " tags=[\"training_loss\", \"validation_loss\"],\n", + " overlay=True,\n", + ")" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/how_to_guide/visualization.rst b/docs/how_to_guide/visualization.rst index 0a8008ba6..3f2a41c39 100644 --- a/docs/how_to_guide/visualization.rst +++ b/docs/how_to_guide/visualization.rst @@ -8,3 +8,4 @@ Visualization :maxdepth: 1 05_conditional_distributions.ipynb + 24_plot_summary_training_logs.ipynb diff --git a/sbi/analysis/tensorboard_output.py b/sbi/analysis/tensorboard_output.py index ee1f203a2..23bf63e72 100644 --- a/sbi/analysis/tensorboard_output.py +++ b/sbi/analysis/tensorboard_output.py @@ -38,6 +38,12 @@ def plot_summary( xlabel: str = "epochs_trained", ylabel: Optional[List[str]] = None, plot_kwargs: Optional[Dict[str, Any]] = None, + overlay: bool = False, + colors: Optional[List[str]] = None, + labels: Optional[List[str]] = None, + legend: bool = True, + grid: bool = False, + title: Optional[Union[str, List[str]]] = None, ) -> Tuple[Figure, Axes]: """Plots data logged by the TensorBoard tracker of an inference object. @@ -55,6 +61,13 @@ def plot_summary( xlabel: x-axis label describing 'steps' attribute of tensorboards ScalarEvent. ylabel: list of alternative ylabels for items in tags. Optional. plot_kwargs: will be passed to ax.plot. + overlay: if True, plots all tags on a single axes intead of separate subplots. + colors: list of colors, one per tag. If None, uses matplotlib's default colors. + labels: list of legend labels, one per tag. If None, uses tag names. + legend: optionally shows a legend when overlay is True or when labels provided + grid: whether to show grid lines. + title: title for the figure or individual subplots. A string sets the title + for all subplots. A list of strings sets titles per subplot. Returns a tuple of Figure and Axes objects. """ @@ -109,25 +122,67 @@ def plot_summary( plot_options.update(figsize=figsize, fontsize=fontsize) if fig is None or axes is None: + num_subplots = len(tags) + if overlay: + num_subplots = 1 fig, axes = plt.subplots( # pyright: ignore[reportAssignmentType] 1, - len(tags), + num_subplots, figsize=plot_options["figsize"], **plot_options["subplots"], ) axes = np.atleast_1d(axes) # type: ignore + assert fig is not None and axes is not None + _labels = labels or tags ylabel = ylabel or tags - for i, ax in enumerate(axes): # type: ignore - ax.plot( - scalars[tags[i]]["step"], scalars[tags[i]]["value"], **plot_kwargs or {} - ) - - ax.set_ylabel(ylabel[i], fontsize=fontsize) + if overlay: + ax = axes[0] + for i, tag in enumerate(tags): + color = colors[i] if colors else None + ax.plot( + scalars[tag]["step"], + scalars[tag]["value"], + color=color, + label=_labels[i], + **(plot_kwargs or {}), + ) ax.set_xlabel(xlabel, fontsize=fontsize) + # If overlay, we join all y labels + ax.set_ylabel( + ylabel[0] if len(set(ylabel)) == 1 else " / ".join(ylabel), + fontsize=fontsize, + ) ax.xaxis.set_tick_params(labelsize=fontsize) ax.yaxis.set_tick_params(labelsize=fontsize) + if legend: + ax.legend(fontsize=fontsize) + if grid: + ax.grid(True) + if title: + t = title if isinstance(title, str) else title[0] + ax.set_title(t, fontsize=fontsize) + else: + for i, ax in enumerate(axes): # type: ignore + color = colors[i] if colors else None + ax.plot( + scalars[tags[i]]["step"], + scalars[tags[i]]["value"], + color=color, + label=_labels[i], + **plot_kwargs or {}, + ) + + ax.set_ylabel(ylabel[i], fontsize=fontsize) + ax.set_xlabel(xlabel, fontsize=fontsize) + ax.xaxis.set_tick_params(labelsize=fontsize) + ax.yaxis.set_tick_params(labelsize=fontsize) + if grid: + ax.grid(True) + if title: + t = title if isinstance(title, str) else title[i] + ax.set_title(t, fontsize=fontsize) plt.subplots_adjust(wspace=0.3) diff --git a/tests/test_plot_summary.py b/tests/test_plot_summary.py new file mode 100644 index 000000000..5a1d9cabc --- /dev/null +++ b/tests/test_plot_summary.py @@ -0,0 +1,135 @@ +"""Tests for the enhanced plot_summary function.""" + +from unittest.mock import patch + +import matplotlib.pyplot as plt +import pytest + +from sbi.analysis.tensorboard_output import plot_summary + + +@pytest.fixture +def mock_scalars(): + """Mock scalar data mimicking tensorboard event data.""" + return { + "training_loss": { + "step": list(range(100)), + "value": [1.0 / (i + 1) for i in range(100)], + }, + "validation_loss": { + "step": list(range(100)), + "value": [1.2 / (i + 1) for i in range(100)], + }, + } + + +@pytest.fixture +def mock_inference(mock_scalars): + """Patch event data loading so we don't need real tensorboard logs.""" + with patch( + "sbi.analysis.tensorboard_output._get_event_data_from_log_dir" + ) as mock_get: + mock_get.return_value = {"scalars": mock_scalars} + from pathlib import Path + + yield Path("/fake/log/dir") + + +class TestPlotSummaryBackwardCompat: + """Existing behavior should not change.""" + + def test_single_tag(self, mock_inference): + fig, axes = plot_summary( + mock_inference, + tags=["training_loss"], + disable_tensorboard_prompt=True, + ) + assert axes.shape == (1,) + plt.close(fig) + + def test_multiple_tags_separate_subplots(self, mock_inference): + fig, axes = plot_summary( + mock_inference, + tags=["training_loss", "validation_loss"], + disable_tensorboard_prompt=True, + ) + assert axes.shape == (2,) + plt.close(fig) + + +class TestPlotSummaryOverlay: + """New overlay functionality.""" + + def test_overlay_creates_single_axes(self, mock_inference): + fig, axes = plot_summary( + mock_inference, + tags=["training_loss", "validation_loss"], + overlay=True, + disable_tensorboard_prompt=True, + ) + assert axes.shape == (1,) + # Should have 2 lines on the single axes + assert len(axes[0].get_lines()) == 2 + plt.close(fig) + + def test_overlay_with_colors(self, mock_inference): + fig, axes = plot_summary( + mock_inference, + tags=["training_loss", "validation_loss"], + overlay=True, + colors=["blue", "orange"], + disable_tensorboard_prompt=True, + ) + lines = axes[0].get_lines() + assert lines[0].get_color() == "blue" + assert lines[1].get_color() == "orange" + plt.close(fig) + + def test_overlay_with_labels(self, mock_inference): + fig, axes = plot_summary( + mock_inference, + tags=["training_loss", "validation_loss"], + overlay=True, + labels=["Train", "Val"], + disable_tensorboard_prompt=True, + ) + legend = axes[0].get_legend() + assert legend is not None + texts = [t.get_text() for t in legend.get_texts()] + assert texts == ["Train", "Val"] + plt.close(fig) + + def test_overlay_legend_disabled(self, mock_inference): + fig, axes = plot_summary( + mock_inference, + tags=["training_loss", "validation_loss"], + overlay=True, + legend=False, + disable_tensorboard_prompt=True, + ) + assert axes[0].get_legend() is None + plt.close(fig) + + +class TestPlotSummaryGrid: + def test_grid_enabled(self, mock_inference): + fig, axes = plot_summary( + mock_inference, + tags=["training_loss"], + grid=True, + disable_tensorboard_prompt=True, + ) + assert axes[0].xaxis.get_gridlines()[0].get_visible() + plt.close(fig) + + +class TestPlotSummaryTitle: + def test_single_title(self, mock_inference): + fig, axes = plot_summary( + mock_inference, + tags=["training_loss"], + title="My Plot", + disable_tensorboard_prompt=True, + ) + assert axes[0].get_title() == "My Plot" + plt.close(fig)