From 4d2dbf0de43a380ac1c7b516b57fd044c69106af Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Thu, 24 Apr 2025 13:15:44 +0100 Subject: [PATCH 01/11] chore(codersdk/toolsdk): add typed argument to toolsdk.Tool --- cli/exp_mcp.go | 2 +- coderd/workspaceagents.go | 24 +++ codersdk/toolsdk/toolsdk.go | 349 ++++++++++++++++--------------- codersdk/toolsdk/toolsdk_test.go | 130 ++++++------ 4 files changed, 274 insertions(+), 231 deletions(-) diff --git a/cli/exp_mcp.go b/cli/exp_mcp.go index 63ee0db04b552..78c5130f6c7f3 100644 --- a/cli/exp_mcp.go +++ b/cli/exp_mcp.go @@ -695,7 +695,7 @@ func getAgentToken(fs afero.Fs) (string, error) { // mcpFromSDK adapts a toolsdk.Tool to go-mcp's server.ServerTool. // It assumes that the tool responds with a valid JSON object. -func mcpFromSDK(sdkTool toolsdk.Tool[any]) server.ServerTool { +func mcpFromSDK(sdkTool toolsdk.Tool[any, any]) server.ServerTool { // NOTE: some clients will silently refuse to use tools if there is an issue // with the tool's schema or configuration. if sdkTool.Schema.Properties == nil { diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 1388b61030d38..980db23e5789f 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -345,6 +345,30 @@ func (api *API) patchWorkspaceAgentAppStatus(rw http.ResponseWriter, r *http.Req return } + if len(req.Message) > 160 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Message is too long.", + Detail: "Message must be less than 160 characters.", + Validations: []codersdk.ValidationError{ + {Field: "message", Detail: "Message must be less than 160 characters."}, + }, + }) + return + } + + switch req.State { + case codersdk.WorkspaceAppStatusStateComplete, codersdk.WorkspaceAppStatusStateFailure, codersdk.WorkspaceAppStatusStateWorking: // valid states + default: + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid state provided.", + Detail: fmt.Sprintf("invalid state: %q", req.State), + Validations: []codersdk.ValidationError{ + {Field: "state", Detail: "State must be one of: complete, failure, working."}, + }, + }) + return + } + workspace, err := api.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID) if err != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ diff --git a/codersdk/toolsdk/toolsdk.go b/codersdk/toolsdk/toolsdk.go index 73dee8e748575..874501367230f 100644 --- a/codersdk/toolsdk/toolsdk.go +++ b/codersdk/toolsdk/toolsdk.go @@ -14,46 +14,121 @@ import ( ) // HandlerFunc is a function that handles a tool call. -type HandlerFunc[T any] func(ctx context.Context, args map[string]any) (T, error) +type HandlerFunc[Arg, Ret any] func(ctx context.Context, args Arg) (Ret, error) -type Tool[T any] struct { +type Tool[Arg, Ret any] struct { aisdk.Tool - Handler HandlerFunc[T] + Handler HandlerFunc[Arg, Ret] } -// Generic returns a Tool[any] that can be used to call the tool. -func (t Tool[T]) Generic() Tool[any] { - return Tool[any]{ +// Generic returns a type-erased version of the Tool. +func (t Tool[Arg, Ret]) Generic() Tool[any, any] { + return Tool[any, any]{ Tool: t.Tool, - Handler: func(ctx context.Context, args map[string]any) (any, error) { - return t.Handler(ctx, args) + Handler: func(ctx context.Context, args any) (any, error) { + typedArg, ok := args.(Arg) + if !ok { + return nil, xerrors.Errorf("developer error: invalid argument type for tool %s", t.Tool.Name) + } + return t.Handler(ctx, typedArg) }, } } +type NoArgs struct{} + +type ReportTaskArgs struct { + Link string `json:"link"` + State string `json:"state"` + Summary string `json:"summary"` +} + +type CreateTemplateVersionArgs struct { + FileID string `json:"file_id"` + TemplateID string `json:"template_id"` +} + +type CreateTemplateArgs struct { + Description string `json:"description"` + DisplayName string `json:"display_name"` + Icon string `json:"icon"` + Name string `json:"name"` + VersionID string `json:"version_id"` +} + +type CreateWorkspaceArgs struct { + Name string `json:"name"` + RichParameters map[string]string `json:"rich_parameters"` + TemplateVersionID string `json:"template_version_id"` + User string `json:"user"` +} + +type CreateWorkspaceBuildArgs struct { + TemplateVersionID string `json:"template_version_id"` + Transition string `json:"transition"` + WorkspaceID string `json:"workspace_id"` +} + +type DeleteTemplateArgs struct { + TemplateID string `json:"template_id"` +} + +type GetTemplateVersionLogsArgs struct { + TemplateVersionID string `json:"template_version_id"` +} + +type GetWorkspaceArgs struct { + WorkspaceID string `json:"workspace_id"` +} + +type GetWorkspaceAgentLogsArgs struct { + WorkspaceAgentID string `json:"workspace_agent_id"` +} + +type GetWorkspaceBuildLogsArgs struct { + WorkspaceBuildID string `json:"workspace_build_id"` +} + +type ListWorkspacesArgs struct { + Owner string `json:"owner"` +} + +type ListTemplateVersionParametersArgs struct { + TemplateVersionID string `json:"template_version_id"` +} + +type UpdateTemplateActiveVersionArgs struct { + TemplateID string `json:"template_id"` + TemplateVersionID string `json:"template_version_id"` +} + +type UploadTarFileArgs struct { + Files map[string]string `json:"files"` +} + var ( // All is a list of all tools that can be used in the Coder CLI. // When you add a new tool, be sure to include it here! - All = []Tool[any]{ - CreateTemplateVersion.Generic(), + All = []Tool[any, any]{ CreateTemplate.Generic(), + CreateTemplateVersion.Generic(), CreateWorkspace.Generic(), CreateWorkspaceBuild.Generic(), DeleteTemplate.Generic(), + ListTemplates.Generic(), + ListTemplateVersionParameters.Generic(), + ListWorkspaces.Generic(), GetAuthenticatedUser.Generic(), GetTemplateVersionLogs.Generic(), - GetWorkspace.Generic(), GetWorkspaceAgentLogs.Generic(), GetWorkspaceBuildLogs.Generic(), - ListWorkspaces.Generic(), - ListTemplates.Generic(), - ListTemplateVersionParameters.Generic(), + GetWorkspace.Generic(), ReportTask.Generic(), UploadTarFile.Generic(), UpdateTemplateActiveVersion.Generic(), } - ReportTask = Tool[string]{ + ReportTask = Tool[ReportTaskArgs, string]{ Tool: aisdk.Tool{ Name: "coder_report_task", Description: "Report progress on a user task in Coder.", @@ -80,7 +155,7 @@ var ( Required: []string{"summary", "link", "state"}, }, }, - Handler: func(ctx context.Context, args map[string]any) (string, error) { + Handler: func(ctx context.Context, args ReportTaskArgs) (string, error) { agentClient, err := agentClientFromContext(ctx) if err != nil { return "", xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set") @@ -89,27 +164,11 @@ var ( if !ok { return "", xerrors.New("workspace app status slug not found in context") } - summary, ok := args["summary"].(string) - if !ok { - return "", xerrors.New("summary must be a string") - } - if len(summary) > 160 { - return "", xerrors.New("summary must be less than 160 characters") - } - link, ok := args["link"].(string) - if !ok { - return "", xerrors.New("link must be a string") - } - state, ok := args["state"].(string) - if !ok { - return "", xerrors.New("state must be a string") - } - if err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ AppSlug: appSlug, - Message: summary, - URI: link, - State: codersdk.WorkspaceAppStatusState(state), + Message: args.Summary, + URI: args.Link, + State: codersdk.WorkspaceAppStatusState(args.State), }); err != nil { return "", err } @@ -117,7 +176,7 @@ var ( }, } - GetWorkspace = Tool[codersdk.Workspace]{ + GetWorkspace = Tool[GetWorkspaceArgs, codersdk.Workspace]{ Tool: aisdk.Tool{ Name: "coder_get_workspace", Description: `Get a workspace by ID. @@ -132,20 +191,20 @@ This returns more data than list_workspaces to reduce token usage.`, Required: []string{"workspace_id"}, }, }, - Handler: func(ctx context.Context, args map[string]any) (codersdk.Workspace, error) { + Handler: func(ctx context.Context, args GetWorkspaceArgs) (codersdk.Workspace, error) { client, err := clientFromContext(ctx) if err != nil { return codersdk.Workspace{}, err } - workspaceID, err := uuidFromArgs(args, "workspace_id") + wsID, err := uuid.Parse(args.WorkspaceID) if err != nil { - return codersdk.Workspace{}, err + return codersdk.Workspace{}, xerrors.New("workspace_id must be a valid UUID") } - return client.Workspace(ctx, workspaceID) + return client.Workspace(ctx, wsID) }, } - CreateWorkspace = Tool[codersdk.Workspace]{ + CreateWorkspace = Tool[CreateWorkspaceArgs, codersdk.Workspace]{ Tool: aisdk.Tool{ Name: "coder_create_workspace", Description: `Create a new workspace in Coder. @@ -176,22 +235,29 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{"user", "template_version_id", "name", "rich_parameters"}, }, }, - Handler: func(ctx context.Context, args map[string]any) (codersdk.Workspace, error) { + Handler: func(ctx context.Context, args CreateWorkspaceArgs) (codersdk.Workspace, error) { client, err := clientFromContext(ctx) if err != nil { return codersdk.Workspace{}, err } - templateVersionID, err := uuidFromArgs(args, "template_version_id") + tvID, err := uuid.Parse(args.TemplateVersionID) if err != nil { - return codersdk.Workspace{}, err - } - name, ok := args["name"].(string) - if !ok { - return codersdk.Workspace{}, xerrors.New("workspace name must be a string") - } - workspace, err := client.CreateUserWorkspace(ctx, "me", codersdk.CreateWorkspaceRequest{ - TemplateVersionID: templateVersionID, - Name: name, + return codersdk.Workspace{}, xerrors.New("template_version_id must be a valid UUID") + } + if args.User == "" { + args.User = codersdk.Me + } + var buildParams []codersdk.WorkspaceBuildParameter + for k, v := range args.RichParameters { + buildParams = append(buildParams, codersdk.WorkspaceBuildParameter{ + Name: k, + Value: v, + }) + } + workspace, err := client.CreateUserWorkspace(ctx, args.User, codersdk.CreateWorkspaceRequest{ + TemplateVersionID: tvID, + Name: args.Name, + RichParameterValues: buildParams, }) if err != nil { return codersdk.Workspace{}, err @@ -200,7 +266,7 @@ is provisioned correctly and the agent can connect to the control plane. }, } - ListWorkspaces = Tool[[]MinimalWorkspace]{ + ListWorkspaces = Tool[ListWorkspacesArgs, []MinimalWorkspace]{ Tool: aisdk.Tool{ Name: "coder_list_workspaces", Description: "Lists workspaces for the authenticated user.", @@ -213,13 +279,13 @@ is provisioned correctly and the agent can connect to the control plane. }, }, }, - Handler: func(ctx context.Context, args map[string]any) ([]MinimalWorkspace, error) { + Handler: func(ctx context.Context, args ListWorkspacesArgs) ([]MinimalWorkspace, error) { client, err := clientFromContext(ctx) if err != nil { return nil, err } - owner, ok := args["owner"].(string) - if !ok { + owner := args.Owner + if owner == "" { owner = codersdk.Me } workspaces, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{ @@ -245,7 +311,7 @@ is provisioned correctly and the agent can connect to the control plane. }, } - ListTemplates = Tool[[]MinimalTemplate]{ + ListTemplates = Tool[NoArgs, []MinimalTemplate]{ Tool: aisdk.Tool{ Name: "coder_list_templates", Description: "Lists templates for the authenticated user.", @@ -254,7 +320,7 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{}, }, }, - Handler: func(ctx context.Context, _ map[string]any) ([]MinimalTemplate, error) { + Handler: func(ctx context.Context, _ NoArgs) ([]MinimalTemplate, error) { client, err := clientFromContext(ctx) if err != nil { return nil, err @@ -278,7 +344,7 @@ is provisioned correctly and the agent can connect to the control plane. }, } - ListTemplateVersionParameters = Tool[[]codersdk.TemplateVersionParameter]{ + ListTemplateVersionParameters = Tool[ListTemplateVersionParametersArgs, []codersdk.TemplateVersionParameter]{ Tool: aisdk.Tool{ Name: "coder_template_version_parameters", Description: "Get the parameters for a template version. You can refer to these as workspace parameters to the user, as they are typically important for creating a workspace.", @@ -291,14 +357,14 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{"template_version_id"}, }, }, - Handler: func(ctx context.Context, args map[string]any) ([]codersdk.TemplateVersionParameter, error) { + Handler: func(ctx context.Context, args ListTemplateVersionParametersArgs) ([]codersdk.TemplateVersionParameter, error) { client, err := clientFromContext(ctx) if err != nil { return nil, err } - templateVersionID, err := uuidFromArgs(args, "template_version_id") + templateVersionID, err := uuid.Parse(args.TemplateVersionID) if err != nil { - return nil, err + return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) } parameters, err := client.TemplateVersionRichParameters(ctx, templateVersionID) if err != nil { @@ -308,7 +374,7 @@ is provisioned correctly and the agent can connect to the control plane. }, } - GetAuthenticatedUser = Tool[codersdk.User]{ + GetAuthenticatedUser = Tool[NoArgs, codersdk.User]{ Tool: aisdk.Tool{ Name: "coder_get_authenticated_user", Description: "Get the currently authenticated user, similar to the `whoami` command.", @@ -317,7 +383,7 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{}, }, }, - Handler: func(ctx context.Context, _ map[string]any) (codersdk.User, error) { + Handler: func(ctx context.Context, _ NoArgs) (codersdk.User, error) { client, err := clientFromContext(ctx) if err != nil { return codersdk.User{}, err @@ -326,7 +392,7 @@ is provisioned correctly and the agent can connect to the control plane. }, } - CreateWorkspaceBuild = Tool[codersdk.WorkspaceBuild]{ + CreateWorkspaceBuild = Tool[CreateWorkspaceBuildArgs, codersdk.WorkspaceBuild]{ Tool: aisdk.Tool{ Name: "coder_create_workspace_build", Description: "Create a new workspace build for an existing workspace. Use this to start, stop, or delete.", @@ -348,25 +414,25 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{"workspace_id", "transition"}, }, }, - Handler: func(ctx context.Context, args map[string]any) (codersdk.WorkspaceBuild, error) { + Handler: func(ctx context.Context, args CreateWorkspaceBuildArgs) (codersdk.WorkspaceBuild, error) { client, err := clientFromContext(ctx) if err != nil { return codersdk.WorkspaceBuild{}, err } - workspaceID, err := uuidFromArgs(args, "workspace_id") + workspaceID, err := uuid.Parse(args.WorkspaceID) if err != nil { - return codersdk.WorkspaceBuild{}, err + return codersdk.WorkspaceBuild{}, xerrors.Errorf("workspace_id must be a valid UUID: %w", err) } - rawTransition, ok := args["transition"].(string) - if !ok { - return codersdk.WorkspaceBuild{}, xerrors.New("transition must be a string") - } - templateVersionID, err := uuidFromArgs(args, "template_version_id") - if err != nil { - return codersdk.WorkspaceBuild{}, err + var templateVersionID uuid.UUID + if args.TemplateVersionID != "" { + tvID, err := uuid.Parse(args.TemplateVersionID) + if err != nil { + return codersdk.WorkspaceBuild{}, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) + } + templateVersionID = tvID } cbr := codersdk.CreateWorkspaceBuildRequest{ - Transition: codersdk.WorkspaceTransition(rawTransition), + Transition: codersdk.WorkspaceTransition(args.Transition), } if templateVersionID != uuid.Nil { cbr.TemplateVersionID = templateVersionID @@ -375,7 +441,7 @@ is provisioned correctly and the agent can connect to the control plane. }, } - CreateTemplateVersion = Tool[codersdk.TemplateVersion]{ + CreateTemplateVersion = Tool[CreateTemplateVersionArgs, codersdk.TemplateVersion]{ Tool: aisdk.Tool{ Name: "coder_create_template_version", Description: `Create a new template version. This is a precursor to creating a template, or you can update an existing template. @@ -833,7 +899,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"file_id"}, }, }, - Handler: func(ctx context.Context, args map[string]any) (codersdk.TemplateVersion, error) { + Handler: func(ctx context.Context, args CreateTemplateVersionArgs) (codersdk.TemplateVersion, error) { client, err := clientFromContext(ctx) if err != nil { return codersdk.TemplateVersion{}, err @@ -842,16 +908,13 @@ The file_id provided is a reference to a tar file you have uploaded containing t if err != nil { return codersdk.TemplateVersion{}, err } - fileID, err := uuidFromArgs(args, "file_id") + fileID, err := uuid.Parse(args.FileID) if err != nil { - return codersdk.TemplateVersion{}, err + return codersdk.TemplateVersion{}, xerrors.Errorf("file_id must be a valid UUID: %w", err) } - var templateID uuid.UUID - if args["template_id"] != nil { - templateID, err = uuidFromArgs(args, "template_id") - if err != nil { - return codersdk.TemplateVersion{}, err - } + templateID, err := uuid.Parse(args.TemplateID) + if err != nil { + return codersdk.TemplateVersion{}, xerrors.Errorf("template_id must be a valid UUID: %w", err) } templateVersion, err := client.CreateTemplateVersion(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{ Message: "Created by AI", @@ -867,12 +930,12 @@ The file_id provided is a reference to a tar file you have uploaded containing t }, } - GetWorkspaceAgentLogs = Tool[[]string]{ + GetWorkspaceAgentLogs = Tool[GetWorkspaceAgentLogsArgs, []string]{ Tool: aisdk.Tool{ Name: "coder_get_workspace_agent_logs", Description: `Get the logs of a workspace agent. -More logs may appear after this call. It does not wait for the agent to finish.`, + More logs may appear after this call. It does not wait for the agent to finish.`, Schema: aisdk.Schema{ Properties: map[string]any{ "workspace_agent_id": map[string]any{ @@ -882,14 +945,14 @@ More logs may appear after this call. It does not wait for the agent to finish.` Required: []string{"workspace_agent_id"}, }, }, - Handler: func(ctx context.Context, args map[string]any) ([]string, error) { + Handler: func(ctx context.Context, args GetWorkspaceAgentLogsArgs) ([]string, error) { client, err := clientFromContext(ctx) if err != nil { return nil, err } - workspaceAgentID, err := uuidFromArgs(args, "workspace_agent_id") + workspaceAgentID, err := uuid.Parse(args.WorkspaceAgentID) if err != nil { - return nil, err + return nil, xerrors.Errorf("workspace_agent_id must be a valid UUID: %w", err) } logs, closer, err := client.WorkspaceAgentLogsAfter(ctx, workspaceAgentID, 0, false) if err != nil { @@ -906,12 +969,12 @@ More logs may appear after this call. It does not wait for the agent to finish.` }, } - GetWorkspaceBuildLogs = Tool[[]string]{ + GetWorkspaceBuildLogs = Tool[GetWorkspaceBuildLogsArgs, []string]{ Tool: aisdk.Tool{ Name: "coder_get_workspace_build_logs", Description: `Get the logs of a workspace build. -Useful for checking whether a workspace builds successfully or not.`, + Useful for checking whether a workspace builds successfully or not.`, Schema: aisdk.Schema{ Properties: map[string]any{ "workspace_build_id": map[string]any{ @@ -921,14 +984,14 @@ Useful for checking whether a workspace builds successfully or not.`, Required: []string{"workspace_build_id"}, }, }, - Handler: func(ctx context.Context, args map[string]any) ([]string, error) { + Handler: func(ctx context.Context, args GetWorkspaceBuildLogsArgs) ([]string, error) { client, err := clientFromContext(ctx) if err != nil { return nil, err } - workspaceBuildID, err := uuidFromArgs(args, "workspace_build_id") + workspaceBuildID, err := uuid.Parse(args.WorkspaceBuildID) if err != nil { - return nil, err + return nil, xerrors.Errorf("workspace_build_id must be a valid UUID: %w", err) } logs, closer, err := client.WorkspaceBuildLogsAfter(ctx, workspaceBuildID, 0) if err != nil { @@ -943,7 +1006,7 @@ Useful for checking whether a workspace builds successfully or not.`, }, } - GetTemplateVersionLogs = Tool[[]string]{ + GetTemplateVersionLogs = Tool[GetTemplateVersionLogsArgs, []string]{ Tool: aisdk.Tool{ Name: "coder_get_template_version_logs", Description: "Get the logs of a template version. This is useful to check whether a template version successfully imports or not.", @@ -956,14 +1019,14 @@ Useful for checking whether a workspace builds successfully or not.`, Required: []string{"template_version_id"}, }, }, - Handler: func(ctx context.Context, args map[string]any) ([]string, error) { + Handler: func(ctx context.Context, args GetTemplateVersionLogsArgs) ([]string, error) { client, err := clientFromContext(ctx) if err != nil { return nil, err } - templateVersionID, err := uuidFromArgs(args, "template_version_id") + templateVersionID, err := uuid.Parse(args.TemplateVersionID) if err != nil { - return nil, err + return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) } logs, closer, err := client.TemplateVersionLogsAfter(ctx, templateVersionID, 0) @@ -979,7 +1042,7 @@ Useful for checking whether a workspace builds successfully or not.`, }, } - UpdateTemplateActiveVersion = Tool[string]{ + UpdateTemplateActiveVersion = Tool[UpdateTemplateActiveVersionArgs, string]{ Tool: aisdk.Tool{ Name: "coder_update_template_active_version", Description: "Update the active version of a template. This is helpful when iterating on templates.", @@ -995,18 +1058,18 @@ Useful for checking whether a workspace builds successfully or not.`, Required: []string{"template_id", "template_version_id"}, }, }, - Handler: func(ctx context.Context, args map[string]any) (string, error) { + Handler: func(ctx context.Context, args UpdateTemplateActiveVersionArgs) (string, error) { client, err := clientFromContext(ctx) if err != nil { return "", err } - templateID, err := uuidFromArgs(args, "template_id") + templateID, err := uuid.Parse(args.TemplateID) if err != nil { - return "", err + return "", xerrors.Errorf("template_id must be a valid UUID: %w", err) } - templateVersionID, err := uuidFromArgs(args, "template_version_id") + templateVersionID, err := uuid.Parse(args.TemplateVersionID) if err != nil { - return "", err + return "", xerrors.Errorf("template_version_id must be a valid UUID: %w", err) } err = client.UpdateActiveTemplateVersion(ctx, templateID, codersdk.UpdateActiveTemplateVersion{ ID: templateVersionID, @@ -1018,15 +1081,12 @@ Useful for checking whether a workspace builds successfully or not.`, }, } - UploadTarFile = Tool[codersdk.UploadResponse]{ + UploadTarFile = Tool[UploadTarFileArgs, codersdk.UploadResponse]{ Tool: aisdk.Tool{ Name: "coder_upload_tar_file", Description: `Create and upload a tar file by key/value mapping of file names to file contents. Use this to create template versions. Reference the tool description of "create_template_version" to understand template requirements.`, Schema: aisdk.Schema{ Properties: map[string]any{ - "mime_type": map[string]any{ - "type": "string", - }, "files": map[string]any{ "type": "object", "description": "A map of file names to file contents.", @@ -1035,37 +1095,27 @@ Useful for checking whether a workspace builds successfully or not.`, Required: []string{"mime_type", "files"}, }, }, - Handler: func(ctx context.Context, args map[string]any) (codersdk.UploadResponse, error) { + Handler: func(ctx context.Context, args UploadTarFileArgs) (codersdk.UploadResponse, error) { client, err := clientFromContext(ctx) if err != nil { return codersdk.UploadResponse{}, err } - files, ok := args["files"].(map[string]any) - if !ok { - return codersdk.UploadResponse{}, xerrors.New("files must be a map") - } - pipeReader, pipeWriter := io.Pipe() go func() { defer pipeWriter.Close() tarWriter := tar.NewWriter(pipeWriter) - for name, content := range files { - contentStr, ok := content.(string) - if !ok { - _ = pipeWriter.CloseWithError(xerrors.New("file content must be a string")) - return - } + for name, content := range args.Files { header := &tar.Header{ Name: name, - Size: int64(len(contentStr)), + Size: int64(len(content)), Mode: 0o644, } if err := tarWriter.WriteHeader(header); err != nil { _ = pipeWriter.CloseWithError(err) return } - if _, err := tarWriter.Write([]byte(contentStr)); err != nil { + if _, err := tarWriter.Write([]byte(content)); err != nil { _ = pipeWriter.CloseWithError(err) return } @@ -1083,7 +1133,7 @@ Useful for checking whether a workspace builds successfully or not.`, }, } - CreateTemplate = Tool[codersdk.Template]{ + CreateTemplate = Tool[CreateTemplateArgs, codersdk.Template]{ Tool: aisdk.Tool{ Name: "coder_create_template", Description: "Create a new template in Coder. First, you must create a template version.", @@ -1110,7 +1160,7 @@ Useful for checking whether a workspace builds successfully or not.`, Required: []string{"name", "display_name", "description", "version_id"}, }, }, - Handler: func(ctx context.Context, args map[string]any) (codersdk.Template, error) { + Handler: func(ctx context.Context, args CreateTemplateArgs) (codersdk.Template, error) { client, err := clientFromContext(ctx) if err != nil { return codersdk.Template{}, err @@ -1119,27 +1169,14 @@ Useful for checking whether a workspace builds successfully or not.`, if err != nil { return codersdk.Template{}, err } - versionID, err := uuidFromArgs(args, "version_id") + versionID, err := uuid.Parse(args.VersionID) if err != nil { - return codersdk.Template{}, err - } - name, ok := args["name"].(string) - if !ok { - return codersdk.Template{}, xerrors.New("name must be a string") - } - displayName, ok := args["display_name"].(string) - if !ok { - return codersdk.Template{}, xerrors.New("display_name must be a string") + return codersdk.Template{}, xerrors.Errorf("version_id must be a valid UUID: %w", err) } - description, ok := args["description"].(string) - if !ok { - return codersdk.Template{}, xerrors.New("description must be a string") - } - template, err := client.CreateTemplate(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateRequest{ - Name: name, - DisplayName: displayName, - Description: description, + Name: args.Name, + DisplayName: args.DisplayName, + Description: args.Description, VersionID: versionID, }) if err != nil { @@ -1149,7 +1186,7 @@ Useful for checking whether a workspace builds successfully or not.`, }, } - DeleteTemplate = Tool[string]{ + DeleteTemplate = Tool[DeleteTemplateArgs, string]{ Tool: aisdk.Tool{ Name: "coder_delete_template", Description: "Delete a template. This is irreversible.", @@ -1161,15 +1198,15 @@ Useful for checking whether a workspace builds successfully or not.`, }, }, }, - Handler: func(ctx context.Context, args map[string]any) (string, error) { + Handler: func(ctx context.Context, args DeleteTemplateArgs) (string, error) { client, err := clientFromContext(ctx) if err != nil { return "", err } - templateID, err := uuidFromArgs(args, "template_id") + templateID, err := uuid.Parse(args.TemplateID) if err != nil { - return "", err + return "", xerrors.Errorf("template_id must be a valid UUID: %w", err) } err = client.DeleteTemplate(ctx, templateID) if err != nil { @@ -1241,19 +1278,3 @@ func workspaceAppStatusSlugFromContext(ctx context.Context) (string, bool) { } return slug, true } - -func uuidFromArgs(args map[string]any, key string) (uuid.UUID, error) { - argKey, ok := args[key] - if !ok { - return uuid.Nil, nil // No error if key is not present - } - raw, ok := argKey.(string) - if !ok { - return uuid.Nil, xerrors.Errorf("%s must be a string", key) - } - id, err := uuid.Parse(raw) - if err != nil { - return uuid.Nil, xerrors.Errorf("failed to parse %s: %w", key, err) - } - return id, nil -} diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index 1504e956f6bd4..fc13dc462adbb 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -71,14 +71,26 @@ func TestTools(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) ctx = toolsdk.WithAgentClient(ctx, agentClient) ctx = toolsdk.WithWorkspaceAppStatusSlug(ctx, "some-agent-app") - _, err := testTool(ctx, t, toolsdk.ReportTask, map[string]any{ - "summary": "test summary", - "state": "complete", - "link": "https://example.com", + _, err := testTool(ctx, t, toolsdk.ReportTask, toolsdk.ReportTaskArgs{ + Summary: "test summary", + State: "complete", + Link: "https://example.com", }) require.NoError(t, err) }) + t.Run("GetWorkspace", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitShort) + ctx = toolsdk.WithClient(ctx, memberClient) + + result, err := testTool(ctx, t, toolsdk.GetWorkspace, toolsdk.GetWorkspaceArgs{ + WorkspaceID: r.Workspace.ID.String(), + }) + + require.NoError(t, err) + require.Equal(t, r.Workspace.ID, result.ID, "expected the workspace ID to match") + }) + t.Run("ListTemplates", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) ctx = toolsdk.WithClient(ctx, memberClient) @@ -87,7 +99,7 @@ func TestTools(t *testing.T) { expected, err := memberClient.Templates(context.Background(), codersdk.TemplateFilter{}) require.NoError(t, err) - result, err := testTool(ctx, t, toolsdk.ListTemplates, map[string]any{}) + result, err := testTool(ctx, t, toolsdk.ListTemplates, toolsdk.NoArgs{}) require.NoError(t, err) require.Len(t, result, len(expected)) @@ -108,7 +120,7 @@ func TestTools(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) ctx = toolsdk.WithClient(ctx, memberClient) - result, err := testTool(ctx, t, toolsdk.GetAuthenticatedUser, map[string]any{}) + result, err := testTool(ctx, t, toolsdk.GetAuthenticatedUser, toolsdk.NoArgs{}) require.NoError(t, err) require.Equal(t, member.ID, result.ID) @@ -119,9 +131,7 @@ func TestTools(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) ctx = toolsdk.WithClient(ctx, memberClient) - result, err := testTool(ctx, t, toolsdk.ListWorkspaces, map[string]any{ - "owner": "me", - }) + result, err := testTool(ctx, t, toolsdk.ListWorkspaces, toolsdk.ListWorkspacesArgs{}) require.NoError(t, err) require.Len(t, result, 1, "expected 1 workspace") @@ -129,26 +139,14 @@ func TestTools(t *testing.T) { require.Equal(t, r.Workspace.ID.String(), workspace.ID, "expected the workspace to match the one we created") }) - t.Run("GetWorkspace", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - - result, err := testTool(ctx, t, toolsdk.GetWorkspace, map[string]any{ - "workspace_id": r.Workspace.ID.String(), - }) - - require.NoError(t, err) - require.Equal(t, r.Workspace.ID, result.ID, "expected the workspace ID to match") - }) - t.Run("CreateWorkspaceBuild", func(t *testing.T) { t.Run("Stop", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) ctx = toolsdk.WithClient(ctx, memberClient) - result, err := testTool(ctx, t, toolsdk.CreateWorkspaceBuild, map[string]any{ - "workspace_id": r.Workspace.ID.String(), - "transition": "stop", + result, err := testTool(ctx, t, toolsdk.CreateWorkspaceBuild, toolsdk.CreateWorkspaceBuildArgs{ + WorkspaceID: r.Workspace.ID.String(), + Transition: "stop", }) require.NoError(t, err) @@ -166,9 +164,9 @@ func TestTools(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) ctx = toolsdk.WithClient(ctx, memberClient) - result, err := testTool(ctx, t, toolsdk.CreateWorkspaceBuild, map[string]any{ - "workspace_id": r.Workspace.ID.String(), - "transition": "start", + result, err := testTool(ctx, t, toolsdk.CreateWorkspaceBuild, toolsdk.CreateWorkspaceBuildArgs{ + WorkspaceID: r.Workspace.ID.String(), + Transition: "start", }) require.NoError(t, err) @@ -201,10 +199,10 @@ func TestTools(t *testing.T) { }).Do() // Update to new version - updateBuild, err := testTool(ctx, t, toolsdk.CreateWorkspaceBuild, map[string]any{ - "workspace_id": r.Workspace.ID.String(), - "transition": "start", - "template_version_id": newVersion.TemplateVersion.ID.String(), + updateBuild, err := testTool(ctx, t, toolsdk.CreateWorkspaceBuild, toolsdk.CreateWorkspaceBuildArgs{ + WorkspaceID: r.Workspace.ID.String(), + Transition: "start", + TemplateVersionID: newVersion.TemplateVersion.ID.String(), }) require.NoError(t, err) require.Equal(t, codersdk.WorkspaceTransitionStart, updateBuild.Transition) @@ -214,10 +212,10 @@ func TestTools(t *testing.T) { require.NoError(t, client.CancelWorkspaceBuild(ctx, updateBuild.ID)) // Roll back to the original version - rollbackBuild, err := testTool(ctx, t, toolsdk.CreateWorkspaceBuild, map[string]any{ - "workspace_id": r.Workspace.ID.String(), - "transition": "start", - "template_version_id": originalVersionID.String(), + rollbackBuild, err := testTool(ctx, t, toolsdk.CreateWorkspaceBuild, toolsdk.CreateWorkspaceBuildArgs{ + WorkspaceID: r.Workspace.ID.String(), + Transition: "start", + TemplateVersionID: originalVersionID.String(), }) require.NoError(t, err) require.Equal(t, codersdk.WorkspaceTransitionStart, rollbackBuild.Transition) @@ -232,8 +230,8 @@ func TestTools(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) ctx = toolsdk.WithClient(ctx, memberClient) - params, err := testTool(ctx, t, toolsdk.ListTemplateVersionParameters, map[string]any{ - "template_version_id": r.TemplateVersion.ID.String(), + params, err := testTool(ctx, t, toolsdk.ListTemplateVersionParameters, toolsdk.ListTemplateVersionParametersArgs{ + TemplateVersionID: r.TemplateVersion.ID.String(), }) require.NoError(t, err) @@ -244,8 +242,8 @@ func TestTools(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) ctx = toolsdk.WithClient(ctx, client) - logs, err := testTool(ctx, t, toolsdk.GetWorkspaceAgentLogs, map[string]any{ - "workspace_agent_id": agentID.String(), + logs, err := testTool(ctx, t, toolsdk.GetWorkspaceAgentLogs, toolsdk.GetWorkspaceAgentLogsArgs{ + WorkspaceAgentID: agentID.String(), }) require.NoError(t, err) @@ -256,8 +254,8 @@ func TestTools(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) ctx = toolsdk.WithClient(ctx, memberClient) - logs, err := testTool(ctx, t, toolsdk.GetWorkspaceBuildLogs, map[string]any{ - "workspace_build_id": r.Build.ID.String(), + logs, err := testTool(ctx, t, toolsdk.GetWorkspaceBuildLogs, toolsdk.GetWorkspaceBuildLogsArgs{ + WorkspaceBuildID: r.Build.ID.String(), }) require.NoError(t, err) @@ -268,8 +266,8 @@ func TestTools(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) ctx = toolsdk.WithClient(ctx, memberClient) - logs, err := testTool(ctx, t, toolsdk.GetTemplateVersionLogs, map[string]any{ - "template_version_id": r.TemplateVersion.ID.String(), + logs, err := testTool(ctx, t, toolsdk.GetTemplateVersionLogs, toolsdk.GetTemplateVersionLogsArgs{ + TemplateVersionID: r.TemplateVersion.ID.String(), }) require.NoError(t, err) @@ -280,9 +278,9 @@ func TestTools(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) ctx = toolsdk.WithClient(ctx, client) // Use owner client for permission - result, err := testTool(ctx, t, toolsdk.UpdateTemplateActiveVersion, map[string]any{ - "template_id": r.Template.ID.String(), - "template_version_id": r.TemplateVersion.ID.String(), + result, err := testTool(ctx, t, toolsdk.UpdateTemplateActiveVersion, toolsdk.UpdateTemplateActiveVersionArgs{ + TemplateID: r.Template.ID.String(), + TemplateVersionID: r.TemplateVersion.ID.String(), }) require.NoError(t, err) @@ -293,8 +291,8 @@ func TestTools(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) ctx = toolsdk.WithClient(ctx, client) - _, err := testTool(ctx, t, toolsdk.DeleteTemplate, map[string]any{ - "template_id": r.Template.ID.String(), + _, err := testTool(ctx, t, toolsdk.DeleteTemplate, toolsdk.DeleteTemplateArgs{ + TemplateID: r.Template.ID.String(), }) // This will fail with because there already exists a workspace. @@ -305,13 +303,12 @@ func TestTools(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) ctx = toolsdk.WithClient(ctx, client) - files := map[string]any{ - "main.tf": "resource \"null_resource\" \"example\" {}", + files := map[string]string{ + "main.tf": `resource "null_resource" "example" {}`, } - result, err := testTool(ctx, t, toolsdk.UploadTarFile, map[string]any{ - "mime_type": string(codersdk.ContentTypeTar), - "files": files, + result, err := testTool(ctx, t, toolsdk.UploadTarFile, toolsdk.UploadTarFileArgs{ + Files: files, }) require.NoError(t, err) @@ -325,8 +322,9 @@ func TestTools(t *testing.T) { // nolint:gocritic // This is in a test package and does not end up in the build file := dbgen.File(t, store, database.File{}) - tv, err := testTool(ctx, t, toolsdk.CreateTemplateVersion, map[string]any{ - "file_id": file.ID.String(), + tv, err := testTool(ctx, t, toolsdk.CreateTemplateVersion, toolsdk.CreateTemplateVersionArgs{ + FileID: file.ID.String(), + TemplateID: r.Template.ID.String(), }) require.NoError(t, err) require.NotEmpty(t, tv) @@ -343,11 +341,11 @@ func TestTools(t *testing.T) { SkipCreateTemplate().Do() // We're going to re-use the pre-existing template version - _, err := testTool(ctx, t, toolsdk.CreateTemplate, map[string]any{ - "name": testutil.GetRandomNameHyphenated(t), - "display_name": "Test Template", - "description": "This is a test template", - "version_id": tv.TemplateVersion.ID.String(), + _, err := testTool(ctx, t, toolsdk.CreateTemplate, toolsdk.CreateTemplateArgs{ + Name: testutil.GetRandomNameHyphenated(t), + DisplayName: "Test Template", + Description: "This is a test template", + VersionID: tv.TemplateVersion.ID.String(), }) require.NoError(t, err) @@ -358,11 +356,11 @@ func TestTools(t *testing.T) { ctx = toolsdk.WithClient(ctx, memberClient) // We need a template version ID to create a workspace - res, err := testTool(ctx, t, toolsdk.CreateWorkspace, map[string]any{ - "user": "me", - "template_version_id": r.TemplateVersion.ID.String(), - "name": testutil.GetRandomNameHyphenated(t), - "rich_parameters": map[string]any{}, + res, err := testTool(ctx, t, toolsdk.CreateWorkspace, toolsdk.CreateWorkspaceArgs{ + User: "me", + TemplateVersionID: r.TemplateVersion.ID.String(), + Name: testutil.GetRandomNameHyphenated(t), + RichParameters: map[string]string{}, }) // The creation might fail for various reasons, but the important thing is @@ -376,7 +374,7 @@ func TestTools(t *testing.T) { var testedTools sync.Map // testTool is a helper function to test a tool and mark it as tested. -func testTool[T any](ctx context.Context, t *testing.T, tool toolsdk.Tool[T], args map[string]any) (T, error) { +func testTool[Arg, Ret any](ctx context.Context, t *testing.T, tool toolsdk.Tool[Arg, Ret], args Arg) (Ret, error) { t.Helper() testedTools.Store(tool.Tool.Name, true) result, err := tool.Handler(ctx, args) From a7784ea2340183a52db02632bdf34b9f106a57bc Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Thu, 24 Apr 2025 15:50:44 +0100 Subject: [PATCH 02/11] chore(codersdk/toolsdk): add tool deps in toolbox struct instead of context --- cli/exp_mcp.go | 16 +-- codersdk/toolsdk/toolsdk.go | 232 ++++++++++++------------------- codersdk/toolsdk/toolsdk_test.go | 111 ++++++--------- 3 files changed, 133 insertions(+), 226 deletions(-) diff --git a/cli/exp_mcp.go b/cli/exp_mcp.go index 78c5130f6c7f3..57adc4ffa21ad 100644 --- a/cli/exp_mcp.go +++ b/cli/exp_mcp.go @@ -400,21 +400,21 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct ) // Create a new context for the tools with all relevant information. - clientCtx := toolsdk.WithClient(ctx, client) + tb := toolsdk.NewToolbox(client) // Get the workspace agent token from the environment. var hasAgentClient bool if agentToken, err := getAgentToken(fs); err == nil && agentToken != "" { hasAgentClient = true agentClient := agentsdk.New(client.URL) agentClient.SetSessionToken(agentToken) - clientCtx = toolsdk.WithAgentClient(clientCtx, agentClient) + tb = tb.WithAgentClient(agentClient) } else { cliui.Warnf(inv.Stderr, "CODER_AGENT_TOKEN is not set, task reporting will not be available") } if appStatusSlug == "" { cliui.Warnf(inv.Stderr, "CODER_MCP_APP_STATUS_SLUG is not set, task reporting will not be available.") } else { - clientCtx = toolsdk.WithWorkspaceAppStatusSlug(clientCtx, appStatusSlug) + tb = tb.WithAppStatusSlug(appStatusSlug) } // Register tools based on the allowlist (if specified) @@ -427,7 +427,7 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct if len(allowedTools) == 0 || slices.ContainsFunc(allowedTools, func(t string) bool { return t == tool.Tool.Name }) { - mcpSrv.AddTools(mcpFromSDK(tool)) + mcpSrv.AddTools(mcpFromSDK(tool, tb)) } } @@ -435,7 +435,7 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct done := make(chan error) go func() { defer close(done) - srvErr := srv.Listen(clientCtx, invStdin, invStdout) + srvErr := srv.Listen(ctx, invStdin, invStdout) done <- srvErr }() @@ -695,7 +695,7 @@ func getAgentToken(fs afero.Fs) (string, error) { // mcpFromSDK adapts a toolsdk.Tool to go-mcp's server.ServerTool. // It assumes that the tool responds with a valid JSON object. -func mcpFromSDK(sdkTool toolsdk.Tool[any, any]) server.ServerTool { +func mcpFromSDK(sdkTool toolsdk.Tool[any, any], tb toolsdk.Toolbox) server.ServerTool { // NOTE: some clients will silently refuse to use tools if there is an issue // with the tool's schema or configuration. if sdkTool.Schema.Properties == nil { @@ -711,8 +711,8 @@ func mcpFromSDK(sdkTool toolsdk.Tool[any, any]) server.ServerTool { Required: sdkTool.Schema.Required, }, }, - Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - result, err := sdkTool.Handler(ctx, request.Params.Arguments) + Handler: func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + result, err := sdkTool.Handler(tb, request.Params.Arguments) if err != nil { return nil, err } diff --git a/codersdk/toolsdk/toolsdk.go b/codersdk/toolsdk/toolsdk.go index 874501367230f..b2b0215136f8f 100644 --- a/codersdk/toolsdk/toolsdk.go +++ b/codersdk/toolsdk/toolsdk.go @@ -13,8 +13,52 @@ import ( "github.com/coder/coder/v2/codersdk/agentsdk" ) +// Toolbox provides access to tool dependencies. +type Toolbox interface { + CoderClient() *codersdk.Client + AgentClient() (*agentsdk.Client, bool) + AppStatusSlug() (string, bool) + + WithAgentClient(*agentsdk.Client) Toolbox + WithAppStatusSlug(string) Toolbox +} + +// toolbox is the concrete implementation of Toolbox. +type toolbox struct { + coderClient *codersdk.Client + agentClient *agentsdk.Client + appStatusSlug string +} + +// NewToolbox constructs a Toolbox with a required CoderClient. +func NewToolbox(coder *codersdk.Client) Toolbox { + return &toolbox{coderClient: coder} +} + +func (tb *toolbox) CoderClient() *codersdk.Client { + return tb.coderClient +} + +func (tb *toolbox) AgentClient() (*agentsdk.Client, bool) { + return tb.agentClient, tb.agentClient != nil +} + +func (tb *toolbox) AppStatusSlug() (string, bool) { + return tb.appStatusSlug, tb.appStatusSlug != "" +} + +func (tb *toolbox) WithAgentClient(agent *agentsdk.Client) Toolbox { + tb.agentClient = agent + return tb +} + +func (tb *toolbox) WithAppStatusSlug(slug string) Toolbox { + tb.appStatusSlug = slug + return tb +} + // HandlerFunc is a function that handles a tool call. -type HandlerFunc[Arg, Ret any] func(ctx context.Context, args Arg) (Ret, error) +type HandlerFunc[Arg, Ret any] func(tb Toolbox, args Arg) (Ret, error) type Tool[Arg, Ret any] struct { aisdk.Tool @@ -25,12 +69,12 @@ type Tool[Arg, Ret any] struct { func (t Tool[Arg, Ret]) Generic() Tool[any, any] { return Tool[any, any]{ Tool: t.Tool, - Handler: func(ctx context.Context, args any) (any, error) { + Handler: func(tb Toolbox, args any) (any, error) { typedArg, ok := args.(Arg) if !ok { return nil, xerrors.Errorf("developer error: invalid argument type for tool %s", t.Tool.Name) } - return t.Handler(ctx, typedArg) + return t.Handler(tb, typedArg) }, } } @@ -155,17 +199,17 @@ var ( Required: []string{"summary", "link", "state"}, }, }, - Handler: func(ctx context.Context, args ReportTaskArgs) (string, error) { - agentClient, err := agentClientFromContext(ctx) - if err != nil { + Handler: func(tb Toolbox, args ReportTaskArgs) (string, error) { + agentClient, ok := tb.AgentClient() + if !ok { return "", xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set") } - appSlug, ok := workspaceAppStatusSlugFromContext(ctx) + appStatusSlug, ok := tb.AppStatusSlug() if !ok { - return "", xerrors.New("workspace app status slug not found in context") + return "", xerrors.New("workspace app status slug not found in toolbox") } - if err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ - AppSlug: appSlug, + if err := agentClient.PatchAppStatus(context.TODO(), agentsdk.PatchAppStatus{ + AppSlug: appStatusSlug, Message: args.Summary, URI: args.Link, State: codersdk.WorkspaceAppStatusState(args.State), @@ -191,16 +235,12 @@ This returns more data than list_workspaces to reduce token usage.`, Required: []string{"workspace_id"}, }, }, - Handler: func(ctx context.Context, args GetWorkspaceArgs) (codersdk.Workspace, error) { - client, err := clientFromContext(ctx) - if err != nil { - return codersdk.Workspace{}, err - } + Handler: func(tb Toolbox, args GetWorkspaceArgs) (codersdk.Workspace, error) { wsID, err := uuid.Parse(args.WorkspaceID) if err != nil { return codersdk.Workspace{}, xerrors.New("workspace_id must be a valid UUID") } - return client.Workspace(ctx, wsID) + return tb.CoderClient().Workspace(context.TODO(), wsID) }, } @@ -235,11 +275,7 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{"user", "template_version_id", "name", "rich_parameters"}, }, }, - Handler: func(ctx context.Context, args CreateWorkspaceArgs) (codersdk.Workspace, error) { - client, err := clientFromContext(ctx) - if err != nil { - return codersdk.Workspace{}, err - } + Handler: func(tb Toolbox, args CreateWorkspaceArgs) (codersdk.Workspace, error) { tvID, err := uuid.Parse(args.TemplateVersionID) if err != nil { return codersdk.Workspace{}, xerrors.New("template_version_id must be a valid UUID") @@ -254,7 +290,7 @@ is provisioned correctly and the agent can connect to the control plane. Value: v, }) } - workspace, err := client.CreateUserWorkspace(ctx, args.User, codersdk.CreateWorkspaceRequest{ + workspace, err := tb.CoderClient().CreateUserWorkspace(context.TODO(), args.User, codersdk.CreateWorkspaceRequest{ TemplateVersionID: tvID, Name: args.Name, RichParameterValues: buildParams, @@ -279,16 +315,12 @@ is provisioned correctly and the agent can connect to the control plane. }, }, }, - Handler: func(ctx context.Context, args ListWorkspacesArgs) ([]MinimalWorkspace, error) { - client, err := clientFromContext(ctx) - if err != nil { - return nil, err - } + Handler: func(tb Toolbox, args ListWorkspacesArgs) ([]MinimalWorkspace, error) { owner := args.Owner if owner == "" { owner = codersdk.Me } - workspaces, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{ + workspaces, err := tb.CoderClient().Workspaces(context.TODO(), codersdk.WorkspaceFilter{ Owner: owner, }) if err != nil { @@ -320,12 +352,8 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{}, }, }, - Handler: func(ctx context.Context, _ NoArgs) ([]MinimalTemplate, error) { - client, err := clientFromContext(ctx) - if err != nil { - return nil, err - } - templates, err := client.Templates(ctx, codersdk.TemplateFilter{}) + Handler: func(tb Toolbox, _ NoArgs) ([]MinimalTemplate, error) { + templates, err := tb.CoderClient().Templates(context.TODO(), codersdk.TemplateFilter{}) if err != nil { return nil, err } @@ -357,16 +385,12 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{"template_version_id"}, }, }, - Handler: func(ctx context.Context, args ListTemplateVersionParametersArgs) ([]codersdk.TemplateVersionParameter, error) { - client, err := clientFromContext(ctx) - if err != nil { - return nil, err - } + Handler: func(tb Toolbox, args ListTemplateVersionParametersArgs) ([]codersdk.TemplateVersionParameter, error) { templateVersionID, err := uuid.Parse(args.TemplateVersionID) if err != nil { return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) } - parameters, err := client.TemplateVersionRichParameters(ctx, templateVersionID) + parameters, err := tb.CoderClient().TemplateVersionRichParameters(context.TODO(), templateVersionID) if err != nil { return nil, err } @@ -383,12 +407,8 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{}, }, }, - Handler: func(ctx context.Context, _ NoArgs) (codersdk.User, error) { - client, err := clientFromContext(ctx) - if err != nil { - return codersdk.User{}, err - } - return client.User(ctx, "me") + Handler: func(tb Toolbox, _ NoArgs) (codersdk.User, error) { + return tb.CoderClient().User(context.TODO(), "me") }, } @@ -414,11 +434,7 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{"workspace_id", "transition"}, }, }, - Handler: func(ctx context.Context, args CreateWorkspaceBuildArgs) (codersdk.WorkspaceBuild, error) { - client, err := clientFromContext(ctx) - if err != nil { - return codersdk.WorkspaceBuild{}, err - } + Handler: func(tb Toolbox, args CreateWorkspaceBuildArgs) (codersdk.WorkspaceBuild, error) { workspaceID, err := uuid.Parse(args.WorkspaceID) if err != nil { return codersdk.WorkspaceBuild{}, xerrors.Errorf("workspace_id must be a valid UUID: %w", err) @@ -437,7 +453,7 @@ is provisioned correctly and the agent can connect to the control plane. if templateVersionID != uuid.Nil { cbr.TemplateVersionID = templateVersionID } - return client.CreateWorkspaceBuild(ctx, workspaceID, cbr) + return tb.CoderClient().CreateWorkspaceBuild(context.TODO(), workspaceID, cbr) }, } @@ -899,12 +915,8 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"file_id"}, }, }, - Handler: func(ctx context.Context, args CreateTemplateVersionArgs) (codersdk.TemplateVersion, error) { - client, err := clientFromContext(ctx) - if err != nil { - return codersdk.TemplateVersion{}, err - } - me, err := client.User(ctx, "me") + Handler: func(tb Toolbox, args CreateTemplateVersionArgs) (codersdk.TemplateVersion, error) { + me, err := tb.CoderClient().User(context.TODO(), "me") if err != nil { return codersdk.TemplateVersion{}, err } @@ -916,7 +928,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t if err != nil { return codersdk.TemplateVersion{}, xerrors.Errorf("template_id must be a valid UUID: %w", err) } - templateVersion, err := client.CreateTemplateVersion(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{ + templateVersion, err := tb.CoderClient().CreateTemplateVersion(context.TODO(), me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{ Message: "Created by AI", StorageMethod: codersdk.ProvisionerStorageMethodFile, FileID: fileID, @@ -945,16 +957,12 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"workspace_agent_id"}, }, }, - Handler: func(ctx context.Context, args GetWorkspaceAgentLogsArgs) ([]string, error) { - client, err := clientFromContext(ctx) - if err != nil { - return nil, err - } + Handler: func(tb Toolbox, args GetWorkspaceAgentLogsArgs) ([]string, error) { workspaceAgentID, err := uuid.Parse(args.WorkspaceAgentID) if err != nil { return nil, xerrors.Errorf("workspace_agent_id must be a valid UUID: %w", err) } - logs, closer, err := client.WorkspaceAgentLogsAfter(ctx, workspaceAgentID, 0, false) + logs, closer, err := tb.CoderClient().WorkspaceAgentLogsAfter(context.TODO(), workspaceAgentID, 0, false) if err != nil { return nil, err } @@ -984,16 +992,12 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"workspace_build_id"}, }, }, - Handler: func(ctx context.Context, args GetWorkspaceBuildLogsArgs) ([]string, error) { - client, err := clientFromContext(ctx) - if err != nil { - return nil, err - } + Handler: func(tb Toolbox, args GetWorkspaceBuildLogsArgs) ([]string, error) { workspaceBuildID, err := uuid.Parse(args.WorkspaceBuildID) if err != nil { return nil, xerrors.Errorf("workspace_build_id must be a valid UUID: %w", err) } - logs, closer, err := client.WorkspaceBuildLogsAfter(ctx, workspaceBuildID, 0) + logs, closer, err := tb.CoderClient().WorkspaceBuildLogsAfter(context.TODO(), workspaceBuildID, 0) if err != nil { return nil, err } @@ -1019,17 +1023,13 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"template_version_id"}, }, }, - Handler: func(ctx context.Context, args GetTemplateVersionLogsArgs) ([]string, error) { - client, err := clientFromContext(ctx) - if err != nil { - return nil, err - } + Handler: func(tb Toolbox, args GetTemplateVersionLogsArgs) ([]string, error) { templateVersionID, err := uuid.Parse(args.TemplateVersionID) if err != nil { return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) } - logs, closer, err := client.TemplateVersionLogsAfter(ctx, templateVersionID, 0) + logs, closer, err := tb.CoderClient().TemplateVersionLogsAfter(context.TODO(), templateVersionID, 0) if err != nil { return nil, err } @@ -1058,11 +1058,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"template_id", "template_version_id"}, }, }, - Handler: func(ctx context.Context, args UpdateTemplateActiveVersionArgs) (string, error) { - client, err := clientFromContext(ctx) - if err != nil { - return "", err - } + Handler: func(tb Toolbox, args UpdateTemplateActiveVersionArgs) (string, error) { templateID, err := uuid.Parse(args.TemplateID) if err != nil { return "", xerrors.Errorf("template_id must be a valid UUID: %w", err) @@ -1071,7 +1067,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t if err != nil { return "", xerrors.Errorf("template_version_id must be a valid UUID: %w", err) } - err = client.UpdateActiveTemplateVersion(ctx, templateID, codersdk.UpdateActiveTemplateVersion{ + err = tb.CoderClient().UpdateActiveTemplateVersion(context.TODO(), templateID, codersdk.UpdateActiveTemplateVersion{ ID: templateVersionID, }) if err != nil { @@ -1095,12 +1091,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"mime_type", "files"}, }, }, - Handler: func(ctx context.Context, args UploadTarFileArgs) (codersdk.UploadResponse, error) { - client, err := clientFromContext(ctx) - if err != nil { - return codersdk.UploadResponse{}, err - } - + Handler: func(tb Toolbox, args UploadTarFileArgs) (codersdk.UploadResponse, error) { pipeReader, pipeWriter := io.Pipe() go func() { defer pipeWriter.Close() @@ -1125,7 +1116,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t } }() - resp, err := client.Upload(ctx, codersdk.ContentTypeTar, pipeReader) + resp, err := tb.CoderClient().Upload(context.TODO(), codersdk.ContentTypeTar, pipeReader) if err != nil { return codersdk.UploadResponse{}, err } @@ -1160,12 +1151,8 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"name", "display_name", "description", "version_id"}, }, }, - Handler: func(ctx context.Context, args CreateTemplateArgs) (codersdk.Template, error) { - client, err := clientFromContext(ctx) - if err != nil { - return codersdk.Template{}, err - } - me, err := client.User(ctx, "me") + Handler: func(tb Toolbox, args CreateTemplateArgs) (codersdk.Template, error) { + me, err := tb.CoderClient().User(context.TODO(), "me") if err != nil { return codersdk.Template{}, err } @@ -1173,7 +1160,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t if err != nil { return codersdk.Template{}, xerrors.Errorf("version_id must be a valid UUID: %w", err) } - template, err := client.CreateTemplate(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateRequest{ + template, err := tb.CoderClient().CreateTemplate(context.TODO(), me.OrganizationIDs[0], codersdk.CreateTemplateRequest{ Name: args.Name, DisplayName: args.DisplayName, Description: args.Description, @@ -1198,17 +1185,12 @@ The file_id provided is a reference to a tar file you have uploaded containing t }, }, }, - Handler: func(ctx context.Context, args DeleteTemplateArgs) (string, error) { - client, err := clientFromContext(ctx) - if err != nil { - return "", err - } - + Handler: func(tb Toolbox, args DeleteTemplateArgs) (string, error) { templateID, err := uuid.Parse(args.TemplateID) if err != nil { return "", xerrors.Errorf("template_id must be a valid UUID: %w", err) } - err = client.DeleteTemplate(ctx, templateID) + err = tb.CoderClient().DeleteTemplate(context.TODO(), templateID) if err != nil { return "", err } @@ -1236,45 +1218,3 @@ type MinimalTemplate struct { ActiveVersionID uuid.UUID `json:"active_version_id"` ActiveUserCount int `json:"active_user_count"` } - -func clientFromContext(ctx context.Context) (*codersdk.Client, error) { - client, ok := ctx.Value(clientContextKey{}).(*codersdk.Client) - if !ok { - return nil, xerrors.New("client required in context") - } - return client, nil -} - -type clientContextKey struct{} - -func WithClient(ctx context.Context, client *codersdk.Client) context.Context { - return context.WithValue(ctx, clientContextKey{}, client) -} - -type agentClientContextKey struct{} - -func WithAgentClient(ctx context.Context, client *agentsdk.Client) context.Context { - return context.WithValue(ctx, agentClientContextKey{}, client) -} - -func agentClientFromContext(ctx context.Context) (*agentsdk.Client, error) { - client, ok := ctx.Value(agentClientContextKey{}).(*agentsdk.Client) - if !ok { - return nil, xerrors.New("agent client required in context") - } - return client, nil -} - -type workspaceAppStatusSlugContextKey struct{} - -func WithWorkspaceAppStatusSlug(ctx context.Context, slug string) context.Context { - return context.WithValue(ctx, workspaceAppStatusSlugContextKey{}, slug) -} - -func workspaceAppStatusSlugFromContext(ctx context.Context) (string, bool) { - slug, ok := ctx.Value(workspaceAppStatusSlugContextKey{}).(string) - if !ok || slug == "" { - return "", false - } - return slug, true -} diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index fc13dc462adbb..d71d871961caf 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -68,10 +68,8 @@ func TestTools(t *testing.T) { }) t.Run("ReportTask", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithAgentClient(ctx, agentClient) - ctx = toolsdk.WithWorkspaceAppStatusSlug(ctx, "some-agent-app") - _, err := testTool(ctx, t, toolsdk.ReportTask, toolsdk.ReportTaskArgs{ + tb := toolsdk.NewToolbox(memberClient).WithAgentClient(agentClient).WithAppStatusSlug("some-agent-app") + _, err := testTool(t, toolsdk.ReportTask, tb, toolsdk.ReportTaskArgs{ Summary: "test summary", State: "complete", Link: "https://example.com", @@ -80,10 +78,8 @@ func TestTools(t *testing.T) { }) t.Run("GetWorkspace", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - - result, err := testTool(ctx, t, toolsdk.GetWorkspace, toolsdk.GetWorkspaceArgs{ + tb := toolsdk.NewToolbox(memberClient) + result, err := testTool(t, toolsdk.GetWorkspace, tb, toolsdk.GetWorkspaceArgs{ WorkspaceID: r.Workspace.ID.String(), }) @@ -92,14 +88,12 @@ func TestTools(t *testing.T) { }) t.Run("ListTemplates", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - + tb := toolsdk.NewToolbox(memberClient) // Get the templates directly for comparison expected, err := memberClient.Templates(context.Background(), codersdk.TemplateFilter{}) require.NoError(t, err) - result, err := testTool(ctx, t, toolsdk.ListTemplates, toolsdk.NoArgs{}) + result, err := testTool(t, toolsdk.ListTemplates, tb, toolsdk.NoArgs{}) require.NoError(t, err) require.Len(t, result, len(expected)) @@ -117,10 +111,8 @@ func TestTools(t *testing.T) { }) t.Run("Whoami", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - - result, err := testTool(ctx, t, toolsdk.GetAuthenticatedUser, toolsdk.NoArgs{}) + tb := toolsdk.NewToolbox(memberClient) + result, err := testTool(t, toolsdk.GetAuthenticatedUser, tb, toolsdk.NoArgs{}) require.NoError(t, err) require.Equal(t, member.ID, result.ID) @@ -128,10 +120,8 @@ func TestTools(t *testing.T) { }) t.Run("ListWorkspaces", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - - result, err := testTool(ctx, t, toolsdk.ListWorkspaces, toolsdk.ListWorkspacesArgs{}) + tb := toolsdk.NewToolbox(memberClient) + result, err := testTool(t, toolsdk.ListWorkspaces, tb, toolsdk.ListWorkspacesArgs{}) require.NoError(t, err) require.Len(t, result, 1, "expected 1 workspace") @@ -142,9 +132,8 @@ func TestTools(t *testing.T) { t.Run("CreateWorkspaceBuild", func(t *testing.T) { t.Run("Stop", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - - result, err := testTool(ctx, t, toolsdk.CreateWorkspaceBuild, toolsdk.CreateWorkspaceBuildArgs{ + tb := toolsdk.NewToolbox(memberClient) + result, err := testTool(t, toolsdk.CreateWorkspaceBuild, tb, toolsdk.CreateWorkspaceBuildArgs{ WorkspaceID: r.Workspace.ID.String(), Transition: "stop", }) @@ -162,9 +151,8 @@ func TestTools(t *testing.T) { t.Run("Start", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - - result, err := testTool(ctx, t, toolsdk.CreateWorkspaceBuild, toolsdk.CreateWorkspaceBuildArgs{ + tb := toolsdk.NewToolbox(memberClient) + result, err := testTool(t, toolsdk.CreateWorkspaceBuild, tb, toolsdk.CreateWorkspaceBuildArgs{ WorkspaceID: r.Workspace.ID.String(), Transition: "start", }) @@ -182,8 +170,7 @@ func TestTools(t *testing.T) { t.Run("TemplateVersionChange", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - + tb := toolsdk.NewToolbox(memberClient) // Get the current template version ID before updating workspace, err := memberClient.Workspace(ctx, r.Workspace.ID) require.NoError(t, err) @@ -199,7 +186,7 @@ func TestTools(t *testing.T) { }).Do() // Update to new version - updateBuild, err := testTool(ctx, t, toolsdk.CreateWorkspaceBuild, toolsdk.CreateWorkspaceBuildArgs{ + updateBuild, err := testTool(t, toolsdk.CreateWorkspaceBuild, tb, toolsdk.CreateWorkspaceBuildArgs{ WorkspaceID: r.Workspace.ID.String(), Transition: "start", TemplateVersionID: newVersion.TemplateVersion.ID.String(), @@ -212,7 +199,7 @@ func TestTools(t *testing.T) { require.NoError(t, client.CancelWorkspaceBuild(ctx, updateBuild.ID)) // Roll back to the original version - rollbackBuild, err := testTool(ctx, t, toolsdk.CreateWorkspaceBuild, toolsdk.CreateWorkspaceBuildArgs{ + rollbackBuild, err := testTool(t, toolsdk.CreateWorkspaceBuild, tb, toolsdk.CreateWorkspaceBuildArgs{ WorkspaceID: r.Workspace.ID.String(), Transition: "start", TemplateVersionID: originalVersionID.String(), @@ -227,10 +214,8 @@ func TestTools(t *testing.T) { }) t.Run("ListTemplateVersionParameters", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - - params, err := testTool(ctx, t, toolsdk.ListTemplateVersionParameters, toolsdk.ListTemplateVersionParametersArgs{ + tb := toolsdk.NewToolbox(memberClient) + params, err := testTool(t, toolsdk.ListTemplateVersionParameters, tb, toolsdk.ListTemplateVersionParametersArgs{ TemplateVersionID: r.TemplateVersion.ID.String(), }) @@ -239,10 +224,8 @@ func TestTools(t *testing.T) { }) t.Run("GetWorkspaceAgentLogs", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, client) - - logs, err := testTool(ctx, t, toolsdk.GetWorkspaceAgentLogs, toolsdk.GetWorkspaceAgentLogsArgs{ + tb := toolsdk.NewToolbox(client) + logs, err := testTool(t, toolsdk.GetWorkspaceAgentLogs, tb, toolsdk.GetWorkspaceAgentLogsArgs{ WorkspaceAgentID: agentID.String(), }) @@ -251,10 +234,8 @@ func TestTools(t *testing.T) { }) t.Run("GetWorkspaceBuildLogs", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - - logs, err := testTool(ctx, t, toolsdk.GetWorkspaceBuildLogs, toolsdk.GetWorkspaceBuildLogsArgs{ + tb := toolsdk.NewToolbox(memberClient) + logs, err := testTool(t, toolsdk.GetWorkspaceBuildLogs, tb, toolsdk.GetWorkspaceBuildLogsArgs{ WorkspaceBuildID: r.Build.ID.String(), }) @@ -263,10 +244,8 @@ func TestTools(t *testing.T) { }) t.Run("GetTemplateVersionLogs", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - - logs, err := testTool(ctx, t, toolsdk.GetTemplateVersionLogs, toolsdk.GetTemplateVersionLogsArgs{ + tb := toolsdk.NewToolbox(memberClient) + logs, err := testTool(t, toolsdk.GetTemplateVersionLogs, tb, toolsdk.GetTemplateVersionLogsArgs{ TemplateVersionID: r.TemplateVersion.ID.String(), }) @@ -275,10 +254,8 @@ func TestTools(t *testing.T) { }) t.Run("UpdateTemplateActiveVersion", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, client) // Use owner client for permission - - result, err := testTool(ctx, t, toolsdk.UpdateTemplateActiveVersion, toolsdk.UpdateTemplateActiveVersionArgs{ + tb := toolsdk.NewToolbox(client) + result, err := testTool(t, toolsdk.UpdateTemplateActiveVersion, tb, toolsdk.UpdateTemplateActiveVersionArgs{ TemplateID: r.Template.ID.String(), TemplateVersionID: r.TemplateVersion.ID.String(), }) @@ -288,10 +265,8 @@ func TestTools(t *testing.T) { }) t.Run("DeleteTemplate", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, client) - - _, err := testTool(ctx, t, toolsdk.DeleteTemplate, toolsdk.DeleteTemplateArgs{ + tb := toolsdk.NewToolbox(client) + _, err := testTool(t, toolsdk.DeleteTemplate, tb, toolsdk.DeleteTemplateArgs{ TemplateID: r.Template.ID.String(), }) @@ -300,14 +275,12 @@ func TestTools(t *testing.T) { }) t.Run("UploadTarFile", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, client) - + tb := toolsdk.NewToolbox(client) files := map[string]string{ "main.tf": `resource "null_resource" "example" {}`, } - result, err := testTool(ctx, t, toolsdk.UploadTarFile, toolsdk.UploadTarFileArgs{ + result, err := testTool(t, toolsdk.UploadTarFile, tb, toolsdk.UploadTarFileArgs{ Files: files, }) @@ -316,13 +289,11 @@ func TestTools(t *testing.T) { }) t.Run("CreateTemplateVersion", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, client) - + tb := toolsdk.NewToolbox(client) // nolint:gocritic // This is in a test package and does not end up in the build file := dbgen.File(t, store, database.File{}) - tv, err := testTool(ctx, t, toolsdk.CreateTemplateVersion, toolsdk.CreateTemplateVersionArgs{ + tv, err := testTool(t, toolsdk.CreateTemplateVersion, tb, toolsdk.CreateTemplateVersionArgs{ FileID: file.ID.String(), TemplateID: r.Template.ID.String(), }) @@ -331,9 +302,7 @@ func TestTools(t *testing.T) { }) t.Run("CreateTemplate", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, client) - + tb := toolsdk.NewToolbox(client) // Create a new template version for use here. tv := dbfake.TemplateVersion(t, store). // nolint:gocritic // This is in a test package and does not end up in the build @@ -341,7 +310,7 @@ func TestTools(t *testing.T) { SkipCreateTemplate().Do() // We're going to re-use the pre-existing template version - _, err := testTool(ctx, t, toolsdk.CreateTemplate, toolsdk.CreateTemplateArgs{ + _, err := testTool(t, toolsdk.CreateTemplate, tb, toolsdk.CreateTemplateArgs{ Name: testutil.GetRandomNameHyphenated(t), DisplayName: "Test Template", Description: "This is a test template", @@ -352,11 +321,9 @@ func TestTools(t *testing.T) { }) t.Run("CreateWorkspace", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - + tb := toolsdk.NewToolbox(memberClient) // We need a template version ID to create a workspace - res, err := testTool(ctx, t, toolsdk.CreateWorkspace, toolsdk.CreateWorkspaceArgs{ + res, err := testTool(t, toolsdk.CreateWorkspace, tb, toolsdk.CreateWorkspaceArgs{ User: "me", TemplateVersionID: r.TemplateVersion.ID.String(), Name: testutil.GetRandomNameHyphenated(t), @@ -374,10 +341,10 @@ func TestTools(t *testing.T) { var testedTools sync.Map // testTool is a helper function to test a tool and mark it as tested. -func testTool[Arg, Ret any](ctx context.Context, t *testing.T, tool toolsdk.Tool[Arg, Ret], args Arg) (Ret, error) { +func testTool[Arg, Ret any](t *testing.T, tool toolsdk.Tool[Arg, Ret], tb toolsdk.Toolbox, args Arg) (Ret, error) { t.Helper() testedTools.Store(tool.Tool.Name, true) - result, err := tool.Handler(ctx, args) + result, err := tool.Handler(tb, args) return result, err } From 5b30ea08e67eea3b060a498def761362e9f317c2 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Thu, 24 Apr 2025 16:19:13 +0100 Subject: [PATCH 03/11] chore(codersdk/toolsdk): add panic recovery tool middleware --- codersdk/toolsdk/toolsdk.go | 25 +++++++++++-- codersdk/toolsdk/toolsdk_test.go | 60 ++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/codersdk/toolsdk/toolsdk.go b/codersdk/toolsdk/toolsdk.go index b2b0215136f8f..2c9a00255cc9c 100644 --- a/codersdk/toolsdk/toolsdk.go +++ b/codersdk/toolsdk/toolsdk.go @@ -150,10 +150,31 @@ type UploadTarFileArgs struct { Files map[string]string `json:"files"` } +// WithRecover wraps a HandlerFunc to recover from panics and return an error. +func WithRecover[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] { + return func(tb Toolbox, args Arg) (ret Ret, err error) { + defer func() { + if r := recover(); r != nil { + err = xerrors.Errorf("tool handler panic: %v", r) + } + }() + return h(tb, args) + } +} + +// wrapAll wraps all provided tools with the given middleware function. +func wrapAll(mw func(HandlerFunc[any, any]) HandlerFunc[any, any], tools ...Tool[any, any]) []Tool[any, any] { + for i, t := range tools { + t.Handler = mw(t.Handler) + tools[i] = t + } + return tools +} + var ( // All is a list of all tools that can be used in the Coder CLI. // When you add a new tool, be sure to include it here! - All = []Tool[any, any]{ + All = wrapAll(WithRecover, CreateTemplate.Generic(), CreateTemplateVersion.Generic(), CreateWorkspace.Generic(), @@ -170,7 +191,7 @@ var ( ReportTask.Generic(), UploadTarFile.Generic(), UpdateTemplateActiveVersion.Generic(), - } + ) ReportTask = Tool[ReportTaskArgs, string]{ Tool: aisdk.Tool{ diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index d71d871961caf..64272015c6f60 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -9,6 +9,8 @@ import ( "time" "github.com/google/uuid" + "github.com/kylecarbs/aisdk-go" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/coderdtest" @@ -379,3 +381,61 @@ func TestMain(m *testing.M) { os.Exit(code) } + +func TestWithRecovery(t *testing.T) { + t.Parallel() + t.Run("OK", func(t *testing.T) { + t.Parallel() + fakeTool := toolsdk.Tool[string, string]{ + Tool: aisdk.Tool{ + Name: "fake_tool", + Description: "Returns a string for testing.", + }, + Handler: func(tb toolsdk.Toolbox, args string) (string, error) { + require.Equal(t, "test", args) + return "ok", nil + }, + } + + wrapped := toolsdk.WithRecover(fakeTool.Handler) + v, err := wrapped(nil, "test") + require.NoError(t, err) + require.Equal(t, "ok", v) + }) + + t.Run("Error", func(t *testing.T) { + t.Parallel() + fakeTool := toolsdk.Tool[string, string]{ + Tool: aisdk.Tool{ + Name: "fake_tool", + Description: "Returns an error for testing.", + }, + Handler: func(tb toolsdk.Toolbox, args string) (string, error) { + require.Equal(t, "test", args) + return "", assert.AnError + }, + } + wrapped := toolsdk.WithRecover(fakeTool.Handler) + v, err := wrapped(nil, "test") + require.Empty(t, v) + require.ErrorIs(t, err, assert.AnError) + }) + + t.Run("Panic", func(t *testing.T) { + t.Parallel() + panicTool := toolsdk.Tool[string, string]{ + Tool: aisdk.Tool{ + Name: "panic_tool", + Description: "Panics for testing.", + }, + Handler: func(tb toolsdk.Toolbox, args string) (string, error) { + panic("you can't sweat this fever out") + }, + } + + wrapped := toolsdk.WithRecover(panicTool.Handler) + v, err := wrapped(nil, "disco") + require.Empty(t, v) + require.ErrorContains(t, err, "you can't sweat this fever out") + }) +} From def3fcbce0e72c0d2a9816f9485dd39c814046b2 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Thu, 24 Apr 2025 17:00:29 +0100 Subject: [PATCH 04/11] chore(coderd): improve tests for patchWorkspaceAgentAppStatus --- coderd/workspaceagents.go | 4 +- coderd/workspaceagents_test.go | 80 ++++++++++++++++++++++++++-------- 2 files changed, 63 insertions(+), 21 deletions(-) diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 980db23e5789f..c8d75d0d4f313 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -338,9 +338,9 @@ func (api *API) patchWorkspaceAgentAppStatus(rw http.ResponseWriter, r *http.Req Slug: req.AppSlug, }) if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: "Failed to get workspace app.", - Detail: err.Error(), + Detail: fmt.Sprintf("No app found with slug %q", req.AppSlug), }) return } diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index a6e10ea5fdabf..2f95ef9727ca5 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -341,27 +341,27 @@ func TestWorkspaceAgentLogs(t *testing.T) { func TestWorkspaceAgentAppStatus(t *testing.T) { t.Parallel() - t.Run("Success", func(t *testing.T) { - t.Parallel() - ctx := testutil.Context(t, testutil.WaitMedium) - client, db := coderdtest.NewWithDatabase(t, nil) - user := coderdtest.CreateFirstUser(t, client) - client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID) + ctx := testutil.Context(t, testutil.WaitMedium) + client, db := coderdtest.NewWithDatabase(t, nil) + user := coderdtest.CreateFirstUser(t, client) + client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID) - r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ - OrganizationID: user.OrganizationID, - OwnerID: user2.ID, - }).WithAgent(func(a []*proto.Agent) []*proto.Agent { - a[0].Apps = []*proto.App{ - { - Slug: "vscode", - }, - } - return a - }).Do() + r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: user.OrganizationID, + OwnerID: user2.ID, + }).WithAgent(func(a []*proto.Agent) []*proto.Agent { + a[0].Apps = []*proto.App{ + { + Slug: "vscode", + }, + } + return a + }).Do() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL) + agentClient.SetSessionToken(r.AgentToken) + t.Run("Success", func(t *testing.T) { + t.Parallel() err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ AppSlug: "vscode", Message: "testing", @@ -382,6 +382,48 @@ func TestWorkspaceAgentAppStatus(t *testing.T) { require.Empty(t, agent.Apps[0].Statuses[0].Icon) require.False(t, agent.Apps[0].Statuses[0].NeedsUserAttention) }) + + t.Run("FailUnknownApp", func(t *testing.T) { + t.Parallel() + err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ + AppSlug: "unknown", + Message: "testing", + URI: "https://example.com", + State: codersdk.WorkspaceAppStatusStateComplete, + }) + require.ErrorContains(t, err, "No app found with slug") + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + }) + + t.Run("FailUnknownState", func(t *testing.T) { + t.Parallel() + err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ + AppSlug: "vscode", + Message: "testing", + URI: "https://example.com", + State: "unknown", + }) + require.ErrorContains(t, err, "Invalid state") + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + }) + + t.Run("FailTooLong", func(t *testing.T) { + t.Parallel() + err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ + AppSlug: "vscode", + Message: strings.Repeat("a", 161), + URI: "https://example.com", + State: codersdk.WorkspaceAppStatusStateComplete, + }) + require.ErrorContains(t, err, "Message is too long") + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + }) } func TestWorkspaceAgentConnectRPC(t *testing.T) { From 5647b8b9489bfb99a051209fd7f7afa005604195 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Thu, 24 Apr 2025 18:04:36 +0100 Subject: [PATCH 05/11] simplify toolbox impl, rename to deps --- cli/exp_mcp.go | 10 ++- codersdk/toolsdk/toolsdk.go | 129 +++++++++++-------------------- codersdk/toolsdk/toolsdk_test.go | 54 +++++++------ 3 files changed, 80 insertions(+), 113 deletions(-) diff --git a/cli/exp_mcp.go b/cli/exp_mcp.go index 57adc4ffa21ad..18314cbbfbd30 100644 --- a/cli/exp_mcp.go +++ b/cli/exp_mcp.go @@ -400,21 +400,23 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct ) // Create a new context for the tools with all relevant information. - tb := toolsdk.NewToolbox(client) + tb := toolsdk.Deps{ + CoderClient: client, + } // Get the workspace agent token from the environment. var hasAgentClient bool if agentToken, err := getAgentToken(fs); err == nil && agentToken != "" { hasAgentClient = true agentClient := agentsdk.New(client.URL) agentClient.SetSessionToken(agentToken) - tb = tb.WithAgentClient(agentClient) + tb.AgentClient = agentClient } else { cliui.Warnf(inv.Stderr, "CODER_AGENT_TOKEN is not set, task reporting will not be available") } if appStatusSlug == "" { cliui.Warnf(inv.Stderr, "CODER_MCP_APP_STATUS_SLUG is not set, task reporting will not be available.") } else { - tb = tb.WithAppStatusSlug(appStatusSlug) + tb.AppStatusSlug = appStatusSlug } // Register tools based on the allowlist (if specified) @@ -695,7 +697,7 @@ func getAgentToken(fs afero.Fs) (string, error) { // mcpFromSDK adapts a toolsdk.Tool to go-mcp's server.ServerTool. // It assumes that the tool responds with a valid JSON object. -func mcpFromSDK(sdkTool toolsdk.Tool[any, any], tb toolsdk.Toolbox) server.ServerTool { +func mcpFromSDK(sdkTool toolsdk.Tool[any, any], tb toolsdk.Deps) server.ServerTool { // NOTE: some clients will silently refuse to use tools if there is an issue // with the tool's schema or configuration. if sdkTool.Schema.Properties == nil { diff --git a/codersdk/toolsdk/toolsdk.go b/codersdk/toolsdk/toolsdk.go index 2c9a00255cc9c..be3c1cd530b5b 100644 --- a/codersdk/toolsdk/toolsdk.go +++ b/codersdk/toolsdk/toolsdk.go @@ -13,52 +13,15 @@ import ( "github.com/coder/coder/v2/codersdk/agentsdk" ) -// Toolbox provides access to tool dependencies. -type Toolbox interface { - CoderClient() *codersdk.Client - AgentClient() (*agentsdk.Client, bool) - AppStatusSlug() (string, bool) - - WithAgentClient(*agentsdk.Client) Toolbox - WithAppStatusSlug(string) Toolbox -} - -// toolbox is the concrete implementation of Toolbox. -type toolbox struct { - coderClient *codersdk.Client - agentClient *agentsdk.Client - appStatusSlug string -} - -// NewToolbox constructs a Toolbox with a required CoderClient. -func NewToolbox(coder *codersdk.Client) Toolbox { - return &toolbox{coderClient: coder} -} - -func (tb *toolbox) CoderClient() *codersdk.Client { - return tb.coderClient -} - -func (tb *toolbox) AgentClient() (*agentsdk.Client, bool) { - return tb.agentClient, tb.agentClient != nil -} - -func (tb *toolbox) AppStatusSlug() (string, bool) { - return tb.appStatusSlug, tb.appStatusSlug != "" -} - -func (tb *toolbox) WithAgentClient(agent *agentsdk.Client) Toolbox { - tb.agentClient = agent - return tb -} - -func (tb *toolbox) WithAppStatusSlug(slug string) Toolbox { - tb.appStatusSlug = slug - return tb +// Deps provides access to tool dependencies. +type Deps struct { + CoderClient *codersdk.Client + AgentClient *agentsdk.Client + AppStatusSlug string } // HandlerFunc is a function that handles a tool call. -type HandlerFunc[Arg, Ret any] func(tb Toolbox, args Arg) (Ret, error) +type HandlerFunc[Arg, Ret any] func(tb Deps, args Arg) (Ret, error) type Tool[Arg, Ret any] struct { aisdk.Tool @@ -69,7 +32,7 @@ type Tool[Arg, Ret any] struct { func (t Tool[Arg, Ret]) Generic() Tool[any, any] { return Tool[any, any]{ Tool: t.Tool, - Handler: func(tb Toolbox, args any) (any, error) { + Handler: func(tb Deps, args any) (any, error) { typedArg, ok := args.(Arg) if !ok { return nil, xerrors.Errorf("developer error: invalid argument type for tool %s", t.Tool.Name) @@ -152,7 +115,7 @@ type UploadTarFileArgs struct { // WithRecover wraps a HandlerFunc to recover from panics and return an error. func WithRecover[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] { - return func(tb Toolbox, args Arg) (ret Ret, err error) { + return func(tb Deps, args Arg) (ret Ret, err error) { defer func() { if r := recover(); r != nil { err = xerrors.Errorf("tool handler panic: %v", r) @@ -220,17 +183,15 @@ var ( Required: []string{"summary", "link", "state"}, }, }, - Handler: func(tb Toolbox, args ReportTaskArgs) (string, error) { - agentClient, ok := tb.AgentClient() - if !ok { + Handler: func(tb Deps, args ReportTaskArgs) (string, error) { + if tb.AgentClient == nil { return "", xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set") } - appStatusSlug, ok := tb.AppStatusSlug() - if !ok { + if tb.AppStatusSlug == "" { return "", xerrors.New("workspace app status slug not found in toolbox") } - if err := agentClient.PatchAppStatus(context.TODO(), agentsdk.PatchAppStatus{ - AppSlug: appStatusSlug, + if err := tb.AgentClient.PatchAppStatus(context.TODO(), agentsdk.PatchAppStatus{ + AppSlug: tb.AppStatusSlug, Message: args.Summary, URI: args.Link, State: codersdk.WorkspaceAppStatusState(args.State), @@ -256,12 +217,12 @@ This returns more data than list_workspaces to reduce token usage.`, Required: []string{"workspace_id"}, }, }, - Handler: func(tb Toolbox, args GetWorkspaceArgs) (codersdk.Workspace, error) { + Handler: func(tb Deps, args GetWorkspaceArgs) (codersdk.Workspace, error) { wsID, err := uuid.Parse(args.WorkspaceID) if err != nil { return codersdk.Workspace{}, xerrors.New("workspace_id must be a valid UUID") } - return tb.CoderClient().Workspace(context.TODO(), wsID) + return tb.CoderClient.Workspace(context.TODO(), wsID) }, } @@ -296,7 +257,7 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{"user", "template_version_id", "name", "rich_parameters"}, }, }, - Handler: func(tb Toolbox, args CreateWorkspaceArgs) (codersdk.Workspace, error) { + Handler: func(tb Deps, args CreateWorkspaceArgs) (codersdk.Workspace, error) { tvID, err := uuid.Parse(args.TemplateVersionID) if err != nil { return codersdk.Workspace{}, xerrors.New("template_version_id must be a valid UUID") @@ -311,7 +272,7 @@ is provisioned correctly and the agent can connect to the control plane. Value: v, }) } - workspace, err := tb.CoderClient().CreateUserWorkspace(context.TODO(), args.User, codersdk.CreateWorkspaceRequest{ + workspace, err := tb.CoderClient.CreateUserWorkspace(context.TODO(), args.User, codersdk.CreateWorkspaceRequest{ TemplateVersionID: tvID, Name: args.Name, RichParameterValues: buildParams, @@ -336,12 +297,12 @@ is provisioned correctly and the agent can connect to the control plane. }, }, }, - Handler: func(tb Toolbox, args ListWorkspacesArgs) ([]MinimalWorkspace, error) { + Handler: func(tb Deps, args ListWorkspacesArgs) ([]MinimalWorkspace, error) { owner := args.Owner if owner == "" { owner = codersdk.Me } - workspaces, err := tb.CoderClient().Workspaces(context.TODO(), codersdk.WorkspaceFilter{ + workspaces, err := tb.CoderClient.Workspaces(context.TODO(), codersdk.WorkspaceFilter{ Owner: owner, }) if err != nil { @@ -373,8 +334,8 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{}, }, }, - Handler: func(tb Toolbox, _ NoArgs) ([]MinimalTemplate, error) { - templates, err := tb.CoderClient().Templates(context.TODO(), codersdk.TemplateFilter{}) + Handler: func(tb Deps, _ NoArgs) ([]MinimalTemplate, error) { + templates, err := tb.CoderClient.Templates(context.TODO(), codersdk.TemplateFilter{}) if err != nil { return nil, err } @@ -406,12 +367,12 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{"template_version_id"}, }, }, - Handler: func(tb Toolbox, args ListTemplateVersionParametersArgs) ([]codersdk.TemplateVersionParameter, error) { + Handler: func(tb Deps, args ListTemplateVersionParametersArgs) ([]codersdk.TemplateVersionParameter, error) { templateVersionID, err := uuid.Parse(args.TemplateVersionID) if err != nil { return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) } - parameters, err := tb.CoderClient().TemplateVersionRichParameters(context.TODO(), templateVersionID) + parameters, err := tb.CoderClient.TemplateVersionRichParameters(context.TODO(), templateVersionID) if err != nil { return nil, err } @@ -428,8 +389,8 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{}, }, }, - Handler: func(tb Toolbox, _ NoArgs) (codersdk.User, error) { - return tb.CoderClient().User(context.TODO(), "me") + Handler: func(tb Deps, _ NoArgs) (codersdk.User, error) { + return tb.CoderClient.User(context.TODO(), "me") }, } @@ -455,7 +416,7 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{"workspace_id", "transition"}, }, }, - Handler: func(tb Toolbox, args CreateWorkspaceBuildArgs) (codersdk.WorkspaceBuild, error) { + Handler: func(tb Deps, args CreateWorkspaceBuildArgs) (codersdk.WorkspaceBuild, error) { workspaceID, err := uuid.Parse(args.WorkspaceID) if err != nil { return codersdk.WorkspaceBuild{}, xerrors.Errorf("workspace_id must be a valid UUID: %w", err) @@ -474,7 +435,7 @@ is provisioned correctly and the agent can connect to the control plane. if templateVersionID != uuid.Nil { cbr.TemplateVersionID = templateVersionID } - return tb.CoderClient().CreateWorkspaceBuild(context.TODO(), workspaceID, cbr) + return tb.CoderClient.CreateWorkspaceBuild(context.TODO(), workspaceID, cbr) }, } @@ -936,8 +897,8 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"file_id"}, }, }, - Handler: func(tb Toolbox, args CreateTemplateVersionArgs) (codersdk.TemplateVersion, error) { - me, err := tb.CoderClient().User(context.TODO(), "me") + Handler: func(tb Deps, args CreateTemplateVersionArgs) (codersdk.TemplateVersion, error) { + me, err := tb.CoderClient.User(context.TODO(), "me") if err != nil { return codersdk.TemplateVersion{}, err } @@ -949,7 +910,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t if err != nil { return codersdk.TemplateVersion{}, xerrors.Errorf("template_id must be a valid UUID: %w", err) } - templateVersion, err := tb.CoderClient().CreateTemplateVersion(context.TODO(), me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{ + templateVersion, err := tb.CoderClient.CreateTemplateVersion(context.TODO(), me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{ Message: "Created by AI", StorageMethod: codersdk.ProvisionerStorageMethodFile, FileID: fileID, @@ -978,12 +939,12 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"workspace_agent_id"}, }, }, - Handler: func(tb Toolbox, args GetWorkspaceAgentLogsArgs) ([]string, error) { + Handler: func(tb Deps, args GetWorkspaceAgentLogsArgs) ([]string, error) { workspaceAgentID, err := uuid.Parse(args.WorkspaceAgentID) if err != nil { return nil, xerrors.Errorf("workspace_agent_id must be a valid UUID: %w", err) } - logs, closer, err := tb.CoderClient().WorkspaceAgentLogsAfter(context.TODO(), workspaceAgentID, 0, false) + logs, closer, err := tb.CoderClient.WorkspaceAgentLogsAfter(context.TODO(), workspaceAgentID, 0, false) if err != nil { return nil, err } @@ -1013,12 +974,12 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"workspace_build_id"}, }, }, - Handler: func(tb Toolbox, args GetWorkspaceBuildLogsArgs) ([]string, error) { + Handler: func(tb Deps, args GetWorkspaceBuildLogsArgs) ([]string, error) { workspaceBuildID, err := uuid.Parse(args.WorkspaceBuildID) if err != nil { return nil, xerrors.Errorf("workspace_build_id must be a valid UUID: %w", err) } - logs, closer, err := tb.CoderClient().WorkspaceBuildLogsAfter(context.TODO(), workspaceBuildID, 0) + logs, closer, err := tb.CoderClient.WorkspaceBuildLogsAfter(context.TODO(), workspaceBuildID, 0) if err != nil { return nil, err } @@ -1044,13 +1005,13 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"template_version_id"}, }, }, - Handler: func(tb Toolbox, args GetTemplateVersionLogsArgs) ([]string, error) { + Handler: func(tb Deps, args GetTemplateVersionLogsArgs) ([]string, error) { templateVersionID, err := uuid.Parse(args.TemplateVersionID) if err != nil { return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) } - logs, closer, err := tb.CoderClient().TemplateVersionLogsAfter(context.TODO(), templateVersionID, 0) + logs, closer, err := tb.CoderClient.TemplateVersionLogsAfter(context.TODO(), templateVersionID, 0) if err != nil { return nil, err } @@ -1079,7 +1040,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"template_id", "template_version_id"}, }, }, - Handler: func(tb Toolbox, args UpdateTemplateActiveVersionArgs) (string, error) { + Handler: func(tb Deps, args UpdateTemplateActiveVersionArgs) (string, error) { templateID, err := uuid.Parse(args.TemplateID) if err != nil { return "", xerrors.Errorf("template_id must be a valid UUID: %w", err) @@ -1088,7 +1049,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t if err != nil { return "", xerrors.Errorf("template_version_id must be a valid UUID: %w", err) } - err = tb.CoderClient().UpdateActiveTemplateVersion(context.TODO(), templateID, codersdk.UpdateActiveTemplateVersion{ + err = tb.CoderClient.UpdateActiveTemplateVersion(context.TODO(), templateID, codersdk.UpdateActiveTemplateVersion{ ID: templateVersionID, }) if err != nil { @@ -1112,7 +1073,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"mime_type", "files"}, }, }, - Handler: func(tb Toolbox, args UploadTarFileArgs) (codersdk.UploadResponse, error) { + Handler: func(tb Deps, args UploadTarFileArgs) (codersdk.UploadResponse, error) { pipeReader, pipeWriter := io.Pipe() go func() { defer pipeWriter.Close() @@ -1137,7 +1098,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t } }() - resp, err := tb.CoderClient().Upload(context.TODO(), codersdk.ContentTypeTar, pipeReader) + resp, err := tb.CoderClient.Upload(context.TODO(), codersdk.ContentTypeTar, pipeReader) if err != nil { return codersdk.UploadResponse{}, err } @@ -1172,8 +1133,8 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"name", "display_name", "description", "version_id"}, }, }, - Handler: func(tb Toolbox, args CreateTemplateArgs) (codersdk.Template, error) { - me, err := tb.CoderClient().User(context.TODO(), "me") + Handler: func(tb Deps, args CreateTemplateArgs) (codersdk.Template, error) { + me, err := tb.CoderClient.User(context.TODO(), "me") if err != nil { return codersdk.Template{}, err } @@ -1181,7 +1142,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t if err != nil { return codersdk.Template{}, xerrors.Errorf("version_id must be a valid UUID: %w", err) } - template, err := tb.CoderClient().CreateTemplate(context.TODO(), me.OrganizationIDs[0], codersdk.CreateTemplateRequest{ + template, err := tb.CoderClient.CreateTemplate(context.TODO(), me.OrganizationIDs[0], codersdk.CreateTemplateRequest{ Name: args.Name, DisplayName: args.DisplayName, Description: args.Description, @@ -1206,12 +1167,12 @@ The file_id provided is a reference to a tar file you have uploaded containing t }, }, }, - Handler: func(tb Toolbox, args DeleteTemplateArgs) (string, error) { + Handler: func(tb Deps, args DeleteTemplateArgs) (string, error) { templateID, err := uuid.Parse(args.TemplateID) if err != nil { return "", xerrors.Errorf("template_id must be a valid UUID: %w", err) } - err = tb.CoderClient().DeleteTemplate(context.TODO(), templateID) + err = tb.CoderClient.DeleteTemplate(context.TODO(), templateID) if err != nil { return "", err } diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index 64272015c6f60..f9c121ebe9929 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -70,7 +70,11 @@ func TestTools(t *testing.T) { }) t.Run("ReportTask", func(t *testing.T) { - tb := toolsdk.NewToolbox(memberClient).WithAgentClient(agentClient).WithAppStatusSlug("some-agent-app") + tb := toolsdk.Deps{ + CoderClient: memberClient, + AgentClient: agentClient, + AppStatusSlug: "some-agent-app", + } _, err := testTool(t, toolsdk.ReportTask, tb, toolsdk.ReportTaskArgs{ Summary: "test summary", State: "complete", @@ -80,7 +84,7 @@ func TestTools(t *testing.T) { }) t.Run("GetWorkspace", func(t *testing.T) { - tb := toolsdk.NewToolbox(memberClient) + tb := toolsdk.Deps{CoderClient: memberClient} result, err := testTool(t, toolsdk.GetWorkspace, tb, toolsdk.GetWorkspaceArgs{ WorkspaceID: r.Workspace.ID.String(), }) @@ -90,7 +94,7 @@ func TestTools(t *testing.T) { }) t.Run("ListTemplates", func(t *testing.T) { - tb := toolsdk.NewToolbox(memberClient) + tb := toolsdk.Deps{CoderClient: memberClient} // Get the templates directly for comparison expected, err := memberClient.Templates(context.Background(), codersdk.TemplateFilter{}) require.NoError(t, err) @@ -113,7 +117,7 @@ func TestTools(t *testing.T) { }) t.Run("Whoami", func(t *testing.T) { - tb := toolsdk.NewToolbox(memberClient) + tb := toolsdk.Deps{CoderClient: memberClient} result, err := testTool(t, toolsdk.GetAuthenticatedUser, tb, toolsdk.NoArgs{}) require.NoError(t, err) @@ -122,7 +126,7 @@ func TestTools(t *testing.T) { }) t.Run("ListWorkspaces", func(t *testing.T) { - tb := toolsdk.NewToolbox(memberClient) + tb := toolsdk.Deps{CoderClient: memberClient} result, err := testTool(t, toolsdk.ListWorkspaces, tb, toolsdk.ListWorkspacesArgs{}) require.NoError(t, err) @@ -134,7 +138,7 @@ func TestTools(t *testing.T) { t.Run("CreateWorkspaceBuild", func(t *testing.T) { t.Run("Stop", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) - tb := toolsdk.NewToolbox(memberClient) + tb := toolsdk.Deps{CoderClient: memberClient} result, err := testTool(t, toolsdk.CreateWorkspaceBuild, tb, toolsdk.CreateWorkspaceBuildArgs{ WorkspaceID: r.Workspace.ID.String(), Transition: "stop", @@ -153,7 +157,7 @@ func TestTools(t *testing.T) { t.Run("Start", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) - tb := toolsdk.NewToolbox(memberClient) + tb := toolsdk.Deps{CoderClient: memberClient} result, err := testTool(t, toolsdk.CreateWorkspaceBuild, tb, toolsdk.CreateWorkspaceBuildArgs{ WorkspaceID: r.Workspace.ID.String(), Transition: "start", @@ -172,7 +176,7 @@ func TestTools(t *testing.T) { t.Run("TemplateVersionChange", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) - tb := toolsdk.NewToolbox(memberClient) + tb := toolsdk.Deps{CoderClient: memberClient} // Get the current template version ID before updating workspace, err := memberClient.Workspace(ctx, r.Workspace.ID) require.NoError(t, err) @@ -216,7 +220,7 @@ func TestTools(t *testing.T) { }) t.Run("ListTemplateVersionParameters", func(t *testing.T) { - tb := toolsdk.NewToolbox(memberClient) + tb := toolsdk.Deps{CoderClient: memberClient} params, err := testTool(t, toolsdk.ListTemplateVersionParameters, tb, toolsdk.ListTemplateVersionParametersArgs{ TemplateVersionID: r.TemplateVersion.ID.String(), }) @@ -226,7 +230,7 @@ func TestTools(t *testing.T) { }) t.Run("GetWorkspaceAgentLogs", func(t *testing.T) { - tb := toolsdk.NewToolbox(client) + tb := toolsdk.Deps{CoderClient: client} logs, err := testTool(t, toolsdk.GetWorkspaceAgentLogs, tb, toolsdk.GetWorkspaceAgentLogsArgs{ WorkspaceAgentID: agentID.String(), }) @@ -236,7 +240,7 @@ func TestTools(t *testing.T) { }) t.Run("GetWorkspaceBuildLogs", func(t *testing.T) { - tb := toolsdk.NewToolbox(memberClient) + tb := toolsdk.Deps{CoderClient: memberClient} logs, err := testTool(t, toolsdk.GetWorkspaceBuildLogs, tb, toolsdk.GetWorkspaceBuildLogsArgs{ WorkspaceBuildID: r.Build.ID.String(), }) @@ -246,7 +250,7 @@ func TestTools(t *testing.T) { }) t.Run("GetTemplateVersionLogs", func(t *testing.T) { - tb := toolsdk.NewToolbox(memberClient) + tb := toolsdk.Deps{CoderClient: memberClient} logs, err := testTool(t, toolsdk.GetTemplateVersionLogs, tb, toolsdk.GetTemplateVersionLogsArgs{ TemplateVersionID: r.TemplateVersion.ID.String(), }) @@ -256,7 +260,7 @@ func TestTools(t *testing.T) { }) t.Run("UpdateTemplateActiveVersion", func(t *testing.T) { - tb := toolsdk.NewToolbox(client) + tb := toolsdk.Deps{CoderClient: client} result, err := testTool(t, toolsdk.UpdateTemplateActiveVersion, tb, toolsdk.UpdateTemplateActiveVersionArgs{ TemplateID: r.Template.ID.String(), TemplateVersionID: r.TemplateVersion.ID.String(), @@ -267,7 +271,7 @@ func TestTools(t *testing.T) { }) t.Run("DeleteTemplate", func(t *testing.T) { - tb := toolsdk.NewToolbox(client) + tb := toolsdk.Deps{CoderClient: client} _, err := testTool(t, toolsdk.DeleteTemplate, tb, toolsdk.DeleteTemplateArgs{ TemplateID: r.Template.ID.String(), }) @@ -277,7 +281,7 @@ func TestTools(t *testing.T) { }) t.Run("UploadTarFile", func(t *testing.T) { - tb := toolsdk.NewToolbox(client) + tb := toolsdk.Deps{CoderClient: client} files := map[string]string{ "main.tf": `resource "null_resource" "example" {}`, } @@ -291,7 +295,7 @@ func TestTools(t *testing.T) { }) t.Run("CreateTemplateVersion", func(t *testing.T) { - tb := toolsdk.NewToolbox(client) + tb := toolsdk.Deps{CoderClient: client} // nolint:gocritic // This is in a test package and does not end up in the build file := dbgen.File(t, store, database.File{}) @@ -304,7 +308,7 @@ func TestTools(t *testing.T) { }) t.Run("CreateTemplate", func(t *testing.T) { - tb := toolsdk.NewToolbox(client) + tb := toolsdk.Deps{CoderClient: client} // Create a new template version for use here. tv := dbfake.TemplateVersion(t, store). // nolint:gocritic // This is in a test package and does not end up in the build @@ -323,7 +327,7 @@ func TestTools(t *testing.T) { }) t.Run("CreateWorkspace", func(t *testing.T) { - tb := toolsdk.NewToolbox(memberClient) + tb := toolsdk.Deps{CoderClient: memberClient} // We need a template version ID to create a workspace res, err := testTool(t, toolsdk.CreateWorkspace, tb, toolsdk.CreateWorkspaceArgs{ User: "me", @@ -343,7 +347,7 @@ func TestTools(t *testing.T) { var testedTools sync.Map // testTool is a helper function to test a tool and mark it as tested. -func testTool[Arg, Ret any](t *testing.T, tool toolsdk.Tool[Arg, Ret], tb toolsdk.Toolbox, args Arg) (Ret, error) { +func testTool[Arg, Ret any](t *testing.T, tool toolsdk.Tool[Arg, Ret], tb toolsdk.Deps, args Arg) (Ret, error) { t.Helper() testedTools.Store(tool.Tool.Name, true) result, err := tool.Handler(tb, args) @@ -391,14 +395,14 @@ func TestWithRecovery(t *testing.T) { Name: "fake_tool", Description: "Returns a string for testing.", }, - Handler: func(tb toolsdk.Toolbox, args string) (string, error) { + Handler: func(tb toolsdk.Deps, args string) (string, error) { require.Equal(t, "test", args) return "ok", nil }, } wrapped := toolsdk.WithRecover(fakeTool.Handler) - v, err := wrapped(nil, "test") + v, err := wrapped(toolsdk.Deps{}, "test") require.NoError(t, err) require.Equal(t, "ok", v) }) @@ -410,13 +414,13 @@ func TestWithRecovery(t *testing.T) { Name: "fake_tool", Description: "Returns an error for testing.", }, - Handler: func(tb toolsdk.Toolbox, args string) (string, error) { + Handler: func(tb toolsdk.Deps, args string) (string, error) { require.Equal(t, "test", args) return "", assert.AnError }, } wrapped := toolsdk.WithRecover(fakeTool.Handler) - v, err := wrapped(nil, "test") + v, err := wrapped(toolsdk.Deps{}, "test") require.Empty(t, v) require.ErrorIs(t, err, assert.AnError) }) @@ -428,13 +432,13 @@ func TestWithRecovery(t *testing.T) { Name: "panic_tool", Description: "Panics for testing.", }, - Handler: func(tb toolsdk.Toolbox, args string) (string, error) { + Handler: func(tb toolsdk.Deps, args string) (string, error) { panic("you can't sweat this fever out") }, } wrapped := toolsdk.WithRecover(panicTool.Handler) - v, err := wrapped(nil, "disco") + v, err := wrapped(toolsdk.Deps{}, "disco") require.Empty(t, v) require.ErrorContains(t, err, "you can't sweat this fever out") }) From c1057d930089681e5ef934b4f0f64e962394a191 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 25 Apr 2025 11:58:02 +0100 Subject: [PATCH 06/11] add WithCleanContext middleware func --- cli/exp_mcp.go | 4 +- codersdk/toolsdk/toolsdk.go | 110 ++++++++++++-------- codersdk/toolsdk/toolsdk_test.go | 169 ++++++++++++++++++++++++------- 3 files changed, 201 insertions(+), 82 deletions(-) diff --git a/cli/exp_mcp.go b/cli/exp_mcp.go index 18314cbbfbd30..1a8ea0570a7ef 100644 --- a/cli/exp_mcp.go +++ b/cli/exp_mcp.go @@ -713,8 +713,8 @@ func mcpFromSDK(sdkTool toolsdk.Tool[any, any], tb toolsdk.Deps) server.ServerTo Required: sdkTool.Schema.Required, }, }, - Handler: func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - result, err := sdkTool.Handler(tb, request.Params.Arguments) + Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + result, err := sdkTool.Handler(ctx, tb, request.Params.Arguments) if err != nil { return nil, err } diff --git a/codersdk/toolsdk/toolsdk.go b/codersdk/toolsdk/toolsdk.go index be3c1cd530b5b..6007cf826f4e6 100644 --- a/codersdk/toolsdk/toolsdk.go +++ b/codersdk/toolsdk/toolsdk.go @@ -21,7 +21,7 @@ type Deps struct { } // HandlerFunc is a function that handles a tool call. -type HandlerFunc[Arg, Ret any] func(tb Deps, args Arg) (Ret, error) +type HandlerFunc[Arg, Ret any] func(context.Context, Deps, Arg) (Ret, error) type Tool[Arg, Ret any] struct { aisdk.Tool @@ -32,12 +32,12 @@ type Tool[Arg, Ret any] struct { func (t Tool[Arg, Ret]) Generic() Tool[any, any] { return Tool[any, any]{ Tool: t.Tool, - Handler: func(tb Deps, args any) (any, error) { + Handler: func(ctx context.Context, tb Deps, args any) (any, error) { typedArg, ok := args.(Arg) if !ok { return nil, xerrors.Errorf("developer error: invalid argument type for tool %s", t.Tool.Name) } - return t.Handler(tb, typedArg) + return t.Handler(ctx, tb, typedArg) }, } } @@ -115,13 +115,41 @@ type UploadTarFileArgs struct { // WithRecover wraps a HandlerFunc to recover from panics and return an error. func WithRecover[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] { - return func(tb Deps, args Arg) (ret Ret, err error) { + return func(ctx context.Context, tb Deps, args Arg) (ret Ret, err error) { defer func() { if r := recover(); r != nil { err = xerrors.Errorf("tool handler panic: %v", r) } }() - return h(tb, args) + return h(ctx, tb, args) + } +} + +// WithCleanContext wraps a HandlerFunc to provide it with a new context. +// This ensures that no data is passed using context.Value. +// If a deadline is set on the parent context, it will be passed to the child +// context. +func WithCleanContext[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] { + return func(parent context.Context, tb Deps, args Arg) (ret Ret, err error) { + child, childCancel := context.WithCancel(context.Background()) + defer childCancel() + // Ensure that cancellation propagates from the parent context to the child context. + go func() { + select { + case <-child.Done(): + return + case <-parent.Done(): + childCancel() + } + }() + // Also ensure that the child context has the same deadline as the parent + // context. + if deadline, ok := parent.Deadline(); ok { + deadlineCtx, deadlineCancel := context.WithDeadline(child, deadline) + defer deadlineCancel() + child = deadlineCtx + } + return h(child, tb, args) } } @@ -137,7 +165,7 @@ func wrapAll(mw func(HandlerFunc[any, any]) HandlerFunc[any, any], tools ...Tool var ( // All is a list of all tools that can be used in the Coder CLI. // When you add a new tool, be sure to include it here! - All = wrapAll(WithRecover, + All = wrapAll(WithCleanContext, wrapAll(WithRecover, CreateTemplate.Generic(), CreateTemplateVersion.Generic(), CreateWorkspace.Generic(), @@ -154,7 +182,7 @@ var ( ReportTask.Generic(), UploadTarFile.Generic(), UpdateTemplateActiveVersion.Generic(), - ) + )...) ReportTask = Tool[ReportTaskArgs, string]{ Tool: aisdk.Tool{ @@ -183,14 +211,14 @@ var ( Required: []string{"summary", "link", "state"}, }, }, - Handler: func(tb Deps, args ReportTaskArgs) (string, error) { + Handler: func(ctx context.Context, tb Deps, args ReportTaskArgs) (string, error) { if tb.AgentClient == nil { return "", xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set") } if tb.AppStatusSlug == "" { return "", xerrors.New("workspace app status slug not found in toolbox") } - if err := tb.AgentClient.PatchAppStatus(context.TODO(), agentsdk.PatchAppStatus{ + if err := tb.AgentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ AppSlug: tb.AppStatusSlug, Message: args.Summary, URI: args.Link, @@ -217,12 +245,12 @@ This returns more data than list_workspaces to reduce token usage.`, Required: []string{"workspace_id"}, }, }, - Handler: func(tb Deps, args GetWorkspaceArgs) (codersdk.Workspace, error) { + Handler: func(ctx context.Context, tb Deps, args GetWorkspaceArgs) (codersdk.Workspace, error) { wsID, err := uuid.Parse(args.WorkspaceID) if err != nil { return codersdk.Workspace{}, xerrors.New("workspace_id must be a valid UUID") } - return tb.CoderClient.Workspace(context.TODO(), wsID) + return tb.CoderClient.Workspace(ctx, wsID) }, } @@ -257,7 +285,7 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{"user", "template_version_id", "name", "rich_parameters"}, }, }, - Handler: func(tb Deps, args CreateWorkspaceArgs) (codersdk.Workspace, error) { + Handler: func(ctx context.Context, tb Deps, args CreateWorkspaceArgs) (codersdk.Workspace, error) { tvID, err := uuid.Parse(args.TemplateVersionID) if err != nil { return codersdk.Workspace{}, xerrors.New("template_version_id must be a valid UUID") @@ -272,7 +300,7 @@ is provisioned correctly and the agent can connect to the control plane. Value: v, }) } - workspace, err := tb.CoderClient.CreateUserWorkspace(context.TODO(), args.User, codersdk.CreateWorkspaceRequest{ + workspace, err := tb.CoderClient.CreateUserWorkspace(ctx, args.User, codersdk.CreateWorkspaceRequest{ TemplateVersionID: tvID, Name: args.Name, RichParameterValues: buildParams, @@ -297,12 +325,12 @@ is provisioned correctly and the agent can connect to the control plane. }, }, }, - Handler: func(tb Deps, args ListWorkspacesArgs) ([]MinimalWorkspace, error) { + Handler: func(ctx context.Context, tb Deps, args ListWorkspacesArgs) ([]MinimalWorkspace, error) { owner := args.Owner if owner == "" { owner = codersdk.Me } - workspaces, err := tb.CoderClient.Workspaces(context.TODO(), codersdk.WorkspaceFilter{ + workspaces, err := tb.CoderClient.Workspaces(ctx, codersdk.WorkspaceFilter{ Owner: owner, }) if err != nil { @@ -334,8 +362,8 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{}, }, }, - Handler: func(tb Deps, _ NoArgs) ([]MinimalTemplate, error) { - templates, err := tb.CoderClient.Templates(context.TODO(), codersdk.TemplateFilter{}) + Handler: func(ctx context.Context, tb Deps, _ NoArgs) ([]MinimalTemplate, error) { + templates, err := tb.CoderClient.Templates(ctx, codersdk.TemplateFilter{}) if err != nil { return nil, err } @@ -367,12 +395,12 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{"template_version_id"}, }, }, - Handler: func(tb Deps, args ListTemplateVersionParametersArgs) ([]codersdk.TemplateVersionParameter, error) { + Handler: func(ctx context.Context, tb Deps, args ListTemplateVersionParametersArgs) ([]codersdk.TemplateVersionParameter, error) { templateVersionID, err := uuid.Parse(args.TemplateVersionID) if err != nil { return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) } - parameters, err := tb.CoderClient.TemplateVersionRichParameters(context.TODO(), templateVersionID) + parameters, err := tb.CoderClient.TemplateVersionRichParameters(ctx, templateVersionID) if err != nil { return nil, err } @@ -389,8 +417,8 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{}, }, }, - Handler: func(tb Deps, _ NoArgs) (codersdk.User, error) { - return tb.CoderClient.User(context.TODO(), "me") + Handler: func(ctx context.Context, tb Deps, _ NoArgs) (codersdk.User, error) { + return tb.CoderClient.User(ctx, "me") }, } @@ -416,7 +444,7 @@ is provisioned correctly and the agent can connect to the control plane. Required: []string{"workspace_id", "transition"}, }, }, - Handler: func(tb Deps, args CreateWorkspaceBuildArgs) (codersdk.WorkspaceBuild, error) { + Handler: func(ctx context.Context, tb Deps, args CreateWorkspaceBuildArgs) (codersdk.WorkspaceBuild, error) { workspaceID, err := uuid.Parse(args.WorkspaceID) if err != nil { return codersdk.WorkspaceBuild{}, xerrors.Errorf("workspace_id must be a valid UUID: %w", err) @@ -435,7 +463,7 @@ is provisioned correctly and the agent can connect to the control plane. if templateVersionID != uuid.Nil { cbr.TemplateVersionID = templateVersionID } - return tb.CoderClient.CreateWorkspaceBuild(context.TODO(), workspaceID, cbr) + return tb.CoderClient.CreateWorkspaceBuild(ctx, workspaceID, cbr) }, } @@ -897,8 +925,8 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"file_id"}, }, }, - Handler: func(tb Deps, args CreateTemplateVersionArgs) (codersdk.TemplateVersion, error) { - me, err := tb.CoderClient.User(context.TODO(), "me") + Handler: func(ctx context.Context, tb Deps, args CreateTemplateVersionArgs) (codersdk.TemplateVersion, error) { + me, err := tb.CoderClient.User(ctx, "me") if err != nil { return codersdk.TemplateVersion{}, err } @@ -910,7 +938,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t if err != nil { return codersdk.TemplateVersion{}, xerrors.Errorf("template_id must be a valid UUID: %w", err) } - templateVersion, err := tb.CoderClient.CreateTemplateVersion(context.TODO(), me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{ + templateVersion, err := tb.CoderClient.CreateTemplateVersion(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{ Message: "Created by AI", StorageMethod: codersdk.ProvisionerStorageMethodFile, FileID: fileID, @@ -939,12 +967,12 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"workspace_agent_id"}, }, }, - Handler: func(tb Deps, args GetWorkspaceAgentLogsArgs) ([]string, error) { + Handler: func(ctx context.Context, tb Deps, args GetWorkspaceAgentLogsArgs) ([]string, error) { workspaceAgentID, err := uuid.Parse(args.WorkspaceAgentID) if err != nil { return nil, xerrors.Errorf("workspace_agent_id must be a valid UUID: %w", err) } - logs, closer, err := tb.CoderClient.WorkspaceAgentLogsAfter(context.TODO(), workspaceAgentID, 0, false) + logs, closer, err := tb.CoderClient.WorkspaceAgentLogsAfter(ctx, workspaceAgentID, 0, false) if err != nil { return nil, err } @@ -974,12 +1002,12 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"workspace_build_id"}, }, }, - Handler: func(tb Deps, args GetWorkspaceBuildLogsArgs) ([]string, error) { + Handler: func(ctx context.Context, tb Deps, args GetWorkspaceBuildLogsArgs) ([]string, error) { workspaceBuildID, err := uuid.Parse(args.WorkspaceBuildID) if err != nil { return nil, xerrors.Errorf("workspace_build_id must be a valid UUID: %w", err) } - logs, closer, err := tb.CoderClient.WorkspaceBuildLogsAfter(context.TODO(), workspaceBuildID, 0) + logs, closer, err := tb.CoderClient.WorkspaceBuildLogsAfter(ctx, workspaceBuildID, 0) if err != nil { return nil, err } @@ -1005,13 +1033,13 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"template_version_id"}, }, }, - Handler: func(tb Deps, args GetTemplateVersionLogsArgs) ([]string, error) { + Handler: func(ctx context.Context, tb Deps, args GetTemplateVersionLogsArgs) ([]string, error) { templateVersionID, err := uuid.Parse(args.TemplateVersionID) if err != nil { return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) } - logs, closer, err := tb.CoderClient.TemplateVersionLogsAfter(context.TODO(), templateVersionID, 0) + logs, closer, err := tb.CoderClient.TemplateVersionLogsAfter(ctx, templateVersionID, 0) if err != nil { return nil, err } @@ -1040,7 +1068,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"template_id", "template_version_id"}, }, }, - Handler: func(tb Deps, args UpdateTemplateActiveVersionArgs) (string, error) { + Handler: func(ctx context.Context, tb Deps, args UpdateTemplateActiveVersionArgs) (string, error) { templateID, err := uuid.Parse(args.TemplateID) if err != nil { return "", xerrors.Errorf("template_id must be a valid UUID: %w", err) @@ -1049,7 +1077,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t if err != nil { return "", xerrors.Errorf("template_version_id must be a valid UUID: %w", err) } - err = tb.CoderClient.UpdateActiveTemplateVersion(context.TODO(), templateID, codersdk.UpdateActiveTemplateVersion{ + err = tb.CoderClient.UpdateActiveTemplateVersion(ctx, templateID, codersdk.UpdateActiveTemplateVersion{ ID: templateVersionID, }) if err != nil { @@ -1073,7 +1101,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"mime_type", "files"}, }, }, - Handler: func(tb Deps, args UploadTarFileArgs) (codersdk.UploadResponse, error) { + Handler: func(ctx context.Context, tb Deps, args UploadTarFileArgs) (codersdk.UploadResponse, error) { pipeReader, pipeWriter := io.Pipe() go func() { defer pipeWriter.Close() @@ -1098,7 +1126,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t } }() - resp, err := tb.CoderClient.Upload(context.TODO(), codersdk.ContentTypeTar, pipeReader) + resp, err := tb.CoderClient.Upload(ctx, codersdk.ContentTypeTar, pipeReader) if err != nil { return codersdk.UploadResponse{}, err } @@ -1133,8 +1161,8 @@ The file_id provided is a reference to a tar file you have uploaded containing t Required: []string{"name", "display_name", "description", "version_id"}, }, }, - Handler: func(tb Deps, args CreateTemplateArgs) (codersdk.Template, error) { - me, err := tb.CoderClient.User(context.TODO(), "me") + Handler: func(ctx context.Context, tb Deps, args CreateTemplateArgs) (codersdk.Template, error) { + me, err := tb.CoderClient.User(ctx, "me") if err != nil { return codersdk.Template{}, err } @@ -1142,7 +1170,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t if err != nil { return codersdk.Template{}, xerrors.Errorf("version_id must be a valid UUID: %w", err) } - template, err := tb.CoderClient.CreateTemplate(context.TODO(), me.OrganizationIDs[0], codersdk.CreateTemplateRequest{ + template, err := tb.CoderClient.CreateTemplate(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateRequest{ Name: args.Name, DisplayName: args.DisplayName, Description: args.Description, @@ -1167,12 +1195,12 @@ The file_id provided is a reference to a tar file you have uploaded containing t }, }, }, - Handler: func(tb Deps, args DeleteTemplateArgs) (string, error) { + Handler: func(ctx context.Context, tb Deps, args DeleteTemplateArgs) (string, error) { templateID, err := uuid.Parse(args.TemplateID) if err != nil { return "", xerrors.Errorf("template_id must be a valid UUID: %w", err) } - err = tb.CoderClient.DeleteTemplate(context.TODO(), templateID) + err = tb.CoderClient.DeleteTemplate(ctx, templateID) if err != nil { return "", err } diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index f9c121ebe9929..1a31c53c0ed13 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -350,42 +350,10 @@ var testedTools sync.Map func testTool[Arg, Ret any](t *testing.T, tool toolsdk.Tool[Arg, Ret], tb toolsdk.Deps, args Arg) (Ret, error) { t.Helper() testedTools.Store(tool.Tool.Name, true) - result, err := tool.Handler(tb, args) + result, err := tool.Handler(context.Background(), tb, args) return result, err } -// TestMain runs after all tests to ensure that all tools in this package have -// been tested once. -func TestMain(m *testing.M) { - // Initialize testedTools - for _, tool := range toolsdk.All { - testedTools.Store(tool.Tool.Name, false) - } - - code := m.Run() - - // Ensure all tools have been tested - var untested []string - for _, tool := range toolsdk.All { - if tested, ok := testedTools.Load(tool.Tool.Name); !ok || !tested.(bool) { - untested = append(untested, tool.Tool.Name) - } - } - - if len(untested) > 0 && code == 0 { - println("The following tools were not tested:") - for _, tool := range untested { - println(" - " + tool) - } - println("Please ensure that all tools are tested using testTool().") - println("If you just added a new tool, please add a test for it.") - println("NOTE: if you just ran an individual test, this is expected.") - os.Exit(1) - } - - os.Exit(code) -} - func TestWithRecovery(t *testing.T) { t.Parallel() t.Run("OK", func(t *testing.T) { @@ -395,14 +363,14 @@ func TestWithRecovery(t *testing.T) { Name: "fake_tool", Description: "Returns a string for testing.", }, - Handler: func(tb toolsdk.Deps, args string) (string, error) { + Handler: func(ctx context.Context, tb toolsdk.Deps, args string) (string, error) { require.Equal(t, "test", args) return "ok", nil }, } wrapped := toolsdk.WithRecover(fakeTool.Handler) - v, err := wrapped(toolsdk.Deps{}, "test") + v, err := wrapped(context.Background(), toolsdk.Deps{}, "test") require.NoError(t, err) require.Equal(t, "ok", v) }) @@ -414,13 +382,13 @@ func TestWithRecovery(t *testing.T) { Name: "fake_tool", Description: "Returns an error for testing.", }, - Handler: func(tb toolsdk.Deps, args string) (string, error) { + Handler: func(ctx context.Context, tb toolsdk.Deps, args string) (string, error) { require.Equal(t, "test", args) return "", assert.AnError }, } wrapped := toolsdk.WithRecover(fakeTool.Handler) - v, err := wrapped(toolsdk.Deps{}, "test") + v, err := wrapped(context.Background(), toolsdk.Deps{}, "test") require.Empty(t, v) require.ErrorIs(t, err, assert.AnError) }) @@ -432,14 +400,137 @@ func TestWithRecovery(t *testing.T) { Name: "panic_tool", Description: "Panics for testing.", }, - Handler: func(tb toolsdk.Deps, args string) (string, error) { + Handler: func(ctx context.Context, tb toolsdk.Deps, args string) (string, error) { panic("you can't sweat this fever out") }, } wrapped := toolsdk.WithRecover(panicTool.Handler) - v, err := wrapped(toolsdk.Deps{}, "disco") + v, err := wrapped(context.Background(), toolsdk.Deps{}, "disco") require.Empty(t, v) require.ErrorContains(t, err, "you can't sweat this fever out") }) } + +type testContextKey struct{} + +func TestWithCleanContext(t *testing.T) { + t.Parallel() + + t.Run("NoContextKeys", func(t *testing.T) { + t.Parallel() + + // This test is to ensure that the context values are not set in the + // toolsdk package. + ctxTool := toolsdk.Tool[toolsdk.NoArgs, string]{ + Tool: aisdk.Tool{ + Name: "context_tool", + Description: "Returns the context value for testing.", + }, + Handler: func(toolCtx context.Context, tb toolsdk.Deps, args toolsdk.NoArgs) (string, error) { + v := toolCtx.Value(testContextKey{}) + assert.Nil(t, v, "expected the context value to be nil") + return "", nil + }, + } + + wrapped := toolsdk.WithCleanContext(ctxTool.Handler) + ctx := context.WithValue(context.Background(), testContextKey{}, "test") + _, _ = wrapped(ctx, toolsdk.Deps{}, toolsdk.NoArgs{}) + }) + + t.Run("PropagateCancel", func(t *testing.T) { + t.Parallel() + + // This test is to ensure that the context is canceled properly. + ctxTool := toolsdk.Tool[toolsdk.NoArgs, string]{ + Tool: aisdk.Tool{ + Name: "context_tool", + Description: "Returns the context value for testing.", + }, + Handler: func(toolCtx context.Context, tb toolsdk.Deps, args toolsdk.NoArgs) (string, error) { + // Wait for the context to be canceled + <-toolCtx.Done() + return "", toolCtx.Err() + }, + } + wrapped := toolsdk.WithCleanContext(ctxTool.Handler) + errCh := make(chan error, 1) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + go func() { + _, err := wrapped(ctx, toolsdk.Deps{}, toolsdk.NoArgs{}) + errCh <- err + }() + + cancel() + select { + case <-t.Context().Done(): + require.Fail(t, "test timed out") + case err := <-errCh: + require.ErrorIs(t, err, context.Canceled) + // Context was canceled and the done channel was closed + } + }) + + t.Run("PropagateDeadline", func(t *testing.T) { + t.Parallel() + + // This test ensures that the context deadline is propagated to the child + // from the parent. + ctxTool := toolsdk.Tool[toolsdk.NoArgs, bool]{ + Tool: aisdk.Tool{ + Name: "context_tool_deadline", + Description: "Checks if context has deadline.", + }, + Handler: func(toolCtx context.Context, tb toolsdk.Deps, args toolsdk.NoArgs) (bool, error) { + _, ok := toolCtx.Deadline() + return ok, nil + }, + } + + wrapped := toolsdk.WithCleanContext(ctxTool.Handler) + parent, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + t.Cleanup(cancel) + ok, err := wrapped(parent, toolsdk.Deps{}, toolsdk.NoArgs{}) + require.NoError(t, err) + assert.True(t, ok, "expected deadline to be set on the child context") + }) +} + +// TestMain runs after all tests to ensure that all tools in this package have +// been tested once. +func TestMain(m *testing.M) { + // Initialize testedTools + /* + for _, tool := range toolsdk.All { + testedTools.Store(tool.Tool.Name, false) + } + */ + + code := m.Run() + + // Ensure all tools have been tested + /* + var untested []string + for _, tool := range toolsdk.All { + if tested, ok := testedTools.Load(tool.Tool.Name); !ok || !tested.(bool) { + untested = append(untested, tool.Tool.Name) + } + } + + if len(untested) > 0 && code == 0 { + println("The following tools were not tested:") + for _, tool := range untested { + println(" - " + tool) + } + println("Please ensure that all tools are tested using testTool().") + println("If you just added a new tool, please add a test for it.") + println("NOTE: if you just ran an individual test, this is expected.") + os.Exit(1) + } + */ + + os.Exit(code) +} From 9edd5f7ab10fa57afadc79f16fd9ea1441a1dc67 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 25 Apr 2025 14:38:18 +0100 Subject: [PATCH 07/11] fix(codersdk/toolsdk): address type incompatibility issues --- cli/exp_mcp.go | 25 ++---- cli/exp_mcp_test.go | 24 ++++-- codersdk/toolsdk/toolsdk.go | 102 +++++++++++++++--------- codersdk/toolsdk/toolsdk_test.go | 130 +++++++++++++++++-------------- 4 files changed, 162 insertions(+), 119 deletions(-) diff --git a/cli/exp_mcp.go b/cli/exp_mcp.go index 1a8ea0570a7ef..2449cc52c563d 100644 --- a/cli/exp_mcp.go +++ b/cli/exp_mcp.go @@ -1,6 +1,7 @@ package cli import ( + "bytes" "context" "encoding/json" "errors" @@ -697,7 +698,7 @@ func getAgentToken(fs afero.Fs) (string, error) { // mcpFromSDK adapts a toolsdk.Tool to go-mcp's server.ServerTool. // It assumes that the tool responds with a valid JSON object. -func mcpFromSDK(sdkTool toolsdk.Tool[any, any], tb toolsdk.Deps) server.ServerTool { +func mcpFromSDK(sdkTool toolsdk.GenericTool, tb toolsdk.Deps) server.ServerTool { // NOTE: some clients will silently refuse to use tools if there is an issue // with the tool's schema or configuration. if sdkTool.Schema.Properties == nil { @@ -714,27 +715,17 @@ func mcpFromSDK(sdkTool toolsdk.Tool[any, any], tb toolsdk.Deps) server.ServerTo }, }, Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - result, err := sdkTool.Handler(ctx, tb, request.Params.Arguments) + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(request.Params.Arguments); err != nil { + return nil, xerrors.Errorf("failed to encode request arguments: %w", err) + } + result, err := sdkTool.Handler(ctx, tb, buf.Bytes()) if err != nil { return nil, err } - var sb strings.Builder - if err := json.NewEncoder(&sb).Encode(result); err == nil { - return &mcp.CallToolResult{ - Content: []mcp.Content{ - mcp.NewTextContent(sb.String()), - }, - }, nil - } - // If the result is not JSON, return it as a string. - // This is a fallback for tools that return non-JSON data. - resultStr, ok := result.(string) - if !ok { - return nil, xerrors.Errorf("tool call result is neither valid JSON or a string, got: %T", result) - } return &mcp.CallToolResult{ Content: []mcp.Content{ - mcp.NewTextContent(resultStr), + mcp.NewTextContent(string(result)), }, }, nil }, diff --git a/cli/exp_mcp_test.go b/cli/exp_mcp_test.go index 0151021579814..2f911bcac6dac 100644 --- a/cli/exp_mcp_test.go +++ b/cli/exp_mcp_test.go @@ -31,12 +31,16 @@ func TestExpMcpServer(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) + cmdDone := make(chan struct{}) cancelCtx, cancel := context.WithCancel(ctx) - t.Cleanup(cancel) + t.Cleanup(func() { + cancel() + <-cmdDone + }) // Given: a running coder deployment client := coderdtest.New(t, nil) - _ = coderdtest.CreateFirstUser(t, client) + owner := coderdtest.CreateFirstUser(t, client) // Given: we run the exp mcp command with allowed tools set inv, root := clitest.New(t, "exp", "mcp", "server", "--allowed-tools=coder_get_authenticated_user") @@ -48,7 +52,6 @@ func TestExpMcpServer(t *testing.T) { // nolint: gocritic // not the focus of this test clitest.SetupConfig(t, client, root) - cmdDone := make(chan struct{}) go func() { defer close(cmdDone) err := inv.Run() @@ -61,9 +64,6 @@ func TestExpMcpServer(t *testing.T) { _ = pty.ReadLine(ctx) // ignore echoed output output := pty.ReadLine(ctx) - cancel() - <-cmdDone - // Then: we should only see the allowed tools in the response var toolsResponse struct { Result struct { @@ -81,6 +81,18 @@ func TestExpMcpServer(t *testing.T) { } slices.Sort(foundTools) require.Equal(t, []string{"coder_get_authenticated_user"}, foundTools) + + // Call the tool and ensure it works. + toolPayload := `{"jsonrpc":"2.0","id":3,"method":"tools/call", "params": {"name": "coder_get_authenticated_user", "arguments": {}}}` + pty.WriteLine(toolPayload) + _ = pty.ReadLine(ctx) // ignore echoed output + output = pty.ReadLine(ctx) + require.NotEmpty(t, output, "should have received a response from the tool") + // Ensure it's valid JSON + _, err = json.Marshal(output) + require.NoError(t, err, "should have received a valid JSON response from the tool") + // Ensure the tool returns the expected user + require.Contains(t, output, owner.UserID.String(), "should have received the expected user ID") }) t.Run("OK", func(t *testing.T) { diff --git a/codersdk/toolsdk/toolsdk.go b/codersdk/toolsdk/toolsdk.go index 6007cf826f4e6..597530d2dc5ff 100644 --- a/codersdk/toolsdk/toolsdk.go +++ b/codersdk/toolsdk/toolsdk.go @@ -2,7 +2,9 @@ package toolsdk import ( "archive/tar" + "bytes" "context" + "encoding/json" "io" "github.com/google/uuid" @@ -20,28 +22,49 @@ type Deps struct { AppStatusSlug string } -// HandlerFunc is a function that handles a tool call. +// HandlerFunc is a typed function that handles a tool call. type HandlerFunc[Arg, Ret any] func(context.Context, Deps, Arg) (Ret, error) +// Tool consists of an aisdk.Tool and a corresponding typed handler function. type Tool[Arg, Ret any] struct { aisdk.Tool Handler HandlerFunc[Arg, Ret] } -// Generic returns a type-erased version of the Tool. -func (t Tool[Arg, Ret]) Generic() Tool[any, any] { - return Tool[any, any]{ +// Generic returns a type-erased version of a TypedTool where the arguments and +// return values are converted to/from json.RawMessage. +// This allows the tool to be referenced without knowing the concrete arguments +// or return values. The original TypedHandlerFunc is wrapped to handle type +// conversion. +func (t Tool[Arg, Ret]) Generic() GenericTool { + return GenericTool{ Tool: t.Tool, - Handler: func(ctx context.Context, tb Deps, args any) (any, error) { - typedArg, ok := args.(Arg) - if !ok { - return nil, xerrors.Errorf("developer error: invalid argument type for tool %s", t.Tool.Name) + Handler: wrap(func(ctx context.Context, tb Deps, args json.RawMessage) (json.RawMessage, error) { + var typedArgs Arg + if err := json.Unmarshal(args, &typedArgs); err != nil { + return nil, xerrors.Errorf("failed to unmarshal args: %w", err) } - return t.Handler(ctx, tb, typedArg) - }, + ret, err := t.Handler(ctx, tb, typedArgs) + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(ret); err != nil { + return json.RawMessage{}, err + } + return buf.Bytes(), err + }, WithCleanContext, WithRecover), } } +// GenericTool is a type-erased wrapper for GenericTool. +// This allows referencing the tool without knowing the concrete argument or +// return type. The Handler function allows calling the tool with known types. +type GenericTool struct { + aisdk.Tool + Handler GenericHandlerFunc +} + +// GenericHandlerFunc is a function that handles a tool call. +type GenericHandlerFunc func(context.Context, Deps, json.RawMessage) (json.RawMessage, error) + type NoArgs struct{} type ReportTaskArgs struct { @@ -114,8 +137,8 @@ type UploadTarFileArgs struct { } // WithRecover wraps a HandlerFunc to recover from panics and return an error. -func WithRecover[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] { - return func(ctx context.Context, tb Deps, args Arg) (ret Ret, err error) { +func WithRecover(h GenericHandlerFunc) GenericHandlerFunc { + return func(ctx context.Context, tb Deps, args json.RawMessage) (ret json.RawMessage, err error) { defer func() { if r := recover(); r != nil { err = xerrors.Errorf("tool handler panic: %v", r) @@ -129,8 +152,8 @@ func WithRecover[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] { // This ensures that no data is passed using context.Value. // If a deadline is set on the parent context, it will be passed to the child // context. -func WithCleanContext[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Ret] { - return func(parent context.Context, tb Deps, args Arg) (ret Ret, err error) { +func WithCleanContext(h GenericHandlerFunc) GenericHandlerFunc { + return func(parent context.Context, tb Deps, args json.RawMessage) (ret json.RawMessage, err error) { child, childCancel := context.WithCancel(context.Background()) defer childCancel() // Ensure that cancellation propagates from the parent context to the child context. @@ -153,19 +176,18 @@ func WithCleanContext[Arg, Ret any](h HandlerFunc[Arg, Ret]) HandlerFunc[Arg, Re } } -// wrapAll wraps all provided tools with the given middleware function. -func wrapAll(mw func(HandlerFunc[any, any]) HandlerFunc[any, any], tools ...Tool[any, any]) []Tool[any, any] { - for i, t := range tools { - t.Handler = mw(t.Handler) - tools[i] = t +// wrap wraps the provided GenericHandlerFunc with the provided middleware functions. +func wrap(hf GenericHandlerFunc, mw ...func(GenericHandlerFunc) GenericHandlerFunc) GenericHandlerFunc { + for _, m := range mw { + hf = m(hf) } - return tools + return hf } var ( // All is a list of all tools that can be used in the Coder CLI. // When you add a new tool, be sure to include it here! - All = wrapAll(WithCleanContext, wrapAll(WithRecover, + All = []GenericTool{ CreateTemplate.Generic(), CreateTemplateVersion.Generic(), CreateWorkspace.Generic(), @@ -182,9 +204,9 @@ var ( ReportTask.Generic(), UploadTarFile.Generic(), UpdateTemplateActiveVersion.Generic(), - )...) + } - ReportTask = Tool[ReportTaskArgs, string]{ + ReportTask = Tool[ReportTaskArgs, codersdk.Response]{ Tool: aisdk.Tool{ Name: "coder_report_task", Description: "Report progress on a user task in Coder.", @@ -211,12 +233,12 @@ var ( Required: []string{"summary", "link", "state"}, }, }, - Handler: func(ctx context.Context, tb Deps, args ReportTaskArgs) (string, error) { + Handler: func(ctx context.Context, tb Deps, args ReportTaskArgs) (codersdk.Response, error) { if tb.AgentClient == nil { - return "", xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set") + return codersdk.Response{}, xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set") } if tb.AppStatusSlug == "" { - return "", xerrors.New("workspace app status slug not found in toolbox") + return codersdk.Response{}, xerrors.New("workspace app status slug not found in toolbox") } if err := tb.AgentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ AppSlug: tb.AppStatusSlug, @@ -224,9 +246,11 @@ var ( URI: args.Link, State: codersdk.WorkspaceAppStatusState(args.State), }); err != nil { - return "", err + return codersdk.Response{}, err } - return "Thanks for reporting!", nil + return codersdk.Response{ + Message: "Thanks for reporting!", + }, nil }, } @@ -934,9 +958,13 @@ The file_id provided is a reference to a tar file you have uploaded containing t if err != nil { return codersdk.TemplateVersion{}, xerrors.Errorf("file_id must be a valid UUID: %w", err) } - templateID, err := uuid.Parse(args.TemplateID) - if err != nil { - return codersdk.TemplateVersion{}, xerrors.Errorf("template_id must be a valid UUID: %w", err) + var templateID uuid.UUID + if args.TemplateID != "" { + tid, err := uuid.Parse(args.TemplateID) + if err != nil { + return codersdk.TemplateVersion{}, xerrors.Errorf("template_id must be a valid UUID: %w", err) + } + templateID = tid } templateVersion, err := tb.CoderClient.CreateTemplateVersion(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{ Message: "Created by AI", @@ -1183,7 +1211,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t }, } - DeleteTemplate = Tool[DeleteTemplateArgs, string]{ + DeleteTemplate = Tool[DeleteTemplateArgs, codersdk.Response]{ Tool: aisdk.Tool{ Name: "coder_delete_template", Description: "Delete a template. This is irreversible.", @@ -1195,16 +1223,18 @@ The file_id provided is a reference to a tar file you have uploaded containing t }, }, }, - Handler: func(ctx context.Context, tb Deps, args DeleteTemplateArgs) (string, error) { + Handler: func(ctx context.Context, tb Deps, args DeleteTemplateArgs) (codersdk.Response, error) { templateID, err := uuid.Parse(args.TemplateID) if err != nil { - return "", xerrors.Errorf("template_id must be a valid UUID: %w", err) + return codersdk.Response{}, xerrors.Errorf("template_id must be a valid UUID: %w", err) } err = tb.CoderClient.DeleteTemplate(ctx, templateID) if err != nil { - return "", err + return codersdk.Response{}, err } - return "Successfully deleted template!", nil + return codersdk.Response{ + Message: "Template deleted successfully.", + }, nil }, } ) diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index 1a31c53c0ed13..51cb25c7d88b3 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -2,6 +2,7 @@ package toolsdk_test import ( "context" + "encoding/json" "os" "sort" "sync" @@ -298,13 +299,22 @@ func TestTools(t *testing.T) { tb := toolsdk.Deps{CoderClient: client} // nolint:gocritic // This is in a test package and does not end up in the build file := dbgen.File(t, store, database.File{}) - - tv, err := testTool(t, toolsdk.CreateTemplateVersion, tb, toolsdk.CreateTemplateVersionArgs{ - FileID: file.ID.String(), - TemplateID: r.Template.ID.String(), + t.Run("WithoutTemplateID", func(t *testing.T) { + tv, err := testTool(t, toolsdk.CreateTemplateVersion, tb, toolsdk.CreateTemplateVersionArgs{ + FileID: file.ID.String(), + }) + require.NoError(t, err) + require.NotEmpty(t, tv) + }) + t.Run("WithTemplateID", func(t *testing.T) { + tb := toolsdk.Deps{CoderClient: client} + tv, err := testTool(t, toolsdk.CreateTemplateVersion, tb, toolsdk.CreateTemplateVersionArgs{ + FileID: file.ID.String(), + TemplateID: r.Template.ID.String(), + }) + require.NoError(t, err) + require.NotEmpty(t, tv) }) - require.NoError(t, err) - require.NotEmpty(t, tv) }) t.Run("CreateTemplate", func(t *testing.T) { @@ -347,66 +357,70 @@ func TestTools(t *testing.T) { var testedTools sync.Map // testTool is a helper function to test a tool and mark it as tested. +// Note that we test the _generic_ version of the tool and not the typed one. +// This is to mimic how we expect external callers to use the tool. func testTool[Arg, Ret any](t *testing.T, tool toolsdk.Tool[Arg, Ret], tb toolsdk.Deps, args Arg) (Ret, error) { t.Helper() - testedTools.Store(tool.Tool.Name, true) - result, err := tool.Handler(context.Background(), tb, args) - return result, err + defer func() { testedTools.Store(tool.Tool.Name, true) }() + toolArgs, err := json.Marshal(args) + require.NoError(t, err, "failed to marshal args") + result, err := tool.Generic().Handler(context.Background(), tb, toolArgs) + var ret Ret + require.NoError(t, json.Unmarshal(result, &ret), "failed to unmarshal result %q", string(result)) + return ret, err } func TestWithRecovery(t *testing.T) { t.Parallel() t.Run("OK", func(t *testing.T) { t.Parallel() - fakeTool := toolsdk.Tool[string, string]{ + fakeTool := toolsdk.GenericTool{ Tool: aisdk.Tool{ - Name: "fake_tool", - Description: "Returns a string for testing.", + Name: "echo", + Description: "Echoes the input.", }, - Handler: func(ctx context.Context, tb toolsdk.Deps, args string) (string, error) { - require.Equal(t, "test", args) - return "ok", nil + Handler: func(ctx context.Context, tb toolsdk.Deps, args json.RawMessage) (json.RawMessage, error) { + return args, nil }, } wrapped := toolsdk.WithRecover(fakeTool.Handler) - v, err := wrapped(context.Background(), toolsdk.Deps{}, "test") + v, err := wrapped(context.Background(), toolsdk.Deps{}, []byte(`{}`)) require.NoError(t, err) - require.Equal(t, "ok", v) + require.JSONEq(t, `{}`, string(v)) }) t.Run("Error", func(t *testing.T) { t.Parallel() - fakeTool := toolsdk.Tool[string, string]{ + fakeTool := toolsdk.GenericTool{ Tool: aisdk.Tool{ Name: "fake_tool", Description: "Returns an error for testing.", }, - Handler: func(ctx context.Context, tb toolsdk.Deps, args string) (string, error) { - require.Equal(t, "test", args) - return "", assert.AnError + Handler: func(ctx context.Context, tb toolsdk.Deps, args json.RawMessage) (json.RawMessage, error) { + return nil, assert.AnError }, } wrapped := toolsdk.WithRecover(fakeTool.Handler) - v, err := wrapped(context.Background(), toolsdk.Deps{}, "test") - require.Empty(t, v) + v, err := wrapped(context.Background(), toolsdk.Deps{}, []byte(`{}`)) + require.Nil(t, v) require.ErrorIs(t, err, assert.AnError) }) t.Run("Panic", func(t *testing.T) { t.Parallel() - panicTool := toolsdk.Tool[string, string]{ + panicTool := toolsdk.GenericTool{ Tool: aisdk.Tool{ Name: "panic_tool", Description: "Panics for testing.", }, - Handler: func(ctx context.Context, tb toolsdk.Deps, args string) (string, error) { + Handler: func(ctx context.Context, tb toolsdk.Deps, args json.RawMessage) (json.RawMessage, error) { panic("you can't sweat this fever out") }, } wrapped := toolsdk.WithRecover(panicTool.Handler) - v, err := wrapped(context.Background(), toolsdk.Deps{}, "disco") + v, err := wrapped(context.Background(), toolsdk.Deps{}, []byte("disco")) require.Empty(t, v) require.ErrorContains(t, err, "you can't sweat this fever out") }) @@ -422,36 +436,36 @@ func TestWithCleanContext(t *testing.T) { // This test is to ensure that the context values are not set in the // toolsdk package. - ctxTool := toolsdk.Tool[toolsdk.NoArgs, string]{ + ctxTool := toolsdk.GenericTool{ Tool: aisdk.Tool{ Name: "context_tool", Description: "Returns the context value for testing.", }, - Handler: func(toolCtx context.Context, tb toolsdk.Deps, args toolsdk.NoArgs) (string, error) { + Handler: func(toolCtx context.Context, tb toolsdk.Deps, args json.RawMessage) (json.RawMessage, error) { v := toolCtx.Value(testContextKey{}) assert.Nil(t, v, "expected the context value to be nil") - return "", nil + return nil, nil }, } wrapped := toolsdk.WithCleanContext(ctxTool.Handler) ctx := context.WithValue(context.Background(), testContextKey{}, "test") - _, _ = wrapped(ctx, toolsdk.Deps{}, toolsdk.NoArgs{}) + _, _ = wrapped(ctx, toolsdk.Deps{}, []byte(`{}`)) }) t.Run("PropagateCancel", func(t *testing.T) { t.Parallel() // This test is to ensure that the context is canceled properly. - ctxTool := toolsdk.Tool[toolsdk.NoArgs, string]{ + ctxTool := toolsdk.GenericTool{ Tool: aisdk.Tool{ Name: "context_tool", Description: "Returns the context value for testing.", }, - Handler: func(toolCtx context.Context, tb toolsdk.Deps, args toolsdk.NoArgs) (string, error) { + Handler: func(toolCtx context.Context, tb toolsdk.Deps, args json.RawMessage) (json.RawMessage, error) { // Wait for the context to be canceled <-toolCtx.Done() - return "", toolCtx.Err() + return nil, toolCtx.Err() }, } wrapped := toolsdk.WithCleanContext(ctxTool.Handler) @@ -460,7 +474,7 @@ func TestWithCleanContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) go func() { - _, err := wrapped(ctx, toolsdk.Deps{}, toolsdk.NoArgs{}) + _, err := wrapped(ctx, toolsdk.Deps{}, []byte(`{}`)) errCh <- err }() @@ -479,23 +493,23 @@ func TestWithCleanContext(t *testing.T) { // This test ensures that the context deadline is propagated to the child // from the parent. - ctxTool := toolsdk.Tool[toolsdk.NoArgs, bool]{ + ctxTool := toolsdk.GenericTool{ Tool: aisdk.Tool{ Name: "context_tool_deadline", Description: "Checks if context has deadline.", }, - Handler: func(toolCtx context.Context, tb toolsdk.Deps, args toolsdk.NoArgs) (bool, error) { + Handler: func(toolCtx context.Context, tb toolsdk.Deps, args json.RawMessage) (json.RawMessage, error) { _, ok := toolCtx.Deadline() - return ok, nil + assert.True(t, ok, "expected deadline to be set on the child context") + return nil, nil }, } wrapped := toolsdk.WithCleanContext(ctxTool.Handler) parent, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) t.Cleanup(cancel) - ok, err := wrapped(parent, toolsdk.Deps{}, toolsdk.NoArgs{}) + _, err := wrapped(parent, toolsdk.Deps{}, []byte(`{}`)) require.NoError(t, err) - assert.True(t, ok, "expected deadline to be set on the child context") }) } @@ -503,34 +517,30 @@ func TestWithCleanContext(t *testing.T) { // been tested once. func TestMain(m *testing.M) { // Initialize testedTools - /* - for _, tool := range toolsdk.All { - testedTools.Store(tool.Tool.Name, false) - } - */ + for _, tool := range toolsdk.All { + testedTools.Store(tool.Tool.Name, false) + } code := m.Run() // Ensure all tools have been tested - /* - var untested []string - for _, tool := range toolsdk.All { - if tested, ok := testedTools.Load(tool.Tool.Name); !ok || !tested.(bool) { - untested = append(untested, tool.Tool.Name) - } + var untested []string + for _, tool := range toolsdk.All { + if tested, ok := testedTools.Load(tool.Tool.Name); !ok || !tested.(bool) { + untested = append(untested, tool.Tool.Name) } + } - if len(untested) > 0 && code == 0 { - println("The following tools were not tested:") - for _, tool := range untested { - println(" - " + tool) - } - println("Please ensure that all tools are tested using testTool().") - println("If you just added a new tool, please add a test for it.") - println("NOTE: if you just ran an individual test, this is expected.") - os.Exit(1) + if len(untested) > 0 && code == 0 { + println("The following tools were not tested:") + for _, tool := range untested { + println(" - " + tool) } - */ + println("Please ensure that all tools are tested using testTool().") + println("If you just added a new tool, please add a test for it.") + println("NOTE: if you just ran an individual test, this is expected.") + os.Exit(1) + } os.Exit(code) } From 0f1d4acdf6742c054c1190dfc3575bab49906fb6 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 25 Apr 2025 15:03:15 +0100 Subject: [PATCH 08/11] fix tests --- cli/exp_mcp_test.go | 6 ++---- coderd/workspaceagents_test.go | 5 ++++- codersdk/toolsdk/toolsdk.go | 14 +++++++------- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/cli/exp_mcp_test.go b/cli/exp_mcp_test.go index 2f911bcac6dac..a6f9b748d3a06 100644 --- a/cli/exp_mcp_test.go +++ b/cli/exp_mcp_test.go @@ -33,10 +33,6 @@ func TestExpMcpServer(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) cmdDone := make(chan struct{}) cancelCtx, cancel := context.WithCancel(ctx) - t.Cleanup(func() { - cancel() - <-cmdDone - }) // Given: a running coder deployment client := coderdtest.New(t, nil) @@ -93,6 +89,8 @@ func TestExpMcpServer(t *testing.T) { require.NoError(t, err, "should have received a valid JSON response from the tool") // Ensure the tool returns the expected user require.Contains(t, output, owner.UserID.String(), "should have received the expected user ID") + cancel() + <-cmdDone }) t.Run("OK", func(t *testing.T) { diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 2f95ef9727ca5..da2619da0b29d 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -341,7 +341,6 @@ func TestWorkspaceAgentLogs(t *testing.T) { func TestWorkspaceAgentAppStatus(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitMedium) client, db := coderdtest.NewWithDatabase(t, nil) user := coderdtest.CreateFirstUser(t, client) client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID) @@ -362,6 +361,7 @@ func TestWorkspaceAgentAppStatus(t *testing.T) { agentClient.SetSessionToken(r.AgentToken) t.Run("Success", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ AppSlug: "vscode", Message: "testing", @@ -385,6 +385,7 @@ func TestWorkspaceAgentAppStatus(t *testing.T) { t.Run("FailUnknownApp", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ AppSlug: "unknown", Message: "testing", @@ -399,6 +400,7 @@ func TestWorkspaceAgentAppStatus(t *testing.T) { t.Run("FailUnknownState", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ AppSlug: "vscode", Message: "testing", @@ -413,6 +415,7 @@ func TestWorkspaceAgentAppStatus(t *testing.T) { t.Run("FailTooLong", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ AppSlug: "vscode", Message: strings.Repeat("a", 161), diff --git a/codersdk/toolsdk/toolsdk.go b/codersdk/toolsdk/toolsdk.go index 597530d2dc5ff..6d625db8ba838 100644 --- a/codersdk/toolsdk/toolsdk.go +++ b/codersdk/toolsdk/toolsdk.go @@ -156,6 +156,13 @@ func WithCleanContext(h GenericHandlerFunc) GenericHandlerFunc { return func(parent context.Context, tb Deps, args json.RawMessage) (ret json.RawMessage, err error) { child, childCancel := context.WithCancel(context.Background()) defer childCancel() + // Ensure that the child context has the same deadline as the parent + // context. + if deadline, ok := parent.Deadline(); ok { + deadlineCtx, deadlineCancel := context.WithDeadline(child, deadline) + defer deadlineCancel() + child = deadlineCtx + } // Ensure that cancellation propagates from the parent context to the child context. go func() { select { @@ -165,13 +172,6 @@ func WithCleanContext(h GenericHandlerFunc) GenericHandlerFunc { childCancel() } }() - // Also ensure that the child context has the same deadline as the parent - // context. - if deadline, ok := parent.Deadline(); ok { - deadlineCtx, deadlineCancel := context.WithDeadline(child, deadline) - defer deadlineCancel() - child = deadlineCtx - } return h(child, tb, args) } } From 2462f7606a5905b4dbc830d23d1fc2405011ae48 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Tue, 29 Apr 2025 14:25:00 +0100 Subject: [PATCH 09/11] address PR feedback --- codersdk/toolsdk/toolsdk.go | 1213 +++++++++++++++--------------- codersdk/toolsdk/toolsdk_test.go | 29 +- 2 files changed, 636 insertions(+), 606 deletions(-) diff --git a/codersdk/toolsdk/toolsdk.go b/codersdk/toolsdk/toolsdk.go index 6d625db8ba838..76f93f45cfe1f 100644 --- a/codersdk/toolsdk/toolsdk.go +++ b/codersdk/toolsdk/toolsdk.go @@ -39,12 +39,12 @@ type Tool[Arg, Ret any] struct { func (t Tool[Arg, Ret]) Generic() GenericTool { return GenericTool{ Tool: t.Tool, - Handler: wrap(func(ctx context.Context, tb Deps, args json.RawMessage) (json.RawMessage, error) { + Handler: wrap(func(ctx context.Context, deps Deps, args json.RawMessage) (json.RawMessage, error) { var typedArgs Arg if err := json.Unmarshal(args, &typedArgs); err != nil { return nil, xerrors.Errorf("failed to unmarshal args: %w", err) } - ret, err := t.Handler(ctx, tb, typedArgs) + ret, err := t.Handler(ctx, deps, typedArgs) var buf bytes.Buffer if err := json.NewEncoder(&buf).Encode(ret); err != nil { return json.RawMessage{}, err @@ -65,86 +65,18 @@ type GenericTool struct { // GenericHandlerFunc is a function that handles a tool call. type GenericHandlerFunc func(context.Context, Deps, json.RawMessage) (json.RawMessage, error) +// NoArgs just represents an empty argument struct. type NoArgs struct{} -type ReportTaskArgs struct { - Link string `json:"link"` - State string `json:"state"` - Summary string `json:"summary"` -} - -type CreateTemplateVersionArgs struct { - FileID string `json:"file_id"` - TemplateID string `json:"template_id"` -} - -type CreateTemplateArgs struct { - Description string `json:"description"` - DisplayName string `json:"display_name"` - Icon string `json:"icon"` - Name string `json:"name"` - VersionID string `json:"version_id"` -} - -type CreateWorkspaceArgs struct { - Name string `json:"name"` - RichParameters map[string]string `json:"rich_parameters"` - TemplateVersionID string `json:"template_version_id"` - User string `json:"user"` -} - -type CreateWorkspaceBuildArgs struct { - TemplateVersionID string `json:"template_version_id"` - Transition string `json:"transition"` - WorkspaceID string `json:"workspace_id"` -} - -type DeleteTemplateArgs struct { - TemplateID string `json:"template_id"` -} - -type GetTemplateVersionLogsArgs struct { - TemplateVersionID string `json:"template_version_id"` -} - -type GetWorkspaceArgs struct { - WorkspaceID string `json:"workspace_id"` -} - -type GetWorkspaceAgentLogsArgs struct { - WorkspaceAgentID string `json:"workspace_agent_id"` -} - -type GetWorkspaceBuildLogsArgs struct { - WorkspaceBuildID string `json:"workspace_build_id"` -} - -type ListWorkspacesArgs struct { - Owner string `json:"owner"` -} - -type ListTemplateVersionParametersArgs struct { - TemplateVersionID string `json:"template_version_id"` -} - -type UpdateTemplateActiveVersionArgs struct { - TemplateID string `json:"template_id"` - TemplateVersionID string `json:"template_version_id"` -} - -type UploadTarFileArgs struct { - Files map[string]string `json:"files"` -} - // WithRecover wraps a HandlerFunc to recover from panics and return an error. func WithRecover(h GenericHandlerFunc) GenericHandlerFunc { - return func(ctx context.Context, tb Deps, args json.RawMessage) (ret json.RawMessage, err error) { + return func(ctx context.Context, deps Deps, args json.RawMessage) (ret json.RawMessage, err error) { defer func() { if r := recover(); r != nil { err = xerrors.Errorf("tool handler panic: %v", r) } }() - return h(ctx, tb, args) + return h(ctx, deps, args) } } @@ -153,7 +85,7 @@ func WithRecover(h GenericHandlerFunc) GenericHandlerFunc { // If a deadline is set on the parent context, it will be passed to the child // context. func WithCleanContext(h GenericHandlerFunc) GenericHandlerFunc { - return func(parent context.Context, tb Deps, args json.RawMessage) (ret json.RawMessage, err error) { + return func(parent context.Context, deps Deps, args json.RawMessage) (ret json.RawMessage, err error) { child, childCancel := context.WithCancel(context.Background()) defer childCancel() // Ensure that the child context has the same deadline as the parent @@ -172,7 +104,7 @@ func WithCleanContext(h GenericHandlerFunc) GenericHandlerFunc { childCancel() } }() - return h(child, tb, args) + return h(child, deps, args) } } @@ -184,317 +116,355 @@ func wrap(hf GenericHandlerFunc, mw ...func(GenericHandlerFunc) GenericHandlerFu return hf } -var ( - // All is a list of all tools that can be used in the Coder CLI. - // When you add a new tool, be sure to include it here! - All = []GenericTool{ - CreateTemplate.Generic(), - CreateTemplateVersion.Generic(), - CreateWorkspace.Generic(), - CreateWorkspaceBuild.Generic(), - DeleteTemplate.Generic(), - ListTemplates.Generic(), - ListTemplateVersionParameters.Generic(), - ListWorkspaces.Generic(), - GetAuthenticatedUser.Generic(), - GetTemplateVersionLogs.Generic(), - GetWorkspaceAgentLogs.Generic(), - GetWorkspaceBuildLogs.Generic(), - GetWorkspace.Generic(), - ReportTask.Generic(), - UploadTarFile.Generic(), - UpdateTemplateActiveVersion.Generic(), - } +// All is a list of all tools that can be used in the Coder CLI. +// When you add a new tool, be sure to include it here! +var All = []GenericTool{ + CreateTemplate.Generic(), + CreateTemplateVersion.Generic(), + CreateWorkspace.Generic(), + CreateWorkspaceBuild.Generic(), + DeleteTemplate.Generic(), + ListTemplates.Generic(), + ListTemplateVersionParameters.Generic(), + ListWorkspaces.Generic(), + GetAuthenticatedUser.Generic(), + GetTemplateVersionLogs.Generic(), + GetWorkspace.Generic(), + GetWorkspaceAgentLogs.Generic(), + GetWorkspaceBuildLogs.Generic(), + ReportTask.Generic(), + UploadTarFile.Generic(), + UpdateTemplateActiveVersion.Generic(), +} - ReportTask = Tool[ReportTaskArgs, codersdk.Response]{ - Tool: aisdk.Tool{ - Name: "coder_report_task", - Description: "Report progress on a user task in Coder.", - Schema: aisdk.Schema{ - Properties: map[string]any{ - "summary": map[string]any{ - "type": "string", - "description": "A concise summary of your current progress on the task. This must be less than 160 characters in length.", - }, - "link": map[string]any{ - "type": "string", - "description": "A link to a relevant resource, such as a PR or issue.", - }, - "state": map[string]any{ - "type": "string", - "description": "The state of your task. This can be one of the following: working, complete, or failure. Select the state that best represents your current progress.", - "enum": []string{ - string(codersdk.WorkspaceAppStatusStateWorking), - string(codersdk.WorkspaceAppStatusStateComplete), - string(codersdk.WorkspaceAppStatusStateFailure), - }, +type ReportTaskArgs struct { + Link string `json:"link"` + State string `json:"state"` + Summary string `json:"summary"` +} + +var ReportTask = Tool[ReportTaskArgs, codersdk.Response]{ + Tool: aisdk.Tool{ + Name: "coder_report_task", + Description: "Report progress on a user task in Coder.", + Schema: aisdk.Schema{ + Properties: map[string]any{ + "summary": map[string]any{ + "type": "string", + "description": "A concise summary of your current progress on the task. This must be less than 160 characters in length.", + }, + "link": map[string]any{ + "type": "string", + "description": "A link to a relevant resource, such as a PR or issue.", + }, + "state": map[string]any{ + "type": "string", + "description": "The state of your task. This can be one of the following: working, complete, or failure. Select the state that best represents your current progress.", + "enum": []string{ + string(codersdk.WorkspaceAppStatusStateWorking), + string(codersdk.WorkspaceAppStatusStateComplete), + string(codersdk.WorkspaceAppStatusStateFailure), }, }, - Required: []string{"summary", "link", "state"}, }, + Required: []string{"summary", "link", "state"}, }, - Handler: func(ctx context.Context, tb Deps, args ReportTaskArgs) (codersdk.Response, error) { - if tb.AgentClient == nil { - return codersdk.Response{}, xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set") - } - if tb.AppStatusSlug == "" { - return codersdk.Response{}, xerrors.New("workspace app status slug not found in toolbox") - } - if err := tb.AgentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ - AppSlug: tb.AppStatusSlug, - Message: args.Summary, - URI: args.Link, - State: codersdk.WorkspaceAppStatusState(args.State), - }); err != nil { - return codersdk.Response{}, err - } - return codersdk.Response{ - Message: "Thanks for reporting!", - }, nil - }, - } + }, + Handler: func(ctx context.Context, deps Deps, args ReportTaskArgs) (codersdk.Response, error) { + if deps.AgentClient == nil { + return codersdk.Response{}, xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set") + } + if deps.AppStatusSlug == "" { + return codersdk.Response{}, xerrors.New("workspace app status slug not found in toolbox") + } + if len(args.Summary) > 160 { + return codersdk.Response{}, xerrors.New("summary must be less than 160 characters") + } + if err := deps.AgentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ + AppSlug: deps.AppStatusSlug, + Message: args.Summary, + URI: args.Link, + State: codersdk.WorkspaceAppStatusState(args.State), + }); err != nil { + return codersdk.Response{}, err + } + return codersdk.Response{ + Message: "Thanks for reporting!", + }, nil + }, +} + +type GetWorkspaceArgs struct { + WorkspaceID string `json:"workspace_id"` +} - GetWorkspace = Tool[GetWorkspaceArgs, codersdk.Workspace]{ - Tool: aisdk.Tool{ - Name: "coder_get_workspace", - Description: `Get a workspace by ID. +var GetWorkspace = Tool[GetWorkspaceArgs, codersdk.Workspace]{ + Tool: aisdk.Tool{ + Name: "coder_get_workspace", + Description: `Get a workspace by ID. This returns more data than list_workspaces to reduce token usage.`, - Schema: aisdk.Schema{ - Properties: map[string]any{ - "workspace_id": map[string]any{ - "type": "string", - }, + Schema: aisdk.Schema{ + Properties: map[string]any{ + "workspace_id": map[string]any{ + "type": "string", }, - Required: []string{"workspace_id"}, }, + Required: []string{"workspace_id"}, }, - Handler: func(ctx context.Context, tb Deps, args GetWorkspaceArgs) (codersdk.Workspace, error) { - wsID, err := uuid.Parse(args.WorkspaceID) - if err != nil { - return codersdk.Workspace{}, xerrors.New("workspace_id must be a valid UUID") - } - return tb.CoderClient.Workspace(ctx, wsID) - }, - } + }, + Handler: func(ctx context.Context, deps Deps, args GetWorkspaceArgs) (codersdk.Workspace, error) { + wsID, err := uuid.Parse(args.WorkspaceID) + if err != nil { + return codersdk.Workspace{}, xerrors.New("workspace_id must be a valid UUID") + } + return deps.CoderClient.Workspace(ctx, wsID) + }, +} - CreateWorkspace = Tool[CreateWorkspaceArgs, codersdk.Workspace]{ - Tool: aisdk.Tool{ - Name: "coder_create_workspace", - Description: `Create a new workspace in Coder. +type CreateWorkspaceArgs struct { + Name string `json:"name"` + RichParameters map[string]string `json:"rich_parameters"` + TemplateVersionID string `json:"template_version_id"` + User string `json:"user"` +} + +var CreateWorkspace = Tool[CreateWorkspaceArgs, codersdk.Workspace]{ + Tool: aisdk.Tool{ + Name: "coder_create_workspace", + Description: `Create a new workspace in Coder. If a user is asking to "test a template", they are typically referring to creating a workspace from a template to ensure the infrastructure is provisioned correctly and the agent can connect to the control plane. `, - Schema: aisdk.Schema{ - Properties: map[string]any{ - "user": map[string]any{ - "type": "string", - "description": "Username or ID of the user to create the workspace for. Use the `me` keyword to create a workspace for the authenticated user.", - }, - "template_version_id": map[string]any{ - "type": "string", - "description": "ID of the template version to create the workspace from.", - }, - "name": map[string]any{ - "type": "string", - "description": "Name of the workspace to create.", - }, - "rich_parameters": map[string]any{ - "type": "object", - "description": "Key/value pairs of rich parameters to pass to the template version to create the workspace.", - }, + Schema: aisdk.Schema{ + Properties: map[string]any{ + "user": map[string]any{ + "type": "string", + "description": "Username or ID of the user to create the workspace for. Use the `me` keyword to create a workspace for the authenticated user.", + }, + "template_version_id": map[string]any{ + "type": "string", + "description": "ID of the template version to create the workspace from.", + }, + "name": map[string]any{ + "type": "string", + "description": "Name of the workspace to create.", + }, + "rich_parameters": map[string]any{ + "type": "object", + "description": "Key/value pairs of rich parameters to pass to the template version to create the workspace.", }, - Required: []string{"user", "template_version_id", "name", "rich_parameters"}, }, + Required: []string{"user", "template_version_id", "name", "rich_parameters"}, }, - Handler: func(ctx context.Context, tb Deps, args CreateWorkspaceArgs) (codersdk.Workspace, error) { - tvID, err := uuid.Parse(args.TemplateVersionID) - if err != nil { - return codersdk.Workspace{}, xerrors.New("template_version_id must be a valid UUID") - } - if args.User == "" { - args.User = codersdk.Me - } - var buildParams []codersdk.WorkspaceBuildParameter - for k, v := range args.RichParameters { - buildParams = append(buildParams, codersdk.WorkspaceBuildParameter{ - Name: k, - Value: v, - }) - } - workspace, err := tb.CoderClient.CreateUserWorkspace(ctx, args.User, codersdk.CreateWorkspaceRequest{ - TemplateVersionID: tvID, - Name: args.Name, - RichParameterValues: buildParams, + }, + Handler: func(ctx context.Context, deps Deps, args CreateWorkspaceArgs) (codersdk.Workspace, error) { + tvID, err := uuid.Parse(args.TemplateVersionID) + if err != nil { + return codersdk.Workspace{}, xerrors.New("template_version_id must be a valid UUID") + } + if args.User == "" { + args.User = codersdk.Me + } + var buildParams []codersdk.WorkspaceBuildParameter + for k, v := range args.RichParameters { + buildParams = append(buildParams, codersdk.WorkspaceBuildParameter{ + Name: k, + Value: v, }) - if err != nil { - return codersdk.Workspace{}, err - } - return workspace, nil - }, - } + } + workspace, err := deps.CoderClient.CreateUserWorkspace(ctx, args.User, codersdk.CreateWorkspaceRequest{ + TemplateVersionID: tvID, + Name: args.Name, + RichParameterValues: buildParams, + }) + if err != nil { + return codersdk.Workspace{}, err + } + return workspace, nil + }, +} - ListWorkspaces = Tool[ListWorkspacesArgs, []MinimalWorkspace]{ - Tool: aisdk.Tool{ - Name: "coder_list_workspaces", - Description: "Lists workspaces for the authenticated user.", - Schema: aisdk.Schema{ - Properties: map[string]any{ - "owner": map[string]any{ - "type": "string", - "description": "The owner of the workspaces to list. Use \"me\" to list workspaces for the authenticated user. If you do not specify an owner, \"me\" will be assumed by default.", - }, +type ListWorkspacesArgs struct { + Owner string `json:"owner"` +} + +var ListWorkspaces = Tool[ListWorkspacesArgs, []MinimalWorkspace]{ + Tool: aisdk.Tool{ + Name: "coder_list_workspaces", + Description: "Lists workspaces for the authenticated user.", + Schema: aisdk.Schema{ + Properties: map[string]any{ + "owner": map[string]any{ + "type": "string", + "description": "The owner of the workspaces to list. Use \"me\" to list workspaces for the authenticated user. If you do not specify an owner, \"me\" will be assumed by default.", }, }, }, - Handler: func(ctx context.Context, tb Deps, args ListWorkspacesArgs) ([]MinimalWorkspace, error) { - owner := args.Owner - if owner == "" { - owner = codersdk.Me - } - workspaces, err := tb.CoderClient.Workspaces(ctx, codersdk.WorkspaceFilter{ - Owner: owner, - }) - if err != nil { - return nil, err - } - minimalWorkspaces := make([]MinimalWorkspace, len(workspaces.Workspaces)) - for i, workspace := range workspaces.Workspaces { - minimalWorkspaces[i] = MinimalWorkspace{ - ID: workspace.ID.String(), - Name: workspace.Name, - TemplateID: workspace.TemplateID.String(), - TemplateName: workspace.TemplateName, - TemplateDisplayName: workspace.TemplateDisplayName, - TemplateIcon: workspace.TemplateIcon, - TemplateActiveVersionID: workspace.TemplateActiveVersionID, - Outdated: workspace.Outdated, - } + }, + Handler: func(ctx context.Context, deps Deps, args ListWorkspacesArgs) ([]MinimalWorkspace, error) { + owner := args.Owner + if owner == "" { + owner = codersdk.Me + } + workspaces, err := deps.CoderClient.Workspaces(ctx, codersdk.WorkspaceFilter{ + Owner: owner, + }) + if err != nil { + return nil, err + } + minimalWorkspaces := make([]MinimalWorkspace, len(workspaces.Workspaces)) + for i, workspace := range workspaces.Workspaces { + minimalWorkspaces[i] = MinimalWorkspace{ + ID: workspace.ID.String(), + Name: workspace.Name, + TemplateID: workspace.TemplateID.String(), + TemplateName: workspace.TemplateName, + TemplateDisplayName: workspace.TemplateDisplayName, + TemplateIcon: workspace.TemplateIcon, + TemplateActiveVersionID: workspace.TemplateActiveVersionID, + Outdated: workspace.Outdated, } - return minimalWorkspaces, nil - }, - } + } + return minimalWorkspaces, nil + }, +} - ListTemplates = Tool[NoArgs, []MinimalTemplate]{ - Tool: aisdk.Tool{ - Name: "coder_list_templates", - Description: "Lists templates for the authenticated user.", - Schema: aisdk.Schema{ - Properties: map[string]any{}, - Required: []string{}, - }, +var ListTemplates = Tool[NoArgs, []MinimalTemplate]{ + Tool: aisdk.Tool{ + Name: "coder_list_templates", + Description: "Lists templates for the authenticated user.", + Schema: aisdk.Schema{ + Properties: map[string]any{}, + Required: []string{}, }, - Handler: func(ctx context.Context, tb Deps, _ NoArgs) ([]MinimalTemplate, error) { - templates, err := tb.CoderClient.Templates(ctx, codersdk.TemplateFilter{}) - if err != nil { - return nil, err - } - minimalTemplates := make([]MinimalTemplate, len(templates)) - for i, template := range templates { - minimalTemplates[i] = MinimalTemplate{ - DisplayName: template.DisplayName, - ID: template.ID.String(), - Name: template.Name, - Description: template.Description, - ActiveVersionID: template.ActiveVersionID, - ActiveUserCount: template.ActiveUserCount, - } + }, + Handler: func(ctx context.Context, deps Deps, _ NoArgs) ([]MinimalTemplate, error) { + templates, err := deps.CoderClient.Templates(ctx, codersdk.TemplateFilter{}) + if err != nil { + return nil, err + } + minimalTemplates := make([]MinimalTemplate, len(templates)) + for i, template := range templates { + minimalTemplates[i] = MinimalTemplate{ + DisplayName: template.DisplayName, + ID: template.ID.String(), + Name: template.Name, + Description: template.Description, + ActiveVersionID: template.ActiveVersionID, + ActiveUserCount: template.ActiveUserCount, } - return minimalTemplates, nil - }, - } + } + return minimalTemplates, nil + }, +} - ListTemplateVersionParameters = Tool[ListTemplateVersionParametersArgs, []codersdk.TemplateVersionParameter]{ - Tool: aisdk.Tool{ - Name: "coder_template_version_parameters", - Description: "Get the parameters for a template version. You can refer to these as workspace parameters to the user, as they are typically important for creating a workspace.", - Schema: aisdk.Schema{ - Properties: map[string]any{ - "template_version_id": map[string]any{ - "type": "string", - }, +type ListTemplateVersionParametersArgs struct { + TemplateVersionID string `json:"template_version_id"` +} + +var ListTemplateVersionParameters = Tool[ListTemplateVersionParametersArgs, []codersdk.TemplateVersionParameter]{ + Tool: aisdk.Tool{ + Name: "coder_template_version_parameters", + Description: "Get the parameters for a template version. You can refer to these as workspace parameters to the user, as they are typically important for creating a workspace.", + Schema: aisdk.Schema{ + Properties: map[string]any{ + "template_version_id": map[string]any{ + "type": "string", }, - Required: []string{"template_version_id"}, }, + Required: []string{"template_version_id"}, }, - Handler: func(ctx context.Context, tb Deps, args ListTemplateVersionParametersArgs) ([]codersdk.TemplateVersionParameter, error) { - templateVersionID, err := uuid.Parse(args.TemplateVersionID) - if err != nil { - return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) - } - parameters, err := tb.CoderClient.TemplateVersionRichParameters(ctx, templateVersionID) - if err != nil { - return nil, err - } - return parameters, nil - }, - } + }, + Handler: func(ctx context.Context, deps Deps, args ListTemplateVersionParametersArgs) ([]codersdk.TemplateVersionParameter, error) { + templateVersionID, err := uuid.Parse(args.TemplateVersionID) + if err != nil { + return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) + } + parameters, err := deps.CoderClient.TemplateVersionRichParameters(ctx, templateVersionID) + if err != nil { + return nil, err + } + return parameters, nil + }, +} - GetAuthenticatedUser = Tool[NoArgs, codersdk.User]{ - Tool: aisdk.Tool{ - Name: "coder_get_authenticated_user", - Description: "Get the currently authenticated user, similar to the `whoami` command.", - Schema: aisdk.Schema{ - Properties: map[string]any{}, - Required: []string{}, - }, - }, - Handler: func(ctx context.Context, tb Deps, _ NoArgs) (codersdk.User, error) { - return tb.CoderClient.User(ctx, "me") +var GetAuthenticatedUser = Tool[NoArgs, codersdk.User]{ + Tool: aisdk.Tool{ + Name: "coder_get_authenticated_user", + Description: "Get the currently authenticated user, similar to the `whoami` command.", + Schema: aisdk.Schema{ + Properties: map[string]any{}, + Required: []string{}, }, - } + }, + Handler: func(ctx context.Context, deps Deps, _ NoArgs) (codersdk.User, error) { + return deps.CoderClient.User(ctx, "me") + }, +} - CreateWorkspaceBuild = Tool[CreateWorkspaceBuildArgs, codersdk.WorkspaceBuild]{ - Tool: aisdk.Tool{ - Name: "coder_create_workspace_build", - Description: "Create a new workspace build for an existing workspace. Use this to start, stop, or delete.", - Schema: aisdk.Schema{ - Properties: map[string]any{ - "workspace_id": map[string]any{ - "type": "string", - }, - "transition": map[string]any{ - "type": "string", - "description": "The transition to perform. Must be one of: start, stop, delete", - "enum": []string{"start", "stop", "delete"}, - }, - "template_version_id": map[string]any{ - "type": "string", - "description": "(Optional) The template version ID to use for the workspace build. If not provided, the previously built version will be used.", - }, +type CreateWorkspaceBuildArgs struct { + TemplateVersionID string `json:"template_version_id"` + Transition string `json:"transition"` + WorkspaceID string `json:"workspace_id"` +} + +var CreateWorkspaceBuild = Tool[CreateWorkspaceBuildArgs, codersdk.WorkspaceBuild]{ + Tool: aisdk.Tool{ + Name: "coder_create_workspace_build", + Description: "Create a new workspace build for an existing workspace. Use this to start, stop, or delete.", + Schema: aisdk.Schema{ + Properties: map[string]any{ + "workspace_id": map[string]any{ + "type": "string", + }, + "transition": map[string]any{ + "type": "string", + "description": "The transition to perform. Must be one of: start, stop, delete", + "enum": []string{"start", "stop", "delete"}, + }, + "template_version_id": map[string]any{ + "type": "string", + "description": "(Optional) The template version ID to use for the workspace build. If not provided, the previously built version will be used.", }, - Required: []string{"workspace_id", "transition"}, }, + Required: []string{"workspace_id", "transition"}, }, - Handler: func(ctx context.Context, tb Deps, args CreateWorkspaceBuildArgs) (codersdk.WorkspaceBuild, error) { - workspaceID, err := uuid.Parse(args.WorkspaceID) + }, + Handler: func(ctx context.Context, deps Deps, args CreateWorkspaceBuildArgs) (codersdk.WorkspaceBuild, error) { + workspaceID, err := uuid.Parse(args.WorkspaceID) + if err != nil { + return codersdk.WorkspaceBuild{}, xerrors.Errorf("workspace_id must be a valid UUID: %w", err) + } + var templateVersionID uuid.UUID + if args.TemplateVersionID != "" { + tvID, err := uuid.Parse(args.TemplateVersionID) if err != nil { - return codersdk.WorkspaceBuild{}, xerrors.Errorf("workspace_id must be a valid UUID: %w", err) - } - var templateVersionID uuid.UUID - if args.TemplateVersionID != "" { - tvID, err := uuid.Parse(args.TemplateVersionID) - if err != nil { - return codersdk.WorkspaceBuild{}, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) - } - templateVersionID = tvID - } - cbr := codersdk.CreateWorkspaceBuildRequest{ - Transition: codersdk.WorkspaceTransition(args.Transition), + return codersdk.WorkspaceBuild{}, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) } - if templateVersionID != uuid.Nil { - cbr.TemplateVersionID = templateVersionID - } - return tb.CoderClient.CreateWorkspaceBuild(ctx, workspaceID, cbr) - }, - } + templateVersionID = tvID + } + cbr := codersdk.CreateWorkspaceBuildRequest{ + Transition: codersdk.WorkspaceTransition(args.Transition), + } + if templateVersionID != uuid.Nil { + cbr.TemplateVersionID = templateVersionID + } + return deps.CoderClient.CreateWorkspaceBuild(ctx, workspaceID, cbr) + }, +} - CreateTemplateVersion = Tool[CreateTemplateVersionArgs, codersdk.TemplateVersion]{ - Tool: aisdk.Tool{ - Name: "coder_create_template_version", - Description: `Create a new template version. This is a precursor to creating a template, or you can update an existing template. +type CreateTemplateVersionArgs struct { + FileID string `json:"file_id"` + TemplateID string `json:"template_id"` +} + +var CreateTemplateVersion = Tool[CreateTemplateVersionArgs, codersdk.TemplateVersion]{ + Tool: aisdk.Tool{ + Name: "coder_create_template_version", + Description: `Create a new template version. This is a precursor to creating a template, or you can update an existing template. Templates are Terraform defining a development environment. The provisioned infrastructure must run an Agent that connects to the Coder Control Plane to provide a rich experience. @@ -937,307 +907,346 @@ resource "kubernetes_deployment" "main" { The file_id provided is a reference to a tar file you have uploaded containing the Terraform. `, - Schema: aisdk.Schema{ - Properties: map[string]any{ - "template_id": map[string]any{ - "type": "string", - }, - "file_id": map[string]any{ - "type": "string", - }, + Schema: aisdk.Schema{ + Properties: map[string]any{ + "template_id": map[string]any{ + "type": "string", + }, + "file_id": map[string]any{ + "type": "string", }, - Required: []string{"file_id"}, }, + Required: []string{"file_id"}, }, - Handler: func(ctx context.Context, tb Deps, args CreateTemplateVersionArgs) (codersdk.TemplateVersion, error) { - me, err := tb.CoderClient.User(ctx, "me") - if err != nil { - return codersdk.TemplateVersion{}, err - } - fileID, err := uuid.Parse(args.FileID) - if err != nil { - return codersdk.TemplateVersion{}, xerrors.Errorf("file_id must be a valid UUID: %w", err) - } - var templateID uuid.UUID - if args.TemplateID != "" { - tid, err := uuid.Parse(args.TemplateID) - if err != nil { - return codersdk.TemplateVersion{}, xerrors.Errorf("template_id must be a valid UUID: %w", err) - } - templateID = tid - } - templateVersion, err := tb.CoderClient.CreateTemplateVersion(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{ - Message: "Created by AI", - StorageMethod: codersdk.ProvisionerStorageMethodFile, - FileID: fileID, - Provisioner: codersdk.ProvisionerTypeTerraform, - TemplateID: templateID, - }) + }, + Handler: func(ctx context.Context, deps Deps, args CreateTemplateVersionArgs) (codersdk.TemplateVersion, error) { + me, err := deps.CoderClient.User(ctx, "me") + if err != nil { + return codersdk.TemplateVersion{}, err + } + fileID, err := uuid.Parse(args.FileID) + if err != nil { + return codersdk.TemplateVersion{}, xerrors.Errorf("file_id must be a valid UUID: %w", err) + } + var templateID uuid.UUID + if args.TemplateID != "" { + tid, err := uuid.Parse(args.TemplateID) if err != nil { - return codersdk.TemplateVersion{}, err + return codersdk.TemplateVersion{}, xerrors.Errorf("template_id must be a valid UUID: %w", err) } - return templateVersion, nil - }, - } + templateID = tid + } + templateVersion, err := deps.CoderClient.CreateTemplateVersion(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{ + Message: "Created by AI", + StorageMethod: codersdk.ProvisionerStorageMethodFile, + FileID: fileID, + Provisioner: codersdk.ProvisionerTypeTerraform, + TemplateID: templateID, + }) + if err != nil { + return codersdk.TemplateVersion{}, err + } + return templateVersion, nil + }, +} + +type GetWorkspaceAgentLogsArgs struct { + WorkspaceAgentID string `json:"workspace_agent_id"` +} - GetWorkspaceAgentLogs = Tool[GetWorkspaceAgentLogsArgs, []string]{ - Tool: aisdk.Tool{ - Name: "coder_get_workspace_agent_logs", - Description: `Get the logs of a workspace agent. +var GetWorkspaceAgentLogs = Tool[GetWorkspaceAgentLogsArgs, []string]{ + Tool: aisdk.Tool{ + Name: "coder_get_workspace_agent_logs", + Description: `Get the logs of a workspace agent. More logs may appear after this call. It does not wait for the agent to finish.`, - Schema: aisdk.Schema{ - Properties: map[string]any{ - "workspace_agent_id": map[string]any{ - "type": "string", - }, + Schema: aisdk.Schema{ + Properties: map[string]any{ + "workspace_agent_id": map[string]any{ + "type": "string", }, - Required: []string{"workspace_agent_id"}, }, + Required: []string{"workspace_agent_id"}, }, - Handler: func(ctx context.Context, tb Deps, args GetWorkspaceAgentLogsArgs) ([]string, error) { - workspaceAgentID, err := uuid.Parse(args.WorkspaceAgentID) - if err != nil { - return nil, xerrors.Errorf("workspace_agent_id must be a valid UUID: %w", err) - } - logs, closer, err := tb.CoderClient.WorkspaceAgentLogsAfter(ctx, workspaceAgentID, 0, false) - if err != nil { - return nil, err - } - defer closer.Close() - var acc []string - for logChunk := range logs { - for _, log := range logChunk { - acc = append(acc, log.Output) - } + }, + Handler: func(ctx context.Context, deps Deps, args GetWorkspaceAgentLogsArgs) ([]string, error) { + workspaceAgentID, err := uuid.Parse(args.WorkspaceAgentID) + if err != nil { + return nil, xerrors.Errorf("workspace_agent_id must be a valid UUID: %w", err) + } + logs, closer, err := deps.CoderClient.WorkspaceAgentLogsAfter(ctx, workspaceAgentID, 0, false) + if err != nil { + return nil, err + } + defer closer.Close() + var acc []string + for logChunk := range logs { + for _, log := range logChunk { + acc = append(acc, log.Output) } - return acc, nil - }, - } + } + return acc, nil + }, +} + +type GetWorkspaceBuildLogsArgs struct { + WorkspaceBuildID string `json:"workspace_build_id"` +} - GetWorkspaceBuildLogs = Tool[GetWorkspaceBuildLogsArgs, []string]{ - Tool: aisdk.Tool{ - Name: "coder_get_workspace_build_logs", - Description: `Get the logs of a workspace build. +var GetWorkspaceBuildLogs = Tool[GetWorkspaceBuildLogsArgs, []string]{ + Tool: aisdk.Tool{ + Name: "coder_get_workspace_build_logs", + Description: `Get the logs of a workspace build. Useful for checking whether a workspace builds successfully or not.`, - Schema: aisdk.Schema{ - Properties: map[string]any{ - "workspace_build_id": map[string]any{ - "type": "string", - }, + Schema: aisdk.Schema{ + Properties: map[string]any{ + "workspace_build_id": map[string]any{ + "type": "string", }, - Required: []string{"workspace_build_id"}, }, + Required: []string{"workspace_build_id"}, }, - Handler: func(ctx context.Context, tb Deps, args GetWorkspaceBuildLogsArgs) ([]string, error) { - workspaceBuildID, err := uuid.Parse(args.WorkspaceBuildID) - if err != nil { - return nil, xerrors.Errorf("workspace_build_id must be a valid UUID: %w", err) - } - logs, closer, err := tb.CoderClient.WorkspaceBuildLogsAfter(ctx, workspaceBuildID, 0) - if err != nil { - return nil, err - } - defer closer.Close() - var acc []string - for log := range logs { - acc = append(acc, log.Output) - } - return acc, nil - }, - } + }, + Handler: func(ctx context.Context, deps Deps, args GetWorkspaceBuildLogsArgs) ([]string, error) { + workspaceBuildID, err := uuid.Parse(args.WorkspaceBuildID) + if err != nil { + return nil, xerrors.Errorf("workspace_build_id must be a valid UUID: %w", err) + } + logs, closer, err := deps.CoderClient.WorkspaceBuildLogsAfter(ctx, workspaceBuildID, 0) + if err != nil { + return nil, err + } + defer closer.Close() + var acc []string + for log := range logs { + acc = append(acc, log.Output) + } + return acc, nil + }, +} - GetTemplateVersionLogs = Tool[GetTemplateVersionLogsArgs, []string]{ - Tool: aisdk.Tool{ - Name: "coder_get_template_version_logs", - Description: "Get the logs of a template version. This is useful to check whether a template version successfully imports or not.", - Schema: aisdk.Schema{ - Properties: map[string]any{ - "template_version_id": map[string]any{ - "type": "string", - }, +type GetTemplateVersionLogsArgs struct { + TemplateVersionID string `json:"template_version_id"` +} + +var GetTemplateVersionLogs = Tool[GetTemplateVersionLogsArgs, []string]{ + Tool: aisdk.Tool{ + Name: "coder_get_template_version_logs", + Description: "Get the logs of a template version. This is useful to check whether a template version successfully imports or not.", + Schema: aisdk.Schema{ + Properties: map[string]any{ + "template_version_id": map[string]any{ + "type": "string", }, - Required: []string{"template_version_id"}, }, + Required: []string{"template_version_id"}, }, - Handler: func(ctx context.Context, tb Deps, args GetTemplateVersionLogsArgs) ([]string, error) { - templateVersionID, err := uuid.Parse(args.TemplateVersionID) - if err != nil { - return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) - } + }, + Handler: func(ctx context.Context, deps Deps, args GetTemplateVersionLogsArgs) ([]string, error) { + templateVersionID, err := uuid.Parse(args.TemplateVersionID) + if err != nil { + return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) + } - logs, closer, err := tb.CoderClient.TemplateVersionLogsAfter(ctx, templateVersionID, 0) - if err != nil { - return nil, err - } - defer closer.Close() - var acc []string - for log := range logs { - acc = append(acc, log.Output) - } - return acc, nil - }, - } + logs, closer, err := deps.CoderClient.TemplateVersionLogsAfter(ctx, templateVersionID, 0) + if err != nil { + return nil, err + } + defer closer.Close() + var acc []string + for log := range logs { + acc = append(acc, log.Output) + } + return acc, nil + }, +} - UpdateTemplateActiveVersion = Tool[UpdateTemplateActiveVersionArgs, string]{ - Tool: aisdk.Tool{ - Name: "coder_update_template_active_version", - Description: "Update the active version of a template. This is helpful when iterating on templates.", - Schema: aisdk.Schema{ - Properties: map[string]any{ - "template_id": map[string]any{ - "type": "string", - }, - "template_version_id": map[string]any{ - "type": "string", - }, +type UpdateTemplateActiveVersionArgs struct { + TemplateID string `json:"template_id"` + TemplateVersionID string `json:"template_version_id"` +} + +var UpdateTemplateActiveVersion = Tool[UpdateTemplateActiveVersionArgs, string]{ + Tool: aisdk.Tool{ + Name: "coder_update_template_active_version", + Description: "Update the active version of a template. This is helpful when iterating on templates.", + Schema: aisdk.Schema{ + Properties: map[string]any{ + "template_id": map[string]any{ + "type": "string", + }, + "template_version_id": map[string]any{ + "type": "string", }, - Required: []string{"template_id", "template_version_id"}, }, + Required: []string{"template_id", "template_version_id"}, }, - Handler: func(ctx context.Context, tb Deps, args UpdateTemplateActiveVersionArgs) (string, error) { - templateID, err := uuid.Parse(args.TemplateID) - if err != nil { - return "", xerrors.Errorf("template_id must be a valid UUID: %w", err) - } - templateVersionID, err := uuid.Parse(args.TemplateVersionID) - if err != nil { - return "", xerrors.Errorf("template_version_id must be a valid UUID: %w", err) - } - err = tb.CoderClient.UpdateActiveTemplateVersion(ctx, templateID, codersdk.UpdateActiveTemplateVersion{ - ID: templateVersionID, - }) - if err != nil { - return "", err - } - return "Successfully updated active version!", nil - }, - } + }, + Handler: func(ctx context.Context, deps Deps, args UpdateTemplateActiveVersionArgs) (string, error) { + templateID, err := uuid.Parse(args.TemplateID) + if err != nil { + return "", xerrors.Errorf("template_id must be a valid UUID: %w", err) + } + templateVersionID, err := uuid.Parse(args.TemplateVersionID) + if err != nil { + return "", xerrors.Errorf("template_version_id must be a valid UUID: %w", err) + } + err = deps.CoderClient.UpdateActiveTemplateVersion(ctx, templateID, codersdk.UpdateActiveTemplateVersion{ + ID: templateVersionID, + }) + if err != nil { + return "", err + } + return "Successfully updated active version!", nil + }, +} - UploadTarFile = Tool[UploadTarFileArgs, codersdk.UploadResponse]{ - Tool: aisdk.Tool{ - Name: "coder_upload_tar_file", - Description: `Create and upload a tar file by key/value mapping of file names to file contents. Use this to create template versions. Reference the tool description of "create_template_version" to understand template requirements.`, - Schema: aisdk.Schema{ - Properties: map[string]any{ - "files": map[string]any{ - "type": "object", - "description": "A map of file names to file contents.", - }, +type UploadTarFileArgs struct { + Files map[string]string `json:"files"` +} + +var UploadTarFile = Tool[UploadTarFileArgs, codersdk.UploadResponse]{ + Tool: aisdk.Tool{ + Name: "coder_upload_tar_file", + Description: `Create and upload a tar file by key/value mapping of file names to file contents. Use this to create template versions. Reference the tool description of "create_template_version" to understand template requirements.`, + Schema: aisdk.Schema{ + Properties: map[string]any{ + "files": map[string]any{ + "type": "object", + "description": "A map of file names to file contents.", }, - Required: []string{"mime_type", "files"}, }, + Required: []string{"files"}, }, - Handler: func(ctx context.Context, tb Deps, args UploadTarFileArgs) (codersdk.UploadResponse, error) { - pipeReader, pipeWriter := io.Pipe() - go func() { - defer pipeWriter.Close() - tarWriter := tar.NewWriter(pipeWriter) - for name, content := range args.Files { - header := &tar.Header{ - Name: name, - Size: int64(len(content)), - Mode: 0o644, - } - if err := tarWriter.WriteHeader(header); err != nil { - _ = pipeWriter.CloseWithError(err) - return - } - if _, err := tarWriter.Write([]byte(content)); err != nil { - _ = pipeWriter.CloseWithError(err) - return - } + }, + Handler: func(ctx context.Context, deps Deps, args UploadTarFileArgs) (codersdk.UploadResponse, error) { + pipeReader, pipeWriter := io.Pipe() + done := make(chan struct{}) + go func() { + defer func() { + _ = pipeWriter.Close() + close(done) + }() + tarWriter := tar.NewWriter(pipeWriter) + for name, content := range args.Files { + header := &tar.Header{ + Name: name, + Size: int64(len(content)), + Mode: 0o644, } - if err := tarWriter.Close(); err != nil { + if err := tarWriter.WriteHeader(header); err != nil { _ = pipeWriter.CloseWithError(err) + return + } + if _, err := tarWriter.Write([]byte(content)); err != nil { + _ = pipeWriter.CloseWithError(err) + return } - }() - - resp, err := tb.CoderClient.Upload(ctx, codersdk.ContentTypeTar, pipeReader) - if err != nil { - return codersdk.UploadResponse{}, err } - return resp, nil - }, - } + if err := tarWriter.Close(); err != nil { + _ = pipeWriter.CloseWithError(err) + } + }() - CreateTemplate = Tool[CreateTemplateArgs, codersdk.Template]{ - Tool: aisdk.Tool{ - Name: "coder_create_template", - Description: "Create a new template in Coder. First, you must create a template version.", - Schema: aisdk.Schema{ - Properties: map[string]any{ - "name": map[string]any{ - "type": "string", - }, - "display_name": map[string]any{ - "type": "string", - }, - "description": map[string]any{ - "type": "string", - }, - "icon": map[string]any{ - "type": "string", - "description": "A URL to an icon to use.", - }, - "version_id": map[string]any{ - "type": "string", - "description": "The ID of the version to use.", - }, + resp, err := deps.CoderClient.Upload(ctx, codersdk.ContentTypeTar, pipeReader) + if err != nil { + _ = pipeReader.CloseWithError(err) + <-done + return codersdk.UploadResponse{}, err + } + <-done + return resp, nil + }, +} + +type CreateTemplateArgs struct { + Description string `json:"description"` + DisplayName string `json:"display_name"` + Icon string `json:"icon"` + Name string `json:"name"` + VersionID string `json:"version_id"` +} + +var CreateTemplate = Tool[CreateTemplateArgs, codersdk.Template]{ + Tool: aisdk.Tool{ + Name: "coder_create_template", + Description: "Create a new template in Coder. First, you must create a template version.", + Schema: aisdk.Schema{ + Properties: map[string]any{ + "name": map[string]any{ + "type": "string", + }, + "display_name": map[string]any{ + "type": "string", + }, + "description": map[string]any{ + "type": "string", + }, + "icon": map[string]any{ + "type": "string", + "description": "A URL to an icon to use.", + }, + "version_id": map[string]any{ + "type": "string", + "description": "The ID of the version to use.", }, - Required: []string{"name", "display_name", "description", "version_id"}, }, + Required: []string{"name", "display_name", "description", "version_id"}, }, - Handler: func(ctx context.Context, tb Deps, args CreateTemplateArgs) (codersdk.Template, error) { - me, err := tb.CoderClient.User(ctx, "me") - if err != nil { - return codersdk.Template{}, err - } - versionID, err := uuid.Parse(args.VersionID) - if err != nil { - return codersdk.Template{}, xerrors.Errorf("version_id must be a valid UUID: %w", err) - } - template, err := tb.CoderClient.CreateTemplate(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateRequest{ - Name: args.Name, - DisplayName: args.DisplayName, - Description: args.Description, - VersionID: versionID, - }) - if err != nil { - return codersdk.Template{}, err - } - return template, nil - }, - } + }, + Handler: func(ctx context.Context, deps Deps, args CreateTemplateArgs) (codersdk.Template, error) { + me, err := deps.CoderClient.User(ctx, "me") + if err != nil { + return codersdk.Template{}, err + } + versionID, err := uuid.Parse(args.VersionID) + if err != nil { + return codersdk.Template{}, xerrors.Errorf("version_id must be a valid UUID: %w", err) + } + template, err := deps.CoderClient.CreateTemplate(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateRequest{ + Name: args.Name, + DisplayName: args.DisplayName, + Description: args.Description, + VersionID: versionID, + }) + if err != nil { + return codersdk.Template{}, err + } + return template, nil + }, +} - DeleteTemplate = Tool[DeleteTemplateArgs, codersdk.Response]{ - Tool: aisdk.Tool{ - Name: "coder_delete_template", - Description: "Delete a template. This is irreversible.", - Schema: aisdk.Schema{ - Properties: map[string]any{ - "template_id": map[string]any{ - "type": "string", - }, +type DeleteTemplateArgs struct { + TemplateID string `json:"template_id"` +} + +var DeleteTemplate = Tool[DeleteTemplateArgs, codersdk.Response]{ + Tool: aisdk.Tool{ + Name: "coder_delete_template", + Description: "Delete a template. This is irreversible.", + Schema: aisdk.Schema{ + Properties: map[string]any{ + "template_id": map[string]any{ + "type": "string", }, }, }, - Handler: func(ctx context.Context, tb Deps, args DeleteTemplateArgs) (codersdk.Response, error) { - templateID, err := uuid.Parse(args.TemplateID) - if err != nil { - return codersdk.Response{}, xerrors.Errorf("template_id must be a valid UUID: %w", err) - } - err = tb.CoderClient.DeleteTemplate(ctx, templateID) - if err != nil { - return codersdk.Response{}, err - } - return codersdk.Response{ - Message: "Template deleted successfully.", - }, nil - }, - } -) + }, + Handler: func(ctx context.Context, deps Deps, args DeleteTemplateArgs) (codersdk.Response, error) { + templateID, err := uuid.Parse(args.TemplateID) + if err != nil { + return codersdk.Response{}, xerrors.Errorf("template_id must be a valid UUID: %w", err) + } + err = deps.CoderClient.DeleteTemplate(ctx, templateID) + if err != nil { + return codersdk.Response{}, err + } + return codersdk.Response{ + Message: "Template deleted successfully.", + }, nil + }, +} type MinimalWorkspace struct { ID string `json:"id"` diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index 51cb25c7d88b3..36d1733663de4 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -13,6 +13,7 @@ import ( "github.com/kylecarbs/aisdk-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/goleak" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" @@ -457,12 +458,14 @@ func TestWithCleanContext(t *testing.T) { t.Parallel() // This test is to ensure that the context is canceled properly. + callCh := make(chan struct{}) ctxTool := toolsdk.GenericTool{ Tool: aisdk.Tool{ Name: "context_tool", Description: "Returns the context value for testing.", }, Handler: func(toolCtx context.Context, tb toolsdk.Deps, args json.RawMessage) (json.RawMessage, error) { + defer close(callCh) // Wait for the context to be canceled <-toolCtx.Done() return nil, toolCtx.Err() @@ -471,6 +474,7 @@ func TestWithCleanContext(t *testing.T) { wrapped := toolsdk.WithCleanContext(ctxTool.Handler) errCh := make(chan error, 1) + tCtx := testutil.Context(t, testutil.WaitShort) ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) go func() { @@ -479,12 +483,21 @@ func TestWithCleanContext(t *testing.T) { }() cancel() + + // Ensure the tool is called select { - case <-t.Context().Done(): + case <-callCh: + case <-tCtx.Done(): + require.Fail(t, "test timed out before handler was called") + } + + // Ensure the correct error is returned + select { + case <-tCtx.Done(): require.Fail(t, "test timed out") case err := <-errCh: - require.ErrorIs(t, err, context.Canceled) // Context was canceled and the done channel was closed + require.ErrorIs(t, err, context.Canceled) } }) @@ -506,7 +519,7 @@ func TestWithCleanContext(t *testing.T) { } wrapped := toolsdk.WithCleanContext(ctxTool.Handler) - parent, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + parent, cancel := context.WithTimeout(context.Background(), testutil.IntervalFast) t.Cleanup(cancel) _, err := wrapped(parent, toolsdk.Deps{}, []byte(`{}`)) require.NoError(t, err) @@ -532,6 +545,7 @@ func TestMain(m *testing.M) { } if len(untested) > 0 && code == 0 { + code = 1 println("The following tools were not tested:") for _, tool := range untested { println(" - " + tool) @@ -539,7 +553,14 @@ func TestMain(m *testing.M) { println("Please ensure that all tools are tested using testTool().") println("If you just added a new tool, please add a test for it.") println("NOTE: if you just ran an individual test, this is expected.") - os.Exit(1) + } + + // Check for goroutine leaks. Below is adapted from goleak.VerifyTestMain: + if code == 0 { + if err := goleak.Find(testutil.GoleakOptions...); err != nil { + println("goleak: Errors on successful test run: ", err.Error()) + code = 1 + } } os.Exit(code) From 233f9da352dc7662913f4422fb33400bd89b99e7 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Tue, 29 Apr 2025 14:56:12 +0100 Subject: [PATCH 10/11] address more PR comments --- cli/exp_mcp.go | 21 ++++++----- codersdk/toolsdk/toolsdk.go | 27 ++++++++++++- codersdk/toolsdk/toolsdk_test.go | 65 +++++++++++++++++++------------- 3 files changed, 77 insertions(+), 36 deletions(-) diff --git a/cli/exp_mcp.go b/cli/exp_mcp.go index 2449cc52c563d..c66db0b816aea 100644 --- a/cli/exp_mcp.go +++ b/cli/exp_mcp.go @@ -400,24 +400,27 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct server.WithInstructions(instructions), ) - // Create a new context for the tools with all relevant information. - tb := toolsdk.Deps{ - CoderClient: client, - } // Get the workspace agent token from the environment. + toolOpts := make([]func(*toolsdk.Deps), 0) var hasAgentClient bool if agentToken, err := getAgentToken(fs); err == nil && agentToken != "" { hasAgentClient = true agentClient := agentsdk.New(client.URL) agentClient.SetSessionToken(agentToken) - tb.AgentClient = agentClient + toolOpts = append(toolOpts, toolsdk.WithAgentClient(agentClient)) } else { cliui.Warnf(inv.Stderr, "CODER_AGENT_TOKEN is not set, task reporting will not be available") } - if appStatusSlug == "" { - cliui.Warnf(inv.Stderr, "CODER_MCP_APP_STATUS_SLUG is not set, task reporting will not be available.") + + if appStatusSlug != "" { + toolOpts = append(toolOpts, toolsdk.WithAppStatusSlug(appStatusSlug)) } else { - tb.AppStatusSlug = appStatusSlug + cliui.Warnf(inv.Stderr, "CODER_MCP_APP_STATUS_SLUG is not set, task reporting will not be available.") + } + + toolDeps, err := toolsdk.NewDeps(client, toolOpts...) + if err != nil { + return xerrors.Errorf("failed to initialize tool dependencies: %w", err) } // Register tools based on the allowlist (if specified) @@ -430,7 +433,7 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct if len(allowedTools) == 0 || slices.ContainsFunc(allowedTools, func(t string) bool { return t == tool.Tool.Name }) { - mcpSrv.AddTools(mcpFromSDK(tool, tb)) + mcpSrv.AddTools(mcpFromSDK(tool, toolDeps)) } } diff --git a/codersdk/toolsdk/toolsdk.go b/codersdk/toolsdk/toolsdk.go index 76f93f45cfe1f..1c66230d35711 100644 --- a/codersdk/toolsdk/toolsdk.go +++ b/codersdk/toolsdk/toolsdk.go @@ -15,6 +15,31 @@ import ( "github.com/coder/coder/v2/codersdk/agentsdk" ) +func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) { + d := Deps{ + CoderClient: client, + } + for _, opt := range opts { + opt(&d) + } + if d.CoderClient == nil { + return Deps{}, xerrors.New("developer error: coder client may not be nil") + } + return d, nil +} + +func WithAgentClient(client *agentsdk.Client) func(*Deps) { + return func(d *Deps) { + d.AgentClient = client + } +} + +func WithAppStatusSlug(slug string) func(*Deps) { + return func(d *Deps) { + d.AppStatusSlug = slug + } +} + // Deps provides access to tool dependencies. type Deps struct { CoderClient *codersdk.Client @@ -175,7 +200,7 @@ var ReportTask = Tool[ReportTaskArgs, codersdk.Response]{ return codersdk.Response{}, xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set") } if deps.AppStatusSlug == "" { - return codersdk.Response{}, xerrors.New("workspace app status slug not found in toolbox") + return codersdk.Response{}, xerrors.New("tool unavailable as CODER_MCP_APP_STATUS_SLUG is not set") } if len(args.Summary) > 160 { return codersdk.Response{}, xerrors.New("summary must be less than 160 characters") diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index 36d1733663de4..fae4e85e52a66 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -72,12 +72,9 @@ func TestTools(t *testing.T) { }) t.Run("ReportTask", func(t *testing.T) { - tb := toolsdk.Deps{ - CoderClient: memberClient, - AgentClient: agentClient, - AppStatusSlug: "some-agent-app", - } - _, err := testTool(t, toolsdk.ReportTask, tb, toolsdk.ReportTaskArgs{ + tb, err := toolsdk.NewDeps(memberClient, toolsdk.WithAgentClient(agentClient), toolsdk.WithAppStatusSlug("some-agent-app")) + require.NoError(t, err) + _, err = testTool(t, toolsdk.ReportTask, tb, toolsdk.ReportTaskArgs{ Summary: "test summary", State: "complete", Link: "https://example.com", @@ -86,7 +83,8 @@ func TestTools(t *testing.T) { }) t.Run("GetWorkspace", func(t *testing.T) { - tb := toolsdk.Deps{CoderClient: memberClient} + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) result, err := testTool(t, toolsdk.GetWorkspace, tb, toolsdk.GetWorkspaceArgs{ WorkspaceID: r.Workspace.ID.String(), }) @@ -96,7 +94,8 @@ func TestTools(t *testing.T) { }) t.Run("ListTemplates", func(t *testing.T) { - tb := toolsdk.Deps{CoderClient: memberClient} + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) // Get the templates directly for comparison expected, err := memberClient.Templates(context.Background(), codersdk.TemplateFilter{}) require.NoError(t, err) @@ -119,7 +118,8 @@ func TestTools(t *testing.T) { }) t.Run("Whoami", func(t *testing.T) { - tb := toolsdk.Deps{CoderClient: memberClient} + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) result, err := testTool(t, toolsdk.GetAuthenticatedUser, tb, toolsdk.NoArgs{}) require.NoError(t, err) @@ -128,7 +128,8 @@ func TestTools(t *testing.T) { }) t.Run("ListWorkspaces", func(t *testing.T) { - tb := toolsdk.Deps{CoderClient: memberClient} + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) result, err := testTool(t, toolsdk.ListWorkspaces, tb, toolsdk.ListWorkspacesArgs{}) require.NoError(t, err) @@ -140,7 +141,8 @@ func TestTools(t *testing.T) { t.Run("CreateWorkspaceBuild", func(t *testing.T) { t.Run("Stop", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) - tb := toolsdk.Deps{CoderClient: memberClient} + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) result, err := testTool(t, toolsdk.CreateWorkspaceBuild, tb, toolsdk.CreateWorkspaceBuildArgs{ WorkspaceID: r.Workspace.ID.String(), Transition: "stop", @@ -159,7 +161,8 @@ func TestTools(t *testing.T) { t.Run("Start", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) - tb := toolsdk.Deps{CoderClient: memberClient} + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) result, err := testTool(t, toolsdk.CreateWorkspaceBuild, tb, toolsdk.CreateWorkspaceBuildArgs{ WorkspaceID: r.Workspace.ID.String(), Transition: "start", @@ -178,7 +181,8 @@ func TestTools(t *testing.T) { t.Run("TemplateVersionChange", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) - tb := toolsdk.Deps{CoderClient: memberClient} + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) // Get the current template version ID before updating workspace, err := memberClient.Workspace(ctx, r.Workspace.ID) require.NoError(t, err) @@ -222,7 +226,8 @@ func TestTools(t *testing.T) { }) t.Run("ListTemplateVersionParameters", func(t *testing.T) { - tb := toolsdk.Deps{CoderClient: memberClient} + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) params, err := testTool(t, toolsdk.ListTemplateVersionParameters, tb, toolsdk.ListTemplateVersionParametersArgs{ TemplateVersionID: r.TemplateVersion.ID.String(), }) @@ -232,7 +237,8 @@ func TestTools(t *testing.T) { }) t.Run("GetWorkspaceAgentLogs", func(t *testing.T) { - tb := toolsdk.Deps{CoderClient: client} + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) logs, err := testTool(t, toolsdk.GetWorkspaceAgentLogs, tb, toolsdk.GetWorkspaceAgentLogsArgs{ WorkspaceAgentID: agentID.String(), }) @@ -242,7 +248,8 @@ func TestTools(t *testing.T) { }) t.Run("GetWorkspaceBuildLogs", func(t *testing.T) { - tb := toolsdk.Deps{CoderClient: memberClient} + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) logs, err := testTool(t, toolsdk.GetWorkspaceBuildLogs, tb, toolsdk.GetWorkspaceBuildLogsArgs{ WorkspaceBuildID: r.Build.ID.String(), }) @@ -252,7 +259,8 @@ func TestTools(t *testing.T) { }) t.Run("GetTemplateVersionLogs", func(t *testing.T) { - tb := toolsdk.Deps{CoderClient: memberClient} + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) logs, err := testTool(t, toolsdk.GetTemplateVersionLogs, tb, toolsdk.GetTemplateVersionLogsArgs{ TemplateVersionID: r.TemplateVersion.ID.String(), }) @@ -262,7 +270,8 @@ func TestTools(t *testing.T) { }) t.Run("UpdateTemplateActiveVersion", func(t *testing.T) { - tb := toolsdk.Deps{CoderClient: client} + tb, err := toolsdk.NewDeps(client) + require.NoError(t, err) result, err := testTool(t, toolsdk.UpdateTemplateActiveVersion, tb, toolsdk.UpdateTemplateActiveVersionArgs{ TemplateID: r.Template.ID.String(), TemplateVersionID: r.TemplateVersion.ID.String(), @@ -273,8 +282,9 @@ func TestTools(t *testing.T) { }) t.Run("DeleteTemplate", func(t *testing.T) { - tb := toolsdk.Deps{CoderClient: client} - _, err := testTool(t, toolsdk.DeleteTemplate, tb, toolsdk.DeleteTemplateArgs{ + tb, err := toolsdk.NewDeps(client) + require.NoError(t, err) + _, err = testTool(t, toolsdk.DeleteTemplate, tb, toolsdk.DeleteTemplateArgs{ TemplateID: r.Template.ID.String(), }) @@ -283,10 +293,11 @@ func TestTools(t *testing.T) { }) t.Run("UploadTarFile", func(t *testing.T) { - tb := toolsdk.Deps{CoderClient: client} files := map[string]string{ "main.tf": `resource "null_resource" "example" {}`, } + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) result, err := testTool(t, toolsdk.UploadTarFile, tb, toolsdk.UploadTarFileArgs{ Files: files, @@ -297,7 +308,8 @@ func TestTools(t *testing.T) { }) t.Run("CreateTemplateVersion", func(t *testing.T) { - tb := toolsdk.Deps{CoderClient: client} + tb, err := toolsdk.NewDeps(client) + require.NoError(t, err) // nolint:gocritic // This is in a test package and does not end up in the build file := dbgen.File(t, store, database.File{}) t.Run("WithoutTemplateID", func(t *testing.T) { @@ -308,7 +320,6 @@ func TestTools(t *testing.T) { require.NotEmpty(t, tv) }) t.Run("WithTemplateID", func(t *testing.T) { - tb := toolsdk.Deps{CoderClient: client} tv, err := testTool(t, toolsdk.CreateTemplateVersion, tb, toolsdk.CreateTemplateVersionArgs{ FileID: file.ID.String(), TemplateID: r.Template.ID.String(), @@ -319,7 +330,8 @@ func TestTools(t *testing.T) { }) t.Run("CreateTemplate", func(t *testing.T) { - tb := toolsdk.Deps{CoderClient: client} + tb, err := toolsdk.NewDeps(client) + require.NoError(t, err) // Create a new template version for use here. tv := dbfake.TemplateVersion(t, store). // nolint:gocritic // This is in a test package and does not end up in the build @@ -327,7 +339,7 @@ func TestTools(t *testing.T) { SkipCreateTemplate().Do() // We're going to re-use the pre-existing template version - _, err := testTool(t, toolsdk.CreateTemplate, tb, toolsdk.CreateTemplateArgs{ + _, err = testTool(t, toolsdk.CreateTemplate, tb, toolsdk.CreateTemplateArgs{ Name: testutil.GetRandomNameHyphenated(t), DisplayName: "Test Template", Description: "This is a test template", @@ -338,7 +350,8 @@ func TestTools(t *testing.T) { }) t.Run("CreateWorkspace", func(t *testing.T) { - tb := toolsdk.Deps{CoderClient: memberClient} + tb, err := toolsdk.NewDeps(client) + require.NoError(t, err) // We need a template version ID to create a workspace res, err := testTool(t, toolsdk.CreateWorkspace, tb, toolsdk.CreateWorkspaceArgs{ User: "me", From 40b2fdc82f61b2f9b22304323016679f48601ee1 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Tue, 29 Apr 2025 15:19:29 +0100 Subject: [PATCH 11/11] unexport deps fields --- codersdk/toolsdk/toolsdk.go | 56 ++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/codersdk/toolsdk/toolsdk.go b/codersdk/toolsdk/toolsdk.go index 1c66230d35711..024e3bad6efdc 100644 --- a/codersdk/toolsdk/toolsdk.go +++ b/codersdk/toolsdk/toolsdk.go @@ -17,12 +17,12 @@ import ( func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) { d := Deps{ - CoderClient: client, + coderClient: client, } for _, opt := range opts { opt(&d) } - if d.CoderClient == nil { + if d.coderClient == nil { return Deps{}, xerrors.New("developer error: coder client may not be nil") } return d, nil @@ -30,21 +30,21 @@ func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) { func WithAgentClient(client *agentsdk.Client) func(*Deps) { return func(d *Deps) { - d.AgentClient = client + d.agentClient = client } } func WithAppStatusSlug(slug string) func(*Deps) { return func(d *Deps) { - d.AppStatusSlug = slug + d.appStatusSlug = slug } } // Deps provides access to tool dependencies. type Deps struct { - CoderClient *codersdk.Client - AgentClient *agentsdk.Client - AppStatusSlug string + coderClient *codersdk.Client + agentClient *agentsdk.Client + appStatusSlug string } // HandlerFunc is a typed function that handles a tool call. @@ -196,17 +196,17 @@ var ReportTask = Tool[ReportTaskArgs, codersdk.Response]{ }, }, Handler: func(ctx context.Context, deps Deps, args ReportTaskArgs) (codersdk.Response, error) { - if deps.AgentClient == nil { + if deps.agentClient == nil { return codersdk.Response{}, xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set") } - if deps.AppStatusSlug == "" { + if deps.appStatusSlug == "" { return codersdk.Response{}, xerrors.New("tool unavailable as CODER_MCP_APP_STATUS_SLUG is not set") } if len(args.Summary) > 160 { return codersdk.Response{}, xerrors.New("summary must be less than 160 characters") } - if err := deps.AgentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ - AppSlug: deps.AppStatusSlug, + if err := deps.agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ + AppSlug: deps.appStatusSlug, Message: args.Summary, URI: args.Link, State: codersdk.WorkspaceAppStatusState(args.State), @@ -243,7 +243,7 @@ This returns more data than list_workspaces to reduce token usage.`, if err != nil { return codersdk.Workspace{}, xerrors.New("workspace_id must be a valid UUID") } - return deps.CoderClient.Workspace(ctx, wsID) + return deps.coderClient.Workspace(ctx, wsID) }, } @@ -300,7 +300,7 @@ is provisioned correctly and the agent can connect to the control plane. Value: v, }) } - workspace, err := deps.CoderClient.CreateUserWorkspace(ctx, args.User, codersdk.CreateWorkspaceRequest{ + workspace, err := deps.coderClient.CreateUserWorkspace(ctx, args.User, codersdk.CreateWorkspaceRequest{ TemplateVersionID: tvID, Name: args.Name, RichParameterValues: buildParams, @@ -334,7 +334,7 @@ var ListWorkspaces = Tool[ListWorkspacesArgs, []MinimalWorkspace]{ if owner == "" { owner = codersdk.Me } - workspaces, err := deps.CoderClient.Workspaces(ctx, codersdk.WorkspaceFilter{ + workspaces, err := deps.coderClient.Workspaces(ctx, codersdk.WorkspaceFilter{ Owner: owner, }) if err != nil { @@ -367,7 +367,7 @@ var ListTemplates = Tool[NoArgs, []MinimalTemplate]{ }, }, Handler: func(ctx context.Context, deps Deps, _ NoArgs) ([]MinimalTemplate, error) { - templates, err := deps.CoderClient.Templates(ctx, codersdk.TemplateFilter{}) + templates, err := deps.coderClient.Templates(ctx, codersdk.TemplateFilter{}) if err != nil { return nil, err } @@ -408,7 +408,7 @@ var ListTemplateVersionParameters = Tool[ListTemplateVersionParametersArgs, []co if err != nil { return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) } - parameters, err := deps.CoderClient.TemplateVersionRichParameters(ctx, templateVersionID) + parameters, err := deps.coderClient.TemplateVersionRichParameters(ctx, templateVersionID) if err != nil { return nil, err } @@ -426,7 +426,7 @@ var GetAuthenticatedUser = Tool[NoArgs, codersdk.User]{ }, }, Handler: func(ctx context.Context, deps Deps, _ NoArgs) (codersdk.User, error) { - return deps.CoderClient.User(ctx, "me") + return deps.coderClient.User(ctx, "me") }, } @@ -477,7 +477,7 @@ var CreateWorkspaceBuild = Tool[CreateWorkspaceBuildArgs, codersdk.WorkspaceBuil if templateVersionID != uuid.Nil { cbr.TemplateVersionID = templateVersionID } - return deps.CoderClient.CreateWorkspaceBuild(ctx, workspaceID, cbr) + return deps.coderClient.CreateWorkspaceBuild(ctx, workspaceID, cbr) }, } @@ -945,7 +945,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t }, }, Handler: func(ctx context.Context, deps Deps, args CreateTemplateVersionArgs) (codersdk.TemplateVersion, error) { - me, err := deps.CoderClient.User(ctx, "me") + me, err := deps.coderClient.User(ctx, "me") if err != nil { return codersdk.TemplateVersion{}, err } @@ -961,7 +961,7 @@ The file_id provided is a reference to a tar file you have uploaded containing t } templateID = tid } - templateVersion, err := deps.CoderClient.CreateTemplateVersion(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{ + templateVersion, err := deps.coderClient.CreateTemplateVersion(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{ Message: "Created by AI", StorageMethod: codersdk.ProvisionerStorageMethodFile, FileID: fileID, @@ -999,7 +999,7 @@ var GetWorkspaceAgentLogs = Tool[GetWorkspaceAgentLogsArgs, []string]{ if err != nil { return nil, xerrors.Errorf("workspace_agent_id must be a valid UUID: %w", err) } - logs, closer, err := deps.CoderClient.WorkspaceAgentLogsAfter(ctx, workspaceAgentID, 0, false) + logs, closer, err := deps.coderClient.WorkspaceAgentLogsAfter(ctx, workspaceAgentID, 0, false) if err != nil { return nil, err } @@ -1038,7 +1038,7 @@ var GetWorkspaceBuildLogs = Tool[GetWorkspaceBuildLogsArgs, []string]{ if err != nil { return nil, xerrors.Errorf("workspace_build_id must be a valid UUID: %w", err) } - logs, closer, err := deps.CoderClient.WorkspaceBuildLogsAfter(ctx, workspaceBuildID, 0) + logs, closer, err := deps.coderClient.WorkspaceBuildLogsAfter(ctx, workspaceBuildID, 0) if err != nil { return nil, err } @@ -1074,7 +1074,7 @@ var GetTemplateVersionLogs = Tool[GetTemplateVersionLogsArgs, []string]{ return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) } - logs, closer, err := deps.CoderClient.TemplateVersionLogsAfter(ctx, templateVersionID, 0) + logs, closer, err := deps.coderClient.TemplateVersionLogsAfter(ctx, templateVersionID, 0) if err != nil { return nil, err } @@ -1117,7 +1117,7 @@ var UpdateTemplateActiveVersion = Tool[UpdateTemplateActiveVersionArgs, string]{ if err != nil { return "", xerrors.Errorf("template_version_id must be a valid UUID: %w", err) } - err = deps.CoderClient.UpdateActiveTemplateVersion(ctx, templateID, codersdk.UpdateActiveTemplateVersion{ + err = deps.coderClient.UpdateActiveTemplateVersion(ctx, templateID, codersdk.UpdateActiveTemplateVersion{ ID: templateVersionID, }) if err != nil { @@ -1174,7 +1174,7 @@ var UploadTarFile = Tool[UploadTarFileArgs, codersdk.UploadResponse]{ } }() - resp, err := deps.CoderClient.Upload(ctx, codersdk.ContentTypeTar, pipeReader) + resp, err := deps.coderClient.Upload(ctx, codersdk.ContentTypeTar, pipeReader) if err != nil { _ = pipeReader.CloseWithError(err) <-done @@ -1221,7 +1221,7 @@ var CreateTemplate = Tool[CreateTemplateArgs, codersdk.Template]{ }, }, Handler: func(ctx context.Context, deps Deps, args CreateTemplateArgs) (codersdk.Template, error) { - me, err := deps.CoderClient.User(ctx, "me") + me, err := deps.coderClient.User(ctx, "me") if err != nil { return codersdk.Template{}, err } @@ -1229,7 +1229,7 @@ var CreateTemplate = Tool[CreateTemplateArgs, codersdk.Template]{ if err != nil { return codersdk.Template{}, xerrors.Errorf("version_id must be a valid UUID: %w", err) } - template, err := deps.CoderClient.CreateTemplate(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateRequest{ + template, err := deps.coderClient.CreateTemplate(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateRequest{ Name: args.Name, DisplayName: args.DisplayName, Description: args.Description, @@ -1263,7 +1263,7 @@ var DeleteTemplate = Tool[DeleteTemplateArgs, codersdk.Response]{ if err != nil { return codersdk.Response{}, xerrors.Errorf("template_id must be a valid UUID: %w", err) } - err = deps.CoderClient.DeleteTemplate(ctx, templateID) + err = deps.coderClient.DeleteTemplate(ctx, templateID) if err != nil { return codersdk.Response{}, err }