diff --git a/lib/prompt.js b/lib/prompt.js index fb5cc80..390aca3 100644 --- a/lib/prompt.js +++ b/lib/prompt.js @@ -4,7 +4,7 @@ const DEFAULT_ENDPOINT = "https://api.githubcopilot.com/chat/completions"; const DEFAULT_MODEL = "gpt-4o"; /** @type {import('..').PromptInterface} */ -function parsePromptArguments(userPrompt, promptOptions) { +export function parsePromptArguments(userPrompt, promptOptions) { const { request: requestOptions, ...options } = typeof userPrompt === "string" ? promptOptions : userPrompt; diff --git a/test/prompt.test.js b/test/prompt.test.js index 12df603..054297e 100644 --- a/test/prompt.test.js +++ b/test/prompt.test.js @@ -3,6 +3,7 @@ import test from "ava"; import { MockAgent } from "undici"; import { prompt, getFunctionCalls } from "../index.js"; +import { parsePromptArguments } from "../lib/prompt.js"; test("smoke", (t) => { t.is(typeof prompt, "function"); @@ -479,3 +480,58 @@ test("does not include function calls", async (t) => { t.deepEqual(result, []); }); + +test("parsePromptArguments - uses Node fetch if no options.fetch passed as argument", (t) => { + const [parsedFetch] = parsePromptArguments( + "What is the capital of France?", + {} + ); + + t.deepEqual(fetch, parsedFetch); +}); + +test("prompt.stream", async (t) => { + const mockAgent = new MockAgent(); + function fetchMock(url, opts) { + opts ||= {}; + opts.dispatcher = mockAgent; + return fetch(url, opts); + } + + mockAgent.disableNetConnect(); + const mockPool = mockAgent.get("https://api.githubcopilot.com"); + mockPool + .intercept({ + method: "post", + path: `/chat/completions`, + }) + .reply(200, "", { + headers: { + "content-type": "text/plain", + "x-request-id": "", + }, + }); + + const { requestId, stream } = await prompt.stream( + "What is the capital of France?", + { + token: "secret", + request: { + fetch: fetchMock, + }, + } + ); + + t.is(requestId, ""); + + let data = ""; + const reader = stream.getReader(); + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + data += new TextDecoder().decode(value); + } + + t.deepEqual(data, ""); +});