Thanks to visit codestin.com
Credit goes to github.com

Skip to content

chore: Allow RecordingAuthorizer to record multiple rbac authz calls #6024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 185 additions & 50 deletions coderd/coderdtest/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@ import (
"net/http"
"strconv"
"strings"
"sync"
"testing"
"time"

"github.com/coder/coder/coderd/database/dbfake"

"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"

"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/database/dbfake"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/rbac/regosql"
"github.com/coder/coder/codersdk"
Expand Down Expand Up @@ -443,7 +443,9 @@ func NewAuthTester(ctx context.Context, t *testing.T, client *codersdk.Client, a

func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck, skipRoutes map[string]string) {
// Always fail auth from this point forward
a.authorizer.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil)
a.authorizer.Wrapped = &FakeAuthorizer{
AlwaysReturn: rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil),
}

routeMissing := make(map[string]bool)
for k, v := range assertRoute {
Expand Down Expand Up @@ -483,7 +485,7 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
return nil
}
a.t.Run(name, func(t *testing.T) {
a.authorizer.reset()
a.authorizer.Reset()
routeKey := strings.TrimRight(name, "/")

routeAssertions, ok := assertRoute[routeKey]
Expand Down Expand Up @@ -514,18 +516,19 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
assert.Equal(t, http.StatusForbidden, resp.StatusCode, "expect unauthorized")
}
}
if a.authorizer.Called != nil {
if a.authorizer.lastCall() != nil {
last := a.authorizer.lastCall()
if routeAssertions.AssertAction != "" {
assert.Equal(t, routeAssertions.AssertAction, a.authorizer.Called.Action, "resource action")
assert.Equal(t, routeAssertions.AssertAction, last.Action, "resource action")
}
if routeAssertions.AssertObject.Type != "" {
assert.Equal(t, routeAssertions.AssertObject.Type, a.authorizer.Called.Object.Type, "resource type")
assert.Equal(t, routeAssertions.AssertObject.Type, last.Object.Type, "resource type")
}
if routeAssertions.AssertObject.Owner != "" {
assert.Equal(t, routeAssertions.AssertObject.Owner, a.authorizer.Called.Object.Owner, "resource owner")
assert.Equal(t, routeAssertions.AssertObject.Owner, last.Object.Owner, "resource owner")
}
if routeAssertions.AssertObject.OrgID != "" {
assert.Equal(t, routeAssertions.AssertObject.OrgID, a.authorizer.Called.Object.OrgID, "resource org")
assert.Equal(t, routeAssertions.AssertObject.OrgID, last.Object.OrgID, "resource org")
}
}
} else {
Expand All @@ -539,52 +542,195 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
}

type authCall struct {
Subject rbac.Subject
Action rbac.Action
Object rbac.Object
Actor rbac.Subject
Action rbac.Action
Object rbac.Object

asserted bool
}

var _ rbac.Authorizer = (*RecordingAuthorizer)(nil)

// RecordingAuthorizer wraps any rbac.Authorizer and records all Authorize()
// calls made. This is useful for testing as these calls can later be asserted.
type RecordingAuthorizer struct {
Called *authCall
AlwaysReturn error
sync.RWMutex
Called []authCall
Wrapped rbac.Authorizer
}

var _ rbac.Authorizer = (*RecordingAuthorizer)(nil)
type ActionObjectPair struct {
Action rbac.Action
Object rbac.Object
}

// AuthorizeSQL does not record the call. This matches the postgres behavior
// of not calling Authorize()
func (r *RecordingAuthorizer) AuthorizeSQL(_ context.Context, _ rbac.Subject, _ rbac.Action, _ rbac.Object) error {
return r.AlwaysReturn
// Pair is on the RecordingAuthorizer to be easy to find and keep the pkg
// interface smaller.
func (*RecordingAuthorizer) Pair(action rbac.Action, object rbac.Objecter) ActionObjectPair {
return ActionObjectPair{
Action: action,
Object: object.RBACObject(),
}
}

func (r *RecordingAuthorizer) Authorize(_ context.Context, subject rbac.Subject, action rbac.Action, object rbac.Object) error {
r.Called = &authCall{
Subject: subject,
Action: action,
Object: object,
// AllAsserted returns an error if all calls to Authorize() have not been
// asserted and checked. This is useful for testing to ensure that all
// Authorize() calls are checked in the unit test.
func (r *RecordingAuthorizer) AllAsserted() error {
r.RLock()
defer r.RUnlock()
missed := []authCall{}
for _, c := range r.Called {
if !c.asserted {
missed = append(missed, c)
}
}
return r.AlwaysReturn

if len(missed) > 0 {
return xerrors.Errorf("missed calls: %+v", missed)
}
return nil
}

func (r *RecordingAuthorizer) Prepare(_ context.Context, subject rbac.Subject, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) {
return &fakePreparedAuthorizer{
Original: r,
Subject: subject,
Action: action,
HardCodedSQLString: "true",
// AssertActor asserts in order. If the order of authz calls does not match,
// this will fail.
func (r *RecordingAuthorizer) AssertActor(t *testing.T, actor rbac.Subject, did ...ActionObjectPair) {
r.RLock()
defer r.RUnlock()
ptr := 0
for i, call := range r.Called {
if ptr == len(did) {
// Finished all assertions
return
}
if call.Actor.ID == actor.ID {
action, object := did[ptr].Action, did[ptr].Object
assert.Equalf(t, action, call.Action, "assert action %d", ptr)
assert.Equalf(t, object, call.Object, "assert object %d", ptr)
r.Called[i].asserted = true
ptr++
}
}

assert.Equalf(t, len(did), ptr, "assert actor: didn't find all actions, %d missing actions", len(did)-ptr)
}

// recordAuthorize is the internal method that records the Authorize() call.
func (r *RecordingAuthorizer) recordAuthorize(subject rbac.Subject, action rbac.Action, object rbac.Object) {
r.Lock()
defer r.Unlock()
r.Called = append(r.Called, authCall{
Actor: subject,
Action: action,
Object: object,
})
}

func (r *RecordingAuthorizer) Authorize(ctx context.Context, subject rbac.Subject, action rbac.Action, object rbac.Object) error {
r.recordAuthorize(subject, action, object)
if r.Wrapped == nil {
panic("Developer error: RecordingAuthorizer.Wrapped is nil")
}
return r.Wrapped.Authorize(ctx, subject, action, object)
}

func (r *RecordingAuthorizer) Prepare(ctx context.Context, subject rbac.Subject, action rbac.Action, objectType string) (rbac.PreparedAuthorized, error) {
r.RLock()
defer r.RUnlock()
if r.Wrapped == nil {
panic("Developer error: RecordingAuthorizer.Wrapped is nil")
}

prep, err := r.Wrapped.Prepare(ctx, subject, action, objectType)
if err != nil {
return nil, err
}
return &PreparedRecorder{
rec: r,
prepped: prep,
subject: subject,
action: action,
}, nil
}

func (r *RecordingAuthorizer) reset() {
// Reset clears the recorded Authorize() calls.
func (r *RecordingAuthorizer) Reset() {
r.Lock()
defer r.Unlock()
r.Called = nil
}

// lastCall is implemented to support legacy tests.
// Deprecated
func (r *RecordingAuthorizer) lastCall() *authCall {
r.RLock()
defer r.RUnlock()
if len(r.Called) == 0 {
return nil
}
return &r.Called[len(r.Called)-1]
}

// PreparedRecorder is the prepared version of the RecordingAuthorizer.
// It records the Authorize() calls to the original recorder. If the caller
// uses CompileToSQL, all recording stops. This is to support parity between
// memory and SQL backed dbs.
type PreparedRecorder struct {
rec *RecordingAuthorizer
prepped rbac.PreparedAuthorized
subject rbac.Subject
action rbac.Action

rw sync.Mutex
usingSQL bool
}

func (s *PreparedRecorder) Authorize(ctx context.Context, object rbac.Object) error {
s.rw.Lock()
defer s.rw.Unlock()

if !s.usingSQL {
s.rec.recordAuthorize(s.subject, s.action, object)
}
return s.prepped.Authorize(ctx, object)
}
func (s *PreparedRecorder) CompileToSQL(ctx context.Context, cfg regosql.ConvertConfig) (string, error) {
s.rw.Lock()
defer s.rw.Unlock()

s.usingSQL = true
return s.prepped.CompileToSQL(ctx, cfg)
}

// FakeAuthorizer is an Authorizer that always returns the same error.
type FakeAuthorizer struct {
// AlwaysReturn is the error that will be returned by Authorize.
AlwaysReturn error
}

var _ rbac.Authorizer = (*FakeAuthorizer)(nil)

func (d *FakeAuthorizer) Authorize(_ context.Context, _ rbac.Subject, _ rbac.Action, _ rbac.Object) error {
return d.AlwaysReturn
}

func (d *FakeAuthorizer) Prepare(_ context.Context, subject rbac.Subject, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) {
return &fakePreparedAuthorizer{
Original: d,
Subject: subject,
Action: action,
}, nil
}

var _ rbac.PreparedAuthorized = (*fakePreparedAuthorizer)(nil)

// fakePreparedAuthorizer is the prepared version of a FakeAuthorizer. It will
// return the same error as the original FakeAuthorizer.
type fakePreparedAuthorizer struct {
Original *RecordingAuthorizer
Subject rbac.Subject
Action rbac.Action
HardCodedSQLString string
HardCodedRegoString string
sync.RWMutex
Original *FakeAuthorizer
Subject rbac.Subject
Action rbac.Action
}

func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Object) error {
Expand All @@ -593,17 +739,6 @@ func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Obje

// CompileToSQL returns a compiled version of the authorizer that will work for
// in memory databases. This fake version will not work against a SQL database.
func (fakePreparedAuthorizer) CompileToSQL(_ context.Context, _ regosql.ConvertConfig) (string, error) {
return "", xerrors.New("not implemented")
}

func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool {
return f.Original.AuthorizeSQL(context.Background(), f.Subject, f.Action, object) == nil
}

func (f fakePreparedAuthorizer) RegoString() string {
if f.HardCodedRegoString != "" {
return f.HardCodedRegoString
}
panic("not implemented")
func (*fakePreparedAuthorizer) CompileToSQL(_ context.Context, _ regosql.ConvertConfig) (string, error) {
return "not a valid sql string", nil
}
Loading