diff --git a/cli/exp_mcp.go b/cli/exp_mcp.go index 2d38d0417194d..40192c0e72cec 100644 --- a/cli/exp_mcp.go +++ b/cli/exp_mcp.go @@ -1,6 +1,7 @@ package cli import ( + "bytes" "context" "encoding/json" "errors" @@ -427,22 +428,27 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct server.WithInstructions(instructions), ) - // Create a new context for the tools with all relevant information. - clientCtx := toolsdk.WithClient(ctx, client) // Get the workspace agent token from the environment. + toolOpts := make([]func(*toolsdk.Deps), 0) var hasAgentClient bool if agentToken, err := getAgentToken(fs); err == nil && agentToken != "" { hasAgentClient = true agentClient := agentsdk.New(client.URL) agentClient.SetSessionToken(agentToken) - clientCtx = toolsdk.WithAgentClient(clientCtx, agentClient) + toolOpts = append(toolOpts, toolsdk.WithAgentClient(agentClient)) } else { cliui.Warnf(inv.Stderr, "CODER_AGENT_TOKEN is not set, task reporting will not be available") } - if appStatusSlug == "" { - cliui.Warnf(inv.Stderr, "CODER_MCP_APP_STATUS_SLUG is not set, task reporting will not be available.") + + if appStatusSlug != "" { + toolOpts = append(toolOpts, toolsdk.WithAppStatusSlug(appStatusSlug)) } else { - clientCtx = toolsdk.WithWorkspaceAppStatusSlug(clientCtx, appStatusSlug) + cliui.Warnf(inv.Stderr, "CODER_MCP_APP_STATUS_SLUG is not set, task reporting will not be available.") + } + + toolDeps, err := toolsdk.NewDeps(client, toolOpts...) + if err != nil { + return xerrors.Errorf("failed to initialize tool dependencies: %w", err) } // Register tools based on the allowlist (if specified) @@ -455,7 +461,7 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct if len(allowedTools) == 0 || slices.ContainsFunc(allowedTools, func(t string) bool { return t == tool.Tool.Name }) { - mcpSrv.AddTools(mcpFromSDK(tool)) + mcpSrv.AddTools(mcpFromSDK(tool, toolDeps)) } } @@ -463,7 +469,7 @@ func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instruct done := make(chan error) go func() { defer close(done) - srvErr := srv.Listen(clientCtx, invStdin, invStdout) + srvErr := srv.Listen(ctx, invStdin, invStdout) done <- srvErr }() @@ -726,7 +732,7 @@ func getAgentToken(fs afero.Fs) (string, error) { // mcpFromSDK adapts a toolsdk.Tool to go-mcp's server.ServerTool. // It assumes that the tool responds with a valid JSON object. -func mcpFromSDK(sdkTool toolsdk.Tool[any]) server.ServerTool { +func mcpFromSDK(sdkTool toolsdk.GenericTool, tb toolsdk.Deps) server.ServerTool { // NOTE: some clients will silently refuse to use tools if there is an issue // with the tool's schema or configuration. if sdkTool.Schema.Properties == nil { @@ -743,27 +749,17 @@ func mcpFromSDK(sdkTool toolsdk.Tool[any]) server.ServerTool { }, }, Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - result, err := sdkTool.Handler(ctx, request.Params.Arguments) + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(request.Params.Arguments); err != nil { + return nil, xerrors.Errorf("failed to encode request arguments: %w", err) + } + result, err := sdkTool.Handler(ctx, tb, buf.Bytes()) if err != nil { return nil, err } - var sb strings.Builder - if err := json.NewEncoder(&sb).Encode(result); err == nil { - return &mcp.CallToolResult{ - Content: []mcp.Content{ - mcp.NewTextContent(sb.String()), - }, - }, nil - } - // If the result is not JSON, return it as a string. - // This is a fallback for tools that return non-JSON data. - resultStr, ok := result.(string) - if !ok { - return nil, xerrors.Errorf("tool call result is neither valid JSON or a string, got: %T", result) - } return &mcp.CallToolResult{ Content: []mcp.Content{ - mcp.NewTextContent(resultStr), + mcp.NewTextContent(string(result)), }, }, nil }, diff --git a/cli/exp_mcp_test.go b/cli/exp_mcp_test.go index 35676cd81de91..93c7acea74f22 100644 --- a/cli/exp_mcp_test.go +++ b/cli/exp_mcp_test.go @@ -31,12 +31,12 @@ func TestExpMcpServer(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) + cmdDone := make(chan struct{}) cancelCtx, cancel := context.WithCancel(ctx) - t.Cleanup(cancel) // Given: a running coder deployment client := coderdtest.New(t, nil) - _ = coderdtest.CreateFirstUser(t, client) + owner := coderdtest.CreateFirstUser(t, client) // Given: we run the exp mcp command with allowed tools set inv, root := clitest.New(t, "exp", "mcp", "server", "--allowed-tools=coder_get_authenticated_user") @@ -48,7 +48,6 @@ func TestExpMcpServer(t *testing.T) { // nolint: gocritic // not the focus of this test clitest.SetupConfig(t, client, root) - cmdDone := make(chan struct{}) go func() { defer close(cmdDone) err := inv.Run() @@ -61,9 +60,6 @@ func TestExpMcpServer(t *testing.T) { _ = pty.ReadLine(ctx) // ignore echoed output output := pty.ReadLine(ctx) - cancel() - <-cmdDone - // Then: we should only see the allowed tools in the response var toolsResponse struct { Result struct { @@ -81,6 +77,20 @@ func TestExpMcpServer(t *testing.T) { } slices.Sort(foundTools) require.Equal(t, []string{"coder_get_authenticated_user"}, foundTools) + + // Call the tool and ensure it works. + toolPayload := `{"jsonrpc":"2.0","id":3,"method":"tools/call", "params": {"name": "coder_get_authenticated_user", "arguments": {}}}` + pty.WriteLine(toolPayload) + _ = pty.ReadLine(ctx) // ignore echoed output + output = pty.ReadLine(ctx) + require.NotEmpty(t, output, "should have received a response from the tool") + // Ensure it's valid JSON + _, err = json.Marshal(output) + require.NoError(t, err, "should have received a valid JSON response from the tool") + // Ensure the tool returns the expected user + require.Contains(t, output, owner.UserID.String(), "should have received the expected user ID") + cancel() + <-cmdDone }) t.Run("OK", func(t *testing.T) { diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 1388b61030d38..c8d75d0d4f313 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -338,9 +338,33 @@ func (api *API) patchWorkspaceAgentAppStatus(rw http.ResponseWriter, r *http.Req Slug: req.AppSlug, }) if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: "Failed to get workspace app.", - Detail: err.Error(), + Detail: fmt.Sprintf("No app found with slug %q", req.AppSlug), + }) + return + } + + if len(req.Message) > 160 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Message is too long.", + Detail: "Message must be less than 160 characters.", + Validations: []codersdk.ValidationError{ + {Field: "message", Detail: "Message must be less than 160 characters."}, + }, + }) + return + } + + switch req.State { + case codersdk.WorkspaceAppStatusStateComplete, codersdk.WorkspaceAppStatusStateFailure, codersdk.WorkspaceAppStatusStateWorking: // valid states + default: + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid state provided.", + Detail: fmt.Sprintf("invalid state: %q", req.State), + Validations: []codersdk.ValidationError{ + {Field: "state", Detail: "State must be one of: complete, failure, working."}, + }, }) return } diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index a6e10ea5fdabf..da2619da0b29d 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -341,27 +341,27 @@ func TestWorkspaceAgentLogs(t *testing.T) { func TestWorkspaceAgentAppStatus(t *testing.T) { t.Parallel() - t.Run("Success", func(t *testing.T) { - t.Parallel() - ctx := testutil.Context(t, testutil.WaitMedium) - client, db := coderdtest.NewWithDatabase(t, nil) - user := coderdtest.CreateFirstUser(t, client) - client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID) + client, db := coderdtest.NewWithDatabase(t, nil) + user := coderdtest.CreateFirstUser(t, client) + client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID) - r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ - OrganizationID: user.OrganizationID, - OwnerID: user2.ID, - }).WithAgent(func(a []*proto.Agent) []*proto.Agent { - a[0].Apps = []*proto.App{ - { - Slug: "vscode", - }, - } - return a - }).Do() + r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: user.OrganizationID, + OwnerID: user2.ID, + }).WithAgent(func(a []*proto.Agent) []*proto.Agent { + a[0].Apps = []*proto.App{ + { + Slug: "vscode", + }, + } + return a + }).Do() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL) + agentClient.SetSessionToken(r.AgentToken) + t.Run("Success", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ AppSlug: "vscode", Message: "testing", @@ -382,6 +382,51 @@ func TestWorkspaceAgentAppStatus(t *testing.T) { require.Empty(t, agent.Apps[0].Statuses[0].Icon) require.False(t, agent.Apps[0].Statuses[0].NeedsUserAttention) }) + + t.Run("FailUnknownApp", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ + AppSlug: "unknown", + Message: "testing", + URI: "https://example.com", + State: codersdk.WorkspaceAppStatusStateComplete, + }) + require.ErrorContains(t, err, "No app found with slug") + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + }) + + t.Run("FailUnknownState", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ + AppSlug: "vscode", + Message: "testing", + URI: "https://example.com", + State: "unknown", + }) + require.ErrorContains(t, err, "Invalid state") + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + }) + + t.Run("FailTooLong", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ + AppSlug: "vscode", + Message: strings.Repeat("a", 161), + URI: "https://example.com", + State: codersdk.WorkspaceAppStatusStateComplete, + }) + require.ErrorContains(t, err, "Message is too long") + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + }) } func TestWorkspaceAgentConnectRPC(t *testing.T) { diff --git a/codersdk/toolsdk/toolsdk.go b/codersdk/toolsdk/toolsdk.go index 73dee8e748575..024e3bad6efdc 100644 --- a/codersdk/toolsdk/toolsdk.go +++ b/codersdk/toolsdk/toolsdk.go @@ -2,7 +2,9 @@ package toolsdk import ( "archive/tar" + "bytes" "context" + "encoding/json" "io" "github.com/google/uuid" @@ -13,372 +15,481 @@ import ( "github.com/coder/coder/v2/codersdk/agentsdk" ) -// HandlerFunc is a function that handles a tool call. -type HandlerFunc[T any] func(ctx context.Context, args map[string]any) (T, error) +func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) { + d := Deps{ + coderClient: client, + } + for _, opt := range opts { + opt(&d) + } + if d.coderClient == nil { + return Deps{}, xerrors.New("developer error: coder client may not be nil") + } + return d, nil +} + +func WithAgentClient(client *agentsdk.Client) func(*Deps) { + return func(d *Deps) { + d.agentClient = client + } +} + +func WithAppStatusSlug(slug string) func(*Deps) { + return func(d *Deps) { + d.appStatusSlug = slug + } +} -type Tool[T any] struct { +// Deps provides access to tool dependencies. +type Deps struct { + coderClient *codersdk.Client + agentClient *agentsdk.Client + appStatusSlug string +} + +// HandlerFunc is a typed function that handles a tool call. +type HandlerFunc[Arg, Ret any] func(context.Context, Deps, Arg) (Ret, error) + +// Tool consists of an aisdk.Tool and a corresponding typed handler function. +type Tool[Arg, Ret any] struct { aisdk.Tool - Handler HandlerFunc[T] + Handler HandlerFunc[Arg, Ret] } -// Generic returns a Tool[any] that can be used to call the tool. -func (t Tool[T]) Generic() Tool[any] { - return Tool[any]{ +// Generic returns a type-erased version of a TypedTool where the arguments and +// return values are converted to/from json.RawMessage. +// This allows the tool to be referenced without knowing the concrete arguments +// or return values. The original TypedHandlerFunc is wrapped to handle type +// conversion. +func (t Tool[Arg, Ret]) Generic() GenericTool { + return GenericTool{ Tool: t.Tool, - Handler: func(ctx context.Context, args map[string]any) (any, error) { - return t.Handler(ctx, args) - }, + Handler: wrap(func(ctx context.Context, deps Deps, args json.RawMessage) (json.RawMessage, error) { + var typedArgs Arg + if err := json.Unmarshal(args, &typedArgs); err != nil { + return nil, xerrors.Errorf("failed to unmarshal args: %w", err) + } + ret, err := t.Handler(ctx, deps, typedArgs) + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(ret); err != nil { + return json.RawMessage{}, err + } + return buf.Bytes(), err + }, WithCleanContext, WithRecover), } } -var ( - // All is a list of all tools that can be used in the Coder CLI. - // When you add a new tool, be sure to include it here! - All = []Tool[any]{ - CreateTemplateVersion.Generic(), - CreateTemplate.Generic(), - CreateWorkspace.Generic(), - CreateWorkspaceBuild.Generic(), - DeleteTemplate.Generic(), - GetAuthenticatedUser.Generic(), - GetTemplateVersionLogs.Generic(), - GetWorkspace.Generic(), - GetWorkspaceAgentLogs.Generic(), - GetWorkspaceBuildLogs.Generic(), - ListWorkspaces.Generic(), - ListTemplates.Generic(), - ListTemplateVersionParameters.Generic(), - ReportTask.Generic(), - UploadTarFile.Generic(), - UpdateTemplateActiveVersion.Generic(), +// GenericTool is a type-erased wrapper for GenericTool. +// This allows referencing the tool without knowing the concrete argument or +// return type. The Handler function allows calling the tool with known types. +type GenericTool struct { + aisdk.Tool + Handler GenericHandlerFunc +} + +// GenericHandlerFunc is a function that handles a tool call. +type GenericHandlerFunc func(context.Context, Deps, json.RawMessage) (json.RawMessage, error) + +// NoArgs just represents an empty argument struct. +type NoArgs struct{} + +// WithRecover wraps a HandlerFunc to recover from panics and return an error. +func WithRecover(h GenericHandlerFunc) GenericHandlerFunc { + return func(ctx context.Context, deps Deps, args json.RawMessage) (ret json.RawMessage, err error) { + defer func() { + if r := recover(); r != nil { + err = xerrors.Errorf("tool handler panic: %v", r) + } + }() + return h(ctx, deps, args) } +} - ReportTask = Tool[string]{ - Tool: aisdk.Tool{ - Name: "coder_report_task", - Description: "Report progress on a user task in Coder.", - Schema: aisdk.Schema{ - Properties: map[string]any{ - "summary": map[string]any{ - "type": "string", - "description": "A concise summary of your current progress on the task. This must be less than 160 characters in length.", - }, - "link": map[string]any{ - "type": "string", - "description": "A link to a relevant resource, such as a PR or issue.", - }, - "state": map[string]any{ - "type": "string", - "description": "The state of your task. This can be one of the following: working, complete, or failure. Select the state that best represents your current progress.", - "enum": []string{ - string(codersdk.WorkspaceAppStatusStateWorking), - string(codersdk.WorkspaceAppStatusStateComplete), - string(codersdk.WorkspaceAppStatusStateFailure), - }, +// WithCleanContext wraps a HandlerFunc to provide it with a new context. +// This ensures that no data is passed using context.Value. +// If a deadline is set on the parent context, it will be passed to the child +// context. +func WithCleanContext(h GenericHandlerFunc) GenericHandlerFunc { + return func(parent context.Context, deps Deps, args json.RawMessage) (ret json.RawMessage, err error) { + child, childCancel := context.WithCancel(context.Background()) + defer childCancel() + // Ensure that the child context has the same deadline as the parent + // context. + if deadline, ok := parent.Deadline(); ok { + deadlineCtx, deadlineCancel := context.WithDeadline(child, deadline) + defer deadlineCancel() + child = deadlineCtx + } + // Ensure that cancellation propagates from the parent context to the child context. + go func() { + select { + case <-child.Done(): + return + case <-parent.Done(): + childCancel() + } + }() + return h(child, deps, args) + } +} + +// wrap wraps the provided GenericHandlerFunc with the provided middleware functions. +func wrap(hf GenericHandlerFunc, mw ...func(GenericHandlerFunc) GenericHandlerFunc) GenericHandlerFunc { + for _, m := range mw { + hf = m(hf) + } + return hf +} + +// All is a list of all tools that can be used in the Coder CLI. +// When you add a new tool, be sure to include it here! +var All = []GenericTool{ + CreateTemplate.Generic(), + CreateTemplateVersion.Generic(), + CreateWorkspace.Generic(), + CreateWorkspaceBuild.Generic(), + DeleteTemplate.Generic(), + ListTemplates.Generic(), + ListTemplateVersionParameters.Generic(), + ListWorkspaces.Generic(), + GetAuthenticatedUser.Generic(), + GetTemplateVersionLogs.Generic(), + GetWorkspace.Generic(), + GetWorkspaceAgentLogs.Generic(), + GetWorkspaceBuildLogs.Generic(), + ReportTask.Generic(), + UploadTarFile.Generic(), + UpdateTemplateActiveVersion.Generic(), +} + +type ReportTaskArgs struct { + Link string `json:"link"` + State string `json:"state"` + Summary string `json:"summary"` +} + +var ReportTask = Tool[ReportTaskArgs, codersdk.Response]{ + Tool: aisdk.Tool{ + Name: "coder_report_task", + Description: "Report progress on a user task in Coder.", + Schema: aisdk.Schema{ + Properties: map[string]any{ + "summary": map[string]any{ + "type": "string", + "description": "A concise summary of your current progress on the task. This must be less than 160 characters in length.", + }, + "link": map[string]any{ + "type": "string", + "description": "A link to a relevant resource, such as a PR or issue.", + }, + "state": map[string]any{ + "type": "string", + "description": "The state of your task. This can be one of the following: working, complete, or failure. Select the state that best represents your current progress.", + "enum": []string{ + string(codersdk.WorkspaceAppStatusStateWorking), + string(codersdk.WorkspaceAppStatusStateComplete), + string(codersdk.WorkspaceAppStatusStateFailure), }, }, - Required: []string{"summary", "link", "state"}, }, + Required: []string{"summary", "link", "state"}, }, - Handler: func(ctx context.Context, args map[string]any) (string, error) { - agentClient, err := agentClientFromContext(ctx) - if err != nil { - return "", xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set") - } - appSlug, ok := workspaceAppStatusSlugFromContext(ctx) - if !ok { - return "", xerrors.New("workspace app status slug not found in context") - } - summary, ok := args["summary"].(string) - if !ok { - return "", xerrors.New("summary must be a string") - } - if len(summary) > 160 { - return "", xerrors.New("summary must be less than 160 characters") - } - link, ok := args["link"].(string) - if !ok { - return "", xerrors.New("link must be a string") - } - state, ok := args["state"].(string) - if !ok { - return "", xerrors.New("state must be a string") - } + }, + Handler: func(ctx context.Context, deps Deps, args ReportTaskArgs) (codersdk.Response, error) { + if deps.agentClient == nil { + return codersdk.Response{}, xerrors.New("tool unavailable as CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE not set") + } + if deps.appStatusSlug == "" { + return codersdk.Response{}, xerrors.New("tool unavailable as CODER_MCP_APP_STATUS_SLUG is not set") + } + if len(args.Summary) > 160 { + return codersdk.Response{}, xerrors.New("summary must be less than 160 characters") + } + if err := deps.agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ + AppSlug: deps.appStatusSlug, + Message: args.Summary, + URI: args.Link, + State: codersdk.WorkspaceAppStatusState(args.State), + }); err != nil { + return codersdk.Response{}, err + } + return codersdk.Response{ + Message: "Thanks for reporting!", + }, nil + }, +} - if err := agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{ - AppSlug: appSlug, - Message: summary, - URI: link, - State: codersdk.WorkspaceAppStatusState(state), - }); err != nil { - return "", err - } - return "Thanks for reporting!", nil - }, - } +type GetWorkspaceArgs struct { + WorkspaceID string `json:"workspace_id"` +} - GetWorkspace = Tool[codersdk.Workspace]{ - Tool: aisdk.Tool{ - Name: "coder_get_workspace", - Description: `Get a workspace by ID. +var GetWorkspace = Tool[GetWorkspaceArgs, codersdk.Workspace]{ + Tool: aisdk.Tool{ + Name: "coder_get_workspace", + Description: `Get a workspace by ID. This returns more data than list_workspaces to reduce token usage.`, - Schema: aisdk.Schema{ - Properties: map[string]any{ - "workspace_id": map[string]any{ - "type": "string", - }, + Schema: aisdk.Schema{ + Properties: map[string]any{ + "workspace_id": map[string]any{ + "type": "string", }, - Required: []string{"workspace_id"}, }, + Required: []string{"workspace_id"}, }, - Handler: func(ctx context.Context, args map[string]any) (codersdk.Workspace, error) { - client, err := clientFromContext(ctx) - if err != nil { - return codersdk.Workspace{}, err - } - workspaceID, err := uuidFromArgs(args, "workspace_id") - if err != nil { - return codersdk.Workspace{}, err - } - return client.Workspace(ctx, workspaceID) - }, - } + }, + Handler: func(ctx context.Context, deps Deps, args GetWorkspaceArgs) (codersdk.Workspace, error) { + wsID, err := uuid.Parse(args.WorkspaceID) + if err != nil { + return codersdk.Workspace{}, xerrors.New("workspace_id must be a valid UUID") + } + return deps.coderClient.Workspace(ctx, wsID) + }, +} - CreateWorkspace = Tool[codersdk.Workspace]{ - Tool: aisdk.Tool{ - Name: "coder_create_workspace", - Description: `Create a new workspace in Coder. +type CreateWorkspaceArgs struct { + Name string `json:"name"` + RichParameters map[string]string `json:"rich_parameters"` + TemplateVersionID string `json:"template_version_id"` + User string `json:"user"` +} + +var CreateWorkspace = Tool[CreateWorkspaceArgs, codersdk.Workspace]{ + Tool: aisdk.Tool{ + Name: "coder_create_workspace", + Description: `Create a new workspace in Coder. If a user is asking to "test a template", they are typically referring to creating a workspace from a template to ensure the infrastructure is provisioned correctly and the agent can connect to the control plane. `, - Schema: aisdk.Schema{ - Properties: map[string]any{ - "user": map[string]any{ - "type": "string", - "description": "Username or ID of the user to create the workspace for. Use the `me` keyword to create a workspace for the authenticated user.", - }, - "template_version_id": map[string]any{ - "type": "string", - "description": "ID of the template version to create the workspace from.", - }, - "name": map[string]any{ - "type": "string", - "description": "Name of the workspace to create.", - }, - "rich_parameters": map[string]any{ - "type": "object", - "description": "Key/value pairs of rich parameters to pass to the template version to create the workspace.", - }, + Schema: aisdk.Schema{ + Properties: map[string]any{ + "user": map[string]any{ + "type": "string", + "description": "Username or ID of the user to create the workspace for. Use the `me` keyword to create a workspace for the authenticated user.", + }, + "template_version_id": map[string]any{ + "type": "string", + "description": "ID of the template version to create the workspace from.", + }, + "name": map[string]any{ + "type": "string", + "description": "Name of the workspace to create.", + }, + "rich_parameters": map[string]any{ + "type": "object", + "description": "Key/value pairs of rich parameters to pass to the template version to create the workspace.", }, - Required: []string{"user", "template_version_id", "name", "rich_parameters"}, }, + Required: []string{"user", "template_version_id", "name", "rich_parameters"}, }, - Handler: func(ctx context.Context, args map[string]any) (codersdk.Workspace, error) { - client, err := clientFromContext(ctx) - if err != nil { - return codersdk.Workspace{}, err - } - templateVersionID, err := uuidFromArgs(args, "template_version_id") - if err != nil { - return codersdk.Workspace{}, err - } - name, ok := args["name"].(string) - if !ok { - return codersdk.Workspace{}, xerrors.New("workspace name must be a string") - } - workspace, err := client.CreateUserWorkspace(ctx, "me", codersdk.CreateWorkspaceRequest{ - TemplateVersionID: templateVersionID, - Name: name, + }, + Handler: func(ctx context.Context, deps Deps, args CreateWorkspaceArgs) (codersdk.Workspace, error) { + tvID, err := uuid.Parse(args.TemplateVersionID) + if err != nil { + return codersdk.Workspace{}, xerrors.New("template_version_id must be a valid UUID") + } + if args.User == "" { + args.User = codersdk.Me + } + var buildParams []codersdk.WorkspaceBuildParameter + for k, v := range args.RichParameters { + buildParams = append(buildParams, codersdk.WorkspaceBuildParameter{ + Name: k, + Value: v, }) - if err != nil { - return codersdk.Workspace{}, err - } - return workspace, nil - }, - } + } + workspace, err := deps.coderClient.CreateUserWorkspace(ctx, args.User, codersdk.CreateWorkspaceRequest{ + TemplateVersionID: tvID, + Name: args.Name, + RichParameterValues: buildParams, + }) + if err != nil { + return codersdk.Workspace{}, err + } + return workspace, nil + }, +} - ListWorkspaces = Tool[[]MinimalWorkspace]{ - Tool: aisdk.Tool{ - Name: "coder_list_workspaces", - Description: "Lists workspaces for the authenticated user.", - Schema: aisdk.Schema{ - Properties: map[string]any{ - "owner": map[string]any{ - "type": "string", - "description": "The owner of the workspaces to list. Use \"me\" to list workspaces for the authenticated user. If you do not specify an owner, \"me\" will be assumed by default.", - }, +type ListWorkspacesArgs struct { + Owner string `json:"owner"` +} + +var ListWorkspaces = Tool[ListWorkspacesArgs, []MinimalWorkspace]{ + Tool: aisdk.Tool{ + Name: "coder_list_workspaces", + Description: "Lists workspaces for the authenticated user.", + Schema: aisdk.Schema{ + Properties: map[string]any{ + "owner": map[string]any{ + "type": "string", + "description": "The owner of the workspaces to list. Use \"me\" to list workspaces for the authenticated user. If you do not specify an owner, \"me\" will be assumed by default.", }, }, }, - Handler: func(ctx context.Context, args map[string]any) ([]MinimalWorkspace, error) { - client, err := clientFromContext(ctx) - if err != nil { - return nil, err - } - owner, ok := args["owner"].(string) - if !ok { - owner = codersdk.Me - } - workspaces, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{ - Owner: owner, - }) - if err != nil { - return nil, err - } - minimalWorkspaces := make([]MinimalWorkspace, len(workspaces.Workspaces)) - for i, workspace := range workspaces.Workspaces { - minimalWorkspaces[i] = MinimalWorkspace{ - ID: workspace.ID.String(), - Name: workspace.Name, - TemplateID: workspace.TemplateID.String(), - TemplateName: workspace.TemplateName, - TemplateDisplayName: workspace.TemplateDisplayName, - TemplateIcon: workspace.TemplateIcon, - TemplateActiveVersionID: workspace.TemplateActiveVersionID, - Outdated: workspace.Outdated, - } - } - return minimalWorkspaces, nil - }, - } + }, + Handler: func(ctx context.Context, deps Deps, args ListWorkspacesArgs) ([]MinimalWorkspace, error) { + owner := args.Owner + if owner == "" { + owner = codersdk.Me + } + workspaces, err := deps.coderClient.Workspaces(ctx, codersdk.WorkspaceFilter{ + Owner: owner, + }) + if err != nil { + return nil, err + } + minimalWorkspaces := make([]MinimalWorkspace, len(workspaces.Workspaces)) + for i, workspace := range workspaces.Workspaces { + minimalWorkspaces[i] = MinimalWorkspace{ + ID: workspace.ID.String(), + Name: workspace.Name, + TemplateID: workspace.TemplateID.String(), + TemplateName: workspace.TemplateName, + TemplateDisplayName: workspace.TemplateDisplayName, + TemplateIcon: workspace.TemplateIcon, + TemplateActiveVersionID: workspace.TemplateActiveVersionID, + Outdated: workspace.Outdated, + } + } + return minimalWorkspaces, nil + }, +} - ListTemplates = Tool[[]MinimalTemplate]{ - Tool: aisdk.Tool{ - Name: "coder_list_templates", - Description: "Lists templates for the authenticated user.", - Schema: aisdk.Schema{ - Properties: map[string]any{}, - Required: []string{}, - }, +var ListTemplates = Tool[NoArgs, []MinimalTemplate]{ + Tool: aisdk.Tool{ + Name: "coder_list_templates", + Description: "Lists templates for the authenticated user.", + Schema: aisdk.Schema{ + Properties: map[string]any{}, + Required: []string{}, }, - Handler: func(ctx context.Context, _ map[string]any) ([]MinimalTemplate, error) { - client, err := clientFromContext(ctx) - if err != nil { - return nil, err - } - templates, err := client.Templates(ctx, codersdk.TemplateFilter{}) - if err != nil { - return nil, err - } - minimalTemplates := make([]MinimalTemplate, len(templates)) - for i, template := range templates { - minimalTemplates[i] = MinimalTemplate{ - DisplayName: template.DisplayName, - ID: template.ID.String(), - Name: template.Name, - Description: template.Description, - ActiveVersionID: template.ActiveVersionID, - ActiveUserCount: template.ActiveUserCount, - } - } - return minimalTemplates, nil - }, - } + }, + Handler: func(ctx context.Context, deps Deps, _ NoArgs) ([]MinimalTemplate, error) { + templates, err := deps.coderClient.Templates(ctx, codersdk.TemplateFilter{}) + if err != nil { + return nil, err + } + minimalTemplates := make([]MinimalTemplate, len(templates)) + for i, template := range templates { + minimalTemplates[i] = MinimalTemplate{ + DisplayName: template.DisplayName, + ID: template.ID.String(), + Name: template.Name, + Description: template.Description, + ActiveVersionID: template.ActiveVersionID, + ActiveUserCount: template.ActiveUserCount, + } + } + return minimalTemplates, nil + }, +} - ListTemplateVersionParameters = Tool[[]codersdk.TemplateVersionParameter]{ - Tool: aisdk.Tool{ - Name: "coder_template_version_parameters", - Description: "Get the parameters for a template version. You can refer to these as workspace parameters to the user, as they are typically important for creating a workspace.", - Schema: aisdk.Schema{ - Properties: map[string]any{ - "template_version_id": map[string]any{ - "type": "string", - }, +type ListTemplateVersionParametersArgs struct { + TemplateVersionID string `json:"template_version_id"` +} + +var ListTemplateVersionParameters = Tool[ListTemplateVersionParametersArgs, []codersdk.TemplateVersionParameter]{ + Tool: aisdk.Tool{ + Name: "coder_template_version_parameters", + Description: "Get the parameters for a template version. You can refer to these as workspace parameters to the user, as they are typically important for creating a workspace.", + Schema: aisdk.Schema{ + Properties: map[string]any{ + "template_version_id": map[string]any{ + "type": "string", }, - Required: []string{"template_version_id"}, }, + Required: []string{"template_version_id"}, }, - Handler: func(ctx context.Context, args map[string]any) ([]codersdk.TemplateVersionParameter, error) { - client, err := clientFromContext(ctx) - if err != nil { - return nil, err - } - templateVersionID, err := uuidFromArgs(args, "template_version_id") - if err != nil { - return nil, err - } - parameters, err := client.TemplateVersionRichParameters(ctx, templateVersionID) - if err != nil { - return nil, err - } - return parameters, nil - }, - } + }, + Handler: func(ctx context.Context, deps Deps, args ListTemplateVersionParametersArgs) ([]codersdk.TemplateVersionParameter, error) { + templateVersionID, err := uuid.Parse(args.TemplateVersionID) + if err != nil { + return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) + } + parameters, err := deps.coderClient.TemplateVersionRichParameters(ctx, templateVersionID) + if err != nil { + return nil, err + } + return parameters, nil + }, +} - GetAuthenticatedUser = Tool[codersdk.User]{ - Tool: aisdk.Tool{ - Name: "coder_get_authenticated_user", - Description: "Get the currently authenticated user, similar to the `whoami` command.", - Schema: aisdk.Schema{ - Properties: map[string]any{}, - Required: []string{}, - }, +var GetAuthenticatedUser = Tool[NoArgs, codersdk.User]{ + Tool: aisdk.Tool{ + Name: "coder_get_authenticated_user", + Description: "Get the currently authenticated user, similar to the `whoami` command.", + Schema: aisdk.Schema{ + Properties: map[string]any{}, + Required: []string{}, }, - Handler: func(ctx context.Context, _ map[string]any) (codersdk.User, error) { - client, err := clientFromContext(ctx) - if err != nil { - return codersdk.User{}, err - } - return client.User(ctx, "me") - }, - } + }, + Handler: func(ctx context.Context, deps Deps, _ NoArgs) (codersdk.User, error) { + return deps.coderClient.User(ctx, "me") + }, +} - CreateWorkspaceBuild = Tool[codersdk.WorkspaceBuild]{ - Tool: aisdk.Tool{ - Name: "coder_create_workspace_build", - Description: "Create a new workspace build for an existing workspace. Use this to start, stop, or delete.", - Schema: aisdk.Schema{ - Properties: map[string]any{ - "workspace_id": map[string]any{ - "type": "string", - }, - "transition": map[string]any{ - "type": "string", - "description": "The transition to perform. Must be one of: start, stop, delete", - "enum": []string{"start", "stop", "delete"}, - }, - "template_version_id": map[string]any{ - "type": "string", - "description": "(Optional) The template version ID to use for the workspace build. If not provided, the previously built version will be used.", - }, +type CreateWorkspaceBuildArgs struct { + TemplateVersionID string `json:"template_version_id"` + Transition string `json:"transition"` + WorkspaceID string `json:"workspace_id"` +} + +var CreateWorkspaceBuild = Tool[CreateWorkspaceBuildArgs, codersdk.WorkspaceBuild]{ + Tool: aisdk.Tool{ + Name: "coder_create_workspace_build", + Description: "Create a new workspace build for an existing workspace. Use this to start, stop, or delete.", + Schema: aisdk.Schema{ + Properties: map[string]any{ + "workspace_id": map[string]any{ + "type": "string", + }, + "transition": map[string]any{ + "type": "string", + "description": "The transition to perform. Must be one of: start, stop, delete", + "enum": []string{"start", "stop", "delete"}, + }, + "template_version_id": map[string]any{ + "type": "string", + "description": "(Optional) The template version ID to use for the workspace build. If not provided, the previously built version will be used.", }, - Required: []string{"workspace_id", "transition"}, }, + Required: []string{"workspace_id", "transition"}, }, - Handler: func(ctx context.Context, args map[string]any) (codersdk.WorkspaceBuild, error) { - client, err := clientFromContext(ctx) - if err != nil { - return codersdk.WorkspaceBuild{}, err - } - workspaceID, err := uuidFromArgs(args, "workspace_id") - if err != nil { - return codersdk.WorkspaceBuild{}, err - } - rawTransition, ok := args["transition"].(string) - if !ok { - return codersdk.WorkspaceBuild{}, xerrors.New("transition must be a string") - } - templateVersionID, err := uuidFromArgs(args, "template_version_id") + }, + Handler: func(ctx context.Context, deps Deps, args CreateWorkspaceBuildArgs) (codersdk.WorkspaceBuild, error) { + workspaceID, err := uuid.Parse(args.WorkspaceID) + if err != nil { + return codersdk.WorkspaceBuild{}, xerrors.Errorf("workspace_id must be a valid UUID: %w", err) + } + var templateVersionID uuid.UUID + if args.TemplateVersionID != "" { + tvID, err := uuid.Parse(args.TemplateVersionID) if err != nil { - return codersdk.WorkspaceBuild{}, err - } - cbr := codersdk.CreateWorkspaceBuildRequest{ - Transition: codersdk.WorkspaceTransition(rawTransition), - } - if templateVersionID != uuid.Nil { - cbr.TemplateVersionID = templateVersionID - } - return client.CreateWorkspaceBuild(ctx, workspaceID, cbr) - }, - } + return codersdk.WorkspaceBuild{}, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) + } + templateVersionID = tvID + } + cbr := codersdk.CreateWorkspaceBuildRequest{ + Transition: codersdk.WorkspaceTransition(args.Transition), + } + if templateVersionID != uuid.Nil { + cbr.TemplateVersionID = templateVersionID + } + return deps.coderClient.CreateWorkspaceBuild(ctx, workspaceID, cbr) + }, +} + +type CreateTemplateVersionArgs struct { + FileID string `json:"file_id"` + TemplateID string `json:"template_id"` +} - CreateTemplateVersion = Tool[codersdk.TemplateVersion]{ - Tool: aisdk.Tool{ - Name: "coder_create_template_version", - Description: `Create a new template version. This is a precursor to creating a template, or you can update an existing template. +var CreateTemplateVersion = Tool[CreateTemplateVersionArgs, codersdk.TemplateVersion]{ + Tool: aisdk.Tool{ + Name: "coder_create_template_version", + Description: `Create a new template version. This is a precursor to creating a template, or you can update an existing template. Templates are Terraform defining a development environment. The provisioned infrastructure must run an Agent that connects to the Coder Control Plane to provide a rich experience. @@ -821,364 +932,346 @@ resource "kubernetes_deployment" "main" { The file_id provided is a reference to a tar file you have uploaded containing the Terraform. `, - Schema: aisdk.Schema{ - Properties: map[string]any{ - "template_id": map[string]any{ - "type": "string", - }, - "file_id": map[string]any{ - "type": "string", - }, + Schema: aisdk.Schema{ + Properties: map[string]any{ + "template_id": map[string]any{ + "type": "string", + }, + "file_id": map[string]any{ + "type": "string", }, - Required: []string{"file_id"}, }, + Required: []string{"file_id"}, }, - Handler: func(ctx context.Context, args map[string]any) (codersdk.TemplateVersion, error) { - client, err := clientFromContext(ctx) - if err != nil { - return codersdk.TemplateVersion{}, err - } - me, err := client.User(ctx, "me") - if err != nil { - return codersdk.TemplateVersion{}, err - } - fileID, err := uuidFromArgs(args, "file_id") - if err != nil { - return codersdk.TemplateVersion{}, err - } - var templateID uuid.UUID - if args["template_id"] != nil { - templateID, err = uuidFromArgs(args, "template_id") - if err != nil { - return codersdk.TemplateVersion{}, err - } - } - templateVersion, err := client.CreateTemplateVersion(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{ - Message: "Created by AI", - StorageMethod: codersdk.ProvisionerStorageMethodFile, - FileID: fileID, - Provisioner: codersdk.ProvisionerTypeTerraform, - TemplateID: templateID, - }) + }, + Handler: func(ctx context.Context, deps Deps, args CreateTemplateVersionArgs) (codersdk.TemplateVersion, error) { + me, err := deps.coderClient.User(ctx, "me") + if err != nil { + return codersdk.TemplateVersion{}, err + } + fileID, err := uuid.Parse(args.FileID) + if err != nil { + return codersdk.TemplateVersion{}, xerrors.Errorf("file_id must be a valid UUID: %w", err) + } + var templateID uuid.UUID + if args.TemplateID != "" { + tid, err := uuid.Parse(args.TemplateID) if err != nil { - return codersdk.TemplateVersion{}, err - } - return templateVersion, nil - }, - } + return codersdk.TemplateVersion{}, xerrors.Errorf("template_id must be a valid UUID: %w", err) + } + templateID = tid + } + templateVersion, err := deps.coderClient.CreateTemplateVersion(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateVersionRequest{ + Message: "Created by AI", + StorageMethod: codersdk.ProvisionerStorageMethodFile, + FileID: fileID, + Provisioner: codersdk.ProvisionerTypeTerraform, + TemplateID: templateID, + }) + if err != nil { + return codersdk.TemplateVersion{}, err + } + return templateVersion, nil + }, +} - GetWorkspaceAgentLogs = Tool[[]string]{ - Tool: aisdk.Tool{ - Name: "coder_get_workspace_agent_logs", - Description: `Get the logs of a workspace agent. +type GetWorkspaceAgentLogsArgs struct { + WorkspaceAgentID string `json:"workspace_agent_id"` +} -More logs may appear after this call. It does not wait for the agent to finish.`, - Schema: aisdk.Schema{ - Properties: map[string]any{ - "workspace_agent_id": map[string]any{ - "type": "string", - }, +var GetWorkspaceAgentLogs = Tool[GetWorkspaceAgentLogsArgs, []string]{ + Tool: aisdk.Tool{ + Name: "coder_get_workspace_agent_logs", + Description: `Get the logs of a workspace agent. + + More logs may appear after this call. It does not wait for the agent to finish.`, + Schema: aisdk.Schema{ + Properties: map[string]any{ + "workspace_agent_id": map[string]any{ + "type": "string", }, - Required: []string{"workspace_agent_id"}, }, + Required: []string{"workspace_agent_id"}, }, - Handler: func(ctx context.Context, args map[string]any) ([]string, error) { - client, err := clientFromContext(ctx) - if err != nil { - return nil, err - } - workspaceAgentID, err := uuidFromArgs(args, "workspace_agent_id") - if err != nil { - return nil, err - } - logs, closer, err := client.WorkspaceAgentLogsAfter(ctx, workspaceAgentID, 0, false) - if err != nil { - return nil, err - } - defer closer.Close() - var acc []string - for logChunk := range logs { - for _, log := range logChunk { - acc = append(acc, log.Output) - } + }, + Handler: func(ctx context.Context, deps Deps, args GetWorkspaceAgentLogsArgs) ([]string, error) { + workspaceAgentID, err := uuid.Parse(args.WorkspaceAgentID) + if err != nil { + return nil, xerrors.Errorf("workspace_agent_id must be a valid UUID: %w", err) + } + logs, closer, err := deps.coderClient.WorkspaceAgentLogsAfter(ctx, workspaceAgentID, 0, false) + if err != nil { + return nil, err + } + defer closer.Close() + var acc []string + for logChunk := range logs { + for _, log := range logChunk { + acc = append(acc, log.Output) } - return acc, nil - }, - } + } + return acc, nil + }, +} - GetWorkspaceBuildLogs = Tool[[]string]{ - Tool: aisdk.Tool{ - Name: "coder_get_workspace_build_logs", - Description: `Get the logs of a workspace build. +type GetWorkspaceBuildLogsArgs struct { + WorkspaceBuildID string `json:"workspace_build_id"` +} -Useful for checking whether a workspace builds successfully or not.`, - Schema: aisdk.Schema{ - Properties: map[string]any{ - "workspace_build_id": map[string]any{ - "type": "string", - }, +var GetWorkspaceBuildLogs = Tool[GetWorkspaceBuildLogsArgs, []string]{ + Tool: aisdk.Tool{ + Name: "coder_get_workspace_build_logs", + Description: `Get the logs of a workspace build. + + Useful for checking whether a workspace builds successfully or not.`, + Schema: aisdk.Schema{ + Properties: map[string]any{ + "workspace_build_id": map[string]any{ + "type": "string", }, - Required: []string{"workspace_build_id"}, }, + Required: []string{"workspace_build_id"}, }, - Handler: func(ctx context.Context, args map[string]any) ([]string, error) { - client, err := clientFromContext(ctx) - if err != nil { - return nil, err - } - workspaceBuildID, err := uuidFromArgs(args, "workspace_build_id") - if err != nil { - return nil, err - } - logs, closer, err := client.WorkspaceBuildLogsAfter(ctx, workspaceBuildID, 0) - if err != nil { - return nil, err - } - defer closer.Close() - var acc []string - for log := range logs { - acc = append(acc, log.Output) - } - return acc, nil - }, - } + }, + Handler: func(ctx context.Context, deps Deps, args GetWorkspaceBuildLogsArgs) ([]string, error) { + workspaceBuildID, err := uuid.Parse(args.WorkspaceBuildID) + if err != nil { + return nil, xerrors.Errorf("workspace_build_id must be a valid UUID: %w", err) + } + logs, closer, err := deps.coderClient.WorkspaceBuildLogsAfter(ctx, workspaceBuildID, 0) + if err != nil { + return nil, err + } + defer closer.Close() + var acc []string + for log := range logs { + acc = append(acc, log.Output) + } + return acc, nil + }, +} - GetTemplateVersionLogs = Tool[[]string]{ - Tool: aisdk.Tool{ - Name: "coder_get_template_version_logs", - Description: "Get the logs of a template version. This is useful to check whether a template version successfully imports or not.", - Schema: aisdk.Schema{ - Properties: map[string]any{ - "template_version_id": map[string]any{ - "type": "string", - }, +type GetTemplateVersionLogsArgs struct { + TemplateVersionID string `json:"template_version_id"` +} + +var GetTemplateVersionLogs = Tool[GetTemplateVersionLogsArgs, []string]{ + Tool: aisdk.Tool{ + Name: "coder_get_template_version_logs", + Description: "Get the logs of a template version. This is useful to check whether a template version successfully imports or not.", + Schema: aisdk.Schema{ + Properties: map[string]any{ + "template_version_id": map[string]any{ + "type": "string", }, - Required: []string{"template_version_id"}, }, + Required: []string{"template_version_id"}, }, - Handler: func(ctx context.Context, args map[string]any) ([]string, error) { - client, err := clientFromContext(ctx) - if err != nil { - return nil, err - } - templateVersionID, err := uuidFromArgs(args, "template_version_id") - if err != nil { - return nil, err - } + }, + Handler: func(ctx context.Context, deps Deps, args GetTemplateVersionLogsArgs) ([]string, error) { + templateVersionID, err := uuid.Parse(args.TemplateVersionID) + if err != nil { + return nil, xerrors.Errorf("template_version_id must be a valid UUID: %w", err) + } + + logs, closer, err := deps.coderClient.TemplateVersionLogsAfter(ctx, templateVersionID, 0) + if err != nil { + return nil, err + } + defer closer.Close() + var acc []string + for log := range logs { + acc = append(acc, log.Output) + } + return acc, nil + }, +} - logs, closer, err := client.TemplateVersionLogsAfter(ctx, templateVersionID, 0) - if err != nil { - return nil, err - } - defer closer.Close() - var acc []string - for log := range logs { - acc = append(acc, log.Output) - } - return acc, nil - }, - } +type UpdateTemplateActiveVersionArgs struct { + TemplateID string `json:"template_id"` + TemplateVersionID string `json:"template_version_id"` +} - UpdateTemplateActiveVersion = Tool[string]{ - Tool: aisdk.Tool{ - Name: "coder_update_template_active_version", - Description: "Update the active version of a template. This is helpful when iterating on templates.", - Schema: aisdk.Schema{ - Properties: map[string]any{ - "template_id": map[string]any{ - "type": "string", - }, - "template_version_id": map[string]any{ - "type": "string", - }, +var UpdateTemplateActiveVersion = Tool[UpdateTemplateActiveVersionArgs, string]{ + Tool: aisdk.Tool{ + Name: "coder_update_template_active_version", + Description: "Update the active version of a template. This is helpful when iterating on templates.", + Schema: aisdk.Schema{ + Properties: map[string]any{ + "template_id": map[string]any{ + "type": "string", + }, + "template_version_id": map[string]any{ + "type": "string", }, - Required: []string{"template_id", "template_version_id"}, }, + Required: []string{"template_id", "template_version_id"}, }, - Handler: func(ctx context.Context, args map[string]any) (string, error) { - client, err := clientFromContext(ctx) - if err != nil { - return "", err - } - templateID, err := uuidFromArgs(args, "template_id") - if err != nil { - return "", err - } - templateVersionID, err := uuidFromArgs(args, "template_version_id") - if err != nil { - return "", err - } - err = client.UpdateActiveTemplateVersion(ctx, templateID, codersdk.UpdateActiveTemplateVersion{ - ID: templateVersionID, - }) - if err != nil { - return "", err - } - return "Successfully updated active version!", nil - }, - } + }, + Handler: func(ctx context.Context, deps Deps, args UpdateTemplateActiveVersionArgs) (string, error) { + templateID, err := uuid.Parse(args.TemplateID) + if err != nil { + return "", xerrors.Errorf("template_id must be a valid UUID: %w", err) + } + templateVersionID, err := uuid.Parse(args.TemplateVersionID) + if err != nil { + return "", xerrors.Errorf("template_version_id must be a valid UUID: %w", err) + } + err = deps.coderClient.UpdateActiveTemplateVersion(ctx, templateID, codersdk.UpdateActiveTemplateVersion{ + ID: templateVersionID, + }) + if err != nil { + return "", err + } + return "Successfully updated active version!", nil + }, +} - UploadTarFile = Tool[codersdk.UploadResponse]{ - Tool: aisdk.Tool{ - Name: "coder_upload_tar_file", - Description: `Create and upload a tar file by key/value mapping of file names to file contents. Use this to create template versions. Reference the tool description of "create_template_version" to understand template requirements.`, - Schema: aisdk.Schema{ - Properties: map[string]any{ - "mime_type": map[string]any{ - "type": "string", - }, - "files": map[string]any{ - "type": "object", - "description": "A map of file names to file contents.", - }, +type UploadTarFileArgs struct { + Files map[string]string `json:"files"` +} + +var UploadTarFile = Tool[UploadTarFileArgs, codersdk.UploadResponse]{ + Tool: aisdk.Tool{ + Name: "coder_upload_tar_file", + Description: `Create and upload a tar file by key/value mapping of file names to file contents. Use this to create template versions. Reference the tool description of "create_template_version" to understand template requirements.`, + Schema: aisdk.Schema{ + Properties: map[string]any{ + "files": map[string]any{ + "type": "object", + "description": "A map of file names to file contents.", }, - Required: []string{"mime_type", "files"}, }, + Required: []string{"files"}, }, - Handler: func(ctx context.Context, args map[string]any) (codersdk.UploadResponse, error) { - client, err := clientFromContext(ctx) - if err != nil { - return codersdk.UploadResponse{}, err - } - - files, ok := args["files"].(map[string]any) - if !ok { - return codersdk.UploadResponse{}, xerrors.New("files must be a map") - } - - pipeReader, pipeWriter := io.Pipe() - go func() { - defer pipeWriter.Close() - tarWriter := tar.NewWriter(pipeWriter) - for name, content := range files { - contentStr, ok := content.(string) - if !ok { - _ = pipeWriter.CloseWithError(xerrors.New("file content must be a string")) - return - } - header := &tar.Header{ - Name: name, - Size: int64(len(contentStr)), - Mode: 0o644, - } - if err := tarWriter.WriteHeader(header); err != nil { - _ = pipeWriter.CloseWithError(err) - return - } - if _, err := tarWriter.Write([]byte(contentStr)); err != nil { - _ = pipeWriter.CloseWithError(err) - return - } + }, + Handler: func(ctx context.Context, deps Deps, args UploadTarFileArgs) (codersdk.UploadResponse, error) { + pipeReader, pipeWriter := io.Pipe() + done := make(chan struct{}) + go func() { + defer func() { + _ = pipeWriter.Close() + close(done) + }() + tarWriter := tar.NewWriter(pipeWriter) + for name, content := range args.Files { + header := &tar.Header{ + Name: name, + Size: int64(len(content)), + Mode: 0o644, } - if err := tarWriter.Close(); err != nil { + if err := tarWriter.WriteHeader(header); err != nil { _ = pipeWriter.CloseWithError(err) + return + } + if _, err := tarWriter.Write([]byte(content)); err != nil { + _ = pipeWriter.CloseWithError(err) + return } - }() - - resp, err := client.Upload(ctx, codersdk.ContentTypeTar, pipeReader) - if err != nil { - return codersdk.UploadResponse{}, err } - return resp, nil - }, - } + if err := tarWriter.Close(); err != nil { + _ = pipeWriter.CloseWithError(err) + } + }() - CreateTemplate = Tool[codersdk.Template]{ - Tool: aisdk.Tool{ - Name: "coder_create_template", - Description: "Create a new template in Coder. First, you must create a template version.", - Schema: aisdk.Schema{ - Properties: map[string]any{ - "name": map[string]any{ - "type": "string", - }, - "display_name": map[string]any{ - "type": "string", - }, - "description": map[string]any{ - "type": "string", - }, - "icon": map[string]any{ - "type": "string", - "description": "A URL to an icon to use.", - }, - "version_id": map[string]any{ - "type": "string", - "description": "The ID of the version to use.", - }, + resp, err := deps.coderClient.Upload(ctx, codersdk.ContentTypeTar, pipeReader) + if err != nil { + _ = pipeReader.CloseWithError(err) + <-done + return codersdk.UploadResponse{}, err + } + <-done + return resp, nil + }, +} + +type CreateTemplateArgs struct { + Description string `json:"description"` + DisplayName string `json:"display_name"` + Icon string `json:"icon"` + Name string `json:"name"` + VersionID string `json:"version_id"` +} + +var CreateTemplate = Tool[CreateTemplateArgs, codersdk.Template]{ + Tool: aisdk.Tool{ + Name: "coder_create_template", + Description: "Create a new template in Coder. First, you must create a template version.", + Schema: aisdk.Schema{ + Properties: map[string]any{ + "name": map[string]any{ + "type": "string", + }, + "display_name": map[string]any{ + "type": "string", + }, + "description": map[string]any{ + "type": "string", + }, + "icon": map[string]any{ + "type": "string", + "description": "A URL to an icon to use.", + }, + "version_id": map[string]any{ + "type": "string", + "description": "The ID of the version to use.", }, - Required: []string{"name", "display_name", "description", "version_id"}, }, + Required: []string{"name", "display_name", "description", "version_id"}, }, - Handler: func(ctx context.Context, args map[string]any) (codersdk.Template, error) { - client, err := clientFromContext(ctx) - if err != nil { - return codersdk.Template{}, err - } - me, err := client.User(ctx, "me") - if err != nil { - return codersdk.Template{}, err - } - versionID, err := uuidFromArgs(args, "version_id") - if err != nil { - return codersdk.Template{}, err - } - name, ok := args["name"].(string) - if !ok { - return codersdk.Template{}, xerrors.New("name must be a string") - } - displayName, ok := args["display_name"].(string) - if !ok { - return codersdk.Template{}, xerrors.New("display_name must be a string") - } - description, ok := args["description"].(string) - if !ok { - return codersdk.Template{}, xerrors.New("description must be a string") - } + }, + Handler: func(ctx context.Context, deps Deps, args CreateTemplateArgs) (codersdk.Template, error) { + me, err := deps.coderClient.User(ctx, "me") + if err != nil { + return codersdk.Template{}, err + } + versionID, err := uuid.Parse(args.VersionID) + if err != nil { + return codersdk.Template{}, xerrors.Errorf("version_id must be a valid UUID: %w", err) + } + template, err := deps.coderClient.CreateTemplate(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateRequest{ + Name: args.Name, + DisplayName: args.DisplayName, + Description: args.Description, + VersionID: versionID, + }) + if err != nil { + return codersdk.Template{}, err + } + return template, nil + }, +} - template, err := client.CreateTemplate(ctx, me.OrganizationIDs[0], codersdk.CreateTemplateRequest{ - Name: name, - DisplayName: displayName, - Description: description, - VersionID: versionID, - }) - if err != nil { - return codersdk.Template{}, err - } - return template, nil - }, - } +type DeleteTemplateArgs struct { + TemplateID string `json:"template_id"` +} - DeleteTemplate = Tool[string]{ - Tool: aisdk.Tool{ - Name: "coder_delete_template", - Description: "Delete a template. This is irreversible.", - Schema: aisdk.Schema{ - Properties: map[string]any{ - "template_id": map[string]any{ - "type": "string", - }, +var DeleteTemplate = Tool[DeleteTemplateArgs, codersdk.Response]{ + Tool: aisdk.Tool{ + Name: "coder_delete_template", + Description: "Delete a template. This is irreversible.", + Schema: aisdk.Schema{ + Properties: map[string]any{ + "template_id": map[string]any{ + "type": "string", }, }, }, - Handler: func(ctx context.Context, args map[string]any) (string, error) { - client, err := clientFromContext(ctx) - if err != nil { - return "", err - } - - templateID, err := uuidFromArgs(args, "template_id") - if err != nil { - return "", err - } - err = client.DeleteTemplate(ctx, templateID) - if err != nil { - return "", err - } - return "Successfully deleted template!", nil - }, - } -) + }, + Handler: func(ctx context.Context, deps Deps, args DeleteTemplateArgs) (codersdk.Response, error) { + templateID, err := uuid.Parse(args.TemplateID) + if err != nil { + return codersdk.Response{}, xerrors.Errorf("template_id must be a valid UUID: %w", err) + } + err = deps.coderClient.DeleteTemplate(ctx, templateID) + if err != nil { + return codersdk.Response{}, err + } + return codersdk.Response{ + Message: "Template deleted successfully.", + }, nil + }, +} type MinimalWorkspace struct { ID string `json:"id"` @@ -1199,61 +1292,3 @@ type MinimalTemplate struct { ActiveVersionID uuid.UUID `json:"active_version_id"` ActiveUserCount int `json:"active_user_count"` } - -func clientFromContext(ctx context.Context) (*codersdk.Client, error) { - client, ok := ctx.Value(clientContextKey{}).(*codersdk.Client) - if !ok { - return nil, xerrors.New("client required in context") - } - return client, nil -} - -type clientContextKey struct{} - -func WithClient(ctx context.Context, client *codersdk.Client) context.Context { - return context.WithValue(ctx, clientContextKey{}, client) -} - -type agentClientContextKey struct{} - -func WithAgentClient(ctx context.Context, client *agentsdk.Client) context.Context { - return context.WithValue(ctx, agentClientContextKey{}, client) -} - -func agentClientFromContext(ctx context.Context) (*agentsdk.Client, error) { - client, ok := ctx.Value(agentClientContextKey{}).(*agentsdk.Client) - if !ok { - return nil, xerrors.New("agent client required in context") - } - return client, nil -} - -type workspaceAppStatusSlugContextKey struct{} - -func WithWorkspaceAppStatusSlug(ctx context.Context, slug string) context.Context { - return context.WithValue(ctx, workspaceAppStatusSlugContextKey{}, slug) -} - -func workspaceAppStatusSlugFromContext(ctx context.Context) (string, bool) { - slug, ok := ctx.Value(workspaceAppStatusSlugContextKey{}).(string) - if !ok || slug == "" { - return "", false - } - return slug, true -} - -func uuidFromArgs(args map[string]any, key string) (uuid.UUID, error) { - argKey, ok := args[key] - if !ok { - return uuid.Nil, nil // No error if key is not present - } - raw, ok := argKey.(string) - if !ok { - return uuid.Nil, xerrors.Errorf("%s must be a string", key) - } - id, err := uuid.Parse(raw) - if err != nil { - return uuid.Nil, xerrors.Errorf("failed to parse %s: %w", key, err) - } - return id, nil -} diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index 1504e956f6bd4..fae4e85e52a66 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -2,6 +2,7 @@ package toolsdk_test import ( "context" + "encoding/json" "os" "sort" "sync" @@ -9,7 +10,10 @@ import ( "time" "github.com/google/uuid" + "github.com/kylecarbs/aisdk-go" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/goleak" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" @@ -68,26 +72,35 @@ func TestTools(t *testing.T) { }) t.Run("ReportTask", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithAgentClient(ctx, agentClient) - ctx = toolsdk.WithWorkspaceAppStatusSlug(ctx, "some-agent-app") - _, err := testTool(ctx, t, toolsdk.ReportTask, map[string]any{ - "summary": "test summary", - "state": "complete", - "link": "https://example.com", + tb, err := toolsdk.NewDeps(memberClient, toolsdk.WithAgentClient(agentClient), toolsdk.WithAppStatusSlug("some-agent-app")) + require.NoError(t, err) + _, err = testTool(t, toolsdk.ReportTask, tb, toolsdk.ReportTaskArgs{ + Summary: "test summary", + State: "complete", + Link: "https://example.com", }) require.NoError(t, err) }) - t.Run("ListTemplates", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) + t.Run("GetWorkspace", func(t *testing.T) { + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) + result, err := testTool(t, toolsdk.GetWorkspace, tb, toolsdk.GetWorkspaceArgs{ + WorkspaceID: r.Workspace.ID.String(), + }) + + require.NoError(t, err) + require.Equal(t, r.Workspace.ID, result.ID, "expected the workspace ID to match") + }) + t.Run("ListTemplates", func(t *testing.T) { + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) // Get the templates directly for comparison expected, err := memberClient.Templates(context.Background(), codersdk.TemplateFilter{}) require.NoError(t, err) - result, err := testTool(ctx, t, toolsdk.ListTemplates, map[string]any{}) + result, err := testTool(t, toolsdk.ListTemplates, tb, toolsdk.NoArgs{}) require.NoError(t, err) require.Len(t, result, len(expected)) @@ -105,10 +118,9 @@ func TestTools(t *testing.T) { }) t.Run("Whoami", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - - result, err := testTool(ctx, t, toolsdk.GetAuthenticatedUser, map[string]any{}) + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) + result, err := testTool(t, toolsdk.GetAuthenticatedUser, tb, toolsdk.NoArgs{}) require.NoError(t, err) require.Equal(t, member.ID, result.ID) @@ -116,12 +128,9 @@ func TestTools(t *testing.T) { }) t.Run("ListWorkspaces", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - - result, err := testTool(ctx, t, toolsdk.ListWorkspaces, map[string]any{ - "owner": "me", - }) + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) + result, err := testTool(t, toolsdk.ListWorkspaces, tb, toolsdk.ListWorkspacesArgs{}) require.NoError(t, err) require.Len(t, result, 1, "expected 1 workspace") @@ -129,26 +138,14 @@ func TestTools(t *testing.T) { require.Equal(t, r.Workspace.ID.String(), workspace.ID, "expected the workspace to match the one we created") }) - t.Run("GetWorkspace", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - - result, err := testTool(ctx, t, toolsdk.GetWorkspace, map[string]any{ - "workspace_id": r.Workspace.ID.String(), - }) - - require.NoError(t, err) - require.Equal(t, r.Workspace.ID, result.ID, "expected the workspace ID to match") - }) - t.Run("CreateWorkspaceBuild", func(t *testing.T) { t.Run("Stop", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - - result, err := testTool(ctx, t, toolsdk.CreateWorkspaceBuild, map[string]any{ - "workspace_id": r.Workspace.ID.String(), - "transition": "stop", + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) + result, err := testTool(t, toolsdk.CreateWorkspaceBuild, tb, toolsdk.CreateWorkspaceBuildArgs{ + WorkspaceID: r.Workspace.ID.String(), + Transition: "stop", }) require.NoError(t, err) @@ -164,11 +161,11 @@ func TestTools(t *testing.T) { t.Run("Start", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - - result, err := testTool(ctx, t, toolsdk.CreateWorkspaceBuild, map[string]any{ - "workspace_id": r.Workspace.ID.String(), - "transition": "start", + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) + result, err := testTool(t, toolsdk.CreateWorkspaceBuild, tb, toolsdk.CreateWorkspaceBuildArgs{ + WorkspaceID: r.Workspace.ID.String(), + Transition: "start", }) require.NoError(t, err) @@ -184,8 +181,8 @@ func TestTools(t *testing.T) { t.Run("TemplateVersionChange", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) // Get the current template version ID before updating workspace, err := memberClient.Workspace(ctx, r.Workspace.ID) require.NoError(t, err) @@ -201,10 +198,10 @@ func TestTools(t *testing.T) { }).Do() // Update to new version - updateBuild, err := testTool(ctx, t, toolsdk.CreateWorkspaceBuild, map[string]any{ - "workspace_id": r.Workspace.ID.String(), - "transition": "start", - "template_version_id": newVersion.TemplateVersion.ID.String(), + updateBuild, err := testTool(t, toolsdk.CreateWorkspaceBuild, tb, toolsdk.CreateWorkspaceBuildArgs{ + WorkspaceID: r.Workspace.ID.String(), + Transition: "start", + TemplateVersionID: newVersion.TemplateVersion.ID.String(), }) require.NoError(t, err) require.Equal(t, codersdk.WorkspaceTransitionStart, updateBuild.Transition) @@ -214,10 +211,10 @@ func TestTools(t *testing.T) { require.NoError(t, client.CancelWorkspaceBuild(ctx, updateBuild.ID)) // Roll back to the original version - rollbackBuild, err := testTool(ctx, t, toolsdk.CreateWorkspaceBuild, map[string]any{ - "workspace_id": r.Workspace.ID.String(), - "transition": "start", - "template_version_id": originalVersionID.String(), + rollbackBuild, err := testTool(t, toolsdk.CreateWorkspaceBuild, tb, toolsdk.CreateWorkspaceBuildArgs{ + WorkspaceID: r.Workspace.ID.String(), + Transition: "start", + TemplateVersionID: originalVersionID.String(), }) require.NoError(t, err) require.Equal(t, codersdk.WorkspaceTransitionStart, rollbackBuild.Transition) @@ -229,11 +226,10 @@ func TestTools(t *testing.T) { }) t.Run("ListTemplateVersionParameters", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - - params, err := testTool(ctx, t, toolsdk.ListTemplateVersionParameters, map[string]any{ - "template_version_id": r.TemplateVersion.ID.String(), + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) + params, err := testTool(t, toolsdk.ListTemplateVersionParameters, tb, toolsdk.ListTemplateVersionParametersArgs{ + TemplateVersionID: r.TemplateVersion.ID.String(), }) require.NoError(t, err) @@ -241,11 +237,10 @@ func TestTools(t *testing.T) { }) t.Run("GetWorkspaceAgentLogs", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, client) - - logs, err := testTool(ctx, t, toolsdk.GetWorkspaceAgentLogs, map[string]any{ - "workspace_agent_id": agentID.String(), + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) + logs, err := testTool(t, toolsdk.GetWorkspaceAgentLogs, tb, toolsdk.GetWorkspaceAgentLogsArgs{ + WorkspaceAgentID: agentID.String(), }) require.NoError(t, err) @@ -253,11 +248,10 @@ func TestTools(t *testing.T) { }) t.Run("GetWorkspaceBuildLogs", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - - logs, err := testTool(ctx, t, toolsdk.GetWorkspaceBuildLogs, map[string]any{ - "workspace_build_id": r.Build.ID.String(), + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) + logs, err := testTool(t, toolsdk.GetWorkspaceBuildLogs, tb, toolsdk.GetWorkspaceBuildLogsArgs{ + WorkspaceBuildID: r.Build.ID.String(), }) require.NoError(t, err) @@ -265,11 +259,10 @@ func TestTools(t *testing.T) { }) t.Run("GetTemplateVersionLogs", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - - logs, err := testTool(ctx, t, toolsdk.GetTemplateVersionLogs, map[string]any{ - "template_version_id": r.TemplateVersion.ID.String(), + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) + logs, err := testTool(t, toolsdk.GetTemplateVersionLogs, tb, toolsdk.GetTemplateVersionLogsArgs{ + TemplateVersionID: r.TemplateVersion.ID.String(), }) require.NoError(t, err) @@ -277,12 +270,11 @@ func TestTools(t *testing.T) { }) t.Run("UpdateTemplateActiveVersion", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, client) // Use owner client for permission - - result, err := testTool(ctx, t, toolsdk.UpdateTemplateActiveVersion, map[string]any{ - "template_id": r.Template.ID.String(), - "template_version_id": r.TemplateVersion.ID.String(), + tb, err := toolsdk.NewDeps(client) + require.NoError(t, err) + result, err := testTool(t, toolsdk.UpdateTemplateActiveVersion, tb, toolsdk.UpdateTemplateActiveVersionArgs{ + TemplateID: r.Template.ID.String(), + TemplateVersionID: r.TemplateVersion.ID.String(), }) require.NoError(t, err) @@ -290,11 +282,10 @@ func TestTools(t *testing.T) { }) t.Run("DeleteTemplate", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, client) - - _, err := testTool(ctx, t, toolsdk.DeleteTemplate, map[string]any{ - "template_id": r.Template.ID.String(), + tb, err := toolsdk.NewDeps(client) + require.NoError(t, err) + _, err = testTool(t, toolsdk.DeleteTemplate, tb, toolsdk.DeleteTemplateArgs{ + TemplateID: r.Template.ID.String(), }) // This will fail with because there already exists a workspace. @@ -302,16 +293,14 @@ func TestTools(t *testing.T) { }) t.Run("UploadTarFile", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, client) - - files := map[string]any{ - "main.tf": "resource \"null_resource\" \"example\" {}", + files := map[string]string{ + "main.tf": `resource "null_resource" "example" {}`, } + tb, err := toolsdk.NewDeps(memberClient) + require.NoError(t, err) - result, err := testTool(ctx, t, toolsdk.UploadTarFile, map[string]any{ - "mime_type": string(codersdk.ContentTypeTar), - "files": files, + result, err := testTool(t, toolsdk.UploadTarFile, tb, toolsdk.UploadTarFileArgs{ + Files: files, }) require.NoError(t, err) @@ -319,23 +308,30 @@ func TestTools(t *testing.T) { }) t.Run("CreateTemplateVersion", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, client) - + tb, err := toolsdk.NewDeps(client) + require.NoError(t, err) // nolint:gocritic // This is in a test package and does not end up in the build file := dbgen.File(t, store, database.File{}) - - tv, err := testTool(ctx, t, toolsdk.CreateTemplateVersion, map[string]any{ - "file_id": file.ID.String(), + t.Run("WithoutTemplateID", func(t *testing.T) { + tv, err := testTool(t, toolsdk.CreateTemplateVersion, tb, toolsdk.CreateTemplateVersionArgs{ + FileID: file.ID.String(), + }) + require.NoError(t, err) + require.NotEmpty(t, tv) + }) + t.Run("WithTemplateID", func(t *testing.T) { + tv, err := testTool(t, toolsdk.CreateTemplateVersion, tb, toolsdk.CreateTemplateVersionArgs{ + FileID: file.ID.String(), + TemplateID: r.Template.ID.String(), + }) + require.NoError(t, err) + require.NotEmpty(t, tv) }) - require.NoError(t, err) - require.NotEmpty(t, tv) }) t.Run("CreateTemplate", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, client) - + tb, err := toolsdk.NewDeps(client) + require.NoError(t, err) // Create a new template version for use here. tv := dbfake.TemplateVersion(t, store). // nolint:gocritic // This is in a test package and does not end up in the build @@ -343,26 +339,25 @@ func TestTools(t *testing.T) { SkipCreateTemplate().Do() // We're going to re-use the pre-existing template version - _, err := testTool(ctx, t, toolsdk.CreateTemplate, map[string]any{ - "name": testutil.GetRandomNameHyphenated(t), - "display_name": "Test Template", - "description": "This is a test template", - "version_id": tv.TemplateVersion.ID.String(), + _, err = testTool(t, toolsdk.CreateTemplate, tb, toolsdk.CreateTemplateArgs{ + Name: testutil.GetRandomNameHyphenated(t), + DisplayName: "Test Template", + Description: "This is a test template", + VersionID: tv.TemplateVersion.ID.String(), }) require.NoError(t, err) }) t.Run("CreateWorkspace", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitShort) - ctx = toolsdk.WithClient(ctx, memberClient) - + tb, err := toolsdk.NewDeps(client) + require.NoError(t, err) // We need a template version ID to create a workspace - res, err := testTool(ctx, t, toolsdk.CreateWorkspace, map[string]any{ - "user": "me", - "template_version_id": r.TemplateVersion.ID.String(), - "name": testutil.GetRandomNameHyphenated(t), - "rich_parameters": map[string]any{}, + res, err := testTool(t, toolsdk.CreateWorkspace, tb, toolsdk.CreateWorkspaceArgs{ + User: "me", + TemplateVersionID: r.TemplateVersion.ID.String(), + Name: testutil.GetRandomNameHyphenated(t), + RichParameters: map[string]string{}, }) // The creation might fail for various reasons, but the important thing is @@ -376,11 +371,172 @@ func TestTools(t *testing.T) { var testedTools sync.Map // testTool is a helper function to test a tool and mark it as tested. -func testTool[T any](ctx context.Context, t *testing.T, tool toolsdk.Tool[T], args map[string]any) (T, error) { +// Note that we test the _generic_ version of the tool and not the typed one. +// This is to mimic how we expect external callers to use the tool. +func testTool[Arg, Ret any](t *testing.T, tool toolsdk.Tool[Arg, Ret], tb toolsdk.Deps, args Arg) (Ret, error) { t.Helper() - testedTools.Store(tool.Tool.Name, true) - result, err := tool.Handler(ctx, args) - return result, err + defer func() { testedTools.Store(tool.Tool.Name, true) }() + toolArgs, err := json.Marshal(args) + require.NoError(t, err, "failed to marshal args") + result, err := tool.Generic().Handler(context.Background(), tb, toolArgs) + var ret Ret + require.NoError(t, json.Unmarshal(result, &ret), "failed to unmarshal result %q", string(result)) + return ret, err +} + +func TestWithRecovery(t *testing.T) { + t.Parallel() + t.Run("OK", func(t *testing.T) { + t.Parallel() + fakeTool := toolsdk.GenericTool{ + Tool: aisdk.Tool{ + Name: "echo", + Description: "Echoes the input.", + }, + Handler: func(ctx context.Context, tb toolsdk.Deps, args json.RawMessage) (json.RawMessage, error) { + return args, nil + }, + } + + wrapped := toolsdk.WithRecover(fakeTool.Handler) + v, err := wrapped(context.Background(), toolsdk.Deps{}, []byte(`{}`)) + require.NoError(t, err) + require.JSONEq(t, `{}`, string(v)) + }) + + t.Run("Error", func(t *testing.T) { + t.Parallel() + fakeTool := toolsdk.GenericTool{ + Tool: aisdk.Tool{ + Name: "fake_tool", + Description: "Returns an error for testing.", + }, + Handler: func(ctx context.Context, tb toolsdk.Deps, args json.RawMessage) (json.RawMessage, error) { + return nil, assert.AnError + }, + } + wrapped := toolsdk.WithRecover(fakeTool.Handler) + v, err := wrapped(context.Background(), toolsdk.Deps{}, []byte(`{}`)) + require.Nil(t, v) + require.ErrorIs(t, err, assert.AnError) + }) + + t.Run("Panic", func(t *testing.T) { + t.Parallel() + panicTool := toolsdk.GenericTool{ + Tool: aisdk.Tool{ + Name: "panic_tool", + Description: "Panics for testing.", + }, + Handler: func(ctx context.Context, tb toolsdk.Deps, args json.RawMessage) (json.RawMessage, error) { + panic("you can't sweat this fever out") + }, + } + + wrapped := toolsdk.WithRecover(panicTool.Handler) + v, err := wrapped(context.Background(), toolsdk.Deps{}, []byte("disco")) + require.Empty(t, v) + require.ErrorContains(t, err, "you can't sweat this fever out") + }) +} + +type testContextKey struct{} + +func TestWithCleanContext(t *testing.T) { + t.Parallel() + + t.Run("NoContextKeys", func(t *testing.T) { + t.Parallel() + + // This test is to ensure that the context values are not set in the + // toolsdk package. + ctxTool := toolsdk.GenericTool{ + Tool: aisdk.Tool{ + Name: "context_tool", + Description: "Returns the context value for testing.", + }, + Handler: func(toolCtx context.Context, tb toolsdk.Deps, args json.RawMessage) (json.RawMessage, error) { + v := toolCtx.Value(testContextKey{}) + assert.Nil(t, v, "expected the context value to be nil") + return nil, nil + }, + } + + wrapped := toolsdk.WithCleanContext(ctxTool.Handler) + ctx := context.WithValue(context.Background(), testContextKey{}, "test") + _, _ = wrapped(ctx, toolsdk.Deps{}, []byte(`{}`)) + }) + + t.Run("PropagateCancel", func(t *testing.T) { + t.Parallel() + + // This test is to ensure that the context is canceled properly. + callCh := make(chan struct{}) + ctxTool := toolsdk.GenericTool{ + Tool: aisdk.Tool{ + Name: "context_tool", + Description: "Returns the context value for testing.", + }, + Handler: func(toolCtx context.Context, tb toolsdk.Deps, args json.RawMessage) (json.RawMessage, error) { + defer close(callCh) + // Wait for the context to be canceled + <-toolCtx.Done() + return nil, toolCtx.Err() + }, + } + wrapped := toolsdk.WithCleanContext(ctxTool.Handler) + errCh := make(chan error, 1) + + tCtx := testutil.Context(t, testutil.WaitShort) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + go func() { + _, err := wrapped(ctx, toolsdk.Deps{}, []byte(`{}`)) + errCh <- err + }() + + cancel() + + // Ensure the tool is called + select { + case <-callCh: + case <-tCtx.Done(): + require.Fail(t, "test timed out before handler was called") + } + + // Ensure the correct error is returned + select { + case <-tCtx.Done(): + require.Fail(t, "test timed out") + case err := <-errCh: + // Context was canceled and the done channel was closed + require.ErrorIs(t, err, context.Canceled) + } + }) + + t.Run("PropagateDeadline", func(t *testing.T) { + t.Parallel() + + // This test ensures that the context deadline is propagated to the child + // from the parent. + ctxTool := toolsdk.GenericTool{ + Tool: aisdk.Tool{ + Name: "context_tool_deadline", + Description: "Checks if context has deadline.", + }, + Handler: func(toolCtx context.Context, tb toolsdk.Deps, args json.RawMessage) (json.RawMessage, error) { + _, ok := toolCtx.Deadline() + assert.True(t, ok, "expected deadline to be set on the child context") + return nil, nil + }, + } + + wrapped := toolsdk.WithCleanContext(ctxTool.Handler) + parent, cancel := context.WithTimeout(context.Background(), testutil.IntervalFast) + t.Cleanup(cancel) + _, err := wrapped(parent, toolsdk.Deps{}, []byte(`{}`)) + require.NoError(t, err) + }) } // TestMain runs after all tests to ensure that all tools in this package have @@ -402,6 +558,7 @@ func TestMain(m *testing.M) { } if len(untested) > 0 && code == 0 { + code = 1 println("The following tools were not tested:") for _, tool := range untested { println(" - " + tool) @@ -409,7 +566,14 @@ func TestMain(m *testing.M) { println("Please ensure that all tools are tested using testTool().") println("If you just added a new tool, please add a test for it.") println("NOTE: if you just ran an individual test, this is expected.") - os.Exit(1) + } + + // Check for goroutine leaks. Below is adapted from goleak.VerifyTestMain: + if code == 0 { + if err := goleak.Find(testutil.GoleakOptions...); err != nil { + println("goleak: Errors on successful test run: ", err.Error()) + code = 1 + } } os.Exit(code)