from itertools import cycle, islice
from spb.series import (
BaseSeries, LineOver1DRangeSeries, ComplexSurfaceBaseSeries
)
from spb.backends.utils import convert_colormap
from sympy import Symbol
from sympy.utilities.iterables import is_sequence
from sympy.external import import_module
[docs]
class Plot:
"""Base class for all backends. A backend represents the plotting library,
which implements the necessary functionalities in order to use SymPy
plotting functions.
How the plotting module works:
1. The user creates the symbolic expressions and calls one of the plotting
functions.
2. The plotting functions generate a list of instances of ``BaseSeries``,
containing the necessary information to generate the appropriate
numerical data and create the proper visualization
(eg the expression, ranges, series name, ...).
3. The plotting functions instantiate the ``Plot`` class, which stores the
list of series and the main attributes of the plot (eg axis labels,
title, etc.). Among the keyword arguments, there must be ``backend=``,
where a subclass of ``Plot`` can be specified in order to use a
particular plotting library.
4. Each data series will be associated to a corresponding renderer, which
receives the numerical data and add it to the figure. A renderer is also
responsible to keep objects up-to-date when interactive-widgets plots
are used. The figure is populated with numerical data when the
``show()`` method or the ``fig`` attribute are called.
The backend should check if it supports the data series that it's given.
Please, explore ``MatplotlibBackend`` source code to understand how a
backend should be coded.
Also note that setting attributes to plot objects or to data series after
they have been instantiated is strongly unrecommended, as it is not
guaranteed that the figure will be updated.
Notes
=====
In order to be used by SymPy plotting functions, a backend must implement
the following methods and attributes:
* ``show(self)``: used to loop over the data series, generate the
numerical data, plot it and set the axis labels, title, ...
* ``save(self, path, **kwargs)``: used to save the current plot to the
specified file path.
* ``self._fig``: an instance attribute to store the backend-specific plot
object, which can be retrieved with the ``Plot.fig`` attribute. This
object can then be used to further customize the resulting plot, using
backend-specific commands.
* ``update_interactive(self, params)``: this method receives a dictionary
mapping parameters to their values from the ``iplot`` function, which
are going to be used to update the objects of the figure.
Parameters
==========
title : str, optional
Set the title of the plot. Default to an empty string.
xlabel, ylabel, zlabel : str, optional
Set the labels of the plot. Default to an empty string.
legend : bool, optional
Show or hide the legend. By default, the backend will automatically
set it to True if multiple data series are shown.
xscale, yscale, zscale : str, optional
Discretization strategy for the provided domain along the specified
direction. Can be either `'linear'` or `'log'`. Default to
`'linear'`. If the backend supports it, the specified direction will
use the user-provided scale. By default, all backends uses linear
scales for both axis. None of the backends support logarithmic scale
for 3D plots.
grid : bool, optional
Show/Hide the grid. The default value depends on the backend.
xlim, ylim, zlim : (float, float), optional
Focus the plot to the specified range. The tuple must be in the form
`(min_val, max_val)`.
aspect : (float, float) or str, optional
Set the aspect ratio of the plot. It only works for 2D plots.
The values depends on the backend. Read the interested backend's
documentation to find out the possible values.
backend : Plot
The subclass to be used to generate the plot.
size : (float, float) or None, optional
Set the size of the plot, `(width, height)`. Default to None.
Examples
========
Combine multiple plots together to create a new plot:
.. plot::
:context: reset
:format: doctest
:include-source: True
>>> from sympy import symbols, sin, cos, log, S
>>> from spb import plot, plot3d
>>> x, y = symbols("x, y")
>>> p1 = plot(sin(x), cos(x), show=False)
>>> p2 = plot(sin(x) * cos(x), log(x), show=False)
>>> p3 = p1 + p2
>>> p3.show()
Use the index notation to access the data series. Let's generate the
numerical data associated to the first series:
.. plot::
:context: close-figs
:format: doctest
:include-source: True
>>> p1 = plot(sin(x), cos(x), show=False)
>>> xx, yy = p1[0].get_data()
Create a new backend with a custom colorloop:
.. plot::
:context: close-figs
:format: doctest
:include-source: True
>>> from spb.backends.matplotlib import MB
>>> class MBchild(MB):
... colorloop = ["r", "g", "b"]
>>> plot(sin(x) / 3, sin(x) * S(2) / 3, sin(x), backend=MBchild)
Plot object containing:
[0]: cartesian line: sin(x)/3 for x over (-10.0, 10.0)
[1]: cartesian line: 2*sin(x)/3 for x over (-10.0, 10.0)
[2]: cartesian line: sin(x) for x over (-10.0, 10.0)
Create a new backend with custom color maps for 3D plots. Note that
it's possible to use Plotly/Colorcet/Matplotlib colormaps interchangeably.
.. plot::
:context: close-figs
:format: doctest
:include-source: True
>>> from spb.backends.matplotlib import MB
>>> import colorcet as cc
>>> class MBchild(MB):
... colormaps = ["plotly3", cc.bmy]
>>> plot3d(
... (cos(x**2 + y**2), (x, -2, 0), (y, -2, 2)),
... (cos(x**2 + y**2), (x, 0, 2), (y, -2, 2)),
... backend=MBchild, n1=25, n2=50, use_cm=True)
Plot object containing:
[0]: cartesian surface: cos(x**2 + y**2) for x over (-2.0, 0.0) and y over (-2.0, 2.0)
[1]: cartesian surface: cos(x**2 + y**2) for x over (0.0, 2.0) and y over (-2.0, 2.0)
See also
========
MatplotlibBackend, PlotlyBackend, BokehBackend, K3DBackend
"""
# set the name of the plotting library being used. This is required in
# order to convert any colormap to the specified plotting library.
_library = ""
colorloop = []
"""List of colors to be used in line plots or solid color surfaces."""
colormaps = []
"""List of color maps to render surfaces."""
cyclic_colormaps = []
"""List of cyclic color maps to render complex series (the phase/argument
ranges over [-pi, pi]).
"""
_allowed_keys = [
"aspect", "axis", "axis_center", "backend",
"detect_poles", "grid", "legend", "show", "size", "title", "use_latex",
"xlabel", "ylabel", "zlabel", "xlim", "ylim", "zlim", "show_axis",
"xscale", "yscale", "zscale", "process_piecewise", "polar_axis",
"imodule", "update_event"
]
"""contains a list of public keyword arguments supported by the series.
It will be used to validate the user-provided keyword arguments.
"""
def _set_labels(self, wrapper="$%s$"):
"""Set the correct axis labels depending on wheter the backend support
Latex rendering.
Parameters
==========
use_latex : boolean
Wheter the backend is customized to show latex labels.
wrapper : str
Wrapper string for the latex labels. Default to '$%s$'.
"""
if not self._use_latex:
wrapper = "%s"
if callable(self.xlabel):
self.xlabel = wrapper % self.xlabel(self._use_latex)
if callable(self.ylabel):
self.ylabel = wrapper % self.ylabel(self._use_latex)
if callable(self.zlabel):
self.zlabel = wrapper % self.zlabel(self._use_latex)
def _set_title(self, wrapper="$%s$"):
"""Set the correct title depending on wheter the backend support
Latex rendering.
Parameters
==========
use_latex : boolean
Wheter the backend is customized to show latex labels.
wrapper : str
Wrapper string for the latex labels. Default to '$%s$'.
"""
if not self._use_latex:
wrapper = "%s"
if callable(self.title):
self.title = self.title(wrapper, self._use_latex)
def _create_parametric_text(self, t, params):
"""Given a tuple of the form `(str, symbol1, symbol2, ...)`
where `str` is a formatted string, read the values from the parameters
`symbol1, symbol2, ...`, pass them to the formatted string to create
an updated text.
"""
if isinstance(t, (tuple, list)):
t_symbols = set().union(*[e.free_symbols for e in t[1:]])
remaining_symbols = set(t_symbols).difference(params.keys())
# TODO: is it worth creating a `ptext` class so that this check
# is only run once at the beginning, and not at every update?
if len(remaining_symbols) > 0:
raise ValueError(
"This parametric text contains symbols that are "
"not part of the `params` dictionary:\n"
f"{t}.\nWhat are these symbols? {remaining_symbols}"
)
values = [
params[s] if isinstance(s, Symbol) else s.subs(params)
for s in t[1:]
]
return t[0].format(*values)
return t
def _get_title_and_labels(self):
"""Returns the appropriate text to be shown on the title and
axis labels.
"""
title = self.title
xlabel = self.xlabel
ylabel = self.ylabel
zlabel = self.zlabel
params = None
if len(self.series) > 0:
# assuming all data series received the same parameters
params = self.series[0].params
if params:
title = self._create_parametric_text(title, params)
xlabel = self._create_parametric_text(xlabel, params)
ylabel = self._create_parametric_text(ylabel, params)
zlabel = self._create_parametric_text(zlabel, params)
return title, xlabel, ylabel, zlabel
def __init__(self, *args, **kwargs):
# the merge function is used by all backends
self._mergedeep = import_module('mergedeep')
self.merge = self._mergedeep.merge
# Options for the graph as a whole.
# The possible values for each option are described in the docstring
# of Plot. They are based purely on convention, no checking is done.
self.title = kwargs.get("title", None)
self.xlabel = kwargs.get("xlabel", None)
self.ylabel = kwargs.get("ylabel", None)
self.zlabel = kwargs.get("zlabel", None)
self.aspect = kwargs.get("aspect", kwargs.get("aspect_ratio", "auto"))
self.axis_center = kwargs.get("axis_center", None)
self.camera = kwargs.get("camera", None)
self.grid = kwargs.get("grid", True)
self.xscale = kwargs.get("xscale", None)
self.yscale = kwargs.get("yscale", None)
self.zscale = kwargs.get("zscale", None)
self.polar_axis = kwargs.get("polar_axis", None)
# NOTE: it would be nice to have detect_poles=True by default.
# However, the correct detection also depends on the number of points
# and the value of `eps`. Getting the detection right is likely to
# be a trial-by-error procedure. Hence, keep this parameter to False.
self.detect_poles = kwargs.get("detect_poles", False)
# NOTE: matplotlib is not designed to be interactive, therefore it
# needs a way to detect where its figure is going to be displayed.
# For regular plots, plt.figure can be used. For interactive-parametric
# plots with holoviz panel, matplotlib.figure.Figure must be used.
self.is_iplot = kwargs.get("is_iplot", False)
# backend might need to create different types of figure depending on
# the interactive module being used
self.imodule = kwargs.get("imodule", None)
# Contains the data objects to be plotted. The backend should be smart
# enough to iterate over this list.
grid_series = [s for s in args if s.is_grid]
non_grid_series = [s for s in args if not s.is_grid]
# grid series must be the last to be rendered: they need to know
# the extension of the area to cover with grids, which can be obtained
# after plotting all other series.
self._series = non_grid_series + grid_series
if "process_piecewise" in kwargs.keys():
# if the backend was called by plot_piecewise, each piecewise
# function must use the same color. Here we preprocess each
# series to add the correct color
series = []
for idx, _series in kwargs["process_piecewise"].items():
color = next(self._cl)
for s in _series:
self._set_piecewise_color(s, color)
series.extend(_series)
self._series = series
# Automatic legend: if more than 1 data series has been provided
# and the user has not set legend=False, then show the legend for
# better clarity.
self.legend = kwargs.get("legend", None)
if self.legend is None:
series_to_show = [
s for s in self._series if (
s.show_in_legend and
(not s.use_cm) and
(not s.is_grid)
)
]
if len(series_to_show) > 1:
# don't show the legend if `plot_piecewise` created this
# backend
if not ("process_piecewise" in kwargs.keys()):
self.legend = True
# allow to invert x-axis if the range is given as (symbol, max, min)
# instead of (symbol, min, max).
# just check the first series.
self._invert_x_axis = False
if (
(len(self._series) > 0) and
isinstance(self._series[0], LineOver1DRangeSeries)
):
# elements of parametric ranges can't be compared because they
# are likely going to be symbolic expressions
if not self._series[0]._interactive_ranges:
r = self._series[0].ranges[0]
# Ranges can be real or complex. Cast them to complex and
# look at the real part.
if complex(r[1]).real > complex(r[2]).real:
self._invert_x_axis = True
# Objects used to render/display the plots, which depends on the
# plotting library.
self._fig = None
is_real = lambda lim: all(getattr(i, "is_real", True) for i in lim)
is_finite = lambda lim: all(getattr(i, "is_finite", True) for i in lim)
# reduce code repetition
def check_and_set(t_name, t):
if t:
if not is_real(t):
raise ValueError(
f"All numbers from {t_name}={t} must be real")
if not is_finite(t):
raise ValueError(
f"All numbers from {t_name}={t} must be finite")
setattr(self, t_name, (float(t[0]), float(t[1])))
self.xlim = None
check_and_set("xlim", kwargs.get("xlim", None))
self.ylim = None
check_and_set("ylim", kwargs.get("ylim", None))
self.zlim = None
check_and_set("zlim", kwargs.get("zlim", None))
self.size = None
check_and_set("size", kwargs.get("size", None))
self.axis = kwargs.get("show_axis", kwargs.get("axis", True))
self._update_event = kwargs.get("update_event", False)
def _copy_kwargs(self):
"""Copy the values of the plot attributes into a dictionary which will
be later used to create a new `Plot` object having the same attributes.
"""
return dict(
title=self.title,
xlabel=self.xlabel,
ylabel=self.ylabel,
zlabel=self.zlabel,
aspect=self.aspect,
axis_center=self.axis_center,
grid=self.grid,
xscale=self.xscale,
yscale=self.yscale,
zscale=self.zscale,
detect_poles=self.detect_poles,
legend=self.legend,
xlim=self.xlim,
ylim=self.ylim,
zlim=self.zlim,
size=self.size,
is_iplot=self.is_iplot,
use_latex=self._use_latex,
camera=self.camera,
polar_axis=self.polar_axis,
axis=self.axis
)
def _init_cyclers(self, start_index_cl=None, start_index_cm=None):
"""Create infinite loop iterators over the provided color maps."""
tb = type(self)
colorloop = self.colorloop if not tb.colorloop else tb.colorloop
colormaps = self.colormaps if not tb.colormaps else tb.colormaps
cyclic_colormaps = self.cyclic_colormaps
if tb.cyclic_colormaps:
cyclic_colormaps = tb.cyclic_colormaps
if not isinstance(colorloop, (list, tuple)):
# assume it is a matplotlib's ListedColormap
self.colorloop = colorloop.colors
self._cl = cycle(colorloop)
colormaps = [convert_colormap(cm, self._library) for cm in colormaps]
self._cm = cycle(colormaps)
cyclic_colormaps = [
convert_colormap(cm, self._library) for cm in cyclic_colormaps
]
self._cyccm = cycle(cyclic_colormaps)
if start_index_cl is not None:
self._cl = islice(self._cl, start_index_cl, None)
if start_index_cm is not None:
self._cm = islice(self._cm, start_index_cm, None)
def _create_renderers(self):
"""Connect data series to appropriate renderers."""
self._renderers = []
for s in self.series:
t = type(s)
if t in self.renderers_map.keys():
self._renderers.append(self.renderers_map[t](self, s))
else:
# NOTE: technically, I could just raise an error at this point.
# However, there are occasions where it might be useful to
# create data series starting from plotting functions, without
# showing the plot. Hence, I raise the error later, if needed.
self._renderers.append(None)
def _check_supported_series(self, renderer, series):
if renderer is None:
raise NotImplementedError(
f"{type(series).__name__} is not associated to any renderer "
f"compatible with {type(self).__name__}. Follow these "
"steps to make it works:\n"
"1. Code an appropriate rendeder class.\n"
f"2. Execute {type(self).__name__}.renderers_map.update"
"({%s})\n" % f"{type(series).__name__}: YourRendererClass"
+ "3. Execute again the plot statement."
)
def _use_cyclic_cm(self, param, is_complex):
"""When using complex_plot and `absarg=True`, it might happens that the
argument is not fully covering the range [-pi, pi]. In such occurences,
the use of a cyclic colormap would create a misleading plot.
"""
np = import_module('numpy')
eps = 0.1
use_cyclic_cm = False
if is_complex:
m, M = np.amin(param), np.amax(param)
if (
(m != M) and
(abs(abs(m) - np.pi) < eps) and
(abs(abs(M) - np.pi) < eps)
):
use_cyclic_cm = True
return use_cyclic_cm
def _set_piecewise_color(self, s, color):
"""Set the color to the given series of a piecewise function."""
raise NotImplementedError
def _update_series_ranges(self, *limits):
"""Update the ranges of data series in order to implement pan/zoom
update events.
Parameters
==========
limits : iterable
Each element is a tuple (min, max).
Returns
=======
all_params : dict
A dictionary containing all the parameters used by the series
"""
all_params = {}
# TODO: can ComplexSurfaceBaseSeries be modified such that it has
# two ranges instead of one? It would allow to simplify this code...
css = [s for s in self._series
if isinstance(s, ComplexSurfaceBaseSeries)]
ncss = [s for s in self._series
if not isinstance(s, ComplexSurfaceBaseSeries)]
for s in ncss:
# skip the ones that don't use `ranges`, like
# List2D/List3D/Arrow2D/Arrow3D
# as well as the parametric series
if (len(s.ranges) > 0) and (not s.is_parametric):
new_ranges = []
for r, l in zip(s.ranges, limits):
if any(len(t.free_symbols) > 0 for t in r[1:]):
# design choice: interactive ranges should not
# be modified
new_ranges.append(r)
else:
new_ranges.append((r[0], *l))
s.ranges = new_ranges
s._interactive_ranges = True
s.is_interactive = True
all_params = self.merge({}, all_params, s.params)
if len(css) > 0:
xlim, ylim = limits[:2]
lim = (xlim[0] + 1j * ylim[0], xlim[1] + 1j * ylim[1])
for s in css:
if all(len(t.free_symbols) == 0 for t in s.ranges[0][1:]):
# design choice: interactive ranges should not
# be modified
s.ranges = [(s.ranges[0][0], *lim)]
s._interactive_ranges = True
s.is_interactive = True
all_params = self.merge({}, all_params, s.params)
return all_params
@property
def fig(self):
"""Returns the figure used to render/display the plots."""
return self._fig
@property
def renderers(self):
"""Returns the renderers associated to each series."""
return self._renderers
@property
def series(self):
"""Returns the series associated to the current plot."""
return self._series
def update_interactive(self, params):
"""Implement the logic to update the data generated by
interactive-widget plots.
Parameters
==========
params : dict
Map parameter-symbols to numeric values.
"""
raise NotImplementedError
def show(self):
"""Implement the functionalities to display the plot."""
raise NotImplementedError
def save(self, path, **kwargs):
"""Implement the functionalities to save the plot.
Parameters
==========
path : str
File path with extension.
kwargs : dict
Optional backend-specific parameters.
"""
raise NotImplementedError
def __str__(self):
series_strs = [
("[%d]: " % i) + str(s) for i, s in enumerate(self._series)
]
return "Plot object containing:\n" + "\n".join(series_strs)
def __getitem__(self, index):
return self._series[index]
def __setitem__(self, index, *args):
if len(args) == 1 and isinstance(args[0], BaseSeries):
self._series[index] = args
def __delitem__(self, index):
del self._series[index]
def __add__(self, other):
return self._do_sum(other)
def __radd__(self, other):
if other == 0:
return self
return other._do_sum(self)
def _do_sum(self, other):
"""Differently from Plot.extend, this method creates a new plot
object, which uses the series of both plots and merges the _kwargs
dictionary of `self` with the one of `other`.
"""
if not isinstance(other, Plot):
raise TypeError(
"Both sides of the `+` operator must be instances of the Plot "
+ "class.\n Received: {} + {}".format(type(self), type(other))
)
series = []
series.extend(self.series)
series.extend(other.series)
kwargs = self._do_sum_kwargs(self, other)
return type(self)(*series, **kwargs)
def append(self, arg):
"""Adds an element from a plot's series to an existing plot.
Parameters
==========
arg : BaseSeries
An instance of `BaseSeries` which will be used to generate the
numerical data.
Examples
========
Consider two `Plot` objects, `p1` and `p2`. To add the
second plot's first series object to the first, use the
`append` method, like so:
.. plot::
:format: doctest
:include-source: True
>>> from sympy import symbols
>>> from spb import plot
>>> x = symbols('x')
>>> p1 = plot(x*x, show=False)
>>> p2 = plot(x, show=False)
>>> p1.append(p2[0])
>>> p1
Plot object containing:
[0]: cartesian line: x**2 for x over (-10.0, 10.0)
[1]: cartesian line: x for x over (-10.0, 10.0)
>>> p1.show()
See Also
========
extend
"""
if isinstance(arg, BaseSeries):
self._series.append(arg)
# auto legend
if len([not s.is_grid for s in self._series]) > 1:
self.legend = True
else:
raise TypeError("Must specify element of plot to append.")
# recreate renderers
self._create_renderers()
def extend(self, arg):
"""Adds all series from another plot.
Parameters
==========
arg : Plot or sequence of BaseSeries
Examples
========
Consider two `Plot` objects, `p1` and `p2`. To add the
second plot to the first, use the `extend` method, like so:
.. plot::
:format: doctest
:include-source: True
>>> from sympy import symbols
>>> from spb import plot
>>> x = symbols('x')
>>> p1 = plot(x**2, show=False)
>>> p2 = plot(x, -x, show=False)
>>> p1.extend(p2)
>>> p1
Plot object containing:
[0]: cartesian line: x**2 for x over (-10.0, 10.0)
[1]: cartesian line: x for x over (-10.0, 10.0)
[2]: cartesian line: -x for x over (-10.0, 10.0)
>>> p1.show()
See Also
========
append
"""
if isinstance(arg, Plot):
self._series.extend(arg._series)
elif (
is_sequence(arg) and
all([isinstance(a, BaseSeries) for a in arg])
):
self._series.extend(arg)
else:
raise TypeError("Expecting Plot or sequence of BaseSeries")
# auto legend
if len([not s.is_grid for s in self._series]) > 1:
self.legend = True
# recreate renderers
self._create_renderers()