diff --git a/router/src/lib.rs b/router/src/lib.rs index 8a218874..6aa50f3c 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -95,6 +95,16 @@ pub(crate) struct EmbedRequest { #[schema(example = json ! ([["0.0", "1.0", "2.0"]]))] pub(crate) struct EmbedResponse(Vec>); +#[derive(Deserialize)] +pub(crate) struct HFECompatRequest { + pub inputs: String, +} + +#[derive(Serialize)] +pub(crate) struct HFECompatResponse { + embeddings: Vec, +} + #[derive(Serialize, ToSchema)] pub(crate) enum ErrorType { Unhealthy, diff --git a/router/src/server.rs b/router/src/server.rs index 2cf9cdf5..496161db 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,7 +1,8 @@ /// HTTP Server logic use crate::{ - EmbedRequest, EmbedResponse, ErrorResponse, ErrorType, Info, Input, OpenAICompatEmbedding, - OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage, + EmbedRequest, EmbedResponse, ErrorResponse, ErrorType, HFECompatRequest, HFECompatResponse, + Info, Input, OpenAICompatEmbedding, OpenAICompatErrorResponse, OpenAICompatRequest, + OpenAICompatResponse, OpenAICompatUsage, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -433,6 +434,93 @@ async fn metrics(prom_handle: Extension) -> String { prom_handle.render() } +/// Huggingface Inference Endpoint compatibility route +async fn hfe_embed( + infer: Extension, + Json(req): Json, +) -> Result<(HeaderMap, Json), (StatusCode, Json)> { + let span = tracing::Span::current(); + let start_time = Instant::now(); + + metrics::increment_counter!("te_request_count", "method" => "single"); + + let compute_chars = req.inputs.chars().count(); + + let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; + let response = infer + .embed(req.inputs, false, permit) + .await + .map_err(ErrorResponse::from)?; + + metrics::increment_counter!("te_request_success", "method" => "single"); + + let compute_tokens = response.prompt_tokens; + let tokenization_time = response.tokenization; + let queue_time = response.queue; + let inference_time = response.inference; + + let total_time = start_time.elapsed(); + + // Tracing metadata + span.record("total_time", format!("{total_time:?}")); + span.record("tokenization_time", format!("{tokenization_time:?}")); + span.record("queue_time", format!("{queue_time:?}")); + span.record("inference_time", format!("{inference_time:?}")); + + // Headers + let mut headers = HeaderMap::new(); + headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); + headers.insert( + "x-compute-time", + total_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-compute-characters", + compute_chars.to_string().parse().unwrap(), + ); + headers.insert( + "x-compute-tokens", + compute_tokens.to_string().parse().unwrap(), + ); + headers.insert( + "x-total-time", + total_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-tokenization-time", + tokenization_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-queue-time", + queue_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-inference-time", + inference_time.as_millis().to_string().parse().unwrap(), + ); + + // Metrics + metrics::histogram!("te_request_duration", total_time.as_secs_f64()); + metrics::histogram!( + "te_request_tokenization_duration", + tokenization_time.as_secs_f64() + ); + metrics::histogram!("e_request_queue_duration", queue_time.as_secs_f64()); + metrics::histogram!( + "te_request_inference_duration", + inference_time.as_secs_f64() + ); + + tracing::info!("Success"); + + Ok(( + headers, + Json(HFECompatResponse { + embeddings: response.embeddings, + }), + )) +} + /// Serving method pub async fn run( infer: Infer, @@ -529,7 +617,6 @@ pub async fn run( let app = Router::new() .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) // Base routes - .route("/", post(embed)) .route("/info", get(get_model_info)) .route("/embed", post(embed)) // OpenAI compat route @@ -543,7 +630,21 @@ pub async fn run( // AWS Sagemaker health route .route("/ping", get(health)) // Prometheus metrics route - .route("/metrics", get(metrics)) + .route("/metrics", get(metrics)); + + let app = if &std::env::var("HFE_COMPATIBILITY") + .unwrap_or("False".to_string()) + .to_lowercase() + == "true" + { + // HuggingFace endpoint compatibility API + app.route("/", post(hfe_embed)) + } else { + // Default API + app.route("/", post(embed)) + }; + + let app = app .layer(Extension(infer)) .layer(Extension(info)) .layer(Extension(prom_handle.clone()))