Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 3b687a3a authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Fix Heart Disease UCI Dataset loaded due to website change.

parent afc29510
No related branches found
No related tags found
No related merge requests found
Pipeline #816406 passed
......@@ -17,10 +17,13 @@
"""Util to download and pre-process the UCI Heart Disease dataset."""
import io
import os
from typing import Literal, Optional, Tuple
import zipfile
from typing import Literal, Optional, Tuple, Union
import pandas as pd # type: ignore
import pandas as pd
import requests
__all__ = [
"load_heart_uci",
......@@ -33,7 +36,7 @@ def load_heart_uci(
) -> Tuple[pd.DataFrame, str]:
"""Load and/or download a pre-processed UCI Heart Disease dataset.
See [https://archive.ics.uci.edu/ml/datasets/Heart+Disease] for
See [https://archive.ics.uci.edu/dataset/45/heart+disease] for
information on the UCI Heart Disease dataset.
Arguments
......@@ -53,31 +56,73 @@ def load_heart_uci(
Name of the target column in `data`.
May be passed as `target` of a declearn `InMemoryDataset`.
"""
# If the file already exists, read and return it.
# If the pre-processed file already exists, read and return it.
if folder is not None:
path = os.path.join(folder, f"data_{name}.csv")
if os.path.isfile(path):
data = pd.read_csv(path)
return data, "num"
# Otherwise, download and pre-process the data, and optionally save it.
data = download_heart_uci_shard(name)
# Download (and optionally save) or read from the source zip file.
source = get_heart_uci_zipfile(folder)
# Extract the target shard and preprocess it.
data = extract_heart_uci_shard(name, source)
data = preprocess_heart_uci_dataframe(data)
# Optionally save the preprocessed shard to disk.
if folder is not None:
os.makedirs(folder, exist_ok=True)
data.to_csv(path, index=False)
path = os.path.join(folder, f"data_{name}.csv")
data.to_csv(path, sep=",", encoding="utf-8", index=False)
return data, "num"
def download_heart_uci_shard(
def get_heart_uci_zipfile(folder: Optional[str]) -> Union[str, bytes]:
"""Download and opt. save the Heart Dataset zip file, or return its path.
Return either the path to the zip file, or its contents.
"""
# Case when the data is to be downloaded and kept only in memory.
if folder is None:
return download_heart_uci()
# Case when the data can be read from a pre-existing file on disk.
path = os.path.join(folder, "heart+disease.zip")
if os.path.isfile(path):
return path
# Case when the data is to be donwloaded and saved on disk for re-use.
data = download_heart_uci()
with open(path, "wb") as file:
file.write(data)
return data
def download_heart_uci() -> bytes:
"""Download the Heart Disease UCI dataset source file."""
print("Downloading Heart Disease UCI dataset.")
url = "https://archive.ics.uci.edu/static/public/45/heart+disease.zip"
reply = requests.get(url, timeout=300)
try:
reply.raise_for_status()
except requests.HTTPError as exc:
raise RuntimeError(
"Failed to download Heart Disease UCI source file."
) from exc
return reply.content
def extract_heart_uci_shard(
name: Literal["cleveland", "hungarian", "switzerland", "va"],
source: Union[str, bytes],
) -> pd.DataFrame:
"""Read a subset of the Heart UCI data, from in-memory or on-disk data."""
zdat = source if isinstance(source, str) else io.BytesIO(source)
with zipfile.ZipFile(zdat) as archive: # type: ignore
with archive.open(f"processed.{name}.data") as path:
data = pd.read_csv(path, header=None, na_values="?")
return data
def preprocess_heart_uci_dataframe(
data: pd.DataFrame,
) -> pd.DataFrame:
"""Download and preprocess a subset of the Heart UCI dataset."""
print(f"Downloading Heart Disease UCI dataset from center {name}.")
url = (
"https://archive.ics.uci.edu/ml/machine-learning-databases/"
f"heart-disease/processed.{name}.data"
)
# Download the dataaset.
data = pd.read_csv(url, header=None, na_values="?")
"""Preprocess a subset of the Heart UCI dataset."""
columns = [
# fmt: off
"age", "sex", "cp", "trestbps", "chol", "fbs", "restecg",
......@@ -90,7 +135,7 @@ def download_heart_uci_shard(
data.reset_index(inplace=True, drop=True)
# Normalize quantitative variables.
for col in ("age", "trestbps", "thalach", "oldpeak"):
data[col] = ( # type: ignore
data[col] = (
data[col] - data[col].mean() / data[col].std() # type: ignore
)
# Binarize the target variable.
......
......@@ -36,7 +36,7 @@ def test_load_heart_uci(tmpdir: str) -> None:
assert tcol in data.columns
# Test that re-loading the dataset works.
with mock.patch(
"declearn.dataset.examples._heart_uci.download_heart_uci_shard"
"declearn.dataset.examples._heart_uci.download_heart_uci"
) as patch_download:
data_bis, tcol_bis = load_heart_uci("va", folder=tmpdir)
patch_download.assert_not_called()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment