Skip to content

Commit d61cffb

Browse files
committed
refactor: on streaming return the same struct as on regular
Signed-off-by: Dorin Geman <[email protected]>
1 parent a0ddf2e commit d61cffb

File tree

1 file changed

+40
-15
lines changed

1 file changed

+40
-15
lines changed

pkg/metrics/openai_recorder.go

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ import (
1818
// per model.
1919
const maximumRecordsPerModel = 10
2020

21+
// subscriberChannelBuffer is the buffer size for subscriber channels.
22+
const subscriberChannelBuffer = 100
23+
2124
type responseRecorder struct {
2225
http.ResponseWriter
2326
body *bytes.Buffer
@@ -71,7 +74,7 @@ type OpenAIRecorder struct {
7174
m sync.RWMutex
7275

7376
// streaming
74-
subscribers map[string]chan *RequestResponsePair
77+
subscribers map[string]chan []ModelRecordsResponse
7578
subMutex sync.RWMutex
7679
}
7780

@@ -80,7 +83,7 @@ func NewOpenAIRecorder(log logging.Logger, modelManager *models.Manager) *OpenAI
8083
log: log,
8184
modelManager: modelManager,
8285
records: make(map[string]*ModelData),
83-
subscribers: make(map[string]chan *RequestResponsePair),
86+
subscribers: make(map[string]chan []ModelRecordsResponse),
8487
}
8588
}
8689

@@ -193,7 +196,18 @@ func (r *OpenAIRecorder) RecordResponse(id, model string, rw http.ResponseWriter
193196
record.Response = response
194197
record.Error = "" // Ensure Error is empty for successful responses
195198
}
196-
go r.broadcastToSubscribers(record)
199+
// Create ModelRecordsResponse with this single updated record to match
200+
// what the non-streaming endpoint returns - []ModelRecordsResponse.
201+
// See getAllRecords and getRecordsByModel.
202+
modelResponse := []ModelRecordsResponse{{
203+
Count: 1,
204+
Model: model,
205+
ModelData: ModelData{
206+
Config: modelData.Config,
207+
Records: []*RequestResponsePair{record},
208+
},
209+
}}
210+
go r.broadcastToSubscribers(modelResponse)
197211
return
198212
}
199213
}
@@ -335,7 +349,7 @@ func (r *OpenAIRecorder) handleStreamingRequests(w http.ResponseWriter, req *htt
335349

336350
// Create subscriber channel.
337351
subscriberID := fmt.Sprintf("sub_%d", time.Now().UnixNano())
338-
ch := make(chan *RequestResponsePair, 100)
352+
ch := make(chan []ModelRecordsResponse, subscriberChannelBuffer)
339353

340354
// Register subscriber.
341355
r.subMutex.Lock()
@@ -368,18 +382,18 @@ func (r *OpenAIRecorder) handleStreamingRequests(w http.ResponseWriter, req *htt
368382

369383
for {
370384
select {
371-
case record, ok := <-ch:
385+
case modelRecords, ok := <-ch:
372386
if !ok {
373387
return
374388
}
375389

376390
// Filter by model if specified.
377-
if model != "" && record.Model != model {
391+
if model != "" && len(modelRecords) > 0 && modelRecords[0].Model != model {
378392
continue
379393
}
380394

381395
// Send as SSE event.
382-
jsonData, err := json.Marshal(record)
396+
jsonData, err := json.Marshal(modelRecords)
383397
if err != nil {
384398
continue
385399
}
@@ -438,13 +452,13 @@ func (r *OpenAIRecorder) getRecordsByModel(model string) []ModelRecordsResponse
438452
return nil
439453
}
440454

441-
func (r *OpenAIRecorder) broadcastToSubscribers(record *RequestResponsePair) {
455+
func (r *OpenAIRecorder) broadcastToSubscribers(modelResponses []ModelRecordsResponse) {
442456
r.subMutex.RLock()
443457
defer r.subMutex.RUnlock()
444458

445459
for _, ch := range r.subscribers {
446460
select {
447-
case ch <- record:
461+
case ch <- modelResponses:
448462
default:
449463
// The channel is full, skip this subscriber.
450464
}
@@ -461,13 +475,24 @@ func (r *OpenAIRecorder) sendExistingRecords(w http.ResponseWriter, model string
461475
}
462476

463477
if records != nil {
464-
for _, modelResponse := range records {
465-
for _, record := range modelResponse.Records {
466-
jsonData, err := json.Marshal(record)
467-
if err != nil {
468-
continue
478+
// Send each individual request-response pair as a separate event.
479+
for _, modelRecord := range records {
480+
for _, requestRecord := range modelRecord.Records {
481+
// Create a ModelRecordsResponse with a single record to match
482+
// what the non-streaming endpoint returns - []ModelRecordsResponse.
483+
// See getAllRecords and getRecordsByModel.
484+
singleRecord := []ModelRecordsResponse{{
485+
Count: 1,
486+
Model: modelRecord.Model,
487+
ModelData: ModelData{
488+
Config: modelRecord.Config,
489+
Records: []*RequestResponsePair{requestRecord},
490+
},
491+
}}
492+
jsonData, err := json.Marshal(singleRecord)
493+
if err == nil {
494+
fmt.Fprintf(w, "event: existing_request\ndata: %s\n\n", jsonData)
469495
}
470-
fmt.Fprintf(w, "event: existing_request\ndata: %s\n\n", jsonData)
471496
}
472497
}
473498
}

0 commit comments

Comments
 (0)