Skip to content

Commit c9ba631

Browse files
committed
make TS typing work to deduce model being undefined
1 parent e09802b commit c9ba631

File tree

5 files changed

+44
-23
lines changed

5 files changed

+44
-23
lines changed

packages/inference/src/lib/makeRequestOptions.ts

+31-17
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ import { REPLICATE_CONFIG } from "../providers/replicate";
1212
import { SAMBANOVA_CONFIG } from "../providers/sambanova";
1313
import { TOGETHER_CONFIG } from "../providers/together";
1414
import { OPENAI_CONFIG } from "../providers/openai";
15-
import type { InferenceProvider, InferenceTask, Options, ProviderConfig, RequestArgs } from "../types";
15+
import type { InferenceProvider, InferenceTask, Options, RequestArgs } from "../types";
1616
import { isUrl } from "./isUrl";
1717
import { version as packageVersion, name as packageName } from "../../package.json";
1818
import { getProviderModelId } from "./getProviderModelId";
19+
import type { InferenceProviderTypes } from "../providers/types";
1920

2021
const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
2122

@@ -28,7 +29,7 @@ let tasks: Record<string, { models: { id: string }[] }> | null = null;
2829
/**
2930
* Config to define how to serialize requests for each provider
3031
*/
31-
const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
32+
const providerConfigs = {
3233
"black-forest-labs": BLACK_FOREST_LABS_CONFIG,
3334
cerebras: CEREBRAS_CONFIG,
3435
cohere: COHERE_CONFIG,
@@ -42,7 +43,8 @@ const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
4243
replicate: REPLICATE_CONFIG,
4344
sambanova: SAMBANOVA_CONFIG,
4445
together: TOGETHER_CONFIG,
45-
};
46+
} satisfies Record<Exclude<InferenceProvider, "hf-inference">, InferenceProviderTypes.Config> &
47+
Record<Extract<InferenceProvider, "hf-inference">, InferenceProviderTypes.ConfigWithOptionalModel>;
4648

4749
/**
4850
* Helper that prepares request arguments.
@@ -82,8 +84,11 @@ export async function makeRequestOptions(
8284
}
8385

8486
if (args.endpointUrl) {
87+
if (provider !== "hf-inference") {
88+
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
89+
}
8590
return makeRequestOptionsFromResolvedModel(
86-
{ endpointUrl: args.endpointUrl, resolvedModel: maybeModel },
91+
{ endpointUrl: args.endpointUrl, resolvedModel: maybeModel, provider },
8792
args,
8893
options
8994
);
@@ -94,7 +99,7 @@ export async function makeRequestOptions(
9499
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
95100
}
96101
return makeRequestOptionsFromResolvedModel(
97-
{ resolvedModel: removeProviderPrefix(maybeModel, provider) },
102+
{ resolvedModel: removeProviderPrefix(maybeModel, provider), provider },
98103
args,
99104
options
100105
);
@@ -109,7 +114,7 @@ export async function makeRequestOptions(
109114
});
110115

111116
// Use the sync version with the resolved model
112-
return makeRequestOptionsFromResolvedModel({ resolvedModel }, args, options);
117+
return makeRequestOptionsFromResolvedModel({ resolvedModel, provider }, args, options);
113118
}
114119

115120
/**
@@ -120,7 +125,9 @@ export function makeRequestOptionsFromResolvedModel(
120125
/**
121126
* Should only be undefined if the endpointUrl is provided
122127
*/
123-
input: { endpointUrl: string; resolvedModel?: string } | { endpointUrl?: undefined; resolvedModel: string },
128+
input:
129+
| { endpointUrl: string; resolvedModel?: string; provider: Extract<InferenceProvider, "hf-inference"> }
130+
| { endpointUrl?: undefined; resolvedModel: string; provider: InferenceProvider },
124131
args: RequestArgs & {
125132
data?: Blob | ArrayBuffer;
126133
stream?: boolean;
@@ -133,17 +140,17 @@ export function makeRequestOptionsFromResolvedModel(
133140
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
134141
void model;
135142
void endpointUrl;
143+
void maybeProvider;
136144

137-
const provider = maybeProvider ?? "hf-inference";
138-
const providerConfig = providerConfigs[provider];
145+
const providerConfig = providerConfigs[input.provider];
139146

140147
const { includeCredentials, task, chatCompletion, signal, billTo } = options ?? {};
141148

142149
const authMethod = (() => {
143150
if (providerConfig.clientSideRoutingOnly) {
144151
// Closed-source providers require an accessToken (cannot be routed).
145152
if (accessToken && accessToken.startsWith("hf_")) {
146-
throw new Error(`Provider ${provider} is closed-source and does not support HF tokens.`);
153+
throw new Error(`Provider ${input.provider} is closed-source and does not support HF tokens.`);
147154
}
148155
return "provider-key";
149156
}
@@ -170,7 +177,7 @@ export function makeRequestOptionsFromResolvedModel(
170177
authMethod,
171178
baseUrl:
172179
authMethod !== "provider-key"
173-
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider)
180+
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", input.provider)
174181
: providerConfig.makeBaseUrl(task),
175182
model: input.resolvedModel,
176183
chatCompletion,
@@ -205,12 +212,19 @@ export function makeRequestOptionsFromResolvedModel(
205212
const body = binary
206213
? args.data
207214
: JSON.stringify(
208-
providerConfig.makeBody({
209-
args: remainingArgs as Record<string, unknown>,
210-
model: input.resolvedModel,
211-
task,
212-
chatCompletion,
213-
})
215+
input.provider === "hf-inference"
216+
? providerConfigs[input.provider].makeBody({
217+
args: remainingArgs as Record<string, unknown>,
218+
model: input.resolvedModel,
219+
task,
220+
chatCompletion,
221+
})
222+
: providerConfig.makeBody({
223+
args: remainingArgs as Record<string, unknown>,
224+
model: input.resolvedModel,
225+
task,
226+
chatCompletion,
227+
})
214228
);
215229

216230
/**

packages/inference/src/providers/black-forest-labs.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ const makeBody: InferenceProviderTypes.MakeBody = (params) => {
2626
return params.args;
2727
};
2828

29-
const makeHeaders: InferenceProviderTypes.MakeHeaders = (params) => {
29+
const makeHeaders: InferenceProviderTypes.MakeHeaders = (params): Record<string, string> => {
3030
if (params.authMethod === "provider-key") {
3131
return { "X-Key": `${params.accessToken}` };
3232
} else {

packages/inference/src/providers/hf-inference.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ const makeBaseUrl: InferenceProviderTypes.MakeBaseUrl = () => {
1717
return `${HF_ROUTER_URL}/hf-inference`;
1818
};
1919

20-
const makeBody: InferenceProviderTypes.MakeBody = (params) => {
20+
const makeBody: InferenceProviderTypes.MakeBodyWithOptionalModel = (params) => {
2121
return {
2222
...params.args,
2323
...(params.chatCompletion ? { model: params.model } : undefined),
@@ -39,7 +39,7 @@ const makeUrl: InferenceProviderTypes.MakeUrl = (params) => {
3939
return `${params.baseUrl}/models/${params.model}`;
4040
};
4141

42-
export const HF_INFERENCE_CONFIG: InferenceProviderTypes.Config = {
42+
export const HF_INFERENCE_CONFIG: InferenceProviderTypes.ConfigWithOptionalModel = {
4343
makeBaseUrl,
4444
makeBody,
4545
makeHeaders,

packages/inference/src/providers/replicate.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ const makeBaseUrl: InferenceProviderTypes.MakeBaseUrl = () => {
2525
const makeBody: InferenceProviderTypes.MakeBody = (params) => {
2626
return {
2727
input: params.args,
28-
version: params.model?.includes(":") ? params.model.split(":")[1] : undefined,
28+
version: params.model.includes(":") ? params.model.split(":")[1] : undefined,
2929
};
3030
};
3131

packages/inference/src/providers/types.ts

+9-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ export namespace InferenceProviderTypes {
1717
clientSideRoutingOnly?: boolean;
1818
}
1919

20+
export interface ConfigWithOptionalModel extends Config {
21+
makeBody: MakeBodyWithOptionalModel;
22+
}
23+
2024
export type MakeBody = (params: BodyParams) => Record<string, unknown>;
21-
export type MakeHeaders = (params: HeaderParams) => Record<string, string | undefined>;
25+
export type MakeBodyWithOptionalModel = (
26+
params: Omit<BodyParams, "model"> & { model?: string }
27+
) => Record<string, unknown>;
28+
export type MakeHeaders = (params: HeaderParams) => Record<string, string>;
2229
export type MakeUrl = (params: UrlParams) => string;
2330
export type MakeBaseUrl = (() => string) | ((task?: InferenceTask) => string);
2431

@@ -38,7 +45,7 @@ export namespace InferenceProviderTypes {
3845
export interface BodyParams {
3946
args: Record<string, unknown>;
4047
chatCompletion?: boolean;
41-
model?: string;
48+
model: string;
4249
task?: InferenceTask;
4350
}
4451
}

0 commit comments

Comments
 (0)