Skip to content

Commit

Permalink
✨ Support modern datasets (Kedro 0.19.7+) (#590)
Browse files Browse the repository at this point in the history
* ✨ Support modern datasets (Kedro 0.19.7+)

* ✨ Add test for modern and legacy datasets

* ✨ Add change description to release notes
  • Loading branch information
deepyaman authored Sep 9, 2024
1 parent bb46df5 commit ed43e0d
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 3 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

## [Unreleased]

### Added

- :sparkles: Add support for "modern" datasets ([introduced in Kedro 0.19.7](https://github.com/kedro-org/kedro/commit/52458c2addd1827623d06c20228b709052a5fdf3)) that expose `load` and `save` publicly ([#590, deepyaman](https://github.com/Galileo-Galilei/kedro-mlflow/pull/590))

## [0.13.0] - 2024-09-01

### Added

- :sparkles: Add support for loading model with alias in `MlflowModelRegistryDataset` [#553](https://github.com/Galileo-Galilei/kedro-mlflow/issues/553)
- :sparkles: Add support for loading model with alias in `MlflowModelRegistryDataset` ([#553](https://github.com/Galileo-Galilei/kedro-mlflow/issues/553))

### Changed

Expand Down
10 changes: 8 additions & 2 deletions kedro_mlflow/io/artifacts/mlflow_artifact_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def _save(self, data: Any):
# for logging on remote storage like Azure S3
local_path = local_path.as_posix()

super()._save(data)
if getattr(super().save, "__savewrapped__", False): # modern dataset
super().save.__wrapped__(self, data)
else: # legacy dataset
super()._save(data)

if self._logging_activated:
if self.run_id:
Expand Down Expand Up @@ -131,7 +134,10 @@ def _load(self) -> Any: # pragma: no cover
shutil.copy(src=temp_download_filepath, dst=local_path)

# finally, read locally
return super()._load()
if getattr(super().load, "__loadwrapped__", False): # modern dataset
return super().load.__wrapped__(self)
else: # legacy dataset
super()._load()

# rename the class
parent_name = dataset_obj.__name__
Expand Down
81 changes: 81 additions & 0 deletions tests/io/artifacts/test_mlflow_artifact_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import mlflow
import pandas as pd
import pytest
from kedro.io import AbstractDataset
from kedro_datasets.pandas import CSVDataset
from kedro_datasets.partitions import PartitionedDataset
from kedro_datasets.pickle import PickleDataset
Expand Down Expand Up @@ -289,3 +290,83 @@ def test_partitioned_dataset_save_and_reload(
reloaded_data = {k: loader() for k, loader in mlflow_dataset.load().items()}
for k, df in data.items():
pd.testing.assert_frame_equal(df, reloaded_data[k])


def test_modern_dataset(tmp_path, mlflow_client, df1):
class MyOwnDatasetWithoutUnderscoreMethods(AbstractDataset):
def __init__(self, filepath):
self._filepath = Path(filepath)

def load(self) -> pd.DataFrame:
return pd.read_csv(self._filepath)

def save(self, df: pd.DataFrame) -> None:
df.to_csv(str(self._filepath), index=False)

def _exists(self) -> bool:
return Path(self._filepath.as_posix()).exists()

def _describe(self):
return dict(param1=self._filepath)

filepath = tmp_path / "data.csv"

mlflow_dataset = MlflowArtifactDataset(
artifact_path="artifact_dir",
dataset=dict(
type=MyOwnDatasetWithoutUnderscoreMethods, filepath=filepath.as_posix()
),
)

with mlflow.start_run():
mlflow_dataset.save(df1)
run_id = mlflow.active_run().info.run_id

# the artifact must be properly uploaded to "mlruns" and reloadable
run_artifacts = [
fileinfo.path
for fileinfo in mlflow_client.list_artifacts(run_id=run_id, path="artifact_dir")
]
remote_path = (Path("artifact_dir") / filepath.name).as_posix()
assert remote_path in run_artifacts
assert df1.equals(mlflow_dataset.load())


def test_legacy_dataset(tmp_path, mlflow_client, df1):
class MyOwnDatasetWithUnderscoreMethods(AbstractDataset):
def __init__(self, filepath):
self._filepath = Path(filepath)

def _load(self) -> pd.DataFrame:
return pd.read_csv(self._filepath)

def _save(self, df: pd.DataFrame) -> None:
df.to_csv(str(self._filepath), index=False)

def _exists(self) -> bool:
return Path(self._filepath.as_posix()).exists()

def _describe(self):
return dict(param1=self._filepath)

filepath = tmp_path / "data.csv"

mlflow_dataset = MlflowArtifactDataset(
artifact_path="artifact_dir",
dataset=dict(
type=MyOwnDatasetWithUnderscoreMethods, filepath=filepath.as_posix()
),
)

with mlflow.start_run():
mlflow_dataset.save(df1)
run_id = mlflow.active_run().info.run_id

# the artifact must be properly uploaded to "mlruns" and reloadable
run_artifacts = [
fileinfo.path
for fileinfo in mlflow_client.list_artifacts(run_id=run_id, path="artifact_dir")
]
remote_path = (Path("artifact_dir") / filepath.name).as_posix()
assert remote_path in run_artifacts
assert df1.equals(mlflow_dataset.load())

0 comments on commit ed43e0d

Please sign in to comment.