Skip to content

Commit fe0e5e6

Browse files
authored
[InferenceSnippets] move as snippet option + more accurate typing (#1360)
Related to #1349 (review) (requested changes after initial approval, sorry about that 😬) cc @kefranabg This PR: - moves `billTo` inside `opts` => doesn't break the method signature + feel more appropriate - adds some typing to `opts` => not perfect but better than nothing
1 parent 0455210 commit fe0e5e6

File tree

3 files changed

+14
-18
lines changed

3 files changed

+14
-18
lines changed

packages/inference/src/snippets/getInferenceSnippets.ts

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import { templates } from "./templates.exported";
1414
import type { InferenceProviderModelMapping } from "../lib/getInferenceProviderMapping";
1515
import { getProviderHelper } from "../lib/getProviderHelper";
1616

17+
export type InferenceSnippetOptions = { streaming?: boolean; billTo?: string } & Record<string, unknown>;
18+
1719
const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const;
1820
const JS_CLIENTS = ["fetch", "huggingface.js", "openai"] as const;
1921
const SH_CLIENTS = ["curl"] as const;
@@ -120,8 +122,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
120122
accessToken: string,
121123
provider: InferenceProvider,
122124
inferenceProviderMapping?: InferenceProviderModelMapping,
123-
billTo?: string,
124-
opts?: Record<string, unknown>
125+
opts?: InferenceSnippetOptions
125126
): InferenceSnippet[] => {
126127
const providerModelId = inferenceProviderMapping?.providerId ?? model.id;
127128
/// Hacky: hard-code conversational templates here
@@ -155,7 +156,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
155156
inferenceProviderMapping,
156157
{
157158
task,
158-
billTo,
159+
billTo: opts?.billTo,
159160
}
160161
);
161162

@@ -194,7 +195,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
194195
model,
195196
provider,
196197
providerModelId: providerModelId ?? model.id,
197-
billTo,
198+
billTo: opts?.billTo,
198199
};
199200

200201
/// Iterate over clients => check if a snippet exists => generate
@@ -283,8 +284,7 @@ const snippets: Partial<
283284
accessToken: string,
284285
provider: InferenceProvider,
285286
inferenceProviderMapping?: InferenceProviderModelMapping,
286-
billTo?: string,
287-
opts?: Record<string, unknown>
287+
opts?: InferenceSnippetOptions
288288
) => InferenceSnippet[]
289289
>
290290
> = {
@@ -324,11 +324,10 @@ export function getInferenceSnippets(
324324
accessToken: string,
325325
provider: InferenceProvider,
326326
inferenceProviderMapping?: InferenceProviderModelMapping,
327-
billTo?: string,
328327
opts?: Record<string, unknown>
329328
): InferenceSnippet[] {
330329
return model.pipeline_tag && model.pipeline_tag in snippets
331-
? snippets[model.pipeline_tag]?.(model, accessToken, provider, inferenceProviderMapping, billTo, opts) ?? []
330+
? snippets[model.pipeline_tag]?.(model, accessToken, provider, inferenceProviderMapping, opts) ?? []
332331
: [];
333332
}
334333

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
export { getInferenceSnippets } from "./getInferenceSnippets.js";
1+
export { getInferenceSnippets, type InferenceSnippetOptions } from "./getInferenceSnippets.js";

packages/tasks-gen/scripts/generate-snippets-fixtures.ts

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ const TEST_CASES: {
3333
model: ModelDataMinimal;
3434
providers: SnippetInferenceProvider[];
3535
lora?: boolean;
36-
billTo?: string;
37-
opts?: Record<string, unknown>;
36+
opts?: snippets.InferenceSnippetOptions;
3837
}[] = [
3938
{
4039
testName: "automatic-speech-recognition",
@@ -237,8 +236,8 @@ const TEST_CASES: {
237236
tags: ["conversational"],
238237
inference: "",
239238
},
240-
billTo: "huggingface",
241239
providers: ["hf-inference"],
240+
opts: { billTo: "huggingface" },
242241
},
243242
] as const;
244243

@@ -266,7 +265,6 @@ function generateInferenceSnippet(
266265
provider: SnippetInferenceProvider,
267266
task: WidgetType,
268267
lora: boolean = false,
269-
billTo?: string,
270268
opts?: Record<string, unknown>
271269
): InferenceSnippet[] {
272270
const allSnippets = snippets.getInferenceSnippets(
@@ -285,7 +283,6 @@ function generateInferenceSnippet(
285283
}
286284
: {}),
287285
},
288-
billTo,
289286
opts
290287
);
291288
return allSnippets
@@ -341,12 +338,12 @@ if (import.meta.vitest) {
341338
const { describe, expect, it } = import.meta.vitest;
342339

343340
describe("inference API snippets", () => {
344-
TEST_CASES.forEach(({ testName, task, model, providers, lora, billTo, opts }) => {
341+
TEST_CASES.forEach(({ testName, task, model, providers, lora, opts }) => {
345342
describe(testName, () => {
346343
inferenceSnippetLanguages.forEach((language) => {
347344
providers.forEach((provider) => {
348345
it(language, async () => {
349-
const generatedSnippets = generateInferenceSnippet(model, language, provider, task, lora, billTo, opts);
346+
const generatedSnippets = generateInferenceSnippet(model, language, provider, task, lora, opts);
350347
const expectedSnippets = await getExpectedInferenceSnippet(testName, language, provider);
351348
expect(generatedSnippets).toEqual(expectedSnippets);
352349
});
@@ -362,11 +359,11 @@ if (import.meta.vitest) {
362359
await fs.rm(path.join(rootDirFinder(), "snippets-fixtures"), { recursive: true, force: true });
363360

364361
console.debug(" 🏭 Generating new fixtures...");
365-
TEST_CASES.forEach(({ testName, task, model, providers, lora, billTo, opts }) => {
362+
TEST_CASES.forEach(({ testName, task, model, providers, lora, opts }) => {
366363
console.debug(` ${testName} (${providers.join(", ")})`);
367364
inferenceSnippetLanguages.forEach(async (language) => {
368365
providers.forEach(async (provider) => {
369-
const generatedSnippets = generateInferenceSnippet(model, language, provider, task, lora, billTo, opts);
366+
const generatedSnippets = generateInferenceSnippet(model, language, provider, task, lora, opts);
370367
await saveExpectedInferenceSnippet(testName, language, provider, generatedSnippets);
371368
});
372369
});

0 commit comments

Comments
 (0)