Skip to content

[Inference] Implement a "1 class = 1 provider<>task pair" logic to isolate provider-specific code #1315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
fc0926c
refactor providers
hanouticelina Mar 26, 2025
4e6cf94
nit
hanouticelina Mar 26, 2025
5514cde
nit
hanouticelina Mar 26, 2025
84745e3
remove unnecessary check
hanouticelina Mar 26, 2025
9466000
fix linting
hanouticelina Mar 27, 2025
fa3f8f0
implement individual classes for sambanova, cohere and cerebras
hanouticelina Mar 27, 2025
a4b4682
add hf-inference helpers
hanouticelina Mar 27, 2025
4dd7e02
fix text-to-image
hanouticelina Mar 27, 2025
e3cb303
Merge branch 'main' of github.com:huggingface/huggingface.js into ref…
hanouticelina Mar 27, 2025
213e658
use conversational task
hanouticelina Mar 27, 2025
6c3823e
backward compatibility hf-inference tasks
hanouticelina Mar 27, 2025
8ccab73
fix tests
hanouticelina Mar 28, 2025
59d1457
Merge branch 'main' of github.com:huggingface/huggingface.js into ref…
hanouticelina Mar 28, 2025
c6252ae
fixes
hanouticelina Mar 28, 2025
e34f2a2
add text-to-audio
hanouticelina Mar 28, 2025
677afc0
improvements and lint
hanouticelina Mar 28, 2025
cc4b255
remove code and add missing tasks for replicate
hanouticelina Mar 28, 2025
b374452
nit
hanouticelina Mar 28, 2025
d0d0f73
Merge branch 'main' of github.com:huggingface/huggingface.js into ref…
hanouticelina Apr 1, 2025
d138032
regenerate fal-ai snippet
hanouticelina Apr 1, 2025
20a8864
nit
hanouticelina Apr 1, 2025
a9ddea1
no need to 'override'
hanouticelina Apr 1, 2025
5a8c576
fix
hanouticelina Apr 1, 2025
6da3a75
apply suggestions
hanouticelina Apr 1, 2025
f87081b
Merge branch 'main' into refactor-providers
hanouticelina Apr 1, 2025
2276e94
some fixes
hanouticelina Apr 2, 2025
6835182
Merge branch 'refactor-providers' of github.com:huggingface/huggingfa…
hanouticelina Apr 2, 2025
67afa27
fix code style
hanouticelina Apr 2, 2025
9d16e7c
group abstract methods
hanouticelina Apr 2, 2025
7782f2b
Better typing + make task functions generic (#1338)
hanouticelina Apr 4, 2025
7753620
Merge branch 'main' of github.com:huggingface/huggingface.js into ref…
hanouticelina Apr 4, 2025
21f9d34
nit
hanouticelina Apr 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import { BlackForestLabsTextToImageTask } from "../providers/black-forest-labs";
import { CerebrasConversationalTask } from "../providers/cerebras";
import { CohereConversationalTask } from "../providers/cohere";
import { FalAITask, FalAITextToImageTask, FalAITextToVideoTask } from "../providers/fal-ai";
import { FireworksConversationalTask } from "../providers/fireworks-ai";
import {
HyperbolicConversationalTask,
HyperbolicTextGenerationTask,
HyperbolicTextToImageTask,
} from "../providers/hyperbolic";
import { NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask } from "../providers/nebius";
import { NovitaConversationalTask, NovitaTextGenerationTask } from "../providers/novita";
import { OpenAIConversationalTask } from "../providers/openai";
import type { TaskProviderHelper } from "../providers/providerHelper";
import { ReplicateTextToImageTask, ReplicateTextToSpeechTask, ReplicateTextToVideoTask } from "../providers/replicate";
import { SambanovaConversationalTask } from "../providers/sambanova";
import { TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask } from "../providers/together";
import type { InferenceProvider, InferenceTask } from "../types";

export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask, TaskProviderHelper>>> = {
"black-forest-labs": {
"text-to-image": new BlackForestLabsTextToImageTask(),
},
cerebras: {
conversational: new CerebrasConversationalTask(),
},
cohere: {
conversational: new CohereConversationalTask(),
},
"fal-ai": {
// TODO: Add automatic-speech-recognition task helper
// "automatic-speech-recognition": new FalAIAutomaticSpeechRecognitionTask(),
"text-to-image": new FalAITextToImageTask(),
"text-to-speech": new FalAITask("text-to-speech"),
"text-to-video": new FalAITextToVideoTask(),
},
"fireworks-ai": {
conversational: new FireworksConversationalTask(),
},
"hf-inference": {
//TODO: Add the correct provider helpers for hf-inference
},
hyperbolic: {
"text-to-image": new HyperbolicTextToImageTask(),
conversational: new HyperbolicConversationalTask(),
"text-generation": new HyperbolicTextGenerationTask(),
},
nebius: {
"text-to-image": new NebiusTextToImageTask(),
conversational: new NebiusConversationalTask(),
"text-generation": new NebiusTextGenerationTask(),
},
novita: {
"text-generation": new NovitaTextGenerationTask(),
conversational: new NovitaConversationalTask(),
},
openai: {
conversational: new OpenAIConversationalTask(),
},
replicate: {
"text-to-image": new ReplicateTextToImageTask(),
"text-to-speech": new ReplicateTextToSpeechTask(),
"text-to-video": new ReplicateTextToVideoTask(),
},
sambanova: {
conversational: new SambanovaConversationalTask(),
},
together: {
"text-to-image": new TogetherTextToImageTask(),
"text-generation": new TogetherTextGenerationTask(),
conversational: new TogetherConversationalTask(),
},
};

/**
* Get provider helper instance by name and task
*/
export function getProviderHelper(provider: InferenceProvider, task: InferenceTask | undefined): TaskProviderHelper {
if (!task) {
throw new Error("you need to provide a task name, e.g. 'text-to-image'");
}
if (!(provider in PROVIDERS)) {
throw new Error(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`);
}
const providerTasks = PROVIDERS[provider];
if (!providerTasks || !(task in providerTasks)) {
throw new Error(
`Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(providerTasks ?? {})}`
);
}
return providerTasks[task] as TaskProviderHelper;
}
68 changes: 15 additions & 53 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,9 @@
import { name as packageName, version as packageVersion } from "../../package.json";
import { HF_HUB_URL, HF_ROUTER_URL } from "../config";
import { BLACK_FOREST_LABS_CONFIG } from "../providers/black-forest-labs";
import { CEREBRAS_CONFIG } from "../providers/cerebras";
import { COHERE_CONFIG } from "../providers/cohere";
import { FAL_AI_CONFIG } from "../providers/fal-ai";
import { FIREWORKS_AI_CONFIG } from "../providers/fireworks-ai";
import { HF_INFERENCE_CONFIG } from "../providers/hf-inference";
import { HYPERBOLIC_CONFIG } from "../providers/hyperbolic";
import { NEBIUS_CONFIG } from "../providers/nebius";
import { NOVITA_CONFIG } from "../providers/novita";
import { REPLICATE_CONFIG } from "../providers/replicate";
import { SAMBANOVA_CONFIG } from "../providers/sambanova";
import { TOGETHER_CONFIG } from "../providers/together";
import { OPENAI_CONFIG } from "../providers/openai";
import type { InferenceProvider, InferenceTask, Options, ProviderConfig, RequestArgs } from "../types";
import { isUrl } from "./isUrl";
import { version as packageVersion, name as packageName } from "../../package.json";
import type { InferenceTask, Options, RequestArgs } from "../types";
import { getProviderHelper } from "./getProviderHelper";
import { getProviderModelId } from "./getProviderModelId";
import { isUrl } from "./isUrl";

const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;

Expand All @@ -25,25 +13,6 @@ const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
*/
let tasks: Record<string, { models: { id: string }[] }> | null = null;

/**
* Config to define how to serialize requests for each provider
*/
const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
"black-forest-labs": BLACK_FOREST_LABS_CONFIG,
cerebras: CEREBRAS_CONFIG,
cohere: COHERE_CONFIG,
"fal-ai": FAL_AI_CONFIG,
"fireworks-ai": FIREWORKS_AI_CONFIG,
"hf-inference": HF_INFERENCE_CONFIG,
hyperbolic: HYPERBOLIC_CONFIG,
openai: OPENAI_CONFIG,
nebius: NEBIUS_CONFIG,
novita: NOVITA_CONFIG,
replicate: REPLICATE_CONFIG,
sambanova: SAMBANOVA_CONFIG,
together: TOGETHER_CONFIG,
};

/**
* Helper that prepares request arguments.
* This async version handle the model ID resolution step.
Expand All @@ -61,9 +30,8 @@ export async function makeRequestOptions(
): Promise<{ url: string; info: RequestInit }> {
const { provider: maybeProvider, model: maybeModel } = args;
const provider = maybeProvider ?? "hf-inference";
const providerConfig = providerConfigs[provider];
const { task, chatCompletion } = options ?? {};

const providerHelper = getProviderHelper(provider, task);
// Validate inputs
if (args.endpointUrl && provider !== "hf-inference") {
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
Expand All @@ -74,16 +42,16 @@ export async function makeRequestOptions(
if (!maybeModel && !task) {
throw new Error("No model provided, and no task has been specified.");
}
if (!providerConfig) {
throw new Error(`No provider config found for provider ${provider}`);
if (!providerHelper) {
throw new Error(`No provider helper found for provider ${provider}`);
}
if (providerConfig.clientSideRoutingOnly && !maybeModel) {
if (providerHelper.clientSideRoutingOnly && !maybeModel) {
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
}

// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const hfModel = maybeModel ?? (await loadDefaultModel(task!));
const resolvedModel = providerConfig.clientSideRoutingOnly
const resolvedModel = providerHelper.clientSideRoutingOnly
? // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
removeProviderPrefix(maybeModel!, provider)
: await getProviderModelId({ model: hfModel, provider }, args, {
Expand Down Expand Up @@ -115,12 +83,10 @@ export function makeRequestOptionsFromResolvedModel(
void model;

const provider = maybeProvider ?? "hf-inference";
const providerConfig = providerConfigs[provider];

const { includeCredentials, task, chatCompletion, signal } = options ?? {};

const providerHelper = getProviderHelper(provider, task);
const authMethod = (() => {
if (providerConfig.clientSideRoutingOnly) {
if (providerHelper.clientSideRoutingOnly) {
// Closed-source providers require an accessToken (cannot be routed).
if (accessToken && accessToken.startsWith("hf_")) {
throw new Error(`Provider ${provider} is closed-source and does not support HF tokens.`);
Expand All @@ -142,24 +108,22 @@ export function makeRequestOptionsFromResolvedModel(
? chatCompletion
? endpointUrl + `/v1/chat/completions`
: endpointUrl
: providerConfig.makeUrl({
: providerHelper.makeUrl({
authMethod,
baseUrl:
authMethod !== "provider-key"
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider)
: providerConfig.makeBaseUrl(task),
: providerHelper.makeBaseUrl(),
model: resolvedModel,
chatCompletion,
task,
});

// Make headers
const binary = "data" in args && !!args.data;
const headers = providerConfig.makeHeaders({
const headers = providerHelper.prepareHeaders({
accessToken,
authMethod,
});

// Add content-type to headers
if (!binary) {
headers["Content-Type"] = "application/json";
Expand All @@ -177,14 +141,13 @@ export function makeRequestOptionsFromResolvedModel(
const body = binary
? args.data
: JSON.stringify(
providerConfig.makeBody({
providerHelper.makeBody({
args: remainingArgs as Record<string, unknown>,
model: resolvedModel,
task,
chatCompletion,
})
);

/**
* For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
*/
Expand All @@ -202,7 +165,6 @@ export function makeRequestOptionsFromResolvedModel(
...(credentials ? { credentials } : undefined),
signal,
};

return { url, info };
}

Expand Down
83 changes: 61 additions & 22 deletions packages/inference/src/providers/black-forest-labs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,72 @@
*
* Thanks!
*/
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
import { InferenceOutputError } from "../lib/InferenceOutputError";
import type { BodyParams, HeaderParams, UrlParams } from "../types";
import { delay } from "../utils/delay";
import { omit } from "../utils/omit";
import { TaskProviderHelper } from "./providerHelper";

const BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";

const makeBaseUrl = (): string => {
return BLACK_FOREST_LABS_AI_API_BASE_URL;
};
export class BlackForestLabsTextToImageTask extends TaskProviderHelper {
constructor() {
super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL, "text-to-image");
}

const makeBody = (params: BodyParams): Record<string, unknown> => {
return params.args;
};
override makeBody(params: BodyParams): Record<string, unknown> {
return {
...omit(params.args, ["inputs", "parameters"]),
...(params.args.parameters as Record<string, unknown>),
prompt: params.args.inputs,
};
}

const makeHeaders = (params: HeaderParams): Record<string, string> => {
if (params.authMethod === "provider-key") {
return { "X-Key": `${params.accessToken}` };
} else {
return { Authorization: `Bearer ${params.accessToken}` };
override prepareHeaders(params: HeaderParams): Record<string, string> {
if (params.authMethod !== "provider-key") {
return { Authorization: `Bearer ${params.accessToken}` };
} else {
return { "X-Key": `${params.accessToken}` };
}
}
};

const makeUrl = (params: UrlParams): string => {
return `${params.baseUrl}/v1/${params.model}`;
};
override makeRoute(params: UrlParams): string {
if (!params) {
throw new Error("Params are required");
}
return `/v1/${params.model}`;
}

export const BLACK_FOREST_LABS_CONFIG: ProviderConfig = {
makeBaseUrl,
makeBody,
makeHeaders,
makeUrl,
};
async getResponse(res: Response, outputType?: "url" | "blob"): Promise<string | Blob> {
const urlObj = new URL(res.url);
for (let step = 0; step < 5; step++) {
await delay(1000);
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
urlObj.searchParams.set("attempt", step.toString(10));
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
if (!resp.ok) {
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
}
const payload = await resp.json();
if (
typeof payload === "object" &&
payload &&
"status" in payload &&
typeof payload.status === "string" &&
payload.status === "Ready" &&
"result" in payload &&
typeof payload.result === "object" &&
payload.result &&
"sample" in payload.result &&
typeof payload.result.sample === "string"
) {
if (outputType === "url") {
return payload.result.sample;
}
const image = await fetch(payload.result.sample);
return await image.blob();
}
}
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
}
}
33 changes: 6 additions & 27 deletions packages/inference/src/providers/cerebras.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,11 @@
*
* Thanks!
*/
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";

const CEREBRAS_API_BASE_URL = "https://api.cerebras.ai";
import { BaseConversationalTask } from "./providerHelper";

const makeBaseUrl = (): string => {
return CEREBRAS_API_BASE_URL;
};

const makeBody = (params: BodyParams): Record<string, unknown> => {
return {
...params.args,
model: params.model,
};
};

const makeHeaders = (params: HeaderParams): Record<string, string> => {
return { Authorization: `Bearer ${params.accessToken}` };
};

const makeUrl = (params: UrlParams): string => {
return `${params.baseUrl}/v1/chat/completions`;
};

export const CEREBRAS_CONFIG: ProviderConfig = {
makeBaseUrl,
makeBody,
makeHeaders,
makeUrl,
};
export class CerebrasConversationalTask extends BaseConversationalTask {
constructor() {
super("cerebras", "https://api.cerebras.ai");
}
}
Loading
Loading