Skip to content

Commit 42e0636

Browse files
benankSBrandeishanouticelina
authored andcommitted
Add Groq as an inference provider (#1352)
This PR adds Groq as a fast inference provider for conversational and text generation tasks. --------- Co-authored-by: SBrandeis <[email protected]> Co-authored-by: Celina Hanouti <[email protected]>
1 parent f75d9db commit 42e0636

File tree

6 files changed

+100
-0
lines changed

6 files changed

+100
-0
lines changed

packages/inference/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ Currently, we support the following providers:
5959
- [Blackforestlabs](https://blackforestlabs.ai)
6060
- [Cohere](https://cohere.com)
6161
- [Cerebras](https://cerebras.ai/)
62+
- [Groq](https://groq.com)
6263

6364
To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
6465
```ts
@@ -86,6 +87,7 @@ Only a subset of models are supported when requesting third-party providers. You
8687
- [Together supported models](https://huggingface.co/api/partners/together/models)
8788
- [Cohere supported models](https://huggingface.co/api/partners/cohere/models)
8889
- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models)
90+
- [Groq supported models](https://console.groq.com/docs/models)
8991
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)
9092

9193
**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import * as Cerebras from "../providers/cerebras";
33
import * as Cohere from "../providers/cohere";
44
import * as FalAI from "../providers/fal-ai";
55
import * as Fireworks from "../providers/fireworks-ai";
6+
import * as Groq from "../providers/groq";
67
import * as HFInference from "../providers/hf-inference";
78

89
import * as Hyperbolic from "../providers/hyperbolic";
@@ -96,6 +97,10 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
9697
"fireworks-ai": {
9798
conversational: new Fireworks.FireworksConversationalTask(),
9899
},
100+
groq: {
101+
conversational: new Groq.GroqConversationalTask(),
102+
"text-generation": new Groq.GroqTextGenerationTask(),
103+
},
99104
hyperbolic: {
100105
"text-to-image": new Hyperbolic.HyperbolicTextToImageTask(),
101106
conversational: new Hyperbolic.HyperbolicConversationalTask(),

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
2424
cohere: {},
2525
"fal-ai": {},
2626
"fireworks-ai": {},
27+
groq: {},
2728
"hf-inference": {},
2829
hyperbolic: {},
2930
nebius: {},
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper";
2+
3+
/**
4+
* See the registered mapping of HF model ID => Groq model ID here:
5+
*
6+
* https://huggingface.co/api/partners/groq/models
7+
*
8+
* This is a publicly available mapping.
9+
*
10+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12+
*
13+
* - If you work at Groq and want to update this mapping, please use the model mapping API we provide on huggingface.co
14+
* - If you're a community member and want to add a new supported HF model to Groq, please open an issue on the present repo
15+
* and we will tag Groq team members.
16+
*
17+
* Thanks!
18+
*/
19+
20+
const GROQ_API_BASE_URL = "https://api.groq.com";
21+
22+
export class GroqTextGenerationTask extends BaseTextGenerationTask {
23+
constructor() {
24+
super("groq", GROQ_API_BASE_URL);
25+
}
26+
27+
override makeRoute(): string {
28+
return "/openai/v1/chat/completions";
29+
}
30+
}
31+
32+
export class GroqConversationalTask extends BaseConversationalTask {
33+
constructor() {
34+
super("groq", GROQ_API_BASE_URL);
35+
}
36+
37+
override makeRoute(): string {
38+
return "/openai/v1/chat/completions";
39+
}
40+
}

packages/inference/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ export const INFERENCE_PROVIDERS = [
4343
"cohere",
4444
"fal-ai",
4545
"fireworks-ai",
46+
"groq",
4647
"hf-inference",
4748
"hyperbolic",
4849
"nebius",

packages/inference/test/InferenceClient.spec.ts

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,4 +1751,55 @@ describe.skip("InferenceClient", () => {
17511751
},
17521752
TIMEOUT
17531753
);
1754+
describe.concurrent(
1755+
"Groq",
1756+
() => {
1757+
const client = new InferenceClient(env.HF_GROQ_KEY ?? "dummy");
1758+
1759+
HARDCODED_MODEL_INFERENCE_MAPPING["groq"] = {
1760+
"meta-llama/Llama-3.3-70B-Instruct": {
1761+
hfModelId: "meta-llama/Llama-3.3-70B-Instruct",
1762+
providerId: "llama-3.3-70b-versatile",
1763+
status: "live",
1764+
task: "conversational",
1765+
},
1766+
};
1767+
1768+
it("chatCompletion", async () => {
1769+
const res = await client.chatCompletion({
1770+
model: "meta-llama/Llama-3.3-70B-Instruct",
1771+
provider: "groq",
1772+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
1773+
});
1774+
if (res.choices && res.choices.length > 0) {
1775+
const completion = res.choices[0].message?.content;
1776+
expect(completion).toContain("two");
1777+
}
1778+
});
1779+
1780+
it("chatCompletion stream", async () => {
1781+
const stream = client.chatCompletionStream({
1782+
model: "meta-llama/Llama-3.3-70B-Instruct",
1783+
provider: "groq",
1784+
messages: [{ role: "user", content: "Say 'this is a test'" }],
1785+
stream: true,
1786+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
1787+
1788+
let fullResponse = "";
1789+
for await (const chunk of stream) {
1790+
if (chunk.choices && chunk.choices.length > 0) {
1791+
const content = chunk.choices[0].delta?.content;
1792+
if (content) {
1793+
fullResponse += content;
1794+
}
1795+
}
1796+
}
1797+
1798+
// Verify we got a meaningful response
1799+
expect(fullResponse).toBeTruthy();
1800+
expect(fullResponse.length).toBeGreaterThan(0);
1801+
});
1802+
},
1803+
TIMEOUT
1804+
);
17541805
});

0 commit comments

Comments
 (0)