Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prototype] Make catalog cli logic reusable #4480

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
3 changes: 1 addition & 2 deletions kedro/framework/cli/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,8 @@ def _add_missing_datasets_to_catalog(missing_ds: list[str], catalog_path: Path)
def rank_catalog_factories(metadata: ProjectMetadata, env: str) -> None:
"""List all dataset factories in the catalog, ranked by priority by which they are matched."""
session = _create_session(metadata.package_name, env=env)
context = session.load_context()
catalog_factories = session.list_catalog_patterns()

catalog_factories = context.catalog.config_resolver.list_patterns()
if catalog_factories:
click.echo(yaml.dump(catalog_factories))
else:
Expand Down
14 changes: 12 additions & 2 deletions kedro/framework/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,10 @@ class KedroContext:
_extra_params: dict[str, Any] | None = field(
init=True, default=None, converter=deepcopy
)
_catalog: CatalogProtocol | None = None

@property
def catalog(self) -> CatalogProtocol:
def catalog(self) -> CatalogProtocol | None:
"""Read-only property referring to Kedro's catalog` for this context.

Returns:
Expand All @@ -187,7 +188,8 @@ def catalog(self) -> CatalogProtocol:
KedroContextError: Incorrect catalog registered for the project.

"""
return self._get_catalog()
self._catalog = self._catalog or self._get_catalog()
return self._catalog

@property
def params(self) -> dict[str, Any]:
Expand All @@ -209,6 +211,14 @@ def params(self) -> dict[str, Any]:

return OmegaConf.to_container(params) if OmegaConf.is_config(params) else params # type: ignore[return-value]

def get_catalog(
self,
save_version: str | None = None,
load_versions: dict[str, str] | None = None,
) -> CatalogProtocol:
self._catalog = self._catalog or self._get_catalog(save_version, load_versions)
return self._catalog

def _get_catalog(
self,
save_version: str | None = None,
Expand Down
150 changes: 150 additions & 0 deletions kedro/framework/session/catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import logging
from typing import Any

from kedro.framework.context import KedroContext
from kedro.framework.project import pipelines as _pipelines
from kedro.io import KedroDataCatalog


class CatalogCommandsMixin:
@property
def context(self) -> KedroContext: ... # type: ignore[empty-body]

@property
def _logger(self) -> logging.Logger: ... # type: ignore[empty-body]

def list_catalog_datasets(self, pipelines: list[str] | None = None) -> dict:
"""Show datasets per type."""
catalog = self.context.catalog
# TODO: remove after moving to new catalog
if not isinstance(catalog, KedroDataCatalog):
self._logger.warning(
"This method is available for `KedroDataCatalog` only."
)
return {}

# TODO: revise setting default pattern logic based on https://github.com/kedro-org/kedro/issues/4475
runtime_pattern = {"{default}": {"type": "MemoryDataset"}}

not_mentioned = "Datasets not mentioned in pipeline"
mentioned = "Datasets mentioned in pipeline"
factories = "Datasets generated from factories"

target_pipelines = pipelines or _pipelines.keys()

result = {}
for pipe in target_pipelines:
pl_obj = _pipelines.get(pipe)
if pl_obj:
pipeline_ds = pl_obj.datasets()
else:
existing_pls = ", ".join(sorted(_pipelines.keys()))
raise ValueError(
f"'{pipe}' pipeline not found! Existing pipelines: {existing_pls}"
)

catalog_ds = set(catalog.keys())
unused_ds = catalog_ds - pipeline_ds
default_ds = pipeline_ds - catalog_ds
used_ds = catalog_ds - unused_ds

patterns_ds = set()

for ds_name in default_ds:
if catalog.config_resolver.match_pattern(ds_name):
patterns_ds.add(ds_name)

default_ds -= patterns_ds
used_ds.update(default_ds)

catalog.config_resolver.add_runtime_patterns(runtime_pattern)

used_ds_by_type = _group_ds_by_type(used_ds, catalog)
patterns_ds_by_type = _group_ds_by_type(patterns_ds, catalog)
unused_ds_by_type = _group_ds_by_type(unused_ds, catalog)

catalog.config_resolver.remove_runtime_patterns(runtime_pattern)

data = (
(mentioned, used_ds_by_type),
(factories, patterns_ds_by_type),
(not_mentioned, unused_ds_by_type),
)
result[pipe] = {key: value for key, value in data if value}

return result

def list_catalog_patterns(self) -> list[str]:
"""List all dataset factories in the catalog, ranked by priority
by which they are matched.
"""
return self.context.catalog.config_resolver.list_patterns()

def resolve_catalog_patterns(self, include_default: bool = False) -> dict[str, Any]:
"""Resolve catalog factories against pipeline datasets."""
catalog = self.context.catalog

# TODO: remove after moving to new catalog
if not isinstance(catalog, KedroDataCatalog):
self._logger.warning(
"This method is available for `KedroDataCatalog` only."
)
return {}

# TODO: revise setting default pattern logic based on https://github.com/kedro-org/kedro/issues/4475
runtime_pattern = {"{default}": {"type": "MemoryDataset"}}
if include_default:
catalog.config_resolver.add_runtime_patterns(runtime_pattern)

pipeline_datasets = set()

for pipe in _pipelines.keys():
pl_obj = _pipelines.get(pipe)
if pl_obj:
pipeline_datasets.update(pl_obj.datasets())

# We need to include datasets defined in the catalog.yaml and datasets added manually to the catalog
explicit_datasets = {}
for ds_name, ds in catalog.items():
# TODO: when breaking change replace with is_parameter() from kedro/io/core.py
if ds_name.startswith("params:") or ds_name == "parameters":
continue

unresolved_config, _ = catalog.config_resolver.unresolve_credentials(
ds_name, ds.to_config()
)
explicit_datasets[ds_name] = unresolved_config

for ds_name in pipeline_datasets:
# TODO: when breaking change replace with is_parameter() from kedro/io/core.py
if (
ds_name in explicit_datasets
or ds_name.startswith("params:")
or ds_name == "parameters"
):
continue

ds_config = catalog.config_resolver.resolve_pattern(ds_name)
if ds_config:
explicit_datasets[ds_name] = ds_config

if include_default:
catalog.config_resolver.remove_runtime_patterns(runtime_pattern)

return explicit_datasets


def _group_ds_by_type(datasets: set[str], catalog: KedroDataCatalog) -> dict[str, dict]:
mapping = {}
for ds_name in datasets:
# TODO: when breaking change replace with is_parameter() from kedro/io/core.py
if ds_name.startswith("params:") or ds_name == "parameters":
continue

ds = catalog[ds_name]
unresolved_config, _ = catalog.config_resolver.unresolve_credentials(
ds_name, ds.to_config()
)
mapping[ds_name] = unresolved_config

return mapping
18 changes: 14 additions & 4 deletions kedro/framework/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
settings,
validate_settings,
)
from kedro.framework.session.catalog import CatalogCommandsMixin
from kedro.io.core import generate_timestamp
from kedro.runner import AbstractRunner, SequentialRunner
from kedro.utils import _find_kedro_project
Expand Down Expand Up @@ -78,7 +79,7 @@ class KedroSessionError(Exception):
pass


class KedroSession:
class KedroSession(CatalogCommandsMixin):
"""``KedroSession`` is the object that is responsible for managing the lifecycle
of a Kedro run. Use `KedroSession.create()` as
a context manager to construct a new KedroSession with session data
Expand Down Expand Up @@ -118,6 +119,7 @@ def __init__(
self._package_name = package_name
self._store = self._init_store()
self._run_called = False
self._context = None

hook_manager = _create_hook_manager()
_register_hooks(hook_manager, settings.HOOKS)
Expand All @@ -128,6 +130,11 @@ def __init__(
self._project_path / settings.CONF_SOURCE
)

@property
def context(self) -> KedroContext:
self._context = self._context or self.load_context()
return self._context

@classmethod
def create(
cls,
Expand Down Expand Up @@ -233,6 +240,8 @@ def store(self) -> dict[str, Any]:

def load_context(self) -> KedroContext:
"""An instance of the project context."""
if self._context:
return self._context
env = self.store.get("env")
extra_params = self.store.get("extra_params")
config_loader = self._get_config_loader()
Expand All @@ -247,6 +256,8 @@ def load_context(self) -> KedroContext:
)
self._hook_manager.hook.after_context_created(context=context)

self._context = context

return context # type: ignore[no-any-return]

def _get_config_loader(self) -> AbstractConfigLoader:
Expand Down Expand Up @@ -338,7 +349,6 @@ def run( # noqa: PLR0913
session_id = self.store["session_id"]
save_version = session_id
extra_params = self.store.get("extra_params") or {}
context = self.load_context()

name = pipeline_name or "__default__"

Expand All @@ -364,7 +374,7 @@ def run( # noqa: PLR0913
record_data = {
"session_id": session_id,
"project_path": self._project_path.as_posix(),
"env": context.env,
"env": self.context.env,
"kedro_version": kedro_version,
"tags": tags,
"from_nodes": from_nodes,
Expand All @@ -379,7 +389,7 @@ def run( # noqa: PLR0913
"runner": getattr(runner, "__name__", str(runner)),
}

catalog = context._get_catalog(
catalog = self.context.get_catalog(
save_version=save_version,
load_versions=load_versions,
)
Expand Down
6 changes: 4 additions & 2 deletions kedro/io/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,16 +252,18 @@ def to_config(self) -> dict[str, Any]:
return_config[VERSIONED_FLAG_KEY] = cached_ds_return_config.pop(
VERSIONED_FLAG_KEY
)
# Pop metadata from configuration
# Pop metadata and data from configuration
cached_ds_return_config.pop("metadata", None)
cached_ds_return_config.pop("data", None)
return_config["dataset"] = cached_ds_return_config

# Set `versioned` key if version present in the dataset
if return_config.pop(VERSION_KEY, None):
return_config[VERSIONED_FLAG_KEY] = True

# Pop metadata from configuration
# Pop metadata and data from configuration
return_config.pop("metadata", None)
return_config.pop("data", None)

return return_config

Expand Down
Loading