Skip to content

Better typing + make task functions generic #1338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 4, 2025
313 changes: 226 additions & 87 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
@@ -1,131 +1,270 @@
import { BlackForestLabsTextToImageTask } from "../providers/black-forest-labs";
import { CerebrasConversationalTask } from "../providers/cerebras";
import { CohereConversationalTask } from "../providers/cohere";
import {
FalAIAutomaticSpeechRecognitionTask,
FalAITextToImageTask,
FalAITextToSpeechTask,
FalAITextToVideoTask,
} from "../providers/fal-ai";
import { FireworksConversationalTask } from "../providers/fireworks-ai";
import {
HFInferenceConversationalTask,
HFInferenceTask,
HFInferenceTextGenerationTask,
HFInferenceTextToImageTask,
} from "../providers/hf-inference";
import {
HyperbolicConversationalTask,
HyperbolicTextGenerationTask,
HyperbolicTextToImageTask,
} from "../providers/hyperbolic";
import { NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask } from "../providers/nebius";
import { NovitaConversationalTask, NovitaTextGenerationTask } from "../providers/novita";
import { OpenAIConversationalTask } from "../providers/openai";
import type { TaskProviderHelper } from "../providers/providerHelper";
import { ReplicateTextToImageTask, ReplicateTextToSpeechTask, ReplicateTextToVideoTask } from "../providers/replicate";
import { SambanovaConversationalTask } from "../providers/sambanova";
import { TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask } from "../providers/together";
import * as BlackForestLabs from "../providers/black-forest-labs";
import * as Cerebras from "../providers/cerebras";
import * as Cohere from "../providers/cohere";
import * as FalAI from "../providers/fal-ai";
import * as Fireworks from "../providers/fireworks-ai";
import * as HFInference from "../providers/hf-inference";

import * as Hyperbolic from "../providers/hyperbolic";
import * as Nebius from "../providers/nebius";
import * as Novita from "../providers/novita";
import * as OpenAI from "../providers/openai";
import type {
AudioClassificationTaskHelper,
AudioToAudioTaskHelper,
AutomaticSpeechRecognitionTaskHelper,
ConversationalTaskHelper,
DocumentQuestionAnsweringTaskHelper,
FeatureExtractionTaskHelper,
FillMaskTaskHelper,
ImageClassificationTaskHelper,
ImageSegmentationTaskHelper,
ImageToImageTaskHelper,
ImageToTextTaskHelper,
ObjectDetectionTaskHelper,
QuestionAnsweringTaskHelper,
SentenceSimilarityTaskHelper,
SummarizationTaskHelper,
TableQuestionAnsweringTaskHelper,
TabularClassificationTaskHelper,
TabularRegressionTaskHelper,
TaskProviderHelper,
TextClassificationTaskHelper,
TextGenerationTaskHelper,
TextToAudioTaskHelper,
TextToImageTaskHelper,
TextToSpeechTaskHelper,
TextToVideoTaskHelper,
TokenClassificationTaskHelper,
TranslationTaskHelper,
VisualQuestionAnsweringTaskHelper,
ZeroShotClassificationTaskHelper,
ZeroShotImageClassificationTaskHelper,
} from "../providers/providerHelper";
import * as Replicate from "../providers/replicate";
import * as Sambanova from "../providers/sambanova";
import * as Together from "../providers/together";
import type { InferenceProvider, InferenceTask } from "../types";

export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask, TaskProviderHelper>>> = {
"black-forest-labs": {
"text-to-image": new BlackForestLabsTextToImageTask(),
"text-to-image": new BlackForestLabs.BlackForestLabsTextToImageTask(),
},
cerebras: {
conversational: new CerebrasConversationalTask(),
conversational: new Cerebras.CerebrasConversationalTask(),
},
cohere: {
conversational: new CohereConversationalTask(),
conversational: new Cohere.CohereConversationalTask(),
},
"fal-ai": {
"automatic-speech-recognition": new FalAIAutomaticSpeechRecognitionTask(),
"text-to-image": new FalAITextToImageTask(),
"text-to-speech": new FalAITextToSpeechTask(),
"text-to-video": new FalAITextToVideoTask(),
},
"fireworks-ai": {
conversational: new FireworksConversationalTask(),
"text-to-image": new FalAI.FalAITextToImageTask(),
"text-to-speech": new FalAI.FalAITextToSpeechTask(),
"text-to-video": new FalAI.FalAITextToVideoTask(),
"automatic-speech-recognition": new FalAI.FalAIAutomaticSpeechRecognitionTask(),
},
"hf-inference": {
"text-to-image": new HFInferenceTextToImageTask(),
conversational: new HFInferenceConversationalTask(),
"text-generation": new HFInferenceTextGenerationTask(),
"text-classification": new HFInferenceTask("text-classification"),
"text-to-audio": new HFInferenceTask("text-to-audio"),
"question-answering": new HFInferenceTask("question-answering"),
"audio-classification": new HFInferenceTask("audio-classification"),
"automatic-speech-recognition": new HFInferenceTask("automatic-speech-recognition"),
"fill-mask": new HFInferenceTask("fill-mask"),
"feature-extraction": new HFInferenceTask("feature-extraction"),
"image-classification": new HFInferenceTask("image-classification"),
"image-segmentation": new HFInferenceTask("image-segmentation"),
"document-question-answering": new HFInferenceTask("document-question-answering"),
"image-to-text": new HFInferenceTask("image-to-text"),
"object-detection": new HFInferenceTask("object-detection"),
"audio-to-audio": new HFInferenceTask("audio-to-audio"),
"zero-shot-image-classification": new HFInferenceTask("zero-shot-image-classification"),
"zero-shot-classification": new HFInferenceTask("zero-shot-classification"),
"image-to-image": new HFInferenceTask("image-to-image"),
"sentence-similarity": new HFInferenceTask("sentence-similarity"),
"table-question-answering": new HFInferenceTask("table-question-answering"),
"tabular-classification": new HFInferenceTask("tabular-classification"),
"text-to-speech": new HFInferenceTask("text-to-speech"),
"token-classification": new HFInferenceTask("token-classification"),
translation: new HFInferenceTask("translation"),
summarization: new HFInferenceTask("summarization"),
"visual-question-answering": new HFInferenceTask("visual-question-answering"),
"text-to-image": new HFInference.HFInferenceTextToImageTask(),
conversational: new HFInference.HFInferenceConversationalTask(),
"text-generation": new HFInference.HFInferenceTextGenerationTask(),
"text-classification": new HFInference.HFInferenceTextClassificationTask(),
"question-answering": new HFInference.HFInferenceQuestionAnsweringTask(),
"audio-classification": new HFInference.HFInferenceAudioClassificationTask(),
"automatic-speech-recognition": new HFInference.HFInferenceAutomaticSpeechRecognitionTask(),
"fill-mask": new HFInference.HFInferenceFillMaskTask(),
"feature-extraction": new HFInference.HFInferenceFeatureExtractionTask(),
"image-classification": new HFInference.HFInferenceImageClassificationTask(),
"image-segmentation": new HFInference.HFInferenceImageSegmentationTask(),
"document-question-answering": new HFInference.HFInferenceDocumentQuestionAnsweringTask(),
"image-to-text": new HFInference.HFInferenceImageToTextTask(),
"object-detection": new HFInference.HFInferenceObjectDetectionTask(),
"audio-to-audio": new HFInference.HFInferenceAudioToAudioTask(),
"zero-shot-image-classification": new HFInference.HFInferenceZeroShotImageClassificationTask(),
"zero-shot-classification": new HFInference.HFInferenceZeroShotClassificationTask(),
"image-to-image": new HFInference.HFInferenceImageToImageTask(),
"sentence-similarity": new HFInference.HFInferenceSentenceSimilarityTask(),
"table-question-answering": new HFInference.HFInferenceTableQuestionAnsweringTask(),
"tabular-classification": new HFInference.HFInferenceTabularClassificationTask(),
"text-to-speech": new HFInference.HFInferenceTextToSpeechTask(),
"token-classification": new HFInference.HFInferenceTokenClassificationTask(),
translation: new HFInference.HFInferenceTranslationTask(),
summarization: new HFInference.HFInferenceSummarizationTask(),
"visual-question-answering": new HFInference.HFInferenceVisualQuestionAnsweringTask(),
"tabular-regression": new HFInference.HFInferenceTabularRegressionTask(),
"text-to-audio": new HFInference.HFInferenceTextToAudioTask(),
},
"fireworks-ai": {
conversational: new Fireworks.FireworksConversationalTask(),
},
hyperbolic: {
"text-to-image": new HyperbolicTextToImageTask(),
conversational: new HyperbolicConversationalTask(),
"text-generation": new HyperbolicTextGenerationTask(),
"text-to-image": new Hyperbolic.HyperbolicTextToImageTask(),
conversational: new Hyperbolic.HyperbolicConversationalTask(),
"text-generation": new Hyperbolic.HyperbolicTextGenerationTask(),
},
nebius: {
"text-to-image": new NebiusTextToImageTask(),
conversational: new NebiusConversationalTask(),
"text-generation": new NebiusTextGenerationTask(),
"text-to-image": new Nebius.NebiusTextToImageTask(),
conversational: new Nebius.NebiusConversationalTask(),
"text-generation": new Nebius.NebiusTextGenerationTask(),
},
novita: {
"text-generation": new NovitaTextGenerationTask(),
conversational: new NovitaConversationalTask(),
conversational: new Novita.NovitaConversationalTask(),
"text-generation": new Novita.NovitaTextGenerationTask(),
},
openai: {
conversational: new OpenAIConversationalTask(),
conversational: new OpenAI.OpenAIConversationalTask(),
},
replicate: {
"text-to-image": new ReplicateTextToImageTask(),
"text-to-speech": new ReplicateTextToSpeechTask(),
"text-to-video": new ReplicateTextToVideoTask(),
"text-to-image": new Replicate.ReplicateTextToImageTask(),
"text-to-speech": new Replicate.ReplicateTextToSpeechTask(),
"text-to-video": new Replicate.ReplicateTextToVideoTask(),
},
sambanova: {
conversational: new SambanovaConversationalTask(),
conversational: new Sambanova.SambanovaConversationalTask(),
},
together: {
"text-to-image": new TogetherTextToImageTask(),
"text-generation": new TogetherTextGenerationTask(),
conversational: new TogetherConversationalTask(),
"text-to-image": new Together.TogetherTextToImageTask(),
conversational: new Together.TogetherConversationalTask(),
"text-generation": new Together.TogetherTextGenerationTask(),
},
};

/**
* Get provider helper instance by name and task
*/
export function getProviderHelper(
provider: InferenceProvider,
task: "text-to-image"
): TextToImageTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "conversational"
): ConversationalTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "text-generation"
): TextGenerationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "text-to-speech"
): TextToSpeechTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "text-to-audio"
): TextToAudioTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "automatic-speech-recognition"
): AutomaticSpeechRecognitionTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "text-to-video"
): TextToVideoTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "text-classification"
): TextClassificationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "question-answering"
): QuestionAnsweringTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "audio-classification"
): AudioClassificationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "audio-to-audio"
): AudioToAudioTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "fill-mask"
): FillMaskTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "feature-extraction"
): FeatureExtractionTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "image-classification"
): ImageClassificationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "image-segmentation"
): ImageSegmentationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "document-question-answering"
): DocumentQuestionAnsweringTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "image-to-text"
): ImageToTextTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "object-detection"
): ObjectDetectionTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "zero-shot-image-classification"
): ZeroShotImageClassificationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "zero-shot-classification"
): ZeroShotClassificationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "image-to-image"
): ImageToImageTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "sentence-similarity"
): SentenceSimilarityTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "table-question-answering"
): TableQuestionAnsweringTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "tabular-classification"
): TabularClassificationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "tabular-regression"
): TabularRegressionTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "token-classification"
): TokenClassificationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "translation"
): TranslationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "summarization"
): SummarizationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "visual-question-answering"
): VisualQuestionAnsweringTaskHelper & TaskProviderHelper;
export function getProviderHelper(provider: InferenceProvider, task: InferenceTask | undefined): TaskProviderHelper;

export function getProviderHelper(provider: InferenceProvider, task: InferenceTask | undefined): TaskProviderHelper {
// special case for hf-inference, where the task is optional
if (provider === "hf-inference") {
if (!task) {
return new HFInferenceTask();
return new HFInference.HFInferenceTask();
}
}
if (!task) {
throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'");
}
const helper = PROVIDERS[provider][task];
if (!helper) {
if (!(provider in PROVIDERS)) {
throw new Error(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`);
}
const providerTasks = PROVIDERS[provider];
if (!providerTasks || !(task in providerTasks)) {
throw new Error(
`Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(PROVIDERS[provider])}`
`Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(providerTasks ?? {})}`
);
}
return helper;
return providerTasks[task] as TaskProviderHelper;
}
10 changes: 5 additions & 5 deletions packages/inference/src/providers/black-forest-labs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ import { InferenceOutputError } from "../lib/InferenceOutputError";
import type { BodyParams, HeaderParams, UrlParams } from "../types";
import { delay } from "../utils/delay";
import { omit } from "../utils/omit";
import { TaskProviderHelper } from "./providerHelper";
import { TaskProviderHelper, type TextToImageTaskHelper } from "./providerHelper";

const BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
interface BlackForestLabsResponse {
id: string;
polling_url: string;
}

export class BlackForestLabsTextToImageTask extends TaskProviderHelper {
export class BlackForestLabsTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper {
constructor() {
super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL, "text-to-image");
super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
}

preparePayload(params: BodyParams): Record<string, unknown> {
Expand Down Expand Up @@ -59,8 +59,8 @@ export class BlackForestLabsTextToImageTask extends TaskProviderHelper {

async getResponse(
response: BlackForestLabsResponse,
url: string,
headers: Record<string, string>,
url?: string,
headers?: HeadersInit,
outputType?: "url" | "blob"
): Promise<string | Blob> {
const urlObj = new URL(response.polling_url);
Expand Down
Loading