diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index ac362295f0e00..732f5aa218de7 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -83,6 +83,7 @@ import ( "github.com/coder/coder/v2/coderd/schedule" "github.com/coder/coder/v2/coderd/telemetry" "github.com/coder/coder/v2/coderd/updatecheck" + "github.com/coder/coder/v2/coderd/usage" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/coderd/webpush" "github.com/coder/coder/v2/coderd/workspaceapps" @@ -186,6 +187,7 @@ type Options struct { TelemetryReporter telemetry.Reporter ProvisionerdServerMetrics *provisionerdserver.Metrics + UsageInserter usage.Inserter } // New constructs a codersdk client connected to an in-memory API instance. @@ -266,6 +268,11 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can } } + var usageInserter *atomic.Pointer[usage.Inserter] + if options.UsageInserter != nil { + usageInserter = &atomic.Pointer[usage.Inserter]{} + usageInserter.Store(&options.UsageInserter) + } if options.Database == nil { options.Database, options.Pubsub = dbtestutil.NewDB(t) } @@ -559,6 +566,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can Database: options.Database, Pubsub: options.Pubsub, ExternalAuthConfigs: options.ExternalAuthConfigs, + UsageInserter: usageInserter, Auditor: options.Auditor, ConnectionLogger: options.ConnectionLogger, diff --git a/coderd/coderdtest/usage.go b/coderd/coderdtest/usage.go new file mode 100644 index 0000000000000..4da724b1779cd --- /dev/null +++ b/coderd/coderdtest/usage.go @@ -0,0 +1,44 @@ +package coderdtest + +import ( + "context" + "sync" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/usage" + "github.com/coder/coder/v2/coderd/usage/usagetypes" +) + +var _ usage.Inserter = (*UsageInserter)(nil) + +type UsageInserter struct { + sync.Mutex + events []usagetypes.DiscreteEvent +} + +func NewUsageInserter() *UsageInserter { + return &UsageInserter{ + events: []usagetypes.DiscreteEvent{}, + } +} + +func (u *UsageInserter) InsertDiscreteUsageEvent(_ context.Context, _ database.Store, event usagetypes.DiscreteEvent) error { + u.Lock() + defer u.Unlock() + u.events = append(u.events, event) + return nil +} + +func (u *UsageInserter) GetEvents() []usagetypes.DiscreteEvent { + u.Lock() + defer u.Unlock() + eventsCopy := make([]usagetypes.DiscreteEvent, len(u.events)) + copy(eventsCopy, u.events) + return eventsCopy +} + +func (u *UsageInserter) Reset() { + u.Lock() + defer u.Unlock() + u.events = []usagetypes.DiscreteEvent{} +} diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index c4598beaf8399..95a950f67a20d 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -2026,13 +2026,11 @@ func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.Pro } var ( - hasAITask bool unknownAppID string taskAppID uuid.NullUUID taskAgentID uuid.NullUUID ) if tasks := jobType.WorkspaceBuild.GetAiTasks(); len(tasks) > 0 { - hasAITask = true task := tasks[0] if task == nil { return xerrors.Errorf("update ai task: task is nil") @@ -2048,7 +2046,6 @@ func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.Pro if !slices.Contains(appIDs, appID) { unknownAppID = appID - hasAITask = false } else { // Only parse for valid app and agent to avoid fk violation. id, err := uuid.Parse(appID) @@ -2083,7 +2080,7 @@ func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.Pro Level: []database.LogLevel{database.LogLevelWarn, database.LogLevelWarn, database.LogLevelWarn, database.LogLevelWarn}, Stage: []string{"Cleaning Up", "Cleaning Up", "Cleaning Up", "Cleaning Up"}, Output: []string{ - fmt.Sprintf("Unknown ai_task_app_id %q. This workspace will be unable to run AI tasks. This may be due to a template configuration issue, please check with the template author.", taskAppID.UUID.String()), + fmt.Sprintf("Unknown ai_task_app_id %q. This workspace will be unable to run AI tasks. This may be due to a template configuration issue, please check with the template author.", unknownAppID), "Template author: double-check the following:", " - You have associated the coder_ai_task with a valid coder_app in your template (ref: https://registry.terraform.io/providers/coder/coder/latest/docs/resources/ai_task).", " - You have associated the coder_agent with at least one other compute resource. Agents with no other associated resources are not inserted into the database.", @@ -2098,21 +2095,23 @@ func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.Pro } } - if hasAITask && workspaceBuild.Transition == database.WorkspaceTransitionStart { - // Insert usage event for managed agents. - usageInserter := s.UsageInserter.Load() - if usageInserter != nil { - event := usagetypes.DCManagedAgentsV1{ - Count: 1, - } - err = (*usageInserter).InsertDiscreteUsageEvent(ctx, db, event) - if err != nil { - return xerrors.Errorf("insert %q event: %w", event.EventType(), err) + var hasAITask bool + if task, err := db.GetTaskByWorkspaceID(ctx, workspace.ID); err == nil { + hasAITask = true + if workspaceBuild.Transition == database.WorkspaceTransitionStart { + // Insert usage event for managed agents. + usageInserter := s.UsageInserter.Load() + if usageInserter != nil { + event := usagetypes.DCManagedAgentsV1{ + Count: 1, + } + err = (*usageInserter).InsertDiscreteUsageEvent(ctx, db, event) + if err != nil { + return xerrors.Errorf("insert %q event: %w", event.EventType(), err) + } } } - } - if task, err := db.GetTaskByWorkspaceID(ctx, workspace.ID); err == nil { // Irrespective of whether the agent or sidebar app is present, // perform the upsert to ensure a link between the task and // workspace build. Linking the task to the build is typically diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index 4dc8621736b5c..c151b73aefdd0 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -2878,7 +2878,7 @@ func TestCompleteJob(t *testing.T) { sidebarAppID := uuid.New() for _, tc := range []testcase{ { - name: "has_ai_task is false by default", + name: "has_ai_task is false if task_id is nil", transition: database.WorkspaceTransitionStart, input: &proto.CompletedJob_WorkspaceBuild{ // No AiTasks defined. @@ -2887,6 +2887,37 @@ func TestCompleteJob(t *testing.T) { expectHasAiTask: false, expectUsageEvent: false, }, + { + name: "has_ai_task is false even if there are coder_ai_task resources, but no task_id", + transition: database.WorkspaceTransitionStart, + input: &proto.CompletedJob_WorkspaceBuild{ + AiTasks: []*sdkproto.AITask{ + { + Id: uuid.NewString(), + AppId: sidebarAppID.String(), + }, + }, + Resources: []*sdkproto.Resource{ + { + Agents: []*sdkproto.Agent{ + { + Id: uuid.NewString(), + Name: "a", + Apps: []*sdkproto.App{ + { + Id: sidebarAppID.String(), + Slug: "test-app", + }, + }, + }, + }, + }, + }, + }, + isTask: false, + expectHasAiTask: false, + expectUsageEvent: false, + }, { name: "has_ai_task is set to true", transition: database.WorkspaceTransitionStart, @@ -2964,15 +2995,17 @@ func TestCompleteJob(t *testing.T) { { Id: uuid.NewString(), // Non-existing app ID would previously trigger a FK violation. - // Now it should just be ignored. + // Now it will trigger a warning instead in the provisioner logs. AppId: sidebarAppID.String(), }, }, }, isTask: true, expectTaskStatus: database.TaskStatusInitializing, - expectHasAiTask: false, - expectUsageEvent: false, + // You can still "sort of" use a task in this state, but as we don't have + // the correct app ID you won't be able to communicate with it via Coder. + expectHasAiTask: true, + expectUsageEvent: true, }, { name: "has_ai_task is set to true, but transition is not start", @@ -3007,19 +3040,6 @@ func TestCompleteJob(t *testing.T) { expectHasAiTask: true, expectUsageEvent: false, }, - { - name: "current build does not have ai task but previous build did", - seedFunc: seedPreviousWorkspaceStartWithAITask, - transition: database.WorkspaceTransitionStop, - input: &proto.CompletedJob_WorkspaceBuild{ - AiTasks: []*sdkproto.AITask{}, - Resources: []*sdkproto.Resource{}, - }, - isTask: true, - expectTaskStatus: database.TaskStatusPaused, - expectHasAiTask: false, // We no longer inherit this from the previous build. - expectUsageEvent: false, - }, } { t.Run(tc.name, func(t *testing.T) { t.Parallel() @@ -4410,62 +4430,3 @@ func (f *fakeUsageInserter) InsertDiscreteUsageEvent(_ context.Context, _ databa f.collectedEvents = append(f.collectedEvents, event) return nil } - -func seedPreviousWorkspaceStartWithAITask(ctx context.Context, t testing.TB, db database.Store) error { - t.Helper() - // If the below looks slightly convoluted, that's because it is. - // The workspace doesn't yet have a latest build, so querying all - // workspaces will fail. - tpls, err := db.GetTemplates(ctx) - if err != nil { - return xerrors.Errorf("seedFunc: get template: %w", err) - } - if len(tpls) != 1 { - return xerrors.Errorf("seedFunc: expected exactly one template, got %d", len(tpls)) - } - ws, err := db.GetWorkspacesByTemplateID(ctx, tpls[0].ID) - if err != nil { - return xerrors.Errorf("seedFunc: get workspaces: %w", err) - } - if len(ws) != 1 { - return xerrors.Errorf("seedFunc: expected exactly one workspace, got %d", len(ws)) - } - w := ws[0] - prevJob := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{ - OrganizationID: w.OrganizationID, - InitiatorID: w.OwnerID, - Type: database.ProvisionerJobTypeWorkspaceBuild, - }) - tvs, err := db.GetTemplateVersionsByTemplateID(ctx, database.GetTemplateVersionsByTemplateIDParams{ - TemplateID: tpls[0].ID, - }) - if err != nil { - return xerrors.Errorf("seedFunc: get template version: %w", err) - } - if len(tvs) != 1 { - return xerrors.Errorf("seedFunc: expected exactly one template version, got %d", len(tvs)) - } - if tpls[0].ActiveVersionID == uuid.Nil { - return xerrors.Errorf("seedFunc: active version id is nil") - } - res := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{ - JobID: prevJob.ID, - }) - agt := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ - ResourceID: res.ID, - }) - _ = dbgen.WorkspaceApp(t, db, database.WorkspaceApp{ - AgentID: agt.ID, - }) - _ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ - BuildNumber: 1, - HasAITask: sql.NullBool{Valid: true, Bool: true}, - ID: w.ID, - InitiatorID: w.OwnerID, - JobID: prevJob.ID, - TemplateVersionID: tvs[0].ID, - Transition: database.WorkspaceTransitionStart, - WorkspaceID: w.ID, - }) - return nil -} diff --git a/coderd/wsbuilder/wsbuilder.go b/coderd/wsbuilder/wsbuilder.go index 6aef8c2c2aa17..0eeddd2ba29fd 100644 --- a/coderd/wsbuilder/wsbuilder.go +++ b/coderd/wsbuilder/wsbuilder.go @@ -6,7 +6,6 @@ import ( "context" "database/sql" "encoding/json" - "errors" "fmt" "net/http" "time" @@ -87,13 +86,15 @@ type Builder struct { templateVersionPresetParameterValues *[]database.TemplateVersionPresetParameter parameterRender dynamicparameters.Renderer workspaceTags *map[string]string + task *database.Task + hasTask *bool // A workspace without a task will have a nil `task` and false `hasTask`. prebuiltWorkspaceBuildStage sdkproto.PrebuiltWorkspaceBuildStage verifyNoLegacyParametersOnce bool } type UsageChecker interface { - CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (UsageCheckResponse, error) + CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, task *database.Task, transition database.WorkspaceTransition) (UsageCheckResponse, error) } type UsageCheckResponse struct { @@ -105,7 +106,7 @@ type NoopUsageChecker struct{} var _ UsageChecker = NoopUsageChecker{} -func (NoopUsageChecker) CheckBuildUsage(_ context.Context, _ database.Store, _ *database.TemplateVersion) (UsageCheckResponse, error) { +func (NoopUsageChecker) CheckBuildUsage(_ context.Context, _ database.Store, _ *database.TemplateVersion, _ *database.Task, _ database.WorkspaceTransition) (UsageCheckResponse, error) { return UsageCheckResponse{ Permitted: true, }, nil @@ -489,8 +490,12 @@ func (b *Builder) buildTx(authFunc func(action policy.Action, object rbac.Object return BuildError{code, "insert workspace build", err} } + task, err := b.getWorkspaceTask() + if err != nil { + return BuildError{http.StatusInternalServerError, "get task by workspace id", err} + } // If this is a task workspace, link it to the latest workspace build. - if task, err := store.GetTaskByWorkspaceID(b.ctx, b.workspace.ID); err == nil { + if task != nil { _, err = store.UpsertTaskWorkspaceApp(b.ctx, database.UpsertTaskWorkspaceAppParams{ TaskID: task.ID, WorkspaceBuildNumber: buildNum, @@ -500,8 +505,6 @@ func (b *Builder) buildTx(authFunc func(action policy.Action, object rbac.Object if err != nil { return BuildError{http.StatusInternalServerError, "upsert task workspace app", err} } - } else if !errors.Is(err, sql.ErrNoRows) { - return BuildError{http.StatusInternalServerError, "get task by workspace id", err} } err = store.InsertWorkspaceBuildParameters(b.ctx, database.InsertWorkspaceBuildParametersParams{ @@ -634,6 +637,27 @@ func (b *Builder) getTemplateVersionID() (uuid.UUID, error) { return bld.TemplateVersionID, nil } +// getWorkspaceTask returns the task associated with the workspace, if any. +// If no task exists, it returns (nil, nil). +func (b *Builder) getWorkspaceTask() (*database.Task, error) { + if b.hasTask != nil { + return b.task, nil + } + t, err := b.store.GetTaskByWorkspaceID(b.ctx, b.workspace.ID) + if err != nil { + if xerrors.Is(err, sql.ErrNoRows) { + b.hasTask = ptr.Ref(false) + //nolint:nilnil // No task exists. + return nil, nil + } + return nil, xerrors.Errorf("get task: %w", err) + } + + b.task = &t + b.hasTask = ptr.Ref(true) + return b.task, nil +} + func (b *Builder) getTemplateTerraformValues() (*database.TemplateVersionTerraformValue, error) { if b.terraformValues != nil { return b.terraformValues, nil @@ -1307,7 +1331,12 @@ func (b *Builder) checkUsage() error { return BuildError{http.StatusInternalServerError, "Failed to fetch template version", err} } - resp, err := b.usageChecker.CheckBuildUsage(b.ctx, b.store, templateVersion) + task, err := b.getWorkspaceTask() + if err != nil { + return BuildError{http.StatusInternalServerError, "Failed to fetch workspace task", err} + } + + resp, err := b.usageChecker.CheckBuildUsage(b.ctx, b.store, templateVersion, task, b.trans) if err != nil { return BuildError{http.StatusInternalServerError, "Failed to check build usage", err} } diff --git a/coderd/wsbuilder/wsbuilder_test.go b/coderd/wsbuilder/wsbuilder_test.go index 3a8921dd6dcd9..38f88f7508a67 100644 --- a/coderd/wsbuilder/wsbuilder_test.go +++ b/coderd/wsbuilder/wsbuilder_test.go @@ -570,6 +570,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { mDB := expectDB(t, // Inputs withTemplate, + withNoTask, withInactiveVersionNoParams(), withLastBuildFound, withTemplateVersionVariables(inactiveVersionID, nil), @@ -605,6 +606,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { withTemplate, withInactiveVersion(richParameters), withLastBuildFound, + withNoTask, withTemplateVersionVariables(inactiveVersionID, nil), withRichParameters(initialBuildParameters), withParameterSchemas(inactiveJobID, nil), @@ -1049,7 +1051,7 @@ func TestWorkspaceBuildUsageChecker(t *testing.T) { var calls int64 fakeUsageChecker := &fakeUsageChecker{ - checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { + checkBuildUsageFunc: func(_ context.Context, _ database.Store, _ *database.TemplateVersion, _ *database.Task, _ database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { atomic.AddInt64(&calls, 1) return wsbuilder.UsageCheckResponse{Permitted: true}, nil }, @@ -1126,7 +1128,7 @@ func TestWorkspaceBuildUsageChecker(t *testing.T) { var calls int64 fakeUsageChecker := &fakeUsageChecker{ - checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { + checkBuildUsageFunc: func(_ context.Context, _ database.Store, _ *database.TemplateVersion, _ *database.Task, _ database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { atomic.AddInt64(&calls, 1) return c.response, c.responseErr }, @@ -1134,6 +1136,7 @@ func TestWorkspaceBuildUsageChecker(t *testing.T) { mDB := expectDB(t, withTemplate, + withNoTask, withInactiveVersionNoParams(), ) fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) @@ -1577,11 +1580,11 @@ func expectFindMatchingPresetID(id uuid.UUID, err error) func(mTx *dbmock.MockSt } type fakeUsageChecker struct { - checkBuildUsageFunc func(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) + checkBuildUsageFunc func(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, task *database.Task, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) } -func (f *fakeUsageChecker) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { - return f.checkBuildUsageFunc(ctx, store, templateVersion) +func (f *fakeUsageChecker) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, task *database.Task, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { + return f.checkBuildUsageFunc(ctx, store, templateVersion, task, transition) } func withNoTask(mTx *dbmock.MockStore) { diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 9a7b1f318f7c2..2dc10de7c3ae6 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -971,7 +971,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { var _ wsbuilder.UsageChecker = &API{} -func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { +func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, task *database.Task, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { // If the template version has an external agent, we need to check that the // license is entitled to this feature. if templateVersion.HasExternalAgent.Valid && templateVersion.HasExternalAgent.Bool { @@ -984,16 +984,31 @@ func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templ } } - // If the template version doesn't have an AI task, we don't need to check - // usage. - if !templateVersion.HasAITask.Valid || !templateVersion.HasAITask.Bool { - return wsbuilder.UsageCheckResponse{ - Permitted: true, - }, nil + resp, err := api.checkAIBuildUsage(ctx, store, task, transition) + if err != nil { + return wsbuilder.UsageCheckResponse{}, err + } + if !resp.Permitted { + return resp, nil + } + + return wsbuilder.UsageCheckResponse{Permitted: true}, nil +} + +// checkAIBuildUsage validates AI-related usage constraints. It is a no-op +// unless the transition is "start" and the template version has an AI task. +func (api *API) checkAIBuildUsage(ctx context.Context, store database.Store, task *database.Task, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { + // Only check AI usage rules for start transitions. + if transition != database.WorkspaceTransitionStart { + return wsbuilder.UsageCheckResponse{Permitted: true}, nil + } + + // If the template version doesn't have an AI task, we don't need to check usage. + if task == nil { + return wsbuilder.UsageCheckResponse{Permitted: true}, nil } - // When unlicensed, we need to check that we haven't breached the managed agent - // limit. + // When licensed, ensure we haven't breached the managed agent limit. // Unlicensed deployments are allowed to use unlimited managed agents. if api.Entitlements.HasLicense() { managedAgentLimit, ok := api.Entitlements.Feature(codersdk.FeatureManagedAgentLimit) @@ -1004,8 +1019,9 @@ func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templ }, nil } - // This check is intentionally not committed to the database. It's fine if - // it's not 100% accurate or allows for minor breaches due to build races. + // This check is intentionally not committed to the database. It's fine + // if it's not 100% accurate or allows for minor breaches due to build + // races. // nolint:gocritic // Requires permission to read all usage events. managedAgentCount, err := store.GetTotalUsageDCManagedAgentsV1(agpldbauthz.AsSystemRestricted(ctx), database.GetTotalUsageDCManagedAgentsV1Params{ StartDate: managedAgentLimit.UsagePeriod.Start, @@ -1023,9 +1039,7 @@ func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templ } } - return wsbuilder.UsageCheckResponse{ - Permitted: true, - }, nil + return wsbuilder.UsageCheckResponse{Permitted: true}, nil } // getProxyDERPStartingRegionID returns the starting region ID that should be diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index c3e6e1579fe91..81d7cdcd92c57 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -3,6 +3,7 @@ package coderd_test import ( "bytes" "context" + "database/sql" "encoding/json" "fmt" "io" @@ -21,6 +22,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" + "go.uber.org/mock/gomock" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" @@ -39,13 +41,16 @@ import ( "github.com/coder/retry" "github.com/coder/serpent" + agplcoderd "github.com/coder/coder/v2/coderd" agplaudit "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbfake" + "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" @@ -621,7 +626,7 @@ func TestManagedAgentLimit(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) - cli, _ := coderdenttest.New(t, &coderdenttest.Options{ + cli, owner := coderdenttest.New(t, &coderdenttest.Options{ Options: &coderdtest.Options{ IncludeProvisionerDaemon: true, }, @@ -635,18 +640,18 @@ func TestManagedAgentLimit(t *testing.T) { }) // Get entitlements to check that the license is a-ok. - entitlements, err := cli.Entitlements(ctx) //nolint:gocritic // we're not testing authz on the entitlements endpoint, so using owner is fine + sdkEntitlements, err := cli.Entitlements(ctx) //nolint:gocritic // we're not testing authz on the entitlements endpoint, so using owner is fine require.NoError(t, err) - require.True(t, entitlements.HasLicense) - agentLimit := entitlements.Features[codersdk.FeatureManagedAgentLimit] + require.True(t, sdkEntitlements.HasLicense) + agentLimit := sdkEntitlements.Features[codersdk.FeatureManagedAgentLimit] require.True(t, agentLimit.Enabled) require.NotNil(t, agentLimit.Limit) require.EqualValues(t, 1, *agentLimit.Limit) require.NotNil(t, agentLimit.SoftLimit) require.EqualValues(t, 1, *agentLimit.SoftLimit) - require.Empty(t, entitlements.Errors) + require.Empty(t, sdkEntitlements.Errors) // There should be a warning since we're really close to our agent limit. - require.Equal(t, entitlements.Warnings[0], "You are approaching the managed agent limit in your license. Please refer to the Deployment Licenses page for more information.") + require.Equal(t, sdkEntitlements.Warnings[0], "You are approaching the managed agent limit in your license. Please refer to the Deployment Licenses page for more information.") // Create a fake provision response that claims there are agents in the // template and every built workspace. @@ -706,15 +711,25 @@ func TestManagedAgentLimit(t *testing.T) { noAiTemplate := coderdtest.CreateTemplate(t, cli, uuid.Nil, noAiVersion.ID) // Create one AI workspace, which should succeed. - workspace := coderdtest.CreateWorkspace(t, cli, aiTemplate.ID) + task, err := cli.CreateTask(ctx, owner.UserID.String(), codersdk.CreateTaskRequest{ + Name: "workspace-1", + TemplateVersionID: aiTemplate.ActiveVersionID, + TemplateVersionPresetID: uuid.Nil, + Input: "hi", + DisplayName: "cool task 1", + }) + require.NoError(t, err, "creating task for AI workspace must succeed") + workspace, err := cli.Workspace(ctx, task.WorkspaceID.UUID) + require.NoError(t, err, "fetching AI workspace must succeed") coderdtest.AwaitWorkspaceBuildJobCompleted(t, cli, workspace.LatestBuild.ID) - // Create a second AI workspace, which should fail. This needs to be done - // manually because coderdtest.CreateWorkspace expects it to succeed. - _, err = cli.CreateUserWorkspace(ctx, codersdk.Me, codersdk.CreateWorkspaceRequest{ //nolint:gocritic // owners must still be subject to the limit - TemplateID: aiTemplate.ID, - Name: coderdtest.RandomUsername(t), - AutomaticUpdates: codersdk.AutomaticUpdatesNever, + // Create a second AI workspace, which should fail. + _, err = cli.CreateTask(ctx, owner.UserID.String(), codersdk.CreateTaskRequest{ + Name: "workspace-2", + TemplateVersionID: aiTemplate.ActiveVersionID, + TemplateVersionPresetID: uuid.Nil, + Input: "hi", + DisplayName: "bad task 2", }) require.ErrorContains(t, err, "You have breached the managed agent limit in your license") @@ -723,6 +738,73 @@ func TestManagedAgentLimit(t *testing.T) { coderdtest.AwaitWorkspaceBuildJobCompleted(t, cli, workspace.LatestBuild.ID) } +func TestCheckBuildUsage_SkipsAIForNonStartTransitions(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Prepare entitlements with a managed agent limit to enforce. + entSet := entitlements.New() + entSet.Modify(func(e *codersdk.Entitlements) { + e.HasLicense = true + limit := int64(1) + issuedAt := time.Now().Add(-2 * time.Hour) + start := time.Now().Add(-time.Hour) + end := time.Now().Add(time.Hour) + e.Features[codersdk.FeatureManagedAgentLimit] = codersdk.Feature{ + Enabled: true, + Limit: &limit, + UsagePeriod: &codersdk.UsagePeriod{IssuedAt: issuedAt, Start: start, End: end}, + } + }) + + // Enterprise API instance with entitlements injected. + agpl := &agplcoderd.API{ + Options: &agplcoderd.Options{ + Entitlements: entSet, + }, + } + eapi := &coderd.API{ + AGPL: agpl, + Options: &coderd.Options{Options: agpl.Options}, + } + + // Template version that has an AI task. + tv := &database.TemplateVersion{ + HasAITask: sql.NullBool{Valid: true, Bool: true}, + HasExternalAgent: sql.NullBool{Valid: true, Bool: false}, + } + + task := &database.Task{ + TemplateVersionID: tv.ID, + } + + // Mock DB: expect exactly one count call for the "start" transition. + mDB := dbmock.NewMockStore(ctrl) + mDB.EXPECT(). + GetTotalUsageDCManagedAgentsV1(gomock.Any(), gomock.Any()). + Times(1). + Return(int64(1), nil) // equal to limit -> should breach + + ctx := context.Background() + + // Start transition: should be not permitted due to limit breach. + startResp, err := eapi.CheckBuildUsage(ctx, mDB, tv, task, database.WorkspaceTransitionStart) + require.NoError(t, err) + require.False(t, startResp.Permitted) + require.Contains(t, startResp.Message, "breached the managed agent limit") + + // Stop transition: should be permitted and must not trigger additional DB calls. + stopResp, err := eapi.CheckBuildUsage(ctx, mDB, tv, task, database.WorkspaceTransitionStop) + require.NoError(t, err) + require.True(t, stopResp.Permitted) + + // Delete transition: should be permitted and must not trigger additional DB calls. + deleteResp, err := eapi.CheckBuildUsage(ctx, mDB, tv, task, database.WorkspaceTransitionDelete) + require.NoError(t, err) + require.True(t, deleteResp.Permitted) +} + // testDBAuthzRole returns a context with a subject that has a role // with permissions required for test setup. func testDBAuthzRole(ctx context.Context) context.Context { diff --git a/enterprise/coderd/workspaces_test.go b/enterprise/coderd/workspaces_test.go index 7cf9cd890b6df..6ffb8b4c30901 100644 --- a/enterprise/coderd/workspaces_test.go +++ b/enterprise/coderd/workspaces_test.go @@ -4477,3 +4477,124 @@ func TestDeleteWorkspaceACL(t *testing.T) { require.Equal(t, acl.Groups[0].ID, group.ID) }) } + +// Unfortunately this test is incompatible with 2.29, so it's commented out in +// this backport PR. +/* +func TestWorkspaceAITask(t *testing.T) { + t.Parallel() + + usage := coderdtest.NewUsageInserter() + owner, _, first := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + UsageInserter: usage, + IncludeProvisionerDaemon: true, + }, + LicenseOptions: (&coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureTemplateRBAC: 1, + }, + }).ManagedAgentLimit(10, 20), + }) + + client, _ := coderdtest.CreateAnotherUser(t, owner, first.OrganizationID, + rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()) + + graphWithTask := []*proto.Response{{ + Type: &proto.Response_Graph{ + Graph: &proto.GraphComplete{ + Error: "", + Timings: nil, + Resources: nil, + Parameters: nil, + ExternalAuthProviders: nil, + Presets: nil, + HasAiTasks: true, + AiTasks: []*proto.AITask{ + { + Id: "test", + SidebarApp: nil, + AppId: "test", + }, + }, + HasExternalAgents: false, + }, + }, + }} + planWithTask := []*proto.Response{{ + Type: &proto.Response_Plan{ + Plan: &proto.PlanComplete{ + Plan: []byte("{}"), + AiTaskCount: 1, + }, + }, + }} + + t.Run("CreateWorkspaceWithTaskNormally", func(t *testing.T) { + // Creating a workspace that has agentic tasks, but is not launced via task + // should not count towards the usage. + t.Cleanup(usage.Reset) + version := coderdtest.CreateTemplateVersion(t, client, first.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionInit: echo.InitComplete, + ProvisionPlan: planWithTask, + ProvisionApply: echo.ApplyComplete, + ProvisionGraph: graphWithTask, + }) + _ = coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, first.OrganizationID, version.ID) + wrk := coderdtest.CreateWorkspace(t, client, template.ID) + build := coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, wrk.LatestBuild.ID) + require.Equal(t, codersdk.WorkspaceStatusRunning, build.Status) + require.Len(t, usage.GetEvents(), 0) + }) + + t.Run("CreateTaskWorkspace", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitMedium) + t.Cleanup(usage.Reset) + version := coderdtest.CreateTemplateVersion(t, client, first.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionInit: echo.InitComplete, + ProvisionPlan: planWithTask, + ProvisionApply: echo.ApplyComplete, + ProvisionGraph: graphWithTask, + }) + _ = coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, first.OrganizationID, version.ID) + + task, err := client.CreateTask(ctx, codersdk.Me, codersdk.CreateTaskRequest{ + TemplateVersionID: template.ActiveVersionID, + Name: "istask", + }) + require.NoError(t, err) + + wrk, err := client.Workspace(ctx, task.WorkspaceID.UUID) + require.NoError(t, err) + + build := coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, wrk.LatestBuild.ID) + require.Equal(t, codersdk.WorkspaceStatusRunning, build.Status) + require.Len(t, usage.GetEvents(), 1) + + usage.Reset() // Clean slate for easy checks + // Stopping the workspace should not create additional usage. + build, err = client.CreateWorkspaceBuild(ctx, wrk.ID, codersdk.CreateWorkspaceBuildRequest{ + TemplateVersionID: wrk.LatestBuild.TemplateVersionID, + Transition: codersdk.WorkspaceTransitionStop, + }) + require.NoError(t, err) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, build.ID) + require.Len(t, usage.GetEvents(), 0) + + usage.Reset() // Clean slate for easy checks + // Starting the workspace manually **WILL** create usage, as it's + // still a task workspace. + build, err = client.CreateWorkspaceBuild(ctx, wrk.ID, codersdk.CreateWorkspaceBuildRequest{ + TemplateVersionID: wrk.LatestBuild.TemplateVersionID, + Transition: codersdk.WorkspaceTransitionStart, + }) + require.NoError(t, err) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, build.ID) + require.Len(t, usage.GetEvents(), 1) + }) +} +*/