Skip to content

Commit e4be4e4

Browse files
committed
Merge remote-tracking branch 'origin/main' into provider-featherless-ai
2 parents 903cace + fe0e5e6 commit e4be4e4

File tree

81 files changed

+1270
-187
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+1270
-187
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ You can run our packages with vanilla JS, without any bundler, by using a CDN or
9797

9898
```html
9999
<script type="module">
100-
import { InferenceClient } from 'https://cdn.jsdelivr.net/npm/@huggingface/inference@3.7.1/+esm';
100+
import { InferenceClient } from 'https://cdn.jsdelivr.net/npm/@huggingface/inference@3.8.0/+esm';
101101
import { createRepo, commit, deleteRepo, listFiles } from "https://cdn.jsdelivr.net/npm/@huggingface/[email protected]/+esm";
102102
</script>
103103
```

packages/inference/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@huggingface/inference",
3-
"version": "3.7.1",
3+
"version": "3.8.0",
44
"packageManager": "[email protected]",
55
"license": "MIT",
66
"author": "Hugging Face and Tim Mikeladze <[email protected]>",
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import type { WidgetType } from "@huggingface/tasks";
2+
import type { InferenceProvider, ModelId } from "../types";
3+
import { HF_HUB_URL } from "../config";
4+
import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts";
5+
import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference";
6+
import { typedInclude } from "../utils/typedInclude";
7+
8+
export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
9+
10+
export type InferenceProviderMapping = Partial<
11+
Record<InferenceProvider, Omit<InferenceProviderModelMapping, "hfModelId" | "adapterWeightsPath">>
12+
>;
13+
14+
export interface InferenceProviderModelMapping {
15+
adapter?: string;
16+
adapterWeightsPath?: string;
17+
hfModelId: ModelId;
18+
providerId: string;
19+
status: "live" | "staging";
20+
task: WidgetType;
21+
}
22+
23+
export async function getInferenceProviderMapping(
24+
params: {
25+
accessToken?: string;
26+
modelId: ModelId;
27+
provider: InferenceProvider;
28+
task: WidgetType;
29+
},
30+
options: {
31+
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
32+
}
33+
): Promise<InferenceProviderModelMapping | null> {
34+
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
35+
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
36+
}
37+
let inferenceProviderMapping: InferenceProviderMapping | null;
38+
if (inferenceProviderMappingCache.has(params.modelId)) {
39+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
40+
inferenceProviderMapping = inferenceProviderMappingCache.get(params.modelId)!;
41+
} else {
42+
const resp = await (options?.fetch ?? fetch)(
43+
`${HF_HUB_URL}/api/models/${params.modelId}?expand[]=inferenceProviderMapping`,
44+
{
45+
headers: params.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${params.accessToken}` } : {},
46+
}
47+
);
48+
if (resp.status === 404) {
49+
throw new Error(`Model ${params.modelId} does not exist`);
50+
}
51+
inferenceProviderMapping = await resp
52+
.json()
53+
.then((json) => json.inferenceProviderMapping)
54+
.catch(() => null);
55+
}
56+
57+
if (!inferenceProviderMapping) {
58+
throw new Error(`We have not been able to find inference provider information for model ${params.modelId}.`);
59+
}
60+
61+
const providerMapping = inferenceProviderMapping[params.provider];
62+
if (providerMapping) {
63+
const equivalentTasks =
64+
params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task)
65+
? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS
66+
: [params.task];
67+
if (!typedInclude(equivalentTasks, providerMapping.task)) {
68+
throw new Error(
69+
`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
70+
);
71+
}
72+
if (providerMapping.status === "staging") {
73+
console.warn(
74+
`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
75+
);
76+
}
77+
if (providerMapping.adapter === "lora") {
78+
const treeResp = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${params.modelId}/tree/main`);
79+
if (!treeResp.ok) {
80+
throw new Error(`Unable to fetch the model tree for ${params.modelId}.`);
81+
}
82+
const tree: Array<{ type: "file" | "directory"; path: string }> = await treeResp.json();
83+
const adapterWeightsPath = tree.find(({ type, path }) => type === "file" && path.endsWith(".safetensors"))?.path;
84+
if (!adapterWeightsPath) {
85+
throw new Error(`No .safetensors file found in the model tree for ${params.modelId}.`);
86+
}
87+
return {
88+
...providerMapping,
89+
hfModelId: params.modelId,
90+
adapterWeightsPath,
91+
};
92+
}
93+
return { ...providerMapping, hfModelId: params.modelId };
94+
}
95+
return null;
96+
}

packages/inference/src/lib/getProviderModelId.ts

Lines changed: 0 additions & 74 deletions
This file was deleted.

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import { name as packageName, version as packageVersion } from "../../package.json";
22
import { HF_HEADER_X_BILL_TO, HF_HUB_URL } from "../config";
33
import type { InferenceTask, Options, RequestArgs } from "../types";
4+
import type { InferenceProviderModelMapping } from "./getInferenceProviderMapping";
5+
import { getInferenceProviderMapping } from "./getInferenceProviderMapping";
46
import type { getProviderHelper } from "./getProviderHelper";
5-
import { getProviderModelId } from "./getProviderModelId";
67
import { isUrl } from "./isUrl";
78

89
/**
@@ -40,7 +41,13 @@ export async function makeRequestOptions(
4041

4142
if (args.endpointUrl) {
4243
// No need to have maybeModel, or to load default model for a task
43-
return makeRequestOptionsFromResolvedModel(maybeModel ?? args.endpointUrl, providerHelper, args, options);
44+
return makeRequestOptionsFromResolvedModel(
45+
maybeModel ?? args.endpointUrl,
46+
providerHelper,
47+
args,
48+
undefined,
49+
options
50+
);
4451
}
4552

4653
if (!maybeModel && !task) {
@@ -54,16 +61,38 @@ export async function makeRequestOptions(
5461
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
5562
}
5663

57-
const resolvedModel = providerHelper.clientSideRoutingOnly
58-
? // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
59-
removeProviderPrefix(maybeModel!, provider)
60-
: await getProviderModelId({ model: hfModel, provider }, args, {
61-
task,
62-
fetch: options?.fetch,
63-
});
64+
const inferenceProviderMapping = providerHelper.clientSideRoutingOnly
65+
? ({
66+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
67+
providerId: removeProviderPrefix(maybeModel!, provider),
68+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
69+
hfModelId: maybeModel!,
70+
status: "live",
71+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
72+
task: task!,
73+
} satisfies InferenceProviderModelMapping)
74+
: await getInferenceProviderMapping(
75+
{
76+
modelId: hfModel,
77+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
78+
task: task!,
79+
provider,
80+
accessToken: args.accessToken,
81+
},
82+
{ fetch: options?.fetch }
83+
);
84+
if (!inferenceProviderMapping) {
85+
throw new Error(`We have not been able to find inference provider information for model ${hfModel}.`);
86+
}
6487

6588
// Use the sync version with the resolved model
66-
return makeRequestOptionsFromResolvedModel(resolvedModel, providerHelper, args, options);
89+
return makeRequestOptionsFromResolvedModel(
90+
inferenceProviderMapping.providerId,
91+
providerHelper,
92+
args,
93+
inferenceProviderMapping,
94+
options
95+
);
6796
}
6897

6998
/**
@@ -77,6 +106,7 @@ export function makeRequestOptionsFromResolvedModel(
77106
data?: Blob | ArrayBuffer;
78107
stream?: boolean;
79108
},
109+
mapping: InferenceProviderModelMapping | undefined,
80110
options?: Options & {
81111
task?: InferenceTask;
82112
}
@@ -138,6 +168,7 @@ export function makeRequestOptionsFromResolvedModel(
138168
args: remainingArgs as Record<string, unknown>,
139169
model: resolvedModel,
140170
task,
171+
mapping,
141172
});
142173
/**
143174
* For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error

packages/inference/src/providers/consts.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1+
import type { InferenceProviderModelMapping } from "../lib/getInferenceProviderMapping";
12
import type { InferenceProvider } from "../types";
23
import { type ModelId } from "../types";
34

4-
type ProviderId = string;
55
/**
66
* If you want to try to run inference for a new model locally before it's registered on huggingface.co
77
* for a given Inference Provider,
88
* you can add it to the following dictionary, for dev purposes.
99
*
1010
* We also inject into this dictionary from tests.
1111
*/
12-
export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelId, ProviderId>> = {
12+
export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
13+
InferenceProvider,
14+
Record<ModelId, InferenceProviderModelMapping>
15+
> = {
1316
/**
1417
* "HF model ID" => "Model ID on Inference Provider's side"
1518
*

packages/inference/src/providers/fal-ai.ts

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
1818
import { InferenceOutputError } from "../lib/InferenceOutputError";
1919
import { isUrl } from "../lib/isUrl";
20-
import type { BodyParams, HeaderParams, UrlParams } from "../types";
20+
import type { BodyParams, HeaderParams, ModelId, UrlParams } from "../types";
2121
import { delay } from "../utils/delay";
2222
import { omit } from "../utils/omit";
2323
import {
@@ -26,6 +26,7 @@ import {
2626
type TextToImageTaskHelper,
2727
type TextToVideoTaskHelper,
2828
} from "./providerHelper";
29+
import { HF_HUB_URL } from "../config";
2930

3031
export interface FalAiQueueOutput {
3132
request_id: string;
@@ -74,14 +75,42 @@ abstract class FalAITask extends TaskProviderHelper {
7475
}
7576
}
7677

78+
function buildLoraPath(modelId: ModelId, adapterWeightsPath: string): string {
79+
return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
80+
}
81+
7782
export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHelper {
7883
override preparePayload(params: BodyParams): Record<string, unknown> {
79-
return {
84+
const payload: Record<string, unknown> = {
8085
...omit(params.args, ["inputs", "parameters"]),
8186
...(params.args.parameters as Record<string, unknown>),
8287
sync_mode: true,
8388
prompt: params.args.inputs,
89+
...(params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath
90+
? {
91+
loras: [
92+
{
93+
path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
94+
scale: 1,
95+
},
96+
],
97+
}
98+
: undefined),
8499
};
100+
101+
if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
102+
payload.loras = [
103+
{
104+
path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
105+
scale: 1,
106+
},
107+
];
108+
if (params.mapping.providerId === "fal-ai/lora") {
109+
payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0";
110+
}
111+
}
112+
113+
return payload;
85114
}
86115

87116
override async getResponse(response: FalAITextToImageOutput, outputType?: "url" | "blob"): Promise<string | Blob> {

packages/inference/src/providers/hf-inference.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ interface AudioToAudioOutput {
8787
label: string;
8888
}
8989

90+
export const EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS = ["feature-extraction", "sentence-similarity"] as const;
91+
9092
export class HFInferenceTask extends TaskProviderHelper {
9193
constructor() {
9294
super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);

0 commit comments

Comments
 (0)