Skip to content

Commit 8e165dc

Browse files
authored
Feat Add headers to openai responses (#506)
* feat: add headers to http response * chore: add test * fix: rename to httpHeader
1 parent 533935e commit 8e165dc

14 files changed

+107
-2
lines changed

audio.go

+18-1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,21 @@ type AudioResponse struct {
6363
Transient bool `json:"transient"`
6464
} `json:"segments"`
6565
Text string `json:"text"`
66+
67+
httpHeader
68+
}
69+
70+
type audioTextResponse struct {
71+
Text string `json:"text"`
72+
73+
httpHeader
74+
}
75+
76+
func (r *audioTextResponse) ToAudioResponse() AudioResponse {
77+
return AudioResponse{
78+
Text: r.Text,
79+
httpHeader: r.httpHeader,
80+
}
6681
}
6782

6883
// CreateTranscription — API call to create a transcription. Returns transcribed text.
@@ -104,7 +119,9 @@ func (c *Client) callAudioAPI(
104119
if request.HasJSONResponse() {
105120
err = c.sendRequest(req, &response)
106121
} else {
107-
err = c.sendRequest(req, &response.Text)
122+
var textResponse audioTextResponse
123+
err = c.sendRequest(req, &textResponse)
124+
response = textResponse.ToAudioResponse()
108125
}
109126
if err != nil {
110127
return AudioResponse{}, err

chat.go

+2
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ type ChatCompletionResponse struct {
142142
Model string `json:"model"`
143143
Choices []ChatCompletionChoice `json:"choices"`
144144
Usage Usage `json:"usage"`
145+
146+
httpHeader
145147
}
146148

147149
// CreateChatCompletion — API call to Create a completion for the chat message.

chat_test.go

+30
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ import (
1616
"github.com/sashabaranov/go-openai/jsonschema"
1717
)
1818

19+
const (
20+
xCustomHeader = "X-CUSTOM-HEADER"
21+
xCustomHeaderValue = "test"
22+
)
23+
1924
func TestChatCompletionsWrongModel(t *testing.T) {
2025
config := DefaultConfig("whatever")
2126
config.BaseURL = "http://localhost/v1"
@@ -68,6 +73,30 @@ func TestChatCompletions(t *testing.T) {
6873
checks.NoError(t, err, "CreateChatCompletion error")
6974
}
7075

76+
// TestCompletions Tests the completions endpoint of the API using the mocked server.
77+
func TestChatCompletionsWithHeaders(t *testing.T) {
78+
client, server, teardown := setupOpenAITestServer()
79+
defer teardown()
80+
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
81+
resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
82+
MaxTokens: 5,
83+
Model: GPT3Dot5Turbo,
84+
Messages: []ChatCompletionMessage{
85+
{
86+
Role: ChatMessageRoleUser,
87+
Content: "Hello!",
88+
},
89+
},
90+
})
91+
checks.NoError(t, err, "CreateChatCompletion error")
92+
93+
a := resp.Header().Get(xCustomHeader)
94+
_ = a
95+
if resp.Header().Get(xCustomHeader) != xCustomHeaderValue {
96+
t.Errorf("expected header %s to be %s", xCustomHeader, xCustomHeaderValue)
97+
}
98+
}
99+
71100
// TestChatCompletionsFunctions tests including a function call.
72101
func TestChatCompletionsFunctions(t *testing.T) {
73102
client, server, teardown := setupOpenAITestServer()
@@ -281,6 +310,7 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
281310
TotalTokens: inputTokens + completionTokens,
282311
}
283312
resBytes, _ = json.Marshal(res)
313+
w.Header().Set(xCustomHeader, xCustomHeaderValue)
284314
fmt.Fprintln(w, string(resBytes))
285315
}
286316

client.go

+19-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,20 @@ type Client struct {
2020
createFormBuilder func(io.Writer) utils.FormBuilder
2121
}
2222

23+
type Response interface {
24+
SetHeader(http.Header)
25+
}
26+
27+
type httpHeader http.Header
28+
29+
func (h *httpHeader) SetHeader(header http.Header) {
30+
*h = httpHeader(header)
31+
}
32+
33+
func (h httpHeader) Header() http.Header {
34+
return http.Header(h)
35+
}
36+
2337
// NewClient creates new OpenAI API client.
2438
func NewClient(authToken string) *Client {
2539
config := DefaultConfig(authToken)
@@ -82,7 +96,7 @@ func (c *Client) newRequest(ctx context.Context, method, url string, setters ...
8296
return req, nil
8397
}
8498

85-
func (c *Client) sendRequest(req *http.Request, v any) error {
99+
func (c *Client) sendRequest(req *http.Request, v Response) error {
86100
req.Header.Set("Accept", "application/json; charset=utf-8")
87101

88102
// Check whether Content-Type is already set, Upload Files API requires
@@ -103,6 +117,10 @@ func (c *Client) sendRequest(req *http.Request, v any) error {
103117
return c.handleErrorResp(res)
104118
}
105119

120+
if v != nil {
121+
v.SetHeader(res.Header)
122+
}
123+
106124
return decodeResponse(res.Body, v)
107125
}
108126

completion.go

+2
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ type CompletionResponse struct {
154154
Model string `json:"model"`
155155
Choices []CompletionChoice `json:"choices"`
156156
Usage Usage `json:"usage"`
157+
158+
httpHeader
157159
}
158160

159161
// CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well

edits.go

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ type EditsResponse struct {
2828
Created int64 `json:"created"`
2929
Usage Usage `json:"usage"`
3030
Choices []EditsChoice `json:"choices"`
31+
32+
httpHeader
3133
}
3234

3335
// Edits Perform an API call to the Edits endpoint.

embeddings.go

+4
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ type EmbeddingResponse struct {
150150
Data []Embedding `json:"data"`
151151
Model EmbeddingModel `json:"model"`
152152
Usage Usage `json:"usage"`
153+
154+
httpHeader
153155
}
154156

155157
type base64String string
@@ -182,6 +184,8 @@ type EmbeddingResponseBase64 struct {
182184
Data []Base64Embedding `json:"data"`
183185
Model EmbeddingModel `json:"model"`
184186
Usage Usage `json:"usage"`
187+
188+
httpHeader
185189
}
186190

187191
// ToEmbeddingResponse converts an embeddingResponseBase64 to an EmbeddingResponse.

engines.go

+4
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@ type Engine struct {
1212
Object string `json:"object"`
1313
Owner string `json:"owner"`
1414
Ready bool `json:"ready"`
15+
16+
httpHeader
1517
}
1618

1719
// EnginesList is a list of engines.
1820
type EnginesList struct {
1921
Engines []Engine `json:"data"`
22+
23+
httpHeader
2024
}
2125

2226
// ListEngines Lists the currently available engines, and provides basic

files.go

+4
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,15 @@ type File struct {
2525
Status string `json:"status"`
2626
Purpose string `json:"purpose"`
2727
StatusDetails string `json:"status_details"`
28+
29+
httpHeader
2830
}
2931

3032
// FilesList is a list of files that belong to the user or organization.
3133
type FilesList struct {
3234
Files []File `json:"data"`
35+
36+
httpHeader
3337
}
3438

3539
// CreateFile uploads a jsonl file to GPT3

fine_tunes.go

+8
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ type FineTune struct {
4141
ValidationFiles []File `json:"validation_files"`
4242
TrainingFiles []File `json:"training_files"`
4343
UpdatedAt int64 `json:"updated_at"`
44+
45+
httpHeader
4446
}
4547

4648
// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API.
@@ -69,6 +71,8 @@ type FineTuneHyperParams struct {
6971
type FineTuneList struct {
7072
Object string `json:"object"`
7173
Data []FineTune `json:"data"`
74+
75+
httpHeader
7276
}
7377

7478
// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API.
@@ -77,6 +81,8 @@ type FineTuneList struct {
7781
type FineTuneEventList struct {
7882
Object string `json:"object"`
7983
Data []FineTuneEvent `json:"data"`
84+
85+
httpHeader
8086
}
8187

8288
// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API.
@@ -86,6 +92,8 @@ type FineTuneDeleteResponse struct {
8692
ID string `json:"id"`
8793
Object string `json:"object"`
8894
Deleted bool `json:"deleted"`
95+
96+
httpHeader
8997
}
9098

9199
// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API.

fine_tuning_job.go

+4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ type FineTuningJob struct {
2121
ValidationFile string `json:"validation_file,omitempty"`
2222
ResultFiles []string `json:"result_files"`
2323
TrainedTokens int `json:"trained_tokens"`
24+
25+
httpHeader
2426
}
2527

2628
type Hyperparameters struct {
@@ -39,6 +41,8 @@ type FineTuningJobEventList struct {
3941
Object string `json:"object"`
4042
Data []FineTuneEvent `json:"data"`
4143
HasMore bool `json:"has_more"`
44+
45+
httpHeader
4246
}
4347

4448
type FineTuningJobEvent struct {

image.go

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ type ImageRequest struct {
3333
type ImageResponse struct {
3434
Created int64 `json:"created,omitempty"`
3535
Data []ImageResponseDataInner `json:"data,omitempty"`
36+
37+
httpHeader
3638
}
3739

3840
// ImageResponseDataInner represents a response data structure for image API.

models.go

+6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ type Model struct {
1515
Permission []Permission `json:"permission"`
1616
Root string `json:"root"`
1717
Parent string `json:"parent"`
18+
19+
httpHeader
1820
}
1921

2022
// Permission struct represents an OpenAPI permission.
@@ -38,11 +40,15 @@ type FineTuneModelDeleteResponse struct {
3840
ID string `json:"id"`
3941
Object string `json:"object"`
4042
Deleted bool `json:"deleted"`
43+
44+
httpHeader
4145
}
4246

4347
// ModelsList is a list of models, including those that belong to the user or organization.
4448
type ModelsList struct {
4549
Models []Model `json:"data"`
50+
51+
httpHeader
4652
}
4753

4854
// ListModels Lists the currently available models,

moderation.go

+2
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ type ModerationResponse struct {
6969
ID string `json:"id"`
7070
Model string `json:"model"`
7171
Results []Result `json:"results"`
72+
73+
httpHeader
7274
}
7375

7476
// Moderations — perform a moderation api call over a string.

0 commit comments

Comments
 (0)