Source code for mdt.visualization.layouts

import itertools
import numpy as np
from matplotlib.gridspec import GridSpec
from mdt.visualization.dict_conversion import SimpleClassConversion, IntConversion, SimpleDictConversion

__author__ = 'Robbert Harms'
__date__ = "2016-09-02"
__maintainer__ = "Robbert Harms"
__email__ = ""

[docs]class GridLayout: def __init__(self, spacings=None): super().__init__() self.spacings = spacings or {'left': 0.10, 'right': 0.86, 'top': 0.97, 'bottom': 0.03, 'wspace': 0.40, 'hspace': 0.15} if self.spacings['top'] <= self.spacings['bottom']: raise ValueError('The top ({}) can not be smaller than the bottom ({}) in the spacings'.format( self.spacings['top'], self.spacings['bottom'])) if self.spacings['left'] >= self.spacings['right']: raise ValueError('Left ({}) can not be larger than right ({}) in the spacings'.format( self.spacings['left'], self.spacings['right']))
[docs] @classmethod def get_conversion_info(cls): return SimpleClassConversion(cls, cls._get_attribute_conversions())
@classmethod def _get_attribute_conversions(cls): return {'spacings': SimpleDictConversion(conversion_func=float)}
[docs] def get_gridspec(self, figure, nmr_plots): """Get the grid layout specifier for the given figure using the given number of plots. Args: figure (Figure): the figure to add the axis to nmr_plots (int): the total number of plots Returns: GridLayoutSpecifier: the specifier we can ask new subplot axis from """
def __eq__(self, other): if not isinstance(other, GridLayout): return False return isinstance(other, type(self)) and other.spacings == self.spacings def __ne__(self, other): return not self.__eq__(other)
[docs]class GridLayoutSpecifier: def __init__(self, gridspec, figure, positions=None): """Create a grid layout specifier using the given gridspec and the given figure. Args: gridspec (GridSpec): the gridspec to use figure (Figure): the figure to generate subplots for positions (:class:`list`): if given, a list with grid spec indices for every requested axis can be logical indices or (x, y) coordinate indices (choose one and stick with it). """ self.gridspec = gridspec self.figure = figure self.positions = positions
[docs] def get_axis(self, index): gridspec_ind = self.gridspec[index] if self.positions is not None: gridspec_ind = self.gridspec[self.positions[index]] return self.figure.add_subplot(gridspec_ind)
[docs]class AutoGridLayout(GridLayout):
[docs] def get_gridspec(self, figure, nmr_plots): rows, cols = self._get_square_size(nmr_plots) return GridLayoutSpecifier(GridSpec(rows, cols, **self.spacings), figure)
def _get_square_size(self, nmr_plots): defaults = ((1, 1), (1, 2), (2, 2), (2, 2), (2, 3), (2, 3), (2, 3)) if nmr_plots < len(defaults): return defaults[nmr_plots - 1] else: cols = np.ceil(nmr_plots / 3.0) rows = np.ceil(float(nmr_plots) / cols) rows = int(rows) cols = int(cols) return rows, cols
[docs]class Rectangular(GridLayout): def __init__(self, rows=None, cols=None, spacings=None): super().__init__(spacings=spacings) self.rows = rows self.cols = cols if self.rows is not None: self.rows = int(self.rows) if self.rows < 1: raise ValueError('The number of rows ({}) can not be smaller than 1.'.format(self.rows)) if self.cols is not None: self.cols = int(self.cols) if self.cols < 1: raise ValueError('The number of columns ({}) can not be smaller than 1.'.format(self.rows)) @classmethod def _get_attribute_conversions(cls): conversions = super()._get_attribute_conversions() conversions.update({'rows': IntConversion(), 'cols': IntConversion()}) return conversions
[docs] def get_gridspec(self, figure, nmr_plots): rows = self.rows cols = self.cols if rows is None and cols is None: return AutoGridLayout(spacings=self.spacings).get_gridspec(figure, nmr_plots) if rows is None: rows = int(np.ceil(nmr_plots / cols)) if cols is None: cols = int(np.ceil(nmr_plots / rows)) if rows * cols < nmr_plots: cols = int(np.ceil(nmr_plots / rows)) return GridLayoutSpecifier(GridSpec(rows, cols, **self.spacings), figure)
def __eq__(self, other): if not isinstance(other, Rectangular): return False return isinstance(other, type(self)) and other.rows == self.rows and other.cols == self.cols \ and other.spacings == self.spacings
[docs]class LowerTriangular(GridLayout): def __init__(self, padding=0, spacings=None): """Create a lower triangular plot layout. Args: padding (int): normally we will construct the lower triangle from the top, like:: * * * * * * If padding is enabled, we will pad as many images from the top as specified. For example, a padding of 1, with 5 images yields:: * * * * * Or a padding of 2 with 4 images:: * * * * spacings (dict): the spacings around each plot """ super().__init__(spacings=spacings) self.padding = padding or 0
[docs] def get_gridspec(self, figure, nmr_plots): rows, columns, positions = self._get_size_and_position(nmr_plots) return GridLayoutSpecifier(GridSpec(rows, columns, **self.spacings), figure, positions=positions)
@classmethod def _get_attribute_conversions(cls): conversions = super()._get_attribute_conversions() conversions.update({'padding': IntConversion()}) return conversions def _get_size_and_position(self, nmr_plots): size = self._get_lowest_triangle_length(nmr_plots + self.padding) row_shift = self._get_biggest_triangle_length(self.padding) positions = [] for x, y in itertools.product(range(size), range(size)): if x >= y: positions.append((x - row_shift, y)) return size - row_shift, size, positions[self.padding:] def __eq__(self, other): if not isinstance(other, LowerTriangular): return False return isinstance(other, type(self)) and other.padding == self.padding @staticmethod def _get_lowest_triangle_length(nmr_plots): return int(np.ceil((-1 + np.sqrt(1 + 8 * nmr_plots)) / 2.)) @staticmethod def _get_biggest_triangle_length(nmr_plots): return int(np.floor((-1 + np.sqrt(1 + 8 * nmr_plots)) / 2.))
[docs]class SingleColumn(GridLayout):
[docs] def get_gridspec(self, figure, nmr_plots): return GridLayoutSpecifier(GridSpec(nmr_plots, 1, **self.spacings), figure)
[docs]class SingleRow(GridLayout):
[docs] def get_gridspec(self, figure, nmr_plots): return GridLayoutSpecifier(GridSpec(1, nmr_plots, **self.spacings), figure)