Skip to content

Commit 7782f2b

Browse files
Better typing + make task functions generic (#1338)
This PR introduces task helpers to achieve better typing of the `getProviderHelper`, this is inspired by @SBrandeis demonstration PR #1332, without reversing the mapping. There is a lot of diff (sorry for that) because I isolated every HF inference API-specific code into the provider file and there are a lot of supported tasks 😬 --------- Co-authored-by: SBrandeis <[email protected]>
1 parent 9d16e7c commit 7782f2b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1089
-494
lines changed
+226-87
Original file line numberDiff line numberDiff line change
@@ -1,131 +1,270 @@
1-
import { BlackForestLabsTextToImageTask } from "../providers/black-forest-labs";
2-
import { CerebrasConversationalTask } from "../providers/cerebras";
3-
import { CohereConversationalTask } from "../providers/cohere";
4-
import {
5-
FalAIAutomaticSpeechRecognitionTask,
6-
FalAITextToImageTask,
7-
FalAITextToSpeechTask,
8-
FalAITextToVideoTask,
9-
} from "../providers/fal-ai";
10-
import { FireworksConversationalTask } from "../providers/fireworks-ai";
11-
import {
12-
HFInferenceConversationalTask,
13-
HFInferenceTask,
14-
HFInferenceTextGenerationTask,
15-
HFInferenceTextToImageTask,
16-
} from "../providers/hf-inference";
17-
import {
18-
HyperbolicConversationalTask,
19-
HyperbolicTextGenerationTask,
20-
HyperbolicTextToImageTask,
21-
} from "../providers/hyperbolic";
22-
import { NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask } from "../providers/nebius";
23-
import { NovitaConversationalTask, NovitaTextGenerationTask } from "../providers/novita";
24-
import { OpenAIConversationalTask } from "../providers/openai";
25-
import type { TaskProviderHelper } from "../providers/providerHelper";
26-
import { ReplicateTextToImageTask, ReplicateTextToSpeechTask, ReplicateTextToVideoTask } from "../providers/replicate";
27-
import { SambanovaConversationalTask } from "../providers/sambanova";
28-
import { TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask } from "../providers/together";
1+
import * as BlackForestLabs from "../providers/black-forest-labs";
2+
import * as Cerebras from "../providers/cerebras";
3+
import * as Cohere from "../providers/cohere";
4+
import * as FalAI from "../providers/fal-ai";
5+
import * as Fireworks from "../providers/fireworks-ai";
6+
import * as HFInference from "../providers/hf-inference";
7+
8+
import * as Hyperbolic from "../providers/hyperbolic";
9+
import * as Nebius from "../providers/nebius";
10+
import * as Novita from "../providers/novita";
11+
import * as OpenAI from "../providers/openai";
12+
import type {
13+
AudioClassificationTaskHelper,
14+
AudioToAudioTaskHelper,
15+
AutomaticSpeechRecognitionTaskHelper,
16+
ConversationalTaskHelper,
17+
DocumentQuestionAnsweringTaskHelper,
18+
FeatureExtractionTaskHelper,
19+
FillMaskTaskHelper,
20+
ImageClassificationTaskHelper,
21+
ImageSegmentationTaskHelper,
22+
ImageToImageTaskHelper,
23+
ImageToTextTaskHelper,
24+
ObjectDetectionTaskHelper,
25+
QuestionAnsweringTaskHelper,
26+
SentenceSimilarityTaskHelper,
27+
SummarizationTaskHelper,
28+
TableQuestionAnsweringTaskHelper,
29+
TabularClassificationTaskHelper,
30+
TabularRegressionTaskHelper,
31+
TaskProviderHelper,
32+
TextClassificationTaskHelper,
33+
TextGenerationTaskHelper,
34+
TextToAudioTaskHelper,
35+
TextToImageTaskHelper,
36+
TextToSpeechTaskHelper,
37+
TextToVideoTaskHelper,
38+
TokenClassificationTaskHelper,
39+
TranslationTaskHelper,
40+
VisualQuestionAnsweringTaskHelper,
41+
ZeroShotClassificationTaskHelper,
42+
ZeroShotImageClassificationTaskHelper,
43+
} from "../providers/providerHelper";
44+
import * as Replicate from "../providers/replicate";
45+
import * as Sambanova from "../providers/sambanova";
46+
import * as Together from "../providers/together";
2947
import type { InferenceProvider, InferenceTask } from "../types";
3048

3149
export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask, TaskProviderHelper>>> = {
3250
"black-forest-labs": {
33-
"text-to-image": new BlackForestLabsTextToImageTask(),
51+
"text-to-image": new BlackForestLabs.BlackForestLabsTextToImageTask(),
3452
},
3553
cerebras: {
36-
conversational: new CerebrasConversationalTask(),
54+
conversational: new Cerebras.CerebrasConversationalTask(),
3755
},
3856
cohere: {
39-
conversational: new CohereConversationalTask(),
57+
conversational: new Cohere.CohereConversationalTask(),
4058
},
4159
"fal-ai": {
42-
"automatic-speech-recognition": new FalAIAutomaticSpeechRecognitionTask(),
43-
"text-to-image": new FalAITextToImageTask(),
44-
"text-to-speech": new FalAITextToSpeechTask(),
45-
"text-to-video": new FalAITextToVideoTask(),
46-
},
47-
"fireworks-ai": {
48-
conversational: new FireworksConversationalTask(),
60+
"text-to-image": new FalAI.FalAITextToImageTask(),
61+
"text-to-speech": new FalAI.FalAITextToSpeechTask(),
62+
"text-to-video": new FalAI.FalAITextToVideoTask(),
63+
"automatic-speech-recognition": new FalAI.FalAIAutomaticSpeechRecognitionTask(),
4964
},
5065
"hf-inference": {
51-
"text-to-image": new HFInferenceTextToImageTask(),
52-
conversational: new HFInferenceConversationalTask(),
53-
"text-generation": new HFInferenceTextGenerationTask(),
54-
"text-classification": new HFInferenceTask("text-classification"),
55-
"text-to-audio": new HFInferenceTask("text-to-audio"),
56-
"question-answering": new HFInferenceTask("question-answering"),
57-
"audio-classification": new HFInferenceTask("audio-classification"),
58-
"automatic-speech-recognition": new HFInferenceTask("automatic-speech-recognition"),
59-
"fill-mask": new HFInferenceTask("fill-mask"),
60-
"feature-extraction": new HFInferenceTask("feature-extraction"),
61-
"image-classification": new HFInferenceTask("image-classification"),
62-
"image-segmentation": new HFInferenceTask("image-segmentation"),
63-
"document-question-answering": new HFInferenceTask("document-question-answering"),
64-
"image-to-text": new HFInferenceTask("image-to-text"),
65-
"object-detection": new HFInferenceTask("object-detection"),
66-
"audio-to-audio": new HFInferenceTask("audio-to-audio"),
67-
"zero-shot-image-classification": new HFInferenceTask("zero-shot-image-classification"),
68-
"zero-shot-classification": new HFInferenceTask("zero-shot-classification"),
69-
"image-to-image": new HFInferenceTask("image-to-image"),
70-
"sentence-similarity": new HFInferenceTask("sentence-similarity"),
71-
"table-question-answering": new HFInferenceTask("table-question-answering"),
72-
"tabular-classification": new HFInferenceTask("tabular-classification"),
73-
"text-to-speech": new HFInferenceTask("text-to-speech"),
74-
"token-classification": new HFInferenceTask("token-classification"),
75-
translation: new HFInferenceTask("translation"),
76-
summarization: new HFInferenceTask("summarization"),
77-
"visual-question-answering": new HFInferenceTask("visual-question-answering"),
66+
"text-to-image": new HFInference.HFInferenceTextToImageTask(),
67+
conversational: new HFInference.HFInferenceConversationalTask(),
68+
"text-generation": new HFInference.HFInferenceTextGenerationTask(),
69+
"text-classification": new HFInference.HFInferenceTextClassificationTask(),
70+
"question-answering": new HFInference.HFInferenceQuestionAnsweringTask(),
71+
"audio-classification": new HFInference.HFInferenceAudioClassificationTask(),
72+
"automatic-speech-recognition": new HFInference.HFInferenceAutomaticSpeechRecognitionTask(),
73+
"fill-mask": new HFInference.HFInferenceFillMaskTask(),
74+
"feature-extraction": new HFInference.HFInferenceFeatureExtractionTask(),
75+
"image-classification": new HFInference.HFInferenceImageClassificationTask(),
76+
"image-segmentation": new HFInference.HFInferenceImageSegmentationTask(),
77+
"document-question-answering": new HFInference.HFInferenceDocumentQuestionAnsweringTask(),
78+
"image-to-text": new HFInference.HFInferenceImageToTextTask(),
79+
"object-detection": new HFInference.HFInferenceObjectDetectionTask(),
80+
"audio-to-audio": new HFInference.HFInferenceAudioToAudioTask(),
81+
"zero-shot-image-classification": new HFInference.HFInferenceZeroShotImageClassificationTask(),
82+
"zero-shot-classification": new HFInference.HFInferenceZeroShotClassificationTask(),
83+
"image-to-image": new HFInference.HFInferenceImageToImageTask(),
84+
"sentence-similarity": new HFInference.HFInferenceSentenceSimilarityTask(),
85+
"table-question-answering": new HFInference.HFInferenceTableQuestionAnsweringTask(),
86+
"tabular-classification": new HFInference.HFInferenceTabularClassificationTask(),
87+
"text-to-speech": new HFInference.HFInferenceTextToSpeechTask(),
88+
"token-classification": new HFInference.HFInferenceTokenClassificationTask(),
89+
translation: new HFInference.HFInferenceTranslationTask(),
90+
summarization: new HFInference.HFInferenceSummarizationTask(),
91+
"visual-question-answering": new HFInference.HFInferenceVisualQuestionAnsweringTask(),
92+
"tabular-regression": new HFInference.HFInferenceTabularRegressionTask(),
93+
"text-to-audio": new HFInference.HFInferenceTextToAudioTask(),
94+
},
95+
"fireworks-ai": {
96+
conversational: new Fireworks.FireworksConversationalTask(),
7897
},
7998
hyperbolic: {
80-
"text-to-image": new HyperbolicTextToImageTask(),
81-
conversational: new HyperbolicConversationalTask(),
82-
"text-generation": new HyperbolicTextGenerationTask(),
99+
"text-to-image": new Hyperbolic.HyperbolicTextToImageTask(),
100+
conversational: new Hyperbolic.HyperbolicConversationalTask(),
101+
"text-generation": new Hyperbolic.HyperbolicTextGenerationTask(),
83102
},
84103
nebius: {
85-
"text-to-image": new NebiusTextToImageTask(),
86-
conversational: new NebiusConversationalTask(),
87-
"text-generation": new NebiusTextGenerationTask(),
104+
"text-to-image": new Nebius.NebiusTextToImageTask(),
105+
conversational: new Nebius.NebiusConversationalTask(),
106+
"text-generation": new Nebius.NebiusTextGenerationTask(),
88107
},
89108
novita: {
90-
"text-generation": new NovitaTextGenerationTask(),
91-
conversational: new NovitaConversationalTask(),
109+
conversational: new Novita.NovitaConversationalTask(),
110+
"text-generation": new Novita.NovitaTextGenerationTask(),
92111
},
93112
openai: {
94-
conversational: new OpenAIConversationalTask(),
113+
conversational: new OpenAI.OpenAIConversationalTask(),
95114
},
96115
replicate: {
97-
"text-to-image": new ReplicateTextToImageTask(),
98-
"text-to-speech": new ReplicateTextToSpeechTask(),
99-
"text-to-video": new ReplicateTextToVideoTask(),
116+
"text-to-image": new Replicate.ReplicateTextToImageTask(),
117+
"text-to-speech": new Replicate.ReplicateTextToSpeechTask(),
118+
"text-to-video": new Replicate.ReplicateTextToVideoTask(),
100119
},
101120
sambanova: {
102-
conversational: new SambanovaConversationalTask(),
121+
conversational: new Sambanova.SambanovaConversationalTask(),
103122
},
104123
together: {
105-
"text-to-image": new TogetherTextToImageTask(),
106-
"text-generation": new TogetherTextGenerationTask(),
107-
conversational: new TogetherConversationalTask(),
124+
"text-to-image": new Together.TogetherTextToImageTask(),
125+
conversational: new Together.TogetherConversationalTask(),
126+
"text-generation": new Together.TogetherTextGenerationTask(),
108127
},
109128
};
110129

111130
/**
112131
* Get provider helper instance by name and task
113132
*/
133+
export function getProviderHelper(
134+
provider: InferenceProvider,
135+
task: "text-to-image"
136+
): TextToImageTaskHelper & TaskProviderHelper;
137+
export function getProviderHelper(
138+
provider: InferenceProvider,
139+
task: "conversational"
140+
): ConversationalTaskHelper & TaskProviderHelper;
141+
export function getProviderHelper(
142+
provider: InferenceProvider,
143+
task: "text-generation"
144+
): TextGenerationTaskHelper & TaskProviderHelper;
145+
export function getProviderHelper(
146+
provider: InferenceProvider,
147+
task: "text-to-speech"
148+
): TextToSpeechTaskHelper & TaskProviderHelper;
149+
export function getProviderHelper(
150+
provider: InferenceProvider,
151+
task: "text-to-audio"
152+
): TextToAudioTaskHelper & TaskProviderHelper;
153+
export function getProviderHelper(
154+
provider: InferenceProvider,
155+
task: "automatic-speech-recognition"
156+
): AutomaticSpeechRecognitionTaskHelper & TaskProviderHelper;
157+
export function getProviderHelper(
158+
provider: InferenceProvider,
159+
task: "text-to-video"
160+
): TextToVideoTaskHelper & TaskProviderHelper;
161+
export function getProviderHelper(
162+
provider: InferenceProvider,
163+
task: "text-classification"
164+
): TextClassificationTaskHelper & TaskProviderHelper;
165+
export function getProviderHelper(
166+
provider: InferenceProvider,
167+
task: "question-answering"
168+
): QuestionAnsweringTaskHelper & TaskProviderHelper;
169+
export function getProviderHelper(
170+
provider: InferenceProvider,
171+
task: "audio-classification"
172+
): AudioClassificationTaskHelper & TaskProviderHelper;
173+
export function getProviderHelper(
174+
provider: InferenceProvider,
175+
task: "audio-to-audio"
176+
): AudioToAudioTaskHelper & TaskProviderHelper;
177+
export function getProviderHelper(
178+
provider: InferenceProvider,
179+
task: "fill-mask"
180+
): FillMaskTaskHelper & TaskProviderHelper;
181+
export function getProviderHelper(
182+
provider: InferenceProvider,
183+
task: "feature-extraction"
184+
): FeatureExtractionTaskHelper & TaskProviderHelper;
185+
export function getProviderHelper(
186+
provider: InferenceProvider,
187+
task: "image-classification"
188+
): ImageClassificationTaskHelper & TaskProviderHelper;
189+
export function getProviderHelper(
190+
provider: InferenceProvider,
191+
task: "image-segmentation"
192+
): ImageSegmentationTaskHelper & TaskProviderHelper;
193+
export function getProviderHelper(
194+
provider: InferenceProvider,
195+
task: "document-question-answering"
196+
): DocumentQuestionAnsweringTaskHelper & TaskProviderHelper;
197+
export function getProviderHelper(
198+
provider: InferenceProvider,
199+
task: "image-to-text"
200+
): ImageToTextTaskHelper & TaskProviderHelper;
201+
export function getProviderHelper(
202+
provider: InferenceProvider,
203+
task: "object-detection"
204+
): ObjectDetectionTaskHelper & TaskProviderHelper;
205+
export function getProviderHelper(
206+
provider: InferenceProvider,
207+
task: "zero-shot-image-classification"
208+
): ZeroShotImageClassificationTaskHelper & TaskProviderHelper;
209+
export function getProviderHelper(
210+
provider: InferenceProvider,
211+
task: "zero-shot-classification"
212+
): ZeroShotClassificationTaskHelper & TaskProviderHelper;
213+
export function getProviderHelper(
214+
provider: InferenceProvider,
215+
task: "image-to-image"
216+
): ImageToImageTaskHelper & TaskProviderHelper;
217+
export function getProviderHelper(
218+
provider: InferenceProvider,
219+
task: "sentence-similarity"
220+
): SentenceSimilarityTaskHelper & TaskProviderHelper;
221+
export function getProviderHelper(
222+
provider: InferenceProvider,
223+
task: "table-question-answering"
224+
): TableQuestionAnsweringTaskHelper & TaskProviderHelper;
225+
export function getProviderHelper(
226+
provider: InferenceProvider,
227+
task: "tabular-classification"
228+
): TabularClassificationTaskHelper & TaskProviderHelper;
229+
export function getProviderHelper(
230+
provider: InferenceProvider,
231+
task: "tabular-regression"
232+
): TabularRegressionTaskHelper & TaskProviderHelper;
233+
export function getProviderHelper(
234+
provider: InferenceProvider,
235+
task: "token-classification"
236+
): TokenClassificationTaskHelper & TaskProviderHelper;
237+
export function getProviderHelper(
238+
provider: InferenceProvider,
239+
task: "translation"
240+
): TranslationTaskHelper & TaskProviderHelper;
241+
export function getProviderHelper(
242+
provider: InferenceProvider,
243+
task: "summarization"
244+
): SummarizationTaskHelper & TaskProviderHelper;
245+
export function getProviderHelper(
246+
provider: InferenceProvider,
247+
task: "visual-question-answering"
248+
): VisualQuestionAnsweringTaskHelper & TaskProviderHelper;
249+
export function getProviderHelper(provider: InferenceProvider, task: InferenceTask | undefined): TaskProviderHelper;
250+
114251
export function getProviderHelper(provider: InferenceProvider, task: InferenceTask | undefined): TaskProviderHelper {
115-
// special case for hf-inference, where the task is optional
116252
if (provider === "hf-inference") {
117253
if (!task) {
118-
return new HFInferenceTask();
254+
return new HFInference.HFInferenceTask();
119255
}
120256
}
121257
if (!task) {
122258
throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'");
123259
}
124-
const helper = PROVIDERS[provider][task];
125-
if (!helper) {
260+
if (!(provider in PROVIDERS)) {
261+
throw new Error(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`);
262+
}
263+
const providerTasks = PROVIDERS[provider];
264+
if (!providerTasks || !(task in providerTasks)) {
126265
throw new Error(
127-
`Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(PROVIDERS[provider])}`
266+
`Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(providerTasks ?? {})}`
128267
);
129268
}
130-
return helper;
269+
return providerTasks[task] as TaskProviderHelper;
131270
}

packages/inference/src/providers/black-forest-labs.ts

+5-5
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,17 @@ import { InferenceOutputError } from "../lib/InferenceOutputError";
1818
import type { BodyParams, HeaderParams, UrlParams } from "../types";
1919
import { delay } from "../utils/delay";
2020
import { omit } from "../utils/omit";
21-
import { TaskProviderHelper } from "./providerHelper";
21+
import { TaskProviderHelper, type TextToImageTaskHelper } from "./providerHelper";
2222

2323
const BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
2424
interface BlackForestLabsResponse {
2525
id: string;
2626
polling_url: string;
2727
}
2828

29-
export class BlackForestLabsTextToImageTask extends TaskProviderHelper {
29+
export class BlackForestLabsTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper {
3030
constructor() {
31-
super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL, "text-to-image");
31+
super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
3232
}
3333

3434
preparePayload(params: BodyParams): Record<string, unknown> {
@@ -59,8 +59,8 @@ export class BlackForestLabsTextToImageTask extends TaskProviderHelper {
5959

6060
async getResponse(
6161
response: BlackForestLabsResponse,
62-
url: string,
63-
headers: Record<string, string>,
62+
url?: string,
63+
headers?: HeadersInit,
6464
outputType?: "url" | "blob"
6565
): Promise<string | Blob> {
6666
const urlObj = new URL(response.polling_url);

0 commit comments

Comments
 (0)