Skip to content

Commit ce08807

Browse files
committed
Make MaxNumTokens work with newer models
Most models accept both `max_tokens` and `max_completion_tokens`; the newer reasoning models only want to see `max_completion_tokens`, while GPT-3.5 deployments error when given `max_completion_tokens`. Since we do not know which model is behind some deployment name, we just try with the newer name and come back with the old name if we get the corresponding error message. The test point also works with o1-mini, gpt-4o is just cheaper.
1 parent fd757c3 commit ce08807

File tree

3 files changed

+40
-8
lines changed

3 files changed

+40
-8
lines changed

Diff for: +llms/+internal/callAzureChatAPI.m

+19-3
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
% % Send a request
3838
% [text, message] = llms.internal.callAzureChatAPI(messages, functions, APIKey=apiKey)
3939

40-
% Copyright 2023-2024 The MathWorks, Inc.
40+
% Copyright 2023-2025 The MathWorks, Inc.
4141

4242
arguments
4343
endpoint
@@ -66,6 +66,17 @@
6666

6767
[response, streamedText] = llms.internal.sendRequestWrapper(parameters,nvp.APIKey, URL, nvp.TimeOut, nvp.StreamFun);
6868

69+
% For old models like GPT-3.5, we may have to change the request sent a
70+
% little. Since we cannot detect the model used other than trying to send a
71+
% request, we have to analyze the response instead.
72+
if response.StatusCode=="BadRequest" && ...
73+
isfield(response.Body.Data,"error") && ...
74+
isfield(response.Body.Data.error,"message") && ...
75+
response.Body.Data.error.message == "Unrecognized request argument supplied: max_completion_tokens"
76+
parameters = renameStructField(parameters,'max_completion_tokens','max_tokens');
77+
[response, streamedText] = llms.internal.sendRequestWrapper(parameters,nvp.APIKey, URL, nvp.TimeOut, nvp.StreamFun);
78+
end
79+
6980
% If call errors, "choices" will not be part of response.Body.Data, instead
7081
% we get response.Body.Data.error
7182
if response.StatusCode=="OK"
@@ -136,10 +147,15 @@
136147

137148
nvpOptions = keys(dict);
138149
for opt = nvpOptions.'
139-
if isfield(nvp, opt)
150+
if isfield(nvp, opt) && ~isempty(nvp.(opt))
140151
parameters.(dict(opt)) = nvp.(opt);
141152
end
142153
end
154+
155+
if nvp.MaxNumTokens == Inf
156+
parameters = rmfield(parameters,dict("MaxNumTokens"));
157+
end
158+
143159
end
144160

145161
function dict = mapNVPToParameters()
@@ -148,7 +164,7 @@
148164
dict("TopP") = "top_p";
149165
dict("NumCompletions") = "n";
150166
dict("StopSequences") = "stop";
151-
dict("MaxNumTokens") = "max_tokens";
167+
dict("MaxNumTokens") = "max_completion_tokens";
152168
dict("PresencePenalty") = "presence_penalty";
153169
dict("FrequencyPenalty") = "frequency_penalty";
154170
end

Diff for: azureChat.m

+3-1
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,9 @@
276276

277277
if isfield(response.Body.Data,"error")
278278
err = response.Body.Data.error.message;
279-
if startsWith(err,"'json_schema' is not one of ['json_object', 'text']")
279+
if startsWith(err,"'json_schema' is not one of ['json_object', 'text']") || ...
280+
startsWith(replace(err,newline," "),...
281+
"Invalid parameter: 'response_format' of type 'json_schema' is not supported with this model.")
280282
error("llms:noStructuredOutputForAzureDeployment", ...
281283
llms.utils.errorMessageCatalog.getMessage( ...
282284
"llms:noStructuredOutputForAzureDeployment",this.DeploymentID));

Diff for: tests/tazureChat.m

+18-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
classdef tazureChat < hopenAIChat
22
% Tests for azureChat
33

4-
% Copyright 2024 The MathWorks, Inc.
4+
% Copyright 2024-2025 The MathWorks, Inc.
55

66
properties(TestParameter)
77
ValidConstructorInput = iGetValidConstructorInput();
@@ -69,6 +69,20 @@ function responseFormatRequiresNewAPI(testCase)
6969
"llms:structuredOutputRequiresAPI");
7070
end
7171

72+
function maxNumTokensWithReasoningModel(testCase)
73+
% Unlike OpenAI, Azure requires different parameter names for
74+
% different models (max_tokens vs max_completion_tokens). Since
75+
% we do not even know what model some deployment uses (us naming
76+
% them after the model deployed is not a guarantee), that is a
77+
% somewhat painful distinction.
78+
testCase.verifyWarningFree(@() generate( ...
79+
azureChat(DeploymentID="gpt-35-turbo-16k-0613"), ...
80+
"What is object oriented design?", MaxNumTokens=23));
81+
testCase.verifyWarningFree(@() generate( ...
82+
azureChat(DeploymentID="o1-mini"), ...
83+
"What is object oriented design?", MaxNumTokens=23));
84+
end
85+
7286
function generateWithImage(testCase)
7387
chat = azureChat(DeploymentID="gpt-4o");
7488
image_path = "peppers.png";
@@ -123,10 +137,10 @@ function canUseAPIVersions(testCase, APIVersions)
123137
end
124138

125139
function specialErrorForUnsupportedResponseFormat(testCase)
126-
testCase.assumeFail("Disabled until `llms.internal.callAzureChat` is updated to use `max_completion_tokens` instead of the deprecated `max_tokens` in the OpenAI API.")
127-
140+
% Our "gpt-4o" deployment has the model version 2024-05-13,
141+
% which does not support structured output
128142
testCase.verifyError(@() generate(...
129-
azureChat(DeploymentID="o1-mini"), ...
143+
azureChat(DeploymentID="gpt-4o"), ...
130144
"What is the smallest prime?", ...
131145
ResponseFormat=struct("number",1)), ...
132146
"llms:noStructuredOutputForAzureDeployment");

0 commit comments

Comments
 (0)