Source code for spb.backends.base_backend

from itertools import cycle
from spb.series import BaseSeries
from spb.backends.utils import convert_colormap
from sympy.utilities.iterables import is_sequence
from sympy.external import import_module


[docs]class Plot: """Base class for all backends. A backend represents the plotting library, which implements the necessary functionalities in order to use SymPy plotting functions. How the plotting module works: 1. The user creates the symbolic expressions and calls one of the plotting functions. 2. The plotting functions generate a list of instances of the `BaseSeries` class, containing the necessary information to plot the expressions (eg the expression, ranges, series name, ...). Eventually, these objects will generate the numerical data to be plotted. 3. The plotting functions instantiate the `Plot` class, which stores the list of series and the main attributes of the plot (eg axis labels, title, etc.). Among the keyword arguments, there must be `backend`, a subclass of `Plot` which specify the backend to be used. 4. The backend will render the numerical data to a plot and (eventually) show it on the screen. The figure is populated with numerical data once the `show()` method or the `fig` attribute are called. The backend should check if it supports the data series that it's given. Please, explore the `MatplotlibBackend` source code to understand how a backend should be coded. Also note that setting attributes to plot objects or to data series after they have been instantiated is strongly unrecommended, as it is not guaranteed that the figure will be updated. Notes ===== In order to be used by SymPy plotting functions, a backend must implement the following methods and attributes: * ``show(self)``: used to loop over the data series, generate the numerical data, plot it and set the axis labels, title, ... * ``save(self, path, **kwargs)``: used to save the current plot to the specified file path. * ``self._fig``: an instance attribute to store the backend-specific plot object, which can be retrieved with the `Plot.fig` attribute. This object can then be used to further customize the resulting plot, using backend-specific commands. * ``_update_interactive(self, params)``: this method receives a dictionary mapping parameters to their values from the ``iplot`` function, which are going to be used to update the objects of the figure. Parameters ========== title : str, optional Set the title of the plot. Default to an empty string. xlabel, ylabel, zlabel : str, optional Set the labels of the plot. Default to an empty string. legend : bool, optional Show or hide the legend. By default, the backend will automatically set it to True if multiple data series are shown. xscale, yscale, zscale : str, optional Discretization strategy for the provided domain along the specified direction. Can be either `'linear'` or `'log'`. Default to `'linear'`. If the backend supports it, the specified direction will use the user-provided scale. By default, all backends uses linear scales for both axis. None of the backends support logarithmic scale for 3D plots. grid : bool, optional Show/Hide the grid. The default value depends on the backend. xlim, ylim, zlim : (float, float), optional Focus the plot to the specified range. The tuple must be in the form `(min_val, max_val)`. aspect : (float, float) or str, optional Set the aspect ratio of the plot. It only works for 2D plots. The values depends on the backend. Read the interested backend's documentation to find out the possible values. backend : Plot The subclass to be used to generate the plot. size : (float, float) or None, optional Set the size of the plot, `(width, height)`. Default to None. Examples ======== Combine multiple plots together to create a new plot: .. plot:: :context: reset :format: doctest :include-source: True >>> from sympy import symbols, sin, cos, log, S >>> from spb import plot, plot3d >>> x, y = symbols("x, y") >>> p1 = plot(sin(x), cos(x), show=False) >>> p2 = plot(sin(x) * cos(x), log(x), show=False) >>> p3 = p1 + p2 >>> p3.show() Use the index notation to access the data series. Let's generate the numerical data associated to the first series: .. plot:: :context: close-figs :format: doctest :include-source: True >>> p1 = plot(sin(x), cos(x), show=False) >>> xx, yy = p1[0].get_data() Create a new backend with a custom colorloop: .. plot:: :context: close-figs :format: doctest :include-source: True >>> from spb.backends.matplotlib import MB >>> class MBchild(MB): ... colorloop = ["r", "g", "b"] >>> plot(sin(x) / 3, sin(x) * S(2) / 3, sin(x), backend=MBchild) Create a new backend with custom color maps for 3D plots. Note that it's possible to use Plotly/Colorcet/Matplotlib colormaps interchangeably. .. plot:: :context: close-figs :format: doctest :include-source: True >>> from spb.backends.matplotlib import MB >>> import colorcet as cc >>> class MBchild(MB): ... colormaps = ["plotly3", cc.bmy] >>> plot3d( ... (cos(x**2 + y**2), (x, -2, 0), (y, -2, 2)), ... (cos(x**2 + y**2), (x, 0, 2), (y, -2, 2)), ... backend=MBchild, n1=25, n2=50, use_cm=True) See also ======== MatplotlibBackend, PlotlyBackend, BokehBackend, K3DBackend """ # set the name of the plotting library being used. This is required in # order to convert any colormap to the specified plotting library. _library = "" colorloop = [] """List of colors to be used in line plots or solid color surfaces.""" colormaps = [] """List of color maps to render surfaces.""" cyclic_colormaps = [] """List of cyclic color maps to render complex series (the phase/argument ranges over [-pi, pi]). """ _allowed_keys = ["aspect", "axis", "axis_center", "backend", "detect_poles", "grid", "legend", "show", "size", "title", "use_latex", "xlabel", "ylabel", "zlabel", "xlim", "ylim", "zlim", "xscale", "yscale", "zscale", "process_piecewise", "polar_axis"] """contains a list of public keyword arguments supported by the series. It will be used to validate the user-provided keyword arguments. """ def __new__(cls, *args, **kwargs): backend = cls._get_backend(kwargs) return super().__new__(backend) @classmethod def _get_backend(cls, kwargs): backend = kwargs.get("backend", "matplotlib") if not ((type(backend) == type) and issubclass(backend, cls)): raise TypeError("backend must be a subclass of Plot") return backend def _set_labels(self, wrapper="$%s$"): """Set the correct labels. Parameters ========== use_latex : boolean Wheter the backend is customized to show latex labels. wrapper : str Wrapper string for the latex labels. Default to '$%s$'. """ if not self._use_latex: wrapper = "%s" if callable(self.xlabel): self.xlabel = wrapper % self.xlabel(self._use_latex) if callable(self.ylabel): self.ylabel = wrapper % self.ylabel(self._use_latex) if callable(self.zlabel): self.zlabel = wrapper % self.zlabel(self._use_latex) def __init__(self, *args, **kwargs): # the merge function is used by all backends self._mergedeep = import_module('mergedeep') self.merge = self._mergedeep.merge # Options for the graph as a whole. # The possible values for each option are described in the docstring # of Plot. They are based purely on convention, no checking is done. self.title = kwargs.get("title", None) self.xlabel = kwargs.get("xlabel", None) self.ylabel = kwargs.get("ylabel", None) self.zlabel = kwargs.get("zlabel", None) self.aspect = kwargs.get("aspect", "auto") self.axis_center = kwargs.get("axis_center", None) self.camera = kwargs.get("camera", None) self.grid = kwargs.get("grid", True) self.xscale = kwargs.get("xscale", "linear") self.yscale = kwargs.get("yscale", "linear") self.zscale = kwargs.get("zscale", "linear") self.polar_axis = kwargs.get("polar_axis", None) # NOTE: it would be nice to have detect_poles=True by default. # However, the correct detection also depends on the number of points # and the value of `eps`. Getting the detection right is likely to # be a trial-by-error procedure. Hence, keep this parameter to False. self.detect_poles = kwargs.get("detect_poles", False) # NOTE: matplotlib is not designed to be interactive, therefore it # needs a way to detect where its figure is going to be displayed. # For regular plots, plt.figure can be used. For interactive-parametric # plots matplotlib.figure.Figure must be used. self.is_iplot = kwargs.get("is_iplot", False) # Contains the data objects to be plotted. The backend should be smart # enough to iterate over this list. self._series = [] self._series.extend(args) if "process_piecewise" in kwargs.keys(): # if the backend was called by plot_piecewise, each piecewise # function must use the same color. Here we preprocess each # series to add the correct color series = [] for idx, _series in kwargs["process_piecewise"].items(): color = next(self._cl) for s in _series: self._set_piecewise_color(s, color) series.extend(_series) self._series = series # Automatic legend: if more than 1 data series has been provided # and the user has not set legend=False, then show the legend for # better clarity. self.legend = _legend = kwargs.get("legend", None) if not self.legend: self.legend = False if (len([s for s in self._series if s.show_in_legend]) > 1) or ( any(s.is_parametric and s.use_cm for s in self._series)): # don't show the legend if `plot_piecewise` created this # backend if ((not ("process_piecewise" in kwargs.keys())) and ((_legend is True) or (_legend is None))): self.legend = True if self.legend and (len([s for s in self._series if s.is_3Dsurface and not s.use_cm]) > 1): self.legend = False # Objects used to render/display the plots, which depends on the # plotting library. self._fig = None is_real = lambda lim: all(getattr(i, "is_real", True) for i in lim) is_finite = lambda lim: all(getattr(i, "is_finite", True) for i in lim) # reduce code repetition def check_and_set(t_name, t): if t: if not is_real(t): raise ValueError( "All numbers from {}={} must be real".format(t_name, t)) if not is_finite(t): raise ValueError( "All numbers from {}={} must be finite".format(t_name, t)) setattr(self, t_name, (float(t[0]), float(t[1]))) self.xlim = None check_and_set("xlim", kwargs.get("xlim", None)) self.ylim = None check_and_set("ylim", kwargs.get("ylim", None)) self.zlim = None check_and_set("zlim", kwargs.get("zlim", None)) self.size = None check_and_set("size", kwargs.get("size", None)) def _copy_kwargs(self): """Copy the values of the plot attributes into a dictionary which will be later used to create a new `Plot` object having the same attributes. """ return dict( title=self.title, xlabel=self.xlabel, ylabel=self.ylabel, zlabel=self.zlabel, aspect=self.aspect, axis_center=self.axis_center, grid=self.grid, xscale=self.xscale, yscale=self.yscale, zscale=self.zscale, detect_poles=self.detect_poles, legend=self.legend, xlim=self.xlim, ylim=self.ylim, zlim=self.zlim, size=self.size, is_iplot=self.is_iplot, use_latex=self._use_latex, camera=self.camera, polar_axis=self.polar_axis ) def _init_cyclers(self): """Create infinite loop iterators over the provided color maps.""" tb = type(self) colorloop = self.colorloop if not tb.colorloop else tb.colorloop colormaps = self.colormaps if not tb.colormaps else tb.colormaps cyclic_colormaps = self.cyclic_colormaps if not tb.cyclic_colormaps else tb.cyclic_colormaps if not isinstance(colorloop, (list, tuple)): # assume it is a matplotlib's ListedColormap self.colorloop = colorloop.colors self._cl = cycle(colorloop) colormaps = [convert_colormap(cm, self._library) for cm in colormaps] self._cm = cycle(colormaps) cyclic_colormaps = [ convert_colormap(cm, self._library) for cm in cyclic_colormaps ] self._cyccm = cycle(cyclic_colormaps) def _get_mode(self): """Verify which environment is used to run the code. Returns ======= mode : int 0 - the code is running on Jupyter Notebook or qtconsole 1 - terminal running IPython 2 - other type (?) 3 - probably standard Python interpreter # TODO: detect if we are running in Jupyter Lab. """ # https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook try: shell = get_ipython().__class__.__name__ if shell == "ZMQInteractiveShell": return 0 # Jupyter notebook or qtconsole elif shell == "TerminalInteractiveShell": return 1 # Terminal running IPython else: return 2 # Other type (?) except NameError: return 3 # Probably standard Python interpreter def _use_cyclic_cm(self, param, is_complex): """When using complex_plot and `absarg=True`, it might happens that the argument is not fully covering the range [-pi, pi]. In such occurences, the use of a cyclic colormap would create a misleading plot. """ np = import_module('numpy') eps = 0.1 use_cyclic_cm = False if is_complex: m, M = np.amin(param), np.amax(param) if (m != M) and (abs(abs(m) - np.pi) < eps) and (abs(abs(M) - np.pi) < eps): use_cyclic_cm = True return use_cyclic_cm def _set_piecewise_color(self, s, color): """Set the color to the given series of a piecewise function.""" raise NotImplementedError @property def fig(self): """Returns the figure used to render/display the plots.""" return self._fig @property def series(self): """Returns the series associated to the current plot.""" return self._series def _update_interactive(self, params): """Implement the logic to update the data generated by InteractiveSeries. """ raise NotImplementedError def show(self): """Implement the functionalities to display the plot.""" raise NotImplementedError def save(self, path, **kwargs): """Implement the functionalities to save the plot. Parameters ========== path : str File path with extension. kwargs : dict Optional backend-specific parameters. """ raise NotImplementedError def __str__(self): series_strs = [("[%d]: " % i) + str(s) for i, s in enumerate(self._series)] return "Plot object containing:\n" + "\n".join(series_strs) def __getitem__(self, index): return self._series[index] def __setitem__(self, index, *args): if len(args) == 1 and isinstance(args[0], BaseSeries): self._series[index] = args def __delitem__(self, index): del self._series[index] def __add__(self, other): return self._do_sum(other) def __radd__(self, other): if other == 0: return self return other._do_sum(self) def _do_sum(self, other): """Differently from Plot.extend, this method creates a new plot object, which uses the series of both plots and merges the _kwargs dictionary of `self` with the one of `other`. """ if not isinstance(other, Plot): raise TypeError( "Both sides of the `+` operator must be instances of the Plot " + "class.\n Received: {} + {}".format(type(self), type(other)) ) series = [] series.extend(self.series) series.extend(other.series) kwargs = self._do_sum_kwargs(self, other) # If the first plot (`p1`) of the summation has been created without # specifying `legend`, then `p1.legend` might be False, hence # `kwargs["legend"]` might be False. But when adding multiple plots # it is very likely that user expect a legend to be shown. Hence, # reset legend and let the backend decide if it needs one or not. kwargs["legend"] = None return type(self)(*series, **kwargs) def append(self, arg): """Adds an element from a plot's series to an existing plot. Parameters ========== arg : BaseSeries An instance of `BaseSeries` which will be used to generate the numerical data. Examples ======== Consider two `Plot` objects, `p1` and `p2`. To add the second plot's first series object to the first, use the `append` method, like so: .. plot:: :format: doctest :include-source: True >>> from sympy import symbols >>> from spb import plot >>> x = symbols('x') >>> p1 = plot(x*x, show=False) >>> p2 = plot(x, show=False) >>> p1.append(p2[0]) >>> p1 Plot object containing: [0]: cartesian line: x**2 for x over (-10.0, 10.0) [1]: cartesian line: x for x over (-10.0, 10.0) >>> p1.show() See Also ======== extend """ if isinstance(arg, BaseSeries): self._series.append(arg) # auto legend if len(self._series) > 1: self.legend = True else: raise TypeError("Must specify element of plot to append.") def extend(self, arg): """Adds all series from another plot. Parameters ========== arg : Plot or sequence of BaseSeries Examples ======== Consider two `Plot` objects, `p1` and `p2`. To add the second plot to the first, use the `extend` method, like so: .. plot:: :format: doctest :include-source: True >>> from sympy import symbols >>> from spb import plot >>> x = symbols('x') >>> p1 = plot(x**2, show=False) >>> p2 = plot(x, -x, show=False) >>> p1.extend(p2) >>> p1 Plot object containing: [0]: cartesian line: x**2 for x over (-10.0, 10.0) [1]: cartesian line: x for x over (-10.0, 10.0) [2]: cartesian line: -x for x over (-10.0, 10.0) >>> p1.show() See Also ======== append """ if isinstance(arg, Plot): self._series.extend(arg._series) elif is_sequence(arg) and all([isinstance(a, BaseSeries) for a in arg]): self._series.extend(arg) else: raise TypeError("Expecting Plot or sequence of BaseSeries") # auto legend if len(self._series) > 1: self.legend = True