Skip to content

Commit

Permalink
jaxtyping added
Browse files Browse the repository at this point in the history
  • Loading branch information
Artur-Galstyan committed Apr 5, 2024
1 parent f71d752 commit 098a16a
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 12 deletions.
9 changes: 5 additions & 4 deletions jaxonloader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from jaxonloader.dataset import JaxonDataset, SingleArrayDataset, DataTargetDataset # noqa
from jaxonloader.dataloader import JaxonDataLoader # noqa
from beartype.claw import beartype_this_package
from jaxtyping import install_import_hook

beartype_this_package()

with install_import_hook(modules=["jaxonloader"], typechecker="beartype.beartype"):
from jaxonloader.dataset import JaxonDataset, SingleArrayDataset, DataTargetDataset # noqa
from jaxonloader.dataloader import JaxonDataLoader # noqa
6 changes: 5 additions & 1 deletion jaxonloader/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from jaxonloader.datasets._datasets import * # noqa
from jaxtyping import install_import_hook


with install_import_hook(modules=["jaxonloader"], typechecker="beartype.beartype"):
from jaxonloader.datasets._datasets import * # noqa
21 changes: 15 additions & 6 deletions jaxonloader/datasets/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,18 +186,27 @@ def decode(latent: NDArray) -> str:


@jaxonloader_cache(dataset_name="titanic")
def get_titanic():
def get_titanic() -> JaxonDataset:
data_url = "https://omnisium.eu-central-1.linodeobjects.com/titanic/titanic.zip"
data_path = pathlib.Path(JAXONLOADER_PATH) / "titanic"
download_and_extract_zip(data_url, data_path)
train_df = pl.read_csv(data_path / "train.csv")

train = pd.read_csv(data_path / "train.csv").to_numpy()
test = pd.read_csv(data_path / "test.csv").to_numpy()
def _gender_to_int(df: pl.DataFrame) -> pl.DataFrame:
df = df.with_columns(
pl.col("Sex")
.apply(lambda gender: 0 if gender == "male" else 1)
.alias("Sex")
)
return df

train_dataset = SingleArrayDataset(train)
test_dataset = SingleArrayDataset(test)
train = _gender_to_int(train_df)
train_data = train.select(pl.exclude("Survived")).to_numpy()
train_target = train.select(pl.col("Survived")).to_numpy()

return train_dataset, test_dataset
train_dataset = DataTargetDataset(train_data, train_target)

return train_dataset


def from_dataframe(dataframe: pl.DataFrame | pd.DataFrame) -> JaxonDataset:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "jaxonloader"
version = "0.2.9"
version = "0.3.0"
description = "A dataloader, but for JAX"
readme = "README.md"
requires-python ="~=3.10"
Expand Down

0 comments on commit 098a16a

Please sign in to comment.