diff --git a/src/training.py b/src/training.py index ad04cbf7a6d78646bf58b74b06870143c6299ebe..aa212c72921d80f0adbba271c90aae347ca7c612 100755 --- a/src/training.py +++ b/src/training.py @@ -46,6 +46,7 @@ def main(): print('Number of images in VALID:', len(valid_dataset.data_B)) elif dataset_type == 'unpaired': + df_file = f'{dpath}/data/train-dataset_rh_4class-jeanzay.csv' train_dataset = ds.UnpairedImageDataset(subset_A, subset_B, df_file, contrast_list) print('Number of images in TRAIN:', len(train_dataset.data))