Source code for spb.utils

from spb.defaults import cfg
from sympy import (
    Tuple, sympify, Expr, Dummy, sin, cos, Symbol, Indexed, ImageSet,
    FiniteSet, Basic, Float, Integer, Rational, Poly, exp,
    NumberSymbol, IndexedBase
)
from sympy.vector import BaseScalar
from sympy.core.function import AppliedUndef
from sympy.core.relational import Relational
from sympy.logic.boolalg import BooleanFunction
from sympy.external import import_module
import param
import warnings


def _create_missing_ranges(exprs, ranges, npar, params=None, imaginary=False):
    """This function does two things:
    1. Check if the number of free symbols is in agreement with the type of
       plot chosen. For example, plot() requires 1 free symbol;
       plot3d() requires 2 free symbols.
    2. Sometime users create plots without providing ranges for the variables.
       Here we create the necessary ranges.

    Parameters
    ==========

    exprs : iterable
        The expressions from which to extract the free symbols
    ranges : iterable
        The limiting ranges provided by the user
    npar : int
        The number of free symbols required by the plot functions.
        For example,
        npar=1 for plot, npar=2 for plot3d, ...
    params : dict
        A dictionary mapping symbols to parameters for iplot.
    imaginary : bool
        Include the imaginary part. Default to False.

    """

    def get_default_range(symbol):
        _min = cfg["plot_range"]["min"]
        _max = cfg["plot_range"]["max"]
        if not imaginary:
            return Tuple(symbol, _min, _max)
        return Tuple(symbol, _min + _min * 1j, _max + _max * 1j)

    free_symbols = _get_free_symbols(exprs)
    if params is not None:
        if any(isinstance(t, (list, tuple)) for t in params.keys()):
            # take care of RangeSlider
            p_symbols = set()
            for k in params.keys():
                if isinstance(k, (list, tuple)):
                    p_symbols = p_symbols.union(k)
                else:
                    p_symbols = p_symbols.union([k])
        else:
            p_symbols = params.keys()
        free_symbols = free_symbols.difference(p_symbols)

    if len(free_symbols) > npar:
        raise ValueError(
            "Too many free symbols.\n"
            + "Expected {} free symbols.\n".format(npar)
            + "Received {}: {}".format(len(free_symbols), free_symbols)
        )

    if len(ranges) > npar:
        raise ValueError(
            "Too many ranges. Received %s, expected %s" % (len(ranges), npar))

    # free symbols in the ranges provided by the user
    rfs = set().union([r[0] for r in ranges])
    if len(rfs) != len(ranges):
        raise ValueError("Multiple ranges with the same symbol")

    if len(ranges) < npar:
        symbols = free_symbols.difference(rfs)
        if symbols != set():
            # add a range for each missing free symbols
            for s in symbols:
                ranges.append(get_default_range(s))
        # if there is still room, fill them with dummys
        for i in range(npar - len(ranges)):
            ranges.append(get_default_range(Dummy()))

    if len(free_symbols) == npar:
        # there could be times when this condition is not met, for example
        # plotting the function f(x, y) = x (which is a plane); in this case,
        # free_symbols = {x} whereas rfs = {x, y} (or x and Dummy)
        rfs = set().union([r[0] for r in ranges])
        if len(free_symbols.difference(rfs)) > 0:
            raise ValueError(
                "Incompatible free symbols of the expressions with "
                "the ranges.\n"
                + "Free symbols in the expressions: {}\n".format(free_symbols)
                + "Free symbols in the ranges: {}".format(rfs)
            )
    return ranges


def _create_ranges_iterable(*ranges):
    """Create a list of ranges. If a range is not provided, it won't be
    included in this list.

    Returns
    -------
    provided_ranges : list
        A list, for example `[r1, r2, r3]`. If `r2` is not provide, the list
        looks like `[r1, r2]`. If no range is provided, `[]` is returned.
    mapping : dict
        Maps the i-th provided range to its position in `provided_ranges`.
    """
    provided_ranges = []
    mapping = {}
    for i, r in enumerate(ranges):
        if r is not None:
            provided_ranges.append(r)
            mapping[i] = len(provided_ranges) - 1
    return provided_ranges, mapping


def _preprocess_multiple_ranges(exprs, ranges, npar, params={}):
    """Users might not provide the necessary ranges to create a 3D plot.
    This function looks at what has been provided, eventually add missing
    ranges and sort them to the appropriate order.

    Parameters
    ----------
    exprs : iterable
        The expressions from which to extract the free symbols
    ranges : iterable
        The limiting ranges provided by the user
    npar : int
        The number of free symbols required by the plot functions.
        For example, npar=1 for plot, npar=2 for plot3d, ...
    params : dict
        A dictionary mapping symbols to parameters for iplot.
    """
    provided_ranges, mapping = _create_ranges_iterable(*ranges)
    # add missing ranges
    ranges = _create_missing_ranges(
        exprs, provided_ranges.copy(), npar, params)
    # sort the ranges in order to get [range1, range2, ranges3 [optional]]
    sorted_ranges = [None] * npar
    for k, v in mapping.items():
        sorted_ranges[k] = provided_ranges[v]
        ranges.remove(provided_ranges[v])
    for r in ranges:
        i = sorted_ranges.index(None)
        sorted_ranges[i] = r
    return sorted_ranges


def _get_free_symbols(exprs):
    """
    Returns the free symbols of a symbolic expression.

    If the expression contains any of these elements, assume that they are
    the "free symbols" of the expression:

    * indexed objects
    * applied undefined function (useful for sympy.physics.mechanics module)
    """
    # TODO: this function gets called 3 times to generate a single plot.
    # See if its possible to remove one functions call inside series.py
    if exprs is None:
        # this case happens when we are retrieving the free symbols from
        # color_func, whose default value is None
        return set()
    if not isinstance(exprs, (list, tuple, set)):
        exprs = [exprs]
    if all(callable(e) for e in exprs):
        return set()

    # NOTE:
    # 1. srepr(IndexedBase("a")) is "IndexedBase(Symbol('a'))"
    #    So, if expr = IndexedBase("a")[0] + 1, it follows that
    #    expr.free_symbols is {IndexedBase("a")[0], Symbol("a")}
    #    This must be filtered to {IndexedBase("a")[0]}
    # 2. Let a = IndexedBase("a"). Even though as of sympy 1.14.0 it is
    #    possible to write expressions like a + 1, for simplicity,
    #    I don't allow them, because of Note 1, which would increase
    #    complexity in this code.

    undefined_func = set().union(*[e.atoms(AppliedUndef) for e in exprs])
    undefined_func_args = set().union(*[f.args for f in undefined_func])
    indexed_base = set().union(*[e.atoms(IndexedBase) for e in exprs])
    indexed_base_args = set().union(*[i.args for i in indexed_base])

    # select all free symbols, be them instances of Symbol, Indexed
    # or the arguments of IndexedBase
    free_symbols = set().union(*[e.free_symbols for e in exprs])
    # remove instances of IndexedBase
    free_symbols = free_symbols.difference(indexed_base)
    # remove free symbols that are arguments of applied undef functions
    # it is unlikely that these symbols are being used as parameters as well.
    free_symbols = free_symbols.difference(undefined_func_args)
    # remove free symbols that are arguments of indexed base
    free_symbols = free_symbols.difference(indexed_base_args)

    free = free_symbols.union(undefined_func)

    return free


def _check_arguments(args, nexpr, npar, **kwargs):
    """Checks the arguments and converts into tuples of the
    form (exprs, ranges, label, rendering_kw).

    Parameters
    ==========

    args
        The arguments provided to the plot functions
    nexpr
        The number of sub-expression forming an expression to be plotted.
        For example:
        nexpr=1 for plot.
        nexpr=2 for plot_parametric: a curve is represented by a tuple of two
            elements.
        nexpr=1 for plot3d.
        nexpr=3 for plot3d_parametric_line: a curve is represented by a tuple
            of three elements.
    npar
        The number of free symbols required by the plot functions. For example,
        npar=1 for plot, npar=2 for plot3d, ...
    **kwargs :
        keyword arguments passed to the plotting function. It will be used to
        verify if ``params`` has ben provided.

    Examples
    ========

    .. plot::
       :context: reset
       :format: doctest
       :include-source: True

       >>> from sympy import cos, sin, symbols
       >>> from sympy.plotting.plot import _check_arguments
       >>> x = symbols('x')
       >>> _check_arguments([cos(x), sin(x)], 2, 1)
           [(cos(x), sin(x), (x, -10, 10), '(cos(x), sin(x))')]

       >>> _check_arguments([x, x**2], 1, 1)
           [(x, (x, -10, 10), 'x'), (x**2, (x, -10, 10), 'x**2')]
    """
    if not args:
        return []
    output = []
    params = kwargs.get("params", None)

    if all([isinstance(a, (Expr, Relational, BooleanFunction)) for a in args[:nexpr]]):
        # In this case, with a single plot command, we are plotting either:
        #   1. one expression
        #   2. multiple expressions over the same range

        exprs, ranges, label, rendering_kw = _unpack_args(*args)
        free_symbols = set().union(*[e.free_symbols for e in exprs])
        ranges = _create_missing_ranges(exprs, ranges, npar, params)

        if nexpr > 1:
            # in case of plot_parametric or plot3d_parametric_line, there will
            # be 2 or 3 expressions defining a curve. Group them together.
            if len(exprs) == nexpr:
                exprs = (tuple(exprs),)
        for expr in exprs:
            # need this if-else to deal with both plot/plot3d and
            # plot_parametric/plot3d_parametric_line
            is_expr = isinstance(expr, (Expr, Relational, BooleanFunction))
            e = (expr,) if is_expr else expr
            output.append((*e, *ranges, label, rendering_kw))

    else:
        # In this case, we are plotting multiple expressions, each one with its
        # range. Each "expression" to be plotted has the following form:
        # (expr, range, label) where label is optional

        _, ranges, labels, rendering_kw = _unpack_args(*args)
        labels = [labels] if labels else []

        # number of expressions
        n = (
            len(ranges) + len(labels) +
            (len(rendering_kw) if rendering_kw is not None else 0)
        )
        new_args = args[:-n] if n > 0 else args

        # at this point, new_args might just be [expr]. But I need it to be
        # [[expr]] in order to be able to loop over
        # [expr, range [opt], label [opt]]
        if not isinstance(new_args[0], (list, tuple, Tuple)):
            new_args = [new_args]

        # Each arg has the form (expr1, expr2, ..., range1 [optional], ...,
        #   label [optional], rendering_kw [optional])
        for arg in new_args:
            # look for "local" range and label. If there is not, use "global".
            l = [a for a in arg if isinstance(a, str)]
            if not l:
                l = labels
            r = [a for a in arg if _is_range(a)]
            if not r:
                r = ranges.copy()
            rend_kw = [a for a in arg if isinstance(a, dict)]
            rend_kw = rendering_kw if len(rend_kw) == 0 else rend_kw[0]

            # NOTE: arg = arg[:nexpr] may raise an exception if lambda
            # functions are used. Execute the following instead:
            arg = [arg[i] for i in range(nexpr)]
            free_symbols = set()
            if all(not callable(a) for a in arg):
                free_symbols = free_symbols.union(*[
                    a.free_symbols for a in arg])
            if len(r) != npar:
                r = _create_missing_ranges(arg, r, npar, params)

            label = None if not l else l[0]
            output.append((*arg, *r, label, rend_kw))
    return output


def _plot_sympify(args):
    """By allowing the users to set custom labels to the expressions being
    plotted, a critical issue is raised: whenever a special character like $,
    {, }, ... is used in the label (type string), sympify will raise an error.
    This function recursively loop over the arguments passed to the plot
    functions: the sympify function will be applied to all arguments except
    those of type string.
    """
    if isinstance(args, Expr):
        return args

    args = list(args)
    for i, a in enumerate(args):
        if isinstance(a, (list, tuple)):
            args[i] = Tuple(*_plot_sympify(a), sympify=False)
        elif not (
            isinstance(a, (str, dict)) or callable(a)
            or (
                (a.__class__.__name__ == "Vector") and
                not isinstance(a, Basic)
            )
        ):
            args[i] = sympify(a)
    return args


def _is_range(r):
    """A range is defined as (symbol, start, end). start and end should
    be numbers.
    """
    if isinstance(r, prange):
        return True
    return (
        isinstance(r, Tuple)
        and (len(r) == 3)
        and (not isinstance(r.args[1], str)) and r.args[1].is_number
        and (not isinstance(r.args[2], str)) and r.args[2].is_number
    )


[docs] class prange(Tuple): """Represents a plot range, an entity describing what interval a particular variable is allowed to vary. It is a 3-elements tuple: (symbol, minimum, maximum). Notes ===== Why does the plotting module needs this class instead of providing a plotting range with ordinary tuple/list? After all, ordinary plots works just fine. If a plotting range is provided with a 3-elements tuple/list, the internal algorithm looks at the tuple and tries to determine what it is. If minimum and maximum are numeric values, than it is a plotting range. Hovewer, there are some plotting functions in which the expression consists of 3-elements tuple/list. The plotting module is also interactive, meaning that minimum and maximum can also be expressions containing parameters. In these cases, the plotting range is indistinguishable from a 3-elements tuple describing an expression. This class is meant to solve that ambiguity: it only represents a plotting range. Examples ======== Let x be a symbol and u, v, t be parameters. An example plotting range is: .. doctest:: >>> from sympy import symbols >>> from spb import prange >>> x, u, v, t = symbols("x, u, v, t") >>> prange(x, u * v, v**2 + t) (x, u*v, t + v**2) """ def __new__(cls, *args): if len(args) != 3: raise ValueError( "`%s` requires 3 elements. Received " % cls.__name__ + "%s elements: %s" % (len(args), args)) if not isinstance(args[0], (str, Symbol, BaseScalar, Indexed)): raise TypeError( "The first element of a plotting range must " "be a symbol. Received: %s" % type(args[0]) ) args = [sympify(a) for a in args] if ( (args[0] in args[1].free_symbols) or (args[0] in args[2].free_symbols) ): raise ValueError( "Symbol `%s` representing the range can only " % args[0] + "be specified in the first element of %s" % cls.__name__) return Tuple.__new__(cls, *args, sympify=False)
def _unpack_args(*args): """Given a list/tuple of arguments previously processed by _plot_sympify() and/or _check_arguments(), separates and returns its components: expressions, ranges, label and rendering keywords. Examples ======== >>> from sympy import cos, sin, symbols >>> x, y = symbols('x, y') >>> args = (sin(x), (x, -10, 10), "f1") >>> args = _plot_sympify(args) >>> _unpack_args(*args) ([sin(x)], [(x, -2, 2)], 'f1') >>> args = (sin(x**2 + y**2), (x, -2, 2), (y, -3, 3), "f2") >>> args = _plot_sympify(args) >>> _unpack_args(*args) ([sin(x**2 + y**2)], [(x, -2, 2), (y, -3, 3)], 'f2') >>> args = (sin(x + y), cos(x - y), x + y, (x, -2, 2), (y, -3, 3), "f3") >>> args = _plot_sympify(args) >>> _unpack_args(*args) ([sin(x + y), cos(x - y), x + y], [(x, -2, 2), (y, -3, 3)], 'f3') """ ranges = [t for t in args if _is_range(t)] labels = [t for t in args if isinstance(t, str)] label = None if not labels else labels[0] rendering_kw = [t for t in args if isinstance(t, dict)] rendering_kw = None if not rendering_kw else rendering_kw[0] # NOTE: why None? because args might have been preprocessed by # _check_arguments, so None might represent the rendering_kw results = [ not (_is_range(a) or isinstance(a, (str, dict)) or (a is None)) for a in args ] exprs = [a for a, b in zip(args, results) if b] return exprs, ranges, label, rendering_kw def ij2k(cols, i, j): """Create the connectivity for the mesh. https://github.com/K3D-tools/K3D-jupyter/issues/273 """ return cols * i + j def get_vertices_indices(x, y, z): """Compute the vertices matrix (Nx3) and the connectivity list for triangular faces. Parameters ========== x, y, z : np.array 2D arrays """ np = import_module('numpy') rows, cols = x.shape x = x.flatten() y = y.flatten() z = z.flatten() vertices = np.vstack([x, y, z]).T indices = [] for i in range(1, rows): for j in range(1, cols): indices.append( [ij2k(cols, i, j), ij2k(cols, i - 1, j), ij2k(cols, i, j - 1)] ) indices.append( [ij2k(cols, i - 1, j - 1), ij2k(cols, i, j - 1), ij2k(cols, i - 1, j)] ) return vertices, indices def _instantiate_backend(Backend, *series, **kwargs): show = kwargs.pop("show", True) p = Backend(*series, **kwargs) if show: p.show() return p def _check_misspelled_kwargs( obj, additional_keys=[], exclude_keys=[], **kwargs ): """Find the user-provided keywords arguments that might contain spelling errors and informs the user of possible alternatives. Parameters ========== obj : param.Parameterized The object holding the correct parameter names. additional_keys : list List of string representing additional keyword arguments that might be involved in the instantiation of `obj`. exclude_keys : list List of string representing parameter names that should not be considered while performing the validation. **kwargs : dict Keyword arguments passed to `obj` __init__ method. Notes ===== Within this module, there are "multiple levels" of keyword arguments: * some keyword arguments get intercepted at the plotting function level. Think for example to ``scalar`` in ``plot_vector``, or ``sum_bound`` in ``plot``. * some plotting function might insert useful keyword arguments, for example ``real``, ``imag``, etc., on complex-related functions. * many of the keyword arguments get passed down to the Series and/or to the Backend classes (for example, ``xscale, ...``). After porting this module to param, I have implemented the validation of keyword arguments at the ``*Series`` and ``graphics``. I was unable to perform it inside the ``Plot`` class because the ``graphics`` function removes unused keyword arguments. Hopefully one day I'll implement on the interactive level too (interactive plots and animations). The plotting module offers two main approaches: * spb.plot_function, inherithed from sympy.plotting. Here, the problem is that keyword arguments from a specific plot function get directed both at series as well as the backend. For example, the Plot class could receive arguments that are meant to go to LineOver1DRangeSeries, and vice-versa. It's a mess. In order to deal with this mess, I introduced the `plot_function` keyword argument: this will enable validation on data series but not on the ``graphics`` function. * spb.graphics: here, there is a clear separation between data series and backend. I can implement the validation on both ends. With this in mind, this function is executed at the ``__init__`` of *Series, and inside the ``graphics`` function. """ if isinstance(obj, param.Parameterized): # do not consider private attributes allowed_keys = [ t for t in obj.param.objects('existing') if t[0] != "_" ] + additional_keys else: allowed_keys = additional_keys allowed_keys = list(set(allowed_keys)) kwargs = [k for k in kwargs if k[0] != "_"] user_provided_keys = set(kwargs).difference(exclude_keys) unused_keys = user_provided_keys.difference(allowed_keys) if len(unused_keys) > 0: t = type(obj).__name__ msg = f"The following keyword arguments are unused by `{t}`.\n" for k in unused_keys: possible_match = find_closest_string(k, allowed_keys) msg += "* '%s'" % k msg += ": did you mean '%s'?\n" % possible_match warnings.warn(msg, stacklevel=3) # taken from # https://en.wikibooks.org/wiki/Algorithm_Implementation/Strings/Levenshtein_distance#Python def levenshtein(s1, s2): if len(s1) < len(s2): return levenshtein(s2, s1) # len(s1) >= len(s2) if len(s2) == 0: return len(s1) previous_row = range(len(s2) + 1) for i, c1 in enumerate(s1): current_row = [i + 1] for j, c2 in enumerate(s2): # j+1 instead of j since previous_row and current_row are # one character longer than s2 insertions = previous_row[j + 1] + 1 deletions = current_row[j] + 1 substitutions = previous_row[j] + (c1 != c2) current_row.append(min(insertions, deletions, substitutions)) previous_row = current_row return previous_row[-1] # taken from plotly.py/packages/python/plotly/_plotly_utils/utils.py def find_closest_string(string, strings): def _key(s): # sort by levenshtein distance and lexographically to maintain a stable # sort for different keys with the same levenshtein distance return (levenshtein(s, string), s) return sorted(strings, key=_key)[0] def spherical_to_cartesian(r, theta, phi): """Convert spherical coordinates to cartesian coordinates. Parameters ========== r : Radius. theta : Polar angle. Must be in [0, pi]. 0 is the north pole, pi/2 is the equator, pi is the south pole. phi : Azimuthal angle. Must be in [0, 2*pi]. Returns ======= x, y, z """ if callable(r): np = import_module('numpy') x = lambda t, p: r(t, p) * np.sin(t) * np.cos(p) y = lambda t, p: r(t, p) * np.sin(t) * np.sin(p) z = lambda t, p: r(t, p) * np.cos(t) else: x = r * sin(theta) * cos(phi) y = r * sin(theta) * sin(phi) z = r * cos(theta) return x, y, z def unwrap(angle, period=None): """Unwrap a phase angle to give a continuous curve Parameters ---------- angle : array_like Array of angles to be unwrapped period : float, optional Period (defaults to `2*pi`) Returns ------- angle_out : array_like Output array, with jumps of period/2 eliminated Examples -------- >>> # Already continuous >>> theta1 = np.array([1.0, 1.5, 2.0, 2.5, 3.0]) * np.pi >>> theta2 = ct.unwrap(theta1) >>> theta2/np.pi # doctest: +SKIP array([1. , 1.5, 2. , 2.5, 3. ]) >>> # Wrapped, discontinuous >>> theta1 = np.array([1.0, 1.5, 0.0, 0.5, 1.0]) * np.pi >>> theta2 = ct.unwrap(theta1) >>> theta2/np.pi # doctest: +SKIP array([1. , 1.5, 2. , 2.5, 3. ]) Notes ----- This function comes from the `control` package, specifically the `control.ctrlutil.py` module. """ np = import_module('numpy') if period is None: period = 2 * np.pi dangle = np.diff(angle) dangle_desired = (dangle + period/2.) % period - period/2. correction = np.cumsum(dangle_desired - dangle) angle[1:] += correction return angle def extract_solution(set_sol, n=10): """Extract numerical solutions from a set solution (computed by solveset, linsolve, nonlinsolve). Often, it is not trivial do get something useful out of them. Parameters ========== n : int, optional In order to replace ImageSet with FiniteSet, an iterator is created for each ImageSet contained in `set_sol`, starting from 0 up to `n`. Default value: 10. """ images = set_sol.find(ImageSet) for im in images: it = iter(im) s = FiniteSet(*[next(it) for n in range(0, n)]) set_sol = set_sol.subs(im, s) return set_sol def is_number(t, allow_complex=True): if allow_complex: number_types = ( NumberSymbol, Float, Integer, Rational, float, int, complex ) else: number_types = (NumberSymbol, Float, Integer, Rational, float, int) return isinstance(t, number_types) or (isinstance(t, Expr) and t.is_number) def tf_to_control(tf): """Convert a transfer function to a ``control.TransferFunction``. Parameters ========== tf : The transfer function's type can be: * an instance of :py:class:`sympy.physics.control.lti.TransferFunction` or :py:class:`sympy.physics.control.lti.TransferFunctionMatrix` * an instance of :py:class:`scipy.signal.TransferFunction` * a symbolic expression in rational form. * a tuple of two or three elements: ``(num, den, generator [opt])``. ``num, den`` can be symbolic expressions or list of coefficients. Returns ======= tf : :py:class:`ct.TransferFunction` """ ct = import_module("control") sp = import_module("scipy") sm = import_module("sympy.physics", import_kwargs={'fromlist': ['control']}) def _from_sympy_to_ct(num, den): fs = num.free_symbols.union(den.free_symbols) if len(fs) != 1: raise ValueError( "In order to convert a SymPy trasfer function to a " "``control`` transfer function, there must only be " "one free-symbol.\nReceived: %s" % fs ) s = fs.pop() delays = tf_find_time_delay(num / den) if len(delays) > 0: raise ValueError( "The symbolic transfer function contains the following " "time delays: %s. " "Time delays are not supported by the ``control`` module. " "Consider applying a Padé approximation." % delays ) n, d = [Poly(t, s).all_coeffs() for t in [num, den]] try: n = [float(t) for t in n] d = [float(t) for t in d] except TypeError as err: raise TypeError( str(err) + "\nYou are trying to convert a transfer function to " "``control.TransferFunction``. It appears like some of the " "coefficients are complex. At the time of coding this " "message, the ``control`` module doesn't support complex " "coefficents. You might still be able to achieve your goal " "by setting ``control=False`` in your function call." ) return ct.tf(n, d) if ct and isinstance(tf, ct.TransferFunction): return tf if isinstance(tf, Expr): gen = tf.free_symbols.pop() tf = sm.control.TransferFunction.from_rational_expression( tf, gen) return _from_sympy_to_ct(tf.num, tf.den) elif isinstance(tf, sm.control.TransferFunction): return _from_sympy_to_ct(tf.num, tf.den) elif isinstance(tf, sm.control.TransferFunctionMatrix): num, den = [], [] for i in range(tf.num_outputs): row_num, row_den = [], [] for j in range(tf.num_inputs): tmp = _from_sympy_to_ct(tf[i, j].num, tf[i, j].den) row_num.append(list(tmp.num[0][0])) row_den.append(list(tmp.den[0][0])) num.append(row_num) den.append(row_den) return ct.tf(num, den) elif sp and isinstance(tf, sp.signal.TransferFunction): return ct.tf(tf.num, tf.den, dt=0 if tf.dt is None else tf.dt) elif isinstance(tf, (list, tuple)): tf = tf_to_sympy(tf) return _from_sympy_to_ct(tf.num, tf.den) else: raise TypeError( "Transfer function's type not recognized.\n" + "Received: type(tf) = %s\n" % type(tf) + "Expected: Expr or sympy.physics.control.TransferFunction or " + "sympy.physics.control.TransferFunctionMatrix" ) def tf_to_sympy(tf, var=None, skip_check_dt=False, params={}): """Convert a transfer function from the control module or from scipy.signal to a sympy ``TransferFunction`` or ``TransferFunctionMatrix``. Parameters ========== tf : control.TransferFunction, scipy.signal.TransferFunction * an instance of :py:class:`sympy.physics.control.lti.TransferFunction` or :py:class:`sympy.physics.control.lti.TransferFunctionMatrix` * an instance of :py:class:`control.TransferFunction` * an instance of :py:class:`scipy.signal.TransferFunction` * a symbolic expression in rational form. * a tuple of two or three elements: ``(num, den, generator [opt])``. ``num, den`` can be symbolic expressions or list of coefficients. var : Symbol or None The s-variable (or z-variable) when ``tf`` is a symbolic expression. If not provided, it will be automatically selected. skip_check_dt : bool If True, don't raise a warning about sympy not supporting discrete-time systems. params : dict A dictionary whose keys are symbols. Returns ======= tf : TransferFunction or TransferFunctionMatrix """ ct = import_module("control") sp = import_module("scipy") sm = import_module("sympy.physics", import_kwargs={'fromlist': ['control']}) gen = Symbol("z") if is_discrete_time(tf) else Symbol("s") TransferFunction = sm.control.lti.TransferFunction TransferFunctionMatrix = sm.control.lti.TransferFunctionMatrix Series = sm.control.lti.Series Parallel = sm.control.lti.Parallel def _check_dt(system): if system.dt and (not skip_check_dt): warnings.warn( "At the time of writing this message, SymPy doesn't " "implement discrete-time transfer functions. Returning " "a continuous-time transfer function." ) if isinstance(tf, (TransferFunction, TransferFunctionMatrix)): return tf elif isinstance(tf, Expr): if var is None: fs = list(tf.free_symbols.difference(params.keys())) if len(fs) > 1: warnings.warn( "Multiple free symbols found in transfer function: %s. " "Selecting the first as the s-variable " "(or z-variable). Use the ``var=`` keyword argument " "to specify the appropriate symbol." % fs ) var = fs[0] if len(tf.free_symbols) > 0 else Symbol("s") return TransferFunction.from_rational_expression(tf, var) elif isinstance(tf, (Series, Parallel)): return tf.doit() if (ct is not None) and isinstance(tf, ct.TransferFunction): if (tf.ninputs == 1) and (tf.noutputs == 1): n, d = tf.num[0][0], tf.den[0][0] n = Poly.from_list(n, gen).as_expr() d = Poly.from_list(d, gen).as_expr() _check_dt(tf) return TransferFunction(n, d, gen) rows = [] for o in range(tf.noutputs): row = [] for i in range(tf.ninputs): n = tf.num[o][i] d = tf.den[o][i] new_tf = tf_to_sympy(ct.tf(n, d, dt=tf.dt)) row.append(new_tf) rows.append(row) return TransferFunctionMatrix(rows) elif (sp is not None) and isinstance(tf, sp.signal.TransferFunction): n = Poly.from_list(tf.num, gen).as_expr() d = Poly.from_list(tf.den, gen).as_expr() _check_dt(tf) return TransferFunction(n, d, gen) if isinstance(tf, (list, tuple)): powers = lambda e, s: [t * s**(len(e) - (k + 1)) for k, t in enumerate(e)] if len(tf) == 2: num, den = tf if all(isinstance(e, Expr) for e in tf): gen = Tuple(num, den).free_symbols.difference(params.keys()).pop() else: num = sum(powers(num, gen)) den = sum(powers(den, gen)) return TransferFunction(num, den, gen) elif len(tf) == 3: num, den, gen = tf if not all(isinstance(e, Expr) for e in tf): num = sum(powers(num, gen)) den = sum(powers(den, gen)) return TransferFunction(num, den, gen) else: raise ValueError( "If a tuple/list is provided, it must have " "two or three elements: (num, den, free_symbol [opt]). " f"Received len(system) = {len(tf)}, system = {tf}" ) else: raise TypeError( "Transfer function's type not recognized.\n" + "Received: type(tf) = %s\n" % type(tf) + "Expected: Expr or sympy.physics.control.TransferFunction" ) def _get_initial_params(params): """Extract the initial values of parameters from the ``params`` dictionary used on interactive-widget plots. """ return { k: (v[0] if hasattr(v, "__iter__") else v) for k, v in params.items() } def is_discrete_time(system): """Verify if ``system`` is a discrete-time control system. """ ct = import_module("control") sp = import_module("scipy") sm = import_module("sympy.physics", import_kwargs={'fromlist': ['control']}) if isinstance(system, sm.control.lti.SISOLinearTimeInvariant): return False if (sp is not None) and isinstance(system, sp.signal.TransferFunction): return False if system.dt is None else True if (ct is not None) and isinstance(system, ct.TransferFunction): return system.isdtime() return False def tf_find_time_delay(tf, var=None): """Find time delays contained in a symbolic TransferFunction. """ sympy = import_module("sympy") if isinstance(tf, Expr): tf = tf_to_sympy(tf, var=var) if not isinstance(tf, sympy.physics.control.TransferFunction): raise TypeError( "``tf_find_time_delay`` only works with instances of " "sympy.physics.control.lti.TransferFunction." ) num, den, s = tf.args exp_num = [t for t in num.find(exp) if t.has(s)] exp_den = [t for t in den.find(exp) if t.has(s)] return exp_num + exp_den def is_siso(system): """Check if a control system is SISO or not. """ ct = import_module("control") sp = import_module("scipy") sm = import_module("sympy.physics", import_kwargs={'fromlist': ['control']}) if isinstance(system, sm.control.lti.SISOLinearTimeInvariant): return True if sp and isinstance(system, sp.signal.TransferFunction): return True if ( ct and isinstance(system, ct.TransferFunction) and (system.ninputs == 1) and (system.noutputs == 1) ): return True if isinstance(system, Expr): return True return False def _aggregate_parameters(params, series): """Loop over data series to extract the `params` dictionaries provided by the user. This is necessary when dealing with the ``graphics`` module. Parameters ========== params : dict Whatever was provided by the user in the main function call (be it plot(), plot_paramentric(), ..., graphics()) series : list Data series of the current interactive widget plot. Returns ======= params : dict """ if params is None: params = {} # if len(params) == 0: # # this is the case when an interactive widget plot is build with # # the `graphics` interface. for s in series: if s.is_interactive: # use s._original_params instead of s.params in order to # keep track of multi-values widgets params.update(s._original_params) # if len(params) == 0: # raise ValueError( # "In order to create an interactive plot, " # "the `params` dictionary must be provided.") return params def get_environment(): """Find which environment is used to run the code. Returns ======= mode : int 0 - the code is running on Jupyter Notebook or qtconsole 1 - terminal running IPython 2 - other type (?) 3 - probably standard Python interpreter References ========== https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook """ try: shell = get_ipython().__class__.__name__ if shell == "ZMQInteractiveShell": return 0 # Jupyter notebook or qtconsole elif shell == "TerminalInteractiveShell": return 1 # Terminal running IPython else: return 2 # Other type (?) except NameError: return 3 # Probably standard Python interpreter def _correct_shape(a, b): """Convert ``a`` to a np.ndarray of the same shape of ``b``. Parameters ========== a : int, float, complex, np.ndarray Usually, this is the result of a numerical evaluation of a symbolic expression. Even if a discretized domain was used to evaluate the function, the result can be a scalar (int, float, complex). b : np.ndarray It represents the correct shape that ``a`` should have. Returns ======= new_a : np.ndarray An array with the correct shape. """ np = import_module('numpy') if not isinstance(a, np.ndarray): a = np.array(a) if a.shape != b.shape: if a.shape == (): a = a * np.ones_like(b) else: a = a.reshape(b.shape) return a