Skip to content

Commit 6b57fd3

Browse files
authored
feat: make preferred model form consistent with the other forms (#309)
* only show notice on actual default workspace * fix test assertion * invalidate muxes after changing preferred model
1 parent 5200192 commit 6b57fd3

6 files changed

+93
-46
lines changed

src/features/workspace/components/__tests__/workspace-preferred-model.test.tsx

+5-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { screen, waitFor } from "@testing-library/react";
33
import { WorkspacePreferredModel } from "../workspace-preferred-model";
44
import userEvent from "@testing-library/user-event";
55

6-
test("render model overrides", () => {
6+
test("render model overrides", async () => {
77
render(
88
<WorkspacePreferredModel
99
isArchived={false}
@@ -19,7 +19,10 @@ test("render model overrides", () => {
1919
expect(
2020
screen.getByRole("button", { name: /select the model/i }),
2121
).toBeVisible();
22-
expect(screen.getByRole("button", { name: /save/i })).toBeVisible();
22+
23+
await waitFor(() => {
24+
expect(screen.getByRole("button", { name: /save/i })).toBeVisible();
25+
});
2326
});
2427

2528
test("submit preferred model", async () => {

src/features/workspace/components/workspace-custom-instructions.tsx

+5-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,11 @@ function useCustomInstructionsValue({
123123
options: V1GetWorkspaceCustomInstructionsData;
124124
queryClient: QueryClient;
125125
}) {
126-
const formState = useFormState({ prompt: initialValue });
126+
const initialFormValues = useMemo(
127+
() => ({ prompt: initialValue }),
128+
[initialValue],
129+
);
130+
const formState = useFormState(initialFormValues);
127131
const { values, updateFormValues } = formState;
128132

129133
// Subscribe to changes in the workspace system prompt value in the query cache

src/features/workspace/components/workspace-name.tsx

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import { useNavigate } from "react-router-dom";
1212
import { twMerge } from "tailwind-merge";
1313
import { useFormState } from "@/hooks/useFormState";
1414
import { FormButtons } from "@/components/FormButtons";
15+
import { FormEvent } from "react";
1516

1617
export function WorkspaceName({
1718
className,
@@ -32,7 +33,7 @@ export function WorkspaceName({
3233
const isDefault = workspaceName === "default";
3334
const isUneditable = isArchived || isPending || isDefault;
3435

35-
const handleSubmit = (event: { preventDefault: () => void }) => {
36+
const handleSubmit = (event: FormEvent) => {
3637
event.preventDefault();
3738

3839
mutateAsync(

src/features/workspace/components/workspace-preferred-model.tsx

+38-23
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import {
22
Alert,
3-
Button,
43
Card,
54
CardBody,
65
CardFooter,
@@ -16,6 +15,10 @@ import { FormEvent } from "react";
1615
import { usePreferredModelWorkspace } from "../hooks/use-preferred-preferred-model";
1716
import { Select, SelectButton } from "@stacklok/ui-kit";
1817
import { useQueryListAllModelsForAllProviders } from "@/hooks/use-query-list-all-models-for-all-providers";
18+
import { FormButtons } from "@/components/FormButtons";
19+
import { invalidateQueries } from "@/lib/react-query-utils";
20+
import { v1GetWorkspaceMuxesQueryKey } from "@/api/generated/@tanstack/react-query.gen";
21+
import { useQueryClient } from "@tanstack/react-query";
1922

2023
function MissingProviderBanner() {
2124
return (
@@ -39,30 +42,38 @@ export function WorkspacePreferredModel({
3942
workspaceName: string;
4043
isArchived: boolean | undefined;
4144
}) {
42-
const { preferredModel, setPreferredModel, isPending } =
43-
usePreferredModelWorkspace(workspaceName);
45+
const queryClient = useQueryClient();
46+
const { formState, isPending } = usePreferredModelWorkspace(workspaceName);
4447
const { mutateAsync } = useMutationPreferredModelWorkspace();
4548
const { data: providerModels = [] } = useQueryListAllModelsForAllProviders();
46-
const { model, provider_id } = preferredModel;
4749
const isModelsEmpty = !isPending && providerModels.length === 0;
4850

4951
const handleSubmit = (event: FormEvent) => {
5052
event.preventDefault();
51-
mutateAsync({
52-
path: { workspace_name: workspaceName },
53-
body: [
54-
{
55-
matcher: "",
56-
provider_id,
57-
model,
58-
matcher_type: MuxMatcherType.CATCH_ALL,
59-
},
60-
],
61-
});
53+
mutateAsync(
54+
{
55+
path: { workspace_name: workspaceName },
56+
body: [
57+
{
58+
matcher: "",
59+
matcher_type: MuxMatcherType.CATCH_ALL,
60+
...formState.values.preferredModel,
61+
},
62+
],
63+
},
64+
{
65+
onSuccess: () =>
66+
invalidateQueries(queryClient, [v1GetWorkspaceMuxesQueryKey]),
67+
},
68+
);
6269
};
6370

6471
return (
65-
<Form onSubmit={handleSubmit} validationBehavior="aria">
72+
<Form
73+
onSubmit={handleSubmit}
74+
validationBehavior="aria"
75+
data-testid="preferred-model"
76+
>
6677
<Card className={twMerge(className, "shrink-0")}>
6778
<CardBody className="flex flex-col gap-6">
6879
<div className="flex flex-col justify-start">
@@ -84,16 +95,18 @@ export function WorkspacePreferredModel({
8495
isRequired
8596
isDisabled={isModelsEmpty}
8697
className="w-full"
87-
selectedKey={preferredModel?.model}
98+
selectedKey={formState.values.preferredModel?.model}
8899
placeholder="Select the model"
89100
onSelectionChange={(model) => {
90101
const preferredModelProvider = providerModels.find(
91102
(item) => item.name === model,
92103
);
93104
if (preferredModelProvider) {
94-
setPreferredModel({
95-
model: preferredModelProvider.name,
96-
provider_id: preferredModelProvider.provider_id,
105+
formState.updateFormValues({
106+
preferredModel: {
107+
model: preferredModelProvider.name,
108+
provider_id: preferredModelProvider.provider_id,
109+
},
97110
});
98111
}
99112
}}
@@ -109,9 +122,11 @@ export function WorkspacePreferredModel({
109122
</div>
110123
</CardBody>
111124
<CardFooter className="justify-end">
112-
<Button isDisabled={isArchived || isModelsEmpty} type="submit">
113-
Save
114-
</Button>
125+
<FormButtons
126+
isPending={isPending}
127+
formState={formState}
128+
canSubmit={!isArchived}
129+
/>
115130
</CardFooter>
116131
</Card>
117132
</Form>

src/features/workspace/hooks/use-preferred-preferred-model.ts

+7-10
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import { MuxRule, V1GetWorkspaceMuxesData } from "@/api/generated";
22
import { v1GetWorkspaceMuxesOptions } from "@/api/generated/@tanstack/react-query.gen";
3+
import { useFormState } from "@/hooks/useFormState";
34
import { useQuery } from "@tanstack/react-query";
4-
import { useEffect, useMemo, useState } from "react";
5+
import { useMemo } from "react";
56

67
type ModelRule = Omit<MuxRule, "matcher_type" | "matcher"> & {};
78

@@ -21,8 +22,6 @@ const usePreferredModel = (options: {
2122
};
2223

2324
export const usePreferredModelWorkspace = (workspaceName: string) => {
24-
const [preferredModel, setPreferredModel] =
25-
useState<ModelRule>(DEFAULT_STATE);
2625
const options: V1GetWorkspaceMuxesData &
2726
Omit<V1GetWorkspaceMuxesData, "body"> = useMemo(
2827
() => ({
@@ -31,12 +30,10 @@ export const usePreferredModelWorkspace = (workspaceName: string) => {
3130
[workspaceName],
3231
);
3332
const { data, isPending } = usePreferredModel(options);
33+
const providerModel = data?.[0];
34+
const formState = useFormState<{ preferredModel: ModelRule }>({
35+
preferredModel: providerModel ?? DEFAULT_STATE,
36+
});
3437

35-
useEffect(() => {
36-
const providerModel = data?.[0];
37-
38-
setPreferredModel(providerModel ?? DEFAULT_STATE);
39-
}, [data, setPreferredModel]);
40-
41-
return { preferredModel, setPreferredModel, isPending };
38+
return { isPending, formState };
4239
};

src/hooks/useFormState.ts

+36-9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { isEqual } from "lodash";
2-
import { useState } from "react";
2+
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
33

44
export type FormState<T> = {
55
values: T;
@@ -8,23 +8,50 @@ export type FormState<T> = {
88
isDirty: boolean;
99
};
1010

11+
function useDeepMemo<T>(value: T): T {
12+
const ref = useRef<T>(value);
13+
if (!isEqual(ref.current, value)) {
14+
ref.current = value;
15+
}
16+
return ref.current;
17+
}
18+
1119
export function useFormState<Values extends Record<string, unknown>>(
1220
initialValues: Values,
1321
): FormState<Values> {
22+
const memoizedInitialValues = useDeepMemo(initialValues);
23+
1424
// this could be replaced with some form library later
15-
const [values, setValues] = useState<Values>(initialValues);
16-
const updateFormValues = (newState: Partial<Values>) => {
25+
const [values, setValues] = useState<Values>(memoizedInitialValues);
26+
const [originalValues, setOriginalValues] = useState<Values>(values);
27+
28+
useEffect(() => {
29+
// this logic supports the use case when the initialValues change
30+
// due to an async request for instance
31+
setOriginalValues(memoizedInitialValues);
32+
setValues(memoizedInitialValues);
33+
}, [memoizedInitialValues]);
34+
35+
const updateFormValues = useCallback((newState: Partial<Values>) => {
1736
setValues((prevState: Values) => ({
1837
...prevState,
1938
...newState,
2039
}));
21-
};
40+
}, []);
41+
42+
const resetForm = useCallback(() => {
43+
setValues(originalValues);
44+
}, [originalValues]);
2245

23-
const resetForm = () => {
24-
setValues(initialValues);
25-
};
46+
const isDirty = useMemo(
47+
() => !isEqual(values, originalValues),
48+
[values, originalValues],
49+
);
2650

27-
const isDirty = !isEqual(values, initialValues);
51+
const formState = useMemo(
52+
() => ({ values, updateFormValues, resetForm, isDirty }),
53+
[values, updateFormValues, resetForm, isDirty],
54+
);
2855

29-
return { values, updateFormValues, resetForm, isDirty };
56+
return formState;
3057
}

0 commit comments

Comments
 (0)