Skip to content

Add non-breaking TemperatureOpt field to ChatCompletionRequest that can be set to explicit zero. #983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 61 additions & 10 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,16 +232,21 @@
MaxTokens int `json:"max_tokens,omitempty"`
// MaxCompletionTokens An upper bound for the number of tokens that can be generated for a completion,
// including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Stream bool `json:"stream,omitempty"`
Stop []string `json:"stop,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"`
Seed *int `json:"seed,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`

// Deprecated: Use TemperatureOpt instead. When TemperatureOpt is set, Temperature is ignored
// regardless of its value. Otherwise (if TemperatureOpt is nil), Temperature is used when
// non-zero.
Temperature float32 `json:"-"`
TemperatureOpt *float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Stream bool `json:"stream,omitempty"`
Stop []string `json:"stop,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"`
Seed *int `json:"seed,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
// LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string.
// incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}`
// refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias
Expand Down Expand Up @@ -277,6 +282,52 @@
Prediction *Prediction `json:"prediction,omitempty"`
}

func (r *ChatCompletionRequest) UnmarshalJSON(data []byte) error {
type plainChatCompletionRequest ChatCompletionRequest
if err := json.Unmarshal(data, (*plainChatCompletionRequest)(r)); err != nil {
return err
}

Check warning on line 289 in chat.go

View check run for this annotation

Codecov / codecov/patch

chat.go#L288-L289

Added lines #L288 - L289 were not covered by tests
if r.TemperatureOpt != nil {
if *r.TemperatureOpt == 0 {
// Explicit zero. This can only be represented in the TemperatureOpt field, so
// we need to preserve it.
// We still link r.TemperatureOpt to r.Temperature, such that legacy code modifying
// temperature after unmarshaling will continue to work correctly.
r.Temperature = 0
r.TemperatureOpt = &r.Temperature
} else {
// Non-zero temperature. This can be represented in the legacy field, and in order
// to minimize incompatibilities, we use the legacy field exclusively.
// New code should use `GetTemperature()` to retrieve the temperature, and explicitly
// setting TemperatureOpt will still be respected.
r.Temperature = *r.TemperatureOpt
r.TemperatureOpt = nil
}
} else {
r.Temperature = 0
}
return nil
}

func (r ChatCompletionRequest) MarshalJSON() ([]byte, error) {
type plainChatCompletionRequest ChatCompletionRequest
plainR := plainChatCompletionRequest(r)
if plainR.TemperatureOpt == nil && plainR.Temperature != 0 {
plainR.TemperatureOpt = &plainR.Temperature
}
return json.Marshal(&plainR)
}

func (r *ChatCompletionRequest) GetTemperature() *float32 {
if r.TemperatureOpt != nil {
return r.TemperatureOpt
}
if r.Temperature != 0 {
return &r.Temperature
}
return nil
}

type StreamOptions struct {
// If set, an additional chunk will be streamed before the data: [DONE] message.
// The usage field on this chunk shows the token usage statistics for the entire request,
Expand Down
108 changes: 108 additions & 0 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,23 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
},
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_temperature_unsupported_new",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
TemperatureOpt: &[]float32{2}[0],
},
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_top_unsupported",
in: openai.ChatCompletionRequest{
Expand Down Expand Up @@ -946,3 +963,94 @@ func TestFinishReason(t *testing.T) {
}
}
}

func TestTemperature(t *testing.T) {
tests := []struct {
name string
in openai.ChatCompletionRequest
expectedTemperature *float32
}{
{
name: "not_set",
in: openai.ChatCompletionRequest{},
expectedTemperature: nil,
},
{
name: "set_legacy",
in: openai.ChatCompletionRequest{
Temperature: 0.5,
},
expectedTemperature: &[]float32{0.5}[0],
},
{
name: "set_new",
in: openai.ChatCompletionRequest{
TemperatureOpt: &[]float32{0.5}[0],
},
expectedTemperature: &[]float32{0.5}[0],
},
{
name: "set_both",
in: openai.ChatCompletionRequest{
Temperature: 0.4,
TemperatureOpt: &[]float32{0.5}[0],
},
expectedTemperature: &[]float32{0.5}[0],
},
{
name: "set_new_explicit_zero",
in: openai.ChatCompletionRequest{
TemperatureOpt: &[]float32{0}[0],
},
expectedTemperature: &[]float32{0}[0],
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.in)
checks.NoError(t, err, "failed to marshal request to JSON")

var req openai.ChatCompletionRequest
err = json.Unmarshal(data, &req)
checks.NoError(t, err, "failed to unmarshal request from JSON")

temp := req.GetTemperature()
if tt.expectedTemperature == nil {
if temp != nil {
t.Error("expected temperature to be nil")
}
} else {
if temp == nil {
t.Error("expected temperature to be set")
} else if *tt.expectedTemperature != *temp {
t.Errorf("expected temperature to be %v but was %v", *tt.expectedTemperature, *temp)
}
}
})
}
}

func TestTemperature_ModifyLegacyAfterUnmarshal(t *testing.T) {
req := openai.ChatCompletionRequest{
TemperatureOpt: &[]float32{0.5}[0],
}

data, err := json.Marshal(req)
checks.NoError(t, err, "failed to marshal request to JSON")

var req2 openai.ChatCompletionRequest
err = json.Unmarshal(data, &req2)
checks.NoError(t, err, "failed to unmarshal request from JSON")

if temp := req2.GetTemperature(); temp == nil || *temp != 0.5 {
t.Errorf("expected temperature to be 0.5 but was %v", temp)
}

// Modify the legacy temperature field
req2.Temperature = 0.4

if temp := req2.GetTemperature(); temp == nil || *temp != 0.4 {
t.Errorf("expected temperature to be 0.4 but was %v", temp)
}
}
2 changes: 1 addition & 1 deletion reasoning_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletion
if request.LogProbs {
return ErrReasoningModelLimitationsLogprobs
}
if request.Temperature > 0 && request.Temperature != 1 {
if temp := request.GetTemperature(); temp != nil && *temp != 1 {
return ErrReasoningModelLimitationsOther
}
if request.TopP > 0 && request.TopP != 1 {
Expand Down
Loading