diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 91722c141ade5..5ba613044f661 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -62,6 +62,10 @@ func IsNotAuthorizedError(err error) bool { if err == nil { return false } + if xerrors.Is(err, NoActorError) { + return true + } + return xerrors.As(err, &NotAuthorizedError{}) } @@ -1338,7 +1342,7 @@ func (q *querier) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) func (q *querier) GetTemplateAppInsights(ctx context.Context, arg database.GetTemplateAppInsightsParams) ([]database.GetTemplateAppInsightsRow, error) { // Used by TemplateAppInsights endpoint // For auditors, check read template_insights, and fall back to update template. - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplateInsights); IsNotAuthorizedError(err) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplateInsights); err != nil { for _, templateID := range arg.TemplateIDs { template, err := q.db.GetTemplateByID(ctx, templateID) if err != nil { @@ -1393,7 +1397,7 @@ func (q *querier) GetTemplateDAUs(ctx context.Context, arg database.GetTemplateD func (q *querier) GetTemplateInsights(ctx context.Context, arg database.GetTemplateInsightsParams) (database.GetTemplateInsightsRow, error) { // Used by TemplateInsights endpoint // For auditors, check read template_insights, and fall back to update template. - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplateInsights); IsNotAuthorizedError(err) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplateInsights); err != nil { for _, templateID := range arg.TemplateIDs { template, err := q.db.GetTemplateByID(ctx, templateID) if err != nil { @@ -1416,7 +1420,7 @@ func (q *querier) GetTemplateInsights(ctx context.Context, arg database.GetTempl func (q *querier) GetTemplateInsightsByInterval(ctx context.Context, arg database.GetTemplateInsightsByIntervalParams) ([]database.GetTemplateInsightsByIntervalRow, error) { // Used by TemplateInsights endpoint // For auditors, check read template_insights, and fall back to update template. - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplateInsights); IsNotAuthorizedError(err) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplateInsights); err != nil { for _, templateID := range arg.TemplateIDs { template, err := q.db.GetTemplateByID(ctx, templateID) if err != nil { @@ -1447,7 +1451,7 @@ func (q *querier) GetTemplateInsightsByTemplate(ctx context.Context, arg databas func (q *querier) GetTemplateParameterInsights(ctx context.Context, arg database.GetTemplateParameterInsightsParams) ([]database.GetTemplateParameterInsightsRow, error) { // Used by both insights endpoint and prometheus collector. // For auditors, check read template_insights, and fall back to update template. - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplateInsights); IsNotAuthorizedError(err) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplateInsights); err != nil { for _, templateID := range arg.TemplateIDs { template, err := q.db.GetTemplateByID(ctx, templateID) if err != nil { @@ -1620,7 +1624,7 @@ func (q *querier) GetUnexpiredLicenses(ctx context.Context) ([]database.License, func (q *querier) GetUserActivityInsights(ctx context.Context, arg database.GetUserActivityInsightsParams) ([]database.GetUserActivityInsightsRow, error) { // Used by insights endpoints. Need to check both for auditors and for regular users with template acl perms. - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplateInsights); IsNotAuthorizedError(err) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplateInsights); err != nil { for _, templateID := range arg.TemplateIDs { template, err := q.db.GetTemplateByID(ctx, templateID) if err != nil { @@ -1657,7 +1661,7 @@ func (q *querier) GetUserCount(ctx context.Context) (int64, error) { func (q *querier) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) { // Used by insights endpoints. Need to check both for auditors and for regular users with template acl perms. - if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplateInsights); IsNotAuthorizedError(err) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplateInsights); err != nil { for _, templateID := range arg.TemplateIDs { template, err := q.db.GetTemplateByID(ctx, templateID) if err != nil { @@ -2266,10 +2270,12 @@ func (q *querier) InsertWorkspaceAgent(ctx context.Context, arg database.InsertW } func (q *querier) InsertWorkspaceAgentLogSources(ctx context.Context, arg database.InsertWorkspaceAgentLogSourcesParams) ([]database.WorkspaceAgentLogSource, error) { + // TODO: This is used by the agent, should we have an rbac check here? return q.db.InsertWorkspaceAgentLogSources(ctx, arg) } func (q *querier) InsertWorkspaceAgentLogs(ctx context.Context, arg database.InsertWorkspaceAgentLogsParams) ([]database.WorkspaceAgentLog, error) { + // TODO: This is used by the agent, should we have an rbac check here? return q.db.InsertWorkspaceAgentLogs(ctx, arg) } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 3e42ec46ac2fd..5f40fe936cb63 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -225,6 +225,26 @@ func (s *MethodTestSuite) TestAPIKey() { ID: a.ID, }).Asserts(a, rbac.ActionUpdate).Returns() })) + s.Run("DeleteApplicationConnectAPIKeysByUserID", s.Subtest(func(db database.Store, check *expects) { + a, _ := dbgen.APIKey(s.T(), db, database.APIKey{ + Scope: database.APIKeyScopeApplicationConnect, + }) + check.Args(a.UserID).Asserts(rbac.ResourceAPIKey.WithOwner(a.UserID.String()), rbac.ActionDelete).Returns() + })) + s.Run("DeleteExternalAuthLink", s.Subtest(func(db database.Store, check *expects) { + a := dbgen.ExternalAuthLink(s.T(), db, database.ExternalAuthLink{}) + check.Args(database.DeleteExternalAuthLinkParams{ + ProviderID: a.ProviderID, + UserID: a.UserID, + }).Asserts(a, rbac.ActionDelete).Returns() + })) + s.Run("GetExternalAuthLinksByUserID", s.Subtest(func(db database.Store, check *expects) { + a := dbgen.ExternalAuthLink(s.T(), db, database.ExternalAuthLink{}) + b := dbgen.ExternalAuthLink(s.T(), db, database.ExternalAuthLink{ + UserID: a.UserID, + }) + check.Args(a.UserID).Asserts(a, rbac.ActionRead, b, rbac.ActionRead) + })) } func (s *MethodTestSuite) TestAuditLogs() { @@ -645,6 +665,10 @@ func (s *MethodTestSuite) TestWorkspaceProxy() { p, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) check.Args(p.ID).Asserts(p, rbac.ActionRead).Returns(p) })) + s.Run("GetWorkspaceProxyByName", s.Subtest(func(db database.Store, check *expects) { + p, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) + check.Args(p.Name).Asserts(p, rbac.ActionRead).Returns(p) + })) s.Run("UpdateWorkspaceProxyDeleted", s.Subtest(func(db database.Store, check *expects) { p, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) check.Args(database.UpdateWorkspaceProxyDeletedParams{ @@ -652,6 +676,12 @@ func (s *MethodTestSuite) TestWorkspaceProxy() { Deleted: true, }).Asserts(p, rbac.ActionDelete) })) + s.Run("UpdateWorkspaceProxy", s.Subtest(func(db database.Store, check *expects) { + p, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) + check.Args(database.UpdateWorkspaceProxyParams{ + ID: p.ID, + }).Asserts(p, rbac.ActionUpdate) + })) s.Run("GetWorkspaceProxies", s.Subtest(func(db database.Store, check *expects) { p1, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) p2, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{}) @@ -813,6 +843,30 @@ func (s *MethodTestSuite) TestTemplate() { ID: t1.ID, }).Asserts(t1, rbac.ActionCreate) })) + s.Run("UpdateTemplateAccessControlByID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(database.UpdateTemplateAccessControlByIDParams{ + ID: t1.ID, + }).Asserts(t1, rbac.ActionUpdate) + })) + s.Run("UpdateTemplateScheduleByID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(database.UpdateTemplateScheduleByIDParams{ + ID: t1.ID, + }).Asserts(t1, rbac.ActionUpdate) + })) + s.Run("UpdateTemplateWorkspacesLastUsedAt", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(database.UpdateTemplateWorkspacesLastUsedAtParams{ + TemplateID: t1.ID, + }).Asserts(t1, rbac.ActionUpdate) + })) + s.Run("UpdateWorkspacesDormantDeletingAtByTemplateID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(database.UpdateWorkspacesDormantDeletingAtByTemplateIDParams{ + TemplateID: t1.ID, + }).Asserts(t1, rbac.ActionUpdate) + })) s.Run("UpdateTemplateActiveVersionByID", s.Subtest(func(db database.Store, check *expects) { t1 := dbgen.Template(s.T(), db, database.Template{ ActiveVersionID: uuid.New(), @@ -875,9 +929,39 @@ func (s *MethodTestSuite) TestTemplate() { ExternalAuthProviders: []string{}, }).Asserts(t1, rbac.ActionUpdate).Returns() })) + s.Run("GetTemplateInsights", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.GetTemplateInsightsParams{}).Asserts(rbac.ResourceTemplateInsights, rbac.ActionRead) + })) + s.Run("GetUserLatencyInsights", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.GetUserLatencyInsightsParams{}).Asserts(rbac.ResourceTemplateInsights, rbac.ActionRead) + })) + s.Run("GetUserActivityInsights", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.GetUserActivityInsightsParams{}).Asserts(rbac.ResourceTemplateInsights, rbac.ActionRead) + })) + s.Run("GetTemplateParameterInsights", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.GetTemplateParameterInsightsParams{}).Asserts(rbac.ResourceTemplateInsights, rbac.ActionRead) + })) + s.Run("GetTemplateInsightsByInterval", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.GetTemplateInsightsByIntervalParams{}).Asserts(rbac.ResourceTemplateInsights, rbac.ActionRead) + })) + s.Run("GetTemplateInsightsByTemplate", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.GetTemplateInsightsByTemplateParams{}).Asserts(rbac.ResourceTemplateInsights, rbac.ActionRead) + })) + s.Run("GetTemplateAppInsights", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.GetTemplateAppInsightsParams{}).Asserts(rbac.ResourceTemplateInsights, rbac.ActionRead) + })) + s.Run("GetTemplateAppInsightsByTemplate", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.GetTemplateAppInsightsByTemplateParams{}).Asserts(rbac.ResourceTemplateInsights, rbac.ActionRead) + })) } func (s *MethodTestSuite) TestUser() { + s.Run("GetAuthorizedUsers", s.Subtest(func(db database.Store, check *expects) { + dbgen.User(s.T(), db, database.User{}) + // No asserts because SQLFilter. + check.Args(database.GetUsersParams{}, emptyPreparedAuthorized{}). + Asserts() + })) s.Run("DeleteAPIKeysByUserID", s.Subtest(func(db database.Store, check *expects) { u := dbgen.User(s.T(), db, database.User{}) check.Args(u.ID).Asserts(rbac.ResourceAPIKey.WithOwner(u.ID.String()), rbac.ActionDelete).Returns() @@ -945,6 +1029,12 @@ func (s *MethodTestSuite) TestUser() { ID: u.ID, }).Asserts(u.UserDataRBACObject(), rbac.ActionUpdate).Returns() })) + s.Run("UpdateUserQuietHoursSchedule", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.UpdateUserQuietHoursScheduleParams{ + ID: u.ID, + }).Asserts(u.UserDataRBACObject(), rbac.ActionUpdate) + })) s.Run("UpdateUserLastSeenAt", s.Subtest(func(db database.Store, check *expects) { u := dbgen.User(s.T(), db, database.User{}) check.Args(database.UpdateUserLastSeenAtParams{ @@ -1048,6 +1138,11 @@ func (s *MethodTestSuite) TestUser() { rbac.ResourceRoleAssignment, rbac.ActionDelete, ).Returns(o) })) + s.Run("AllUserIDs", s.Subtest(func(db database.Store, check *expects) { + a := dbgen.User(s.T(), db, database.User{}) + b := dbgen.User(s.T(), db, database.User{}) + check.Args().Asserts(rbac.ResourceSystem, rbac.ActionRead).Returns(slice.New(a.ID, b.ID)) + })) } func (s *MethodTestSuite) TestWorkspace() { @@ -1082,6 +1177,34 @@ func (s *MethodTestSuite) TestWorkspace() { agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) check.Args(agt.ID).Asserts(ws, rbac.ActionRead).Returns(agt) })) + s.Run("GetWorkspaceAgentLifecycleStateByID", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{}) + ws := dbgen.Workspace(s.T(), db, database.Workspace{ + TemplateID: tpl.ID, + }) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + check.Args(agt.ID).Asserts(ws, rbac.ActionRead) + })) + s.Run("GetWorkspaceAgentMetadata", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{}) + ws := dbgen.Workspace(s.T(), db, database.Workspace{ + TemplateID: tpl.ID, + }) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + _ = db.InsertWorkspaceAgentMetadata(context.Background(), database.InsertWorkspaceAgentMetadataParams{ + WorkspaceAgentID: agt.ID, + DisplayName: "test", + Key: "test", + }) + check.Args(database.GetWorkspaceAgentMetadataParams{ + WorkspaceAgentID: agt.ID, + Keys: []string{"test"}, + }).Asserts(ws, rbac.ActionRead) + })) s.Run("GetWorkspaceAgentByInstanceID", s.Subtest(func(db database.Store, check *expects) { tpl := dbgen.Template(s.T(), db, database.Template{}) ws := dbgen.Workspace(s.T(), db, database.Workspace{ @@ -1105,6 +1228,18 @@ func (s *MethodTestSuite) TestWorkspace() { LifecycleState: database.WorkspaceAgentLifecycleStateCreated, }).Asserts(ws, rbac.ActionUpdate).Returns() })) + s.Run("UpdateWorkspaceAgentMetadata", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{}) + ws := dbgen.Workspace(s.T(), db, database.Workspace{ + TemplateID: tpl.ID, + }) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + check.Args(database.UpdateWorkspaceAgentMetadataParams{ + WorkspaceAgentID: agt.ID, + }).Asserts(ws, rbac.ActionUpdate).Returns() + })) s.Run("UpdateWorkspaceAgentLogOverflowByID", s.Subtest(func(db database.Store, check *expects) { tpl := dbgen.Template(s.T(), db, database.Template{}) ws := dbgen.Workspace(s.T(), db, database.Workspace{ @@ -1217,6 +1352,16 @@ func (s *MethodTestSuite) TestWorkspace() { TemplateName: tpl.Name, }) })) + s.Run("GetWorkspaceAgentsInLatestBuildByWorkspaceID", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{}) + ws := dbgen.Workspace(s.T(), db, database.Workspace{ + TemplateID: tpl.ID, + }) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + check.Args(ws.ID).Asserts(ws, rbac.ActionRead) + })) s.Run("GetWorkspaceByOwnerIDAndName", s.Subtest(func(db database.Store, check *expects) { ws := dbgen.Workspace(s.T(), db, database.Workspace{}) check.Args(database.GetWorkspaceByOwnerIDAndNameParams{ @@ -1341,6 +1486,19 @@ func (s *MethodTestSuite) TestWorkspace() { ID: w.ID, }).Asserts(w, rbac.ActionUpdate).Returns(expected) })) + s.Run("UpdateWorkspaceDormantDeletingAt", s.Subtest(func(db database.Store, check *expects) { + w := dbgen.Workspace(s.T(), db, database.Workspace{}) + check.Args(database.UpdateWorkspaceDormantDeletingAtParams{ + ID: w.ID, + }).Asserts(w, rbac.ActionUpdate) + })) + s.Run("UpdateWorkspaceAutomaticUpdates", s.Subtest(func(db database.Store, check *expects) { + w := dbgen.Workspace(s.T(), db, database.Workspace{}) + check.Args(database.UpdateWorkspaceAutomaticUpdatesParams{ + ID: w.ID, + AutomaticUpdates: database.AutomaticUpdatesAlways, + }).Asserts(w, rbac.ActionUpdate) + })) s.Run("InsertWorkspaceAgentStat", s.Subtest(func(db database.Store, check *expects) { ws := dbgen.Workspace(s.T(), db, database.Workspace{}) check.Args(database.InsertWorkspaceAgentStatParams{ @@ -1405,6 +1563,14 @@ func (s *MethodTestSuite) TestWorkspace() { app := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: agt.ID}) check.Args(app.ID).Asserts(ws, rbac.ActionRead).Returns(ws) })) + s.Run("ActivityBumpWorkspace", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild}) + check.Args(database.ActivityBumpWorkspaceParams{ + WorkspaceID: ws.ID, + }).Asserts(ws, rbac.ActionUpdate).Returns() + })) } func (s *MethodTestSuite) TestExtraMethods() { @@ -1417,6 +1583,174 @@ func (s *MethodTestSuite) TestExtraMethods() { s.NoError(err, "insert provisioner daemon") check.Args().Asserts(d, rbac.ActionRead) })) + s.Run("DeleteOldProvisionerDaemons", s.Subtest(func(db database.Store, check *expects) { + _, err := db.UpsertProvisionerDaemon(context.Background(), database.UpsertProvisionerDaemonParams{ + Tags: database.StringMap(map[string]string{ + provisionersdk.TagScope: provisionersdk.ScopeOrganization, + }), + }) + s.NoError(err, "insert provisioner daemon") + check.Args().Asserts(rbac.ResourceSystem, rbac.ActionDelete) + })) +} + +// All functions in this method test suite are not implemented in dbmem, but +// we still want to assert RBAC checks. +func (s *MethodTestSuite) TestTailnetFunctions() { + s.Run("CleanTailnetCoordinators", s.Subtest(func(db database.Store, check *expects) { + check.Args(). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionDelete). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("CleanTailnetLostPeers", s.Subtest(func(db database.Store, check *expects) { + check.Args(). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionDelete). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("CleanTailnetTunnels", s.Subtest(func(db database.Store, check *expects) { + check.Args(). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionDelete). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("DeleteAllTailnetClientSubscriptions", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.DeleteAllTailnetClientSubscriptionsParams{}). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionDelete). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("DeleteAllTailnetTunnels", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.DeleteAllTailnetTunnelsParams{}). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionDelete). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("DeleteCoordinator", s.Subtest(func(db database.Store, check *expects) { + check.Args(uuid.New()). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionDelete). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("DeleteTailnetAgent", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.DeleteTailnetAgentParams{}). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionUpdate). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("DeleteTailnetClient", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.DeleteTailnetClientParams{}). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionDelete). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("DeleteTailnetClientSubscription", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.DeleteTailnetClientSubscriptionParams{}). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionDelete). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("DeleteTailnetPeer", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.DeleteTailnetPeerParams{}). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionDelete). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("DeleteTailnetTunnel", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.DeleteTailnetTunnelParams{}). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionDelete). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("GetAllTailnetAgents", s.Subtest(func(db database.Store, check *expects) { + check.Args(). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionRead). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("GetTailnetAgents", s.Subtest(func(db database.Store, check *expects) { + check.Args(uuid.New()). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionRead). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("GetTailnetClientsForAgent", s.Subtest(func(db database.Store, check *expects) { + check.Args(uuid.New()). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionRead). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("GetTailnetPeers", s.Subtest(func(db database.Store, check *expects) { + check.Args(uuid.New()). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionRead). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("GetTailnetTunnelPeerBindings", s.Subtest(func(db database.Store, check *expects) { + check.Args(uuid.New()). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionRead). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("GetTailnetTunnelPeerIDs", s.Subtest(func(db database.Store, check *expects) { + check.Args(uuid.New()). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionRead). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("GetAllTailnetCoordinators", s.Subtest(func(db database.Store, check *expects) { + check.Args(). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionRead). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("GetAllTailnetPeers", s.Subtest(func(db database.Store, check *expects) { + check.Args(). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionRead). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("GetAllTailnetTunnels", s.Subtest(func(db database.Store, check *expects) { + check.Args(). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionRead). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("UpsertTailnetAgent", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.UpsertTailnetAgentParams{}). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionUpdate). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("UpsertTailnetClient", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.UpsertTailnetClientParams{}). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionUpdate). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("UpsertTailnetClientSubscription", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.UpsertTailnetClientSubscriptionParams{}). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionUpdate). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("UpsertTailnetCoordinator", s.Subtest(func(db database.Store, check *expects) { + check.Args(uuid.New()). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionUpdate). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("UpsertTailnetPeer", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.UpsertTailnetPeerParams{ + Status: database.TailnetStatusOk, + }). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionCreate). + Errors(dbmem.ErrUnimplemented) + })) + s.Run("UpsertTailnetTunnel", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.UpsertTailnetTunnelParams{}). + Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionCreate). + Errors(dbmem.ErrUnimplemented) + })) +} + +func (s *MethodTestSuite) TestDBCrypt() { + s.Run("GetDBCryptKeys", s.Subtest(func(db database.Store, check *expects) { + check.Args(). + Asserts(rbac.ResourceSystem, rbac.ActionRead). + Returns([]database.DBCryptKey{}) + })) + s.Run("InsertDBCryptKey", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertDBCryptKeyParams{}). + Asserts(rbac.ResourceSystem, rbac.ActionCreate). + Returns() + })) + s.Run("RevokeDBCryptKey", s.Subtest(func(db database.Store, check *expects) { + err := db.InsertDBCryptKey(context.Background(), database.InsertDBCryptKeyParams{ + ActiveKeyDigest: "revoke me", + }) + s.NoError(err) + check.Args("revoke me"). + Asserts(rbac.ResourceSystem, rbac.ActionUpdate). + Returns() + })) } func (s *MethodTestSuite) TestSystemFunctions() { @@ -1571,6 +1905,15 @@ func (s *MethodTestSuite) TestSystemFunctions() { Asserts(rbac.ResourceSystem, rbac.ActionRead). Returns(slice.New(tv1, tv2, tv3)) })) + s.Run("GetParameterSchemasByJobID", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{}) + tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, + }) + job := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ID: tv.JobID}) + check.Args(job.ID). + Asserts(tpl, rbac.ActionRead).Errors(sql.ErrNoRows) + })) s.Run("GetWorkspaceAppsByAgentIDs", s.Subtest(func(db database.Store, check *expects) { aWs := dbgen.Workspace(s.T(), db, database.Workspace{}) aBuild := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: aWs.ID, JobID: uuid.New()}) @@ -1718,4 +2061,130 @@ func (s *MethodTestSuite) TestSystemFunctions() { Transition: database.WorkspaceTransitionStart, }).Asserts(rbac.ResourceSystem, rbac.ActionCreate) })) + s.Run("DeleteOldWorkspaceAgentLogs", s.Subtest(func(db database.Store, check *expects) { + check.Args().Asserts(rbac.ResourceSystem, rbac.ActionDelete) + })) + s.Run("InsertWorkspaceAgentStats", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertWorkspaceAgentStatsParams{}).Asserts(rbac.ResourceSystem, rbac.ActionCreate).Errors(errMatchAny) + })) + s.Run("InsertWorkspaceAppStats", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertWorkspaceAppStatsParams{}).Asserts(rbac.ResourceSystem, rbac.ActionCreate) + })) + s.Run("InsertWorkspaceAgentScripts", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertWorkspaceAgentScriptsParams{}).Asserts(rbac.ResourceSystem, rbac.ActionCreate) + })) + s.Run("InsertWorkspaceAgentMetadata", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertWorkspaceAgentMetadataParams{}).Asserts(rbac.ResourceSystem, rbac.ActionCreate) + })) + s.Run("InsertWorkspaceAgentLogs", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertWorkspaceAgentLogsParams{}).Asserts() + })) + s.Run("InsertWorkspaceAgentLogSources", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertWorkspaceAgentLogSourcesParams{}).Asserts() + })) + s.Run("GetTemplateDAUs", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.GetTemplateDAUsParams{}).Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("GetActiveWorkspaceBuildsByTemplateID", s.Subtest(func(db database.Store, check *expects) { + check.Args(uuid.New()).Asserts(rbac.ResourceSystem, rbac.ActionRead).Errors(sql.ErrNoRows) + })) + s.Run("GetDeploymentDAUs", s.Subtest(func(db database.Store, check *expects) { + check.Args(int32(0)).Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("GetAppSecurityKey", s.Subtest(func(db database.Store, check *expects) { + check.Args().Asserts() + })) + s.Run("UpsertAppSecurityKey", s.Subtest(func(db database.Store, check *expects) { + check.Args("").Asserts() + })) + s.Run("GetApplicationName", s.Subtest(func(db database.Store, check *expects) { + db.UpsertApplicationName(context.Background(), "foo") + check.Args().Asserts() + })) + s.Run("UpsertApplicationName", s.Subtest(func(db database.Store, check *expects) { + check.Args("").Asserts(rbac.ResourceDeploymentValues, rbac.ActionCreate) + })) + s.Run("GetHealthSettings", s.Subtest(func(db database.Store, check *expects) { + check.Args().Asserts() + })) + s.Run("UpsertHealthSettings", s.Subtest(func(db database.Store, check *expects) { + check.Args("foo").Asserts(rbac.ResourceDeploymentValues, rbac.ActionCreate) + })) + s.Run("GetDeploymentWorkspaceAgentStats", s.Subtest(func(db database.Store, check *expects) { + check.Args(time.Time{}).Asserts() + })) + s.Run("GetDeploymentWorkspaceStats", s.Subtest(func(db database.Store, check *expects) { + check.Args().Asserts() + })) + s.Run("GetFileTemplates", s.Subtest(func(db database.Store, check *expects) { + check.Args(uuid.New()).Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("GetHungProvisionerJobs", s.Subtest(func(db database.Store, check *expects) { + check.Args(time.Time{}).Asserts() + })) + s.Run("UpsertOAuthSigningKey", s.Subtest(func(db database.Store, check *expects) { + check.Args("foo").Asserts(rbac.ResourceSystem, rbac.ActionUpdate) + })) + s.Run("GetOAuthSigningKey", s.Subtest(func(db database.Store, check *expects) { + db.UpsertOAuthSigningKey(context.Background(), "foo") + check.Args().Asserts(rbac.ResourceSystem, rbac.ActionUpdate) + })) + s.Run("InsertMissingGroups", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertMissingGroupsParams{}).Asserts(rbac.ResourceSystem, rbac.ActionCreate).Errors(errMatchAny) + })) + s.Run("UpdateUserLoginType", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.UpdateUserLoginTypeParams{ + NewLoginType: database.LoginTypePassword, + UserID: u.ID, + }).Asserts(rbac.ResourceSystem, rbac.ActionUpdate) + })) + s.Run("GetWorkspaceAgentStatsAndLabels", s.Subtest(func(db database.Store, check *expects) { + check.Args(time.Time{}).Asserts() + })) + s.Run("GetWorkspaceAgentStats", s.Subtest(func(db database.Store, check *expects) { + check.Args(time.Time{}).Asserts() + })) + s.Run("GetWorkspaceProxyByHostname", s.Subtest(func(db database.Store, check *expects) { + p, _ := dbgen.WorkspaceProxy(s.T(), db, database.WorkspaceProxy{ + WildcardHostname: "*.example.com", + }) + check.Args(database.GetWorkspaceProxyByHostnameParams{ + Hostname: "foo.example.com", + AllowWildcardHostname: true, + }).Asserts(rbac.ResourceSystem, rbac.ActionRead).Returns(p) + })) + s.Run("GetTemplateAverageBuildTime", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.GetTemplateAverageBuildTimeParams{}).Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("GetWorkspacesEligibleForTransition", s.Subtest(func(db database.Store, check *expects) { + check.Args(time.Time{}).Asserts() + })) + s.Run("InsertTemplateVersionVariable", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertTemplateVersionVariableParams{}).Asserts(rbac.ResourceSystem, rbac.ActionCreate) + })) + s.Run("UpdateInactiveUsersToDormant", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.UpdateInactiveUsersToDormantParams{}).Asserts(rbac.ResourceSystem, rbac.ActionCreate).Errors(sql.ErrNoRows) + })) + s.Run("GetWorkspaceUniqueOwnerCountByTemplateIDs", s.Subtest(func(db database.Store, check *expects) { + check.Args([]uuid.UUID{uuid.New()}).Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("GetWorkspaceAgentScriptsByAgentIDs", s.Subtest(func(db database.Store, check *expects) { + check.Args([]uuid.UUID{uuid.New()}).Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("GetWorkspaceAgentLogSourcesByAgentIDs", s.Subtest(func(db database.Store, check *expects) { + check.Args([]uuid.UUID{uuid.New()}).Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) + s.Run("GetProvisionerJobsByIDsWithQueuePosition", s.Subtest(func(db database.Store, check *expects) { + check.Args([]uuid.UUID{}).Asserts() + })) + s.Run("GetReplicaByID", s.Subtest(func(db database.Store, check *expects) { + check.Args(uuid.New()).Asserts(rbac.ResourceSystem, rbac.ActionRead).Errors(sql.ErrNoRows) + })) + s.Run("GetWorkspaceAgentAndOwnerByAuthToken", s.Subtest(func(db database.Store, check *expects) { + check.Args(uuid.New()).Asserts(rbac.ResourceSystem, rbac.ActionRead).Errors(sql.ErrNoRows) + })) + s.Run("GetUserLinksByUserID", s.Subtest(func(db database.Store, check *expects) { + check.Args(uuid.New()).Asserts(rbac.ResourceSystem, rbac.ActionRead) + })) } diff --git a/coderd/database/dbauthz/setup_test.go b/coderd/database/dbauthz/setup_test.go index 33eccfe09ff3b..3c54d8be4e345 100644 --- a/coderd/database/dbauthz/setup_test.go +++ b/coderd/database/dbauthz/setup_test.go @@ -2,6 +2,7 @@ package dbauthz_test import ( "context" + "errors" "fmt" "reflect" "sort" @@ -27,10 +28,14 @@ import ( "github.com/coder/coder/v2/coderd/util/slice" ) +var errMatchAny = errors.New("match any error") + var skipMethods = map[string]string{ - "InTx": "Not relevant", - "Ping": "Not relevant", - "Wrappers": "Not relevant", + "InTx": "Not relevant", + "Ping": "Not relevant", + "Wrappers": "Not relevant", + "AcquireLock": "Not relevant", + "TryAcquireLock": "Not relevant", } // TestMethodTestSuite runs MethodTestSuite. @@ -62,7 +67,8 @@ func (s *MethodTestSuite) SetupSuite() { mockStore.EXPECT().Wrappers().Return([]string{}).AnyTimes() az := dbauthz.New(mockStore, nil, slog.Make(), coderdtest.AccessControlStorePointer()) // Take the underlying type of the interface. - azt := reflect.TypeOf(az).Elem() + azt := reflect.TypeOf(az) + require.Greater(s.T(), azt.NumMethod(), 0, "no methods found on querier") s.methodAccounting = make(map[string]int) for i := 0; i < azt.NumMethod(); i++ { method := azt.Method(i) @@ -168,7 +174,16 @@ func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expec fakeAuthorizer.AlwaysReturn = nil outputs, err := callMethod(ctx) - s.NoError(err, "method %q returned an error", methodName) + if testCase.err == nil { + s.NoError(err, "method %q returned an error", methodName) + } else { + if errors.Is(testCase.err, errMatchAny) { + // This means we do not care exactly what the error is. + s.Error(err, "method %q returned an error", methodName) + } else { + s.EqualError(err, testCase.err.Error(), "method %q returned an unexpected error", methodName) + } + } // Some tests may not care about the outputs, so we only assert if // they are provided. @@ -289,6 +304,7 @@ type expects struct { assertions []AssertRBAC // outputs is optional. Can assert non-error return values. outputs []reflect.Value + err error } // Asserts is required. Asserts the RBAC authorize calls that should be made. @@ -313,6 +329,12 @@ func (m *expects) Returns(rets ...any) *expects { return m } +// Errors is optional. If it is never called, it will not be asserted. +func (m *expects) Errors(err error) *expects { + m.err = err + return m +} + // AssertRBAC contains the object and actions to be asserted. type AssertRBAC struct { Object rbac.Object diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 2e85bda1ffcbd..95e5528dbba0a 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -880,7 +880,7 @@ func (q *FakeQuerier) AllUserIDs(_ context.Context) ([]uuid.UUID, error) { defer q.mutex.RUnlock() userIDs := make([]uuid.UUID, 0, len(q.users)) for idx := range q.users { - userIDs[idx] = q.users[idx].ID + userIDs = append(userIDs, q.users[idx].ID) } return userIDs, nil }