Skip to content

Commit aabf577

Browse files
authored
Merge pull request #27 from matlab-deep-learning/tests_using_key
OPENAI_KEY as env variable and extra tests
2 parents acae9cf + 78279d2 commit aabf577

File tree

5 files changed

+281
-142
lines changed

5 files changed

+281
-142
lines changed

.github/workflows/ci.yml

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ jobs:
1313
products: Text_Analytics_Toolbox
1414
cache: true
1515
- name: Run tests and generate artifacts
16+
env:
17+
OPENAI_KEY: ${{ secrets.OPENAI_KEY }}
1618
uses: matlab-actions/run-tests@v2
1719
with:
1820
test-results-junit: test-results/results.xml

tests/textractOpenAIEmbeddings.m

+27-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ function saveEnvVar(testCase)
1616
end
1717

1818
properties(TestParameter)
19-
InvalidInput = iGetInvalidInput;
19+
InvalidInput = iGetInvalidInput();
20+
ValidDimensionsModelCombinations = iGetValidDimensionsModelCombinations();
2021
end
2122

2223
methods(Test)
@@ -31,6 +32,18 @@ function keyNotFound(testCase)
3132
testCase.verifyError(@()extractOpenAIEmbeddings("bla"), "llms:keyMustBeSpecified");
3233
end
3334

35+
function validCombinationOfModelAndDimension(testCase, ValidDimensionsModelCombinations)
36+
testCase.verifyWarningFree(@()extractOpenAIEmbeddings("bla", ...
37+
Dimensions=ValidDimensionsModelCombinations.Dimensions,...
38+
ModelName=ValidDimensionsModelCombinations.ModelName, ...
39+
ApiKey="not-real"));
40+
end
41+
42+
function embedStringWithSuccessfulOpenAICall(testCase)
43+
testCase.verifyWarningFree(@()extractOpenAIEmbeddings("bla", ...
44+
ApiKey=getenv("OPENAI_KEY")));
45+
end
46+
3447
function invalidCombinationOfModelAndDimension(testCase)
3548
testCase.verifyError(@()extractOpenAIEmbeddings("bla", ...
3649
Dimensions=10,...
@@ -54,7 +67,7 @@ function testInvalidInputs(testCase, InvalidInput)
5467
end
5568
end
5669

57-
function invalidInput = iGetInvalidInput
70+
function invalidInput = iGetInvalidInput()
5871
invalidInput = struct( ...
5972
"InvalidEmptyText", struct( ...
6073
"Input",{{ "" }},...
@@ -117,4 +130,15 @@ function testInvalidInputs(testCase, InvalidInput)
117130
"InvalidApiKeySize",struct( ...
118131
"Input",{{"bla", "ApiKey" ["abc" "abc"] }},...
119132
"Error","MATLAB:validators:mustBeTextScalar"));
120-
end
133+
end
134+
135+
function validDimensionsModelCombinations = iGetValidDimensionsModelCombinations()
136+
validDimensionsModelCombinations = struct( ...
137+
"CaseTextEmbedding3Small", struct( ...
138+
"Dimensions",10,...
139+
"ModelName", "text-embedding-3-small"), ...
140+
...
141+
"CaseTextEmbedding3Large", struct( ...
142+
"Dimensions",10,...
143+
"ModelName", "text-embedding-3-large"));
144+
end

tests/topenAIChat.m

+65-11
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ function saveEnvVar(testCase)
1616
end
1717

1818
properties(TestParameter)
19-
InvalidConstructorInput = iGetInvalidConstructorInput;
20-
InvalidGenerateInput = iGetInvalidGenerateInput;
21-
InvalidValuesSetters = iGetInvalidValuesSetters;
19+
InvalidConstructorInput = iGetInvalidConstructorInput();
20+
InvalidGenerateInput = iGetInvalidGenerateInput();
21+
InvalidValuesSetters = iGetInvalidValuesSetters();
2222
end
2323

2424
methods(Test)
@@ -34,11 +34,8 @@ function generateAcceptsMessagesAsInput(testCase)
3434
chat = openAIChat(ApiKey="this-is-not-a-real-key");
3535
messages = openAIMessages;
3636
messages = addUserMessage(messages, "This should be okay.");
37-
testCase.verifyWarningFree(@()generate(chat,messages));
38-
end
3937

40-
function constructMdlWithInvalidParameters(testCase)
41-
testCase.verifyError(@()openAIChat(ApiKey="this-is-not-a-real-key", ModelName="gpt-4", ResponseFormat="json"), "llms:invalidOptionAndValueForModel");
38+
testCase.verifyWarningFree(@()generate(chat,messages));
4239
end
4340

4441
function keyNotFound(testCase)
@@ -59,6 +56,7 @@ function constructChatWithAllNVP(testCase)
5956
chat = openAIChat(systemPrompt, Tools=functions, ModelName=modelName, ...
6057
Temperature=temperature, TopProbabilityMass=topP, StopSequences=stop, ApiKey=apiKey,...
6158
FrequencyPenalty=frequenceP, PresencePenalty=presenceP, TimeOut=timeout);
59+
6260
testCase.verifyEqual(chat.ModelName, modelName);
6361
testCase.verifyEqual(chat.Temperature, temperature);
6462
testCase.verifyEqual(chat.TopProbabilityMass, topP);
@@ -69,27 +67,37 @@ function constructChatWithAllNVP(testCase)
6967

7068
function verySmallTimeOutErrors(testCase)
7169
chat = openAIChat(TimeOut=0.0001, ApiKey="false-key");
70+
7271
testCase.verifyError(@()generate(chat, "hi"), "MATLAB:webservices:Timeout")
7372
end
7473

7574
function errorsWhenPassingToolChoiceWithEmptyTools(testCase)
7675
chat = openAIChat(ApiKey="this-is-not-a-real-key");
76+
7777
testCase.verifyError(@()generate(chat,"input", ToolChoice="bla"), "llms:mustSetFunctionsForCall");
7878
end
7979

8080
function settingToolChoiceWithNone(testCase)
8181
functions = openAIFunction("funName");
8282
chat = openAIChat(ApiKey="this-is-not-a-real-key",Tools=functions);
83+
8384
testCase.verifyWarningFree(@()generate(chat,"This is okay","ToolChoice","none"));
8485
end
8586

87+
function settingSeedToInteger(testCase)
88+
chat = openAIChat(ApiKey="this-is-not-a-real-key");
89+
90+
testCase.verifyWarningFree(@()generate(chat,"This is okay", "Seed", 2));
91+
end
92+
8693
function invalidInputsConstructor(testCase, InvalidConstructorInput)
8794
testCase.verifyError(@()openAIChat(InvalidConstructorInput.Input{:}), InvalidConstructorInput.Error);
8895
end
8996

9097
function invalidInputsGenerate(testCase, InvalidGenerateInput)
9198
f = openAIFunction("validfunction");
9299
chat = openAIChat(Tools=f, ApiKey="this-is-not-a-real-key");
100+
93101
testCase.verifyError(@()generate(chat,InvalidGenerateInput.Input{:}), InvalidGenerateInput.Error);
94102
end
95103

@@ -107,18 +115,56 @@ function invalidGenerateInputforModel(testCase)
107115
image_path = "peppers.png";
108116
emptyMessages = openAIMessages;
109117
inValidMessages = addUserMessageWithImages(emptyMessages,"What is in the image?",image_path);
118+
110119
testCase.verifyError(@()generate(chat,inValidMessages), "llms:invalidContentTypeForModel")
111120
end
112121

113122
function noStopSequencesNoMaxNumTokens(testCase)
114123
chat = openAIChat(ApiKey="this-is-not-a-real-key");
124+
115125
testCase.verifyWarningFree(@()generate(chat,"This is okay"));
116126
end
117127

128+
function createOpenAIChatWithStreamFunc(testCase)
129+
130+
function seen = sf(str)
131+
persistent data;
132+
if isempty(data)
133+
data = strings(1, 0);
134+
end
135+
% Append streamed text to an empty string array of length 1
136+
data = [data, str];
137+
seen = data;
138+
end
139+
chat = openAIChat(ApiKey=getenv("OPENAI_KEY"), StreamFun=@sf);
140+
141+
testCase.verifyWarningFree(@()generate(chat, "Hello world."));
142+
% Checking that persistent data, which is still stored in
143+
% memory, is greater than 1. This would mean that the stream
144+
% function has been called and streamed some text.
145+
testCase.verifyGreaterThan(numel(sf("")), 1);
146+
end
147+
148+
function warningJSONResponseFormatGPT35(testCase)
149+
chat = @() openAIChat("You are a useful assistant", ...
150+
ApiKey="this-is-not-a-real-key", ...
151+
ResponseFormat="json", ...
152+
ModelName="gpt-3.5-turbo");
153+
154+
testCase.verifyWarning(@()chat(), "llms:warningJsonInstruction");
155+
end
156+
157+
function createOpenAIChatWithOpenAIKey(testCase)
158+
chat = openAIChat("You are a useful assistant", ...
159+
ApiKey=getenv("OPENAI_KEY"));
160+
161+
testCase.verifyWarningFree(@()generate(chat, "Hello world."));
162+
end
163+
118164
end
119165
end
120166

121-
function invalidValuesSetters = iGetInvalidValuesSetters
167+
function invalidValuesSetters = iGetInvalidValuesSetters()
122168

123169
invalidValuesSetters = struct( ...
124170
"InvalidTemperatureType", struct( ...
@@ -222,7 +268,7 @@ function noStopSequencesNoMaxNumTokens(testCase)
222268
"Error", "MATLAB:notGreaterEqual"));
223269
end
224270

225-
function invalidConstructorInput = iGetInvalidConstructorInput
271+
function invalidConstructorInput = iGetInvalidConstructorInput()
226272
validFunction = openAIFunction("funName");
227273
invalidConstructorInput = struct( ...
228274
"InvalidResponseFormatValue", struct( ...
@@ -233,6 +279,10 @@ function noStopSequencesNoMaxNumTokens(testCase)
233279
"Input",{{"ResponseFormat", ["text" "text"] }},...
234280
"Error", "MATLAB:validation:IncompatibleSize"), ...
235281
...
282+
"InvalidResponseFormatModelCombination", struct( ...
283+
"Input", {{"ApiKey", "this-is-not-a-real-key", "ModelName", "gpt-4", "ResponseFormat", "json"}}, ...
284+
"Error", "llms:invalidOptionAndValueForModel"), ...
285+
...
236286
"InvalidStreamFunType", struct( ...
237287
"Input",{{"StreamFun", "2" }},...
238288
"Error", "MATLAB:validators:mustBeA"), ...
@@ -366,7 +416,7 @@ function noStopSequencesNoMaxNumTokens(testCase)
366416
"Error","MATLAB:validators:mustBeTextScalar"));
367417
end
368418

369-
function invalidGenerateInput = iGetInvalidGenerateInput
419+
function invalidGenerateInput = iGetInvalidGenerateInput()
370420
emptyMessages = openAIMessages;
371421
validMessages = addUserMessage(emptyMessages,"Who invented the telephone?");
372422

@@ -409,5 +459,9 @@ function noStopSequencesNoMaxNumTokens(testCase)
409459
...
410460
"InvalidToolChoiceSize",struct( ...
411461
"Input",{{ validMessages "ToolChoice" ["validfunction", "validfunction"] }},...
412-
"Error","MATLAB:validators:mustBeTextScalar"));
462+
"Error","MATLAB:validators:mustBeTextScalar"),...
463+
...
464+
"InvalidSeed",struct( ...
465+
"Input",{{ validMessages "Seed" "2" }},...
466+
"Error","MATLAB:validators:mustBeNumericOrLogical"));
413467
end

tests/topenAIImages.m

+12
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,18 @@ function invalidInputsVariation(testCase, InvalidVariationInput)
129129
mdl = openAIImages(ApiKey="this-is-not-a-real-key");
130130
testCase.verifyError(@()createVariation(mdl,InvalidVariationInput.Input{:}), InvalidVariationInput.Error);
131131
end
132+
133+
function testThatImageIsReturned(testCase)
134+
mdl = openAIImages(ApiKey=getenv("OPENAI_KEY"));
135+
136+
[images, response] = generate(mdl, ...
137+
"Create a 3D avatar of a whimsical sushi on the beach. " + ...
138+
"He is decorated with various sushi elements and is " + ...
139+
"playfully interacting with the beach environment.");
140+
141+
testCase.verifySize(images{:}, [1024, 1024, 3]);
142+
testCase.verifyEqual(response.StatusLine.ReasonPhrase, "OK");
143+
end
132144
end
133145
end
134146

0 commit comments

Comments
 (0)