Source code for mdt.visualization.maps.matplotlib_renderer

import os
import numpy as np
from matplotlib import pyplot as plt, patches
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.cm import get_cmap
from mdt import get_slice_in_dimension
from mdt.visualization.maps.base import Clipping, Scale, Point2d
from mdt.visualization.utils import MyColourBarTickLocator

__author__ = 'Robbert Harms'
__date__ = "2016-09-02"
__maintainer__ = "Robbert Harms"
__email__ = "robbert@xkls.nl"


[docs]class MapsVisualizer: def __init__(self, data_info, figure): self._data_info = data_info self._figure = figure
[docs] def render(self, plot_config): """Render all the maps to the figure. This is for use in GUI embedded situations. Returns: list of AxisData: the list with the drawn axes and the accompanying data """ renderer = Renderer(self._data_info, self._figure, plot_config) renderer.render() return renderer.image_axes
[docs] def to_file(self, file_name, plot_config, **kwargs): """Renders the figures to the given filename.""" Renderer(self._data_info, self._figure, plot_config).render() if not os.path.dirname(file_name).strip(): return if not os.path.isdir(os.path.dirname(file_name)): os.makedirs(os.path.dirname(file_name)) kwargs['dpi'] = kwargs.get('dpi') or 100 self._figure.savefig(file_name, **kwargs)
[docs] def show(self, plot_config, block=True, maximize=False, window_title=None): """Show the data contained in this visualizer using the specifics in this function call. Args: plot_config (mdt.visualization.maps.base.MapPlotConfig): the plot configuration block (boolean): If we want to block after calling the plots or not. Set this to False if you do not want the routine to block after drawing. In doing so you manually need to block. maximize (boolean): if we want to display the window maximized or not window_title (str): the title of the window. If None, the default title is used """ Renderer(self._data_info, self._figure, plot_config).render() if maximize: mng = plt.get_current_fig_manager() mng.window.showMaximized() if window_title: mng = plt.get_current_fig_manager() mng.canvas.set_window_title(window_title) if block: plt.show(True)
[docs]class AxisData: def __init__(self, axis, map_name, map_info, plot_config): """Contains a reference to a drawn matpotlib axis and to the accompanying data. Args: axis (Axis): the matpotlib axis map_name (str): the name/key of this map map_info (mdt.visualization.maps.base.SingleMapInfo): the map information plot_config (mdt.visualization.maps.base.MapPlotConfig): the map plot configuration """ self.axis = axis self.map_name = map_name self._map_info = map_info self._plot_config = plot_config
[docs] def coordinates_to_index(self, x, y): """Converts data coordinates to index coordinates of the array. Args: x (int): The x-coordinate in data coordinates. y (int): The y-coordinate in data coordinates. Returns tuple: Index coordinates of the map associated with the image (x, y, z, d). """ return _coordinates_to_index(self._map_info, self._plot_config, x, y)
[docs] def get_value(self, index): """Get the value of this axis data at the given index. Args: index (tuple)): the 3d or 4d index to the map corresponding to this axis data (x, y, z, [v]) Returns: float: the value at the given index. """ return self._map_info.data[tuple(index)]
[docs]class Renderer: def __init__(self, data_info, figure, plot_config): """Create a new renderer for the given volumes on the given figure using the given configuration. This renders the images with flipped upside down with the origin at the bottom left. The upside down flip is necessary to allow counter-clockwise rotation. Args: data_info (DataInfo): the information about the maps to show figure (Figure): the matplotlib figure to draw on plot_config (mdt.visualization.maps.base.MapPlotConfig): the plot configuration """ self._data_info = data_info self._figure = figure self._plot_config = plot_config self.image_axes = []
[docs] def render(self): """Render the maps""" grid_layout_specifier = self._plot_config.grid_layout.get_gridspec( self._figure, len(self._plot_config.maps_to_show)) if self._plot_config.title: self._figure.suptitle(self._plot_config.title, fontsize=self._plot_config.font.size, family=self._plot_config.font.name) for ind, map_name in enumerate(self._plot_config.maps_to_show): axis = grid_layout_specifier.get_axis(ind) axis_data = self._render_map(map_name, axis) self.image_axes.append(axis_data)
def _render_map(self, map_name, axis): """Render a single map to the given axis""" if self._get_map_attr(map_name, 'show_title', self._plot_config.show_titles): axis.set_title(self._get_map_attr(map_name, 'title', map_name), y=self._get_title_spacing(map_name)) axis.axis('on' if self._plot_config.show_axis else 'off') data = self._get_image(map_name) plot_options = self._get_map_plot_options(map_name) plot_options['origin'] = 'lower' plot_options['interpolation'] = self._plot_config.interpolation vf = axis.imshow(data, **plot_options) self._add_annotations(map_name, axis) self._add_colorbar(axis, map_name, vf, self._get_map_attr(map_name, 'colorbar_label')) self._apply_font_general_axis(axis) return AxisData(axis, map_name, self._data_info.get_single_map_info(map_name), self._plot_config) def _apply_font_general_axis(self, image_axis): """Apply the font from the plot configuration to the image axis""" items = [image_axis.xaxis.label, image_axis.yaxis.label] items.extend(image_axis.get_xticklabels()) items.extend(image_axis.get_yticklabels()) for item in items: item.set_fontsize(self._plot_config.font.size - 2) item.set_family(self._plot_config.font.name) image_axis.title.set_fontsize(self._plot_config.font.size) image_axis.title.set_family(self._plot_config.font.name) def _apply_font_colorbar_axis(self, colorbar_axis): """Apply the font from the plot configuration to the colorbar axis""" for axes in [colorbar_axis.xaxis, colorbar_axis.yaxis]: items = axes.get_ticklabels() for item in items: item.set_fontsize(self._plot_config.font.size - 2) item.set_family(self._plot_config.font.name) axes.label.set_fontsize(self._plot_config.font.size) axes.label.set_family(self._plot_config.font.name) axes.offsetText.set_fontsize(self._plot_config.font.size - 3) axes.offsetText.set_family(self._plot_config.font.name) def _add_annotations(self, map_name, axis): def get_value(annotation): data = self._data_info.get_map_data(map_name) index = tuple(annotation.voxel_index) if len(data.shape) > 3 and data.shape[3] > 1: if len(index) < 4: index = index + (self._plot_config.volume_index,) return float(data[index]) return float(data[index[:3]]) def get_font_size(annotation): font_size = self._plot_config.font.size - 3 if annotation.font_size is not None: font_size = annotation.font_size return font_size def get_coordinate(annotation): coordinate = _index_to_coordinates(self._data_info.get_single_map_info(map_name), self._plot_config, annotation.voxel_index) if coordinate is None: return None return np.array(coordinate) def get_text_box_location(annotation): axis_to_data = axis.transData + axis.transAxes.inverted() coordinate_normalized = axis_to_data.transform(get_coordinate(annotation)) delta = annotation.text_distance + axis_to_data.transform([annotation.marker_size] * 2)[0] definitions = { 'upper left': [(-delta, +delta), 'right', 'bottom'], 'top left': [(-delta, +delta), 'right', 'bottom'], 'upper right': [(+delta, +delta), 'left', 'bottom'], 'top right': [(+delta, +delta), 'left', 'bottom'], 'lower left': [(-delta, -delta), 'right', 'top'], 'bottom left': [(-delta, -delta), 'right', 'top'], 'lower right': [(+delta, -delta), 'left', 'top'], 'bottom right': [(+delta, -delta), 'left', 'top'], 'top': [(0, +np.sqrt(2) * delta), 'center', 'bottom'], 'north': [(0, +np.sqrt(2) * delta), 'center', 'bottom'], 'bottom': [(0, -np.sqrt(2) * delta), 'center', 'top'], 'south': [(0, -np.sqrt(2) * delta), 'center', 'top'], 'left': [(-np.sqrt(2) * delta, 0), 'right', 'center'], 'west': [(-np.sqrt(2) * delta, 0), 'right', 'center'], 'right': [(+np.sqrt(2) * delta, 0), 'left', 'center'], 'east': [(+np.sqrt(2) * delta, 0), 'left', 'center'], } transform, halign, valign = definitions[annotation.text_location] return axis_to_data.inverted().transform(coordinate_normalized + transform), \ halign, valign def get_arrow_head_location(annotation): coordinate = get_coordinate(annotation) delta = annotation.marker_size / 2. definitions = { 'upper left': (-delta, +delta), 'top left': (-delta, +delta), 'upper right': (+delta, +delta), 'top right': (+delta, +delta), 'lower left': (-delta, -delta), 'bottom left': (-delta, -delta), 'lower right': (+delta, -delta), 'bottom right': (+delta, -delta), 'top': (0, delta), 'north': (0, delta), 'bottom': (0, -delta), 'south': (0, -delta), 'left': (-delta, 0), 'west': (-delta, 0), 'right': (delta, 0), 'east': (delta, 0), } return coordinate + definitions[annotation.text_location] def draw_annotation(annotation): coordinate = get_coordinate(annotation) if coordinate is None: return axis.add_patch(patches.Rectangle(coordinate - (annotation.marker_size / 2.), annotation.marker_size, annotation.marker_size, linewidth=1, edgecolor='white', facecolor='#0066ff')) xy_text, horizontal_alignment, vertical_alignment = get_text_box_location(annotation) axis.annotate( annotation.text_template.format(voxel_index=tuple(annotation.voxel_index), value=get_value(annotation)), xy=get_arrow_head_location(annotation), xytext=xy_text, horizontalalignment=horizontal_alignment, verticalalignment=vertical_alignment, multialignment='left', arrowprops=dict(color='white', connectionstyle="arc3", width=annotation.arrow_width, headwidth=8 + annotation.arrow_width, headlength=10), color='black', size=get_font_size(annotation), family=self._plot_config.font.name, bbox=dict(facecolor='white')) for annotation in self._plot_config.annotations: draw_annotation(annotation) def _get_colorbar_setting(self, map_name, attr_name): map_specific_colorbar_settings = self._get_map_attr(map_name, 'colorbar_settings') global_colorbar_settings = self._plot_config.colorbar_settings return map_specific_colorbar_settings.get_preferred(attr_name, [global_colorbar_settings]) def _add_colorbar(self, axis, map_name, image_figure, colorbar_label): """Add a colorbar to the axis Returns: axis: the image axis """ divider = make_axes_locatable(axis) colorbar_location = self._get_colorbar_setting(map_name, 'location') show_colorbar = self._get_colorbar_setting(map_name, 'visible') at_least_one_colorbar = any(self._get_colorbar_setting(name, 'visible') for name in self._plot_config.map_plot_options) if at_least_one_colorbar: axis_kwargs = dict(size="5%", pad=0.1) if self._plot_config.show_axis and colorbar_location in ['bottom', 'left']: axis_kwargs['pad'] = 0.3 if show_colorbar and colorbar_location in ('left', 'right'): colorbar_axis = divider.append_axes(colorbar_location, **axis_kwargs) else: fake_axis = divider.append_axes('right', **axis_kwargs) fake_axis.axis('off') if show_colorbar and colorbar_location in ('top', 'bottom'): colorbar_axis = divider.append_axes(colorbar_location, **axis_kwargs) else: fake_axis = divider.append_axes('bottom', **axis_kwargs) fake_axis.axis('off') if show_colorbar: kwargs = dict(cax=colorbar_axis, ticks=self._get_tick_locator(map_name)) if colorbar_label: kwargs.update(dict(label=colorbar_label)) if colorbar_location in ['top', 'bottom']: kwargs['orientation'] = 'horizontal' cbar = plt.colorbar(image_figure, **kwargs) cbar.formatter.set_powerlimits(self._get_colorbar_setting(map_name, 'power_limits')) colorbar_axis.yaxis.set_offset_position('left') if colorbar_location == 'left': colorbar_axis.yaxis.set_ticks_position('left') colorbar_axis.yaxis.set_label_position('left') if colorbar_location == 'top': colorbar_axis.xaxis.set_ticks_position('top') colorbar_axis.xaxis.set_label_position('top') cbar.update_ticks() if cbar.ax.get_yticklabels(): cbar.ax.get_yticklabels()[-1].set_verticalalignment('top') elif cbar.ax.get_xticklabels(): cbar.ax.get_xticklabels()[0].set_horizontalalignment('left') cbar.ax.get_xticklabels()[-1].set_horizontalalignment('right') self._apply_font_colorbar_axis(colorbar_axis) def _get_map_attr(self, map_name, option, default=None): if map_name in self._plot_config.map_plot_options: value = getattr(self._plot_config.map_plot_options[map_name], option) if value is not None: return value return default def _get_title_spacing(self, map_name): spacing = 1 + self._get_map_attr(map_name, 'title_spacing', self._plot_config.title_spacing or 0) if self._get_colorbar_setting(map_name, 'location') == 'top': spacing += 0.15 return spacing def _get_map_plot_options(self, map_name): cmap = get_cmap(self._get_map_attr(map_name, 'colormap', self._plot_config.colormap)) masked_color = self._get_map_attr(map_name, 'colormap_masked_color', self._plot_config.colormap_masked_color) if masked_color is not None: cmap.set_bad(color=masked_color) output_dict = {'vmin': self._data_info.get_single_map_info(map_name).min(), 'vmax': self._data_info.get_single_map_info(map_name).max(), 'cmap': cmap} scale = self._get_map_attr(map_name, 'scale', Scale()) if scale.use_max: output_dict['vmax'] = scale.vmax if scale.use_min: output_dict['vmin'] = scale.vmin return output_dict def _get_image(self, map_name): """Get the 2d image to display for the given data.""" def get_slice(data): """Get the slice of the data in the desired dimension.""" dimension = self._plot_config.dimension slice_index = self._plot_config.slice_index volume_index = self._plot_config.volume_index data_slice = get_slice_in_dimension(data, dimension, slice_index) if len(data_slice.shape) > 2: if data_slice.shape[2] == 3 and self._get_map_attr(map_name, 'interpret_as_colormap'): data_slice = np.abs(data_slice[:, :, :]) if self._get_map_attr(map_name, 'colormap_order'): order = self._get_map_attr(map_name, 'colormap_order') data_slice[..., :] = data_slice[..., [order.index(color) for color in 'rbg']] elif volume_index < data_slice.shape[2]: data_slice = np.squeeze(data_slice[:, :, volume_index]) else: data_slice = np.squeeze(data_slice[:, :, data_slice.shape[2] - 1]) if len(data_slice.shape) == 1: data_slice = data_slice[:, None] return data_slice def apply_modifications(data_slice): """Modify the data using (possible) value clipping and masking""" data_slice = self._get_map_attr(map_name, 'clipping', Clipping()).apply(data_slice) mask_name = self._get_map_attr(map_name, 'mask_name', self._plot_config.mask_name) if mask_name: mask = (get_slice(self._data_info.get_map_data(mask_name)) > 0) if len(data_slice.shape) != len(mask.shape): mask = mask[..., None] data_slice = data_slice * mask return data_slice data = self._data_info.get_map_data(map_name) if len(data.shape) == 2: data_slice = data else: data_slice = get_slice(data) data_slice = apply_modifications(data_slice) data_slice = _apply_transformations(self._plot_config, data_slice) if len(data_slice.shape) == 3: multiplication_map = self._get_map_attr(map_name, 'colormap_weight_map', None) if multiplication_map is not None and map_name != multiplication_map: multiplication_map = self._get_image(multiplication_map) if len(multiplication_map.shape) != 3: multiplication_map = multiplication_map[..., None] scale = self._get_map_attr(map_name, 'scale', Scale()) new_vmax = np.min(((scale.vmax if scale.use_max else 1), 1)) new_vmin = np.max(((scale.vmin if scale.use_min else 0), 0)) multiplication_map = multiplication_map + (1 - (new_vmax - new_vmin)) data_slice = data_slice * multiplication_map if np.max(data_slice) > 1: data_slice /= np.max(data_slice) return data_slice def _get_tick_locator(self, map_name): min_val, max_val = self._data_info.get_single_map_info(map_name).min_max() scale = self._get_map_attr(map_name, 'scale', Scale()) if scale.use_max: max_val = scale.vmax if scale.use_min: min_val = scale.vmin return MyColourBarTickLocator( min_val, max_val, numticks=self._get_colorbar_setting(map_name, 'nmr_ticks'), round_precision=self._get_colorbar_setting(map_name, 'round_precision') )
def _apply_transformations(plot_config, data_slice): """Rotate, flip and zoom the data slice. Depending on the plot configuration, this will apply some transformations to the given data slice. Args: plot_config (mdt.visualization.maps.base.MapPlotConfig): the plot configuration data_slice (ndarray): the 2d slice of data to transform Returns: ndarray: the transformed 2d slice of data """ if plot_config.rotate: data_slice = np.rot90(data_slice, plot_config.rotate // 90) if not plot_config.flipud: # by default we flipud to correct for matplotlib lower origin. If the user # sets flipud, we do not need to to it data_slice = np.flipud(data_slice) data_slice = plot_config.zoom.apply(data_slice) return data_slice def _coordinates_to_index(map_info, plot_config, x, y): """Converts data coordinates to index coordinates of the array. This is the inverse of :func:`_index_to_coordinates`. Args: map_info (mdt.visualization.maps.base.SingleMapInfo): the map information plot_config (mdt.visualization.maps.base.MapPlotConfig): the map plot configuration x (int): The x-coordinate in data coordinates. y (int): The y-coordinate in data coordinates. Returns tuple: Index coordinates of the map associated with the image (x, y, z, w). """ shape = map_info.get_size_in_dimension(plot_config.dimension, plot_config.rotate) # correct for zoom x += plot_config.zoom.p0.x y += plot_config.zoom.p0.y # correct for flip upside down if not plot_config.flipud: y = map_info.get_max_y_index(plot_config.dimension, plot_config.rotate) - y # correct for displayed axis, the view is x-data on y-image and y-data on x-image x, y = y, x # rotate the point rotated = Point2d(x, y).rotate90((-1 * plot_config.rotate % 360) // 90) # translate the point back to a new origin if plot_config.rotate == 90: rotated.y = shape[1] + rotated.y elif plot_config.rotate == 180: rotated.x = shape[1] + rotated.x rotated.y = shape[0] + rotated.y elif plot_config.rotate == 270: rotated.x = shape[0] + rotated.x if len(map_info.shape) == 2: return [rotated.x, rotated.y] # create the index index = [rotated.x, rotated.y] index.insert(plot_config.dimension, plot_config.slice_index) if len(map_info.data.shape) > 3: if plot_config.volume_index < map_info.data.shape[3]: index.append(plot_config.volume_index) else: index.append(map_info.data.shape[3] - 1) # fixes issue with off-by-one errors for ind in range(len(index)): if index[ind] >= map_info.shape[ind]: index[ind] = map_info.shape[ind] - 1 if index[ind] < 0: index[ind] = 0 return [int(el) for el in index] def _index_to_coordinates(map_info, plot_config, index): """Converts data index coordinates to an onscreen position. This is the inverse of :func:`_coordinates_to_index`. Args: map_info (mdt.visualization.maps.base.SingleMapInfo): the map information plot_config (mdt.visualization.maps.base.MapPlotConfig): the map plot configuration index (tuple of int): The data index we wish to translate to an onscreen position. Returns tuple or None: Coordinates of the given index on the screen (x, y) or None if coordinate not visible. """ dimension = plot_config.dimension if index[dimension] != plot_config.slice_index: return None reduced_index = list(index[:3]) del reduced_index[dimension] img_shape = list(map_info.shape[:3]) del img_shape[dimension] index_slice = np.arange(0, int(np.prod(img_shape)), 1).reshape(img_shape) index_nmr = index_slice[tuple(reduced_index)] coordinates = np.where(_apply_transformations(plot_config, index_slice) == index_nmr) if len(coordinates[0]) and len(coordinates[1]): return coordinates[1][0], coordinates[0][0] return None