From 098a16a77815d4697b840c72b85eeb7e0bccfcf1 Mon Sep 17 00:00:00 2001 From: Artur Galstyan Date: Fri, 5 Apr 2024 17:18:19 +0200 Subject: [PATCH] jaxtyping added --- jaxonloader/__init__.py | 9 +++++---- jaxonloader/datasets/__init__.py | 6 +++++- jaxonloader/datasets/_datasets.py | 21 +++++++++++++++------ pyproject.toml | 2 +- 4 files changed, 26 insertions(+), 12 deletions(-) diff --git a/jaxonloader/__init__.py b/jaxonloader/__init__.py index ca6e67f..32256dd 100644 --- a/jaxonloader/__init__.py +++ b/jaxonloader/__init__.py @@ -1,5 +1,6 @@ -from jaxonloader.dataset import JaxonDataset, SingleArrayDataset, DataTargetDataset # noqa -from jaxonloader.dataloader import JaxonDataLoader # noqa -from beartype.claw import beartype_this_package +from jaxtyping import install_import_hook -beartype_this_package() + +with install_import_hook(modules=["jaxonloader"], typechecker="beartype.beartype"): + from jaxonloader.dataset import JaxonDataset, SingleArrayDataset, DataTargetDataset # noqa + from jaxonloader.dataloader import JaxonDataLoader # noqa diff --git a/jaxonloader/datasets/__init__.py b/jaxonloader/datasets/__init__.py index f21b139..d057450 100644 --- a/jaxonloader/datasets/__init__.py +++ b/jaxonloader/datasets/__init__.py @@ -1 +1,5 @@ -from jaxonloader.datasets._datasets import * # noqa +from jaxtyping import install_import_hook + + +with install_import_hook(modules=["jaxonloader"], typechecker="beartype.beartype"): + from jaxonloader.datasets._datasets import * # noqa diff --git a/jaxonloader/datasets/_datasets.py b/jaxonloader/datasets/_datasets.py index 9545daf..06298db 100644 --- a/jaxonloader/datasets/_datasets.py +++ b/jaxonloader/datasets/_datasets.py @@ -186,18 +186,27 @@ def decode(latent: NDArray) -> str: @jaxonloader_cache(dataset_name="titanic") -def get_titanic(): +def get_titanic() -> JaxonDataset: data_url = "https://omnisium.eu-central-1.linodeobjects.com/titanic/titanic.zip" data_path = pathlib.Path(JAXONLOADER_PATH) / "titanic" download_and_extract_zip(data_url, data_path) + train_df = pl.read_csv(data_path / "train.csv") - train = pd.read_csv(data_path / "train.csv").to_numpy() - test = pd.read_csv(data_path / "test.csv").to_numpy() + def _gender_to_int(df: pl.DataFrame) -> pl.DataFrame: + df = df.with_columns( + pl.col("Sex") + .apply(lambda gender: 0 if gender == "male" else 1) + .alias("Sex") + ) + return df - train_dataset = SingleArrayDataset(train) - test_dataset = SingleArrayDataset(test) + train = _gender_to_int(train_df) + train_data = train.select(pl.exclude("Survived")).to_numpy() + train_target = train.select(pl.col("Survived")).to_numpy() - return train_dataset, test_dataset + train_dataset = DataTargetDataset(train_data, train_target) + + return train_dataset def from_dataframe(dataframe: pl.DataFrame | pd.DataFrame) -> JaxonDataset: diff --git a/pyproject.toml b/pyproject.toml index cfce443..46772b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "jaxonloader" -version = "0.2.9" +version = "0.3.0" description = "A dataloader, but for JAX" readme = "README.md" requires-python ="~=3.10"