Skip to content

Commit 9b3e2cf

Browse files
[Inference] Implement a "1 class = 1 provider<>task pair" logic to isolate provider-specific code (#1315)
The Goal of this PR is to isolate provider-specific logic into its own class/file. Similar to what we have in the Python client, this gives a clear boundary between generic task implementation and provider-specific implementation details. This makes updating or adding a new provider much easier, adding support for a new provider will be as much as easier as: 1) register the provider in the mapping. 2) create a new `packages/inference/src/providers/{provider_name}.ts`. 3) implement the methods that require custom handling. ## Main changes - Added a new `TaskProviderHelper` abstract class that defines a common interface (`makeBody`, `preparePayload` `makeRoute`, `getResponse`, `prepareHeaders`, etc.) that all provider-specific helpers must implement. - Added `PROVIDERS` registry that maps any provider<>task to the corresponding `TaskProviderHelper` subclass that implements the provider<>task logic. - Added a lookup function (`getProviderHelper`) that takes a provider and task and returns the appropriate TaskProviderHelper instance from the `PROVIDERS` registry. - Implemented the task helpers for providers: - [x] All task helpers have been implemented for all providers. 🙏 Any feedback is welcome as the implementation is quite biased towards the Python but we might want to do some things differently here. --------- Co-authored-by: SBrandeis <[email protected]>
1 parent ae23913 commit 9b3e2cf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+2113
-1159
lines changed

Diff for: packages/inference/src/lib/getProviderHelper.ts

+270
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
import * as BlackForestLabs from "../providers/black-forest-labs";
2+
import * as Cerebras from "../providers/cerebras";
3+
import * as Cohere from "../providers/cohere";
4+
import * as FalAI from "../providers/fal-ai";
5+
import * as Fireworks from "../providers/fireworks-ai";
6+
import * as HFInference from "../providers/hf-inference";
7+
8+
import * as Hyperbolic from "../providers/hyperbolic";
9+
import * as Nebius from "../providers/nebius";
10+
import * as Novita from "../providers/novita";
11+
import * as OpenAI from "../providers/openai";
12+
import type {
13+
AudioClassificationTaskHelper,
14+
AudioToAudioTaskHelper,
15+
AutomaticSpeechRecognitionTaskHelper,
16+
ConversationalTaskHelper,
17+
DocumentQuestionAnsweringTaskHelper,
18+
FeatureExtractionTaskHelper,
19+
FillMaskTaskHelper,
20+
ImageClassificationTaskHelper,
21+
ImageSegmentationTaskHelper,
22+
ImageToImageTaskHelper,
23+
ImageToTextTaskHelper,
24+
ObjectDetectionTaskHelper,
25+
QuestionAnsweringTaskHelper,
26+
SentenceSimilarityTaskHelper,
27+
SummarizationTaskHelper,
28+
TableQuestionAnsweringTaskHelper,
29+
TabularClassificationTaskHelper,
30+
TabularRegressionTaskHelper,
31+
TaskProviderHelper,
32+
TextClassificationTaskHelper,
33+
TextGenerationTaskHelper,
34+
TextToAudioTaskHelper,
35+
TextToImageTaskHelper,
36+
TextToSpeechTaskHelper,
37+
TextToVideoTaskHelper,
38+
TokenClassificationTaskHelper,
39+
TranslationTaskHelper,
40+
VisualQuestionAnsweringTaskHelper,
41+
ZeroShotClassificationTaskHelper,
42+
ZeroShotImageClassificationTaskHelper,
43+
} from "../providers/providerHelper";
44+
import * as Replicate from "../providers/replicate";
45+
import * as Sambanova from "../providers/sambanova";
46+
import * as Together from "../providers/together";
47+
import type { InferenceProvider, InferenceTask } from "../types";
48+
49+
export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask, TaskProviderHelper>>> = {
50+
"black-forest-labs": {
51+
"text-to-image": new BlackForestLabs.BlackForestLabsTextToImageTask(),
52+
},
53+
cerebras: {
54+
conversational: new Cerebras.CerebrasConversationalTask(),
55+
},
56+
cohere: {
57+
conversational: new Cohere.CohereConversationalTask(),
58+
},
59+
"fal-ai": {
60+
"text-to-image": new FalAI.FalAITextToImageTask(),
61+
"text-to-speech": new FalAI.FalAITextToSpeechTask(),
62+
"text-to-video": new FalAI.FalAITextToVideoTask(),
63+
"automatic-speech-recognition": new FalAI.FalAIAutomaticSpeechRecognitionTask(),
64+
},
65+
"hf-inference": {
66+
"text-to-image": new HFInference.HFInferenceTextToImageTask(),
67+
conversational: new HFInference.HFInferenceConversationalTask(),
68+
"text-generation": new HFInference.HFInferenceTextGenerationTask(),
69+
"text-classification": new HFInference.HFInferenceTextClassificationTask(),
70+
"question-answering": new HFInference.HFInferenceQuestionAnsweringTask(),
71+
"audio-classification": new HFInference.HFInferenceAudioClassificationTask(),
72+
"automatic-speech-recognition": new HFInference.HFInferenceAutomaticSpeechRecognitionTask(),
73+
"fill-mask": new HFInference.HFInferenceFillMaskTask(),
74+
"feature-extraction": new HFInference.HFInferenceFeatureExtractionTask(),
75+
"image-classification": new HFInference.HFInferenceImageClassificationTask(),
76+
"image-segmentation": new HFInference.HFInferenceImageSegmentationTask(),
77+
"document-question-answering": new HFInference.HFInferenceDocumentQuestionAnsweringTask(),
78+
"image-to-text": new HFInference.HFInferenceImageToTextTask(),
79+
"object-detection": new HFInference.HFInferenceObjectDetectionTask(),
80+
"audio-to-audio": new HFInference.HFInferenceAudioToAudioTask(),
81+
"zero-shot-image-classification": new HFInference.HFInferenceZeroShotImageClassificationTask(),
82+
"zero-shot-classification": new HFInference.HFInferenceZeroShotClassificationTask(),
83+
"image-to-image": new HFInference.HFInferenceImageToImageTask(),
84+
"sentence-similarity": new HFInference.HFInferenceSentenceSimilarityTask(),
85+
"table-question-answering": new HFInference.HFInferenceTableQuestionAnsweringTask(),
86+
"tabular-classification": new HFInference.HFInferenceTabularClassificationTask(),
87+
"text-to-speech": new HFInference.HFInferenceTextToSpeechTask(),
88+
"token-classification": new HFInference.HFInferenceTokenClassificationTask(),
89+
translation: new HFInference.HFInferenceTranslationTask(),
90+
summarization: new HFInference.HFInferenceSummarizationTask(),
91+
"visual-question-answering": new HFInference.HFInferenceVisualQuestionAnsweringTask(),
92+
"tabular-regression": new HFInference.HFInferenceTabularRegressionTask(),
93+
"text-to-audio": new HFInference.HFInferenceTextToAudioTask(),
94+
},
95+
"fireworks-ai": {
96+
conversational: new Fireworks.FireworksConversationalTask(),
97+
},
98+
hyperbolic: {
99+
"text-to-image": new Hyperbolic.HyperbolicTextToImageTask(),
100+
conversational: new Hyperbolic.HyperbolicConversationalTask(),
101+
"text-generation": new Hyperbolic.HyperbolicTextGenerationTask(),
102+
},
103+
nebius: {
104+
"text-to-image": new Nebius.NebiusTextToImageTask(),
105+
conversational: new Nebius.NebiusConversationalTask(),
106+
"text-generation": new Nebius.NebiusTextGenerationTask(),
107+
},
108+
novita: {
109+
conversational: new Novita.NovitaConversationalTask(),
110+
"text-generation": new Novita.NovitaTextGenerationTask(),
111+
},
112+
openai: {
113+
conversational: new OpenAI.OpenAIConversationalTask(),
114+
},
115+
replicate: {
116+
"text-to-image": new Replicate.ReplicateTextToImageTask(),
117+
"text-to-speech": new Replicate.ReplicateTextToSpeechTask(),
118+
"text-to-video": new Replicate.ReplicateTextToVideoTask(),
119+
},
120+
sambanova: {
121+
conversational: new Sambanova.SambanovaConversationalTask(),
122+
},
123+
together: {
124+
"text-to-image": new Together.TogetherTextToImageTask(),
125+
conversational: new Together.TogetherConversationalTask(),
126+
"text-generation": new Together.TogetherTextGenerationTask(),
127+
},
128+
};
129+
130+
/**
131+
* Get provider helper instance by name and task
132+
*/
133+
export function getProviderHelper(
134+
provider: InferenceProvider,
135+
task: "text-to-image"
136+
): TextToImageTaskHelper & TaskProviderHelper;
137+
export function getProviderHelper(
138+
provider: InferenceProvider,
139+
task: "conversational"
140+
): ConversationalTaskHelper & TaskProviderHelper;
141+
export function getProviderHelper(
142+
provider: InferenceProvider,
143+
task: "text-generation"
144+
): TextGenerationTaskHelper & TaskProviderHelper;
145+
export function getProviderHelper(
146+
provider: InferenceProvider,
147+
task: "text-to-speech"
148+
): TextToSpeechTaskHelper & TaskProviderHelper;
149+
export function getProviderHelper(
150+
provider: InferenceProvider,
151+
task: "text-to-audio"
152+
): TextToAudioTaskHelper & TaskProviderHelper;
153+
export function getProviderHelper(
154+
provider: InferenceProvider,
155+
task: "automatic-speech-recognition"
156+
): AutomaticSpeechRecognitionTaskHelper & TaskProviderHelper;
157+
export function getProviderHelper(
158+
provider: InferenceProvider,
159+
task: "text-to-video"
160+
): TextToVideoTaskHelper & TaskProviderHelper;
161+
export function getProviderHelper(
162+
provider: InferenceProvider,
163+
task: "text-classification"
164+
): TextClassificationTaskHelper & TaskProviderHelper;
165+
export function getProviderHelper(
166+
provider: InferenceProvider,
167+
task: "question-answering"
168+
): QuestionAnsweringTaskHelper & TaskProviderHelper;
169+
export function getProviderHelper(
170+
provider: InferenceProvider,
171+
task: "audio-classification"
172+
): AudioClassificationTaskHelper & TaskProviderHelper;
173+
export function getProviderHelper(
174+
provider: InferenceProvider,
175+
task: "audio-to-audio"
176+
): AudioToAudioTaskHelper & TaskProviderHelper;
177+
export function getProviderHelper(
178+
provider: InferenceProvider,
179+
task: "fill-mask"
180+
): FillMaskTaskHelper & TaskProviderHelper;
181+
export function getProviderHelper(
182+
provider: InferenceProvider,
183+
task: "feature-extraction"
184+
): FeatureExtractionTaskHelper & TaskProviderHelper;
185+
export function getProviderHelper(
186+
provider: InferenceProvider,
187+
task: "image-classification"
188+
): ImageClassificationTaskHelper & TaskProviderHelper;
189+
export function getProviderHelper(
190+
provider: InferenceProvider,
191+
task: "image-segmentation"
192+
): ImageSegmentationTaskHelper & TaskProviderHelper;
193+
export function getProviderHelper(
194+
provider: InferenceProvider,
195+
task: "document-question-answering"
196+
): DocumentQuestionAnsweringTaskHelper & TaskProviderHelper;
197+
export function getProviderHelper(
198+
provider: InferenceProvider,
199+
task: "image-to-text"
200+
): ImageToTextTaskHelper & TaskProviderHelper;
201+
export function getProviderHelper(
202+
provider: InferenceProvider,
203+
task: "object-detection"
204+
): ObjectDetectionTaskHelper & TaskProviderHelper;
205+
export function getProviderHelper(
206+
provider: InferenceProvider,
207+
task: "zero-shot-image-classification"
208+
): ZeroShotImageClassificationTaskHelper & TaskProviderHelper;
209+
export function getProviderHelper(
210+
provider: InferenceProvider,
211+
task: "zero-shot-classification"
212+
): ZeroShotClassificationTaskHelper & TaskProviderHelper;
213+
export function getProviderHelper(
214+
provider: InferenceProvider,
215+
task: "image-to-image"
216+
): ImageToImageTaskHelper & TaskProviderHelper;
217+
export function getProviderHelper(
218+
provider: InferenceProvider,
219+
task: "sentence-similarity"
220+
): SentenceSimilarityTaskHelper & TaskProviderHelper;
221+
export function getProviderHelper(
222+
provider: InferenceProvider,
223+
task: "table-question-answering"
224+
): TableQuestionAnsweringTaskHelper & TaskProviderHelper;
225+
export function getProviderHelper(
226+
provider: InferenceProvider,
227+
task: "tabular-classification"
228+
): TabularClassificationTaskHelper & TaskProviderHelper;
229+
export function getProviderHelper(
230+
provider: InferenceProvider,
231+
task: "tabular-regression"
232+
): TabularRegressionTaskHelper & TaskProviderHelper;
233+
export function getProviderHelper(
234+
provider: InferenceProvider,
235+
task: "token-classification"
236+
): TokenClassificationTaskHelper & TaskProviderHelper;
237+
export function getProviderHelper(
238+
provider: InferenceProvider,
239+
task: "translation"
240+
): TranslationTaskHelper & TaskProviderHelper;
241+
export function getProviderHelper(
242+
provider: InferenceProvider,
243+
task: "summarization"
244+
): SummarizationTaskHelper & TaskProviderHelper;
245+
export function getProviderHelper(
246+
provider: InferenceProvider,
247+
task: "visual-question-answering"
248+
): VisualQuestionAnsweringTaskHelper & TaskProviderHelper;
249+
export function getProviderHelper(provider: InferenceProvider, task: InferenceTask | undefined): TaskProviderHelper;
250+
251+
export function getProviderHelper(provider: InferenceProvider, task: InferenceTask | undefined): TaskProviderHelper {
252+
if (provider === "hf-inference") {
253+
if (!task) {
254+
return new HFInference.HFInferenceTask();
255+
}
256+
}
257+
if (!task) {
258+
throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'");
259+
}
260+
if (!(provider in PROVIDERS)) {
261+
throw new Error(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`);
262+
}
263+
const providerTasks = PROVIDERS[provider];
264+
if (!providerTasks || !(task in providerTasks)) {
265+
throw new Error(
266+
`Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(providerTasks ?? {})}`
267+
);
268+
}
269+
return providerTasks[task] as TaskProviderHelper;
270+
}

0 commit comments

Comments
 (0)