-
Notifications
You must be signed in to change notification settings - Fork 239
Improve plot_summary with overlay support, colors, labels, grid, and title (#1733) #1814
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,139 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "5c901931", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "# How to visualize training logs with `plot_summary`\n", | ||
| "\n", | ||
| "The `plot_summary` function plots data logged by the tensorboard summary writer during inference training. It is useful for inspecting training and validation loss curves." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "e6a2bb95", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Basic usage" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "4ee1a507", | ||
| "metadata": { | ||
| "vscode": { | ||
| "languageId": "plaintext" | ||
| } | ||
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "from sbi.analysis import plot_summary\n", | ||
| "\n", | ||
| "fig, axes = plot_summary(inference, tags=[\"validation_loss\"])" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "0711dc51", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Overlaying training and validation loss\n", | ||
| "\n", | ||
| "The most common use case is comparing training and validation loss to check for overfitting. Set `overlay=True` to plot all tags on a single axes:" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "77f1e3b8", | ||
| "metadata": { | ||
| "vscode": { | ||
| "languageId": "plaintext" | ||
| } | ||
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "fig, axes = plot_summary(\n", | ||
| " inference,\n", | ||
| " tags=[\"training_loss\", \"validation_loss\"],\n", | ||
| " overlay=True,\n", | ||
| " colors=[\"blue\", \"orange\"],\n", | ||
| " labels=[\"Train\", \"Validation\"],\n", | ||
| ")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "22a8d01e", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Customization\n", | ||
| "\n", | ||
| "You can add titles, grid lines, and adjust fonts:" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "5190ebe3", | ||
| "metadata": { | ||
| "vscode": { | ||
| "languageId": "plaintext" | ||
| } | ||
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "fig, axes = plot_summary(\n", | ||
| " inference,\n", | ||
| " tags=[\"training_loss\", \"validation_loss\"],\n", | ||
| " overlay=True,\n", | ||
| " colors=[\"blue\", \"orange\"],\n", | ||
| " labels=[\"Train\", \"Validation\"],\n", | ||
| " title=\"Loss Curves\",\n", | ||
| " grid=True,\n", | ||
| " ylabel=[\"Loss\"],\n", | ||
| " fontsize=14,\n", | ||
| ")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "566cb166", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## Using a log directory directly\n", | ||
| "\n", | ||
| "You can also pass a `Path` to a tensorboard log directory instead of an inference object:" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "a3116ef8", | ||
| "metadata": { | ||
| "vscode": { | ||
| "languageId": "plaintext" | ||
| } | ||
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "from pathlib import Path\n", | ||
| "\n", | ||
| "fig, axes = plot_summary(\n", | ||
| " Path(\"path/to/log/dir\"),\n", | ||
| " tags=[\"training_loss\", \"validation_loss\"],\n", | ||
| " overlay=True,\n", | ||
| ")" | ||
| ] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "language_info": { | ||
| "name": "python" | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 5 | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,3 +8,4 @@ Visualization | |
| :maxdepth: 1 | ||
|
|
||
| 05_conditional_distributions.ipynb | ||
| 24_plot_summary_training_logs.ipynb | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I made a couple of comments on the if-else structures below. but more generally, I think we can avoid it entirely. I suggest to construct a list of I suggest: Before creating the figure, convert the flat tags list into subplot_tags, a list of lists where each inner list is the group of tags for one subplot. In overlay mode, this is [tags] (one subplot, all tags). In non-overlay mode it's [[tag] for tag in tags] (one subplot per tag). Similarly, precompute subplot_ylabels: in overlay mode, join the You would still need the helper function Overall, with this approach you can have a single loop and we avoid the current code duplication in the two branches. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,6 +38,12 @@ def plot_summary( | |
| xlabel: str = "epochs_trained", | ||
| ylabel: Optional[List[str]] = None, | ||
| plot_kwargs: Optional[Dict[str, Any]] = None, | ||
| overlay: bool = False, | ||
| colors: Optional[List[str]] = None, | ||
| labels: Optional[List[str]] = None, | ||
| legend: bool = True, | ||
| grid: bool = False, | ||
| title: Optional[Union[str, List[str]]] = None, | ||
| ) -> Tuple[Figure, Axes]: | ||
| """Plots data logged by the TensorBoard tracker of an inference object. | ||
|
|
||
|
|
@@ -55,6 +61,13 @@ def plot_summary( | |
| xlabel: x-axis label describing 'steps' attribute of tensorboards ScalarEvent. | ||
| ylabel: list of alternative ylabels for items in tags. Optional. | ||
| plot_kwargs: will be passed to ax.plot. | ||
| overlay: if True, plots all tags on a single axes intead of separate subplots. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo: instead. |
||
| colors: list of colors, one per tag. If None, uses matplotlib's default colors. | ||
| labels: list of legend labels, one per tag. If None, uses tag names. | ||
| legend: optionally shows a legend when overlay is True or when labels provided | ||
| grid: whether to show grid lines. | ||
| title: title for the figure or individual subplots. A string sets the title | ||
| for all subplots. A list of strings sets titles per subplot. | ||
|
|
||
| Returns a tuple of Figure and Axes objects. | ||
| """ | ||
|
|
@@ -109,25 +122,67 @@ def plot_summary( | |
|
|
||
| plot_options.update(figsize=figsize, fontsize=fontsize) | ||
| if fig is None or axes is None: | ||
| num_subplots = len(tags) | ||
| if overlay: | ||
| num_subplots = 1 | ||
| fig, axes = plt.subplots( # pyright: ignore[reportAssignmentType] | ||
| 1, | ||
| len(tags), | ||
| num_subplots, | ||
| figsize=plot_options["figsize"], | ||
| **plot_options["subplots"], | ||
| ) | ||
| axes = np.atleast_1d(axes) # type: ignore | ||
| assert fig is not None and axes is not None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's better practice to do an if-else and raise a suitable error here, e.g., |
||
|
|
||
| _labels = labels or tags | ||
| ylabel = ylabel or tags | ||
|
|
||
| for i, ax in enumerate(axes): # type: ignore | ||
| ax.plot( | ||
| scalars[tags[i]]["step"], scalars[tags[i]]["value"], **plot_kwargs or {} | ||
| ) | ||
|
|
||
| ax.set_ylabel(ylabel[i], fontsize=fontsize) | ||
| if overlay: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we will have a The fix would be merging all |
||
| ax = axes[0] | ||
| for i, tag in enumerate(tags): | ||
| color = colors[i] if colors else None | ||
| ax.plot( | ||
| scalars[tag]["step"], | ||
| scalars[tag]["value"], | ||
| color=color, | ||
| label=_labels[i], | ||
| **(plot_kwargs or {}), | ||
| ) | ||
| ax.set_xlabel(xlabel, fontsize=fontsize) | ||
| # If overlay, we join all y labels | ||
| ax.set_ylabel( | ||
| ylabel[0] if len(set(ylabel)) == 1 else " / ".join(ylabel), | ||
| fontsize=fontsize, | ||
| ) | ||
| ax.xaxis.set_tick_params(labelsize=fontsize) | ||
| ax.yaxis.set_tick_params(labelsize=fontsize) | ||
| if legend: | ||
| ax.legend(fontsize=fontsize) | ||
| if grid: | ||
| ax.grid(True) | ||
| if title: | ||
| t = title if isinstance(title, str) else title[0] | ||
| ax.set_title(t, fontsize=fontsize) | ||
|
Comment on lines
+159
to
+165
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we move these if-else cases outside of the |
||
| else: | ||
| for i, ax in enumerate(axes): # type: ignore | ||
| color = colors[i] if colors else None | ||
| ax.plot( | ||
| scalars[tags[i]]["step"], | ||
| scalars[tags[i]]["value"], | ||
| color=color, | ||
| label=_labels[i], | ||
| **plot_kwargs or {}, | ||
| ) | ||
|
|
||
| ax.set_ylabel(ylabel[i], fontsize=fontsize) | ||
| ax.set_xlabel(xlabel, fontsize=fontsize) | ||
| ax.xaxis.set_tick_params(labelsize=fontsize) | ||
| ax.yaxis.set_tick_params(labelsize=fontsize) | ||
| if grid: | ||
| ax.grid(True) | ||
| if title: | ||
| t = title if isinstance(title, str) else title[i] | ||
| ax.set_title(t, fontsize=fontsize) | ||
|
|
||
| plt.subplots_adjust(wspace=0.3) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| """Tests for the enhanced plot_summary function.""" | ||
|
|
||
| from unittest.mock import patch | ||
|
|
||
| import matplotlib.pyplot as plt | ||
| import pytest | ||
|
|
||
| from sbi.analysis.tensorboard_output import plot_summary | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def mock_scalars(): | ||
| """Mock scalar data mimicking tensorboard event data.""" | ||
| return { | ||
| "training_loss": { | ||
| "step": list(range(100)), | ||
| "value": [1.0 / (i + 1) for i in range(100)], | ||
| }, | ||
| "validation_loss": { | ||
| "step": list(range(100)), | ||
| "value": [1.2 / (i + 1) for i in range(100)], | ||
| }, | ||
| } | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def mock_inference(mock_scalars): | ||
| """Patch event data loading so we don't need real tensorboard logs.""" | ||
| with patch( | ||
| "sbi.analysis.tensorboard_output._get_event_data_from_log_dir" | ||
| ) as mock_get: | ||
| mock_get.return_value = {"scalars": mock_scalars} | ||
| from pathlib import Path | ||
|
|
||
| yield Path("/fake/log/dir") | ||
|
|
||
|
|
||
| class TestPlotSummaryBackwardCompat: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's please use plain test functions instead of these test classes, unless we really need the classes, e.g., when we share init or fixtures within classes. |
||
| """Existing behavior should not change.""" | ||
|
|
||
| def test_single_tag(self, mock_inference): | ||
| fig, axes = plot_summary( | ||
| mock_inference, | ||
| tags=["training_loss"], | ||
| disable_tensorboard_prompt=True, | ||
| ) | ||
| assert axes.shape == (1,) | ||
| plt.close(fig) | ||
|
|
||
| def test_multiple_tags_separate_subplots(self, mock_inference): | ||
| fig, axes = plot_summary( | ||
| mock_inference, | ||
| tags=["training_loss", "validation_loss"], | ||
| disable_tensorboard_prompt=True, | ||
| ) | ||
| assert axes.shape == (2,) | ||
| plt.close(fig) | ||
|
|
||
|
|
||
| class TestPlotSummaryOverlay: | ||
| """New overlay functionality.""" | ||
|
|
||
| def test_overlay_creates_single_axes(self, mock_inference): | ||
| fig, axes = plot_summary( | ||
| mock_inference, | ||
| tags=["training_loss", "validation_loss"], | ||
| overlay=True, | ||
| disable_tensorboard_prompt=True, | ||
| ) | ||
| assert axes.shape == (1,) | ||
| # Should have 2 lines on the single axes | ||
| assert len(axes[0].get_lines()) == 2 | ||
| plt.close(fig) | ||
|
|
||
| def test_overlay_with_colors(self, mock_inference): | ||
| fig, axes = plot_summary( | ||
| mock_inference, | ||
| tags=["training_loss", "validation_loss"], | ||
| overlay=True, | ||
| colors=["blue", "orange"], | ||
| disable_tensorboard_prompt=True, | ||
| ) | ||
| lines = axes[0].get_lines() | ||
| assert lines[0].get_color() == "blue" | ||
| assert lines[1].get_color() == "orange" | ||
| plt.close(fig) | ||
|
|
||
| def test_overlay_with_labels(self, mock_inference): | ||
| fig, axes = plot_summary( | ||
| mock_inference, | ||
| tags=["training_loss", "validation_loss"], | ||
| overlay=True, | ||
| labels=["Train", "Val"], | ||
| disable_tensorboard_prompt=True, | ||
| ) | ||
| legend = axes[0].get_legend() | ||
| assert legend is not None | ||
| texts = [t.get_text() for t in legend.get_texts()] | ||
| assert texts == ["Train", "Val"] | ||
| plt.close(fig) | ||
|
|
||
| def test_overlay_legend_disabled(self, mock_inference): | ||
| fig, axes = plot_summary( | ||
| mock_inference, | ||
| tags=["training_loss", "validation_loss"], | ||
| overlay=True, | ||
| legend=False, | ||
| disable_tensorboard_prompt=True, | ||
| ) | ||
| assert axes[0].get_legend() is None | ||
| plt.close(fig) | ||
|
|
||
|
|
||
| class TestPlotSummaryGrid: | ||
| def test_grid_enabled(self, mock_inference): | ||
| fig, axes = plot_summary( | ||
| mock_inference, | ||
| tags=["training_loss"], | ||
| grid=True, | ||
| disable_tensorboard_prompt=True, | ||
| ) | ||
| assert axes[0].xaxis.get_gridlines()[0].get_visible() | ||
| plt.close(fig) | ||
|
|
||
|
|
||
| class TestPlotSummaryTitle: | ||
| def test_single_title(self, mock_inference): | ||
| fig, axes = plot_summary( | ||
| mock_inference, | ||
| tags=["training_loss"], | ||
| title="My Plot", | ||
| disable_tensorboard_prompt=True, | ||
| ) | ||
| assert axes[0].get_title() == "My Plot" | ||
| plt.close(fig) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good to this! a couple of comments:
TensorBoardsetup, training, and viewing logs. This guide should focus narrowly onplot_summaryusage and link to guide 22 for the setup, rather than re-explaining how to pass a Path or inference object.I suggest the following outline / cells:
plot_summarydoes, link to guide 22 for training/tracking setupplot_summarycall, show output