@@ -496,12 +496,20 @@ def __init__(self, base_path: str):
496
496
self .parties = ["alice" , "bob" ]
497
497
self .serving_cmd_dict = {}
498
498
self .serving_config_dict = {}
499
+
500
+ self .trace_configs = {}
501
+
502
+ self .trace_files = {}
499
503
for p in self .parties :
500
504
party_base_path = os .path .join (base_path , p )
501
505
serving_config_file = os .path .join (party_base_path , "serving.config" )
502
506
logging_config_file = os .path .join (party_base_path , "logging.config" )
507
+ trace_config_file = os .path .join (party_base_path , "trace.config" )
508
+
509
+ self .trace_configs [p ] = trace_config_file
510
+
503
511
self .serving_cmd_dict [p ] = (
504
- f"./bazel-bin/secretflow_serving/server/secretflow_serving --serving_config_file={ serving_config_file } --logging_config_file={ logging_config_file } "
512
+ f"./bazel-bin/secretflow_serving/server/secretflow_serving --serving_config_file={ serving_config_file } --logging_config_file={ logging_config_file } --trace_config_file= { trace_config_file } "
505
513
)
506
514
507
515
with open (serving_config_file , "r" ) as file :
@@ -526,13 +534,21 @@ def exec(self):
526
534
},
527
535
}
528
536
537
+ trace_id_prefix = "1234567890abcdef1234567890abcde"
538
+ self .span_id = "1234567890abcdef"
539
+ self .trace_id_map = {}
540
+ parties_index = 0
529
541
# make request
530
542
for p in self .parties :
543
+ trace_id = f"{ trace_id_prefix } { parties_index } "
544
+ self .trace_id_map [p ] = trace_id
545
+ parties_index += 1
531
546
res = self .run_cmd (
532
547
build_predict_cmd (
533
548
"127.0.0.1" ,
534
549
self .serving_config_dict [p ]['serverConf' ]['servicePort' ],
535
550
json .dumps (body_dict ),
551
+ {'X-B3-TraceId' : trace_id , 'X-B3-SpanId' : self .span_id },
536
552
)
537
553
)
538
554
out = res .stdout .decode ()
@@ -541,11 +557,85 @@ def exec(self):
541
557
assert (
542
558
res ["status" ]["code" ] == 1
543
559
), f'return status code({ res ["status" ]["code" ]} ) should be OK(1)'
560
+
561
+ # check trace log
562
+ self .check_trace_log ()
563
+
544
564
finally :
545
565
self .cleanup_sub_procs ()
566
+ self .cleanup_trace_files ()
567
+
568
+ def cleanup_trace_files (self ):
569
+ for p in self .parties :
570
+ trace_config_file = self .trace_configs [p ]
571
+ with open (trace_config_file , "r" ) as f :
572
+ trace_config = json .load (f )
573
+ trace_dir = trace_config ["traceLogConf" ]["traceLogPath" ]
574
+ if os .path .exists (trace_dir ):
575
+ os .remove (trace_dir )
576
+
577
+ def check_trace_log (self ):
578
+ def decode_bytes (proto_bytes ):
579
+ import base64
580
+
581
+ return base64 .b16encode (base64 .b64decode (proto_bytes )).lower ().decode ()
582
+
583
+ def get_spans_from_trace_file (trace_filename ):
584
+ spans = []
585
+ with open (trace_filename , 'r' ) as trace_file :
586
+ for line in trace_file :
587
+ start_index = line .find ('{' )
588
+ if start_index == - 1 :
589
+ continue
590
+ resource_span = json .loads (line [start_index :])
591
+ for scopeSpan in resource_span ['scopeSpans' ]:
592
+ spans .extend (scopeSpan ['spans' ])
593
+ return spans
594
+
595
+ stub_span_id_dict = {}
596
+ service_span_id_dict = {}
597
+
598
+ for p , config in self .trace_configs .items ():
599
+ stub_span_id_dict [p ] = set ()
600
+ service_span_id_dict [p ] = set ()
601
+
602
+ with open (config , "r" ) as f :
603
+ trace_config = json .load (f )
604
+ trace_dir = trace_config ["traceLogConf" ]["traceLogPath" ]
605
+ spans = get_spans_from_trace_file (trace_dir )
606
+ intrest_span_found = False
607
+ for span in spans :
608
+ if span ['name' ] == "PredictionService/Predict" :
609
+ assert self .trace_id_map [p ] == decode_bytes (
610
+ span ['traceId' ]
611
+ ), f"trace id mismatch, expected: { self .trace_id_map [p ]} , actual: { decode_bytes (span ['traceId' ])} "
612
+ assert self .span_id == decode_bytes (
613
+ span ['parentSpanId' ]
614
+ ), f"parent span id mismatch, expected: { self .span_id } , actual: { decode_bytes (span ['parentSpanId' ])} "
615
+ intrest_span_found = True
616
+ if (
617
+ span ['name' ].startswith ("ExecutionService" )
618
+ and span ['kind' ] == "SPAN_KIND_SERVER"
619
+ ):
620
+ service_span_id_dict [p ].add (decode_bytes (span ['parentSpanId' ]))
621
+ if (
622
+ span ['name' ].startswith ("ExecutionService" )
623
+ and span ['kind' ] == "SPAN_KIND_CLIENT"
624
+ ):
625
+ stub_span_id_dict [p ].add (decode_bytes (span ['spanId' ]))
626
+ assert intrest_span_found
627
+ assert (
628
+ service_span_id_dict ["alice" ] == stub_span_id_dict ["bob" ]
629
+ ), f"execution parent span id mismatch, expected: { service_span_id_dict ['alice' ]} , actual: { stub_span_id_dict ['bob' ]} "
630
+
631
+ assert (
632
+ service_span_id_dict ["bob" ] == stub_span_id_dict ["alice" ]
633
+ ), f"execution parent span id mismatch, expected: { service_span_id_dict ['bob' ]} , actual: { stub_span_id_dict ['alice' ]} "
546
634
547
635
548
636
if __name__ == "__main__" :
637
+ ExampleTest ('examples' ).exec ()
638
+
549
639
# glm
550
640
with open (".ci/simple_test/node_processing_alice.json" , "rb" ) as f :
551
641
alice_trace_content = f .read ()
@@ -1904,5 +1994,3 @@ def exec(self):
1904
1994
PredefineTest ('model_path' ).exec ()
1905
1995
CsvTest ('model_path' ).exec ()
1906
1996
SpecificTest ('model_path' ).exec ()
1907
-
1908
- ExampleTest ('examples' ).exec ()
0 commit comments