diff --git a/frontend/src/components/experimenter/experimenter_data_editor.ts b/frontend/src/components/experimenter/experimenter_data_editor.ts
index 9d508d749..e67198aa8 100644
--- a/frontend/src/components/experimenter/experimenter_data_editor.ts
+++ b/frontend/src/components/experimenter/experimenter_data_editor.ts
@@ -177,7 +177,7 @@ export class ExperimenterDataEditor extends MobxLitElement {
// ============ Local Ollama server ============
private renderServerSettings() {
- const updateServerSettings = (e: InputEvent, field: 'url' | 'llmType') => {
+ const updateServerSettings = (e: InputEvent, field: 'url') => {
const oldData = this.authService.experimenterData;
if (!oldData) return;
@@ -196,18 +196,6 @@ export class ExperimenterDataEditor extends MobxLitElement {
},
});
break;
-
- case "llmType":
- newData = updateExperimenterData(oldData, {
- apiKeys: {
- ...oldData.apiKeys,
- ollamaApiKey: {
- ...oldData.apiKeys.ollamaApiKey,
- llmType: value,
- },
- },
- });
- break;
default:
console.error("Error: field type not found: ", field);
return;
@@ -228,18 +216,6 @@ export class ExperimenterDataEditor extends MobxLitElement {
>
Please ensure that the URL is valid before proceeding.
- updateServerSettings(e, 'llmType')}
- >
-
- All supported LLM types can be found
- here.
- Make sure the LLM type has been deployed on the server prior to selecting it here.
-
`;
}
diff --git a/frontend/src/components/stages/chat_editor.ts b/frontend/src/components/stages/chat_editor.ts
index 0786ba090..be3c150c8 100644
--- a/frontend/src/components/stages/chat_editor.ts
+++ b/frontend/src/components/stages/chat_editor.ts
@@ -214,11 +214,8 @@ export class ChatEditor extends MobxLitElement {
};
return html`
- Model
-
- Model ID for the agent. Currently only used for OpenAI and OAI-compatible APIs.
-
=18"
+ }
+ },
"node_modules/@nodelib/fs.scandir": {
"version": "2.1.5",
"dev": true,
@@ -2642,6 +2661,31 @@
"node": ">= 8"
}
},
+ "node_modules/@open-draft/deferred-promise": {
+ "version": "2.2.0",
+ "resolved": "https://registry.npmjs.org/@open-draft/deferred-promise/-/deferred-promise-2.2.0.tgz",
+ "integrity": "sha512-CecwLWx3rhxVQF6V4bAgPS5t+So2sTbPgAzafKkVizyi7tlwpcFpdFqq+wqF2OwNBmqFuu6tOyouTuxgpMfzmA==",
+ "dev": true,
+ "license": "MIT"
+ },
+ "node_modules/@open-draft/logger": {
+ "version": "0.3.0",
+ "resolved": "https://registry.npmjs.org/@open-draft/logger/-/logger-0.3.0.tgz",
+ "integrity": "sha512-X2g45fzhxH238HKO4xbSr7+wBS8Fvw6ixhTDuvLd5mqh6bJJCFAPwU9mPDxbcrRtfxv4u5IHCEH77BmxvXmmxQ==",
+ "dev": true,
+ "license": "MIT",
+ "dependencies": {
+ "is-node-process": "^1.2.0",
+ "outvariant": "^1.4.0"
+ }
+ },
+ "node_modules/@open-draft/until": {
+ "version": "2.1.0",
+ "resolved": "https://registry.npmjs.org/@open-draft/until/-/until-2.1.0.tgz",
+ "integrity": "sha512-U69T3ItWHvLwGg5eJ0n3I62nWuE6ilHlmz7zM0npLBRvPRd7e6NYmg54vvRtP5mZG7kZqZCFVdsTWo7BPtBujg==",
+ "dev": true,
+ "license": "MIT"
+ },
"node_modules/@pkgr/core": {
"version": "0.1.1",
"resolved": "https://registry.npmjs.org/@pkgr/core/-/core-0.1.1.tgz",
@@ -6136,6 +6180,13 @@
"url": "https://github.com/sponsors/ljharb"
}
},
+ "node_modules/is-node-process": {
+ "version": "1.2.0",
+ "resolved": "https://registry.npmjs.org/is-node-process/-/is-node-process-1.2.0.tgz",
+ "integrity": "sha512-Vg4o6/fqPxIjtxgUH5QLJhwZ7gW5diGCVlXpuUfELC62CuxM1iHcRe51f2W1FDy04Ai4KJkagKjx3XaqyfRKXw==",
+ "dev": true,
+ "license": "MIT"
+ },
"node_modules/is-number": {
"version": "7.0.0",
"resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz",
@@ -6998,6 +7049,13 @@
"dev": true,
"license": "MIT"
},
+ "node_modules/json-stringify-safe": {
+ "version": "5.0.1",
+ "resolved": "https://registry.npmjs.org/json-stringify-safe/-/json-stringify-safe-5.0.1.tgz",
+ "integrity": "sha512-ZClg6AaYvamvYEE82d3Iyd3vSSIjQ+odgjaTzRuO3s7toCdFKczob2i0zCh7JE8kWn17yvAWhUVxvqGwUalsRA==",
+ "dev": true,
+ "license": "ISC"
+ },
"node_modules/json5": {
"version": "2.2.3",
"resolved": "https://registry.npmjs.org/json5/-/json5-2.2.3.tgz",
@@ -7413,6 +7471,21 @@
"node": ">= 0.6"
}
},
+ "node_modules/nock": {
+ "version": "14.0.1",
+ "resolved": "https://registry.npmjs.org/nock/-/nock-14.0.1.tgz",
+ "integrity": "sha512-IJN4O9pturuRdn60NjQ7YkFt6Rwei7ZKaOwb1tvUIIqTgeD0SDDAX3vrqZD4wcXczeEy/AsUXxpGpP/yHqV7xg==",
+ "dev": true,
+ "license": "MIT",
+ "dependencies": {
+ "@mswjs/interceptors": "^0.37.3",
+ "json-stringify-safe": "^5.0.1",
+ "propagate": "^2.0.0"
+ },
+ "engines": {
+ "node": ">=18.20.0 <20 || >=20.12.1"
+ }
+ },
"node_modules/node-abi": {
"version": "3.62.0",
"license": "MIT",
@@ -7683,6 +7756,13 @@
"node": ">= 0.8.0"
}
},
+ "node_modules/outvariant": {
+ "version": "1.4.3",
+ "resolved": "https://registry.npmjs.org/outvariant/-/outvariant-1.4.3.tgz",
+ "integrity": "sha512-+Sl2UErvtsoajRDKCE5/dBz4DIvHXQQnAxtQTF04OJxY0+DyZXSo5P5Bb7XYWOh81syohlYL24hbDwxedPUJCA==",
+ "dev": true,
+ "license": "MIT"
+ },
"node_modules/p-limit": {
"version": "3.1.0",
"devOptional": true,
@@ -8003,6 +8083,16 @@
"node": ">= 6"
}
},
+ "node_modules/propagate": {
+ "version": "2.0.1",
+ "resolved": "https://registry.npmjs.org/propagate/-/propagate-2.0.1.tgz",
+ "integrity": "sha512-vGrhOavPSTz4QVNuBNdcNXePNdNMaO1xj9yBeH1ScQPjk/rhg9sSlCXPhMkFuaNNW/syTvYqsnbIJxMBfRbbag==",
+ "dev": true,
+ "license": "MIT",
+ "engines": {
+ "node": ">= 8"
+ }
+ },
"node_modules/proto3-json-serializer": {
"version": "2.0.1",
"license": "Apache-2.0",
@@ -8734,6 +8824,13 @@
"license": "MIT",
"optional": true
},
+ "node_modules/strict-event-emitter": {
+ "version": "0.5.1",
+ "resolved": "https://registry.npmjs.org/strict-event-emitter/-/strict-event-emitter-0.5.1.tgz",
+ "integrity": "sha512-vMgjE/GGEPEFnhFub6pa4FmJBRBVOLpIII2hvCZ8Kzb7K0hlHo7mQv6xYrBvCL2LtAIBwFUK8wvuJgTVSQ5MFQ==",
+ "dev": true,
+ "license": "MIT"
+ },
"node_modules/string_decoder": {
"version": "1.3.0",
"license": "MIT",
diff --git a/functions/package.json b/functions/package.json
index 7cb8809d1..1d13ace00 100644
--- a/functions/package.json
+++ b/functions/package.json
@@ -49,6 +49,7 @@
"eslint-plugin-only-warn": "^1.1.0",
"eslint-plugin-prettier": "^5.1.3",
"firebase-functions-test": "^3.1.0",
+ "nock": "^14.0.1",
"ts-jest": "^29.2.5",
"typescript": "^4.9.0"
},
diff --git a/functions/src/agent.utils.ts b/functions/src/agent.utils.ts
index f246c9ea6..326a02ba0 100644
--- a/functions/src/agent.utils.ts
+++ b/functions/src/agent.utils.ts
@@ -17,11 +17,11 @@ export async function getAgentResponse(data: ExperimenterData, prompt: string, a
process.env.OPENAI_MODEL_NAME,
prompt)
} else if (keyType === ApiKeyType.GEMINI_API_KEY) {
- response = getGeminiResponse(data, prompt);
+ response = getGeminiResponse(data, agent.model, prompt);
} else if (keyType === ApiKeyType.OPENAI_API_KEY) {
response = getOpenAIAPIResponse(data, agent.model, prompt, agent.generationConfig);
} else if (keyType === ApiKeyType.OLLAMA_CUSTOM_URL) {
- response = await getOllamaResponse(data, prompt);
+ response = await getOllamaResponse(data, agent.model, prompt);
} else {
console.error("Error: invalid apiKey type: ", keyType)
response = {text: ""};
@@ -30,8 +30,8 @@ export async function getAgentResponse(data: ExperimenterData, prompt: string, a
return response
}
-export async function getGeminiResponse(data: ExperimenterData, prompt: string): Promise {
- return await getGeminiAPIResponse(data.apiKeys.geminiApiKey, prompt);
+export async function getGeminiResponse(data: ExperimenterData, modelName: string, prompt: string): Promise {
+ return await getGeminiAPIResponse(data.apiKeys.geminiApiKey, modelName, prompt);
}
async function getOpenAIAPIResponse(
@@ -46,6 +46,8 @@ async function getOpenAIAPIResponse(
);
}
-export async function getOllamaResponse(data: ExperimenterData, prompt: string): Promise {
+export async function getOllamaResponse(
+ data: ExperimenterData, modelName: string, prompt: string
+): Promise {
return await ollamaChat([prompt], data.apiKeys.ollamaApiKey);
}
diff --git a/functions/src/api/gemini.api.ts b/functions/src/api/gemini.api.ts
index cf6137a11..22633100a 100644
--- a/functions/src/api/gemini.api.ts
+++ b/functions/src/api/gemini.api.ts
@@ -65,6 +65,7 @@ export async function callGemini(
/** Constructs Gemini API query and returns response. */
export async function getGeminiAPIResponse(
apiKey: string,
+ modelName: string,
promptText: string,
stopSequences: string[] = [],
maxOutputTokens = 300,
@@ -86,7 +87,7 @@ export async function getGeminiAPIResponse(
apiKey,
promptText,
generationConfig,
- GEMINI_DEFAULT_MODEL
+ modelName
);
} catch (error: any) {
if (error.message.includes(QUOTA_ERROR_CODE.toString())) {
@@ -98,4 +99,4 @@ export async function getGeminiAPIResponse(
}
return response;
-}
\ No newline at end of file
+}
diff --git a/functions/src/api/model.response.ts b/functions/src/api/model.response.ts
index 7ba675698..44e135908 100644
--- a/functions/src/api/model.response.ts
+++ b/functions/src/api/model.response.ts
@@ -1,7 +1,7 @@
/**
* Common interface for all model responses.
*/
-interface ModelResponse {
+export interface ModelResponse {
score?: number;
text: string;
}
diff --git a/functions/src/api/ollama.api.test.ts b/functions/src/api/ollama.api.test.ts
index fcb497027..e0c733ff7 100644
--- a/functions/src/api/ollama.api.test.ts
+++ b/functions/src/api/ollama.api.test.ts
@@ -1,22 +1,32 @@
-import { ollamaChat } from "./ollama.api";
+import nock = require('nock');
-/**
- * Test assumes a container with ollama is running on port 11434
- * Download the docker image to run :
- * https://ollama.com/blog/ollama-is-now-available-as-an-official-docker-image
- *
- * Example docker instance hosting an ollama server: https://github.com/dimits-ts/deliberate-lab-utils/tree/master/llm_server
- */
+import { ollamaChat } from "./ollama.api";
-const MODEL_TYPE = "llama3.2";
+const MODEL_NAME = "llama3.2";
const LLM_SERVER_ENDPOINT = "http://localhost:11434/api/chat";
+const LLM_SERVER_HOST = "http://localhost:11434";
+const LLM_SERVER_PATH = "/api/chat";
const TEST_MESSAGE = "Say hello!";
describe("OllamaChat Client", () => {
- it("should return a response containing 'hello' (case insensitive)", async () => {
- const response = await ollamaChat([TEST_MESSAGE], {url: LLM_SERVER_ENDPOINT, llmType: MODEL_TYPE});
- expect(response.text.toLowerCase()).toContain("hello");
- console.log(response);
- });
-});
\ No newline at end of file
+ it("should return a response containing 'hello' (case insensitive)", async () => {
+ nock(LLM_SERVER_HOST)
+ .post(LLM_SERVER_PATH, body => body.model == MODEL_NAME)
+ .reply(200, {
+ 'created_at': Date.now(),
+ 'model': MODEL_NAME,
+ 'message': {
+ 'role': 'assistant',
+ 'content': 'Hello!',
+ },
+ 'done': true,
+ });
+
+ const response = await ollamaChat([TEST_MESSAGE], MODEL_NAME, {url: LLM_SERVER_ENDPOINT});
+ expect(response.text.toLowerCase()).toContain("hello");
+ console.log(response);
+
+ nock.cleanAll();
+ });
+});
diff --git a/functions/src/api/ollama.api.ts b/functions/src/api/ollama.api.ts
index 272b57c33..ed846e592 100644
--- a/functions/src/api/ollama.api.ts
+++ b/functions/src/api/ollama.api.ts
@@ -9,6 +9,7 @@
*/
import { OllamaServerConfig } from "@deliberation-lab/utils"
+import { ModelResponse } from './model.response';
/**
@@ -37,9 +38,10 @@ type OllamaMessage = {
* @returns the model's response as a string, or empty string if an error occured
*/
export async function ollamaChat(messages: string[],
+ modelName: string,
serverConfig: OllamaServerConfig)
: Promise {
- const messageObjects = encodeMessages(messages, serverConfig.llmType);
+ const messageObjects = encodeMessages(messages, modelName);
const response = await fetch(serverConfig.url, { method: "POST", body: JSON.stringify(messageObjects) });
const responseMessage = await decodeResponse(response);
return { text: responseMessage };
@@ -56,7 +58,7 @@ async function decodeResponse(response: Response): Promise {
throw new Error("Failed to read response body");
}
- const { done, value } = await reader.read();
+ const { done: _, value } = await reader.read();
const rawjson = new TextDecoder().decode(value);
if (isError(rawjson)) {
@@ -72,14 +74,14 @@ async function decodeResponse(response: Response): Promise {
/**
* Transform string-messages to JSON objects appropriate for the model's API.
* @param messages a list of string-messages to be sent to the LLM
- * @param modelType the type of llm running in the server (e.g. "llama3.2").
+ * @param modelName the type of llm running in the server (e.g. "llama3.2").
* Keep in mind that the model must have been loaded server-side in order to be used.
* @returns appropriate JSON objects which the model can understand
*/
-function encodeMessages(messages: string[], modelType: string): OutgoingMessage {
+function encodeMessages(messages: string[], modelName: string): OutgoingMessage {
const messageObjs: OllamaMessage[] = messages.map((message) => ({ role: "user", content: message }));
return {
- model: modelType,
+ model: modelName,
messages: messageObjs,
stream: false
};
diff --git a/functions/src/api/openai.api.test.ts b/functions/src/api/openai.api.test.ts
new file mode 100644
index 000000000..f47d5b246
--- /dev/null
+++ b/functions/src/api/openai.api.test.ts
@@ -0,0 +1,47 @@
+import nock = require('nock');
+
+import { AgentGenerationConfig } from '@deliberation-lab/utils';
+import { getOpenAIAPITextCompletionResponse } from './openai.api';
+import { ModelResponse } from './model.response';
+
+
+describe('OpenAI-compatible API', () => {
+ it('handles text completion request', async () => {
+ nock('https://test.uri')
+ .post('/v1/completions', body => body.model == 'test-model')
+ .reply(200, {
+ 'id': 'test-id',
+ 'object': 'text_completion',
+ 'created': Date.now(),
+ 'model': 'test-model',
+ 'choices': [
+ {
+ 'text': 'test output',
+ 'index': 0,
+ 'logprobs': null,
+ 'finish_reason': 'stop',
+ }
+ ],
+ });
+
+ const generationConfig: AgentGenerationConfig = {
+ temperature: 0.7,
+ topP: 1,
+ frequencyPenalty: 0,
+ presencePenalty: 0,
+ customRequestBodyFields: [ {name: 'foo', value: 'bar'} ]
+ };
+
+ const response: ModelResponse = await getOpenAIAPITextCompletionResponse(
+ 'testapikey',
+ 'https://test.uri/v1/',
+ 'test-model',
+ 'This is a test prompt.',
+ generationConfig
+ );
+
+ expect(response.text).toEqual('test output');
+
+ nock.cleanAll();
+ });
+});
diff --git a/functions/src/api/openai.api.ts b/functions/src/api/openai.api.ts
index 18b3524a5..686515a8d 100644
--- a/functions/src/api/openai.api.ts
+++ b/functions/src/api/openai.api.ts
@@ -2,6 +2,7 @@ import OpenAI from "openai"
import {
AgentGenerationConfig
} from '@deliberation-lab/utils';
+import { ModelResponse } from './model.response';
const MAX_TOKENS_FINISH_REASON = "length";
@@ -10,7 +11,7 @@ export async function callOpenAITextCompletion(
baseUrl: string | null,
modelName: string,
prompt: string,
- generationConfig: agentGenerationConfig
+ generationConfig: AgentGenerationConfig
) {
const client = new OpenAI({
apiKey: apiKey,
@@ -27,7 +28,6 @@ export async function callOpenAITextCompletion(
top_p: generationConfig.topP,
frequency_penalty: generationConfig.frequencyPenalty,
presence_penalty: generationConfig.presencePenalty,
- // @ts-expect-error allow extra request fields
...customFields
});
@@ -37,7 +37,7 @@ export async function callOpenAITextCompletion(
return { text: '' };
}
- const finishReason = response.choices[0].finishReason;
+ const finishReason = response.choices[0].finish_reason;
if (finishReason === MAX_TOKENS_FINISH_REASON) {
console.error(
`Error: Token limit exceeded`
diff --git a/utils/src/experimenter.test.ts b/utils/src/experimenter.test.ts
index 5cbe9e8c5..7948809ff 100644
--- a/utils/src/experimenter.test.ts
+++ b/utils/src/experimenter.test.ts
@@ -33,21 +33,21 @@ describe('checkApiKeyExists', () => {
test('returns false if active API key type is Ollama and ollamaApiKey is invalid', () => {
experimenterData.apiKeys.activeApiKeyType = ApiKeyType.OLLAMA_CUSTOM_URL;
- experimenterData.apiKeys.ollamaApiKey = { url: '' , llmType: "llama3.2"};
+ experimenterData.apiKeys.ollamaApiKey = { url: '' };
expect(checkApiKeyExists(experimenterData)).toBe(false);
});
test('returns true if active API key type is Ollama and ollamaApiKey is valid', () => {
experimenterData.apiKeys.activeApiKeyType = ApiKeyType.OLLAMA_CUSTOM_URL;
- experimenterData.apiKeys.ollamaApiKey = { url: 'http://valid-url.com' , llmType: "llama3.2" };
+ experimenterData.apiKeys.ollamaApiKey = { url: 'http://valid-url.com' };
expect(checkApiKeyExists(experimenterData)).toBe(true);
});
test('returns false if active API key type is Ollama and ollamaApiKey is invalid', () => {
experimenterData.apiKeys.activeApiKeyType = ApiKeyType.OLLAMA_CUSTOM_URL;
- experimenterData.apiKeys.ollamaApiKey = { url: 'http://valid-url.com' , llmType: '' };
+ experimenterData.apiKeys.ollamaApiKey = { url: 'http://valid-url.com' };
expect(checkApiKeyExists(experimenterData)).toBe(false);
});
diff --git a/utils/src/experimenter.ts b/utils/src/experimenter.ts
index 8dcb110ae..853c2cbea 100644
--- a/utils/src/experimenter.ts
+++ b/utils/src/experimenter.ts
@@ -52,11 +52,6 @@ export interface OpenAIServerConfig {
export interface OllamaServerConfig {
url: string;
- /*
- * The type of llm running in the server (e.g. "llama3.2").
- * Keep in mind that the model must have been loaded server-side in order to be used.
- */
- llmType: string
// port: number; // apparently not needed? https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
// will probably need more data for server-side auth?
}
@@ -67,7 +62,6 @@ export interface OllamaServerConfig {
// ************************************************************************* //
const INVALID_API_KEY = ''
-const INVALID_LLM_TYPE = ''
const EMPTY_BASE_URL = ''
// ************************************************************************* //
@@ -101,7 +95,7 @@ export function createExperimenterData(
apiKeys: {
geminiApiKey: INVALID_API_KEY,
openAIApiKey: createOpenAIServerConfig(),
- ollamaApiKey: { url: INVALID_API_KEY, llmType: INVALID_LLM_TYPE },
+ ollamaApiKey: { url: INVALID_API_KEY },
activeApiKeyType: ApiKeyType.GEMINI_API_KEY
},
email: experimenterEmail
@@ -133,8 +127,7 @@ export function checkApiKeyExists(experimenterData: ExperimenterData | null | un
if (experimenterData.apiKeys.activeApiKeyType === ApiKeyType.OLLAMA_CUSTOM_URL) {
// implicitly checks if llamaApiKey exists
return (
- (experimenterData.apiKeys.ollamaApiKey.url !== INVALID_API_KEY) &&
- (experimenterData.apiKeys.ollamaApiKey.llmType !== INVALID_LLM_TYPE)
+ (experimenterData.apiKeys.ollamaApiKey.url !== INVALID_API_KEY)
);
}