Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f10c081

Browse files
committed
add WithCleanContext middleware func
1 parent e440e69 commit f10c081

File tree

3 files changed

+201
-82
lines changed

3 files changed

+201
-82
lines changed

cli/exp_mcp.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -713,8 +713,8 @@ func mcpFromSDK(sdkTool toolsdk.Tool[any, any], tb toolsdk.Deps) server.ServerTo
713713
Required: sdkTool.Schema.Required,
714714
},
715715
},
716-
Handler: func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
717-
result, err := sdkTool.Handler(tb, request.Params.Arguments)
716+
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
717+
result, err := sdkTool.Handler(ctx, tb, request.Params.Arguments)
718718
if err != nil {
719719
return nil, err
720720
}

codersdk/toolsdk/toolsdk.go

Lines changed: 69 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ type Deps struct {
2121
}
2222

2323
// HandlerFunc is a function that handles a tool call.
24-
type HandlerFunc[Arg, Ret any] func(tb Deps, args Arg) (Ret, error)
24+
type HandlerFunc[Arg, Ret any] func(context.Context, Deps, Arg) (Ret, error)
2525

2626
type Tool[Arg, Ret any] struct {
2727
aisdk.Tool
@@ -32,12 +32,12 @@ type Tool[Arg, Ret any] struct {
3232
func (t Tool[Arg, Ret]) Generic() Tool[any, any] {
3333
return Tool[any, any]{
3434
Tool: t.Tool,
35-
Handler: func(tb Deps, args any) (any, error) {
35+
Handler: func(ctx context.Context, tb Deps, args any) (any, error) {
3636
typedArg, ok := args.(Arg)
3737
if !ok {
3838
return nil, xerrors.Errorf("developer error: invalid argument type for tool %s", t.Tool.Name)
3939
}
40-
return t.Handler(tb, typedArg)
40+
return t.Handler(ctx, tb, typedArg)
4141
},
4242
}
4343
}
@@ -115,13 +115,41 @@ type UploadTarFileArgs struct {
115115

116116
// WithRecover wraps a HandlerFunc to recover from panics and return an error.
117117
func WithRecover[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] {
118-
return func(tb Deps, args Arg) (ret Ret, err error) {
118+
return func(ctx context.Context, tb Deps, args Arg) (ret Ret, err error) {
119119
defer func() {
120120
if r := recover(); r != nil {
121121
err = xerrors.Errorf("tool handler panic: %v", r)
122122
}
123123
}()
124-
return h(tb, args)
124+
return h(ctx, tb, args)
125+
}
126+
}
127+
128+
// WithCleanContext wraps a HandlerFunc to provide it with a new context.
129+
// This ensures that no data is passed using context.Value.
130+
// If a deadline is set on the parent context, it will be passed to the child
131+
// context.
132+
func WithCleanContext[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] {
133+
return func(parent context.Context, tb Deps, args Arg) (ret Ret, err error) {
134+
child, childCancel := context.WithCancel(context.Background())
135+
defer childCancel()
136+
// Ensure that cancellation propagates from the parent context to the child context.
137+
go func() {
138+
select {
139+
case <-child.Done():
140+
return
141+
case <-parent.Done():
142+
childCancel()
143+
}
144+
}()
145+
// Also ensure that the child context has the same deadline as the parent
146+
// context.
147+
if deadline, ok := parent.Deadline(); ok {
148+
deadlineCtx, deadlineCancel := context.WithDeadline(child, deadline)
149+
defer deadlineCancel()
150+
child = deadlineCtx
151+
}
152+
return h(child, tb, args)
125153
}
126154
}
127155

@@ -137,7 +165,7 @@ func wrapAll(mw func(HandlerFunc[any, any]) HandlerFunc[any, any], tools ...Tool
137165
var (
138166
// All is a list of all tools that can be used in the Coder CLI.
139167
// When you add a new tool, be sure to include it here!
140-
All = wrapAll(WithRecover,
168+
All = wrapAll(WithCleanContext, wrapAll(WithRecover,
141169
CreateTemplate.Generic(),
142170
CreateTemplateVersion.Generic(),
143171
CreateWorkspace.Generic(),
@@ -154,7 +182,7 @@ var (
154182
ReportTask.Generic(),
155183
UploadTarFile.Generic(),
156184
UpdateTemplateActiveVersion.Generic(),
157-
)
185+
)...)
158186

159187
ReportTask = Tool[ReportTaskArgs, string]{
160188
Tool: aisdk.Tool{
@@ -183,14 +211,14 @@ var (
183211
Required: []string{"summary", "link", "state"},
184212
},
185213
},
186-
Handler: func(tb Deps, args ReportTaskArgs) (string, error) {
214+
Handler: func(ctx context.Context, tb Deps, args ReportTaskArgs) (string, error) {
187215
if tb.AgentClient == nil {
188216
return "", xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set")
189217
}
190218
if tb.AppStatusSlug == "" {
191219
return "", xerrors.New("workspace app status slug not found in toolbox")
192220
}
193-
if err := tb.AgentClient.PatchAppStatus(context.TODO(), agentsdk.PatchAppStatus{
221+
if err := tb.AgentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
194222
AppSlug: tb.AppStatusSlug,
195223
Message: args.Summary,
196224
URI: args.Link,
@@ -217,12 +245,12 @@ This returns more data than list_workspaces to reduce token usage.`,
217245
Required: []string{"workspace_id"},
218246
},
219247
},
220-
Handler: func(tb Deps, args GetWorkspaceArgs) (codersdk.Workspace, error) {
248+
Handler: func(ctx context.Context, tb Deps, args GetWorkspaceArgs) (codersdk.Workspace, error) {
221249
wsID, err := uuid.Parse(args.WorkspaceID)
222250
if err != nil {
223251
return codersdk.Workspace{}, xerrors.New("workspace_id must be a valid UUID")
224252
}
225-
return tb.CoderClient.Workspace(context.TODO(), wsID)
253+
return tb.CoderClient.Workspace(ctx, wsID)
226254
},
227255
}
228256

@@ -257,7 +285,7 @@ is provisioned correctly and the agent can connect to the control plane.
257285
Required: []string{"user", "template_version_id", "name", "rich_parameters"},
258286
},
259287
},
260-
Handler: func(tb Deps, args CreateWorkspaceArgs) (codersdk.Workspace, error) {
288+
Handler: func(ctx context.Context, tb Deps, args CreateWorkspaceArgs) (codersdk.Workspace, error) {
261289
tvID, err := uuid.Parse(args.TemplateVersionID)
262290
if err != nil {
263291
return codersdk.Workspace{}, xerrors.New("template_version_id must be a valid UUID")
@@ -272,7 +300,7 @@ is provisioned correctly and the agent can connect to the control plane.
272300
Value: v,
273301
})
274302
}
275-
workspace, err := tb.CoderClient.CreateUserWorkspace(context.TODO(), args.User, codersdk.CreateWorkspaceRequest{
303+
workspace, err := tb.CoderClient.CreateUserWorkspace(ctx, args.User, codersdk.CreateWorkspaceRequest{
276304
TemplateVersionID: tvID,
277305
Name: args.Name,
278306
RichParameterValues: buildParams,
@@ -297,12 +325,12 @@ is provisioned correctly and the agent can connect to the control plane.
297325
},
298326
},
299327
},
300-
Handler: func(tb Deps, args ListWorkspacesArgs) ([]MinimalWorkspace, error) {
328+
Handler: func(ctx context.Context, tb Deps, args ListWorkspacesArgs) ([]MinimalWorkspace, error) {
301329
owner := args.Owner
302330
if owner == "" {
303331
owner = codersdk.Me
304332
}
305-
workspaces, err := tb.CoderClient.Workspaces(context.TODO(), codersdk.WorkspaceFilter{
333+
workspaces, err := tb.CoderClient.Workspaces(ctx, codersdk.WorkspaceFilter{
306334
Owner: owner,
307335
})
308336
if err != nil {
@@ -334,8 +362,8 @@ is provisioned correctly and the agent can connect to the control plane.
334362
Required: []string{},
335363
},
336364
},
337-
Handler: func(tb Deps, _ NoArgs) ([]MinimalTemplate, error) {
338-
templates, err := tb.CoderClient.Templates(context.TODO(), codersdk.TemplateFilter{})
365+
Handler: func(ctx context.Context, tb Deps, _ NoArgs) ([]MinimalTemplate, error) {
366+
templates, err := tb.CoderClient.Templates(ctx, codersdk.TemplateFilter{})
339367
if err != nil {
340368
return nil, err
341369
}
@@ -367,12 +395,12 @@ is provisioned correctly and the agent can connect to the control plane.
367395
Required: []string{"template_version_id"},
368396
},
369397
},
370-
Handler: func(tb Deps, args ListTemplateVersionParametersArgs) ([]codersdk.TemplateVersionParameter, error) {
398+
Handler: func(ctx context.Context, tb Deps, args ListTemplateVersionParametersArgs) ([]codersdk.TemplateVersionParameter, error) {
371399
templateVersionID, err := uuid.Parse(args.TemplateVersionID)
372400
if err != nil {
373401
return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err)
374402
}
375-
parameters, err := tb.CoderClient.TemplateVersionRichParameters(context.TODO(), templateVersionID)
403+
parameters, err := tb.CoderClient.TemplateVersionRichParameters(ctx, templateVersionID)
376404
if err != nil {
377405
return nil, err
378406
}
@@ -389,8 +417,8 @@ is provisioned correctly and the agent can connect to the control plane.
389417
Required: []string{},
390418
},
391419
},
392-
Handler: func(tb Deps, _ NoArgs) (codersdk.User, error) {
393-
return tb.CoderClient.User(context.TODO(), "me")
420+
Handler: func(ctx context.Context, tb Deps, _ NoArgs) (codersdk.User, error) {
421+
return tb.CoderClient.User(ctx, "me")
394422
},
395423
}
396424

@@ -416,7 +444,7 @@ is provisioned correctly and the agent can connect to the control plane.
416444
Required: []string{"workspace_id", "transition"},
417445
},
418446
},
419-
Handler: func(tb Deps, args CreateWorkspaceBuildArgs) (codersdk.WorkspaceBuild, error) {
447+
Handler: func(ctx context.Context, tb Deps, args CreateWorkspaceBuildArgs) (codersdk.WorkspaceBuild, error) {
420448
workspaceID, err := uuid.Parse(args.WorkspaceID)
421449
if err != nil {
422450
return codersdk.WorkspaceBuild{}, xerrors.Errorf("workspace_id must be a valid UUID: %w", err)
@@ -435,7 +463,7 @@ is provisioned correctly and the agent can connect to the control plane.
435463
if templateVersionID != uuid.Nil {
436464
cbr.TemplateVersionID = templateVersionID
437465
}
438-
return tb.CoderClient.CreateWorkspaceBuild(context.TODO(), workspaceID, cbr)
466+
return tb.CoderClient.CreateWorkspaceBuild(ctx, workspaceID, cbr)
439467
},
440468
}
441469

@@ -897,8 +925,8 @@ The file_id provided is a reference to a tar file you have uploaded containing t
897925
Required: []string{"file_id"},
898926
},
899927
},
900-
Handler: func(tb Deps, args CreateTemplateVersionArgs) (codersdk.TemplateVersion, error) {
901-
me, err := tb.CoderClient.User(context.TODO(), "me")
928+
Handler: func(ctx context.Context, tb Deps, args CreateTemplateVersionArgs) (codersdk.TemplateVersion, error) {
929+
me, err := tb.CoderClient.User(ctx, "me")
902930
if err != nil {
903931
return codersdk.TemplateVersion{}, err
904932
}
@@ -910,7 +938,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t
910938
if err != nil {
911939
return codersdk.TemplateVersion{}, xerrors.Errorf("template_id must be a valid UUID: %w", err)
912940
}
913-
templateVersion, err := tb.CoderClient.CreateTemplateVersion(context.TODO(), me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{
941+
templateVersion, err := tb.CoderClient.CreateTemplateVersion(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{
914942
Message: "Created by AI",
915943
StorageMethod: codersdk.ProvisionerStorageMethodFile,
916944
FileID: fileID,
@@ -939,12 +967,12 @@ The file_id provided is a reference to a tar file you have uploaded containing t
939967
Required: []string{"workspace_agent_id"},
940968
},
941969
},
942-
Handler: func(tb Deps, args GetWorkspaceAgentLogsArgs) ([]string, error) {
970+
Handler: func(ctx context.Context, tb Deps, args GetWorkspaceAgentLogsArgs) ([]string, error) {
943971
workspaceAgentID, err := uuid.Parse(args.WorkspaceAgentID)
944972
if err != nil {
945973
return nil, xerrors.Errorf("workspace_agent_id must be a valid UUID: %w", err)
946974
}
947-
logs, closer, err := tb.CoderClient.WorkspaceAgentLogsAfter(context.TODO(), workspaceAgentID, 0, false)
975+
logs, closer, err := tb.CoderClient.WorkspaceAgentLogsAfter(ctx, workspaceAgentID, 0, false)
948976
if err != nil {
949977
return nil, err
950978
}
@@ -974,12 +1002,12 @@ The file_id provided is a reference to a tar file you have uploaded containing t
9741002
Required: []string{"workspace_build_id"},
9751003
},
9761004
},
977-
Handler: func(tb Deps, args GetWorkspaceBuildLogsArgs) ([]string, error) {
1005+
Handler: func(ctx context.Context, tb Deps, args GetWorkspaceBuildLogsArgs) ([]string, error) {
9781006
workspaceBuildID, err := uuid.Parse(args.WorkspaceBuildID)
9791007
if err != nil {
9801008
return nil, xerrors.Errorf("workspace_build_id must be a valid UUID: %w", err)
9811009
}
982-
logs, closer, err := tb.CoderClient.WorkspaceBuildLogsAfter(context.TODO(), workspaceBuildID, 0)
1010+
logs, closer, err := tb.CoderClient.WorkspaceBuildLogsAfter(ctx, workspaceBuildID, 0)
9831011
if err != nil {
9841012
return nil, err
9851013
}
@@ -1005,13 +1033,13 @@ The file_id provided is a reference to a tar file you have uploaded containing t
10051033
Required: []string{"template_version_id"},
10061034
},
10071035
},
1008-
Handler: func(tb Deps, args GetTemplateVersionLogsArgs) ([]string, error) {
1036+
Handler: func(ctx context.Context, tb Deps, args GetTemplateVersionLogsArgs) ([]string, error) {
10091037
templateVersionID, err := uuid.Parse(args.TemplateVersionID)
10101038
if err != nil {
10111039
return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err)
10121040
}
10131041

1014-
logs, closer, err := tb.CoderClient.TemplateVersionLogsAfter(context.TODO(), templateVersionID, 0)
1042+
logs, closer, err := tb.CoderClient.TemplateVersionLogsAfter(ctx, templateVersionID, 0)
10151043
if err != nil {
10161044
return nil, err
10171045
}
@@ -1040,7 +1068,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t
10401068
Required: []string{"template_id", "template_version_id"},
10411069
},
10421070
},
1043-
Handler: func(tb Deps, args UpdateTemplateActiveVersionArgs) (string, error) {
1071+
Handler: func(ctx context.Context, tb Deps, args UpdateTemplateActiveVersionArgs) (string, error) {
10441072
templateID, err := uuid.Parse(args.TemplateID)
10451073
if err != nil {
10461074
return "", xerrors.Errorf("template_id must be a valid UUID: %w", err)
@@ -1049,7 +1077,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t
10491077
if err != nil {
10501078
return "", xerrors.Errorf("template_version_id must be a valid UUID: %w", err)
10511079
}
1052-
err = tb.CoderClient.UpdateActiveTemplateVersion(context.TODO(), templateID, codersdk.UpdateActiveTemplateVersion{
1080+
err = tb.CoderClient.UpdateActiveTemplateVersion(ctx, templateID, codersdk.UpdateActiveTemplateVersion{
10531081
ID: templateVersionID,
10541082
})
10551083
if err != nil {
@@ -1073,7 +1101,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t
10731101
Required: []string{"mime_type", "files"},
10741102
},
10751103
},
1076-
Handler: func(tb Deps, args UploadTarFileArgs) (codersdk.UploadResponse, error) {
1104+
Handler: func(ctx context.Context, tb Deps, args UploadTarFileArgs) (codersdk.UploadResponse, error) {
10771105
pipeReader, pipeWriter := io.Pipe()
10781106
go func() {
10791107
defer pipeWriter.Close()
@@ -1098,7 +1126,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t
10981126
}
10991127
}()
11001128

1101-
resp, err := tb.CoderClient.Upload(context.TODO(), codersdk.ContentTypeTar, pipeReader)
1129+
resp, err := tb.CoderClient.Upload(ctx, codersdk.ContentTypeTar, pipeReader)
11021130
if err != nil {
11031131
return codersdk.UploadResponse{}, err
11041132
}
@@ -1133,16 +1161,16 @@ The file_id provided is a reference to a tar file you have uploaded containing t
11331161
Required: []string{"name", "display_name", "description", "version_id"},
11341162
},
11351163
},
1136-
Handler: func(tb Deps, args CreateTemplateArgs) (codersdk.Template, error) {
1137-
me, err := tb.CoderClient.User(context.TODO(), "me")
1164+
Handler: func(ctx context.Context, tb Deps, args CreateTemplateArgs) (codersdk.Template, error) {
1165+
me, err := tb.CoderClient.User(ctx, "me")
11381166
if err != nil {
11391167
return codersdk.Template{}, err
11401168
}
11411169
versionID, err := uuid.Parse(args.VersionID)
11421170
if err != nil {
11431171
return codersdk.Template{}, xerrors.Errorf("version_id must be a valid UUID: %w", err)
11441172
}
1145-
template, err := tb.CoderClient.CreateTemplate(context.TODO(), me.OrganizationIDs[0], codersdk.CreateTemplateRequest{
1173+
template, err := tb.CoderClient.CreateTemplate(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateRequest{
11461174
Name: args.Name,
11471175
DisplayName: args.DisplayName,
11481176
Description: args.Description,
@@ -1167,12 +1195,12 @@ The file_id provided is a reference to a tar file you have uploaded containing t
11671195
},
11681196
},
11691197
},
1170-
Handler: func(tb Deps, args DeleteTemplateArgs) (string, error) {
1198+
Handler: func(ctx context.Context, tb Deps, args DeleteTemplateArgs) (string, error) {
11711199
templateID, err := uuid.Parse(args.TemplateID)
11721200
if err != nil {
11731201
return "", xerrors.Errorf("template_id must be a valid UUID: %w", err)
11741202
}
1175-
err = tb.CoderClient.DeleteTemplate(context.TODO(), templateID)
1203+
err = tb.CoderClient.DeleteTemplate(ctx, templateID)
11761204
if err != nil {
11771205
return "", err
11781206
}

0 commit comments

Comments
 (0)