From f1c729dfb3b6c2e9264f9fe9a1d099eaae78c0cd Mon Sep 17 00:00:00 2001 From: Olivier Toromanoff Date: Wed, 5 Feb 2025 14:49:17 +0000 Subject: [PATCH] text-generation for replicate --- packages/inference/src/providers/replicate.ts | 5 +++ .../inference/src/tasks/nlp/textGeneration.ts | 36 +++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/packages/inference/src/providers/replicate.ts b/packages/inference/src/providers/replicate.ts index cd55c2c3e..62f27402a 100644 --- a/packages/inference/src/providers/replicate.ts +++ b/packages/inference/src/providers/replicate.ts @@ -5,6 +5,11 @@ export const REPLICATE_API_BASE_URL = "https://api.replicate.com"; type ReplicateId = string; export const REPLICATE_SUPPORTED_MODEL_IDS: ProviderMapping = { + "text-generation": { + "deepseek-ai/DeepSeek-R1": "deepseek-ai/deepseek-r1", + "meta/meta-llama-3-70b": "meta/meta-llama-3-70b", + "meta/meta-llama-3-8b": "meta/meta-llama-3-8b", + }, "text-to-image": { "black-forest-labs/FLUX.1-dev": "black-forest-labs/flux-dev", "black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell", diff --git a/packages/inference/src/tasks/nlp/textGeneration.ts b/packages/inference/src/tasks/nlp/textGeneration.ts index 7d5906f12..468ead0f1 100644 --- a/packages/inference/src/tasks/nlp/textGeneration.ts +++ b/packages/inference/src/tasks/nlp/textGeneration.ts @@ -6,11 +6,17 @@ import type { } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; +import { omit } from "../../utils/omit"; import { toArray } from "../../utils/toArray"; import { request } from "../custom/request"; export type { TextGenerationInput, TextGenerationOutput }; +interface ReplicateTextCompletionOutput { + status: string; + output?: string[]; +} + interface TogeteherTextCompletionOutput extends Omit { choices: Array<{ text: string; @@ -43,6 +49,36 @@ export async function textGeneration( return { generated_text: completion.text, }; + } else if (args.provider === "replicate") { + const payload = { + ...omit(args, ["inputs", "parameters"]), + ...args.parameters, + prompt: args.inputs, + }; + + const raw = await request(payload, { + ...options, + taskHint: "text-generation", + }); + + if (typeof raw !== "object" || !("status" in raw)) { + throw new InferenceOutputError("Incomplete response"); + } + + const status = raw.status; + if (status === "starting") { + throw new InferenceOutputError("Replicate server-side time out"); + } + + if (!("output" in raw && Array.isArray(raw?.output))) { + throw new InferenceOutputError("Invalid response: no output"); + } + + const joined_output = raw.output.join(""); + + return { + generated_text: joined_output, + }; } else { const res = toArray( await request(args, {