Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions adaptive/learner/learnerND.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,7 @@ def plot_slice(self, cut_mapping, n=None):
else:
raise ValueError("Only 1 or 2-dimensional plots can be generated.")

def plot_3D(self, with_triangulation=False):
def plot_3D(self, with_triangulation=False, return_fig=False):
"""Plot the learner's data in 3D using plotly.

Does *not* work with the
Expand All @@ -919,6 +919,9 @@ def plot_3D(self, with_triangulation=False):
----------
with_triangulation : bool, default: False
Add the verticices to the plot.
return_fig : bool, default: False
Return the `plotly.graph_objs.Figure` object instead of showing
the rendered plot (default).

Returns
-------
Expand Down Expand Up @@ -989,7 +992,7 @@ def plot_3D(self, with_triangulation=False):

fig = plotly.graph_objs.Figure(data=plots, layout=layout)

return plotly.offline.iplot(fig)
return fig if return_fig else plotly.offline.iplot(fig)

def _get_iso(self, level=0.0, which="surface"):
if which == "surface":
Expand Down
19 changes: 16 additions & 3 deletions docs/source/docs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,22 @@ on the *Play* :fa:`play` button or move the sliders.
return np.exp(-(x**2 + y**2 + z**2 - 0.75**2)**2/a**4)

learner = adaptive.LearnerND(sphere, bounds=[(-1, 1), (-1, 1), (-1, 1)])
adaptive.runner.simple(learner, lambda l: l.npoints == 3000)

learner.plot_3D()
adaptive.runner.simple(learner, lambda l: l.npoints == 5000)

fig = learner.plot_3D(return_fig=True)

# Remove a slice from the plot to show the inside of the sphere
scatter = fig.data[0]
coords_col = [
(x, y, z, color)
for x, y, z, color in zip(
scatter["x"], scatter["y"], scatter["z"], scatter.marker["color"]
)
if not (x > 0 and y > 0)
]
scatter["x"], scatter["y"], scatter["z"], scatter.marker["color"] = zip(*coords_col)

fig

see more in the :ref:`Tutorial Adaptive`.

Expand Down