Skip to content

Commit

Permalink
Merge pull request #33201 from vespa-engine/glebashnik/fix-openai-lm-…
Browse files Browse the repository at this point in the history
…component

Glebashnik/fix openai lm component
  • Loading branch information
bjorncs authored Jan 28, 2025
2 parents 1520b06 + 908700b commit 965ed30
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 37 deletions.
30 changes: 17 additions & 13 deletions model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
/**
* A configurable OpenAI client.
*
* @author lesters glebashnik
* @author lesters
* @author glebashnik
*/
@Beta
public class OpenAI extends ConfigurableLanguageModel {
Expand All @@ -45,23 +46,26 @@ public OpenAI(LlmClientConfig config, Secrets secretStore) {
}

}

@Override
public List<Completion> complete(Prompt prompt, InferenceParameters parameters) {
var combinedParameters = parameters.withDefaultOptions(configOptions::get);

private InferenceParameters prepareParameters(InferenceParameters parameters) {
setApiKey(parameters);
setEndpoint(parameters);
return client.complete(prompt, combinedParameters);
var combinedParameters = parameters.withDefaultOptions(configOptions::get);
return combinedParameters;
}

@Override
public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
InferenceParameters parameters,
Consumer<Completion> consumer) {
var combinedParameters = parameters.withDefaultOptions(configOptions::get);
setApiKey(parameters);
setEndpoint(parameters);
return client.completeAsync(prompt, combinedParameters, consumer);
public List<Completion> complete(
Prompt prompt, InferenceParameters parameters) {
var preparedParameters = prepareParameters(parameters);
return client.complete(prompt, preparedParameters);
}

@Override
public CompletableFuture<Completion.FinishReason> completeAsync(
Prompt prompt, InferenceParameters parameters, Consumer<Completion> consumer) {
var preparedParameters = prepareParameters(parameters);
return client.completeAsync(prompt, preparedParameters, consumer);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,84 @@

import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.completion.StringPrompt;
import com.yahoo.container.jdisc.SecretsProvider;
import ai.vespa.secret.Secret;
import ai.vespa.secret.Secrets;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.util.Map;

public class OpenAITest {

private static final String apiKey = "<your-api-key>";

private static final String API_KEY = "<YOUR_API_KEY>";
@Test
@Disabled
public void testOpenAIGeneration() {
var config = new LlmClientConfig.Builder().build();
var openai = new OpenAI(config, new SecretsProvider().get());
public void testComplete() {
var config = new LlmClientConfig.Builder()
.apiKeySecretName("openai")
.maxTokens(10)
.build();
var openai = new OpenAI(config, new MockSecrets());
var options = Map.of(
"maxTokens", "10"
"model", "gpt-4o-mini"
);

var prompt = StringPrompt.from("why are ducks better than cats?");
var future = openai.completeAsync(prompt, new InferenceParameters(apiKey, options::get), completion -> {
System.out.print(completion.text());
}).exceptionally(exception -> {
System.out.println("Error: " + exception);
return null;
});
future.join();
var prompt = StringPrompt.from("Explain why ducks better than cats in 20 words?");
var completions = openai.complete(prompt, new InferenceParameters(options::get));
var text = completions.get(0).text();

System.out.print(text);
assertNumTokens(text, 3, 10);
}

@Test
@Disabled
public void testComplete() {
var config = new LlmClientConfig.Builder().maxTokens(10).build();
var openai = new OpenAI(config, new SecretsProvider().get());
public void testCompleteAsync() {
var config = new LlmClientConfig.Builder()
.apiKeySecretName("openai")
.maxTokens(10)
.build();
var openai = new OpenAI(config, new MockSecrets());
var options = Map.of(
"model", "gpt-4o-mini"
);
var prompt = StringPrompt.from("Explain why ducks better than cats in 20 words?");
var completions = openai.complete(prompt, new InferenceParameters(apiKey, options::get));
assertFalse(completions.isEmpty());
var text = new StringBuilder();

// Token is smaller than word.
var future = openai.completeAsync(prompt, new InferenceParameters(API_KEY, options::get), completion -> {
text.append(completion.text());
}).exceptionally(exception -> {
System.out.println("Error: " + exception);
return null;
});
future.join();

System.out.print(text);
assertNumTokens(text.toString(), 3, 10);
}

private void assertNumTokens(String completion, int minTokens, int maxTokens) {
// Splitting by space is a poor tokenizer but it is good enough for this test.
assertTrue(completions.get(0).text().split(" ").length <= 10);
var numTokens = completion.split(" ").length;
assertTrue( minTokens <= numTokens && numTokens <= maxTokens);
}

static class MockSecrets implements Secrets {
@Override
public Secret get(String key) {
if (key.equals("openai")) {
return new Secret() {
@Override
public String current() {
return API_KEY;
}
};
}

return null;
}
}

}

0 comments on commit 965ed30

Please sign in to comment.