@@ -1119,6 +1119,8 @@ class Task(MetaflowObject):
1119
1119
1120
1120
def __init__ (self , * args , ** kwargs ):
1121
1121
super (Task , self ).__init__ (* args , ** kwargs )
1122
+ # We want to cache metadata dictionary since it's used in many places
1123
+ self ._metadata_dict = None
1122
1124
1123
1125
def _iter_filter (self , x ):
1124
1126
# exclude private data artifacts
@@ -1129,20 +1131,15 @@ def _get_task_for_queried_step(self, flow_id, run_id, query_step):
1129
1131
Returns a Task object corresponding to the queried step.
1130
1132
If the queried step has several tasks, the first task is returned.
1131
1133
"""
1132
- # Find any task corresponding to the queried step
1133
- step = Step (f"{ flow_id } /{ run_id } /{ query_step } " , _namespace_check = False )
1134
- task = next (iter (step .tasks ()), None )
1135
- if task :
1136
- return task
1137
- raise MetaflowNotFound (f"No task found for the queried step { query_step } " )
1134
+ return Step (f"{ flow_id } /{ run_id } /{ query_step } " , _namespace_check = False ).task
1138
1135
1139
1136
def _get_metadata_query_vals (
1140
1137
self ,
1141
1138
flow_id : str ,
1142
1139
run_id : str ,
1143
1140
cur_foreach_stack_len : int ,
1144
1141
steps : List [str ],
1145
- query_type : str ,
1142
+ is_ancestor : bool ,
1146
1143
):
1147
1144
"""
1148
1145
Returns the field name and field value to be used for querying metadata of successor or ancestor tasks.
@@ -1158,15 +1155,15 @@ def _get_metadata_query_vals(
1158
1155
steps : List[str]
1159
1156
List of step names whose tasks will be returned. For static joins, and static splits, we can have
1160
1157
ancestors and successors across multiple steps.
1161
- query_type : str
1162
- Type of query. Can be ' ancestor' or 'successor' .
1158
+ is_ancestor : bool
1159
+ If we are querying for ancestor tasks, set this to True .
1163
1160
"""
1164
1161
# For each task, we also log additional metadata fields such as foreach-indices and foreach-indices-truncated
1165
1162
# which help us in querying ancestor and successor tasks.
1166
1163
# `foreach-indices`: contains the indices of the foreach stack at the time of task execution.
1167
1164
# `foreach-indices-truncated`: contains the indices of the foreach stack at the time of task execution but
1168
1165
# truncated by 1
1169
- # For example, a task thats nested 3 levels deep in a foreach stack may have the following values:
1166
+ # For example, a task that's nested 3 levels deep in a foreach stack may have the following values:
1170
1167
# foreach-indices = [0, 1, 2]
1171
1168
# foreach-indices-truncated = [0, 1]
1172
1169
@@ -1185,7 +1182,7 @@ def _get_metadata_query_vals(
1185
1182
# The successor or ancestor tasks belong to the same foreach stack level
1186
1183
field_name = "foreach-indices"
1187
1184
field_value = self .metadata_dict .get (field_name )
1188
- elif query_type == "ancestor" :
1185
+ elif is_ancestor :
1189
1186
if query_foreach_stack_len > cur_foreach_stack_len :
1190
1187
# This is a foreach join
1191
1188
# Current Task: foreach-indices = [0, 1], foreach-indices-truncated = [0]
@@ -1199,7 +1196,7 @@ def _get_metadata_query_vals(
1199
1196
# Current Task: foreach-indices = [0, 1, 2], foreach-indices-truncated = [0, 1]
1200
1197
# Ancestor Task: foreach-indices = [0, 1], foreach-indices-truncated = [0]
1201
1198
# We will compare the foreach-indices value of ancestor task with the
1202
- # foreach-indices value of current task
1199
+ # foreach-indices-truncated value of current task
1203
1200
field_name = "foreach-indices"
1204
1201
field_value = self .metadata_dict .get ("foreach-indices-truncated" )
1205
1202
else :
@@ -1221,13 +1218,12 @@ def _get_metadata_query_vals(
1221
1218
field_value = self .metadata_dict .get ("foreach-indices-truncated" )
1222
1219
return field_name , field_value
1223
1220
1224
- def _get_related_tasks (self , relation_type : str ) -> Dict [str , List [str ]]:
1225
- start_time = time .time ()
1221
+ def _get_related_tasks (self , is_ancestor : bool ) -> Dict [str , List [str ]]:
1226
1222
flow_id , run_id , _ , _ = self .path_components
1227
1223
steps = (
1228
- self .metadata_dict .get ("previous_steps " )
1229
- if relation_type == "ancestor"
1230
- else self .metadata_dict .get ("successor_steps " )
1224
+ self .metadata_dict .get ("previous-steps " )
1225
+ if is_ancestor
1226
+ else self .metadata_dict .get ("successor-steps " )
1231
1227
)
1232
1228
1233
1229
if not steps :
@@ -1238,11 +1234,9 @@ def _get_related_tasks(self, relation_type: str) -> Dict[str, List[str]]:
1238
1234
run_id ,
1239
1235
len (self .metadata_dict .get ("foreach-indices" , [])),
1240
1236
steps ,
1241
- relation_type ,
1237
+ is_ancestor = is_ancestor ,
1242
1238
)
1243
1239
1244
- cur_time = time .time ()
1245
-
1246
1240
return {
1247
1241
step : [
1248
1242
f"{ flow_id } /{ run_id } /{ step } /{ task_id } "
@@ -1266,7 +1260,7 @@ def immediate_ancestors(self) -> Dict[str, List[str]]:
1266
1260
names of the ancestor steps and the values are the corresponding
1267
1261
task pathspecs of the ancestors.
1268
1262
"""
1269
- return self ._get_related_tasks ("ancestor" )
1263
+ return self ._get_related_tasks (is_ancestor = True )
1270
1264
1271
1265
@property
1272
1266
def immediate_successors (self ) -> Dict [str , List [str ]]:
@@ -1281,7 +1275,7 @@ def immediate_successors(self) -> Dict[str, List[str]]:
1281
1275
names of the successor steps and the values are the corresponding
1282
1276
task pathspecs of the successors.
1283
1277
"""
1284
- return self ._get_related_tasks ("successor" )
1278
+ return self ._get_related_tasks (is_ancestor = False )
1285
1279
1286
1280
@property
1287
1281
def immediate_siblings (self ) -> Dict [str , List [str ]]:
@@ -1408,9 +1402,12 @@ def metadata_dict(self) -> Dict[str, str]:
1408
1402
Dictionary mapping metadata name with value
1409
1403
"""
1410
1404
# use the newest version of each key, hence sorting
1411
- return {
1412
- m .name : m .value for m in sorted (self .metadata , key = lambda m : m .created_at )
1413
- }
1405
+ if self ._metadata_dict is None :
1406
+ self ._metadata_dict = {
1407
+ m .name : m .value
1408
+ for m in sorted (self .metadata , key = lambda m : m .created_at )
1409
+ }
1410
+ return self ._metadata_dict
1414
1411
1415
1412
@property
1416
1413
def index (self ) -> Optional [int ]:
0 commit comments