Source code for glue.viewers.image.state

import uuid
from collections import defaultdict

from glue.core import BaseData
from glue.config import colormaps
from glue.viewers.matplotlib.state import (MatplotlibDataViewerState,
                                           MatplotlibLayerState,
                                           DeferredDrawCallbackProperty as DDCProperty,
                                           DeferredDrawSelectionCallbackProperty as DDSCProperty)
from glue.core.state_objects import StateAttributeLimitsHelper
from glue.utils import defer_draw, view_shape
from echo import delay_callback
from glue.core.data_combo_helper import ManualDataComboHelper, ComponentIDComboHelper
from glue.core.exceptions import IncompatibleDataException

__all__ = ['ImageViewerState', 'ImageLayerState', 'ImageSubsetLayerState', 'AggregateSlice']


def get_sliced_data_maker(x_axis=None, y_axis=None, slices=None, data=None,
                          target_cid=None, reference_data=None, transpose=False):
    """
    Convenience function for use in exported Python scripts.
    """

    if reference_data is None:
        reference_data = data

    def get_array(bounds=None):

        full_bounds = list(slices)
        full_bounds[y_axis] = bounds[0]
        full_bounds[x_axis] = bounds[1]

        if isinstance(data, BaseData):
            array = data.compute_fixed_resolution_buffer(full_bounds, target_data=reference_data,
                                                         target_cid=target_cid, broadcast=False)
        else:
            array = data.data.compute_fixed_resolution_buffer(full_bounds, target_data=reference_data,
                                                              subset_state=data.subset_state, broadcast=False)

        if transpose:
            array = array.transpose()

        return array

    return get_array


[docs]class AggregateSlice(object): def __init__(self, slice=None, center=None, function=None): self.slice = slice self.center = center self.function = function def __gluestate__(self, context): state = dict(slice=context.do(self.slice), center=self.center, function=context.do(self.function)) return state @classmethod def __setgluestate__(cls, rec, context): return cls(slice=context.object(rec['slice']), center=rec['center'], function=context.object(rec['function']))
[docs]class ImageViewerState(MatplotlibDataViewerState): """ A state class that includes all the attributes for an image viewer. """ x_att = DDCProperty(docstring='The component ID giving the pixel component ' 'shown on the x axis') y_att = DDCProperty(docstring='The component ID giving the pixel component ' 'shown on the y axis') x_att_world = DDSCProperty(docstring='The component ID giving the world component ' 'shown on the x axis', default_index=-1) y_att_world = DDSCProperty(docstring='The component ID giving the world component ' 'shown on the y axis', default_index=-2) aspect = DDSCProperty(0, docstring='Whether to enforce square pixels (``equal``) ' 'or fill the axes (``auto``)') reference_data = DDSCProperty(docstring='The dataset that is used to define the ' 'available pixel/world components, and ' 'which defines the coordinate frame in ' 'which the images are shown') slices = DDCProperty(docstring='The current slice along all dimensions') color_mode = DDSCProperty(0, docstring='Whether each layer can have ' 'its own colormap (``Colormaps``) or ' 'whether each layer is assigned ' 'a single color (``One color per layer``)') dpi = DDCProperty(72, docstring='The resolution (in dots per inch) of density maps, if present') def __init__(self, **kwargs): super(ImageViewerState, self).__init__() self.limits_cache = {} # NOTE: we don't need to use StateAttributeLimitsHelper here because # we can simply call reset_limits below when x/y attributes change. # Using StateAttributeLimitsHelper makes things a lot slower. self.ref_data_helper = ManualDataComboHelper(self, 'reference_data') self.xw_att_helper = ComponentIDComboHelper(self, 'x_att_world', numeric=False, datetime=False, categorical=False) self.yw_att_helper = ComponentIDComboHelper(self, 'y_att_world', numeric=False, datetime=False, categorical=False) self.add_callback('reference_data', self._reference_data_changed, priority=1000) self.add_callback('layers', self._layers_changed, priority=1000) self.add_callback('x_att', self._on_xatt_change, priority=500) self.add_callback('y_att', self._on_yatt_change, priority=500) self.add_callback('x_att_world', self._on_xatt_world_change, priority=1000) self.add_callback('y_att_world', self._on_yatt_world_change, priority=1000) aspect_display = {'equal': 'Square Pixels', 'auto': 'Automatic'} ImageViewerState.aspect.set_choices(self, ['equal', 'auto']) ImageViewerState.aspect.set_display_func(self, aspect_display.get) ImageViewerState.color_mode.set_choices(self, ['Colormaps', 'One color per layer']) self.update_from_dict(kwargs)
[docs] def reset_limits(self): if self.reference_data is None or self.x_att is None or self.y_att is None: return nx = self.reference_data.shape[self.x_att.axis] ny = self.reference_data.shape[self.y_att.axis] with delay_callback(self, 'x_min', 'x_max', 'y_min', 'y_max'): self.x_min = -0.5 self.x_max = nx - 0.5 self.y_min = -0.5 self.y_max = ny - 0.5 # We need to adjust the limits in here to avoid triggering all # the update events then changing the limits again. self._adjust_limits_aspect()
@property def _display_world(self): return getattr(self.reference_data, 'coords', None) is not None def _reference_data_changed(self, *args): # This signal can get emitted if just the choices but not the actual # reference data change, so we check here that the reference data has # actually changed if self.reference_data is not getattr(self, '_last_reference_data', None): self._last_reference_data = self.reference_data # Note that we deliberately use nested delay_callback here, because # we want to make sure that x_att_world and y_att_world both get # updated first, then x_att and y_att can be changed, before # subsequent events are fired. with delay_callback(self, 'x_att', 'y_att'): with delay_callback(self, 'x_att_world', 'y_att_world', 'slices'): if self._display_world: self.xw_att_helper.pixel_coord = False self.yw_att_helper.pixel_coord = False self.xw_att_helper.world_coord = True self.yw_att_helper.world_coord = True else: self.xw_att_helper.pixel_coord = True self.yw_att_helper.pixel_coord = True self.xw_att_helper.world_coord = False self.yw_att_helper.world_coord = False self._update_combo_att() self._set_default_slices() # We need to make sure that we update x_att and y_att # at the same time before any other callbacks get called, # so we do this here manually. self._on_xatt_world_change() self._on_yatt_world_change() def _layers_changed(self, *args): # The layers callback gets executed if anything in the layers changes, # but we only care about whether the actual set of 'layer' attributes # for all layers change. layers_data = self.layers_data layers_data_cache = getattr(self, '_layers_data_cache', []) if layers_data == layers_data_cache: return self._update_combo_ref_data() self._set_reference_data() self._update_syncing() self._layers_data_cache = layers_data def _update_syncing(self): # If there are multiple layers for a given dataset, we disable the # syncing by default. layer_state_by_data = defaultdict(list) for layer_state in self.layers: if isinstance(layer_state.layer, BaseData): layer_state_by_data[layer_state.layer].append(layer_state) for data, layer_states in layer_state_by_data.items(): if len(layer_states) > 1: for layer_state in layer_states: # Scatter layers don't have global_sync so we need to be # careful here and make sure we return a default value if getattr(layer_state, 'global_sync', False): layer_state.global_sync = False def _update_combo_ref_data(self): self.ref_data_helper.set_multiple_data(self.layers_data) def _update_combo_att(self): with delay_callback(self, 'x_att_world', 'y_att_world'): if self.reference_data is None: self.xw_att_helper.set_multiple_data([]) self.yw_att_helper.set_multiple_data([]) else: self.xw_att_helper.set_multiple_data([self.reference_data]) self.yw_att_helper.set_multiple_data([self.reference_data]) def _update_priority(self, name): if name == 'layers': return 3 elif name == 'reference_data': return 2 elif name.endswith(('_min', '_max')): return 0 else: return 1 @defer_draw def _on_xatt_change(self, *args): if self.x_att is not None: if self._display_world: self.x_att_world = self.reference_data.world_component_ids[self.x_att.axis] else: self.x_att_world = self.x_att @defer_draw def _on_yatt_change(self, *args): if self.y_att is not None: if self._display_world: self.y_att_world = self.reference_data.world_component_ids[self.y_att.axis] else: self.y_att_world = self.y_att @defer_draw def _on_xatt_world_change(self, *args): if self.x_att_world is not None: with delay_callback(self, 'y_att_world', 'x_att'): if self.x_att_world == self.y_att_world: if self._display_world: world_ids = self.reference_data.world_component_ids else: world_ids = self.reference_data.pixel_component_ids if self.x_att_world == world_ids[-1]: self.y_att_world = world_ids[-2] else: self.y_att_world = world_ids[-1] if self._display_world: index = self.reference_data.world_component_ids.index(self.x_att_world) self.x_att = self.reference_data.pixel_component_ids[index] else: self.x_att = self.x_att_world @defer_draw def _on_yatt_world_change(self, *args): if self.y_att_world is not None: with delay_callback(self, 'x_att_world', 'y_att'): if self.y_att_world == self.x_att_world: if self._display_world: world_ids = self.reference_data.world_component_ids else: world_ids = self.reference_data.pixel_component_ids if self.y_att_world == world_ids[-1]: self.x_att_world = world_ids[-2] else: self.x_att_world = world_ids[-1] if self._display_world: index = self.reference_data.world_component_ids.index(self.y_att_world) self.y_att = self.reference_data.pixel_component_ids[index] else: self.y_att = self.y_att_world def _set_reference_data(self): if self.reference_data is None: for layer in self.layers: if isinstance(layer.layer, BaseData): self.reference_data = layer.layer return def _set_default_slices(self): # Need to make sure this gets called immediately when reference_data is changed if self.reference_data is None: self.slices = () else: self.slices = (0,) * self.reference_data.ndim
[docs] @property def numpy_slice_aggregation_transpose(self): """ Returns slicing information usable by Numpy. This returns two objects: the first is an object that can be used to slice Numpy arrays and return a 2D array, and the second object is a boolean indicating whether to transpose the result. """ if self.reference_data is None: return None slices = [] agg_func = [] for i in range(self.reference_data.ndim): if i == self.x_att.axis or i == self.y_att.axis: slices.append(slice(None)) agg_func.append(None) else: if isinstance(self.slices[i], AggregateSlice): slices.append(self.slices[i].slice) agg_func.append(self.slices[i].function) else: slices.append(self.slices[i]) transpose = self.y_att.axis > self.x_att.axis return slices, agg_func, transpose
[docs] @property def wcsaxes_slice(self): """ Returns slicing information usable by WCSAxes. This returns an iterable of slices, and including ``'x'`` and ``'y'`` for the dimensions along which we are not slicing. """ if self.reference_data is None: return None slices = [] for i in range(self.reference_data.ndim): if i == self.x_att.axis: slices.append('x') elif i == self.y_att.axis: slices.append('y') else: if isinstance(self.slices[i], AggregateSlice): slices.append(self.slices[i].center) else: slices.append(self.slices[i]) return slices[::-1]
[docs] def flip_x(self): """ Flip the x_min/x_max limits. """ with delay_callback(self, 'x_min', 'x_max'): self.x_min, self.x_max = self.x_max, self.x_min
[docs] def flip_y(self): """ Flip the y_min/y_max limits. """ with delay_callback(self, 'y_min', 'y_max'): self.y_min, self.y_max = self.y_max, self.y_min
class BaseImageLayerState(MatplotlibLayerState): _viewer_callbacks_set = False _image_cache = None _pixel_cache = None def get_sliced_data_shape(self, view=None): if (self.viewer_state.reference_data is None or self.viewer_state.x_att is None or self.viewer_state.y_att is None): return None x_axis = self.viewer_state.x_att.axis y_axis = self.viewer_state.y_att.axis shape = self.viewer_state.reference_data.shape shape_slice = shape[y_axis], shape[x_axis] if view is None: return shape_slice else: return view_shape(shape_slice, view) def get_sliced_data(self, view=None, bounds=None): full_view, agg_func, transpose = self.viewer_state.numpy_slice_aggregation_transpose x_axis = self.viewer_state.x_att.axis y_axis = self.viewer_state.y_att.axis # For this method, we make use of Data.compute_fixed_resolution_buffer, # which requires us to specify bounds in the form (min, max, nsteps). # We also allow view to be passed here (which is a normal Numpy view) # and, if given, translate it to bounds. If neither are specified, # we behave as if view was [slice(None), slice(None)]. def slice_to_bound(slc, size): min, max, step = slc.indices(size) n = (max - min - 1) // step max = min + step * n return (min, max, n + 1) if bounds is None: # The view should be that which should just be applied to the data # slice, not to all the dimensions of the data - thus it should have at # most two dimensions if view is None: view = [slice(None), slice(None)] elif len(view) == 1: view = view + [slice(None)] elif len(view) > 2: raise ValueError('view should have at most two elements') full_view[x_axis] = view[1] full_view[y_axis] = view[0] else: full_view[x_axis] = bounds[1] full_view[y_axis] = bounds[0] for i in range(self.viewer_state.reference_data.ndim): if isinstance(full_view[i], slice): full_view[i] = slice_to_bound(full_view[i], self.viewer_state.reference_data.shape[i]) # We now get the fixed resolution buffer if isinstance(self.layer, BaseData): image = self.layer.compute_fixed_resolution_buffer(full_view, target_data=self.viewer_state.reference_data, target_cid=self.attribute, broadcast=False, cache_id=self.uuid) else: image = self.layer.data.compute_fixed_resolution_buffer(full_view, target_data=self.viewer_state.reference_data, subset_state=self.layer.subset_state, broadcast=False, cache_id=self.uuid) # We apply aggregation functions if needed if agg_func is None: if image.ndim != 2: raise IncompatibleDataException() else: if image.ndim != len(agg_func): raise ValueError("Sliced image dimensions ({0}) does not match " "aggregation function list ({1})" .format(image.ndim, len(agg_func))) for axis in range(image.ndim - 1, -1, -1): func = agg_func[axis] if func is not None: image = func(image, axis=axis) if image.ndim != 2: raise ValueError("Image after aggregation should have two dimensions") # And finally we transpose the data if the order of x/y is different # from the native order. if transpose: image = image.transpose() return image
[docs]class ImageLayerState(BaseImageLayerState): """ A state class that includes all the attributes for data layers in an image plot. """ attribute = DDSCProperty(docstring='The attribute shown in the layer') v_min = DDCProperty(docstring='The lower level shown') v_max = DDCProperty(docstring='The upper level shown') percentile = DDSCProperty(docstring='The percentile value used to ' 'automatically calculate levels') contrast = DDCProperty(1, docstring='The contrast of the layer') bias = DDCProperty(0.5, docstring='A constant value that is added to the ' 'layer before rendering') cmap = DDCProperty(docstring='The colormap used to render the layer') stretch = DDSCProperty(docstring='The stretch used to render the layer, ' 'which should be one of ``linear``, ' '``sqrt``, ``log``, or ``arcsinh``') global_sync = DDCProperty(False, docstring='Whether the color and transparency ' 'should be synced with the global ' 'color and transparency for the data') def __init__(self, layer=None, viewer_state=None, **kwargs): self.uuid = str(uuid.uuid4()) super(ImageLayerState, self).__init__(layer=layer, viewer_state=viewer_state) self.attribute_lim_helper = StateAttributeLimitsHelper(self, attribute='attribute', percentile='percentile', lower='v_min', upper='v_max') self.attribute_att_helper = ComponentIDComboHelper(self, 'attribute', numeric=True, categorical=False) percentile_display = {100: 'Min/Max', 99.5: '99.5%', 99: '99%', 95: '95%', 90: '90%', 'Custom': 'Custom'} ImageLayerState.percentile.set_choices(self, [100, 99.5, 99, 95, 90, 'Custom']) ImageLayerState.percentile.set_display_func(self, percentile_display.get) stretch_display = {'linear': 'Linear', 'sqrt': 'Square Root', 'arcsinh': 'Arcsinh', 'log': 'Logarithmic'} ImageLayerState.stretch.set_choices(self, ['linear', 'sqrt', 'arcsinh', 'log']) ImageLayerState.stretch.set_display_func(self, stretch_display.get) self.add_callback('global_sync', self._update_syncing) self.add_callback('layer', self._update_attribute) self._update_syncing() if layer is not None: self._update_attribute() self.update_from_dict(kwargs) if self.cmap is None: self.cmap = self.layer.style.preferred_cmap or colormaps.members[0][1] def _update_attribute(self, *args): if self.layer is not None: self.attribute_att_helper.set_multiple_data([self.layer]) self.attribute = self.layer.main_components[0] def _update_priority(self, name): if name == 'layer': return 3 elif name == 'attribute': return 2 elif name == 'global_sync': return 1.5 elif name.endswith(('_min', '_max')): return 0 else: return 1 def _update_syncing(self, *args): if self.global_sync: self._sync_color.enable_syncing() self._sync_alpha.enable_syncing() else: self._sync_color.disable_syncing() self._sync_alpha.disable_syncing() def _get_image(self, view=None): return self.layer[self.attribute, view]
[docs] def flip_limits(self): """ Flip the image levels. """ self.attribute_lim_helper.flip_limits()
[docs] def reset_contrast_bias(self): with delay_callback(self, 'contrast', 'bias'): self.contrast = 1 self.bias = 0.5
[docs]class ImageSubsetLayerState(BaseImageLayerState): """ A state class that includes all the attributes for subset layers in an image plot. """ # TODO: we can save memory by not showing subset multiple times for # different image datasets since the footprint should be the same. def __init__(self, *args, **kwargs): self.uuid = str(uuid.uuid4()) super(ImageSubsetLayerState, self).__init__(*args, **kwargs) def _get_image(self, view=None): return self.layer.to_mask(view=view)