diff --git a/declearn/dataset/_inmemory.py b/declearn/dataset/_inmemory.py index 7f3b33fa73844f1a5cc285c2ebf95de3e7065063..f8e5f54df08b82ea538fc2a90eda41f71a244598 100644 --- a/declearn/dataset/_inmemory.py +++ b/declearn/dataset/_inmemory.py @@ -149,6 +149,11 @@ class InMemoryDataset(Dataset): target = self.data[target] else: target = load_data_array(target) + if ( + isinstance(target, pd.DataFrame) + and len(target.columns) == 1 + ): + target = target.iloc[:, 0] self.target = target # Assign the (optional) sample weights data array. if isinstance(s_wght, str):