from __future__ import absolute_import, division, print_function
import copy
import numpy as np
from matplotlib.patches import Ellipse, Polygon, Rectangle, Path as mplPath
from matplotlib.transforms import IdentityTransform, blended_transform_factory
from glue.core.exceptions import UndefinedROI
from glue.utils import points_inside_poly
np.seterr(all='ignore')
__all__ = ['Roi', 'RectangularROI', 'CircularROI', 'PolygonalROI',
'AbstractMplRoi', 'MplRectangularROI', 'MplCircularROI',
'MplPolygonalROI', 'MplXRangeROI', 'MplYRangeROI',
'XRangeROI', 'RangeROI', 'YRangeROI', 'VertexROIBase',
'CategoricalROI']
PATCH_COLOR = '#FFFF00'
SCRUBBING_KEY = 'control'
def aspect_ratio(axes):
""" Returns the pixel height / width of a box that spans 1
data unit in x and y
"""
width = axes.get_position().width * axes.figure.get_figwidth()
height = axes.get_position().height * axes.figure.get_figheight()
xmin, xmax = axes.get_xlim()
ymin, ymax = axes.get_ylim()
return height / width / (ymax - ymin) * (xmax - xmin)
def data_to_norm(axes, x, y):
xy = np.column_stack((np.asarray(x).ravel(), np.asarray(y).ravel()))
pixel = axes.transData.transform(xy)
norm = axes.transAxes.inverted().transform(pixel)
return norm
def data_to_pixel(axes, x, y):
xy = np.column_stack((np.asarray(x).ravel(), np.asarray(y).ravel()))
return axes.transData.transform(xy)
def pixel_to_data(axes, x, y):
xy = np.column_stack((np.asarray(x).ravel(), np.asarray(y).ravel()))
return axes.transData.inverted().transform(xy)
[docs]class Roi(object): # pragma: no cover
"""
A geometrical 2D region of interest.
Glue uses Roi's to represent user-drawn regions on plots. There
are many specific subtypes of Roi, but they all have a ``contains``
method to test whether a collection of 2D points lies inside the region.
"""
[docs] def contains(self, x, y):
"""Return true/false for each x/y pair.
:param x: Array of X locations
:param y: Array of Y locations
:returns: A Boolean array, where each element is True
if the corresponding (x,y) tuple is inside the Roi.
:raises: UndefinedROI exception if not defined
"""
raise NotImplementedError()
[docs] def center(self):
"""Return the (x,y) coordinates of the ROI center"""
raise NotImplementedError()
[docs] def move_to(self, x, y):
"""Translate the ROI to a center of (x, y)"""
raise NotImplementedError()
[docs] def defined(self):
""" Returns whether or not the subset is properly defined """
raise NotImplementedError()
[docs] def to_polygon(self):
""" Returns a tuple of x and y points, approximating the ROI
as a polygon."""
raise NotImplementedError
[docs] def copy(self):
"""
Return a clone of the ROI
"""
return copy.copy(self)
class PointROI(Roi):
def __init__(self, x=None, y=None):
self.x = x
self.y = y
def contains(self, x, y):
return False
def move_to(self, x, y):
self.x = x
self.y = y
def defined(self):
try:
return np.isfinite([self.x, self.y]).all()
except TypeError:
return False
def center(self):
return self.x, self.y
def reset(self):
self.x = self.y = None
[docs]class RectangularROI(Roi):
"""
A 2D rectangular region of interest.
"""
def __init__(self, xmin=None, xmax=None, ymin=None, ymax=None):
super(RectangularROI, self).__init__()
self.xmin = xmin
self.xmax = xmax
self.ymin = ymin
self.ymax = ymax
def __str__(self):
if self.defined():
return "x=[%0.3f, %0.3f], y=[%0.3f, %0.3f]" % (self.xmin,
self.xmax,
self.ymin,
self.ymax)
else:
return "Undefined Rectangular ROI"
[docs] def center(self):
return self.xmin + self.width() / 2, self.ymin + self.height() / 2
[docs] def move_to(self, x, y):
cx, cy = self.center()
dx = x - cx
dy = y - cy
self.xmin += dx
self.xmax += dx
self.ymin += dy
self.ymax += dy
[docs] def transpose(self, copy=True):
if copy:
new = self.copy()
new.xmin, new.xmax = self.ymin, self.ymax
new.ymin, new.ymax = self.xmin, self.xmax
return new
self.xmin, self.ymin = self.ymin, self.xmin
self.xmax, self.ymax = self.ymax, self.xmax
[docs] def corner(self):
return (self.xmin, self.ymin)
[docs] def width(self):
return self.xmax - self.xmin
[docs] def height(self):
return self.ymax - self.ymin
[docs] def contains(self, x, y):
"""
Test whether a set of (x,y) points falls within
the region of interest
:param x: A scalar or numpy array of x points
:param y: A scalar or numpy array of y points
*Returns*
A list of True/False values, for whether each (x,y)
point falls within the ROI
"""
if not self.defined():
raise UndefinedROI
return (x > self.xmin) & (x < self.xmax) & \
(y > self.ymin) & (y < self.ymax)
[docs] def update_limits(self, xmin, ymin, xmax, ymax):
"""
Update the limits of the rectangle
"""
self.xmin = min(xmin, xmax)
self.xmax = max(xmin, xmax)
self.ymin = min(ymin, ymax)
self.ymax = max(ymin, ymax)
[docs] def reset(self):
"""
Reset the rectangular region.
"""
self.xmin = None
self.xmax = None
self.ymin = None
self.ymax = None
[docs] def defined(self):
return self.xmin is not None
[docs] def to_polygon(self):
if self.defined():
return [self.xmin, self.xmax, self.xmax, self.xmin, self.xmin], \
[self.ymin, self.ymin, self.ymax, self.ymax, self.ymin]
else:
return [], []
def __gluestate__(self, context):
return dict(xmin=self.xmin, xmax=self.xmax, ymin=self.ymin, ymax=self.ymax)
@classmethod
def __setgluestate__(cls, rec, context):
return cls(xmin=rec['xmin'], xmax=rec['xmax'],
ymin=rec['ymin'], ymax=rec['ymax'])
[docs]class RangeROI(Roi):
def __init__(self, orientation, min=None, max=None):
""":param orientation: 'x' or 'y'. Sets which axis to range"""
super(RangeROI, self).__init__()
self.min = min
self.max = max
self.ori = orientation
@property
def ori(self):
return self._ori
@ori.setter
def ori(self, value):
if value in set('xy'):
self._ori = value
else:
raise ValueError("Orientation must be one of 'x', 'y'")
def __str__(self):
if self.defined():
return "%0.3f < %s < %0.3f" % (self.min, self.ori,
self.max)
else:
return "Undefined %s" % type(self).__name__
[docs] def range(self):
return self.min, self.max
[docs] def center(self):
return (self.min + self.max) / 2
[docs] def set_range(self, lo, hi):
self.min, self.max = lo, hi
[docs] def move_to(self, center):
delta = center - self.center()
self.min += delta
self.max += delta
[docs] def contains(self, x, y):
if not self.defined():
raise UndefinedROI()
coord = x if self.ori == 'x' else y
return (coord > self.min) & (coord < self.max)
[docs] def reset(self):
self.min = None
self.max = None
[docs] def defined(self):
return self.min is not None and self.max is not None
[docs] def to_polygon(self):
if self.defined():
on = [self.min, self.max, self.max, self.min, self.min]
off = [-1e100, -1e100, 1e100, 1e100, -1e100]
x, y = (on, off) if (self.ori == 'x') else (off, on)
return x, y
else:
return [], []
def __gluestate__(self, context):
return dict(ori=self.ori, min=self.min, max=self.max)
@classmethod
def __setgluestate__(cls, rec, context):
return cls(rec['ori'], min=rec['min'], max=rec['max'])
[docs]class XRangeROI(RangeROI):
def __init__(self, min=None, max=None):
super(XRangeROI, self).__init__('x', min=min, max=max)
[docs]class YRangeROI(RangeROI):
def __init__(self, min=None, max=None):
super(YRangeROI, self).__init__('y', min=min, max=max)
[docs]class CircularROI(Roi):
"""
A 2D circular region of interest.
"""
def __init__(self, xc=None, yc=None, radius=None):
super(CircularROI, self).__init__()
self.xc = xc
self.yc = yc
self.radius = radius
[docs] def contains(self, x, y):
"""
Test whether a set of (x,y) points falls within
the region of interest
:param x: A list of x points
:param y: A list of y points
*Returns*
A list of True/False values, for whether each (x,y)
point falls within the ROI
"""
if not self.defined():
raise UndefinedROI
if not isinstance(x, np.ndarray):
x = np.asarray(x)
if not isinstance(y, np.ndarray):
y = np.asarray(y)
return (x - self.xc) ** 2 + (y - self.yc) ** 2 < self.radius ** 2
[docs] def set_center(self, x, y):
"""
Set the center of the circular region
"""
self.xc = x
self.yc = y
[docs] def set_radius(self, radius):
"""
Set the radius of the circular region
"""
self.radius = radius
[docs] def get_center(self):
return self.xc, self.yc
[docs] def get_radius(self):
return self.radius
[docs] def reset(self):
"""
Reset the rectangular region.
"""
self.xc = None
self.yc = None
self.radius = 0.
[docs] def defined(self):
""" Returns True if the ROI is defined """
return self.xc is not None and \
self.yc is not None and self.radius is not None
[docs] def to_polygon(self):
""" Returns x, y, where each is a list of points """
if not self.defined():
return [], []
theta = np.linspace(0, 2 * np.pi, num=20)
x = self.xc + self.radius * np.cos(theta)
y = self.yc + self.radius * np.sin(theta)
return x, y
def __gluestate__(self, context):
return dict(xc=self.xc, yc=self.yc, radius=self.radius)
@classmethod
def __setgluestate__(cls, rec, context):
return cls(xc=rec['xc'], yc=rec['yc'], radius=rec['radius'])
[docs]class VertexROIBase(Roi):
def __init__(self, vx=None, vy=None):
"""
:param vx: initial x vertices
:type vx: list
:param vy: initial y vertices
:type vy: list
"""
super(VertexROIBase, self).__init__()
self.vx = vx
self.vy = vy
if self.vx is None:
self.vx = []
if self.vy is None:
self.vy = []
[docs] def add_point(self, x, y):
"""
Add another vertex to the ROI
:param x: The x coordinate
:param y: The y coordinate
"""
self.vx.append(x)
self.vy.append(y)
[docs] def reset(self):
"""
Reset the vertex list.
"""
self.vx = []
self.vy = []
[docs] def replace_last_point(self, x, y):
if len(self.vx) > 0:
self.vx[-1] = x
self.vy[-1] = y
[docs] def remove_point(self, x, y, thresh=None):
"""Remove the vertex closest to a reference (xy) point
:param x: The x coordinate of the reference point
:param y: The y coordinate of the reference point
:param thresh: An optional threshhold. If present, the vertex
closest to (x,y) will only be removed if the distance
is less than thresh
"""
if len(self.vx) == 0:
return
# find distance between vertices and input
dist = [(x - a) ** 2 + (y - b) ** 2 for a, b
in zip(self.vx, self.vy)]
inds = range(len(dist))
near = min(inds, key=lambda x: dist[x])
if thresh is not None and dist[near] > (thresh ** 2):
return
self.vx = [self.vx[i] for i in inds if i != near]
self.vy = [self.vy[i] for i in inds if i != near]
[docs] def defined(self):
return len(self.vx) > 0
[docs] def to_polygon(self):
return self.vx, self.vy
def __gluestate__(self, context):
return dict(vx=np.asarray(self.vx).tolist(),
vy=np.asarray(self.vy).tolist())
@classmethod
def __setgluestate__(cls, rec, context):
return cls(vx=rec['vx'], vy=rec['vy'])
[docs]class PolygonalROI(VertexROIBase):
"""
A class to define 2D polygonal regions-of-interest
"""
def __str__(self):
result = 'Polygonal ROI ('
result += ','.join(['(%s, %s)' % (x, y)
for x, y in zip(self.vx, self.vy)])
result += ')'
return result
[docs] def contains(self, x, y):
"""
Test whether a set of (x,y) points falls within
the region of interest
:param x: A list of x points
:param y: A list of y points
*Returns*
A list of True/False values, for whether each (x,y)
point falls within the ROI
"""
if not self.defined():
raise UndefinedROI
if not isinstance(x, np.ndarray):
x = np.asarray(x)
if not isinstance(y, np.ndarray):
y = np.asarray(y)
result = points_inside_poly(x.flat, y.flat, self.vx, self.vy)
good = np.isfinite(x.flat) & np.isfinite(y.flat)
result[~good] = False
result.shape = x.shape
return result
[docs] def move_to(self, xdelta, ydelta):
self.vx = list(map(lambda x: x + xdelta, self.vx))
self.vy = list(map(lambda y: y + ydelta, self.vy))
class Path(VertexROIBase):
def __str__(self):
result = 'Path ('
result += ','.join(['(%s, %s)' % (x, y)
for x, y in zip(self.vx, self.vy)])
result += ')'
return result
[docs]class AbstractMplRoi(object): # pragma: no cover
""" Base class for objects which use
Matplotlib user events to edit/display ROIs
"""
def __init__(self, axes):
"""
:param axes: The Matplotlib Axes object to draw to
"""
self._axes = axes
self._roi = self._roi_factory()
self._previous_roi = None
self._mid_selection = False
self._scrubbing = False
def _draw(self):
self._axes.figure.canvas.draw()
def _roi_factory(self):
raise NotImplementedError()
[docs] def roi(self):
return self._roi.copy()
[docs] def reset(self, include_roi=True):
self._mid_selection = False
self._scrubbing = False
if include_roi:
self._roi.reset()
self._sync_patch()
[docs] def active(self):
return self._mid_selection
[docs] def start_selection(self, event):
raise NotImplementedError()
[docs] def update_selection(self, event):
raise NotImplementedError()
[docs] def finalize_selection(self, event):
raise NotImplementedError()
[docs] def abort_selection(self, event):
if self._mid_selection:
self._roi_restore()
self.reset(include_roi=False)
def _sync_patch(self):
raise NotImplementedError()
def _roi_store(self):
self._previous_roi = self._roi.copy()
def _roi_restore(self):
self._roi = self._previous_roi
class MplPickROI(AbstractMplRoi):
def _draw(self):
pass
def _roi_factory(self):
return PointROI()
def start_selection(self, event):
self._roi.x = event.xdata
self._roi.y = event.ydata
def update_selection(self, event):
self._roi.x = event.xdata
self._roi.y = event.ydata
def finalize_selection(self, event):
self._roi.x = event.xdata
self._roi.y = event.ydata
def _sync_patch(self):
pass
[docs]class MplRectangularROI(AbstractMplRoi):
"""
A subclass of RectangularROI that also renders the ROI to a plot
*Attributes*:
plot_opts:
Dictionary instance
A dictionary of plot keywords that are passed to
the patch representing the ROI. These control
the visual properties of the ROI
"""
def __init__(self, axes):
"""
:param axes: A matplotlib Axes object to attach the graphical ROI to
"""
AbstractMplRoi.__init__(self, axes)
self._xi = None
self._yi = None
self.plot_opts = {'edgecolor': PATCH_COLOR, 'facecolor': PATCH_COLOR,
'alpha': 0.3}
self._patch = Rectangle((0., 0.), 1., 1.)
self._patch.set_zorder(100)
self._setup_patch()
def _setup_patch(self):
self._axes.add_patch(self._patch)
self._patch.set_visible(False)
self._sync_patch()
def _roi_factory(self):
return RectangularROI()
[docs] def start_selection(self, event):
if event.inaxes != self._axes:
return False
if event.key == SCRUBBING_KEY:
if not self._roi.defined():
return False
elif not self._roi.contains(event.xdata, event.ydata):
return False
self._roi_store()
self._xi = event.xdata
self._yi = event.ydata
if event.key == SCRUBBING_KEY:
self._scrubbing = True
self._cx, self._cy = self._roi.center()
else:
self.reset()
self._roi.update_limits(event.xdata, event.ydata,
event.xdata, event.ydata)
self._mid_selection = True
self._sync_patch()
[docs] def update_selection(self, event):
if not self._mid_selection or event.inaxes != self._axes:
return False
if event.key == SCRUBBING_KEY:
if not self._roi.defined():
return False
if self._scrubbing:
self._roi.move_to(self._cx + event.xdata - self._xi,
self._cy + event.ydata - self._yi)
else:
self._roi.update_limits(min(event.xdata, self._xi),
min(event.ydata, self._yi),
max(event.xdata, self._xi),
max(event.ydata, self._yi))
self._sync_patch()
[docs] def finalize_selection(self, event):
self._scrubbing = False
self._mid_selection = False
self._patch.set_visible(False)
self._draw()
def _sync_patch(self):
if self._roi.defined():
corner = self._roi.corner()
width = self._roi.width()
height = self._roi.height()
self._patch.set_xy(corner)
self._patch.set_width(width)
self._patch.set_height(height)
self._patch.set(**self.plot_opts)
self._patch.set_visible(True)
else:
self._patch.set_visible(False)
self._draw()
def __str__(self):
return "MPL Rectangle: %s" % self._patch
[docs]class MplXRangeROI(AbstractMplRoi):
def __init__(self, axes):
"""
:param axes: A matplotlib Axes object to attach the graphical ROI to
"""
AbstractMplRoi.__init__(self, axes)
self._xi = None
self.plot_opts = {'edgecolor': PATCH_COLOR, 'facecolor': PATCH_COLOR,
'alpha': 0.3}
trans = blended_transform_factory(self._axes.transData,
self._axes.transAxes)
self._patch = Rectangle((0., 0.), 1., 1., transform=trans)
self._patch.set_zorder(100)
self._setup_patch()
def _setup_patch(self):
self._axes.add_patch(self._patch)
self._patch.set_visible(False)
self._sync_patch()
def _roi_factory(self):
return XRangeROI()
[docs] def start_selection(self, event):
if event.inaxes != self._axes:
return False
if event.key == SCRUBBING_KEY:
if not self._roi.defined():
return False
elif not self._roi.contains(event.xdata, event.ydata):
return False
self._roi_store()
if event.key == SCRUBBING_KEY:
self._scrubbing = True
self._dx = event.xdata - self._roi.center()
else:
self.reset()
self._roi.set_range(event.xdata, event.xdata)
self._xi = event.xdata
self._mid_selection = True
self._sync_patch()
[docs] def update_selection(self, event):
if not self._mid_selection or event.inaxes != self._axes:
return False
if event.key == SCRUBBING_KEY:
if not self._roi.defined():
return False
if self._scrubbing:
self._roi.move_to(event.xdata + self._dx)
else:
self._roi.set_range(min(event.xdata, self._xi),
max(event.xdata, self._xi))
self._sync_patch()
[docs] def finalize_selection(self, event):
self._scrubbing = False
self._mid_selection = False
self._patch.set_visible(False)
self._draw()
def _sync_patch(self):
if self._roi.defined():
rng = self._roi.range()
self._patch.set_xy((rng[0], 0))
self._patch.set_width(rng[1] - rng[0])
self._patch.set_height(1)
self._patch.set(**self.plot_opts)
self._patch.set_visible(True)
else:
self._patch.set_visible(False)
self._draw()
[docs]class MplYRangeROI(AbstractMplRoi):
def __init__(self, axes):
"""
:param axes: A matplotlib Axes object to attach the graphical ROI to
"""
AbstractMplRoi.__init__(self, axes)
self._xi = None
self.plot_opts = {'edgecolor': PATCH_COLOR, 'facecolor': PATCH_COLOR,
'alpha': 0.3}
trans = blended_transform_factory(self._axes.transAxes,
self._axes.transData)
self._patch = Rectangle((0., 0.), 1., 1., transform=trans)
self._patch.set_zorder(100)
self._setup_patch()
def _setup_patch(self):
self._axes.add_patch(self._patch)
self._patch.set_visible(False)
self._sync_patch()
def _roi_factory(self):
return YRangeROI()
[docs] def start_selection(self, event):
if event.inaxes != self._axes:
return False
if event.key == SCRUBBING_KEY:
if not self._roi.defined():
return False
elif not self._roi.contains(event.xdata, event.ydata):
return False
self._roi_store()
if event.key == SCRUBBING_KEY:
self._scrubbing = True
self._dy = event.ydata - self._roi.center()
else:
self.reset()
self._roi.set_range(event.ydata, event.ydata)
self._xi = event.ydata
self._mid_selection = True
self._sync_patch()
[docs] def update_selection(self, event):
if not self._mid_selection or event.inaxes != self._axes:
return False
if event.key == SCRUBBING_KEY:
if not self._roi.defined():
return False
if self._scrubbing:
self._roi.move_to(event.ydata + self._dy)
else:
self._roi.set_range(min(event.ydata, self._xi),
max(event.ydata, self._xi))
self._sync_patch()
[docs] def finalize_selection(self, event):
self._scrubbing = False
self._mid_selection = False
self._patch.set_visible(False)
self._draw()
def _sync_patch(self):
if self._roi.defined():
rng = self._roi.range()
self._patch.set_xy((0, rng[0]))
self._patch.set_height(rng[1] - rng[0])
self._patch.set_width(1)
self._patch.set(**self.plot_opts)
self._patch.set_visible(True)
else:
self._patch.set_visible(False)
self._draw()
[docs]class MplCircularROI(AbstractMplRoi):
"""
Class to display / edit circular ROIs using matplotlib
Since circles on the screen may not be circles in the data
(due, e.g., to logarithmic scalings on the axes), the
ultimate ROI that is created is a polygonal ROI
:param plot_opts:
A dictionary of plot keywords that are passed to
the patch representing the ROI. These control
the visual properties of the ROI
"""
def __init__(self, axes):
"""
:param axes: A matplotlib Axes object to attach the graphical ROI to
"""
AbstractMplRoi.__init__(self, axes)
self.plot_opts = {'edgecolor': PATCH_COLOR, 'facecolor': PATCH_COLOR,
'alpha': 0.3}
self._xi = None
self._yi = None
self._setup_patch()
def _setup_patch(self):
self._patch = Ellipse((0., 0.), transform=IdentityTransform(),
width=0., height=0.,)
self._patch.set_zorder(100)
self._patch.set(**self.plot_opts)
self._axes.add_patch(self._patch)
self._patch.set_visible(False)
self._sync_patch()
def _roi_factory(self):
return CircularROI()
def _sync_patch(self):
# Update geometry
if not self._roi.defined():
self._patch.set_visible(False)
else:
xy = self._roi.get_center()
r = self._roi.get_radius()
self._patch.center = xy
self._patch.width = 2. * r
self._patch.height = 2. * r
self._patch.set_visible(True)
# Update appearance
self._patch.set(**self.plot_opts)
# Refresh
self._axes.figure.canvas.draw()
[docs] def start_selection(self, event):
if event.inaxes != self._axes:
return False
xy = data_to_pixel(self._axes, [event.xdata], [event.ydata])
xi = xy[0, 0]
yi = xy[0, 1]
if event.key == SCRUBBING_KEY:
if not self._roi.defined():
return False
elif not self._roi.contains(xi, yi):
return False
self._roi_store()
if event.key == SCRUBBING_KEY:
self._scrubbing = True
(xc, yc) = self._roi.get_center()
self._dx = xc - xi
self._dy = yc - yi
else:
self.reset()
self._roi.set_center(xi, yi)
self._roi.set_radius(0.)
self._xi = xi
self._yi = yi
self._mid_selection = True
self._sync_patch()
[docs] def update_selection(self, event):
if not self._mid_selection or event.inaxes != self._axes:
return False
xy = data_to_pixel(self._axes, [event.xdata], [event.ydata])
xi = xy[0, 0]
yi = xy[0, 1]
if event.key == SCRUBBING_KEY:
if not self._roi.defined():
return False
if self._scrubbing:
self._roi.set_center(xi + self._dx, yi + self._dy)
else:
dx = xy[0, 0] - self._xi
dy = xy[0, 1] - self._yi
self._roi.set_radius(np.hypot(dx, dy))
self._sync_patch()
[docs] def roi(self):
if not self._roi.defined():
return PolygonalROI()
theta = np.linspace(0, 2 * np.pi, num=200)
xy_center = self._roi.get_center()
rad = self._roi.get_radius()
x = xy_center[0] + rad * np.cos(theta)
y = xy_center[1] + rad * np.sin(theta)
xy_data = pixel_to_data(self._axes, x, y)
vx = xy_data[:, 0].ravel().tolist()
vy = xy_data[:, 1].ravel().tolist()
result = PolygonalROI(vx, vy)
return result
[docs] def finalize_selection(self, event):
self._scrubbing = False
self._mid_selection = False
self._patch.set_visible(False)
self._axes.figure.canvas.draw()
[docs]class MplPolygonalROI(AbstractMplRoi):
"""
Defines and displays polygonal ROIs on matplotlib plots
Attributes:
plot_opts: Dictionary instance
A dictionary of plot keywords that are passed to
the patch representing the ROI. These control
the visual properties of the ROI
"""
def __init__(self, axes):
"""
:param axes: A matplotlib Axes object to attach the graphical ROI to
"""
AbstractMplRoi.__init__(self, axes)
self.plot_opts = {'edgecolor': PATCH_COLOR, 'facecolor': PATCH_COLOR,
'alpha': 0.3}
self._setup_patch()
def _setup_patch(self):
self._patch = Polygon(np.array(list(zip([0, 1], [0, 1]))))
self._patch.set_zorder(100)
self._patch.set(**self.plot_opts)
self._axes.add_patch(self._patch)
self._patch.set_visible(False)
self._sync_patch()
def _roi_factory(self):
return PolygonalROI()
def _sync_patch(self):
# Update geometry
if not self._roi.defined():
self._patch.set_visible(False)
else:
x, y = self._roi.to_polygon()
self._patch.set_xy(list(zip(x + [x[0]],
y + [y[0]])))
self._patch.set_visible(True)
# Update appearance
self._patch.set(**self.plot_opts)
# Refresh
self._axes.figure.canvas.draw()
[docs] def start_selection(self, event):
if event.inaxes != self._axes:
return False
if event.key == SCRUBBING_KEY:
if not self._roi.defined():
return False
elif not self._roi.contains(event.xdata, event.ydata):
return False
self._roi_store()
if event.key == SCRUBBING_KEY:
self._scrubbing = True
self._cx = event.xdata
self._cy = event.ydata
else:
self.reset()
self._roi.add_point(event.xdata, event.ydata)
self._mid_selection = True
self._sync_patch()
[docs] def update_selection(self, event):
if not self._mid_selection or event.inaxes != self._axes:
return False
if event.key == SCRUBBING_KEY:
if not self._roi.defined():
return False
if self._scrubbing:
self._roi.move_to(event.xdata - self._cx,
event.ydata - self._cy)
self._cx = event.xdata
self._cy = event.ydata
else:
self._roi.add_point(event.xdata, event.ydata)
self._sync_patch()
[docs] def finalize_selection(self, event):
self._scrubbing = False
self._mid_selection = False
self._patch.set_visible(False)
self._axes.figure.canvas.draw()
class MplPathROI(MplPolygonalROI):
def roi_factory(self):
return Path()
def _setup_patch(self):
self._patch = None
def _sync_patch(self):
if self._patch is not None:
self._patch.remove()
self._patch = None
# Update geometry
if not self._roi.defined():
return
else:
x, y = self._roi.to_polygon()
p = MplPath(np.column_stack((x, y)))
self._patch = PatchPath(p)
self._patch.set_visible(True)
# Update appearance
self._patch.set(**self.plot_opts)
# Refresh
self._axes.figure.canvas.draw()
def finalize_selection(self, event):
self._mid_selection = False
if self._patch is not None:
self._patch.set_visible(False)
self._axes.figure.canvas.draw()
[docs]class CategoricalROI(Roi):
"""
A ROI abstraction to represent selections of categorical data.
"""
def __init__(self, categories=None):
if categories is None:
self.categories = None
else:
self.update_categories(categories)
[docs] def to_polygon(self):
""" Just not possible.
"""
raise NotImplementedError
def _categorical_helper(self, indata):
"""
A helper function to do the rigamaroll of getting categorical data.
:param indata: Any type of input data
:return: The best guess at the categorical data associated with indata
"""
try:
if indata.categorical:
return indata._categorical_data
else:
return indata[:]
except AttributeError:
return np.asarray(indata)
[docs] def contains(self, x, y):
"""
Test whether a set categorical elements fall within
the region of interest
:param x: Any array-like object of categories
(includes CategoricalComponenets)
:param y: Unused but required for compatibility
*Returns*
A list of True/False values, for whether each x value falls
within the ROI
"""
check = self._categorical_helper(x)
index = np.minimum(np.searchsorted(self.categories, check),
len(self.categories) - 1)
return self.categories[index] == check
[docs] def update_categories(self, categories):
self.categories = np.unique(self._categorical_helper(categories))
[docs] def defined(self):
""" Returns True if the ROI is defined """
return self.categories is not None
[docs] def reset(self):
self.categories = None
@staticmethod
[docs] def from_range(cat_comp, lo, hi):
"""
Utility function to help construct the Roi from a range.
:param cat_comp: Anything understood by ._categorical_helper ... array, list or component
:param lo: lower bound of the range
:param hi: upper bound of the range
:return: CategoricalROI object
"""
# Convert lo and hi to integers. Note that if lo or hi are negative,
# which can happen if the user zoomed out, we need to reset the to zero
# otherwise they will have strange effects when slicing the categories.
# Note that we used ceil for lo, because if lo is 0.9 then we should
# only select 1 and above.
lo = np.intp(np.ceil(lo) if lo > 0 else 0)
hi = np.intp(np.ceil(hi) if hi > 0 else 0)
roi = CategoricalROI()
cat_data = cat_comp.categories
roi.update_categories(cat_data[lo:hi])
return roi