Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc bugfixes #3234

Merged
merged 9 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 46 additions & 79 deletions src/zenml/cli/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,33 @@
logger = get_logger(__name__)


def _import_pipeline(source: str) -> Pipeline:
"""Import a pipeline.

Args:
source: The pipeline source.

Returns:
The pipeline.
"""
try:
pipeline_instance = source_utils.load(source)
except ModuleNotFoundError as e:
source_root = source_utils.get_source_root()
cli_utils.error(
f"Unable to import module `{e.name}`. Make sure the source path is "
f"relative to your source root `{source_root}`."
)
except AttributeError as e:
cli_utils.error("Unable to load attribute from module: " + str(e))

if not isinstance(pipeline_instance, Pipeline):
cli_utils.error(
f"The given source path `{source}` does not resolve to a pipeline "
"object."
)


@cli.group(cls=TagGroup, tag=CliCategories.MANAGEMENT_TOOLS)
def pipeline() -> None:
"""Interact with pipelines, runs and schedules."""
Expand Down Expand Up @@ -85,22 +112,7 @@ def register_pipeline(
"source code root."
)

try:
pipeline_instance = source_utils.load(source)
except ModuleNotFoundError as e:
source_root = source_utils.get_source_root()
cli_utils.error(
f"Unable to import module `{e.name}`. Make sure the source path is "
f"relative to your source root `{source_root}`."
)
except AttributeError as e:
cli_utils.error("Unable to load attribute from module: " + str(e))

if not isinstance(pipeline_instance, Pipeline):
cli_utils.error(
f"The given source path `{source}` does not resolve to a pipeline "
"object."
)
pipeline_instance = _import_pipeline(source=source)

parameters: Dict[str, Any] = {}
if parameters_path:
Expand Down Expand Up @@ -176,24 +188,9 @@ def build_pipeline(
"your source code root."
)

try:
pipeline_instance = source_utils.load(source)
except ModuleNotFoundError as e:
source_root = source_utils.get_source_root()
cli_utils.error(
f"Unable to import module `{e.name}`. Make sure the source path is "
f"relative to your source root `{source_root}`."
)
except AttributeError as e:
cli_utils.error("Unable to load attribute from module: " + str(e))

if not isinstance(pipeline_instance, Pipeline):
cli_utils.error(
f"The given source path `{source}` does not resolve to a pipeline "
"object."
)

with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id):
pipeline_instance = _import_pipeline(source=source)

pipeline_instance = pipeline_instance.with_options(
config_path=config_path
)
Expand Down Expand Up @@ -277,36 +274,21 @@ def run_pipeline(
"your source code root."
)

try:
pipeline_instance = source_utils.load(source)
except ModuleNotFoundError as e:
source_root = source_utils.get_source_root()
cli_utils.error(
f"Unable to import module `{e.name}`. Make sure the source path is "
f"relative to your source root `{source_root}`."
)
except AttributeError as e:
cli_utils.error("Unable to load attribute from module: " + str(e))

if not isinstance(pipeline_instance, Pipeline):
cli_utils.error(
f"The given source path `{source}` does not resolve to a pipeline "
"object."
)

build: Union[str, PipelineBuildBase, None] = None
if build_path_or_id:
if uuid_utils.is_valid_uuid(build_path_or_id):
build = build_path_or_id
elif os.path.exists(build_path_or_id):
build = PipelineBuildBase.from_yaml(build_path_or_id)
else:
cli_utils.error(
f"The specified build {build_path_or_id} is not a valid UUID "
"or file path."
)

with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id):
pipeline_instance = _import_pipeline(source=source)

build: Union[str, PipelineBuildBase, None] = None
if build_path_or_id:
if uuid_utils.is_valid_uuid(build_path_or_id):
build = build_path_or_id
elif os.path.exists(build_path_or_id):
build = PipelineBuildBase.from_yaml(build_path_or_id)
else:
cli_utils.error(
f"The specified build {build_path_or_id} is not a valid UUID "
"or file path."
)

pipeline_instance = pipeline_instance.with_options(
config_path=config_path,
build=build,
Expand Down Expand Up @@ -369,24 +351,9 @@ def create_run_template(
"init` at your source code root."
)

try:
pipeline_instance = source_utils.load(source)
except ModuleNotFoundError as e:
source_root = source_utils.get_source_root()
cli_utils.error(
f"Unable to import module `{e.name}`. Make sure the source path is "
f"relative to your source root `{source_root}`."
)
except AttributeError as e:
cli_utils.error("Unable to load attribute from module: " + str(e))

if not isinstance(pipeline_instance, Pipeline):
cli_utils.error(
f"The given source path `{source}` does not resolve to a pipeline "
"object."
)

with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id):
pipeline_instance = _import_pipeline(source=source)

pipeline_instance = pipeline_instance.with_options(
config_path=config_path
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,12 @@ def prepare_step_run(self, info: "StepRunInfo") -> None:
NeptuneExperimentTrackerSettings, self.get_settings(info)
)

self.run_state.token = self.config.api_token
self.run_state.project = self.config.project
self.run_state.run_name = info.run_name
self.run_state.tags = list(settings.tags)
self.run_state.initialize(
project=self.config.project,
token=self.config.api_token,
run_name=info.run_name,
tags=list(settings.tags),
)

def get_step_run_metadata(
self, info: "StepRunInfo"
Expand All @@ -107,4 +109,4 @@ def cleanup_step_run(self, info: "StepRunInfo", step_failed: bool) -> None:
"""
self.run_state.active_run.sync()
self.run_state.active_run.stop()
self.run_state.reset_active_run()
self.run_state.reset()
122 changes: 69 additions & 53 deletions src/zenml/integrations/neptune/experiment_trackers/run_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import zenml
from zenml.client import Client
from zenml.integrations.constants import NEPTUNE
from zenml.utils.singleton import SingletonMetaClass

if TYPE_CHECKING:
Expand All @@ -29,20 +28,38 @@
_INTEGRATION_VERSION_KEY = "source_code/integrations/zenml"


class InvalidExperimentTrackerSelected(Exception):
"""Raised if a Neptune run is fetched while using a different experiment tracker."""


class RunProvider(metaclass=SingletonMetaClass):
"""Singleton object used to store and persist a Neptune run state across the pipeline."""

def __init__(self) -> None:
"""Initialize RunProvider. Called with no arguments."""
self._active_run: Optional["Run"] = None
self._project: Optional[str]
self._run_name: Optional[str]
self._token: Optional[str]
self._tags: Optional[List[str]]
self._project: Optional[str] = None
self._run_name: Optional[str] = None
self._token: Optional[str] = None
self._tags: Optional[List[str]] = None
self._initialized = False

def initialize(
self,
project: Optional[str] = None,
token: Optional[str] = None,
run_name: Optional[str] = None,
tags: Optional[List[str]] = None,
) -> None:
"""Initialize the run state.

Args:
project: The neptune project.
token: The neptune token.
run_name: The neptune run name.
tags: Tags for the neptune run.
"""
self._project = project
self._token = token
self._run_name = run_name
self._tags = tags
self._initialized = True

@property
def project(self) -> Optional[Any]:
Expand All @@ -53,15 +70,6 @@ def project(self) -> Optional[Any]:
"""
return self._project

@project.setter
def project(self, project: str) -> None:
"""Setter for project name.

Args:
project: Neptune project name
"""
self._project = project

@property
def token(self) -> Optional[Any]:
"""Getter for API token.
Expand All @@ -71,15 +79,6 @@ def token(self) -> Optional[Any]:
"""
return self._token

@token.setter
def token(self, token: str) -> None:
"""Setter for API token.

Args:
token: Neptune API token
"""
self._token = token

@property
def run_name(self) -> Optional[Any]:
"""Getter for run name.
Expand All @@ -89,15 +88,6 @@ def run_name(self) -> Optional[Any]:
"""
return self._run_name

@run_name.setter
def run_name(self, run_name: str) -> None:
"""Setter for run name.

Args:
run_name: name of the pipeline run
"""
self._run_name = run_name

@property
def tags(self) -> Optional[Any]:
"""Getter for run tags.
Expand All @@ -107,14 +97,14 @@ def tags(self) -> Optional[Any]:
"""
return self._tags

@tags.setter
def tags(self, tags: List[str]) -> None:
"""Setter for run tags.
@property
def initialized(self) -> bool:
"""If the run state is initialized.

Args:
tags: list of tags associated with a Neptune run
Returns:
If the run state is initialized.
"""
self._tags = tags
return self._initialized

@property
def active_run(self) -> "Run":
Expand All @@ -137,9 +127,14 @@ def active_run(self) -> "Run":
self._active_run = run
return self._active_run

def reset_active_run(self) -> None:
"""Resets the active run state to None."""
def reset(self) -> None:
"""Reset the run state."""
self._active_run = None
self._project = None
self._run_name = None
self._token = None
self._tags = None
self._initialized = False


def get_neptune_run() -> "Run":
Expand All @@ -149,14 +144,35 @@ def get_neptune_run() -> "Run":
Neptune run object

Raises:
InvalidExperimentTrackerSelected: when called while using an experiment tracker other than Neptune
RuntimeError: When unable to fetch the active neptune run.
"""
client = Client()
experiment_tracker = client.active_stack.experiment_tracker
if experiment_tracker.flavor == NEPTUNE: # type: ignore
return experiment_tracker.run_state.active_run # type: ignore
raise InvalidExperimentTrackerSelected(
"Fetching a Neptune run works only with the 'neptune' flavor of "
"the experiment tracker. The flavor currently selected is %s"
% experiment_tracker.flavor # type: ignore
from zenml.integrations.neptune.experiment_trackers import (
NeptuneExperimentTracker,
)

experiment_tracker = Client().active_stack.experiment_tracker

if not experiment_tracker:
raise RuntimeError(
"Unable to get neptune run: Missing experiment tracker in the "
"active stack."
)

if not isinstance(experiment_tracker, NeptuneExperimentTracker):
raise RuntimeError(
"Unable to get neptune run: Experiment tracker in the active "
f"stack ({experiment_tracker.flavor}) is not a neptune experiment "
"tracker."
)

run_state = experiment_tracker.run_state
if not run_state.initialized:
raise RuntimeError(
"Unable to get neptune run: The experiment tracker has not been "
"initialized. To solve this, make sure you use the experiment "
"tracker in your step. See "
"https://docs.zenml.io/stack-components/experiment-trackers/neptune#how-do-you-use-it "
"for more information."
)

return experiment_tracker.run_state.active_run
4 changes: 2 additions & 2 deletions src/zenml/integrations/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def select_integration_requirements(
)
else:
raise KeyError(
f"Version {integration_name} does not exist. "
f"Integration {integration_name} does not exist. "
f"Currently the following integrations are implemented. "
f"{self.list_integration_names}"
)
Expand Down Expand Up @@ -148,7 +148,7 @@ def select_uninstall_requirements(
].get_uninstall_requirements(target_os=target_os)
else:
raise KeyError(
f"Version {integration_name} does not exist. "
f"Integration {integration_name} does not exist. "
f"Currently the following integrations are implemented. "
f"{self.list_integration_names}"
)
Expand Down
Loading