diff --git a/metaflow/client/core.py b/metaflow/client/core.py index aba68a8f0df..e5273db9e33 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1278,41 +1278,58 @@ def immediate_successors(self) -> Dict[str, List[str]]: return self._get_related_tasks(is_ancestor=False) @property - def immediate_siblings(self) -> Dict[str, List[str]]: + def siblings(self) -> Dict[str, List[str]]: """ - Returns a dictionary of closest sibling task pathspecs of this task for each - sibling step. + Returns a dictionary of sibling task pathspecs of this task. Siblings of a task have + the same common parent task. Returns ------- Dict[str, List[str]] - Dictionary of closest siblings of this task. The keys are the + Dictionary of siblings task pathspecs of this task. The keys are the names of the current step and the values are the corresponding task pathspecs of the siblings. """ flow_id, run_id, step_name, _ = self.path_components - foreach_stack = self.metadata_dict.get("foreach-stack", []) - foreach_step_names = self.metadata_dict.get("foreach-step-names", []) - if len(foreach_stack) == 0: - raise MetaflowInternalError("Task is not part of any foreach split") - if step_name != foreach_step_names[-1]: - raise MetaflowInternalError( - f"Step {step_name} does not have any direct siblings since it is not part " - f"of a new foreach split." - ) + ancestor_steps = self.metadata_dict.get("previous-steps") + cur_foreach_stack_len = len(self.metadata_dict.get("foreach-indices", [])) + if len(ancestor_steps) > 1 or step_name in ("start", "end"): + # This is a static join, or a start/end step. The current task will have no siblings. + return { + step_name: [f"{flow_id}/{run_id}/{step_name}/{self.id}"], + } - field_name = "foreach-indices-truncated" - field_value = self.metadata_dict.get("foreach-indices-truncated") - # We find all tasks of the same step that have the same foreach-indices-truncated value - return { - step_name: [ - f"{flow_id}/{run_id}/{step_name}/{task_id}" - for task_id in self._metaflow.metadata.filter_tasks_by_metadata( - flow_id, run_id, step_name, field_name, field_value - ) - ] - } + # This can be a linear step, a foreach split, a foreach join, or a static split. + query_task = self._get_task_for_queried_step(flow_id, run_id, ancestor_steps[0]) + query_foreach_stack_len = len( + query_task.metadata_dict.get("foreach-indices", []) + ) + if query_foreach_stack_len > cur_foreach_stack_len: + # This is a foreach join, there will be no siblings + return { + step_name: [f"{flow_id}/{run_id}/{step_name}/{self.id}"], + } + elif query_foreach_stack_len < cur_foreach_stack_len: + # This is a foreach split, there will be multiple siblings + field_name = "foreach-indices-truncated" + field_value = self.metadata_dict.get("foreach-indices-truncated") + # We find all tasks of the same step that have the same foreach-indices-truncated value + return { + step_name: [ + f"{flow_id}/{run_id}/{step_name}/{task_id}" + for task_id in self._metaflow.metadata.filter_tasks_by_metadata( + flow_id, run_id, step_name, field_name, field_value + ) + ] + } + + # Logic for static splits, and linear steps + # To find siblings, we first find the single ancestor task of the current task. + # And then we find all the successor tasks of this ancestor task. + ancestor_task_pathspecs = self.immediate_ancestors.get(ancestor_steps[0]) + ancestor_task = Task(ancestor_task_pathspecs[0], _namespace_check=False) + return ancestor_task.immediate_successors @property def metadata(self) -> List[Metadata]: