Skip to content

Commit 69ca8b6

Browse files
authored
Merge pull request #19 from matlab-deep-learning/fix-embedding-bugs
Fix argument validation for extractOpenAIEmbeddings
2 parents d3e7389 + 48a7073 commit 69ca8b6

File tree

3 files changed

+42
-5
lines changed

3 files changed

+42
-5
lines changed

Diff for: +llms/+utils/errorMessageCatalog.m

+1
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,5 @@
5555
catalog("llms:pngExpected") = "Argument must be a PNG image.";
5656
catalog("llms:warningJsonInstruction") = "When using JSON mode, you must also prompt the model to produce JSON yourself via a system or user message.";
5757
catalog("llms:apiReturnedError") = "OpenAI API Error: {1}";
58+
catalog("llms:dimensionsMustBeSmallerThan") = "Dimensions must be less than or equal to {1}.";
5859
end

Diff for: extractOpenAIEmbeddings.m

+17-3
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
% [emb, response] = EXTRACTOPENAIEMBEDDINGS(...) also returns the full
2121
% response from the OpenAI API call.
2222
%
23-
% Copyright 2023 The MathWorks, Inc.
23+
% Copyright 2023-2024 The MathWorks, Inc.
2424

2525
arguments
26-
text (1,:) {mustBeText}
26+
text (1,:) {mustBeNonzeroLengthText}
2727
nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["text-embedding-ada-002", ...
2828
"text-embedding-3-large", "text-embedding-3-small"])} = "text-embedding-ada-002"
2929
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10
30-
nvp.Dimensions (1,1) {mustBeInteger}
30+
nvp.Dimensions (1,1) {mustBeInteger,mustBePositive}
3131
nvp.ApiKey {llms.utils.mustBeNonzeroLengthTextScalar}
3232
end
3333

@@ -42,6 +42,7 @@
4242
error("llms:invalidOptionForModel", ...
4343
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "Dimensions", nvp.ModelName));
4444
end
45+
mustBeCorrectDimensions(nvp.Dimensions,nvp.ModelName);
4546
parameters.dimensions = nvp.Dimensions;
4647
end
4748

@@ -53,4 +54,17 @@
5354
emb = emb';
5455
else
5556
emb = [];
57+
end
58+
end
59+
60+
function mustBeCorrectDimensions(dimensions,modelName)
61+
model2dim = ....
62+
dictionary(["text-embedding-3-large", "text-embedding-3-small"], ...
63+
[3072,1536]);
64+
65+
if dimensions>model2dim(modelName)
66+
error("llms:dimensionsMustBeSmallerThan", ...
67+
llms.utils.errorMessageCatalog.getMessage("llms:dimensionsMustBeSmallerThan", ...
68+
string(model2dim(modelName))));
69+
end
5670
end

Diff for: tests/textractOpenAIEmbeddings.m

+24-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
classdef textractOpenAIEmbeddings < matlab.unittest.TestCase
22
% Tests for extractOpenAIEmbeddings
33

4-
% Copyright 2023 The MathWorks, Inc.
4+
% Copyright 2023-2024 The MathWorks, Inc.
55

66
methods (TestClassSetup)
77
function saveEnvVar(testCase)
@@ -56,6 +56,14 @@ function testInvalidInputs(testCase, InvalidInput)
5656

5757
function invalidInput = iGetInvalidInput
5858
invalidInput = struct( ...
59+
"InvalidEmptyText", struct( ...
60+
"Input",{{ "" }},...
61+
"Error", "MATLAB:validators:mustBeNonzeroLengthText"), ...
62+
...
63+
"InvalidEmptyTextArray", struct( ...
64+
"Input",{{ ["", ""] }},...
65+
"Error", "MATLAB:validators:mustBeNonzeroLengthText"), ...
66+
...
5967
"InvalidTimeOutType", struct( ...
6068
"Input",{{ "bla", "TimeOut", "2" }},...
6169
"Error", "MATLAB:validators:mustBeReal"), ...
@@ -66,7 +74,7 @@ function testInvalidInputs(testCase, InvalidInput)
6674
...
6775
"WrongTypeText",struct( ...
6876
"Input",{{ 123 }},...
69-
"Error","MATLAB:validators:mustBeText"),...
77+
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
7078
...
7179
"InvalidModelNameType",struct( ...
7280
"Input",{{"bla", "ModelName", 0 }},...
@@ -84,6 +92,20 @@ function testInvalidInputs(testCase, InvalidInput)
8492
"Input",{{"bla", "Dimensions", "123" }},...
8593
"Error","MATLAB:validators:mustBeNumericOrLogical"),...
8694
...
95+
"InvalidDimensionValue",struct( ...
96+
"Input",{{"bla", "Dimensions", "-11" }},...
97+
"Error","MATLAB:validators:mustBeNumericOrLogical"),...
98+
...
99+
"LargeDimensionValueForModelLarge",struct( ...
100+
"Input",{{"bla", "ModelName", "text-embedding-3-large", ...
101+
"Dimensions", 3073, "ApiKey", "fake-key" }},...
102+
"Error","llms:dimensionsMustBeSmallerThan"),...
103+
...
104+
"LargeDimensionValueForModelSmall",struct( ...
105+
"Input",{{"bla", "ModelName", "text-embedding-3-small", ...
106+
"Dimensions", 1537, "ApiKey", "fake-key" }},...
107+
"Error","llms:dimensionsMustBeSmallerThan"),...
108+
...
87109
"InvalidDimensionSize",struct( ...
88110
"Input",{{"bla", "Dimensions", [123, 123] }},...
89111
"Error","MATLAB:validation:IncompatibleSize"),...

0 commit comments

Comments
 (0)