Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1092093

Browse files
authored
feat: add internal subagent model override wiring (#24399)
> Mux working on behalf of Mike. ## Summary - add an enabled chat model config lookup by ID for internal callers - keep `spawn_agent` unchanged while threading an internal model override through child subagent chat creation - extend chatd coverage for inherited bindings, plan mode, and internal override behavior ## Validation - `go test ./coderd/x/chatd ./coderd/database/dbauthz` - `make lint`
1 parent eae9444 commit 1092093

9 files changed

Lines changed: 256 additions & 9 deletions

File tree

coderd/database/dbauthz/dbauthz.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3026,6 +3026,13 @@ func (q *querier) GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx context.C
30263026
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetEligibleProvisionerDaemonsByProvisionerJobIDs)(ctx, provisionerJobIDs)
30273027
}
30283028

3029+
func (q *querier) GetEnabledChatModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) {
3030+
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
3031+
return database.ChatModelConfig{}, err
3032+
}
3033+
return q.db.GetEnabledChatModelConfigByID(ctx, id)
3034+
}
3035+
30293036
func (q *querier) GetEnabledChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
30303037
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
30313038
return nil, err

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,11 @@ func (s *MethodTestSuite) TestChats() {
846846
dbm.EXPECT().GetChatWorkspaceTTL(gomock.Any()).Return("1h", nil).AnyTimes()
847847
check.Args().Asserts()
848848
}))
849+
s.Run("GetEnabledChatModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
850+
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
851+
dbm.EXPECT().GetEnabledChatModelConfigByID(gomock.Any(), config.ID).Return(config, nil).AnyTimes()
852+
check.Args(config.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
853+
}))
849854
s.Run("GetEnabledChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
850855
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
851856
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})

coderd/database/dbmetrics/querymetrics.go

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/querier.go

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

Lines changed: 39 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/chatmodelconfigs.sql

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,21 @@ ORDER BY
4646
cmc.updated_at DESC,
4747
cmc.id DESC;
4848

49+
-- name: GetEnabledChatModelConfigByID :one
50+
SELECT
51+
cmc.*
52+
FROM
53+
chat_model_configs cmc
54+
-- Providers can be disabled independently of their model configs.
55+
-- Check both to ensure the selected config is actually usable.
56+
JOIN
57+
chat_providers cp ON cp.provider = cmc.provider
58+
WHERE
59+
cmc.id = @id::uuid
60+
AND cmc.deleted = FALSE
61+
AND cmc.enabled = TRUE
62+
AND cp.enabled = TRUE;
63+
4964
-- name: InsertChatModelConfig :one
5065
INSERT INTO chat_model_configs (
5166
provider,

coderd/x/chatd/subagent.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -411,8 +411,9 @@ func parseSubagentToolChatID(raw string) (uuid.UUID, error) {
411411
}
412412

413413
type childSubagentChatOptions struct {
414-
chatMode database.NullChatMode
415-
systemPrompt string
414+
chatMode database.NullChatMode
415+
systemPrompt string
416+
modelConfigIDOverride *uuid.UUID
416417
}
417418

418419
func (p *Server) createChildSubagentChat(
@@ -449,8 +450,13 @@ func (p *Server) createChildSubagentChatWithOptions(
449450
if parent.RootChatID.Valid {
450451
rootChatID = parent.RootChatID.UUID
451452
}
452-
if parent.LastModelConfigID == uuid.Nil {
453-
return database.Chat{}, xerrors.New("parent chat model config id is required")
453+
454+
modelConfigID := parent.LastModelConfigID
455+
if opts.modelConfigIDOverride != nil {
456+
modelConfigID = *opts.modelConfigIDOverride
457+
}
458+
if modelConfigID == uuid.Nil {
459+
return database.Chat{}, xerrors.New("model config is required")
454460
}
455461

456462
mcpServerIDs := parent.MCPServerIDs
@@ -482,7 +488,7 @@ func (p *Server) createChildSubagentChatWithOptions(
482488
AgentID: parent.AgentID,
483489
ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
484490
RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true},
485-
LastModelConfigID: parent.LastModelConfigID,
491+
LastModelConfigID: modelConfigID,
486492
Title: title,
487493
Mode: opts.chatMode,
488494
PlanMode: parent.PlanMode,
@@ -528,7 +534,7 @@ func (p *Server) createChildSubagentChatWithOptions(
528534
database.ChatMessageRoleSystem,
529535
deploymentContent,
530536
database.ChatMessageVisibilityModel,
531-
parent.LastModelConfigID,
537+
modelConfigID,
532538
chatprompt.CurrentContentVersion,
533539
))
534540
}
@@ -543,15 +549,15 @@ func (p *Server) createChildSubagentChatWithOptions(
543549
database.ChatMessageRoleSystem,
544550
childSystemPromptContent,
545551
database.ChatMessageVisibilityModel,
546-
parent.LastModelConfigID,
552+
modelConfigID,
547553
chatprompt.CurrentContentVersion,
548554
))
549555
}
550556
appendChatMessage(&systemParams, newChatMessage(
551557
database.ChatMessageRoleSystem,
552558
workspaceAwarenessContent,
553559
database.ChatMessageVisibilityModel,
554-
parent.LastModelConfigID,
560+
modelConfigID,
555561
chatprompt.CurrentContentVersion,
556562
))
557563
if _, err := tx.InsertChatMessages(ctx, systemParams); err != nil {
@@ -578,7 +584,7 @@ func (p *Server) createChildSubagentChatWithOptions(
578584
database.ChatMessageRoleUser,
579585
userContent,
580586
database.ChatMessageVisibilityBoth,
581-
parent.LastModelConfigID,
587+
modelConfigID,
582588
chatprompt.CurrentContentVersion,
583589
).withCreatedBy(parent.OwnerID))
584590
if _, err := tx.InsertChatMessages(ctx, userParams); err != nil {

coderd/x/chatd/subagent_internal_test.go

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,33 @@ func seedInternalChatDeps(
156156
return user, org, model
157157
}
158158

159+
func insertInternalChatModelConfig(
160+
ctx context.Context,
161+
t *testing.T,
162+
db database.Store,
163+
userID uuid.UUID,
164+
model string,
165+
enabled bool,
166+
) database.ChatModelConfig {
167+
t.Helper()
168+
169+
modelConfig, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
170+
Provider: "openai",
171+
Model: model,
172+
DisplayName: model,
173+
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
174+
UpdatedBy: uuid.NullUUID{UUID: userID, Valid: true},
175+
Enabled: enabled,
176+
IsDefault: false,
177+
ContextLimit: 128000,
178+
CompressionThreshold: 70,
179+
Options: json.RawMessage(`{}`),
180+
})
181+
require.NoError(t, err)
182+
183+
return modelConfig
184+
}
185+
159186
func seedWorkspaceBinding(
160187
t *testing.T,
161188
db database.Store,
@@ -256,6 +283,74 @@ func TestCreateChildSubagentChatInheritsWorkspaceBinding(t *testing.T) {
256283
require.Equal(t, parentChat.AgentID, childChat.AgentID)
257284
}
258285

286+
func createInternalParentChat(
287+
ctx context.Context,
288+
t *testing.T,
289+
server *Server,
290+
db database.Store,
291+
orgID uuid.UUID,
292+
userID uuid.UUID,
293+
modelConfigID uuid.UUID,
294+
title string,
295+
) database.Chat {
296+
t.Helper()
297+
298+
parent, err := server.CreateChat(ctx, CreateOptions{
299+
OrganizationID: orgID,
300+
OwnerID: userID,
301+
Title: title,
302+
ModelConfigID: modelConfigID,
303+
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
304+
})
305+
require.NoError(t, err)
306+
307+
parentChat, err := db.GetChatByID(ctx, parent.ID)
308+
require.NoError(t, err)
309+
310+
return parentChat
311+
}
312+
313+
func runSpawnAgentTool(
314+
ctx context.Context,
315+
t *testing.T,
316+
server *Server,
317+
parentChat database.Chat,
318+
args spawnAgentArgs,
319+
) fantasy.ToolResponse {
320+
t.Helper()
321+
322+
tools := server.subagentTools(ctx, func() database.Chat { return parentChat })
323+
tool := findToolByName(tools, "spawn_agent")
324+
require.NotNil(t, tool, "spawn_agent tool must be present")
325+
326+
input, err := json.Marshal(args)
327+
require.NoError(t, err)
328+
329+
resp, err := tool.Run(ctx, fantasy.ToolCall{
330+
ID: uuid.NewString(),
331+
Name: "spawn_agent",
332+
Input: string(input),
333+
})
334+
require.NoError(t, err)
335+
336+
return resp
337+
}
338+
339+
func requireSpawnAgentChildChatID(t *testing.T, resp fantasy.ToolResponse) uuid.UUID {
340+
t.Helper()
341+
require.False(t, resp.IsError, "expected success but got: %s", resp.Content)
342+
343+
var result struct {
344+
ChatID string `json:"chat_id"`
345+
}
346+
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
347+
require.NotEmpty(t, result.ChatID, "response must contain chat_id")
348+
349+
childID, err := uuid.Parse(result.ChatID)
350+
require.NoError(t, err)
351+
return childID
352+
}
353+
259354
func TestCreateChildSubagentChatCopiesPlanMode(t *testing.T) {
260355
t.Parallel()
261356

@@ -293,6 +388,60 @@ func TestCreateChildSubagentChatCopiesPlanMode(t *testing.T) {
293388
require.Equal(t, planMode, childChat.PlanMode)
294389
}
295390

391+
func TestSpawnAgent_InheritsParentModelWhenOmitted(t *testing.T) {
392+
t.Parallel()
393+
394+
db, ps := dbtestutil.NewDB(t)
395+
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
396+
397+
ctx := chatdTestContext(t)
398+
user, org, model := seedInternalChatDeps(ctx, t, db)
399+
parentChat := createInternalParentChat(
400+
ctx, t, server, db, org.ID, user.ID, model.ID, "parent-inherited-model",
401+
)
402+
403+
resp := runSpawnAgentTool(ctx, t, server, parentChat, spawnAgentArgs{
404+
Prompt: "delegate work",
405+
})
406+
childID := requireSpawnAgentChildChatID(t, resp)
407+
408+
childChat, err := db.GetChatByID(ctx, childID)
409+
require.NoError(t, err)
410+
require.Equal(t, parentChat.LastModelConfigID, childChat.LastModelConfigID)
411+
}
412+
413+
func TestCreateChildSubagentChat_OverrideWorksWhenParentHasNoModel(t *testing.T) {
414+
t.Parallel()
415+
416+
db, ps := dbtestutil.NewDB(t)
417+
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
418+
419+
ctx := chatdTestContext(t)
420+
user, org, model := seedInternalChatDeps(ctx, t, db)
421+
overrideModel := insertInternalChatModelConfig(
422+
ctx, t, db, user.ID, "override-no-parent-model-"+uuid.NewString(), true,
423+
)
424+
parentChat := createInternalParentChat(
425+
ctx, t, server, db, org.ID, user.ID, model.ID, "parent-no-model",
426+
)
427+
428+
// The chats table enforces a foreign key for last_model_config_id, so
429+
// use a synthetic parent value here to exercise the override path.
430+
parentChat.LastModelConfigID = uuid.Nil
431+
child, err := server.createChildSubagentChatWithOptions(
432+
ctx,
433+
parentChat,
434+
"delegate work",
435+
"",
436+
childSubagentChatOptions{modelConfigIDOverride: &overrideModel.ID},
437+
)
438+
require.NoError(t, err)
439+
440+
childChat, err := db.GetChatByID(ctx, child.ID)
441+
require.NoError(t, err)
442+
require.Equal(t, overrideModel.ID, childChat.LastModelConfigID)
443+
}
444+
296445
func TestSpawnComputerUseAgent_NoAnthropicProvider(t *testing.T) {
297446
t.Parallel()
298447

0 commit comments

Comments
 (0)