Skip to content

Commit 062d703

Browse files
committed
Chore: Support openai o1 model
Signed-off-by: Daishan Peng <[email protected]>
1 parent 3f876b2 commit 062d703

File tree

2 files changed

+44
-13
lines changed

2 files changed

+44
-13
lines changed

pkg/openai/client.go

+43-13
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ func toToolCall(call types.CompletionToolCall) openai.ToolCall {
240240
}
241241
}
242242

243-
func toMessages(request types.CompletionRequest, compat bool) (result []openai.ChatCompletionMessage, err error) {
243+
func toMessages(request types.CompletionRequest, compat, useO1Model bool) (result []openai.ChatCompletionMessage, err error) {
244244
var (
245245
systemPrompts []string
246246
msgs []types.CompletionMessage
@@ -259,8 +259,12 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
259259
}
260260

261261
if len(systemPrompts) > 0 {
262+
role := types.CompletionMessageRoleTypeSystem
263+
if useO1Model {
264+
role = types.CompletionMessageRoleTypeDeveloper
265+
}
262266
msgs = slices.Insert(msgs, 0, types.CompletionMessage{
263-
Role: types.CompletionMessageRoleTypeSystem,
267+
Role: role,
264268
Content: types.Text(strings.Join(systemPrompts, "\n")),
265269
})
266270
}
@@ -306,9 +310,9 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
306310
return
307311
}
308312

309-
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
313+
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, envs []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
310314
if err := c.ValidAuth(); err != nil {
311-
if err := c.RetrieveAPIKey(ctx, env); err != nil {
315+
if err := c.RetrieveAPIKey(ctx, envs); err != nil {
312316
return nil, err
313317
}
314318
}
@@ -317,7 +321,9 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
317321
messageRequest.Model = c.defaultModel
318322
}
319323

320-
msgs, err := toMessages(messageRequest, !c.setSeed)
324+
useO1Model := isO1Model(messageRequest.Model, envs)
325+
326+
msgs, err := toMessages(messageRequest, !c.setSeed, useO1Model)
321327
if err != nil {
322328
return nil, err
323329
}
@@ -348,10 +354,13 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
348354
MaxTokens: messageRequest.MaxTokens,
349355
}
350356

351-
if messageRequest.Temperature == nil {
352-
request.Temperature = new(float32)
353-
} else {
354-
request.Temperature = messageRequest.Temperature
357+
// openai O1 doesn't support setting temperature
358+
if !useO1Model {
359+
if messageRequest.Temperature == nil {
360+
messageRequest.Temperature = new(float32)
361+
} else {
362+
request.Temperature = messageRequest.Temperature
363+
}
355364
}
356365

357366
if messageRequest.JSONResponse {
@@ -404,15 +413,15 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
404413
if err != nil {
405414
return nil, err
406415
} else if !ok {
407-
result, err = c.call(ctx, request, id, env, status)
416+
result, err = c.call(ctx, request, id, envs, status)
408417

409418
// If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass.
410419
var apiError *openai.APIError
411420
if errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" && messageRequest.Chat {
412421
// Decrease maxTokens by 10% to make garbage collection more aggressive.
413422
// The retry loop will further decrease maxTokens if needed.
414423
maxTokens := decreaseTenPercent(messageRequest.MaxTokens)
415-
result, err = c.contextLimitRetryLoop(ctx, request, id, env, maxTokens, status)
424+
result, err = c.contextLimitRetryLoop(ctx, request, id, envs, maxTokens, status)
416425
}
417426
if err != nil {
418427
return nil, err
@@ -446,6 +455,22 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
446455
return &result, nil
447456
}
448457

458+
func isO1Model(model string, envs []string) bool {
459+
if model == "o1" {
460+
return true
461+
}
462+
463+
o1Model := false
464+
for _, env := range envs {
465+
k, v, _ := strings.Cut(env, "=")
466+
if k == "OPENAI_MODEL_NAME" && v == "o1" {
467+
o1Model = true
468+
}
469+
}
470+
471+
return o1Model
472+
}
473+
449474
func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, env []string, maxTokens int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) {
450475
var (
451476
response types.CompletionMessage
@@ -545,9 +570,14 @@ func override(left, right string) string {
545570
return left
546571
}
547572

548-
func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, transactionID string, env []string, partial chan<- types.CompletionStatus) (types.CompletionMessage, error) {
573+
func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, transactionID string, envs []string, partial chan<- types.CompletionStatus) (types.CompletionMessage, error) {
549574
streamResponse := os.Getenv("GPTSCRIPT_INTERNAL_OPENAI_STREAMING") != "false"
550575

576+
useO1Model := isO1Model(request.Model, envs)
577+
if useO1Model {
578+
streamResponse = false
579+
}
580+
551581
partial <- types.CompletionStatus{
552582
CompletionID: transactionID,
553583
PartialResponse: &types.CompletionMessage{
@@ -567,7 +597,7 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
567597
},
568598
}
569599
)
570-
for _, e := range env {
600+
for _, e := range envs {
571601
if strings.HasPrefix(e, "GPTSCRIPT_MODEL_PROVIDER_") {
572602
modelProviderEnv = append(modelProviderEnv, e)
573603
} else if strings.HasPrefix(e, "GPTSCRIPT_DISABLE_RETRIES") {

pkg/types/completion.go

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ type CompletionFunctionDefinition struct {
4141
const (
4242
CompletionMessageRoleTypeUser = CompletionMessageRoleType("user")
4343
CompletionMessageRoleTypeSystem = CompletionMessageRoleType("system")
44+
CompletionMessageRoleTypeDeveloper = CompletionMessageRoleType("developer")
4445
CompletionMessageRoleTypeAssistant = CompletionMessageRoleType("assistant")
4546
CompletionMessageRoleTypeTool = CompletionMessageRoleType("tool")
4647
)

0 commit comments

Comments
 (0)