@@ -18,6 +18,9 @@ import (
18
18
// per model.
19
19
const maximumRecordsPerModel = 10
20
20
21
+ // subscriberChannelBuffer is the buffer size for subscriber channels.
22
+ const subscriberChannelBuffer = 100
23
+
21
24
type responseRecorder struct {
22
25
http.ResponseWriter
23
26
body * bytes.Buffer
@@ -69,13 +72,18 @@ type OpenAIRecorder struct {
69
72
records map [string ]* ModelData // key is model ID
70
73
modelManager * models.Manager // for resolving model tags to IDs
71
74
m sync.RWMutex
75
+
76
+ // streaming
77
+ subscribers map [string ]chan []ModelRecordsResponse
78
+ subMutex sync.RWMutex
72
79
}
73
80
74
81
func NewOpenAIRecorder (log logging.Logger , modelManager * models.Manager ) * OpenAIRecorder {
75
82
return & OpenAIRecorder {
76
83
log : log ,
77
84
modelManager : modelManager ,
78
85
records : make (map [string ]* ModelData ),
86
+ subscribers : make (map [string ]chan []ModelRecordsResponse ),
79
87
}
80
88
}
81
89
@@ -188,6 +196,18 @@ func (r *OpenAIRecorder) RecordResponse(id, model string, rw http.ResponseWriter
188
196
record .Response = response
189
197
record .Error = "" // Ensure Error is empty for successful responses
190
198
}
199
+ // Create ModelRecordsResponse with this single updated record to match
200
+ // what the non-streaming endpoint returns - []ModelRecordsResponse.
201
+ // See getAllRecords and getRecordsByModel.
202
+ modelResponse := []ModelRecordsResponse {{
203
+ Count : 1 ,
204
+ Model : model ,
205
+ ModelData : ModelData {
206
+ Config : modelData .Config ,
207
+ Records : []* RequestResponsePair {record },
208
+ },
209
+ }}
210
+ go r .broadcastToSubscribers (modelResponse )
191
211
return
192
212
}
193
213
}
@@ -274,36 +294,124 @@ func (r *OpenAIRecorder) convertStreamingResponse(streamingBody string) string {
274
294
275
295
func (r * OpenAIRecorder ) GetRecordsHandler () http.HandlerFunc {
276
296
return func (w http.ResponseWriter , req * http.Request ) {
277
- w .Header ().Set ("Content-Type" , "application/json" )
297
+ acceptHeader := req .Header .Get ("Accept" )
298
+
299
+ // Check if client wants Server-Sent Events
300
+ if acceptHeader == "text/event-stream" {
301
+ r .handleStreamingRequests (w , req )
302
+ return
303
+ }
304
+
305
+ // Default to JSON response
306
+ r .handleJSONRequests (w , req )
307
+ }
308
+ }
309
+
310
+ func (r * OpenAIRecorder ) handleJSONRequests (w http.ResponseWriter , req * http.Request ) {
311
+ w .Header ().Set ("Content-Type" , "application/json" )
312
+
313
+ model := req .URL .Query ().Get ("model" )
314
+
315
+ if model == "" {
316
+ // Retrieve all records for all models.
317
+ allRecords := r .getAllRecords ()
318
+ if allRecords == nil {
319
+ allRecords = []ModelRecordsResponse {}
320
+ }
321
+ if err := json .NewEncoder (w ).Encode (allRecords ); err != nil {
322
+ http .Error (w , fmt .Sprintf ("Failed to encode all records: %v" , err ),
323
+ http .StatusInternalServerError )
324
+ return
325
+ }
326
+ } else {
327
+ // Retrieve records for the specified model.
328
+ records := r .getRecordsByModel (model )
329
+ if records == nil {
330
+ records = []ModelRecordsResponse {}
331
+ }
332
+ if err := json .NewEncoder (w ).Encode (records ); err != nil {
333
+ http .Error (w , fmt .Sprintf ("Failed to encode records for model '%s': %v" , model , err ),
334
+ http .StatusInternalServerError )
335
+ return
336
+ }
337
+ }
338
+ }
278
339
279
- model := req .URL .Query ().Get ("model" )
340
+ func (r * OpenAIRecorder ) handleStreamingRequests (w http.ResponseWriter , req * http.Request ) {
341
+ // Set SSE headers.
342
+ w .Header ().Set ("Content-Type" , "text/event-stream" )
343
+ w .Header ().Set ("Cache-Control" , "no-cache" )
344
+ w .Header ().Set ("Connection" , "keep-alive" )
345
+
346
+ // Create subscriber channel.
347
+ subscriberID := fmt .Sprintf ("sub_%d" , time .Now ().UnixNano ())
348
+ ch := make (chan []ModelRecordsResponse , subscriberChannelBuffer )
349
+
350
+ // Register subscriber.
351
+ r .subMutex .Lock ()
352
+ r .subscribers [subscriberID ] = ch
353
+ r .subMutex .Unlock ()
354
+
355
+ // Clean up on disconnect.
356
+ defer func () {
357
+ r .subMutex .Lock ()
358
+ delete (r .subscribers , subscriberID )
359
+ close (ch )
360
+ r .subMutex .Unlock ()
361
+ }()
362
+
363
+ // Optional: Send existing records first.
364
+ model := req .URL .Query ().Get ("model" )
365
+ if includeExisting := req .URL .Query ().Get ("include_existing" ); includeExisting == "true" {
366
+ r .sendExistingRecords (w , model )
367
+ }
280
368
281
- if model == "" {
282
- // Retrieve all records for all models.
283
- allRecords := r .getAllRecords ()
284
- if allRecords == nil {
285
- // No records found.
286
- http .Error (w , "No records found" , http .StatusNotFound )
369
+ flusher , ok := w .(http.Flusher )
370
+ if ! ok {
371
+ http .Error (w , "Streaming not supported" , http .StatusInternalServerError )
372
+ return
373
+ }
374
+
375
+ // Send heartbeat to establish connection.
376
+ if _ , err := fmt .Fprintf (w , "event: connected\n data: {\" status\" : \" connected\" }\n \n " ); err != nil {
377
+ r .log .Errorf ("Failed to write connected event to response: %v" , err )
378
+ }
379
+ flusher .Flush ()
380
+
381
+ for {
382
+ select {
383
+ case modelRecords , ok := <- ch :
384
+ if ! ok {
287
385
return
288
386
}
289
- if err := json .NewEncoder (w ).Encode (allRecords ); err != nil {
290
- http .Error (w , fmt .Sprintf ("Failed to encode all records: %v" , err ),
291
- http .StatusInternalServerError )
292
- return
387
+
388
+ // Filter by model if specified.
389
+ // modelRecords is assumed to have size 1 because that's how we call broadcastToSubscribers.
390
+ // We do this so we don't need to query a 2nd time for the model config.
391
+ if model != "" && len (modelRecords ) > 0 && modelRecords [0 ].Model != model {
392
+ continue
293
393
}
294
- } else {
295
- // Retrieve records for the specified model.
296
- records := r .getRecordsByModel (model )
297
- if records == nil {
298
- // No records found for the specified model.
299
- http .Error (w , fmt .Sprintf ("No records found for model '%s'" , model ), http .StatusNotFound )
300
- return
394
+
395
+ // Send as SSE event.
396
+ jsonData , err := json .Marshal (modelRecords )
397
+ if err != nil {
398
+ r .log .Errorf ("Failed to marshal record for streaming: %v" , err )
399
+ errorMsg := fmt .Sprintf (`{"error": "Failed to marshal record: %v"}` , err )
400
+ if _ , writeErr := fmt .Fprintf (w , "event: error\n data: %s\n \n " , errorMsg ); writeErr != nil {
401
+ r .log .Errorf ("Failed to write error event to response: %v" , writeErr )
402
+ }
403
+ flusher .Flush ()
404
+ continue
301
405
}
302
- if err := json .NewEncoder (w ).Encode (records ); err != nil {
303
- http .Error (w , fmt .Sprintf ("Failed to encode records for model '%s': %v" , model , err ),
304
- http .StatusInternalServerError )
305
- return
406
+
407
+ if _ , err := fmt .Fprintf (w , "event: new_request\n data: %s\n \n " , jsonData ); err != nil {
408
+ r .log .Errorf ("Failed to write new_request event to response: %v" , err )
306
409
}
410
+ flusher .Flush ()
411
+
412
+ case <- req .Context ().Done ():
413
+ // Client disconnected.
414
+ return
307
415
}
308
416
}
309
417
}
@@ -352,6 +460,60 @@ func (r *OpenAIRecorder) getRecordsByModel(model string) []ModelRecordsResponse
352
460
return nil
353
461
}
354
462
463
+ func (r * OpenAIRecorder ) broadcastToSubscribers (modelResponses []ModelRecordsResponse ) {
464
+ r .subMutex .RLock ()
465
+ defer r .subMutex .RUnlock ()
466
+
467
+ for _ , ch := range r .subscribers {
468
+ select {
469
+ case ch <- modelResponses :
470
+ default :
471
+ // The channel is full, skip this subscriber.
472
+ }
473
+ }
474
+ }
475
+
476
+ func (r * OpenAIRecorder ) sendExistingRecords (w http.ResponseWriter , model string ) {
477
+ var records []ModelRecordsResponse
478
+
479
+ if model == "" {
480
+ records = r .getAllRecords ()
481
+ } else {
482
+ records = r .getRecordsByModel (model )
483
+ }
484
+
485
+ if records != nil {
486
+ // Send each individual request-response pair as a separate event.
487
+ for _ , modelRecord := range records {
488
+ for _ , requestRecord := range modelRecord .Records {
489
+ // Create a ModelRecordsResponse with a single record to match
490
+ // what the non-streaming endpoint returns - []ModelRecordsResponse.
491
+ // See getAllRecords and getRecordsByModel.
492
+ singleRecord := []ModelRecordsResponse {{
493
+ Count : 1 ,
494
+ Model : modelRecord .Model ,
495
+ ModelData : ModelData {
496
+ Config : modelRecord .Config ,
497
+ Records : []* RequestResponsePair {requestRecord },
498
+ },
499
+ }}
500
+ jsonData , err := json .Marshal (singleRecord )
501
+ if err != nil {
502
+ r .log .Errorf ("Failed to marshal existing record for streaming: %v" , err )
503
+ errorMsg := fmt .Sprintf (`{"error": "Failed to marshal existing record: %v"}` , err )
504
+ if _ , writeErr := fmt .Fprintf (w , "event: error\n data: %s\n \n " , errorMsg ); writeErr != nil {
505
+ r .log .Errorf ("Failed to write error event to response: %v" , writeErr )
506
+ }
507
+ } else {
508
+ if _ , writeErr := fmt .Fprintf (w , "event: existing_request\n data: %s\n \n " , jsonData ); writeErr != nil {
509
+ r .log .Errorf ("Failed to write existing_request event to response: %v" , writeErr )
510
+ }
511
+ }
512
+ }
513
+ }
514
+ }
515
+ }
516
+
355
517
func (r * OpenAIRecorder ) RemoveModel (model string ) {
356
518
modelID := r .modelManager .ResolveModelID (model )
357
519
0 commit comments