Source code for spb.backends.matplotlib

import itertools
from spb.defaults import cfg
from spb.backends.base_backend import Plot
from spb.backends.utils import compute_streamtubes
from sympy import latex
from sympy.external import import_module
from packaging import version

# Global variable
# Set to False when running tests / doctests so that the plots don't show.
_show = True


def unset_show():
    """
    Disable show(). For use in the tests.
    """
    global _show
    _show = False


def _matplotlib_list(interval_list):
    """
    Returns lists for matplotlib `fill` command from a list of bounding
    rectangular intervals
    """
    xlist = []
    ylist = []
    if len(interval_list):
        for intervals in interval_list:
            intervalx = intervals[0]
            intervaly = intervals[1]
            xlist.extend(
                [intervalx.start, intervalx.start, intervalx.end, intervalx.end, None]
            )
            ylist.extend(
                [intervaly.start, intervaly.end, intervaly.end, intervaly.start, None]
            )
    else:
        # XXX Ugly hack. Matplotlib does not accept empty lists for `fill`
        xlist.extend([None, None, None, None])
        ylist.extend([None, None, None, None])
    return xlist, ylist


[docs]class MatplotlibBackend(Plot): """ A backend for plotting SymPy's symbolic expressions using Matplotlib. Parameters ========== aspect : (float, float) or str, optional Set the aspect ratio of a 2D plot. Possible values: * ``"auto"``: Matplotlib will fit the plot in the vibile area. * ``"equal"``: sets equal spacing. * tuple containing 2 float numbers, from which the aspect ratio is computed. This only works for 2D plots. axis_center : (float, float) or str or None, optional Set the location of the intersection between the horizontal and vertical axis in a 2D plot. It can be: * ``None``: traditional layout, with the horizontal axis fixed on the bottom and the vertical axis fixed on the left. This is the default value. * a tuple ``(x, y)`` specifying the exact intersection point. * ``'center'``: center of the current plot area. * ``'auto'``: the intersection point is automatically computed. camera : dict, optional A dictionary of keyword arguments that will be passed to the ``Axes3D.view_init`` method. Refer to [#fn9]_ for more information. rendering_kw : dict, optional A dictionary of keywords/values which is passed to Matplotlib's plot functions to customize the appearance of lines, surfaces, images, contours, quivers, streamlines... To learn more about customization: * Refer to [#fn1]_ to customize contour plots. * Refer to [#fn2]_ to customize image plots. * Refer to [#fn3]_ to customize solid line plots. * Refer to [#fn4]_ to customize colormap-based line plots. * Refer to [#fn5]_ to customize quiver plots. * Refer to [#fn6]_ to customize surface plots. * Refer to [#fn7]_ to customize stramline plots. * Refer to [#fn8]_ to customize 3D scatter plots. use_cm : boolean, optional If True, apply a color map to the mesh/surface or parametric lines. If False, solid colors will be used instead. Default to True. annotations : list, optional A list of dictionaries specifying the type of annotation required. The keys in the dictionary should be equivalent to the arguments of the `matplotlib.axes.Axes.annotate` method. This feature is experimental. It might get removed in the future. markers : list, optional A list of dictionaries specifying the type the markers required. The keys in the dictionary should be equivalent to the arguments of the `matplotlib.pyplot.plot()` function along with the marker related keyworded arguments. This feature is experimental. It might get removed in the future. rectangles : list, optional A list of dictionaries specifying the dimensions of the rectangles to be plotted. The keys in the dictionary should be equivalent to the arguments of the `matplotlib.patches.Rectangle` class. This feature is experimental. It might get removed in the future. fill : dict, optional A dictionary specifying the type of color filling required in the plot. The keys in the dictionary should be equivalent to the arguments of the `matplotlib.axes.Axes.fill_between` method. This feature is experimental. It might get removed in the future. References ========== .. [#fn1] https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.contourf.html .. [#fn2] https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.imshow.html .. [#fn3] https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.plot.html .. [#fn4] https://matplotlib.org/stable/api/collections_api.html#matplotlib.collections.LineCollection .. [#fn5] https://matplotlib.org/stable/api/quiver_api.html#module-matplotlib.quiver .. [#fn6] https://matplotlib.org/stable/api/_as_gen/mpl_toolkits.mplot3d.axes3d.Axes3D.html#mpl_toolkits.mplot3d.axes3d.Axes3D.plot_surface .. [#fn7] https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.streamplot.html#matplotlib.axes.Axes.streamplot .. [#fn8] https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.scatter.html .. [#fn9] https://matplotlib.org/stable/api/_as_gen/mpl_toolkits.mplot3d.axes3d.Axes3D.html#mpl_toolkits.mplot3d.axes3d.Axes3D.view_init See also ======== Plot, PlotlyBackend, BokehBackend, K3DBackend """ _library = "matplotlib" _allowed_keys = Plot._allowed_keys + [ "markers", "annotations", "fill", "rectangles", "camera"] wireframe_color = "k" colormaps = [] cyclic_colormaps = [] def __new__(cls, *args, **kwargs): return object.__new__(cls) def __init__(self, *args, **kwargs): self.matplotlib = import_module( 'matplotlib', import_kwargs={'fromlist': ['pyplot', 'cm', 'collections', 'colors']}, warn_not_installed=True, min_module_version='1.1.0', catch=(RuntimeError,)) self.plt = self.matplotlib.pyplot self.cm = cm = self.matplotlib.cm self.LineCollection = self.matplotlib.collections.LineCollection self.ListedColormap = self.matplotlib.colors.ListedColormap self.Line2D = self.matplotlib.lines.Line2D self.Rectangle = self.matplotlib.patches.Rectangle self.Normalize = self.matplotlib.colors.Normalize # set default colors self.colormaps = [ cm.viridis, cm.autumn, cm.winter, cm.plasma, cm.jet, cm.gnuplot, cm.brg, cm.coolwarm, cm.cool, cm.summer] self.cyclic_colormaps = [cm.twilight, cm.hsv] # load default colorloop self.colorloop = self.plt.rcParams['axes.prop_cycle'].by_key()["color"] self._init_cyclers() super().__init__(*args, **kwargs) # set labels self._use_latex = kwargs.get("use_latex", cfg["matplotlib"]["use_latex"]) self._set_labels() if ((len([s for s in self._series if s.is_2Dline]) > 10) and (not type(self).colorloop) and not ("process_piecewise" in kwargs.keys())): # add colors if needed self.colorloop = cm.tab20.colors # plotgrid() can provide its figure and axes to be populated with # the data from the series. self._plotgrid_fig = kwargs.pop("fig", None) self._plotgrid_ax = kwargs.pop("ax", None) if self.axis_center is None: self.axis_center = cfg["matplotlib"]["axis_center"] self.grid = kwargs.get("grid", cfg["matplotlib"]["grid"]) self._show_minor_grid = kwargs.get("show_minor_grid", cfg["matplotlib"]["show_minor_grid"]) self._handles = dict() self._legend_handles = [] def _set_piecewise_color(self, s, color): """Set the color to the given series""" if "color" not in s.rendering_kw: # only set the color if the user didn't do that already s.rendering_kw["color"] = color @staticmethod def _do_sum_kwargs(p1, p2): return p1._copy_kwargs() def _init_cyclers(self): super()._init_cyclers() np = import_module('numpy') # For flexibily, spb.backends.utils.convert_colormap returns numpy # ndarrays whenever plotly/colorcet/k3d color map are given. Here we # create ListedColormap that can be used by Matplotlib def process_iterator(it, colormaps): cm = [] for i in range(len(colormaps)): c = next(it) cm.append(c if not isinstance(c, np.ndarray) else self.ListedColormap(c)) return itertools.cycle(cm) self._cm = process_iterator(self._cm, self.colormaps) self._cyccm = process_iterator(self._cyccm, self.cyclic_colormaps) def _create_figure(self): if self._plotgrid_fig is not None: self._fig = self._plotgrid_fig self._ax = self._plotgrid_ax else: if self.is_iplot and (self.imodule == "panel"): self._fig = self.matplotlib.figure.Figure(figsize=self.size) else: self._fig = self.plt.figure(figsize=self.size) is_3D = [s.is_3D for s in self.series] if any(is_3D) and (not all(is_3D)): # allow sum of 3D plots with contour plots if not all(s.is_3D or s.is_contour for s in self.series): raise ValueError( "MatplotlibBackend can not mix 2D and 3D.") kwargs = dict() if any(is_3D): kwargs["projection"] = "3d" elif (self.polar_axis and any(s.is_2Dline or s.is_contour for s in self.series)): kwargs["projection"] = "polar" self._ax = self._fig.add_subplot(1, 1, 1, **kwargs) def _create_ax_if_not_available(self): if (not hasattr(self, "_ax")): # if the backend was created without showing it self.process_series() @property def fig(self): """Returns the figure.""" self._create_ax_if_not_available() return self._fig @property def ax(self): """Returns the axis used for the plot. Notes ===== To get the axis of a colorbar, index ``p.fig.axes`` where ``p`` is a plot object. ``p.fig.axes[0]`` corresponds to ``p.ax``. """ self._create_ax_if_not_available() return self._ax @staticmethod def get_segments(x, y, z=None): """ Convert two list of coordinates to a list of segments to be used with Matplotlib's LineCollection. Parameters ========== x: list List of x-coordinates y: list List of y-coordinates z: list List of z-coordinates for a 3D line. """ np = import_module('numpy') if z is not None: dim = 3 points = (x, y, z) else: dim = 2 points = (x, y) points = np.ma.array(points).T.reshape(-1, 1, dim) return np.ma.concatenate([points[:-1], points[1:]], axis=1) def _add_colorbar(self, c, label, use_cm, override=False, norm=None, cmap=None): """Add a colorbar for the specificied collection Parameters ========== c : collection label : str override : boolean For parametric plots the colorbar acts like a legend. Hence, when legend=False we don't display the colorbar. However, for contour plots the colorbar is essential to understand it. Hence, to show it we set override=True. Default to False. """ np = import_module('numpy') # design choice: instead of showing a legend entry (which # would require to work with proxy artists and custom # classes in order to create a gradient line), just show a # colorbar with the name of the expression on the side. if (self.legend and use_cm) or override: if norm is None: cb = self._fig.colorbar(c, ax=self._ax) else: mappable = self.cm.ScalarMappable(cmap=cmap, norm=norm) cb = self._fig.colorbar(mappable) cb.set_label(label, rotation=90) return True return False def _add_handle(self, i, h, kw=None, *args): """self._handle is a dictionary which will be used with iplot. In particular: key: integer corresponding to the i-th series. value: a list of two elements: 1. handle of the object created by Matplotlib commands 2. optionally, keyword arguments used to create the handle. Some object can't be updated, hence we need to reconstruct it from scratch at every update. 3. anything else needed to reconstruct the object. """ self._handles[i] = [h if not isinstance(h, (list, tuple)) else h[0], kw, *args] def _process_series(self, series): np = import_module('numpy') mpl_toolkits = import_module( 'mpl_toolkits', # noqa import_kwargs={'fromlist': ['mplot3d']}, catch=(RuntimeError,)) Line3DCollection = mpl_toolkits.mplot3d.art3d.Line3DCollection merge = self.merge # XXX Workaround for matplotlib issue # https://github.com/matplotlib/matplotlib/issues/17130 xlims, ylims, zlims = [], [], [] self._ax.cla() self._init_cyclers() self._legend_handles = [] for i, s in enumerate(series): kw = None if s.is_2Dline: if s.is_parametric: x, y, param = s.get_data() else: x, y = s.get_data() if s.is_parametric and s.use_cm: colormap = ( next(self._cyccm) if self._use_cyclic_cm(param, s.is_complex) else next(self._cm) ) if not s.is_point: lkw = dict(array=param, cmap=colormap) kw = merge({}, lkw, s.rendering_kw) segments = self.get_segments(x, y) c = self.LineCollection(segments, **kw) self._ax.add_collection(c) else: lkw = dict(c=param, cmap=colormap) kw = merge({}, lkw, s.rendering_kw) c = self._ax.scatter(x, y, **kw) is_cb_added = self._add_colorbar(c, s.get_label(self._use_latex), s.use_cm) self._add_handle(i, c, kw, is_cb_added, self._fig.axes[-1]) else: color = next(self._cl) if s.line_color is None else s.line_color lkw = dict(label=s.get_label(self._use_latex), color=color) if s.is_point: lkw["marker"] = "o" lkw["linestyle"] = "None" if not s.is_filled: lkw["markerfacecolor"] = (1, 1, 1) kw = merge({}, lkw, s.rendering_kw) l = self._ax.plot(x, y, **kw) self._add_handle(i, l) elif s.is_contour: x, y, z = s.get_data() ckw = dict(cmap=next(self._cm)) if any(s.is_vector and (not s.is_streamlines) for s in self.series): # NOTE: # When plotting and updating a vector plot containing both # a contour series and a quiver series, because it's not # possible to update contour objects (we can only remove # and recreating them), the quiver series which is usually # after the contour plot (in terms of rendering order) will # be moved on top, resulting in the contour to hide the # quivers. Setting zorder appears to fix the problem. ckw["zorder"] = 0 kw = merge({}, ckw, s.rendering_kw) func = self._ax.contourf if s.is_filled else self._ax.contour c = func(x, y, z, **kw) clabel = None if s.is_filled: self._add_colorbar(c, s.get_label(self._use_latex), s.use_cm, True) else: if s.show_clabels: clabel = self._ax.clabel(c) self._add_handle(i, c, kw, self._fig.axes[-1], clabel) elif s.is_3Dline: if s.is_parametric: x, y, z, param = s.get_data() else: x, y, z = s.get_data() param = np.ones_like(x) lkw = dict() if not s.is_point: if s.use_cm: segments = self.get_segments(x, y, z) lkw["cmap"] = next(self._cm) lkw["array"] = param kw = merge({}, lkw, s.rendering_kw) c = Line3DCollection(segments, **kw) self._ax.add_collection(c) self._add_colorbar(c, s.get_label(self._use_latex), s.use_cm) self._add_handle(i, c, kw, self._fig.axes[-1]) else: lkw["label"] = s.get_label(self._use_latex) kw = merge({}, lkw, s.rendering_kw, ({} if s.line_color is None else {"color": s.line_color}) if s.show_in_legend else {"color": self.wireframe_color}) l = self._ax.plot(x, y, z, **kw) self._add_handle(i, l) else: if s.use_cm: lkw["cmap"] = next(self._cm) lkw["c"] = param else: # lkw["c"] = param lkw["color"] = next(self._cl) if s.line_color is None else s.line_color if not s.is_filled: lkw["facecolors"] = "none" lkw["alpha"] = 1 kw = merge({}, lkw, s.rendering_kw) l = self._ax.scatter(x, y, z, **kw) if s.use_cm: self._add_colorbar(l, s.get_label(self._use_latex), s.use_cm) self._add_handle(i, l, kw, self._fig.axes[-1]) else: self._add_handle(i, l) xlims.append((np.amin(x), np.amax(x))) ylims.append((np.amin(y), np.amax(y))) zlims.append((np.amin(z), np.amax(z))) elif (s.is_3Dsurface and (not s.is_domain_coloring) and (not s.is_implicit)): if not s.is_parametric: x, y, z = self.series[i].get_data() facecolors = s.eval_color_func(x, y, z) else: x, y, z, u, v = self.series[i].get_data() facecolors = s.eval_color_func(x, y, z, u, v) skw = dict(rstride=1, cstride=1, linewidth=0.1) norm, cmap = None, None if s.use_cm: vmin = s.rendering_kw.get("vmin", np.amin(facecolors)) vmax = s.rendering_kw.get("vmax", np.amax(facecolors)) norm = self.Normalize(vmin=vmin, vmax=vmax) cmap = next(self._cm) skw["cmap"] = cmap else: skw["color"] = next(self._cl) if s.surface_color is None else s.surface_color kw = merge({}, skw, s.rendering_kw) if s.use_cm: # facecolors must be computed here because s.rendering_kw # might have its own cmap cmap = kw["cmap"] if isinstance(cmap, str): cmap = self.cm.get_cmap(cmap) kw["facecolors"] = cmap(norm(facecolors)) c = self._ax.plot_surface(x, y, z, **kw) is_cb_added = self._add_colorbar(c, s.get_label(self._use_latex), s.use_cm, norm=norm, cmap=cmap) self._add_handle(i, c, kw, is_cb_added, self._fig.axes[-1]) xlims.append((np.amin(x), np.amax(x))) ylims.append((np.amin(y), np.amax(y))) zlims.append((np.amin(z), np.amax(z))) elif s.is_implicit and not s.is_3Dsurface: points = s.get_data() if len(points) == 2: # interval math plotting x, y = _matplotlib_list(points[0]) fkw = {"color": next(self._cl), "edgecolor": "None"} kw = merge({}, fkw, s.rendering_kw) c = self._ax.fill(x, y, **kw) self._add_handle(i, c, kw) proxy_artist = self.Rectangle((0, 0), 1, 1, color=kw["color"], label=s.get_label(self._use_latex)) else: # use contourf or contour depending on whether it is # an inequality or equality. xarray, yarray, zarray, plot_type = points color = next(self._cl) if plot_type == "contour": colormap = self.ListedColormap([color, color]) ckw = dict(cmap=colormap) kw = merge({}, ckw, s.rendering_kw) c = self._ax.contour(xarray, yarray, zarray, [0.0], **kw) proxy_artist = self.Line2D([], [], color=color, label=s.get_label(self._use_latex)) else: colormap = self.ListedColormap(["#ffffff00", color]) ckw = dict(cmap=colormap) kw = merge({}, ckw, s.rendering_kw) c = self._ax.contourf(xarray, yarray, zarray, **kw) proxy_artist = self.Rectangle((0, 0), 1, 1, color=color, label=s.get_label(self._use_latex)) self._add_handle(i, c, kw) self._legend_handles.append(proxy_artist) elif s.is_vector: if s.is_2Dvector: xx, yy, uu, vv = s.get_data() mag = np.sqrt(uu ** 2 + vv ** 2) uu0, vv0 = [t.copy() for t in [uu, vv]] if s.normalize: uu, vv = [t / mag for t in [uu, vv]] if s.is_streamlines: skw = dict() if (not s.use_quiver_solid_color) and s.use_cm: color_val = mag if s.color_func is not None: color_val = s.eval_color_func(xx, yy, uu0, vv0) skw["cmap"] = next(self._cm) skw["color"] = color_val kw = merge({}, skw, s.rendering_kw) sp = self._ax.streamplot(xx, yy, uu, vv, **kw) is_cb_added = self._add_colorbar( sp.lines, s.get_label(self._use_latex), s.use_cm) else: skw["color"] = next(self._cl) kw = merge({}, skw, s.rendering_kw) sp = self._ax.streamplot(xx, yy, uu, vv, **kw) is_cb_added = False self._add_handle(i, sp, kw, is_cb_added, self._fig.axes[-1]) else: qkw = dict() if any(s.is_contour for s in self.series): # NOTE: # When plotting and updating a vector plot # containing both a contour series and a quiver # series, because it's not possible to update # contour objects (we can only remove and # recreating them), the quiver series which is # usually after the contour plot (in terms of # rendering order) will be moved on top, resulting # in the contour to hide the quivers. Setting # zorder appears to fix the problem. qkw["zorder"] = 1 if (not s.use_quiver_solid_color) and s.use_cm: # don't use color map if a scalar field is # visible or if use_cm=False color_val = mag if s.color_func is not None: color_val = s.eval_color_func(xx, yy, uu0, vv0) qkw["cmap"] = next(self._cm) kw = merge({}, qkw, s.rendering_kw) q = self._ax.quiver(xx, yy, uu, vv, color_val, **kw) is_cb_added = self._add_colorbar( q, s.get_label(self._use_latex), s.use_cm) else: is_cb_added = False qkw["color"] = next(self._cl) kw = merge({}, qkw, s.rendering_kw) q = self._ax.quiver(xx, yy, uu, vv, **kw) self._add_handle(i, q, kw, is_cb_added, self._fig.axes[-1]) else: xx, yy, zz, uu, vv, ww = s.get_data() mag = np.sqrt(uu ** 2 + vv ** 2 + ww ** 2) uu0, vv0, zz0 = [t.copy() for t in [uu, vv, ww]] if s.normalize: uu, vv, ww = [t / mag for t in [uu, vv, ww]] if s.is_streamlines: vertices, color_val = compute_streamtubes( xx, yy, zz, uu, vv, ww, s.rendering_kw, s.color_func) lkw = dict() stream_kw = s.rendering_kw.copy() # remove rendering-unrelated keywords for k in ["starts", "max_prop", "npoints", "radius"]: if k in stream_kw.keys(): stream_kw.pop(k) if s.use_cm: segments = self.get_segments( vertices[:, 0], vertices[:, 1], vertices[:, 2]) lkw["cmap"] = next(self._cm) lkw["array"] = color_val kw = merge({}, lkw, stream_kw) c = Line3DCollection(segments, **kw) self._ax.add_collection(c) self._add_colorbar(c, s.get_label(self._use_latex), s.use_cm) self._add_handle(i, c) else: lkw["label"] = s.get_label(self._use_latex) kw = merge({}, lkw, stream_kw) l = self._ax.plot(vertices[:, 0], vertices[:, 1], vertices[:, 2], **kw) self._add_handle(i, l) xlims.append((np.amin(xx), np.amax(xx))) ylims.append((np.amin(yy), np.amax(yy))) zlims.append((np.amin(zz), np.amax(zz))) else: qkw = dict() if s.use_cm: # NOTE: each quiver is composed of 3 lines: the # stem and two segments for the head. I could set # the colors keyword argument in order to apply # the same color to the entire quiver, like this: # [c1, c2, ..., cn, c1, c1, c2, c2, ... cn, cn] # However, it doesn't appear to work reliably, so # I'll keep things simpler. color_val = mag if s.color_func is not None: color_val = s.eval_color_func( xx, yy, zz, uu0, vv0, zz0) qkw["cmap"] = next(self._cm) qkw["array"] = color_val.flatten() kw = merge({}, qkw, s.rendering_kw) q = self._ax.quiver(xx, yy, zz, uu, vv, ww, **kw) is_cb_added = self._add_colorbar( q, s.get_label(self._use_latex), s.use_cm) else: qkw["color"] = next(self._cl) kw = merge({}, qkw, s.rendering_kw) q = self._ax.quiver(xx, yy, zz, uu, vv, ww, **kw) is_cb_added = False self._add_handle(i, q, kw, is_cb_added, self._fig.axes[-1]) xlims.append((np.amin(xx), np.amax(xx))) ylims.append((np.amin(yy), np.amax(yy))) zlims.append((np.nanmin(zz), np.nanmax(zz))) elif s.is_complex: if not s.is_3Dsurface: x, y, _, _, img, colors = s.get_data() ikw = dict( extent=[np.amin(x), np.amax(x), np.amin(y), np.amax(y)], interpolation="nearest", origin="lower", ) kw = merge({}, ikw, s.rendering_kw) image = self._ax.imshow(img, **kw) self._add_handle(i, image, kw) # chroma/phase-colorbar if colors is not None: colors = colors / 255.0 colormap = self.ListedColormap(colors) norm = self.Normalize(vmin=-np.pi, vmax=np.pi) cb2 = self._fig.colorbar( self.cm.ScalarMappable(norm=norm, cmap=colormap), orientation="vertical", label="Argument", ticks=[-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi], ax=self._ax, ) cb2.ax.set_yticklabels( [r"-$\pi$", r"-$\pi / 2$", "0", r"$\pi / 2$", r"$\pi$"] ) else: x, y, mag, arg, facecolors, colorscale = s.get_data() skw = dict(rstride=1, cstride=1, linewidth=0.1) if s.use_cm: skw["facecolors"] = facecolors / 255 else: skw["color"] = next(self._cl) if s.surface_color is None else s.surface_color kw = merge({}, skw, s.rendering_kw) c = self._ax.plot_surface(x, y, mag, **kw) if s.use_cm and (colorscale is not None): if len(colorscale.shape) == 3: colorscale = colorscale.reshape((-1, 3)) else: colorscale = colorscale / 255.0 # this colorbar is essential to understand the plot. # Always show it, except when use_cm=False norm = self.Normalize(vmin=-np.pi, vmax=np.pi) mappable = self.cm.ScalarMappable( cmap=self.ListedColormap(colorscale), norm=norm ) cb = self._fig.colorbar( mappable, orientation="vertical", label="Argument", ticks=[-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi], ax=self._ax, ) cb.ax.set_yticklabels( [r"-$\pi$", r"-$\pi / 2$", "0", r"$\pi / 2$", r"$\pi$"] ) self._add_handle(i, c, kw) xlims.append((np.amin(x), np.amax(x))) ylims.append((np.amin(y), np.amax(y))) zlims.append((np.amin(mag), np.amax(mag))) elif s.is_geometry: x, y = s.get_data() color = next(self._cl) fkw = dict(facecolor=color, fill=s.is_filled, edgecolor=color) kw = merge({}, fkw, s.rendering_kw) c = self._ax.fill(x, y, **kw) self._add_handle(i, c, kw) elif s.is_generic: if s.type == "markers": kw = merge({}, {"color": next(self._cl)}, s.rendering_kw) self._ax.plot(*s.args, **kw) elif s.type == "annotations": self._ax.annotate(*s.args, **s.rendering_kw) elif s.type == "fill": kw = merge({}, {"color": next(self._cl)}, s.rendering_kw) self._ax.fill_between(*s.args, **kw) elif s.type == "rectangles": kw = merge({}, {"color": next(self._cl)}, s.rendering_kw) self._ax.add_patch( self.matplotlib.patches.Rectangle(*s.args, **kw)) else: raise NotImplementedError( "{} is not supported by {}\n".format(type(s), type(self).__name__) ) Axes3D = mpl_toolkits.mplot3d.Axes3D # Set global options. # TODO The 3D stuff # XXX The order of those is important. if self.xscale and not isinstance(self._ax, Axes3D): self._ax.set_xscale(self.xscale) if self.yscale and not isinstance(self._ax, Axes3D): self._ax.set_yscale(self.yscale) if self.axis_center: val = self.axis_center if isinstance(self._ax, Axes3D): pass elif val == "center": self._ax.spines["left"].set_position("center") self._ax.spines["bottom"].set_position("center") self._ax.yaxis.set_ticks_position("left") self._ax.xaxis.set_ticks_position("bottom") self._ax.spines["right"].set_visible(False) self._ax.spines["top"].set_visible(False) elif val == "auto": xl, xh = self._ax.get_xlim() yl, yh = self._ax.get_ylim() pos_left = ("data", 0) if xl * xh <= 0 else "center" pos_bottom = ("data", 0) if yl * yh <= 0 else "center" self._ax.spines["left"].set_position(pos_left) self._ax.spines["bottom"].set_position(pos_bottom) self._ax.yaxis.set_ticks_position("left") self._ax.xaxis.set_ticks_position("bottom") self._ax.spines["right"].set_visible(False) self._ax.spines["top"].set_visible(False) else: self._ax.spines["left"].set_position(("data", val[0])) self._ax.spines["bottom"].set_position(("data", val[1])) self._ax.yaxis.set_ticks_position("left") self._ax.xaxis.set_ticks_position("bottom") self._ax.spines["right"].set_visible(False) self._ax.spines["top"].set_visible(False) if self.grid: if isinstance(self._ax, Axes3D): self._ax.grid() else: self._ax.grid(visible=True, which='major', linestyle='-', linewidth=0.75, color='0.75') self._ax.grid(visible=True, which='minor', linestyle='--', linewidth=0.6, color='0.825') if self._show_minor_grid: self._ax.minorticks_on() if self.legend: if len(self._legend_handles) > 0: self._ax.legend(handles=self._legend_handles, loc="best") else: handles, _ = self._ax.get_legend_handles_labels() # Show the legend only if there are legend entries. # For example, if we are plotting only parametric expressions, # there will be only colorbars, no legend entries. if len(handles) > 0: self._ax.legend(loc="best") if self.title: self._ax.set_title(self.title) if self.xlabel: self._ax.set_xlabel( self.xlabel, position=(1, 0) if self.axis_center else (0.5, 0) ) if self.ylabel: self._ax.set_ylabel( self.ylabel, position=(0, 1) if self.axis_center else (0, 0.5) ) if isinstance(self._ax, Axes3D): if self.zlabel: self._ax.set_zlabel(self.zlabel, position=(0, 1)) if self.camera is not None: self._ax.view_init(**self.camera) self._set_lims(xlims, ylims, zlims) self._set_aspect() def _set_aspect(self): aspect = self.aspect current_version = version.parse(self.matplotlib.__version__) v_3_6_0 = version.parse("3.6.0") if isinstance(aspect, str): if (aspect == "equal") and (current_version < v_3_6_0): if any(s.is_3D for s in self.series): # plot_vector uses aspect="equal" by default. Older # matplotlib versions do not support equal 3D axis. aspect = "auto" elif hasattr(aspect, "__iter__"): aspect = float(aspect[1]) / aspect[0] else: aspect = "auto" self._ax.set_aspect(aspect) def _get_plotting_func_name(self, t): if t == "markers": return "scatter" elif t == "annotations": return "annotate" elif t == "fills": return "fill" elif t == "rectangles": return "rect" raise ValueError("%s is not supported by MatplotlibBackend" % t) def _set_lims(self, xlims, ylims, zlims): np = import_module('numpy') mpl_toolkits = import_module( 'mpl_toolkits', # noqa import_kwargs={'fromlist': ['mplot3d']}, catch=(RuntimeError,)) Axes3D = mpl_toolkits.mplot3d.Axes3D if not isinstance(self._ax, Axes3D): self._ax.autoscale_view( scalex=self._ax.get_autoscalex_on(), scaley=self._ax.get_autoscaley_on() ) # HACK: in order to make interactive contour plots to scale to # the appropriate range if xlims and ( any(s.is_contour for s in self.series) or any(s.is_vector and (not s.is_3D) for s in self.series) or any(s.is_2Dline and s.is_parametric for s in self.series) ): xlims = np.array(xlims) xlim = (np.nanmin(xlims[:, 0]), np.nanmax(xlims[:, 1])) self._ax.set_xlim(xlim) if ylims and ( any(s.is_contour for s in self.series) or any(s.is_2Dline and s.is_parametric for s in self.series) ): ylims = np.array(ylims) ylim = (np.nanmin(ylims[:, 0]), np.nanmax(ylims[:, 1])) self._ax.set_ylim(ylim) else: # XXX Workaround for matplotlib issue # https://github.com/matplotlib/matplotlib/issues/17130 if xlims: xlims = np.array(xlims) xlim = (np.nanmin(xlims[:, 0]), np.nanmax(xlims[:, 1])) self._ax.set_xlim(xlim) else: self._ax.set_xlim([0, 1]) if ylims: ylims = np.array(ylims) ylim = (np.nanmin(ylims[:, 0]), np.nanmax(ylims[:, 1])) self._ax.set_ylim(ylim) else: self._ax.set_ylim([0, 1]) if zlims: zlims = np.array(zlims) zlim = [np.nanmin(zlims[:, 0]), np.nanmax(zlims[:, 1])] if np.isnan(zlim[0]): zlim[0] = -10 if not np.isnan(zlim[1]): zlim[0] = zlim[1] - 10 if np.isnan(zlim[1]): zlim[1] = 10 zlim = (-10 if np.isnan(z) else z for z in zlim) self._ax.set_zlim(zlim) else: self._ax.set_zlim([0, 1]) # xlim and ylim should always be set at last so that plot limits # doesn't get altered during the process. if self.xlim: self._ax.set_xlim(self.xlim) if self.ylim: self._ax.set_ylim(self.ylim) if self.zlim: self._ax.set_zlim(self.zlim) def _update_colorbar(self, cax, cmap, label, param=None, norm=None): """This method reduces code repetition. The name is misleading: here we create a new colorbar which will be placed on the same colorbar axis as the original. """ np = import_module('numpy') cax.clear() if norm is None: norm = self.Normalize(vmin=np.amin(param), vmax=np.amax(param)) mappable = self.cm.ScalarMappable(cmap=cmap, norm=norm) self._fig.colorbar(mappable, orientation="vertical", label=label, cax=cax) def update_interactive(self, params): """Implement the logic to update the data generated by interactive-widget plots. Parameters ========== params : dict Map parameter-symbols to numeric values. """ np = import_module('numpy') mpl_toolkits = import_module( 'mpl_toolkits', # noqa import_kwargs={'fromlist': ['mplot3d']}, catch=(RuntimeError,)) Line3DCollection = mpl_toolkits.mplot3d.art3d.Line3DCollection Path3DCollection = mpl_toolkits.mplot3d.art3d.Path3DCollection # iplot doesn't call the show method. The following line of # code will add the numerical data (if not already present). if len(self._handles) == 0: self.process_series() xlims, ylims, zlims = [], [], [] for i, s in enumerate(self.series): if s.is_interactive: self.series[i].params = params if s.is_2Dline: if s.is_parametric and s.use_cm: x, y, param = self.series[i].get_data() kw, is_cb_added, cax = self._handles[i][1:] if not s.is_point: segments = self.get_segments(x, y) self._handles[i][0].set_segments(segments) else: self._handles[i][0].set_offsets(np.c_[x,y]) self._handles[i][0].set_array(param) self._handles[i][0].set_clim( vmin=min(param), vmax=max(param)) if is_cb_added: norm = self.Normalize(vmin=np.amin(param), vmax=np.amax(param)) self._update_colorbar(cax, kw["cmap"], s.get_label(self._use_latex), norm=norm) xlims.append((np.amin(x), np.amax(x))) ylims.append((np.amin(y), np.amax(y))) else: if s.is_parametric: x, y, param = self.series[i].get_data() else: x, y = self.series[i].get_data() # TODO: Point2D are updated but not visible. self._handles[i][0].set_data(x, y) elif s.is_3Dline: if s.is_parametric: x, y, z, param = self.series[i].get_data() else: x, y, z = self.series[i].get_data() if isinstance(self._handles[i][0], Line3DCollection): # gradient lines segments = self.get_segments(x, y, z) self._handles[i][0].set_segments(segments) elif isinstance(self._handles[i][0], Path3DCollection): # 3D points self._handles[i][0]._offsets3d = (x, y, z) else: if hasattr(self._handles[i][0], "set_data_3d"): # solid lines self._handles[i][0].set_data_3d(x, y, z) else: # scatter self._handles[i][0].set_offset(np.c_[x, y, z]) if s.is_parametric and s.use_cm: self._handles[i][0].set_array(param) kw, cax = self._handles[i][1:] self._update_colorbar(cax, kw["cmap"], s.get_label(self._use_latex), param=param) xlims.append((np.amin(x), np.amax(x))) ylims.append((np.amin(y), np.amax(y))) zlims.append((np.amin(z), np.amax(z))) elif s.is_contour and (not s.is_complex): x, y, z = self.series[i].get_data() kw, cax, clabels = self._handles[i][1:] for c in self._handles[i][0].collections: c.remove() if (not s.is_filled) and s.show_clabels: for cl in clabels: cl.remove() func = self._ax.contourf if s.is_filled else self._ax.contour self._handles[i][0] = func(x, y, z, **kw) if s.is_filled: self._update_colorbar(cax, kw["cmap"], s.get_label(self._use_latex), param=z) else: if s.show_clabels: clabels = self._ax.clabel(self._handles[i][0]) self._handles[i][-1] = clabels xlims.append((np.amin(x), np.amax(x))) ylims.append((np.amin(y), np.amax(y))) elif s.is_3Dsurface and (not s.is_domain_coloring) and (not s.is_implicit): if not s.is_parametric: x, y, z = self.series[i].get_data() facecolors = s.eval_color_func(x, y, z) else: x, y, z, u, v = self.series[i].get_data() facecolors = s.eval_color_func(x, y, z, u, v) # TODO: by setting the keyword arguments, somehow the # update becomes really really slow. kw, is_cb_added, cax = self._handles[i][1:] if is_cb_added: # TODO: if use_cm=True and a single 3D expression is # shown with legend=False, this won't get executed. # In widget plots, the surface will never change color. vmin = s.rendering_kw.get("vmin", np.amin(facecolors)) vmax = s.rendering_kw.get("vmax", np.amax(facecolors)) norm = self.Normalize(vmin=vmin, vmax=vmax) cmap = kw["cmap"] if isinstance(cmap, str): cmap = self.cm.get_cmap(cmap) kw["facecolors"] = cmap(norm(facecolors)) self._handles[i][0].remove() self._handles[i][0] = self._ax.plot_surface( x, y, z, **kw) if is_cb_added: self._update_colorbar(cax, kw["cmap"], s.get_label(self._use_latex), norm=norm) xlims.append((np.amin(x), np.amax(x))) ylims.append((np.amin(y), np.amax(y))) zlims.append((np.amin(z), np.amax(z))) elif s.is_implicit and not s.is_3Dsurface: points = s.get_data() if len(points) == 2: raise NotImplementedError else: for c in self._handles[i][0].collections: c.remove() xx, yy, zz, plot_type = points kw = self._handles[i][1] if plot_type == "contour": self._handles[i][0] = self._ax.contour( xx, yy, zz, [0.0], **kw ) else: self._handles[i][0] = self._ax.contourf(xx, yy, zz, **kw) xlims.append((np.amin(xx), np.amax(xx))) ylims.append((np.amin(yy), np.amax(yy))) elif s.is_vector and s.is_3D: if s.is_streamlines: raise NotImplementedError xx, yy, zz, uu, vv, ww = self.series[i].get_data() kw, is_cb_added, cax = self._handles[i][1:] mag = np.sqrt(uu ** 2 + vv ** 2 + ww ** 2) uu0, vv0, ww0 = [t.copy() for t in [uu, vv, ww]] if s.normalize: uu, vv, ww = [t / mag for t in [uu, vv, ww]] self._handles[i][0].remove() if "array" in kw.keys(): color_val = mag if s.color_func is not None: color_val = s.eval_color_func(xx, yy, zz, uu0, vv0, ww0) kw["array"] = color_val.flatten() self._handles[i][0] = self._ax.quiver(xx, yy, zz, uu, vv, ww, **kw) if is_cb_added: self._update_colorbar(cax, kw["cmap"], s.get_label(self._use_latex), param=mag) xlims.append((np.amin(xx), np.amax(xx))) ylims.append((np.amin(yy), np.amax(yy))) zlims.append((np.nanmin(zz), np.nanmax(zz))) elif s.is_vector: xx, yy, uu, vv = self.series[i].get_data() mag = np.sqrt(uu ** 2 + vv ** 2) uu0, vv0 = [t.copy() for t in [uu, vv]] if s.normalize: uu, vv = [t / mag for t in [uu, vv]] if s.is_streamlines: raise NotImplementedError # Streamlines are composed by lines and arrows. # Arrows belongs to a PatchCollection. Currently, # there is no way to remove a PatchCollection.... kw = self._handles[i][1] self._handles[i][0].lines.remove() self._handles[i][0].arrows.remove() self._handles[i][0] = self._ax.streamplot(xx, yy, uu, vv, **kw) else: kw, is_cb_added, cax = self._handles[i][1:] color_val = mag if s.color_func is not None: color_val = s.eval_color_func(xx, yy, uu0, vv0) if is_cb_added: self._handles[i][0].set_UVC(uu, vv, color_val) self._update_colorbar(cax, kw["cmap"], s.get_label(self._use_latex), color_val) else: self._handles[i][0].set_UVC(uu, vv) self._handles[i][0].set_offsets(np.c_[xx.flatten(), yy.flatten()]) xlims.append((np.amin(xx), np.amax(xx))) ylims.append((np.amin(yy), np.amax(yy))) elif s.is_complex: if not s.is_3Dsurface: x, y, _, _, img, colors = s.get_data() self._handles[i][0].set_data(img) self._handles[i][0].set_extent((x.min(), x.max(), y.min(), y.max())) else: x, y, mag, arg, facecolors, colorscale = s.get_data() self._handles[i][0].remove() kw = self._handles[i][1] if s.use_cm: kw["facecolors"] = facecolors / 255 self._handles[i][0] = self._ax.plot_surface(x, y, mag, **kw) xlims.append((np.amin(x), np.amax(x))) ylims.append((np.amin(y), np.amax(y))) zlims.append((np.amin(mag), np.amax(mag))) elif s.is_geometry and not (s.is_2Dline): # TODO: fill doesn't update x, y = self.series[i].get_data() self._handles[i][0].remove() self._handles[i][0] = self._ax.fill(x, y, **self._handles[i][1]) # Update the plot limits according to the new data Axes3D = mpl_toolkits.mplot3d.Axes3D if not isinstance(self._ax, Axes3D): # https://stackoverflow.com/questions/10984085/automatically-rescale-ylim-and-xlim-in-matplotlib # recompute the ax.dataLim self._ax.relim() # update ax.viewLim using the new dataLim self._ax.autoscale_view() else: pass self._set_lims(xlims, ylims, zlims) def process_series(self): """ Loop over data series, generates numerical data and add it to the figure. """ # create the figure from scratch every time, otherwise if the plot was # previously shown, it would not be possible to show it again. This # behaviour is specific to Matplotlib self._create_figure() self._process_series(self.series) def show(self, **kwargs): """Display the current plot. Parameters ========== **kwargs : dict Keyword arguments to be passed to plt.show(). """ self.process_series() if _show: self._fig.tight_layout() self.plt.show(**kwargs) else: self.close() def save(self, path, **kwargs): """Save the current plot at the specified location. Refer to [#fn10]_ to visualize all the available keyword arguments. References ========== .. [#fn10] https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html """ if self._fig is None: self.process_series() self._fig.savefig(path, **kwargs) def close(self): """Close the current plot.""" self.plt.close(self._fig)
MB = MatplotlibBackend