diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 07d14a8bac2f1bdb46a69cc9b09c493e6e15e8e4..b5b3d9165689d8cf2e25b91490f45a9e6580d070 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -1,5 +1,6 @@
 # make sure to pin image tag for ease of reproducibility
-image: continuumio/miniconda3:4.9.2
+image: continuumio/miniconda3
+#image: continuumio/miniconda3:4.9.2
 # image: mambaorg/micromamba
 
 variables:
@@ -69,7 +70,8 @@ workflow:
     - mkdir -p ${CONDA_ENVS}
     - conda config --prepend envs_dirs ${CONDA_ENVS}
     - conda update -n base conda &&
-    - conda install 'mamba<=1.4.5' -n base -c conda-forge &&
+    #- conda install 'mamba<=1.4.5' -n base -c conda-forge &&
+    - conda install mamba -n base -c conda-forge &&
     - conda update mamba -n base -c conda-forge
 
 .conda_setup: &conda_setup
@@ -186,7 +188,7 @@ docs:
 
         # libenchant is needed for sphinxcontrib.spelling by way of PyEnchant,
         # and is not easily installed via conda
-        - apt-get update --allow-releaseinfo-change && apt-get install -yqq libenchant-dev
+        - apt-get update --allow-releaseinfo-change && apt-get install -yqq libenchant-2-dev
     script:
         - cd docs/
         - make html
diff --git a/dnadna/__init__.py b/dnadna/__init__.py
index 91affc411c3410c2b4aafdf6819706443df967db..3bb76ace57fa7a6c9038cb8cb22b8d38fedc606a 100644
--- a/dnadna/__init__.py
+++ b/dnadna/__init__.py
@@ -24,7 +24,8 @@ BUILTIN_PLUGINS = [
     'dnadna.nets',
     'dnadna.optim',
     'dnadna.simulator',
-    'dnadna.transforms'
+    'dnadna.transforms',
+    'dnadna.collate'
 ]
 """
 List all internal modules that provide a `Pluggable` class.
diff --git a/dnadna/collate.py b/dnadna/collate.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6fa2042395ac3d055446c9af8ed3e4d1190c96f
--- /dev/null
+++ b/dnadna/collate.py
@@ -0,0 +1,334 @@
+import abc
+from math import ceil, floor
+
+import numpy as np
+import torch
+
+from .utils.config import Config
+from .utils.plugins import Pluggable
+
+
+class Collate(Pluggable, metaclass=abc.ABCMeta):
+    """
+    Parent Pluggable class to inherent from when the user want to write their
+    own collate_batch function.
+
+    The user must define a schema and can use the implementation of CollateTraining
+    as an example of how the class should look. Then the main function to implement is
+    _collate_batch, which will be called by default by the __call__ function.
+
+    Here is an example using the class CollateTraining:
+
+    from dnadna.collate import CollateTraining
+    from dnadna.utils.config import Config
+    collate = CollateTraining.from_config(config = Config({
+         'collate': {
+             'snp_dim': 'max',
+             'indiv_dim': 'max',
+             'value_fill': -1,
+             'pad_right': True,
+             'pad_left': False,
+             'pad_bottom': True,
+             'pad_top': False
+         }
+    }))
+    collated_batch = collate(batch)
+    """
+
+    schema = {}
+
+    def __init__(self, config):
+        self.padding_params = config
+
+    def __repr__(self):
+        return str(self.padding_params)
+
+    def __str__(self):
+        return str(self.padding_params)
+
+    def __call__(self, batch):
+        return self._collate_batch(batch)
+
+    @abc.abstractmethod
+    def _collate_batch(self, batch):
+        """
+        Specifies how multiple scenario samples are collated into batches.
+
+        Each batch element is a single element as returned by
+        ``DNATrainingDataset.__getitem__``: ``(scenario_idx, replicate_idx,
+        snp_sample, target)``.
+
+        It is the main function to implement when heriting from Collate.
+        See CollateTraining for an example of implementation
+        """
+
+    @classmethod
+    def from_config(cls, config, validate=True):
+        """
+        Initialize a `Collate` from the batch configuration from a
+        config file.
+
+        Examples
+        --------
+
+        >>> from dnadna.collate import CollateTraining
+        >>> from dnadna.utils.config import Config
+        >>> collate = CollateTraining.from_config(config = Config({
+        ...     'collate': {
+        ...         'snp_dim': 'max',
+        ...         'indiv_dim': 'max',
+        ...         'value_fill': -1,
+        ...         'pad_right': True,
+        ...         'pad_left': False,
+        ...         'pad_bottom': True,
+        ...         'pad_top': False
+        ...      }
+        ... }))
+        >>> collate
+        Config({'snp_dim': 'max', 'indiv_dim': 'max', 'value_fill': -1, 'pad_right': True,
+                'pad_left': False, 'pad_bottom': True, 'pad_top': False})
+        """
+
+        if not isinstance(config, Config):
+            config = Config(config)
+
+        if validate:
+            config.validate(cls.get_schema())
+        config = config['collate']
+
+        return cls(config)
+
+    @classmethod
+    def get_schema(cls):
+        """
+        Returns the schema of the Collate implemented
+        """
+
+        return cls.schema
+
+
+class CollateTraining(Collate):
+
+    schema = {
+        'properties': {
+            'value_fill': {'type': 'number'},
+            'snp_dim': {'type': ['string', 'number']},
+            'indiv_dim': {'type': ['string', 'number']},
+            'pad_right': {'type': 'boolean'},
+            'pad_left': {'type': 'boolean'},
+            'pad_bottom': {'type': 'boolean'},
+            'pad_top': {'type': 'boolean'}
+        }
+    }
+
+    def _collate_batch(self, batch):
+        """
+        Specifies how multiple scenario samples are collated into batches.
+
+        Each batch element is a single element as returned by
+        ``DNATrainingDataset.__getitem__``: ``(scenario_idx, replicate_idx,
+        snp_sample, target)``.
+
+        For input samples and targets are collated into batches "vertically",
+        so that the size of the first dimension represents the number of items
+        in a batch.
+
+        Examples
+        --------
+
+        >>> import torch
+        >>> from dnadna.collate import CollateTraining
+        >>> from dnadna.utils.config import Config
+        >>> from dnadna.snp_sample import SNPSample
+        >>> fake_snps = [torch.rand(3, 3 + i) for i in range(5)]
+        >>> fake_snps = [SNPSample(s[1:], s[0]) for s in fake_snps]
+        >>> fake_params = [torch.rand(4, dtype=torch.float64) for _ in range(5)]
+        >>> fake_batch = list(zip(range(5), [0] * 5, fake_snps, fake_params))
+        >>> config = Config({
+        ...     'collate': {
+        ...         'snp_dim': 'max',
+        ...         'indiv_dim': 'max',
+        ...         'value_fill': -1,
+        ...         'pad_right': True,
+        ...         'pad_left': False,
+        ...         'pad_bottom': True,
+        ...         'pad_top': False
+        ...      }
+        ... })
+        >>> collate_func = CollateTraining.from_config(config=config)
+        >>> collated_batch = collate_func(fake_batch)
+        >>> scenario_idxs, inputs, targets = collated_batch
+        >>> bool((torch.arange(5) == scenario_idxs).all())
+        True
+        >>> inputs.shape  # last dim should be num SNPs in largest fake SNP
+        torch.Size([5, 3, 7])
+        >>> bool((inputs[0,:3,:3] == fake_snps[0].tensor).all())
+        True
+        >>> bool((inputs[0,3:,3:] == -1).all())
+        True
+        >>> bool((inputs[-1] == fake_snps[-1].tensor).all())
+        True
+        >>> targets.shape
+        torch.Size([5, 4])
+        >>> [bool((fake_params[bat].float() == targets[bat]).all())
+        ...  for bat in range(targets.shape[0])]
+        [True, True, True, True, True]
+        """
+
+        # filter any missing samples out of the batch
+        batch = list(filter(lambda it: it[2] is not None, batch))
+
+        # we have to consider the possibility of an empty batch (which could
+        # happen if batch_size=1 and the sample was missing)
+        if not batch:
+            return None
+
+        scen_idxs, _, samples, targets = zip(*batch)
+
+        # add padding so that all input SNPs are the same size and shape
+        # (though should all have the same number of rows, but may have
+        # different number of SNPs (columns)
+        # the value -1 is used for padded regions
+        # TODO: Question: Must nets explicitly account for this padding, and if
+        # so, how?  ReLU?
+        # NOTE: This extra step of filling unevenly-sized matrices with -1
+        # is unnecessary if we are using a dataset with uniform=True, so
+        # there should be an option to skip this step entirely.  For now we
+        # just check if all inputs are the same size
+        inputs = [s.tensor for s in samples]
+
+        pad_top = self.padding_params['pad_top']
+        pad_bottom = self.padding_params['pad_bottom']
+        pad_left = self.padding_params['pad_left']
+        pad_right = self.padding_params['pad_right']
+
+        if not (pad_top or pad_bottom) or not (pad_left or pad_right):
+            raise ValueError("Invalid padding configuration. At least one horizontal"
+                             "and one vertical value must be set to true.")
+
+        snp_dim = self.padding_params['snp_dim']
+        indiv_dim = self.padding_params['indiv_dim']
+        filler = self.padding_params["value_fill"]
+        max_indiv_dim = max(input.shape[1] for input in inputs)
+        max_snp_dim = max(input.shape[0] for input in inputs)
+
+        # Determine target dimensions based on input parameters
+        target_indiv_dim = indiv_dim if isinstance(indiv_dim, int) else max_indiv_dim
+        target_snp_dim = snp_dim if isinstance(snp_dim, int) else max_snp_dim
+
+        # Initialize a tensor to store the padded inputs
+        new_inputs = torch.zeros((len(inputs), target_snp_dim, target_indiv_dim))
+
+        for i, input in enumerate(inputs):
+
+            # This case is when "dim_snp" and "dim_indiv" are set to max
+            if target_indiv_dim == max_indiv_dim and target_snp_dim == max_snp_dim:
+                # Calculate padding dimensions based on boolean parameters
+                if not (pad_bottom and pad_top) and not (pad_left and pad_right):
+                    pad_dims = (
+                        (0 if not pad_top else max_snp_dim - input.shape[0],
+                         0 if not pad_bottom else max_snp_dim - input.shape[0]),
+
+                        (0 if not pad_left else max_indiv_dim - input.shape[1],
+                         0 if not pad_right else max_indiv_dim - input.shape[1])
+                    )
+                elif (pad_bottom and pad_top) and not (pad_left and pad_right):
+                    pad_dims = (
+                        # top then bottom
+                        # If dimension doesn't divide by 2, the extra padding is done on the bottom
+                        (floor((max_snp_dim - input.shape[0]) / 2),
+                         ceil((max_snp_dim - input.shape[0]) / 2)),
+
+                        (0 if not pad_left else max_indiv_dim - input.shape[1],
+                         0 if not pad_right else max_indiv_dim - input.shape[1])
+                    )
+                elif not (pad_bottom and pad_top) and (pad_left and pad_right):
+                    # left then right
+                    # If dimension doesn't divide by 2, the extra padding is done on the right
+                    pad_dims = (
+                        (0 if not pad_top else max_snp_dim - input.shape[0],
+                         0 if not pad_bottom else max_snp_dim - input.shape[0]),
+
+                        (floor((max_indiv_dim - input.shape[0]) / 2),
+                         ceil((max_indiv_dim - input.shape[0]) / 2)),
+                    )
+                else:
+                    pad_dims = (
+                        (floor((max_snp_dim - input.shape[0]) / 2),
+                         ceil((max_snp_dim - input.shape[0]) / 2)),
+
+                        (floor((max_indiv_dim - input.shape[0]) / 2),
+                         ceil((max_indiv_dim - input.shape[0]) / 2)),
+                    )
+
+            # This case is when "dim_snp" and "dim_indiv" are set to integer values
+            else:
+                if not (pad_bottom and pad_top) and not (pad_left and pad_right):
+                    pad_dims = (
+                        (max(0, (0 if not pad_top else target_snp_dim - input.shape[0])),
+                         max(0, 0 if not pad_bottom else target_snp_dim - input.shape[0])),
+
+                        (max(0, (0 if not pad_left else target_indiv_dim - input.shape[1])),
+                         max(0, 0 if not pad_right else target_indiv_dim - input.shape[1]))
+                    )
+                elif (pad_bottom and pad_top) and not (pad_left and pad_right):
+                    pad_dims = (
+                        # top then bottom
+                        # If dimension doesn't divide by 2, the extra padding is done on the bottom
+                        (max(0, (floor((target_snp_dim - input.shape[0]) / 2))),
+                         max(0, ceil((target_snp_dim - input.shape[0]) / 2))),
+
+                        (max(0, (0 if not pad_left else target_indiv_dim - input.shape[1])),
+                         max(0, 0 if not pad_right else target_indiv_dim - input.shape[1]))
+                    )
+                elif not (pad_bottom and pad_top) and (pad_left and pad_right):
+                    # left then right
+                    # If dimension doesn't divide by 2, the extra padding is done on the right
+                    pad_dims = (
+                        (max(0, (0 if not pad_top else target_snp_dim - input.shape[0])),
+                         max(0, 0 if not pad_bottom else target_snp_dim - input.shape[0])),
+
+                        (max(0, (floor((target_indiv_dim - input.shape[0]) / 2))),
+                         max(0, ceil((target_indiv_dim - input.shape[0]) / 2))),
+                    )
+                else:
+                    pad_dims = (
+                        (max(0, (floor((target_snp_dim - input.shape[0]) / 2))),
+                         max(0, ceil((target_snp_dim - input.shape[0]) / 2))),
+
+                        (max(0, (floor((target_indiv_dim - input.shape[0]) / 2))),
+                         max(0, ceil((target_indiv_dim - input.shape[0]) / 2))),
+                    )
+
+            # Adjust dimensions if needed
+            padded_input = np.pad(input, pad_dims, mode='constant', constant_values=filler)
+            adjusted_input = padded_input[:target_snp_dim, :target_indiv_dim]
+
+            new_inputs[i, :, :] = torch.as_tensor(adjusted_input).clone().detach()
+
+        # Concatenate targets and ensure they are converted to single-precision
+        # floats for passing to GPU devices
+        # NOTE: targets can be all None if the Dataset was initialized without
+        # scenario_params
+        if targets[0] is not None:
+            n_parameters = len(targets[0])
+            targets = torch.cat(targets).reshape(-1, n_parameters).float()
+
+        return [torch.tensor(scen_idxs, dtype=torch.long), new_inputs, targets]
+
+
+class CollateInference(CollateTraining):
+
+    def _collate_batch(self, batch):
+        """
+        Like `CollateTraining.collate_batch` but also returns the
+        replicate indices and sample paths (e.g. filenames) which are used
+        in the prediction output.
+        """
+
+        rep_idxs = [b[1] for b in batch]
+        paths = [b[2].path for b in batch]
+        out = super()._collate_batch(batch)
+        return (out[:1] +
+                [torch.tensor(rep_idxs, dtype=torch.long), paths] +
+                out[1:])
diff --git a/dnadna/datasets.py b/dnadna/datasets.py
index 7f8985fc3b164b1c68a28504c909b419e319750c..1c30b504ff7852fd5775ddb32e6779e502fc3bab 100644
--- a/dnadna/datasets.py
+++ b/dnadna/datasets.py
@@ -1283,94 +1283,6 @@ how to know if ubunt
 
         return xfs
 
-    @staticmethod
-    def collate_batch(batch):
-        """
-        Specifies how multiple scenario samples are collated into batches.
-
-        Each batch element is a single element as returned by
-        ``DNATrainingDataset.__getitem__``: ``(scenario_idx, replicate_idx,
-        snp_sample, target)``.
-
-        For input samples and targets are collated into batches "vertically",
-        so that the size of the first dimension represents the number of items
-        in a batch.
-
-        Examples
-        --------
-
-        >>> import torch
-        >>> from dnadna.datasets import DNATrainingDataset
-        >>> from dnadna.snp_sample import SNPSample
-        >>> fake_snps = [torch.rand(3, 3 + i) for i in range(5)]
-        >>> fake_snps = [SNPSample(s[1:], s[0]) for s in fake_snps]
-        >>> fake_params = [torch.rand(4, dtype=torch.float64) for _ in range(5)]
-        >>> fake_batch = list(zip(range(5), [0] * 5, fake_snps, fake_params))
-        >>> collated_batch = DNATrainingDataset.collate_batch(fake_batch)
-        >>> scenario_idxs, inputs, targets = collated_batch
-        >>> bool((torch.arange(5) == scenario_idxs).all())
-        True
-        >>> inputs.shape  # last dim should be num SNPs in largest fake SNP
-        torch.Size([5, 3, 7])
-        >>> bool((inputs[0,:3,:3] == fake_snps[0].tensor).all())
-        True
-        >>> bool((inputs[0,3:,3:] == -1).all())
-        True
-        >>> bool((inputs[-1] == fake_snps[-1].tensor).all())
-        True
-        >>> targets.shape
-        torch.Size([5, 4])
-        >>> [bool((fake_params[bat].float() == targets[bat]).all())
-        ...  for bat in range(targets.shape[0])]
-        [True, True, True, True, True]
-        """
-
-        # filter any missing samples out of the batch
-        batch = list(filter(lambda it: it[2] is not None, batch))
-
-        # we have to consider the possibility of an empty batch (which could
-        # happen if batch_size=1 and the sample was missing)
-        if not batch:
-            return None
-
-        scen_idxs, _, samples, targets = zip(*batch)
-
-        # add padding so that all input SNPs are the same size and shape
-        # (though should all have the same number of rows, but may have
-        # different number of SNPs (columns)
-        # the value -1 is used for padded regions
-        # TODO: Question: Must nets explicitly account for this padding, and if
-        # so, how?  ReLU?
-        # NOTE: This extra step of filling unevenly-sized matrices with -1
-        # is unnecessary if we are using a dataset with uniform=True, so
-        # there should be an option to skip this step entirely.  For now we
-        # just check if all inputs are the same size
-        inputs = [s.tensor for s in samples]
-        example_shape = inputs[0].shape
-        if any(inp.shape != example_shape for inp in inputs):
-            max_ind_batch = np.max([inp.shape[0] for inp in inputs])
-            max_snp_batch = np.max([inp.shape[1] for inp in inputs])
-            new_inputs_shape = (len(batch), max_ind_batch, max_snp_batch)
-            new_inputs = torch.full(new_inputs_shape, -1, dtype=torch.float)
-            for idx, inp in enumerate(inputs):
-                # NOTE: I feel like there should be a more efficient way to fill an
-                # N-D tensor for a sequence of (N - 1)-D tensors, but at the moment
-                # I can't find it; to revisit
-                new_inputs[idx, :inp.shape[0], :inp.shape[1]] = inp
-        else:
-            # Just stack the inputs
-            new_inputs = torch.stack(inputs).float()
-
-        # Concatenate targets and ensure they are converted to single-precision
-        # floats for passing to GPU devices
-        # NOTE: targets can be all None if the Dataset was initialized without
-        # scenario_params
-        if targets[0] is not None:
-            n_parameters = len(targets[0])
-            targets = torch.cat(targets).reshape(-1, n_parameters).float()
-
-        return [torch.tensor(scen_idxs, dtype=torch.long), new_inputs, targets]
-
     def _iter_samples(self):
         """
         Iterates over each replicate for each scenario in
diff --git a/dnadna/defaults/training.yml b/dnadna/defaults/training.yml
index 30ba6db72a897bd87d83498957c8f336f475fc11..ba9268486dee8cc2820996d0ff615ce8a1520333 100644
--- a/dnadna/defaults/training.yml
+++ b/dnadna/defaults/training.yml
@@ -14,6 +14,14 @@ dataset_transforms:
     - snp_format: concat
     - validate_snp:
           uniform_shape: false
+collate:
+    snp_dim: max
+    indiv_dim: max
+    value_fill: -1
+    pad_right: True
+    pad_left: False
+    pad_bottom: True
+    pad_top: False
 n_epochs: 1
 evaluation_interval: 1
 batch_size: 8
diff --git a/dnadna/examples/one_event.py b/dnadna/examples/one_event.py
index 6bc3d8381ad8a4c20db8f16133812d1ba62c57a3..e351e8e3a047feae5e9f0ddd30f462ea651795ae 100644
--- a/dnadna/examples/one_event.py
+++ b/dnadna/examples/one_event.py
@@ -132,6 +132,14 @@ DEFAULT_ONE_EVENT_TRAINING_CONFIG = Config({
     'evaluation_interval': 10,
     'loader_num_workers': 4,
     'device': '',
+    'collate':
+        {'snp_dim': 'max',
+        'indiv_dim': 'max',
+        'value_fill': -1,
+        'pad_right': True,
+        'pad_left': False,
+        'pad_bottom': True,
+        'pad_top': False}
 })
 
 
diff --git a/dnadna/nets.py b/dnadna/nets.py
index 82506f244fb353e61f744b9bcb79e3913a991c6b..c1a2a8e7fbdf394868a4df01991695649c529da7 100644
--- a/dnadna/nets.py
+++ b/dnadna/nets.py
@@ -163,8 +163,8 @@ class SPIDNA(Network):
     Publication
     -----------
     T. Sanchez, J. Cury, G. Charpiat, et F. Jay,
-    « Deep learning for population size history inference: Design, comparison and
-    combination with approximate Bayesian computation »,
+    « Deep learning for population size history inference: Design, comparison and
+    combination with approximate Bayesian computation »,
     Mol Ecol Resour, p. 1755‑0998.13224, juill. 2020, doi: 10.1111/1755-0998.13224.
 
     Parameters
@@ -236,8 +236,8 @@ class CustomCNN(Network):
     Publication
     -----------
     T. Sanchez, J. Cury, G. Charpiat, et F. Jay,
-    « Deep learning for population size history inference: Design, comparison and
-    combination with approximate Bayesian computation »,
+    « Deep learning for population size history inference: Design, comparison and
+    combination with approximate Bayesian computation »,
     Mol Ecol Resour, p. 1755‑0998.13224, juill. 2020, doi: 10.1111/1755-0998.13224.
 
     Parameters
diff --git a/dnadna/prediction.py b/dnadna/prediction.py
index 9f8945c9db686e16eb5a90f47912a64a3ee21303..2b022178ddefea56214b242780568ca12b763d45 100644
--- a/dnadna/prediction.py
+++ b/dnadna/prediction.py
@@ -22,6 +22,7 @@ from .nets import Network
 from .params import ParamSet
 from .utils.config import Config, ConfigMixIn
 from .utils.misc import stdio_redirect_tqdm
+from .collate import CollateInference
 
 log = logging.getLogger(__name__)
 
@@ -52,26 +53,6 @@ class DNAPredictionDataset(DatasetTransformationMixIn):
             self.preprocess = False
             self.min_snp = self.min_indiv = None
 
-    @classmethod
-    def collate_batch(cls, batch):
-        """
-        Like `DatasetTransformationMixIn.collate_batch` but also returns the
-        replicate indices and sample paths (e.g. filenames) which are used
-        in the prediction output.
-
-        .. note::
-
-            This is implemented as a `classmethod` which is necessary to be
-            able to call the super-class's ``collate_batch``.
-        """
-
-        rep_idxs = [b[1] for b in batch]
-        paths = [b[2].path for b in batch]
-        out = super().collate_batch(batch)
-        return (out[:1] +
-                [torch.tensor(rep_idxs, dtype=torch.long), paths] +
-                out[1:])
-
     def _iter_samples(self):
         idx = 0
         for _, sample in super()._iter_samples():
@@ -99,6 +80,7 @@ class Predictor(ConfigMixIn):
         self.net = net
         # Place the network in evaluation mode
         self.net.eval()
+        self.collate_func = CollateInference.from_config(config)
         self.params = ParamSet(self.learned_params)
         # Update predict_transforms attribute with the merge of
         # dataset_transforms and predict_transforms (which has priority):
@@ -194,9 +176,17 @@ class Predictor(ConfigMixIn):
         ...             'type': 'classification',
         ...             'classes': ['yes', 'no']
         ...         }
+        ...     },
+        ...     'collate': {
+        ...         'snp_dim': 'max',
+        ...         'indiv_dim': 'max',
+        ...         'value_fill': -1,
+        ...         'pad_right': True,
+        ...         'pad_left': False,
+        ...         'pad_bottom': True,
+        ...         'pad_top': False
         ...     }
         ... })
-        ...
 
         Initialize a fake instance of the net; this along with the config are
         the bare minimum needed to instantiate a `Predictor`:
@@ -474,7 +464,7 @@ class Predictor(ConfigMixIn):
         loader = DataLoader(dataset=dataset, batch_size=batch_size,
                             num_workers=loader_num_workers,
                             pin_memory=device.type == 'cuda',
-                            collate_fn=DNAPredictionDataset.collate_batch,
+                            collate_fn=self.collate_func,
                             sampler=sampler)
         if device.type == 'cuda':
             self.net = nn.DataParallel(self.net, device_ids=os.environ['CUDA_VISIBLE_DEVICES'])
diff --git a/dnadna/schemas/training.yml b/dnadna/schemas/training.yml
index 8ff49d1fb60f7f5986ff05c07c8d9a21556e500b..bf9d29d57efc8e1292be8b9d0292ec10c5be29e2 100644
--- a/dnadna/schemas/training.yml
+++ b/dnadna/schemas/training.yml
@@ -61,11 +61,26 @@ allOf:
               minimum: 1
               default: 1
 
-          batch_size:
-              description: "sample batch size to train on"
-              type: "integer"
-              minimum: 1
-              default: 1
+          collate:
+              description: >-
+                options for the collate function applied to the batch. Different parameters
+                are possible. snp_dim is the number of snp per batch. If set to max, each
+                element of the batch is padded up to the maximum dimension of snp of the batch.
+                If set to an int, all snp are padded/cutted to this dimension. indiv_dim
+                works the same way for the number of individuals. value_fill is the number to
+                fill with, set to -1 by default. It is also possible to chose from which sad the 
+                padding is done. The default is pad_right and pad_bottom. If two dual variables are
+                set to true (i.e. pad_right and pad_left), the padding will be done on both sides equally. 
+                If the size does not match, it will pad in priority on the right side and on the bottom side.
+              default:
+                snp_dim: max
+                indiv_dim: max
+                value_fill: -1
+                pad_right: True
+                pad_left: False
+                pad_bottom: True
+                pad_top: False
+              "$ref": "py-obj:dnadna.schemas.plugins.collate"
 
           loader_num_workers:
               description: "number of subprocesses to use for data loading"
diff --git a/dnadna/training.py b/dnadna/training.py
index a58589dd96c597f2d0c55ab0de091210560145d0..723862a15269a16244a6404ac299353f2c464d38 100644
--- a/dnadna/training.py
+++ b/dnadna/training.py
@@ -24,6 +24,7 @@ import tqdm
 from torch.utils.data import DataLoader, sampler
 
 from . import __version__, DNADNAWarning
+from .collate import CollateTraining
 from .datasets import DNATrainingDataset
 from .nets import Network
 from .optim import Optimizer, LRScheduler
@@ -190,6 +191,12 @@ class ModelTrainer(ConfigMixIn):
         config = Config.from_file(filename, validate=False, **kwargs)
         return cls(config=config, validate=True, progress_bar=progress_bar)
 
+    @prepared_property
+    def collate(self):
+        """
+        Returns the arguments that the `Collate` of the model trainer.
+        """
+
     @prepared_property
     def net(self):
         """
@@ -380,6 +387,8 @@ class ModelTrainer(ConfigMixIn):
                                          learned_params=self._learned_params,
                                          validate=False)
 
+        self._collate = CollateTraining.from_config(self.config)
+
         self._dataset = dataset
 
         # Load the dataset
@@ -544,7 +553,7 @@ class ModelTrainer(ConfigMixIn):
                               num_workers=num_workers,
                               worker_init_fn=self._worker_init_fn,
                               pin_memory=self.device.type == 'cuda',
-                              collate_fn=DNATrainingDataset.collate_batch)
+                              collate_fn=self.collate)
 
         training_sampler = init_sampler(dataset.training_set)
         validation_sampler = init_sampler(dataset.validation_set)
@@ -768,7 +777,8 @@ class ModelTrainer(ConfigMixIn):
             'dnadna_version': __version__,
             'state_dict': model_state_dict,
             'state_dict_optimizer': self.optimizer.state_dict(),
-            'epoch': self.current_epoch
+            'epoch': self.current_epoch,
+            'collate': self.config['collate']
         })
         # In case of classification task, there are no train_mean nor train_std
         try:
@@ -978,12 +988,16 @@ class ModelTrainer(ConfigMixIn):
             param = self.learned_params.params[param_name]
             target_slice, output_slice = \
                 self.learned_params.param_slices[param_name]
+
             target = targets[:, target_slice]
+
             output = outputs[:, output_slice]
 
             if param['type'] == 'regression':
                 # Compute error metrics
+
                 squared_err = (output - target) ** 2
+
                 errors[:, target_slice] = squared_err
 
                 # For all regression params, omit samples where the target
diff --git a/dnadna/transforms.py b/dnadna/transforms.py
index e4f7f5e3d4d6556ba712eebea2e6dc581fa82cb5..8fdd01a08561ecdfae2647e25b8ef36db2268381 100644
--- a/dnadna/transforms.py
+++ b/dnadna/transforms.py
@@ -613,6 +613,7 @@ class ReformatPosition(Transform):
 
     def __call__(self, data):
         snp, lp, scen = data
+        pos = snp.pos
         input_pos_format = snp.full_pos_format
         new_pos_format = {}
         chromosome_size = input_pos_format.get('chromosome_size',
@@ -620,8 +621,6 @@ class ReformatPosition(Transform):
         initial_position = input_pos_format.get('initial_position',
                                                 self.initial_position)
 
-        pos = snp.pos
-
         # Normalize/de-normalize first; this makes it easier to specify
         # normalization when converting to/from distances
         if (self.normalized is not None and
diff --git a/dnadna/utils/plugins.py b/dnadna/utils/plugins.py
index 07eb624c664014d51471cca39fae6b12a2492865..88afc7ff4b7b5bd56ffd163447a87f08be390ac9 100644
--- a/dnadna/utils/plugins.py
+++ b/dnadna/utils/plugins.py
@@ -242,7 +242,6 @@ class Pluggable:
 
         >>> del Pluggable._registry['my_pluggable']  # cleanup
         """
-
         if Pluggable in cls.__bases__:
             registry = Pluggable._registry
             # Give the pluggable its own registry of plugins overriding
@@ -251,7 +250,6 @@ class Pluggable:
             cls.pluggable = cls
         else:
             registry = cls._registry
-
         plugin_name = decamelcase(cls.__name__)
 
         if '_plugin_name' not in cls.__dict__:
diff --git a/docs/training.rst b/docs/training.rst
index 19d4b0fb8d72645743ef2dc35451cb9457decc8a..48d7aa4e57a4a0e74abcc859e9911a3caef1c425 100644
--- a/docs/training.rst
+++ b/docs/training.rst
@@ -144,6 +144,7 @@ They are specified through each of the following:
     * ``batch_size``: number of examples in a batch
     * ``n_epochs``: number of epochs
     * ``evaluation_interval``: interval (number of batches processed) between two validation steps.
+    * ``collate``: parameters to configure the batching process, mostly padding.
 
 .. code-block:: yaml
 
@@ -178,6 +179,24 @@ They are specified through each of the following:
             betas: [0.9, 0.999]
             eps: 1.0e-08
             amsgrad: false
+            
+    # options for the collate function applied to the batch. Different parameters
+    # are possible. snp_dim is the number of snp per batch. If set to max, each
+    # element of the batch is padded up to the maximum dimension of snp of the batch.
+    # If set to an int, all snp are padded/cutted to this dimension. indiv_dim
+    # works the same way for the number of individuals. value_fill is the number to
+    # fill with, set to -1 by default. It is also possible to chose from which sad the 
+    # padding is done. The default is pad_right and pad_bottom. If two dual variables are
+    # set to true (i.e. pad_right and pad_left), the padding will be done on both sides equally. 
+    # If the size does not match, it will pad in priority on the right side and on the bottom side.
+    collate:
+        snp_dim: max
+        indiv_dim: max
+        value_fill: -1
+        pad_right: True
+        pad_left: False
+        pad_bottom: True
+        pad_top: False
 
     # name and parameters of the scheduler to use; all built-in schedulers from the
     # torch.optim.lr_scheduler package are available for use here, and you can also
diff --git a/tests/test_collate.py b/tests/test_collate.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b09bacb48d42ede51504c225f8437370bff7454
--- /dev/null
+++ b/tests/test_collate.py
@@ -0,0 +1,942 @@
+import unittest
+from dnadna.utils.config import Config
+from dnadna.collate import CollateTraining
+from dnadna.snp_sample import SNPSample
+import torch
+
+fake_snps = [torch.zeros(3 + i, 3 + i) for i in range(5)]
+fake_snps = [SNPSample(s[1:], s[0]) for s in fake_snps]
+fake_params = [torch.rand(4, dtype=torch.float64) for _ in range(5)]
+fake_batch = list(zip(range(5), [0] * 5, fake_snps, fake_params))
+
+
+class TestCollate(unittest.TestCase):
+
+    def test_collate_batch(self):
+        """Test if the error is working properly"""
+        config_0 = Config({
+            'collate': {
+                'snp_dim': 'max',
+                'indiv_dim': 'max',
+                'value_fill': -1,
+                'pad_right': False,
+                'pad_left': False,
+                'pad_bottom': False,
+                'pad_top': True
+            }
+        })
+
+        collate = CollateTraining.from_config(config_0)
+        self.assertRaises(ValueError, collate, fake_batch)
+
+        """
+        this format should be
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+         0  0  0 -1 -1 -1 -1
+         0  0  0 -1 -1 -1 -1
+         0  0  0 -1 -1 -1 -1
+
+        and then increase one by one for each input,
+        to enventually reach
+
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        """
+
+        config_1 = Config({
+            'collate': {
+                'snp_dim': 'max',
+                'indiv_dim': 'max',
+                'value_fill': -1,
+                'pad_right': True,
+                'pad_left': False,
+                'pad_bottom': False,
+                'pad_top': True
+            }
+        })
+
+        tensor1 = torch.tensor([
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [0, 0, 0, -1, -1, -1, -1],
+            [0, 0, 0, -1, -1, -1, -1],
+            [0, 0, 0, -1, -1, -1, -1]
+        ])
+
+        tensor2 = torch.tensor([
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [0, 0, 0, 0, -1, -1, -1],
+            [0, 0, 0, 0, -1, -1, -1],
+            [0, 0, 0, 0, -1, -1, -1],
+            [0, 0, 0, 0, -1, -1, -1]
+        ])
+
+        tensor3 = torch.tensor([
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [0, 0, 0, 0, 0, -1, -1],
+            [0, 0, 0, 0, 0, -1, -1],
+            [0, 0, 0, 0, 0, -1, -1],
+            [0, 0, 0, 0, 0, -1, -1],
+            [0, 0, 0, 0, 0, -1, -1]
+        ])
+
+        tensor4 = torch.tensor([
+            [-1, -1, -1, -1, -1, -1, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1]
+        ])
+
+        tensor5 = torch.tensor([
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0]
+        ])
+
+        collate = CollateTraining.from_config(config_1)
+        collated_batch = collate(fake_batch)
+
+        assert torch.equal(collated_batch[1][0], tensor1)
+        assert torch.equal(collated_batch[1][1], tensor2)
+        assert torch.equal(collated_batch[1][2], tensor3)
+        assert torch.equal(collated_batch[1][3], tensor4)
+        assert torch.equal(collated_batch[1][4], tensor5)
+
+        """
+        this format should be
+         X  X  X -1 -1 -1 -1
+         X  X  X -1 -1 -1 -1
+         X  X  X -1 -1 -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+
+        and then increase one by one for each input,
+        to enventually reach
+
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        """
+
+        tensor1 = torch.tensor([
+            [0, 0, 0, -1, -1, -1, -1],
+            [0, 0, 0, -1, -1, -1, -1],
+            [0, 0, 0, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1]
+        ])
+
+        tensor2 = torch.tensor([
+            [0, 0, 0, 0, -1, -1, -1],
+            [0, 0, 0, 0, -1, -1, -1],
+            [0, 0, 0, 0, -1, -1, -1],
+            [0, 0, 0, 0, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1]
+        ])
+
+        tensor3 = torch.tensor([
+            [0, 0, 0, 0, 0, -1, -1],
+            [0, 0, 0, 0, 0, -1, -1],
+            [0, 0, 0, 0, 0, -1, -1],
+            [0, 0, 0, 0, 0, -1, -1],
+            [0, 0, 0, 0, 0, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1]
+        ])
+
+        tensor4 = torch.tensor([
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [-1, -1, -1, -1, -1, -1, -1]
+        ])
+
+        tensor5 = torch.tensor([
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0]
+        ])
+
+        config_2 = Config({
+            'collate': {
+                'snp_dim': 'max',
+                'indiv_dim': 'max',
+                'value_fill': -1,
+                'pad_right': True,
+                'pad_left': False,
+                'pad_bottom': True,
+                'pad_top': False
+            }
+        })
+
+        collate = CollateTraining.from_config(config_2)
+        collated_batch = collate(fake_batch)
+
+        assert torch.equal(collated_batch[1][0], tensor1)
+        assert torch.equal(collated_batch[1][1], tensor2)
+        assert torch.equal(collated_batch[1][2], tensor3)
+        assert torch.equal(collated_batch[1][3], tensor4)
+        assert torch.equal(collated_batch[1][4], tensor5)
+
+        """
+        this format should be
+        -1 -1 -1 -1  X  X  X
+        -1 -1 -1 -1  X  X  X
+        -1 -1 -1 -1  X  X  X
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+
+        and then increase one by one for each input,
+        to enventually reach
+
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        """
+
+        tensor1 = torch.tensor([
+            [-1, -1, -1, -1, 0, 0, 0],
+            [-1, -1, -1, -1, 0, 0, 0],
+            [-1, -1, -1, -1, 0, 0, 0],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1]
+        ])
+
+        tensor2 = torch.tensor([
+            [-1, -1, -1, 0, 0, 0, 0],
+            [-1, -1, -1, 0, 0, 0, 0],
+            [-1, -1, -1, 0, 0, 0, 0],
+            [-1, -1, -1, 0, 0, 0, 0],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1]
+        ])
+
+        tensor3 = torch.tensor([
+            [-1, -1, 0, 0, 0, 0, 0],
+            [-1, -1, 0, 0, 0, 0, 0],
+            [-1, -1, 0, 0, 0, 0, 0],
+            [-1, -1, 0, 0, 0, 0, 0],
+            [-1, -1, 0, 0, 0, 0, 0],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1]
+        ])
+
+        tensor4 = torch.tensor([
+            [-1, 0, 0, 0, 0, 0, 0],
+            [-1, 0, 0, 0, 0, 0, 0],
+            [-1, 0, 0, 0, 0, 0, 0],
+            [-1, 0, 0, 0, 0, 0, 0],
+            [-1, 0, 0, 0, 0, 0, 0],
+            [-1, 0, 0, 0, 0, 0, 0],
+            [-1, -1, -1, -1, -1, -1, -1]
+        ])
+
+        tensor5 = torch.tensor([
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+        ])
+
+        config_3 = Config({
+            'collate': {
+                'snp_dim': 'max',
+                'indiv_dim': 'max',
+                'value_fill': -1,
+                'pad_right': False,
+                'pad_left': True,
+                'pad_bottom': True,
+                'pad_top': False
+            }
+        })
+
+        collate = CollateTraining.from_config(config_3)
+        collated_batch = collate(fake_batch)
+
+        assert torch.equal(collated_batch[1][0], tensor1)
+        assert torch.equal(collated_batch[1][1], tensor2)
+        assert torch.equal(collated_batch[1][2], tensor3)
+        assert torch.equal(collated_batch[1][3], tensor4)
+        assert torch.equal(collated_batch[1][4], tensor5)
+
+        """
+        this format should be
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1 -1 -1  X  X  X
+        -1 -1 -1 -1  X  X  X
+        -1 -1 -1 -1  X  X  X
+
+        and then increase one by one for each input,
+        to enventually reach
+
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        """
+
+        tensor1 = torch.tensor([
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, 0, 0, 0],
+            [-1, -1, -1, -1, 0, 0, 0],
+            [-1, -1, -1, -1, 0, 0, 0]
+        ])
+
+        tensor2 = torch.tensor([
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, 0, 0, 0, 0],
+            [-1, -1, -1, 0, 0, 0, 0],
+            [-1, -1, -1, 0, 0, 0, 0],
+            [-1, -1, -1, 0, 0, 0, 0]
+        ])
+
+        tensor3 = torch.tensor([
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, 0, 0, 0, 0, 0],
+            [-1, -1, 0, 0, 0, 0, 0],
+            [-1, -1, 0, 0, 0, 0, 0],
+            [-1, -1, 0, 0, 0, 0, 0],
+            [-1, -1, 0, 0, 0, 0, 0]
+        ])
+
+        tensor4 = torch.tensor([
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, 0, 0, 0, 0, 0, 0],
+            [-1, 0, 0, 0, 0, 0, 0],
+            [-1, 0, 0, 0, 0, 0, 0],
+            [-1, 0, 0, 0, 0, 0, 0],
+            [-1, 0, 0, 0, 0, 0, 0],
+            [-1, 0, 0, 0, 0, 0, 0]
+        ])
+
+        tensor5 = torch.tensor([
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0]
+        ])
+
+        config_4 = Config({
+            'collate': {
+                'snp_dim': 'max',
+                'indiv_dim': 'max',
+                'value_fill': -1,
+                'pad_right': False,
+                'pad_left': True,
+                'pad_bottom': False,
+                'pad_top': True
+            }
+        })
+
+        collate = CollateTraining.from_config(config_4)
+        collated_batch = collate(fake_batch)
+
+        assert torch.equal(collated_batch[1][0], tensor1)
+        assert torch.equal(collated_batch[1][1], tensor2)
+        assert torch.equal(collated_batch[1][2], tensor3)
+        assert torch.equal(collated_batch[1][3], tensor4)
+        assert torch.equal(collated_batch[1][4], tensor5)
+
+        """
+        Now we test the cases where two dual variables are set to true,
+        i.e. pad_right = True and pad_left = True
+        In this case, the input is in the middle of the array. If the dimension doesn't fit,
+        the padding on the right and the padding on the bottom is privileged
+        """
+
+        """
+        this format should be
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1  X  X  X -1 -1
+        -1 -1  X  X  X -1 -1
+        -1 -1  X  X  X -1 -1
+
+        and then increase one by one for each input,
+        to enventually reach
+
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        """
+
+        tensor1 = torch.tensor([
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, 0, 0, 0, -1, -1],
+            [-1, -1, 0, 0, 0, -1, -1],
+            [-1, -1, 0, 0, 0, -1, -1]
+        ])
+
+        tensor2 = torch.tensor([
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, 0, 0, 0, 0, -1, -1],
+            [-1, 0, 0, 0, 0, -1, -1],
+            [-1, 0, 0, 0, 0, -1, -1],
+            [-1, 0, 0, 0, 0, -1, -1]
+        ])
+
+        tensor3 = torch.tensor([
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, 0, 0, 0, 0, 0, -1],
+            [-1, 0, 0, 0, 0, 0, -1],
+            [-1, 0, 0, 0, 0, 0, -1],
+            [-1, 0, 0, 0, 0, 0, -1],
+            [-1, 0, 0, 0, 0, 0, -1]
+        ])
+
+        tensor4 = torch.tensor([
+            [-1, -1, -1, -1, -1, -1, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1]
+        ])
+
+        tensor5 = torch.tensor([
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0]
+        ])
+
+        config_5 = Config({
+            'collate': {
+                'snp_dim': 'max',
+                'indiv_dim': 'max',
+                'value_fill': -1,
+                'pad_right': True,
+                'pad_left': True,
+                'pad_bottom': False,
+                'pad_top': True
+            }
+        })
+
+        collate = CollateTraining.from_config(config_5)
+        collated_batch = collate(fake_batch)
+
+        assert torch.equal(collated_batch[1][0], tensor1)
+        assert torch.equal(collated_batch[1][1], tensor2)
+        assert torch.equal(collated_batch[1][2], tensor3)
+        assert torch.equal(collated_batch[1][3], tensor4)
+        assert torch.equal(collated_batch[1][4], tensor5)
+
+        """
+        this format should be
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+         X  X  X -1 -1 -1 -1
+         X  X  X -1 -1 -1 -1
+         X  X  X -1 -1 -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+
+        and then increase one by one for each input,
+        to enventually reach
+
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        """
+
+        tensor1 = torch.tensor([
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [0, 0, 0, -1, -1, -1, -1],
+            [0, 0, 0, -1, -1, -1, -1],
+            [0, 0, 0, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1]
+        ])
+
+        tensor2 = torch.tensor([
+            [-1, -1, -1, -1, -1, -1, -1],
+            [0, 0, 0, 0, -1, -1, -1],
+            [0, 0, 0, 0, -1, -1, -1],
+            [0, 0, 0, 0, -1, -1, -1],
+            [0, 0, 0, 0, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1]
+        ])
+
+        tensor3 = torch.tensor([
+            [-1, -1, -1, -1, -1, -1, -1],
+            [0, 0, 0, 0, 0, -1, -1],
+            [0, 0, 0, 0, 0, -1, -1],
+            [0, 0, 0, 0, 0, -1, -1],
+            [0, 0, 0, 0, 0, -1, -1],
+            [0, 0, 0, 0, 0, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1]
+        ])
+
+        tensor4 = torch.tensor([
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [-1, -1, -1, -1, -1, -1, -1]
+        ])
+
+        tensor5 = torch.tensor([
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0]
+        ])
+
+        config_6 = Config({
+            'collate': {
+                'snp_dim': 'max',
+                'indiv_dim': 'max',
+                'value_fill': -1,
+                'pad_right': True,
+                'pad_left': False,
+                'pad_bottom': True,
+                'pad_top': True
+            }
+        })
+
+        collate = CollateTraining.from_config(config_6)
+        collated_batch = collate(fake_batch)
+
+        assert torch.equal(collated_batch[1][0], tensor1)
+        assert torch.equal(collated_batch[1][1], tensor2)
+        assert torch.equal(collated_batch[1][2], tensor3)
+        assert torch.equal(collated_batch[1][3], tensor4)
+        assert torch.equal(collated_batch[1][4], tensor5)
+
+        """
+        this format should be
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1  X  X  X -1 -1
+        -1 -1  X  X  X -1 -1
+        -1 -1  X  X  X -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+
+        and then increase one by one for each input,
+        to enventually reach
+
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        """
+
+        tensor1 = torch.tensor([
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, 0, 0, 0, -1, -1],
+            [-1, -1, 0, 0, 0, -1, -1],
+            [-1, -1, 0, 0, 0, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1]
+        ])
+
+        tensor2 = torch.tensor([
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, 0, 0, 0, 0, -1, -1],
+            [-1, 0, 0, 0, 0, -1, -1],
+            [-1, 0, 0, 0, 0, -1, -1],
+            [-1, 0, 0, 0, 0, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1]
+        ])
+
+        tensor3 = torch.tensor([
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, 0, 0, 0, 0, 0, -1],
+            [-1, 0, 0, 0, 0, 0, -1],
+            [-1, 0, 0, 0, 0, 0, -1],
+            [-1, 0, 0, 0, 0, 0, -1],
+            [-1, 0, 0, 0, 0, 0, -1],
+            [-1, -1, -1, -1, -1, -1, -1]
+        ])
+
+        tensor4 = torch.tensor([
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [-1, -1, -1, -1, -1, -1, -1]
+        ])
+
+        tensor5 = torch.tensor([
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0]
+        ])
+
+        config_7 = Config({
+            'collate': {
+                'snp_dim': 'max',
+                'indiv_dim': 'max',
+                'value_fill': -1,
+                'pad_right': True,
+                'pad_left': True,
+                'pad_bottom': True,
+                'pad_top': True
+            }
+        })
+
+        collate = CollateTraining.from_config(config_7)
+        collated_batch = collate(fake_batch)
+
+        assert torch.equal(collated_batch[1][0], tensor1)
+        assert torch.equal(collated_batch[1][1], tensor2)
+        assert torch.equal(collated_batch[1][2], tensor3)
+        assert torch.equal(collated_batch[1][3], tensor4)
+        assert torch.equal(collated_batch[1][4], tensor5)
+
+        """
+        Last batch of tests:
+        We will test different case where the snp_dim or the
+        indiv_dim is not set to max
+        """
+
+        """
+        this format should be
+        -1 -1 -1 -1 -1 -1 -1
+        -1 -1  X  X  X -1 -1
+        -1 -1  X  X  X -1 -1
+        -1 -1  X  X  X -1 -1
+        -1 -1 -1 -1 -1 -1 -1
+
+        and then increase one by one for each input,
+        to enventually reach
+
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        X  X  X  X  X  X  X
+        """
+
+        tensor1 = torch.tensor([
+            [-1, -1, -1, -1, -1, -1, -1],
+            [-1, -1, 0, 0, 0, -1, -1],
+            [-1, -1, 0, 0, 0, -1, -1],
+            [-1, -1, 0, 0, 0, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1]
+        ])
+
+        tensor2 = torch.tensor([
+            [-1, 0, 0, 0, 0, -1, -1],
+            [-1, 0, 0, 0, 0, -1, -1],
+            [-1, 0, 0, 0, 0, -1, -1],
+            [-1, 0, 0, 0, 0, -1, -1],
+            [-1, -1, -1, -1, -1, -1, -1]
+        ])
+
+        tensor3 = torch.tensor([
+            [-1, 0, 0, 0, 0, 0, -1],
+            [-1, 0, 0, 0, 0, 0, -1],
+            [-1, 0, 0, 0, 0, 0, -1],
+            [-1, 0, 0, 0, 0, 0, -1],
+            [-1, 0, 0, 0, 0, 0, -1]
+        ])
+
+        tensor4 = torch.tensor([
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1],
+            [0, 0, 0, 0, 0, 0, -1]
+        ])
+
+        tensor5 = torch.tensor([
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0, 0, 0]
+        ])
+
+        config_8 = Config({
+            'collate': {
+                'snp_dim': 5,
+                'indiv_dim': 'max',
+                'value_fill': -1,
+                'pad_right': True,
+                'pad_left': True,
+                'pad_bottom': True,
+                'pad_top': True
+            }
+        })
+
+        collate = CollateTraining.from_config(config_8)
+        collated_batch = collate(fake_batch)
+
+        assert torch.equal(collated_batch[1][0], tensor1)
+        assert torch.equal(collated_batch[1][1], tensor2)
+        assert torch.equal(collated_batch[1][2], tensor3)
+        assert torch.equal(collated_batch[1][3], tensor4)
+        assert torch.equal(collated_batch[1][4], tensor5)
+
+        """
+        this format should be
+        -1 -1 -1 -1 -1
+        -1 -1 -1 -1 -1
+        -1  X  X  X -1
+        -1  X  X  X -1
+        -1  X  X  X -1
+        -1 -1 -1 -1 -1
+        -1 -1 -1 -1 -1
+
+        and then increase one by one for each input,
+        to enventually reach
+
+        X  X  X  X  X
+        X  X  X  X  X
+        X  X  X  X  X
+        X  X  X  X  X
+        X  X  X  X  X
+        X  X  X  X  X
+        X  X  X  X  X
+        """
+
+        tensor1 = torch.tensor([
+            [-1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1],
+            [-1, 0, 0, 0, -1],
+            [-1, 0, 0, 0, -1],
+            [-1, 0, 0, 0, -1],
+            [-1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1]
+        ])
+
+        tensor2 = torch.tensor([
+            [-1, -1, -1, -1, -1],
+            [0, 0, 0, 0, -1],
+            [0, 0, 0, 0, -1],
+            [0, 0, 0, 0, -1],
+            [0, 0, 0, 0, -1],
+            [-1, -1, -1, -1, -1],
+            [-1, -1, -1, -1, -1]
+        ])
+
+        tensor3 = torch.tensor([
+            [-1, -1, -1, -1, -1],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [-1, -1, -1, -1, -1]
+        ])
+
+        tensor4 = torch.tensor([
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [-1, -1, -1, -1, -1]
+        ])
+
+        tensor5 = torch.tensor([
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+        ])
+
+        config_9 = Config({
+            'collate': {
+                'snp_dim': 'max',
+                'indiv_dim': 5,
+                'value_fill': -1,
+                'pad_right': True,
+                'pad_left': True,
+                'pad_bottom': True,
+                'pad_top': True
+            }
+        })
+
+        collate = CollateTraining.from_config(config_9)
+        collated_batch = collate(fake_batch)
+
+        assert torch.equal(collated_batch[1][0], tensor1)
+        assert torch.equal(collated_batch[1][1], tensor2)
+        assert torch.equal(collated_batch[1][2], tensor3)
+        assert torch.equal(collated_batch[1][3], tensor4)
+        assert torch.equal(collated_batch[1][4], tensor5)
+
+        """
+        this format should be
+        -1 -1 -1 -1 -1
+        -1  X  X  X -1
+        -1  X  X  X -1
+        -1  X  X  X -1
+        -1 -1 -1 -1 -1
+
+        and then increase one by one for each input,
+        to enventually reach
+
+        X  X  X  X  X
+        X  X  X  X  X
+        X  X  X  X  X
+        X  X  X  X  X
+        X  X  X  X  X
+        """
+
+        tensor1 = torch.tensor([
+            [-1, -1, -1, -1, -1],
+            [-1, 0, 0, 0, -1],
+            [-1, 0, 0, 0, -1],
+            [-1, 0, 0, 0, -1],
+            [-1, -1, -1, -1, -1,]
+        ])
+
+        tensor2 = torch.tensor([
+            [0, 0, 0, 0, -1],
+            [0, 0, 0, 0, -1],
+            [0, 0, 0, 0, -1],
+            [0, 0, 0, 0, -1],
+            [-1, -1, -1, -1, -1,]
+        ])
+
+        tensor3 = torch.tensor([
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0]
+        ])
+
+        tensor4 = torch.tensor([
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0]
+        ])
+
+        tensor5 = torch.tensor([
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0]
+        ])
+
+        config_10 = Config({
+            'collate': {
+                'snp_dim': 5,
+                'indiv_dim': 5,
+                'value_fill': -1,
+                'pad_right': True,
+                'pad_left': True,
+                'pad_bottom': True,
+                'pad_top': True
+            }
+        })
+
+        collate = CollateTraining.from_config(config_10)
+        collated_batch = collate(fake_batch)
+
+        assert torch.equal(collated_batch[1][0], tensor1)
+        assert torch.equal(collated_batch[1][1], tensor2)
+        assert torch.equal(collated_batch[1][2], tensor3)
+        assert torch.equal(collated_batch[1][3], tensor4)
+        assert torch.equal(collated_batch[1][4], tensor5)
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 7f82a84b3ab87462f8e37b51b65fbfaf200236cc..eca45f8ddf71b38396914e39b010c77853e1d121 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -18,7 +18,9 @@ from dnadna.examples import one_event
 from dnadna.params import LearnedParams
 from dnadna.snp_sample import SNPSample
 from dnadna.transforms import ValidateSnp
+from dnadna.utils.config import Config
 from dnadna.utils.misc import zero_pad_format
+from dnadna.collate import CollateTraining
 
 
 @pytest.fixture(scope='module')
@@ -345,11 +347,24 @@ def test_invalid_uniform_dataloader(num_workers):
     exception is raised.
     """
 
+    # Necessary since the change of collate_func to a class
+    config = Config({
+        'collate': {
+            'snp_dim': 'max',
+            'indiv_dim': 'max',
+            'value_fill': -1,
+            'pad_right': True,
+            'pad_left': False,
+            'pad_bottom': True,
+            'pad_top': False
+        }})
+
+    collate = CollateTraining.from_config(config)
     _, dataset = nonuniform_dataset(nonuniformity='both', n_scenarios=10)
     loader = DataLoader(dataset=dataset,
                         batch_size=1,
                         num_workers=num_workers,
-                        collate_fn=DNATrainingDataset.collate_batch)
+                        collate_fn=collate)
 
     # Attempting to iterate over the loader should result in an
     # InvalidSNPSample error eventually
@@ -371,7 +386,7 @@ def test_invalid_uniform_dataloader(num_workers):
     loader = DataLoader(dataset=dataset,
                         batch_size=1,
                         num_workers=num_workers,
-                        collate_fn=DNATrainingDataset.collate_batch)
+                        collate_fn=collate)
 
     for r_idx, datum in enumerate(loader):
         s_idx, snp = datum[:2]
diff --git a/tests/test_prediction.py b/tests/test_prediction.py
index 2c9335a7ef028cf27b262858c32981763b1ccf3f..39ce9c04f496dee9c913560cbaf8632920e7e4ec 100644
--- a/tests/test_prediction.py
+++ b/tests/test_prediction.py
@@ -29,6 +29,15 @@ def test_predict_with_classification_params():
                 'type': 'classification',
                 'classes': ['yes', 'no']
             }
+        },
+        'collate': {
+            'snp_dim': 'max',
+            'indiv_dim': 'max',
+            'value_fill': -1,
+            'pad_right': True,
+            'pad_left': False,
+            'pad_bottom': True,
+            'pad_top': False
         }
     })
 
@@ -60,6 +69,15 @@ def test_predict_with_log_transform(device):
         },
         'learned_params': {
             'position': {'type': 'regression', 'log_transform': True},
+        },
+        'collate': {
+            'snp_dim': 'max',
+            'indiv_dim': 'max',
+            'value_fill': -1,
+            'pad_right': True,
+            'pad_left': False,
+            'pad_bottom': True,
+            'pad_top': False
         }
     })