Skip to content

Commit f3241d3

Browse files
committed
fix langgraph node invoke trigger detection from __pregel_push
1 parent 4f8d78a commit f3241d3

File tree

2 files changed

+45
-21
lines changed

2 files changed

+45
-21
lines changed

ddtrace/llmobs/_integrations/langchain.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from typing import Any
55
from typing import Dict
66
from typing import List
7-
from typing import Set
87
from typing import Optional
8+
from typing import Set
99
from typing import Union
1010
from weakref import WeakKeyDictionary
1111

@@ -227,11 +227,16 @@ def _set_input_links(self, instance: Any, span: Span, parent_span: Union[Span, N
227227
links = []
228228

229229
if not is_step:
230-
self._set_span_links(span, [{
231-
"trace_id": "{:x}".format(span.trace_id),
232-
"span_id": str(invoker_spans[0].span_id),
233-
"attributes": invoker_links_attributes[0],
234-
}])
230+
self._set_span_links(
231+
span,
232+
[
233+
{
234+
"trace_id": "{:x}".format(span.trace_id),
235+
"span_id": str(invoker_spans[0].span_id),
236+
"attributes": invoker_links_attributes[0],
237+
}
238+
],
239+
)
235240

236241
return step_idx
237242

@@ -250,7 +255,7 @@ def _set_input_links(self, instance: Any, span: Span, parent_span: Union[Span, N
250255
invoker_span = self._spans[id(step)]
251256
invoker_link_attributes = {"from": "output", "to": "input"}
252257
break
253-
if isinstance(step, list): # parallel steps in the list
258+
if isinstance(step, list): # parallel steps in the list
254259
for parallel_step in step:
255260
if id(parallel_step) in self._spans:
256261
if not has_parallel_steps:
@@ -277,16 +282,16 @@ def _set_input_links(self, instance: Any, span: Span, parent_span: Union[Span, N
277282
self._set_span_links(span, links)
278283

279284
return step_idx
280-
285+
281286
def _set_output_links(self, span: Span, parent_span: Union[Span, None], step_idx: int) -> None:
282287
"""
283288
Sets the output links for the parent span of the given span (to: output)
284289
This is done by removing repeated span links from steps in a chain.
285-
We add output->output span links at every step
290+
We add output->output span links at every step
286291
"""
287292
if parent_span is None:
288293
return
289-
294+
290295
parent_links = parent_span._get_ctx_item(SPAN_LINKS) or []
291296
pop_indecies = self._get_popped_span_link_indecies(parent_span, parent_links, step_idx)
292297
parent_links = [link for i, link in enumerate(parent_links) if i not in pop_indecies]
@@ -303,7 +308,9 @@ def _set_output_links(self, span: Span, parent_span: Union[Span, None], step_idx
303308
],
304309
)
305310

306-
def _get_popped_span_link_indecies(self, parent_span: Span, parent_links: List[Dict[str, Any]], step_idx: int) -> List[int]:
311+
def _get_popped_span_link_indecies(
312+
self, parent_span: Span, parent_links: List[Dict[str, Any]], step_idx: int
313+
) -> List[int]:
307314
"""
308315
Returns a list of indecies to pop from the parent span links list
309316
This is determined by if the parent span represents a chain, and if there are steps before the step
@@ -316,11 +323,11 @@ def _get_popped_span_link_indecies(self, parent_span: Span, parent_links: List[D
316323
parent_instance = self._instances.get(parent_span)
317324
if not parent_instance:
318325
return pop_indecies
319-
326+
320327
parent_instance = _extract_bound(parent_instance)
321328
if not hasattr(parent_instance, "steps"): # chain instance
322329
return pop_indecies
323-
330+
324331
steps = getattr(parent_instance, "steps", [])
325332
flatmap_chain_steps = _flattened_chain_steps(steps)
326333
for i in range(step_idx - 1, -1, -1):
@@ -338,19 +345,15 @@ def _get_popped_span_link_indecies(self, parent_span: Span, parent_links: List[D
338345
if id(parallel_step) in self._spans:
339346
invoker_span_id = self._spans[id(parallel_step)].span_id
340347
link_idx = next(
341-
(
342-
i
343-
for i, link in enumerate(parent_links)
344-
if link["span_id"] == str(invoker_span_id)
345-
),
348+
(i for i, link in enumerate(parent_links) if link["span_id"] == str(invoker_span_id)),
346349
None,
347350
)
348351
if link_idx is not None:
349352
pop_indecies.append(link_idx)
350353
break
351354

352355
return pop_indecies
353-
356+
354357
def _set_span_links(self, span: Span, links: List[Dict[str, Any]]) -> None:
355358
"""Sets the span links on the given span along with the existing links."""
356359
existing_links = span._get_ctx_item(SPAN_LINKS) or []
@@ -452,7 +455,7 @@ def _llmobs_set_tags_from_chat_model(
452455
content = (
453456
message.get("content", "") if isinstance(message, dict) else getattr(message, "content", "")
454457
)
455-
role = getattr(message, "role", ROLE_MAPPING.get(message.type, ""))
458+
role = getattr(message, "role", ROLE_MAPPING.get(getattr(message, "type", None), ""))
456459
input_messages.append({"content": str(content), "role": str(role)})
457460
span._set_ctx_item(input_tag_key, input_messages)
458461

ddtrace/llmobs/_integrations/langgraph.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,11 @@ def _handle_finished_graph(self, graph_span, finished_tasks, is_subgraph_node):
107107
def _link_task_to_parent(self, task_id, task, finished_task_names_to_ids):
108108
"""Create the span links for a queued task from its triggering parent tasks."""
109109
task_config = getattr(task, "config", {})
110-
task_triggers = task_config.get("metadata", {}).get("langgraph_triggers", [])
110+
task_triggers = _normalize_triggers(
111+
triggers=task_config.get("metadata", {}).get("langgraph_triggers", []),
112+
finished_tasks=finished_task_names_to_ids,
113+
next_task=task,
114+
)
111115

112116
trigger_node_names = [_extract_parent(trigger) for trigger in task_triggers]
113117
trigger_node_ids: List[str] = [
@@ -132,6 +136,23 @@ def _link_task_to_parent(self, task_id, task, finished_task_names_to_ids):
132136
span_links.append(span_link)
133137

134138

139+
def _normalize_triggers(triggers, finished_tasks, next_task) -> List[str]:
140+
"""
141+
Return the default triggers for a LangGraph node.
142+
143+
For nodes queued up with `langgraph.types.Send`, the triggers are an unhelpful ['__pregel_push'].
144+
In this case (and in any case with 1 finished task and 1 trigger), we can infer the trigger from
145+
the one finished task.
146+
"""
147+
if len(finished_tasks) != 1 or len(triggers) != 1:
148+
return []
149+
150+
finished_task_name = list(finished_tasks.keys())[0]
151+
next_task_name = getattr(next_task, "name", "")
152+
153+
return [f"{finished_task_name}:{next_task_name}"]
154+
155+
135156
def _extract_parent(trigger: str) -> str:
136157
"""
137158
Extract the parent node name from a trigger string.

0 commit comments

Comments
 (0)