Source code for epipack.interactive

"""
Interactive Jupyter widgets for SymbolicEpiModels.
"""

import copy
from collections import OrderedDict
from math import log10

import numpy as np
import sympy

import ipywidgets as widgets

import matplotlib.pyplot as pl

from epipack.colors import palettes, hex_colors

[docs]def get_box_layout(): """Return default box layout""" return widgets.Layout( margin='0px 10px 10px 0px', padding='5px 5px 5px 5px' )
[docs]class Range(dict): """ Defines a value range for an interactive linear value slider. Parameters ========== min : float Minimal value of parameter range max : float Maximal value of parameter range step_count : int, default = 100 Divide the parameter space into that many intervals value : float, default = None Initial value. If ``None``, defaults to the mean of ``min`` and ``max``. """ def __init__(self, min, max, step_count=100, value=None): super().__init__() assert(max > min) assert(step_count>0) self['min'] = min self['max'] = max if value is None: self['value'] = 0.5*(max+min) else: assert(min <= value and max >= value) self['value'] = value self['step'] = (max-min)/step_count def __float__(self): return float(self['value']) def __add__(self, other): return other + float(self) def __radd__(self, other): return other + float(self) def __mul__(self, other): return other * float(self) def __rmul__(self, other): return other * float(self) def __truediv__(self, other): return float(self) / other def __rtruediv__(self, other): return other / float(self) def __pow__(self, other): return float(self)**other def __rpow__(self, other): return other**float(self) def __sub__(self, other): return float(self) - other def __rsub__(self, other): return other - float(self)
[docs]class LogRange(dict): """ Defines a value range for an interactive logarithmic value slider. Parameters ========== min : float Minimal value of parameter range max : float Maximal value of parameter range step_count : int, default = 100 Divide the exponent space into that many intervals base : float, default = 10 Base of the logarithm value : float, default = None Initial value. If ``None``, defaults to the geometric mean of ``min`` and ``max``. """ def __init__(self, min, max, step_count=100, value=None, base=10, ): super().__init__() assert(max > min) assert(step_count>0) assert(base>0) def logB(x): return np.log(x) / np.log(base) self['min'] = logB(min) self['max'] = logB(max) if value is None: self['value'] = np.sqrt(max*min) else: assert(min <= value and max >= value) self['value'] = value self['step'] = (logB(max)-logB(min))/step_count self['base'] = base def __float__(self): return float(self['value']) def __add__(self, other): return other + float(self) def __radd__(self, other): return other + float(self) def __mul__(self, other): return other * float(self) def __rmul__(self, other): return other * float(self) def __truediv__(self, other): return float(self) / other def __rtruediv__(self, other): return other / float(self) def __pow__(self, other): return float(self)**other def __rpow__(self, other): return other**float(self) def __sub__(self, other): return float(self) - other def __rsub__(self, other): return other - float(self)
[docs]class InteractiveIntegrator(widgets.HBox): """ An interactive widget that lets you control parameters of a SymbolicEpiModel and shows you the output. Based on this tutorial: https://kapernikov.com/ipywidgets-with-matplotlib/ Parameters ========== model : epipack.symbolic_epi_models.SymbolicEpiModel An instance of ``SymbolicEpiModel`` that has been initiated with initial conditions parameter_values : dict A dictionary that maps parameter symbols to single, fixed values or ranges (instances of :class:`epipack.interactive.Range` or :class:`epipack.interactive.LogRange`). t : numpy.ndarray The time points over which the model will be integrated return_compartments : list, default = None A list of compartments that should be displayed. If ``None``, all compartments will be displayed. return_derivatives : list, default = None A list of derivatives that should be displayed If ``None``, no derivatives will be displayed. figsize : tuple, default = (4,4) Width and height of the created figure. palette : str, default = 'dark' A palette from ``epipack.colors``. Choose from .. code:: python [ 'dark', 'light', 'dark pastel', 'light pastel', 'french79', 'french79 pastel', 'brewer light', 'brewer dark', 'brewer dark pastel', 'brewer light pastel' ] integrator : str, default = 'dopri5' Either ``euler`` or ``dopri5``. continuous_update : bool, default = False If ``False``, curves will be updated only if the mouse button is released. If ``True``, curves will be continuously updated. show_grid : bool, default = False Whether or not to display a grid Attributes ========== model : epipack.symbolic_epi_models.SymbolicEpiModel An instance of ``SymbolicEpiModel`` that has been initiated with initial conditions. fixed_parameters : dict A dictionary that maps parameter symbols to single, fixed values t : numpy.ndarray The time points over which the model will be integrated return_compartments : list A list of compartments that will be displayed. colors : list A list of hexstrings. fig : matplotlib.Figure The figure that will be displayed. ax : matplotlib.Axis The axis that will be displayed. lines : dict Maps compartments to line objects children : list Contains two displayed boxes (controls and output) continuous_update : bool, default = False If ``False``, curves will be updated only if the mouse button is released. If ``True``, curves will be continuously updated. """ def __init__(self, model, parameter_values, t, return_compartments=None, return_derivatives=None, figsize=(4,4), palette='dark', integrator='dopri5', continuous_update=False, show_grid=False, ): super().__init__() self.model = model self.t = np.array(t) self.colors = [ hex_colors[colorname] for colorname in palettes[palette] ] if return_compartments is None: self.return_compartments = self.model.compartments else: self.return_compartments = return_compartments self.return_derivatives = return_derivatives self.integrator = integrator self.lines = None self.continuous_update = continuous_update output = widgets.Output() with output: self.fig, self.ax = pl.subplots(constrained_layout=True, figsize=figsize) self.ax.set_xlabel('time') self.ax.set_ylabel('frequency') self.ax.grid(show_grid) self.fig.canvas.toolbar_position = 'bottom' # define widgets self.fixed_parameters = {} self.sliders = {} for parameter, value in parameter_values.items(): self.fixed_parameters[parameter] = float(value) if type(value) not in [Range, LogRange]: continue else: these_vals = copy.deepcopy(value) these_vals['description'] = r'\(' + sympy.latex(parameter) + r'\)' these_vals['continuous_update'] = self.continuous_update if type(value) == LogRange: slider = widgets.FloatLogSlider(**these_vals) else: slider = widgets.FloatSlider(**these_vals) self.sliders[parameter] = slider checkb_xscale = widgets.Checkbox( value=False, description='logscale time', ) checkb_yscale = widgets.Checkbox( value=False, description='logscale frequency', ) controls = widgets.VBox( list(self.sliders.values()) + [ checkb_xscale, checkb_yscale, ]) controls.layout = get_box_layout() out_box = widgets.Box([output]) output.layout = get_box_layout() for parameter, slider in self.sliders.items(): slider.observe(self.update_parameters, 'value') checkb_xscale.observe(self.update_xscale, 'value') checkb_yscale.observe(self.update_yscale, 'value') self.children = [controls, output] self.update_parameters()
[docs] def update_parameters(self, *args, **kwargs): """Update the current values of parameters as given by slider positions.""" parameters = copy.deepcopy(self.fixed_parameters) for parameter, slider in self.sliders.items(): parameters[parameter] = slider.value self.update_plot(parameters)
[docs] def update_plot(self, parameters): """Recompute and -draw the epidemic curves with updated parameter values""" self.model.set_parameter_values(parameters) if self.return_derivatives is None: res = self.model.integrate( self.t, return_compartments=self.return_compartments, integrator=self.integrator) else: res = self.model.integrate_and_return_by_index( self.t, integrator=self.integrator) ndx = [ self.model.get_compartment_id(C) for C in self.return_derivatives ] dydt = self.model.get_numerical_dydt() derivatives = np.array([ dydt(t,res[:,it]) for it, t in enumerate(self.t) ]).T res = {C: res[self.model.get_compartment_id(C),:] for C in self.return_compartments} der = {C: derivatives[self.model.get_compartment_id(C),:] for C in self.return_derivatives} is_initial_run = self.lines is None if is_initial_run: self.lines = {} # plot compartments for iC, C in enumerate(self.return_compartments): ydata = res[C] if is_initial_run: self.lines[C], = self.ax.plot(self.t,ydata,label=str(C),color=self.colors[iC]) else: self.lines[C].set_ydata(ydata) # plot derivatives if self.return_derivatives is not None: for iC, C in enumerate(self.return_derivatives): ydata = der[C] _C = 'd' + str(C) + '/dt' if is_initial_run: self.lines[_C], = self.ax.plot(self.t,ydata,ls='--',label=_C,color=self.colors[iC]) else: self.lines[_C].set_ydata(ydata) if is_initial_run: self.ax.legend() self.fig.canvas.draw()
[docs] def update_xscale(self, change): """Update the scale of the x-axis. For "log", pass an object ``change`` that has ``change.new=True``""" scale = 'linear' if change.new: scale = 'log' self.ax.set_xscale(scale)
[docs] def update_yscale(self, change): """Update the scale of the y-axis. For "log", pass an object ``change`` that has ``change.new=True``""" scale = 'linear' if change.new: scale = 'log' self.ax.set_yscale(scale)
[docs]class GeneralInteractiveWidget(widgets.HBox): """ An interactive widget that lets you control parameters that are passed to a custom function which returns a result dictionary. Based on this tutorial: https://kapernikov.com/ipywidgets-with-matplotlib/ Parameters ========== result_function : func A function that returns a result dictionary when passed parameter values as ``result_function(**parameter_values)``. parameter_values : dict A dictionary that maps parameter names to single, fixed values or ranges (instances of :class:`epipack.interactive.Range` or :class:`epipack.interactive.LogRange`). t : numpy.ndarray The time points corresponding to values in the result dictionary. return_keys : list, default = None A list of result keys that should be shown. If ``None``, all compartments will be displayed. figsize : tuple, default = (4,4) Width and height of the created figure. palette : str, default = 'dark' A palette from ``epipack.colors``. Choose from .. code:: python [ 'dark', 'light', 'dark pastel', 'light pastel', 'french79', 'french79 pastel', 'brewer light', 'brewer dark', 'brewer dark pastel', 'brewer light pastel' ] continuous_update : bool, default = False If ``False``, curves will be updated only if the mouse button is released. If ``True``, curves will be continuously updated. show_grid : bool, default = False Whether or not to display a grid ylabel : str, default = 'frequency' What to name the yaxis label_converter : func, default = str A function that returns a string when passed a result key or parameter name. Attributes ========== result_function : func A function that returns a result dictionary when passed parameter values as ``result_function(**parameter_values)``. fixed_parameters : dict A dictionary that maps parameter names to fixed values t : numpy.ndarray The time points corresponding to values in the result dictionary. return_keys : list A list of result dictionary keys of which the result will be displayed. colors : list A list of hexstrings. fig : matplotlib.Figure The figure that will be displayed. ax : matplotlib.Axis The axis that will be displayed. lines : dict Maps compartments to line objects children : list Contains two displayed boxes (controls and output) continuous_update : bool, default = False If ``False``, curves will be updated only if the mouse button is released. If ``True``, curves will be continuously updated. lbl : func, default = str A function that returns a string when passed a result key or parameter name. """ def __init__(self, result_function, parameter_values, t, return_keys=None, figsize=(4,4), palette='dark', continuous_update=False, show_grid=False, ylabel='frequency', label_converter=str, ): super().__init__() self.t = t self.get_result = result_function self.colors = [ hex_colors[colorname] for colorname in palettes[palette] ] self.return_keys = return_keys self.lines = None self.continuous_update = continuous_update self.lbl = label_converter output = widgets.Output() with output: self.fig, self.ax = pl.subplots(constrained_layout=True, figsize=figsize) self.ax.set_xlabel('time') self.ax.set_ylabel(ylabel) self.ax.grid(show_grid) self.fig.canvas.toolbar_position = 'bottom' # define widgets self.fixed_parameters = {} self.sliders = {} for parameter, value in parameter_values.items(): self.fixed_parameters[parameter] = float(value) if type(value) not in [Range, LogRange]: continue else: these_vals = copy.deepcopy(value) these_vals['description'] = self.lbl(parameter) or parameter these_vals['continuous_update'] = self.continuous_update if type(value) == LogRange: slider = widgets.FloatLogSlider(**these_vals) else: slider = widgets.FloatSlider(**these_vals) self.sliders[parameter] = slider checkb_xscale = widgets.Checkbox( value=False, description='logscale time', ) checkb_yscale = widgets.Checkbox( value=False, description='logscale frequency', ) controls = widgets.VBox( list(self.sliders.values()) + [ checkb_xscale, checkb_yscale, ]) controls.layout = get_box_layout() out_box = widgets.Box([output]) output.layout = get_box_layout() for parameter, slider in self.sliders.items(): slider.observe(self.update_parameters, 'value') checkb_xscale.observe(self.update_xscale, 'value') checkb_yscale.observe(self.update_yscale, 'value') self.children = [controls, output] self.update_parameters()
[docs] def update_parameters(self, *args, **kwargs): """Update the current values of parameters as given by slider positions.""" parameters = copy.deepcopy(self.fixed_parameters) for parameter, slider in self.sliders.items(): parameters[parameter] = slider.value self.update_plot(parameters)
[docs] def update_plot(self, parameters): """Recompute and -draw the epidemic curves with updated parameter values""" res = self.get_result(**parameters) is_initial_run = self.lines is None if is_initial_run: self.lines = {} if self.return_keys is None: keys = res.keys() else: keys = self.return_keys # plot compartments for iC, C in enumerate(keys): ydata = res[C] if is_initial_run: self.lines[C], = self.ax.plot(self.t,ydata,label=self.lbl(C),color=self.colors[iC]) else: self.lines[C].set_ydata(ydata) if is_initial_run: self.ax.legend() self.fig.canvas.draw()
[docs] def update_xscale(self, change): """Update the scale of the x-axis. For "log", pass an object ``change`` that has ``change.new=True``""" scale = 'linear' if change.new: scale = 'log' self.ax.set_xscale(scale)
[docs] def update_yscale(self, change): """Update the scale of the y-axis. For "log", pass an object ``change`` that has ``change.new=True``""" scale = 'linear' if change.new: scale = 'log' self.ax.set_yscale(scale)
if __name__=="__main__": # pragma: no cover A = LogRange(0.1,1,value=0.5) print(A + 2) print(2 + A) print(A * 2) print(2 * A) print(A / 2) print(2 / A) print(A**2) print(2**A) print(A - 2) print(2 - A)