@@ -12,10 +12,11 @@ import { REPLICATE_CONFIG } from "../providers/replicate";
12
12
import { SAMBANOVA_CONFIG } from "../providers/sambanova" ;
13
13
import { TOGETHER_CONFIG } from "../providers/together" ;
14
14
import { OPENAI_CONFIG } from "../providers/openai" ;
15
- import type { InferenceProvider , InferenceTask , Options , ProviderConfig , RequestArgs } from "../types" ;
15
+ import type { InferenceProvider , InferenceTask , Options , RequestArgs } from "../types" ;
16
16
import { isUrl } from "./isUrl" ;
17
17
import { version as packageVersion , name as packageName } from "../../package.json" ;
18
18
import { getProviderModelId } from "./getProviderModelId" ;
19
+ import type { InferenceProviderTypes } from "../providers/types" ;
19
20
20
21
const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${ HF_ROUTER_URL } /{{PROVIDER}}` ;
21
22
@@ -28,7 +29,7 @@ let tasks: Record<string, { models: { id: string }[] }> | null = null;
28
29
/**
29
30
* Config to define how to serialize requests for each provider
30
31
*/
31
- const providerConfigs : Record < InferenceProvider , ProviderConfig > = {
32
+ const providerConfigs = {
32
33
"black-forest-labs" : BLACK_FOREST_LABS_CONFIG ,
33
34
cerebras : CEREBRAS_CONFIG ,
34
35
cohere : COHERE_CONFIG ,
@@ -42,7 +43,8 @@ const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
42
43
replicate : REPLICATE_CONFIG ,
43
44
sambanova : SAMBANOVA_CONFIG ,
44
45
together : TOGETHER_CONFIG ,
45
- } ;
46
+ } satisfies Record < Exclude < InferenceProvider , "hf-inference" > , InferenceProviderTypes . Config > &
47
+ Record < Extract < InferenceProvider , "hf-inference" > , InferenceProviderTypes . ConfigWithOptionalModel > ;
46
48
47
49
/**
48
50
* Helper that prepares request arguments.
@@ -82,8 +84,11 @@ export async function makeRequestOptions(
82
84
}
83
85
84
86
if ( args . endpointUrl ) {
87
+ if ( provider !== "hf-inference" ) {
88
+ throw new Error ( `Cannot use endpointUrl with a third-party provider.` ) ;
89
+ }
85
90
return makeRequestOptionsFromResolvedModel (
86
- { endpointUrl : args . endpointUrl , resolvedModel : maybeModel } ,
91
+ { endpointUrl : args . endpointUrl , resolvedModel : maybeModel , provider } ,
87
92
args ,
88
93
options
89
94
) ;
@@ -94,7 +99,7 @@ export async function makeRequestOptions(
94
99
throw new Error ( `Provider ${ provider } requires a model ID to be passed directly.` ) ;
95
100
}
96
101
return makeRequestOptionsFromResolvedModel (
97
- { resolvedModel : removeProviderPrefix ( maybeModel , provider ) } ,
102
+ { resolvedModel : removeProviderPrefix ( maybeModel , provider ) , provider } ,
98
103
args ,
99
104
options
100
105
) ;
@@ -109,7 +114,7 @@ export async function makeRequestOptions(
109
114
} ) ;
110
115
111
116
// Use the sync version with the resolved model
112
- return makeRequestOptionsFromResolvedModel ( { resolvedModel } , args , options ) ;
117
+ return makeRequestOptionsFromResolvedModel ( { resolvedModel, provider } , args , options ) ;
113
118
}
114
119
115
120
/**
@@ -120,7 +125,9 @@ export function makeRequestOptionsFromResolvedModel(
120
125
/**
121
126
* Should only be undefined if the endpointUrl is provided
122
127
*/
123
- input : { endpointUrl : string ; resolvedModel ?: string } | { endpointUrl ?: undefined ; resolvedModel : string } ,
128
+ input :
129
+ | { endpointUrl : string ; resolvedModel ?: string ; provider : Extract < InferenceProvider , "hf-inference" > }
130
+ | { endpointUrl ?: undefined ; resolvedModel : string ; provider : InferenceProvider } ,
124
131
args : RequestArgs & {
125
132
data ?: Blob | ArrayBuffer ;
126
133
stream ?: boolean ;
@@ -133,17 +140,17 @@ export function makeRequestOptionsFromResolvedModel(
133
140
const { accessToken, endpointUrl, provider : maybeProvider , model, ...remainingArgs } = args ;
134
141
void model ;
135
142
void endpointUrl ;
143
+ void maybeProvider ;
136
144
137
- const provider = maybeProvider ?? "hf-inference" ;
138
- const providerConfig = providerConfigs [ provider ] ;
145
+ const providerConfig = providerConfigs [ input . provider ] ;
139
146
140
147
const { includeCredentials, task, chatCompletion, signal, billTo } = options ?? { } ;
141
148
142
149
const authMethod = ( ( ) => {
143
150
if ( providerConfig . clientSideRoutingOnly ) {
144
151
// Closed-source providers require an accessToken (cannot be routed).
145
152
if ( accessToken && accessToken . startsWith ( "hf_" ) ) {
146
- throw new Error ( `Provider ${ provider } is closed-source and does not support HF tokens.` ) ;
153
+ throw new Error ( `Provider ${ input . provider } is closed-source and does not support HF tokens.` ) ;
147
154
}
148
155
return "provider-key" ;
149
156
}
@@ -170,7 +177,7 @@ export function makeRequestOptionsFromResolvedModel(
170
177
authMethod,
171
178
baseUrl :
172
179
authMethod !== "provider-key"
173
- ? HF_HUB_INFERENCE_PROXY_TEMPLATE . replace ( "{{PROVIDER}}" , provider )
180
+ ? HF_HUB_INFERENCE_PROXY_TEMPLATE . replace ( "{{PROVIDER}}" , input . provider )
174
181
: providerConfig . makeBaseUrl ( task ) ,
175
182
model : input . resolvedModel ,
176
183
chatCompletion,
@@ -205,12 +212,19 @@ export function makeRequestOptionsFromResolvedModel(
205
212
const body = binary
206
213
? args . data
207
214
: JSON . stringify (
208
- providerConfig . makeBody ( {
209
- args : remainingArgs as Record < string , unknown > ,
210
- model : input . resolvedModel ,
211
- task,
212
- chatCompletion,
213
- } )
215
+ input . provider === "hf-inference"
216
+ ? providerConfigs [ input . provider ] . makeBody ( {
217
+ args : remainingArgs as Record < string , unknown > ,
218
+ model : input . resolvedModel ,
219
+ task,
220
+ chatCompletion,
221
+ } )
222
+ : providerConfig . makeBody ( {
223
+ args : remainingArgs as Record < string , unknown > ,
224
+ model : input . resolvedModel ,
225
+ task,
226
+ chatCompletion,
227
+ } )
214
228
) ;
215
229
216
230
/**
0 commit comments