@@ -18,6 +18,9 @@ import (
18
18
// per model.
19
19
const maximumRecordsPerModel = 10
20
20
21
+ // subscriberChannelBuffer is the buffer size for subscriber channels.
22
+ const subscriberChannelBuffer = 100
23
+
21
24
type responseRecorder struct {
22
25
http.ResponseWriter
23
26
body * bytes.Buffer
@@ -71,7 +74,7 @@ type OpenAIRecorder struct {
71
74
m sync.RWMutex
72
75
73
76
// streaming
74
- subscribers map [string ]chan * RequestResponsePair
77
+ subscribers map [string ]chan [] ModelRecordsResponse
75
78
subMutex sync.RWMutex
76
79
}
77
80
@@ -80,7 +83,7 @@ func NewOpenAIRecorder(log logging.Logger, modelManager *models.Manager) *OpenAI
80
83
log : log ,
81
84
modelManager : modelManager ,
82
85
records : make (map [string ]* ModelData ),
83
- subscribers : make (map [string ]chan * RequestResponsePair ),
86
+ subscribers : make (map [string ]chan [] ModelRecordsResponse ),
84
87
}
85
88
}
86
89
@@ -193,7 +196,18 @@ func (r *OpenAIRecorder) RecordResponse(id, model string, rw http.ResponseWriter
193
196
record .Response = response
194
197
record .Error = "" // Ensure Error is empty for successful responses
195
198
}
196
- go r .broadcastToSubscribers (record )
199
+ // Create ModelRecordsResponse with this single updated record to match
200
+ // what the non-streaming endpoint returns - []ModelRecordsResponse.
201
+ // See getAllRecords and getRecordsByModel.
202
+ modelResponse := []ModelRecordsResponse {{
203
+ Count : 1 ,
204
+ Model : model ,
205
+ ModelData : ModelData {
206
+ Config : modelData .Config ,
207
+ Records : []* RequestResponsePair {record },
208
+ },
209
+ }}
210
+ go r .broadcastToSubscribers (modelResponse )
197
211
return
198
212
}
199
213
}
@@ -335,7 +349,7 @@ func (r *OpenAIRecorder) handleStreamingRequests(w http.ResponseWriter, req *htt
335
349
336
350
// Create subscriber channel.
337
351
subscriberID := fmt .Sprintf ("sub_%d" , time .Now ().UnixNano ())
338
- ch := make (chan * RequestResponsePair , 100 )
352
+ ch := make (chan [] ModelRecordsResponse , subscriberChannelBuffer )
339
353
340
354
// Register subscriber.
341
355
r .subMutex .Lock ()
@@ -368,18 +382,18 @@ func (r *OpenAIRecorder) handleStreamingRequests(w http.ResponseWriter, req *htt
368
382
369
383
for {
370
384
select {
371
- case record , ok := <- ch :
385
+ case modelRecords , ok := <- ch :
372
386
if ! ok {
373
387
return
374
388
}
375
389
376
390
// Filter by model if specified.
377
- if model != "" && record .Model != model {
391
+ if model != "" && len ( modelRecords ) > 0 && modelRecords [ 0 ] .Model != model {
378
392
continue
379
393
}
380
394
381
395
// Send as SSE event.
382
- jsonData , err := json .Marshal (record )
396
+ jsonData , err := json .Marshal (modelRecords )
383
397
if err != nil {
384
398
continue
385
399
}
@@ -438,13 +452,13 @@ func (r *OpenAIRecorder) getRecordsByModel(model string) []ModelRecordsResponse
438
452
return nil
439
453
}
440
454
441
- func (r * OpenAIRecorder ) broadcastToSubscribers (record * RequestResponsePair ) {
455
+ func (r * OpenAIRecorder ) broadcastToSubscribers (modelResponses [] ModelRecordsResponse ) {
442
456
r .subMutex .RLock ()
443
457
defer r .subMutex .RUnlock ()
444
458
445
459
for _ , ch := range r .subscribers {
446
460
select {
447
- case ch <- record :
461
+ case ch <- modelResponses :
448
462
default :
449
463
// The channel is full, skip this subscriber.
450
464
}
@@ -461,13 +475,24 @@ func (r *OpenAIRecorder) sendExistingRecords(w http.ResponseWriter, model string
461
475
}
462
476
463
477
if records != nil {
464
- for _ , modelResponse := range records {
465
- for _ , record := range modelResponse .Records {
466
- jsonData , err := json .Marshal (record )
467
- if err != nil {
468
- continue
478
+ // Send each individual request-response pair as a separate event.
479
+ for _ , modelRecord := range records {
480
+ for _ , requestRecord := range modelRecord .Records {
481
+ // Create a ModelRecordsResponse with a single record to match
482
+ // what the non-streaming endpoint returns - []ModelRecordsResponse.
483
+ // See getAllRecords and getRecordsByModel.
484
+ singleRecord := []ModelRecordsResponse {{
485
+ Count : 1 ,
486
+ Model : modelRecord .Model ,
487
+ ModelData : ModelData {
488
+ Config : modelRecord .Config ,
489
+ Records : []* RequestResponsePair {requestRecord },
490
+ },
491
+ }}
492
+ jsonData , err := json .Marshal (singleRecord )
493
+ if err == nil {
494
+ fmt .Fprintf (w , "event: existing_request\n data: %s\n \n " , jsonData )
469
495
}
470
- fmt .Fprintf (w , "event: existing_request\n data: %s\n \n " , jsonData )
471
496
}
472
497
}
473
498
}
0 commit comments