Skip to content

Commit 14e9c6c

Browse files
authored
Merge pull request #157 from doringeman/requests-sse
feat: add streaming endpoint for inference requests
2 parents 38bb017 + 296301f commit 14e9c6c

File tree

1 file changed

+185
-23
lines changed

1 file changed

+185
-23
lines changed

pkg/metrics/openai_recorder.go

Lines changed: 185 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ import (
1818
// per model.
1919
const maximumRecordsPerModel = 10
2020

21+
// subscriberChannelBuffer is the buffer size for subscriber channels.
22+
const subscriberChannelBuffer = 100
23+
2124
type responseRecorder struct {
2225
http.ResponseWriter
2326
body *bytes.Buffer
@@ -69,13 +72,18 @@ type OpenAIRecorder struct {
6972
records map[string]*ModelData // key is model ID
7073
modelManager *models.Manager // for resolving model tags to IDs
7174
m sync.RWMutex
75+
76+
// streaming
77+
subscribers map[string]chan []ModelRecordsResponse
78+
subMutex sync.RWMutex
7279
}
7380

7481
func NewOpenAIRecorder(log logging.Logger, modelManager *models.Manager) *OpenAIRecorder {
7582
return &OpenAIRecorder{
7683
log: log,
7784
modelManager: modelManager,
7885
records: make(map[string]*ModelData),
86+
subscribers: make(map[string]chan []ModelRecordsResponse),
7987
}
8088
}
8189

@@ -188,6 +196,18 @@ func (r *OpenAIRecorder) RecordResponse(id, model string, rw http.ResponseWriter
188196
record.Response = response
189197
record.Error = "" // Ensure Error is empty for successful responses
190198
}
199+
// Create ModelRecordsResponse with this single updated record to match
200+
// what the non-streaming endpoint returns - []ModelRecordsResponse.
201+
// See getAllRecords and getRecordsByModel.
202+
modelResponse := []ModelRecordsResponse{{
203+
Count: 1,
204+
Model: model,
205+
ModelData: ModelData{
206+
Config: modelData.Config,
207+
Records: []*RequestResponsePair{record},
208+
},
209+
}}
210+
go r.broadcastToSubscribers(modelResponse)
191211
return
192212
}
193213
}
@@ -274,36 +294,124 @@ func (r *OpenAIRecorder) convertStreamingResponse(streamingBody string) string {
274294

275295
func (r *OpenAIRecorder) GetRecordsHandler() http.HandlerFunc {
276296
return func(w http.ResponseWriter, req *http.Request) {
277-
w.Header().Set("Content-Type", "application/json")
297+
acceptHeader := req.Header.Get("Accept")
298+
299+
// Check if client wants Server-Sent Events
300+
if acceptHeader == "text/event-stream" {
301+
r.handleStreamingRequests(w, req)
302+
return
303+
}
304+
305+
// Default to JSON response
306+
r.handleJSONRequests(w, req)
307+
}
308+
}
309+
310+
func (r *OpenAIRecorder) handleJSONRequests(w http.ResponseWriter, req *http.Request) {
311+
w.Header().Set("Content-Type", "application/json")
312+
313+
model := req.URL.Query().Get("model")
314+
315+
if model == "" {
316+
// Retrieve all records for all models.
317+
allRecords := r.getAllRecords()
318+
if allRecords == nil {
319+
allRecords = []ModelRecordsResponse{}
320+
}
321+
if err := json.NewEncoder(w).Encode(allRecords); err != nil {
322+
http.Error(w, fmt.Sprintf("Failed to encode all records: %v", err),
323+
http.StatusInternalServerError)
324+
return
325+
}
326+
} else {
327+
// Retrieve records for the specified model.
328+
records := r.getRecordsByModel(model)
329+
if records == nil {
330+
records = []ModelRecordsResponse{}
331+
}
332+
if err := json.NewEncoder(w).Encode(records); err != nil {
333+
http.Error(w, fmt.Sprintf("Failed to encode records for model '%s': %v", model, err),
334+
http.StatusInternalServerError)
335+
return
336+
}
337+
}
338+
}
278339

279-
model := req.URL.Query().Get("model")
340+
func (r *OpenAIRecorder) handleStreamingRequests(w http.ResponseWriter, req *http.Request) {
341+
// Set SSE headers.
342+
w.Header().Set("Content-Type", "text/event-stream")
343+
w.Header().Set("Cache-Control", "no-cache")
344+
w.Header().Set("Connection", "keep-alive")
345+
346+
// Create subscriber channel.
347+
subscriberID := fmt.Sprintf("sub_%d", time.Now().UnixNano())
348+
ch := make(chan []ModelRecordsResponse, subscriberChannelBuffer)
349+
350+
// Register subscriber.
351+
r.subMutex.Lock()
352+
r.subscribers[subscriberID] = ch
353+
r.subMutex.Unlock()
354+
355+
// Clean up on disconnect.
356+
defer func() {
357+
r.subMutex.Lock()
358+
delete(r.subscribers, subscriberID)
359+
close(ch)
360+
r.subMutex.Unlock()
361+
}()
362+
363+
// Optional: Send existing records first.
364+
model := req.URL.Query().Get("model")
365+
if includeExisting := req.URL.Query().Get("include_existing"); includeExisting == "true" {
366+
r.sendExistingRecords(w, model)
367+
}
280368

281-
if model == "" {
282-
// Retrieve all records for all models.
283-
allRecords := r.getAllRecords()
284-
if allRecords == nil {
285-
// No records found.
286-
http.Error(w, "No records found", http.StatusNotFound)
369+
flusher, ok := w.(http.Flusher)
370+
if !ok {
371+
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
372+
return
373+
}
374+
375+
// Send heartbeat to establish connection.
376+
if _, err := fmt.Fprintf(w, "event: connected\ndata: {\"status\": \"connected\"}\n\n"); err != nil {
377+
r.log.Errorf("Failed to write connected event to response: %v", err)
378+
}
379+
flusher.Flush()
380+
381+
for {
382+
select {
383+
case modelRecords, ok := <-ch:
384+
if !ok {
287385
return
288386
}
289-
if err := json.NewEncoder(w).Encode(allRecords); err != nil {
290-
http.Error(w, fmt.Sprintf("Failed to encode all records: %v", err),
291-
http.StatusInternalServerError)
292-
return
387+
388+
// Filter by model if specified.
389+
// modelRecords is assumed to have size 1 because that's how we call broadcastToSubscribers.
390+
// We do this so we don't need to query a 2nd time for the model config.
391+
if model != "" && len(modelRecords) > 0 && modelRecords[0].Model != model {
392+
continue
293393
}
294-
} else {
295-
// Retrieve records for the specified model.
296-
records := r.getRecordsByModel(model)
297-
if records == nil {
298-
// No records found for the specified model.
299-
http.Error(w, fmt.Sprintf("No records found for model '%s'", model), http.StatusNotFound)
300-
return
394+
395+
// Send as SSE event.
396+
jsonData, err := json.Marshal(modelRecords)
397+
if err != nil {
398+
r.log.Errorf("Failed to marshal record for streaming: %v", err)
399+
errorMsg := fmt.Sprintf(`{"error": "Failed to marshal record: %v"}`, err)
400+
if _, writeErr := fmt.Fprintf(w, "event: error\ndata: %s\n\n", errorMsg); writeErr != nil {
401+
r.log.Errorf("Failed to write error event to response: %v", writeErr)
402+
}
403+
flusher.Flush()
404+
continue
301405
}
302-
if err := json.NewEncoder(w).Encode(records); err != nil {
303-
http.Error(w, fmt.Sprintf("Failed to encode records for model '%s': %v", model, err),
304-
http.StatusInternalServerError)
305-
return
406+
407+
if _, err := fmt.Fprintf(w, "event: new_request\ndata: %s\n\n", jsonData); err != nil {
408+
r.log.Errorf("Failed to write new_request event to response: %v", err)
306409
}
410+
flusher.Flush()
411+
412+
case <-req.Context().Done():
413+
// Client disconnected.
414+
return
307415
}
308416
}
309417
}
@@ -352,6 +460,60 @@ func (r *OpenAIRecorder) getRecordsByModel(model string) []ModelRecordsResponse
352460
return nil
353461
}
354462

463+
func (r *OpenAIRecorder) broadcastToSubscribers(modelResponses []ModelRecordsResponse) {
464+
r.subMutex.RLock()
465+
defer r.subMutex.RUnlock()
466+
467+
for _, ch := range r.subscribers {
468+
select {
469+
case ch <- modelResponses:
470+
default:
471+
// The channel is full, skip this subscriber.
472+
}
473+
}
474+
}
475+
476+
func (r *OpenAIRecorder) sendExistingRecords(w http.ResponseWriter, model string) {
477+
var records []ModelRecordsResponse
478+
479+
if model == "" {
480+
records = r.getAllRecords()
481+
} else {
482+
records = r.getRecordsByModel(model)
483+
}
484+
485+
if records != nil {
486+
// Send each individual request-response pair as a separate event.
487+
for _, modelRecord := range records {
488+
for _, requestRecord := range modelRecord.Records {
489+
// Create a ModelRecordsResponse with a single record to match
490+
// what the non-streaming endpoint returns - []ModelRecordsResponse.
491+
// See getAllRecords and getRecordsByModel.
492+
singleRecord := []ModelRecordsResponse{{
493+
Count: 1,
494+
Model: modelRecord.Model,
495+
ModelData: ModelData{
496+
Config: modelRecord.Config,
497+
Records: []*RequestResponsePair{requestRecord},
498+
},
499+
}}
500+
jsonData, err := json.Marshal(singleRecord)
501+
if err != nil {
502+
r.log.Errorf("Failed to marshal existing record for streaming: %v", err)
503+
errorMsg := fmt.Sprintf(`{"error": "Failed to marshal existing record: %v"}`, err)
504+
if _, writeErr := fmt.Fprintf(w, "event: error\ndata: %s\n\n", errorMsg); writeErr != nil {
505+
r.log.Errorf("Failed to write error event to response: %v", writeErr)
506+
}
507+
} else {
508+
if _, writeErr := fmt.Fprintf(w, "event: existing_request\ndata: %s\n\n", jsonData); writeErr != nil {
509+
r.log.Errorf("Failed to write existing_request event to response: %v", writeErr)
510+
}
511+
}
512+
}
513+
}
514+
}
515+
}
516+
355517
func (r *OpenAIRecorder) RemoveModel(model string) {
356518
modelID := r.modelManager.ResolveModelID(model)
357519

0 commit comments

Comments
 (0)