import param
import itertools
import os
from spb.defaults import cfg
from spb.doc_utils.ipython import modify_parameterized_doc
from spb.backends.base_backend import Plot
from spb.backends.utils import tick_formatter_multiples_of
from spb.backends.plotly.renderers import (
Line2DRenderer, Line3DRenderer, Vector2DRenderer, Vector3DRenderer,
ComplexRenderer, ContourRenderer, SurfaceRenderer, Implicit3DRenderer,
GeometryRenderer, GenericRenderer, HVLineRenderer, Arrow2DRenderer
)
from spb.series import (
LineOver1DRangeSeries, List2DSeries, Parametric2DLineSeries,
ColoredLineOver1DRangeSeries, AbsArgLineSeries, ComplexPointSeries,
Parametric3DLineSeries, ComplexParametric3DLineSeries,
List3DSeries, Vector2DSeries, Vector3DSeries, SliceVector3DSeries,
RiemannSphereSeries, Implicit3DSeries,
ComplexDomainColoringSeries, ComplexSurfaceSeries,
ContourSeries, SurfaceOver2DRangeSeries, ParametricSurfaceSeries,
PlaneSeries, Geometry2DSeries, Geometry3DSeries, GenericDataSeries,
HVLineSeries, Arrow2DSeries, HLineSeries, VLineSeries
)
from sympy.external import import_module
import warnings
[docs]
@modify_parameterized_doc()
class PlotlyBackend(Plot):
"""
A backend for plotting SymPy's symbolic expressions using Plotly.
Notes
=====
A few bugs related to Plotly might prevent the correct visualization:
* with 2D domain coloring, the vertical axis is reversed, with negative
values on the top and positive values on the bottom.
* with 3D complex plots: when hovering a point, the tooltip will display
wrong information for the argument and the phase.
https://github.com/plotly/plotly.js/issues/5003
Hopefully, this bug will be fixed upstream.
See also
========
Plot, MatplotlibBackend, BokehBackend, K3DBackend
"""
wireframe_color = "#000000"
scattergl_threshold = 2000
# color bar spacing
_cbs = 0.15
# color bar scale down factor
_cbsdf = 0.75
renderers_map = {
LineOver1DRangeSeries: Line2DRenderer,
List2DSeries: Line2DRenderer,
Parametric2DLineSeries: Line2DRenderer,
ColoredLineOver1DRangeSeries: Line2DRenderer,
AbsArgLineSeries: Line2DRenderer,
ComplexPointSeries: Line2DRenderer,
Parametric3DLineSeries: Line3DRenderer,
ComplexParametric3DLineSeries: Line3DRenderer,
List3DSeries: Line3DRenderer,
Vector2DSeries: Vector2DRenderer,
Vector3DSeries: Vector3DRenderer,
SliceVector3DSeries: Vector3DRenderer,
Implicit3DSeries: Implicit3DRenderer,
ComplexDomainColoringSeries: ComplexRenderer,
ComplexSurfaceSeries: ComplexRenderer,
RiemannSphereSeries: ComplexRenderer,
ContourSeries: ContourRenderer,
SurfaceOver2DRangeSeries: SurfaceRenderer,
ParametricSurfaceSeries: SurfaceRenderer,
PlaneSeries: SurfaceRenderer,
Geometry2DSeries: GeometryRenderer,
Geometry3DSeries: GeometryRenderer,
GenericDataSeries: GenericRenderer,
HVLineSeries: HVLineRenderer,
HLineSeries: HVLineRenderer,
VLineSeries: HVLineRenderer,
Arrow2DSeries: Arrow2DRenderer
}
quivers_colors = param.ClassSelector(default=[], class_=(list, tuple), doc="""
List of colors for rendering quivers.""")
def __init__(self, *series, **kwargs):
self.np = import_module('numpy')
self.plotly = import_module(
'plotly',
import_kwargs={'fromlist': ['graph_objects', 'figure_factory']},
warn_not_installed=True,
min_module_version='5.0.0')
self.go = self.plotly.graph_objects
self.create_quiver = self.plotly.figure_factory.create_quiver
self.create_streamline = self.plotly.figure_factory.create_streamline
kwargs["_library"] = "plotly"
# The following colors corresponds to the discret color map
# px.colors.qualitative.Plotly.
kwargs.setdefault("colorloop", [
"#636EFA", "#EF553B", "#00CC96", "#AB63FA", "#FFA15A",
"#19D3F3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52"])
kwargs.setdefault("colormaps", [
"aggrnyl", "plotly3", "reds_r", "ice", "inferno",
"deep_r", "turbid_r", "gnbu_r", "geyser_r", "oranges_r"])
kwargs.setdefault("cyclic_colormaps", [
"phase", "twilight", "hsv", "icefire"])
# TODO: here I selected black and white, but they are not visible
# with dark or light theme respectively... Need a better selection
# of colors. Although, they are placed in the middle of the loop,
# so they are unlikely going to be used.
kwargs.setdefault("quivers_colors", [
"magenta", "crimson", "darkorange", "dodgerblue", "wheat",
"slategrey", "white", "black", "darkred", "indigo"])
kwargs.setdefault("update_event", cfg["plotly"]["update_event"])
kwargs.setdefault("use_latex", cfg["plotly"]["use_latex"])
kwargs.setdefault("grid", cfg["plotly"]["grid"])
kwargs.setdefault("minor_grid", cfg["plotly"]["show_minor_grid"])
kwargs.setdefault("theme", cfg["plotly"]["theme"])
# _init_cyclers needs to know if an existing figure was provided
self._use_existing_figure = "fig" in kwargs
super().__init__(*series, **kwargs)
if (
self.update_event and any(isinstance(s, Vector2DSeries)
for s in series)
):
warnings.warn(
"You are trying to use `update_event=True` with a 2D quiver "
"plot. This is likely going to cause a render-loop. You might "
"need to interrupt the kernel."
)
self._init_cyclers()
if not self._use_existing_figure:
if (
(self._imodule == "ipywidgets")
or self.update_event
):
self._fig = self.go.FigureWidget()
else:
self._fig = self.go.Figure()
# NOTE: Plotly 3D currently doesn't support latex labels
# https://github.com/plotly/plotly.js/issues/608
self._set_labels()
self._set_title()
if (
(len([s for s in self._series if s.is_2Dline]) > 10) and
(not type(self).colorloop) and
not ("process_piecewise" in kwargs.keys())
):
# add colors if needed
# this corresponds to px.colors.qualitative.Light24
self.colorloop = [
"#FD3216", "#00FE35", "#6A76FC", "#FED4C4", "#FE00CE",
"#0DF9FF", "#F6F926", "#FF9616", "#479B55", "#EEA6FB",
"#DC587D", "#D626FF", "#6E899C", "#00B5F7", "#B68E00",
"#C9FBE5", "#FF0092", "#22FFA7", "#E3EE9E", "#86CE00",
"#BC7196", "#7E7DCD", "#FC6955", "#E48F72"
]
self._colorbar_counter = 0
self._scale_down_colorbar = (
self.legend and
any(s.use_cm for s in self.series) and
any((not s.use_cm) for s in self.series)
)
self._show_2D_vectors = any(s.is_2Dvector for s in self.series)
self._create_renderers()
self._n_annotations = 0
if self.update_event:
self._fig.layout.on_change(
lambda obj, xrange, yrange: self._update_axis_limits(xrange, yrange),
('xaxis', 'range'), ('yaxis', 'range'))
def _update_axis_limits(self, *limits):
"""Update the ranges of data series in order to implement pan/zoom
update events.
Parameters
==========
limits : iterable
Tuples of (min, max) values.
"""
params = self._update_series_ranges(*limits)
self.update_interactive(params)
@property
def fig(self):
"""Returns the figure."""
if len(self.renderers) > 0 and len(self.renderers[0].handles) == 0:
# if the backend was created without showing it
self.draw()
return self._fig
def draw(self):
""" Loop over data renderers, generates numerical data and add it to
the figure. Note that this method doesn't show the plot.
"""
self._process_renderers()
self._update_layout()
self._execute_hooks()
process_series = draw
def _set_piecewise_color(self, s, color):
"""Set the color to the given series"""
if "line_color" not in s.rendering_kw:
# only set the color if the user didn't do that already
s.rendering_kw["line_color"] = color
if not s.is_filled:
s.rendering_kw["marker"] = dict(
color="#E5ECF6",
line=dict(color=color))
@staticmethod
def _do_sum_kwargs(p1, p2):
return p1._copy_kwargs()
def _init_cyclers(self):
start_index_cl, start_index_cm = None, None
if self._use_existing_figure:
# attempt to determine how many lines or surfaces are plotted
# on the user-provided figure
# assume user plotted 3d surfaces using solid colors
count_meshes = sum([
isinstance(c, self.go.Surface) for c in self._fig.data])
count_lines = sum([
isinstance(c, self.go.Scatter) for c in self._fig.data])
start_index_cl = count_lines + count_meshes
super()._init_cyclers(start_index_cl, 0)
tb = type(self)
quivers_colors = (
self.quivers_colors if not tb.quivers_colors
else tb.quivers_colors
)
self._qc = itertools.cycle(quivers_colors)
def _create_colorbar(self, label, sc=False):
"""This method reduces code repetition.
Parameters
==========
label : str
Name to display besides the color bar
sc : boolean
Scale Down the color bar to make room for the legend.
Default to False
"""
k = self._colorbar_counter
self._colorbar_counter += 1
return dict(
x=1 + self._cbs * k,
title=dict(
text=label,
side="right"
),
# scale down the color bar to make room for legend
len=(
self._cbsdf if (sc and (self.legend or (self.legend is None)))
else 1
),
yanchor="bottom",
y=0,
)
def _solid_colorscale(self, s):
# create a solid color to be used when s.use_cm=False
col = s.line_color
if col is None:
col = next(self._cl)
return [[0, col], [1, col]]
def _process_renderers(self):
self._init_cyclers()
if not self._use_existing_figure:
# If this instance visualizes only symbolic expressions,
# I want to clear axes so that each time `.show()` is called there
# won't be repeated handles.
# On the other hand, if the current axes is provided by the user,
# we don't want to erase its content.
self._fig.data = []
for r, s in zip(self.renderers, self.series):
self._check_supported_series(r, s)
r.draw()
def update_interactive(self, params):
"""
Implement the logic to update the data generated by
interactive-widget plots.
Parameters
==========
params : dict
Map parameter-symbols to numeric values.
"""
# Because InteractivePlot doesn't call the show method, the following
# line of code will add the numerical data (if not already present).
if len(self.renderers) > 0 and len(self.renderers[0].handles) == 0:
self.draw()
if self._imodule == "ipywidgets":
with self._fig.batch_update():
self._update_interactive_helper(params)
else:
self._update_interactive_helper(params)
self._set_axes_texts()
self._execute_hooks()
def _update_interactive_helper(self, params):
for r in self.renderers:
if (
r.series.is_interactive
or hasattr(r.series, "_interactive_app_controls")
):
r.update(params)
def _get_data_limits_for_custom_tickers(self):
_min = lambda t: min(t) if len(t) > 0 else 0
_max = lambda t: max(t) if len(t) > 0 else 0
x_min, x_max = [], []
y_min, y_max = [], []
for s in self.series:
if isinstance(s, (LineOver1DRangeSeries, SurfaceOver2DRangeSeries)):
x_min.append(s.ranges[0][1].subs(s.params))
x_max.append(s.ranges[0][2].subs(s.params))
if isinstance(s, SurfaceOver2DRangeSeries):
y_min.append(s.ranges[1][1].subs(s.params))
y_max.append(s.ranges[1][2].subs(s.params))
x_min, y_min = float(_min(x_min)), float(_min(y_min))
x_max, y_max = float(_max(x_max)), float(_max(y_max))
return x_min, x_max, y_min, y_max
def _update_layout(self):
title, xlabel, ylabel, zlabel = self._get_title_and_labels()
show_major_grid = True if self.grid else False
show_minor_grid = True if self.minor_grid else False
major_grid_line_kw = {}
minor_grid_line_kw = {}
if isinstance(self.grid, dict):
major_grid_line_kw = self.grid
if isinstance(self.minor_grid, dict):
minor_grid_line_kw = self.minor_grid
minor_grid_line_kw_x = minor_grid_line_kw.copy()
minor_grid_line_kw_y = minor_grid_line_kw.copy()
# if necessary, apply custom tick formatting
x_tickvals, x_ticktext = None, None
y_tickvals, y_ticktext = None, None
polar_angular_dtick = 30
is_formatter = lambda t: isinstance(t, tick_formatter_multiples_of)
if any(is_formatter(t) for t in [
self.x_ticks_formatter, self.y_ticks_formatter]
):
x_min, x_max, y_min, y_max = self._get_data_limits_for_custom_tickers()
if is_formatter(self.x_ticks_formatter):
if not self.np.isclose(x_min, x_max):
x_tickvals, x_ticktext = self.x_ticks_formatter.PB_ticks(
x_min, x_max)
q = self.x_ticks_formatter.quantity
n = self.x_ticks_formatter.n
n_minor = self.x_ticks_formatter.n_minor
minor_grid_line_kw_x["dtick"] = (q / n) / (n_minor + 1)
polar_angular_dtick = q / n
if is_formatter(self.y_ticks_formatter):
if not self.np.isclose(y_min, y_max):
y_tickvals, y_ticktext = self.y_ticks_formatter.PB_ticks(
y_min, y_max)
q = self.y_ticks_formatter.quantity
n = self.y_ticks_formatter.n
n_minor = self.y_ticks_formatter.n_minor
minor_grid_line_kw_y["dtick"] = (q / n) / (n_minor + 1)
self._fig.update_layout(
template=self.theme,
width=None if not self.size else self.size[0],
height=None if not self.size else self.size[1],
title=r"<b>%s</b>" % ("" if not title else title),
title_x=0.5,
xaxis=dict(
title="" if not xlabel else xlabel,
range=None if not self.xlim else self.xlim,
type=self.xscale,
showgrid=show_major_grid, # thin lines in the background
zeroline=show_major_grid, # thick line at x=0
constrain="domain",
visible=self.axis,
autorange=None if not self.invert_x_axis else "reversed",
tickvals=x_tickvals,
ticktext=x_ticktext,
**major_grid_line_kw
),
yaxis=dict(
title="" if not ylabel else ylabel,
range=None if not self.ylim else self.ylim,
type=self.yscale,
showgrid=show_major_grid, # thin lines in the background
zeroline=show_major_grid, # thick line at x=0
scaleanchor="x" if self.aspect == "equal" else None,
visible=self.axis,
tickvals=y_tickvals,
ticktext=y_ticktext,
**major_grid_line_kw
),
polar=dict(
angularaxis=dict(
direction='counterclockwise',
rotation=0,
thetaunit="radians" if is_formatter(self.x_ticks_formatter) else None,
dtick=polar_angular_dtick,
),
radialaxis=dict(
range=None if not self.ylim else self.ylim
),
sector=None if not self.xlim else self.xlim
),
margin=dict(
t=50,
l=0,
b=0,
r=40
),
showlegend=True if self.legend else False,
scene=dict(
xaxis=dict(
title="" if not xlabel else xlabel,
range=None if not self.xlim else self.xlim,
type=self.xscale,
showgrid=show_major_grid, # thin lines in the background
zeroline=show_major_grid, # thick line at x=0
visible=show_major_grid, # numbers below,
tickvals=x_tickvals,
ticktext=x_ticktext,
),
yaxis=dict(
title="" if not ylabel else ylabel,
range=None if not self.ylim else self.ylim,
type=self.yscale,
showgrid=show_major_grid, # thin lines in the background
zeroline=show_major_grid, # thick line at x=0
visible=show_major_grid, # numbers below,
tickvals=y_tickvals,
ticktext=y_ticktext,
),
zaxis=dict(
title="" if not zlabel else zlabel,
range=None if not self.zlim else self.zlim,
type=self.zscale,
showgrid=show_major_grid, # thin lines in the background
zeroline=show_major_grid, # thick line at x=0
visible=show_major_grid, # numbers below
),
aspectmode=(
"manual" if isinstance(self.aspect, dict)
else (self.aspect if self.aspect != "equal" else "auto")
),
aspectratio=(
self.aspect if isinstance(self.aspect, dict) else None
),
camera=self.camera
),
)
self._fig.update_xaxes(minor=dict(
showgrid=show_minor_grid, **minor_grid_line_kw_x))
self._fig.update_yaxes(minor=dict(
showgrid=show_minor_grid, **minor_grid_line_kw_y))
def _set_axes_texts(self):
title, xlabel, ylabel, zlabel = self._get_title_and_labels()
self._fig.update_layout(
title=r"<b>%s</b>" % ("" if not title else title),
xaxis=dict(
title="" if not xlabel else xlabel,
),
yaxis=dict(
title="" if not ylabel else ylabel,
),
scene=dict(
xaxis=dict(
title="" if not xlabel else xlabel,
),
yaxis=dict(
title="" if not ylabel else ylabel,
),
zaxis=dict(
title="" if not zlabel else zlabel,
),
),
)
def show(self):
"""
Visualize the plot on the screen.
"""
if len(self.renderers) > 0 and len(self.renderers[0].handles) == 0:
self.draw()
self._fig.show()
def save(self, path, **kwargs):
"""
Export the plot to a static picture or to an interactive html file.
Refer to [#fn11]_ and [#fn12]_ to visualize all the available keyword
arguments.
Notes
=====
In order to export static pictures, the user also need to install the
packages listed in [#fn11]_.
References
==========
.. [#fn11] https://plotly.com/python/static-image-export/
.. [#fn12] https://plotly.com/python/interactive-html-export/
"""
if len(self.renderers) > 0 and len(self.renderers[0].handles) == 0:
self.draw()
ext = os.path.splitext(path)[1]
if ext.lower() in [".htm", ".html"]:
self.fig.write_html(path, **kwargs)
else:
self._fig.write_image(path, **kwargs)
PB = PlotlyBackend