Skip to content

Commit f7608fb

Browse files
authored
[Inference] Make CI green on main (#1359)
Best reviewed commit per commit - Some missing code for HF Inference tests - Run the format and lint scripts
1 parent fbf3a89 commit f7608fb

File tree

6 files changed

+274
-73
lines changed

6 files changed

+274
-73
lines changed

packages/inference/src/lib/getInferenceProviderMapping.ts

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts";
55
import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference";
66
import { typedInclude } from "../utils/typedInclude";
77

8-
98
export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
109

1110
export type InferenceProviderMapping = Partial<
@@ -24,12 +23,12 @@ export interface InferenceProviderModelMapping {
2423
export async function getInferenceProviderMapping(
2524
params: {
2625
accessToken?: string;
27-
modelId: ModelId,
28-
provider: InferenceProvider,
29-
task: WidgetType
26+
modelId: ModelId;
27+
provider: InferenceProvider;
28+
task: WidgetType;
3029
},
3130
options: {
32-
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>
31+
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
3332
}
3433
): Promise<InferenceProviderModelMapping | null> {
3534
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
@@ -49,7 +48,8 @@ export async function getInferenceProviderMapping(
4948
if (resp.status === 404) {
5049
throw new Error(`Model ${params.modelId} does not exist`);
5150
}
52-
inferenceProviderMapping = await resp.json()
51+
inferenceProviderMapping = await resp
52+
.json()
5353
.then((json) => json.inferenceProviderMapping)
5454
.catch(() => null);
5555
}
@@ -60,7 +60,10 @@ export async function getInferenceProviderMapping(
6060

6161
const providerMapping = inferenceProviderMapping[params.provider];
6262
if (providerMapping) {
63-
const equivalentTasks = params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task]
63+
const equivalentTasks =
64+
params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task)
65+
? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS
66+
: [params.task];
6467
if (!typedInclude(equivalentTasks, providerMapping.task)) {
6568
throw new Error(
6669
`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
@@ -76,7 +79,7 @@ export async function getInferenceProviderMapping(
7679
if (!treeResp.ok) {
7780
throw new Error(`Unable to fetch the model tree for ${params.modelId}.`);
7881
}
79-
const tree: Array<{ type: "file" | "directory"; path: string; }> = await treeResp.json();
82+
const tree: Array<{ type: "file" | "directory"; path: string }> = await treeResp.json();
8083
const adapterWeightsPath = tree.find(({ type, path }) => type === "file" && path.endsWith(".safetensors"))?.path;
8184
if (!adapterWeightsPath) {
8285
throw new Error(`No .safetensors file found in the model tree for ${params.modelId}.`);
@@ -85,9 +88,9 @@ export async function getInferenceProviderMapping(
8588
...providerMapping,
8689
hfModelId: params.modelId,
8790
adapterWeightsPath,
88-
}
91+
};
8992
}
9093
return { ...providerMapping, hfModelId: params.modelId };
9194
}
9295
return null;
93-
}
96+
}

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,13 @@ export async function makeRequestOptions(
4141

4242
if (args.endpointUrl) {
4343
// No need to have maybeModel, or to load default model for a task
44-
return makeRequestOptionsFromResolvedModel(maybeModel ?? args.endpointUrl, providerHelper, args, undefined, options);
44+
return makeRequestOptionsFromResolvedModel(
45+
maybeModel ?? args.endpointUrl,
46+
providerHelper,
47+
args,
48+
undefined,
49+
options
50+
);
4551
}
4652

4753
if (!maybeModel && !task) {
@@ -55,23 +61,26 @@ export async function makeRequestOptions(
5561
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
5662
}
5763

58-
const inferenceProviderMapping = providerHelper.clientSideRoutingOnly ?
59-
{
60-
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
61-
providerId: removeProviderPrefix(maybeModel!, provider),
62-
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
63-
hfModelId: maybeModel!,
64-
status: "live",
65-
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
66-
task: task!
67-
} satisfies InferenceProviderModelMapping
68-
: await getInferenceProviderMapping({
69-
modelId: hfModel,
70-
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
71-
task: task!,
72-
provider,
73-
accessToken: args.accessToken,
74-
}, { fetch: options?.fetch });
64+
const inferenceProviderMapping = providerHelper.clientSideRoutingOnly
65+
? ({
66+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
67+
providerId: removeProviderPrefix(maybeModel!, provider),
68+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
69+
hfModelId: maybeModel!,
70+
status: "live",
71+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
72+
task: task!,
73+
} satisfies InferenceProviderModelMapping)
74+
: await getInferenceProviderMapping(
75+
{
76+
modelId: hfModel,
77+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
78+
task: task!,
79+
provider,
80+
accessToken: args.accessToken,
81+
},
82+
{ fetch: options?.fetch }
83+
);
7584
if (!inferenceProviderMapping) {
7685
throw new Error(`We have not been able to find inference provider information for model ${hfModel}.`);
7786
}

packages/inference/src/providers/consts.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ import { type ModelId } from "../types";
99
*
1010
* We also inject into this dictionary from tests.
1111
*/
12-
export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<InferenceProvider, Record<ModelId, InferenceProviderModelMapping>> = {
12+
export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
13+
InferenceProvider,
14+
Record<ModelId, InferenceProviderModelMapping>
15+
> = {
1316
/**
1417
* "HF model ID" => "Model ID on Inference Provider's side"
1518
*

packages/inference/src/providers/fal-ai.ts

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ abstract class FalAITask extends TaskProviderHelper {
7676
}
7777

7878
function buildLoraPath(modelId: ModelId, adapterWeightsPath: string): string {
79-
return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`
79+
return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
8080
}
8181

8282
export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHelper {
@@ -86,25 +86,31 @@ export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHe
8686
...(params.args.parameters as Record<string, unknown>),
8787
sync_mode: true,
8888
prompt: params.args.inputs,
89-
...(params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath ? {
90-
loras: [{
91-
path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
92-
scale: 1
93-
}]
94-
} : undefined)
89+
...(params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath
90+
? {
91+
loras: [
92+
{
93+
path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
94+
scale: 1,
95+
},
96+
],
97+
}
98+
: undefined),
9599
};
96100

97101
if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
98-
payload.loras = [{
99-
path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
100-
scale: 1
101-
}]
102+
payload.loras = [
103+
{
104+
path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
105+
scale: 1,
106+
},
107+
];
102108
if (params.mapping.providerId === "fal-ai/lora") {
103109
payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0";
104110
}
105111
}
106112

107-
return payload
113+
return payload;
108114
}
109115

110116
override async getResponse(response: FalAITextToImageOutput, outputType?: "url" | "blob"): Promise<string | Blob> {
@@ -160,8 +166,9 @@ export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHe
160166
let status = response.status;
161167

162168
const parsedUrl = new URL(url);
163-
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
164-
}`;
169+
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${
170+
parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
171+
}`;
165172

166173
// extracting the provider model id for status and result urls
167174
// from the response as it might be different from the mapped model in `url`
@@ -253,7 +260,8 @@ export class FalAITextToSpeechTask extends FalAITask {
253260
return await urlResponse.blob();
254261
} catch (error) {
255262
throw new InferenceOutputError(
256-
`Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${error instanceof Error ? error.message : String(error)
263+
`Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${
264+
error instanceof Error ? error.message : String(error)
257265
}`
258266
);
259267
}

packages/inference/src/providers/hf-inference.ts

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ interface AudioToAudioOutput {
8787
label: string;
8888
}
8989

90-
9190
export const EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS = ["feature-extraction", "sentence-similarity"] as const;
9291

9392
export class HFInferenceTask extends TaskProviderHelper {
@@ -217,7 +216,8 @@ export class HFInferenceAudioClassificationTask extends HFInferenceTask implemen
217216

218217
export class HFInferenceAutomaticSpeechRecognitionTask
219218
extends HFInferenceTask
220-
implements AutomaticSpeechRecognitionTaskHelper {
219+
implements AutomaticSpeechRecognitionTaskHelper
220+
{
221221
override async getResponse(response: AutomaticSpeechRecognitionOutput): Promise<AutomaticSpeechRecognitionOutput> {
222222
return response;
223223
}
@@ -250,7 +250,8 @@ export class HFInferenceAudioToAudioTask extends HFInferenceTask implements Audi
250250

251251
export class HFInferenceDocumentQuestionAnsweringTask
252252
extends HFInferenceTask
253-
implements DocumentQuestionAnsweringTaskHelper {
253+
implements DocumentQuestionAnsweringTaskHelper
254+
{
254255
override async getResponse(
255256
response: DocumentQuestionAnsweringOutput
256257
): Promise<DocumentQuestionAnsweringOutput[number]> {
@@ -352,7 +353,8 @@ export class HFInferenceObjectDetectionTask extends HFInferenceTask implements O
352353

353354
export class HFInferenceZeroShotImageClassificationTask
354355
extends HFInferenceTask
355-
implements ZeroShotImageClassificationTaskHelper {
356+
implements ZeroShotImageClassificationTaskHelper
357+
{
356358
override async getResponse(response: ZeroShotImageClassificationOutput): Promise<ZeroShotImageClassificationOutput> {
357359
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
358360
return response;
@@ -378,20 +380,20 @@ export class HFInferenceQuestionAnsweringTask extends HFInferenceTask implements
378380
if (
379381
Array.isArray(response)
380382
? response.every(
381-
(elem) =>
382-
typeof elem === "object" &&
383-
!!elem &&
384-
typeof elem.answer === "string" &&
385-
typeof elem.end === "number" &&
386-
typeof elem.score === "number" &&
387-
typeof elem.start === "number"
388-
)
383+
(elem) =>
384+
typeof elem === "object" &&
385+
!!elem &&
386+
typeof elem.answer === "string" &&
387+
typeof elem.end === "number" &&
388+
typeof elem.score === "number" &&
389+
typeof elem.start === "number"
390+
)
389391
: typeof response === "object" &&
390-
!!response &&
391-
typeof response.answer === "string" &&
392-
typeof response.end === "number" &&
393-
typeof response.score === "number" &&
394-
typeof response.start === "number"
392+
!!response &&
393+
typeof response.answer === "string" &&
394+
typeof response.end === "number" &&
395+
typeof response.score === "number" &&
396+
typeof response.start === "number"
395397
) {
396398
return Array.isArray(response) ? response[0] : response;
397399
}
@@ -536,7 +538,8 @@ export class HFInferenceTabularClassificationTask extends HFInferenceTask implem
536538

537539
export class HFInferenceVisualQuestionAnsweringTask
538540
extends HFInferenceTask
539-
implements VisualQuestionAnsweringTaskHelper {
541+
implements VisualQuestionAnsweringTaskHelper
542+
{
540543
override async getResponse(response: VisualQuestionAnsweringOutput): Promise<VisualQuestionAnsweringOutput[number]> {
541544
if (
542545
Array.isArray(response) &&

0 commit comments

Comments
 (0)