Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 035ab2c

Browse files
committed
improve argument handling by abusing json.Marshal/Unmarshal
1 parent 3d810c0 commit 035ab2c

File tree

1 file changed

+67
-68
lines changed

1 file changed

+67
-68
lines changed

mcp/tools/tools_coder.go

+67-68
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ import (
2020
"github.com/coder/coder/v2/codersdk/workspacesdk"
2121
)
2222

23+
type handleCoderReportTaskArgs struct {
24+
Summary string `json:"summary"`
25+
Link string `json:"link"`
26+
Emoji string `json:"emoji"`
27+
Done bool `json:"done"`
28+
}
29+
2330
// Example payload:
2431
// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_report_task", "arguments": {"summary": "I'm working on the login page.", "link": "https://github.com/coder/coder/pull/1234", "emoji": "🔍", "done": false}}}
2532
func handleCoderReportTask(deps ToolDeps) mcpserver.ToolHandlerFunc {
@@ -28,30 +35,15 @@ func handleCoderReportTask(deps ToolDeps) mcpserver.ToolHandlerFunc {
2835
return nil, xerrors.New("developer error: client is required")
2936
}
3037

31-
args := request.Params.Arguments
32-
33-
summary, ok := args["summary"].(string)
34-
if !ok {
35-
return nil, xerrors.New("summary is required")
36-
}
37-
38-
link, ok := args["link"].(string)
39-
if !ok {
40-
return nil, xerrors.New("link is required")
41-
}
42-
43-
emoji, ok := args["emoji"].(string)
44-
if !ok {
45-
return nil, xerrors.New("emoji is required")
46-
}
47-
48-
done, ok := args["done"].(bool)
49-
if !ok {
50-
return nil, xerrors.New("done is required")
38+
// Convert the request parameters to a json.RawMessage so we can unmarshal
39+
// them into the correct struct.
40+
args, err := unmarshalArgs[handleCoderReportTaskArgs](request.Params.Arguments)
41+
if err != nil {
42+
return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err)
5143
}
5244

5345
// TODO: Waiting on support for tasks.
54-
deps.Logger.Info(ctx, "report task tool called", slog.F("summary", summary), slog.F("link", link), slog.F("done", done), slog.F("emoji", emoji))
46+
deps.Logger.Info(ctx, "report task tool called", slog.F("summary", args.Summary), slog.F("link", args.Link), slog.F("done", args.Done), slog.F("emoji", args.Emoji))
5547
/*
5648
err := sdk.PostTask(ctx, agentsdk.PostTaskRequest{
5749
Reporter: "claude",
@@ -98,33 +90,28 @@ func handleCoderWhoami(deps ToolDeps) mcpserver.ToolHandlerFunc {
9890
}
9991
}
10092

93+
type handleCoderListWorkspacesArgs struct {
94+
Owner string `json:"owner"`
95+
Offset int `json:"offset"`
96+
Limit int `json:"limit"`
97+
}
98+
10199
// Example payload:
102100
// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_list_workspaces", "arguments": {"owner": "me", "offset": 0, "limit": 10}}}
103101
func handleCoderListWorkspaces(deps ToolDeps) mcpserver.ToolHandlerFunc {
104102
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
105103
if deps.Client == nil {
106104
return nil, xerrors.New("developer error: client is required")
107105
}
108-
args := request.Params.Arguments
109-
110-
owner, ok := args["owner"].(string)
111-
if !ok {
112-
owner = codersdk.Me
113-
}
114-
115-
offset, ok := args["offset"].(int)
116-
if !ok || offset < 0 {
117-
offset = 0
118-
}
119-
limit, ok := args["limit"].(int)
120-
if !ok || limit <= 0 {
121-
limit = 10
106+
args, err := unmarshalArgs[handleCoderListWorkspacesArgs](request.Params.Arguments)
107+
if err != nil {
108+
return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err)
122109
}
123110

124111
workspaces, err := deps.Client.Workspaces(ctx, codersdk.WorkspaceFilter{
125-
Owner: owner,
126-
Offset: offset,
127-
Limit: limit,
112+
Owner: args.Owner,
113+
Offset: args.Offset,
114+
Limit: args.Limit,
128115
})
129116
if err != nil {
130117
return nil, xerrors.Errorf("failed to fetch workspaces: %w", err)
@@ -144,21 +131,23 @@ func handleCoderListWorkspaces(deps ToolDeps) mcpserver.ToolHandlerFunc {
144131
}
145132
}
146133

134+
type handleCoderGetWorkspaceArgs struct {
135+
Workspace string `json:"workspace"`
136+
}
137+
147138
// Example payload:
148139
// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_get_workspace", "arguments": {"workspace": "dev"}}}
149140
func handleCoderGetWorkspace(deps ToolDeps) mcpserver.ToolHandlerFunc {
150141
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
151142
if deps.Client == nil {
152143
return nil, xerrors.New("developer error: client is required")
153144
}
154-
args := request.Params.Arguments
155-
156-
wsArg, ok := args["workspace"].(string)
157-
if !ok {
158-
return nil, xerrors.New("workspace is required")
145+
args, err := unmarshalArgs[handleCoderGetWorkspaceArgs](request.Params.Arguments)
146+
if err != nil {
147+
return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err)
159148
}
160149

161-
workspace, err := getWorkspaceByIDOrOwnerName(ctx, deps.Client, wsArg)
150+
workspace, err := getWorkspaceByIDOrOwnerName(ctx, deps.Client, args.Workspace)
162151
if err != nil {
163152
return nil, xerrors.Errorf("failed to fetch workspace: %w", err)
164153
}
@@ -176,28 +165,26 @@ func handleCoderGetWorkspace(deps ToolDeps) mcpserver.ToolHandlerFunc {
176165
}
177166
}
178167

168+
type handleCoderWorkspaceExecArgs struct {
169+
Workspace string `json:"workspace"`
170+
Command string `json:"command"`
171+
}
172+
179173
// Example payload:
180174
// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_workspace_exec", "arguments": {"workspace": "dev", "command": "ps -ef"}}}
181175
func handleCoderWorkspaceExec(deps ToolDeps) mcpserver.ToolHandlerFunc {
182176
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
183177
if deps.Client == nil {
184178
return nil, xerrors.New("developer error: client is required")
185179
}
186-
args := request.Params.Arguments
187-
188-
wsArg, ok := args["workspace"].(string)
189-
if !ok {
190-
return nil, xerrors.New("workspace is required")
191-
}
192-
193-
command, ok := args["command"].(string)
194-
if !ok {
195-
return nil, xerrors.New("command is required")
180+
args, err := unmarshalArgs[handleCoderWorkspaceExecArgs](request.Params.Arguments)
181+
if err != nil {
182+
return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err)
196183
}
197184

198185
// Attempt to fetch the workspace. We may get a UUID or a name, so try to
199186
// handle both.
200-
ws, err := getWorkspaceByIDOrOwnerName(ctx, deps.Client, wsArg)
187+
ws, err := getWorkspaceByIDOrOwnerName(ctx, deps.Client, args.Workspace)
201188
if err != nil {
202189
return nil, xerrors.Errorf("failed to fetch workspace: %w", err)
203190
}
@@ -224,7 +211,7 @@ func handleCoderWorkspaceExec(deps ToolDeps) mcpserver.ToolHandlerFunc {
224211
Reconnect: uuid.New(),
225212
Width: 80,
226213
Height: 24,
227-
Command: command,
214+
Command: args.Command,
228215
BackendType: "buffered", // the screen backend is annoying to use here.
229216
})
230217
if err != nil {
@@ -288,6 +275,11 @@ func handleCoderListTemplates(deps ToolDeps) mcpserver.ToolHandlerFunc {
288275
}
289276
}
290277

278+
type handleCoderWorkspaceTransitionArgs struct {
279+
Workspace string `json:"workspace"`
280+
Transition string `json:"transition"`
281+
}
282+
291283
// Example payload:
292284
// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name":
293285
// "coder_workspace_transition", "arguments": {"workspace": "dev", "transition": "stop"}}}
@@ -296,24 +288,17 @@ func handleCoderWorkspaceTransition(deps ToolDeps) mcpserver.ToolHandlerFunc {
296288
if deps.Client == nil {
297289
return nil, xerrors.New("developer error: client is required")
298290
}
299-
300-
args := request.Params.Arguments
301-
302-
wsArg, ok := args["workspace"].(string)
303-
if !ok {
304-
return nil, xerrors.New("workspace is required")
291+
args, err := unmarshalArgs[handleCoderWorkspaceTransitionArgs](request.Params.Arguments)
292+
if err != nil {
293+
return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err)
305294
}
306295

307-
workspace, err := getWorkspaceByIDOrOwnerName(ctx, deps.Client, wsArg)
296+
workspace, err := getWorkspaceByIDOrOwnerName(ctx, deps.Client, args.Workspace)
308297
if err != nil {
309298
return nil, xerrors.Errorf("failed to fetch workspace: %w", err)
310299
}
311300

312-
transition, ok := args["transition"].(string)
313-
if !ok {
314-
return nil, xerrors.New("transition is required")
315-
}
316-
wsTransition := codersdk.WorkspaceTransition(transition)
301+
wsTransition := codersdk.WorkspaceTransition(args.Transition)
317302
switch wsTransition {
318303
case codersdk.WorkspaceTransitionStart:
319304
case codersdk.WorkspaceTransitionStop:
@@ -350,3 +335,17 @@ func getWorkspaceByIDOrOwnerName(ctx context.Context, client *codersdk.Client, i
350335
}
351336
return client.WorkspaceByOwnerAndName(ctx, codersdk.Me, identifier, codersdk.WorkspaceOptions{})
352337
}
338+
339+
// unmarshalArgs is a helper function to convert the map[string]any we get from
340+
// the MCP server into a typed struct. It does this by marshaling and unmarshalling
341+
// the arguments.
342+
func unmarshalArgs[T any](args map[string]interface{}) (t T, err error) {
343+
argsJSON, err := json.Marshal(args)
344+
if err != nil {
345+
return t, xerrors.Errorf("failed to marshal arguments: %w", err)
346+
}
347+
if err := json.Unmarshal(argsJSON, &t); err != nil {
348+
return t, xerrors.Errorf("failed to unmarshal arguments: %w", err)
349+
}
350+
return t, nil
351+
}

0 commit comments

Comments
 (0)