@@ -10,7 +10,7 @@ use crate::http::types::{
10
10
VertexResponse ,
11
11
} ;
12
12
use crate :: {
13
- shutdown, ClassifierModel , EmbeddingModel , ErrorResponse , ErrorType , Info , ModelType ,
13
+ logging , shutdown, ClassifierModel , EmbeddingModel , ErrorResponse , ErrorType , Info , ModelType ,
14
14
ResponseMetadata ,
15
15
} ;
16
16
use :: http:: HeaderMap ;
@@ -39,6 +39,7 @@ use text_embeddings_core::TextEmbeddingsError;
39
39
use tokio:: sync:: OwnedSemaphorePermit ;
40
40
use tower_http:: cors:: { AllowOrigin , CorsLayer } ;
41
41
use tracing:: instrument;
42
+ use tracing_opentelemetry:: OpenTelemetrySpanExt ;
42
43
use utoipa:: OpenApi ;
43
44
use utoipa_swagger_ui:: SwaggerUi ;
44
45
@@ -103,9 +104,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
103
104
async fn predict (
104
105
infer : Extension < Infer > ,
105
106
info : Extension < Info > ,
107
+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
106
108
Json ( req) : Json < PredictRequest > ,
107
109
) -> Result < ( HeaderMap , Json < PredictResponse > ) , ( StatusCode , Json < ErrorResponse > ) > {
108
110
let span = tracing:: Span :: current ( ) ;
111
+ if let Some ( context) = context {
112
+ span. set_parent ( context) ;
113
+ }
114
+
109
115
let start_time = Instant :: now ( ) ;
110
116
111
117
// Closure for predict
@@ -301,9 +307,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
301
307
async fn rerank (
302
308
infer : Extension < Infer > ,
303
309
info : Extension < Info > ,
310
+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
304
311
Json ( req) : Json < RerankRequest > ,
305
312
) -> Result < ( HeaderMap , Json < RerankResponse > ) , ( StatusCode , Json < ErrorResponse > ) > {
306
313
let span = tracing:: Span :: current ( ) ;
314
+ if let Some ( context) = context {
315
+ span. set_parent ( context) ;
316
+ }
317
+
307
318
let start_time = Instant :: now ( ) ;
308
319
309
320
if req. texts . is_empty ( ) {
@@ -489,6 +500,7 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
489
500
async fn similarity (
490
501
infer : Extension < Infer > ,
491
502
info : Extension < Info > ,
503
+ context : Extension < Option < opentelemetry:: Context > > ,
492
504
Json ( req) : Json < SimilarityRequest > ,
493
505
) -> Result < ( HeaderMap , Json < SimilarityResponse > ) , ( StatusCode , Json < ErrorResponse > ) > {
494
506
if req. inputs . sentences . is_empty ( ) {
@@ -535,7 +547,7 @@ async fn similarity(
535
547
} ;
536
548
537
549
// Get embeddings
538
- let ( header_map, embed_response) = embed ( infer, info, Json ( embed_req) ) . await ?;
550
+ let ( header_map, embed_response) = embed ( infer, info, context , Json ( embed_req) ) . await ?;
539
551
let embeddings = embed_response. 0 . 0 ;
540
552
541
553
// Compute cosine
@@ -573,9 +585,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
573
585
async fn embed (
574
586
infer : Extension < Infer > ,
575
587
info : Extension < Info > ,
588
+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
576
589
Json ( req) : Json < EmbedRequest > ,
577
590
) -> Result < ( HeaderMap , Json < EmbedResponse > ) , ( StatusCode , Json < ErrorResponse > ) > {
578
591
let span = tracing:: Span :: current ( ) ;
592
+ if let Some ( context) = context {
593
+ span. set_parent ( context) ;
594
+ }
595
+
579
596
let start_time = Instant :: now ( ) ;
580
597
581
598
let truncate = req. truncate . unwrap_or ( info. auto_truncate ) ;
@@ -742,9 +759,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
742
759
async fn embed_sparse (
743
760
infer : Extension < Infer > ,
744
761
info : Extension < Info > ,
762
+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
745
763
Json ( req) : Json < EmbedSparseRequest > ,
746
764
) -> Result < ( HeaderMap , Json < EmbedSparseResponse > ) , ( StatusCode , Json < ErrorResponse > ) > {
747
765
let span = tracing:: Span :: current ( ) ;
766
+ if let Some ( context) = context {
767
+ span. set_parent ( context) ;
768
+ }
769
+
748
770
let start_time = Instant :: now ( ) ;
749
771
750
772
let sparsify = |values : Vec < f32 > | {
@@ -920,9 +942,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
920
942
async fn embed_all (
921
943
infer : Extension < Infer > ,
922
944
info : Extension < Info > ,
945
+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
923
946
Json ( req) : Json < EmbedAllRequest > ,
924
947
) -> Result < ( HeaderMap , Json < EmbedAllResponse > ) , ( StatusCode , Json < ErrorResponse > ) > {
925
948
let span = tracing:: Span :: current ( ) ;
949
+ if let Some ( context) = context {
950
+ span. set_parent ( context) ;
951
+ }
952
+
926
953
let start_time = Instant :: now ( ) ;
927
954
928
955
let truncate = req. truncate . unwrap_or ( info. auto_truncate ) ;
@@ -1087,6 +1114,7 @@ example = json ! ({"message": "Batch size error", "type": "validation"})),
1087
1114
async fn openai_embed (
1088
1115
infer : Extension < Infer > ,
1089
1116
info : Extension < Info > ,
1117
+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
1090
1118
Json ( req) : Json < OpenAICompatRequest > ,
1091
1119
) -> Result < ( HeaderMap , Json < OpenAICompatResponse > ) , ( StatusCode , Json < OpenAICompatErrorResponse > ) >
1092
1120
{
@@ -1106,6 +1134,10 @@ async fn openai_embed(
1106
1134
} ;
1107
1135
1108
1136
let span = tracing:: Span :: current ( ) ;
1137
+ if let Some ( context) = context {
1138
+ span. set_parent ( context) ;
1139
+ }
1140
+
1109
1141
let start_time = Instant :: now ( ) ;
1110
1142
1111
1143
let truncate = info. auto_truncate ;
@@ -1469,54 +1501,71 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
1469
1501
async fn vertex_compatibility (
1470
1502
infer : Extension < Infer > ,
1471
1503
info : Extension < Info > ,
1504
+ context : Extension < Option < opentelemetry:: Context > > ,
1472
1505
Json ( req) : Json < VertexRequest > ,
1473
1506
) -> Result < Json < VertexResponse > , ( StatusCode , Json < ErrorResponse > ) > {
1474
- let embed_future = move |infer : Extension < Infer > , info : Extension < Info > , req : EmbedRequest | async move {
1475
- let result = embed ( infer, info, Json ( req) ) . await ?;
1507
+ let embed_future = move |infer : Extension < Infer > ,
1508
+ info : Extension < Info > ,
1509
+ context : Extension < Option < opentelemetry:: Context > > ,
1510
+ req : EmbedRequest | async move {
1511
+ let result = embed ( infer, info, context, Json ( req) ) . await ?;
1476
1512
Ok ( VertexPrediction :: Embed ( result. 1 . 0 ) )
1477
1513
} ;
1478
- let embed_sparse_future =
1479
- move |infer : Extension < Infer > , info : Extension < Info > , req : EmbedSparseRequest | async move {
1480
- let result = embed_sparse ( infer, info, Json ( req) ) . await ?;
1481
- Ok ( VertexPrediction :: EmbedSparse ( result. 1 . 0 ) )
1482
- } ;
1483
- let predict_future =
1484
- move |infer : Extension < Infer > , info : Extension < Info > , req : PredictRequest | async move {
1485
- let result = predict ( infer, info, Json ( req) ) . await ?;
1486
- Ok ( VertexPrediction :: Predict ( result. 1 . 0 ) )
1487
- } ;
1488
- let rerank_future =
1489
- move |infer : Extension < Infer > , info : Extension < Info > , req : RerankRequest | async move {
1490
- let result = rerank ( infer, info, Json ( req) ) . await ?;
1491
- Ok ( VertexPrediction :: Rerank ( result. 1 . 0 ) )
1492
- } ;
1514
+ let embed_sparse_future = move |infer : Extension < Infer > ,
1515
+ info : Extension < Info > ,
1516
+ context : Extension < Option < opentelemetry:: Context > > ,
1517
+ req : EmbedSparseRequest | async move {
1518
+ let result = embed_sparse ( infer, info, context, Json ( req) ) . await ?;
1519
+ Ok ( VertexPrediction :: EmbedSparse ( result. 1 . 0 ) )
1520
+ } ;
1521
+ let predict_future = move |infer : Extension < Infer > ,
1522
+ info : Extension < Info > ,
1523
+ context : Extension < Option < opentelemetry:: Context > > ,
1524
+ req : PredictRequest | async move {
1525
+ let result = predict ( infer, info, context, Json ( req) ) . await ?;
1526
+ Ok ( VertexPrediction :: Predict ( result. 1 . 0 ) )
1527
+ } ;
1528
+ let rerank_future = move |infer : Extension < Infer > ,
1529
+ info : Extension < Info > ,
1530
+ context : Extension < Option < opentelemetry:: Context > > ,
1531
+ req : RerankRequest | async move {
1532
+ let result = rerank ( infer, info, context, Json ( req) ) . await ?;
1533
+ Ok ( VertexPrediction :: Rerank ( result. 1 . 0 ) )
1534
+ } ;
1493
1535
1494
1536
let mut futures = Vec :: with_capacity ( req. instances . len ( ) ) ;
1495
1537
for instance in req. instances {
1496
1538
let local_infer = infer. clone ( ) ;
1497
1539
let local_info = info. clone ( ) ;
1540
+ let local_context = context. clone ( ) ;
1498
1541
1499
1542
// Rerank is the only payload that can me matched safely
1500
1543
if let Ok ( instance) = serde_json:: from_value :: < RerankRequest > ( instance. clone ( ) ) {
1501
- futures. push ( rerank_future ( local_infer, local_info, instance) . boxed ( ) ) ;
1544
+ futures. push ( rerank_future ( local_infer, local_info, local_context , instance) . boxed ( ) ) ;
1502
1545
continue ;
1503
1546
}
1504
1547
1505
1548
match info. model_type {
1506
1549
ModelType :: Classifier ( _) | ModelType :: Reranker ( _) => {
1507
1550
let instance = serde_json:: from_value :: < PredictRequest > ( instance)
1508
1551
. map_err ( ErrorResponse :: from) ?;
1509
- futures. push ( predict_future ( local_infer, local_info, instance) . boxed ( ) ) ;
1552
+ futures
1553
+ . push ( predict_future ( local_infer, local_info, local_context, instance) . boxed ( ) ) ;
1510
1554
}
1511
1555
ModelType :: Embedding ( _) => {
1512
1556
if infer. is_splade ( ) {
1513
1557
let instance = serde_json:: from_value :: < EmbedSparseRequest > ( instance)
1514
1558
. map_err ( ErrorResponse :: from) ?;
1515
- futures. push ( embed_sparse_future ( local_infer, local_info, instance) . boxed ( ) ) ;
1559
+ futures. push (
1560
+ embed_sparse_future ( local_infer, local_info, local_context, instance)
1561
+ . boxed ( ) ,
1562
+ ) ;
1516
1563
} else {
1517
1564
let instance = serde_json:: from_value :: < EmbedRequest > ( instance)
1518
1565
. map_err ( ErrorResponse :: from) ?;
1519
- futures. push ( embed_future ( local_infer, local_info, instance) . boxed ( ) ) ;
1566
+ futures. push (
1567
+ embed_future ( local_infer, local_info, local_context, instance) . boxed ( ) ,
1568
+ ) ;
1520
1569
}
1521
1570
}
1522
1571
}
@@ -1784,6 +1833,7 @@ pub async fn run(
1784
1833
. layer ( Extension ( info) )
1785
1834
. layer ( Extension ( prom_handle. clone ( ) ) )
1786
1835
. layer ( OtelAxumLayer :: default ( ) )
1836
+ . layer ( axum:: middleware:: from_fn ( logging:: trace_context_middleware) )
1787
1837
. layer ( DefaultBodyLimit :: max ( payload_limit) )
1788
1838
. layer ( cors_layer) ;
1789
1839
0 commit comments