@@ -201,115 +201,90 @@ def _set_links(self, span: Span):
201201 instance = _extract_bound (instance )
202202 parent_span = _get_nearest_llmobs_ancestor (span )
203203
204- step_idx = self ._set_input_links (instance , span , parent_span )
204+ prev_traced_step_idx = self ._set_input_links (instance , span , parent_span )
205205
206- self ._set_output_links (span , parent_span , step_idx )
206+ self ._set_output_links (span , parent_span , prev_traced_step_idx )
207207
208208 def _set_input_links (self , instance : Any , span : Span , parent_span : Union [Span , None ]) -> int :
209209 """
210210 Sets input links (to: input) on the given span
211211 1. If the instance associated with the span is not a step in a chain, link from its parent span (input->input)
212212 2. If the instance associated with the span is a step in a chain, link from the last traced step in the chain
213213 a. This could be multiple steps, if the last step was a RunnableParallel
214- b. In this case, it would be an output->input relationship
214+ b. If there was no previous traced step, link from the parent span (input->input)
215+ b. Otherwise, it would be an output->input relationship with the previously traced span(s)
215216 """
216217 if parent_span is None :
217218 return - 1
218219
219220 is_step = id (instance ) in self ._chain_steps
220221
221- # defaults
222- invoker_spans = [parent_span ]
223- invoker_links_attributes = [{"from" : "input" , "to" : "input" }]
224- has_parallel_steps = False
225- step_idx = - 1
226-
227- links = []
228-
229222 if not is_step :
230- self ._set_span_links (
231- span ,
232- [
233- {
234- "trace_id" : "{:x}" .format (span .trace_id ),
235- "span_id" : str (invoker_spans [0 ].span_id ),
236- "attributes" : invoker_links_attributes [0 ],
237- }
238- ],
239- )
223+ self ._set_span_links (span , [parent_span ], "input" , "input" )
240224
241- return step_idx
225+ return - 1
242226
243- chain_instance = _extract_bound (self ._instances .get (invoker_spans [ 0 ] ))
227+ chain_instance = _extract_bound (self ._instances .get (parent_span ))
244228 steps = getattr (chain_instance , "steps" , [])
245229 flatmap_chain_steps = _flattened_chain_steps (steps )
246- for i , step in enumerate (flatmap_chain_steps ):
247- if id (step ) == id (instance ) or (
248- isinstance (step , list ) and any (id (sub_step ) == id (instance ) for sub_step in step )
249- ):
250- step_idx = i
251- break
252- for i in range (step_idx - 1 , - 1 , - 1 ):
253- step = flatmap_chain_steps [i ]
254- if id (step ) in self ._spans :
255- invoker_span = self ._spans [id (step )]
256- invoker_link_attributes = {"from" : "output" , "to" : "input" }
257- break
258- if isinstance (step , list ): # parallel steps in the list
259- for parallel_step in step :
260- if id (parallel_step ) in self ._spans :
261- if not has_parallel_steps :
262- invoker_spans = []
263- invoker_links_attributes = []
264- has_parallel_steps = True
265-
266- invoker_spans .append (self ._spans [id (parallel_step )])
267- invoker_links_attributes .append ({"from" : "output" , "to" : "input" })
268- break
269-
270- for link_data in zip (invoker_spans , invoker_links_attributes ):
271- invoker_span , invoker_link_attributes = link_data
272- if invoker_span is None :
273- continue
274- links .append (
275- {
276- "trace_id" : "{:x}" .format (span .trace_id ),
277- "span_id" : str (invoker_span .span_id ),
278- "attributes" : invoker_link_attributes ,
279- }
280- )
230+ prev_traced_step_idx = self ._find_previous_traced_step_index (instance , flatmap_chain_steps )
281231
282- self ._set_span_links (span , links )
232+ if prev_traced_step_idx == - 1 :
233+ self ._set_span_links (span , [parent_span ], "input" , "input" )
283234
284- return step_idx
235+ return prev_traced_step_idx
285236
286- def _set_output_links (self , span : Span , parent_span : Union [Span , None ], step_idx : int ) -> None :
237+ invoker_spans = []
238+ prev_traced_step = flatmap_chain_steps [prev_traced_step_idx ]
239+ if isinstance (prev_traced_step , list ):
240+ for parallel_step in prev_traced_step :
241+ if id (parallel_step ) in self ._spans :
242+ invoker_spans .append (self ._spans [id (parallel_step )])
243+ else :
244+ invoker_spans .append (self ._spans [id (prev_traced_step )])
245+
246+ self ._set_span_links (span , invoker_spans , "output" , "input" )
247+
248+ return prev_traced_step_idx
249+
250+ def _find_previous_traced_step_index (self , instance , flatmap_chain_steps ):
251+ """
252+ Finds the index in the list of steps of the last traced step in the chain before the current instance.
253+ """
254+ curr_idx = 0
255+ curr_step = flatmap_chain_steps [0 ]
256+ prev_traced_step_idx = - 1
257+
258+ while (
259+ curr_idx < len (flatmap_chain_steps )
260+ and id (curr_step ) != id (instance )
261+ and not (isinstance (curr_step , list ) and any (id (sub_step ) == id (instance ) for sub_step in curr_step ))
262+ ):
263+ if id (curr_step ) in self ._spans or (
264+ isinstance (curr_step , list ) and any (id (sub_step ) in self ._spans for sub_step in curr_step )
265+ ):
266+ prev_traced_step_idx = curr_idx
267+ curr_idx += 1
268+ curr_step = flatmap_chain_steps [curr_idx ]
269+
270+ return prev_traced_step_idx
271+
272+ def _set_output_links (self , span : Span , parent_span : Union [Span , None ], prev_traced_step_idx : int ) -> None :
287273 """
288274 Sets the output links for the parent span of the given span (to: output)
289275 This is done by removing repeated span links from steps in a chain.
290- We add output->output span links at every step
276+ We add output->output span links at every step.
291277 """
292278 if parent_span is None :
293279 return
294280
295281 parent_links = parent_span ._get_ctx_item (SPAN_LINKS ) or []
296- pop_indecies = self ._get_popped_span_link_indecies (parent_span , parent_links , step_idx )
297- parent_links = [link for i , link in enumerate (parent_links ) if i not in pop_indecies ]
298-
299- parent_span ._set_ctx_item (
300- SPAN_LINKS ,
301- parent_links
302- + [
303- {
304- "trace_id" : "{:x}" .format (span .trace_id ),
305- "span_id" : str (span .span_id ),
306- "attributes" : {"from" : "output" , "to" : "output" },
307- }
308- ],
309- )
282+ pop_indecies = self ._get_popped_span_link_indecies (parent_span , parent_links , prev_traced_step_idx )
283+
284+ self ._set_span_links (parent_span , [span ], "output" , "output" , popped_span_link_indecies = pop_indecies )
310285
311286 def _get_popped_span_link_indecies (
312- self , parent_span : Span , parent_links : List [Dict [str , Any ]], step_idx : int
287+ self , parent_span : Span , parent_links : List [Dict [str , Any ]], prev_traced_step_idx : int
313288 ) -> List [int ]:
314289 """
315290 Returns a list of indecies to pop from the parent span links list
@@ -321,7 +296,7 @@ def _get_popped_span_link_indecies(
321296 """
322297 pop_indecies : List [int ] = []
323298 parent_instance = self ._instances .get (parent_span )
324- if not parent_instance :
299+ if not parent_instance or prev_traced_step_idx == - 1 :
325300 return pop_indecies
326301
327302 parent_instance = _extract_bound (parent_instance )
@@ -330,33 +305,47 @@ def _get_popped_span_link_indecies(
330305
331306 steps = getattr (parent_instance , "steps" , [])
332307 flatmap_chain_steps = _flattened_chain_steps (steps )
333- for i in range (step_idx - 1 , - 1 , - 1 ):
334- step = flatmap_chain_steps [i ]
335- if id (step ) in self ._spans :
336- invoker_span_id = self ._spans [id (step )].span_id
337- link_idx = next (
338- (i for i , link in enumerate (parent_links ) if link ["span_id" ] == str (invoker_span_id )), None
339- )
340- if link_idx is not None :
341- pop_indecies .append (link_idx )
342- break
343- if isinstance (step , list ): # parallel steps in the list
344- for parallel_step in step :
345- if id (parallel_step ) in self ._spans :
346- invoker_span_id = self ._spans [id (parallel_step )].span_id
347- link_idx = next (
348- (i for i , link in enumerate (parent_links ) if link ["span_id" ] == str (invoker_span_id )),
349- None ,
350- )
351- if link_idx is not None :
352- pop_indecies .append (link_idx )
353- break
308+ prev_traced_step = flatmap_chain_steps [prev_traced_step_idx ]
309+
310+ if isinstance (prev_traced_step , list ):
311+ for parallel_step in prev_traced_step :
312+ if id (parallel_step ) in self ._spans :
313+ invoker_span_id = self ._spans [id (parallel_step )].span_id
314+ link_idx = next (
315+ (i for i , link in enumerate (parent_links ) if link ["span_id" ] == str (invoker_span_id )), None
316+ )
317+ if link_idx is not None :
318+ pop_indecies .append (link_idx )
319+ else :
320+ invoker_span_id = self ._spans [id (prev_traced_step )].span_id
321+ link_idx = next ((i for i , link in enumerate (parent_links ) if link ["span_id" ] == str (invoker_span_id )), None )
322+ if link_idx is not None :
323+ pop_indecies .append (link_idx )
354324
355325 return pop_indecies
356326
357- def _set_span_links (self , span : Span , links : List [Dict [str , Any ]]) -> None :
327+ def _set_span_links (
328+ self ,
329+ span : Span ,
330+ from_spans : List [Span ],
331+ link_from : str ,
332+ link_to : str ,
333+ popped_span_link_indecies : Optional [List [int ]] = None ,
334+ ) -> None :
358335 """Sets the span links on the given span along with the existing links."""
359336 existing_links = span ._get_ctx_item (SPAN_LINKS ) or []
337+
338+ if popped_span_link_indecies :
339+ existing_links = [link for i , link in enumerate (existing_links ) if i not in popped_span_link_indecies ]
340+
341+ links = [
342+ {
343+ "trace_id" : "{:x}" .format (from_span .trace_id ),
344+ "span_id" : str (from_span .span_id ),
345+ "attributes" : {"from" : link_from , "to" : link_to },
346+ }
347+ for from_span in from_spans
348+ ]
360349 span ._set_ctx_item (SPAN_LINKS , existing_links + links )
361350
362351 def _llmobs_set_metadata (self , span : Span , model_provider : Optional [str ] = None ) -> None :
0 commit comments