Source code for spb.series.base

import param
from spb.defaults import cfg
from spb.utils import _get_free_symbols, _check_misspelled_kwargs
from sympy import Tuple, sympify, Expr, Symbol
from param.parameterized import Undefined
import typing


def _check_misspelled_series_kwargs(series, **kwargs):
    plot_function = kwargs.pop("plot_function", False)

    if plot_function:
        from spb.backends.base_backend import Plot
        plot_params = list(Plot.param) + [
            "show", "backend", "imodule", "threed", "process_piecewise",
            "animation", "servable", "template", "ncols", "layout",
            "markers", "rectangles", "annotations"
        ]
        for k in plot_params:
            kwargs.pop(k, None)

    _check_misspelled_kwargs(series, exclude_keys=["n"], **kwargs)


def _get_wrapper_for_expr(ret):
    wrapper = "%s"
    if ret == "real":
        wrapper = "re(%s)"
    elif ret == "imag":
        wrapper = "im(%s)"
    elif ret == "abs":
        wrapper = "abs(%s)"
    elif ret == "arg":
        wrapper = "arg(%s)"
    return wrapper


def _raise_color_func_error(series, nargs):
    if not isinstance(series, BaseSeries):
        return
    if series.color_func is None:
        return

    class_ = type(series).__name__
    raise ValueError(
        f"Error while processing the `color_func` of {class_}:"
        f" wrong number of arguments ({nargs}).\n"
        " Here is the documentation of the `color_func` attribute:\n\n"
        f"{series.param.color_func.doc}"
    )


class _ParametersDict(param.Dict):
    """As of `param 2.2.0, there is no mechanism to preprocess the value
    of an attribute just before it is set. This class allows to do just that.

    https://discourse.holoviz.org/t/what-is-the-best-way-to-make-custom-validation-and-transform-before-validation/3369
    """

    def __set__(self, obj, val):
        """Preprocess the parameters provided to the series.

        NOTE: at this point, if the data series is being instantiated,
        val is a dictionary with the following form:
            {
                symb1: (1, 0, 10, "label"),
                symb2: FloatSlider(value=2, min=0, max=5),
                (symb3, symb4): RangeSlider(...),
            }
        On the other hand, if the data series is being updated with new
        data from the widgets, val has this form:
            {
                symb1: val1,
                symb2: val2,
                (symb3, symb4): (val3, val4),
            }

        Here I unpack (symb3, symb4) so that self.param.keys()
        contains only symbols. This is what val is going to look
        after the preprocessing:
            {
                symb1: (1, 0, 10, "label"),
                symb2: FloatSlider(value=2, min=0, max=5),
                symb3: RangeSlider(...),
                symb4: RangeSlider(...),
            }
        Or if numeric values are provided:
            {
                symb1: val1,
                symb2: val2,
                symb3: val3,
                symb4: val4,
            }
        """
        # NOTE: Data series are unable to extract numerical values
        # from widgets. This step is done by iplot(). Before executing
        # the get_data() method, be sure to provide a ``params``
        # dictionary mapping symbols to numeric values.

        if any(isinstance(t, (list, tuple)) for t in val.keys()):
            new_params = {}
            for k, v in val.items():
                if isinstance(k, (list, tuple)):
                    # we are dealing with a multivalued widget
                    if isinstance(v, (list, tuple)):
                        # this is executed when params is updated with new
                        # numerical data from the widget
                        for symb, num in zip(k, v):
                            new_params[symb] = num
                    else:
                        # this is executed at data series instantiation
                        for symb in k:
                            new_params[symb] = v
                else:
                    new_params[k] = v
            val = new_params

        super().__set__(obj, val)


class _CastToInteger(param.Integer):
    """
    ``n1, n2, n3`` (number of discretization points) should be integer
    for np.linspace to work properly, but can receive float numbers.
    For example, 1e04.
    """

    def __set__(self, obj, val):
        super().__set__(obj, int(val))


class _CastToFloat(param.Number):
    """
    This parameter assures that instances of NumberSymbol are cast to float.
    """

    def __set__(self, obj, val):
        super().__set__(obj, float(val))


class _RangeTuple(param.ClassSelector):
    """
    Represent a range for some variable. It must be a 3-elements tuple,
    `(symbol, min_val, max_val)`.
    """

    @typing.overload
    def __init__(
        self,
        default=None, *, is_instance=True,
        allow_None=False, doc=None, label=None, precedence=None,
        instantiate=True, constant=False, readonly=False,
        pickle_default_value=True, per_instance=True,
        allow_refs=False, nested_refs=False
    ):
        ...

    def __init__(self, default=Undefined, **params):
        super().__init__(default=default, class_=(tuple, Tuple), **params)

    def _validate(self, val):
        super()._validate(val)
        if val is not None:
            if len(val) != 3:
                raise ValueError(
                    f"Parameter '{self.name}' must be a 3-elements tuple."
                    f" Instead, {len(val)} elements were provided."
                )

    def __set__(self, obj, val):
        if (val is not None) and isinstance(val[0], str):
            val = (Symbol(val[0]), *val[1:])
        super().__set__(obj, sympify(val))


[docs] class BaseSeries(param.Parameterized): """Base class for the data objects containing stuff to be plotted. Notes ===== The backend should check if it supports the data series that it's given. It's the backend responsibility to know how to use the data series that it's given. """ # Some flags follow. The rationale for using flags instead of checking # base classes is that setting multiple flags is simpler than multiple # inheritance. is_2Dline = False is_3Dline = False is_3Dsurface = False is_contour = False is_implicit = False # Both contour and implicit series uses colormap, but they are different. # Hence, a different attribute is_parametric = False is_vector = False is_2Dvector = False is_3Dvector = False is_slice = False # Represents a 2D or 3D vector is_complex = False # Represent a complex expression is_domain_coloring = False is_geometry = False # If True, it represents an object of the sympy.geometry module is_generic = False # Implement back-compatibility with sympy.plotting <= 1.11 # Please, read NOTE section on GenericDataSeries is_grid = False # Represents grids like s-grid, z-grid, n-grid, ... ##################### # Instance Attributes ##################### # NOTE: some data series should not be shown on the legend, for example # wireframe lines on 3D plots. show_in_legend = param.Boolean(True, doc=""" Toggle the visibility of the data series on the legend.""") colorbar = param.Boolean(True, doc=""" Toggle the visibility of the colorbar associated to the current data series. Note that a colorbar is only visible if ``use_cm=True`` and ``color_func`` is not None.""") use_cm = param.Boolean(False, doc=""" Toggle the use of a colormap. By default, some series might use a colormap to display the necessary data. Setting this attribute to False will inform the associated renderer to use solid color. Related parameters: ``color_func``.""") # TODO: can I remove _label_str and only keep label? # NOTE: By default the data series stores two labels: one for the # string representation of the symbolic expression, the other for the # latex representation. The plotting library will then decide which one # is best to be shown. If the user set this parameter, both labels will # receive the same value. To retrieve one or the other representation, # call the ``get_label`` method of the data series. label = param.String("", doc=""" Set the label associated to this series, which will be eventually shown on the legend or colorbar.""") rendering_kw = param.Dict(doc=""" Keyword arguments to be passed to the renderers of the selected plotting library in order to further customize the appearance of this data series.""") # TODO: can the code be modified so that series.params ALWAYS returns # a dictionary mapping symbols to numerical values? This requires # the extraction of values during series instantiation. params = _ParametersDict({}, doc=""" A dictionary mapping parameters (symbols not being used in the ranges) to numerical values.""") _label_str = param.String("", doc="""Contains str representation.""") _label_latex = param.String("", doc="""Contains latex representation.""") _is_interactive = param.Boolean(False, constant=True, doc=""" Verify if this data series is interactive or not. Each data series expect one (or more) symbols to be specified as a discretization variable (ie, the ranges of the data series). However, the symbolic expressions may contain more symbols than what is expected by the ranges. In that case, the additional symbols are considered parameters, which will receive numerical values from interactive widgets. If this parameter is True, then the ``params`` attributes contains a non-empty dictionary.""") # TODO: I probably don't need this if I can better implement ``params`` # in the first place. See TODO on ``params``. _original_params = param.Dict({}, doc=""" This stores a copy of the ``params`` dictionary, just as it was provided by the user during a plotting function call. It is used by spb.interactive to keep track of multi-values widgets, which allows the mapping of symbols to the appropriate numerical values.""") _parametric_ranges = param.Boolean(False, doc=""" Whether the series contains any parametric range, which is a range depending on symbols contained in ``params.keys()``.""") _range_names = param.List(default=[], item_type=str, doc=""" List of parameter names refering to ranges. This parameter allows to quickly retrieve all ranges associated to a particular data series.""") def __repr__(self): if cfg["use_repr"] is False: return object.__repr__(self) return super().__repr__() def _enforce_dict_on_rendering_kw(self, rendering_kw): return {} if rendering_kw is None else rendering_kw @param.depends("_label_str", watch=True) def _update_label(self): # NOTE: this implements back-compatibility with sympy.plotting self.label = self._label_str @param.depends("label", watch=True) def _update_latex_and_str_labels(self): # this is triggered when someone changes the label after instantiating # the plot, like p[0].label = "something" self._label_latex = self.label self._label_str = self.label def __init__(self, *args, **kwargs): # allow the user to specify the number of discretization points # using different keyword arguments kwargs = _set_discretization_points(kwargs.copy(), type(self)) # user (or plotting function) may still provide None to rendering_kw. # here we prevent this event from raising errors. # This helps to maintain back-compatibility with the graphics module. rendering_kw = kwargs.get("rendering_kw", None) if rendering_kw is None: kwargs["rendering_kw"] = {} # if user provides a label, overrides both the string and latex # representations label = kwargs.get("label", None) if label: kwargs["_label_str"] = kwargs["_label_latex"] = label _params = kwargs.setdefault("params", {}) # this is used by spb.interactive to keep track of multi-values widgets kwargs.setdefault("_original_params", kwargs.get("params", {})) # remove keyword arguments that are not parameters of this series kwargs = { k: v for k, v in kwargs.items() if k in self._get_list_of_allowed_params() } super().__init__(*args, **kwargs) if len(_params) > 0: with param.edit_constant(self): self._is_interactive = True numbers_or_expressions = set().union( *[nv[1:] for nv in self.ranges] ) fs = set().union(*[e.free_symbols for e in numbers_or_expressions]) if len(fs) > 0: self._parametric_ranges = True def _post_init(self): exprs = self.expr if hasattr(self.expr, "__iter__") else [self.expr] if any(callable(e) for e in exprs) and self.params: raise TypeError( "`params` was provided, hence an interactive plot " "is expected. However, interactive plots do not support " "user-provided numerical functions.") # if the expressions is a lambda function and no label has been # provided, then its better to do the following in order to avoid # suprises on the backend if any(callable(e) for e in exprs): if self._label_str == str(self.expr): self.label = "" self._check_fs() @classmethod def _get_list_of_allowed_params(cls): # also allows n1, n2, n3. they will be removed later on inside # _set_discretization_points return list(cls.param) + [ "nb_of_points", "nb_of_points_x", "nb_of_points_y", "nb_of_points_u", "nb_of_points_v" ] def _block_lambda_functions(self, *exprs): if any(callable(e) for e in exprs): raise TypeError( type(self).__name__ + " requires a symbolic expression.") def _check_fs(self): """ Checks if there are enogh parameters and free symbols. """ exprs, ranges = self.expr, self.ranges params, label = self.params, self.label exprs = exprs if hasattr(exprs, "__iter__") else [exprs] if any(callable(e) for e in exprs): return # from the expression's free symbols, remove the ones used in # the parameters and the ranges fs = _get_free_symbols(exprs) if hasattr(self, "color_func"): fs = fs.union(_get_free_symbols(self.color_func)) fs = fs.difference(params.keys()) if ranges is not None: fs = fs.difference([r[0] for r in ranges]) if len(fs) > 0: if (ranges is not None) and len(ranges) > 0: erl = f"Expressions: {exprs}\n" if ( hasattr(self, "color_func") and isinstance(self.color_func, Expr) ): erl += f"color_func: {self.color_func}\n" erl += f"Ranges: {ranges}\nLabel: {label}\n" else: erl = "Expressions: %s\nLabel: %s\n" % (exprs, label) raise ValueError( "Incompatible expression and parameters.\n%s" "params: %s\n" "Specify what these symbols represent: %s\n" "Are they ranges or parameters?" % (erl, params, fs) ) # verify that all symbols are known (they either represent plotting # ranges or parameters) range_symbols = [r[0] for r in ranges] for r in ranges: fs = set().union(*[e.free_symbols for e in r[1:]]) if any(t in fs for t in range_symbols): raise ValueError( "Range symbols can't be included into minimum and maximum" " of a range. Received range: %s" % str(r)) remaining_fs = fs.difference(params.keys()) if len(remaining_fs) > 0: raise ValueError( "Unkown symbols found in plotting range: %s. " % (r,) + "Are the following parameters? %s" % remaining_fs) @property def is_3D(self): flags3D = [self.is_3Dline, self.is_3Dsurface, self.is_3Dvector] return any(flags3D) @property def is_line(self): flagslines = [self.is_2Dline, self.is_3Dline] return any(flagslines) @property def is_interactive(self): return self._is_interactive def _line_surface_color(self, prop, val): """This method enables back-compatibility with old sympy.plotting""" # NOTE: color_func is set inside the init method of the series. # If line_color/surface_color is not a callable, then color_func will # be set to None. prop = prop[1:] # remove underscore if callable(val) or isinstance(val, Expr): prop_val = None cf_val = val else: prop_val = val cf_val = None # prevents the triggering of events, which would cause recursion error with param.discard_events(self): setattr(self, prop, prop_val) if val is not None: # avoid resetting color_func when user writes line_color=None self.color_func = cf_val @property def scales(self): # get scale function gs = lambda k: getattr(self, k) if hasattr(self, k) else "linear" return [gs("xscale"), gs("yscale"), gs("zscale")] def eval_color_func(self, *args): """ Evaluate the color function. Depending on the data series, either the data series itself or the backend will eventually execute this function to generate the appropriate coloring value. Parameters ---------- args : tuple Arguments to be passed to the coloring function. Can be numerical coordinates or parameters or both. Read the documentation of each data series `color_func` attribute to find out what the arguments should be. Returns ------- color : np.ndarray or float Results of the numerical evaluation of the ``color_func`` attribute. """ if hasattr(self, "evaluator") and (self.evaluator is not None): color = self.evaluator.eval_color_func(*args) if color is not None: return color if hasattr(self, "_eval_color_func_helper"): return self._eval_color_func_helper(*args) raise NotImplementedError def get_data(self): """Compute and returns the numerical data. The number of arrays returned by this method depends on the specific instance. Let ``s`` be an instance of ``BaseSeries``. Make sure to read ``help(s.get_data)`` to understand what it returns. """ raise NotImplementedError def _get_wrapped_label(self, label, wrapper): """Given a latex representation of an expression, wrap it inside some characters. Matplotlib needs "$%s%$", K3D-Jupyter needs "%s". """ return wrapper % label def get_label(self, use_latex=False, wrapper="$%s$"): """ Return the label to be used to display the expression. Parameters ========== use_latex : bool If False, the string representation of the expression is returned. If True, the latex representation is returned. wrapper : str The backend might need the latex representation to be wrapped by some characters. Default to ``"$%s$"``. Returns ======= label : str """ if use_latex is False: return self._label_str if self._label_str == str(self.expr): return self._get_wrapped_label(self._label_latex, wrapper) return self._label_latex @property def ranges(self): """ Return a list of up to three 3-elements tuples, each one having the form (symbol, min, max), representing the ranges of numerical values used by each of the specified symbols. """ return [getattr(self, k) for k in self._range_names] @ranges.setter def ranges(self, values): for k, v in zip(self._range_names, values): setattr(self, k, v) def _apply_transform(self, *args): """Apply transformations to the results of numerical evaluation. Parameters ========== args : tuple Results of the numerical evaluation. Returns ======= transformed_args : tuple Tuple containing the transformed results. """ raise NotImplementedError def _get_transform_helper(self): t = lambda x, transform: x if transform is None else transform(x) return t def _str_helper(self, s): pre, post = "", "" if self.is_interactive: pre = "interactive " post = " and parameters " + str(tuple(self.params.keys())) return pre + s + post
def _set_discretization_points(kwargs, Series): """Allow the use of the keyword arguments n, n1 and n2 (and n3) to specify the number of discretization points in two (or three) directions. Parameters ========== kwargs : dict Series : BaseSeries The type of the series, which indicates the kind of plot we are trying to create. Returns ======= kwargs : dict """ deprecated_keywords = { "nb_of_points": "n1", "nb_of_points_x": "n1", "nb_of_points_y": "n2", "nb_of_points_u": "n1", "nb_of_points_v": "n2", "points": "n" } for k, v in deprecated_keywords.items(): if k in kwargs.keys(): kwargs[v] = kwargs.pop(k) n = [None] * 3 provided_n = kwargs.pop("n", None) if provided_n is not None: if hasattr(provided_n, "__iter__"): for i in range(min(len(provided_n), 3)): n[i] = int(provided_n[i]) else: n = [int(provided_n)] * 3 if n[0] is not None: kwargs.setdefault("n1", n[0]) if n[1] is not None: kwargs.setdefault("n2", n[1]) if n[2] is not None: kwargs.setdefault("n3", n[2]) return kwargs