diff --git a/src/__tests__/integration/api/agent.test.ts b/src/__tests__/integration/api/agent.test.ts index 1f407ec1a9..8fa5973c98 100644 --- a/src/__tests__/integration/api/agent.test.ts +++ b/src/__tests__/integration/api/agent.test.ts @@ -149,6 +149,48 @@ describe("POST /api/agent Integration Tests", () => { expect(sessionBody.agent_name).toBeUndefined(); expect(sessionBody.model).toBe("sonnet"); }); + + test("passes full provider/name model string through to session payload (openrouter)", async () => { + const user = await createTestUser(); + const workspace = await createTestWorkspace({ ownerId: user.id }); + const task = await createTestTask({ + workspaceId: workspace.id, + createdById: user.id, + title: "Kimi test task", + }); + + await db.task.update({ + where: { id: task.id }, + data: { mode: "agent", model: "openrouter/moonshotai/kimi-k2.6" }, + }); + + process.env.OPENROUTER_API_KEY = "test-openrouter-key"; + + getMockedSession().mockResolvedValue(createAuthenticatedSession(user)); + + const request = createPostRequest("http://localhost/api/agent", { + message: "Help me write a feature", + taskId: task.id, + model: "openrouter/moonshotai/kimi-k2.6", + }); + + const response = await POST(request); + const data = await response.json(); + + expect(response.status).toBe(200); + expect(data.success).toBe(true); + + const sessionCall = mockFetch.mock.calls.find(([url]: [string]) => + url.includes("/session"), + ); + expect(sessionCall).toBeDefined(); + + const sessionBody = JSON.parse(sessionCall[1].body); + expect(sessionBody.model).toBe("openrouter/moonshotai/kimi-k2.6"); + expect(sessionBody.apiKey).toBe("test-openrouter-key"); + + delete process.env.OPENROUTER_API_KEY; + }); }); // NOTE: Most tests commented out due to significant implementation gaps: diff --git a/src/__tests__/integration/api/pool-manager/create-pool.test.ts b/src/__tests__/integration/api/pool-manager/create-pool.test.ts index ed94dd6f68..3fca1cf106 100644 --- a/src/__tests__/integration/api/pool-manager/create-pool.test.ts +++ b/src/__tests__/integration/api/pool-manager/create-pool.test.ts @@ -213,7 +213,7 @@ describe("POST /api/pool-manager/create-pool", () => { expect(data.pool).toEqual(mockPool); }); - test("allows workspace member to create pool", async () => { + test("denies workspace member (DEVELOPER role) from creating pool — requires ADMIN or OWNER", async () => { const member = await createTestUser({ email: "member@test.com" }); await db.workspaceMember.create({ data: { @@ -225,20 +225,6 @@ describe("POST /api/pool-manager/create-pool", () => { getMockedSession().mockResolvedValue(createAuthenticatedSession(member)); - const mockPool = { - id: "pool-456", - name: swarm.id, - status: "active" as const, - owner_id: member.id, - created_at: new Date().toISOString(), - updated_at: new Date().toISOString(), - }; - - mockPoolManagerService.mockReturnValue({ - createPool: vi.fn().mockResolvedValue(mockPool), - updateApiKey: vi.fn(), - } as any); - const request = createPostRequest( "http://localhost/api/pool-manager/create-pool", { @@ -248,8 +234,7 @@ describe("POST /api/pool-manager/create-pool", () => { ); const response = await POST(request); - const data = await expectSuccess(response, 201); - expect(data.pool).toEqual(mockPool); + await expectNotFound(response, "Workspace not found or access denied"); }); }); diff --git a/src/__tests__/unit/lib/ai/models.test.ts b/src/__tests__/unit/lib/ai/models.test.ts index 7d2d5d336b..4a68ce723a 100644 --- a/src/__tests__/unit/lib/ai/models.test.ts +++ b/src/__tests__/unit/lib/ai/models.test.ts @@ -13,8 +13,30 @@ describe("models", () => { expect(isValidModel("haiku")).toBe(true); }); - test("returns false for unknown model", () => { - expect(isValidModel("unknown-model")).toBe(false); + test("returns true for full provider/name format strings", () => { + expect(isValidModel("openrouter/moonshotai/kimi-k2.6")).toBe(true); + expect(isValidModel("anthropic/claude-sonnet-4-6")).toBe(true); + expect(isValidModel("openai/gpt-4o")).toBe(true); + expect(isValidModel("google/gemini-pro")).toBe(true); + }); + + test("returns true for legacy short aliases", () => { + expect(isValidModel("sonnet")).toBe(true); + expect(isValidModel("gpt")).toBe(true); + expect(isValidModel("gemini")).toBe(true); + }); + + test("returns true for any non-empty string (previously unknown model)", () => { + // isValidModel now accepts any non-empty string; the admin panel is source of truth + expect(isValidModel("unknown-model")).toBe(true); + }); + + test("returns false for empty string", () => { + expect(isValidModel("")).toBe(false); + }); + + test("returns false for whitespace-only string", () => { + expect(isValidModel(" ")).toBe(false); }); test("returns false for non-string values", () => { diff --git a/src/__tests__/unit/pages/task-reconciliation.test.tsx b/src/__tests__/unit/pages/task-reconciliation.test.tsx index 69d403a65f..02a1c2c3c3 100644 --- a/src/__tests__/unit/pages/task-reconciliation.test.tsx +++ b/src/__tests__/unit/pages/task-reconciliation.test.tsx @@ -98,25 +98,43 @@ vi.mock("@/components/ui/resizable", () => { return { ResizablePanel: ({ children }: any) => React.createElement("div", null, children), ResizablePanelGroup: ({ children }: any) => React.createElement("div", null, children), - ResizableHandle: () => React.createElement("div", null), + ResizableHandle: () => React.createElement("div"), }; }); -vi.mock("framer-motion", () => { - const React = require("react"); - return { - motion: { - div: ({ children, ...props }: any) => React.createElement("div", props, children), - }, - AnimatePresence: ({ children }: any) => - React.createElement(React.Fragment, null, children), - }; -}); +vi.mock("@/hooks/useWorkspaceAccess", () => ({ + useWorkspaceAccess: () => ({ + canRead: true, + canWrite: true, + canAdmin: false, + permissions: {}, + }), +})); + +vi.mock("@/contexts/StreamContext", () => ({ + useStreamContext: () => ({ + streamContext: null, + onMessage: vi.fn(), + onWorkflowStatusUpdate: vi.fn(), + }), +})); + +vi.mock("sonner", () => ({ + toast: { error: vi.fn(), success: vi.fn() }, +})); -vi.mock("sonner", () => ({ toast: { error: vi.fn(), success: vi.fn() } })); +vi.mock("framer-motion", () => ({ + motion: { + div: ({ children, ...props }: any) => { + const React = require("react"); + return React.createElement("div", props, children); + }, + }, + AnimatePresence: ({ children }: any) => children, +})); // --------------------------------------------------------------------------- -// useWorkflowPolling mock — controllable per-test +// Workflow polling mock — tests can set mockWorkflowPollingData to simulate results // --------------------------------------------------------------------------- let mockWorkflowPollingData: any = null; vi.mock("@/hooks/useWorkflowPolling", () => ({ @@ -167,16 +185,49 @@ function makeMessagesResponse(overrides: { describe("TaskChatPage — reconciliation polling", () => { const mockFetch = vi.fn(); + // Per-test URL-keyed response queues. Fetch calls are routed by URL substring. + // More specific keys should be pushed first so they match before shorter keys. + const urlQueues: Map Promise }>> = new Map(); + + function pushFetchResponse( + urlSubstring: string, + response: { ok: boolean; json: () => Promise }, + ) { + if (!urlQueues.has(urlSubstring)) urlQueues.set(urlSubstring, []); + urlQueues.get(urlSubstring)!.push(response); + } + beforeEach(() => { vi.clearAllMocks(); capturedOnWorkflowStatusUpdate = null; mockWorkflowPollingData = null; + urlQueues.clear(); + + // Route fetch calls by URL substring. This avoids FIFO queue conflicts between + // the /api/llm-models fetch (added for the LLM model selector) and the task + // messages fetch — both fire on mount. + mockFetch.mockImplementation((url: string) => { + if (typeof url === "string") { + for (const [key, queue] of urlQueues.entries()) { + if (url.includes(key) && queue.length > 0) { + return Promise.resolve(queue.shift()!); + } + } + // Fallback by URL pattern + if (url.includes("/api/llm-models")) { + return Promise.resolve({ ok: true, json: async () => ({ models: [] }) }); + } + } + return Promise.resolve({ ok: true, json: async () => ({}) }); + }); + global.fetch = mockFetch; }); it("starts reconciliation when task loads with IN_PROGRESS + stakworkProjectId", async () => { - mockFetch.mockResolvedValue( - makeMessagesResponse({ workflowStatus: "IN_PROGRESS", stakworkProjectId: 42 }) + pushFetchResponse( + "/api/tasks/task-abc/messages", + makeMessagesResponse({ workflowStatus: "IN_PROGRESS", stakworkProjectId: 42 }), ); const { useWorkflowPolling } = await import("@/hooks/useWorkflowPolling"); @@ -204,8 +255,9 @@ describe("TaskChatPage — reconciliation polling", () => { }); it("does not start reconciliation when task loads with COMPLETED status", async () => { - mockFetch.mockResolvedValue( - makeMessagesResponse({ workflowStatus: "COMPLETED", stakworkProjectId: 42 }) + pushFetchResponse( + "/api/tasks/task-abc/messages", + makeMessagesResponse({ workflowStatus: "COMPLETED", stakworkProjectId: 42 }), ); const { useWorkflowPolling } = await import("@/hooks/useWorkflowPolling"); @@ -233,8 +285,9 @@ describe("TaskChatPage — reconciliation polling", () => { }); it("does not start reconciliation when IN_PROGRESS but no stakworkProjectId", async () => { - mockFetch.mockResolvedValue( - makeMessagesResponse({ workflowStatus: "IN_PROGRESS", stakworkProjectId: null }) + pushFetchResponse( + "/api/tasks/task-abc/messages", + makeMessagesResponse({ workflowStatus: "IN_PROGRESS", stakworkProjectId: null }), ); const { useWorkflowPolling } = await import("@/hooks/useWorkflowPolling"); @@ -262,12 +315,12 @@ describe("TaskChatPage — reconciliation polling", () => { }); it("patches workflowStatus to COMPLETED and stops reconciling when polling returns 'completed'", async () => { - // Page loads with IN_PROGRESS - mockFetch.mockResolvedValueOnce( - makeMessagesResponse({ workflowStatus: "IN_PROGRESS", stakworkProjectId: 42 }) + // Push more-specific URL first so messages fetch matches before the PATCH URL key + pushFetchResponse( + "/api/tasks/task-abc/messages", + makeMessagesResponse({ workflowStatus: "IN_PROGRESS", stakworkProjectId: 42 }), ); - // PATCH call - mockFetch.mockResolvedValueOnce({ ok: true, json: async () => ({}) }); + pushFetchResponse("/api/tasks/task-abc", { ok: true, json: async () => ({}) }); // Simulate polling returning completed mockWorkflowPollingData = { @@ -304,10 +357,11 @@ describe("TaskChatPage — reconciliation polling", () => { }); it("patches workflowStatus to FAILED and stops reconciling when polling returns 'failed'", async () => { - mockFetch.mockResolvedValueOnce( - makeMessagesResponse({ workflowStatus: "IN_PROGRESS", stakworkProjectId: 99 }) + pushFetchResponse( + "/api/tasks/task-abc/messages", + makeMessagesResponse({ workflowStatus: "IN_PROGRESS", stakworkProjectId: 99 }), ); - mockFetch.mockResolvedValueOnce({ ok: true, json: async () => ({}) }); + pushFetchResponse("/api/tasks/task-abc", { ok: true, json: async () => ({}) }); mockWorkflowPollingData = { status: "failed", @@ -335,8 +389,9 @@ describe("TaskChatPage — reconciliation polling", () => { it("stops reconciliation when Pusher WORKFLOW_STATUS_UPDATE fires, with no extra PATCH", async () => { // Task loads with IN_PROGRESS — reconciliation starts - mockFetch.mockResolvedValueOnce( - makeMessagesResponse({ workflowStatus: "IN_PROGRESS", stakworkProjectId: 77 }) + pushFetchResponse( + "/api/tasks/task-abc/messages", + makeMessagesResponse({ workflowStatus: "IN_PROGRESS", stakworkProjectId: 77 }), ); // No terminal polling data — reconciliation is active but hasn't resolved yet diff --git a/src/app/api/agent/route.ts b/src/app/api/agent/route.ts index 23282788d1..315fa0746d 100644 --- a/src/app/api/agent/route.ts +++ b/src/app/api/agent/route.ts @@ -64,7 +64,7 @@ import { db } from "@/lib/db"; import { EncryptionService } from "@/lib/encryption"; import { ChatRole, ChatStatus, ArtifactType } from "@prisma/client"; import { createWebhookToken, generateWebhookSecret } from "@/lib/auth/agent-jwt"; -import { isValidModel, getApiKeyForModel, type ModelName } from "@/lib/ai/models"; +import { isValidModel, getApiKeyForModel } from "@/lib/ai/models"; import { canAccessServerFeature, FEATURE_FLAGS } from "@/lib/feature-flags"; import { claimPodAndGetFrontend, updatePodRepositories, POD_PORTS, releasePodById } from "@/lib/pods"; @@ -325,7 +325,7 @@ async function createAgentSession( agentPassword: string | null, taskId: string, webhookUrl: string, - effectiveModel: ModelName | undefined, + effectiveModel: string | undefined, ): Promise { const sessionUrl = agentUrl.replace(/\/$/, "") + "/session"; @@ -409,7 +409,7 @@ export async function POST(request: NextRequest) { const { message, taskId, artifacts = [], model } = body; // Validate model parameter if provided - const requestModel: ModelName | undefined = isValidModel(model) ? model : undefined; + const requestModel: string | undefined = isValidModel(model) ? model : undefined; // 1. Authenticate user const session = await getServerSession(authOptions); @@ -475,8 +475,8 @@ export async function POST(request: NextRequest) { } // Determine effective model: request > task > default - const taskModel: ModelName | undefined = isValidModel(task.model) ? task.model : undefined; - const effectiveModel: ModelName | undefined = requestModel || taskModel; + const taskModel: string | undefined = isValidModel(task.model) ? task.model : undefined; + const effectiveModel: string | undefined = requestModel || taskModel; // 3. Ensure pod is available (claim if needed) let agentCredentials: AgentCredentials; diff --git a/src/app/api/pool-manager/create-pool/route.ts b/src/app/api/pool-manager/create-pool/route.ts index 31d8b65a70..4ad813b659 100644 --- a/src/app/api/pool-manager/create-pool/route.ts +++ b/src/app/api/pool-manager/create-pool/route.ts @@ -4,6 +4,7 @@ import { EncryptionService } from "@/lib/encryption"; import { poolManagerService } from "@/lib/service-factory"; import { saveOrUpdateSwarm } from "@/services/swarm/db"; import { getSwarmPoolApiKeyFor, updateSwarmPoolApiKeyFor } from "@/services/swarm/secrets"; +import { validateWorkspaceAccessById } from "@/services/workspace"; import { isApiError } from "@/types/errors"; import { getServerSession } from "next-auth/next"; import { NextRequest, NextResponse } from "next/server"; @@ -88,7 +89,9 @@ export async function POST(request: NextRequest) { return NextResponse.json({ error: "Invalid user session" }, { status: 401 }); } - // Find the swarm and verify user has access to the workspace + // Find the swarm to resolve the canonical workspaceId (never trust + // the body-supplied workspaceId alone — an attacker could pass their + // own workspaceId with a victim's swarmId to pass the auth check). const swarm = await db.swarm.findFirst({ where: { ...(swarmId ? { swarmId } : {}), @@ -99,11 +102,6 @@ export async function POST(request: NextRequest) { select: { id: true, slug: true, - ownerId: true, - members: { - where: { userId, leftAt: null }, - select: { role: true }, - }, }, }, }, @@ -113,11 +111,6 @@ export async function POST(request: NextRequest) { return NextResponse.json({ error: "Swarm not found" }, { status: 404 }); } - // IDOR guard: previously the owner/member check ran further down, - // AFTER the handler had already called `saveOrUpdateSwarm({ containerFiles })` - // with attacker-controlled content. Run the authz check immediately - // after the swarm lookup so no swarm row or secret decryption work - // happens on behalf of a non-member. if (!swarm.workspace) { return NextResponse.json( { error: "Workspace not found or access denied" }, @@ -125,9 +118,12 @@ export async function POST(request: NextRequest) { ); } - const isOwner = swarm.workspace.ownerId === userId; - const isMember = swarm.workspace.members.length > 0; - if (!isOwner && !isMember) { + // IDOR + privilege guard: pool creation is an infrastructure-level + // operation (equivalent to delete). Require ADMIN or OWNER on the + // swarm's actual workspace — any lesser role (VIEWER, DEVELOPER, etc.) + // must not be able to provision compute resources. + const access = await validateWorkspaceAccessById(swarm.workspaceId, userId); + if (!access.hasAccess || !access.canAdmin) { return NextResponse.json( { error: "Workspace not found or access denied" }, { status: 404 }, diff --git a/src/app/api/tasks/route.ts b/src/app/api/tasks/route.ts index ed41ec9755..be99914fcb 100644 --- a/src/app/api/tasks/route.ts +++ b/src/app/api/tasks/route.ts @@ -2,7 +2,7 @@ import { db } from "@/lib/db"; import { extractPrArtifact, sanitizeTask } from "@/lib/helpers/tasks"; import { Priority, Prisma, TaskSourceType, TaskStatus, WorkflowStatus } from "@prisma/client"; import { NextRequest, NextResponse } from "next/server"; -import { VALID_MODELS } from "@/lib/ai/models"; +import { isValidModel } from "@/lib/ai/models"; import { getMiddlewareContext, requireAuth } from "@/lib/middleware/utils"; import { resolveWorkspaceAccess, requireReadAccess, isPublicViewer } from "@/lib/auth/workspace-access"; import { toPublicTasks } from "@/lib/auth/public-redact"; @@ -544,7 +544,7 @@ export async function POST(request: NextRequest) { } // Validate model if provided - const taskModel = model && VALID_MODELS.includes(model) ? model : null; + const taskModel = isValidModel(model) ? model : null; // Create the task const task = await db.task.create({ diff --git a/src/app/w/[slug]/task/[...taskParams]/components/TaskStartInput.tsx b/src/app/w/[slug]/task/[...taskParams]/components/TaskStartInput.tsx index 990f1adc1d..83b0d3e8c1 100644 --- a/src/app/w/[slug]/task/[...taskParams]/components/TaskStartInput.tsx +++ b/src/app/w/[slug]/task/[...taskParams]/components/TaskStartInput.tsx @@ -51,7 +51,7 @@ import { WorkflowVersionSelector } from "@/components/workflow/WorkflowVersionSe import { PromptsPanel } from "@/components/prompts"; import { toast } from "sonner"; import { cn } from "@/lib/utils"; -import { VALID_MODELS, type ModelName } from "@/lib/ai/models"; +import { VALID_MODELS, getModelValue, type LlmModelOption } from "@/lib/ai/models"; import { useWorkspace } from "@/hooks/useWorkspace"; import { useRepoBranches } from "@/hooks/useRepoBranches"; @@ -65,7 +65,7 @@ interface PendingImage { } interface TaskStartInputProps { - onStart: (task: string, model?: ModelName, autoMerge?: boolean, images?: File[], repositoryId?: string, branch?: string, runBuild?: boolean, runTestSuite?: boolean) => void; + onStart: (task: string, model?: string, autoMerge?: boolean, images?: File[], repositoryId?: string, branch?: string, runBuild?: boolean, runTestSuite?: boolean) => void; taskMode: string; onModeChange: (mode: string) => void; isLoading?: boolean; @@ -81,8 +81,9 @@ interface TaskStartInputProps { // Project debugger props onProjectSelect?: (projectId: string, projectData: any) => void; // Model selection for agent mode - selectedModel?: ModelName; - onModelChange?: (model: ModelName) => void; + selectedModel?: string; + onModelChange?: (model: string) => void; + llmModels?: LlmModelOption[]; } export function TaskStartInput({ @@ -101,6 +102,7 @@ export function TaskStartInput({ onProjectSelect, selectedModel = "sonnet", onModelChange, + llmModels, }: TaskStartInputProps) { const searchParams = useSearchParams(); const { workspace } = useWorkspace(); @@ -831,7 +833,7 @@ export function TaskStartInput({ )} {taskMode === "agent" && onModelChange && ( - onModelChange(value)}>
@@ -839,13 +841,21 @@ export function TaskStartInput({
- {VALID_MODELS.map((model) => ( - + {(llmModels && llmModels.length > 0 ? llmModels : []).map((m) => ( +
- {model} + {m.name}
))} + {(!llmModels || llmModels.length === 0) && + VALID_MODELS.map((model) => ( + +
+ {model} +
+
+ ))}
)} diff --git a/src/app/w/[slug]/task/[...taskParams]/page.tsx b/src/app/w/[slug]/task/[...taskParams]/page.tsx index 826290d51d..150c537d4a 100644 --- a/src/app/w/[slug]/task/[...taskParams]/page.tsx +++ b/src/app/w/[slug]/task/[...taskParams]/page.tsx @@ -45,7 +45,7 @@ import { useStreamContext } from "@/hooks/useStreamContext"; import { FEATURE_FLAGS } from "@/lib/feature-flags"; import { useSession } from "next-auth/react"; import { WorkflowTransition, getStepType } from "@/types/stakwork/workflow"; -import type { ModelName } from "@/lib/ai/models"; +import { getModelValue, type LlmModelOption } from "@/lib/ai/models"; // Generate unique IDs to prevent collisions function generateUniqueId() { @@ -149,7 +149,8 @@ export default function TaskChatPage() { const [isSubsequentCommit, setIsSubsequentCommit] = useState(false); const [showPreview, setShowPreview] = useState(false); const [showBountyModal, setShowBountyModal] = useState(false); - const [selectedModel, setSelectedModel] = useState("sonnet"); + const [llmModels, setLlmModels] = useState([]); + const [selectedModel, setSelectedModel] = useState("sonnet"); const [isPrototypeTask, setIsPrototypeTask] = useState(false); const [isSavingPlan, setIsSavingPlan] = useState(false); const [isReconciling, setIsReconciling] = useState(false); @@ -222,6 +223,27 @@ export default function TaskChatPage() { onStreamStatusUpdate(update); }, [taskMode, onStreamStatusUpdate]); + // Fetch active LLM models for task mode selector + useEffect(() => { + const fetchLlmModels = async () => { + try { + const response = await fetch("/api/llm-models"); + if (response.ok) { + const data = await response.json(); + const models: LlmModelOption[] = data.models ?? []; + setLlmModels(models); + if (models.length > 0) { + const defaultModel = models.find((m) => m.isTaskDefault); + setSelectedModel(getModelValue(defaultModel ?? models[0])); + } + } + } catch (error) { + console.error("Error fetching LLM models:", error); + } + }; + fetchLlmModels(); + }, []); // eslint-disable-line react-hooks/exhaustive-deps + // When reconciliation polling returns a terminal state, persist it to the DB and update UI useEffect(() => { if (!reconcilingWorkflowData || !currentTaskId || !isReconciling) return; @@ -911,7 +933,7 @@ export default function TaskChatPage() { } }; - const handleStart = async (msg: string, model?: ModelName, autoMerge?: boolean, images?: File[], repositoryId?: string, branch?: string, runBuild?: boolean, runTestSuite?: boolean) => { + const handleStart = async (msg: string, model?: string, autoMerge?: boolean, images?: File[], repositoryId?: string, branch?: string, runBuild?: boolean, runTestSuite?: boolean) => { if (isLoading) return; // Prevent duplicate sends setIsLoading(true); @@ -1887,6 +1909,7 @@ Plan and implement the real feature from this branch.`; workflowsError={workflowsError} selectedModel={selectedModel} onModelChange={setSelectedModel} + llmModels={llmModels} /> ) : ( diff --git a/src/lib/ai/models.ts b/src/lib/ai/models.ts index 4f269d0192..79a9dc3715 100644 --- a/src/lib/ai/models.ts +++ b/src/lib/ai/models.ts @@ -5,9 +5,14 @@ */ // Valid model names that can be passed from frontend -export type ModelName = "sonnet" | "opus" | "haiku" | "kimi" | "gemini" | "gpt"; +// Using string & {} preserves autocomplete for known aliases while accepting full provider/name strings +export type ModelName = (string & {}); -export const VALID_MODELS: ModelName[] = ["sonnet", "opus", "haiku", "kimi", "gemini", "gpt"]; +/** + * @deprecated Use `isValidModel` instead. Callers checking VALID_MODELS.includes() should + * migrate to `isValidModel()` which accepts both legacy short aliases and full provider/name strings. + */ +export const VALID_MODELS: string[] = ["sonnet", "opus", "haiku", "kimi", "gemini", "gpt"]; // Map model names to their API key environment variables export const API_KEY_ENV_VARS: Record = { @@ -19,8 +24,8 @@ export const API_KEY_ENV_VARS: Record = { kimi: "OPENROUTER_API_KEY", }; -export function isValidModel(model: unknown): model is ModelName { - return typeof model === "string" && VALID_MODELS.includes(model as ModelName); +export function isValidModel(model: unknown): model is string { + return typeof model === "string" && model.trim().length > 0; } // Map LlmProvider enum values to their API key environment variables