diff --git a/gridplot/__init__.py b/gridplot/__init__.py
index 6530b8115dcfaf78aef5cbb4d10b54d571b1c7bc..3edf8f61e70ed11006cd7670c57e2a14d996ed6c 100644
--- a/gridplot/__init__.py
+++ b/gridplot/__init__.py
@@ -1,6 +1,7 @@
 #!/usr/bin/python3
 # -*- coding: utf-8 -*-
 
+import logging
 from functools import partial, reduce
 from typing import TypeVar, Any, Union, Optional, Callable
 
@@ -39,6 +40,7 @@ T = TypeVar("T")
 T2 = TypeVar("T2")
 Tree = Optional[Union[list[str], dict[str, "Tree"]]]
 
+logger = logging.getLogger(__name__)
 
 # All levels, in dependency order
 hierarchy = [
@@ -55,7 +57,7 @@ hierarchy = [
 
 def axes_levels(levels: list[str]):
   """Levels used *inside* an axes."""
-  return [ lvl for lvl in levels if lvl not in ["row", "column"] ]
+  return [lvl for lvl in levels if lvl not in ["row", "column"]]
 
 
 def values(xs, deps: Optional[list[str]] = None, **kw):
@@ -86,9 +88,10 @@ def unique_leaf_values(xs: Tree):
   elif isinstance(xs, dict):
     return reduce(
       lambda x1, x2: list(dict.fromkeys(x1 + x2)),
-      [ unique_leaf_values(x) for x in xs.values() ],
+      [unique_leaf_values(x) for x in xs.values()],
     )
   else:
+    logger.error(f"unique_leaf_values not implemented for xs of type {type(xs)}: {xs}")
     raise NotImplementedError
 
 
@@ -130,7 +133,7 @@ def validate_deps(deps: Optional[dict[str, list[str]]], levels: list[str] = hier
     else:
       # it's ok to depend on previous set levels in the hierarchy only
       ok_levels = levels[:i]
-      bad_levels = [ lv for lv in lvl_deps if lv not in ok_levels ]
+      bad_levels = [lv for lv in lvl_deps if lv not in ok_levels]
       if bad_levels:
         # not empty
         raise ValueError(
@@ -173,7 +176,7 @@ def set_default_levelmaps(
       elif k == "marker":
         xs = list(MarkerStyle.markers.keys())[:n]
       elif k == "alpha":
-        xs = np.arange(start=1., stop=0., step=-1/n)[:-1]
+        xs = np.arange(start=1.0, stop=0.0, step=-1 / n)[:-1]
       elif k == "linewidth":
         # FIXME improve
         xs = range(n, 0, -1)
@@ -182,7 +185,7 @@ def set_default_levelmaps(
         xs = range(n, 0, -1)
       else:
         raise ValueError("level not recognized")
-      levelmaps[k] = { v: x for v, x in zip(vs, xs) }
+      levelmaps[k] = {v: x for v, x in zip(vs, xs)}
 
 
 def empty_handle() -> Patch:
@@ -199,11 +202,12 @@ def legend_columns(
   relabel: Optional[Callable[[str], str]] = None,
 ):
   if relabel is None:
+
     def relabel(s):
       return s
 
   # `1 + ...` for the subtitle
-  lens = { lv: 1 + len(li) for lv, li in handlabs.items() }
+  lens = {lv: 1 + len(li) for lv, li in handlabs.items()}
   max_len = max(lens.values())
   clen = col_len(max_len)
 
@@ -264,22 +268,14 @@ def horizontal_legend_columns(handlabs, relabel):
   def col_len(max_len):
     return min(max_len, 5)
 
-  return legend_columns(
-    handlabs=handlabs,
-    col_len=col_len,
-    relabel=relabel
-  )
+  return legend_columns(handlabs=handlabs, col_len=col_len, relabel=relabel)
 
 
 def vertical_legend_columns(handlabs, relabel):
   def col_len(max_len):
     return max(max_len, 10)
 
-  return legend_columns(
-    handlabs=handlabs,
-    col_len=col_len,
-    relabel=relabel
-  )
+  return legend_columns(handlabs=handlabs, col_len=col_len, relabel=relabel)
 
 
 class Mapper(object):
@@ -294,15 +290,31 @@ class Mapper(object):
   """
 
   def __init__(
-    self, deps=None,
-    rows=None, columns=None, colors=None, linestyles=None, markers=None,
-    alphas=None, linewidths=None, markersizes=None,
-    palette=None, stylemap=None, markermap=None,
-    alphamap=None, linewidthmap=None, markersizemap=None,
-    legend=None, legend_title=None, label_legend=None,
-    sharex=None, sharey=None,
-    aspect=None, figsize=None
-  ) :
+    self,
+    deps=None,
+    rows=None,
+    columns=None,
+    colors=None,
+    linestyles=None,
+    markers=None,
+    alphas=None,
+    linewidths=None,
+    markersizes=None,
+    palette=None,
+    stylemap=None,
+    markermap=None,
+    alphamap=None,
+    linewidthmap=None,
+    markersizemap=None,
+    legend=None,
+    legend_layout=None,
+    legend_title=None,
+    label_legend=None,
+    sharex=None,
+    sharey=None,
+    aspect=None,
+    figsize=None,
+  ):
     """Set up the figure, its layout and its subplots."""
     self.levels = {
       "row": rows,
@@ -315,7 +327,7 @@ class Mapper(object):
       "markersize": markersizes,
     }
 
-    self.used_levels = [ lvl for lvl in hierarchy if self.levels[lvl] is not None ]
+    self.used_levels = [lvl for lvl in hierarchy if self.levels[lvl] is not None]
     self.deps = validate_deps(deps, self.used_levels)
 
     self.levelmaps = {
@@ -343,23 +355,26 @@ class Mapper(object):
     self.fig = f
 
     self.spec = self.set_legend_grid(
-      legend,
-      title=legend_title,
-      relabel=label_legend
+      legend, layout=legend_layout, title=legend_title, relabel=label_legend
     )
 
     # improve constrained layout
     f.set_constrained_layout_pads(
       w_pad=0,
       h_pad=0,
-      wspace=0.01,
-      hspace=0.01,
+      wspace=0.05,
+      hspace=0.05,
     )
 
     self.add_subplots(sharex, sharey, subplot_kw)
 
-  def set_legend_grid(self, legend, title, relabel):
-    handlabs = self.legend_entries(relabel=relabel)
+  def set_legend_grid(self, legend, layout, title, relabel):
+    if layout is None or layout == "by_level":
+      handlabs = self.legend_entries_by_level(relabel=relabel)
+    elif layout == "by_element":
+      handlabs = self.legend_entries_by_element(relabel=relabel)
+    else:
+      raise ValueError("invalid legend_layout value")
     # FIXME to decide the number of columns,
     # also pass the figure size to determine the space available
     if legend == "horizontal":
@@ -381,11 +396,8 @@ class Mapper(object):
       # make a grid to partition between plots and legend
       w, h = self.fig.get_size_inches()
       needed = lh / h + 0.01
-      print("Space needed:", needed)
-      g_all = self.fig.add_gridspec(
-        2, 1,
-        height_ratios=[1 - needed, needed]
-      )
+      logger.debug("Space needed:", needed)
+      g_all = self.fig.add_gridspec(2, 1, height_ratios=[1 - needed, needed])
       # put the plots into the first row
       g = g_all[0].subgridspec(self.nrows, self.ncols)
       # put the legend into the second row
@@ -417,11 +429,8 @@ class Mapper(object):
       # make a grid to partition between plots and legend
       w, h = self.fig.get_size_inches()
       needed = lw / w + 0.01
-      print("Space needed:", needed)
-      g_all = self.fig.add_gridspec(
-        1, 2,
-        width_ratios=[1 - needed, needed]
-      )
+      logger.debug("Space needed:", needed)
+      g_all = self.fig.add_gridspec(1, 2, width_ratios=[1 - needed, needed])
       # put the plots into the first column
       g = g_all[0].subgridspec(self.nrows, self.ncols)
       # put the legend into the second column
@@ -438,10 +447,10 @@ class Mapper(object):
 
       # return grid for plots
       return g
-    elif legend is None:
+    elif legend is None or not legend:
       return self.fig.add_gridspec(self.nrows, self.ncols)
     else:
-      raise ValueError("legend not 'vertical', 'horizontal', or None")
+      raise ValueError("legend not 'vertical', 'horizontal', False, or None")
 
   def shared_with(self, axes, row, col, share):
     if isinstance(share, np.ndarray):
@@ -469,11 +478,7 @@ class Mapper(object):
     if xs is None:
       return [None]
     else:
-      return values(
-        self.levels[level],
-        deps=deps,
-        **kw
-      )
+      return values(self.levels[level], deps=deps, **kw)
 
   def add_subplots(self, sharex, sharey, subplot_kw):
     # copied from Figure.subplots
@@ -482,20 +487,13 @@ class Mapper(object):
       for col in range(self.ncols):
         subplot_kw["sharex"] = self.shared_with(axes, row, col, sharex)
         subplot_kw["sharey"] = self.shared_with(axes, row, col, sharey)
-        axes[row, col] = self.fig.add_subplot(
-          self.spec[row, col],
-          **subplot_kw
-        )
+        axes[row, col] = self.fig.add_subplot(self.spec[row, col], **subplot_kw)
 
     # turn off redundant tick labeling
     if sharex in ["col", "all"]:
       # turn off all but the bottom row
       for ax in axes[:-1, :].flat:
-        ax.xaxis.set_tick_params(
-          which='both',
-          labelbottom=False,
-          labeltop=False
-        )
+        ax.xaxis.set_tick_params(which="both", labelbottom=False, labeltop=False)
         ax.xaxis.offsetText.set_visible(False)
 
     # I cheat with the isinstance, for the more complex case
@@ -503,11 +501,7 @@ class Mapper(object):
     if isinstance(sharey, np.ndarray) or sharey in ["row", "all"]:
       # turn off all but the first column
       for ax in axes[:, 1:].flat:
-        ax.yaxis.set_tick_params(
-          which='both',
-          labelleft=False,
-          labelright=False
-        )
+        ax.yaxis.set_tick_params(which="both", labelleft=False, labelright=False)
         ax.yaxis.offsetText.set_visible(False)
 
     # despine all axes (top and right)
@@ -574,24 +568,27 @@ class Mapper(object):
     ax = self.axes.loc[column, row]
     f(ax=ax, **kw)
 
-  def iter(self, f, levels=None):
-    """Call `f` for every combination of level values."""
+  def sub_levels(self, levels=None):
     if levels is None:
       levels = self.used_levels
     else:
       # put in order and remove bad levels
-      levels = [ lvl for lvl in self.used_levels if lvl in levels ]
+      levels = [lvl for lvl in self.used_levels if lvl in levels]
+
+    return levels
 
-    levelvs = [ lvl + "v" for lvl in levels ]
+  def iter(self, f, levels=None):
+    """Call `f` for every combination of level values."""
+    levels = self.sub_levels(levels)
+
+    levelvs = [lvl + "v" for lvl in levels]
 
     def get_values(lvl, **kw):
       return self.level_values(lvl[:-1], **kw)
 
     def g(**kw):
       self.call_with_maps(
-        f=partial(self.call_with_ax, f=f),
-        levels=axes_levels(levels),
-        **kw
+        f=partial(self.call_with_ax, f=f), levels=axes_levels(levels), **kw
       )
 
     reciter(
@@ -620,8 +617,9 @@ class Mapper(object):
     for columnv in self.level_values("column", rowv=rowv):
       f(ax=self.axes.loc[columnv, rowv], columnv=columnv)
 
-  def legend_entries(self, levels=None, relabel=None):
+  def legend_entries_by_level(self, levels=None, relabel=None):
     """Dict of legend entries optionally relabeled using function `relabel` by level."""
+
     # legend labels
     def lab(c):
       if relabel:
@@ -638,7 +636,8 @@ class Mapper(object):
     # linestyle level
     def line(ls):
       return Line2D(
-        [0.], [0.],
+        [0.0],
+        [0.0],
         linestyle=ls,
         color="black",
       )
@@ -646,7 +645,8 @@ class Mapper(object):
     # marker level
     def mark(mk):
       return Line2D(
-        [0.], [0.],
+        [0.0],
+        [0.0],
         linestyle="",
         marker=mk,
         markerfacecolor="black",
@@ -654,15 +654,13 @@ class Mapper(object):
 
     # alpha level
     def alpha_patch(a):
-      return Patch(
-        alpha=a,
-        color="black"
-      )
+      return Patch(alpha=a, color="black")
 
     # linewidth level
     def linew(w):
       return Line2D(
-        [0.], [0.],
+        [0.0],
+        [0.0],
         linewidth=w,
         color="black",
       )
@@ -670,18 +668,17 @@ class Mapper(object):
     # markersize level
     def point(sz):
       return Line2D(
-        [0.], [0.],
+        [0.0],
+        [0.0],
         linestyle="",
         marker="o",
         markersize=sz,
-        markerfacecolor=(1., 1., 1., 0.),
+        markerfacecolor=(1.0, 1.0, 1.0, 0.0),
         color="black",
       )
 
-    if levels is None:
-      levels = self.used_levels
-    # exclude 'row'/'column', for which we don't need a legend
-    levels = axes_levels(levels)
+    # only keep valid levels, in order, and exclude 'row'/'column' levels for which we don't need a legend
+    levels = axes_levels(self.sub_levels(levels))
 
     handles = {
       "color": patch,
@@ -696,15 +693,63 @@ class Mapper(object):
       return handles[lvl](self.levelmaps[lvl][v])
 
     handlabs = {
-      lvl: [
-        (handle(lvl, v), lab(v))
-        for v in unique_leaf_values(self.levels[lvl])
-      ]
+      lvl: [(handle(lvl, v), lab(v)) for v in unique_leaf_values(self.levels[lvl])]
       for lvl in levels
     }
 
     return handlabs
 
+  def legend_entries_by_element(self, levels=None, relabel=None):
+    # legend labels
+    def lab(*cs):
+      if relabel:
+        return relabel(*cs)
+      else:
+        return " | ".join([str(c) for c in cs])
+
+    def rich_handle(
+      color="black",
+      linestyle="",
+      linewidth=None,
+      marker=None,
+      markersize=None,
+      alpha=None,
+    ):
+      return Line2D(
+        [0.0],
+        [0.0],
+        color=color,
+        linestyle=linestyle,
+        linewidth=linewidth,
+        marker=marker,
+        markersize=markersize,
+        markerfacecolor=color,
+        alpha=alpha,
+      )
+
+    # only keep valid levels, in order, and exclude 'row'/'column' levels for which we don't need a legend
+    levels = axes_levels(self.sub_levels(levels))
+    levelvs = [lvl + "v" for lvl in levels]
+
+    handlabs = list()
+
+    def f(**kw):
+      labs = [v for k, v in kw.items() if k.endswith("v")]
+      lvs = {k: v for k, v in kw.items() if not k.endswith("v")}
+      label = lab(*labs)
+      handle = rich_handle(**lvs)
+      handlabs.append((handle, label))
+
+    def get_values(lvl, **kw):
+      return self.level_values(lvl[:-1], **kw)
+
+    def g(**kw):
+      self.call_with_maps(f=f, levels=levels, **kw)
+
+    reciter(g, levels=levelvs, get_values=get_values)
+
+    return {"": handlabs}
+
   def _redraw(self):
     self.fig.draw(self.fig.canvas.get_renderer())