diff --git a/coderd/activitybump.go b/coderd/activitybump.go index 059655ed8f33e..63cfacb528c2f 100644 --- a/coderd/activitybump.go +++ b/coderd/activitybump.go @@ -15,10 +15,10 @@ import ( // activityBumpWorkspace automatically bumps the workspace's auto-off timer // if it is set to expire soon. -func activityBumpWorkspace(log slog.Logger, db database.Store, workspaceID uuid.UUID) { +func activityBumpWorkspace(ctx context.Context, log slog.Logger, db database.Store, workspaceID uuid.UUID) { // We set a short timeout so if the app is under load, these // low priority operations fail first. - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + ctx, cancel := context.WithTimeout(ctx, time.Second*15) defer cancel() err := db.InTx(func(s database.Store) error { @@ -82,9 +82,12 @@ func activityBumpWorkspace(log slog.Logger, db database.Store, workspaceID uuid. return nil }, nil) if err != nil { - log.Error(ctx, "bump failed", slog.Error(err), - slog.F("workspace_id", workspaceID), - ) + if !xerrors.Is(err, context.Canceled) { + // Bump will fail if the context is cancelled, but this is ok. + log.Error(ctx, "bump failed", slog.Error(err), + slog.F("workspace_id", workspaceID), + ) + } return } diff --git a/coderd/authorize.go b/coderd/authorize.go index ab1f3a39fd542..d75cb043bbea9 100644 --- a/coderd/authorize.go +++ b/coderd/authorize.go @@ -51,6 +51,28 @@ type HTTPAuthorizer struct { // return // } func (api *API) Authorize(r *http.Request, action rbac.Action, object rbac.Objecter) bool { + // The experiment does not replace ALL rbac checks, but does replace most. + // This statement aborts early on the checks that will be removed in the + // future when this experiment is default. + if api.Experiments.Enabled(codersdk.ExperimentAuthzQuerier) { + // Some resource types do not interact with the persistent layer and + // we need to keep these checks happening in the API layer. + switch object.RBACObject().Type { + case rbac.ResourceWorkspaceExecution.Type: + // This is not a db resource, always in API layer + case rbac.ResourceDeploymentConfig.Type: + // For metric cache items like DAU, we do not hit the DB. + // Some db actions are in asserted in the authz layer. + case rbac.ResourceReplicas.Type: + // Replica rbac is checked for adding and removing replicas. + case rbac.ResourceProvisionerDaemon.Type: + // Provisioner rbac is checked for adding and removing provisioners. + case rbac.ResourceDebugInfo.Type: + // This is not a db resource, always in API layer. + default: + return true + } + } return api.HTTPAuth.Authorize(r, action, object) } diff --git a/coderd/autobuild/executor/lifecycle_executor.go b/coderd/autobuild/executor/lifecycle_executor.go index 40d5ccbdc6626..5af701de4b89d 100644 --- a/coderd/autobuild/executor/lifecycle_executor.go +++ b/coderd/autobuild/executor/lifecycle_executor.go @@ -12,6 +12,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/coderd/autobuild/schedule" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbauthz" ) // Executor automatically starts or stops workspaces. @@ -33,7 +34,8 @@ type Stats struct { // New returns a new autobuild executor. func New(ctx context.Context, db database.Store, log slog.Logger, tick <-chan time.Time) *Executor { le := &Executor{ - ctx: ctx, + //nolint:gocritic // TODO: make an autostart role instead of using System + ctx: dbauthz.AsSystem(ctx), db: db, tick: tick, log: log, diff --git a/coderd/coderd.go b/coderd/coderd.go index 319b65893c743..532eba43bf711 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -42,6 +42,7 @@ import ( "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/awsidentity" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/database/dbtype" "github.com/coder/coder/coderd/gitauth" "github.com/coder/coder/coderd/gitsshkey" @@ -157,13 +158,6 @@ func New(options *Options) *API { options = &Options{} } experiments := initExperiments(options.Logger, options.DeploymentConfig.Experiments.Value, options.DeploymentConfig.Experimental.Value) - // TODO: remove this once we promote authz_querier out of experiments. - if experiments.Enabled(codersdk.ExperimentAuthzQuerier) { - panic("Coming soon!") - // if _, ok := (options.Database).(*authzquery.AuthzQuerier); !ok { - // options.Database = authzquery.NewAuthzQuerier(options.Database, options.Authorizer) - // } - } if options.AppHostname != "" && options.AppHostnameRegex == nil || options.AppHostname == "" && options.AppHostnameRegex != nil { panic("coderd: both AppHostname and AppHostnameRegex must be set or unset") } @@ -204,6 +198,14 @@ func New(options *Options) *API { if options.Auditor == nil { options.Auditor = audit.NewNop() } + // TODO: remove this once we promote authz_querier out of experiments. + if experiments.Enabled(codersdk.ExperimentAuthzQuerier) { + options.Database = dbauthz.New( + options.Database, + options.Authorizer, + options.Logger.Named("authz_querier"), + ) + } if options.SetUserGroups == nil { options.SetUserGroups = func(context.Context, database.Store, uuid.UUID, []string) error { return nil } } @@ -304,8 +306,10 @@ func New(options *Options) *API { DisableSessionExpiryRefresh: options.DeploymentConfig.DisableSessionExpiryRefresh.Value, Optional: true, }), - httpmw.ExtractUserParam(api.Database, false), - httpmw.ExtractWorkspaceAndAgentParam(api.Database), + httpmw.AsAuthzSystem( + httpmw.ExtractUserParam(api.Database, false), + httpmw.ExtractWorkspaceAndAgentParam(api.Database), + ), ), // Build-Version is helpful for debugging. func(next http.Handler) http.Handler { @@ -332,11 +336,13 @@ func New(options *Options) *API { DisableSessionExpiryRefresh: options.DeploymentConfig.DisableSessionExpiryRefresh.Value, Optional: true, }), - // Redirect to the login page if the user tries to open an app with - // "me" as the username and they are not logged in. - httpmw.ExtractUserParam(api.Database, true), - // Extracts the from the url - httpmw.ExtractWorkspaceAndAgentParam(api.Database), + httpmw.AsAuthzSystem( + // Redirect to the login page if the user tries to open an app with + // "me" as the username and they are not logged in. + httpmw.ExtractUserParam(api.Database, true), + // Extracts the from the url + httpmw.ExtractWorkspaceAndAgentParam(api.Database), + ), ) r.HandleFunc("/*", api.workspaceAppsProxyPath) } diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index 275ac56cd1a2b..294ac80c08859 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -12,7 +12,6 @@ import ( "testing" "time" - "github.com/coder/coder/cryptorand" "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/moby/moby/pkg/namesgenerator" @@ -20,8 +19,9 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/xerrors" + "github.com/coder/coder/cryptorand" + "github.com/coder/coder/coderd" - "github.com/coder/coder/coderd/database/dbfake" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/coderd/rbac/regosql" "github.com/coder/coder/codersdk" @@ -30,12 +30,6 @@ import ( ) func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { - // For any route using SQL filters, we need to know if the database is an - // in memory fake. This is because the in memory fake does not use SQL, and - // still uses rego. So this boolean indicates how to assert the expected - // behavior. - _, isMemoryDB := a.api.Database.(dbfake.FakeDatabase) - // Some quick reused objects workspaceRBACObj := rbac.ResourceWorkspace.WithID(a.Workspace.ID).InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String()) workspaceExecObj := rbac.ResourceWorkspaceExecution.WithID(a.Workspace.ID).InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String()) @@ -269,16 +263,17 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { "POST:/api/v2/workspaces/{workspace}/builds": {StatusCode: http.StatusBadRequest, NoAuthorize: true}, "POST:/api/v2/organizations/{organization}/templateversions": {StatusCode: http.StatusBadRequest, NoAuthorize: true}, - // Endpoints that use the SQLQuery filter. + // For any route using SQL filters, we do not check authorization. + // This is because the in memory fake does not use SQL. "GET:/api/v2/workspaces/": { StatusCode: http.StatusOK, - NoAuthorize: !isMemoryDB, + NoAuthorize: true, AssertAction: rbac.ActionRead, AssertObject: rbac.ResourceWorkspace, }, "GET:/api/v2/organizations/{organization}/templates": { StatusCode: http.StatusOK, - NoAuthorize: !isMemoryDB, + NoAuthorize: true, AssertAction: rbac.ActionRead, AssertObject: rbac.ResourceTemplate, }, diff --git a/coderd/coderdtest/authorize_test.go b/coderd/coderdtest/authorize_test.go index 422c29ee63563..8ef2ef05d9b7b 100644 --- a/coderd/coderdtest/authorize_test.go +++ b/coderd/coderdtest/authorize_test.go @@ -2,15 +2,21 @@ package coderdtest_test import ( "context" + "os" + "strings" "testing" "github.com/stretchr/testify/require" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/coderd/rbac" + "github.com/coder/coder/codersdk" ) func TestAuthorizeAllEndpoints(t *testing.T) { + if strings.Contains(os.Getenv("CODER_EXPERIMENTS_TEST"), string(codersdk.ExperimentAuthzQuerier)) { + t.Skip("Skipping TestAuthorizeAllEndpoints for authz_querier experiment") + } t.Parallel() client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ // Required for any subdomain-based proxy tests to pass. diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 9324675ccdd81..3938c64fd5a85 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -35,6 +35,7 @@ import ( "github.com/golang-jwt/jwt" "github.com/google/uuid" "github.com/moby/moby/pkg/namesgenerator" + "github.com/prometheus/client_golang/prometheus" "github.com/spf13/afero" "github.com/spf13/pflag" "github.com/stretchr/testify/assert" @@ -58,6 +59,7 @@ import ( "github.com/coder/coder/coderd/autobuild/executor" "github.com/coder/coder/coderd/awsidentity" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/database/dbtestutil" "github.com/coder/coder/coderd/gitauth" "github.com/coder/coder/coderd/gitsshkey" @@ -179,12 +181,13 @@ func NewOptions(t *testing.T, options *Options) (func(http.Handler), context.Can options.Database, options.Pubsub = dbtestutil.NewDB(t) } // TODO: remove this once we're ready to enable authz querier by default. - if strings.Contains(os.Getenv("CODER_EXPERIMENTS_TEST"), "authz_querier") { - panic("Coming soon!") - // if options.Authorizer != nil { - // options.Authorizer = &RecordingAuthorizer{} - // } - // options.Database = authzquery.NewAuthzQuerier(options.Database, options.Authorizer) + if strings.Contains(os.Getenv("CODER_EXPERIMENTS_TEST"), string(codersdk.ExperimentAuthzQuerier)) { + if options.Authorizer == nil { + options.Authorizer = &RecordingAuthorizer{ + Wrapped: rbac.NewAuthorizer(prometheus.NewRegistry()), + } + } + options.Database = dbauthz.New(options.Database, options.Authorizer, slogtest.Make(t, nil).Leveled(slog.LevelDebug)) } if options.DeploymentConfig == nil { options.DeploymentConfig = DeploymentConfig(t) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go new file mode 100644 index 0000000000000..b3f80cd4a5468 --- /dev/null +++ b/coderd/database/dbauthz/dbauthz.go @@ -0,0 +1,387 @@ +package dbauthz + +import ( + "context" + "database/sql" + "fmt" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/open-policy-agent/opa/topdown" + + "cdr.dev/slog" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/rbac" +) + +var _ database.Store = (*querier)(nil) + +var ( + // NoActorError wraps ErrNoRows for the api to return a 404. This is the correct + // response when the user is not authorized. + NoActorError = xerrors.Errorf("no authorization actor in context: %w", sql.ErrNoRows) +) + +// NotAuthorizedError is a sentinel error that unwraps to sql.ErrNoRows. +// This allows the internal error to be read by the caller if needed. Otherwise +// it will be handled as a 404. +type NotAuthorizedError struct { + Err error +} + +func (e NotAuthorizedError) Error() string { + return fmt.Sprintf("unauthorized: %s", e.Err.Error()) +} + +// Unwrap will always unwrap to a sql.ErrNoRows so the API returns a 404. +// So 'errors.Is(err, sql.ErrNoRows)' will always be true. +func (NotAuthorizedError) Unwrap() error { + return sql.ErrNoRows +} + +func logNotAuthorizedError(ctx context.Context, logger slog.Logger, err error) error { + // Only log the errors if it is an UnauthorizedError error. + internalError := new(rbac.UnauthorizedError) + if err != nil && xerrors.As(err, &internalError) { + e := new(topdown.Error) + if xerrors.As(err, &e) || e.Code == topdown.CancelErr { + // For some reason rego changes a cancelled context to a topdown.CancelErr. We + // expect to check for cancelled context errors if the user cancels the request, + // so we should change the error to a context.Canceled error. + // + // NotAuthorizedError is == to sql.ErrNoRows, which is not correct + // if it's actually a cancelled context. + internalError.SetInternal(context.Canceled) + return internalError + } + logger.Debug(ctx, "unauthorized", + slog.F("internal", internalError.Internal()), + slog.F("input", internalError.Input()), + slog.Error(err), + ) + } + return NotAuthorizedError{ + Err: err, + } +} + +// querier is a wrapper around the database store that performs authorization +// checks before returning data. All querier methods expect an authorization +// subject present in the context. If no subject is present, most methods will +// fail. +// +// Use WithAuthorizeContext to set the authorization subject in the context for +// the common user case. +type querier struct { + db database.Store + auth rbac.Authorizer + log slog.Logger +} + +func New(db database.Store, authorizer rbac.Authorizer, logger slog.Logger) database.Store { + // If the underlying db store is already a querier, return it. + // Do not double wrap. + if _, ok := db.(*querier); ok { + return db + } + return &querier{ + db: db, + auth: authorizer, + log: logger, + } +} + +// authorizeContext is a helper function to authorize an action on an object. +func (q *querier) authorizeContext(ctx context.Context, action rbac.Action, object rbac.Objecter) error { + act, ok := ActorFromContext(ctx) + if !ok { + return NoActorError + } + + err := q.auth.Authorize(ctx, act, action, object.RBACObject()) + if err != nil { + return logNotAuthorizedError(ctx, q.log, err) + } + return nil +} + +type authContextKey struct{} + +// ActorFromContext returns the authorization subject from the context. +// All authentication flows should set the authorization subject in the context. +// If no actor is present, the function returns false. +func ActorFromContext(ctx context.Context) (rbac.Subject, bool) { + a, ok := ctx.Value(authContextKey{}).(rbac.Subject) + return a, ok +} + +// AsSystem returns a context with a system actor. This is used for internal +// system operations that are not tied to any particular actor. +// When you use this function, be sure to add a //nolint comment +// explaining why it is necessary. +// +// We trust you have received the usual lecture from the local System +// Administrator. It usually boils down to these three things: +// #1) Respect the privacy of others. +// #2) Think before you type. +// #3) With great power comes great responsibility. +func AsSystem(ctx context.Context) context.Context { + return context.WithValue(ctx, authContextKey{}, rbac.Subject{ + ID: uuid.Nil.String(), + Roles: rbac.Roles([]rbac.Role{ + { + Name: "system", + DisplayName: "System", + Site: []rbac.Permission{ + { + ResourceType: rbac.ResourceWildcard.Type, + Action: rbac.WildcardSymbol, + }, + }, + Org: map[string][]rbac.Permission{}, + User: []rbac.Permission{}, + }, + }), + Scope: rbac.ScopeAll, + }, + ) +} + +var AsRemoveActor = rbac.Subject{ + ID: "remove-actor", +} + +// As returns a context with the given actor stored in the context. +// This is used for cases where the actor touching the database is not the +// actor stored in the context. +// When you use this function, be sure to add a //nolint comment +// explaining why it is necessary. +func As(ctx context.Context, actor rbac.Subject) context.Context { + if actor.Equal(AsRemoveActor) { + // AsRemoveActor is a special case that is used to indicate that the actor + // should be removed from the context. + return context.WithValue(ctx, authContextKey{}, nil) + } + return context.WithValue(ctx, authContextKey{}, actor) +} + +// +// Generic functions used to implement the database.Store methods. +// + +// insert runs an rbac.ActionCreate on the rbac object argument before +// running the insertFunc. The insertFunc is expected to return the object that +// was inserted. +func insert[ + ObjectType any, + ArgumentType any, + Insert func(ctx context.Context, arg ArgumentType) (ObjectType, error), +]( + logger slog.Logger, + authorizer rbac.Authorizer, + object rbac.Objecter, + insertFunc Insert, +) Insert { + return func(ctx context.Context, arg ArgumentType) (empty ObjectType, err error) { + // Fetch the rbac subject + act, ok := ActorFromContext(ctx) + if !ok { + return empty, NoActorError + } + + // Authorize the action + err = authorizer.Authorize(ctx, act, rbac.ActionCreate, object.RBACObject()) + if err != nil { + return empty, logNotAuthorizedError(ctx, logger, err) + } + + // Insert the database object + return insertFunc(ctx, arg) + } +} + +func deleteQ[ + ObjectType rbac.Objecter, + ArgumentType any, + Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error), + Delete func(ctx context.Context, arg ArgumentType) error, +]( + logger slog.Logger, + authorizer rbac.Authorizer, + fetchFunc Fetch, + deleteFunc Delete, +) Delete { + return fetchAndExec(logger, authorizer, + rbac.ActionDelete, fetchFunc, deleteFunc) +} + +func updateWithReturn[ + ObjectType rbac.Objecter, + ArgumentType any, + Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error), + UpdateQuery func(ctx context.Context, arg ArgumentType) (ObjectType, error), +]( + logger slog.Logger, + authorizer rbac.Authorizer, + fetchFunc Fetch, + updateQuery UpdateQuery, +) UpdateQuery { + return fetchAndQuery(logger, authorizer, rbac.ActionUpdate, fetchFunc, updateQuery) +} + +func update[ + ObjectType rbac.Objecter, + ArgumentType any, + Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error), + Exec func(ctx context.Context, arg ArgumentType) error, +]( + logger slog.Logger, + authorizer rbac.Authorizer, + fetchFunc Fetch, + updateExec Exec, +) Exec { + return fetchAndExec(logger, authorizer, rbac.ActionUpdate, fetchFunc, updateExec) +} + +// fetch is a generic function that wraps a database +// query function (returns an object and an error) with authorization. The +// returned function has the same arguments as the database function. +// +// The database query function will **ALWAYS** hit the database, even if the +// user cannot read the resource. This is because the resource details are +// required to run a proper authorization check. +func fetch[ + ArgumentType any, + ObjectType rbac.Objecter, + DatabaseFunc func(ctx context.Context, arg ArgumentType) (ObjectType, error), +]( + logger slog.Logger, + authorizer rbac.Authorizer, + f DatabaseFunc, +) DatabaseFunc { + return func(ctx context.Context, arg ArgumentType) (empty ObjectType, err error) { + // Fetch the rbac subject + act, ok := ActorFromContext(ctx) + if !ok { + return empty, NoActorError + } + + // Fetch the database object + object, err := f(ctx, arg) + if err != nil { + return empty, xerrors.Errorf("fetch object: %w", err) + } + + // Authorize the action + err = authorizer.Authorize(ctx, act, rbac.ActionRead, object.RBACObject()) + if err != nil { + return empty, logNotAuthorizedError(ctx, logger, err) + } + + return object, nil + } +} + +// fetchAndExec uses fetchAndQuery but only returns the error. The naming comes +// from SQL 'exec' functions which only return an error. +// See fetchAndQuery for more information. +func fetchAndExec[ + ObjectType rbac.Objecter, + ArgumentType any, + Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error), + Exec func(ctx context.Context, arg ArgumentType) error, +]( + logger slog.Logger, + authorizer rbac.Authorizer, + action rbac.Action, + fetchFunc Fetch, + execFunc Exec, +) Exec { + f := fetchAndQuery(logger, authorizer, action, fetchFunc, func(ctx context.Context, arg ArgumentType) (empty ObjectType, err error) { + return empty, execFunc(ctx, arg) + }) + return func(ctx context.Context, arg ArgumentType) error { + _, err := f(ctx, arg) + return err + } +} + +// fetchAndQuery is a generic function that wraps a database fetch and query. +// A query has potential side effects in the database (update, delete, etc). +// The fetch is used to know which rbac object the action should be asserted on +// **before** the query runs. The returns from the fetch are only used to +// assert rbac. The final return of this function comes from the Query function. +func fetchAndQuery[ + ObjectType rbac.Objecter, + ArgumentType any, + Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error), + Query func(ctx context.Context, arg ArgumentType) (ObjectType, error), +]( + logger slog.Logger, + authorizer rbac.Authorizer, + action rbac.Action, + fetchFunc Fetch, + queryFunc Query, +) Query { + return func(ctx context.Context, arg ArgumentType) (empty ObjectType, err error) { + // Fetch the rbac subject + act, ok := ActorFromContext(ctx) + if !ok { + return empty, NoActorError + } + + // Fetch the database object + object, err := fetchFunc(ctx, arg) + if err != nil { + return empty, xerrors.Errorf("fetch object: %w", err) + } + + // Authorize the action + err = authorizer.Authorize(ctx, act, action, object.RBACObject()) + if err != nil { + return empty, logNotAuthorizedError(ctx, logger, err) + } + + return queryFunc(ctx, arg) + } +} + +// fetchWithPostFilter is like fetch, but works with lists of objects. +// SQL filters are much more optimal. +func fetchWithPostFilter[ + ArgumentType any, + ObjectType rbac.Objecter, + DatabaseFunc func(ctx context.Context, arg ArgumentType) ([]ObjectType, error), +]( + authorizer rbac.Authorizer, + f DatabaseFunc, +) DatabaseFunc { + return func(ctx context.Context, arg ArgumentType) (empty []ObjectType, err error) { + // Fetch the rbac subject + act, ok := ActorFromContext(ctx) + if !ok { + return empty, NoActorError + } + + // Fetch the database object + objects, err := f(ctx, arg) + if err != nil { + return nil, xerrors.Errorf("fetch object: %w", err) + } + + // Authorize the action + return rbac.Filter(ctx, authorizer, act, rbac.ActionRead, objects) + } +} + +// prepareSQLFilter is a helper function that prepares a SQL filter using the +// given authorization context. +func prepareSQLFilter(ctx context.Context, authorizer rbac.Authorizer, action rbac.Action, resourceType string) (rbac.PreparedAuthorized, error) { + act, ok := ActorFromContext(ctx) + if !ok { + return nil, NoActorError + } + + return authorizer.Prepare(ctx, act, action, resourceType) +} diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go new file mode 100644 index 0000000000000..ab4da817599db --- /dev/null +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -0,0 +1,151 @@ +package dbauthz_test + +import ( + "context" + "reflect" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbauthz" + "github.com/coder/coder/coderd/database/dbfake" + "github.com/coder/coder/coderd/database/dbgen" + "github.com/coder/coder/coderd/rbac" +) + +func TestAsNoActor(t *testing.T) { + t.Parallel() + + t.Run("AsRemoveActor", func(t *testing.T) { + t.Parallel() + _, ok := dbauthz.ActorFromContext(context.Background()) + require.False(t, ok, "no actor should be present") + }) + + t.Run("AsActor", func(t *testing.T) { + t.Parallel() + ctx := dbauthz.As(context.Background(), coderdtest.RandomRBACSubject()) + _, ok := dbauthz.ActorFromContext(ctx) + require.True(t, ok, "actor present") + }) + + t.Run("DeleteActor", func(t *testing.T) { + t.Parallel() + // First set an actor + ctx := dbauthz.As(context.Background(), coderdtest.RandomRBACSubject()) + _, ok := dbauthz.ActorFromContext(ctx) + require.True(t, ok, "actor present") + + // Delete the actor + ctx = dbauthz.As(ctx, dbauthz.AsRemoveActor) + _, ok = dbauthz.ActorFromContext(ctx) + require.False(t, ok, "actor should be deleted") + }) +} + +func TestPing(t *testing.T) { + t.Parallel() + + q := dbauthz.New(dbfake.New(), &coderdtest.RecordingAuthorizer{}, slog.Make()) + _, err := q.Ping(context.Background()) + require.NoError(t, err, "must not error") +} + +// TestInTX is not perfect, just checks that it properly checks auth. +func TestInTX(t *testing.T) { + t.Parallel() + + db := dbfake.New() + q := dbauthz.New(db, &coderdtest.RecordingAuthorizer{ + Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: xerrors.New("custom error")}, + }, slog.Make()) + actor := rbac.Subject{ + ID: uuid.NewString(), + Roles: rbac.RoleNames{rbac.RoleOwner()}, + Groups: []string{}, + Scope: rbac.ScopeAll, + } + + w := dbgen.Workspace(t, db, database.Workspace{}) + ctx := dbauthz.As(context.Background(), actor) + err := q.InTx(func(tx database.Store) error { + // The inner tx should use the parent's authz + _, err := tx.GetWorkspaceByID(ctx, w.ID) + return err + }, nil) + require.Error(t, err, "must error") + require.ErrorAs(t, err, &dbauthz.NotAuthorizedError{}, "must be an authorized error") +} + +// TestNew should not double wrap a querier. +func TestNew(t *testing.T) { + t.Parallel() + + var ( + db = dbfake.New() + exp = dbgen.Workspace(t, db, database.Workspace{}) + rec = &coderdtest.RecordingAuthorizer{ + Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: nil}, + } + subj = rbac.Subject{} + ctx = dbauthz.As(context.Background(), rbac.Subject{}) + ) + + // Double wrap should not cause an actual double wrap. So only 1 rbac call + // should be made. + az := dbauthz.New(db, rec, slog.Make()) + az = dbauthz.New(az, rec, slog.Make()) + + w, err := az.GetWorkspaceByID(ctx, exp.ID) + require.NoError(t, err, "must not error") + require.Equal(t, exp, w, "must be equal") + + rec.AssertActor(t, subj, rec.Pair(rbac.ActionRead, exp)) + require.NoError(t, rec.AllAsserted(), "should only be 1 rbac call") +} + +// TestDBAuthzRecursive is a simple test to search for infinite recursion +// bugs. It isn't perfect, and only catches a subset of the possible bugs +// as only the first db call will be made. But it is better than nothing. +func TestDBAuthzRecursive(t *testing.T) { + t.Parallel() + q := dbauthz.New(dbfake.New(), &coderdtest.RecordingAuthorizer{ + Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: nil}, + }, slog.Make()) + actor := rbac.Subject{ + ID: uuid.NewString(), + Roles: rbac.RoleNames{rbac.RoleOwner()}, + Groups: []string{}, + Scope: rbac.ScopeAll, + } + for i := 0; i < reflect.TypeOf(q).NumMethod(); i++ { + var ins []reflect.Value + ctx := dbauthz.As(context.Background(), actor) + + ins = append(ins, reflect.ValueOf(ctx)) + method := reflect.TypeOf(q).Method(i) + for i := 2; i < method.Type.NumIn(); i++ { + ins = append(ins, reflect.New(method.Type.In(i)).Elem()) + } + if method.Name == "InTx" || method.Name == "Ping" { + continue + } + // Log the name of the last method, so if there is a panic, it is + // easy to know which method failed. + // t.Log(method.Name) + // Call the function. Any infinite recursion will stack overflow. + reflect.ValueOf(q).Method(i).Call(ins) + } +} + +func must[T any](value T, err error) T { + if err != nil { + panic(err) + } + return value +} diff --git a/coderd/database/dbauthz/doc.go b/coderd/database/dbauthz/doc.go new file mode 100644 index 0000000000000..31af28bb951ef --- /dev/null +++ b/coderd/database/dbauthz/doc.go @@ -0,0 +1,17 @@ +// Package dbauthz provides an authorization layer on top of the database. This +// package exposes an interface that is currently a 1:1 mapping with +// database.Store. +// +// The same cultural rules apply to this package as they do to database.Store. +// Meaning that each method implemented should keep the number of database +// queries as close to 1 as possible. Each method should do 1 thing, with no +// unexpected side effects (eg: updating multiple tables in a single method). +// +// Do not implement business logic in this package. Only authorization related +// logic should be implemented here. In most cases, this should only be a call to +// the rbac authorizer. +// +// When a new database method is added to database.Store, it should be added to +// this package as well. The unit test "Accounting" will ensure all methods are +// tested. See other unit tests for examples on how to write these. +package dbauthz diff --git a/coderd/database/dbauthz/querier.go b/coderd/database/dbauthz/querier.go new file mode 100644 index 0000000000000..4442619ef3850 --- /dev/null +++ b/coderd/database/dbauthz/querier.go @@ -0,0 +1,1622 @@ +package dbauthz + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/rbac" + "github.com/coder/coder/coderd/util/slice" +) + +func (q *querier) Ping(ctx context.Context) (time.Duration, error) { + return q.db.Ping(ctx) +} + +// InTx runs the given function in a transaction. +func (q *querier) InTx(function func(querier database.Store) error, txOpts *sql.TxOptions) error { + return q.db.InTx(func(tx database.Store) error { + // Wrap the transaction store in a querier. + wrapped := New(tx, q.auth, q.log) + return function(wrapped) + }, txOpts) +} + +func (q *querier) DeleteAPIKeyByID(ctx context.Context, id string) error { + return deleteQ(q.log, q.auth, q.db.GetAPIKeyByID, q.db.DeleteAPIKeyByID)(ctx, id) +} + +func (q *querier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { + return fetch(q.log, q.auth, q.db.GetAPIKeyByID)(ctx, id) +} + +func (q *querier) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) { + return fetchWithPostFilter(q.auth, q.db.GetAPIKeysByLoginType)(ctx, loginType) +} + +func (q *querier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]database.APIKey, error) { + return fetchWithPostFilter(q.auth, q.db.GetAPIKeysLastUsedAfter)(ctx, lastUsed) +} + +func (q *querier) InsertAPIKey(ctx context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) { + return insert(q.log, q.auth, + rbac.ResourceAPIKey.WithOwner(arg.UserID.String()), + q.db.InsertAPIKey)(ctx, arg) +} + +func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error { + fetch := func(ctx context.Context, arg database.UpdateAPIKeyByIDParams) (database.APIKey, error) { + return q.db.GetAPIKeyByID(ctx, arg.ID) + } + return update(q.log, q.auth, fetch, q.db.UpdateAPIKeyByID)(ctx, arg) +} + +func (q *querier) InsertAuditLog(ctx context.Context, arg database.InsertAuditLogParams) (database.AuditLog, error) { + return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertAuditLog)(ctx, arg) +} + +func (q *querier) GetAuditLogsOffset(ctx context.Context, arg database.GetAuditLogsOffsetParams) ([]database.GetAuditLogsOffsetRow, error) { + // To optimize audit logs, we only check the global audit log permission once. + // This is because we expect a large unbounded set of audit logs, and applying a SQL + // filter would slow down the query for no benefit. + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceAuditLog); err != nil { + return nil, err + } + return q.db.GetAuditLogsOffset(ctx, arg) +} + +func (q *querier) GetFileByHashAndCreator(ctx context.Context, arg database.GetFileByHashAndCreatorParams) (database.File, error) { + return fetch(q.log, q.auth, q.db.GetFileByHashAndCreator)(ctx, arg) +} + +func (q *querier) GetFileByID(ctx context.Context, id uuid.UUID) (database.File, error) { + return fetch(q.log, q.auth, q.db.GetFileByID)(ctx, id) +} + +func (q *querier) InsertFile(ctx context.Context, arg database.InsertFileParams) (database.File, error) { + return insert(q.log, q.auth, rbac.ResourceFile.WithOwner(arg.CreatedBy.String()), q.db.InsertFile)(ctx, arg) +} + +func (q *querier) DeleteGroupByID(ctx context.Context, id uuid.UUID) error { + return deleteQ(q.log, q.auth, q.db.GetGroupByID, q.db.DeleteGroupByID)(ctx, id) +} + +func (q *querier) DeleteGroupMemberFromGroup(ctx context.Context, arg database.DeleteGroupMemberFromGroupParams) error { + // Deleting a group member counts as updating a group. + fetch := func(ctx context.Context, arg database.DeleteGroupMemberFromGroupParams) (database.Group, error) { + return q.db.GetGroupByID(ctx, arg.GroupID) + } + return update(q.log, q.auth, fetch, q.db.DeleteGroupMemberFromGroup)(ctx, arg) +} + +func (q *querier) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error { + // This will add the user to all named groups. This counts as updating a group. + // NOTE: instead of checking if the user has permission to update each group, we instead + // check if the user has permission to update *a* group in the org. + fetch := func(ctx context.Context, arg database.InsertUserGroupsByNameParams) (rbac.Objecter, error) { + return rbac.ResourceGroup.InOrg(arg.OrganizationID), nil + } + return update(q.log, q.auth, fetch, q.db.InsertUserGroupsByName)(ctx, arg) +} + +func (q *querier) DeleteGroupMembersByOrgAndUser(ctx context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) error { + // This will remove the user from all groups in the org. This counts as updating a group. + // NOTE: instead of fetching all groups in the org with arg.UserID as a member, we instead + // check if the caller has permission to update any group in the org. + fetch := func(ctx context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) (rbac.Objecter, error) { + return rbac.ResourceGroup.InOrg(arg.OrganizationID), nil + } + return update(q.log, q.auth, fetch, q.db.DeleteGroupMembersByOrgAndUser)(ctx, arg) +} + +func (q *querier) GetGroupByID(ctx context.Context, id uuid.UUID) (database.Group, error) { + return fetch(q.log, q.auth, q.db.GetGroupByID)(ctx, id) +} + +func (q *querier) GetGroupByOrgAndName(ctx context.Context, arg database.GetGroupByOrgAndNameParams) (database.Group, error) { + return fetch(q.log, q.auth, q.db.GetGroupByOrgAndName)(ctx, arg) +} + +func (q *querier) GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]database.User, error) { + if _, err := q.GetGroupByID(ctx, groupID); err != nil { // AuthZ check + return nil, err + } + return q.db.GetGroupMembers(ctx, groupID) +} + +func (q *querier) InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (database.Group, error) { + // This method creates a new group. + return insert(q.log, q.auth, rbac.ResourceGroup.InOrg(organizationID), q.db.InsertAllUsersGroup)(ctx, organizationID) +} + +func (q *querier) InsertGroup(ctx context.Context, arg database.InsertGroupParams) (database.Group, error) { + return insert(q.log, q.auth, rbac.ResourceGroup.InOrg(arg.OrganizationID), q.db.InsertGroup)(ctx, arg) +} + +func (q *querier) InsertGroupMember(ctx context.Context, arg database.InsertGroupMemberParams) error { + fetch := func(ctx context.Context, arg database.InsertGroupMemberParams) (database.Group, error) { + return q.db.GetGroupByID(ctx, arg.GroupID) + } + return update(q.log, q.auth, fetch, q.db.InsertGroupMember)(ctx, arg) +} + +func (q *querier) UpdateGroupByID(ctx context.Context, arg database.UpdateGroupByIDParams) (database.Group, error) { + fetch := func(ctx context.Context, arg database.UpdateGroupByIDParams) (database.Group, error) { + return q.db.GetGroupByID(ctx, arg.ID) + } + return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateGroupByID)(ctx, arg) +} + +func (q *querier) UpdateProvisionerJobWithCancelByID(ctx context.Context, arg database.UpdateProvisionerJobWithCancelByIDParams) error { + job, err := q.db.GetProvisionerJobByID(ctx, arg.ID) + if err != nil { + return err + } + + switch job.Type { + case database.ProvisionerJobTypeWorkspaceBuild: + build, err := q.db.GetWorkspaceBuildByJobID(ctx, arg.ID) + if err != nil { + return err + } + workspace, err := q.db.GetWorkspaceByID(ctx, build.WorkspaceID) + if err != nil { + return err + } + + template, err := q.db.GetTemplateByID(ctx, workspace.TemplateID) + if err != nil { + return err + } + + // Template can specify if cancels are allowed. + // Would be nice to have a way in the rbac rego to do this. + if !template.AllowUserCancelWorkspaceJobs { + // Only owners can cancel workspace builds + actor, ok := ActorFromContext(ctx) + if !ok { + return NoActorError + } + if !slice.Contains(actor.Roles.Names(), rbac.RoleOwner()) { + return xerrors.Errorf("only owners can cancel workspace builds") + } + } + + err = q.authorizeContext(ctx, rbac.ActionUpdate, workspace) + if err != nil { + return err + } + case database.ProvisionerJobTypeTemplateVersionDryRun, database.ProvisionerJobTypeTemplateVersionImport: + // Authorized call to get template version. + templateVersion, err := authorizedTemplateVersionFromJob(ctx, q, job) + if err != nil { + return err + } + + if templateVersion.TemplateID.Valid { + template, err := q.db.GetTemplateByID(ctx, templateVersion.TemplateID.UUID) + if err != nil { + return err + } + err = q.authorizeContext(ctx, rbac.ActionUpdate, templateVersion.RBACObject(template)) + if err != nil { + return err + } + } else { + err = q.authorizeContext(ctx, rbac.ActionUpdate, templateVersion.RBACObjectNoTemplate()) + if err != nil { + return err + } + } + default: + return xerrors.Errorf("unknown job type: %q", job.Type) + } + return q.db.UpdateProvisionerJobWithCancelByID(ctx, arg) +} + +func (q *querier) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) { + job, err := q.db.GetProvisionerJobByID(ctx, id) + if err != nil { + return database.ProvisionerJob{}, err + } + + switch job.Type { + case database.ProvisionerJobTypeWorkspaceBuild: + // Authorized call to get workspace build. If we can read the build, we + // can read the job. + _, err := q.GetWorkspaceBuildByJobID(ctx, id) + if err != nil { + return database.ProvisionerJob{}, err + } + case database.ProvisionerJobTypeTemplateVersionDryRun, database.ProvisionerJobTypeTemplateVersionImport: + // Authorized call to get template version. + _, err := authorizedTemplateVersionFromJob(ctx, q, job) + if err != nil { + return database.ProvisionerJob{}, err + } + default: + return database.ProvisionerJob{}, xerrors.Errorf("unknown job type: %q", job.Type) + } + + return job, nil +} + +func (q *querier) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ProvisionerJob, error) { + // TODO: This is missing authorization and is incorrect. This call is used by telemetry, and by 1 http route. + // That http handler should find a better way to fetch these jobs with easier rbac authz. + return q.db.GetProvisionerJobsByIDs(ctx, ids) +} + +func (q *querier) GetProvisionerLogsByIDBetween(ctx context.Context, arg database.GetProvisionerLogsByIDBetweenParams) ([]database.ProvisionerJobLog, error) { + // Authorized read on job lets the actor also read the logs. + _, err := q.GetProvisionerJobByID(ctx, arg.JobID) + if err != nil { + return nil, err + } + return q.db.GetProvisionerLogsByIDBetween(ctx, arg) +} + +func (q *querier) GetLicenses(ctx context.Context) ([]database.License, error) { + fetch := func(ctx context.Context, _ interface{}) ([]database.License, error) { + return q.db.GetLicenses(ctx) + } + return fetchWithPostFilter(q.auth, fetch)(ctx, nil) +} + +func (q *querier) InsertLicense(ctx context.Context, arg database.InsertLicenseParams) (database.License, error) { + if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceLicense); err != nil { + return database.License{}, err + } + return q.db.InsertLicense(ctx, arg) +} + +func (q *querier) InsertOrUpdateLogoURL(ctx context.Context, value string) error { + if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceDeploymentConfig); err != nil { + return err + } + return q.db.InsertOrUpdateLogoURL(ctx, value) +} + +func (q *querier) InsertOrUpdateServiceBanner(ctx context.Context, value string) error { + if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceDeploymentConfig); err != nil { + return err + } + return q.db.InsertOrUpdateServiceBanner(ctx, value) +} + +func (q *querier) GetLicenseByID(ctx context.Context, id int32) (database.License, error) { + return fetch(q.log, q.auth, q.db.GetLicenseByID)(ctx, id) +} + +func (q *querier) DeleteLicense(ctx context.Context, id int32) (int32, error) { + err := deleteQ(q.log, q.auth, q.db.GetLicenseByID, func(ctx context.Context, id int32) error { + _, err := q.db.DeleteLicense(ctx, id) + return err + })(ctx, id) + if err != nil { + return -1, err + } + return id, nil +} + +func (q *querier) GetDeploymentID(ctx context.Context) (string, error) { + // No authz checks + return q.db.GetDeploymentID(ctx) +} + +func (q *querier) GetLogoURL(ctx context.Context) (string, error) { + // No authz checks + return q.db.GetLogoURL(ctx) +} + +func (q *querier) GetServiceBanner(ctx context.Context) (string, error) { + // No authz checks + return q.db.GetServiceBanner(ctx) +} + +func (q *querier) GetProvisionerDaemons(ctx context.Context) ([]database.ProvisionerDaemon, error) { + fetch := func(ctx context.Context, _ interface{}) ([]database.ProvisionerDaemon, error) { + return q.db.GetProvisionerDaemons(ctx) + } + return fetchWithPostFilter(q.auth, fetch)(ctx, nil) +} + +func (q *querier) GetDeploymentDAUs(ctx context.Context) ([]database.GetDeploymentDAUsRow, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceUser.All()); err != nil { + return nil, err + } + return q.db.GetDeploymentDAUs(ctx) +} + +func (q *querier) GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]database.Group, error) { + return fetchWithPostFilter(q.auth, q.db.GetGroupsByOrganizationID)(ctx, organizationID) +} + +func (q *querier) GetOrganizationByID(ctx context.Context, id uuid.UUID) (database.Organization, error) { + return fetch(q.log, q.auth, q.db.GetOrganizationByID)(ctx, id) +} + +func (q *querier) GetOrganizationByName(ctx context.Context, name string) (database.Organization, error) { + return fetch(q.log, q.auth, q.db.GetOrganizationByName)(ctx, name) +} + +func (q *querier) GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uuid.UUID) ([]database.GetOrganizationIDsByMemberIDsRow, error) { + // TODO: This should be rewritten to return a list of database.OrganizationMember for consistent RBAC objects. + // Currently this row returns a list of org ids per user, which is challenging to check against the RBAC system. + return fetchWithPostFilter(q.auth, q.db.GetOrganizationIDsByMemberIDs)(ctx, ids) +} + +func (q *querier) GetOrganizationMemberByUserID(ctx context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) { + return fetch(q.log, q.auth, q.db.GetOrganizationMemberByUserID)(ctx, arg) +} + +func (q *querier) GetOrganizationMembershipsByUserID(ctx context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) { + return fetchWithPostFilter(q.auth, q.db.GetOrganizationMembershipsByUserID)(ctx, userID) +} + +func (q *querier) GetOrganizations(ctx context.Context) ([]database.Organization, error) { + fetch := func(ctx context.Context, _ interface{}) ([]database.Organization, error) { + return q.db.GetOrganizations(ctx) + } + return fetchWithPostFilter(q.auth, fetch)(ctx, nil) +} + +func (q *querier) GetOrganizationsByUserID(ctx context.Context, userID uuid.UUID) ([]database.Organization, error) { + return fetchWithPostFilter(q.auth, q.db.GetOrganizationsByUserID)(ctx, userID) +} + +func (q *querier) InsertOrganization(ctx context.Context, arg database.InsertOrganizationParams) (database.Organization, error) { + return insert(q.log, q.auth, rbac.ResourceOrganization, q.db.InsertOrganization)(ctx, arg) +} + +func (q *querier) InsertOrganizationMember(ctx context.Context, arg database.InsertOrganizationMemberParams) (database.OrganizationMember, error) { + // All roles are added roles. Org member is always implied. + addedRoles := append(arg.Roles, rbac.RoleOrgMember(arg.OrganizationID)) + err := q.canAssignRoles(ctx, &arg.OrganizationID, addedRoles, []string{}) + if err != nil { + return database.OrganizationMember{}, err + } + + obj := rbac.ResourceOrganizationMember.InOrg(arg.OrganizationID).WithID(arg.UserID) + return insert(q.log, q.auth, obj, q.db.InsertOrganizationMember)(ctx, arg) +} + +func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) { + // Authorized fetch will check that the actor has read access to the org member since the org member is returned. + member, err := q.GetOrganizationMemberByUserID(ctx, database.GetOrganizationMemberByUserIDParams{ + OrganizationID: arg.OrgID, + UserID: arg.UserID, + }) + if err != nil { + return database.OrganizationMember{}, err + } + + // The org member role is always implied. + impliedTypes := append(arg.GrantedRoles, rbac.RoleOrgMember(arg.OrgID)) + added, removed := rbac.ChangeRoleSet(member.Roles, impliedTypes) + err = q.canAssignRoles(ctx, &arg.OrgID, added, removed) + if err != nil { + return database.OrganizationMember{}, err + } + + return q.db.UpdateMemberRoles(ctx, arg) +} + +func (q *querier) canAssignRoles(ctx context.Context, orgID *uuid.UUID, added, removed []string) error { + actor, ok := ActorFromContext(ctx) + if !ok { + return NoActorError + } + + roleAssign := rbac.ResourceRoleAssignment + shouldBeOrgRoles := false + if orgID != nil { + roleAssign = roleAssign.InOrg(*orgID) + shouldBeOrgRoles = true + } + + grantedRoles := append(added, removed...) + // Validate that the roles being assigned are valid. + for _, r := range grantedRoles { + _, isOrgRole := rbac.IsOrgRole(r) + if shouldBeOrgRoles && !isOrgRole { + return xerrors.Errorf("Must only update org roles") + } + if !shouldBeOrgRoles && isOrgRole { + return xerrors.Errorf("Must only update site wide roles") + } + + // All roles should be valid roles + if _, err := rbac.RoleByName(r); err != nil { + return xerrors.Errorf("%q is not a supported role", r) + } + } + + if len(added) > 0 && q.authorizeContext(ctx, rbac.ActionCreate, roleAssign) != nil { + return logNotAuthorizedError(ctx, q.log, xerrors.Errorf("not authorized to assign roles")) + } + + if len(removed) > 0 && q.authorizeContext(ctx, rbac.ActionDelete, roleAssign) != nil { + return logNotAuthorizedError(ctx, q.log, xerrors.Errorf("not authorized to delete roles")) + } + + for _, roleName := range grantedRoles { + if !rbac.CanAssignRole(actor.Roles, roleName) { + return xerrors.Errorf("not authorized to assign role %q", roleName) + } + } + + return nil +} + +func (q *querier) parameterRBACResource(ctx context.Context, scope database.ParameterScope, scopeID uuid.UUID) (rbac.Objecter, error) { + var resource rbac.Objecter + var err error + switch scope { + case database.ParameterScopeWorkspace: + return q.db.GetWorkspaceByID(ctx, scopeID) + case database.ParameterScopeImportJob: + var version database.TemplateVersion + version, err = q.db.GetTemplateVersionByJobID(ctx, scopeID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, err + } + resource = version.RBACObjectNoTemplate() + + var template database.Template + template, err = q.db.GetTemplateByID(ctx, version.TemplateID.UUID) + if err == nil { + resource = version.RBACObject(template) + } else if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + return nil, err + } + return resource, nil + case database.ParameterScopeTemplate: + return q.db.GetTemplateByID(ctx, scopeID) + default: + return nil, xerrors.Errorf("Parameter scope %q unsupported", scope) + } +} + +func (q *querier) InsertParameterValue(ctx context.Context, arg database.InsertParameterValueParams) (database.ParameterValue, error) { + resource, err := q.parameterRBACResource(ctx, arg.Scope, arg.ScopeID) + if err != nil { + return database.ParameterValue{}, err + } + + err = q.authorizeContext(ctx, rbac.ActionUpdate, resource) + if err != nil { + return database.ParameterValue{}, err + } + + return q.db.InsertParameterValue(ctx, arg) +} + +func (q *querier) ParameterValue(ctx context.Context, id uuid.UUID) (database.ParameterValue, error) { + parameter, err := q.db.ParameterValue(ctx, id) + if err != nil { + return database.ParameterValue{}, err + } + + resource, err := q.parameterRBACResource(ctx, parameter.Scope, parameter.ScopeID) + if err != nil { + return database.ParameterValue{}, err + } + + err = q.authorizeContext(ctx, rbac.ActionRead, resource) + if err != nil { + return database.ParameterValue{}, err + } + + return parameter, nil +} + +// ParameterValues is implemented as an all or nothing query. If the user is not +// able to read a single parameter value, then the entire query is denied. +// This should likely be revisited and see if the usage of this function cannot be changed. +func (q *querier) ParameterValues(ctx context.Context, arg database.ParameterValuesParams) ([]database.ParameterValue, error) { + // This is a bit of a special case. Each parameter value returned might have a different scope. This could likely + // be implemented in a more efficient manner. + values, err := q.db.ParameterValues(ctx, arg) + if err != nil { + return nil, err + } + + cached := make(map[uuid.UUID]bool) + for _, value := range values { + // If we already checked this scopeID, then we can skip it. + // All scope ids are uuids of objects and universally unique. + if allowed := cached[value.ScopeID]; allowed { + continue + } + rbacObj, err := q.parameterRBACResource(ctx, value.Scope, value.ScopeID) + if err != nil { + return nil, err + } + err = q.authorizeContext(ctx, rbac.ActionRead, rbacObj) + if err != nil { + return nil, err + } + cached[value.ScopeID] = true + } + + return values, nil +} + +func (q *querier) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) { + version, err := q.db.GetTemplateVersionByJobID(ctx, jobID) + if err != nil { + return nil, err + } + object := version.RBACObjectNoTemplate() + if version.TemplateID.Valid { + tpl, err := q.db.GetTemplateByID(ctx, version.TemplateID.UUID) + if err != nil { + return nil, err + } + object = version.RBACObject(tpl) + } + + err = q.authorizeContext(ctx, rbac.ActionRead, object) + if err != nil { + return nil, err + } + return q.db.GetParameterSchemasByJobID(ctx, jobID) +} + +func (q *querier) GetParameterValueByScopeAndName(ctx context.Context, arg database.GetParameterValueByScopeAndNameParams) (database.ParameterValue, error) { + resource, err := q.parameterRBACResource(ctx, arg.Scope, arg.ScopeID) + if err != nil { + return database.ParameterValue{}, err + } + + err = q.authorizeContext(ctx, rbac.ActionRead, resource) + if err != nil { + return database.ParameterValue{}, err + } + + return q.db.GetParameterValueByScopeAndName(ctx, arg) +} + +func (q *querier) DeleteParameterValueByID(ctx context.Context, id uuid.UUID) error { + parameter, err := q.db.ParameterValue(ctx, id) + if err != nil { + return err + } + + resource, err := q.parameterRBACResource(ctx, parameter.Scope, parameter.ScopeID) + if err != nil { + return err + } + + // A deleted param is still updating the underlying resource for the scope. + err = q.authorizeContext(ctx, rbac.ActionUpdate, resource) + if err != nil { + return err + } + + return q.db.DeleteParameterValueByID(ctx, id) +} + +func (q *querier) GetPreviousTemplateVersion(ctx context.Context, arg database.GetPreviousTemplateVersionParams) (database.TemplateVersion, error) { + // An actor can read the previous template version if they can read the related template. + // If no linked template exists, we check if the actor can read *a* template. + if !arg.TemplateID.Valid { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate.InOrg(arg.OrganizationID)); err != nil { + return database.TemplateVersion{}, err + } + } + if _, err := q.GetTemplateByID(ctx, arg.TemplateID.UUID); err != nil { + return database.TemplateVersion{}, err + } + return q.db.GetPreviousTemplateVersion(ctx, arg) +} + +func (q *querier) GetTemplateAverageBuildTime(ctx context.Context, arg database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) { + // An actor can read the average build time if they can read the related template. + // It doesn't make any sense to get the average build time for a template that doesn't + // exist, so omitting this check here. + if _, err := q.GetTemplateByID(ctx, arg.TemplateID.UUID); err != nil { + return database.GetTemplateAverageBuildTimeRow{}, err + } + return q.db.GetTemplateAverageBuildTime(ctx, arg) +} + +func (q *querier) GetTemplateByID(ctx context.Context, id uuid.UUID) (database.Template, error) { + return fetch(q.log, q.auth, q.db.GetTemplateByID)(ctx, id) +} + +func (q *querier) GetTemplateByOrganizationAndName(ctx context.Context, arg database.GetTemplateByOrganizationAndNameParams) (database.Template, error) { + return fetch(q.log, q.auth, q.db.GetTemplateByOrganizationAndName)(ctx, arg) +} + +func (q *querier) GetTemplateDAUs(ctx context.Context, templateID uuid.UUID) ([]database.GetTemplateDAUsRow, error) { + // An actor can read the DAUs if they can read the related template. + // Again, it doesn't make sense to get DAUs for a template that doesn't exist. + if _, err := q.GetTemplateByID(ctx, templateID); err != nil { + return nil, err + } + return q.db.GetTemplateDAUs(ctx, templateID) +} + +func (q *querier) GetTemplateVersionByID(ctx context.Context, tvid uuid.UUID) (database.TemplateVersion, error) { + tv, err := q.db.GetTemplateVersionByID(ctx, tvid) + if err != nil { + return database.TemplateVersion{}, err + } + if !tv.TemplateID.Valid { + // If no linked template exists, check if the actor can read a template in the organization. + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate.InOrg(tv.OrganizationID)); err != nil { + return database.TemplateVersion{}, err + } + } else if _, err := q.GetTemplateByID(ctx, tv.TemplateID.UUID); err != nil { + // An actor can read the template version if they can read the related template. + return database.TemplateVersion{}, err + } + return tv, nil +} + +func (q *querier) GetTemplateVersionByJobID(ctx context.Context, jobID uuid.UUID) (database.TemplateVersion, error) { + tv, err := q.db.GetTemplateVersionByJobID(ctx, jobID) + if err != nil { + return database.TemplateVersion{}, err + } + if !tv.TemplateID.Valid { + // If no linked template exists, check if the actor can read a template in the organization. + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate.InOrg(tv.OrganizationID)); err != nil { + return database.TemplateVersion{}, err + } + } else if _, err := q.GetTemplateByID(ctx, tv.TemplateID.UUID); err != nil { + // An actor can read the template version if they can read the related template. + return database.TemplateVersion{}, err + } + return tv, nil +} + +func (q *querier) GetTemplateVersionByTemplateIDAndName(ctx context.Context, arg database.GetTemplateVersionByTemplateIDAndNameParams) (database.TemplateVersion, error) { + tv, err := q.db.GetTemplateVersionByTemplateIDAndName(ctx, arg) + if err != nil { + return database.TemplateVersion{}, err + } + if !tv.TemplateID.Valid { + // If no linked template exists, check if the actor can read a template in the organization. + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate.InOrg(tv.OrganizationID)); err != nil { + return database.TemplateVersion{}, err + } + } else if _, err := q.GetTemplateByID(ctx, tv.TemplateID.UUID); err != nil { + // An actor can read the template version if they can read the related template. + return database.TemplateVersion{}, err + } + return tv, nil +} + +func (q *querier) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) { + // An actor can read template version parameters if they can read the related template. + tv, err := q.db.GetTemplateVersionByID(ctx, templateVersionID) + if err != nil { + return nil, err + } + + var object rbac.Objecter + template, err := q.db.GetTemplateByID(ctx, tv.TemplateID.UUID) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return nil, err + } + object = rbac.ResourceTemplate.InOrg(tv.OrganizationID) + } else { + object = tv.RBACObject(template) + } + + if err := q.authorizeContext(ctx, rbac.ActionRead, object); err != nil { + return nil, err + } + return q.db.GetTemplateVersionParameters(ctx, templateVersionID) +} + +func (q *querier) GetTemplateVersionsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.TemplateVersion, error) { + // TODO: This is so inefficient + versions, err := q.db.GetTemplateVersionsByIDs(ctx, ids) + if err != nil { + return nil, err + } + checked := make(map[uuid.UUID]bool) + for _, v := range versions { + if _, ok := checked[v.TemplateID.UUID]; ok { + continue + } + + obj := v.RBACObjectNoTemplate() + template, err := q.db.GetTemplateByID(ctx, v.TemplateID.UUID) + if err == nil { + obj = v.RBACObject(template) + } + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + return nil, err + } + if err := q.authorizeContext(ctx, rbac.ActionRead, obj); err != nil { + return nil, err + } + checked[v.TemplateID.UUID] = true + } + + return versions, nil +} + +func (q *querier) GetTemplateVersionsByTemplateID(ctx context.Context, arg database.GetTemplateVersionsByTemplateIDParams) ([]database.TemplateVersion, error) { + // An actor can read template versions if they can read the related template. + template, err := q.db.GetTemplateByID(ctx, arg.TemplateID) + if err != nil { + return nil, err + } + + if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { + return nil, err + } + + return q.db.GetTemplateVersionsByTemplateID(ctx, arg) +} + +func (q *querier) GetTemplateVersionsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.TemplateVersion, error) { + // An actor can read execute this query if they can read all templates. + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTemplate.All()); err != nil { + return nil, err + } + return q.db.GetTemplateVersionsCreatedAfter(ctx, createdAt) +} + +func (q *querier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, _ rbac.PreparedAuthorized) ([]database.Template, error) { + // TODO Delete this function, all GetTemplates should be authorized. For now just call getTemplates on the authz querier. + return q.GetTemplatesWithFilter(ctx, arg) +} + +func (q *querier) GetTemplatesWithFilter(ctx context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) { + prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceTemplate.Type) + if err != nil { + return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err) + } + return q.db.GetAuthorizedTemplates(ctx, arg, prep) +} + +func (q *querier) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) (database.Template, error) { + obj := rbac.ResourceTemplate.InOrg(arg.OrganizationID) + return insert(q.log, q.auth, obj, q.db.InsertTemplate)(ctx, arg) +} + +func (q *querier) InsertTemplateVersion(ctx context.Context, arg database.InsertTemplateVersionParams) (database.TemplateVersion, error) { + if !arg.TemplateID.Valid { + // Making a new template version is the same permission as creating a new template. + err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceTemplate.InOrg(arg.OrganizationID)) + if err != nil { + return database.TemplateVersion{}, err + } + } else { + // Must do an authorized fetch to prevent leaking template ids this way. + tpl, err := q.GetTemplateByID(ctx, arg.TemplateID.UUID) + if err != nil { + return database.TemplateVersion{}, err + } + // Check the create permission on the template. + err = q.authorizeContext(ctx, rbac.ActionCreate, tpl) + if err != nil { + return database.TemplateVersion{}, err + } + } + + return q.db.InsertTemplateVersion(ctx, arg) +} + +func (q *querier) UpdateTemplateACLByID(ctx context.Context, arg database.UpdateTemplateACLByIDParams) (database.Template, error) { + // UpdateTemplateACL uses the ActionCreate action. Only users that can create the template + // may update the ACL. + fetch := func(ctx context.Context, arg database.UpdateTemplateACLByIDParams) (database.Template, error) { + return q.db.GetTemplateByID(ctx, arg.ID) + } + return fetchAndQuery(q.log, q.auth, rbac.ActionCreate, fetch, q.db.UpdateTemplateACLByID)(ctx, arg) +} + +func (q *querier) UpdateTemplateActiveVersionByID(ctx context.Context, arg database.UpdateTemplateActiveVersionByIDParams) error { + fetch := func(ctx context.Context, arg database.UpdateTemplateActiveVersionByIDParams) (database.Template, error) { + return q.db.GetTemplateByID(ctx, arg.ID) + } + return update(q.log, q.auth, fetch, q.db.UpdateTemplateActiveVersionByID)(ctx, arg) +} + +func (q *querier) SoftDeleteTemplateByID(ctx context.Context, id uuid.UUID) error { + deleteF := func(ctx context.Context, id uuid.UUID) error { + return q.db.UpdateTemplateDeletedByID(ctx, database.UpdateTemplateDeletedByIDParams{ + ID: id, + Deleted: true, + UpdatedAt: database.Now(), + }) + } + return deleteQ(q.log, q.auth, q.db.GetTemplateByID, deleteF)(ctx, id) +} + +// Deprecated: use SoftDeleteTemplateByID instead. +func (q *querier) UpdateTemplateDeletedByID(ctx context.Context, arg database.UpdateTemplateDeletedByIDParams) error { + return q.SoftDeleteTemplateByID(ctx, arg.ID) +} + +func (q *querier) UpdateTemplateMetaByID(ctx context.Context, arg database.UpdateTemplateMetaByIDParams) (database.Template, error) { + fetch := func(ctx context.Context, arg database.UpdateTemplateMetaByIDParams) (database.Template, error) { + return q.db.GetTemplateByID(ctx, arg.ID) + } + return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateTemplateMetaByID)(ctx, arg) +} + +func (q *querier) UpdateTemplateVersionByID(ctx context.Context, arg database.UpdateTemplateVersionByIDParams) error { + template, err := q.db.GetTemplateByID(ctx, arg.TemplateID.UUID) + if err != nil { + return err + } + if err := q.authorizeContext(ctx, rbac.ActionUpdate, template); err != nil { + return err + } + return q.db.UpdateTemplateVersionByID(ctx, arg) +} + +func (q *querier) UpdateTemplateVersionDescriptionByJobID(ctx context.Context, arg database.UpdateTemplateVersionDescriptionByJobIDParams) error { + // An actor is allowed to update the template version description if they are authorized to update the template. + tv, err := q.db.GetTemplateVersionByJobID(ctx, arg.JobID) + if err != nil { + return err + } + var obj rbac.Objecter + if !tv.TemplateID.Valid { + obj = rbac.ResourceTemplate.InOrg(tv.OrganizationID) + } else { + tpl, err := q.db.GetTemplateByID(ctx, tv.TemplateID.UUID) + if err != nil { + return err + } + obj = tpl + } + if err := q.authorizeContext(ctx, rbac.ActionUpdate, obj); err != nil { + return err + } + return q.db.UpdateTemplateVersionDescriptionByJobID(ctx, arg) +} + +func (q *querier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { + // An actor is authorized to read template group roles if they are authorized to read the template. + template, err := q.db.GetTemplateByID(ctx, id) + if err != nil { + return nil, err + } + if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { + return nil, err + } + return q.db.GetTemplateGroupRoles(ctx, id) +} + +func (q *querier) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) { + // An actor is authorized to query template user roles if they are authorized to read the template. + template, err := q.db.GetTemplateByID(ctx, id) + if err != nil { + return nil, err + } + if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { + return nil, err + } + return q.db.GetTemplateUserRoles(ctx, id) +} + +func (q *querier) DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error { + // TODO: This is not 100% correct because it omits apikey IDs. + err := q.authorizeContext(ctx, rbac.ActionDelete, + rbac.ResourceAPIKey.WithOwner(userID.String())) + if err != nil { + return err + } + return q.db.DeleteAPIKeysByUserID(ctx, userID) +} + +func (q *querier) GetQuotaAllowanceForUser(ctx context.Context, userID uuid.UUID) (int64, error) { + err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceUser.WithID(userID)) + if err != nil { + return -1, err + } + return q.db.GetQuotaAllowanceForUser(ctx, userID) +} + +func (q *querier) GetQuotaConsumedForUser(ctx context.Context, userID uuid.UUID) (int64, error) { + err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceUser.WithID(userID)) + if err != nil { + return -1, err + } + return q.db.GetQuotaConsumedForUser(ctx, userID) +} + +func (q *querier) GetUserByEmailOrUsername(ctx context.Context, arg database.GetUserByEmailOrUsernameParams) (database.User, error) { + return fetch(q.log, q.auth, q.db.GetUserByEmailOrUsername)(ctx, arg) +} + +func (q *querier) GetUserByID(ctx context.Context, id uuid.UUID) (database.User, error) { + return fetch(q.log, q.auth, q.db.GetUserByID)(ctx, id) +} + +func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { + return q.db.GetAuthorizedUserCount(ctx, arg, prepared) +} + +func (q *querier) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) { + prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type) + if err != nil { + return -1, xerrors.Errorf("(dev error) prepare sql filter: %w", err) + } + // TODO: This should be the only implementation. + return q.GetAuthorizedUserCount(ctx, arg, prep) +} + +func (q *querier) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) { + // TODO: We should use GetUsersWithCount with a better method signature. + return fetchWithPostFilter(q.auth, q.db.GetUsers)(ctx, arg) +} + +func (q *querier) GetUsersWithCount(ctx context.Context, arg database.GetUsersParams) ([]database.User, int64, error) { + // TODO Implement this with a SQL filter. The count is incorrect without it. + rowUsers, err := q.db.GetUsers(ctx, arg) + if err != nil { + return nil, -1, err + } + + if len(rowUsers) == 0 { + return []database.User{}, 0, nil + } + + act, ok := ActorFromContext(ctx) + if !ok { + return nil, -1, NoActorError + } + + // TODO: Is this correct? Should we return a restricted user? + users := database.ConvertUserRows(rowUsers) + users, err = rbac.Filter(ctx, q.auth, act, rbac.ActionRead, users) + if err != nil { + return nil, -1, err + } + + return users, rowUsers[0].Count, nil +} + +// TODO: Remove this and use a filter on GetUsers +func (q *querier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]database.User, error) { + return fetchWithPostFilter(q.auth, q.db.GetUsersByIDs)(ctx, ids) +} + +func (q *querier) InsertUser(ctx context.Context, arg database.InsertUserParams) (database.User, error) { + // Always check if the assigned roles can actually be assigned by this actor. + impliedRoles := append([]string{rbac.RoleMember()}, arg.RBACRoles...) + err := q.canAssignRoles(ctx, nil, impliedRoles, []string{}) + if err != nil { + return database.User{}, err + } + obj := rbac.ResourceUser + return insert(q.log, q.auth, obj, q.db.InsertUser)(ctx, arg) +} + +// TODO: Should this be in system.go? +func (q *querier) InsertUserLink(ctx context.Context, arg database.InsertUserLinkParams) (database.UserLink, error) { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceUser.WithID(arg.UserID)); err != nil { + return database.UserLink{}, err + } + return q.db.InsertUserLink(ctx, arg) +} + +func (q *querier) SoftDeleteUserByID(ctx context.Context, id uuid.UUID) error { + deleteF := func(ctx context.Context, id uuid.UUID) error { + return q.db.UpdateUserDeletedByID(ctx, database.UpdateUserDeletedByIDParams{ + ID: id, + Deleted: true, + }) + } + return deleteQ(q.log, q.auth, q.db.GetUserByID, deleteF)(ctx, id) +} + +// UpdateUserDeletedByID +// Deprecated: Delete this function in favor of 'SoftDeleteUserByID'. Deletes are +// irreversible. +func (q *querier) UpdateUserDeletedByID(ctx context.Context, arg database.UpdateUserDeletedByIDParams) error { + fetch := func(ctx context.Context, arg database.UpdateUserDeletedByIDParams) (database.User, error) { + return q.db.GetUserByID(ctx, arg.ID) + } + // This uses the rbac.ActionDelete action always as this function should always delete. + // We should delete this function in favor of 'SoftDeleteUserByID'. + return deleteQ(q.log, q.auth, fetch, q.db.UpdateUserDeletedByID)(ctx, arg) +} + +func (q *querier) UpdateUserHashedPassword(ctx context.Context, arg database.UpdateUserHashedPasswordParams) error { + user, err := q.db.GetUserByID(ctx, arg.ID) + if err != nil { + return err + } + + err = q.authorizeContext(ctx, rbac.ActionUpdate, user.UserDataRBACObject()) + if err != nil { + return err + } + + return q.db.UpdateUserHashedPassword(ctx, arg) +} + +func (q *querier) UpdateUserLastSeenAt(ctx context.Context, arg database.UpdateUserLastSeenAtParams) (database.User, error) { + fetch := func(ctx context.Context, arg database.UpdateUserLastSeenAtParams) (database.User, error) { + return q.db.GetUserByID(ctx, arg.ID) + } + return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateUserLastSeenAt)(ctx, arg) +} + +func (q *querier) UpdateUserProfile(ctx context.Context, arg database.UpdateUserProfileParams) (database.User, error) { + u, err := q.db.GetUserByID(ctx, arg.ID) + if err != nil { + return database.User{}, err + } + if err := q.authorizeContext(ctx, rbac.ActionUpdate, u.UserDataRBACObject()); err != nil { + return database.User{}, err + } + return q.db.UpdateUserProfile(ctx, arg) +} + +func (q *querier) UpdateUserStatus(ctx context.Context, arg database.UpdateUserStatusParams) (database.User, error) { + fetch := func(ctx context.Context, arg database.UpdateUserStatusParams) (database.User, error) { + return q.db.GetUserByID(ctx, arg.ID) + } + return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateUserStatus)(ctx, arg) +} + +func (q *querier) DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error { + return deleteQ(q.log, q.auth, q.db.GetGitSSHKey, q.db.DeleteGitSSHKey)(ctx, userID) +} + +func (q *querier) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) { + return fetch(q.log, q.auth, q.db.GetGitSSHKey)(ctx, userID) +} + +func (q *querier) InsertGitSSHKey(ctx context.Context, arg database.InsertGitSSHKeyParams) (database.GitSSHKey, error) { + return insert(q.log, q.auth, rbac.ResourceUserData.WithOwner(arg.UserID.String()).WithID(arg.UserID), q.db.InsertGitSSHKey)(ctx, arg) +} + +func (q *querier) UpdateGitSSHKey(ctx context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { + fetch := func(ctx context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { + return q.db.GetGitSSHKey(ctx, arg.UserID) + } + return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateGitSSHKey)(ctx, arg) +} + +func (q *querier) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) { + return fetch(q.log, q.auth, q.db.GetGitAuthLink)(ctx, arg) +} + +func (q *querier) InsertGitAuthLink(ctx context.Context, arg database.InsertGitAuthLinkParams) (database.GitAuthLink, error) { + return insert(q.log, q.auth, rbac.ResourceUserData.WithOwner(arg.UserID.String()).WithID(arg.UserID), q.db.InsertGitAuthLink)(ctx, arg) +} + +func (q *querier) UpdateGitAuthLink(ctx context.Context, arg database.UpdateGitAuthLinkParams) error { + fetch := func(ctx context.Context, arg database.UpdateGitAuthLinkParams) (database.GitAuthLink, error) { + return q.db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{UserID: arg.UserID, ProviderID: arg.ProviderID}) + } + return update(q.log, q.auth, fetch, q.db.UpdateGitAuthLink)(ctx, arg) +} + +func (q *querier) UpdateUserLink(ctx context.Context, arg database.UpdateUserLinkParams) (database.UserLink, error) { + fetch := func(ctx context.Context, arg database.UpdateUserLinkParams) (database.UserLink, error) { + return q.db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{ + UserID: arg.UserID, + LoginType: arg.LoginType, + }) + } + return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateUserLink)(ctx, arg) +} + +// UpdateUserRoles updates the site roles of a user. The validation for this function include more than +// just a basic RBAC check. +func (q *querier) UpdateUserRoles(ctx context.Context, arg database.UpdateUserRolesParams) (database.User, error) { + // We need to fetch the user being updated to identify the change in roles. + // This requires read access on the user in question, since the user is + // returned from this function. + user, err := fetch(q.log, q.auth, q.db.GetUserByID)(ctx, arg.ID) + if err != nil { + return database.User{}, err + } + + // The member role is always implied. + impliedTypes := append(arg.GrantedRoles, rbac.RoleMember()) + // If the changeset is nothing, less rbac checks need to be done. + added, removed := rbac.ChangeRoleSet(user.RBACRoles, impliedTypes) + err = q.canAssignRoles(ctx, nil, added, removed) + if err != nil { + return database.User{}, err + } + + return q.db.UpdateUserRoles(ctx, arg) +} + +func (q *querier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, _ rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { + // TODO Delete this function, all GetWorkspaces should be authorized. For now just call GetWorkspaces on the authz querier. + return q.GetWorkspaces(ctx, arg) +} + +func (q *querier) GetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams) ([]database.GetWorkspacesRow, error) { + prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceWorkspace.Type) + if err != nil { + return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err) + } + return q.db.GetAuthorizedWorkspaces(ctx, arg, prep) +} + +func (q *querier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { + if _, err := q.GetWorkspaceByID(ctx, workspaceID); err != nil { + return database.WorkspaceBuild{}, err + } + return q.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspaceID) +} + +func (q *querier) GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]database.WorkspaceBuild, error) { + // This is not ideal as not all builds will be returned if the workspace cannot be read. + // This should probably be handled differently? Maybe join workspace builds with workspace + // ownership properties and filter on that. + for _, id := range ids { + _, err := q.GetWorkspaceByID(ctx, id) + if err != nil { + return nil, err + } + } + + return q.db.GetLatestWorkspaceBuildsByWorkspaceIDs(ctx, ids) +} + +func (q *querier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { + if _, err := q.GetWorkspaceByAgentID(ctx, id); err != nil { + return database.WorkspaceAgent{}, err + } + return q.db.GetWorkspaceAgentByID(ctx, id) +} + +// GetWorkspaceAgentByInstanceID might want to be a system call? Unsure exactly, +// but this will fail. Need to figure out what AuthInstanceID is, and if it +// is essentially an auth token. But the caller using this function is not +// an authenticated user. So this authz check will fail. +func (q *querier) GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (database.WorkspaceAgent, error) { + agent, err := q.db.GetWorkspaceAgentByInstanceID(ctx, authInstanceID) + if err != nil { + return database.WorkspaceAgent{}, err + } + _, err = q.GetWorkspaceByAgentID(ctx, agent.ID) + if err != nil { + return database.WorkspaceAgent{}, err + } + return agent, nil +} + +// GetWorkspaceAgentsByResourceIDs is an all or nothing call. If the user cannot read +// a single agent, the entire call will fail. +func (q *querier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]database.WorkspaceAgent, error) { + if _, ok := ActorFromContext(ctx); !ok { + return nil, NoActorError + } + // TODO: Make this more efficient. This is annoying because all these resources should be owned by the same workspace. + // So the authz check should just be 1 check, but we cannot do that easily here. We should see if all callers can + // instead do something like GetWorkspaceAgentsByWorkspaceID. + agents, err := q.db.GetWorkspaceAgentsByResourceIDs(ctx, ids) + if err != nil { + return nil, err + } + + for _, a := range agents { + // Check if we can fetch the workspace by the agent ID. + _, err := q.GetWorkspaceByAgentID(ctx, a.ID) + if err == nil { + continue + } + if errors.Is(err, sql.ErrNoRows) && !errors.As(err, &NotAuthorizedError{}) { + // The agent is not tied to a workspace, likely from an orphaned template version. + // Just return it. + continue + } + // Otherwise, we cannot read the workspace, so we cannot read the agent. + return nil, logNotAuthorizedError(ctx, q.log, err) + } + return agents, nil +} + +func (q *querier) UpdateWorkspaceAgentLifecycleStateByID(ctx context.Context, arg database.UpdateWorkspaceAgentLifecycleStateByIDParams) error { + agent, err := q.db.GetWorkspaceAgentByID(ctx, arg.ID) + if err != nil { + return err + } + + workspace, err := q.db.GetWorkspaceByAgentID(ctx, agent.ID) + if err != nil { + return err + } + + if err := q.authorizeContext(ctx, rbac.ActionUpdate, workspace); err != nil { + return err + } + + return q.db.UpdateWorkspaceAgentLifecycleStateByID(ctx, arg) +} + +func (q *querier) UpdateWorkspaceAgentStartupByID(ctx context.Context, arg database.UpdateWorkspaceAgentStartupByIDParams) error { + agent, err := q.db.GetWorkspaceAgentByID(ctx, arg.ID) + if err != nil { + return err + } + + workspace, err := q.db.GetWorkspaceByAgentID(ctx, agent.ID) + if err != nil { + return err + } + + if err := q.authorizeContext(ctx, rbac.ActionUpdate, workspace); err != nil { + return err + } + + return q.db.UpdateWorkspaceAgentStartupByID(ctx, arg) +} + +func (q *querier) GetWorkspaceAppByAgentIDAndSlug(ctx context.Context, arg database.GetWorkspaceAppByAgentIDAndSlugParams) (database.WorkspaceApp, error) { + // If we can fetch the workspace, we can fetch the apps. Use the authorized call. + if _, err := q.GetWorkspaceByAgentID(ctx, arg.AgentID); err != nil { + return database.WorkspaceApp{}, err + } + + return q.db.GetWorkspaceAppByAgentIDAndSlug(ctx, arg) +} + +func (q *querier) GetWorkspaceAppsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.WorkspaceApp, error) { + if _, err := q.GetWorkspaceByAgentID(ctx, agentID); err != nil { + return nil, err + } + return q.db.GetWorkspaceAppsByAgentID(ctx, agentID) +} + +// GetWorkspaceAppsByAgentIDs is an all or nothing call. If the user cannot read a single app, the entire call will fail. +func (q *querier) GetWorkspaceAppsByAgentIDs(ctx context.Context, ids []uuid.UUID) ([]database.WorkspaceApp, error) { + // TODO: This should be reworked. All these apps are likely owned by the same workspace, so we should be able to + // do 1 authz call. We should refactor this to be GetWorkspaceAppsByWorkspaceID. + for _, id := range ids { + _, err := q.GetWorkspaceAgentByID(ctx, id) + if err != nil { + return nil, err + } + } + + return q.db.GetWorkspaceAppsByAgentIDs(ctx, ids) +} + +func (q *querier) GetWorkspaceBuildByID(ctx context.Context, buildID uuid.UUID) (database.WorkspaceBuild, error) { + build, err := q.db.GetWorkspaceBuildByID(ctx, buildID) + if err != nil { + return database.WorkspaceBuild{}, err + } + if _, err := q.GetWorkspaceByID(ctx, build.WorkspaceID); err != nil { + return database.WorkspaceBuild{}, err + } + return build, nil +} + +func (q *querier) GetWorkspaceBuildByJobID(ctx context.Context, jobID uuid.UUID) (database.WorkspaceBuild, error) { + build, err := q.db.GetWorkspaceBuildByJobID(ctx, jobID) + if err != nil { + return database.WorkspaceBuild{}, err + } + // Authorized fetch + _, err = q.GetWorkspaceByID(ctx, build.WorkspaceID) + if err != nil { + return database.WorkspaceBuild{}, err + } + return build, nil +} + +func (q *querier) GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx context.Context, arg database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams) (database.WorkspaceBuild, error) { + if _, err := q.GetWorkspaceByID(ctx, arg.WorkspaceID); err != nil { + return database.WorkspaceBuild{}, err + } + return q.db.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, arg) +} + +func (q *querier) GetWorkspaceBuildParameters(ctx context.Context, workspaceBuildID uuid.UUID) ([]database.WorkspaceBuildParameter, error) { + // Authorized call to get the workspace build. If we can read the build, + // we can read the params. + _, err := q.GetWorkspaceBuildByID(ctx, workspaceBuildID) + if err != nil { + return nil, err + } + + return q.db.GetWorkspaceBuildParameters(ctx, workspaceBuildID) +} + +func (q *querier) GetWorkspaceBuildsByWorkspaceID(ctx context.Context, arg database.GetWorkspaceBuildsByWorkspaceIDParams) ([]database.WorkspaceBuild, error) { + if _, err := q.GetWorkspaceByID(ctx, arg.WorkspaceID); err != nil { + return nil, err + } + return q.db.GetWorkspaceBuildsByWorkspaceID(ctx, arg) +} + +func (q *querier) GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (database.Workspace, error) { + return fetch(q.log, q.auth, q.db.GetWorkspaceByAgentID)(ctx, agentID) +} + +func (q *querier) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (database.Workspace, error) { + return fetch(q.log, q.auth, q.db.GetWorkspaceByID)(ctx, id) +} + +func (q *querier) GetWorkspaceByOwnerIDAndName(ctx context.Context, arg database.GetWorkspaceByOwnerIDAndNameParams) (database.Workspace, error) { + return fetch(q.log, q.auth, q.db.GetWorkspaceByOwnerIDAndName)(ctx, arg) +} + +func (q *querier) GetWorkspaceResourceByID(ctx context.Context, id uuid.UUID) (database.WorkspaceResource, error) { + // TODO: Optimize this + resource, err := q.db.GetWorkspaceResourceByID(ctx, id) + if err != nil { + return database.WorkspaceResource{}, err + } + + _, err = q.GetProvisionerJobByID(ctx, resource.JobID) + if err != nil { + return database.WorkspaceResource{}, err + } + + return resource, nil +} + +// GetWorkspaceResourceMetadataByResourceIDs is an all or nothing call. If a single resource is not authorized, then +// an error is returned. +func (q *querier) GetWorkspaceResourceMetadataByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]database.WorkspaceResourceMetadatum, error) { + // TODO: This is very inefficient. Since all these resources are likely asscoiated with the same workspace. + for _, id := range ids { + // If we can read the resource, we can read the metadata. + _, err := q.GetWorkspaceResourceByID(ctx, id) + if err != nil { + return nil, err + } + } + + return q.db.GetWorkspaceResourceMetadataByResourceIDs(ctx, ids) +} + +func (q *querier) GetWorkspaceResourcesByJobID(ctx context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) { + job, err := q.db.GetProvisionerJobByID(ctx, jobID) + if err != nil { + return nil, err + } + var obj rbac.Objecter + switch job.Type { + case database.ProvisionerJobTypeTemplateVersionDryRun, database.ProvisionerJobTypeTemplateVersionImport: + // We don't need to do an authorized check, but this helper function + // handles the job type for us. + // TODO: Do not duplicate auth checks. + tv, err := authorizedTemplateVersionFromJob(ctx, q, job) + if err != nil { + return nil, err + } + if !tv.TemplateID.Valid { + // Orphaned template version + obj = tv.RBACObjectNoTemplate() + } else { + template, err := q.db.GetTemplateByID(ctx, tv.TemplateID.UUID) + if err != nil { + return nil, err + } + obj = template.RBACObject() + } + case database.ProvisionerJobTypeWorkspaceBuild: + build, err := q.db.GetWorkspaceBuildByJobID(ctx, jobID) + if err != nil { + return nil, err + } + workspace, err := q.db.GetWorkspaceByID(ctx, build.WorkspaceID) + if err != nil { + return nil, err + } + obj = workspace + default: + return nil, xerrors.Errorf("unknown job type: %s", job.Type) + } + + if err := q.authorizeContext(ctx, rbac.ActionRead, obj); err != nil { + return nil, err + } + return q.db.GetWorkspaceResourcesByJobID(ctx, jobID) +} + +// GetWorkspaceResourcesByJobIDs is an all or nothing call. If a single resource is not authorized, then +// an error is returned. +func (q *querier) GetWorkspaceResourcesByJobIDs(ctx context.Context, ids []uuid.UUID) ([]database.WorkspaceResource, error) { + // TODO: This is very inefficient. Since all these resources are likely asscoiated with the same workspace. + for _, id := range ids { + // If we can read the resource, we can read the metadata. + _, err := q.GetProvisionerJobByID(ctx, id) + if err != nil { + return nil, err + } + } + + return q.db.GetWorkspaceResourcesByJobIDs(ctx, ids) +} + +func (q *querier) InsertWorkspace(ctx context.Context, arg database.InsertWorkspaceParams) (database.Workspace, error) { + obj := rbac.ResourceWorkspace.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID) + return insert(q.log, q.auth, obj, q.db.InsertWorkspace)(ctx, arg) +} + +func (q *querier) InsertWorkspaceBuild(ctx context.Context, arg database.InsertWorkspaceBuildParams) (database.WorkspaceBuild, error) { + w, err := q.db.GetWorkspaceByID(ctx, arg.WorkspaceID) + if err != nil { + return database.WorkspaceBuild{}, err + } + + var action rbac.Action = rbac.ActionUpdate + if arg.Transition == database.WorkspaceTransitionDelete { + action = rbac.ActionDelete + } + + if err = q.authorizeContext(ctx, action, w); err != nil { + return database.WorkspaceBuild{}, err + } + + return q.db.InsertWorkspaceBuild(ctx, arg) +} + +func (q *querier) InsertWorkspaceBuildParameters(ctx context.Context, arg database.InsertWorkspaceBuildParametersParams) error { + // TODO: Optimize this. We always have the workspace and build already fetched. + build, err := q.db.GetWorkspaceBuildByID(ctx, arg.WorkspaceBuildID) + if err != nil { + return err + } + + workspace, err := q.db.GetWorkspaceByID(ctx, build.WorkspaceID) + if err != nil { + return err + } + + err = q.authorizeContext(ctx, rbac.ActionUpdate, workspace) + if err != nil { + return err + } + + return q.db.InsertWorkspaceBuildParameters(ctx, arg) +} + +func (q *querier) UpdateWorkspace(ctx context.Context, arg database.UpdateWorkspaceParams) (database.Workspace, error) { + fetch := func(ctx context.Context, arg database.UpdateWorkspaceParams) (database.Workspace, error) { + return q.db.GetWorkspaceByID(ctx, arg.ID) + } + return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateWorkspace)(ctx, arg) +} + +func (q *querier) UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg database.UpdateWorkspaceAgentConnectionByIDParams) error { + // TODO: This is a workspace agent operation. Should users be able to query this? + fetch := func(ctx context.Context, arg database.UpdateWorkspaceAgentConnectionByIDParams) (database.Workspace, error) { + return q.db.GetWorkspaceByAgentID(ctx, arg.ID) + } + return update(q.log, q.auth, fetch, q.db.UpdateWorkspaceAgentConnectionByID)(ctx, arg) +} + +func (q *querier) InsertAgentStat(ctx context.Context, arg database.InsertAgentStatParams) (database.AgentStat, error) { + // TODO: This is a workspace agent operation. Should users be able to query this? + // Not really sure what this is for. + workspace, err := q.db.GetWorkspaceByID(ctx, arg.WorkspaceID) + if err != nil { + return database.AgentStat{}, err + } + err = q.authorizeContext(ctx, rbac.ActionUpdate, workspace) + if err != nil { + return database.AgentStat{}, err + } + return q.db.InsertAgentStat(ctx, arg) +} + +func (q *querier) UpdateWorkspaceAppHealthByID(ctx context.Context, arg database.UpdateWorkspaceAppHealthByIDParams) error { + // TODO: This is a workspace agent operation. Should users be able to query this? + workspace, err := q.db.GetWorkspaceByWorkspaceAppID(ctx, arg.ID) + if err != nil { + return err + } + + err = q.authorizeContext(ctx, rbac.ActionUpdate, workspace.RBACObject()) + if err != nil { + return err + } + return q.db.UpdateWorkspaceAppHealthByID(ctx, arg) +} + +func (q *querier) UpdateWorkspaceAutostart(ctx context.Context, arg database.UpdateWorkspaceAutostartParams) error { + fetch := func(ctx context.Context, arg database.UpdateWorkspaceAutostartParams) (database.Workspace, error) { + return q.db.GetWorkspaceByID(ctx, arg.ID) + } + return update(q.log, q.auth, fetch, q.db.UpdateWorkspaceAutostart)(ctx, arg) +} + +func (q *querier) UpdateWorkspaceBuildByID(ctx context.Context, arg database.UpdateWorkspaceBuildByIDParams) (database.WorkspaceBuild, error) { + build, err := q.db.GetWorkspaceBuildByID(ctx, arg.ID) + if err != nil { + return database.WorkspaceBuild{}, err + } + + workspace, err := q.db.GetWorkspaceByID(ctx, build.WorkspaceID) + if err != nil { + return database.WorkspaceBuild{}, err + } + err = q.authorizeContext(ctx, rbac.ActionUpdate, workspace.RBACObject()) + if err != nil { + return database.WorkspaceBuild{}, err + } + + return q.db.UpdateWorkspaceBuildByID(ctx, arg) +} + +func (q *querier) SoftDeleteWorkspaceByID(ctx context.Context, id uuid.UUID) error { + return deleteQ(q.log, q.auth, q.db.GetWorkspaceByID, func(ctx context.Context, id uuid.UUID) error { + return q.db.UpdateWorkspaceDeletedByID(ctx, database.UpdateWorkspaceDeletedByIDParams{ + ID: id, + Deleted: true, + }) + })(ctx, id) +} + +// Deprecated: Use SoftDeleteWorkspaceByID +func (q *querier) UpdateWorkspaceDeletedByID(ctx context.Context, arg database.UpdateWorkspaceDeletedByIDParams) error { + // TODO deleteQ me, placeholder for database.Store + fetch := func(ctx context.Context, arg database.UpdateWorkspaceDeletedByIDParams) (database.Workspace, error) { + return q.db.GetWorkspaceByID(ctx, arg.ID) + } + // This function is always used to deleteQ. + return deleteQ(q.log, q.auth, fetch, q.db.UpdateWorkspaceDeletedByID)(ctx, arg) +} + +func (q *querier) UpdateWorkspaceLastUsedAt(ctx context.Context, arg database.UpdateWorkspaceLastUsedAtParams) error { + fetch := func(ctx context.Context, arg database.UpdateWorkspaceLastUsedAtParams) (database.Workspace, error) { + return q.db.GetWorkspaceByID(ctx, arg.ID) + } + return update(q.log, q.auth, fetch, q.db.UpdateWorkspaceLastUsedAt)(ctx, arg) +} + +func (q *querier) UpdateWorkspaceTTL(ctx context.Context, arg database.UpdateWorkspaceTTLParams) error { + fetch := func(ctx context.Context, arg database.UpdateWorkspaceTTLParams) (database.Workspace, error) { + return q.db.GetWorkspaceByID(ctx, arg.ID) + } + return update(q.log, q.auth, fetch, q.db.UpdateWorkspaceTTL)(ctx, arg) +} + +func (q *querier) GetWorkspaceByWorkspaceAppID(ctx context.Context, workspaceAppID uuid.UUID) (database.Workspace, error) { + return fetch(q.log, q.auth, q.db.GetWorkspaceByWorkspaceAppID)(ctx, workspaceAppID) +} + +func authorizedTemplateVersionFromJob(ctx context.Context, q *querier, job database.ProvisionerJob) (database.TemplateVersion, error) { + switch job.Type { + case database.ProvisionerJobTypeTemplateVersionDryRun: + // TODO: This is really unfortunate that we need to inspect the json + // payload. We should fix this. + tmp := struct { + TemplateVersionID uuid.UUID `json:"template_version_id"` + }{} + err := json.Unmarshal(job.Input, &tmp) + if err != nil { + return database.TemplateVersion{}, xerrors.Errorf("dry-run unmarshal: %w", err) + } + // Authorized call to get template version. + tv, err := q.GetTemplateVersionByID(ctx, tmp.TemplateVersionID) + if err != nil { + return database.TemplateVersion{}, err + } + return tv, nil + case database.ProvisionerJobTypeTemplateVersionImport: + // Authorized call to get template version. + tv, err := q.GetTemplateVersionByJobID(ctx, job.ID) + if err != nil { + return database.TemplateVersion{}, err + } + return tv, nil + default: + return database.TemplateVersion{}, xerrors.Errorf("unknown job type: %q", job.Type) + } +} diff --git a/coderd/database/dbauthz/querier_test.go b/coderd/database/dbauthz/querier_test.go new file mode 100644 index 0000000000000..96290f57745ab --- /dev/null +++ b/coderd/database/dbauthz/querier_test.go @@ -0,0 +1,1226 @@ +package dbauthz_test + +import ( + "context" + "encoding/json" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbgen" + "github.com/coder/coder/coderd/rbac" + "github.com/coder/coder/coderd/util/slice" +) + +func (s *MethodTestSuite) TestAPIKey() { + s.Run("DeleteAPIKeyByID", s.Subtest(func(db database.Store, check *expects) { + key, _ := dbgen.APIKey(s.T(), db, database.APIKey{}) + check.Args(key.ID).Asserts(key, rbac.ActionDelete).Returns() + })) + s.Run("GetAPIKeyByID", s.Subtest(func(db database.Store, check *expects) { + key, _ := dbgen.APIKey(s.T(), db, database.APIKey{}) + check.Args(key.ID).Asserts(key, rbac.ActionRead).Returns(key) + })) + s.Run("GetAPIKeysByLoginType", s.Subtest(func(db database.Store, check *expects) { + a, _ := dbgen.APIKey(s.T(), db, database.APIKey{LoginType: database.LoginTypePassword}) + b, _ := dbgen.APIKey(s.T(), db, database.APIKey{LoginType: database.LoginTypePassword}) + _, _ = dbgen.APIKey(s.T(), db, database.APIKey{LoginType: database.LoginTypeGithub}) + check.Args(database.LoginTypePassword). + Asserts(a, rbac.ActionRead, b, rbac.ActionRead). + Returns(slice.New(a, b)) + })) + s.Run("GetAPIKeysLastUsedAfter", s.Subtest(func(db database.Store, check *expects) { + a, _ := dbgen.APIKey(s.T(), db, database.APIKey{LastUsed: time.Now().Add(time.Hour)}) + b, _ := dbgen.APIKey(s.T(), db, database.APIKey{LastUsed: time.Now().Add(time.Hour)}) + _, _ = dbgen.APIKey(s.T(), db, database.APIKey{LastUsed: time.Now().Add(-time.Hour)}) + check.Args(time.Now()). + Asserts(a, rbac.ActionRead, b, rbac.ActionRead). + Returns(slice.New(a, b)) + })) + s.Run("InsertAPIKey", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.InsertAPIKeyParams{ + UserID: u.ID, + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, + }).Asserts(rbac.ResourceAPIKey.WithOwner(u.ID.String()), rbac.ActionCreate) + })) + s.Run("UpdateAPIKeyByID", s.Subtest(func(db database.Store, check *expects) { + a, _ := dbgen.APIKey(s.T(), db, database.APIKey{}) + check.Args(database.UpdateAPIKeyByIDParams{ + ID: a.ID, + }).Asserts(a, rbac.ActionUpdate).Returns() + })) +} + +func (s *MethodTestSuite) TestAuditLogs() { + s.Run("InsertAuditLog", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertAuditLogParams{ + ResourceType: database.ResourceTypeOrganization, + Action: database.AuditActionCreate, + }).Asserts(rbac.ResourceAuditLog, rbac.ActionCreate) + })) + s.Run("GetAuditLogsOffset", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.AuditLog(s.T(), db, database.AuditLog{}) + _ = dbgen.AuditLog(s.T(), db, database.AuditLog{}) + check.Args(database.GetAuditLogsOffsetParams{ + Limit: 10, + }).Asserts(rbac.ResourceAuditLog, rbac.ActionRead) + })) +} + +func (s *MethodTestSuite) TestFile() { + s.Run("GetFileByHashAndCreator", s.Subtest(func(db database.Store, check *expects) { + f := dbgen.File(s.T(), db, database.File{}) + check.Args(database.GetFileByHashAndCreatorParams{ + Hash: f.Hash, + CreatedBy: f.CreatedBy, + }).Asserts(f, rbac.ActionRead).Returns(f) + })) + s.Run("GetFileByID", s.Subtest(func(db database.Store, check *expects) { + f := dbgen.File(s.T(), db, database.File{}) + check.Args(f.ID).Asserts(f, rbac.ActionRead).Returns(f) + })) + s.Run("InsertFile", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.InsertFileParams{ + CreatedBy: u.ID, + }).Asserts(rbac.ResourceFile.WithOwner(u.ID.String()), rbac.ActionCreate) + })) +} + +func (s *MethodTestSuite) TestGroup() { + s.Run("DeleteGroupByID", s.Subtest(func(db database.Store, check *expects) { + g := dbgen.Group(s.T(), db, database.Group{}) + check.Args(g.ID).Asserts(g, rbac.ActionDelete).Returns() + })) + s.Run("DeleteGroupMemberFromGroup", s.Subtest(func(db database.Store, check *expects) { + g := dbgen.Group(s.T(), db, database.Group{}) + m := dbgen.GroupMember(s.T(), db, database.GroupMember{ + GroupID: g.ID, + }) + check.Args(database.DeleteGroupMemberFromGroupParams{ + UserID: m.UserID, + GroupID: g.ID, + }).Asserts(g, rbac.ActionUpdate).Returns() + })) + s.Run("GetGroupByID", s.Subtest(func(db database.Store, check *expects) { + g := dbgen.Group(s.T(), db, database.Group{}) + check.Args(g.ID).Asserts(g, rbac.ActionRead).Returns(g) + })) + s.Run("GetGroupByOrgAndName", s.Subtest(func(db database.Store, check *expects) { + g := dbgen.Group(s.T(), db, database.Group{}) + check.Args(database.GetGroupByOrgAndNameParams{ + OrganizationID: g.OrganizationID, + Name: g.Name, + }).Asserts(g, rbac.ActionRead).Returns(g) + })) + s.Run("GetGroupMembers", s.Subtest(func(db database.Store, check *expects) { + g := dbgen.Group(s.T(), db, database.Group{}) + _ = dbgen.GroupMember(s.T(), db, database.GroupMember{}) + check.Args(g.ID).Asserts(g, rbac.ActionRead) + })) + s.Run("InsertAllUsersGroup", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + check.Args(o.ID).Asserts(rbac.ResourceGroup.InOrg(o.ID), rbac.ActionCreate) + })) + s.Run("InsertGroup", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + check.Args(database.InsertGroupParams{ + OrganizationID: o.ID, + Name: "test", + }).Asserts(rbac.ResourceGroup.InOrg(o.ID), rbac.ActionCreate) + })) + s.Run("InsertGroupMember", s.Subtest(func(db database.Store, check *expects) { + g := dbgen.Group(s.T(), db, database.Group{}) + check.Args(database.InsertGroupMemberParams{ + UserID: uuid.New(), + GroupID: g.ID, + }).Asserts(g, rbac.ActionUpdate).Returns() + })) + s.Run("InsertUserGroupsByName", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + u1 := dbgen.User(s.T(), db, database.User{}) + g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + _ = dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g1.ID, UserID: u1.ID}) + check.Args(database.InsertUserGroupsByNameParams{ + OrganizationID: o.ID, + UserID: u1.ID, + GroupNames: slice.New(g1.Name, g2.Name), + }).Asserts(rbac.ResourceGroup.InOrg(o.ID), rbac.ActionUpdate).Returns() + })) + s.Run("DeleteGroupMembersByOrgAndUser", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + u1 := dbgen.User(s.T(), db, database.User{}) + g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + _ = dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g1.ID, UserID: u1.ID}) + _ = dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g2.ID, UserID: u1.ID}) + check.Args(database.DeleteGroupMembersByOrgAndUserParams{ + OrganizationID: o.ID, + UserID: u1.ID, + }).Asserts(rbac.ResourceGroup.InOrg(o.ID), rbac.ActionUpdate).Returns() + })) + s.Run("UpdateGroupByID", s.Subtest(func(db database.Store, check *expects) { + g := dbgen.Group(s.T(), db, database.Group{}) + check.Args(database.UpdateGroupByIDParams{ + ID: g.ID, + }).Asserts(g, rbac.ActionUpdate) + })) +} + +func (s *MethodTestSuite) TestProvsionerJob() { + s.Run("Build/GetProvisionerJobByID", s.Subtest(func(db database.Store, check *expects) { + w := dbgen.Workspace(s.T(), db, database.Workspace{}) + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeWorkspaceBuild, + }) + _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{JobID: j.ID, WorkspaceID: w.ID}) + check.Args(j.ID).Asserts(w, rbac.ActionRead).Returns(j) + })) + s.Run("TemplateVersion/GetProvisionerJobByID", s.Subtest(func(db database.Store, check *expects) { + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeTemplateVersionImport, + }) + tpl := dbgen.Template(s.T(), db, database.Template{}) + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, + JobID: j.ID, + }) + check.Args(j.ID).Asserts(v.RBACObject(tpl), rbac.ActionRead).Returns(j) + })) + s.Run("TemplateVersionDryRun/GetProvisionerJobByID", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{}) + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, + }) + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeTemplateVersionDryRun, + Input: must(json.Marshal(struct { + TemplateVersionID uuid.UUID `json:"template_version_id"` + }{TemplateVersionID: v.ID})), + }) + check.Args(j.ID).Asserts(v.RBACObject(tpl), rbac.ActionRead).Returns(j) + })) + s.Run("Build/UpdateProvisionerJobWithCancelByID", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{AllowUserCancelWorkspaceJobs: true}) + w := dbgen.Workspace(s.T(), db, database.Workspace{TemplateID: tpl.ID}) + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeWorkspaceBuild, + }) + _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{JobID: j.ID, WorkspaceID: w.ID}) + check.Args(database.UpdateProvisionerJobWithCancelByIDParams{ID: j.ID}).Asserts(w, rbac.ActionUpdate).Returns() + })) + s.Run("BuildFalseCancel/UpdateProvisionerJobWithCancelByID", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{AllowUserCancelWorkspaceJobs: false}) + w := dbgen.Workspace(s.T(), db, database.Workspace{TemplateID: tpl.ID}) + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeWorkspaceBuild, + }) + _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{JobID: j.ID, WorkspaceID: w.ID}) + check.Args(database.UpdateProvisionerJobWithCancelByIDParams{ID: j.ID}).Asserts(w, rbac.ActionUpdate).Returns() + })) + s.Run("TemplateVersion/UpdateProvisionerJobWithCancelByID", s.Subtest(func(db database.Store, check *expects) { + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeTemplateVersionImport, + }) + tpl := dbgen.Template(s.T(), db, database.Template{}) + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, + JobID: j.ID, + }) + check.Args(database.UpdateProvisionerJobWithCancelByIDParams{ID: j.ID}). + Asserts(v.RBACObject(tpl), []rbac.Action{rbac.ActionRead, rbac.ActionUpdate}).Returns() + })) + s.Run("TemplateVersionNoTemplate/UpdateProvisionerJobWithCancelByID", s.Subtest(func(db database.Store, check *expects) { + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeTemplateVersionImport, + }) + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: uuid.Nil, Valid: false}, + JobID: j.ID, + }) + check.Args(database.UpdateProvisionerJobWithCancelByIDParams{ID: j.ID}). + Asserts(v.RBACObjectNoTemplate(), []rbac.Action{rbac.ActionRead, rbac.ActionUpdate}).Returns() + })) + s.Run("TemplateVersionDryRun/UpdateProvisionerJobWithCancelByID", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{}) + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, + }) + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeTemplateVersionDryRun, + Input: must(json.Marshal(struct { + TemplateVersionID uuid.UUID `json:"template_version_id"` + }{TemplateVersionID: v.ID})), + }) + check.Args(database.UpdateProvisionerJobWithCancelByIDParams{ID: j.ID}). + Asserts(v.RBACObject(tpl), []rbac.Action{rbac.ActionRead, rbac.ActionUpdate}).Returns() + })) + s.Run("GetProvisionerJobsByIDs", s.Subtest(func(db database.Store, check *expects) { + a := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) + b := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) + check.Args([]uuid.UUID{a.ID, b.ID}).Asserts().Returns(slice.New(a, b)) + })) + s.Run("GetProvisionerLogsByIDBetween", s.Subtest(func(db database.Store, check *expects) { + w := dbgen.Workspace(s.T(), db, database.Workspace{}) + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeWorkspaceBuild, + }) + _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{JobID: j.ID, WorkspaceID: w.ID}) + check.Args(database.GetProvisionerLogsByIDBetweenParams{ + JobID: j.ID, + }).Asserts(w, rbac.ActionRead).Returns([]database.ProvisionerJobLog{}) + })) +} + +func (s *MethodTestSuite) TestLicense() { + s.Run("GetLicenses", s.Subtest(func(db database.Store, check *expects) { + l, err := db.InsertLicense(context.Background(), database.InsertLicenseParams{ + UUID: uuid.New(), + }) + require.NoError(s.T(), err) + check.Args().Asserts(l, rbac.ActionRead). + Returns([]database.License{l}) + })) + s.Run("InsertLicense", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertLicenseParams{}). + Asserts(rbac.ResourceLicense, rbac.ActionCreate) + })) + s.Run("InsertOrUpdateLogoURL", s.Subtest(func(db database.Store, check *expects) { + check.Args("value").Asserts(rbac.ResourceDeploymentConfig, rbac.ActionCreate) + })) + s.Run("InsertOrUpdateServiceBanner", s.Subtest(func(db database.Store, check *expects) { + check.Args("value").Asserts(rbac.ResourceDeploymentConfig, rbac.ActionCreate) + })) + s.Run("GetLicenseByID", s.Subtest(func(db database.Store, check *expects) { + l, err := db.InsertLicense(context.Background(), database.InsertLicenseParams{ + UUID: uuid.New(), + }) + require.NoError(s.T(), err) + check.Args(l.ID).Asserts(l, rbac.ActionRead).Returns(l) + })) + s.Run("DeleteLicense", s.Subtest(func(db database.Store, check *expects) { + l, err := db.InsertLicense(context.Background(), database.InsertLicenseParams{ + UUID: uuid.New(), + }) + require.NoError(s.T(), err) + check.Args(l.ID).Asserts(l, rbac.ActionDelete) + })) + s.Run("GetDeploymentID", s.Subtest(func(db database.Store, check *expects) { + check.Args().Asserts().Returns("") + })) + s.Run("GetLogoURL", s.Subtest(func(db database.Store, check *expects) { + err := db.InsertOrUpdateLogoURL(context.Background(), "value") + require.NoError(s.T(), err) + check.Args().Asserts().Returns("value") + })) + s.Run("GetServiceBanner", s.Subtest(func(db database.Store, check *expects) { + err := db.InsertOrUpdateServiceBanner(context.Background(), "value") + require.NoError(s.T(), err) + check.Args().Asserts().Returns("value") + })) +} + +func (s *MethodTestSuite) TestOrganization() { + s.Run("GetGroupsByOrganizationID", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + a := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + b := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID}) + check.Args(o.ID).Asserts(a, rbac.ActionRead, b, rbac.ActionRead). + Returns([]database.Group{a, b}) + })) + s.Run("GetOrganizationByID", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + check.Args(o.ID).Asserts(o, rbac.ActionRead).Returns(o) + })) + s.Run("GetOrganizationByName", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + check.Args(o.Name).Asserts(o, rbac.ActionRead).Returns(o) + })) + s.Run("GetOrganizationIDsByMemberIDs", s.Subtest(func(db database.Store, check *expects) { + oa := dbgen.Organization(s.T(), db, database.Organization{}) + ob := dbgen.Organization(s.T(), db, database.Organization{}) + ma := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{OrganizationID: oa.ID}) + mb := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{OrganizationID: ob.ID}) + check.Args([]uuid.UUID{ma.UserID, mb.UserID}). + Asserts(rbac.ResourceUser.WithID(ma.UserID), rbac.ActionRead, rbac.ResourceUser.WithID(mb.UserID), rbac.ActionRead) + })) + s.Run("GetOrganizationMemberByUserID", s.Subtest(func(db database.Store, check *expects) { + mem := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{}) + check.Args(database.GetOrganizationMemberByUserIDParams{ + OrganizationID: mem.OrganizationID, + UserID: mem.UserID, + }).Asserts(mem, rbac.ActionRead).Returns(mem) + })) + s.Run("GetOrganizationMembershipsByUserID", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + a := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID}) + b := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID}) + check.Args(u.ID).Asserts(a, rbac.ActionRead, b, rbac.ActionRead).Returns(slice.New(a, b)) + })) + s.Run("GetOrganizations", s.Subtest(func(db database.Store, check *expects) { + a := dbgen.Organization(s.T(), db, database.Organization{}) + b := dbgen.Organization(s.T(), db, database.Organization{}) + check.Args().Asserts(a, rbac.ActionRead, b, rbac.ActionRead).Returns(slice.New(a, b)) + })) + s.Run("GetOrganizationsByUserID", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + a := dbgen.Organization(s.T(), db, database.Organization{}) + _ = dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID, OrganizationID: a.ID}) + b := dbgen.Organization(s.T(), db, database.Organization{}) + _ = dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID, OrganizationID: b.ID}) + check.Args(u.ID).Asserts(a, rbac.ActionRead, b, rbac.ActionRead).Returns(slice.New(a, b)) + })) + s.Run("InsertOrganization", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertOrganizationParams{ + ID: uuid.New(), + Name: "random", + }).Asserts(rbac.ResourceOrganization, rbac.ActionCreate) + })) + s.Run("InsertOrganizationMember", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + u := dbgen.User(s.T(), db, database.User{}) + + check.Args(database.InsertOrganizationMemberParams{ + OrganizationID: o.ID, + UserID: u.ID, + Roles: []string{rbac.RoleOrgAdmin(o.ID)}, + }).Asserts( + rbac.ResourceRoleAssignment.InOrg(o.ID), rbac.ActionCreate, + rbac.ResourceOrganizationMember.InOrg(o.ID).WithID(u.ID), rbac.ActionCreate) + })) + s.Run("UpdateMemberRoles", s.Subtest(func(db database.Store, check *expects) { + o := dbgen.Organization(s.T(), db, database.Organization{}) + u := dbgen.User(s.T(), db, database.User{}) + mem := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{ + OrganizationID: o.ID, + UserID: u.ID, + Roles: []string{rbac.RoleOrgAdmin(o.ID)}, + }) + out := mem + out.Roles = []string{} + + check.Args(database.UpdateMemberRolesParams{ + GrantedRoles: []string{}, + UserID: u.ID, + OrgID: o.ID, + }).Asserts( + mem, rbac.ActionRead, + rbac.ResourceRoleAssignment.InOrg(o.ID), rbac.ActionCreate, // org-mem + rbac.ResourceRoleAssignment.InOrg(o.ID), rbac.ActionDelete, // org-admin + ).Returns(out) + })) +} + +func (s *MethodTestSuite) TestParameters() { + s.Run("Workspace/InsertParameterValue", s.Subtest(func(db database.Store, check *expects) { + w := dbgen.Workspace(s.T(), db, database.Workspace{}) + check.Args(database.InsertParameterValueParams{ + ScopeID: w.ID, + Scope: database.ParameterScopeWorkspace, + SourceScheme: database.ParameterSourceSchemeNone, + DestinationScheme: database.ParameterDestinationSchemeNone, + }).Asserts(w, rbac.ActionUpdate) + })) + s.Run("TemplateVersionNoTemplate/InsertParameterValue", s.Subtest(func(db database.Store, check *expects) { + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{JobID: j.ID, TemplateID: uuid.NullUUID{Valid: false}}) + check.Args(database.InsertParameterValueParams{ + ScopeID: j.ID, + Scope: database.ParameterScopeImportJob, + SourceScheme: database.ParameterSourceSchemeNone, + DestinationScheme: database.ParameterDestinationSchemeNone, + }).Asserts(v.RBACObjectNoTemplate(), rbac.ActionUpdate) + })) + s.Run("TemplateVersionTemplate/InsertParameterValue", s.Subtest(func(db database.Store, check *expects) { + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) + tpl := dbgen.Template(s.T(), db, database.Template{}) + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{JobID: j.ID, + TemplateID: uuid.NullUUID{ + UUID: tpl.ID, + Valid: true, + }}, + ) + check.Args(database.InsertParameterValueParams{ + ScopeID: j.ID, + Scope: database.ParameterScopeImportJob, + SourceScheme: database.ParameterSourceSchemeNone, + DestinationScheme: database.ParameterDestinationSchemeNone, + }).Asserts(v.RBACObject(tpl), rbac.ActionUpdate) + })) + s.Run("Template/InsertParameterValue", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{}) + check.Args(database.InsertParameterValueParams{ + ScopeID: tpl.ID, + Scope: database.ParameterScopeTemplate, + SourceScheme: database.ParameterSourceSchemeNone, + DestinationScheme: database.ParameterDestinationSchemeNone, + }).Asserts(tpl, rbac.ActionUpdate) + })) + s.Run("Template/ParameterValue", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{}) + pv := dbgen.ParameterValue(s.T(), db, database.ParameterValue{ + ScopeID: tpl.ID, + Scope: database.ParameterScopeTemplate, + }) + check.Args(pv.ID).Asserts(tpl, rbac.ActionRead).Returns(pv) + })) + s.Run("ParameterValues", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{}) + a := dbgen.ParameterValue(s.T(), db, database.ParameterValue{ + ScopeID: tpl.ID, + Scope: database.ParameterScopeTemplate, + }) + w := dbgen.Workspace(s.T(), db, database.Workspace{}) + b := dbgen.ParameterValue(s.T(), db, database.ParameterValue{ + ScopeID: w.ID, + Scope: database.ParameterScopeWorkspace, + }) + check.Args(database.ParameterValuesParams{ + IDs: []uuid.UUID{a.ID, b.ID}, + }).Asserts(tpl, rbac.ActionRead, w, rbac.ActionRead).Returns(slice.New(a, b)) + })) + s.Run("GetParameterSchemasByJobID", s.Subtest(func(db database.Store, check *expects) { + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) + tpl := dbgen.Template(s.T(), db, database.Template{}) + tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{JobID: j.ID, TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}}) + a := dbgen.ParameterSchema(s.T(), db, database.ParameterSchema{JobID: j.ID}) + check.Args(j.ID).Asserts(tv.RBACObject(tpl), rbac.ActionRead). + Returns([]database.ParameterSchema{a}) + })) + s.Run("Workspace/GetParameterValueByScopeAndName", s.Subtest(func(db database.Store, check *expects) { + w := dbgen.Workspace(s.T(), db, database.Workspace{}) + v := dbgen.ParameterValue(s.T(), db, database.ParameterValue{ + Scope: database.ParameterScopeWorkspace, + ScopeID: w.ID, + }) + check.Args(database.GetParameterValueByScopeAndNameParams{ + Scope: v.Scope, + ScopeID: v.ScopeID, + Name: v.Name, + }).Asserts(w, rbac.ActionRead).Returns(v) + })) + s.Run("Workspace/DeleteParameterValueByID", s.Subtest(func(db database.Store, check *expects) { + w := dbgen.Workspace(s.T(), db, database.Workspace{}) + v := dbgen.ParameterValue(s.T(), db, database.ParameterValue{ + Scope: database.ParameterScopeWorkspace, + ScopeID: w.ID, + }) + check.Args(v.ID).Asserts(w, rbac.ActionUpdate).Returns() + })) +} + +func (s *MethodTestSuite) TestTemplate() { + s.Run("GetPreviousTemplateVersion", s.Subtest(func(db database.Store, check *expects) { + tvid := uuid.New() + now := time.Now() + o1 := dbgen.Organization(s.T(), db, database.Organization{}) + t1 := dbgen.Template(s.T(), db, database.Template{ + OrganizationID: o1.ID, + ActiveVersionID: tvid, + }) + _ = dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + CreatedAt: now.Add(-time.Hour), + ID: tvid, + Name: t1.Name, + OrganizationID: o1.ID, + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}}) + b := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + CreatedAt: now.Add(-2 * time.Hour), + Name: t1.Name, + OrganizationID: o1.ID, + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}}) + check.Args(database.GetPreviousTemplateVersionParams{ + Name: t1.Name, + OrganizationID: o1.ID, + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }).Asserts(t1, rbac.ActionRead).Returns(b) + })) + s.Run("GetTemplateAverageBuildTime", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(database.GetTemplateAverageBuildTimeParams{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }).Asserts(t1, rbac.ActionRead) + })) + s.Run("GetTemplateByID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(t1.ID).Asserts(t1, rbac.ActionRead).Returns(t1) + })) + s.Run("GetTemplateByOrganizationAndName", s.Subtest(func(db database.Store, check *expects) { + o1 := dbgen.Organization(s.T(), db, database.Organization{}) + t1 := dbgen.Template(s.T(), db, database.Template{ + OrganizationID: o1.ID, + }) + check.Args(database.GetTemplateByOrganizationAndNameParams{ + Name: t1.Name, + OrganizationID: o1.ID, + }).Asserts(t1, rbac.ActionRead).Returns(t1) + })) + s.Run("GetTemplateDAUs", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(t1.ID).Asserts(t1, rbac.ActionRead) + })) + s.Run("GetTemplateVersionByJobID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + check.Args(tv.JobID).Asserts(t1, rbac.ActionRead).Returns(tv) + })) + s.Run("GetTemplateVersionByTemplateIDAndName", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + check.Args(database.GetTemplateVersionByTemplateIDAndNameParams{ + Name: tv.Name, + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }).Asserts(t1, rbac.ActionRead).Returns(tv) + })) + s.Run("GetTemplateVersionParameters", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + check.Args(tv.ID).Asserts(t1, rbac.ActionRead).Returns([]database.TemplateVersionParameter{}) + })) + s.Run("GetTemplateGroupRoles", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(t1.ID).Asserts(t1, rbac.ActionRead) + })) + s.Run("GetTemplateUserRoles", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(t1.ID).Asserts(t1, rbac.ActionRead) + })) + s.Run("GetTemplateVersionByID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + check.Args(tv.ID).Asserts(t1, rbac.ActionRead).Returns(tv) + })) + s.Run("GetTemplateVersionsByIDs", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + t2 := dbgen.Template(s.T(), db, database.Template{}) + tv1 := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + tv2 := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t2.ID, Valid: true}, + }) + tv3 := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t2.ID, Valid: true}, + }) + check.Args([]uuid.UUID{tv1.ID, tv2.ID, tv3.ID}). + Asserts(t1, rbac.ActionRead, t2, rbac.ActionRead). + Returns(slice.New(tv1, tv2, tv3)) + })) + s.Run("GetTemplateVersionsByTemplateID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + a := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + b := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + check.Args(database.GetTemplateVersionsByTemplateIDParams{ + TemplateID: t1.ID, + }).Asserts(t1, rbac.ActionRead). + Returns(slice.New(a, b)) + })) + s.Run("GetTemplateVersionsCreatedAfter", s.Subtest(func(db database.Store, check *expects) { + now := time.Now() + t1 := dbgen.Template(s.T(), db, database.Template{}) + _ = dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + CreatedAt: now.Add(-time.Hour), + }) + _ = dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + CreatedAt: now.Add(-2 * time.Hour), + }) + check.Args(now.Add(-time.Hour)).Asserts(rbac.ResourceTemplate.All(), rbac.ActionRead) + })) + s.Run("GetTemplatesWithFilter", s.Subtest(func(db database.Store, check *expects) { + a := dbgen.Template(s.T(), db, database.Template{}) + // No asserts because SQLFilter. + check.Args(database.GetTemplatesWithFilterParams{}). + Asserts().Returns(slice.New(a)) + })) + s.Run("GetAuthorizedTemplates", s.Subtest(func(db database.Store, check *expects) { + a := dbgen.Template(s.T(), db, database.Template{}) + // No asserts because SQLFilter. + check.Args(database.GetTemplatesWithFilterParams{}, emptyPreparedAuthorized{}). + Asserts(). + Returns(slice.New(a)) + })) + s.Run("InsertTemplate", s.Subtest(func(db database.Store, check *expects) { + orgID := uuid.New() + check.Args(database.InsertTemplateParams{ + Provisioner: "echo", + OrganizationID: orgID, + }).Asserts(rbac.ResourceTemplate.InOrg(orgID), rbac.ActionCreate) + })) + s.Run("InsertTemplateVersion", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(database.InsertTemplateVersionParams{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + OrganizationID: t1.OrganizationID, + }).Asserts(t1, rbac.ActionRead, t1, rbac.ActionCreate) + })) + s.Run("SoftDeleteTemplateByID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(t1.ID).Asserts(t1, rbac.ActionDelete) + })) + s.Run("UpdateTemplateACLByID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(database.UpdateTemplateACLByIDParams{ + ID: t1.ID, + }).Asserts(t1, rbac.ActionCreate).Returns(t1) + })) + s.Run("UpdateTemplateActiveVersionByID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{ + ActiveVersionID: uuid.New(), + }) + tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + ID: t1.ActiveVersionID, + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + check.Args(database.UpdateTemplateActiveVersionByIDParams{ + ID: t1.ID, + ActiveVersionID: tv.ID, + }).Asserts(t1, rbac.ActionUpdate).Returns() + })) + s.Run("UpdateTemplateDeletedByID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(database.UpdateTemplateDeletedByIDParams{ + ID: t1.ID, + Deleted: true, + }).Asserts(t1, rbac.ActionDelete).Returns() + })) + s.Run("UpdateTemplateMetaByID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + check.Args(database.UpdateTemplateMetaByIDParams{ + ID: t1.ID, + }).Asserts(t1, rbac.ActionUpdate) + })) + s.Run("UpdateTemplateVersionByID", s.Subtest(func(db database.Store, check *expects) { + t1 := dbgen.Template(s.T(), db, database.Template{}) + tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }) + check.Args(database.UpdateTemplateVersionByIDParams{ + ID: tv.ID, + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + }).Asserts(t1, rbac.ActionUpdate).Returns() + })) + s.Run("UpdateTemplateVersionDescriptionByJobID", s.Subtest(func(db database.Store, check *expects) { + jobID := uuid.New() + t1 := dbgen.Template(s.T(), db, database.Template{}) + _ = dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: t1.ID, Valid: true}, + JobID: jobID, + }) + check.Args(database.UpdateTemplateVersionDescriptionByJobIDParams{ + JobID: jobID, + Readme: "foo", + }).Asserts(t1, rbac.ActionUpdate).Returns() + })) +} + +func (s *MethodTestSuite) TestUser() { + s.Run("DeleteAPIKeysByUserID", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(u.ID).Asserts(rbac.ResourceAPIKey.WithOwner(u.ID.String()), rbac.ActionDelete).Returns() + })) + s.Run("GetQuotaAllowanceForUser", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(u.ID).Asserts(u, rbac.ActionRead).Returns(int64(0)) + })) + s.Run("GetQuotaConsumedForUser", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(u.ID).Asserts(u, rbac.ActionRead).Returns(int64(0)) + })) + s.Run("GetUserByEmailOrUsername", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.GetUserByEmailOrUsernameParams{ + Username: u.Username, + Email: u.Email, + }).Asserts(u, rbac.ActionRead).Returns(u) + })) + s.Run("GetUserByID", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(u.ID).Asserts(u, rbac.ActionRead).Returns(u) + })) + s.Run("GetAuthorizedUserCount", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.User(s.T(), db, database.User{}) + check.Args(database.GetFilteredUserCountParams{}, emptyPreparedAuthorized{}).Asserts().Returns(int64(1)) + })) + s.Run("GetFilteredUserCount", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.User(s.T(), db, database.User{}) + check.Args(database.GetFilteredUserCountParams{}).Asserts().Returns(int64(1)) + })) + s.Run("GetUsers", s.Subtest(func(db database.Store, check *expects) { + a := dbgen.User(s.T(), db, database.User{CreatedAt: database.Now().Add(-time.Hour)}) + b := dbgen.User(s.T(), db, database.User{CreatedAt: database.Now()}) + check.Args(database.GetUsersParams{}). + Asserts(a, rbac.ActionRead, b, rbac.ActionRead) + })) + s.Run("GetUsersWithCount", s.Subtest(func(db database.Store, check *expects) { + a := dbgen.User(s.T(), db, database.User{CreatedAt: database.Now().Add(-time.Hour)}) + b := dbgen.User(s.T(), db, database.User{CreatedAt: database.Now()}) + check.Args(database.GetUsersParams{}).Asserts(a, rbac.ActionRead, b, rbac.ActionRead) + })) + s.Run("GetUsersByIDs", s.Subtest(func(db database.Store, check *expects) { + a := dbgen.User(s.T(), db, database.User{CreatedAt: database.Now().Add(-time.Hour)}) + b := dbgen.User(s.T(), db, database.User{CreatedAt: database.Now()}) + check.Args([]uuid.UUID{a.ID, b.ID}). + Asserts(a, rbac.ActionRead, b, rbac.ActionRead). + Returns(slice.New(a, b)) + })) + s.Run("InsertUser", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertUserParams{ + ID: uuid.New(), + LoginType: database.LoginTypePassword, + }).Asserts(rbac.ResourceRoleAssignment, rbac.ActionCreate, rbac.ResourceUser, rbac.ActionCreate) + })) + s.Run("InsertUserLink", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.InsertUserLinkParams{ + UserID: u.ID, + LoginType: database.LoginTypeOIDC, + }).Asserts(u, rbac.ActionUpdate) + })) + s.Run("SoftDeleteUserByID", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(u.ID).Asserts(u, rbac.ActionDelete).Returns() + })) + s.Run("UpdateUserDeletedByID", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{Deleted: true}) + check.Args(database.UpdateUserDeletedByIDParams{ + ID: u.ID, + Deleted: true, + }).Asserts(u, rbac.ActionDelete).Returns() + })) + s.Run("UpdateUserHashedPassword", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.UpdateUserHashedPasswordParams{ + ID: u.ID, + }).Asserts(u.UserDataRBACObject(), rbac.ActionUpdate).Returns() + })) + s.Run("UpdateUserLastSeenAt", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.UpdateUserLastSeenAtParams{ + ID: u.ID, + UpdatedAt: u.UpdatedAt, + LastSeenAt: u.LastSeenAt, + }).Asserts(u, rbac.ActionUpdate).Returns(u) + })) + s.Run("UpdateUserProfile", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.UpdateUserProfileParams{ + ID: u.ID, + Email: u.Email, + Username: u.Username, + UpdatedAt: u.UpdatedAt, + }).Asserts(u.UserDataRBACObject(), rbac.ActionUpdate).Returns(u) + })) + s.Run("UpdateUserStatus", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.UpdateUserStatusParams{ + ID: u.ID, + Status: u.Status, + UpdatedAt: u.UpdatedAt, + }).Asserts(u, rbac.ActionUpdate).Returns(u) + })) + s.Run("DeleteGitSSHKey", s.Subtest(func(db database.Store, check *expects) { + key := dbgen.GitSSHKey(s.T(), db, database.GitSSHKey{}) + check.Args(key.UserID).Asserts(key, rbac.ActionDelete).Returns() + })) + s.Run("GetGitSSHKey", s.Subtest(func(db database.Store, check *expects) { + key := dbgen.GitSSHKey(s.T(), db, database.GitSSHKey{}) + check.Args(key.UserID).Asserts(key, rbac.ActionRead).Returns(key) + })) + s.Run("InsertGitSSHKey", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.InsertGitSSHKeyParams{ + UserID: u.ID, + }).Asserts(rbac.ResourceUserData.WithID(u.ID).WithOwner(u.ID.String()), rbac.ActionCreate) + })) + s.Run("UpdateGitSSHKey", s.Subtest(func(db database.Store, check *expects) { + key := dbgen.GitSSHKey(s.T(), db, database.GitSSHKey{}) + check.Args(database.UpdateGitSSHKeyParams{ + UserID: key.UserID, + UpdatedAt: key.UpdatedAt, + }).Asserts(key, rbac.ActionUpdate).Returns(key) + })) + s.Run("GetGitAuthLink", s.Subtest(func(db database.Store, check *expects) { + link := dbgen.GitAuthLink(s.T(), db, database.GitAuthLink{}) + check.Args(database.GetGitAuthLinkParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + }).Asserts(link, rbac.ActionRead).Returns(link) + })) + s.Run("InsertGitAuthLink", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(database.InsertGitAuthLinkParams{ + ProviderID: uuid.NewString(), + UserID: u.ID, + }).Asserts(rbac.ResourceUserData.WithOwner(u.ID.String()).WithID(u.ID), rbac.ActionCreate) + })) + s.Run("UpdateGitAuthLink", s.Subtest(func(db database.Store, check *expects) { + link := dbgen.GitAuthLink(s.T(), db, database.GitAuthLink{}) + check.Args(database.UpdateGitAuthLinkParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + }).Asserts(link, rbac.ActionUpdate).Returns() + })) + s.Run("UpdateUserLink", s.Subtest(func(db database.Store, check *expects) { + link := dbgen.UserLink(s.T(), db, database.UserLink{}) + check.Args(database.UpdateUserLinkParams{ + OAuthAccessToken: link.OAuthAccessToken, + OAuthRefreshToken: link.OAuthRefreshToken, + OAuthExpiry: link.OAuthExpiry, + UserID: link.UserID, + LoginType: link.LoginType, + }).Asserts(link, rbac.ActionUpdate).Returns(link) + })) + s.Run("UpdateUserRoles", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{RBACRoles: []string{rbac.RoleTemplateAdmin()}}) + o := u + o.RBACRoles = []string{rbac.RoleUserAdmin()} + check.Args(database.UpdateUserRolesParams{ + GrantedRoles: []string{rbac.RoleUserAdmin()}, + ID: u.ID, + }).Asserts( + u, rbac.ActionRead, + rbac.ResourceRoleAssignment, rbac.ActionCreate, + rbac.ResourceRoleAssignment, rbac.ActionDelete, + ).Returns(o) + })) +} + +func (s *MethodTestSuite) TestWorkspace() { + s.Run("GetWorkspaceByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + check.Args(ws.ID).Asserts(ws, rbac.ActionRead) + })) + s.Run("GetWorkspaces", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.Workspace(s.T(), db, database.Workspace{}) + _ = dbgen.Workspace(s.T(), db, database.Workspace{}) + // No asserts here because SQLFilter. + check.Args(database.GetWorkspacesParams{}).Asserts() + })) + s.Run("GetAuthorizedWorkspaces", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.Workspace(s.T(), db, database.Workspace{}) + _ = dbgen.Workspace(s.T(), db, database.Workspace{}) + // No asserts here because SQLFilter. + check.Args(database.GetWorkspacesParams{}, emptyPreparedAuthorized{}).Asserts() + })) + s.Run("GetLatestWorkspaceBuildByWorkspaceID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID}) + check.Args(ws.ID).Asserts(ws, rbac.ActionRead).Returns(b) + })) + s.Run("GetLatestWorkspaceBuildsByWorkspaceIDs", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID}) + check.Args([]uuid.UUID{ws.ID}).Asserts(ws, rbac.ActionRead).Returns(slice.New(b)) + })) + s.Run("GetWorkspaceAgentByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + check.Args(agt.ID).Asserts(ws, rbac.ActionRead).Returns(agt) + })) + s.Run("GetWorkspaceAgentByInstanceID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + check.Args(agt.AuthInstanceID.String).Asserts(ws, rbac.ActionRead).Returns(agt) + })) + s.Run("GetWorkspaceAgentsByResourceIDs", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + check.Args([]uuid.UUID{res.ID}).Asserts(ws, rbac.ActionRead). + Returns([]database.WorkspaceAgent{agt}) + })) + s.Run("UpdateWorkspaceAgentLifecycleStateByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + check.Args(database.UpdateWorkspaceAgentLifecycleStateByIDParams{ + ID: agt.ID, + LifecycleState: database.WorkspaceAgentLifecycleStateCreated, + }).Asserts(ws, rbac.ActionUpdate).Returns() + })) + s.Run("UpdateWorkspaceAgentStartupByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + check.Args(database.UpdateWorkspaceAgentStartupByIDParams{ + ID: agt.ID, + }).Asserts(ws, rbac.ActionUpdate).Returns() + })) + s.Run("GetWorkspaceAppByAgentIDAndSlug", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + app := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: agt.ID}) + + check.Args(database.GetWorkspaceAppByAgentIDAndSlugParams{ + AgentID: agt.ID, + Slug: app.Slug, + }).Asserts(ws, rbac.ActionRead).Returns(app) + })) + s.Run("GetWorkspaceAppsByAgentID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + a := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: agt.ID}) + b := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: agt.ID}) + + check.Args(agt.ID).Asserts(ws, rbac.ActionRead).Returns(slice.New(a, b)) + })) + s.Run("GetWorkspaceAppsByAgentIDs", s.Subtest(func(db database.Store, check *expects) { + aWs := dbgen.Workspace(s.T(), db, database.Workspace{}) + aBuild := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: aWs.ID, JobID: uuid.New()}) + aRes := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: aBuild.JobID}) + aAgt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: aRes.ID}) + a := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: aAgt.ID}) + + bWs := dbgen.Workspace(s.T(), db, database.Workspace{}) + bBuild := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: bWs.ID, JobID: uuid.New()}) + bRes := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: bBuild.JobID}) + bAgt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: bRes.ID}) + b := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: bAgt.ID}) + + check.Args([]uuid.UUID{a.AgentID, b.AgentID}). + Asserts(aWs, rbac.ActionRead, bWs, rbac.ActionRead). + Returns([]database.WorkspaceApp{a, b}) + })) + s.Run("GetWorkspaceBuildByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID}) + check.Args(build.ID).Asserts(ws, rbac.ActionRead).Returns(build) + })) + s.Run("GetWorkspaceBuildByJobID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID}) + check.Args(build.JobID).Asserts(ws, rbac.ActionRead).Returns(build) + })) + s.Run("GetWorkspaceBuildByWorkspaceIDAndBuildNumber", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, BuildNumber: 10}) + check.Args(database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{ + WorkspaceID: ws.ID, + BuildNumber: build.BuildNumber, + }).Asserts(ws, rbac.ActionRead).Returns(build) + })) + s.Run("GetWorkspaceBuildParameters", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID}) + check.Args(build.ID).Asserts(ws, rbac.ActionRead). + Returns([]database.WorkspaceBuildParameter{}) + })) + s.Run("GetWorkspaceBuildsByWorkspaceID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, BuildNumber: 1}) + _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, BuildNumber: 2}) + _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, BuildNumber: 3}) + check.Args(database.GetWorkspaceBuildsByWorkspaceIDParams{WorkspaceID: ws.ID}).Asserts(ws, rbac.ActionRead) // ordering + })) + s.Run("GetWorkspaceByAgentID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + check.Args(agt.ID).Asserts(ws, rbac.ActionRead).Returns(ws) + })) + s.Run("GetWorkspaceByOwnerIDAndName", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + check.Args(database.GetWorkspaceByOwnerIDAndNameParams{ + OwnerID: ws.OwnerID, + Deleted: ws.Deleted, + Name: ws.Name, + }).Asserts(ws, rbac.ActionRead).Returns(ws) + })) + s.Run("GetWorkspaceResourceByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + _ = dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + check.Args(res.ID).Asserts(ws, rbac.ActionRead).Returns(res) + })) + s.Run("GetWorkspaceResourceMetadataByResourceIDs", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + _ = dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild}) + a := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + b := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + check.Args([]uuid.UUID{a.ID, b.ID}). + Asserts(ws, []rbac.Action{rbac.ActionRead, rbac.ActionRead}) + })) + s.Run("Build/GetWorkspaceResourcesByJobID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + job := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild}) + check.Args(job.ID).Asserts(ws, rbac.ActionRead).Returns([]database.WorkspaceResource{}) + })) + s.Run("Template/GetWorkspaceResourcesByJobID", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{}) + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, JobID: uuid.New()}) + job := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: v.JobID, Type: database.ProvisionerJobTypeTemplateVersionImport}) + check.Args(job.ID).Asserts(v.RBACObject(tpl), []rbac.Action{rbac.ActionRead, rbac.ActionRead}).Returns([]database.WorkspaceResource{}) + })) + s.Run("GetWorkspaceResourcesByJobIDs", s.Subtest(func(db database.Store, check *expects) { + tpl := dbgen.Template(s.T(), db, database.Template{}) + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, JobID: uuid.New()}) + tJob := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: v.JobID, Type: database.ProvisionerJobTypeTemplateVersionImport}) + + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + wJob := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild}) + check.Args([]uuid.UUID{tJob.ID, wJob.ID}).Asserts(v.RBACObject(tpl), rbac.ActionRead, ws, rbac.ActionRead).Returns([]database.WorkspaceResource{}) + })) + s.Run("InsertWorkspace", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + o := dbgen.Organization(s.T(), db, database.Organization{}) + check.Args(database.InsertWorkspaceParams{ + ID: uuid.New(), + OwnerID: u.ID, + OrganizationID: o.ID, + }).Asserts(rbac.ResourceWorkspace.WithOwner(u.ID.String()).InOrg(o.ID), rbac.ActionCreate) + })) + s.Run("Start/InsertWorkspaceBuild", s.Subtest(func(db database.Store, check *expects) { + w := dbgen.Workspace(s.T(), db, database.Workspace{}) + check.Args(database.InsertWorkspaceBuildParams{ + WorkspaceID: w.ID, + Transition: database.WorkspaceTransitionStart, + Reason: database.BuildReasonInitiator, + }).Asserts(w, rbac.ActionUpdate) + })) + s.Run("Delete/InsertWorkspaceBuild", s.Subtest(func(db database.Store, check *expects) { + w := dbgen.Workspace(s.T(), db, database.Workspace{}) + check.Args(database.InsertWorkspaceBuildParams{ + WorkspaceID: w.ID, + Transition: database.WorkspaceTransitionDelete, + Reason: database.BuildReasonInitiator, + }).Asserts(w, rbac.ActionDelete) + })) + s.Run("InsertWorkspaceBuildParameters", s.Subtest(func(db database.Store, check *expects) { + w := dbgen.Workspace(s.T(), db, database.Workspace{}) + b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: w.ID}) + check.Args(database.InsertWorkspaceBuildParametersParams{ + WorkspaceBuildID: b.ID, + Name: []string{"foo", "bar"}, + Value: []string{"baz", "qux"}, + }).Asserts(w, rbac.ActionUpdate) + })) + s.Run("UpdateWorkspace", s.Subtest(func(db database.Store, check *expects) { + w := dbgen.Workspace(s.T(), db, database.Workspace{}) + expected := w + expected.Name = "" + check.Args(database.UpdateWorkspaceParams{ + ID: w.ID, + }).Asserts(w, rbac.ActionUpdate).Returns(expected) + })) + s.Run("UpdateWorkspaceAgentConnectionByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + check.Args(database.UpdateWorkspaceAgentConnectionByIDParams{ + ID: agt.ID, + }).Asserts(ws, rbac.ActionUpdate).Returns() + })) + s.Run("InsertAgentStat", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + check.Args(database.InsertAgentStatParams{ + WorkspaceID: ws.ID, + }).Asserts(ws, rbac.ActionUpdate) + })) + s.Run("UpdateWorkspaceAppHealthByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + app := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: agt.ID}) + check.Args(database.UpdateWorkspaceAppHealthByIDParams{ + ID: app.ID, + Health: database.WorkspaceAppHealthDisabled, + }).Asserts(ws, rbac.ActionUpdate).Returns() + })) + s.Run("UpdateWorkspaceAutostart", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + check.Args(database.UpdateWorkspaceAutostartParams{ + ID: ws.ID, + }).Asserts(ws, rbac.ActionUpdate).Returns() + })) + s.Run("UpdateWorkspaceBuildByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + check.Args(database.UpdateWorkspaceBuildByIDParams{ + ID: build.ID, + UpdatedAt: build.UpdatedAt, + Deadline: build.Deadline, + }).Asserts(ws, rbac.ActionUpdate).Returns(build) + })) + s.Run("SoftDeleteWorkspaceByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + ws.Deleted = true + check.Args(ws.ID).Asserts(ws, rbac.ActionDelete).Returns() + })) + s.Run("UpdateWorkspaceDeletedByID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{Deleted: true}) + check.Args(database.UpdateWorkspaceDeletedByIDParams{ + ID: ws.ID, + Deleted: true, + }).Asserts(ws, rbac.ActionDelete).Returns() + })) + s.Run("UpdateWorkspaceLastUsedAt", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + check.Args(database.UpdateWorkspaceLastUsedAtParams{ + ID: ws.ID, + }).Asserts(ws, rbac.ActionUpdate).Returns() + })) + s.Run("UpdateWorkspaceTTL", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + check.Args(database.UpdateWorkspaceTTLParams{ + ID: ws.ID, + }).Asserts(ws, rbac.ActionUpdate).Returns() + })) + s.Run("GetWorkspaceByWorkspaceAppID", s.Subtest(func(db database.Store, check *expects) { + ws := dbgen.Workspace(s.T(), db, database.Workspace{}) + build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()}) + res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID}) + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID}) + app := dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{AgentID: agt.ID}) + check.Args(app.ID).Asserts(ws, rbac.ActionRead).Returns(ws) + })) +} + +func (s *MethodTestSuite) TestExtraMethods() { + s.Run("GetProvisionerDaemons", s.Subtest(func(db database.Store, check *expects) { + d, err := db.InsertProvisionerDaemon(context.Background(), database.InsertProvisionerDaemonParams{ + ID: uuid.New(), + }) + s.NoError(err, "insert provisioner daemon") + check.Args().Asserts(d, rbac.ActionRead) + })) + s.Run("GetDeploymentDAUs", s.Subtest(func(db database.Store, check *expects) { + check.Args().Asserts(rbac.ResourceUser.All(), rbac.ActionRead) + })) +} diff --git a/coderd/database/dbauthz/setup_test.go b/coderd/database/dbauthz/setup_test.go new file mode 100644 index 0000000000000..6fe03e52d0ebe --- /dev/null +++ b/coderd/database/dbauthz/setup_test.go @@ -0,0 +1,377 @@ +package dbauthz_test + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "sort" + "strings" + "testing" + + "golang.org/x/xerrors" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "cdr.dev/slog" + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbauthz" + "github.com/coder/coder/coderd/database/dbfake" + "github.com/coder/coder/coderd/rbac" + "github.com/coder/coder/coderd/rbac/regosql" + "github.com/coder/coder/coderd/util/slice" +) + +var ( + skipMethods = map[string]string{ + "InTx": "Not relevant", + "Ping": "Not relevant", + } +) + +// TestMethodTestSuite runs MethodTestSuite. +// In order for 'go test' to run this suite, we need to create +// a normal test function and pass our suite to suite.Run +// nolint: paralleltest +func TestMethodTestSuite(t *testing.T) { + suite.Run(t, new(MethodTestSuite)) +} + +// MethodTestSuite runs all methods tests for querier. We use +// a test suite so we can account for all functions tested on the querier. +// We can then assert all methods were tested and asserted for proper RBAC +// checks. This forces RBAC checks to be written for all methods. +// Additionally, the way unit tests are written allows for easily executing +// a single test for debugging. +type MethodTestSuite struct { + suite.Suite + // methodAccounting counts all methods called by a 'RunMethodTest' + methodAccounting map[string]int +} + +// SetupSuite sets up the suite by creating a map of all methods on querier +// and setting their count to 0. +func (s *MethodTestSuite) SetupSuite() { + az := dbauthz.New(nil, nil, slog.Make()) + // Take the underlying type of the interface. + azt := reflect.TypeOf(az).Elem() + s.methodAccounting = make(map[string]int) + for i := 0; i < azt.NumMethod(); i++ { + method := azt.Method(i) + if _, ok := skipMethods[method.Name]; ok { + // We can't use s.T().Skip as this will skip the entire suite. + s.T().Logf("Skipping method %q: %s", method.Name, skipMethods[method.Name]) + continue + } + s.methodAccounting[method.Name] = 0 + } +} + +// TearDownSuite asserts that all methods were called at least once. +func (s *MethodTestSuite) TearDownSuite() { + s.Run("Accounting", func() { + t := s.T() + notCalled := []string{} + for m, c := range s.methodAccounting { + if c <= 0 { + notCalled = append(notCalled, m) + } + } + sort.Strings(notCalled) + for _, m := range notCalled { + t.Errorf("Method never called: %q", m) + } + }) +} + +// Subtest is a helper function that returns a function that can be passed to +// s.Run(). This function will run the test case for the method that is being +// tested. The check parameter is used to assert the results of the method. +// If the caller does not use the `check` parameter, the test will fail. +func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expects)) func() { + return func() { + t := s.T() + testName := s.T().Name() + names := strings.Split(testName, "/") + methodName := names[len(names)-1] + s.methodAccounting[methodName]++ + + db := dbfake.New() + fakeAuthorizer := &coderdtest.FakeAuthorizer{ + AlwaysReturn: nil, + } + rec := &coderdtest.RecordingAuthorizer{ + Wrapped: fakeAuthorizer, + } + az := dbauthz.New(db, rec, slog.Make()) + actor := rbac.Subject{ + ID: uuid.NewString(), + Roles: rbac.RoleNames{rbac.RoleOwner()}, + Groups: []string{}, + Scope: rbac.ScopeAll, + } + ctx := dbauthz.As(context.Background(), actor) + + var testCase expects + testCaseF(db, &testCase) + // Check the developer added assertions. If there are no assertions, + // an empty list should be passed. + s.Require().False(testCase.assertions == nil, "rbac assertions not set, use the 'check' parameter") + + // Find the method with the name of the test. + var callMethod func(ctx context.Context) ([]reflect.Value, error) + azt := reflect.TypeOf(az) + MethodLoop: + for i := 0; i < azt.NumMethod(); i++ { + method := azt.Method(i) + if method.Name == methodName { + methodF := reflect.ValueOf(az).Method(i) + + callMethod = func(ctx context.Context) ([]reflect.Value, error) { + resp := methodF.Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.inputs...)) + return splitResp(t, resp) + } + break MethodLoop + } + } + + require.NotNil(t, callMethod, "method %q does not exist", methodName) + + if len(testCase.assertions) > 0 { + // Only run these tests if we know the underlying call makes + // rbac assertions. + s.NotAuthorizedErrorTest(ctx, fakeAuthorizer, callMethod) + } + + if len(testCase.assertions) > 0 || + slice.Contains([]string{ + "GetAuthorizedWorkspaces", + "GetAuthorizedTemplates", + }, methodName) { + // Some methods do not make RBAC assertions because they use + // SQL. We still want to test that they return an error if the + // actor is not set. + s.NoActorErrorTest(callMethod) + } + + // Always run + s.Run("Success", func() { + rec.Reset() + fakeAuthorizer.AlwaysReturn = nil + + outputs, err := callMethod(ctx) + s.NoError(err, "method %q returned an error", methodName) + + // Some tests may not care about the outputs, so we only assert if + // they are provided. + if testCase.outputs != nil { + // Assert the required outputs + s.Equal(len(testCase.outputs), len(outputs), "method %q returned unexpected number of outputs", methodName) + for i := range outputs { + a, b := testCase.outputs[i].Interface(), outputs[i].Interface() + if reflect.TypeOf(a).Kind() == reflect.Slice || reflect.TypeOf(a).Kind() == reflect.Array { + // Order does not matter + s.ElementsMatch(a, b, "method %q returned unexpected output %d", methodName, i) + } else { + s.Equal(a, b, "method %q returned unexpected output %d", methodName, i) + } + } + } + + var pairs []coderdtest.ActionObjectPair + for _, assrt := range testCase.assertions { + for _, action := range assrt.Actions { + pairs = append(pairs, coderdtest.ActionObjectPair{ + Action: action, + Object: assrt.Object, + }) + } + } + + rec.AssertActor(s.T(), actor, pairs...) + s.NoError(rec.AllAsserted(), "all rbac calls must be asserted") + }) + } +} + +func (s *MethodTestSuite) NoActorErrorTest(callMethod func(ctx context.Context) ([]reflect.Value, error)) { + s.Run("AsRemoveActor", func() { + // Call without any actor + _, err := callMethod(context.Background()) + s.ErrorIs(err, dbauthz.NoActorError, "method should return NoActorError error when no actor is provided") + }) +} + +// NotAuthorizedErrorTest runs the given method with an authorizer that will fail authz. +// Asserts that the error returned is a NotAuthorizedError. +func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderdtest.FakeAuthorizer, callMethod func(ctx context.Context) ([]reflect.Value, error)) { + s.Run("NotAuthorized", func() { + az.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("Always fail authz"), rbac.Subject{}, "", rbac.Object{}, nil) + + // If we have assertions, that means the method should FAIL + // if RBAC will disallow the request. The returned error should + // be expected to be a NotAuthorizedError. + resp, err := callMethod(ctx) + + // This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out + // any case where the error is nil and the response is an empty slice. + if err != nil || !hasEmptySliceResponse(resp) { + s.ErrorContainsf(err, "unauthorized", "error string should have a good message") + s.Errorf(err, "method should an error with disallow authz") + s.ErrorIsf(err, sql.ErrNoRows, "error should match sql.ErrNoRows") + s.ErrorAs(err, &dbauthz.NotAuthorizedError{}, "error should be NotAuthorizedError") + } + }) +} + +func hasEmptySliceResponse(values []reflect.Value) bool { + for _, r := range values { + if r.Kind() == reflect.Slice || r.Kind() == reflect.Array { + if r.Len() == 0 { + return true + } + } + } + return false +} + +func splitResp(t *testing.T, values []reflect.Value) ([]reflect.Value, error) { + outputs := []reflect.Value{} + for _, r := range values { + if r.Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) { + if r.IsNil() { + // Error is found, but it's nil! + return outputs, nil + } + err, ok := r.Interface().(error) + if !ok { + t.Fatal("error is not an error?!") + } + return outputs, err + } + outputs = append(outputs, r) + } //nolint: unreachable + t.Fatal("no expected error value found in responses (error can be nil)") + return nil, nil // unreachable, required to compile +} + +// expects is used to build a test case for a method. +// It includes the expected inputs, rbac assertions, and expected outputs. +type expects struct { + inputs []reflect.Value + assertions []AssertRBAC + // outputs is optional. Can assert non-error return values. + outputs []reflect.Value +} + +// Asserts is required. Asserts the RBAC authorize calls that should be made. +// If no RBAC calls are expected, pass an empty list: 'm.Asserts()' +func (m *expects) Asserts(pairs ...any) *expects { + m.assertions = asserts(pairs...) + return m +} + +// Args is required. The arguments to be provided to the method. +// If there are no arguments, pass an empty list: 'm.Args()' +// The first context argument should not be included, as the test suite +// will provide it. +func (m *expects) Args(args ...any) *expects { + m.inputs = values(args...) + return m +} + +// Returns is optional. If it is never called, it will not be asserted. +func (m *expects) Returns(rets ...any) *expects { + m.outputs = values(rets...) + return m +} + +// AssertRBAC contains the object and actions to be asserted. +type AssertRBAC struct { + Object rbac.Object + Actions []rbac.Action +} + +// values is a convenience method for creating []reflect.Value. +// +// values(workspace, template, ...) +// +// is equivalent to +// +// []reflect.Value{ +// reflect.ValueOf(workspace), +// reflect.ValueOf(template), +// ... +// } +func values(ins ...any) []reflect.Value { + out := make([]reflect.Value, 0) + for _, input := range ins { + input := input + out = append(out, reflect.ValueOf(input)) + } + return out +} + +// asserts is a convenience method for creating AssertRBACs. +// +// The number of inputs must be an even number. +// asserts() will panic if this is not the case. +// +// Even-numbered inputs are the objects, and odd-numbered inputs are the actions. +// Objects must implement rbac.Objecter. +// Inputs can be a single rbac.Action, or a slice of rbac.Action. +// +// asserts(workspace, rbac.ActionRead, template, slice(rbac.ActionRead, rbac.ActionWrite), ...) +// +// is equivalent to +// +// []AssertRBAC{ +// {Object: workspace, Actions: []rbac.Action{rbac.ActionRead}}, +// {Object: template, Actions: []rbac.Action{rbac.ActionRead, rbac.ActionWrite)}}, +// ... +// } +func asserts(inputs ...any) []AssertRBAC { + if len(inputs)%2 != 0 { + panic(fmt.Sprintf("Must be an even length number of args, found %d", len(inputs))) + } + + out := make([]AssertRBAC, 0) + for i := 0; i < len(inputs); i += 2 { + obj, ok := inputs[i].(rbac.Objecter) + if !ok { + panic(fmt.Sprintf("object type '%T' does not implement rbac.Objecter", inputs[i])) + } + rbacObj := obj.RBACObject() + + var actions []rbac.Action + actions, ok = inputs[i+1].([]rbac.Action) + if !ok { + action, ok := inputs[i+1].(rbac.Action) + if !ok { + // Could be the string type. + actionAsString, ok := inputs[i+1].(string) + if !ok { + panic(fmt.Sprintf("action '%q' not a supported action", actionAsString)) + } + action = rbac.Action(actionAsString) + } + actions = []rbac.Action{action} + } + + out = append(out, AssertRBAC{ + Object: rbacObj, + Actions: actions, + }) + } + return out +} + +type emptyPreparedAuthorized struct{} + +func (emptyPreparedAuthorized) Authorize(_ context.Context, _ rbac.Object) error { return nil } +func (emptyPreparedAuthorized) CompileToSQL(_ context.Context, _ regosql.ConvertConfig) (string, error) { + return "", nil +} diff --git a/coderd/database/dbauthz/system.go b/coderd/database/dbauthz/system.go new file mode 100644 index 0000000000000..bec4a6ae052e0 --- /dev/null +++ b/coderd/database/dbauthz/system.go @@ -0,0 +1,194 @@ +package dbauthz + +import ( + "context" + "time" + + "github.com/google/uuid" + + "github.com/coder/coder/coderd/database" +) + +// TODO: All these system functions should have rbac objects created to allow +// only system roles to call them. No user roles should ever have the permission +// to these objects. Might need a negative permission on the `Owner` role to +// prevent owners. + +func (q *querier) UpdateUserLinkedID(ctx context.Context, arg database.UpdateUserLinkedIDParams) (database.UserLink, error) { + return q.db.UpdateUserLinkedID(ctx, arg) +} + +func (q *querier) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (database.UserLink, error) { + return q.db.GetUserLinkByLinkedID(ctx, linkedID) +} + +func (q *querier) GetUserLinkByUserIDLoginType(ctx context.Context, arg database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) { + return q.db.GetUserLinkByUserIDLoginType(ctx, arg) +} + +func (q *querier) GetLatestWorkspaceBuilds(ctx context.Context) ([]database.WorkspaceBuild, error) { + // This function is a system function until we implement a join for workspace builds. + // This is because we need to query for all related workspaces to the returned builds. + // This is a very inefficient method of fetching the latest workspace builds. + // We should just join the rbac properties. + return q.db.GetLatestWorkspaceBuilds(ctx) +} + +// GetWorkspaceAgentByAuthToken is used in http middleware to get the workspace agent. +// This should only be used by a system user in that middleware. +func (q *querier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (database.WorkspaceAgent, error) { + return q.db.GetWorkspaceAgentByAuthToken(ctx, authToken) +} + +func (q *querier) GetActiveUserCount(ctx context.Context) (int64, error) { + return q.db.GetActiveUserCount(ctx) +} + +func (q *querier) GetUnexpiredLicenses(ctx context.Context) ([]database.License, error) { + return q.db.GetUnexpiredLicenses(ctx) +} + +func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (database.GetAuthorizationUserRolesRow, error) { + return q.db.GetAuthorizationUserRoles(ctx, userID) +} + +func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) { + // TODO Implement authz check for system user. + return q.db.GetDERPMeshKey(ctx) +} + +func (q *querier) InsertDERPMeshKey(ctx context.Context, value string) error { + // TODO Implement authz check for system user. + return q.db.InsertDERPMeshKey(ctx, value) +} + +func (q *querier) InsertDeploymentID(ctx context.Context, value string) error { + // TODO Implement authz check for system user. + return q.db.InsertDeploymentID(ctx, value) +} + +func (q *querier) InsertReplica(ctx context.Context, arg database.InsertReplicaParams) (database.Replica, error) { + // TODO Implement authz check for system user. + return q.db.InsertReplica(ctx, arg) +} + +func (q *querier) UpdateReplica(ctx context.Context, arg database.UpdateReplicaParams) (database.Replica, error) { + // TODO Implement authz check for system user. + return q.db.UpdateReplica(ctx, arg) +} + +func (q *querier) DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error { + // TODO Implement authz check for system user. + return q.db.DeleteReplicasUpdatedBefore(ctx, updatedAt) +} + +func (q *querier) GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]database.Replica, error) { + // TODO Implement authz check for system user. + return q.db.GetReplicasUpdatedAfter(ctx, updatedAt) +} + +func (q *querier) GetUserCount(ctx context.Context) (int64, error) { + return q.db.GetUserCount(ctx) +} + +func (q *querier) GetTemplates(ctx context.Context) ([]database.Template, error) { + // TODO Implement authz check for system user. + return q.db.GetTemplates(ctx) +} + +// UpdateWorkspaceBuildCostByID is used by the provisioning system to update the cost of a workspace build. +func (q *querier) UpdateWorkspaceBuildCostByID(ctx context.Context, arg database.UpdateWorkspaceBuildCostByIDParams) (database.WorkspaceBuild, error) { + return q.db.UpdateWorkspaceBuildCostByID(ctx, arg) +} + +func (q *querier) InsertOrUpdateLastUpdateCheck(ctx context.Context, value string) error { + return q.db.InsertOrUpdateLastUpdateCheck(ctx, value) +} + +func (q *querier) GetLastUpdateCheck(ctx context.Context) (string, error) { + return q.db.GetLastUpdateCheck(ctx) +} + +// Telemetry related functions. These functions are system functions for returning +// telemetry data. Never called by a user. + +func (q *querier) GetWorkspaceBuildsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceBuild, error) { + return q.db.GetWorkspaceBuildsCreatedAfter(ctx, createdAt) +} + +func (q *querier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceAgent, error) { + return q.db.GetWorkspaceAgentsCreatedAfter(ctx, createdAt) +} + +func (q *querier) GetWorkspaceAppsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceApp, error) { + return q.db.GetWorkspaceAppsCreatedAfter(ctx, createdAt) +} + +func (q *querier) GetWorkspaceResourcesCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceResource, error) { + return q.db.GetWorkspaceResourcesCreatedAfter(ctx, createdAt) +} + +func (q *querier) GetWorkspaceResourceMetadataCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceResourceMetadatum, error) { + return q.db.GetWorkspaceResourceMetadataCreatedAfter(ctx, createdAt) +} + +func (q *querier) DeleteOldAgentStats(ctx context.Context) error { + return q.db.DeleteOldAgentStats(ctx) +} + +func (q *querier) GetParameterSchemasCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.ParameterSchema, error) { + return q.db.GetParameterSchemasCreatedAfter(ctx, createdAt) +} +func (q *querier) GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.ProvisionerJob, error) { + return q.db.GetProvisionerJobsCreatedAfter(ctx, createdAt) +} + +// Provisionerd server functions + +func (q *querier) InsertWorkspaceAgent(ctx context.Context, arg database.InsertWorkspaceAgentParams) (database.WorkspaceAgent, error) { + return q.db.InsertWorkspaceAgent(ctx, arg) +} + +func (q *querier) InsertWorkspaceApp(ctx context.Context, arg database.InsertWorkspaceAppParams) (database.WorkspaceApp, error) { + return q.db.InsertWorkspaceApp(ctx, arg) +} + +func (q *querier) InsertWorkspaceResourceMetadata(ctx context.Context, arg database.InsertWorkspaceResourceMetadataParams) ([]database.WorkspaceResourceMetadatum, error) { + return q.db.InsertWorkspaceResourceMetadata(ctx, arg) +} + +func (q *querier) AcquireProvisionerJob(ctx context.Context, arg database.AcquireProvisionerJobParams) (database.ProvisionerJob, error) { + return q.db.AcquireProvisionerJob(ctx, arg) +} + +func (q *querier) UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg database.UpdateProvisionerJobWithCompleteByIDParams) error { + return q.db.UpdateProvisionerJobWithCompleteByID(ctx, arg) +} + +func (q *querier) UpdateProvisionerJobByID(ctx context.Context, arg database.UpdateProvisionerJobByIDParams) error { + return q.db.UpdateProvisionerJobByID(ctx, arg) +} + +func (q *querier) InsertProvisionerJob(ctx context.Context, arg database.InsertProvisionerJobParams) (database.ProvisionerJob, error) { + return q.db.InsertProvisionerJob(ctx, arg) +} + +func (q *querier) InsertProvisionerJobLogs(ctx context.Context, arg database.InsertProvisionerJobLogsParams) ([]database.ProvisionerJobLog, error) { + return q.db.InsertProvisionerJobLogs(ctx, arg) +} + +func (q *querier) InsertProvisionerDaemon(ctx context.Context, arg database.InsertProvisionerDaemonParams) (database.ProvisionerDaemon, error) { + return q.db.InsertProvisionerDaemon(ctx, arg) +} + +func (q *querier) InsertTemplateVersionParameter(ctx context.Context, arg database.InsertTemplateVersionParameterParams) (database.TemplateVersionParameter, error) { + return q.db.InsertTemplateVersionParameter(ctx, arg) +} + +func (q *querier) InsertWorkspaceResource(ctx context.Context, arg database.InsertWorkspaceResourceParams) (database.WorkspaceResource, error) { + return q.db.InsertWorkspaceResource(ctx, arg) +} + +func (q *querier) InsertParameterSchema(ctx context.Context, arg database.InsertParameterSchemaParams) (database.ParameterSchema, error) { + return q.db.InsertParameterSchema(ctx, arg) +} diff --git a/coderd/database/dbauthz/system_test.go b/coderd/database/dbauthz/system_test.go new file mode 100644 index 0000000000000..aa3baa179c82d --- /dev/null +++ b/coderd/database/dbauthz/system_test.go @@ -0,0 +1,219 @@ +package dbauthz_test + +import ( + "context" + "database/sql" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbgen" +) + +func (s *MethodTestSuite) TestSystemFunctions() { + s.Run("UpdateUserLinkedID", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + l := dbgen.UserLink(s.T(), db, database.UserLink{UserID: u.ID}) + check.Args(database.UpdateUserLinkedIDParams{ + UserID: u.ID, + LinkedID: l.LinkedID, + LoginType: database.LoginTypeGithub, + }).Asserts().Returns(l) + })) + s.Run("GetUserLinkByLinkedID", s.Subtest(func(db database.Store, check *expects) { + l := dbgen.UserLink(s.T(), db, database.UserLink{}) + check.Args(l.LinkedID).Asserts().Returns(l) + })) + s.Run("GetUserLinkByUserIDLoginType", s.Subtest(func(db database.Store, check *expects) { + l := dbgen.UserLink(s.T(), db, database.UserLink{}) + check.Args(database.GetUserLinkByUserIDLoginTypeParams{ + UserID: l.UserID, + LoginType: l.LoginType, + }).Asserts().Returns(l) + })) + s.Run("GetLatestWorkspaceBuilds", s.Subtest(func(db database.Store, check *expects) { + dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{}) + dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{}) + check.Args().Asserts() + })) + s.Run("GetWorkspaceAgentByAuthToken", s.Subtest(func(db database.Store, check *expects) { + agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{}) + check.Args(agt.AuthToken).Asserts().Returns(agt) + })) + s.Run("GetActiveUserCount", s.Subtest(func(db database.Store, check *expects) { + check.Args().Asserts().Returns(int64(0)) + })) + s.Run("GetUnexpiredLicenses", s.Subtest(func(db database.Store, check *expects) { + check.Args().Asserts() + })) + s.Run("GetAuthorizationUserRoles", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + check.Args(u.ID).Asserts() + })) + s.Run("GetDERPMeshKey", s.Subtest(func(db database.Store, check *expects) { + check.Args().Asserts() + })) + s.Run("InsertDERPMeshKey", s.Subtest(func(db database.Store, check *expects) { + check.Args("value").Asserts().Returns() + })) + s.Run("InsertDeploymentID", s.Subtest(func(db database.Store, check *expects) { + check.Args("value").Asserts().Returns() + })) + s.Run("InsertReplica", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertReplicaParams{ + ID: uuid.New(), + }).Asserts() + })) + s.Run("UpdateReplica", s.Subtest(func(db database.Store, check *expects) { + replica, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New()}) + require.NoError(s.T(), err) + check.Args(database.UpdateReplicaParams{ + ID: replica.ID, + DatabaseLatency: 100, + }).Asserts() + })) + s.Run("DeleteReplicasUpdatedBefore", s.Subtest(func(db database.Store, check *expects) { + _, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New(), UpdatedAt: time.Now()}) + require.NoError(s.T(), err) + check.Args(time.Now().Add(time.Hour)).Asserts() + })) + s.Run("GetReplicasUpdatedAfter", s.Subtest(func(db database.Store, check *expects) { + _, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New(), UpdatedAt: time.Now()}) + require.NoError(s.T(), err) + check.Args(time.Now().Add(time.Hour * -1)).Asserts() + })) + s.Run("GetUserCount", s.Subtest(func(db database.Store, check *expects) { + check.Args().Asserts().Returns(int64(0)) + })) + s.Run("GetTemplates", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.Template(s.T(), db, database.Template{}) + check.Args().Asserts() + })) + s.Run("UpdateWorkspaceBuildCostByID", s.Subtest(func(db database.Store, check *expects) { + b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{}) + o := b + o.DailyCost = 10 + check.Args(database.UpdateWorkspaceBuildCostByIDParams{ + ID: b.ID, + DailyCost: 10, + }).Asserts().Returns(o) + })) + s.Run("InsertOrUpdateLastUpdateCheck", s.Subtest(func(db database.Store, check *expects) { + check.Args("value").Asserts() + })) + s.Run("GetLastUpdateCheck", s.Subtest(func(db database.Store, check *expects) { + err := db.InsertOrUpdateLastUpdateCheck(context.Background(), "value") + require.NoError(s.T(), err) + check.Args().Asserts() + })) + s.Run("GetWorkspaceBuildsCreatedAfter", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{CreatedAt: time.Now().Add(-time.Hour)}) + check.Args(time.Now()).Asserts() + })) + s.Run("GetWorkspaceAgentsCreatedAfter", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{CreatedAt: time.Now().Add(-time.Hour)}) + check.Args(time.Now()).Asserts() + })) + s.Run("GetWorkspaceAppsCreatedAfter", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{CreatedAt: time.Now().Add(-time.Hour)}) + check.Args(time.Now()).Asserts() + })) + s.Run("GetWorkspaceResourcesCreatedAfter", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{CreatedAt: time.Now().Add(-time.Hour)}) + check.Args(time.Now()).Asserts() + })) + s.Run("GetWorkspaceResourceMetadataCreatedAfter", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.WorkspaceResourceMetadatums(s.T(), db, database.WorkspaceResourceMetadatum{}) + check.Args(time.Now()).Asserts() + })) + s.Run("DeleteOldAgentStats", s.Subtest(func(db database.Store, check *expects) { + check.Args().Asserts() + })) + s.Run("GetParameterSchemasCreatedAfter", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.ParameterSchema(s.T(), db, database.ParameterSchema{CreatedAt: time.Now().Add(-time.Hour)}) + check.Args(time.Now()).Asserts() + })) + s.Run("GetProvisionerJobsCreatedAfter", s.Subtest(func(db database.Store, check *expects) { + _ = dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{CreatedAt: time.Now().Add(-time.Hour)}) + check.Args(time.Now()).Asserts() + })) + s.Run("InsertWorkspaceAgent", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertWorkspaceAgentParams{ + ID: uuid.New(), + }).Asserts() + })) + s.Run("InsertWorkspaceApp", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertWorkspaceAppParams{ + ID: uuid.New(), + Health: database.WorkspaceAppHealthDisabled, + SharingLevel: database.AppSharingLevelOwner, + }).Asserts() + })) + s.Run("InsertWorkspaceResourceMetadata", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertWorkspaceResourceMetadataParams{ + WorkspaceResourceID: uuid.New(), + }).Asserts() + })) + s.Run("AcquireProvisionerJob", s.Subtest(func(db database.Store, check *expects) { + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ + StartedAt: sql.NullTime{Valid: false}, + }) + check.Args(database.AcquireProvisionerJobParams{Types: []database.ProvisionerType{j.Provisioner}}). + Asserts() + })) + s.Run("UpdateProvisionerJobWithCompleteByID", s.Subtest(func(db database.Store, check *expects) { + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) + check.Args(database.UpdateProvisionerJobWithCompleteByIDParams{ + ID: j.ID, + }).Asserts() + })) + s.Run("UpdateProvisionerJobByID", s.Subtest(func(db database.Store, check *expects) { + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) + check.Args(database.UpdateProvisionerJobByIDParams{ + ID: j.ID, + UpdatedAt: time.Now(), + }).Asserts() + })) + s.Run("InsertProvisionerJob", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertProvisionerJobParams{ + ID: uuid.New(), + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + Type: database.ProvisionerJobTypeWorkspaceBuild, + }).Asserts() + })) + s.Run("InsertProvisionerJobLogs", s.Subtest(func(db database.Store, check *expects) { + j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{}) + check.Args(database.InsertProvisionerJobLogsParams{ + JobID: j.ID, + }).Asserts() + })) + s.Run("InsertProvisionerDaemon", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertProvisionerDaemonParams{ + ID: uuid.New(), + }).Asserts() + })) + s.Run("InsertTemplateVersionParameter", s.Subtest(func(db database.Store, check *expects) { + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{}) + check.Args(database.InsertTemplateVersionParameterParams{ + TemplateVersionID: v.ID, + }).Asserts() + })) + s.Run("InsertWorkspaceResource", s.Subtest(func(db database.Store, check *expects) { + r := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{}) + check.Args(database.InsertWorkspaceResourceParams{ + ID: r.ID, + Transition: database.WorkspaceTransitionStart, + }).Asserts() + })) + s.Run("InsertParameterSchema", s.Subtest(func(db database.Store, check *expects) { + check.Args(database.InsertParameterSchemaParams{ + ID: uuid.New(), + DefaultSourceScheme: database.ParameterSourceSchemeNone, + DefaultDestinationScheme: database.ParameterDestinationSchemeNone, + ValidationTypeSystem: database.ParameterTypeSystemNone, + }).Asserts() + })) +} diff --git a/coderd/database/dbfake/databasefake.go b/coderd/database/dbfake/databasefake.go index 872ce87e9049a..f38edca9bc5fb 100644 --- a/coderd/database/dbfake/databasefake.go +++ b/coderd/database/dbfake/databasefake.go @@ -614,6 +614,14 @@ func (q *fakeQuerier) GetAuthorizedUserCount(ctx context.Context, params databas q.mutex.RLock() defer q.mutex.RUnlock() + // Call this to match the same function calls as the SQL implementation. + if prepared != nil { + _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) + if err != nil { + return -1, err + } + } + users := make([]database.User, 0, len(q.users)) for _, user := range q.users { @@ -892,6 +900,14 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database. q.mutex.RLock() defer q.mutex.RUnlock() + if prepared != nil { + // Call this to match the same function calls as the SQL implementation. + _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) + if err != nil { + return nil, err + } + } + workspaces := make([]database.Workspace, 0) for _, workspace := range q.workspaces { if arg.OwnerID != uuid.Nil && workspace.OwnerID != arg.OwnerID { @@ -1230,6 +1246,23 @@ func (q *fakeQuerier) GetWorkspaceByOwnerIDAndName(_ context.Context, arg databa return database.Workspace{}, sql.ErrNoRows } +func (q *fakeQuerier) GetWorkspaceByWorkspaceAppID(_ context.Context, workspaceAppID uuid.UUID) (database.Workspace, error) { + if err := validateDatabaseType(workspaceAppID); err != nil { + return database.Workspace{}, err + } + + q.mutex.RLock() + defer q.mutex.RUnlock() + + for _, workspaceApp := range q.workspaceApps { + workspaceApp := workspaceApp + if workspaceApp.ID == workspaceAppID { + return q.GetWorkspaceByAgentID(context.Background(), workspaceApp.AgentID) + } + } + return database.Workspace{}, sql.ErrNoRows +} + func (q *fakeQuerier) GetWorkspaceAppsByAgentID(_ context.Context, id uuid.UUID) ([]database.WorkspaceApp, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -1646,6 +1679,14 @@ func (q *fakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.G q.mutex.RLock() defer q.mutex.RUnlock() + // Call this to match the same function calls as the SQL implementation. + if prepared != nil { + _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithACL()) + if err != nil { + return nil, err + } + } + var templates []database.Template for _, template := range q.templates { if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil { @@ -3819,6 +3860,18 @@ func (q *fakeQuerier) InsertLicense( return l, nil } +func (q *fakeQuerier) GetLicenseByID(_ context.Context, id int32) (database.License, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + for _, license := range q.licenses { + if license.ID == id { + return license, nil + } + } + return database.License{}, sql.ErrNoRows +} + func (q *fakeQuerier) GetLicenses(_ context.Context) ([]database.License, error) { q.mutex.RLock() defer q.mutex.RUnlock() diff --git a/coderd/database/dbgen/generator.go b/coderd/database/dbgen/generator.go index ea39d2e8e34a1..545bc681d0112 100644 --- a/coderd/database/dbgen/generator.go +++ b/coderd/database/dbgen/generator.go @@ -66,7 +66,7 @@ func Template(t testing.TB, db database.Store, seed database.Template) database. UserACL: seed.UserACL, GroupACL: seed.GroupACL, DisplayName: takeFirst(seed.DisplayName, namesgenerator.GetRandomName(1)), - AllowUserCancelWorkspaceJobs: takeFirst(seed.AllowUserCancelWorkspaceJobs, true), + AllowUserCancelWorkspaceJobs: seed.AllowUserCancelWorkspaceJobs, }) require.NoError(t, err, "insert template") return template @@ -369,11 +369,8 @@ func GitAuthLink(t testing.TB, db database.Store, orig database.GitAuthLink) dat func TemplateVersion(t testing.TB, db database.Store, orig database.TemplateVersion) database.TemplateVersion { version, err := db.InsertTemplateVersion(context.Background(), database.InsertTemplateVersionParams{ - ID: takeFirst(orig.ID, uuid.New()), - TemplateID: uuid.NullUUID{ - UUID: takeFirst(orig.TemplateID.UUID, uuid.New()), - Valid: takeFirst(orig.TemplateID.Valid, true), - }, + ID: takeFirst(orig.ID, uuid.New()), + TemplateID: orig.TemplateID, OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), CreatedAt: takeFirst(orig.CreatedAt, database.Now()), UpdatedAt: takeFirst(orig.UpdatedAt, database.Now()), diff --git a/coderd/database/dbgen/generator_test.go b/coderd/database/dbgen/generator_test.go index 6ae00e5672793..c09cc6df8a466 100644 --- a/coderd/database/dbgen/generator_test.go +++ b/coderd/database/dbgen/generator_test.go @@ -68,7 +68,7 @@ func TestGenerator(t *testing.T) { require.Equal(t, exp, must(db.GetWorkspaceAppsByAgentID(context.Background(), exp.AgentID))[0]) }) - t.Run("WorkspaceResourceMetadatum", func(t *testing.T) { + t.Run("WorkspaceResourceMetadata", func(t *testing.T) { t.Parallel() db := dbfake.New() exp := dbgen.WorkspaceResourceMetadatums(t, db, database.WorkspaceResourceMetadatum{}) diff --git a/coderd/database/modelmethods.go b/coderd/database/modelmethods.go index 487e8a7e6a250..44c598697ef8b 100644 --- a/coderd/database/modelmethods.go +++ b/coderd/database/modelmethods.go @@ -2,6 +2,7 @@ package database import ( "sort" + "strconv" "github.com/coder/coder/coderd/rbac" ) @@ -63,6 +64,11 @@ func (TemplateVersion) RBACObject(template Template) rbac.Object { return template.RBACObject() } +// RBACObjectNoTemplate is for orphaned template versions. +func (v TemplateVersion) RBACObjectNoTemplate() rbac.Object { + return rbac.ResourceTemplate.InOrg(v.OrganizationID) +} + func (g Group) RBACObject() rbac.Object { return rbac.ResourceGroup.WithID(g.ID). InOrg(g.OrganizationID) @@ -94,6 +100,13 @@ func (m OrganizationMember) RBACObject() rbac.Object { InOrg(m.OrganizationID) } +func (m GetOrganizationIDsByMemberIDsRow) RBACObject() rbac.Object { + // TODO: This feels incorrect as we are really returning a list of orgmembers. + // This return type should be refactored to return a list of orgmembers, not this + // special type. + return rbac.ResourceUser.WithID(m.UserID) +} + func (o Organization) RBACObject() rbac.Object { return rbac.ResourceOrganization. WithID(o.ID). @@ -118,11 +131,29 @@ func (u User) RBACObject() rbac.Object { } func (u User) UserDataRBACObject() rbac.Object { - return rbac.ResourceUser.WithID(u.ID).WithOwner(u.ID.String()) + return rbac.ResourceUserData.WithID(u.ID).WithOwner(u.ID.String()) +} + +func (u GetUsersRow) RBACObject() rbac.Object { + return rbac.ResourceUser.WithID(u.ID) +} + +func (u GitSSHKey) RBACObject() rbac.Object { + return rbac.ResourceUserData.WithID(u.UserID).WithOwner(u.UserID.String()) +} + +func (u GitAuthLink) RBACObject() rbac.Object { + // I assume UserData is ok? + return rbac.ResourceUserData.WithID(u.UserID).WithOwner(u.UserID.String()) +} + +func (u UserLink) RBACObject() rbac.Object { + // I assume UserData is ok? + return rbac.ResourceUserData.WithOwner(u.UserID.String()).WithID(u.UserID) } -func (License) RBACObject() rbac.Object { - return rbac.ResourceLicense +func (l License) RBACObject() rbac.Object { + return rbac.ResourceLicense.WithIDString(strconv.FormatInt(int64(l.ID), 10)) } func ConvertUserRows(rows []GetUsersRow) []User { diff --git a/coderd/database/querier.go b/coderd/database/querier.go index a588b28076233..d11ad35999c3c 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -56,6 +56,7 @@ type sqlcQuerier interface { GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (WorkspaceBuild, error) GetLatestWorkspaceBuilds(ctx context.Context) ([]WorkspaceBuild, error) GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceBuild, error) + GetLicenseByID(ctx context.Context, id int32) (License, error) GetLicenses(ctx context.Context) ([]License, error) GetLogoURL(ctx context.Context) (string, error) GetOrganizationByID(ctx context.Context, id uuid.UUID) (Organization, error) @@ -121,6 +122,7 @@ type sqlcQuerier interface { GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (Workspace, error) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (Workspace, error) GetWorkspaceByOwnerIDAndName(ctx context.Context, arg GetWorkspaceByOwnerIDAndNameParams) (Workspace, error) + GetWorkspaceByWorkspaceAppID(ctx context.Context, workspaceAppID uuid.UUID) (Workspace, error) GetWorkspaceResourceByID(ctx context.Context, id uuid.UUID) (WorkspaceResource, error) GetWorkspaceResourceMetadataByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceResourceMetadatum, error) GetWorkspaceResourceMetadataCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceResourceMetadatum, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index a1806b128909c..a41ae0b363f28 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1343,6 +1343,30 @@ func (q *sqlQuerier) DeleteLicense(ctx context.Context, id int32) (int32, error) return id, err } +const getLicenseByID = `-- name: GetLicenseByID :one +SELECT + id, uploaded_at, jwt, exp, uuid +FROM + licenses +WHERE + id = $1 +LIMIT + 1 +` + +func (q *sqlQuerier) GetLicenseByID(ctx context.Context, id int32) (License, error) { + row := q.db.QueryRowContext(ctx, getLicenseByID, id) + var i License + err := row.Scan( + &i.ID, + &i.UploadedAt, + &i.JWT, + &i.Exp, + &i.UUID, + ) + return i, err +} + const getLicenses = `-- name: GetLicenses :many SELECT id, uploaded_at, jwt, exp, uuid FROM licenses @@ -6513,6 +6537,62 @@ func (q *sqlQuerier) GetWorkspaceByOwnerIDAndName(ctx context.Context, arg GetWo return i, err } +const getWorkspaceByWorkspaceAppID = `-- name: GetWorkspaceByWorkspaceAppID :one +SELECT + id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at +FROM + workspaces +WHERE + workspaces.id = ( + SELECT + workspace_id + FROM + workspace_builds + WHERE + workspace_builds.job_id = ( + SELECT + job_id + FROM + workspace_resources + WHERE + workspace_resources.id = ( + SELECT + resource_id + FROM + workspace_agents + WHERE + workspace_agents.id = ( + SELECT + agent_id + FROM + workspace_apps + WHERE + workspace_apps.id = $1 + ) + ) + ) + ) +` + +func (q *sqlQuerier) GetWorkspaceByWorkspaceAppID(ctx context.Context, workspaceAppID uuid.UUID) (Workspace, error) { + row := q.db.QueryRowContext(ctx, getWorkspaceByWorkspaceAppID, workspaceAppID) + var i Workspace + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.OwnerID, + &i.OrganizationID, + &i.TemplateID, + &i.Deleted, + &i.Name, + &i.AutostartSchedule, + &i.Ttl, + &i.LastUsedAt, + ) + return i, err +} + const getWorkspaces = `-- name: GetWorkspaces :many SELECT workspaces.id, workspaces.created_at, workspaces.updated_at, workspaces.owner_id, workspaces.organization_id, workspaces.template_id, workspaces.deleted, workspaces.name, workspaces.autostart_schedule, workspaces.ttl, workspaces.last_used_at, COUNT(*) OVER () as count diff --git a/coderd/database/queries/licenses.sql b/coderd/database/queries/licenses.sql index 1622151a477f1..3512a46514787 100644 --- a/coderd/database/queries/licenses.sql +++ b/coderd/database/queries/licenses.sql @@ -14,6 +14,16 @@ SELECT * FROM licenses ORDER BY (id); +-- name: GetLicenseByID :one +SELECT + * +FROM + licenses +WHERE + id = $1 +LIMIT + 1; + -- name: GetUnexpiredLicenses :many SELECT * FROM licenses diff --git a/coderd/database/queries/workspaces.sql b/coderd/database/queries/workspaces.sql index e32e68a6fc142..def4436bed94c 100644 --- a/coderd/database/queries/workspaces.sql +++ b/coderd/database/queries/workspaces.sql @@ -8,6 +8,42 @@ WHERE LIMIT 1; +-- name: GetWorkspaceByWorkspaceAppID :one +SELECT + * +FROM + workspaces +WHERE + workspaces.id = ( + SELECT + workspace_id + FROM + workspace_builds + WHERE + workspace_builds.job_id = ( + SELECT + job_id + FROM + workspace_resources + WHERE + workspace_resources.id = ( + SELECT + resource_id + FROM + workspace_agents + WHERE + workspace_agents.id = ( + SELECT + agent_id + FROM + workspace_apps + WHERE + workspace_apps.id = @workspace_app_id + ) + ) + ) + ); + -- name: GetWorkspaceByAgentID :one SELECT * diff --git a/coderd/files.go b/coderd/files.go index 57e919d7dab3d..91858eb3ca06e 100644 --- a/coderd/files.go +++ b/coderd/files.go @@ -76,7 +76,14 @@ func (api *API) postFile(rw http.ResponseWriter, r *http.Request) { ID: file.ID, }) return + } else if !errors.Is(err, sql.ErrNoRows) { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error getting file.", + Detail: err.Error(), + }) + return } + id := uuid.New() file, err = api.Database.InsertFile(ctx, database.InsertFileParams{ ID: id, diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index a498399f48967..553fe43d89ef9 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -19,6 +19,7 @@ import ( "golang.org/x/xerrors" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/codersdk" @@ -159,7 +160,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler { return } - key, err := cfg.DB.GetAPIKeyByID(r.Context(), keyID) + //nolint:gocritic // System needs to fetch API key to check if it's valid. + key, err := cfg.DB.GetAPIKeyByID(dbauthz.AsSystem(ctx), keyID) if err != nil { if errors.Is(err, sql.ErrNoRows) { optionalWrite(http.StatusUnauthorized, codersdk.Response{ @@ -192,7 +194,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler { changed = false ) if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC { - link, err = cfg.DB.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{ + //nolint:gocritic // System needs to fetch UserLink to check if it's valid. + link, err = cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystem(ctx), database.GetUserLinkByUserIDLoginTypeParams{ UserID: key.UserID, LoginType: key.LoginType, }) @@ -275,7 +278,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler { } } if changed { - err := cfg.DB.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{ + //nolint:gocritic // System needs to update API Key LastUsed + err := cfg.DB.UpdateAPIKeyByID(dbauthz.AsSystem(ctx), database.UpdateAPIKeyByIDParams{ ID: key.ID, LastUsed: key.LastUsed, ExpiresAt: key.ExpiresAt, @@ -291,7 +295,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler { // If the API Key is associated with a user_link (e.g. Github/OIDC) // then we want to update the relevant oauth fields. if link.UserID != uuid.Nil { - link, err = cfg.DB.UpdateUserLink(r.Context(), database.UpdateUserLinkParams{ + // nolint:gocritic + link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystem(ctx), database.UpdateUserLinkParams{ UserID: link.UserID, LoginType: link.LoginType, OAuthAccessToken: link.OAuthAccessToken, @@ -310,7 +315,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler { // We only want to update this occasionally to reduce DB write // load. We update alongside the UserLink and APIKey since it's // easier on the DB to colocate writes. - _, err = cfg.DB.UpdateUserLastSeenAt(ctx, database.UpdateUserLastSeenAtParams{ + // nolint:gocritic + _, err = cfg.DB.UpdateUserLastSeenAt(dbauthz.AsSystem(ctx), database.UpdateUserLastSeenAtParams{ ID: key.UserID, LastSeenAt: database.Now(), UpdatedAt: database.Now(), @@ -327,7 +333,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler { // If the key is valid, we also fetch the user roles and status. // The roles are used for RBAC authorize checks, and the status // is to block 'suspended' users from accessing the platform. - roles, err := cfg.DB.GetAuthorizationUserRoles(r.Context(), key.UserID) + // nolint:gocritic + roles, err := cfg.DB.GetAuthorizationUserRoles(dbauthz.AsSystem(ctx), key.UserID) if err != nil { write(http.StatusUnauthorized, codersdk.Response{ Message: internalErrorMessage, @@ -343,16 +350,20 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler { return } + // Actor is the user's authorization context. + actor := rbac.Subject{ + ID: key.UserID.String(), + Roles: rbac.RoleNames(roles.Roles), + Groups: roles.Groups, + Scope: rbac.ScopeName(key.Scope), + } ctx = context.WithValue(ctx, apiKeyContextKey{}, key) ctx = context.WithValue(ctx, userAuthKey{}, Authorization{ Username: roles.Username, - Actor: rbac.Subject{ - ID: key.UserID.String(), - Roles: rbac.RoleNames(roles.Roles), - Groups: roles.Groups, - Scope: rbac.ScopeName(key.Scope), - }, + Actor: actor, }) + // Set the auth context for the authzquerier as well. + ctx = dbauthz.As(ctx, actor) next.ServeHTTP(rw, r.WithContext(ctx)) }) diff --git a/coderd/httpmw/authz.go b/coderd/httpmw/authz.go new file mode 100644 index 0000000000000..5bfe69d47c956 --- /dev/null +++ b/coderd/httpmw/authz.go @@ -0,0 +1,37 @@ +package httpmw + +import ( + "net/http" + + "github.com/coder/coder/coderd/database/dbauthz" + + "github.com/go-chi/chi/v5" +) + +// AsAuthzSystem is a chained handler that temporarily sets the dbauthz context +// to System for the inner handlers, and resets the context afterwards. +// +// TODO: Refactor the middleware functions to not require this. +// This is a bit of a kludge for now as some middleware functions require +// usage as a system user in some cases, but not all cases. To avoid large +// refactors, we use this middleware to temporarily set the context to a system. +func AsAuthzSystem(mws ...func(http.Handler) http.Handler) func(http.Handler) http.Handler { + chain := chi.Chain(mws...) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + before, beforeExists := dbauthz.ActorFromContext(r.Context()) + if !beforeExists { + // AsRemoveActor will actually remove the actor from the context. + before = dbauthz.AsRemoveActor + } + + // nolint:gocritic // AsAuthzSystem needs to do this. + r = r.WithContext(dbauthz.AsSystem(ctx)) + chain.Handler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + r = r.WithContext(dbauthz.As(r.Context(), before)) + next.ServeHTTP(rw, r) + })).ServeHTTP(rw, r) + }) + } +} diff --git a/coderd/httpmw/authz_test.go b/coderd/httpmw/authz_test.go new file mode 100644 index 0000000000000..29474aa264bd9 --- /dev/null +++ b/coderd/httpmw/authz_test.go @@ -0,0 +1,97 @@ +package httpmw_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/coderd/database/dbauthz" + "github.com/coder/coder/coderd/httpmw" +) + +func TestAsAuthzSystem(t *testing.T) { + t.Parallel() + userActor := coderdtest.RandomRBACSubject() + + base := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + actor, ok := dbauthz.ActorFromContext(r.Context()) + assert.True(t, ok, "actor should exist") + assert.True(t, userActor.Equal(actor), "actor should be the user actor") + }) + + mwSetUser := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + r = r.WithContext(dbauthz.As(r.Context(), userActor)) + next.ServeHTTP(rw, r) + }) + } + + mwAssertSystem := mwAssert(func(req *http.Request) { + actor, ok := dbauthz.ActorFromContext(req.Context()) + assert.True(t, ok, "actor should exist") + assert.False(t, userActor.Equal(actor), "systemActor should not be the user actor") + assert.Contains(t, actor.Roles.Names(), "system", "should have system role") + }) + + mwAssertUser := mwAssert(func(req *http.Request) { + actor, ok := dbauthz.ActorFromContext(req.Context()) + assert.True(t, ok, "actor should exist") + assert.True(t, userActor.Equal(actor), "should be the useractor") + }) + + mwAssertNoUser := mwAssert(func(req *http.Request) { + _, ok := dbauthz.ActorFromContext(req.Context()) + assert.False(t, ok, "actor should not exist") + }) + + // Request as the user actor + const pattern = "/" + req := httptest.NewRequest("GET", pattern, nil) + res := httptest.NewRecorder() + + handler := chi.NewRouter() + handler.Route(pattern, func(r chi.Router) { + r.Use( + // First assert there is no actor context + mwAssertNoUser, + httpmw.AsAuthzSystem( + // Assert the system actor + mwAssertSystem, + mwAssertSystem, + ), + // Assert no user present outside of the AsAuthzSystem chain + mwAssertNoUser, + // ---- + // Set to the user actor + mwSetUser, + // Assert the user actor + mwAssertUser, + httpmw.AsAuthzSystem( + // Assert the system actor + mwAssertSystem, + mwAssertSystem, + ), + // Check the user actor was returned to the context + mwAssertUser, + ) + r.Handle("/", base) + r.NotFound(func(writer http.ResponseWriter, request *http.Request) { + assert.Fail(t, "should not hit not found, the route should be correct") + }) + }) + + handler.ServeHTTP(res, req) +} + +func mwAssert(assertF func(req *http.Request)) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + assertF(r) + next.ServeHTTP(rw, r) + }) + } +} diff --git a/coderd/httpmw/userparam.go b/coderd/httpmw/userparam.go index 74119d503a97b..4cbec80c695f6 100644 --- a/coderd/httpmw/userparam.go +++ b/coderd/httpmw/userparam.go @@ -11,6 +11,7 @@ import ( "github.com/google/uuid" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/codersdk" ) @@ -68,7 +69,8 @@ func ExtractUserParam(db database.Store, redirectToLoginOnMe bool) func(http.Han }) return } - user, err = db.GetUserByID(ctx, apiKey.UserID) + //nolint:gocritic // System needs to be able to get user from param. + user, err = db.GetUserByID(dbauthz.AsSystem(ctx), apiKey.UserID) if xerrors.Is(err, sql.ErrNoRows) { httpapi.ResourceNotFound(rw) return @@ -81,8 +83,8 @@ func ExtractUserParam(db database.Store, redirectToLoginOnMe bool) func(http.Han return } } else if userID, err := uuid.Parse(userQuery); err == nil { - // If the userQuery is a valid uuid - user, err = db.GetUserByID(ctx, userID) + //nolint:gocritic // If the userQuery is a valid uuid + user, err = db.GetUserByID(dbauthz.AsSystem(ctx), userID) if err != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: userErrorMessage, @@ -90,8 +92,8 @@ func ExtractUserParam(db database.Store, redirectToLoginOnMe bool) func(http.Han return } } else { - // Try as a username last - user, err = db.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{ + // nolint:gocritic // Try as a username last + user, err = db.GetUserByEmailOrUsername(dbauthz.AsSystem(ctx), database.GetUserByEmailOrUsernameParams{ Username: userQuery, }) if err != nil { diff --git a/coderd/httpmw/workspaceagent.go b/coderd/httpmw/workspaceagent.go index ea76ac8ad08f6..980872434d114 100644 --- a/coderd/httpmw/workspaceagent.go +++ b/coderd/httpmw/workspaceagent.go @@ -10,7 +10,9 @@ import ( "github.com/google/uuid" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/codersdk" ) @@ -45,7 +47,8 @@ func ExtractWorkspaceAgent(db database.Store) func(http.Handler) http.Handler { }) return } - agent, err := db.GetWorkspaceAgentByAuthToken(ctx, token) + //nolint:gocritic // System needs to be able to get workspace agents. + agent, err := db.GetWorkspaceAgentByAuthToken(dbauthz.AsSystem(ctx), token) if err != nil { if errors.Is(err, sql.ErrNoRows) { httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{ @@ -62,8 +65,50 @@ func ExtractWorkspaceAgent(db database.Store) func(http.Handler) http.Handler { return } + //nolint:gocritic // System needs to be able to get workspace agents. + subject, err := getAgentSubject(dbauthz.AsSystem(ctx), db, agent) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching workspace agent.", + Detail: err.Error(), + }) + return + } + ctx = context.WithValue(ctx, workspaceAgentContextKey{}, agent) + // Also set the dbauthz actor for the request. + ctx = dbauthz.As(ctx, subject) next.ServeHTTP(rw, r.WithContext(ctx)) }) } } + +func getAgentSubject(ctx context.Context, db database.Store, agent database.WorkspaceAgent) (rbac.Subject, error) { + // TODO: make a different query that gets the workspace owner and roles along with the agent. + workspace, err := db.GetWorkspaceByAgentID(ctx, agent.ID) + if err != nil { + return rbac.Subject{}, err + } + + user, err := db.GetUserByID(ctx, workspace.OwnerID) + if err != nil { + return rbac.Subject{}, err + } + + roles, err := db.GetAuthorizationUserRoles(ctx, user.ID) + if err != nil { + return rbac.Subject{}, err + } + + // A user that creates a workspace can use this agent auth token and + // impersonate the workspace. So to prevent privilege escalation, the + // subject inherits the roles of the user that owns the workspace. + // We then add a workspace-agent scope to limit the permissions + // to only what the workspace agent needs. + return rbac.Subject{ + ID: user.ID.String(), + Roles: rbac.RoleNames(roles.Roles), + Groups: roles.Groups, + Scope: rbac.WorkspaceAgentScope(workspace.ID, user.ID), + }, nil +} diff --git a/coderd/httpmw/workspaceagent_test.go b/coderd/httpmw/workspaceagent_test.go index 85800a6a71d66..bcf6ee2f7e0e2 100644 --- a/coderd/httpmw/workspaceagent_test.go +++ b/coderd/httpmw/workspaceagent_test.go @@ -19,11 +19,10 @@ import ( func TestWorkspaceAgent(t *testing.T) { t.Parallel() - setup := func(db database.Store) (*http.Request, uuid.UUID) { - token := uuid.New() + setup := func(db database.Store, token uuid.UUID) *http.Request { r := httptest.NewRequest("GET", "/", nil) r.Header.Set(codersdk.SessionTokenHeader, token.String()) - return r, token + return r } t.Run("None", func(t *testing.T) { @@ -34,7 +33,7 @@ func TestWorkspaceAgent(t *testing.T) { httpmw.ExtractWorkspaceAgent(db), ) rtr.Get("/", nil) - r, _ := setup(db) + r := setup(db, uuid.New()) rw := httptest.NewRecorder() rtr.ServeHTTP(rw, r) @@ -46,6 +45,24 @@ func TestWorkspaceAgent(t *testing.T) { t.Run("Found", func(t *testing.T) { t.Parallel() db := dbfake.New() + var ( + user = dbgen.User(t, db, database.User{}) + workspace = dbgen.Workspace(t, db, database.Workspace{ + OwnerID: user.ID, + }) + job = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{}) + resource = dbgen.WorkspaceResource(t, db, database.WorkspaceResource{ + JobID: job.ID, + }) + _ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + WorkspaceID: workspace.ID, + JobID: job.ID, + }) + agent = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ + ResourceID: resource.ID, + }) + ) + rtr := chi.NewRouter() rtr.Use( httpmw.ExtractWorkspaceAgent(db), @@ -54,10 +71,7 @@ func TestWorkspaceAgent(t *testing.T) { _ = httpmw.WorkspaceAgent(r) rw.WriteHeader(http.StatusOK) }) - r, token := setup(db) - _ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ - AuthToken: token, - }) + r := setup(db, agent.AuthToken) rw := httptest.NewRecorder() rtr.ServeHTTP(rw, r) diff --git a/coderd/members.go b/coderd/members.go index c67937423dd15..c3e4607b0f9e5 100644 --- a/coderd/members.go +++ b/coderd/members.go @@ -55,20 +55,20 @@ func (api *API) putMemberRoles(rw http.ResponseWriter, r *http.Request) { // Assigning a role requires the create permission. if len(added) > 0 && !api.Authorize(r, rbac.ActionCreate, rbac.ResourceOrgRoleAssignment.InOrg(organization.ID)) { - httpapi.Forbidden(rw) + httpapi.ResourceNotFound(rw) return } // Removing a role requires the delete permission. if len(removed) > 0 && !api.Authorize(r, rbac.ActionDelete, rbac.ResourceOrgRoleAssignment.InOrg(organization.ID)) { - httpapi.Forbidden(rw) + httpapi.ResourceNotFound(rw) return } // Just treat adding & removing as "assigning" for now. for _, roleName := range append(added, removed...) { if !rbac.CanAssignRole(actorRoles.Actor.Roles, roleName) { - httpapi.Forbidden(rw) + httpapi.ResourceNotFound(rw) return } } diff --git a/coderd/metricscache/metricscache.go b/coderd/metricscache/metricscache.go index 66742e3c71bb2..7c073a7e8200b 100644 --- a/coderd/metricscache/metricscache.go +++ b/coderd/metricscache/metricscache.go @@ -14,6 +14,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/codersdk" "github.com/coder/retry" ) @@ -142,6 +143,8 @@ func countUniqueUsers(rows []database.GetTemplateDAUsRow) int { } func (c *Cache) refresh(ctx context.Context) error { + //nolint:gocritic // This is a system service. + ctx = dbauthz.AsSystem(ctx) err := c.database.DeleteOldAgentStats(ctx) if err != nil { return xerrors.Errorf("delete old stats: %w", err) diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index 08c60d8b50f89..b97cc8594a573 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -25,6 +25,7 @@ import ( "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/parameter" "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/codersdk" @@ -56,6 +57,8 @@ type Server struct { // AcquireJob queries the database to lock a job. func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + //nolint:gocritic //TODO: make a provisionerd role + ctx = dbauthz.AsSystem(ctx) // This prevents loads of provisioner daemons from consistently // querying the database when no jobs are available. // @@ -270,6 +273,8 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac } func (server *Server) CommitQuota(ctx context.Context, request *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error) { + //nolint:gocritic //TODO: make a provisionerd role + ctx = dbauthz.AsSystem(ctx) jobID, err := uuid.Parse(request.JobId) if err != nil { return nil, xerrors.Errorf("parse job id: %w", err) @@ -299,6 +304,8 @@ func (server *Server) CommitQuota(ctx context.Context, request *proto.CommitQuot } func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { + //nolint:gocritic //TODO: make a provisionerd role + ctx = dbauthz.AsSystem(ctx) parsedID, err := uuid.Parse(request.JobId) if err != nil { return nil, xerrors.Errorf("parse job id: %w", err) @@ -345,7 +352,8 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq slog.F("stage", log.Stage), slog.F("output", log.Output)) } - logs, err := server.Database.InsertProvisionerJobLogs(context.Background(), insertParams) + //nolint:gocritic //TODO: make a provisionerd role + logs, err := server.Database.InsertProvisionerJobLogs(dbauthz.AsSystem(context.Background()), insertParams) if err != nil { server.Logger.Error(ctx, "failed to insert job logs", slog.F("job_id", parsedID), slog.Error(err)) return nil, xerrors.Errorf("insert job logs: %w", err) @@ -470,6 +478,8 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq } func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*proto.Empty, error) { + //nolint:gocritic // TODO: make a provisionerd role + ctx = dbauthz.AsSystem(ctx) jobID, err := uuid.Parse(failJob.JobId) if err != nil { return nil, xerrors.Errorf("parse job id: %w", err) @@ -596,6 +606,8 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p // CompleteJob is triggered by a provision daemon to mark a provisioner job as completed. func (server *Server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) (*proto.Empty, error) { + //nolint:gocritic // TODO: make a provisionerd role + ctx = dbauthz.AsSystem(ctx) jobID, err := uuid.Parse(completed.JobId) if err != nil { return nil, xerrors.Errorf("parse job id: %w", err) diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index 35841b6a7ce56..3770cf217d0a0 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -16,9 +16,10 @@ import ( "nhooyr.io/websocket" "cdr.dev/slog" - "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/codersdk" ) @@ -32,6 +33,7 @@ import ( func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob) { var ( ctx = r.Context() + actor, _ = dbauthz.ActorFromContext(ctx) logger = api.Logger.With(slog.F("job_id", job.ID)) follow = r.URL.Query().Has("follow") afterRaw = r.URL.Query().Get("after") @@ -49,7 +51,7 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job // of processed IDs. var bufferedLogs <-chan database.ProvisionerJobLog if follow { - bl, closeFollow, err := api.followLogs(job.ID) + bl, closeFollow, err := api.followLogs(actor, job.ID) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error watching provisioner logs.", @@ -367,7 +369,7 @@ type provisionerJobLogsMessage struct { EndOfLogs bool `json:"end_of_logs,omitempty"` } -func (api *API) followLogs(jobID uuid.UUID) (<-chan database.ProvisionerJobLog, func(), error) { +func (api *API) followLogs(actor rbac.Subject, jobID uuid.UUID) (<-chan database.ProvisionerJobLog, func(), error) { logger := api.Logger.With(slog.F("job_id", jobID)) var ( @@ -392,7 +394,7 @@ func (api *API) followLogs(jobID uuid.UUID) (<-chan database.ProvisionerJobLog, } if jlMsg.CreatedAfter != 0 { - logs, err := api.Database.GetProvisionerLogsByIDBetween(ctx, database.GetProvisionerLogsByIDBetweenParams{ + logs, err := api.Database.GetProvisionerLogsByIDBetween(dbauthz.As(ctx, actor), database.GetProvisionerLogsByIDBetweenParams{ JobID: jobID, CreatedAfter: jlMsg.CreatedAfter, }) diff --git a/coderd/rbac/authz_internal_test.go b/coderd/rbac/authz_internal_test.go index 6e6d812dc4306..b9332f459c81a 100644 --- a/coderd/rbac/authz_internal_test.go +++ b/coderd/rbac/authz_internal_test.go @@ -1039,7 +1039,6 @@ func testAuthorize(t *testing.T, name string, subject Subject, sets ...[]authTes } } - func must[T any](value T, err error) T { if err != nil { panic(err) diff --git a/coderd/rbac/builtin.go b/coderd/rbac/builtin.go index c5b0396629270..b644a03e03695 100644 --- a/coderd/rbac/builtin.go +++ b/coderd/rbac/builtin.go @@ -133,6 +133,8 @@ var ( ResourceWorkspace.Type: {ActionRead}, // CRUD to provisioner daemons for now. ResourceProvisionerDaemon.Type: {ActionCreate, ActionRead, ActionUpdate, ActionDelete}, + // Needs to read all organizations since + ResourceOrganization.Type: {ActionRead}, }), Org: map[string][]Permission{}, User: []Permission{}, @@ -217,6 +219,12 @@ var ( // The first key is the actor role, the second is the roles they can assign. // map[actor_role][assign_role] assignRoles = map[string]map[string]bool{ + "system": { + owner: true, + member: true, + orgAdmin: true, + orgMember: true, + }, owner: { owner: true, auditor: true, diff --git a/coderd/rbac/builtin_internal_test.go b/coderd/rbac/builtin_internal_test.go index c91100ab5fb40..d7289fcdd2446 100644 --- a/coderd/rbac/builtin_internal_test.go +++ b/coderd/rbac/builtin_internal_test.go @@ -10,7 +10,7 @@ import ( // BenchmarkRBACValueAllocation benchmarks the cost of allocating a rego input // value. By default, `ast.InterfaceToValue` is used to convert the input, -// which uses json marshalling under the hood. +// which uses json marshaling under the hood. // // Currently ast.Object.insert() is the slowest part of the process and allocates // the most amount of bytes. This general approach copies all of our struct diff --git a/coderd/rbac/builtin_test.go b/coderd/rbac/builtin_test.go index 92ed3da501356..6e5b67b6474a8 100644 --- a/coderd/rbac/builtin_test.go +++ b/coderd/rbac/builtin_test.go @@ -19,6 +19,7 @@ type authSubject struct { Actor rbac.Subject } +// TODO: add the SYSTEM to the MATRIX func TestRolePermissions(t *testing.T) { t.Parallel() @@ -183,8 +184,8 @@ func TestRolePermissions(t *testing.T) { Actions: []rbac.Action{rbac.ActionRead}, Resource: rbac.ResourceOrganization.WithID(orgID).InOrg(orgID), AuthorizeMap: map[bool][]authSubject{ - true: {owner, orgAdmin, orgMemberMe}, - false: {otherOrgAdmin, otherOrgMember, memberMe, templateAdmin, userAdmin}, + true: {owner, orgAdmin, orgMemberMe, templateAdmin}, + false: {otherOrgAdmin, otherOrgMember, memberMe, userAdmin}, }, }, { diff --git a/coderd/rbac/error.go b/coderd/rbac/error.go index ec0bf02f8f21f..dafd08af2e6b7 100644 --- a/coderd/rbac/error.go +++ b/coderd/rbac/error.go @@ -1,6 +1,10 @@ package rbac -import "github.com/open-policy-agent/opa/rego" +import ( + "errors" + + "github.com/open-policy-agent/opa/rego" +) const ( // errUnauthorized is the error message that should be returned to @@ -24,6 +28,12 @@ type UnauthorizedError struct { output rego.ResultSet } +// IsUnauthorizedError is a convenience function to check if err is UnauthorizedError. +// It is equivalent to errors.As(err, &UnauthorizedError{}). +func IsUnauthorizedError(err error) bool { + return errors.As(err, &UnauthorizedError{}) +} + // ForbiddenWithInternal creates a new error that will return a simple // "forbidden" to the client, logging internally the more detailed message // provided. @@ -37,6 +47,10 @@ func ForbiddenWithInternal(internal error, subject Subject, action Action, objec } } +func (e UnauthorizedError) Unwrap() error { + return e.internal +} + // Error implements the error interface. func (UnauthorizedError) Error() string { return errUnauthorized @@ -47,6 +61,10 @@ func (e *UnauthorizedError) Internal() error { return e.internal } +func (e *UnauthorizedError) SetInternal(err error) { + e.internal = err +} + func (e *UnauthorizedError) Input() map[string]interface{} { return map[string]interface{}{ "subject": e.subject, @@ -59,3 +77,11 @@ func (e *UnauthorizedError) Input() map[string]interface{} { func (e *UnauthorizedError) Output() rego.ResultSet { return e.output } + +// As implements the errors.As interface. +func (*UnauthorizedError) As(target interface{}) bool { + if _, ok := target.(*UnauthorizedError); ok { + return true + } + return false +} diff --git a/coderd/rbac/error_test.go b/coderd/rbac/error_test.go new file mode 100644 index 0000000000000..23bbc7b3bc54c --- /dev/null +++ b/coderd/rbac/error_test.go @@ -0,0 +1,32 @@ +package rbac_test + +import ( + "testing" + + "github.com/coder/coder/coderd/rbac" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" +) + +func TestIsUnauthorizedError(t *testing.T) { + t.Parallel() + t.Run("NotWrapped", func(t *testing.T) { + t.Parallel() + errFunc := func() error { + return rbac.UnauthorizedError{} + } + + err := errFunc() + require.True(t, rbac.IsUnauthorizedError(err)) + }) + + t.Run("Wrapped", func(t *testing.T) { + t.Parallel() + errFunc := func() error { + return xerrors.Errorf("test error: %w", rbac.UnauthorizedError{}) + } + err := errFunc() + require.True(t, rbac.IsUnauthorizedError(err)) + }) +} diff --git a/coderd/rbac/scopes.go b/coderd/rbac/scopes.go index 15cdeb2da8c88..82b64f7179135 100644 --- a/coderd/rbac/scopes.go +++ b/coderd/rbac/scopes.go @@ -3,6 +3,8 @@ package rbac import ( "fmt" + "github.com/google/uuid" + "golang.org/x/xerrors" ) @@ -41,6 +43,29 @@ func (s Scope) Name() string { return s.Role.Name } +// WorkspaceAgentScope returns a scope that is the same as ScopeAll but can only +// affect resources in the allow list. Only a scope is returned as the roles +// should come from the workspace owner. +func WorkspaceAgentScope(workspaceID, ownerID uuid.UUID) Scope { + allScope, err := ScopeAll.Expand() + if err != nil { + panic("failed to expand scope all, this should never happen") + } + return Scope{ + // TODO: We want to limit the role too to be extra safe. + // Even though the allowlist blocks anything else, it is still good + // incase we change the behavior of the allowlist. The allowlist is new + // and evolving. + Role: allScope.Role, + // This prevents the agent from being able to access any other resource. + AllowIDList: []string{ + workspaceID.String(), + ownerID.String(), + // TODO: Might want to include the template the workspace uses too? + }, + } +} + const ( ScopeAll ScopeName = "all" ScopeApplicationConnect ScopeName = "application_connect" diff --git a/coderd/roles.go b/coderd/roles.go index 743d2bdba8a6f..a067173300e43 100644 --- a/coderd/roles.go +++ b/coderd/roles.go @@ -47,7 +47,7 @@ func (api *API) assignableOrgRoles(rw http.ResponseWriter, r *http.Request) { actorRoles := httpmw.UserAuthorization(r) if !api.Authorize(r, rbac.ActionRead, rbac.ResourceOrgRoleAssignment.InOrg(organization.ID)) { - httpapi.Forbidden(rw) + httpapi.ResourceNotFound(rw) return } diff --git a/coderd/roles_test.go b/coderd/roles_test.go index 5e5bad8c455ef..2b9eb35f34e15 100644 --- a/coderd/roles_test.go +++ b/coderd/roles_test.go @@ -30,7 +30,7 @@ func TestListRoles(t *testing.T) { }) require.NoError(t, err, "create org") - const forbidden = "Forbidden" + const notFound = "Resource not found" testCases := []struct { Name string Client *codersdk.Client @@ -66,7 +66,7 @@ func TestListRoles(t *testing.T) { APICall: func(ctx context.Context) ([]codersdk.AssignableRoles, error) { return member.ListOrganizationRoles(ctx, otherOrg.ID) }, - AuthorizedError: forbidden, + AuthorizedError: notFound, }, // Org admin { @@ -95,7 +95,7 @@ func TestListRoles(t *testing.T) { APICall: func(ctx context.Context) ([]codersdk.AssignableRoles, error) { return orgAdmin.ListOrganizationRoles(ctx, otherOrg.ID) }, - AuthorizedError: forbidden, + AuthorizedError: notFound, }, // Admin { @@ -133,7 +133,7 @@ func TestListRoles(t *testing.T) { if c.AuthorizedError != "" { var apiErr *codersdk.Error require.ErrorAs(t, err, &apiErr) - require.Equal(t, http.StatusForbidden, apiErr.StatusCode()) + require.Equal(t, http.StatusNotFound, apiErr.StatusCode()) require.Contains(t, apiErr.Message, c.AuthorizedError) } else { require.NoError(t, err) diff --git a/coderd/templates.go b/coderd/templates.go index febe8ab84c93b..564ccd74946f5 100644 --- a/coderd/templates.go +++ b/coderd/templates.go @@ -82,6 +82,10 @@ func (api *API) deleteTemplate(rw http.ResponseWriter, r *http.Request) { return } + // TODO: This just returns the workspaces a user can view. We should use + // a system function to get all workspaces that use this template. + // This data should never be exposed to the user aside from a non-zero count. + // Or we move this into a postgres constraint. workspaces, err := api.Database.GetWorkspaces(ctx, database.GetWorkspacesParams{ TemplateIds: []uuid.UUID{template.ID}, }) diff --git a/coderd/userauth.go b/coderd/userauth.go index 3518481d508fb..a59b89f08ebee 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -18,9 +18,9 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog" - "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/rbac" @@ -57,7 +57,8 @@ func (api *API) postLogin(rw http.ResponseWriter, r *http.Request) { return } - user, err := api.Database.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{ + //nolint:gocritic // In order to login, we need to get the user first! + user, err := api.Database.GetUserByEmailOrUsername(dbauthz.AsSystem(ctx), database.GetUserByEmailOrUsernameParams{ Email: loginWithPassword.Email, }) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { @@ -111,15 +112,32 @@ func (api *API) postLogin(rw http.ResponseWriter, r *http.Request) { return } + //nolint:gocritic // System needs to fetch user roles in order to login user. + roles, err := api.Database.GetAuthorizationUserRoles(dbauthz.AsSystem(ctx), user.ID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error.", + }) + return + } + // If the user logged into a suspended account, reject the login request. - if user.Status != database.UserStatusActive { + if roles.Status != database.UserStatusActive { httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{ Message: "Your account is suspended. Contact an admin to reactivate your account.", }) return } - cookie, key, err := api.createAPIKey(ctx, createAPIKeyParams{ + userSubj := rbac.Subject{ + ID: user.ID.String(), + Roles: rbac.RoleNames(roles.Roles), + Groups: roles.Groups, + Scope: rbac.ScopeAll, + } + + //nolint:gocritic // Creating the API key as the user instead of as system. + cookie, key, err := api.createAPIKey(dbauthz.As(ctx, userSubj), createAPIKeyParams{ UserID: user.ID, LoginType: database.LoginTypePassword, RemoteAddr: r.RemoteAddr, @@ -765,7 +783,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook // with OIDC for the first time. if user.ID == uuid.Nil { var organizationID uuid.UUID - organizations, _ := tx.GetOrganizations(ctx) + //nolint:gocritic + organizations, _ := tx.GetOrganizations(dbauthz.AsSystem(ctx)) if len(organizations) > 0 { // Add the user to the first organization. Once multi-organization // support is added, we should enable a configuration map of user @@ -773,7 +792,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook organizationID = organizations[0].ID } - _, err := tx.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{ + //nolint:gocritic + _, err := tx.GetUserByEmailOrUsername(dbauthz.AsSystem(ctx), database.GetUserByEmailOrUsernameParams{ Username: params.Username, }) if err == nil { @@ -786,7 +806,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook params.Username = httpapi.UsernameFrom(alternate) - _, err := tx.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{ + //nolint:gocritic + _, err := tx.GetUserByEmailOrUsername(dbauthz.AsSystem(ctx), database.GetUserByEmailOrUsernameParams{ Username: params.Username, }) if xerrors.Is(err, sql.ErrNoRows) { @@ -805,7 +826,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook } } - user, _, err = api.CreateUser(ctx, tx, CreateUserRequest{ + //nolint:gocritic + user, _, err = api.CreateUser(dbauthz.AsSystem(ctx), tx, CreateUserRequest{ CreateUserRequest: codersdk.CreateUserRequest{ Email: params.Email, Username: params.Username, @@ -819,7 +841,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook } if link.UserID == uuid.Nil { - link, err = tx.InsertUserLink(ctx, database.InsertUserLinkParams{ + //nolint:gocritic + link, err = tx.InsertUserLink(dbauthz.AsSystem(ctx), database.InsertUserLinkParams{ UserID: user.ID, LoginType: params.LoginType, LinkedID: params.LinkedID, @@ -833,7 +856,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook } if link.UserID != uuid.Nil { - link, err = tx.UpdateUserLink(ctx, database.UpdateUserLinkParams{ + //nolint:gocritic + link, err = tx.UpdateUserLink(dbauthz.AsSystem(ctx), database.UpdateUserLinkParams{ UserID: user.ID, LoginType: params.LoginType, OAuthAccessToken: params.State.Token.AccessToken, @@ -847,7 +871,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook // Ensure groups are correct. if len(params.Groups) > 0 { - err := api.Options.SetUserGroups(ctx, tx, user.ID, params.Groups) + //nolint:gocritic + err := api.Options.SetUserGroups(dbauthz.AsSystem(ctx), tx, user.ID, params.Groups) if err != nil { return xerrors.Errorf("set user groups: %w", err) } @@ -880,7 +905,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook // In such cases in the current implementation this user can now no // longer sign in until an administrator finds the offending built-in // user and changes their username. - user, err = tx.UpdateUserProfile(ctx, database.UpdateUserProfileParams{ + //nolint:gocritic + user, err = tx.UpdateUserProfile(dbauthz.AsSystem(ctx), database.UpdateUserProfileParams{ ID: user.ID, Email: user.Email, Username: user.Username, @@ -898,7 +924,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook return nil, database.APIKey{}, xerrors.Errorf("in tx: %w", err) } - cookie, key, err := api.createAPIKey(ctx, createAPIKeyParams{ + //nolint:gocritic + cookie, key, err := api.createAPIKey(dbauthz.AsSystem(ctx), createAPIKeyParams{ UserID: user.ID, LoginType: params.LoginType, RemoteAddr: r.RemoteAddr, diff --git a/coderd/users.go b/coderd/users.go index 9f49a74aec64b..ed79fc43d7d3c 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -16,6 +16,7 @@ import ( "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/gitsshkey" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" @@ -37,7 +38,8 @@ import ( // @Router /users/first [get] func (api *API) firstUser(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - userCount, err := api.Database.GetUserCount(ctx) + //nolint:gocritic // needed for first user check + userCount, err := api.Database.GetUserCount(dbauthz.AsSystem(ctx)) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching user count.", @@ -70,6 +72,7 @@ func (api *API) firstUser(rw http.ResponseWriter, r *http.Request) { // @Success 201 {object} codersdk.CreateFirstUserResponse // @Router /users/first [post] func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) { + // TODO: Should this admin system context be in a middleware? ctx := r.Context() var createUser codersdk.CreateFirstUserRequest if !httpapi.Read(ctx, rw, r, &createUser) { @@ -77,7 +80,8 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) { } // This should only function for the first user. - userCount, err := api.Database.GetUserCount(ctx) + //nolint:gocritic // needed to create first user + userCount, err := api.Database.GetUserCount(dbauthz.AsSystem(ctx)) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching user count.", @@ -117,7 +121,8 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) { return } - user, organizationID, err := api.CreateUser(ctx, api.Database, CreateUserRequest{ + //nolint:gocritic // needed to create first user + user, organizationID, err := api.CreateUser(dbauthz.AsSystem(ctx), api.Database, CreateUserRequest{ CreateUserRequest: codersdk.CreateUserRequest{ Email: createUser.Email, Username: createUser.Username, @@ -146,7 +151,8 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) { // the user. Maybe I add this ability to grant roles in the createUser api // and add some rbac bypass when calling api functions this way?? // Add the admin role to this first user. - _, err = api.Database.UpdateUserRoles(ctx, database.UpdateUserRolesParams{ + //nolint:gocritic // needed to create first user + _, err = api.Database.UpdateUserRoles(dbauthz.AsSystem(ctx), database.UpdateUserRolesParams{ GrantedRoles: []string{rbac.RoleOwner()}, ID: user.ID, }) @@ -987,7 +993,7 @@ func (api *API) organizationByUserAndName(rw http.ResponseWriter, r *http.Reques ctx := r.Context() organizationName := chi.URLParam(r, "organizationname") organization, err := api.Database.GetOrganizationByName(ctx, organizationName) - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, sql.ErrNoRows) || rbac.IsUnauthorizedError(err) { httpapi.ResourceNotFound(rw) return } diff --git a/coderd/users_test.go b/coderd/users_test.go index 283a50a25e9d8..ad262802f7789 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -854,7 +854,7 @@ func TestGrantSiteRoles(t *testing.T) { AssignToUser: randOrgUser.ID.String(), Roles: []string{rbac.RoleOrgMember(randOrg.ID)}, Error: true, - StatusCode: http.StatusForbidden, + StatusCode: http.StatusNotFound, }, { Name: "AdminUpdateOrgSelf", diff --git a/coderd/util/slice/slice.go b/coderd/util/slice/slice.go index d13162cb4fa57..9909fe2b72c21 100644 --- a/coderd/util/slice/slice.go +++ b/coderd/util/slice/slice.go @@ -1,10 +1,5 @@ package slice -// New is a convenience method for creating []T. -func New[T any](items ...T) []T { - return items -} - // SameElements returns true if the 2 lists have the same elements in any // order. func SameElements[T comparable](a []T, b []T) bool { @@ -67,3 +62,8 @@ func OverlapCompare[T any](a []T, b []T, equal func(a, b T) bool) bool { } return false } + +// New is a convenience method for creating []T. +func New[T any](items ...T) []T { + return items +} diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index f8f40d362e999..ed7e09bb61d1c 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -26,6 +26,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/agent" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/gitauth" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" @@ -625,14 +626,29 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request // inactive disconnect timeout we ensure that we don't block but // also guarantee that the agent will be considered disconnected // by normal status check. - ctx, cancel := context.WithTimeout(api.ctx, api.AgentInactiveDisconnectTimeout) + // + // Use a system context as the agent has disconnected and that token + // may no longer be valid. + //nolint:gocritic + ctx, cancel := context.WithTimeout(dbauthz.AsSystem(api.ctx), api.AgentInactiveDisconnectTimeout) defer cancel() disconnectedAt = sql.NullTime{ Time: database.Now(), Valid: true, } - _ = updateConnectionTimes(ctx) + err := updateConnectionTimes(ctx) + if err != nil { + // This is a bug with unit tests that cancel the app context and + // cause this error log to be generated. We should fix the unit tests + // as this is a valid log. + if !xerrors.Is(err, context.Canceled) { + api.Logger.Error(ctx, "failed to update agent disconnect time", + slog.Error(err), + slog.F("workspace", build.WorkspaceID), + ) + } + } api.publishWorkspaceUpdate(ctx, build.WorkspaceID) }() @@ -907,7 +923,7 @@ func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Reques slog.F("payload", req), ) - activityBumpWorkspace(api.Logger.Named("activity_bump"), api.Database, workspace.ID) + activityBumpWorkspace(ctx, api.Logger.Named("activity_bump"), api.Database, workspace.ID) payload, err := json.Marshal(req) if err != nil { diff --git a/coderd/workspaceapps.go b/coderd/workspaceapps.go index 4c6f4822f3d10..43714d089f9e8 100644 --- a/coderd/workspaceapps.go +++ b/coderd/workspaceapps.go @@ -24,6 +24,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/rbac" @@ -330,7 +331,8 @@ func (api *API) handleWorkspaceAppLogout(rw http.ResponseWriter, r *http.Request // different auth formats, and tricks this endpoint into deleting an // unchecked API key, we validate that the secret matches the secret // we store in the database. - apiKey, err := api.Database.GetAPIKeyByID(ctx, id) + //nolint:gocritic // needed for workspace app logout + apiKey, err := api.Database.GetAPIKeyByID(dbauthz.AsSystem(ctx), id) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to lookup API key.", @@ -349,7 +351,8 @@ func (api *API) handleWorkspaceAppLogout(rw http.ResponseWriter, r *http.Request }) return } - err = api.Database.DeleteAPIKeyByID(ctx, id) + //nolint:gocritic // needed for workspace app logout + err = api.Database.DeleteAPIKeyByID(dbauthz.AsSystem(ctx), id) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to delete API key.", @@ -409,7 +412,10 @@ func (api *API) handleWorkspaceAppLogout(rw http.ResponseWriter, r *http.Request // error while looking it up, an HTML error page is returned and false is // returned so the caller can return early. func (api *API) lookupWorkspaceApp(rw http.ResponseWriter, r *http.Request, agentID uuid.UUID, appSlug string) (database.WorkspaceApp, bool) { - app, err := api.Database.GetWorkspaceAppByAgentIDAndSlug(r.Context(), database.GetWorkspaceAppByAgentIDAndSlugParams{ + // dbauthz.AsSystem is allowed here as the app authz is checked later. + // The app authz is determined by the sharing level. + //nolint:gocritic + app, err := api.Database.GetWorkspaceAppByAgentIDAndSlug(dbauthz.AsSystem(r.Context()), database.GetWorkspaceAppByAgentIDAndSlugParams{ AgentID: agentID, Slug: appSlug, }) @@ -1019,7 +1025,8 @@ func decryptAPIKey(ctx context.Context, db database.Store, encryptedAPIKey strin // Lookup the API key so we can decrypt it. keyID := object.Header.KeyID - key, err := db.GetAPIKeyByID(ctx, keyID) + //nolint:gocritic // needed to check API key + key, err := db.GetAPIKeyByID(dbauthz.AsSystem(ctx), keyID) if err != nil { return database.APIKey{}, "", xerrors.Errorf("get API key by key ID: %w", err) } diff --git a/coderd/workspaceresourceauth.go b/coderd/workspaceresourceauth.go index f8f9b3a32b795..7fa8e8aa8907d 100644 --- a/coderd/workspaceresourceauth.go +++ b/coderd/workspaceresourceauth.go @@ -10,6 +10,7 @@ import ( "github.com/coder/coder/coderd/awsidentity" "github.com/coder/coder/coderd/azureidentity" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/provisionerdserver" "github.com/coder/coder/codersdk" @@ -126,7 +127,8 @@ func (api *API) postWorkspaceAuthGoogleInstanceIdentity(rw http.ResponseWriter, func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, instanceID string) { ctx := r.Context() - agent, err := api.Database.GetWorkspaceAgentByInstanceID(ctx, instanceID) + //nolint:gocritic // needed for auth instance id + agent, err := api.Database.GetWorkspaceAgentByInstanceID(dbauthz.AsSystem(ctx), instanceID) if errors.Is(err, sql.ErrNoRows) { httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ Message: fmt.Sprintf("Instance with id %q not found.", instanceID), @@ -140,7 +142,8 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in }) return } - resource, err := api.Database.GetWorkspaceResourceByID(ctx, agent.ResourceID) + //nolint:gocritic // needed for auth instance id + resource, err := api.Database.GetWorkspaceResourceByID(dbauthz.AsSystem(ctx), agent.ResourceID) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching provisioner job resource.", @@ -148,7 +151,8 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in }) return } - job, err := api.Database.GetProvisionerJobByID(ctx, resource.JobID) + //nolint:gocritic // needed for auth instance id + job, err := api.Database.GetProvisionerJobByID(dbauthz.AsSystem(ctx), resource.JobID) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching provisioner job.", @@ -171,7 +175,8 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in }) return } - resourceHistory, err := api.Database.GetWorkspaceBuildByID(ctx, jobData.WorkspaceBuildID) + //nolint:gocritic // needed for auth instance id + resourceHistory, err := api.Database.GetWorkspaceBuildByID(dbauthz.AsSystem(ctx), jobData.WorkspaceBuildID) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching workspace build.", @@ -182,7 +187,8 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in // This token should only be exchanged if the instance ID is valid // for the latest history. If an instance ID is recycled by a cloud, // we'd hate to leak access to a user's workspace. - latestHistory, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(ctx, resourceHistory.WorkspaceID) + //nolint:gocritic // needed for auth instance id + latestHistory, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(dbauthz.AsSystem(ctx), resourceHistory.WorkspaceID) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching the latest workspace build.", diff --git a/coderd/workspaces.go b/coderd/workspaces.go index f945cf39cf4bc..00a10e3dbbb71 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -371,6 +371,9 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req return } + // TODO: This should be a system call as the actor might not be able to + // read other workspaces. Ideally we check the error on create and look for + // a postgres conflict error. workspace, err := api.Database.GetWorkspaceByOwnerIDAndName(ctx, database.GetWorkspaceByOwnerIDAndNameParams{ OwnerID: user.ID, Name: createWorkspace.Name, diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 67d5b543aef31..20d984a3b946c 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -144,7 +144,9 @@ func New(ctx context.Context, options *Options) (*API, error) { if len(options.SCIMAPIKey) != 0 { api.AGPL.RootHandler.Route("/scim/v2", func(r chi.Router) { - r.Use(api.scimEnabledMW) + r.Use( + api.scimEnabledMW, + ) r.Post("/Users", api.scimPostUser) r.Route("/Users", func(r chi.Router) { r.Get("/", api.scimGetUsers) diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index 4d67d97029830..1cba0a1f633c0 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -6,6 +6,8 @@ import ( "testing" "time" + "github.com/coder/coder/coderd/database/dbauthz" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" @@ -100,7 +102,9 @@ func TestEntitlements(t *testing.T) { require.NoError(t, err) require.False(t, entitlements.HasLicense) coderdtest.CreateFirstUser(t, client) - _, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{ + //nolint:gocritic // unit test + ctx := dbauthz.AsSystem(context.Background()) + _, err = api.Database.InsertLicense(ctx, database.InsertLicenseParams{ UploadedAt: database.Now(), Exp: database.Now().AddDate(1, 0, 0), JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ @@ -128,7 +132,9 @@ func TestEntitlements(t *testing.T) { require.False(t, entitlements.HasLicense) coderdtest.CreateFirstUser(t, client) // Valid - _, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{ + ctx := context.Background() + //nolint:gocritic // unit test + _, err = api.Database.InsertLicense(dbauthz.AsSystem(ctx), database.InsertLicenseParams{ UploadedAt: database.Now(), Exp: database.Now().AddDate(1, 0, 0), JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ @@ -139,7 +145,8 @@ func TestEntitlements(t *testing.T) { }) require.NoError(t, err) // Expired - _, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{ + //nolint:gocritic // unit test + _, err = api.Database.InsertLicense(dbauthz.AsSystem(ctx), database.InsertLicenseParams{ UploadedAt: database.Now(), Exp: database.Now().AddDate(-1, 0, 0), JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ @@ -148,7 +155,8 @@ func TestEntitlements(t *testing.T) { }) require.NoError(t, err) // Invalid - _, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{ + //nolint:gocritic // unit test + _, err = api.Database.InsertLicense(dbauthz.AsSystem(ctx), database.InsertLicenseParams{ UploadedAt: database.Now(), Exp: database.Now().AddDate(1, 0, 0), JWT: "invalid", diff --git a/enterprise/coderd/coderdenttest/coderdenttest_test.go b/enterprise/coderd/coderdenttest/coderdenttest_test.go index 8fdfbd0a8c9e2..aa32c582e67f9 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest_test.go +++ b/enterprise/coderd/coderdenttest/coderdenttest_test.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "net/http" + "os" + "strings" "testing" "github.com/stretchr/testify/require" @@ -22,6 +24,9 @@ func TestNew(t *testing.T) { } func TestAuthorizeAllEndpoints(t *testing.T) { + if strings.Contains(os.Getenv("CODER_EXPERIMENTS_TEST"), string(codersdk.ExperimentAuthzQuerier)) { + t.Skip("Skipping TestAuthorizeAllEndpoints for authz_querier experiment") + } t.Parallel() client, _, api := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ Options: &coderdtest.Options{ diff --git a/enterprise/coderd/scim.go b/enterprise/coderd/scim.go index 8912c3fafaf07..b0ad00e72fd3a 100644 --- a/enterprise/coderd/scim.go +++ b/enterprise/coderd/scim.go @@ -14,6 +14,7 @@ import ( agpl "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/codersdk" ) @@ -155,7 +156,8 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { return } - user, _, err := api.AGPL.CreateUser(ctx, api.Database, agpl.CreateUserRequest{ + //nolint:gocritic // needed for SCIM + user, _, err := api.AGPL.CreateUser(dbauthz.AsSystem(ctx), api.Database, agpl.CreateUserRequest{ CreateUserRequest: codersdk.CreateUserRequest{ Username: sUser.UserName, Email: email, @@ -207,7 +209,8 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) { return } - dbUser, err := api.Database.GetUserByID(ctx, uid) + //nolint:gocritic // needed for SCIM + dbUser, err := api.Database.GetUserByID(dbauthz.AsSystem(ctx), uid) if err != nil { _ = handlerutil.WriteError(rw, err) return @@ -220,7 +223,8 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) { status = database.UserStatusSuspended } - _, err = api.Database.UpdateUserStatus(r.Context(), database.UpdateUserStatusParams{ + //nolint:gocritic // needed for SCIM + _, err = api.Database.UpdateUserStatus(dbauthz.AsSystem(r.Context()), database.UpdateUserStatusParams{ ID: dbUser.ID, Status: status, UpdatedAt: database.Now(), diff --git a/enterprise/coderd/templates_test.go b/enterprise/coderd/templates_test.go index e77bafd550540..7c9303f188310 100644 --- a/enterprise/coderd/templates_test.go +++ b/enterprise/coderd/templates_test.go @@ -921,6 +921,10 @@ func TestTemplateAccess(t *testing.T) { testTemplateRead := func(t *testing.T, org orgSetup, usr *codersdk.Client, read []codersdk.Template) { found, err := usr.TemplatesByOrganization(ctx, org.Org.ID) + if len(read) == 0 && err != nil { + require.ErrorContains(t, err, "Resource not found") + return + } require.NoError(t, err, "failed to get templates") exp := make(map[uuid.UUID]codersdk.Template) diff --git a/provisionerd/runner/runner.go b/provisionerd/runner/runner.go index 4947df09350cf..4526da1fce58e 100644 --- a/provisionerd/runner/runner.go +++ b/provisionerd/runner/runner.go @@ -24,6 +24,7 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog" + "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/tracing" "github.com/coder/coder/provisionerd/proto" sdkproto "github.com/coder/coder/provisionersdk/proto" @@ -886,7 +887,8 @@ func (r *Runner) commitQuota(ctx context.Context, resources []*sdkproto.Resource const stage = "Commit quota" - resp, err := r.quotaCommitter.CommitQuota(ctx, &proto.CommitQuotaRequest{ + //nolint:gocritic // TODO: make a provisionerd role + resp, err := r.quotaCommitter.CommitQuota(dbauthz.AsSystem(ctx), &proto.CommitQuotaRequest{ JobId: r.job.JobId, DailyCost: int32(cost), }) diff --git a/scripts/rules.go b/scripts/rules.go index 1d83dd2315bf3..414167da15036 100644 --- a/scripts/rules.go +++ b/scripts/rules.go @@ -20,6 +20,29 @@ import ( "github.com/quasilyte/go-ruleguard/dsl/types" ) +// dbauthzAuthorizationContext is a lint rule that protects the usage of +// system contexts. This is a dangerous pattern that can lead to +// leaking database information as a system context can be essentially +// "sudo". +// +// Anytime a function like "AsSystem" is used, it should be accompanied by a comment +// explaining why it's ok and a nolint. +func dbauthzAuthorizationContext(m dsl.Matcher) { + m.Import("context") + m.Import("github.com/coder/coder/coderd/database/dbauthz") + + m.Match( + `dbauthz.$f($c)`, + ). + Where( + m["c"].Type.Implements("context.Context") && + // Only report on functions that start with "As". + m["f"].Text.Matches("^As"), + ). + // Instructions for fixing the lint error should be included on the dangerous function. + Report("Using '$f' is dangerous and should be accompanied by a comment explaining why it's ok and a nolint.") +} + // Use xerrors everywhere! It provides additional stacktrace info! // //nolint:unused,deadcode,varnamelen