diff --git a/declearn/dataset/examples/_heart_uci.py b/declearn/dataset/examples/_heart_uci.py index 457ee2c8bff2755b878b633b71128c7458abbebe..a290f0cfd49cb8b1a224320369fa500c24f2f5d3 100644 --- a/declearn/dataset/examples/_heart_uci.py +++ b/declearn/dataset/examples/_heart_uci.py @@ -17,10 +17,13 @@ """Util to download and pre-process the UCI Heart Disease dataset.""" +import io import os -from typing import Literal, Optional, Tuple +import zipfile +from typing import Literal, Optional, Tuple, Union -import pandas as pd # type: ignore +import pandas as pd +import requests __all__ = [ "load_heart_uci", @@ -33,7 +36,7 @@ def load_heart_uci( ) -> Tuple[pd.DataFrame, str]: """Load and/or download a pre-processed UCI Heart Disease dataset. - See [https://archive.ics.uci.edu/ml/datasets/Heart+Disease] for + See [https://archive.ics.uci.edu/dataset/45/heart+disease] for information on the UCI Heart Disease dataset. Arguments @@ -53,31 +56,73 @@ def load_heart_uci( Name of the target column in `data`. May be passed as `target` of a declearn `InMemoryDataset`. """ - # If the file already exists, read and return it. + # If the pre-processed file already exists, read and return it. if folder is not None: path = os.path.join(folder, f"data_{name}.csv") if os.path.isfile(path): data = pd.read_csv(path) return data, "num" - # Otherwise, download and pre-process the data, and optionally save it. - data = download_heart_uci_shard(name) + # Download (and optionally save) or read from the source zip file. + source = get_heart_uci_zipfile(folder) + # Extract the target shard and preprocess it. + data = extract_heart_uci_shard(name, source) + data = preprocess_heart_uci_dataframe(data) + # Optionally save the preprocessed shard to disk. if folder is not None: - os.makedirs(folder, exist_ok=True) - data.to_csv(path, index=False) + path = os.path.join(folder, f"data_{name}.csv") + data.to_csv(path, sep=",", encoding="utf-8", index=False) return data, "num" -def download_heart_uci_shard( +def get_heart_uci_zipfile(folder: Optional[str]) -> Union[str, bytes]: + """Download and opt. save the Heart Dataset zip file, or return its path. + + Return either the path to the zip file, or its contents. + """ + # Case when the data is to be downloaded and kept only in memory. + if folder is None: + return download_heart_uci() + # Case when the data can be read from a pre-existing file on disk. + path = os.path.join(folder, "heart+disease.zip") + if os.path.isfile(path): + return path + # Case when the data is to be donwloaded and saved on disk for re-use. + data = download_heart_uci() + with open(path, "wb") as file: + file.write(data) + return data + + +def download_heart_uci() -> bytes: + """Download the Heart Disease UCI dataset source file.""" + print("Downloading Heart Disease UCI dataset.") + url = "https://archive.ics.uci.edu/static/public/45/heart+disease.zip" + reply = requests.get(url, timeout=300) + try: + reply.raise_for_status() + except requests.HTTPError as exc: + raise RuntimeError( + "Failed to download Heart Disease UCI source file." + ) from exc + return reply.content + + +def extract_heart_uci_shard( name: Literal["cleveland", "hungarian", "switzerland", "va"], + source: Union[str, bytes], +) -> pd.DataFrame: + """Read a subset of the Heart UCI data, from in-memory or on-disk data.""" + zdat = source if isinstance(source, str) else io.BytesIO(source) + with zipfile.ZipFile(zdat) as archive: # type: ignore + with archive.open(f"processed.{name}.data") as path: + data = pd.read_csv(path, header=None, na_values="?") + return data + + +def preprocess_heart_uci_dataframe( + data: pd.DataFrame, ) -> pd.DataFrame: - """Download and preprocess a subset of the Heart UCI dataset.""" - print(f"Downloading Heart Disease UCI dataset from center {name}.") - url = ( - "https://archive.ics.uci.edu/ml/machine-learning-databases/" - f"heart-disease/processed.{name}.data" - ) - # Download the dataaset. - data = pd.read_csv(url, header=None, na_values="?") + """Preprocess a subset of the Heart UCI dataset.""" columns = [ # fmt: off "age", "sex", "cp", "trestbps", "chol", "fbs", "restecg", @@ -90,7 +135,7 @@ def download_heart_uci_shard( data.reset_index(inplace=True, drop=True) # Normalize quantitative variables. for col in ("age", "trestbps", "thalach", "oldpeak"): - data[col] = ( # type: ignore + data[col] = ( data[col] - data[col].mean() / data[col].std() # type: ignore ) # Binarize the target variable. diff --git a/test/dataset/test_examples.py b/test/dataset/test_examples.py index 84c6790f316ca05bc7d1ad5194c7e4079019f112..0615b59d0e95e0375fb178b2e4053e3406e70350 100644 --- a/test/dataset/test_examples.py +++ b/test/dataset/test_examples.py @@ -36,7 +36,7 @@ def test_load_heart_uci(tmpdir: str) -> None: assert tcol in data.columns # Test that re-loading the dataset works. with mock.patch( - "declearn.dataset.examples._heart_uci.download_heart_uci_shard" + "declearn.dataset.examples._heart_uci.download_heart_uci" ) as patch_download: data_bis, tcol_bis = load_heart_uci("va", folder=tmpdir) patch_download.assert_not_called()