Skip to content

Commit 985e5fd

Browse files
committed
[Inference Snippet] Add adirectRequest option (false by default)
1 parent 1131b56 commit 985e5fd

File tree

28 files changed

+203
-61
lines changed

28 files changed

+203
-61
lines changed

packages/inference/src/snippets/getInferenceSnippets.ts

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@ import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions.j
1414
import type { InferenceProviderOrPolicy, InferenceTask, RequestArgs } from "../types.js";
1515
import { templates } from "./templates.exported.js";
1616

17-
export type InferenceSnippetOptions = { streaming?: boolean; billTo?: string; accessToken?: string } & Record<
18-
string,
19-
unknown
20-
>;
17+
export type InferenceSnippetOptions = {
18+
streaming?: boolean;
19+
billTo?: string;
20+
accessToken?: string;
21+
directRequest?: boolean;
22+
} & Record<string, unknown>;
2123

2224
const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const;
2325
const JS_CLIENTS = ["fetch", "huggingface.js", "openai"] as const;
@@ -124,7 +126,10 @@ const HF_JS_METHODS: Partial<Record<WidgetType, string>> = {
124126
translation: "translation",
125127
};
126128

127-
const ACCESS_TOKEN_PLACEHOLDER = "<ACCESS_TOKEN>"; // Placeholder to replace with env variable in snippets
129+
// Placeholders to replace with env variable in snippets
130+
// little hack to support both direct requests and routing => routed requests should start with "hf_"
131+
const ACCESS_TOKEN_ROUTING_PLACEHOLDER = "hf_token_placeholder";
132+
const ACCESS_TOKEN_DIRECT_REQUEST_PLACEHOLDER = "not_hf_token_placeholder";
128133

129134
// Snippet generators
130135
const snippetGenerator = (templateName: string, inputPreparationFn?: InputPreparationFn) => {
@@ -153,7 +158,11 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
153158
console.error(`Failed to get provider helper for ${provider} (${task})`, e);
154159
return [];
155160
}
156-
const accessTokenOrPlaceholder = opts?.accessToken ?? ACCESS_TOKEN_PLACEHOLDER;
161+
162+
const placeholder = opts?.directRequest
163+
? ACCESS_TOKEN_DIRECT_REQUEST_PLACEHOLDER
164+
: ACCESS_TOKEN_ROUTING_PLACEHOLDER;
165+
const accessTokenOrPlaceholder = opts?.accessToken ?? placeholder;
157166

158167
/// Prepare inputs + make request
159168
const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
@@ -255,8 +264,8 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
255264
}
256265

257266
/// Replace access token placeholder
258-
if (snippet.includes(ACCESS_TOKEN_PLACEHOLDER)) {
259-
snippet = replaceAccessTokenPlaceholder(snippet, language, provider);
267+
if (snippet.includes(placeholder)) {
268+
snippet = replaceAccessTokenPlaceholder(opts?.directRequest, placeholder, snippet, language, provider);
260269
}
261270

262271
/// Snippet is ready!
@@ -431,6 +440,8 @@ function removeSuffix(str: string, suffix: string) {
431440
}
432441

433442
function replaceAccessTokenPlaceholder(
443+
directRequest: boolean | undefined,
444+
placeholder: string,
434445
snippet: string,
435446
language: InferenceSnippetLanguage,
436447
provider: InferenceProviderOrPolicy
@@ -439,46 +450,57 @@ function replaceAccessTokenPlaceholder(
439450
// Once snippets are rendered, we replace the placeholder with code to fetch the access token from an environment variable.
440451

441452
// Determine if HF_TOKEN or specific provider token should be used
442-
const accessTokenEnvVar =
443-
!snippet.includes("https://") || // no URL provided => using a client => use $HF_TOKEN
444-
snippet.includes("https://router.huggingface.co") || // explicit routed request => use $HF_TOKEN
445-
provider == "hf-inference" // hf-inference provider => use $HF_TOKEN
446-
? "HF_TOKEN"
447-
: provider.toUpperCase().replace("-", "_") + "_API_KEY"; // e.g. "REPLICATE_API_KEY"
453+
const useHfToken =
454+
provider == "hf-inference" || // hf-inference provider => use $HF_TOKEN
455+
(!directRequest && // if explicit directRequest => use provider-specific token
456+
(!snippet.includes("https://") || // no URL provided => using a client => use $HF_TOKEN
457+
snippet.includes("https://router.huggingface.co"))); // explicit routed request => use $HF_TOKEN
458+
459+
const accessTokenEnvVar = useHfToken
460+
? "HF_TOKEN" // e.g. routed request or hf-inference
461+
: provider.toUpperCase().replace("-", "_") + "_API_KEY"; // e.g. "REPLICATE_API_KEY"
448462

449463
// Replace the placeholder with the env variable
450464
if (language === "sh") {
451465
snippet = snippet.replace(
452-
`'Authorization: Bearer ${ACCESS_TOKEN_PLACEHOLDER}'`,
466+
`'Authorization: Bearer ${placeholder}'`,
453467
`"Authorization: Bearer $${accessTokenEnvVar}"` // e.g. "Authorization: Bearer $HF_TOKEN"
454468
);
455469
} else if (language === "python") {
456470
snippet = "import os\n" + snippet;
457471
snippet = snippet.replace(
458-
`"${ACCESS_TOKEN_PLACEHOLDER}"`,
472+
`"${placeholder}"`,
459473
`os.environ["${accessTokenEnvVar}"]` // e.g. os.environ["HF_TOKEN")
460474
);
461475
snippet = snippet.replace(
462-
`"Bearer ${ACCESS_TOKEN_PLACEHOLDER}"`,
476+
`"Bearer ${placeholder}"`,
463477
`f"Bearer {os.environ['${accessTokenEnvVar}']}"` // e.g. f"Bearer {os.environ['HF_TOKEN']}"
464478
);
465479
snippet = snippet.replace(
466-
`"Key ${ACCESS_TOKEN_PLACEHOLDER}"`,
480+
`"Key ${placeholder}"`,
467481
`f"Key {os.environ['${accessTokenEnvVar}']}"` // e.g. f"Key {os.environ['FAL_AI_API_KEY']}"
468482
);
483+
snippet = snippet.replace(
484+
`"X-Key ${placeholder}"`,
485+
`f"X-Key {os.environ['${accessTokenEnvVar}']}"` // e.g. f"X-Key {os.environ['BLACK_FOREST_LABS_API_KEY']}"
486+
);
469487
} else if (language === "js") {
470488
snippet = snippet.replace(
471-
`"${ACCESS_TOKEN_PLACEHOLDER}"`,
489+
`"${placeholder}"`,
472490
`process.env.${accessTokenEnvVar}` // e.g. process.env.HF_TOKEN
473491
);
474492
snippet = snippet.replace(
475-
`Authorization: "Bearer ${ACCESS_TOKEN_PLACEHOLDER}",`,
493+
`Authorization: "Bearer ${placeholder}",`,
476494
`Authorization: \`Bearer $\{process.env.${accessTokenEnvVar}}\`,` // e.g. Authorization: `Bearer ${process.env.HF_TOKEN}`,
477495
);
478496
snippet = snippet.replace(
479-
`Authorization: "Key ${ACCESS_TOKEN_PLACEHOLDER}",`,
497+
`Authorization: "Key ${placeholder}",`,
480498
`Authorization: \`Key $\{process.env.${accessTokenEnvVar}}\`,` // e.g. Authorization: `Key ${process.env.FAL_AI_API_KEY}`,
481499
);
500+
snippet = snippet.replace(
501+
`Authorization: "X-Key ${placeholder}",`,
502+
`Authorization: \`X-Key $\{process.env.${accessTokenEnvVar}}\`,` // e.g. Authorization: `X-Key ${process.env.BLACK_FOREST_LABS_AI_API_KEY}`,
503+
);
482504
}
483505
return snippet;
484506
}

packages/tasks-gen/scripts/generate-snippets-fixtures.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,18 @@ const TEST_CASES: {
252252
providers: ["hf-inference"],
253253
opts: { accessToken: "hf_xxx" },
254254
},
255+
{
256+
testName: "explicit-direct-request",
257+
task: "conversational",
258+
model: {
259+
id: "meta-llama/Llama-3.1-8B-Instruct",
260+
pipeline_tag: "text-generation",
261+
tags: ["conversational"],
262+
inference: "",
263+
},
264+
providers: ["together"],
265+
opts: { directRequest: true },
266+
},
255267
{
256268
testName: "text-to-speech",
257269
task: "text-to-speech",

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/js/openai/0.together.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import { OpenAI } from "openai";
22

33
const client = new OpenAI({
4-
baseURL: "https://api.together.xyz/v1",
5-
apiKey: process.env.TOGETHER_API_KEY,
4+
baseURL: "https://router.huggingface.co/together/v1",
5+
apiKey: process.env.HF_TOKEN,
66
});
77

88
const chatCompletion = await client.chat.completions.create({

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/python/openai/0.together.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from openai import OpenAI
33

44
client = OpenAI(
5-
base_url="https://api.together.xyz/v1",
6-
api_key=os.environ["TOGETHER_API_KEY"],
5+
base_url="https://router.huggingface.co/together/v1",
6+
api_key=os.environ["HF_TOKEN"],
77
)
88

99
completion = client.chat.completions.create(

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/python/requests/0.together.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os
22
import requests
33

4-
API_URL = "https://api.together.xyz/v1/chat/completions"
4+
API_URL = "https://router.huggingface.co/together/v1/chat/completions"
55
headers = {
6-
"Authorization": f"Bearer {os.environ['TOGETHER_API_KEY']}",
6+
"Authorization": f"Bearer {os.environ['HF_TOKEN']}",
77
}
88

99
def query(payload):

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/sh/curl/0.together.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
curl https://api.together.xyz/v1/chat/completions \
2-
-H "Authorization: Bearer $TOGETHER_API_KEY" \
1+
curl https://router.huggingface.co/together/v1/chat/completions \
2+
-H "Authorization: Bearer $HF_TOKEN" \
33
-H 'Content-Type: application/json' \
44
-d '{
55
"messages": [

packages/tasks-gen/snippets-fixtures/conversational-llm-stream/js/openai/0.together.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import { OpenAI } from "openai";
22

33
const client = new OpenAI({
4-
baseURL: "https://api.together.xyz/v1",
5-
apiKey: process.env.TOGETHER_API_KEY,
4+
baseURL: "https://router.huggingface.co/together/v1",
5+
apiKey: process.env.HF_TOKEN,
66
});
77

88
const stream = await client.chat.completions.create({

packages/tasks-gen/snippets-fixtures/conversational-llm-stream/python/openai/0.together.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from openai import OpenAI
33

44
client = OpenAI(
5-
base_url="https://api.together.xyz/v1",
6-
api_key=os.environ["TOGETHER_API_KEY"],
5+
base_url="https://router.huggingface.co/together/v1",
6+
api_key=os.environ["HF_TOKEN"],
77
)
88

99
stream = client.chat.completions.create(

packages/tasks-gen/snippets-fixtures/conversational-llm-stream/python/requests/0.together.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import json
33
import requests
44

5-
API_URL = "https://api.together.xyz/v1/chat/completions"
5+
API_URL = "https://router.huggingface.co/together/v1/chat/completions"
66
headers = {
7-
"Authorization": f"Bearer {os.environ['TOGETHER_API_KEY']}",
7+
"Authorization": f"Bearer {os.environ['HF_TOKEN']}",
88
}
99

1010
def query(payload):

packages/tasks-gen/snippets-fixtures/conversational-llm-stream/sh/curl/0.together.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
curl https://api.together.xyz/v1/chat/completions \
2-
-H "Authorization: Bearer $TOGETHER_API_KEY" \
1+
curl https://router.huggingface.co/together/v1/chat/completions \
2+
-H "Authorization: Bearer $HF_TOKEN" \
33
-H 'Content-Type: application/json' \
44
-d '{
55
"messages": [

packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/js/openai/0.fireworks-ai.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import { OpenAI } from "openai";
22

33
const client = new OpenAI({
4-
baseURL: "https://api.fireworks.ai/inference/v1",
5-
apiKey: process.env.FIREWORKS_AI_API_KEY,
4+
baseURL: "https://router.huggingface.co/fireworks-ai/inference/v1",
5+
apiKey: process.env.HF_TOKEN,
66
});
77

88
const chatCompletion = await client.chat.completions.create({

packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/python/openai/0.fireworks-ai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from openai import OpenAI
33

44
client = OpenAI(
5-
base_url="https://api.fireworks.ai/inference/v1",
6-
api_key=os.environ["FIREWORKS_AI_API_KEY"],
5+
base_url="https://router.huggingface.co/fireworks-ai/inference/v1",
6+
api_key=os.environ["HF_TOKEN"],
77
)
88

99
completion = client.chat.completions.create(

packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/python/requests/0.fireworks-ai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os
22
import requests
33

4-
API_URL = "https://api.fireworks.ai/inference/v1/chat/completions"
4+
API_URL = "https://router.huggingface.co/fireworks-ai/inference/v1/chat/completions"
55
headers = {
6-
"Authorization": f"Bearer {os.environ['FIREWORKS_AI_API_KEY']}",
6+
"Authorization": f"Bearer {os.environ['HF_TOKEN']}",
77
}
88

99
def query(payload):

packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/sh/curl/0.fireworks-ai.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
curl https://api.fireworks.ai/inference/v1/chat/completions \
2-
-H "Authorization: Bearer $FIREWORKS_AI_API_KEY" \
1+
curl https://router.huggingface.co/fireworks-ai/inference/v1/chat/completions \
2+
-H "Authorization: Bearer $HF_TOKEN" \
33
-H 'Content-Type: application/json' \
44
-d '{
55
"messages": [

packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/js/openai/0.fireworks-ai.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import { OpenAI } from "openai";
22

33
const client = new OpenAI({
4-
baseURL: "https://api.fireworks.ai/inference/v1",
5-
apiKey: process.env.FIREWORKS_AI_API_KEY,
4+
baseURL: "https://router.huggingface.co/fireworks-ai/inference/v1",
5+
apiKey: process.env.HF_TOKEN,
66
});
77

88
const stream = await client.chat.completions.create({

packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/python/openai/0.fireworks-ai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from openai import OpenAI
33

44
client = OpenAI(
5-
base_url="https://api.fireworks.ai/inference/v1",
6-
api_key=os.environ["FIREWORKS_AI_API_KEY"],
5+
base_url="https://router.huggingface.co/fireworks-ai/inference/v1",
6+
api_key=os.environ["HF_TOKEN"],
77
)
88

99
stream = client.chat.completions.create(

packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/python/requests/0.fireworks-ai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import json
33
import requests
44

5-
API_URL = "https://api.fireworks.ai/inference/v1/chat/completions"
5+
API_URL = "https://router.huggingface.co/fireworks-ai/inference/v1/chat/completions"
66
headers = {
7-
"Authorization": f"Bearer {os.environ['FIREWORKS_AI_API_KEY']}",
7+
"Authorization": f"Bearer {os.environ['HF_TOKEN']}",
88
}
99

1010
def query(payload):

packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/sh/curl/0.fireworks-ai.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
curl https://api.fireworks.ai/inference/v1/chat/completions \
2-
-H "Authorization: Bearer $FIREWORKS_AI_API_KEY" \
1+
curl https://router.huggingface.co/fireworks-ai/inference/v1/chat/completions \
2+
-H "Authorization: Bearer $HF_TOKEN" \
33
-H 'Content-Type: application/json' \
44
-d '{
55
"messages": [
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import { InferenceClient } from "@huggingface/inference";
2+
3+
const client = new InferenceClient(process.env.TOGETHER_API_KEY);
4+
5+
const chatCompletion = await client.chatCompletion({
6+
provider: "together",
7+
model: "meta-llama/Llama-3.1-8B-Instruct",
8+
messages: [
9+
{
10+
role: "user",
11+
content: "What is the capital of France?",
12+
},
13+
],
14+
});
15+
16+
console.log(chatCompletion.choices[0].message);
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import { OpenAI } from "openai";
2+
3+
const client = new OpenAI({
4+
baseURL: "https://api.together.xyz/v1",
5+
apiKey: process.env.TOGETHER_API_KEY,
6+
});
7+
8+
const chatCompletion = await client.chat.completions.create({
9+
model: "<together alias for meta-llama/Llama-3.1-8B-Instruct>",
10+
messages: [
11+
{
12+
role: "user",
13+
content: "What is the capital of France?",
14+
},
15+
],
16+
});
17+
18+
console.log(chatCompletion.choices[0].message);
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import os
2+
from huggingface_hub import InferenceClient
3+
4+
client = InferenceClient(
5+
provider="together",
6+
api_key=os.environ["TOGETHER_API_KEY"],
7+
)
8+
9+
completion = client.chat.completions.create(
10+
model="meta-llama/Llama-3.1-8B-Instruct",
11+
messages=[
12+
{
13+
"role": "user",
14+
"content": "What is the capital of France?"
15+
}
16+
],
17+
)
18+
19+
print(completion.choices[0].message)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import os
2+
from openai import OpenAI
3+
4+
client = OpenAI(
5+
base_url="https://api.together.xyz/v1",
6+
api_key=os.environ["TOGETHER_API_KEY"],
7+
)
8+
9+
completion = client.chat.completions.create(
10+
model="<together alias for meta-llama/Llama-3.1-8B-Instruct>",
11+
messages=[
12+
{
13+
"role": "user",
14+
"content": "What is the capital of France?"
15+
}
16+
],
17+
)
18+
19+
print(completion.choices[0].message)

0 commit comments

Comments
 (0)