Skip to content

Commit d1669b9

Browse files
committed
Seed new thread drafts from sticky settings
1 parent 867f55a commit d1669b9

File tree

3 files changed

+85
-48
lines changed

3 files changed

+85
-48
lines changed

apps/web/src/components/Sidebar.tsx

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ import { isElectron } from "../env";
4646
import { APP_STAGE_LABEL, APP_VERSION } from "../branding";
4747
import { isLinuxPlatform, isMacPlatform, newCommandId, newProjectId } from "../lib/utils";
4848
import { newThreadId } from "../lib/utils";
49+
import { seedNewThreadDraft } from "../lib/newThreadDraft";
4950
import { useStore } from "../store";
5051
import { isChatNewLocalShortcut, isChatNewShortcut, shortcutLabelForCommand } from "../keybindings";
5152
import { derivePendingApprovals, derivePendingUserInputs } from "../session-logic";
@@ -264,6 +265,8 @@ export default function Sidebar() {
264265
const reorderProjects = useStore((store) => store.reorderProjects);
265266
const clearComposerDraftForThread = useComposerDraftStore((store) => store.clearThreadDraft);
266267
const draftByThreadId = useComposerDraftStore((store) => store.draftsByThreadId);
268+
const stickyModel = useComposerDraftStore((store) => store.stickyModel);
269+
const stickyModelOptions = useComposerDraftStore((store) => store.stickyModelOptions);
267270
const getDraftThreadByProjectId = useComposerDraftStore(
268271
(store) => store.getDraftThreadByProjectId,
269272
);
@@ -274,6 +277,7 @@ export default function Sidebar() {
274277
const setDraftThreadContext = useComposerDraftStore((store) => store.setDraftThreadContext);
275278
const setDraftProvider = useComposerDraftStore((store) => store.setProvider);
276279
const setDraftModel = useComposerDraftStore((store) => store.setModel);
280+
const setDraftModelOptions = useComposerDraftStore((store) => store.setModelOptions);
277281
const clearProjectDraftThreadId = useComposerDraftStore(
278282
(store) => store.clearProjectDraftThreadId,
279283
);
@@ -404,31 +408,14 @@ export default function Sidebar() {
404408
model?: ModelSlug | null;
405409
},
406410
): Promise<void> => {
407-
const activeThread = routeThreadId
408-
? (threads.find((thread) => thread.id === routeThreadId) ?? null)
409-
: null;
410-
const activeDraftState = routeThreadId ? (draftByThreadId[routeThreadId] ?? null) : null;
411-
const activeDraftThread = routeThreadId ? getDraftThread(routeThreadId) : null;
412-
const sourceProjectId = activeThread?.projectId ?? activeDraftThread?.projectId ?? null;
413-
const shouldSeedFromActiveContext = sourceProjectId === projectId;
414-
const nextProvider =
415-
options?.provider !== undefined
416-
? (options.provider ?? null)
417-
: shouldSeedFromActiveContext
418-
? (activeDraftState?.provider ?? activeThread?.session?.provider ?? null)
419-
: null;
420-
const nextModel =
421-
options?.model !== undefined
422-
? (options.model ?? null)
423-
: shouldSeedFromActiveContext
424-
? (activeDraftState?.model ?? activeThread?.model ?? null)
425-
: null;
426411
const hasBranchOption = options?.branch !== undefined;
427412
const hasWorktreePathOption = options?.worktreePath !== undefined;
428413
const hasEnvModeOption = options?.envMode !== undefined;
429-
const hasProviderOption = nextProvider !== null;
430-
const hasModelOption = nextModel !== null;
414+
const activeDraftThread = routeThreadId ? (getDraftThread(routeThreadId) ?? null) : null;
431415
const storedDraftThread = getDraftThreadByProjectId(projectId);
416+
const storedComposerDraft = storedDraftThread
417+
? (draftByThreadId[storedDraftThread.threadId] ?? null)
418+
: null;
432419
if (storedDraftThread) {
433420
return (async () => {
434421
if (hasBranchOption || hasWorktreePathOption || hasEnvModeOption) {
@@ -438,12 +425,16 @@ export default function Sidebar() {
438425
...(hasEnvModeOption ? { envMode: options?.envMode } : {}),
439426
});
440427
}
441-
if (hasProviderOption) {
442-
setDraftProvider(storedDraftThread.threadId, nextProvider);
443-
}
444-
if (hasModelOption) {
445-
setDraftModel(storedDraftThread.threadId, nextModel, nextProvider);
446-
}
428+
seedNewThreadDraft({
429+
threadId: storedDraftThread.threadId,
430+
...(options?.provider !== undefined ? { provider: options.provider } : {}),
431+
...(options?.model !== undefined ? { model: options.model } : {}),
432+
stickyModel: storedComposerDraft ? null : stickyModel,
433+
stickyModelOptions: storedComposerDraft ? {} : stickyModelOptions,
434+
setProvider: setDraftProvider,
435+
setModel: setDraftModel,
436+
setModelOptions: setDraftModelOptions,
437+
});
447438
setProjectDraftThreadId(projectId, storedDraftThread.threadId);
448439
if (routeThreadId === storedDraftThread.threadId) {
449440
return;
@@ -464,12 +455,16 @@ export default function Sidebar() {
464455
...(hasEnvModeOption ? { envMode: options?.envMode } : {}),
465456
});
466457
}
467-
if (hasProviderOption) {
468-
setDraftProvider(routeThreadId, nextProvider);
469-
}
470-
if (hasModelOption) {
471-
setDraftModel(routeThreadId, nextModel, nextProvider);
472-
}
458+
seedNewThreadDraft({
459+
threadId: routeThreadId,
460+
...(options?.provider !== undefined ? { provider: options.provider } : {}),
461+
...(options?.model !== undefined ? { model: options.model } : {}),
462+
stickyModel: draftByThreadId[routeThreadId] ? null : stickyModel,
463+
stickyModelOptions: draftByThreadId[routeThreadId] ? {} : stickyModelOptions,
464+
setProvider: setDraftProvider,
465+
setModel: setDraftModel,
466+
setModelOptions: setDraftModelOptions,
467+
});
473468
setProjectDraftThreadId(projectId, routeThreadId);
474469
return Promise.resolve();
475470
}
@@ -483,12 +478,16 @@ export default function Sidebar() {
483478
envMode: options?.envMode ?? "local",
484479
runtimeMode: DEFAULT_RUNTIME_MODE,
485480
});
486-
if (hasProviderOption) {
487-
setDraftProvider(threadId, nextProvider);
488-
}
489-
if (hasModelOption) {
490-
setDraftModel(threadId, nextModel, nextProvider);
491-
}
481+
seedNewThreadDraft({
482+
threadId,
483+
...(options?.provider !== undefined ? { provider: options.provider } : {}),
484+
...(options?.model !== undefined ? { model: options.model } : {}),
485+
stickyModel,
486+
stickyModelOptions,
487+
setProvider: setDraftProvider,
488+
setModel: setDraftModel,
489+
setModelOptions: setDraftModelOptions,
490+
});
492491

493492
await navigate({
494493
to: "/$threadId",
@@ -504,10 +503,12 @@ export default function Sidebar() {
504503
getDraftThread,
505504
routeThreadId,
506505
setDraftModel,
506+
setDraftModelOptions,
507507
setDraftThreadContext,
508508
setDraftProvider,
509509
setProjectDraftThreadId,
510-
threads,
510+
stickyModel,
511+
stickyModelOptions,
511512
],
512513
);
513514

apps/web/src/hooks/useHandleNewThread.ts

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import { DEFAULT_RUNTIME_MODE, type ProjectId, ThreadId } from "@t3tools/contracts";
22
import { useNavigate, useParams } from "@tanstack/react-router";
33
import { useCallback } from "react";
4-
import { inferProviderForModel } from "@t3tools/shared/model";
54
import {
65
type DraftThreadEnvMode,
76
type DraftThreadState,
87
useComposerDraftStore,
98
} from "../composerDraftStore";
9+
import { seedNewThreadDraft } from "../lib/newThreadDraft";
1010
import { newThreadId } from "../lib/utils";
1111
import { useStore } from "../store";
1212

@@ -102,13 +102,14 @@ export function useHandleNewThread() {
102102
envMode: options?.envMode ?? "local",
103103
runtimeMode: DEFAULT_RUNTIME_MODE,
104104
});
105-
if (stickyModel) {
106-
setProvider(threadId, inferProviderForModel(stickyModel));
107-
setModel(threadId, stickyModel);
108-
}
109-
if (Object.keys(stickyModelOptions).length > 0) {
110-
setModelOptions(threadId, stickyModelOptions);
111-
}
105+
seedNewThreadDraft({
106+
threadId,
107+
stickyModel,
108+
stickyModelOptions,
109+
setProvider,
110+
setModel,
111+
setModelOptions,
112+
});
112113

113114
await navigate({
114115
to: "/$threadId",

apps/web/src/lib/newThreadDraft.ts

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import { type ProviderKind, type ProviderModelOptions, type ThreadId } from "@t3tools/contracts";
2+
import { inferProviderForModel } from "@t3tools/shared/model";
3+
4+
interface SeedNewThreadDraftInput {
5+
readonly threadId: ThreadId;
6+
readonly provider?: ProviderKind | null;
7+
readonly model?: string | null;
8+
readonly stickyModel: string | null;
9+
readonly stickyModelOptions: ProviderModelOptions;
10+
readonly setProvider: (threadId: ThreadId, provider: ProviderKind | null | undefined) => void;
11+
readonly setModel: (
12+
threadId: ThreadId,
13+
model: string | null | undefined,
14+
provider?: ProviderKind | null | undefined,
15+
) => void;
16+
readonly setModelOptions: (
17+
threadId: ThreadId,
18+
modelOptions: ProviderModelOptions | null | undefined,
19+
) => void;
20+
}
21+
22+
export function seedNewThreadDraft(input: SeedNewThreadDraftInput): void {
23+
const nextModel = input.model ?? input.stickyModel ?? null;
24+
const nextProvider = input.provider ?? (nextModel ? inferProviderForModel(nextModel) : null);
25+
26+
if (nextProvider) {
27+
input.setProvider(input.threadId, nextProvider);
28+
}
29+
if (nextModel) {
30+
input.setModel(input.threadId, nextModel, nextProvider);
31+
}
32+
if (Object.keys(input.stickyModelOptions).length > 0) {
33+
input.setModelOptions(input.threadId, input.stickyModelOptions);
34+
}
35+
}

0 commit comments

Comments
 (0)