diff --git a/gridplot/__init__.py b/gridplot/__init__.py index 6530b8115dcfaf78aef5cbb4d10b54d571b1c7bc..3edf8f61e70ed11006cd7670c57e2a14d996ed6c 100644 --- a/gridplot/__init__.py +++ b/gridplot/__init__.py @@ -1,6 +1,7 @@ #!/usr/bin/python3 # -*- coding: utf-8 -*- +import logging from functools import partial, reduce from typing import TypeVar, Any, Union, Optional, Callable @@ -39,6 +40,7 @@ T = TypeVar("T") T2 = TypeVar("T2") Tree = Optional[Union[list[str], dict[str, "Tree"]]] +logger = logging.getLogger(__name__) # All levels, in dependency order hierarchy = [ @@ -55,7 +57,7 @@ hierarchy = [ def axes_levels(levels: list[str]): """Levels used *inside* an axes.""" - return [ lvl for lvl in levels if lvl not in ["row", "column"] ] + return [lvl for lvl in levels if lvl not in ["row", "column"]] def values(xs, deps: Optional[list[str]] = None, **kw): @@ -86,9 +88,10 @@ def unique_leaf_values(xs: Tree): elif isinstance(xs, dict): return reduce( lambda x1, x2: list(dict.fromkeys(x1 + x2)), - [ unique_leaf_values(x) for x in xs.values() ], + [unique_leaf_values(x) for x in xs.values()], ) else: + logger.error(f"unique_leaf_values not implemented for xs of type {type(xs)}: {xs}") raise NotImplementedError @@ -130,7 +133,7 @@ def validate_deps(deps: Optional[dict[str, list[str]]], levels: list[str] = hier else: # it's ok to depend on previous set levels in the hierarchy only ok_levels = levels[:i] - bad_levels = [ lv for lv in lvl_deps if lv not in ok_levels ] + bad_levels = [lv for lv in lvl_deps if lv not in ok_levels] if bad_levels: # not empty raise ValueError( @@ -173,7 +176,7 @@ def set_default_levelmaps( elif k == "marker": xs = list(MarkerStyle.markers.keys())[:n] elif k == "alpha": - xs = np.arange(start=1., stop=0., step=-1/n)[:-1] + xs = np.arange(start=1.0, stop=0.0, step=-1 / n)[:-1] elif k == "linewidth": # FIXME improve xs = range(n, 0, -1) @@ -182,7 +185,7 @@ def set_default_levelmaps( xs = range(n, 0, -1) else: raise ValueError("level not recognized") - levelmaps[k] = { v: x for v, x in zip(vs, xs) } + levelmaps[k] = {v: x for v, x in zip(vs, xs)} def empty_handle() -> Patch: @@ -199,11 +202,12 @@ def legend_columns( relabel: Optional[Callable[[str], str]] = None, ): if relabel is None: + def relabel(s): return s # `1 + ...` for the subtitle - lens = { lv: 1 + len(li) for lv, li in handlabs.items() } + lens = {lv: 1 + len(li) for lv, li in handlabs.items()} max_len = max(lens.values()) clen = col_len(max_len) @@ -264,22 +268,14 @@ def horizontal_legend_columns(handlabs, relabel): def col_len(max_len): return min(max_len, 5) - return legend_columns( - handlabs=handlabs, - col_len=col_len, - relabel=relabel - ) + return legend_columns(handlabs=handlabs, col_len=col_len, relabel=relabel) def vertical_legend_columns(handlabs, relabel): def col_len(max_len): return max(max_len, 10) - return legend_columns( - handlabs=handlabs, - col_len=col_len, - relabel=relabel - ) + return legend_columns(handlabs=handlabs, col_len=col_len, relabel=relabel) class Mapper(object): @@ -294,15 +290,31 @@ class Mapper(object): """ def __init__( - self, deps=None, - rows=None, columns=None, colors=None, linestyles=None, markers=None, - alphas=None, linewidths=None, markersizes=None, - palette=None, stylemap=None, markermap=None, - alphamap=None, linewidthmap=None, markersizemap=None, - legend=None, legend_title=None, label_legend=None, - sharex=None, sharey=None, - aspect=None, figsize=None - ) : + self, + deps=None, + rows=None, + columns=None, + colors=None, + linestyles=None, + markers=None, + alphas=None, + linewidths=None, + markersizes=None, + palette=None, + stylemap=None, + markermap=None, + alphamap=None, + linewidthmap=None, + markersizemap=None, + legend=None, + legend_layout=None, + legend_title=None, + label_legend=None, + sharex=None, + sharey=None, + aspect=None, + figsize=None, + ): """Set up the figure, its layout and its subplots.""" self.levels = { "row": rows, @@ -315,7 +327,7 @@ class Mapper(object): "markersize": markersizes, } - self.used_levels = [ lvl for lvl in hierarchy if self.levels[lvl] is not None ] + self.used_levels = [lvl for lvl in hierarchy if self.levels[lvl] is not None] self.deps = validate_deps(deps, self.used_levels) self.levelmaps = { @@ -343,23 +355,26 @@ class Mapper(object): self.fig = f self.spec = self.set_legend_grid( - legend, - title=legend_title, - relabel=label_legend + legend, layout=legend_layout, title=legend_title, relabel=label_legend ) # improve constrained layout f.set_constrained_layout_pads( w_pad=0, h_pad=0, - wspace=0.01, - hspace=0.01, + wspace=0.05, + hspace=0.05, ) self.add_subplots(sharex, sharey, subplot_kw) - def set_legend_grid(self, legend, title, relabel): - handlabs = self.legend_entries(relabel=relabel) + def set_legend_grid(self, legend, layout, title, relabel): + if layout is None or layout == "by_level": + handlabs = self.legend_entries_by_level(relabel=relabel) + elif layout == "by_element": + handlabs = self.legend_entries_by_element(relabel=relabel) + else: + raise ValueError("invalid legend_layout value") # FIXME to decide the number of columns, # also pass the figure size to determine the space available if legend == "horizontal": @@ -381,11 +396,8 @@ class Mapper(object): # make a grid to partition between plots and legend w, h = self.fig.get_size_inches() needed = lh / h + 0.01 - print("Space needed:", needed) - g_all = self.fig.add_gridspec( - 2, 1, - height_ratios=[1 - needed, needed] - ) + logger.debug("Space needed:", needed) + g_all = self.fig.add_gridspec(2, 1, height_ratios=[1 - needed, needed]) # put the plots into the first row g = g_all[0].subgridspec(self.nrows, self.ncols) # put the legend into the second row @@ -417,11 +429,8 @@ class Mapper(object): # make a grid to partition between plots and legend w, h = self.fig.get_size_inches() needed = lw / w + 0.01 - print("Space needed:", needed) - g_all = self.fig.add_gridspec( - 1, 2, - width_ratios=[1 - needed, needed] - ) + logger.debug("Space needed:", needed) + g_all = self.fig.add_gridspec(1, 2, width_ratios=[1 - needed, needed]) # put the plots into the first column g = g_all[0].subgridspec(self.nrows, self.ncols) # put the legend into the second column @@ -438,10 +447,10 @@ class Mapper(object): # return grid for plots return g - elif legend is None: + elif legend is None or not legend: return self.fig.add_gridspec(self.nrows, self.ncols) else: - raise ValueError("legend not 'vertical', 'horizontal', or None") + raise ValueError("legend not 'vertical', 'horizontal', False, or None") def shared_with(self, axes, row, col, share): if isinstance(share, np.ndarray): @@ -469,11 +478,7 @@ class Mapper(object): if xs is None: return [None] else: - return values( - self.levels[level], - deps=deps, - **kw - ) + return values(self.levels[level], deps=deps, **kw) def add_subplots(self, sharex, sharey, subplot_kw): # copied from Figure.subplots @@ -482,20 +487,13 @@ class Mapper(object): for col in range(self.ncols): subplot_kw["sharex"] = self.shared_with(axes, row, col, sharex) subplot_kw["sharey"] = self.shared_with(axes, row, col, sharey) - axes[row, col] = self.fig.add_subplot( - self.spec[row, col], - **subplot_kw - ) + axes[row, col] = self.fig.add_subplot(self.spec[row, col], **subplot_kw) # turn off redundant tick labeling if sharex in ["col", "all"]: # turn off all but the bottom row for ax in axes[:-1, :].flat: - ax.xaxis.set_tick_params( - which='both', - labelbottom=False, - labeltop=False - ) + ax.xaxis.set_tick_params(which="both", labelbottom=False, labeltop=False) ax.xaxis.offsetText.set_visible(False) # I cheat with the isinstance, for the more complex case @@ -503,11 +501,7 @@ class Mapper(object): if isinstance(sharey, np.ndarray) or sharey in ["row", "all"]: # turn off all but the first column for ax in axes[:, 1:].flat: - ax.yaxis.set_tick_params( - which='both', - labelleft=False, - labelright=False - ) + ax.yaxis.set_tick_params(which="both", labelleft=False, labelright=False) ax.yaxis.offsetText.set_visible(False) # despine all axes (top and right) @@ -574,24 +568,27 @@ class Mapper(object): ax = self.axes.loc[column, row] f(ax=ax, **kw) - def iter(self, f, levels=None): - """Call `f` for every combination of level values.""" + def sub_levels(self, levels=None): if levels is None: levels = self.used_levels else: # put in order and remove bad levels - levels = [ lvl for lvl in self.used_levels if lvl in levels ] + levels = [lvl for lvl in self.used_levels if lvl in levels] + + return levels - levelvs = [ lvl + "v" for lvl in levels ] + def iter(self, f, levels=None): + """Call `f` for every combination of level values.""" + levels = self.sub_levels(levels) + + levelvs = [lvl + "v" for lvl in levels] def get_values(lvl, **kw): return self.level_values(lvl[:-1], **kw) def g(**kw): self.call_with_maps( - f=partial(self.call_with_ax, f=f), - levels=axes_levels(levels), - **kw + f=partial(self.call_with_ax, f=f), levels=axes_levels(levels), **kw ) reciter( @@ -620,8 +617,9 @@ class Mapper(object): for columnv in self.level_values("column", rowv=rowv): f(ax=self.axes.loc[columnv, rowv], columnv=columnv) - def legend_entries(self, levels=None, relabel=None): + def legend_entries_by_level(self, levels=None, relabel=None): """Dict of legend entries optionally relabeled using function `relabel` by level.""" + # legend labels def lab(c): if relabel: @@ -638,7 +636,8 @@ class Mapper(object): # linestyle level def line(ls): return Line2D( - [0.], [0.], + [0.0], + [0.0], linestyle=ls, color="black", ) @@ -646,7 +645,8 @@ class Mapper(object): # marker level def mark(mk): return Line2D( - [0.], [0.], + [0.0], + [0.0], linestyle="", marker=mk, markerfacecolor="black", @@ -654,15 +654,13 @@ class Mapper(object): # alpha level def alpha_patch(a): - return Patch( - alpha=a, - color="black" - ) + return Patch(alpha=a, color="black") # linewidth level def linew(w): return Line2D( - [0.], [0.], + [0.0], + [0.0], linewidth=w, color="black", ) @@ -670,18 +668,17 @@ class Mapper(object): # markersize level def point(sz): return Line2D( - [0.], [0.], + [0.0], + [0.0], linestyle="", marker="o", markersize=sz, - markerfacecolor=(1., 1., 1., 0.), + markerfacecolor=(1.0, 1.0, 1.0, 0.0), color="black", ) - if levels is None: - levels = self.used_levels - # exclude 'row'/'column', for which we don't need a legend - levels = axes_levels(levels) + # only keep valid levels, in order, and exclude 'row'/'column' levels for which we don't need a legend + levels = axes_levels(self.sub_levels(levels)) handles = { "color": patch, @@ -696,15 +693,63 @@ class Mapper(object): return handles[lvl](self.levelmaps[lvl][v]) handlabs = { - lvl: [ - (handle(lvl, v), lab(v)) - for v in unique_leaf_values(self.levels[lvl]) - ] + lvl: [(handle(lvl, v), lab(v)) for v in unique_leaf_values(self.levels[lvl])] for lvl in levels } return handlabs + def legend_entries_by_element(self, levels=None, relabel=None): + # legend labels + def lab(*cs): + if relabel: + return relabel(*cs) + else: + return " | ".join([str(c) for c in cs]) + + def rich_handle( + color="black", + linestyle="", + linewidth=None, + marker=None, + markersize=None, + alpha=None, + ): + return Line2D( + [0.0], + [0.0], + color=color, + linestyle=linestyle, + linewidth=linewidth, + marker=marker, + markersize=markersize, + markerfacecolor=color, + alpha=alpha, + ) + + # only keep valid levels, in order, and exclude 'row'/'column' levels for which we don't need a legend + levels = axes_levels(self.sub_levels(levels)) + levelvs = [lvl + "v" for lvl in levels] + + handlabs = list() + + def f(**kw): + labs = [v for k, v in kw.items() if k.endswith("v")] + lvs = {k: v for k, v in kw.items() if not k.endswith("v")} + label = lab(*labs) + handle = rich_handle(**lvs) + handlabs.append((handle, label)) + + def get_values(lvl, **kw): + return self.level_values(lvl[:-1], **kw) + + def g(**kw): + self.call_with_maps(f=f, levels=levels, **kw) + + reciter(g, levels=levelvs, get_values=get_values) + + return {"": handlabs} + def _redraw(self): self.fig.draw(self.fig.canvas.get_renderer())