Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1235550

Browse files
johnstcnkylecarbs
andauthored
feat(codersdk): add toolsdk and replace existing mcp server tool impl (#17343)
- Refactors existing `mcp` package to use `kylecarbs/aisdk-go` and moves to `codersdk/toolsdk` package. - Updates existing MCP server implementation to use `codersdk/toolsdk` Co-authored-by: Kyle Carberry <[email protected]>
1 parent 2c573dc commit 1235550

File tree

9 files changed

+1774
-1096
lines changed

9 files changed

+1774
-1096
lines changed

cli/exp_mcp.go

+65-26
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,19 @@ import (
66
"errors"
77
"os"
88
"path/filepath"
9+
"slices"
910
"strings"
1011

12+
"github.com/mark3labs/mcp-go/mcp"
1113
"github.com/mark3labs/mcp-go/server"
1214
"github.com/spf13/afero"
1315
"golang.org/x/xerrors"
1416

15-
"cdr.dev/slog"
16-
"cdr.dev/slog/sloggers/sloghuman"
1717
"github.com/coder/coder/v2/buildinfo"
1818
"github.com/coder/coder/v2/cli/cliui"
1919
"github.com/coder/coder/v2/codersdk"
2020
"github.com/coder/coder/v2/codersdk/agentsdk"
21-
codermcp "github.com/coder/coder/v2/mcp"
21+
"github.com/coder/coder/v2/codersdk/toolsdk"
2222
"github.com/coder/serpent"
2323
)
2424

@@ -365,6 +365,8 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct
365365
ctx, cancel := context.WithCancel(inv.Context())
366366
defer cancel()
367367

368+
fs := afero.NewOsFs()
369+
368370
me, err := client.User(ctx, codersdk.Me)
369371
if err != nil {
370372
cliui.Errorf(inv.Stderr, "Failed to log in to the Coder deployment.")
@@ -397,40 +399,36 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct
397399
server.WithInstructions(instructions),
398400
)
399401

400-
// Create a separate logger for the tools.
401-
toolLogger := slog.Make(sloghuman.Sink(invStderr))
402-
403-
toolDeps := codermcp.ToolDeps{
404-
Client: client,
405-
Logger: &toolLogger,
406-
AppStatusSlug: appStatusSlug,
407-
AgentClient: agentsdk.New(client.URL),
408-
}
409-
402+
// Create a new context for the tools with all relevant information.
403+
clientCtx := toolsdk.WithClient(ctx, client)
410404
// Get the workspace agent token from the environment.
411-
agentToken, ok := os.LookupEnv("CODER_AGENT_TOKEN")
412-
if ok && agentToken != "" {
413-
toolDeps.AgentClient.SetSessionToken(agentToken)
405+
if agentToken, err := getAgentToken(fs); err == nil && agentToken != "" {
406+
agentClient := agentsdk.New(client.URL)
407+
agentClient.SetSessionToken(agentToken)
408+
clientCtx = toolsdk.WithAgentClient(clientCtx, agentClient)
414409
} else {
415410
cliui.Warnf(inv.Stderr, "CODER_AGENT_TOKEN is not set, task reporting will not be available")
416411
}
417-
if appStatusSlug == "" {
412+
if appStatusSlug != "" {
418413
cliui.Warnf(inv.Stderr, "CODER_MCP_APP_STATUS_SLUG is not set, task reporting will not be available.")
414+
} else {
415+
clientCtx = toolsdk.WithWorkspaceAppStatusSlug(clientCtx, appStatusSlug)
419416
}
420417

421418
// Register tools based on the allowlist (if specified)
422-
reg := codermcp.AllTools()
423-
if len(allowedTools) > 0 {
424-
reg = reg.WithOnlyAllowed(allowedTools...)
419+
for _, tool := range toolsdk.All {
420+
if len(allowedTools) == 0 || slices.ContainsFunc(allowedTools, func(t string) bool {
421+
return t == tool.Tool.Name
422+
}) {
423+
mcpSrv.AddTools(mcpFromSDK(tool))
424+
}
425425
}
426426

427-
reg.Register(mcpSrv, toolDeps)
428-
429427
srv := server.NewStdioServer(mcpSrv)
430428
done := make(chan error)
431429
go func() {
432430
defer close(done)
433-
srvErr := srv.Listen(ctx, invStdin, invStdout)
431+
srvErr := srv.Listen(clientCtx, invStdin, invStdout)
434432
done <- srvErr
435433
}()
436434

@@ -527,8 +525,8 @@ func configureClaude(fs afero.Fs, cfg ClaudeConfig) error {
527525
if !ok {
528526
mcpServers = make(map[string]any)
529527
}
530-
for name, mcp := range cfg.MCPServers {
531-
mcpServers[name] = mcp
528+
for name, cfgmcp := range cfg.MCPServers {
529+
mcpServers[name] = cfgmcp
532530
}
533531
project["mcpServers"] = mcpServers
534532
// Prevents Claude from asking the user to complete the project onboarding.
@@ -674,7 +672,7 @@ func indexOf(s, substr string) int {
674672

675673
func getAgentToken(fs afero.Fs) (string, error) {
676674
token, ok := os.LookupEnv("CODER_AGENT_TOKEN")
677-
if ok {
675+
if ok && token != "" {
678676
return token, nil
679677
}
680678
tokenFile, ok := os.LookupEnv("CODER_AGENT_TOKEN_FILE")
@@ -687,3 +685,44 @@ func getAgentToken(fs afero.Fs) (string, error) {
687685
}
688686
return string(bs), nil
689687
}
688+
689+
// mcpFromSDK adapts a toolsdk.Tool to go-mcp's server.ServerTool.
690+
// It assumes that the tool responds with a valid JSON object.
691+
func mcpFromSDK(sdkTool toolsdk.Tool[any]) server.ServerTool {
692+
return server.ServerTool{
693+
Tool: mcp.Tool{
694+
Name: sdkTool.Tool.Name,
695+
Description: sdkTool.Description,
696+
InputSchema: mcp.ToolInputSchema{
697+
Type: "object", // Default of mcp.NewTool()
698+
Properties: sdkTool.Schema.Properties,
699+
Required: sdkTool.Schema.Required,
700+
},
701+
},
702+
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
703+
result, err := sdkTool.Handler(ctx, request.Params.Arguments)
704+
if err != nil {
705+
return nil, err
706+
}
707+
var sb strings.Builder
708+
if err := json.NewEncoder(&sb).Encode(result); err == nil {
709+
return &mcp.CallToolResult{
710+
Content: []mcp.Content{
711+
mcp.NewTextContent(sb.String()),
712+
},
713+
}, nil
714+
}
715+
// If the result is not JSON, return it as a string.
716+
// This is a fallback for tools that return non-JSON data.
717+
resultStr, ok := result.(string)
718+
if !ok {
719+
return nil, xerrors.Errorf("tool call result is neither valid JSON or a string, got: %T", result)
720+
}
721+
return &mcp.CallToolResult{
722+
Content: []mcp.Content{
723+
mcp.NewTextContent(resultStr),
724+
},
725+
}, nil
726+
},
727+
}
728+
}

cli/exp_mcp_test.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,13 @@ func TestExpMcpServer(t *testing.T) {
3939
_ = coderdtest.CreateFirstUser(t, client)
4040

4141
// Given: we run the exp mcp command with allowed tools set
42-
inv, root := clitest.New(t, "exp", "mcp", "server", "--allowed-tools=coder_whoami,coder_list_templates")
42+
inv, root := clitest.New(t, "exp", "mcp", "server", "--allowed-tools=coder_get_authenticated_user")
4343
inv = inv.WithContext(cancelCtx)
4444

4545
pty := ptytest.New(t)
4646
inv.Stdin = pty.Input()
4747
inv.Stdout = pty.Output()
48+
// nolint: gocritic // not the focus of this test
4849
clitest.SetupConfig(t, client, root)
4950

5051
cmdDone := make(chan struct{})
@@ -73,13 +74,13 @@ func TestExpMcpServer(t *testing.T) {
7374
}
7475
err := json.Unmarshal([]byte(output), &toolsResponse)
7576
require.NoError(t, err)
76-
require.Len(t, toolsResponse.Result.Tools, 2, "should have exactly 2 tools")
77+
require.Len(t, toolsResponse.Result.Tools, 1, "should have exactly 1 tool")
7778
foundTools := make([]string, 0, 2)
7879
for _, tool := range toolsResponse.Result.Tools {
7980
foundTools = append(foundTools, tool.Name)
8081
}
8182
slices.Sort(foundTools)
82-
require.Equal(t, []string{"coder_list_templates", "coder_whoami"}, foundTools)
83+
require.Equal(t, []string{"coder_get_authenticated_user"}, foundTools)
8384
})
8485

8586
t.Run("OK", func(t *testing.T) {

coderd/database/dbfake/dbfake.go

+29-22
Original file line numberDiff line numberDiff line change
@@ -287,23 +287,25 @@ type TemplateVersionResponse struct {
287287
}
288288

289289
type TemplateVersionBuilder struct {
290-
t testing.TB
291-
db database.Store
292-
seed database.TemplateVersion
293-
fileID uuid.UUID
294-
ps pubsub.Pubsub
295-
resources []*sdkproto.Resource
296-
params []database.TemplateVersionParameter
297-
promote bool
290+
t testing.TB
291+
db database.Store
292+
seed database.TemplateVersion
293+
fileID uuid.UUID
294+
ps pubsub.Pubsub
295+
resources []*sdkproto.Resource
296+
params []database.TemplateVersionParameter
297+
promote bool
298+
autoCreateTemplate bool
298299
}
299300

300301
// TemplateVersion generates a template version and optionally a parent
301302
// template if no template ID is set on the seed.
302303
func TemplateVersion(t testing.TB, db database.Store) TemplateVersionBuilder {
303304
return TemplateVersionBuilder{
304-
t: t,
305-
db: db,
306-
promote: true,
305+
t: t,
306+
db: db,
307+
promote: true,
308+
autoCreateTemplate: true,
307309
}
308310
}
309311

@@ -337,6 +339,13 @@ func (t TemplateVersionBuilder) Params(ps ...database.TemplateVersionParameter)
337339
return t
338340
}
339341

342+
func (t TemplateVersionBuilder) SkipCreateTemplate() TemplateVersionBuilder {
343+
// nolint: revive // returns modified struct
344+
t.autoCreateTemplate = false
345+
t.promote = false
346+
return t
347+
}
348+
340349
func (t TemplateVersionBuilder) Do() TemplateVersionResponse {
341350
t.t.Helper()
342351

@@ -347,7 +356,7 @@ func (t TemplateVersionBuilder) Do() TemplateVersionResponse {
347356
t.fileID = takeFirst(t.fileID, uuid.New())
348357

349358
var resp TemplateVersionResponse
350-
if t.seed.TemplateID.UUID == uuid.Nil {
359+
if t.seed.TemplateID.UUID == uuid.Nil && t.autoCreateTemplate {
351360
resp.Template = dbgen.Template(t.t, t.db, database.Template{
352361
ActiveVersionID: t.seed.ID,
353362
OrganizationID: t.seed.OrganizationID,
@@ -360,16 +369,14 @@ func (t TemplateVersionBuilder) Do() TemplateVersionResponse {
360369
}
361370

362371
version := dbgen.TemplateVersion(t.t, t.db, t.seed)
363-
364-
// Always make this version the active version. We can easily
365-
// add a conditional to the builder to opt out of this when
366-
// necessary.
367-
err := t.db.UpdateTemplateActiveVersionByID(ownerCtx, database.UpdateTemplateActiveVersionByIDParams{
368-
ID: t.seed.TemplateID.UUID,
369-
ActiveVersionID: t.seed.ID,
370-
UpdatedAt: dbtime.Now(),
371-
})
372-
require.NoError(t.t, err)
372+
if t.promote {
373+
err := t.db.UpdateTemplateActiveVersionByID(ownerCtx, database.UpdateTemplateActiveVersionByIDParams{
374+
ID: t.seed.TemplateID.UUID,
375+
ActiveVersionID: t.seed.ID,
376+
UpdatedAt: dbtime.Now(),
377+
})
378+
require.NoError(t.t, err)
379+
}
373380

374381
payload, err := json.Marshal(provisionerdserver.TemplateVersionImportJob{
375382
TemplateVersionID: t.seed.ID,

0 commit comments

Comments
 (0)