Skip to content

Commit 35d9a34

Browse files
authored
Get opentelemetry trace id from request headers (#425)
1 parent 9621564 commit 35d9a34

File tree

2 files changed

+128
-23
lines changed

2 files changed

+128
-23
lines changed

router/src/http/server.rs

+73-23
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::http::types::{
1010
VertexResponse,
1111
};
1212
use crate::{
13-
shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType,
13+
logging, shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType,
1414
ResponseMetadata,
1515
};
1616
use ::http::HeaderMap;
@@ -39,6 +39,7 @@ use text_embeddings_core::TextEmbeddingsError;
3939
use tokio::sync::OwnedSemaphorePermit;
4040
use tower_http::cors::{AllowOrigin, CorsLayer};
4141
use tracing::instrument;
42+
use tracing_opentelemetry::OpenTelemetrySpanExt;
4243
use utoipa::OpenApi;
4344
use utoipa_swagger_ui::SwaggerUi;
4445

@@ -103,9 +104,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
103104
async fn predict(
104105
infer: Extension<Infer>,
105106
info: Extension<Info>,
107+
Extension(context): Extension<Option<opentelemetry::Context>>,
106108
Json(req): Json<PredictRequest>,
107109
) -> Result<(HeaderMap, Json<PredictResponse>), (StatusCode, Json<ErrorResponse>)> {
108110
let span = tracing::Span::current();
111+
if let Some(context) = context {
112+
span.set_parent(context);
113+
}
114+
109115
let start_time = Instant::now();
110116

111117
// Closure for predict
@@ -301,9 +307,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
301307
async fn rerank(
302308
infer: Extension<Infer>,
303309
info: Extension<Info>,
310+
Extension(context): Extension<Option<opentelemetry::Context>>,
304311
Json(req): Json<RerankRequest>,
305312
) -> Result<(HeaderMap, Json<RerankResponse>), (StatusCode, Json<ErrorResponse>)> {
306313
let span = tracing::Span::current();
314+
if let Some(context) = context {
315+
span.set_parent(context);
316+
}
317+
307318
let start_time = Instant::now();
308319

309320
if req.texts.is_empty() {
@@ -489,6 +500,7 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
489500
async fn similarity(
490501
infer: Extension<Infer>,
491502
info: Extension<Info>,
503+
context: Extension<Option<opentelemetry::Context>>,
492504
Json(req): Json<SimilarityRequest>,
493505
) -> Result<(HeaderMap, Json<SimilarityResponse>), (StatusCode, Json<ErrorResponse>)> {
494506
if req.inputs.sentences.is_empty() {
@@ -535,7 +547,7 @@ async fn similarity(
535547
};
536548

537549
// Get embeddings
538-
let (header_map, embed_response) = embed(infer, info, Json(embed_req)).await?;
550+
let (header_map, embed_response) = embed(infer, info, context, Json(embed_req)).await?;
539551
let embeddings = embed_response.0 .0;
540552

541553
// Compute cosine
@@ -573,9 +585,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
573585
async fn embed(
574586
infer: Extension<Infer>,
575587
info: Extension<Info>,
588+
Extension(context): Extension<Option<opentelemetry::Context>>,
576589
Json(req): Json<EmbedRequest>,
577590
) -> Result<(HeaderMap, Json<EmbedResponse>), (StatusCode, Json<ErrorResponse>)> {
578591
let span = tracing::Span::current();
592+
if let Some(context) = context {
593+
span.set_parent(context);
594+
}
595+
579596
let start_time = Instant::now();
580597

581598
let truncate = req.truncate.unwrap_or(info.auto_truncate);
@@ -742,9 +759,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
742759
async fn embed_sparse(
743760
infer: Extension<Infer>,
744761
info: Extension<Info>,
762+
Extension(context): Extension<Option<opentelemetry::Context>>,
745763
Json(req): Json<EmbedSparseRequest>,
746764
) -> Result<(HeaderMap, Json<EmbedSparseResponse>), (StatusCode, Json<ErrorResponse>)> {
747765
let span = tracing::Span::current();
766+
if let Some(context) = context {
767+
span.set_parent(context);
768+
}
769+
748770
let start_time = Instant::now();
749771

750772
let sparsify = |values: Vec<f32>| {
@@ -920,9 +942,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
920942
async fn embed_all(
921943
infer: Extension<Infer>,
922944
info: Extension<Info>,
945+
Extension(context): Extension<Option<opentelemetry::Context>>,
923946
Json(req): Json<EmbedAllRequest>,
924947
) -> Result<(HeaderMap, Json<EmbedAllResponse>), (StatusCode, Json<ErrorResponse>)> {
925948
let span = tracing::Span::current();
949+
if let Some(context) = context {
950+
span.set_parent(context);
951+
}
952+
926953
let start_time = Instant::now();
927954

928955
let truncate = req.truncate.unwrap_or(info.auto_truncate);
@@ -1087,6 +1114,7 @@ example = json ! ({"message": "Batch size error", "type": "validation"})),
10871114
async fn openai_embed(
10881115
infer: Extension<Infer>,
10891116
info: Extension<Info>,
1117+
Extension(context): Extension<Option<opentelemetry::Context>>,
10901118
Json(req): Json<OpenAICompatRequest>,
10911119
) -> Result<(HeaderMap, Json<OpenAICompatResponse>), (StatusCode, Json<OpenAICompatErrorResponse>)>
10921120
{
@@ -1106,6 +1134,10 @@ async fn openai_embed(
11061134
};
11071135

11081136
let span = tracing::Span::current();
1137+
if let Some(context) = context {
1138+
span.set_parent(context);
1139+
}
1140+
11091141
let start_time = Instant::now();
11101142

11111143
let truncate = info.auto_truncate;
@@ -1469,54 +1501,71 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
14691501
async fn vertex_compatibility(
14701502
infer: Extension<Infer>,
14711503
info: Extension<Info>,
1504+
context: Extension<Option<opentelemetry::Context>>,
14721505
Json(req): Json<VertexRequest>,
14731506
) -> Result<Json<VertexResponse>, (StatusCode, Json<ErrorResponse>)> {
1474-
let embed_future = move |infer: Extension<Infer>, info: Extension<Info>, req: EmbedRequest| async move {
1475-
let result = embed(infer, info, Json(req)).await?;
1507+
let embed_future = move |infer: Extension<Infer>,
1508+
info: Extension<Info>,
1509+
context: Extension<Option<opentelemetry::Context>>,
1510+
req: EmbedRequest| async move {
1511+
let result = embed(infer, info, context, Json(req)).await?;
14761512
Ok(VertexPrediction::Embed(result.1 .0))
14771513
};
1478-
let embed_sparse_future =
1479-
move |infer: Extension<Infer>, info: Extension<Info>, req: EmbedSparseRequest| async move {
1480-
let result = embed_sparse(infer, info, Json(req)).await?;
1481-
Ok(VertexPrediction::EmbedSparse(result.1 .0))
1482-
};
1483-
let predict_future =
1484-
move |infer: Extension<Infer>, info: Extension<Info>, req: PredictRequest| async move {
1485-
let result = predict(infer, info, Json(req)).await?;
1486-
Ok(VertexPrediction::Predict(result.1 .0))
1487-
};
1488-
let rerank_future =
1489-
move |infer: Extension<Infer>, info: Extension<Info>, req: RerankRequest| async move {
1490-
let result = rerank(infer, info, Json(req)).await?;
1491-
Ok(VertexPrediction::Rerank(result.1 .0))
1492-
};
1514+
let embed_sparse_future = move |infer: Extension<Infer>,
1515+
info: Extension<Info>,
1516+
context: Extension<Option<opentelemetry::Context>>,
1517+
req: EmbedSparseRequest| async move {
1518+
let result = embed_sparse(infer, info, context, Json(req)).await?;
1519+
Ok(VertexPrediction::EmbedSparse(result.1 .0))
1520+
};
1521+
let predict_future = move |infer: Extension<Infer>,
1522+
info: Extension<Info>,
1523+
context: Extension<Option<opentelemetry::Context>>,
1524+
req: PredictRequest| async move {
1525+
let result = predict(infer, info, context, Json(req)).await?;
1526+
Ok(VertexPrediction::Predict(result.1 .0))
1527+
};
1528+
let rerank_future = move |infer: Extension<Infer>,
1529+
info: Extension<Info>,
1530+
context: Extension<Option<opentelemetry::Context>>,
1531+
req: RerankRequest| async move {
1532+
let result = rerank(infer, info, context, Json(req)).await?;
1533+
Ok(VertexPrediction::Rerank(result.1 .0))
1534+
};
14931535

14941536
let mut futures = Vec::with_capacity(req.instances.len());
14951537
for instance in req.instances {
14961538
let local_infer = infer.clone();
14971539
let local_info = info.clone();
1540+
let local_context = context.clone();
14981541

14991542
// Rerank is the only payload that can me matched safely
15001543
if let Ok(instance) = serde_json::from_value::<RerankRequest>(instance.clone()) {
1501-
futures.push(rerank_future(local_infer, local_info, instance).boxed());
1544+
futures.push(rerank_future(local_infer, local_info, local_context, instance).boxed());
15021545
continue;
15031546
}
15041547

15051548
match info.model_type {
15061549
ModelType::Classifier(_) | ModelType::Reranker(_) => {
15071550
let instance = serde_json::from_value::<PredictRequest>(instance)
15081551
.map_err(ErrorResponse::from)?;
1509-
futures.push(predict_future(local_infer, local_info, instance).boxed());
1552+
futures
1553+
.push(predict_future(local_infer, local_info, local_context, instance).boxed());
15101554
}
15111555
ModelType::Embedding(_) => {
15121556
if infer.is_splade() {
15131557
let instance = serde_json::from_value::<EmbedSparseRequest>(instance)
15141558
.map_err(ErrorResponse::from)?;
1515-
futures.push(embed_sparse_future(local_infer, local_info, instance).boxed());
1559+
futures.push(
1560+
embed_sparse_future(local_infer, local_info, local_context, instance)
1561+
.boxed(),
1562+
);
15161563
} else {
15171564
let instance = serde_json::from_value::<EmbedRequest>(instance)
15181565
.map_err(ErrorResponse::from)?;
1519-
futures.push(embed_future(local_infer, local_info, instance).boxed());
1566+
futures.push(
1567+
embed_future(local_infer, local_info, local_context, instance).boxed(),
1568+
);
15201569
}
15211570
}
15221571
}
@@ -1784,6 +1833,7 @@ pub async fn run(
17841833
.layer(Extension(info))
17851834
.layer(Extension(prom_handle.clone()))
17861835
.layer(OtelAxumLayer::default())
1836+
.layer(axum::middleware::from_fn(logging::trace_context_middleware))
17871837
.layer(DefaultBodyLimit::max(payload_limit))
17881838
.layer(cors_layer);
17891839

router/src/logging.rs

+55
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
use axum::{extract::Request, middleware::Next, response::Response};
2+
use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId};
3+
use opentelemetry::Context;
14
use opentelemetry::{global, KeyValue};
25
use opentelemetry_otlp::WithExportConfig;
36
use opentelemetry_sdk::propagation::TraceContextPropagator;
@@ -7,6 +10,58 @@ use tracing_subscriber::layer::SubscriberExt;
710
use tracing_subscriber::util::SubscriberInitExt;
811
use tracing_subscriber::{EnvFilter, Layer};
912

13+
struct TraceParent {
14+
#[allow(dead_code)]
15+
version: u8,
16+
trace_id: TraceId,
17+
parent_id: SpanId,
18+
trace_flags: TraceFlags,
19+
}
20+
21+
fn parse_traceparent(header_value: &str) -> Option<TraceParent> {
22+
let parts: Vec<&str> = header_value.split('-').collect();
23+
if parts.len() != 4 {
24+
return None;
25+
}
26+
27+
let version = u8::from_str_radix(parts[0], 16).ok()?;
28+
if version == 0xff {
29+
return None;
30+
}
31+
32+
let trace_id = TraceId::from_hex(parts[1]).ok()?;
33+
let parent_id = SpanId::from_hex(parts[2]).ok()?;
34+
let trace_flags = u8::from_str_radix(parts[3], 16).ok()?;
35+
36+
Some(TraceParent {
37+
version,
38+
trace_id,
39+
parent_id,
40+
trace_flags: TraceFlags::new(trace_flags),
41+
})
42+
}
43+
44+
pub async fn trace_context_middleware(mut request: Request, next: Next) -> Response {
45+
let context = request
46+
.headers()
47+
.get("traceparent")
48+
.and_then(|v| v.to_str().ok())
49+
.and_then(parse_traceparent)
50+
.map(|traceparent| {
51+
Context::new().with_remote_span_context(SpanContext::new(
52+
traceparent.trace_id,
53+
traceparent.parent_id,
54+
traceparent.trace_flags,
55+
true,
56+
Default::default(),
57+
))
58+
});
59+
60+
request.extensions_mut().insert(context);
61+
62+
next.run(request).await
63+
}
64+
1065
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
1166
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
1267
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)

0 commit comments

Comments
 (0)