diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 07d14a8bac2f1bdb46a69cc9b09c493e6e15e8e4..b5b3d9165689d8cf2e25b91490f45a9e6580d070 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,5 +1,6 @@ # make sure to pin image tag for ease of reproducibility -image: continuumio/miniconda3:4.9.2 +image: continuumio/miniconda3 +#image: continuumio/miniconda3:4.9.2 # image: mambaorg/micromamba variables: @@ -69,7 +70,8 @@ workflow: - mkdir -p ${CONDA_ENVS} - conda config --prepend envs_dirs ${CONDA_ENVS} - conda update -n base conda && - - conda install 'mamba<=1.4.5' -n base -c conda-forge && + #- conda install 'mamba<=1.4.5' -n base -c conda-forge && + - conda install mamba -n base -c conda-forge && - conda update mamba -n base -c conda-forge .conda_setup: &conda_setup @@ -186,7 +188,7 @@ docs: # libenchant is needed for sphinxcontrib.spelling by way of PyEnchant, # and is not easily installed via conda - - apt-get update --allow-releaseinfo-change && apt-get install -yqq libenchant-dev + - apt-get update --allow-releaseinfo-change && apt-get install -yqq libenchant-2-dev script: - cd docs/ - make html diff --git a/dnadna/__init__.py b/dnadna/__init__.py index 91affc411c3410c2b4aafdf6819706443df967db..3bb76ace57fa7a6c9038cb8cb22b8d38fedc606a 100644 --- a/dnadna/__init__.py +++ b/dnadna/__init__.py @@ -24,7 +24,8 @@ BUILTIN_PLUGINS = [ 'dnadna.nets', 'dnadna.optim', 'dnadna.simulator', - 'dnadna.transforms' + 'dnadna.transforms', + 'dnadna.collate' ] """ List all internal modules that provide a `Pluggable` class. diff --git a/dnadna/collate.py b/dnadna/collate.py new file mode 100644 index 0000000000000000000000000000000000000000..f6fa2042395ac3d055446c9af8ed3e4d1190c96f --- /dev/null +++ b/dnadna/collate.py @@ -0,0 +1,334 @@ +import abc +from math import ceil, floor + +import numpy as np +import torch + +from .utils.config import Config +from .utils.plugins import Pluggable + + +class Collate(Pluggable, metaclass=abc.ABCMeta): + """ + Parent Pluggable class to inherent from when the user want to write their + own collate_batch function. + + The user must define a schema and can use the implementation of CollateTraining + as an example of how the class should look. Then the main function to implement is + _collate_batch, which will be called by default by the __call__ function. + + Here is an example using the class CollateTraining: + + from dnadna.collate import CollateTraining + from dnadna.utils.config import Config + collate = CollateTraining.from_config(config = Config({ + 'collate': { + 'snp_dim': 'max', + 'indiv_dim': 'max', + 'value_fill': -1, + 'pad_right': True, + 'pad_left': False, + 'pad_bottom': True, + 'pad_top': False + } + })) + collated_batch = collate(batch) + """ + + schema = {} + + def __init__(self, config): + self.padding_params = config + + def __repr__(self): + return str(self.padding_params) + + def __str__(self): + return str(self.padding_params) + + def __call__(self, batch): + return self._collate_batch(batch) + + @abc.abstractmethod + def _collate_batch(self, batch): + """ + Specifies how multiple scenario samples are collated into batches. + + Each batch element is a single element as returned by + ``DNATrainingDataset.__getitem__``: ``(scenario_idx, replicate_idx, + snp_sample, target)``. + + It is the main function to implement when heriting from Collate. + See CollateTraining for an example of implementation + """ + + @classmethod + def from_config(cls, config, validate=True): + """ + Initialize a `Collate` from the batch configuration from a + config file. + + Examples + -------- + + >>> from dnadna.collate import CollateTraining + >>> from dnadna.utils.config import Config + >>> collate = CollateTraining.from_config(config = Config({ + ... 'collate': { + ... 'snp_dim': 'max', + ... 'indiv_dim': 'max', + ... 'value_fill': -1, + ... 'pad_right': True, + ... 'pad_left': False, + ... 'pad_bottom': True, + ... 'pad_top': False + ... } + ... })) + >>> collate + Config({'snp_dim': 'max', 'indiv_dim': 'max', 'value_fill': -1, 'pad_right': True, + 'pad_left': False, 'pad_bottom': True, 'pad_top': False}) + """ + + if not isinstance(config, Config): + config = Config(config) + + if validate: + config.validate(cls.get_schema()) + config = config['collate'] + + return cls(config) + + @classmethod + def get_schema(cls): + """ + Returns the schema of the Collate implemented + """ + + return cls.schema + + +class CollateTraining(Collate): + + schema = { + 'properties': { + 'value_fill': {'type': 'number'}, + 'snp_dim': {'type': ['string', 'number']}, + 'indiv_dim': {'type': ['string', 'number']}, + 'pad_right': {'type': 'boolean'}, + 'pad_left': {'type': 'boolean'}, + 'pad_bottom': {'type': 'boolean'}, + 'pad_top': {'type': 'boolean'} + } + } + + def _collate_batch(self, batch): + """ + Specifies how multiple scenario samples are collated into batches. + + Each batch element is a single element as returned by + ``DNATrainingDataset.__getitem__``: ``(scenario_idx, replicate_idx, + snp_sample, target)``. + + For input samples and targets are collated into batches "vertically", + so that the size of the first dimension represents the number of items + in a batch. + + Examples + -------- + + >>> import torch + >>> from dnadna.collate import CollateTraining + >>> from dnadna.utils.config import Config + >>> from dnadna.snp_sample import SNPSample + >>> fake_snps = [torch.rand(3, 3 + i) for i in range(5)] + >>> fake_snps = [SNPSample(s[1:], s[0]) for s in fake_snps] + >>> fake_params = [torch.rand(4, dtype=torch.float64) for _ in range(5)] + >>> fake_batch = list(zip(range(5), [0] * 5, fake_snps, fake_params)) + >>> config = Config({ + ... 'collate': { + ... 'snp_dim': 'max', + ... 'indiv_dim': 'max', + ... 'value_fill': -1, + ... 'pad_right': True, + ... 'pad_left': False, + ... 'pad_bottom': True, + ... 'pad_top': False + ... } + ... }) + >>> collate_func = CollateTraining.from_config(config=config) + >>> collated_batch = collate_func(fake_batch) + >>> scenario_idxs, inputs, targets = collated_batch + >>> bool((torch.arange(5) == scenario_idxs).all()) + True + >>> inputs.shape # last dim should be num SNPs in largest fake SNP + torch.Size([5, 3, 7]) + >>> bool((inputs[0,:3,:3] == fake_snps[0].tensor).all()) + True + >>> bool((inputs[0,3:,3:] == -1).all()) + True + >>> bool((inputs[-1] == fake_snps[-1].tensor).all()) + True + >>> targets.shape + torch.Size([5, 4]) + >>> [bool((fake_params[bat].float() == targets[bat]).all()) + ... for bat in range(targets.shape[0])] + [True, True, True, True, True] + """ + + # filter any missing samples out of the batch + batch = list(filter(lambda it: it[2] is not None, batch)) + + # we have to consider the possibility of an empty batch (which could + # happen if batch_size=1 and the sample was missing) + if not batch: + return None + + scen_idxs, _, samples, targets = zip(*batch) + + # add padding so that all input SNPs are the same size and shape + # (though should all have the same number of rows, but may have + # different number of SNPs (columns) + # the value -1 is used for padded regions + # TODO: Question: Must nets explicitly account for this padding, and if + # so, how? ReLU? + # NOTE: This extra step of filling unevenly-sized matrices with -1 + # is unnecessary if we are using a dataset with uniform=True, so + # there should be an option to skip this step entirely. For now we + # just check if all inputs are the same size + inputs = [s.tensor for s in samples] + + pad_top = self.padding_params['pad_top'] + pad_bottom = self.padding_params['pad_bottom'] + pad_left = self.padding_params['pad_left'] + pad_right = self.padding_params['pad_right'] + + if not (pad_top or pad_bottom) or not (pad_left or pad_right): + raise ValueError("Invalid padding configuration. At least one horizontal" + "and one vertical value must be set to true.") + + snp_dim = self.padding_params['snp_dim'] + indiv_dim = self.padding_params['indiv_dim'] + filler = self.padding_params["value_fill"] + max_indiv_dim = max(input.shape[1] for input in inputs) + max_snp_dim = max(input.shape[0] for input in inputs) + + # Determine target dimensions based on input parameters + target_indiv_dim = indiv_dim if isinstance(indiv_dim, int) else max_indiv_dim + target_snp_dim = snp_dim if isinstance(snp_dim, int) else max_snp_dim + + # Initialize a tensor to store the padded inputs + new_inputs = torch.zeros((len(inputs), target_snp_dim, target_indiv_dim)) + + for i, input in enumerate(inputs): + + # This case is when "dim_snp" and "dim_indiv" are set to max + if target_indiv_dim == max_indiv_dim and target_snp_dim == max_snp_dim: + # Calculate padding dimensions based on boolean parameters + if not (pad_bottom and pad_top) and not (pad_left and pad_right): + pad_dims = ( + (0 if not pad_top else max_snp_dim - input.shape[0], + 0 if not pad_bottom else max_snp_dim - input.shape[0]), + + (0 if not pad_left else max_indiv_dim - input.shape[1], + 0 if not pad_right else max_indiv_dim - input.shape[1]) + ) + elif (pad_bottom and pad_top) and not (pad_left and pad_right): + pad_dims = ( + # top then bottom + # If dimension doesn't divide by 2, the extra padding is done on the bottom + (floor((max_snp_dim - input.shape[0]) / 2), + ceil((max_snp_dim - input.shape[0]) / 2)), + + (0 if not pad_left else max_indiv_dim - input.shape[1], + 0 if not pad_right else max_indiv_dim - input.shape[1]) + ) + elif not (pad_bottom and pad_top) and (pad_left and pad_right): + # left then right + # If dimension doesn't divide by 2, the extra padding is done on the right + pad_dims = ( + (0 if not pad_top else max_snp_dim - input.shape[0], + 0 if not pad_bottom else max_snp_dim - input.shape[0]), + + (floor((max_indiv_dim - input.shape[0]) / 2), + ceil((max_indiv_dim - input.shape[0]) / 2)), + ) + else: + pad_dims = ( + (floor((max_snp_dim - input.shape[0]) / 2), + ceil((max_snp_dim - input.shape[0]) / 2)), + + (floor((max_indiv_dim - input.shape[0]) / 2), + ceil((max_indiv_dim - input.shape[0]) / 2)), + ) + + # This case is when "dim_snp" and "dim_indiv" are set to integer values + else: + if not (pad_bottom and pad_top) and not (pad_left and pad_right): + pad_dims = ( + (max(0, (0 if not pad_top else target_snp_dim - input.shape[0])), + max(0, 0 if not pad_bottom else target_snp_dim - input.shape[0])), + + (max(0, (0 if not pad_left else target_indiv_dim - input.shape[1])), + max(0, 0 if not pad_right else target_indiv_dim - input.shape[1])) + ) + elif (pad_bottom and pad_top) and not (pad_left and pad_right): + pad_dims = ( + # top then bottom + # If dimension doesn't divide by 2, the extra padding is done on the bottom + (max(0, (floor((target_snp_dim - input.shape[0]) / 2))), + max(0, ceil((target_snp_dim - input.shape[0]) / 2))), + + (max(0, (0 if not pad_left else target_indiv_dim - input.shape[1])), + max(0, 0 if not pad_right else target_indiv_dim - input.shape[1])) + ) + elif not (pad_bottom and pad_top) and (pad_left and pad_right): + # left then right + # If dimension doesn't divide by 2, the extra padding is done on the right + pad_dims = ( + (max(0, (0 if not pad_top else target_snp_dim - input.shape[0])), + max(0, 0 if not pad_bottom else target_snp_dim - input.shape[0])), + + (max(0, (floor((target_indiv_dim - input.shape[0]) / 2))), + max(0, ceil((target_indiv_dim - input.shape[0]) / 2))), + ) + else: + pad_dims = ( + (max(0, (floor((target_snp_dim - input.shape[0]) / 2))), + max(0, ceil((target_snp_dim - input.shape[0]) / 2))), + + (max(0, (floor((target_indiv_dim - input.shape[0]) / 2))), + max(0, ceil((target_indiv_dim - input.shape[0]) / 2))), + ) + + # Adjust dimensions if needed + padded_input = np.pad(input, pad_dims, mode='constant', constant_values=filler) + adjusted_input = padded_input[:target_snp_dim, :target_indiv_dim] + + new_inputs[i, :, :] = torch.as_tensor(adjusted_input).clone().detach() + + # Concatenate targets and ensure they are converted to single-precision + # floats for passing to GPU devices + # NOTE: targets can be all None if the Dataset was initialized without + # scenario_params + if targets[0] is not None: + n_parameters = len(targets[0]) + targets = torch.cat(targets).reshape(-1, n_parameters).float() + + return [torch.tensor(scen_idxs, dtype=torch.long), new_inputs, targets] + + +class CollateInference(CollateTraining): + + def _collate_batch(self, batch): + """ + Like `CollateTraining.collate_batch` but also returns the + replicate indices and sample paths (e.g. filenames) which are used + in the prediction output. + """ + + rep_idxs = [b[1] for b in batch] + paths = [b[2].path for b in batch] + out = super()._collate_batch(batch) + return (out[:1] + + [torch.tensor(rep_idxs, dtype=torch.long), paths] + + out[1:]) diff --git a/dnadna/datasets.py b/dnadna/datasets.py index 7f8985fc3b164b1c68a28504c909b419e319750c..1c30b504ff7852fd5775ddb32e6779e502fc3bab 100644 --- a/dnadna/datasets.py +++ b/dnadna/datasets.py @@ -1283,94 +1283,6 @@ how to know if ubunt return xfs - @staticmethod - def collate_batch(batch): - """ - Specifies how multiple scenario samples are collated into batches. - - Each batch element is a single element as returned by - ``DNATrainingDataset.__getitem__``: ``(scenario_idx, replicate_idx, - snp_sample, target)``. - - For input samples and targets are collated into batches "vertically", - so that the size of the first dimension represents the number of items - in a batch. - - Examples - -------- - - >>> import torch - >>> from dnadna.datasets import DNATrainingDataset - >>> from dnadna.snp_sample import SNPSample - >>> fake_snps = [torch.rand(3, 3 + i) for i in range(5)] - >>> fake_snps = [SNPSample(s[1:], s[0]) for s in fake_snps] - >>> fake_params = [torch.rand(4, dtype=torch.float64) for _ in range(5)] - >>> fake_batch = list(zip(range(5), [0] * 5, fake_snps, fake_params)) - >>> collated_batch = DNATrainingDataset.collate_batch(fake_batch) - >>> scenario_idxs, inputs, targets = collated_batch - >>> bool((torch.arange(5) == scenario_idxs).all()) - True - >>> inputs.shape # last dim should be num SNPs in largest fake SNP - torch.Size([5, 3, 7]) - >>> bool((inputs[0,:3,:3] == fake_snps[0].tensor).all()) - True - >>> bool((inputs[0,3:,3:] == -1).all()) - True - >>> bool((inputs[-1] == fake_snps[-1].tensor).all()) - True - >>> targets.shape - torch.Size([5, 4]) - >>> [bool((fake_params[bat].float() == targets[bat]).all()) - ... for bat in range(targets.shape[0])] - [True, True, True, True, True] - """ - - # filter any missing samples out of the batch - batch = list(filter(lambda it: it[2] is not None, batch)) - - # we have to consider the possibility of an empty batch (which could - # happen if batch_size=1 and the sample was missing) - if not batch: - return None - - scen_idxs, _, samples, targets = zip(*batch) - - # add padding so that all input SNPs are the same size and shape - # (though should all have the same number of rows, but may have - # different number of SNPs (columns) - # the value -1 is used for padded regions - # TODO: Question: Must nets explicitly account for this padding, and if - # so, how? ReLU? - # NOTE: This extra step of filling unevenly-sized matrices with -1 - # is unnecessary if we are using a dataset with uniform=True, so - # there should be an option to skip this step entirely. For now we - # just check if all inputs are the same size - inputs = [s.tensor for s in samples] - example_shape = inputs[0].shape - if any(inp.shape != example_shape for inp in inputs): - max_ind_batch = np.max([inp.shape[0] for inp in inputs]) - max_snp_batch = np.max([inp.shape[1] for inp in inputs]) - new_inputs_shape = (len(batch), max_ind_batch, max_snp_batch) - new_inputs = torch.full(new_inputs_shape, -1, dtype=torch.float) - for idx, inp in enumerate(inputs): - # NOTE: I feel like there should be a more efficient way to fill an - # N-D tensor for a sequence of (N - 1)-D tensors, but at the moment - # I can't find it; to revisit - new_inputs[idx, :inp.shape[0], :inp.shape[1]] = inp - else: - # Just stack the inputs - new_inputs = torch.stack(inputs).float() - - # Concatenate targets and ensure they are converted to single-precision - # floats for passing to GPU devices - # NOTE: targets can be all None if the Dataset was initialized without - # scenario_params - if targets[0] is not None: - n_parameters = len(targets[0]) - targets = torch.cat(targets).reshape(-1, n_parameters).float() - - return [torch.tensor(scen_idxs, dtype=torch.long), new_inputs, targets] - def _iter_samples(self): """ Iterates over each replicate for each scenario in diff --git a/dnadna/defaults/training.yml b/dnadna/defaults/training.yml index 30ba6db72a897bd87d83498957c8f336f475fc11..ba9268486dee8cc2820996d0ff615ce8a1520333 100644 --- a/dnadna/defaults/training.yml +++ b/dnadna/defaults/training.yml @@ -14,6 +14,14 @@ dataset_transforms: - snp_format: concat - validate_snp: uniform_shape: false +collate: + snp_dim: max + indiv_dim: max + value_fill: -1 + pad_right: True + pad_left: False + pad_bottom: True + pad_top: False n_epochs: 1 evaluation_interval: 1 batch_size: 8 diff --git a/dnadna/examples/one_event.py b/dnadna/examples/one_event.py index 6bc3d8381ad8a4c20db8f16133812d1ba62c57a3..e351e8e3a047feae5e9f0ddd30f462ea651795ae 100644 --- a/dnadna/examples/one_event.py +++ b/dnadna/examples/one_event.py @@ -132,6 +132,14 @@ DEFAULT_ONE_EVENT_TRAINING_CONFIG = Config({ 'evaluation_interval': 10, 'loader_num_workers': 4, 'device': '', + 'collate': + {'snp_dim': 'max', + 'indiv_dim': 'max', + 'value_fill': -1, + 'pad_right': True, + 'pad_left': False, + 'pad_bottom': True, + 'pad_top': False} }) diff --git a/dnadna/nets.py b/dnadna/nets.py index 82506f244fb353e61f744b9bcb79e3913a991c6b..c1a2a8e7fbdf394868a4df01991695649c529da7 100644 --- a/dnadna/nets.py +++ b/dnadna/nets.py @@ -163,8 +163,8 @@ class SPIDNA(Network): Publication ----------- T. Sanchez, J. Cury, G. Charpiat, et F. Jay, - « Deep learning for population size history inference: Design, comparison and - combination with approximate Bayesian computation », + « Deep learning for population size history inference: Design, comparison and + combination with approximate Bayesian computation », Mol Ecol Resour, p. 1755‑0998.13224, juill. 2020, doi: 10.1111/1755-0998.13224. Parameters @@ -236,8 +236,8 @@ class CustomCNN(Network): Publication ----------- T. Sanchez, J. Cury, G. Charpiat, et F. Jay, - « Deep learning for population size history inference: Design, comparison and - combination with approximate Bayesian computation », + « Deep learning for population size history inference: Design, comparison and + combination with approximate Bayesian computation », Mol Ecol Resour, p. 1755‑0998.13224, juill. 2020, doi: 10.1111/1755-0998.13224. Parameters diff --git a/dnadna/prediction.py b/dnadna/prediction.py index 9f8945c9db686e16eb5a90f47912a64a3ee21303..2b022178ddefea56214b242780568ca12b763d45 100644 --- a/dnadna/prediction.py +++ b/dnadna/prediction.py @@ -22,6 +22,7 @@ from .nets import Network from .params import ParamSet from .utils.config import Config, ConfigMixIn from .utils.misc import stdio_redirect_tqdm +from .collate import CollateInference log = logging.getLogger(__name__) @@ -52,26 +53,6 @@ class DNAPredictionDataset(DatasetTransformationMixIn): self.preprocess = False self.min_snp = self.min_indiv = None - @classmethod - def collate_batch(cls, batch): - """ - Like `DatasetTransformationMixIn.collate_batch` but also returns the - replicate indices and sample paths (e.g. filenames) which are used - in the prediction output. - - .. note:: - - This is implemented as a `classmethod` which is necessary to be - able to call the super-class's ``collate_batch``. - """ - - rep_idxs = [b[1] for b in batch] - paths = [b[2].path for b in batch] - out = super().collate_batch(batch) - return (out[:1] + - [torch.tensor(rep_idxs, dtype=torch.long), paths] + - out[1:]) - def _iter_samples(self): idx = 0 for _, sample in super()._iter_samples(): @@ -99,6 +80,7 @@ class Predictor(ConfigMixIn): self.net = net # Place the network in evaluation mode self.net.eval() + self.collate_func = CollateInference.from_config(config) self.params = ParamSet(self.learned_params) # Update predict_transforms attribute with the merge of # dataset_transforms and predict_transforms (which has priority): @@ -194,9 +176,17 @@ class Predictor(ConfigMixIn): ... 'type': 'classification', ... 'classes': ['yes', 'no'] ... } + ... }, + ... 'collate': { + ... 'snp_dim': 'max', + ... 'indiv_dim': 'max', + ... 'value_fill': -1, + ... 'pad_right': True, + ... 'pad_left': False, + ... 'pad_bottom': True, + ... 'pad_top': False ... } ... }) - ... Initialize a fake instance of the net; this along with the config are the bare minimum needed to instantiate a `Predictor`: @@ -474,7 +464,7 @@ class Predictor(ConfigMixIn): loader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=loader_num_workers, pin_memory=device.type == 'cuda', - collate_fn=DNAPredictionDataset.collate_batch, + collate_fn=self.collate_func, sampler=sampler) if device.type == 'cuda': self.net = nn.DataParallel(self.net, device_ids=os.environ['CUDA_VISIBLE_DEVICES']) diff --git a/dnadna/schemas/training.yml b/dnadna/schemas/training.yml index 8ff49d1fb60f7f5986ff05c07c8d9a21556e500b..bf9d29d57efc8e1292be8b9d0292ec10c5be29e2 100644 --- a/dnadna/schemas/training.yml +++ b/dnadna/schemas/training.yml @@ -61,11 +61,26 @@ allOf: minimum: 1 default: 1 - batch_size: - description: "sample batch size to train on" - type: "integer" - minimum: 1 - default: 1 + collate: + description: >- + options for the collate function applied to the batch. Different parameters + are possible. snp_dim is the number of snp per batch. If set to max, each + element of the batch is padded up to the maximum dimension of snp of the batch. + If set to an int, all snp are padded/cutted to this dimension. indiv_dim + works the same way for the number of individuals. value_fill is the number to + fill with, set to -1 by default. It is also possible to chose from which sad the + padding is done. The default is pad_right and pad_bottom. If two dual variables are + set to true (i.e. pad_right and pad_left), the padding will be done on both sides equally. + If the size does not match, it will pad in priority on the right side and on the bottom side. + default: + snp_dim: max + indiv_dim: max + value_fill: -1 + pad_right: True + pad_left: False + pad_bottom: True + pad_top: False + "$ref": "py-obj:dnadna.schemas.plugins.collate" loader_num_workers: description: "number of subprocesses to use for data loading" diff --git a/dnadna/training.py b/dnadna/training.py index a58589dd96c597f2d0c55ab0de091210560145d0..723862a15269a16244a6404ac299353f2c464d38 100644 --- a/dnadna/training.py +++ b/dnadna/training.py @@ -24,6 +24,7 @@ import tqdm from torch.utils.data import DataLoader, sampler from . import __version__, DNADNAWarning +from .collate import CollateTraining from .datasets import DNATrainingDataset from .nets import Network from .optim import Optimizer, LRScheduler @@ -190,6 +191,12 @@ class ModelTrainer(ConfigMixIn): config = Config.from_file(filename, validate=False, **kwargs) return cls(config=config, validate=True, progress_bar=progress_bar) + @prepared_property + def collate(self): + """ + Returns the arguments that the `Collate` of the model trainer. + """ + @prepared_property def net(self): """ @@ -380,6 +387,8 @@ class ModelTrainer(ConfigMixIn): learned_params=self._learned_params, validate=False) + self._collate = CollateTraining.from_config(self.config) + self._dataset = dataset # Load the dataset @@ -544,7 +553,7 @@ class ModelTrainer(ConfigMixIn): num_workers=num_workers, worker_init_fn=self._worker_init_fn, pin_memory=self.device.type == 'cuda', - collate_fn=DNATrainingDataset.collate_batch) + collate_fn=self.collate) training_sampler = init_sampler(dataset.training_set) validation_sampler = init_sampler(dataset.validation_set) @@ -768,7 +777,8 @@ class ModelTrainer(ConfigMixIn): 'dnadna_version': __version__, 'state_dict': model_state_dict, 'state_dict_optimizer': self.optimizer.state_dict(), - 'epoch': self.current_epoch + 'epoch': self.current_epoch, + 'collate': self.config['collate'] }) # In case of classification task, there are no train_mean nor train_std try: @@ -978,12 +988,16 @@ class ModelTrainer(ConfigMixIn): param = self.learned_params.params[param_name] target_slice, output_slice = \ self.learned_params.param_slices[param_name] + target = targets[:, target_slice] + output = outputs[:, output_slice] if param['type'] == 'regression': # Compute error metrics + squared_err = (output - target) ** 2 + errors[:, target_slice] = squared_err # For all regression params, omit samples where the target diff --git a/dnadna/transforms.py b/dnadna/transforms.py index e4f7f5e3d4d6556ba712eebea2e6dc581fa82cb5..8fdd01a08561ecdfae2647e25b8ef36db2268381 100644 --- a/dnadna/transforms.py +++ b/dnadna/transforms.py @@ -613,6 +613,7 @@ class ReformatPosition(Transform): def __call__(self, data): snp, lp, scen = data + pos = snp.pos input_pos_format = snp.full_pos_format new_pos_format = {} chromosome_size = input_pos_format.get('chromosome_size', @@ -620,8 +621,6 @@ class ReformatPosition(Transform): initial_position = input_pos_format.get('initial_position', self.initial_position) - pos = snp.pos - # Normalize/de-normalize first; this makes it easier to specify # normalization when converting to/from distances if (self.normalized is not None and diff --git a/dnadna/utils/plugins.py b/dnadna/utils/plugins.py index 07eb624c664014d51471cca39fae6b12a2492865..88afc7ff4b7b5bd56ffd163447a87f08be390ac9 100644 --- a/dnadna/utils/plugins.py +++ b/dnadna/utils/plugins.py @@ -242,7 +242,6 @@ class Pluggable: >>> del Pluggable._registry['my_pluggable'] # cleanup """ - if Pluggable in cls.__bases__: registry = Pluggable._registry # Give the pluggable its own registry of plugins overriding @@ -251,7 +250,6 @@ class Pluggable: cls.pluggable = cls else: registry = cls._registry - plugin_name = decamelcase(cls.__name__) if '_plugin_name' not in cls.__dict__: diff --git a/docs/training.rst b/docs/training.rst index 19d4b0fb8d72645743ef2dc35451cb9457decc8a..48d7aa4e57a4a0e74abcc859e9911a3caef1c425 100644 --- a/docs/training.rst +++ b/docs/training.rst @@ -144,6 +144,7 @@ They are specified through each of the following: * ``batch_size``: number of examples in a batch * ``n_epochs``: number of epochs * ``evaluation_interval``: interval (number of batches processed) between two validation steps. + * ``collate``: parameters to configure the batching process, mostly padding. .. code-block:: yaml @@ -178,6 +179,24 @@ They are specified through each of the following: betas: [0.9, 0.999] eps: 1.0e-08 amsgrad: false + + # options for the collate function applied to the batch. Different parameters + # are possible. snp_dim is the number of snp per batch. If set to max, each + # element of the batch is padded up to the maximum dimension of snp of the batch. + # If set to an int, all snp are padded/cutted to this dimension. indiv_dim + # works the same way for the number of individuals. value_fill is the number to + # fill with, set to -1 by default. It is also possible to chose from which sad the + # padding is done. The default is pad_right and pad_bottom. If two dual variables are + # set to true (i.e. pad_right and pad_left), the padding will be done on both sides equally. + # If the size does not match, it will pad in priority on the right side and on the bottom side. + collate: + snp_dim: max + indiv_dim: max + value_fill: -1 + pad_right: True + pad_left: False + pad_bottom: True + pad_top: False # name and parameters of the scheduler to use; all built-in schedulers from the # torch.optim.lr_scheduler package are available for use here, and you can also diff --git a/tests/test_collate.py b/tests/test_collate.py new file mode 100644 index 0000000000000000000000000000000000000000..3b09bacb48d42ede51504c225f8437370bff7454 --- /dev/null +++ b/tests/test_collate.py @@ -0,0 +1,942 @@ +import unittest +from dnadna.utils.config import Config +from dnadna.collate import CollateTraining +from dnadna.snp_sample import SNPSample +import torch + +fake_snps = [torch.zeros(3 + i, 3 + i) for i in range(5)] +fake_snps = [SNPSample(s[1:], s[0]) for s in fake_snps] +fake_params = [torch.rand(4, dtype=torch.float64) for _ in range(5)] +fake_batch = list(zip(range(5), [0] * 5, fake_snps, fake_params)) + + +class TestCollate(unittest.TestCase): + + def test_collate_batch(self): + """Test if the error is working properly""" + config_0 = Config({ + 'collate': { + 'snp_dim': 'max', + 'indiv_dim': 'max', + 'value_fill': -1, + 'pad_right': False, + 'pad_left': False, + 'pad_bottom': False, + 'pad_top': True + } + }) + + collate = CollateTraining.from_config(config_0) + self.assertRaises(ValueError, collate, fake_batch) + + """ + this format should be + -1 -1 -1 -1 -1 -1 -1 + -1 -1 -1 -1 -1 -1 -1 + -1 -1 -1 -1 -1 -1 -1 + -1 -1 -1 -1 -1 -1 -1 + 0 0 0 -1 -1 -1 -1 + 0 0 0 -1 -1 -1 -1 + 0 0 0 -1 -1 -1 -1 + + and then increase one by one for each input, + to enventually reach + + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + """ + + config_1 = Config({ + 'collate': { + 'snp_dim': 'max', + 'indiv_dim': 'max', + 'value_fill': -1, + 'pad_right': True, + 'pad_left': False, + 'pad_bottom': False, + 'pad_top': True + } + }) + + tensor1 = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [0, 0, 0, -1, -1, -1, -1], + [0, 0, 0, -1, -1, -1, -1], + [0, 0, 0, -1, -1, -1, -1] + ]) + + tensor2 = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [0, 0, 0, 0, -1, -1, -1], + [0, 0, 0, 0, -1, -1, -1], + [0, 0, 0, 0, -1, -1, -1], + [0, 0, 0, 0, -1, -1, -1] + ]) + + tensor3 = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [0, 0, 0, 0, 0, -1, -1], + [0, 0, 0, 0, 0, -1, -1], + [0, 0, 0, 0, 0, -1, -1], + [0, 0, 0, 0, 0, -1, -1], + [0, 0, 0, 0, 0, -1, -1] + ]) + + tensor4 = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1] + ]) + + tensor5 = torch.tensor([ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0] + ]) + + collate = CollateTraining.from_config(config_1) + collated_batch = collate(fake_batch) + + assert torch.equal(collated_batch[1][0], tensor1) + assert torch.equal(collated_batch[1][1], tensor2) + assert torch.equal(collated_batch[1][2], tensor3) + assert torch.equal(collated_batch[1][3], tensor4) + assert torch.equal(collated_batch[1][4], tensor5) + + """ + this format should be + X X X -1 -1 -1 -1 + X X X -1 -1 -1 -1 + X X X -1 -1 -1 -1 + -1 -1 -1 -1 -1 -1 -1 + -1 -1 -1 -1 -1 -1 -1 + -1 -1 -1 -1 -1 -1 -1 + + and then increase one by one for each input, + to enventually reach + + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + """ + + tensor1 = torch.tensor([ + [0, 0, 0, -1, -1, -1, -1], + [0, 0, 0, -1, -1, -1, -1], + [0, 0, 0, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1] + ]) + + tensor2 = torch.tensor([ + [0, 0, 0, 0, -1, -1, -1], + [0, 0, 0, 0, -1, -1, -1], + [0, 0, 0, 0, -1, -1, -1], + [0, 0, 0, 0, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1] + ]) + + tensor3 = torch.tensor([ + [0, 0, 0, 0, 0, -1, -1], + [0, 0, 0, 0, 0, -1, -1], + [0, 0, 0, 0, 0, -1, -1], + [0, 0, 0, 0, 0, -1, -1], + [0, 0, 0, 0, 0, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1] + ]) + + tensor4 = torch.tensor([ + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [-1, -1, -1, -1, -1, -1, -1] + ]) + + tensor5 = torch.tensor([ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0] + ]) + + config_2 = Config({ + 'collate': { + 'snp_dim': 'max', + 'indiv_dim': 'max', + 'value_fill': -1, + 'pad_right': True, + 'pad_left': False, + 'pad_bottom': True, + 'pad_top': False + } + }) + + collate = CollateTraining.from_config(config_2) + collated_batch = collate(fake_batch) + + assert torch.equal(collated_batch[1][0], tensor1) + assert torch.equal(collated_batch[1][1], tensor2) + assert torch.equal(collated_batch[1][2], tensor3) + assert torch.equal(collated_batch[1][3], tensor4) + assert torch.equal(collated_batch[1][4], tensor5) + + """ + this format should be + -1 -1 -1 -1 X X X + -1 -1 -1 -1 X X X + -1 -1 -1 -1 X X X + -1 -1 -1 -1 -1 -1 -1 + -1 -1 -1 -1 -1 -1 -1 + -1 -1 -1 -1 -1 -1 -1 + + and then increase one by one for each input, + to enventually reach + + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + """ + + tensor1 = torch.tensor([ + [-1, -1, -1, -1, 0, 0, 0], + [-1, -1, -1, -1, 0, 0, 0], + [-1, -1, -1, -1, 0, 0, 0], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1] + ]) + + tensor2 = torch.tensor([ + [-1, -1, -1, 0, 0, 0, 0], + [-1, -1, -1, 0, 0, 0, 0], + [-1, -1, -1, 0, 0, 0, 0], + [-1, -1, -1, 0, 0, 0, 0], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1] + ]) + + tensor3 = torch.tensor([ + [-1, -1, 0, 0, 0, 0, 0], + [-1, -1, 0, 0, 0, 0, 0], + [-1, -1, 0, 0, 0, 0, 0], + [-1, -1, 0, 0, 0, 0, 0], + [-1, -1, 0, 0, 0, 0, 0], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1] + ]) + + tensor4 = torch.tensor([ + [-1, 0, 0, 0, 0, 0, 0], + [-1, 0, 0, 0, 0, 0, 0], + [-1, 0, 0, 0, 0, 0, 0], + [-1, 0, 0, 0, 0, 0, 0], + [-1, 0, 0, 0, 0, 0, 0], + [-1, 0, 0, 0, 0, 0, 0], + [-1, -1, -1, -1, -1, -1, -1] + ]) + + tensor5 = torch.tensor([ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ]) + + config_3 = Config({ + 'collate': { + 'snp_dim': 'max', + 'indiv_dim': 'max', + 'value_fill': -1, + 'pad_right': False, + 'pad_left': True, + 'pad_bottom': True, + 'pad_top': False + } + }) + + collate = CollateTraining.from_config(config_3) + collated_batch = collate(fake_batch) + + assert torch.equal(collated_batch[1][0], tensor1) + assert torch.equal(collated_batch[1][1], tensor2) + assert torch.equal(collated_batch[1][2], tensor3) + assert torch.equal(collated_batch[1][3], tensor4) + assert torch.equal(collated_batch[1][4], tensor5) + + """ + this format should be + -1 -1 -1 -1 -1 -1 -1 + -1 -1 -1 -1 -1 -1 -1 + -1 -1 -1 -1 -1 -1 -1 + -1 -1 -1 -1 X X X + -1 -1 -1 -1 X X X + -1 -1 -1 -1 X X X + + and then increase one by one for each input, + to enventually reach + + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + """ + + tensor1 = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, 0, 0, 0], + [-1, -1, -1, -1, 0, 0, 0], + [-1, -1, -1, -1, 0, 0, 0] + ]) + + tensor2 = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, 0, 0, 0, 0], + [-1, -1, -1, 0, 0, 0, 0], + [-1, -1, -1, 0, 0, 0, 0], + [-1, -1, -1, 0, 0, 0, 0] + ]) + + tensor3 = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, 0, 0, 0, 0, 0], + [-1, -1, 0, 0, 0, 0, 0], + [-1, -1, 0, 0, 0, 0, 0], + [-1, -1, 0, 0, 0, 0, 0], + [-1, -1, 0, 0, 0, 0, 0] + ]) + + tensor4 = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1], + [-1, 0, 0, 0, 0, 0, 0], + [-1, 0, 0, 0, 0, 0, 0], + [-1, 0, 0, 0, 0, 0, 0], + [-1, 0, 0, 0, 0, 0, 0], + [-1, 0, 0, 0, 0, 0, 0], + [-1, 0, 0, 0, 0, 0, 0] + ]) + + tensor5 = torch.tensor([ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0] + ]) + + config_4 = Config({ + 'collate': { + 'snp_dim': 'max', + 'indiv_dim': 'max', + 'value_fill': -1, + 'pad_right': False, + 'pad_left': True, + 'pad_bottom': False, + 'pad_top': True + } + }) + + collate = CollateTraining.from_config(config_4) + collated_batch = collate(fake_batch) + + assert torch.equal(collated_batch[1][0], tensor1) + assert torch.equal(collated_batch[1][1], tensor2) + assert torch.equal(collated_batch[1][2], tensor3) + assert torch.equal(collated_batch[1][3], tensor4) + assert torch.equal(collated_batch[1][4], tensor5) + + """ + Now we test the cases where two dual variables are set to true, + i.e. pad_right = True and pad_left = True + In this case, the input is in the middle of the array. If the dimension doesn't fit, + the padding on the right and the padding on the bottom is privileged + """ + + """ + this format should be + -1 -1 -1 -1 -1 -1 -1 + -1 -1 -1 -1 -1 -1 -1 + -1 -1 -1 -1 -1 -1 -1 + -1 -1 -1 -1 -1 -1 -1 + -1 -1 X X X -1 -1 + -1 -1 X X X -1 -1 + -1 -1 X X X -1 -1 + + and then increase one by one for each input, + to enventually reach + + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + """ + + tensor1 = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, 0, 0, 0, -1, -1], + [-1, -1, 0, 0, 0, -1, -1], + [-1, -1, 0, 0, 0, -1, -1] + ]) + + tensor2 = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, 0, 0, 0, 0, -1, -1], + [-1, 0, 0, 0, 0, -1, -1], + [-1, 0, 0, 0, 0, -1, -1], + [-1, 0, 0, 0, 0, -1, -1] + ]) + + tensor3 = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, 0, 0, 0, 0, 0, -1], + [-1, 0, 0, 0, 0, 0, -1], + [-1, 0, 0, 0, 0, 0, -1], + [-1, 0, 0, 0, 0, 0, -1], + [-1, 0, 0, 0, 0, 0, -1] + ]) + + tensor4 = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1] + ]) + + tensor5 = torch.tensor([ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0] + ]) + + config_5 = Config({ + 'collate': { + 'snp_dim': 'max', + 'indiv_dim': 'max', + 'value_fill': -1, + 'pad_right': True, + 'pad_left': True, + 'pad_bottom': False, + 'pad_top': True + } + }) + + collate = CollateTraining.from_config(config_5) + collated_batch = collate(fake_batch) + + assert torch.equal(collated_batch[1][0], tensor1) + assert torch.equal(collated_batch[1][1], tensor2) + assert torch.equal(collated_batch[1][2], tensor3) + assert torch.equal(collated_batch[1][3], tensor4) + assert torch.equal(collated_batch[1][4], tensor5) + + """ + this format should be + -1 -1 -1 -1 -1 -1 -1 + -1 -1 -1 -1 -1 -1 -1 + X X X -1 -1 -1 -1 + X X X -1 -1 -1 -1 + X X X -1 -1 -1 -1 + -1 -1 -1 -1 -1 -1 -1 + -1 -1 -1 -1 -1 -1 -1 + + and then increase one by one for each input, + to enventually reach + + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + """ + + tensor1 = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [0, 0, 0, -1, -1, -1, -1], + [0, 0, 0, -1, -1, -1, -1], + [0, 0, 0, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1] + ]) + + tensor2 = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1], + [0, 0, 0, 0, -1, -1, -1], + [0, 0, 0, 0, -1, -1, -1], + [0, 0, 0, 0, -1, -1, -1], + [0, 0, 0, 0, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1] + ]) + + tensor3 = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1], + [0, 0, 0, 0, 0, -1, -1], + [0, 0, 0, 0, 0, -1, -1], + [0, 0, 0, 0, 0, -1, -1], + [0, 0, 0, 0, 0, -1, -1], + [0, 0, 0, 0, 0, -1, -1], + [-1, -1, -1, -1, -1, -1, -1] + ]) + + tensor4 = torch.tensor([ + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [-1, -1, -1, -1, -1, -1, -1] + ]) + + tensor5 = torch.tensor([ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0] + ]) + + config_6 = Config({ + 'collate': { + 'snp_dim': 'max', + 'indiv_dim': 'max', + 'value_fill': -1, + 'pad_right': True, + 'pad_left': False, + 'pad_bottom': True, + 'pad_top': True + } + }) + + collate = CollateTraining.from_config(config_6) + collated_batch = collate(fake_batch) + + assert torch.equal(collated_batch[1][0], tensor1) + assert torch.equal(collated_batch[1][1], tensor2) + assert torch.equal(collated_batch[1][2], tensor3) + assert torch.equal(collated_batch[1][3], tensor4) + assert torch.equal(collated_batch[1][4], tensor5) + + """ + this format should be + -1 -1 -1 -1 -1 -1 -1 + -1 -1 -1 -1 -1 -1 -1 + -1 -1 X X X -1 -1 + -1 -1 X X X -1 -1 + -1 -1 X X X -1 -1 + -1 -1 -1 -1 -1 -1 -1 + -1 -1 -1 -1 -1 -1 -1 + + and then increase one by one for each input, + to enventually reach + + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + """ + + tensor1 = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, 0, 0, 0, -1, -1], + [-1, -1, 0, 0, 0, -1, -1], + [-1, -1, 0, 0, 0, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1] + ]) + + tensor2 = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1], + [-1, 0, 0, 0, 0, -1, -1], + [-1, 0, 0, 0, 0, -1, -1], + [-1, 0, 0, 0, 0, -1, -1], + [-1, 0, 0, 0, 0, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1] + ]) + + tensor3 = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1], + [-1, 0, 0, 0, 0, 0, -1], + [-1, 0, 0, 0, 0, 0, -1], + [-1, 0, 0, 0, 0, 0, -1], + [-1, 0, 0, 0, 0, 0, -1], + [-1, 0, 0, 0, 0, 0, -1], + [-1, -1, -1, -1, -1, -1, -1] + ]) + + tensor4 = torch.tensor([ + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [-1, -1, -1, -1, -1, -1, -1] + ]) + + tensor5 = torch.tensor([ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0] + ]) + + config_7 = Config({ + 'collate': { + 'snp_dim': 'max', + 'indiv_dim': 'max', + 'value_fill': -1, + 'pad_right': True, + 'pad_left': True, + 'pad_bottom': True, + 'pad_top': True + } + }) + + collate = CollateTraining.from_config(config_7) + collated_batch = collate(fake_batch) + + assert torch.equal(collated_batch[1][0], tensor1) + assert torch.equal(collated_batch[1][1], tensor2) + assert torch.equal(collated_batch[1][2], tensor3) + assert torch.equal(collated_batch[1][3], tensor4) + assert torch.equal(collated_batch[1][4], tensor5) + + """ + Last batch of tests: + We will test different case where the snp_dim or the + indiv_dim is not set to max + """ + + """ + this format should be + -1 -1 -1 -1 -1 -1 -1 + -1 -1 X X X -1 -1 + -1 -1 X X X -1 -1 + -1 -1 X X X -1 -1 + -1 -1 -1 -1 -1 -1 -1 + + and then increase one by one for each input, + to enventually reach + + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + X X X X X X X + """ + + tensor1 = torch.tensor([ + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, 0, 0, 0, -1, -1], + [-1, -1, 0, 0, 0, -1, -1], + [-1, -1, 0, 0, 0, -1, -1], + [-1, -1, -1, -1, -1, -1, -1] + ]) + + tensor2 = torch.tensor([ + [-1, 0, 0, 0, 0, -1, -1], + [-1, 0, 0, 0, 0, -1, -1], + [-1, 0, 0, 0, 0, -1, -1], + [-1, 0, 0, 0, 0, -1, -1], + [-1, -1, -1, -1, -1, -1, -1] + ]) + + tensor3 = torch.tensor([ + [-1, 0, 0, 0, 0, 0, -1], + [-1, 0, 0, 0, 0, 0, -1], + [-1, 0, 0, 0, 0, 0, -1], + [-1, 0, 0, 0, 0, 0, -1], + [-1, 0, 0, 0, 0, 0, -1] + ]) + + tensor4 = torch.tensor([ + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1], + [0, 0, 0, 0, 0, 0, -1] + ]) + + tensor5 = torch.tensor([ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0] + ]) + + config_8 = Config({ + 'collate': { + 'snp_dim': 5, + 'indiv_dim': 'max', + 'value_fill': -1, + 'pad_right': True, + 'pad_left': True, + 'pad_bottom': True, + 'pad_top': True + } + }) + + collate = CollateTraining.from_config(config_8) + collated_batch = collate(fake_batch) + + assert torch.equal(collated_batch[1][0], tensor1) + assert torch.equal(collated_batch[1][1], tensor2) + assert torch.equal(collated_batch[1][2], tensor3) + assert torch.equal(collated_batch[1][3], tensor4) + assert torch.equal(collated_batch[1][4], tensor5) + + """ + this format should be + -1 -1 -1 -1 -1 + -1 -1 -1 -1 -1 + -1 X X X -1 + -1 X X X -1 + -1 X X X -1 + -1 -1 -1 -1 -1 + -1 -1 -1 -1 -1 + + and then increase one by one for each input, + to enventually reach + + X X X X X + X X X X X + X X X X X + X X X X X + X X X X X + X X X X X + X X X X X + """ + + tensor1 = torch.tensor([ + [-1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1], + [-1, 0, 0, 0, -1], + [-1, 0, 0, 0, -1], + [-1, 0, 0, 0, -1], + [-1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1] + ]) + + tensor2 = torch.tensor([ + [-1, -1, -1, -1, -1], + [0, 0, 0, 0, -1], + [0, 0, 0, 0, -1], + [0, 0, 0, 0, -1], + [0, 0, 0, 0, -1], + [-1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1] + ]) + + tensor3 = torch.tensor([ + [-1, -1, -1, -1, -1], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [-1, -1, -1, -1, -1] + ]) + + tensor4 = torch.tensor([ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [-1, -1, -1, -1, -1] + ]) + + tensor5 = torch.tensor([ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ]) + + config_9 = Config({ + 'collate': { + 'snp_dim': 'max', + 'indiv_dim': 5, + 'value_fill': -1, + 'pad_right': True, + 'pad_left': True, + 'pad_bottom': True, + 'pad_top': True + } + }) + + collate = CollateTraining.from_config(config_9) + collated_batch = collate(fake_batch) + + assert torch.equal(collated_batch[1][0], tensor1) + assert torch.equal(collated_batch[1][1], tensor2) + assert torch.equal(collated_batch[1][2], tensor3) + assert torch.equal(collated_batch[1][3], tensor4) + assert torch.equal(collated_batch[1][4], tensor5) + + """ + this format should be + -1 -1 -1 -1 -1 + -1 X X X -1 + -1 X X X -1 + -1 X X X -1 + -1 -1 -1 -1 -1 + + and then increase one by one for each input, + to enventually reach + + X X X X X + X X X X X + X X X X X + X X X X X + X X X X X + """ + + tensor1 = torch.tensor([ + [-1, -1, -1, -1, -1], + [-1, 0, 0, 0, -1], + [-1, 0, 0, 0, -1], + [-1, 0, 0, 0, -1], + [-1, -1, -1, -1, -1,] + ]) + + tensor2 = torch.tensor([ + [0, 0, 0, 0, -1], + [0, 0, 0, 0, -1], + [0, 0, 0, 0, -1], + [0, 0, 0, 0, -1], + [-1, -1, -1, -1, -1,] + ]) + + tensor3 = torch.tensor([ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0] + ]) + + tensor4 = torch.tensor([ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0] + ]) + + tensor5 = torch.tensor([ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0] + ]) + + config_10 = Config({ + 'collate': { + 'snp_dim': 5, + 'indiv_dim': 5, + 'value_fill': -1, + 'pad_right': True, + 'pad_left': True, + 'pad_bottom': True, + 'pad_top': True + } + }) + + collate = CollateTraining.from_config(config_10) + collated_batch = collate(fake_batch) + + assert torch.equal(collated_batch[1][0], tensor1) + assert torch.equal(collated_batch[1][1], tensor2) + assert torch.equal(collated_batch[1][2], tensor3) + assert torch.equal(collated_batch[1][3], tensor4) + assert torch.equal(collated_batch[1][4], tensor5) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 7f82a84b3ab87462f8e37b51b65fbfaf200236cc..eca45f8ddf71b38396914e39b010c77853e1d121 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -18,7 +18,9 @@ from dnadna.examples import one_event from dnadna.params import LearnedParams from dnadna.snp_sample import SNPSample from dnadna.transforms import ValidateSnp +from dnadna.utils.config import Config from dnadna.utils.misc import zero_pad_format +from dnadna.collate import CollateTraining @pytest.fixture(scope='module') @@ -345,11 +347,24 @@ def test_invalid_uniform_dataloader(num_workers): exception is raised. """ + # Necessary since the change of collate_func to a class + config = Config({ + 'collate': { + 'snp_dim': 'max', + 'indiv_dim': 'max', + 'value_fill': -1, + 'pad_right': True, + 'pad_left': False, + 'pad_bottom': True, + 'pad_top': False + }}) + + collate = CollateTraining.from_config(config) _, dataset = nonuniform_dataset(nonuniformity='both', n_scenarios=10) loader = DataLoader(dataset=dataset, batch_size=1, num_workers=num_workers, - collate_fn=DNATrainingDataset.collate_batch) + collate_fn=collate) # Attempting to iterate over the loader should result in an # InvalidSNPSample error eventually @@ -371,7 +386,7 @@ def test_invalid_uniform_dataloader(num_workers): loader = DataLoader(dataset=dataset, batch_size=1, num_workers=num_workers, - collate_fn=DNATrainingDataset.collate_batch) + collate_fn=collate) for r_idx, datum in enumerate(loader): s_idx, snp = datum[:2] diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 2c9335a7ef028cf27b262858c32981763b1ccf3f..39ce9c04f496dee9c913560cbaf8632920e7e4ec 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -29,6 +29,15 @@ def test_predict_with_classification_params(): 'type': 'classification', 'classes': ['yes', 'no'] } + }, + 'collate': { + 'snp_dim': 'max', + 'indiv_dim': 'max', + 'value_fill': -1, + 'pad_right': True, + 'pad_left': False, + 'pad_bottom': True, + 'pad_top': False } }) @@ -60,6 +69,15 @@ def test_predict_with_log_transform(device): }, 'learned_params': { 'position': {'type': 'regression', 'log_transform': True}, + }, + 'collate': { + 'snp_dim': 'max', + 'indiv_dim': 'max', + 'value_fill': -1, + 'pad_right': True, + 'pad_left': False, + 'pad_bottom': True, + 'pad_top': False } })