Skip to content

Commit 96c7560

Browse files
committed
Address comments
1 parent 81c9e4e commit 96c7560

File tree

3 files changed

+27
-32
lines changed

3 files changed

+27
-32
lines changed

metaflow/client/core.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,8 @@ class Task(MetaflowObject):
11191119

11201120
def __init__(self, *args, **kwargs):
11211121
super(Task, self).__init__(*args, **kwargs)
1122+
# We want to cache metadata dictionary since it's used in many places
1123+
self._metadata_dict = None
11221124

11231125
def _iter_filter(self, x):
11241126
# exclude private data artifacts
@@ -1129,20 +1131,15 @@ def _get_task_for_queried_step(self, flow_id, run_id, query_step):
11291131
Returns a Task object corresponding to the queried step.
11301132
If the queried step has several tasks, the first task is returned.
11311133
"""
1132-
# Find any task corresponding to the queried step
1133-
step = Step(f"{flow_id}/{run_id}/{query_step}", _namespace_check=False)
1134-
task = next(iter(step.tasks()), None)
1135-
if task:
1136-
return task
1137-
raise MetaflowNotFound(f"No task found for the queried step {query_step}")
1134+
return Step(f"{flow_id}/{run_id}/{query_step}", _namespace_check=False).task
11381135

11391136
def _get_metadata_query_vals(
11401137
self,
11411138
flow_id: str,
11421139
run_id: str,
11431140
cur_foreach_stack_len: int,
11441141
steps: List[str],
1145-
query_type: str,
1142+
is_ancestor: bool,
11461143
):
11471144
"""
11481145
Returns the field name and field value to be used for querying metadata of successor or ancestor tasks.
@@ -1158,15 +1155,15 @@ def _get_metadata_query_vals(
11581155
steps : List[str]
11591156
List of step names whose tasks will be returned. For static joins, and static splits, we can have
11601157
ancestors and successors across multiple steps.
1161-
query_type : str
1162-
Type of query. Can be 'ancestor' or 'successor'.
1158+
is_ancestor : bool
1159+
If we are querying for ancestor tasks, set this to True.
11631160
"""
11641161
# For each task, we also log additional metadata fields such as foreach-indices and foreach-indices-truncated
11651162
# which help us in querying ancestor and successor tasks.
11661163
# `foreach-indices`: contains the indices of the foreach stack at the time of task execution.
11671164
# `foreach-indices-truncated`: contains the indices of the foreach stack at the time of task execution but
11681165
# truncated by 1
1169-
# For example, a task thats nested 3 levels deep in a foreach stack may have the following values:
1166+
# For example, a task that's nested 3 levels deep in a foreach stack may have the following values:
11701167
# foreach-indices = [0, 1, 2]
11711168
# foreach-indices-truncated = [0, 1]
11721169

@@ -1185,7 +1182,7 @@ def _get_metadata_query_vals(
11851182
# The successor or ancestor tasks belong to the same foreach stack level
11861183
field_name = "foreach-indices"
11871184
field_value = self.metadata_dict.get(field_name)
1188-
elif query_type == "ancestor":
1185+
elif is_ancestor:
11891186
if query_foreach_stack_len > cur_foreach_stack_len:
11901187
# This is a foreach join
11911188
# Current Task: foreach-indices = [0, 1], foreach-indices-truncated = [0]
@@ -1199,7 +1196,7 @@ def _get_metadata_query_vals(
11991196
# Current Task: foreach-indices = [0, 1, 2], foreach-indices-truncated = [0, 1]
12001197
# Ancestor Task: foreach-indices = [0, 1], foreach-indices-truncated = [0]
12011198
# We will compare the foreach-indices value of ancestor task with the
1202-
# foreach-indices value of current task
1199+
# foreach-indices-truncated value of current task
12031200
field_name = "foreach-indices"
12041201
field_value = self.metadata_dict.get("foreach-indices-truncated")
12051202
else:
@@ -1221,13 +1218,12 @@ def _get_metadata_query_vals(
12211218
field_value = self.metadata_dict.get("foreach-indices-truncated")
12221219
return field_name, field_value
12231220

1224-
def _get_related_tasks(self, relation_type: str) -> Dict[str, List[str]]:
1225-
start_time = time.time()
1221+
def _get_related_tasks(self, is_ancestor: bool) -> Dict[str, List[str]]:
12261222
flow_id, run_id, _, _ = self.path_components
12271223
steps = (
1228-
self.metadata_dict.get("previous_steps")
1229-
if relation_type == "ancestor"
1230-
else self.metadata_dict.get("successor_steps")
1224+
self.metadata_dict.get("previous-steps")
1225+
if is_ancestor
1226+
else self.metadata_dict.get("successor-steps")
12311227
)
12321228

12331229
if not steps:
@@ -1238,11 +1234,9 @@ def _get_related_tasks(self, relation_type: str) -> Dict[str, List[str]]:
12381234
run_id,
12391235
len(self.metadata_dict.get("foreach-indices", [])),
12401236
steps,
1241-
relation_type,
1237+
is_ancestor=is_ancestor,
12421238
)
12431239

1244-
cur_time = time.time()
1245-
12461240
return {
12471241
step: [
12481242
f"{flow_id}/{run_id}/{step}/{task_id}"
@@ -1266,7 +1260,7 @@ def immediate_ancestors(self) -> Dict[str, List[str]]:
12661260
names of the ancestor steps and the values are the corresponding
12671261
task pathspecs of the ancestors.
12681262
"""
1269-
return self._get_related_tasks("ancestor")
1263+
return self._get_related_tasks(is_ancestor=True)
12701264

12711265
@property
12721266
def immediate_successors(self) -> Dict[str, List[str]]:
@@ -1281,7 +1275,7 @@ def immediate_successors(self) -> Dict[str, List[str]]:
12811275
names of the successor steps and the values are the corresponding
12821276
task pathspecs of the successors.
12831277
"""
1284-
return self._get_related_tasks("successor")
1278+
return self._get_related_tasks(is_ancestor=False)
12851279

12861280
@property
12871281
def immediate_siblings(self) -> Dict[str, List[str]]:
@@ -1408,9 +1402,12 @@ def metadata_dict(self) -> Dict[str, str]:
14081402
Dictionary mapping metadata name with value
14091403
"""
14101404
# use the newest version of each key, hence sorting
1411-
return {
1412-
m.name: m.value for m in sorted(self.metadata, key=lambda m: m.created_at)
1413-
}
1405+
if self._metadata_dict is None:
1406+
self._metadata_dict = {
1407+
m.name: m.value
1408+
for m in sorted(self.metadata, key=lambda m: m.created_at)
1409+
}
1410+
return self._metadata_dict
14141411

14151412
@property
14161413
def index(self) -> Optional[int]:

metaflow/plugins/metadata_providers/local.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def _get_latest_metadata_file(path: str, field_prefix: str) -> tuple:
247247
# and the artifact files are saved as: <attempt>_artifact__<artifact_name>.json
248248
# We loop over all the JSON files in the directory and find the latest one
249249
# that matches the field prefix.
250-
json_files = glob.glob(os.path.join(path, "*.json"))
250+
json_files = glob.glob(os.path.join(path, f"{field_prefix}*.json"))
251251
matching_files = []
252252

253253
for file_path in json_files:
@@ -287,8 +287,6 @@ def _read_metadata_value(file_path: str) -> dict:
287287
# Filter tasks based on metadata
288288
for task in tasks:
289289
task_id = task.get("task_id")
290-
if not task_id:
291-
continue
292290

293291
meta_path = LocalMetadataProvider._get_metadir(
294292
flow_id, run_id, query_step, task_id

metaflow/task.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -781,15 +781,15 @@ def run_step(
781781
tags=metadata_tags,
782782
),
783783
MetaDatum(
784-
field="previous_steps",
784+
field="previous-steps",
785785
value=previous_steps,
786-
type="previous_steps",
786+
type="previous-steps",
787787
tags=metadata_tags,
788788
),
789789
MetaDatum(
790-
field="successor_steps",
790+
field="successor-steps",
791791
value=successor_steps,
792-
type="successor_steps",
792+
type="successor-steps",
793793
tags=metadata_tags,
794794
),
795795
],

0 commit comments

Comments
 (0)