diff --git a/ddtrace/llmobs/_integrations/langchain.py b/ddtrace/llmobs/_integrations/langchain.py index 0694fae341..ed50a65873 100644 --- a/ddtrace/llmobs/_integrations/langchain.py +++ b/ddtrace/llmobs/_integrations/langchain.py @@ -201,9 +201,9 @@ def _set_links(self, span: Span): instance = _extract_bound(instance) parent_span = _get_nearest_llmobs_ancestor(span) - step_idx = self._set_input_links(instance, span, parent_span) + prev_traced_step_idx = self._set_input_links(instance, span, parent_span) - self._set_output_links(span, parent_span, step_idx) + self._set_output_links(span, parent_span, prev_traced_step_idx) def _set_input_links(self, instance: Any, span: Span, parent_span: Union[Span, None]) -> int: """ @@ -211,105 +211,80 @@ def _set_input_links(self, instance: Any, span: Span, parent_span: Union[Span, N 1. If the instance associated with the span is not a step in a chain, link from its parent span (input->input) 2. If the instance associated with the span is a step in a chain, link from the last traced step in the chain a. This could be multiple steps, if the last step was a RunnableParallel - b. In this case, it would be an output->input relationship + b. If there was no previous traced step, link from the parent span (input->input) + b. Otherwise, it would be an output->input relationship with the previously traced span(s) """ if parent_span is None: return -1 is_step = id(instance) in self._chain_steps - # defaults - invoker_spans = [parent_span] - invoker_links_attributes = [{"from": "input", "to": "input"}] - has_parallel_steps = False - step_idx = -1 - - links = [] - if not is_step: - self._set_span_links( - span, - [ - { - "trace_id": "{:x}".format(span.trace_id), - "span_id": str(invoker_spans[0].span_id), - "attributes": invoker_links_attributes[0], - } - ], - ) + self._set_span_links(span, [parent_span], "input", "input") - return step_idx + return -1 - chain_instance = _extract_bound(self._instances.get(invoker_spans[0])) + chain_instance = _extract_bound(self._instances.get(parent_span)) steps = getattr(chain_instance, "steps", []) flatmap_chain_steps = _flattened_chain_steps(steps) - for i, step in enumerate(flatmap_chain_steps): - if id(step) == id(instance) or ( - isinstance(step, list) and any(id(sub_step) == id(instance) for sub_step in step) - ): - step_idx = i - break - for i in range(step_idx - 1, -1, -1): - step = flatmap_chain_steps[i] - if id(step) in self._spans: - invoker_span = self._spans[id(step)] - invoker_link_attributes = {"from": "output", "to": "input"} - break - if isinstance(step, list): # parallel steps in the list - for parallel_step in step: - if id(parallel_step) in self._spans: - if not has_parallel_steps: - invoker_spans = [] - invoker_links_attributes = [] - has_parallel_steps = True - - invoker_spans.append(self._spans[id(parallel_step)]) - invoker_links_attributes.append({"from": "output", "to": "input"}) - break - - for link_data in zip(invoker_spans, invoker_links_attributes): - invoker_span, invoker_link_attributes = link_data - if invoker_span is None: - continue - links.append( - { - "trace_id": "{:x}".format(span.trace_id), - "span_id": str(invoker_span.span_id), - "attributes": invoker_link_attributes, - } - ) + prev_traced_step_idx = self._find_previous_traced_step_index(instance, flatmap_chain_steps) - self._set_span_links(span, links) + if prev_traced_step_idx == -1: + self._set_span_links(span, [parent_span], "input", "input") - return step_idx + return prev_traced_step_idx - def _set_output_links(self, span: Span, parent_span: Union[Span, None], step_idx: int) -> None: + invoker_spans = [] + prev_traced_step = flatmap_chain_steps[prev_traced_step_idx] + if isinstance(prev_traced_step, list): + for parallel_step in prev_traced_step: + if id(parallel_step) in self._spans: + invoker_spans.append(self._spans[id(parallel_step)]) + else: + invoker_spans.append(self._spans[id(prev_traced_step)]) + + self._set_span_links(span, invoker_spans, "output", "input") + + return prev_traced_step_idx + + def _find_previous_traced_step_index(self, instance, flatmap_chain_steps): + """ + Finds the index in the list of steps of the last traced step in the chain before the current instance. + """ + curr_idx = 0 + curr_step = flatmap_chain_steps[0] + prev_traced_step_idx = -1 + + while ( + curr_idx < len(flatmap_chain_steps) + and id(curr_step) != id(instance) + and not (isinstance(curr_step, list) and any(id(sub_step) == id(instance) for sub_step in curr_step)) + ): + if id(curr_step) in self._spans or ( + isinstance(curr_step, list) and any(id(sub_step) in self._spans for sub_step in curr_step) + ): + prev_traced_step_idx = curr_idx + curr_idx += 1 + curr_step = flatmap_chain_steps[curr_idx] + + return prev_traced_step_idx + + def _set_output_links(self, span: Span, parent_span: Union[Span, None], prev_traced_step_idx: int) -> None: """ Sets the output links for the parent span of the given span (to: output) This is done by removing repeated span links from steps in a chain. - We add output->output span links at every step + We add output->output span links at every step. """ if parent_span is None: return parent_links = parent_span._get_ctx_item(SPAN_LINKS) or [] - pop_indecies = self._get_popped_span_link_indecies(parent_span, parent_links, step_idx) - parent_links = [link for i, link in enumerate(parent_links) if i not in pop_indecies] - - parent_span._set_ctx_item( - SPAN_LINKS, - parent_links - + [ - { - "trace_id": "{:x}".format(span.trace_id), - "span_id": str(span.span_id), - "attributes": {"from": "output", "to": "output"}, - } - ], - ) + pop_indecies = self._get_popped_span_link_indecies(parent_span, parent_links, prev_traced_step_idx) + + self._set_span_links(parent_span, [span], "output", "output", popped_span_link_indecies=pop_indecies) def _get_popped_span_link_indecies( - self, parent_span: Span, parent_links: List[Dict[str, Any]], step_idx: int + self, parent_span: Span, parent_links: List[Dict[str, Any]], prev_traced_step_idx: int ) -> List[int]: """ Returns a list of indecies to pop from the parent span links list @@ -321,7 +296,7 @@ def _get_popped_span_link_indecies( """ pop_indecies: List[int] = [] parent_instance = self._instances.get(parent_span) - if not parent_instance: + if not parent_instance or prev_traced_step_idx == -1: return pop_indecies parent_instance = _extract_bound(parent_instance) @@ -330,33 +305,47 @@ def _get_popped_span_link_indecies( steps = getattr(parent_instance, "steps", []) flatmap_chain_steps = _flattened_chain_steps(steps) - for i in range(step_idx - 1, -1, -1): - step = flatmap_chain_steps[i] - if id(step) in self._spans: - invoker_span_id = self._spans[id(step)].span_id - link_idx = next( - (i for i, link in enumerate(parent_links) if link["span_id"] == str(invoker_span_id)), None - ) - if link_idx is not None: - pop_indecies.append(link_idx) - break - if isinstance(step, list): # parallel steps in the list - for parallel_step in step: - if id(parallel_step) in self._spans: - invoker_span_id = self._spans[id(parallel_step)].span_id - link_idx = next( - (i for i, link in enumerate(parent_links) if link["span_id"] == str(invoker_span_id)), - None, - ) - if link_idx is not None: - pop_indecies.append(link_idx) - break + prev_traced_step = flatmap_chain_steps[prev_traced_step_idx] + + if isinstance(prev_traced_step, list): + for parallel_step in prev_traced_step: + if id(parallel_step) in self._spans: + invoker_span_id = self._spans[id(parallel_step)].span_id + link_idx = next( + (i for i, link in enumerate(parent_links) if link["span_id"] == str(invoker_span_id)), None + ) + if link_idx is not None: + pop_indecies.append(link_idx) + else: + invoker_span_id = self._spans[id(prev_traced_step)].span_id + link_idx = next((i for i, link in enumerate(parent_links) if link["span_id"] == str(invoker_span_id)), None) + if link_idx is not None: + pop_indecies.append(link_idx) return pop_indecies - def _set_span_links(self, span: Span, links: List[Dict[str, Any]]) -> None: + def _set_span_links( + self, + span: Span, + from_spans: List[Span], + link_from: str, + link_to: str, + popped_span_link_indecies: Optional[List[int]] = None, + ) -> None: """Sets the span links on the given span along with the existing links.""" existing_links = span._get_ctx_item(SPAN_LINKS) or [] + + if popped_span_link_indecies: + existing_links = [link for i, link in enumerate(existing_links) if i not in popped_span_link_indecies] + + links = [ + { + "trace_id": "{:x}".format(from_span.trace_id), + "span_id": str(from_span.span_id), + "attributes": {"from": link_from, "to": link_to}, + } + for from_span in from_spans + ] span._set_ctx_item(SPAN_LINKS, existing_links + links) def _llmobs_set_metadata(self, span: Span, model_provider: Optional[str] = None) -> None: