@@ -240,7 +240,7 @@ func toToolCall(call types.CompletionToolCall) openai.ToolCall {
240
240
}
241
241
}
242
242
243
- func toMessages (request types.CompletionRequest , compat bool ) (result []openai.ChatCompletionMessage , err error ) {
243
+ func toMessages (request types.CompletionRequest , compat , useO1Model bool ) (result []openai.ChatCompletionMessage , err error ) {
244
244
var (
245
245
systemPrompts []string
246
246
msgs []types.CompletionMessage
@@ -259,8 +259,12 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
259
259
}
260
260
261
261
if len (systemPrompts ) > 0 {
262
+ role := types .CompletionMessageRoleTypeSystem
263
+ if useO1Model {
264
+ role = types .CompletionMessageRoleTypeDeveloper
265
+ }
262
266
msgs = slices .Insert (msgs , 0 , types.CompletionMessage {
263
- Role : types . CompletionMessageRoleTypeSystem ,
267
+ Role : role ,
264
268
Content : types .Text (strings .Join (systemPrompts , "\n " )),
265
269
})
266
270
}
@@ -306,9 +310,9 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
306
310
return
307
311
}
308
312
309
- func (c * Client ) Call (ctx context.Context , messageRequest types.CompletionRequest , env []string , status chan <- types.CompletionStatus ) (* types.CompletionMessage , error ) {
313
+ func (c * Client ) Call (ctx context.Context , messageRequest types.CompletionRequest , envs []string , status chan <- types.CompletionStatus ) (* types.CompletionMessage , error ) {
310
314
if err := c .ValidAuth (); err != nil {
311
- if err := c .RetrieveAPIKey (ctx , env ); err != nil {
315
+ if err := c .RetrieveAPIKey (ctx , envs ); err != nil {
312
316
return nil , err
313
317
}
314
318
}
@@ -317,7 +321,9 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
317
321
messageRequest .Model = c .defaultModel
318
322
}
319
323
320
- msgs , err := toMessages (messageRequest , ! c .setSeed )
324
+ useO1Model := isO1Model (messageRequest .Model , envs )
325
+
326
+ msgs , err := toMessages (messageRequest , ! c .setSeed , useO1Model )
321
327
if err != nil {
322
328
return nil , err
323
329
}
@@ -348,10 +354,13 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
348
354
MaxTokens : messageRequest .MaxTokens ,
349
355
}
350
356
351
- if messageRequest .Temperature == nil {
352
- request .Temperature = new (float32 )
353
- } else {
354
- request .Temperature = messageRequest .Temperature
357
+ // openai O1 doesn't support setting temperature
358
+ if ! useO1Model {
359
+ if messageRequest .Temperature == nil {
360
+ messageRequest .Temperature = new (float32 )
361
+ } else {
362
+ request .Temperature = messageRequest .Temperature
363
+ }
355
364
}
356
365
357
366
if messageRequest .JSONResponse {
@@ -404,15 +413,15 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
404
413
if err != nil {
405
414
return nil , err
406
415
} else if ! ok {
407
- result , err = c .call (ctx , request , id , env , status )
416
+ result , err = c .call (ctx , request , id , envs , status )
408
417
409
418
// If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass.
410
419
var apiError * openai.APIError
411
420
if errors .As (err , & apiError ) && apiError .Code == "context_length_exceeded" && messageRequest .Chat {
412
421
// Decrease maxTokens by 10% to make garbage collection more aggressive.
413
422
// The retry loop will further decrease maxTokens if needed.
414
423
maxTokens := decreaseTenPercent (messageRequest .MaxTokens )
415
- result , err = c .contextLimitRetryLoop (ctx , request , id , env , maxTokens , status )
424
+ result , err = c .contextLimitRetryLoop (ctx , request , id , envs , maxTokens , status )
416
425
}
417
426
if err != nil {
418
427
return nil , err
@@ -446,6 +455,22 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
446
455
return & result , nil
447
456
}
448
457
458
+ func isO1Model (model string , envs []string ) bool {
459
+ if model == "o1" {
460
+ return true
461
+ }
462
+
463
+ o1Model := false
464
+ for _ , env := range envs {
465
+ k , v , _ := strings .Cut (env , "=" )
466
+ if k == "OPENAI_MODEL_NAME" && v == "o1" {
467
+ o1Model = true
468
+ }
469
+ }
470
+
471
+ return o1Model
472
+ }
473
+
449
474
func (c * Client ) contextLimitRetryLoop (ctx context.Context , request openai.ChatCompletionRequest , id string , env []string , maxTokens int , status chan <- types.CompletionStatus ) (types.CompletionMessage , error ) {
450
475
var (
451
476
response types.CompletionMessage
@@ -545,9 +570,14 @@ func override(left, right string) string {
545
570
return left
546
571
}
547
572
548
- func (c * Client ) call (ctx context.Context , request openai.ChatCompletionRequest , transactionID string , env []string , partial chan <- types.CompletionStatus ) (types.CompletionMessage , error ) {
573
+ func (c * Client ) call (ctx context.Context , request openai.ChatCompletionRequest , transactionID string , envs []string , partial chan <- types.CompletionStatus ) (types.CompletionMessage , error ) {
549
574
streamResponse := os .Getenv ("GPTSCRIPT_INTERNAL_OPENAI_STREAMING" ) != "false"
550
575
576
+ useO1Model := isO1Model (request .Model , envs )
577
+ if useO1Model {
578
+ streamResponse = false
579
+ }
580
+
551
581
partial <- types.CompletionStatus {
552
582
CompletionID : transactionID ,
553
583
PartialResponse : & types.CompletionMessage {
@@ -567,7 +597,7 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
567
597
},
568
598
}
569
599
)
570
- for _ , e := range env {
600
+ for _ , e := range envs {
571
601
if strings .HasPrefix (e , "GPTSCRIPT_MODEL_PROVIDER_" ) {
572
602
modelProviderEnv = append (modelProviderEnv , e )
573
603
} else if strings .HasPrefix (e , "GPTSCRIPT_DISABLE_RETRIES" ) {
0 commit comments