Skip to content

Commit 38bb017

Browse files
authored
Merge pull request #155 from doringeman/requests
OpenAIRecorder: Allow getting all records
2 parents 239d3e1 + 051e832 commit 38bb017

File tree

2 files changed

+55
-16
lines changed

2 files changed

+55
-16
lines changed

pkg/inference/scheduling/scheduler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ func (s *Scheduler) routeHandlers() map[string]http.HandlerFunc {
117117
m["POST "+inference.InferencePrefix+"/unload"] = s.Unload
118118
m["POST "+inference.InferencePrefix+"/{backend}/_configure"] = s.Configure
119119
m["POST "+inference.InferencePrefix+"/_configure"] = s.Configure
120-
m["GET "+inference.InferencePrefix+"/requests"] = s.openAIRecorder.GetRecordsByModelHandler()
120+
m["GET "+inference.InferencePrefix+"/requests"] = s.openAIRecorder.GetRecordsHandler()
121121
return m
122122
}
123123

pkg/metrics/openai_recorder.go

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ type ModelData struct {
5858
Records []*RequestResponsePair `json:"records"`
5959
}
6060

61+
type ModelRecordsResponse struct {
62+
Count int `json:"count"`
63+
Model string `json:"model"`
64+
ModelData
65+
}
66+
6167
type OpenAIRecorder struct {
6268
log logging.Logger
6369
records map[string]*ModelData // key is model ID
@@ -266,30 +272,34 @@ func (r *OpenAIRecorder) convertStreamingResponse(streamingBody string) string {
266272
return string(jsonResult)
267273
}
268274

269-
func (r *OpenAIRecorder) GetRecordsByModelHandler() http.HandlerFunc {
275+
func (r *OpenAIRecorder) GetRecordsHandler() http.HandlerFunc {
270276
return func(w http.ResponseWriter, req *http.Request) {
271277
w.Header().Set("Content-Type", "application/json")
272278

273279
model := req.URL.Query().Get("model")
274280

275281
if model == "" {
276-
http.Error(w, "A 'model' query parameter is required", http.StatusBadRequest)
282+
// Retrieve all records for all models.
283+
allRecords := r.getAllRecords()
284+
if allRecords == nil {
285+
// No records found.
286+
http.Error(w, "No records found", http.StatusNotFound)
287+
return
288+
}
289+
if err := json.NewEncoder(w).Encode(allRecords); err != nil {
290+
http.Error(w, fmt.Sprintf("Failed to encode all records: %v", err),
291+
http.StatusInternalServerError)
292+
return
293+
}
277294
} else {
278295
// Retrieve records for the specified model.
279-
records := r.GetRecordsByModel(model)
296+
records := r.getRecordsByModel(model)
280297
if records == nil {
281298
// No records found for the specified model.
282299
http.Error(w, fmt.Sprintf("No records found for model '%s'", model), http.StatusNotFound)
283300
return
284301
}
285-
286-
modelID := r.modelManager.ResolveModelID(model)
287-
if err := json.NewEncoder(w).Encode(map[string]interface{}{
288-
"model": model,
289-
"records": records,
290-
"count": len(records),
291-
"config": r.records[modelID].Config,
292-
}); err != nil {
302+
if err := json.NewEncoder(w).Encode(records); err != nil {
293303
http.Error(w, fmt.Sprintf("Failed to encode records for model '%s': %v", model, err),
294304
http.StatusInternalServerError)
295305
return
@@ -298,16 +308,45 @@ func (r *OpenAIRecorder) GetRecordsByModelHandler() http.HandlerFunc {
298308
}
299309
}
300310

301-
func (r *OpenAIRecorder) GetRecordsByModel(model string) []*RequestResponsePair {
311+
func (r *OpenAIRecorder) getAllRecords() []ModelRecordsResponse {
312+
r.m.RLock()
313+
defer r.m.RUnlock()
314+
315+
if len(r.records) == 0 {
316+
return nil
317+
}
318+
319+
result := make([]ModelRecordsResponse, 0, len(r.records))
320+
321+
for modelID, modelData := range r.records {
322+
result = append(result, ModelRecordsResponse{
323+
Count: len(modelData.Records),
324+
Model: modelID,
325+
ModelData: ModelData{
326+
Config: modelData.Config,
327+
Records: modelData.Records,
328+
},
329+
})
330+
}
331+
332+
return result
333+
}
334+
335+
func (r *OpenAIRecorder) getRecordsByModel(model string) []ModelRecordsResponse {
302336
modelID := r.modelManager.ResolveModelID(model)
303337

304338
r.m.RLock()
305339
defer r.m.RUnlock()
306340

307341
if modelData, exists := r.records[modelID]; exists {
308-
result := make([]*RequestResponsePair, len(modelData.Records))
309-
copy(result, modelData.Records)
310-
return result
342+
return []ModelRecordsResponse{{
343+
Count: len(modelData.Records),
344+
Model: modelID,
345+
ModelData: ModelData{
346+
Config: modelData.Config,
347+
Records: modelData.Records,
348+
},
349+
}}
311350
}
312351

313352
return nil

0 commit comments

Comments
 (0)