Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
talsperre committed Jan 15, 2025
1 parent 81c9e4e commit 96c7560
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 32 deletions.
47 changes: 22 additions & 25 deletions metaflow/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,8 @@ class Task(MetaflowObject):

def __init__(self, *args, **kwargs):
super(Task, self).__init__(*args, **kwargs)
# We want to cache metadata dictionary since it's used in many places
self._metadata_dict = None

def _iter_filter(self, x):
# exclude private data artifacts
Expand All @@ -1129,20 +1131,15 @@ def _get_task_for_queried_step(self, flow_id, run_id, query_step):
Returns a Task object corresponding to the queried step.
If the queried step has several tasks, the first task is returned.
"""
# Find any task corresponding to the queried step
step = Step(f"{flow_id}/{run_id}/{query_step}", _namespace_check=False)
task = next(iter(step.tasks()), None)
if task:
return task
raise MetaflowNotFound(f"No task found for the queried step {query_step}")
return Step(f"{flow_id}/{run_id}/{query_step}", _namespace_check=False).task

def _get_metadata_query_vals(
self,
flow_id: str,
run_id: str,
cur_foreach_stack_len: int,
steps: List[str],
query_type: str,
is_ancestor: bool,
):
"""
Returns the field name and field value to be used for querying metadata of successor or ancestor tasks.
Expand All @@ -1158,15 +1155,15 @@ def _get_metadata_query_vals(
steps : List[str]
List of step names whose tasks will be returned. For static joins, and static splits, we can have
ancestors and successors across multiple steps.
query_type : str
Type of query. Can be 'ancestor' or 'successor'.
is_ancestor : bool
If we are querying for ancestor tasks, set this to True.
"""
# For each task, we also log additional metadata fields such as foreach-indices and foreach-indices-truncated
# which help us in querying ancestor and successor tasks.
# `foreach-indices`: contains the indices of the foreach stack at the time of task execution.
# `foreach-indices-truncated`: contains the indices of the foreach stack at the time of task execution but
# truncated by 1
# For example, a task thats nested 3 levels deep in a foreach stack may have the following values:
# For example, a task that's nested 3 levels deep in a foreach stack may have the following values:
# foreach-indices = [0, 1, 2]
# foreach-indices-truncated = [0, 1]

Expand All @@ -1185,7 +1182,7 @@ def _get_metadata_query_vals(
# The successor or ancestor tasks belong to the same foreach stack level
field_name = "foreach-indices"
field_value = self.metadata_dict.get(field_name)
elif query_type == "ancestor":
elif is_ancestor:
if query_foreach_stack_len > cur_foreach_stack_len:
# This is a foreach join
# Current Task: foreach-indices = [0, 1], foreach-indices-truncated = [0]
Expand All @@ -1199,7 +1196,7 @@ def _get_metadata_query_vals(
# Current Task: foreach-indices = [0, 1, 2], foreach-indices-truncated = [0, 1]
# Ancestor Task: foreach-indices = [0, 1], foreach-indices-truncated = [0]
# We will compare the foreach-indices value of ancestor task with the
# foreach-indices value of current task
# foreach-indices-truncated value of current task
field_name = "foreach-indices"
field_value = self.metadata_dict.get("foreach-indices-truncated")
else:
Expand All @@ -1221,13 +1218,12 @@ def _get_metadata_query_vals(
field_value = self.metadata_dict.get("foreach-indices-truncated")
return field_name, field_value

def _get_related_tasks(self, relation_type: str) -> Dict[str, List[str]]:
start_time = time.time()
def _get_related_tasks(self, is_ancestor: bool) -> Dict[str, List[str]]:
flow_id, run_id, _, _ = self.path_components
steps = (
self.metadata_dict.get("previous_steps")
if relation_type == "ancestor"
else self.metadata_dict.get("successor_steps")
self.metadata_dict.get("previous-steps")
if is_ancestor
else self.metadata_dict.get("successor-steps")
)

if not steps:
Expand All @@ -1238,11 +1234,9 @@ def _get_related_tasks(self, relation_type: str) -> Dict[str, List[str]]:
run_id,
len(self.metadata_dict.get("foreach-indices", [])),
steps,
relation_type,
is_ancestor=is_ancestor,
)

cur_time = time.time()

return {
step: [
f"{flow_id}/{run_id}/{step}/{task_id}"
Expand All @@ -1266,7 +1260,7 @@ def immediate_ancestors(self) -> Dict[str, List[str]]:
names of the ancestor steps and the values are the corresponding
task pathspecs of the ancestors.
"""
return self._get_related_tasks("ancestor")
return self._get_related_tasks(is_ancestor=True)

@property
def immediate_successors(self) -> Dict[str, List[str]]:
Expand All @@ -1281,7 +1275,7 @@ def immediate_successors(self) -> Dict[str, List[str]]:
names of the successor steps and the values are the corresponding
task pathspecs of the successors.
"""
return self._get_related_tasks("successor")
return self._get_related_tasks(is_ancestor=False)

@property
def immediate_siblings(self) -> Dict[str, List[str]]:
Expand Down Expand Up @@ -1408,9 +1402,12 @@ def metadata_dict(self) -> Dict[str, str]:
Dictionary mapping metadata name with value
"""
# use the newest version of each key, hence sorting
return {
m.name: m.value for m in sorted(self.metadata, key=lambda m: m.created_at)
}
if self._metadata_dict is None:
self._metadata_dict = {
m.name: m.value
for m in sorted(self.metadata, key=lambda m: m.created_at)
}
return self._metadata_dict

@property
def index(self) -> Optional[int]:
Expand Down
4 changes: 1 addition & 3 deletions metaflow/plugins/metadata_providers/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _get_latest_metadata_file(path: str, field_prefix: str) -> tuple:
# and the artifact files are saved as: <attempt>_artifact__<artifact_name>.json
# We loop over all the JSON files in the directory and find the latest one
# that matches the field prefix.
json_files = glob.glob(os.path.join(path, "*.json"))
json_files = glob.glob(os.path.join(path, f"{field_prefix}*.json"))
matching_files = []

for file_path in json_files:
Expand Down Expand Up @@ -287,8 +287,6 @@ def _read_metadata_value(file_path: str) -> dict:
# Filter tasks based on metadata
for task in tasks:
task_id = task.get("task_id")
if not task_id:
continue

meta_path = LocalMetadataProvider._get_metadir(
flow_id, run_id, query_step, task_id
Expand Down
8 changes: 4 additions & 4 deletions metaflow/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,15 +781,15 @@ def run_step(
tags=metadata_tags,
),
MetaDatum(
field="previous_steps",
field="previous-steps",
value=previous_steps,
type="previous_steps",
type="previous-steps",
tags=metadata_tags,
),
MetaDatum(
field="successor_steps",
field="successor-steps",
value=successor_steps,
type="successor_steps",
type="successor-steps",
tags=metadata_tags,
),
],
Expand Down

0 comments on commit 96c7560

Please sign in to comment.