import abc
from functools import partial
import matplotlib.animation as mplanim
import matplotlib.pyplot as plt
import matplotlib.widgets as widgets
import mpl_toolkits.axes_grid1.axes_size as Size
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable
try:
from astropy import units
except ImportError:
units = None
__all__ = ['BaseFuncAnimator', 'ArrayAnimator']
[docs]
class BaseFuncAnimator(metaclass=abc.ABCMeta):
"""
Create a Matplotlib backend independent data explorer which allows
definition of figure update functions for each slider.
The following keyboard shortcuts are defined in the viewer:
* 'left': previous step on active slider.
* 'right': next step on active slider.
* 'top': change the active slider up one.
* 'bottom': change the active slider down one.
* 'p': play/pause active slider.
User-defined buttons can be added to the viewer by specifying the button labels
and functions called when those buttons are clicked.
See the descriptions of the ``button_labels`` and ``button_func`` keyword arguments.
To make this class useful the subclass must implement ``_plot_start_image``
which must define a ``self.im`` attribute which is an instance of
`matplotlib.image.AxesImage`.
Parameters
----------
data: `iterable`
Some arbitrary data.
slider_functions: `list`
A list of functions to call when that slider is changed.
These functions will have ``val``, the axes image object and the slider
widget instance passed to them, e.g., ``update_slider(val, im, slider)``
slider_ranges: `list`
A list of ``[min,max]`` pairs to set the ranges for each slider or an array
of values for all points of the slider.
(The slider update function decides which to support.)
fig: `matplotlib.figure.Figure`, optional
`~matplotlib.figure.Figure` to use. Defaults to `None`, in which case a new
figure is created.
interval: `int`, optional
Animation interval in milliseconds. Defaults to 200.
colorbar: `bool`, optional
Plot a colorbar. Defaults to `False`.
button_labels: `list`, optional
A list of strings to label buttons. Defaults to `None`. If `None`
and ``button_func`` is specified, it will default to the
names of the functions.
button_func: `list`, optional
A list of functions to map to the buttons. These functions are called
with two arguments, ``(animator, event)`` where the first argument is
the animator object, and the second is a
`matplotlib.backend_bases.MouseEvent` object.
Defaults to `None`.
slider_labels: `list`, optional
A list of labels to draw in the slider, must be the same length as
``slider_functions``.
Attributes
----------
fig : `matplotlib.figure.Figure`
axes : `matplotlib.axes.Axes`
Notes
-----
Extra keywords are passed to `matplotlib.pyplot.imshow`.
"""
def __init__(self, data, slider_functions, slider_ranges, fig=None,
interval=200, colorbar=False, button_func=None, button_labels=None,
start_image_func=None, slider_labels=None, **kwargs):
# Allow the user to specify the button func:
self.button_func = button_func or []
if button_func and not button_labels:
button_labels = [a.__name__ for a in button_func]
self.button_labels = button_labels or []
self.num_buttons = len(self.button_func)
if not fig:
fig = plt.figure()
self.fig = fig
self.data = data
self.interval = interval
self.if_colorbar = colorbar
self.imshow_kwargs = kwargs
if len(slider_functions) != len(slider_ranges):
raise ValueError("slider_functions and slider_ranges must be the same length.")
if slider_labels is not None:
if len(slider_labels) != len(slider_functions):
raise ValueError("slider_functions and slider_labels must be the same length.")
self.num_sliders = len(slider_functions)
self.slider_functions = slider_functions
self.slider_ranges = slider_ranges
self.slider_labels = slider_labels or [''] * len(slider_functions)
# Set active slider
self.active_slider = 0
# Set a blank timer
self.timer = None
# Set up axes
self.axes = None
self._make_axes_grid()
self._add_widgets()
self._set_active_slider(0)
# Set the current axes to the main axes so commands like plt.ylabel() work.
#
# Only do this if figure has a manager, so directly constructed figures
# (ie. via matplotlib.figure.Figure()) work.
if hasattr(self.fig.canvas, "manager") and self.fig.canvas.manager is not None:
plt.sca(self.axes)
# Do Plot
self.im = self.plot_start_image(self.axes)
# Connect fig events
self._connect_fig_events()
[docs]
def label_slider(self, i, label):
"""
Change the slider label.
Parameters
----------
i: `int`
The index of the slider to change (0 is bottom).
label: `str`
The label to set.
"""
self.sliders[i]._slider.label.set_text(label)
[docs]
def get_animation(self, axes=None, slider=0, startframe=0, endframe=None,
stepframe=1, **kwargs):
"""
Return a `~matplotlib.animation.FuncAnimation` instance for the
selected slider.
This will allow easy saving of the animation to a file.
Parameters
----------
axes: `matplotlib.axes.Axes`, optional
The `matplotlib.axes.Axes` to animate. Defaults to `None`, in which
case the Axes associated with this animator are used. Passing a
custom Axes can be useful if you want to create the animation on
a custom figure that is not the figure set up by this Animator.
slider: `int`, optional
The slider to animate along. Defaults to 0.
startframe: `int`, optional
The frame to start the animation. Defaults to 0.
endframe: `int`, optional
The frame to end the animation. Defaults to `None`.
stepframe: `int`, optional
The step between frames. Defaults to 1.
Notes
-----
Extra keywords are passed to `matplotlib.animation.FuncAnimation`.
"""
if not axes:
axes = self.axes
anim_fig = axes.get_figure()
if endframe is None:
endframe = self.slider_ranges[slider][1]
im = self.plot_start_image(axes)
anim_kwargs = {'frames': list(range(startframe, endframe, stepframe)),
'fargs': [im, self.sliders[slider]._slider]}
anim_kwargs.update(kwargs)
ani = mplanim.FuncAnimation(anim_fig, self.slider_functions[slider],
**anim_kwargs)
return ani
[docs]
@abc.abstractmethod
def plot_start_image(self, ax):
"""
This method creates the initial image on the `matplotlib.axes.Axes`.
.. warning::
This method needs to be implemented in subclasses.
Parameters
----------
ax: `matplotlib.axes.Axes`
This is the axes on which to plot the image.
Returns
-------
`matplotlib.artist.Artist`
The matplotlib object to be animated, this is usually either a
`~matplotlib.image.AxesImage` object, or a
`~matplotlib.lines.Line2D`.
"""
raise NotImplementedError("Please define this function.")
def _connect_fig_events(self):
self.fig.canvas.mpl_connect('button_press_event', self._mouse_click)
self.fig.canvas.mpl_connect('key_press_event', self._key_press)
def _add_colorbar(self, im):
self.colorbar = self.fig.colorbar(im, self.cax)
# =============================================================================
# Figure event callback functions
# =============================================================================
def _mouse_click(self, event):
if event.inaxes in self.sliders:
slider = self.sliders.index(event.inaxes)
self._set_active_slider(slider)
def _key_press(self, event):
if event.key == 'left':
self._previous(self.sliders[self.active_slider]._slider)
elif event.key == 'right':
self._step(self.sliders[self.active_slider]._slider)
elif event.key == 'up':
self._set_active_slider((self.active_slider+1) % self.num_sliders)
elif event.key == 'down':
self._set_active_slider((self.active_slider-1) % self.num_sliders)
elif event.key == 'p':
self._click_slider_button(event, self.slider_buttons[self.active_slider]._button,
self.sliders[self.active_slider]._slider)
# =============================================================================
# Active Slider methods
# =============================================================================
def _set_active_slider(self, ind):
self._dehighlight_slider(self.active_slider)
self._highlight_slider(ind)
self.active_slider = ind
def _highlight_slider(self, ind):
ax = self.sliders[ind]
[a.set_linewidth(2.0) for n, a in ax.spines.items()]
self.fig.canvas.draw()
def _dehighlight_slider(self, ind):
ax = self.sliders[ind]
[a.set_linewidth(1.0) for n, a in ax.spines.items()]
self.fig.canvas.draw()
# =============================================================================
# Build the figure and place the widgets
# =============================================================================
def _setup_main_axes(self):
"""
Allow replacement of main axes by subclassing.
This method must set the ``axes`` attribute.
"""
if self.axes is None:
self.axes = self.fig.add_subplot(111)
def _make_axes_grid(self):
self._setup_main_axes()
# Split up the current axes so there is space for start & stop buttons
self.divider = make_axes_locatable(self.axes)
pad = 0.01 # Padding between axes
pad_size = Size.Fraction(pad, Size.AxesX(self.axes))
large_pad_size = Size.Fraction(0.1, Size.AxesY(self.axes))
button_grid = max((7, self.num_buttons))
# Define size of useful axes cells, 50% each in x 20% for buttons in y.
ysize = Size.Fraction((1.-2.*pad)/15., Size.AxesY(self.axes))
xsize = Size.Fraction((1.-2.*pad)/button_grid, Size.AxesX(self.axes))
# Set up grid, 3x3 with cells for padding.
if self.num_buttons > 0:
horiz = [xsize] + [pad_size, xsize]*(button_grid-1)
vert = [ysize, pad_size] * self.num_sliders + \
[large_pad_size, large_pad_size, Size.AxesY(self.axes)]
else:
vert = [ysize, large_pad_size] * self.num_sliders + \
[large_pad_size, Size.AxesY(self.axes)]
horiz = [Size.Fraction(0.1, Size.AxesX(self.axes))] + \
[Size.Fraction(0.05, Size.AxesX(self.axes))] + \
[Size.Fraction(0.65, Size.AxesX(self.axes))] + \
[Size.Fraction(0.1, Size.AxesX(self.axes))] + \
[Size.Fraction(0.1, Size.AxesX(self.axes))]
self.divider.set_horizontal(horiz)
self.divider.set_vertical(vert)
self.button_ny = len(vert) - 3
# If we are going to add a colorbar it'll need an axis next to the plot
if self.if_colorbar:
nx1 = -3
self.cax = self.fig.add_axes((0., 0., 0.141, 1.))
locator = self.divider.new_locator(nx=-2, ny=len(vert)-1, nx1=-1)
self.cax.set_axes_locator(locator)
else:
# Main figure spans all horiz and is in the top (2) in vert.
nx1 = -1
self.axes.set_axes_locator(
self.divider.new_locator(nx=0, ny=len(vert)-1, nx1=nx1))
def _add_widgets(self):
self.buttons = []
for i in range(0, self.num_buttons):
x = i * 2
# The i+1/10. is a bug that if you make two axes directly on top of
# one another then the divider doesn't work.
self.buttons.append(self.fig.add_axes((0., 0., 0.+i/10., 1.)))
locator = self.divider.new_locator(nx=x, ny=self.button_ny)
self.buttons[-1].set_axes_locator(locator)
self.buttons[-1]._button = widgets.Button(self.buttons[-1],
self.button_labels[i])
self.buttons[-1]._button.on_clicked(partial(self.button_func[i], self))
self.sliders = []
self.slider_buttons = []
for i in range(self.num_sliders):
y = i * 2
self.sliders.append(self.fig.add_axes((0., 0., 0.01+i/10., 1.)))
if self.num_buttons == 0:
nx1 = 3
else:
nx1 = -2
locator = self.divider.new_locator(nx=2, ny=y, nx1=nx1)
self.sliders[-1].set_axes_locator(locator)
self.sliders[-1].text(0.5, 0.5, self.slider_labels[i],
transform=self.sliders[-1].transAxes,
horizontalalignment="center",
verticalalignment="center")
sframe = widgets.Slider(self.sliders[-1], "",
self.slider_ranges[i][0],
self.slider_ranges[i][-1]-1,
valinit=self.slider_ranges[i][0],
valfmt='%4.1f')
sframe.on_changed(partial(self._slider_changed, slider=sframe))
sframe.slider_ind = i
sframe.cval = sframe.val
self.sliders[-1]._slider = sframe
self.slider_buttons.append(
self.fig.add_axes((0., 0., 0.05+y/10., 1.)))
locator = self.divider.new_locator(nx=0, ny=y)
self.slider_buttons[-1].set_axes_locator(locator)
butt = widgets.Button(self.slider_buttons[-1], ">")
butt.on_clicked(partial(self._click_slider_button, button=butt, slider=sframe))
butt.clicked = False
self.slider_buttons[-1]._button = butt
# =============================================================================
# Widget callbacks
# =============================================================================
def _slider_changed(self, val, slider):
self.slider_functions[slider.slider_ind](val, self.im, slider)
def _click_slider_button(self, event, button, slider):
self._set_active_slider(slider.slider_ind)
if button.clicked:
self._stop_play(event)
button.clicked = False
button.label.set_text(">")
else:
button.clicked = True
self._start_play(event, button, slider)
button.label.set_text("||")
self.fig.canvas.draw()
def _start_play(self, event, button, slider):
if not self.timer:
self.timer = self.fig.canvas.new_timer()
self.timer.interval = self.interval
self.timer.add_callback(self._step, slider)
self.timer.start()
def _stop_play(self, event):
if self.timer:
self.timer.remove_callback(self._step)
self.timer = None
def _step(self, slider):
s = slider
if s.val >= s.valmax:
s.set_val(s.valmin)
else:
s.set_val(s.val+1)
self.fig.canvas.draw()
def _previous(self, slider):
s = slider
if s.val <= s.valmin:
s.set_val(s.valmax)
else:
s.set_val(s.val-1)
self.fig.canvas.draw()
[docs]
class ArrayAnimator(BaseFuncAnimator, metaclass=abc.ABCMeta):
"""
Create a Matplotlib backend independent data explorer.
The following keyboard shortcuts are defined in the viewer:
* 'left': previous step on active slider.
* 'right': next step on active slider.
* 'top': change the active slider up one.
* 'bottom': change the active slider down one.
* 'p': play/pause active slider.
This viewer can have user defined buttons added by specifying the labels
and functions called when those buttons are clicked as keyword arguments.
Parameters
----------
data: `numpy.ndarray`
The data to be visualized.
image_axes: `list`, optional
A list of the axes order that make up the image.
axis_ranges: `list` of physical coordinates for the `numpy.ndarray`, optional
Defaults to `None` and array indices will be used for all axes.
The `list` should contain one element for each axis of the `numpy.ndarray`.
For the image axes a ``[min, max]`` pair should be specified which will be
passed to `matplotlib.pyplot.imshow` as an extent.
For the slider axes a ``[min, max]`` pair can be specified or an array the
same length as the axis which will provide all values for that slider.
Notes
-----
Extra keywords are passed to `~sunpy.visualization.animator.BaseFuncAnimator`.
"""
def __init__(self, data, image_axes=[-2, -1], axis_ranges=None, **kwargs):
all_axes = list(range(self.naxis))
# Handle negative indexes
self.image_axes = [all_axes[i] for i in image_axes]
slider_axes = list(range(self.naxis))
for x in self.image_axes:
slider_axes.remove(x)
if len(slider_axes) != self.num_sliders:
raise ValueError("Number of sliders doesn't match the number of slider axes.")
self.slider_axes = slider_axes
# Verify that combined slider_axes and image_axes make all axes
ax = self.slider_axes + self.image_axes
ax.sort()
if ax != list(range(self.naxis)):
raise ValueError("Number of image and slider axes do not match total number of axes.")
self.axis_ranges, self.extent = self._sanitize_axis_ranges(axis_ranges, data.shape)
# create data slice
self.frame_slice = [slice(None)] * self.naxis
for i in self.slider_axes:
self.frame_slice[i] = 0
slider_functions = kwargs.pop("slider_functions", [])
slider_ranges = kwargs.pop("slider_ranges", [])
base_kwargs = {
'slider_functions': ([self.update_plot] * self.num_sliders) + slider_functions,
'slider_ranges': [[0, dim] for dim in np.array(data.shape)[self.slider_axes]] + slider_ranges
}
self.num_sliders = len(base_kwargs["slider_functions"])
base_kwargs.update(kwargs)
super().__init__(data, **base_kwargs)
@property
def frame_index(self):
"""
A tuple version of ``frame_slice`` to be used when indexing arrays.
"""
return tuple(self.frame_slice)
[docs]
def label_slider(self, i, label):
"""
Change the Slider label.
Parameters
----------
i: `int`
The index of the slider to change (0 is bottom).
label: `str`
The label to set.
"""
self.sliders[i]._slider.label.set_text(label)
def _sanitize_axis_ranges(self, axis_ranges, data_shape):
"""
This method takes the various allowed values of ``axis_ranges`` and
returns them in a standardized way for the rest of the class to use.
The outputted axis range describes the physical coordinates of the
array axes.
The allowed values of axis range is either `None` or a `list`.
If ``axis_ranges`` is `None` then all axis are assumed to be not scaled and
will use array indices.
Where ``axis_ranges`` is a `list` it must have the same length as the number
of axis as the array and each element must be one of the following:
* `None`: Build a "min,max" pair or `numpy.linspace` array of array indices.
* ``[min, max]``: Either leave for the image axes or convert to a array for slider axes
(from min to max in axis length steps)
* ``[min, max]`` pair where ``min == max``: convert to array indies "min,max" pair or array.
* array of axis length, check that it was passed for a slider axes and do nothing
if it was, error if it is not.
* For slider axes: a function which maps from pixel to world value.
"""
ndim = len(data_shape)
# If no axis range at all make it all [min,max] pairs
if axis_ranges is None:
axis_ranges = [None] * ndim
# need the same number of axis ranges as axes
if len(axis_ranges) != ndim:
raise ValueError("Length of axis_ranges must equal number of axes")
# Define error message for incompatible axis_range input.
def incompatible_axis_ranges_error_message(j): return \
(f"Unrecognized format for {j}th entry in axis_ranges: {axis_ranges[j]}"
"axis_ranges must be None, a ``[min, max]`` pair, or "
"an array-like giving the edge values of each pixel, "
"i.e. length must be length of axis + 1.")
# If axis range not given, define a function such that the range goes
# from -0.5 to number of pixels-0.5. Thus, the center of the pixels
# along the axis will correspond to integer values.
def none_image_axis_range(j): return [-0.5, data_shape[j]-0.5]
# For each axis validate and translate the axis_ranges. For image axes,
# also determine the plot extent. To do this, iterate through image and slider
# axes separately. Iterate through image axes in reverse order
# because numpy is in y-x and extent is x-y.
extent = []
for i in self.image_axes[::-1]:
if axis_ranges[i] is None:
extent = extent + none_image_axis_range(i)
axis_ranges[i] = np.array(none_image_axis_range(i))
else:
# Depending on length of axis_ranges[i], leave unchanged,
# convert to pixel centers or raise an error due to incompatible format.
axis_ranges[i] = np.asarray(axis_ranges[i])
if len(axis_ranges[i]) == 2:
# Set extent.
extent += [axis_ranges[i][0], axis_ranges[i][-1]]
elif axis_ranges[i].ndim == 1 and len(axis_ranges[i]) == data_shape[i]+1:
# If array of individual pixel edges supplied, first set extent
# from first and last pixel edge, then convert axis_ranges to pixel centers.
# The reason that pixel edges are required as input rather than centers
# is so that the plot extent can be derived from axis_ranges (above)
# and APIs using both [min, max] pair and manual definition of each pixel
# values can be unambiguously and simultaneously supported.
extent += [axis_ranges[i][0], axis_ranges[i][-1]]
axis_ranges[i] = edges_to_centers_nd(axis_ranges[i], 0)
elif axis_ranges[i].ndim == ndim and axis_ranges[i].shape[i] == data_shape[i]+1:
extent += [axis_ranges[i].min(), axis_ranges[i].max()]
axis_ranges[i] = edges_to_centers_nd(axis_ranges[i], i)
else:
raise ValueError(incompatible_axis_ranges_error_message(i))
# For each slider axis validate and translate the axis_ranges.
def get_pixel_to_world_callable(array):
def pixel_to_world(pixel):
return array[pixel]
return pixel_to_world
for sidx in self.slider_axes:
if axis_ranges[sidx] is None:
# If axis range not supplied, set pixel center values as integers starting at 0.
axis_ranges[sidx] = get_pixel_to_world_callable(np.arange(data_shape[sidx]))
elif not callable(axis_ranges[sidx]):
axis_ranges[sidx] = np.array(axis_ranges[sidx])
if len(axis_ranges[sidx]) == 2:
# If axis range given as a min, max pair, derive the center of each pixel
# assuming they are equally spaced.
axis_ranges[sidx] = np.linspace(axis_ranges[sidx][0], axis_ranges[sidx][-1],
data_shape[sidx]+1)
axis_ranges[sidx] = get_pixel_to_world_callable(
edges_to_centers_nd(axis_ranges[sidx], sidx))
elif axis_ranges[sidx].ndim == 1 and len(axis_ranges[sidx]) == data_shape[sidx]+1:
# If axis range given as 1D array of pixel edges (i.e. axis is independent),
# derive pixel centers.
axis_ranges[sidx] = get_pixel_to_world_callable(
edges_to_centers_nd(np.asarray(axis_ranges[sidx]), 0))
elif axis_ranges[sidx].ndim == ndim and axis_ranges[sidx].shape[sidx] == data_shape[sidx]+1:
# If axis range given as array of pixel edges the same shape as
# the data array (i.e. axis is not independent), derive pixel centers.
axis_ranges[sidx] = get_pixel_to_world_callable(
edges_to_centers_nd(np.asarray(axis_ranges[sidx]), i))
else:
raise ValueError(incompatible_axis_ranges_error_message(i))
return axis_ranges, extent
[docs]
@abc.abstractmethod
def plot_start_image(self, ax):
"""
Abstract method for plotting first slice of array.
Must exist here but be defined in subclass.
"""
[docs]
@abc.abstractmethod
def update_plot(self, val, artist, slider):
"""
Abstract method for updating the plot.
Must exist here but be defined in subclass.
"""
ind = int(val)
ax_ind = self.slider_axes[slider.slider_ind]
# Update slider label to reflect real world values in axis_ranges.
label = self.axis_ranges[ax_ind](ind)
if units is not None and isinstance(label, units.Quantity):
slider.valtext.set_text(label.to_string(precision=5,
format='latex',
subfmt='inline'))
elif isinstance(label, str):
slider.valtext.set_text(label)
else:
slider.valtext.set_text(f"{label:10.2f}")
def edges_to_centers_nd(axis_range, edges_axis):
"""
Converts ND array of pixel edges to pixel centers along one axis.
Parameters
----------
axis_range: `numpy.ndarray`
Array of pixel edges.
edges_axis: `int`
Index of axis along which centers are to be calculated.
"""
upper_edge_indices = [slice(None)] * axis_range.ndim
upper_edge_indices[edges_axis] = slice(1, axis_range.shape[edges_axis])
upper_edges = axis_range[tuple(upper_edge_indices)]
lower_edge_indices = [slice(None)] * axis_range.ndim
lower_edge_indices[edges_axis] = slice(0, -1)
lower_edges = axis_range[tuple(lower_edge_indices)]
return (upper_edges - lower_edges) / 2 + lower_edges