From 96c75601dbc26a60e3ba5fe10d251a3d662ae8fc Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Tue, 14 Jan 2025 12:29:53 -0800 Subject: [PATCH] Address comments --- metaflow/client/core.py | 47 +++++++++----------- metaflow/plugins/metadata_providers/local.py | 4 +- metaflow/task.py | 8 ++-- 3 files changed, 27 insertions(+), 32 deletions(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index fdbc4ad5ed6..aba68a8f0df 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1119,6 +1119,8 @@ class Task(MetaflowObject): def __init__(self, *args, **kwargs): super(Task, self).__init__(*args, **kwargs) + # We want to cache metadata dictionary since it's used in many places + self._metadata_dict = None def _iter_filter(self, x): # exclude private data artifacts @@ -1129,12 +1131,7 @@ def _get_task_for_queried_step(self, flow_id, run_id, query_step): Returns a Task object corresponding to the queried step. If the queried step has several tasks, the first task is returned. """ - # Find any task corresponding to the queried step - step = Step(f"{flow_id}/{run_id}/{query_step}", _namespace_check=False) - task = next(iter(step.tasks()), None) - if task: - return task - raise MetaflowNotFound(f"No task found for the queried step {query_step}") + return Step(f"{flow_id}/{run_id}/{query_step}", _namespace_check=False).task def _get_metadata_query_vals( self, @@ -1142,7 +1139,7 @@ def _get_metadata_query_vals( run_id: str, cur_foreach_stack_len: int, steps: List[str], - query_type: str, + is_ancestor: bool, ): """ Returns the field name and field value to be used for querying metadata of successor or ancestor tasks. @@ -1158,15 +1155,15 @@ def _get_metadata_query_vals( steps : List[str] List of step names whose tasks will be returned. For static joins, and static splits, we can have ancestors and successors across multiple steps. - query_type : str - Type of query. Can be 'ancestor' or 'successor'. + is_ancestor : bool + If we are querying for ancestor tasks, set this to True. """ # For each task, we also log additional metadata fields such as foreach-indices and foreach-indices-truncated # which help us in querying ancestor and successor tasks. # `foreach-indices`: contains the indices of the foreach stack at the time of task execution. # `foreach-indices-truncated`: contains the indices of the foreach stack at the time of task execution but # truncated by 1 - # For example, a task thats nested 3 levels deep in a foreach stack may have the following values: + # For example, a task that's nested 3 levels deep in a foreach stack may have the following values: # foreach-indices = [0, 1, 2] # foreach-indices-truncated = [0, 1] @@ -1185,7 +1182,7 @@ def _get_metadata_query_vals( # The successor or ancestor tasks belong to the same foreach stack level field_name = "foreach-indices" field_value = self.metadata_dict.get(field_name) - elif query_type == "ancestor": + elif is_ancestor: if query_foreach_stack_len > cur_foreach_stack_len: # This is a foreach join # Current Task: foreach-indices = [0, 1], foreach-indices-truncated = [0] @@ -1199,7 +1196,7 @@ def _get_metadata_query_vals( # Current Task: foreach-indices = [0, 1, 2], foreach-indices-truncated = [0, 1] # Ancestor Task: foreach-indices = [0, 1], foreach-indices-truncated = [0] # We will compare the foreach-indices value of ancestor task with the - # foreach-indices value of current task + # foreach-indices-truncated value of current task field_name = "foreach-indices" field_value = self.metadata_dict.get("foreach-indices-truncated") else: @@ -1221,13 +1218,12 @@ def _get_metadata_query_vals( field_value = self.metadata_dict.get("foreach-indices-truncated") return field_name, field_value - def _get_related_tasks(self, relation_type: str) -> Dict[str, List[str]]: - start_time = time.time() + def _get_related_tasks(self, is_ancestor: bool) -> Dict[str, List[str]]: flow_id, run_id, _, _ = self.path_components steps = ( - self.metadata_dict.get("previous_steps") - if relation_type == "ancestor" - else self.metadata_dict.get("successor_steps") + self.metadata_dict.get("previous-steps") + if is_ancestor + else self.metadata_dict.get("successor-steps") ) if not steps: @@ -1238,11 +1234,9 @@ def _get_related_tasks(self, relation_type: str) -> Dict[str, List[str]]: run_id, len(self.metadata_dict.get("foreach-indices", [])), steps, - relation_type, + is_ancestor=is_ancestor, ) - cur_time = time.time() - return { step: [ f"{flow_id}/{run_id}/{step}/{task_id}" @@ -1266,7 +1260,7 @@ def immediate_ancestors(self) -> Dict[str, List[str]]: names of the ancestor steps and the values are the corresponding task pathspecs of the ancestors. """ - return self._get_related_tasks("ancestor") + return self._get_related_tasks(is_ancestor=True) @property def immediate_successors(self) -> Dict[str, List[str]]: @@ -1281,7 +1275,7 @@ def immediate_successors(self) -> Dict[str, List[str]]: names of the successor steps and the values are the corresponding task pathspecs of the successors. """ - return self._get_related_tasks("successor") + return self._get_related_tasks(is_ancestor=False) @property def immediate_siblings(self) -> Dict[str, List[str]]: @@ -1408,9 +1402,12 @@ def metadata_dict(self) -> Dict[str, str]: Dictionary mapping metadata name with value """ # use the newest version of each key, hence sorting - return { - m.name: m.value for m in sorted(self.metadata, key=lambda m: m.created_at) - } + if self._metadata_dict is None: + self._metadata_dict = { + m.name: m.value + for m in sorted(self.metadata, key=lambda m: m.created_at) + } + return self._metadata_dict @property def index(self) -> Optional[int]: diff --git a/metaflow/plugins/metadata_providers/local.py b/metaflow/plugins/metadata_providers/local.py index 88d0810ef5f..995e589daf6 100644 --- a/metaflow/plugins/metadata_providers/local.py +++ b/metaflow/plugins/metadata_providers/local.py @@ -247,7 +247,7 @@ def _get_latest_metadata_file(path: str, field_prefix: str) -> tuple: # and the artifact files are saved as: _artifact__.json # We loop over all the JSON files in the directory and find the latest one # that matches the field prefix. - json_files = glob.glob(os.path.join(path, "*.json")) + json_files = glob.glob(os.path.join(path, f"{field_prefix}*.json")) matching_files = [] for file_path in json_files: @@ -287,8 +287,6 @@ def _read_metadata_value(file_path: str) -> dict: # Filter tasks based on metadata for task in tasks: task_id = task.get("task_id") - if not task_id: - continue meta_path = LocalMetadataProvider._get_metadir( flow_id, run_id, query_step, task_id diff --git a/metaflow/task.py b/metaflow/task.py index 6785c2a424a..3e9b553d03b 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -781,15 +781,15 @@ def run_step( tags=metadata_tags, ), MetaDatum( - field="previous_steps", + field="previous-steps", value=previous_steps, - type="previous_steps", + type="previous-steps", tags=metadata_tags, ), MetaDatum( - field="successor_steps", + field="successor-steps", value=successor_steps, - type="successor_steps", + type="successor-steps", tags=metadata_tags, ), ],