Source code for spb.animation

import io
import os
import shutil
from spb.interactive import IPlot
from spb.utils import _aggregate_parameters, get_environment
from sympy import Symbol
from sympy.external import import_module
from tempfile import TemporaryDirectory


[docs] class BaseAnimation: """Implements the base functionalities to create animations. """ def _load_imageio(self): imageio = import_module( 'imageio', import_kwargs={ 'fromlist': [ 'v3' ] }, warn_not_installed=True, ) self.imwrite = imageio.v3.imwrite self.imread = imageio.v3.imread self.mimwrite = imageio.mimwrite def _post_init_plot(self, *args, **kwargs): """This methods has to be executed after self._backend has been set. """ mergedeep = import_module("mergedeep") merge = mergedeep.merge animation = kwargs.get("animation", False) self._animation_data = None if isinstance(animation, AnimationData): self._animation_data = animation else: params = {} if animation: params = _aggregate_parameters(params, self._backend.series) animation_data_kwargs = {"params": params} if isinstance(animation, dict): animation_data_kwargs = merge({}, animation_data_kwargs, animation) if animation: self._animation_data = AnimationData(**animation_data_kwargs) initial_params = self.animation_data[0] # update series with proper initial values before plotting for s in self._backend.series: if s.is_interactive: s.params = initial_params self._load_imageio() def _post_init_plotgrid(self, *args, **kwargs): """This methods has to be executed after self._backend has been set. """ self._animation_data = None original_params, fps, time = {}, [], [] for p in self._backend._all_plots: if isinstance(p, BaseAnimation): original_params = _aggregate_parameters( original_params.copy(), p.backend.series) fps.append(p._animation_data.fps) time.append(p._animation_data.time) if original_params: self._animation_data = AnimationData( fps=max(fps), time=max(time), params=original_params) self._load_imageio() @property def animation_data(self): return self._animation_data def update_animation(self, frame_idx): """Update the figure in order to obtain the visualization at a specified frame of the animation. Parameters ========== frame_idx : int Must be ``0 <= frame_idx < fps*time``. """ if not self.animation_data: raise RuntimeError( "The data necessary to build the animation has not been " "provided. You must set `animation=True` to the function call." ) params = self.animation_data[frame_idx] self.backend.update_interactive(params) def get_FuncAnimation(self): """Return a Matplotlib's ``FuncAnimation`` object. It only works if this animation is showing a Matplotlib's figure. """ from spb import MB if not isinstance(self.backend, MB): raise TypeError( "FuncAnimation can only be created when the backend produced " "Matplotlib's figure. " f"`{type(self.backend).__name__}` does not." ) matplotlib = import_module( 'matplotlib', import_kwargs={ 'fromlist': ['animation'] }, warn_not_installed=True, min_module_version='1.1.0', catch=(RuntimeError,)) return matplotlib.animation.FuncAnimation( fig=self.backend.fig, func=self.update_animation, frames=self.animation_data.n_frames, interval=int(1000 / self.animation_data.fps) ) def save(self, path, save_frames=False, **kwargs): """Save the animation to a file. Parameters ========== path : str Where to save the animation on the disk. Supported formats are ``.gif`` or ``.mp4``. save_frames : bool, optional Default to False. If True, save individual frames into png files. **kwargs : Keyword arguments to customize the gif/video creation process. Both gif/video animations are created using ``imageio.mimwrite``. In particular: * GIFs are created with :py:class:`imageio.plugins.pillowmulti.GIFFormat` * MP4s are created with :py:class:`imageio.plugins.ffmpeg.FfmpegFormat`. If a video seems to be low-quality, try to increase the bitrate. Its default value is ``bitrate=3000000``. Notes ===== * When saving an animation from Jupyter Notebook/Lab, a progress bar will be visibile, indicating how many frames have been generated. * Saving K3D-Jupyter animations is particularly slow. """ ext = os.path.splitext(path)[1] if len(ext) == 0: raise ValueError("Please, provide a file extension.") # avoid circular imports from spb.plotgrid import PlotGrid if ( isinstance(self._backend, PlotGrid) and (not self._backend.is_matplotlib_fig) ): raise RuntimeError( "Saving plotgrid animation is only supported when the overall " "figure is a Matplotlib's figure." ) from spb import KB if isinstance(self._backend, KB): self._save_k3d_animation(path, save_frames, **kwargs) else: self._save_other_backends_animation(path, save_frames, **kwargs) def _save_k3d_animation(self, path, save_frames, **kwargs): from tqdm.notebook import trange n_frames = self.animation_data.n_frames base = os.path.basename(path).split(".")[0] @self._backend.fig.yield_screenshots def inner_func(): frames = [] r = (range(n_frames) if get_environment() != 0 else trange(n_frames)) for i in r: self.update_animation(i) self._backend.fig.fetch_screenshot() screenshot_bytes = yield buffer = io.BytesIO(screenshot_bytes) img = self.imread(buffer) frames.append(img) if save_frames: name = base + "_" + str(i) + ".png" self.imwrite(os.path.join(os.path.dirname(path), name), img) self._save_helper(path, frames, **kwargs) inner_func() def _save_other_backends_animation(self, path, save_frames, **kwargs): from tqdm.notebook import trange n_frames = self.animation_data.n_frames base = os.path.basename(path).split(".")[0] with TemporaryDirectory(prefix="animation") as tmpdir: tmp_filenames = [] dest = os.path.dirname(path) if dest == "": dest = "." r = (range(n_frames) if get_environment() != 0 else trange(n_frames)) for i in r: self.update_animation(i) filename = base + "_" + str(i) + ".png" tmp_filename = os.path.join(tmpdir, filename) tmp_filenames.append(tmp_filename) self._backend.save(tmp_filename) if save_frames: shutil.copy2(tmp_filename, dest) frames = [self.imread(f) for f in tmp_filenames] self._save_helper(path, frames, **kwargs) def _save_helper(self, path, frames, **kwargs): ext = os.path.splitext(path)[1] fps = self.animation_data.fps if ext == ".gif": kwargs.setdefault("loop", 0) # loop=0 means loops continuously kwargs.setdefault("fps", fps) # NOTE: from my tests on 3D plots with colorbars, 2 works best. kwargs.setdefault("quantizer", 2) elif ext == ".mp4": kwargs.setdefault("fps", fps) # NOTE: setting quality=something would use variable bitrate. # However, this creates artifacts between consecutive frames. # Instead, let's use constant bitrate. kwargs.setdefault("bitrate", 3000000) self.mimwrite(path, frames, **kwargs)
class AnimationData: """Verify that the user provided the appropriate parameters. If so, creates a matrix with the following form: .. code-block:: text | param 1 | param 2 | ... | param M | --------|------------------------------------ frame 1 | val | val | ... | val | frame 2 | val | val | ... | val | ... | ... | ... | ... | ... | frame N | val | val | ... | val | Where each column represents a time-series of values associated to a particular symbol. Each row represent the values of all symbols at a particular time. """ def __init__(self, fps=30, time=5, params=None): if not isinstance(params, dict): raise TypeError("``params`` must be a dictionary.") if len(params) == 0: raise ValueError( "In order to build an animation, at lest one " "parameter must be provided.") not_symbols = [not isinstance(k, Symbol) for k in params.keys()] if any(not_symbols): raise ValueError( "All keys of ``params`` must be a single symbol. The " "following keys are something else: %s" % [k for i, k in enumerate(params.keys()) if not_symbols[i]]) self.parameters = list(params.keys()) self.n_frames = int(fps * time) self.time = time self.fps = fps np = import_module("numpy") values = [] for k, v in params.items(): if isinstance(v, dict): values.append(self._create_steps(v)) elif isinstance(v, (list, tuple)) and (len(v) <= 3): values.append(self._create_interpolation(v)) elif isinstance(v, np.ndarray): if len(v) != self.n_frames: raise ValueError( "The length of the values associated to `%s` must " "be %s. Instead, an array of length %s was received." "" % (k, self.n_frames, len(v)) ) values.append(v) else: raise TypeError( "The value associated to '%s' is not supported. " "Expected an instance of `dict`, or `list` or `tuple`. " "Received: %s" % (k, type(v)) ) self.matrix = np.array(values).T def _create_steps(self, d): np = import_module("numpy") ani_time = self.time values = np.zeros(self.n_frames) for time, val in d.items(): n_frame = int(time / ani_time * self.n_frames) values[n_frame:] = val return values def _create_interpolation(self, v): if len(v) == 2: start, end = [float(t) for t in v] strategy = "linear" else: start, end, strategy = v start, end = float(start), float(end) strategy = strategy.lower() allowed_strategies = ["linear", "log"] if not (strategy in allowed_strategies): raise ValueError( "Discretization strategy must be either one of the " "following: %s" % allowed_strategies) np = import_module("numpy") if strategy == "linear": return np.linspace(start, end, self.n_frames) return np.geomspace(start, end, self.n_frames) def __getitem__(self, index): """Returns a dictionary mapping parameters to values at the specified animation frame. """ return {k: v for k, v in zip(self.parameters, self.matrix[index, :])}