diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index acc6fdf09d..8c345208a4 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -1,131 +1,270 @@ -import { BlackForestLabsTextToImageTask } from "../providers/black-forest-labs"; -import { CerebrasConversationalTask } from "../providers/cerebras"; -import { CohereConversationalTask } from "../providers/cohere"; -import { - FalAIAutomaticSpeechRecognitionTask, - FalAITextToImageTask, - FalAITextToSpeechTask, - FalAITextToVideoTask, -} from "../providers/fal-ai"; -import { FireworksConversationalTask } from "../providers/fireworks-ai"; -import { - HFInferenceConversationalTask, - HFInferenceTask, - HFInferenceTextGenerationTask, - HFInferenceTextToImageTask, -} from "../providers/hf-inference"; -import { - HyperbolicConversationalTask, - HyperbolicTextGenerationTask, - HyperbolicTextToImageTask, -} from "../providers/hyperbolic"; -import { NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask } from "../providers/nebius"; -import { NovitaConversationalTask, NovitaTextGenerationTask } from "../providers/novita"; -import { OpenAIConversationalTask } from "../providers/openai"; -import type { TaskProviderHelper } from "../providers/providerHelper"; -import { ReplicateTextToImageTask, ReplicateTextToSpeechTask, ReplicateTextToVideoTask } from "../providers/replicate"; -import { SambanovaConversationalTask } from "../providers/sambanova"; -import { TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask } from "../providers/together"; +import * as BlackForestLabs from "../providers/black-forest-labs"; +import * as Cerebras from "../providers/cerebras"; +import * as Cohere from "../providers/cohere"; +import * as FalAI from "../providers/fal-ai"; +import * as Fireworks from "../providers/fireworks-ai"; +import * as HFInference from "../providers/hf-inference"; + +import * as Hyperbolic from "../providers/hyperbolic"; +import * as Nebius from "../providers/nebius"; +import * as Novita from "../providers/novita"; +import * as OpenAI from "../providers/openai"; +import type { + AudioClassificationTaskHelper, + AudioToAudioTaskHelper, + AutomaticSpeechRecognitionTaskHelper, + ConversationalTaskHelper, + DocumentQuestionAnsweringTaskHelper, + FeatureExtractionTaskHelper, + FillMaskTaskHelper, + ImageClassificationTaskHelper, + ImageSegmentationTaskHelper, + ImageToImageTaskHelper, + ImageToTextTaskHelper, + ObjectDetectionTaskHelper, + QuestionAnsweringTaskHelper, + SentenceSimilarityTaskHelper, + SummarizationTaskHelper, + TableQuestionAnsweringTaskHelper, + TabularClassificationTaskHelper, + TabularRegressionTaskHelper, + TaskProviderHelper, + TextClassificationTaskHelper, + TextGenerationTaskHelper, + TextToAudioTaskHelper, + TextToImageTaskHelper, + TextToSpeechTaskHelper, + TextToVideoTaskHelper, + TokenClassificationTaskHelper, + TranslationTaskHelper, + VisualQuestionAnsweringTaskHelper, + ZeroShotClassificationTaskHelper, + ZeroShotImageClassificationTaskHelper, +} from "../providers/providerHelper"; +import * as Replicate from "../providers/replicate"; +import * as Sambanova from "../providers/sambanova"; +import * as Together from "../providers/together"; import type { InferenceProvider, InferenceTask } from "../types"; export const PROVIDERS: Record>> = { "black-forest-labs": { - "text-to-image": new BlackForestLabsTextToImageTask(), + "text-to-image": new BlackForestLabs.BlackForestLabsTextToImageTask(), }, cerebras: { - conversational: new CerebrasConversationalTask(), + conversational: new Cerebras.CerebrasConversationalTask(), }, cohere: { - conversational: new CohereConversationalTask(), + conversational: new Cohere.CohereConversationalTask(), }, "fal-ai": { - "automatic-speech-recognition": new FalAIAutomaticSpeechRecognitionTask(), - "text-to-image": new FalAITextToImageTask(), - "text-to-speech": new FalAITextToSpeechTask(), - "text-to-video": new FalAITextToVideoTask(), - }, - "fireworks-ai": { - conversational: new FireworksConversationalTask(), + "text-to-image": new FalAI.FalAITextToImageTask(), + "text-to-speech": new FalAI.FalAITextToSpeechTask(), + "text-to-video": new FalAI.FalAITextToVideoTask(), + "automatic-speech-recognition": new FalAI.FalAIAutomaticSpeechRecognitionTask(), }, "hf-inference": { - "text-to-image": new HFInferenceTextToImageTask(), - conversational: new HFInferenceConversationalTask(), - "text-generation": new HFInferenceTextGenerationTask(), - "text-classification": new HFInferenceTask("text-classification"), - "text-to-audio": new HFInferenceTask("text-to-audio"), - "question-answering": new HFInferenceTask("question-answering"), - "audio-classification": new HFInferenceTask("audio-classification"), - "automatic-speech-recognition": new HFInferenceTask("automatic-speech-recognition"), - "fill-mask": new HFInferenceTask("fill-mask"), - "feature-extraction": new HFInferenceTask("feature-extraction"), - "image-classification": new HFInferenceTask("image-classification"), - "image-segmentation": new HFInferenceTask("image-segmentation"), - "document-question-answering": new HFInferenceTask("document-question-answering"), - "image-to-text": new HFInferenceTask("image-to-text"), - "object-detection": new HFInferenceTask("object-detection"), - "audio-to-audio": new HFInferenceTask("audio-to-audio"), - "zero-shot-image-classification": new HFInferenceTask("zero-shot-image-classification"), - "zero-shot-classification": new HFInferenceTask("zero-shot-classification"), - "image-to-image": new HFInferenceTask("image-to-image"), - "sentence-similarity": new HFInferenceTask("sentence-similarity"), - "table-question-answering": new HFInferenceTask("table-question-answering"), - "tabular-classification": new HFInferenceTask("tabular-classification"), - "text-to-speech": new HFInferenceTask("text-to-speech"), - "token-classification": new HFInferenceTask("token-classification"), - translation: new HFInferenceTask("translation"), - summarization: new HFInferenceTask("summarization"), - "visual-question-answering": new HFInferenceTask("visual-question-answering"), + "text-to-image": new HFInference.HFInferenceTextToImageTask(), + conversational: new HFInference.HFInferenceConversationalTask(), + "text-generation": new HFInference.HFInferenceTextGenerationTask(), + "text-classification": new HFInference.HFInferenceTextClassificationTask(), + "question-answering": new HFInference.HFInferenceQuestionAnsweringTask(), + "audio-classification": new HFInference.HFInferenceAudioClassificationTask(), + "automatic-speech-recognition": new HFInference.HFInferenceAutomaticSpeechRecognitionTask(), + "fill-mask": new HFInference.HFInferenceFillMaskTask(), + "feature-extraction": new HFInference.HFInferenceFeatureExtractionTask(), + "image-classification": new HFInference.HFInferenceImageClassificationTask(), + "image-segmentation": new HFInference.HFInferenceImageSegmentationTask(), + "document-question-answering": new HFInference.HFInferenceDocumentQuestionAnsweringTask(), + "image-to-text": new HFInference.HFInferenceImageToTextTask(), + "object-detection": new HFInference.HFInferenceObjectDetectionTask(), + "audio-to-audio": new HFInference.HFInferenceAudioToAudioTask(), + "zero-shot-image-classification": new HFInference.HFInferenceZeroShotImageClassificationTask(), + "zero-shot-classification": new HFInference.HFInferenceZeroShotClassificationTask(), + "image-to-image": new HFInference.HFInferenceImageToImageTask(), + "sentence-similarity": new HFInference.HFInferenceSentenceSimilarityTask(), + "table-question-answering": new HFInference.HFInferenceTableQuestionAnsweringTask(), + "tabular-classification": new HFInference.HFInferenceTabularClassificationTask(), + "text-to-speech": new HFInference.HFInferenceTextToSpeechTask(), + "token-classification": new HFInference.HFInferenceTokenClassificationTask(), + translation: new HFInference.HFInferenceTranslationTask(), + summarization: new HFInference.HFInferenceSummarizationTask(), + "visual-question-answering": new HFInference.HFInferenceVisualQuestionAnsweringTask(), + "tabular-regression": new HFInference.HFInferenceTabularRegressionTask(), + "text-to-audio": new HFInference.HFInferenceTextToAudioTask(), + }, + "fireworks-ai": { + conversational: new Fireworks.FireworksConversationalTask(), }, hyperbolic: { - "text-to-image": new HyperbolicTextToImageTask(), - conversational: new HyperbolicConversationalTask(), - "text-generation": new HyperbolicTextGenerationTask(), + "text-to-image": new Hyperbolic.HyperbolicTextToImageTask(), + conversational: new Hyperbolic.HyperbolicConversationalTask(), + "text-generation": new Hyperbolic.HyperbolicTextGenerationTask(), }, nebius: { - "text-to-image": new NebiusTextToImageTask(), - conversational: new NebiusConversationalTask(), - "text-generation": new NebiusTextGenerationTask(), + "text-to-image": new Nebius.NebiusTextToImageTask(), + conversational: new Nebius.NebiusConversationalTask(), + "text-generation": new Nebius.NebiusTextGenerationTask(), }, novita: { - "text-generation": new NovitaTextGenerationTask(), - conversational: new NovitaConversationalTask(), + conversational: new Novita.NovitaConversationalTask(), + "text-generation": new Novita.NovitaTextGenerationTask(), }, openai: { - conversational: new OpenAIConversationalTask(), + conversational: new OpenAI.OpenAIConversationalTask(), }, replicate: { - "text-to-image": new ReplicateTextToImageTask(), - "text-to-speech": new ReplicateTextToSpeechTask(), - "text-to-video": new ReplicateTextToVideoTask(), + "text-to-image": new Replicate.ReplicateTextToImageTask(), + "text-to-speech": new Replicate.ReplicateTextToSpeechTask(), + "text-to-video": new Replicate.ReplicateTextToVideoTask(), }, sambanova: { - conversational: new SambanovaConversationalTask(), + conversational: new Sambanova.SambanovaConversationalTask(), }, together: { - "text-to-image": new TogetherTextToImageTask(), - "text-generation": new TogetherTextGenerationTask(), - conversational: new TogetherConversationalTask(), + "text-to-image": new Together.TogetherTextToImageTask(), + conversational: new Together.TogetherConversationalTask(), + "text-generation": new Together.TogetherTextGenerationTask(), }, }; /** * Get provider helper instance by name and task */ +export function getProviderHelper( + provider: InferenceProvider, + task: "text-to-image" +): TextToImageTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "conversational" +): ConversationalTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "text-generation" +): TextGenerationTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "text-to-speech" +): TextToSpeechTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "text-to-audio" +): TextToAudioTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "automatic-speech-recognition" +): AutomaticSpeechRecognitionTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "text-to-video" +): TextToVideoTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "text-classification" +): TextClassificationTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "question-answering" +): QuestionAnsweringTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "audio-classification" +): AudioClassificationTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "audio-to-audio" +): AudioToAudioTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "fill-mask" +): FillMaskTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "feature-extraction" +): FeatureExtractionTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "image-classification" +): ImageClassificationTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "image-segmentation" +): ImageSegmentationTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "document-question-answering" +): DocumentQuestionAnsweringTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "image-to-text" +): ImageToTextTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "object-detection" +): ObjectDetectionTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "zero-shot-image-classification" +): ZeroShotImageClassificationTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "zero-shot-classification" +): ZeroShotClassificationTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "image-to-image" +): ImageToImageTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "sentence-similarity" +): SentenceSimilarityTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "table-question-answering" +): TableQuestionAnsweringTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "tabular-classification" +): TabularClassificationTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "tabular-regression" +): TabularRegressionTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "token-classification" +): TokenClassificationTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "translation" +): TranslationTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "summarization" +): SummarizationTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProvider, + task: "visual-question-answering" +): VisualQuestionAnsweringTaskHelper & TaskProviderHelper; +export function getProviderHelper(provider: InferenceProvider, task: InferenceTask | undefined): TaskProviderHelper; + export function getProviderHelper(provider: InferenceProvider, task: InferenceTask | undefined): TaskProviderHelper { - // special case for hf-inference, where the task is optional if (provider === "hf-inference") { if (!task) { - return new HFInferenceTask(); + return new HFInference.HFInferenceTask(); } } if (!task) { throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'"); } - const helper = PROVIDERS[provider][task]; - if (!helper) { + if (!(provider in PROVIDERS)) { + throw new Error(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`); + } + const providerTasks = PROVIDERS[provider]; + if (!providerTasks || !(task in providerTasks)) { throw new Error( - `Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(PROVIDERS[provider])}` + `Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(providerTasks ?? {})}` ); } - return helper; + return providerTasks[task] as TaskProviderHelper; } diff --git a/packages/inference/src/providers/black-forest-labs.ts b/packages/inference/src/providers/black-forest-labs.ts index 68a542e479..58ccf5bcde 100644 --- a/packages/inference/src/providers/black-forest-labs.ts +++ b/packages/inference/src/providers/black-forest-labs.ts @@ -18,7 +18,7 @@ import { InferenceOutputError } from "../lib/InferenceOutputError"; import type { BodyParams, HeaderParams, UrlParams } from "../types"; import { delay } from "../utils/delay"; import { omit } from "../utils/omit"; -import { TaskProviderHelper } from "./providerHelper"; +import { TaskProviderHelper, type TextToImageTaskHelper } from "./providerHelper"; const BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai"; interface BlackForestLabsResponse { @@ -26,9 +26,9 @@ interface BlackForestLabsResponse { polling_url: string; } -export class BlackForestLabsTextToImageTask extends TaskProviderHelper { +export class BlackForestLabsTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper { constructor() { - super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL, "text-to-image"); + super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL); } preparePayload(params: BodyParams): Record { @@ -59,8 +59,8 @@ export class BlackForestLabsTextToImageTask extends TaskProviderHelper { async getResponse( response: BlackForestLabsResponse, - url: string, - headers: Record, + url?: string, + headers?: HeadersInit, outputType?: "url" | "blob" ): Promise { const urlObj = new URL(response.polling_url); diff --git a/packages/inference/src/providers/cohere.ts b/packages/inference/src/providers/cohere.ts index a80037e7d7..e688083718 100644 --- a/packages/inference/src/providers/cohere.ts +++ b/packages/inference/src/providers/cohere.ts @@ -14,15 +14,13 @@ * * Thanks! */ -import type { UrlParams } from "../types"; import { BaseConversationalTask } from "./providerHelper"; export class CohereConversationalTask extends BaseConversationalTask { constructor() { super("cohere", "https://api.cohere.com"); } - override makeRoute(params: UrlParams): string { - void params; + override makeRoute(): string { return "/compatibility/v1/chat/completions"; } } diff --git a/packages/inference/src/providers/fal-ai.ts b/packages/inference/src/providers/fal-ai.ts index abbd4ea66a..5e0669c8d1 100644 --- a/packages/inference/src/providers/fal-ai.ts +++ b/packages/inference/src/providers/fal-ai.ts @@ -17,10 +17,15 @@ import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks"; import { InferenceOutputError } from "../lib/InferenceOutputError"; import { isUrl } from "../lib/isUrl"; -import type { BodyParams, HeaderParams, InferenceTask, UrlParams } from "../types"; +import type { BodyParams, HeaderParams, UrlParams } from "../types"; import { delay } from "../utils/delay"; import { omit } from "../utils/omit"; -import { TaskProviderHelper } from "./providerHelper"; +import { + type AutomaticSpeechRecognitionTaskHelper, + TaskProviderHelper, + type TextToImageTaskHelper, + type TextToVideoTaskHelper, +} from "./providerHelper"; export interface FalAiQueueOutput { request_id: string; @@ -47,8 +52,8 @@ interface FalAITextToSpeechOutput { export const FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"]; abstract class FalAITask extends TaskProviderHelper { - constructor(task: InferenceTask, url?: string) { - super("fal-ai", url || "https://fal.run", task); + constructor(url?: string) { + super("fal-ai", url || "https://fal.run"); } preparePayload(params: BodyParams): Record { @@ -69,10 +74,7 @@ abstract class FalAITask extends TaskProviderHelper { } } -export class FalAITextToImageTask extends FalAITask { - constructor() { - super("text-to-image"); - } +export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHelper { override preparePayload(params: BodyParams): Record { return { ...omit(params.args, ["inputs", "parameters"]), @@ -102,9 +104,9 @@ export class FalAITextToImageTask extends FalAITask { } } -export class FalAITextToVideoTask extends FalAITask { +export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHelper { constructor() { - super("text-to-video", "https://queue.fal.run"); + super("https://queue.fal.run"); } override makeRoute(params: UrlParams): string { if (params.authMethod !== "provider-key") { @@ -188,16 +190,13 @@ export class FalAITextToVideoTask extends FalAITask { } } -export class FalAIAutomaticSpeechRecognitionTask extends FalAITask { - constructor() { - super("automatic-speech-recognition"); - } +export class FalAIAutomaticSpeechRecognitionTask extends FalAITask implements AutomaticSpeechRecognitionTaskHelper { override prepareHeaders(params: HeaderParams, binary: boolean): Record { const headers = super.prepareHeaders(params, binary); headers["Content-Type"] = "application/json"; return headers; } - override getResponse(response: unknown): AutomaticSpeechRecognitionOutput { + override async getResponse(response: unknown): Promise { const res = response as FalAIAutomaticSpeechRecognitionOutput; if (typeof res?.text !== "string") { throw new InferenceOutputError( @@ -209,10 +208,6 @@ export class FalAIAutomaticSpeechRecognitionTask extends FalAITask { } export class FalAITextToSpeechTask extends FalAITask { - constructor() { - super("text-to-speech"); - } - override preparePayload(params: BodyParams): Record { return { ...omit(params.args, ["inputs", "parameters"]), diff --git a/packages/inference/src/providers/fireworks-ai.ts b/packages/inference/src/providers/fireworks-ai.ts index 5716ce6ceb..f2e721f7bd 100644 --- a/packages/inference/src/providers/fireworks-ai.ts +++ b/packages/inference/src/providers/fireworks-ai.ts @@ -14,15 +14,14 @@ * * Thanks! */ -import type { UrlParams } from "../types"; import { BaseConversationalTask } from "./providerHelper"; export class FireworksConversationalTask extends BaseConversationalTask { constructor() { super("fireworks-ai", "https://api.fireworks.ai"); } - override makeRoute(params: UrlParams): string { - void params; + + override makeRoute(): string { return "/inference/v1/chat/completions"; } } diff --git a/packages/inference/src/providers/hf-inference.ts b/packages/inference/src/providers/hf-inference.ts index 9fc4a693f5..496519e5f0 100644 --- a/packages/inference/src/providers/hf-inference.ts +++ b/packages/inference/src/providers/hf-inference.ts @@ -9,11 +9,66 @@ * and we will tag HF team members. * * Thanks! - */ import type { TextGenerationOutput } from "@huggingface/tasks"; + */ +import type { + AudioClassificationOutput, + AutomaticSpeechRecognitionOutput, + ChatCompletionOutput, + DocumentQuestionAnsweringOutput, + FeatureExtractionOutput, + FillMaskOutput, + ImageClassificationOutput, + ImageSegmentationOutput, + ImageToTextOutput, + ObjectDetectionOutput, + QuestionAnsweringOutput, + SentenceSimilarityOutput, + SummarizationOutput, + TableQuestionAnsweringOutput, + TextClassificationOutput, + TextGenerationOutput, + TokenClassificationOutput, + TranslationOutput, + VisualQuestionAnsweringOutput, + ZeroShotClassificationOutput, + ZeroShotImageClassificationOutput, +} from "@huggingface/tasks"; import { HF_ROUTER_URL } from "../config"; import { InferenceOutputError } from "../lib/InferenceOutputError"; -import type { BodyParams, InferenceTask, UrlParams } from "../types"; +import type { TabularClassificationOutput } from "../tasks/tabular/tabularClassification"; +import type { BodyParams, UrlParams } from "../types"; import { toArray } from "../utils/toArray"; +import type { + AudioClassificationTaskHelper, + AudioToAudioTaskHelper, + AutomaticSpeechRecognitionTaskHelper, + ConversationalTaskHelper, + DocumentQuestionAnsweringTaskHelper, + FeatureExtractionTaskHelper, + FillMaskTaskHelper, + ImageClassificationTaskHelper, + ImageSegmentationTaskHelper, + ImageToImageTaskHelper, + ImageToTextTaskHelper, + ObjectDetectionTaskHelper, + QuestionAnsweringTaskHelper, + SentenceSimilarityTaskHelper, + SummarizationTaskHelper, + TableQuestionAnsweringTaskHelper, + TabularClassificationTaskHelper, + TabularRegressionTaskHelper, + TextClassificationTaskHelper, + TextGenerationTaskHelper, + TextToAudioTaskHelper, + TextToImageTaskHelper, + TextToSpeechTaskHelper, + TokenClassificationTaskHelper, + TranslationTaskHelper, + VisualQuestionAnsweringTaskHelper, + ZeroShotClassificationTaskHelper, + ZeroShotImageClassificationTaskHelper, +} from "./providerHelper"; + import { TaskProviderHelper } from "./providerHelper"; interface Base64ImageGeneration { @@ -25,10 +80,15 @@ interface Base64ImageGeneration { interface OutputUrlImageGeneration { output: string[]; } +interface AudioToAudioOutput { + blob: string; + "content-type": string; + label: string; +} export class HFInferenceTask extends TaskProviderHelper { - constructor(task?: InferenceTask) { - super("hf-inference", `${HF_ROUTER_URL}/hf-inference`, task); + constructor() { + super("hf-inference", `${HF_ROUTER_URL}/hf-inference`); } preparePayload(params: BodyParams): Record { return params.args; @@ -48,22 +108,18 @@ export class HFInferenceTask extends TaskProviderHelper { return `models/${params.model}`; } - override getResponse(response: unknown): unknown { + override async getResponse(response: unknown): Promise { return response; } } -export class HFInferenceTextToImageTask extends HFInferenceTask { - constructor() { - super("text-to-image"); - } - +export class HFInferenceTextToImageTask extends HFInferenceTask implements TextToImageTaskHelper { override async getResponse( response: Base64ImageGeneration | OutputUrlImageGeneration, url?: string, - headers?: Record, + headers?: HeadersInit, outputType?: "url" | "blob" - ): Promise { + ): Promise { if (!response) { throw new InferenceOutputError("response is undefined"); } @@ -96,10 +152,7 @@ export class HFInferenceTextToImageTask extends HFInferenceTask { } } -export class HFInferenceConversationalTask extends HFInferenceTask { - constructor() { - super("conversational"); - } +export class HFInferenceConversationalTask extends HFInferenceTask implements ConversationalTaskHelper { override makeUrl(params: UrlParams): string { let url: string; if (params.model.startsWith("http://") || params.model.startsWith("https://")) { @@ -124,14 +177,14 @@ export class HFInferenceConversationalTask extends HFInferenceTask { model: params.model, }; } -} -export class HFInferenceTextGenerationTask extends HFInferenceTask { - constructor() { - super("text-generation"); + override async getResponse(response: ChatCompletionOutput): Promise { + return response; } +} - override getResponse(response: TextGenerationOutput | TextGenerationOutput[]): unknown { +export class HFInferenceTextGenerationTask extends HFInferenceTask implements TextGenerationTaskHelper { + override async getResponse(response: TextGenerationOutput | TextGenerationOutput[]): Promise { const res = toArray(response); if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) { return (res as TextGenerationOutput[])?.[0]; @@ -139,3 +192,373 @@ export class HFInferenceTextGenerationTask extends HFInferenceTask { throw new InferenceOutputError("Expected Array<{generated_text: string}>"); } } + +export class HFInferenceAudioClassificationTask extends HFInferenceTask implements AudioClassificationTaskHelper { + override async getResponse(response: unknown): Promise { + // Add type checking/validation for the 'unknown' input + if ( + Array.isArray(response) && + response.every( + (x): x is { label: string; score: number } => + typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number" + ) + ) { + // If validation passes, it's safe to return as AudioClassificationOutput + return response; + } + // If validation fails, throw an error + throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format"); + } +} + +export class HFInferenceAutomaticSpeechRecognitionTask + extends HFInferenceTask + implements AutomaticSpeechRecognitionTaskHelper +{ + override async getResponse(response: AutomaticSpeechRecognitionOutput): Promise { + return response; + } +} +export class HFInferenceAudioToAudioTask extends HFInferenceTask implements AudioToAudioTaskHelper { + override async getResponse(response: AudioToAudioOutput[]): Promise { + if (!Array.isArray(response)) { + throw new InferenceOutputError("Expected Array"); + } + if ( + !response.every((elem): elem is AudioToAudioOutput => { + return ( + typeof elem === "object" && + elem && + "label" in elem && + typeof elem.label === "string" && + "content-type" in elem && + typeof elem["content-type"] === "string" && + "blob" in elem && + typeof elem.blob === "string" + ); + }) + ) { + throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>"); + } + return response; + } +} +export class HFInferenceDocumentQuestionAnsweringTask + extends HFInferenceTask + implements DocumentQuestionAnsweringTaskHelper +{ + override async getResponse( + response: DocumentQuestionAnsweringOutput + ): Promise { + if ( + Array.isArray(response) && + response.every( + (elem) => + typeof elem === "object" && + !!elem && + typeof elem?.answer === "string" && + (typeof elem.end === "number" || typeof elem.end === "undefined") && + (typeof elem.score === "number" || typeof elem.score === "undefined") && + (typeof elem.start === "number" || typeof elem.start === "undefined") + ) + ) { + return response[0]; + } + throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>"); + } +} + +export class HFInferenceFeatureExtractionTask extends HFInferenceTask implements FeatureExtractionTaskHelper { + override async getResponse(response: FeatureExtractionOutput): Promise { + const isNumArrayRec = (arr: unknown[], maxDepth: number, curDepth = 0): boolean => { + if (curDepth > maxDepth) return false; + if (arr.every((x) => Array.isArray(x))) { + return arr.every((x) => isNumArrayRec(x as unknown[], maxDepth, curDepth + 1)); + } else { + return arr.every((x) => typeof x === "number"); + } + }; + if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) { + return response; + } + throw new InferenceOutputError("Expected Array"); + } +} + +export class HFInferenceImageClassificationTask extends HFInferenceTask implements ImageClassificationTaskHelper { + override async getResponse(response: ImageClassificationOutput): Promise { + if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) { + return response; + } + throw new InferenceOutputError("Expected Array<{label: string, score: number}>"); + } +} + +export class HFInferenceImageSegmentationTask extends HFInferenceTask implements ImageSegmentationTaskHelper { + override async getResponse(response: ImageSegmentationOutput): Promise { + if ( + Array.isArray(response) && + response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number") + ) { + return response; + } + throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>"); + } +} + +export class HFInferenceImageToTextTask extends HFInferenceTask implements ImageToTextTaskHelper { + override async getResponse(response: ImageToTextOutput): Promise { + if (typeof response?.generated_text !== "string") { + throw new InferenceOutputError("Expected {generated_text: string}"); + } + return response; + } +} + +export class HFInferenceImageToImageTask extends HFInferenceTask implements ImageToImageTaskHelper { + override async getResponse(response: Blob): Promise { + if (response instanceof Blob) { + return response; + } + throw new InferenceOutputError("Expected Blob"); + } +} + +export class HFInferenceObjectDetectionTask extends HFInferenceTask implements ObjectDetectionTaskHelper { + override async getResponse(response: ObjectDetectionOutput): Promise { + if ( + Array.isArray(response) && + response.every( + (x) => + typeof x.label === "string" && + typeof x.score === "number" && + typeof x.box.xmin === "number" && + typeof x.box.ymin === "number" && + typeof x.box.xmax === "number" && + typeof x.box.ymax === "number" + ) + ) { + return response; + } + throw new InferenceOutputError( + "Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>" + ); + } +} + +export class HFInferenceZeroShotImageClassificationTask + extends HFInferenceTask + implements ZeroShotImageClassificationTaskHelper +{ + override async getResponse(response: ZeroShotImageClassificationOutput): Promise { + if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) { + return response; + } + throw new InferenceOutputError("Expected Array<{label: string, score: number}>"); + } +} + +export class HFInferenceTextClassificationTask extends HFInferenceTask implements TextClassificationTaskHelper { + override async getResponse(response: TextClassificationOutput): Promise { + if (Array.isArray(response) && response.every((x) => typeof x?.label === "string" && typeof x.score === "number")) { + return response; + } + throw new InferenceOutputError("Expected Array<{label: string, score: number}>"); + } +} + +export class HFInferenceQuestionAnsweringTask extends HFInferenceTask implements QuestionAnsweringTaskHelper { + override async getResponse( + response: QuestionAnsweringOutput | QuestionAnsweringOutput[number] + ): Promise { + if ( + Array.isArray(response) + ? response.every( + (elem) => + typeof elem === "object" && + !!elem && + typeof elem.answer === "string" && + typeof elem.end === "number" && + typeof elem.score === "number" && + typeof elem.start === "number" + ) + : typeof response === "object" && + !!response && + typeof response.answer === "string" && + typeof response.end === "number" && + typeof response.score === "number" && + typeof response.start === "number" + ) { + return Array.isArray(response) ? response[0] : response; + } + throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>"); + } +} + +export class HFInferenceFillMaskTask extends HFInferenceTask implements FillMaskTaskHelper { + override async getResponse(response: FillMaskOutput): Promise { + if ( + Array.isArray(response) && + response.every( + (x) => + typeof x.score === "number" && + typeof x.sequence === "string" && + typeof x.token === "number" && + typeof x.token_str === "string" + ) + ) { + return response; + } + throw new InferenceOutputError( + "Expected Array<{score: number, sequence: string, token: number, token_str: string}>" + ); + } +} + +export class HFInferenceZeroShotClassificationTask extends HFInferenceTask implements ZeroShotClassificationTaskHelper { + override async getResponse(response: ZeroShotClassificationOutput): Promise { + if ( + Array.isArray(response) && + response.every( + (x) => + Array.isArray(x.labels) && + x.labels.every((_label) => typeof _label === "string") && + Array.isArray(x.scores) && + x.scores.every((_score) => typeof _score === "number") && + typeof x.sequence === "string" + ) + ) { + return response; + } + throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>"); + } +} + +export class HFInferenceSentenceSimilarityTask extends HFInferenceTask implements SentenceSimilarityTaskHelper { + override async getResponse(response: SentenceSimilarityOutput): Promise { + if (Array.isArray(response) && response.every((x) => typeof x === "number")) { + return response; + } + throw new InferenceOutputError("Expected Array"); + } +} + +export class HFInferenceTableQuestionAnsweringTask extends HFInferenceTask implements TableQuestionAnsweringTaskHelper { + static validate(elem: unknown): elem is TableQuestionAnsweringOutput[number] { + return ( + typeof elem === "object" && + !!elem && + "aggregator" in elem && + typeof elem.aggregator === "string" && + "answer" in elem && + typeof elem.answer === "string" && + "cells" in elem && + Array.isArray(elem.cells) && + elem.cells.every((x: unknown): x is string => typeof x === "string") && + "coordinates" in elem && + Array.isArray(elem.coordinates) && + elem.coordinates.every( + (coord: unknown): coord is number[] => Array.isArray(coord) && coord.every((x) => typeof x === "number") + ) + ); + } + override async getResponse(response: TableQuestionAnsweringOutput): Promise { + if ( + Array.isArray(response) && Array.isArray(response) + ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) + : HFInferenceTableQuestionAnsweringTask.validate(response) + ) { + return Array.isArray(response) ? response[0] : response; + } + throw new InferenceOutputError( + "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}" + ); + } +} + +export class HFInferenceTokenClassificationTask extends HFInferenceTask implements TokenClassificationTaskHelper { + override async getResponse(response: TokenClassificationOutput): Promise { + if ( + Array.isArray(response) && + response.every( + (x) => + typeof x.end === "number" && + typeof x.entity_group === "string" && + typeof x.score === "number" && + typeof x.start === "number" && + typeof x.word === "string" + ) + ) { + return response; + } + throw new InferenceOutputError( + "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>" + ); + } +} + +export class HFInferenceTranslationTask extends HFInferenceTask implements TranslationTaskHelper { + override async getResponse(response: TranslationOutput): Promise { + if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) { + return response?.length === 1 ? response?.[0] : response; + } + throw new InferenceOutputError("Expected Array<{translation_text: string}>"); + } +} + +export class HFInferenceSummarizationTask extends HFInferenceTask implements SummarizationTaskHelper { + override async getResponse(response: SummarizationOutput): Promise { + if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) { + return response?.[0]; + } + throw new InferenceOutputError("Expected Array<{summary_text: string}>"); + } +} + +export class HFInferenceTextToSpeechTask extends HFInferenceTask implements TextToSpeechTaskHelper { + override async getResponse(response: Blob): Promise { + return response; + } +} + +export class HFInferenceTabularClassificationTask extends HFInferenceTask implements TabularClassificationTaskHelper { + override async getResponse(response: TabularClassificationOutput): Promise { + if (Array.isArray(response) && response.every((x) => typeof x === "number")) { + return response; + } + throw new InferenceOutputError("Expected Array"); + } +} + +export class HFInferenceVisualQuestionAnsweringTask + extends HFInferenceTask + implements VisualQuestionAnsweringTaskHelper +{ + override async getResponse(response: VisualQuestionAnsweringOutput): Promise { + if ( + Array.isArray(response) && + response.every( + (elem) => + typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number" + ) + ) { + return response[0]; + } + throw new InferenceOutputError("Expected Array<{answer: string, score: number}>"); + } +} + +export class HFInferenceTabularRegressionTask extends HFInferenceTask implements TabularRegressionTaskHelper { + override async getResponse(response: number[]): Promise { + if (Array.isArray(response) && response.every((x) => typeof x === "number")) { + return response; + } + throw new InferenceOutputError("Expected Array"); + } +} + +export class HFInferenceTextToAudioTask extends HFInferenceTask implements TextToAudioTaskHelper { + override async getResponse(response: Blob): Promise { + return response; + } +} diff --git a/packages/inference/src/providers/hyperbolic.ts b/packages/inference/src/providers/hyperbolic.ts index 419f0d0b1e..56d7bc31d9 100644 --- a/packages/inference/src/providers/hyperbolic.ts +++ b/packages/inference/src/providers/hyperbolic.ts @@ -18,7 +18,12 @@ import type { ChatCompletionOutput, TextGenerationOutput } from "@huggingface/ta import { InferenceOutputError } from "../lib/InferenceOutputError"; import type { BodyParams, UrlParams } from "../types"; import { omit } from "../utils/omit"; -import { BaseConversationalTask, BaseTextGenerationTask, TaskProviderHelper } from "./providerHelper"; +import { + BaseConversationalTask, + BaseTextGenerationTask, + TaskProviderHelper, + type TextToImageTaskHelper, +} from "./providerHelper"; const HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz"; @@ -43,8 +48,7 @@ export class HyperbolicTextGenerationTask extends BaseTextGenerationTask { super("hyperbolic", HYPERBOLIC_API_BASE_URL); } - override makeRoute(params: UrlParams): string { - void params; + override makeRoute(): string { return "v1/chat/completions"; } @@ -62,7 +66,7 @@ export class HyperbolicTextGenerationTask extends BaseTextGenerationTask { }; } - override getResponse(response: HyperbolicTextCompletionOutput): TextGenerationOutput { + override async getResponse(response: HyperbolicTextCompletionOutput): Promise { if ( typeof response === "object" && "choices" in response && @@ -79,9 +83,9 @@ export class HyperbolicTextGenerationTask extends BaseTextGenerationTask { } } -export class HyperbolicTextToImageTask extends TaskProviderHelper { +export class HyperbolicTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper { constructor() { - super("hyperbolic", HYPERBOLIC_API_BASE_URL, "text-to-image"); + super("hyperbolic", HYPERBOLIC_API_BASE_URL); } makeRoute(params: UrlParams): string { @@ -98,7 +102,12 @@ export class HyperbolicTextToImageTask extends TaskProviderHelper { }; } - getResponse(response: HyperbolicTextToImageOutput, outputType?: "url" | "blob"): Promise | string { + async getResponse( + response: HyperbolicTextToImageOutput, + url?: string, + headers?: HeadersInit, + outputType?: "url" | "blob" + ): Promise { if ( typeof response === "object" && "images" in response && diff --git a/packages/inference/src/providers/nebius.ts b/packages/inference/src/providers/nebius.ts index c3ffcd4900..c5b6eae3ea 100644 --- a/packages/inference/src/providers/nebius.ts +++ b/packages/inference/src/providers/nebius.ts @@ -17,7 +17,12 @@ import { InferenceOutputError } from "../lib/InferenceOutputError"; import type { BodyParams, UrlParams } from "../types"; import { omit } from "../utils/omit"; -import { BaseConversationalTask, BaseTextGenerationTask, TaskProviderHelper } from "./providerHelper"; +import { + BaseConversationalTask, + BaseTextGenerationTask, + TaskProviderHelper, + type TextToImageTaskHelper, +} from "./providerHelper"; const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai"; @@ -39,9 +44,9 @@ export class NebiusTextGenerationTask extends BaseTextGenerationTask { } } -export class NebiusTextToImageTask extends TaskProviderHelper { +export class NebiusTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper { constructor() { - super("nebius", NEBIUS_API_BASE_URL, "text-to-image"); + super("nebius", NEBIUS_API_BASE_URL); } preparePayload(params: BodyParams): Record { @@ -59,7 +64,12 @@ export class NebiusTextToImageTask extends TaskProviderHelper { return "v1/images/generations"; } - getResponse(response: NebiusBase64ImageGeneration, outputType?: "url" | "blob"): string | Promise { + async getResponse( + response: NebiusBase64ImageGeneration, + url?: string, + headers?: HeadersInit, + outputType?: "url" | "blob" + ): Promise { if ( typeof response === "object" && "data" in response && diff --git a/packages/inference/src/providers/novita.ts b/packages/inference/src/providers/novita.ts index 8cb024f89d..bc66e0936d 100644 --- a/packages/inference/src/providers/novita.ts +++ b/packages/inference/src/providers/novita.ts @@ -18,7 +18,12 @@ import { InferenceOutputError } from "../lib/InferenceOutputError"; import { isUrl } from "../lib/isUrl"; import type { BodyParams, UrlParams } from "../types"; import { omit } from "../utils/omit"; -import { BaseConversationalTask, BaseTextGenerationTask, TaskProviderHelper } from "./providerHelper"; +import { + BaseConversationalTask, + BaseTextGenerationTask, + TaskProviderHelper, + type TextToVideoTaskHelper, +} from "./providerHelper"; const NOVITA_API_BASE_URL = "https://api.novita.ai"; export interface NovitaOutput { @@ -31,8 +36,7 @@ export class NovitaTextGenerationTask extends BaseTextGenerationTask { super("novita", NOVITA_API_BASE_URL); } - override makeRoute(params: UrlParams): string { - void params; + override makeRoute(): string { return "/v3/openai/chat/completions"; } } @@ -42,21 +46,20 @@ export class NovitaConversationalTask extends BaseConversationalTask { super("novita", NOVITA_API_BASE_URL); } - override makeRoute(params: UrlParams): string { - void params; + override makeRoute(): string { return "/v3/openai/chat/completions"; } } -export class NovitaTextToVideoTask extends TaskProviderHelper { +export class NovitaTextToVideoTask extends TaskProviderHelper implements TextToVideoTaskHelper { constructor() { - super("novita", NOVITA_API_BASE_URL, "text-to-video"); + super("novita", NOVITA_API_BASE_URL); } makeRoute(params: UrlParams): string { return `/v3/hf/${params.model}`; } - preparePayload(params: BodyParams): unknown { + preparePayload(params: BodyParams): Record { return { ...omit(params.args, ["inputs", "parameters"]), ...(params.args.parameters as Record), diff --git a/packages/inference/src/providers/providerHelper.ts b/packages/inference/src/providers/providerHelper.ts index 94ccc62c99..a0da0f7c6f 100644 --- a/packages/inference/src/providers/providerHelper.ts +++ b/packages/inference/src/providers/providerHelper.ts @@ -1,7 +1,54 @@ -import type { ChatCompletionOutput, TextGenerationOutput } from "@huggingface/tasks"; +import type { + AudioClassificationInput, + AudioClassificationOutput, + AutomaticSpeechRecognitionInput, + AutomaticSpeechRecognitionOutput, + ChatCompletionInput, + ChatCompletionOutput, + DocumentQuestionAnsweringInput, + DocumentQuestionAnsweringOutput, + FeatureExtractionInput, + FeatureExtractionOutput, + FillMaskInput, + FillMaskOutput, + ImageClassificationInput, + ImageClassificationOutput, + ImageSegmentationInput, + ImageSegmentationOutput, + ImageToImageInput, + ImageToTextInput, + ImageToTextOutput, + ObjectDetectionInput, + ObjectDetectionOutput, + QuestionAnsweringInput, + QuestionAnsweringOutput, + SentenceSimilarityInput, + SentenceSimilarityOutput, + SummarizationInput, + SummarizationOutput, + TableQuestionAnsweringInput, + TableQuestionAnsweringOutput, + TextClassificationOutput, + TextGenerationInput, + TextGenerationOutput, + TextToImageInput, + TextToSpeechInput, + TextToVideoInput, + TokenClassificationInput, + TokenClassificationOutput, + TranslationInput, + TranslationOutput, + VisualQuestionAnsweringInput, + VisualQuestionAnsweringOutput, + ZeroShotClassificationInput, + ZeroShotClassificationOutput, + ZeroShotImageClassificationInput, + ZeroShotImageClassificationOutput, +} from "@huggingface/tasks"; import { HF_ROUTER_URL } from "../config"; import { InferenceOutputError } from "../lib/InferenceOutputError"; -import type { BodyParams, HeaderParams, UrlParams } from "../types"; +import type { AudioToAudioOutput } from "../tasks/audio/audioToAudio"; +import type { BaseArgs, BodyParams, HeaderParams, InferenceProvider, UrlParams } from "../types"; import { toArray } from "../utils/toArray"; /** @@ -9,9 +56,8 @@ import { toArray } from "../utils/toArray"; */ export abstract class TaskProviderHelper { constructor( - private provider: string, + private provider: InferenceProvider, private baseUrl: string, - private task?: string, readonly clientSideRoutingOnly: boolean = false ) {} @@ -22,9 +68,9 @@ export abstract class TaskProviderHelper { abstract getResponse( response: unknown, url?: string, - headers?: Record, + headers?: HeadersInit, outputType?: "url" | "blob" - ): unknown; + ): Promise; /** * Prepare the route for the request @@ -75,9 +121,172 @@ export abstract class TaskProviderHelper { } } -export class BaseConversationalTask extends TaskProviderHelper { - constructor(provider: string, baseUrl: string, clientSideRoutingOnly: boolean = false) { - super(provider, baseUrl, "conversational", clientSideRoutingOnly); +// PER-TASK PROVIDER HELPER INTERFACES + +// CV Tasks +export interface TextToImageTaskHelper { + getResponse( + response: unknown, + url?: string, + headers?: HeadersInit, + outputType?: "url" | "blob" + ): Promise; + preparePayload(params: BodyParams): Record; +} + +export interface TextToVideoTaskHelper { + getResponse(response: unknown, url?: string, headers?: Record): Promise; + preparePayload(params: BodyParams): Record; +} + +export interface ImageToImageTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record; +} + +export interface ImageSegmentationTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record | BodyInit; +} + +export interface ImageClassificationTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record | BodyInit; +} + +export interface ObjectDetectionTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record | BodyInit; +} + +export interface ImageToTextTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record | BodyInit; +} + +export interface ZeroShotImageClassificationTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record | BodyInit; +} + +// NLP Tasks +export interface TextGenerationTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record; +} + +export interface ConversationalTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record; +} + +export interface TextClassificationTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record; +} + +export interface QuestionAnsweringTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record; +} + +export interface FillMaskTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record; +} + +export interface ZeroShotClassificationTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record; +} + +export interface SentenceSimilarityTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record; +} + +export interface TableQuestionAnsweringTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record; +} + +export interface TokenClassificationTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record; +} + +export interface TranslationTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record; +} + +export interface SummarizationTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record; +} + +// Audio Tasks +export interface TextToSpeechTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record; +} + +export interface TextToAudioTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams & BaseArgs>): Record; +} + +export interface AudioToAudioTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload( + params: BodyParams> + ): Record | BodyInit; +} +export interface AutomaticSpeechRecognitionTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record | BodyInit; +} + +export interface AudioClassificationTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record | BodyInit; +} + +// Multimodal Tasks +export interface DocumentQuestionAnsweringTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record | BodyInit; +} + +export interface FeatureExtractionTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record; +} + +export interface VisualQuestionAnsweringTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record | BodyInit; +} + +export interface TabularClassificationTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload( + params: BodyParams } } & Record> + ): Record | BodyInit; +} + +export interface TabularRegressionTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload( + params: BodyParams } } & Record> + ): Record | BodyInit; +} + +// BASE IMPLEMENTATIONS FOR COMMON PATTERNS + +export class BaseConversationalTask extends TaskProviderHelper implements ConversationalTaskHelper { + constructor(provider: InferenceProvider, baseUrl: string, clientSideRoutingOnly: boolean = false) { + super(provider, baseUrl, clientSideRoutingOnly); } makeRoute(): string { @@ -91,7 +300,7 @@ export class BaseConversationalTask extends TaskProviderHelper { }; } - getResponse(response: ChatCompletionOutput): ChatCompletionOutput { + async getResponse(response: ChatCompletionOutput): Promise { if ( typeof response === "object" && Array.isArray(response?.choices) && @@ -111,25 +320,33 @@ export class BaseConversationalTask extends TaskProviderHelper { } } -export class BaseTextGenerationTask extends TaskProviderHelper { - constructor(provider: string, baseUrl: string, clientSideRoutingOnly: boolean = false) { - super(provider, baseUrl, "text-generation", clientSideRoutingOnly); +export class BaseTextGenerationTask extends TaskProviderHelper implements TextGenerationTaskHelper { + constructor(provider: InferenceProvider, baseUrl: string, clientSideRoutingOnly: boolean = false) { + super(provider, baseUrl, clientSideRoutingOnly); } + preparePayload(params: BodyParams): Record { return { ...params.args, model: params.model, }; } + makeRoute(): string { return "v1/completions"; } - getResponse(response: unknown): TextGenerationOutput { + async getResponse(response: unknown): Promise { const res = toArray(response); - // @ts-expect-error - We need to check properties on unknown type - if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) { - return (res as TextGenerationOutput[])?.[0]; + if ( + Array.isArray(res) && + res.length > 0 && + res.every( + (x): x is { generated_text: string } => + typeof x === "object" && !!x && "generated_text" in x && typeof x.generated_text === "string" + ) + ) { + return res[0]; } throw new InferenceOutputError("Expected Array<{generated_text: string}>"); diff --git a/packages/inference/src/providers/replicate.ts b/packages/inference/src/providers/replicate.ts index cb01b0a37a..5a732dcf5e 100644 --- a/packages/inference/src/providers/replicate.ts +++ b/packages/inference/src/providers/replicate.ts @@ -16,17 +16,17 @@ */ import { InferenceOutputError } from "../lib/InferenceOutputError"; import { isUrl } from "../lib/isUrl"; -import type { BodyParams, HeaderParams, InferenceTask, UrlParams } from "../types"; +import type { BodyParams, HeaderParams, UrlParams } from "../types"; import { omit } from "../utils/omit"; -import { TaskProviderHelper } from "./providerHelper"; +import { TaskProviderHelper, type TextToImageTaskHelper, type TextToVideoTaskHelper } from "./providerHelper"; export interface ReplicateOutput { output?: string | string[]; } abstract class ReplicateTask extends TaskProviderHelper { - constructor(task: InferenceTask, url?: string) { - super("replicate", url || "https://api.replicate.com", task); + constructor(url?: string) { + super("replicate", url || "https://api.replicate.com"); } makeRoute(params: UrlParams): string { @@ -62,10 +62,7 @@ abstract class ReplicateTask extends TaskProviderHelper { } } -export class ReplicateTextToImageTask extends ReplicateTask { - constructor() { - super("text-to-image"); - } +export class ReplicateTextToImageTask extends ReplicateTask implements TextToImageTaskHelper { override async getResponse( res: ReplicateOutput | Blob, url?: string, @@ -93,10 +90,6 @@ export class ReplicateTextToImageTask extends ReplicateTask { } export class ReplicateTextToSpeechTask extends ReplicateTask { - constructor() { - super("text-to-speech"); - } - override preparePayload(params: BodyParams): Record { const payload = super.preparePayload(params); @@ -129,11 +122,7 @@ export class ReplicateTextToSpeechTask extends ReplicateTask { } } -export class ReplicateTextToVideoTask extends ReplicateTask { - constructor() { - super("text-to-video"); - } - +export class ReplicateTextToVideoTask extends ReplicateTask implements TextToVideoTaskHelper { override async getResponse(response: ReplicateOutput): Promise { if ( typeof response === "object" && diff --git a/packages/inference/src/providers/together.ts b/packages/inference/src/providers/together.ts index 2bb1565b5f..b475bd7d18 100644 --- a/packages/inference/src/providers/together.ts +++ b/packages/inference/src/providers/together.ts @@ -18,7 +18,12 @@ import type { ChatCompletionOutput, TextGenerationOutput, TextGenerationOutputFi import { InferenceOutputError } from "../lib/InferenceOutputError"; import type { BodyParams } from "../types"; import { omit } from "../utils/omit"; -import { BaseConversationalTask, BaseTextGenerationTask, TaskProviderHelper } from "./providerHelper"; +import { + BaseConversationalTask, + BaseTextGenerationTask, + TaskProviderHelper, + type TextToImageTaskHelper, +} from "./providerHelper"; const TOGETHER_API_BASE_URL = "https://api.together.xyz"; @@ -57,7 +62,7 @@ export class TogetherTextGenerationTask extends BaseTextGenerationTask { }; } - override getResponse(response: TogetherTextCompletionOutput): TextGenerationOutput { + override async getResponse(response: TogetherTextCompletionOutput): Promise { if ( typeof response === "object" && "choices" in response && @@ -73,9 +78,9 @@ export class TogetherTextGenerationTask extends BaseTextGenerationTask { } } -export class TogetherTextToImageTask extends TaskProviderHelper { +export class TogetherTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper { constructor() { - super("together", TOGETHER_API_BASE_URL, "text-to-image"); + super("together", TOGETHER_API_BASE_URL); } makeRoute(): string { @@ -92,7 +97,7 @@ export class TogetherTextToImageTask extends TaskProviderHelper { }; } - getResponse(response: TogetherBase64ImageGeneration, outputType?: "url" | "blob"): string | Promise { + async getResponse(response: TogetherBase64ImageGeneration, outputType?: "url" | "blob"): Promise { if ( typeof response === "object" && "data" in response && diff --git a/packages/inference/src/snippets/getInferenceSnippets.ts b/packages/inference/src/snippets/getInferenceSnippets.ts index 3a93d0264d..0d701a43ad 100644 --- a/packages/inference/src/snippets/getInferenceSnippets.ts +++ b/packages/inference/src/snippets/getInferenceSnippets.ts @@ -9,7 +9,7 @@ import { import type { PipelineType, WidgetType } from "@huggingface/tasks/src/pipelines.js"; import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingface/tasks/src/tasks/index.js"; import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions"; -import type { InferenceProvider, RequestArgs } from "../types"; +import type { InferenceProvider, InferenceTask, RequestArgs } from "../types"; import { templates } from "./templates.exported"; const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const; @@ -120,6 +120,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar opts?: Record ): InferenceSnippet[] => { /// Hacky: hard-code conversational templates here + let task = model.pipeline_tag as InferenceTask; if ( model.pipeline_tag && ["text-generation", "image-text-to-text"].includes(model.pipeline_tag) && @@ -127,14 +128,21 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar ) { templateName = opts?.streaming ? "conversationalStream" : "conversational"; inputPreparationFn = prepareConversationalInput; + task = "conversational"; } /// Prepare inputs + make request const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) }; - const request = makeRequestOptionsFromResolvedModel(providerModelId ?? model.id, { - accessToken: accessToken, - provider: provider, - ...inputs, - } as RequestArgs); + const request = makeRequestOptionsFromResolvedModel( + providerModelId ?? model.id, + { + accessToken: accessToken, + provider: provider, + ...inputs, + } as RequestArgs, + { + task: task, + } + ); /// Parse request.info.body if not a binary. /// This is the body sent to the provider. Important for snippets with raw payload (e.g curl, requests, etc.) diff --git a/packages/inference/src/tasks/audio/audioClassification.ts b/packages/inference/src/tasks/audio/audioClassification.ts index b17c875051..1489c37112 100644 --- a/packages/inference/src/tasks/audio/audioClassification.ts +++ b/packages/inference/src/tasks/audio/audioClassification.ts @@ -1,5 +1,5 @@ import type { AudioClassificationInput, AudioClassificationOutput } from "@huggingface/tasks"; -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; import type { LegacyAudioInput } from "./utils"; @@ -15,15 +15,12 @@ export async function audioClassification( args: AudioClassificationArgs, options?: Options ): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification"); const payload = preparePayload(args); - const res = await request(payload, { + const res = await request(payload, { ...options, task: "audio-classification", }); - const isValidOutput = - Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{label: string, score: number}>"); - } - return res; + + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/audio/audioToAudio.ts b/packages/inference/src/tasks/audio/audioToAudio.ts index 84e4e79641..5db31d3007 100644 --- a/packages/inference/src/tasks/audio/audioToAudio.ts +++ b/packages/inference/src/tasks/audio/audioToAudio.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; import type { LegacyAudioInput } from "./utils"; @@ -36,34 +36,11 @@ export interface AudioToAudioOutput { * Example model: speechbrain/sepformer-wham does audio source separation. */ export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio"); const payload = preparePayload(args); const res = await request(payload, { ...options, task: "audio-to-audio", }); - - return validateOutput(res); -} - -function validateOutput(output: unknown): AudioToAudioOutput[] { - if (!Array.isArray(output)) { - throw new InferenceOutputError("Expected Array"); - } - if ( - !output.every((elem): elem is AudioToAudioOutput => { - return ( - typeof elem === "object" && - elem && - "label" in elem && - typeof elem.label === "string" && - "content-type" in elem && - typeof elem["content-type"] === "string" && - "blob" in elem && - typeof elem.blob === "string" - ); - }) - ) { - throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>"); - } - return output; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts b/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts index d12c8eac21..d0c14c56ba 100644 --- a/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts +++ b/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts @@ -27,7 +27,7 @@ export async function automaticSpeechRecognition( if (!isValidOutput) { throw new InferenceOutputError("Expected {text: string}"); } - return providerHelper.getResponse(res) as AutomaticSpeechRecognitionOutput; + return providerHelper.getResponse(res); } async function buildPayload(args: AutomaticSpeechRecognitionArgs): Promise { diff --git a/packages/inference/src/tasks/audio/textToSpeech.ts b/packages/inference/src/tasks/audio/textToSpeech.ts index 0a89746b62..905c1fdcc4 100644 --- a/packages/inference/src/tasks/audio/textToSpeech.ts +++ b/packages/inference/src/tasks/audio/textToSpeech.ts @@ -18,5 +18,5 @@ export async function textToSpeech(args: TextToSpeechArgs, options?: Options): P ...options, task: "text-to-speech", }); - return providerHelper.getResponse(res) as Promise; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/audio/utils.ts b/packages/inference/src/tasks/audio/utils.ts index b8dcd10501..0a84a76305 100644 --- a/packages/inference/src/tasks/audio/utils.ts +++ b/packages/inference/src/tasks/audio/utils.ts @@ -1,4 +1,4 @@ -import type { BaseArgs, RequestArgs } from "../../types"; +import type { BaseArgs, InferenceProvider, RequestArgs } from "../../types"; import { omit } from "../../utils/omit"; /** @@ -6,6 +6,7 @@ import { omit } from "../../utils/omit"; */ export interface LegacyAudioInput { data: Blob | ArrayBuffer; + provider?: InferenceProvider; } export function preparePayload(args: BaseArgs & ({ inputs: Blob } | LegacyAudioInput)): RequestArgs { diff --git a/packages/inference/src/tasks/cv/imageClassification.ts b/packages/inference/src/tasks/cv/imageClassification.ts index 4f7a6e6b04..b5f1b9d9f1 100644 --- a/packages/inference/src/tasks/cv/imageClassification.ts +++ b/packages/inference/src/tasks/cv/imageClassification.ts @@ -1,5 +1,5 @@ import type { ImageClassificationInput, ImageClassificationOutput } from "@huggingface/tasks"; -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; import { preparePayload, type LegacyImageInput } from "./utils"; @@ -14,15 +14,11 @@ export async function imageClassification( args: ImageClassificationArgs, options?: Options ): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification"); const payload = preparePayload(args); const res = await request(payload, { ...options, task: "image-classification", }); - const isValidOutput = - Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{label: string, score: number}>"); - } - return res; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/cv/imageSegmentation.ts b/packages/inference/src/tasks/cv/imageSegmentation.ts index abbc808bf4..dd0c55d545 100644 --- a/packages/inference/src/tasks/cv/imageSegmentation.ts +++ b/packages/inference/src/tasks/cv/imageSegmentation.ts @@ -1,5 +1,5 @@ import type { ImageSegmentationInput, ImageSegmentationOutput } from "@huggingface/tasks"; -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; import { preparePayload, type LegacyImageInput } from "./utils"; @@ -14,16 +14,11 @@ export async function imageSegmentation( args: ImageSegmentationArgs, options?: Options ): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation"); const payload = preparePayload(args); const res = await request(payload, { ...options, task: "image-segmentation", }); - const isValidOutput = - Array.isArray(res) && - res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>"); - } - return res; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/cv/imageToImage.ts b/packages/inference/src/tasks/cv/imageToImage.ts index 41a098f797..b331bf4f99 100644 --- a/packages/inference/src/tasks/cv/imageToImage.ts +++ b/packages/inference/src/tasks/cv/imageToImage.ts @@ -1,5 +1,5 @@ import type { ImageToImageInput } from "@huggingface/tasks"; -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options, RequestArgs } from "../../types"; import { base64FromBytes } from "../../utils/base64FromBytes"; import { request } from "../custom/request"; @@ -11,6 +11,7 @@ export type ImageToImageArgs = BaseArgs & ImageToImageInput; * Recommended model: lllyasviel/sd-controlnet-depth */ export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image"); let reqArgs: RequestArgs; if (!args.parameters) { reqArgs = { @@ -30,9 +31,5 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P ...options, task: "image-to-image", }); - const isValidOutput = res && res instanceof Blob; - if (!isValidOutput) { - throw new InferenceOutputError("Expected Blob"); - } - return res; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/cv/imageToText.ts b/packages/inference/src/tasks/cv/imageToText.ts index 52cefd1fc4..9dcbe84a71 100644 --- a/packages/inference/src/tasks/cv/imageToText.ts +++ b/packages/inference/src/tasks/cv/imageToText.ts @@ -1,5 +1,5 @@ import type { ImageToTextInput, ImageToTextOutput } from "@huggingface/tasks"; -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; import type { LegacyImageInput } from "./utils"; @@ -10,6 +10,7 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput); * This task reads some image input and outputs the text caption. */ export async function imageToText(args: ImageToTextArgs, options?: Options): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text"); const payload = preparePayload(args); const res = ( await request<[ImageToTextOutput]>(payload, { @@ -18,9 +19,5 @@ export async function imageToText(args: ImageToTextArgs, options?: Options): Pro }) )?.[0]; - if (typeof res?.generated_text !== "string") { - throw new InferenceOutputError("Expected {generated_text: string}"); - } - - return res; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/cv/objectDetection.ts b/packages/inference/src/tasks/cv/objectDetection.ts index e66372e1d8..416ee4a6d4 100644 --- a/packages/inference/src/tasks/cv/objectDetection.ts +++ b/packages/inference/src/tasks/cv/objectDetection.ts @@ -1,7 +1,7 @@ -import { request } from "../custom/request"; -import type { BaseArgs, Options } from "../../types"; -import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { ObjectDetectionInput, ObjectDetectionOutput } from "@huggingface/tasks"; +import { getProviderHelper } from "../../lib/getProviderHelper"; +import type { BaseArgs, Options } from "../../types"; +import { request } from "../custom/request"; import { preparePayload, type LegacyImageInput } from "./utils"; export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImageInput); @@ -11,26 +11,11 @@ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImage * Recommended model: facebook/detr-resnet-50 */ export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection"); const payload = preparePayload(args); const res = await request(payload, { ...options, task: "object-detection", }); - const isValidOutput = - Array.isArray(res) && - res.every( - (x) => - typeof x.label === "string" && - typeof x.score === "number" && - typeof x.box.xmin === "number" && - typeof x.box.ymin === "number" && - typeof x.box.xmax === "number" && - typeof x.box.ymax === "number" - ); - if (!isValidOutput) { - throw new InferenceOutputError( - "Expected Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>" - ); - } - return res; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/cv/textToImage.ts b/packages/inference/src/tasks/cv/textToImage.ts index 686e7c46a4..b6338a3a8b 100644 --- a/packages/inference/src/tasks/cv/textToImage.ts +++ b/packages/inference/src/tasks/cv/textToImage.ts @@ -30,6 +30,5 @@ export async function textToImage(args: TextToImageArgs, options?: TextToImageOp task: "text-to-image", }); const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-image" }); - // @ts-expect-error - Provider-specific implementations accept the outputType parameter return providerHelper.getResponse(res, url, info.headers as Record, options?.outputType); } diff --git a/packages/inference/src/tasks/cv/textToVideo.ts b/packages/inference/src/tasks/cv/textToVideo.ts index c32d842e40..d99051af69 100644 --- a/packages/inference/src/tasks/cv/textToVideo.ts +++ b/packages/inference/src/tasks/cv/textToVideo.ts @@ -19,5 +19,5 @@ export async function textToVideo(args: TextToVideoArgs, options?: Options): Pro task: "text-to-video", }); const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-video" }); - return (await providerHelper.getResponse(response, url, info.headers as Record)) as TextToVideoOutput; + return providerHelper.getResponse(response, url, info.headers as Record); } diff --git a/packages/inference/src/tasks/cv/zeroShotImageClassification.ts b/packages/inference/src/tasks/cv/zeroShotImageClassification.ts index ca80c9c37f..6207ebe18c 100644 --- a/packages/inference/src/tasks/cv/zeroShotImageClassification.ts +++ b/packages/inference/src/tasks/cv/zeroShotImageClassification.ts @@ -1,9 +1,8 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; -import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; -import type { RequestArgs } from "../../types"; -import { base64FromBytes } from "../../utils/base64FromBytes"; import type { ZeroShotImageClassificationInput, ZeroShotImageClassificationOutput } from "@huggingface/tasks"; +import { getProviderHelper } from "../../lib/getProviderHelper"; +import type { BaseArgs, Options, RequestArgs } from "../../types"; +import { base64FromBytes } from "../../utils/base64FromBytes"; +import { request } from "../custom/request"; /** * @deprecated @@ -45,15 +44,11 @@ export async function zeroShotImageClassification( args: ZeroShotImageClassificationArgs, options?: Options ): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification"); const payload = await preparePayload(args); const res = await request(payload, { ...options, task: "zero-shot-image-classification", }); - const isValidOutput = - Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{label: string, score: number}>"); - } - return res; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/index.ts b/packages/inference/src/tasks/index.ts index 91de1d940f..6d7b954e02 100644 --- a/packages/inference/src/tasks/index.ts +++ b/packages/inference/src/tasks/index.ts @@ -4,21 +4,23 @@ export * from "./custom/streamingRequest"; // Audio tasks export * from "./audio/audioClassification"; +export * from "./audio/audioToAudio"; export * from "./audio/automaticSpeechRecognition"; export * from "./audio/textToSpeech"; -export * from "./audio/audioToAudio"; // Computer Vision tasks export * from "./cv/imageClassification"; export * from "./cv/imageSegmentation"; +export * from "./cv/imageToImage"; export * from "./cv/imageToText"; export * from "./cv/objectDetection"; export * from "./cv/textToImage"; -export * from "./cv/imageToImage"; -export * from "./cv/zeroShotImageClassification"; export * from "./cv/textToVideo"; +export * from "./cv/zeroShotImageClassification"; // Natural Language Processing tasks +export * from "./nlp/chatCompletion"; +export * from "./nlp/chatCompletionStream"; export * from "./nlp/featureExtraction"; export * from "./nlp/fillMask"; export * from "./nlp/questionAnswering"; @@ -31,13 +33,11 @@ export * from "./nlp/textGenerationStream"; export * from "./nlp/tokenClassification"; export * from "./nlp/translation"; export * from "./nlp/zeroShotClassification"; -export * from "./nlp/chatCompletion"; -export * from "./nlp/chatCompletionStream"; // Multimodal tasks export * from "./multimodal/documentQuestionAnswering"; export * from "./multimodal/visualQuestionAnswering"; // Tabular tasks -export * from "./tabular/tabularRegression"; export * from "./tabular/tabularClassification"; +export * from "./tabular/tabularRegression"; diff --git a/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts b/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts index b0fe1af3fa..fa6b573afc 100644 --- a/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts +++ b/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts @@ -1,14 +1,13 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; -import type { BaseArgs, Options } from "../../types"; -import { request } from "../custom/request"; -import type { RequestArgs } from "../../types"; -import { toArray } from "../../utils/toArray"; -import { base64FromBytes } from "../../utils/base64FromBytes"; import type { DocumentQuestionAnsweringInput, DocumentQuestionAnsweringInputData, DocumentQuestionAnsweringOutput, } from "@huggingface/tasks"; +import { getProviderHelper } from "../../lib/getProviderHelper"; +import type { BaseArgs, Options, RequestArgs } from "../../types"; +import { base64FromBytes } from "../../utils/base64FromBytes"; +import { toArray } from "../../utils/toArray"; +import { request } from "../custom/request"; /// Override the type to properly set inputs.image as Blob export type DocumentQuestionAnsweringArgs = BaseArgs & @@ -21,6 +20,7 @@ export async function documentQuestionAnswering( args: DocumentQuestionAnsweringArgs, options?: Options ): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "document-question-answering"); const reqArgs: RequestArgs = { ...args, inputs: { @@ -35,21 +35,5 @@ export async function documentQuestionAnswering( task: "document-question-answering", }) ); - - const isValidOutput = - Array.isArray(res) && - res.every( - (elem) => - typeof elem === "object" && - !!elem && - typeof elem?.answer === "string" && - (typeof elem.end === "number" || typeof elem.end === "undefined") && - (typeof elem.score === "number" || typeof elem.score === "undefined") && - (typeof elem.start === "number" || typeof elem.start === "undefined") - ); - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>"); - } - - return res[0]; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts b/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts index 536fab8340..6fc85c0c3f 100644 --- a/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts +++ b/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts @@ -3,7 +3,7 @@ import type { VisualQuestionAnsweringInputData, VisualQuestionAnsweringOutput, } from "@huggingface/tasks"; -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options, RequestArgs } from "../../types"; import { base64FromBytes } from "../../utils/base64FromBytes"; import { request } from "../custom/request"; @@ -19,6 +19,7 @@ export async function visualQuestionAnswering( args: VisualQuestionAnsweringArgs, options?: Options ): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "visual-question-answering"); const reqArgs: RequestArgs = { ...args, inputs: { @@ -31,13 +32,5 @@ export async function visualQuestionAnswering( ...options, task: "visual-question-answering", }); - const isValidOutput = - Array.isArray(res) && - res.every( - (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number" - ); - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{answer: string, score: number}>"); - } - return res[0]; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/nlp/chatCompletion.ts b/packages/inference/src/tasks/nlp/chatCompletion.ts index 40ee3b29a1..ee7dd34c81 100644 --- a/packages/inference/src/tasks/nlp/chatCompletion.ts +++ b/packages/inference/src/tasks/nlp/chatCompletion.ts @@ -15,5 +15,5 @@ export async function chatCompletion( ...options, task: "conversational", }); - return providerHelper.getResponse(response) as ChatCompletionOutput; + return providerHelper.getResponse(response); } diff --git a/packages/inference/src/tasks/nlp/featureExtraction.ts b/packages/inference/src/tasks/nlp/featureExtraction.ts index 25a6695a2c..fffa81a375 100644 --- a/packages/inference/src/tasks/nlp/featureExtraction.ts +++ b/packages/inference/src/tasks/nlp/featureExtraction.ts @@ -1,5 +1,5 @@ import type { FeatureExtractionInput } from "@huggingface/tasks"; -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -17,25 +17,10 @@ export async function featureExtraction( args: FeatureExtractionArgs, options?: Options ): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction"); const res = await request(args, { ...options, task: "feature-extraction", }); - let isValidOutput = true; - - const isNumArrayRec = (arr: unknown[], maxDepth: number, curDepth = 0): boolean => { - if (curDepth > maxDepth) return false; - if (arr.every((x) => Array.isArray(x))) { - return arr.every((x) => isNumArrayRec(x as unknown[], maxDepth, curDepth + 1)); - } else { - return arr.every((x) => typeof x === "number"); - } - }; - - isValidOutput = Array.isArray(res) && isNumArrayRec(res, 3, 0); - - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array"); - } - return res; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/nlp/fillMask.ts b/packages/inference/src/tasks/nlp/fillMask.ts index 9a30b056e3..21674cb6f8 100644 --- a/packages/inference/src/tasks/nlp/fillMask.ts +++ b/packages/inference/src/tasks/nlp/fillMask.ts @@ -1,5 +1,5 @@ import type { FillMaskInput, FillMaskOutput } from "@huggingface/tasks"; -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -9,23 +9,10 @@ export type FillMaskArgs = BaseArgs & FillMaskInput; * Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models. */ export async function fillMask(args: FillMaskArgs, options?: Options): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask"); const res = await request(args, { ...options, task: "fill-mask", }); - const isValidOutput = - Array.isArray(res) && - res.every( - (x) => - typeof x.score === "number" && - typeof x.sequence === "string" && - typeof x.token === "number" && - typeof x.token_str === "string" - ); - if (!isValidOutput) { - throw new InferenceOutputError( - "Expected Array<{score: number, sequence: string, token: number, token_str: string}>" - ); - } - return res; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/nlp/questionAnswering.ts b/packages/inference/src/tasks/nlp/questionAnswering.ts index 4141c193a2..a9b1506cd8 100644 --- a/packages/inference/src/tasks/nlp/questionAnswering.ts +++ b/packages/inference/src/tasks/nlp/questionAnswering.ts @@ -1,5 +1,5 @@ import type { QuestionAnsweringInput, QuestionAnsweringOutput } from "@huggingface/tasks"; -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -12,28 +12,10 @@ export async function questionAnswering( args: QuestionAnsweringArgs, options?: Options ): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering"); const res = await request(args, { ...options, task: "question-answering", }); - const isValidOutput = Array.isArray(res) - ? res.every( - (elem) => - typeof elem === "object" && - !!elem && - typeof elem.answer === "string" && - typeof elem.end === "number" && - typeof elem.score === "number" && - typeof elem.start === "number" - ) - : typeof res === "object" && - !!res && - typeof res.answer === "string" && - typeof res.end === "number" && - typeof res.score === "number" && - typeof res.start === "number"; - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>"); - } - return Array.isArray(res) ? res[0] : res; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/nlp/sentenceSimilarity.ts b/packages/inference/src/tasks/nlp/sentenceSimilarity.ts index a2d365b4fb..4fa45a7b52 100644 --- a/packages/inference/src/tasks/nlp/sentenceSimilarity.ts +++ b/packages/inference/src/tasks/nlp/sentenceSimilarity.ts @@ -1,5 +1,5 @@ import type { SentenceSimilarityInput, SentenceSimilarityOutput } from "@huggingface/tasks"; -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -12,14 +12,10 @@ export async function sentenceSimilarity( args: SentenceSimilarityArgs, options?: Options ): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity"); const res = await request(args, { ...options, task: "sentence-similarity", }); - - const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected number[]"); - } - return res; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/nlp/summarization.ts b/packages/inference/src/tasks/nlp/summarization.ts index bf7439ba73..710916ddcf 100644 --- a/packages/inference/src/tasks/nlp/summarization.ts +++ b/packages/inference/src/tasks/nlp/summarization.ts @@ -1,5 +1,5 @@ import type { SummarizationInput, SummarizationOutput } from "@huggingface/tasks"; -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -9,13 +9,10 @@ export type SummarizationArgs = BaseArgs & SummarizationInput; * This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model. */ export async function summarization(args: SummarizationArgs, options?: Options): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization"); const res = await request(args, { ...options, task: "summarization", }); - const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{summary_text: string}>"); - } - return res?.[0]; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts b/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts index 0ac08c5897..0febda6fde 100644 --- a/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts +++ b/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts @@ -1,5 +1,5 @@ import type { TableQuestionAnsweringInput, TableQuestionAnsweringOutput } from "@huggingface/tasks"; -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -12,34 +12,10 @@ export async function tableQuestionAnswering( args: TableQuestionAnsweringArgs, options?: Options ): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering"); const res = await request(args, { ...options, task: "table-question-answering", }); - const isValidOutput = Array.isArray(res) ? res.every((elem) => validate(elem)) : validate(res); - if (!isValidOutput) { - throw new InferenceOutputError( - "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}" - ); - } - return Array.isArray(res) ? res[0] : res; -} - -function validate(elem: unknown): elem is TableQuestionAnsweringOutput[number] { - return ( - typeof elem === "object" && - !!elem && - "aggregator" in elem && - typeof elem.aggregator === "string" && - "answer" in elem && - typeof elem.answer === "string" && - "cells" in elem && - Array.isArray(elem.cells) && - elem.cells.every((x: unknown): x is string => typeof x === "string") && - "coordinates" in elem && - Array.isArray(elem.coordinates) && - elem.coordinates.every( - (coord: unknown): coord is number[] => Array.isArray(coord) && coord.every((x) => typeof x === "number") - ) - ); + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/nlp/textClassification.ts b/packages/inference/src/tasks/nlp/textClassification.ts index 7c99ddeece..995825e405 100644 --- a/packages/inference/src/tasks/nlp/textClassification.ts +++ b/packages/inference/src/tasks/nlp/textClassification.ts @@ -1,5 +1,5 @@ import type { TextClassificationInput, TextClassificationOutput } from "@huggingface/tasks"; -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -12,16 +12,12 @@ export async function textClassification( args: TextClassificationArgs, options?: Options ): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification"); const res = ( await request(args, { ...options, task: "text-classification", }) )?.[0]; - const isValidOutput = - Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{label: string, score: number}>"); - } - return res; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/nlp/textGeneration.ts b/packages/inference/src/tasks/nlp/textGeneration.ts index b8541c63f0..7ede69e646 100644 --- a/packages/inference/src/tasks/nlp/textGeneration.ts +++ b/packages/inference/src/tasks/nlp/textGeneration.ts @@ -19,5 +19,5 @@ export async function textGeneration( ...options, task: "text-generation", }); - return providerHelper.getResponse(response) as TextGenerationOutput; + return providerHelper.getResponse(response); } diff --git a/packages/inference/src/tasks/nlp/tokenClassification.ts b/packages/inference/src/tasks/nlp/tokenClassification.ts index 46d53ffcbd..b97690268d 100644 --- a/packages/inference/src/tasks/nlp/tokenClassification.ts +++ b/packages/inference/src/tasks/nlp/tokenClassification.ts @@ -1,5 +1,5 @@ import type { TokenClassificationInput, TokenClassificationOutput } from "@huggingface/tasks"; -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { toArray } from "../../utils/toArray"; import { request } from "../custom/request"; @@ -13,26 +13,12 @@ export async function tokenClassification( args: TokenClassificationArgs, options?: Options ): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification"); const res = toArray( await request(args, { ...options, task: "token-classification", }) ); - const isValidOutput = - Array.isArray(res) && - res.every( - (x) => - typeof x.end === "number" && - typeof x.entity_group === "string" && - typeof x.score === "number" && - typeof x.start === "number" && - typeof x.word === "string" - ); - if (!isValidOutput) { - throw new InferenceOutputError( - "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>" - ); - } - return res; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/nlp/translation.ts b/packages/inference/src/tasks/nlp/translation.ts index a05b228eaa..a47e46fe0b 100644 --- a/packages/inference/src/tasks/nlp/translation.ts +++ b/packages/inference/src/tasks/nlp/translation.ts @@ -1,5 +1,5 @@ import type { TranslationInput, TranslationOutput } from "@huggingface/tasks"; -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -8,13 +8,10 @@ export type TranslationArgs = BaseArgs & TranslationInput; * This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en. */ export async function translation(args: TranslationArgs, options?: Options): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation"); const res = await request(args, { ...options, task: "translation", }); - const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected type Array<{translation_text: string}>"); - } - return res?.length === 1 ? res?.[0] : res; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/nlp/zeroShotClassification.ts b/packages/inference/src/tasks/nlp/zeroShotClassification.ts index 769315ac22..c3d015f0df 100644 --- a/packages/inference/src/tasks/nlp/zeroShotClassification.ts +++ b/packages/inference/src/tasks/nlp/zeroShotClassification.ts @@ -1,5 +1,5 @@ import type { ZeroShotClassificationInput, ZeroShotClassificationOutput } from "@huggingface/tasks"; -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { toArray } from "../../utils/toArray"; import { request } from "../custom/request"; @@ -13,24 +13,12 @@ export async function zeroShotClassification( args: ZeroShotClassificationArgs, options?: Options ): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification"); const res = toArray( await request(args, { ...options, task: "zero-shot-classification", }) ); - const isValidOutput = - Array.isArray(res) && - res.every( - (x) => - Array.isArray(x.labels) && - x.labels.every((_label) => typeof _label === "string") && - Array.isArray(x.scores) && - x.scores.every((_score) => typeof _score === "number") && - typeof x.sequence === "string" - ); - if (!isValidOutput) { - throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>"); - } - return res; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/tabular/tabularClassification.ts b/packages/inference/src/tasks/tabular/tabularClassification.ts index aa5f0c72d2..d8182249a1 100644 --- a/packages/inference/src/tasks/tabular/tabularClassification.ts +++ b/packages/inference/src/tasks/tabular/tabularClassification.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -25,13 +25,10 @@ export async function tabularClassification( args: TabularClassificationArgs, options?: Options ): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification"); const res = await request(args, { ...options, task: "tabular-classification", }); - const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected number[]"); - } - return res; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/tabular/tabularRegression.ts b/packages/inference/src/tasks/tabular/tabularRegression.ts index 102f1dc3c8..84754b72ed 100644 --- a/packages/inference/src/tasks/tabular/tabularRegression.ts +++ b/packages/inference/src/tasks/tabular/tabularRegression.ts @@ -1,4 +1,4 @@ -import { InferenceOutputError } from "../../lib/InferenceOutputError"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; @@ -25,13 +25,10 @@ export async function tabularRegression( args: TabularRegressionArgs, options?: Options ): Promise { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression"); const res = await request(args, { ...options, task: "tabular-regression", }); - const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number"); - if (!isValidOutput) { - throw new InferenceOutputError("Expected number[]"); - } - return res; + return providerHelper.getResponse(res); } diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index d36f34007f..390b55fc1b 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -114,8 +114,8 @@ export interface UrlParams { task?: InferenceTask; } -export interface BodyParams { - args: Record; +export interface BodyParams = Record> { + args: T; model: string; task?: InferenceTask; }