Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read model name from agent config in Gemini and Ollama APIs. #447

Merged
merged 4 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 1 addition & 25 deletions frontend/src/components/experimenter/experimenter_data_editor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ export class ExperimenterDataEditor extends MobxLitElement {

// ============ Local Ollama server ============
private renderServerSettings() {
const updateServerSettings = (e: InputEvent, field: 'url' | 'llmType') => {
const updateServerSettings = (e: InputEvent, field: 'url') => {
const oldData = this.authService.experimenterData;
if (!oldData) return;

Expand All @@ -196,18 +196,6 @@ export class ExperimenterDataEditor extends MobxLitElement {
},
});
break;

case "llmType":
newData = updateExperimenterData(oldData, {
apiKeys: {
...oldData.apiKeys,
ollamaApiKey: {
...oldData.apiKeys.ollamaApiKey,
llmType: value,
},
},
});
break;
default:
console.error("Error: field type not found: ", field);
return;
Expand All @@ -228,18 +216,6 @@ export class ExperimenterDataEditor extends MobxLitElement {
></pr-textarea>
<p>Please ensure that the URL is valid before proceeding.</p>

<pr-textarea
label="LLM type"
placeholder="llama3.2"
variant="outlined"
.value=${data?.apiKeys.ollamaApiKey?.llmType ?? ""}
@input=${(e: InputEvent) => updateServerSettings(e, 'llmType')}
></pr-textarea>
<p>
All supported LLM types can be found
<a target="_blank" href="https://ollama.com/library">here</a>.
Make sure the LLM type has been deployed on the server prior to selecting it here.
</p>
</div>
`;
}
Expand Down
5 changes: 1 addition & 4 deletions frontend/src/components/stages/chat_editor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,8 @@ export class ChatEditor extends MobxLitElement {
};

return html`
<div class="question-label">Model</div>
<div class="description">
Model ID for the agent. Currently only used for OpenAI and OAI-compatible APIs.
</div>
<pr-textarea
label="Model"
placeholder="Model ID"
variant="outlined"
.value=${agent.model}
Expand Down
97 changes: 97 additions & 0 deletions functions/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions functions/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"eslint-plugin-only-warn": "^1.1.0",
"eslint-plugin-prettier": "^5.1.3",
"firebase-functions-test": "^3.1.0",
"nock": "^14.0.1",
"ts-jest": "^29.2.5",
"typescript": "^4.9.0"
},
Expand Down
12 changes: 7 additions & 5 deletions functions/src/agent.utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ export async function getAgentResponse(data: ExperimenterData, prompt: string, a
process.env.OPENAI_MODEL_NAME,
prompt)
} else if (keyType === ApiKeyType.GEMINI_API_KEY) {
response = getGeminiResponse(data, prompt);
response = getGeminiResponse(data, agent.model, prompt);
} else if (keyType === ApiKeyType.OPENAI_API_KEY) {
response = getOpenAIAPIResponse(data, agent.model, prompt, agent.generationConfig);
} else if (keyType === ApiKeyType.OLLAMA_CUSTOM_URL) {
response = await getOllamaResponse(data, prompt);
response = await getOllamaResponse(data, agent.model, prompt);
} else {
console.error("Error: invalid apiKey type: ", keyType)
response = {text: ""};
Expand All @@ -30,8 +30,8 @@ export async function getAgentResponse(data: ExperimenterData, prompt: string, a
return response
}

export async function getGeminiResponse(data: ExperimenterData, prompt: string): Promise<ModelResponse> {
return await getGeminiAPIResponse(data.apiKeys.geminiApiKey, prompt);
export async function getGeminiResponse(data: ExperimenterData, modelName: string, prompt: string): Promise<ModelResponse> {
return await getGeminiAPIResponse(data.apiKeys.geminiApiKey, modelName, prompt);
}

async function getOpenAIAPIResponse(
Expand All @@ -46,6 +46,8 @@ async function getOpenAIAPIResponse(
);
}

export async function getOllamaResponse(data: ExperimenterData, prompt: string): Promise<ModelResponse> {
export async function getOllamaResponse(
data: ExperimenterData, modelName: string, prompt: string
): Promise<ModelResponse> {
return await ollamaChat([prompt], data.apiKeys.ollamaApiKey);
}
5 changes: 3 additions & 2 deletions functions/src/api/gemini.api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ export async function callGemini(
/** Constructs Gemini API query and returns response. */
export async function getGeminiAPIResponse(
apiKey: string,
modelName: string,
promptText: string,
stopSequences: string[] = [],
maxOutputTokens = 300,
Expand All @@ -86,7 +87,7 @@ export async function getGeminiAPIResponse(
apiKey,
promptText,
generationConfig,
GEMINI_DEFAULT_MODEL
modelName
);
} catch (error: any) {
if (error.message.includes(QUOTA_ERROR_CODE.toString())) {
Expand All @@ -98,4 +99,4 @@ export async function getGeminiAPIResponse(
}

return response;
}
}
2 changes: 1 addition & 1 deletion functions/src/api/model.response.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/**
* Common interface for all model responses.
*/
interface ModelResponse {
export interface ModelResponse {
score?: number;
text: string;
}
40 changes: 25 additions & 15 deletions functions/src/api/ollama.api.test.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
import { ollamaChat } from "./ollama.api";
import nock = require('nock');

/**
* Test assumes a container with ollama is running on port 11434
* Download the docker image to run :
* https://ollama.com/blog/ollama-is-now-available-as-an-official-docker-image
*
* Example docker instance hosting an ollama server: https://github.com/dimits-ts/deliberate-lab-utils/tree/master/llm_server
*/
import { ollamaChat } from "./ollama.api";

const MODEL_TYPE = "llama3.2";
const MODEL_NAME = "llama3.2";
const LLM_SERVER_ENDPOINT = "http://localhost:11434/api/chat";
const LLM_SERVER_HOST = "http://localhost:11434";
const LLM_SERVER_PATH = "/api/chat";
const TEST_MESSAGE = "Say hello!";


describe("OllamaChat Client", () => {
it("should return a response containing 'hello' (case insensitive)", async () => {
const response = await ollamaChat([TEST_MESSAGE], {url: LLM_SERVER_ENDPOINT, llmType: MODEL_TYPE});
expect(response.text.toLowerCase()).toContain("hello");
console.log(response);
});
});
it("should return a response containing 'hello' (case insensitive)", async () => {
nock(LLM_SERVER_HOST)
.post(LLM_SERVER_PATH, body => body.model == MODEL_NAME)
.reply(200, {
'created_at': Date.now(),
'model': MODEL_NAME,
'message': {
'role': 'assistant',
'content': 'Hello!',
},
'done': true,
});

const response = await ollamaChat([TEST_MESSAGE], MODEL_NAME, {url: LLM_SERVER_ENDPOINT});
expect(response.text.toLowerCase()).toContain("hello");
console.log(response);

nock.cleanAll();
});
});
12 changes: 7 additions & 5 deletions functions/src/api/ollama.api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
*/

import { OllamaServerConfig } from "@deliberation-lab/utils"
import { ModelResponse } from './model.response';


/**
Expand Down Expand Up @@ -37,9 +38,10 @@ type OllamaMessage = {
* @returns the model's response as a string, or empty string if an error occured
*/
export async function ollamaChat(messages: string[],
modelName: string,
serverConfig: OllamaServerConfig)
: Promise<ModelResponse> {
const messageObjects = encodeMessages(messages, serverConfig.llmType);
const messageObjects = encodeMessages(messages, modelName);
const response = await fetch(serverConfig.url, { method: "POST", body: JSON.stringify(messageObjects) });
const responseMessage = await decodeResponse(response);
return { text: responseMessage };
Expand All @@ -56,7 +58,7 @@ async function decodeResponse(response: Response): Promise<string> {
throw new Error("Failed to read response body");
}

const { done, value } = await reader.read();
const { done: _, value } = await reader.read();
const rawjson = new TextDecoder().decode(value);

if (isError(rawjson)) {
Expand All @@ -72,14 +74,14 @@ async function decodeResponse(response: Response): Promise<string> {
/**
* Transform string-messages to JSON objects appropriate for the model's API.
* @param messages a list of string-messages to be sent to the LLM
* @param modelType the type of llm running in the server (e.g. "llama3.2").
* @param modelName the type of llm running in the server (e.g. "llama3.2").
* Keep in mind that the model must have been loaded server-side in order to be used.
* @returns appropriate JSON objects which the model can understand
*/
function encodeMessages(messages: string[], modelType: string): OutgoingMessage {
function encodeMessages(messages: string[], modelName: string): OutgoingMessage {
const messageObjs: OllamaMessage[] = messages.map((message) => ({ role: "user", content: message }));
return {
model: modelType,
model: modelName,
messages: messageObjs,
stream: false
};
Expand Down
Loading