diff --git a/packages/inference/README.md b/packages/inference/README.md index 83bfdc7f01..ad5c9fc1a6 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -48,6 +48,7 @@ You can send inference requests to third-party providers with the inference clie Currently, we support the following providers: - [Fal.ai](https://fal.ai) +- [Featherless AI](https://featherless.ai) - [Fireworks AI](https://fireworks.ai) - [Hyperbolic](https://hyperbolic.xyz) - [Nebius](https://studio.nebius.ai) @@ -78,6 +79,7 @@ When authenticated with a third-party provider key, the request is made directly Only a subset of models are supported when requesting third-party providers. You can check the list of supported models per pipeline tasks here: - [Fal.ai supported models](https://huggingface.co/api/partners/fal-ai/models) +- [Featherless AI supported models](https://huggingface.co/api/partners/featherless-ai/models) - [Fireworks AI supported models](https://huggingface.co/api/partners/fireworks-ai/models) - [Hyperbolic supported models](https://huggingface.co/api/partners/hyperbolic/models) - [Nebius supported models](https://huggingface.co/api/partners/nebius/models) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 0c8de60326..14bd941987 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -2,10 +2,10 @@ import * as BlackForestLabs from "../providers/black-forest-labs"; import * as Cerebras from "../providers/cerebras"; import * as Cohere from "../providers/cohere"; import * as FalAI from "../providers/fal-ai"; +import * as FeatherlessAI from "../providers/featherless-ai"; import * as Fireworks from "../providers/fireworks-ai"; import * as Groq from "../providers/groq"; import * as HFInference from "../providers/hf-inference"; - import * as Hyperbolic from "../providers/hyperbolic"; import * as Nebius from "../providers/nebius"; import * as Novita from "../providers/novita"; @@ -64,6 +64,10 @@ export const PROVIDERS: Record { + choices: Array<{ + text: string; + finish_reason: TextGenerationOutputFinishReason; + seed: number; + logprobs: unknown; + index: number; + }>; +} + +const FEATHERLESS_API_BASE_URL = "https://api.featherless.ai"; + +export class FeatherlessAIConversationalTask extends BaseConversationalTask { + constructor() { + super("featherless-ai", FEATHERLESS_API_BASE_URL); + } +} + +export class FeatherlessAITextGenerationTask extends BaseTextGenerationTask { + constructor() { + super("featherless-ai", FEATHERLESS_API_BASE_URL); + } + + override preparePayload(params: BodyParams): Record { + return { + ...params.args, + ...params.args.parameters, + model: params.model, + prompt: params.args.inputs, + }; + } + + override async getResponse(response: FeatherlessAITextCompletionOutput): Promise { + if ( + typeof response === "object" && + "choices" in response && + Array.isArray(response?.choices) && + typeof response?.model === "string" + ) { + const completion = response.choices[0]; + return { + generated_text: completion.text, + }; + } + throw new InferenceOutputError("Expected Featherless AI text generation response format"); + } +} diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index b1db14c567..e5870f6ef3 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -42,6 +42,7 @@ export const INFERENCE_PROVIDERS = [ "cerebras", "cohere", "fal-ai", + "featherless-ai", "fireworks-ai", "groq", "hf-inference", diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 60be17b186..c64a396d37 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1045,6 +1045,79 @@ describe.skip("InferenceClient", () => { TIMEOUT ); + describe.concurrent( + "Featherless", + () => { + HARDCODED_MODEL_INFERENCE_MAPPING["featherless-ai"] = { + "meta-llama/Llama-3.1-8B": { + providerId: "meta-llama/Meta-Llama-3.1-8B", + hfModelId: "meta-llama/Llama-3.1-8B", + task: "text-generation", + status: "live", + }, + "meta-llama/Llama-3.1-8B-Instruct": { + providerId: "meta-llama/Meta-Llama-3.1-8B-Instruct", + hfModelId: "meta-llama/Llama-3.1-8B-Instruct", + task: "text-generation", + status: "live", + }, + }; + + it("chatCompletion", async () => { + const res = await chatCompletion({ + accessToken: env.HF_FEATHERLESS_KEY ?? "dummy", + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "featherless-ai", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + temperature: 0.1, + }); + + expect(res).toBeDefined(); + expect(res.choices).toBeDefined(); + expect(res.choices?.length).toBeGreaterThan(0); + + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toBeDefined(); + expect(typeof completion).toBe("string"); + expect(completion).toContain("two"); + } + }); + + it("chatCompletion stream", async () => { + const stream = chatCompletionStream({ + accessToken: env.HF_FEATHERLESS_KEY ?? "dummy", + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "featherless-ai", + messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }], + }) as AsyncGenerator; + let out = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + out += chunk.choices[0].delta.content; + } + } + expect(out).toContain("2"); + }); + + it("textGeneration", async () => { + const res = await textGeneration({ + accessToken: env.HF_FEATHERLESS_KEY ?? "dummy", + model: "meta-llama/Llama-3.1-8B", + provider: "featherless-ai", + inputs: "Paris is a city of ", + parameters: { + temperature: 0, + top_p: 0.01, + max_tokens: 10, + }, + }); + expect(res).toMatchObject({ generated_text: "2.2 million people, and it is the" }); + }); + }, + TIMEOUT + ); + describe.concurrent( "Replicate", () => {