Skip to content

Commit ad2eafc

Browse files
authored
feat: allow passing context.call settings when defining agent (#90)
1 parent f868cf1 commit ad2eafc

File tree

7 files changed

+72
-62
lines changed

7 files changed

+72
-62
lines changed

src/agents/adapters.ts

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
* this file contains adapters which convert tools and models
44
* to workflow tools and models.
55
*/
6-
import { createOpenAI } from "@ai-sdk/openai";
76
import { HTTPMethods } from "@upstash/qstash";
87
import { WorkflowContext } from "../context";
98
import { tool } from "ai";
10-
import { AISDKTool, LangchainTool, ProviderFunction } from "./types";
9+
import { AgentCallParams, AISDKTool, LangchainTool, ProviderFunction } from "./types";
1110
import { AGENT_NAME_HEADER } from "./constants";
1211
import { z, ZodType } from "zod";
1312

1413
export const fetchWithContextCall = async (
1514
context: WorkflowContext,
15+
agentCallParams?: AgentCallParams,
1616
...params: Parameters<typeof fetch>
1717
) => {
1818
const [input, init] = params;
@@ -33,6 +33,9 @@ export const fetchWithContextCall = async (
3333
method: init?.method as HTTPMethods,
3434
headers,
3535
body,
36+
timeout: agentCallParams?.timeout,
37+
retries: agentCallParams?.retries,
38+
flowControl: agentCallParams?.flowControl,
3639
});
3740

3841
// Construct headers for the response
@@ -61,37 +64,19 @@ export const fetchWithContextCall = async (
6164
}
6265
};
6366

64-
/**
65-
* creates an AI SDK openai client with a custom
66-
* fetch implementation which uses context.call.
67-
*
68-
* @param context workflow context
69-
* @returns ai sdk openai
70-
*/
71-
export const createWorkflowOpenAI = (
72-
context: WorkflowContext,
73-
config?: { baseURL?: string; apiKey?: string }
74-
) => {
75-
const { baseURL, apiKey } = config ?? {};
76-
return createOpenAI({
77-
baseURL,
78-
apiKey,
79-
compatibility: "strict",
80-
fetch: async (...params) => fetchWithContextCall(context, ...params),
81-
});
82-
};
83-
8467
export const createWorkflowModel = <TProvider extends ProviderFunction>({
8568
context,
8669
provider,
8770
providerParams,
71+
agentCallParams,
8872
}: {
8973
context: WorkflowContext;
9074
provider: TProvider;
9175
providerParams?: Omit<Required<Parameters<TProvider>>[0], "fetch">;
76+
agentCallParams?: AgentCallParams;
9277
}): ReturnType<TProvider> => {
9378
return provider({
94-
fetch: (...params) => fetchWithContextCall(context, ...params),
79+
fetch: (...params) => fetchWithContextCall(context, agentCallParams, ...params),
9580
...providerParams,
9681
});
9782
};

src/agents/agent.test.ts

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,18 @@ describe("agents", () => {
3333
const maxSteps = 2;
3434
const name = "my agent";
3535
const temparature = 0.4;
36-
const model = agentsApi.openai("gpt-3.5-turbo");
36+
37+
const flowControlKey = "flowControlKey";
38+
const model = agentsApi.openai("gpt-3.5-turbo", {
39+
callSettings: {
40+
flowControl: {
41+
key: flowControlKey,
42+
parallelism: 2,
43+
},
44+
retries: 5,
45+
timeout: 10,
46+
},
47+
});
3748

3849
const agent = new Agent(
3950
{
@@ -103,12 +114,15 @@ describe("agents", () => {
103114
"upstash-forward-content-type": "application/json",
104115
"upstash-forward-upstash-agent-name": "my agent",
105116
"upstash-method": "POST",
106-
"upstash-retries": "0",
107117
"upstash-workflow-calltype": "toCallback",
108118
"upstash-workflow-init": "false",
109119
"upstash-workflow-runid": workflowRunId,
110120
"upstash-workflow-url": "https://requestcatcher.com/api",
111121
"upstash-callback-retries": "5",
122+
"upstash-flow-control-key": "flowControlKey",
123+
"upstash-flow-control-value": "parallelism=2",
124+
"upstash-retries": "5",
125+
"upstash-timeout": "10",
112126
},
113127
},
114128
],
@@ -168,11 +182,14 @@ describe("agents", () => {
168182
"upstash-forward-content-type": "application/json",
169183
"upstash-forward-upstash-agent-name": "my agent",
170184
"upstash-method": "POST",
171-
"upstash-retries": "0",
172185
"upstash-workflow-calltype": "toCallback",
173186
"upstash-workflow-init": "false",
174187
"upstash-workflow-runid": workflowRunId,
175188
"upstash-workflow-url": "https://requestcatcher.com/api",
189+
"upstash-flow-control-key": "flowControlKey",
190+
"upstash-flow-control-value": "parallelism=2",
191+
"upstash-retries": "5",
192+
"upstash-timeout": "10",
176193
},
177194
},
178195
],
@@ -231,11 +248,14 @@ describe("agents", () => {
231248
"upstash-forward-content-type": "application/json",
232249
"upstash-forward-upstash-agent-name": "manager llm",
233250
"upstash-method": "POST",
234-
"upstash-retries": "0",
235251
"upstash-workflow-calltype": "toCallback",
236252
"upstash-workflow-init": "false",
237253
"upstash-workflow-runid": workflowRunId,
238254
"upstash-workflow-url": "https://requestcatcher.com/api",
255+
"upstash-flow-control-key": "flowControlKey",
256+
"upstash-flow-control-value": "parallelism=2",
257+
"upstash-retries": "5",
258+
"upstash-timeout": "10",
239259
},
240260
},
241261
],

src/agents/index.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,13 @@ export class WorkflowAgents {
9595
*/
9696
public openai(...params: CustomModelParams) {
9797
const [model, settings] = params;
98-
const { baseURL, apiKey, ...otherSettings } = settings ?? {};
98+
const { baseURL, apiKey, callSettings, ...otherSettings } = settings ?? {};
9999

100100
const openaiModel = this.AISDKModel({
101101
context: this.context,
102102
provider: createOpenAI,
103103
providerParams: { baseURL, apiKey, compatibility: "strict" },
104+
agentCallParams: callSettings,
104105
});
105106

106107
return openaiModel(model, otherSettings);

src/agents/types.ts

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import type { CoreTool, generateText } from "ai";
22
import { Agent } from "./agent";
3-
import { createWorkflowOpenAI, WorkflowTool } from "./adapters";
3+
import { WorkflowTool } from "./adapters";
4+
import { CallSettings } from "../types";
5+
import { createOpenAI } from "@ai-sdk/openai";
46

57
export type AISDKTool = CoreTool;
68
export type LangchainTool = {
@@ -90,8 +92,11 @@ export type ManagerAgentParameters = {
9092
} & Pick<Partial<AgentParameters>, "name" | "background"> &
9193
Pick<AgentParameters, "maxSteps">;
9294

93-
type ModelParams = Parameters<ReturnType<typeof createWorkflowOpenAI>>;
94-
type CustomModelSettings = ModelParams["1"] & { baseURL?: string; apiKey?: string };
95+
type ModelParams = Parameters<ReturnType<typeof createOpenAI>>;
96+
export type AgentCallParams = Pick<CallSettings, "flowControl" | "retries" | "timeout">;
97+
type CustomModelSettings = ModelParams["1"] & { baseURL?: string; apiKey?: string } & {
98+
callSettings: AgentCallParams;
99+
};
95100
export type CustomModelParams = [ModelParams[0], CustomModelSettings?];
96101

97102
export type ProviderFunction = (params: {

src/context/auto-executor.ts

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ export class AutoExecutor {
6060
if (this.executingStep) {
6161
throw new WorkflowError(
6262
"A step can not be run inside another step." +
63-
` Tried to run '${stepInfo.stepName}' inside '${this.executingStep}'`
63+
` Tried to run '${stepInfo.stepName}' inside '${this.executingStep}'`
6464
);
6565
}
6666

@@ -173,7 +173,7 @@ export class AutoExecutor {
173173
// user has added/removed a parallel step
174174
throw new WorkflowError(
175175
`Incompatible number of parallel steps when call state was '${parallelCallState}'.` +
176-
` Expected ${parallelSteps.length}, got ${plannedParallelStepCount} from the request.`
176+
` Expected ${parallelSteps.length}, got ${plannedParallelStepCount} from the request.`
177177
);
178178
}
179179

@@ -193,7 +193,7 @@ export class AutoExecutor {
193193
initialStepCount,
194194
invokeCount: this.invokeCount,
195195
telemetry: this.telemetry,
196-
debug: this.debug
196+
debug: this.debug,
197197
});
198198
break;
199199
}
@@ -208,7 +208,7 @@ export class AutoExecutor {
208208
if (!planStep || planStep.targetStep === undefined) {
209209
throw new WorkflowError(
210210
`There must be a last step and it should have targetStep larger than 0.` +
211-
`Received: ${JSON.stringify(planStep)}`
211+
`Received: ${JSON.stringify(planStep)}`
212212
);
213213
}
214214
const stepIndex = planStep.targetStep - initialStepCount;
@@ -384,14 +384,14 @@ const validateStep = (lazyStep: BaseLazyStep, stepFromRequest: Step): void => {
384384
if (lazyStep.stepName !== stepFromRequest.stepName) {
385385
throw new WorkflowError(
386386
`Incompatible step name. Expected '${lazyStep.stepName}',` +
387-
` got '${stepFromRequest.stepName}' from the request`
387+
` got '${stepFromRequest.stepName}' from the request`
388388
);
389389
}
390390
// check type name
391391
if (lazyStep.stepType !== stepFromRequest.stepType) {
392392
throw new WorkflowError(
393393
`Incompatible step type. Expected '${lazyStep.stepType}',` +
394-
` got '${stepFromRequest.stepType}' from the request`
394+
` got '${stepFromRequest.stepType}' from the request`
395395
);
396396
}
397397
};
@@ -419,10 +419,10 @@ const validateParallelSteps = (lazySteps: BaseLazyStep[], stepsFromRequest: Step
419419
const requestStepTypes = stepsFromRequest.map((step) => step.stepType);
420420
throw new WorkflowError(
421421
`Incompatible steps detected in parallel execution: ${error.message}` +
422-
`\n > Step Names from the request: ${JSON.stringify(requestStepNames)}` +
423-
`\n Step Types from the request: ${JSON.stringify(requestStepTypes)}` +
424-
`\n > Step Names expected: ${JSON.stringify(lazyStepNames)}` +
425-
`\n Step Types expected: ${JSON.stringify(lazyStepTypes)}`
422+
`\n > Step Names from the request: ${JSON.stringify(requestStepNames)}` +
423+
`\n Step Types from the request: ${JSON.stringify(requestStepTypes)}` +
424+
`\n > Step Names expected: ${JSON.stringify(lazyStepNames)}` +
425+
`\n Step Types expected: ${JSON.stringify(lazyStepTypes)}`
426426
);
427427
}
428428
throw error;

src/context/steps.ts

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,14 @@ export abstract class BaseLazyStep<TResult = unknown> {
151151
}
152152

153153
async submitStep({ context, body, headers }: SubmitStepParams) {
154-
return await context.qstashClient.batch([
154+
return (await context.qstashClient.batch([
155155
{
156156
body,
157157
headers,
158158
method: "POST",
159159
url: context.url,
160160
},
161-
]) as { messageId: string }[];
161+
])) as { messageId: string }[];
162162
}
163163
}
164164

@@ -236,15 +236,15 @@ export class LazySleepStep extends BaseLazyStep {
236236
}
237237

238238
async submitStep({ context, body, headers, isParallel }: SubmitStepParams) {
239-
return await context.qstashClient.batch([
239+
return (await context.qstashClient.batch([
240240
{
241241
body,
242242
headers,
243243
method: "POST",
244244
url: context.url,
245245
delay: isParallel ? undefined : this.sleep,
246246
},
247-
]) as { messageId: string }[]
247+
])) as { messageId: string }[];
248248
}
249249
}
250250

@@ -287,15 +287,15 @@ export class LazySleepUntilStep extends BaseLazyStep {
287287
}
288288

289289
async submitStep({ context, body, headers, isParallel }: SubmitStepParams) {
290-
return await context.qstashClient.batch([
290+
return (await context.qstashClient.batch([
291291
{
292292
body,
293293
headers,
294294
method: "POST",
295295
url: context.url,
296296
notBefore: isParallel ? undefined : this.sleepUntil,
297297
},
298-
]) as { messageId: string }[];
298+
])) as { messageId: string }[];
299299
}
300300
}
301301

@@ -463,14 +463,14 @@ export class LazyCallStep<TResult = unknown, TBody = unknown> extends BaseLazySt
463463
}
464464

465465
async submitStep({ context, headers }: SubmitStepParams) {
466-
return await context.qstashClient.batch([
466+
return (await context.qstashClient.batch([
467467
{
468468
headers,
469469
body: JSON.stringify(this.body),
470470
method: this.method,
471471
url: this.url,
472472
},
473-
]) as { messageId: string }[];
473+
])) as { messageId: string }[];
474474
}
475475
}
476476

@@ -538,11 +538,11 @@ export class LazyWaitForEventStep extends BaseLazyStep<WaitStepResponse> {
538538
// to include telemetry headers:
539539
...(telemetry
540540
? Object.fromEntries(
541-
Object.entries(getTelemetryHeaders(telemetry)).map(([header, value]) => [
542-
header,
543-
[value],
544-
])
545-
)
541+
Object.entries(getTelemetryHeaders(telemetry)).map(([header, value]) => [
542+
header,
543+
[value],
544+
])
545+
)
546546
: {}),
547547

548548
// note: using WORKFLOW_ID_HEADER doesn't work, because Runid -> RunId:
@@ -571,14 +571,14 @@ export class LazyWaitForEventStep extends BaseLazyStep<WaitStepResponse> {
571571
}
572572

573573
async submitStep({ context, body, headers }: SubmitStepParams) {
574-
const result = await context.qstashClient.http.request({
574+
const result = (await context.qstashClient.http.request({
575575
path: ["v2", "wait", this.eventId],
576576
body: body,
577577
headers,
578578
method: "POST",
579579
parseResponseAsJson: false,
580-
}) as { messageId: string };
581-
return [result]
580+
})) as { messageId: string };
581+
return [result];
582582
}
583583
}
584584

@@ -748,12 +748,12 @@ export class LazyInvokeStep<TResult = unknown, TBody = unknown> extends BaseLazy
748748

749749
async submitStep({ context, body, headers }: SubmitStepParams) {
750750
const newUrl = context.url.replace(/[^/]+$/, this.workflowId);
751-
const result = await context.qstashClient.publish({
751+
const result = (await context.qstashClient.publish({
752752
headers,
753753
method: "POST",
754754
body,
755755
url: newUrl,
756-
}) as { messageId: string };
756+
})) as { messageId: string };
757757
return [result];
758758
}
759759
}

src/qstash/submit-steps.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ export const submitParallelSteps = async ({
3030
steps: planSteps,
3131
});
3232

33-
const result = await context.qstashClient.batch(
33+
const result = (await context.qstashClient.batch(
3434
planSteps.map((planStep) => {
3535
const { headers } = getHeaders({
3636
initHeaderValue: "false",
@@ -55,7 +55,7 @@ export const submitParallelSteps = async ({
5555
delay: planStep.sleepFor,
5656
};
5757
})
58-
) as { messageId: string }[];
58+
)) as { messageId: string }[];
5959

6060
await debug?.log("INFO", "SUBMIT_STEP", {
6161
messageIds: result.map((message) => {
@@ -65,7 +65,6 @@ export const submitParallelSteps = async ({
6565
}),
6666
});
6767

68-
6968
throw new WorkflowAbort(planSteps[0].stepName, planSteps[0]);
7069
};
7170

0 commit comments

Comments
 (0)