From 1327cd1fa6cbdda7a7593a5100df261bb594a6b1 Mon Sep 17 00:00:00 2001
From: Benjamin Nguyen-Van-Yen <benjamin.nguyen-van-yen@inria.fr>
Date: Mon, 15 Jul 2024 12:51:25 +0200
Subject: [PATCH 1/4] Auto format

---
 gridplot/__init__.py | 95 +++++++++++++++-----------------------------
 1 file changed, 31 insertions(+), 64 deletions(-)

diff --git a/gridplot/__init__.py b/gridplot/__init__.py
index 6530b81..d1b171e 100644
--- a/gridplot/__init__.py
+++ b/gridplot/__init__.py
@@ -55,7 +55,7 @@ hierarchy = [
 
 def axes_levels(levels: list[str]):
   """Levels used *inside* an axes."""
-  return [ lvl for lvl in levels if lvl not in ["row", "column"] ]
+  return [lvl for lvl in levels if lvl not in ["row", "column"]]
 
 
 def values(xs, deps: Optional[list[str]] = None, **kw):
@@ -86,7 +86,7 @@ def unique_leaf_values(xs: Tree):
   elif isinstance(xs, dict):
     return reduce(
       lambda x1, x2: list(dict.fromkeys(x1 + x2)),
-      [ unique_leaf_values(x) for x in xs.values() ],
+      [unique_leaf_values(x) for x in xs.values()],
     )
   else:
     raise NotImplementedError
@@ -130,7 +130,7 @@ def validate_deps(deps: Optional[dict[str, list[str]]], levels: list[str] = hier
     else:
       # it's ok to depend on previous set levels in the hierarchy only
       ok_levels = levels[:i]
-      bad_levels = [ lv for lv in lvl_deps if lv not in ok_levels ]
+      bad_levels = [lv for lv in lvl_deps if lv not in ok_levels]
       if bad_levels:
         # not empty
         raise ValueError(
@@ -173,7 +173,7 @@ def set_default_levelmaps(
       elif k == "marker":
         xs = list(MarkerStyle.markers.keys())[:n]
       elif k == "alpha":
-        xs = np.arange(start=1., stop=0., step=-1/n)[:-1]
+        xs = np.arange(start=1.0, stop=0.0, step=-1 / n)[:-1]
       elif k == "linewidth":
         # FIXME improve
         xs = range(n, 0, -1)
@@ -182,7 +182,7 @@ def set_default_levelmaps(
         xs = range(n, 0, -1)
       else:
         raise ValueError("level not recognized")
-      levelmaps[k] = { v: x for v, x in zip(vs, xs) }
+      levelmaps[k] = {v: x for v, x in zip(vs, xs)}
 
 
 def empty_handle() -> Patch:
@@ -199,11 +199,12 @@ def legend_columns(
   relabel: Optional[Callable[[str], str]] = None,
 ):
   if relabel is None:
+
     def relabel(s):
       return s
 
   # `1 + ...` for the subtitle
-  lens = { lv: 1 + len(li) for lv, li in handlabs.items() }
+  lens = {lv: 1 + len(li) for lv, li in handlabs.items()}
   max_len = max(lens.values())
   clen = col_len(max_len)
 
@@ -264,22 +265,14 @@ def horizontal_legend_columns(handlabs, relabel):
   def col_len(max_len):
     return min(max_len, 5)
 
-  return legend_columns(
-    handlabs=handlabs,
-    col_len=col_len,
-    relabel=relabel
-  )
+  return legend_columns(handlabs=handlabs, col_len=col_len, relabel=relabel)
 
 
 def vertical_legend_columns(handlabs, relabel):
   def col_len(max_len):
     return max(max_len, 10)
 
-  return legend_columns(
-    handlabs=handlabs,
-    col_len=col_len,
-    relabel=relabel
-  )
+  return legend_columns(handlabs=handlabs, col_len=col_len, relabel=relabel)
 
 
 class Mapper(object):
@@ -315,7 +308,7 @@ class Mapper(object):
       "markersize": markersizes,
     }
 
-    self.used_levels = [ lvl for lvl in hierarchy if self.levels[lvl] is not None ]
+    self.used_levels = [lvl for lvl in hierarchy if self.levels[lvl] is not None]
     self.deps = validate_deps(deps, self.used_levels)
 
     self.levelmaps = {
@@ -343,9 +336,7 @@ class Mapper(object):
     self.fig = f
 
     self.spec = self.set_legend_grid(
-      legend,
-      title=legend_title,
-      relabel=label_legend
+      legend, layout=legend_layout, title=legend_title, relabel=label_legend
     )
 
     # improve constrained layout
@@ -382,10 +373,7 @@ class Mapper(object):
       w, h = self.fig.get_size_inches()
       needed = lh / h + 0.01
       print("Space needed:", needed)
-      g_all = self.fig.add_gridspec(
-        2, 1,
-        height_ratios=[1 - needed, needed]
-      )
+      g_all = self.fig.add_gridspec(2, 1, height_ratios=[1 - needed, needed])
       # put the plots into the first row
       g = g_all[0].subgridspec(self.nrows, self.ncols)
       # put the legend into the second row
@@ -418,10 +406,7 @@ class Mapper(object):
       w, h = self.fig.get_size_inches()
       needed = lw / w + 0.01
       print("Space needed:", needed)
-      g_all = self.fig.add_gridspec(
-        1, 2,
-        width_ratios=[1 - needed, needed]
-      )
+      g_all = self.fig.add_gridspec(1, 2, width_ratios=[1 - needed, needed])
       # put the plots into the first column
       g = g_all[0].subgridspec(self.nrows, self.ncols)
       # put the legend into the second column
@@ -469,11 +454,7 @@ class Mapper(object):
     if xs is None:
       return [None]
     else:
-      return values(
-        self.levels[level],
-        deps=deps,
-        **kw
-      )
+      return values(self.levels[level], deps=deps, **kw)
 
   def add_subplots(self, sharex, sharey, subplot_kw):
     # copied from Figure.subplots
@@ -482,20 +463,13 @@ class Mapper(object):
       for col in range(self.ncols):
         subplot_kw["sharex"] = self.shared_with(axes, row, col, sharex)
         subplot_kw["sharey"] = self.shared_with(axes, row, col, sharey)
-        axes[row, col] = self.fig.add_subplot(
-          self.spec[row, col],
-          **subplot_kw
-        )
+        axes[row, col] = self.fig.add_subplot(self.spec[row, col], **subplot_kw)
 
     # turn off redundant tick labeling
     if sharex in ["col", "all"]:
       # turn off all but the bottom row
       for ax in axes[:-1, :].flat:
-        ax.xaxis.set_tick_params(
-          which='both',
-          labelbottom=False,
-          labeltop=False
-        )
+        ax.xaxis.set_tick_params(which="both", labelbottom=False, labeltop=False)
         ax.xaxis.offsetText.set_visible(False)
 
     # I cheat with the isinstance, for the more complex case
@@ -503,11 +477,7 @@ class Mapper(object):
     if isinstance(sharey, np.ndarray) or sharey in ["row", "all"]:
       # turn off all but the first column
       for ax in axes[:, 1:].flat:
-        ax.yaxis.set_tick_params(
-          which='both',
-          labelleft=False,
-          labelright=False
-        )
+        ax.yaxis.set_tick_params(which="both", labelleft=False, labelright=False)
         ax.yaxis.offsetText.set_visible(False)
 
     # despine all axes (top and right)
@@ -582,16 +552,14 @@ class Mapper(object):
       # put in order and remove bad levels
       levels = [ lvl for lvl in self.used_levels if lvl in levels ]
 
-    levelvs = [ lvl + "v" for lvl in levels ]
+    levelvs = [lvl + "v" for lvl in levels]
 
     def get_values(lvl, **kw):
       return self.level_values(lvl[:-1], **kw)
 
     def g(**kw):
       self.call_with_maps(
-        f=partial(self.call_with_ax, f=f),
-        levels=axes_levels(levels),
-        **kw
+        f=partial(self.call_with_ax, f=f), levels=axes_levels(levels), **kw
       )
 
     reciter(
@@ -622,6 +590,7 @@ class Mapper(object):
 
   def legend_entries(self, levels=None, relabel=None):
     """Dict of legend entries optionally relabeled using function `relabel` by level."""
+
     # legend labels
     def lab(c):
       if relabel:
@@ -638,7 +607,8 @@ class Mapper(object):
     # linestyle level
     def line(ls):
       return Line2D(
-        [0.], [0.],
+        [0.0],
+        [0.0],
         linestyle=ls,
         color="black",
       )
@@ -646,7 +616,8 @@ class Mapper(object):
     # marker level
     def mark(mk):
       return Line2D(
-        [0.], [0.],
+        [0.0],
+        [0.0],
         linestyle="",
         marker=mk,
         markerfacecolor="black",
@@ -654,15 +625,13 @@ class Mapper(object):
 
     # alpha level
     def alpha_patch(a):
-      return Patch(
-        alpha=a,
-        color="black"
-      )
+      return Patch(alpha=a, color="black")
 
     # linewidth level
     def linew(w):
       return Line2D(
-        [0.], [0.],
+        [0.0],
+        [0.0],
         linewidth=w,
         color="black",
       )
@@ -670,11 +639,12 @@ class Mapper(object):
     # markersize level
     def point(sz):
       return Line2D(
-        [0.], [0.],
+        [0.0],
+        [0.0],
         linestyle="",
         marker="o",
         markersize=sz,
-        markerfacecolor=(1., 1., 1., 0.),
+        markerfacecolor=(1.0, 1.0, 1.0, 0.0),
         color="black",
       )
 
@@ -696,10 +666,7 @@ class Mapper(object):
       return handles[lvl](self.levelmaps[lvl][v])
 
     handlabs = {
-      lvl: [
-        (handle(lvl, v), lab(v))
-        for v in unique_leaf_values(self.levels[lvl])
-      ]
+      lvl: [(handle(lvl, v), lab(v)) for v in unique_leaf_values(self.levels[lvl])]
       for lvl in levels
     }
 
-- 
GitLab


From 34c6ad4e08479f8457f0d56185d018c409e0fb31 Mon Sep 17 00:00:00 2001
From: Benjamin Nguyen-Van-Yen <benjamin.nguyen-van-yen@inria.fr>
Date: Mon, 15 Jul 2024 12:53:04 +0200
Subject: [PATCH 2/4] Add legend_layout arg to layout legend by element

Adds the possibility to layout legend by element instead of by level.
---
 gridplot/__init__.py | 117 +++++++++++++++++++++++++++++++++++--------
 1 file changed, 96 insertions(+), 21 deletions(-)

diff --git a/gridplot/__init__.py b/gridplot/__init__.py
index d1b171e..7ca1bf9 100644
--- a/gridplot/__init__.py
+++ b/gridplot/__init__.py
@@ -287,15 +287,31 @@ class Mapper(object):
   """
 
   def __init__(
-    self, deps=None,
-    rows=None, columns=None, colors=None, linestyles=None, markers=None,
-    alphas=None, linewidths=None, markersizes=None,
-    palette=None, stylemap=None, markermap=None,
-    alphamap=None, linewidthmap=None, markersizemap=None,
-    legend=None, legend_title=None, label_legend=None,
-    sharex=None, sharey=None,
-    aspect=None, figsize=None
-  ) :
+    self,
+    deps=None,
+    rows=None,
+    columns=None,
+    colors=None,
+    linestyles=None,
+    markers=None,
+    alphas=None,
+    linewidths=None,
+    markersizes=None,
+    palette=None,
+    stylemap=None,
+    markermap=None,
+    alphamap=None,
+    linewidthmap=None,
+    markersizemap=None,
+    legend=None,
+    legend_layout=None,
+    legend_title=None,
+    label_legend=None,
+    sharex=None,
+    sharey=None,
+    aspect=None,
+    figsize=None,
+  ):
     """Set up the figure, its layout and its subplots."""
     self.levels = {
       "row": rows,
@@ -349,8 +365,13 @@ class Mapper(object):
 
     self.add_subplots(sharex, sharey, subplot_kw)
 
-  def set_legend_grid(self, legend, title, relabel):
-    handlabs = self.legend_entries(relabel=relabel)
+  def set_legend_grid(self, legend, layout, title, relabel):
+    if layout is None or layout == "by_level":
+      handlabs = self.legend_entries_by_level(relabel=relabel)
+    elif layout == "by_element":
+      handlabs = self.legend_entries_by_element(relabel=relabel)
+    else:
+      raise ValueError("invalid legend_layout value")
     # FIXME to decide the number of columns,
     # also pass the figure size to determine the space available
     if legend == "horizontal":
@@ -423,10 +444,10 @@ class Mapper(object):
 
       # return grid for plots
       return g
-    elif legend is None:
+    elif legend is None or not legend:
       return self.fig.add_gridspec(self.nrows, self.ncols)
     else:
-      raise ValueError("legend not 'vertical', 'horizontal', or None")
+      raise ValueError("legend not 'vertical', 'horizontal', False, or None")
 
   def shared_with(self, axes, row, col, share):
     if isinstance(share, np.ndarray):
@@ -544,13 +565,18 @@ class Mapper(object):
     ax = self.axes.loc[column, row]
     f(ax=ax, **kw)
 
-  def iter(self, f, levels=None):
-    """Call `f` for every combination of level values."""
+  def sub_levels(self, levels=None):
     if levels is None:
       levels = self.used_levels
     else:
       # put in order and remove bad levels
-      levels = [ lvl for lvl in self.used_levels if lvl in levels ]
+      levels = [lvl for lvl in self.used_levels if lvl in levels]
+
+    return levels
+
+  def iter(self, f, levels=None):
+    """Call `f` for every combination of level values."""
+    levels = self.sub_levels(levels)
 
     levelvs = [lvl + "v" for lvl in levels]
 
@@ -588,7 +614,7 @@ class Mapper(object):
     for columnv in self.level_values("column", rowv=rowv):
       f(ax=self.axes.loc[columnv, rowv], columnv=columnv)
 
-  def legend_entries(self, levels=None, relabel=None):
+  def legend_entries_by_level(self, levels=None, relabel=None):
     """Dict of legend entries optionally relabeled using function `relabel` by level."""
 
     # legend labels
@@ -648,10 +674,8 @@ class Mapper(object):
         color="black",
       )
 
-    if levels is None:
-      levels = self.used_levels
-    # exclude 'row'/'column', for which we don't need a legend
-    levels = axes_levels(levels)
+    # only keep valid levels, in order, and exclude 'row'/'column' levels for which we don't need a legend
+    levels = axes_levels(self.sub_levels(levels))
 
     handles = {
       "color": patch,
@@ -672,6 +696,57 @@ class Mapper(object):
 
     return handlabs
 
+  def legend_entries_by_element(self, levels=None, relabel=None):
+    # legend labels
+    def lab(*cs):
+      if relabel:
+        return relabel(*cs)
+      else:
+        return " | ".join([str(c) for c in cs])
+
+    def rich_handle(
+      color="black",
+      linestyle="",
+      linewidth=None,
+      marker=None,
+      markersize=None,
+      alpha=None,
+    ):
+      return Line2D(
+        [0.0],
+        [0.0],
+        color=color,
+        linestyle=linestyle,
+        linewidth=linewidth,
+        marker=marker,
+        markersize=markersize,
+        markerfacecolor=color,
+        alpha=alpha,
+      )
+
+    # only keep valid levels, in order, and exclude 'row'/'column' levels for which we don't need a legend
+    levels = axes_levels(self.sub_levels(levels))
+    levelvs = [lvl + "v" for lvl in levels]
+
+    handlabs = list()
+
+    def f(**kw):
+      labs = [v for k, v in kw.items() if k.endswith("v")]
+      lvs = {k: v for k, v in kw.items() if not k.endswith("v")}
+      label = lab(*labs)
+      handle = rich_handle(**lvs)
+      handlabs.append((handle, label))
+
+    def get_values(lvl, **kw):
+      return self.level_values(lvl[:-1], **kw)
+
+    def g(**kw):
+      self.call_with_maps(f=f, levels=levels, **kw)
+
+    reciter(g, levels=levelvs, get_values=get_values)
+
+    return {"": handlabs}
+
   def _redraw(self):
     self.fig.draw(self.fig.canvas.get_renderer())
 
-- 
GitLab


From 4a48e0036c9315ea2fb78fc5aadf345bb5b79623 Mon Sep 17 00:00:00 2001
From: Benjamin Nguyen-Van-Yen <benjamin.nguyen-van-yen@inria.fr>
Date: Mon, 15 Jul 2024 14:03:47 +0200
Subject: [PATCH 3/4] Use logging

---
 gridplot/__init__.py | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/gridplot/__init__.py b/gridplot/__init__.py
index 7ca1bf9..3b33efc 100644
--- a/gridplot/__init__.py
+++ b/gridplot/__init__.py
@@ -1,6 +1,7 @@
 #!/usr/bin/python3
 # -*- coding: utf-8 -*-
 
+import logging
 from functools import partial, reduce
 from typing import TypeVar, Any, Union, Optional, Callable
 
@@ -39,6 +40,7 @@ T = TypeVar("T")
 T2 = TypeVar("T2")
 Tree = Optional[Union[list[str], dict[str, "Tree"]]]
 
+logger = logging.getLogger(__name__)
 
 # All levels, in dependency order
 hierarchy = [
@@ -89,6 +91,7 @@ def unique_leaf_values(xs: Tree):
       [unique_leaf_values(x) for x in xs.values()],
     )
   else:
+    logger.error(f"unique_leaf_values not implemented for xs of type {type(xs)}: {xs}")
     raise NotImplementedError
 
 
@@ -393,7 +396,7 @@ class Mapper(object):
       # make a grid to partition between plots and legend
       w, h = self.fig.get_size_inches()
       needed = lh / h + 0.01
-      print("Space needed:", needed)
+      logger.debug("Space needed:", needed)
       g_all = self.fig.add_gridspec(2, 1, height_ratios=[1 - needed, needed])
       # put the plots into the first row
       g = g_all[0].subgridspec(self.nrows, self.ncols)
@@ -426,7 +429,7 @@ class Mapper(object):
       # make a grid to partition between plots and legend
       w, h = self.fig.get_size_inches()
       needed = lw / w + 0.01
-      print("Space needed:", needed)
+      logger.debug("Space needed:", needed)
       g_all = self.fig.add_gridspec(1, 2, width_ratios=[1 - needed, needed])
       # put the plots into the first column
       g = g_all[0].subgridspec(self.nrows, self.ncols)
-- 
GitLab


From 65c4c1bb2da43f152174478103b08670f1f760b4 Mon Sep 17 00:00:00 2001
From: Benjamin Nguyen-Van-Yen <benjamin.nguyen-van-yen@inria.fr>
Date: Mon, 15 Jul 2024 14:03:57 +0200
Subject: [PATCH 4/4] Adjust facet margins

---
 gridplot/__init__.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/gridplot/__init__.py b/gridplot/__init__.py
index 3b33efc..3edf8f6 100644
--- a/gridplot/__init__.py
+++ b/gridplot/__init__.py
@@ -362,8 +362,8 @@ class Mapper(object):
     f.set_constrained_layout_pads(
       w_pad=0,
       h_pad=0,
-      wspace=0.01,
-      hspace=0.01,
+      wspace=0.05,
+      hspace=0.05,
     )
 
     self.add_subplots(sharex, sharey, subplot_kw)
-- 
GitLab