Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 339f1de

Browse files
committed
Extract count from audit logs query to a separate one and optimize audit logs query with conditional joins
1 parent 688d2ee commit 339f1de

File tree

12 files changed

+699
-238
lines changed

12 files changed

+699
-238
lines changed

coderd/audit.go

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,21 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
6464
filter.Username = ""
6565
}
6666

67-
dblogs, err := api.Database.GetAuditLogsOffset(ctx, filter)
67+
// Use the same filters to count the number of audit logs
68+
count, err := api.Database.CountAuditLogs(ctx, database.CountAuditLogsParams{
69+
ResourceType: filter.ResourceType,
70+
ResourceID: filter.ResourceID,
71+
OrganizationID: filter.OrganizationID,
72+
ResourceTarget: filter.ResourceTarget,
73+
Action: filter.Action,
74+
UserID: filter.UserID,
75+
Username: filter.Username,
76+
Email: filter.Email,
77+
DateFrom: filter.DateFrom,
78+
DateTo: filter.DateTo,
79+
BuildReason: filter.BuildReason,
80+
RequestID: filter.RequestID,
81+
})
6882
if dbauthz.IsNotAuthorizedError(err) {
6983
httpapi.Forbidden(rw)
7084
return
@@ -73,19 +87,28 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
7387
httpapi.InternalServerError(rw, err)
7488
return
7589
}
76-
// GetAuditLogsOffset does not return ErrNoRows because it uses a window function to get the count.
77-
// So we need to check if the dblogs is empty and return an empty array if so.
78-
if len(dblogs) == 0 {
90+
// If count is 0, then we don't need to query audit logs
91+
if count == 0 {
7992
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
8093
AuditLogs: []codersdk.AuditLog{},
8194
Count: 0,
8295
})
8396
return
8497
}
8598

99+
dblogs, err := api.Database.GetAuditLogsOffset(ctx, filter)
100+
if dbauthz.IsNotAuthorizedError(err) {
101+
httpapi.Forbidden(rw)
102+
return
103+
}
104+
if err != nil {
105+
httpapi.InternalServerError(rw, err)
106+
return
107+
}
108+
86109
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
87110
AuditLogs: api.convertAuditLogs(ctx, dblogs),
88-
Count: dblogs[0].Count,
111+
Count: count,
89112
})
90113
}
91114

coderd/database/dbauthz/dbauthz.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,6 +1301,22 @@ func (q *querier) CleanTailnetTunnels(ctx context.Context) error {
13011301
return q.db.CleanTailnetTunnels(ctx)
13021302
}
13031303

1304+
func (q *querier) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) {
1305+
// Shortcut if the user is an owner. The SQL filter is noticeable,
1306+
// and this is an easy win for owners. Which is the common case.
1307+
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAuditLog)
1308+
if err == nil {
1309+
return q.db.CountAuditLogs(ctx, arg)
1310+
}
1311+
1312+
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAuditLog.Type)
1313+
if err != nil {
1314+
return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
1315+
}
1316+
1317+
return q.db.CountAuthorizedAuditLogs(ctx, arg, prep)
1318+
}
1319+
13041320
func (q *querier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
13051321
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil {
13061322
return nil, err
@@ -5256,3 +5272,7 @@ func (q *querier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersP
52565272
func (q *querier) GetAuthorizedAuditLogsOffset(ctx context.Context, arg database.GetAuditLogsOffsetParams, _ rbac.PreparedAuthorized) ([]database.GetAuditLogsOffsetRow, error) {
52575273
return q.GetAuditLogsOffset(ctx, arg)
52585274
}
5275+
5276+
func (q *querier) CountAuthorizedAuditLogs(ctx context.Context, arg database.CountAuditLogsParams, _ rbac.PreparedAuthorized) (int64, error) {
5277+
return q.CountAuditLogs(ctx, arg)
5278+
}

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,16 @@ func (s *MethodTestSuite) TestAuditLogs() {
327327
LimitOpt: 10,
328328
}, emptyPreparedAuthorized{}).Asserts(rbac.ResourceAuditLog, policy.ActionRead)
329329
}))
330+
s.Run("CountAuditLogs", s.Subtest(func(db database.Store, check *expects) {
331+
_ = dbgen.AuditLog(s.T(), db, database.AuditLog{})
332+
_ = dbgen.AuditLog(s.T(), db, database.AuditLog{})
333+
check.Args(database.CountAuditLogsParams{}).Asserts(rbac.ResourceAuditLog, policy.ActionRead).WithNotAuthorized("nil")
334+
}))
335+
s.Run("CountAuthorizedAuditLogs", s.Subtest(func(db database.Store, check *expects) {
336+
_ = dbgen.AuditLog(s.T(), db, database.AuditLog{})
337+
_ = dbgen.AuditLog(s.T(), db, database.AuditLog{})
338+
check.Args(database.CountAuditLogsParams{}, emptyPreparedAuthorized{}).Asserts(rbac.ResourceAuditLog, policy.ActionRead)
339+
}))
330340
}
331341

332342
func (s *MethodTestSuite) TestFile() {

coderd/database/dbauthz/setup_test.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
271271

272272
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
273273
// any case where the error is nil and the response is an empty slice.
274-
if err != nil || !hasEmptySliceResponse(resp) {
274+
if err != nil || !hasEmptyResponse(resp) {
275275
// Expect the default error
276276
if testCase.notAuthorizedExpect == "" {
277277
s.ErrorContainsf(err, "unauthorized", "error string should have a good message")
@@ -297,7 +297,7 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
297297

298298
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
299299
// any case where the error is nil and the response is an empty slice.
300-
if err != nil || !hasEmptySliceResponse(resp) {
300+
if err != nil || !hasEmptyResponse(resp) {
301301
if testCase.cancelledCtxExpect == "" {
302302
s.Errorf(err, "method should an error with cancellation")
303303
s.ErrorIsf(err, context.Canceled, "error should match context.Canceled")
@@ -308,13 +308,20 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
308308
})
309309
}
310310

311-
func hasEmptySliceResponse(values []reflect.Value) bool {
311+
func hasEmptyResponse(values []reflect.Value) bool {
312312
for _, r := range values {
313313
if r.Kind() == reflect.Slice || r.Kind() == reflect.Array {
314314
if r.Len() == 0 {
315315
return true
316316
}
317317
}
318+
319+
// Special case for int64, as it's the return type for count query.
320+
if r.Kind() == reflect.Int64 {
321+
if r.Int() == 0 {
322+
return true
323+
}
324+
}
318325
}
319326
return false
320327
}

coderd/database/dbmem/dbmem.go

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1779,6 +1779,10 @@ func (*FakeQuerier) CleanTailnetTunnels(context.Context) error {
17791779
return ErrUnimplemented
17801780
}
17811781

1782+
func (q *FakeQuerier) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) {
1783+
return q.CountAuthorizedAuditLogs(ctx, arg, nil)
1784+
}
1785+
17821786
func (q *FakeQuerier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
17831787
return nil, ErrUnimplemented
17841788
}
@@ -13930,18 +13934,89 @@ func (q *FakeQuerier) GetAuthorizedAuditLogsOffset(ctx context.Context, arg data
1393013934
UserQuietHoursSchedule: sql.NullString{String: user.QuietHoursSchedule, Valid: userValid},
1393113935
UserStatus: database.NullUserStatus{UserStatus: user.Status, Valid: userValid},
1393213936
UserRoles: user.RBACRoles,
13933-
Count: 0,
1393413937
})
1393513938

1393613939
if len(logs) >= int(arg.LimitOpt) {
1393713940
break
1393813941
}
1393913942
}
1394013943

13941-
count := int64(len(logs))
13942-
for i := range logs {
13943-
logs[i].Count = count
13944+
return logs, nil
13945+
}
13946+
13947+
func (q *FakeQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg database.CountAuditLogsParams, prepared rbac.PreparedAuthorized) (int64, error) {
13948+
if err := validateDatabaseType(arg); err != nil {
13949+
return 0, err
1394413950
}
1394513951

13946-
return logs, nil
13952+
// Call this to match the same function calls as the SQL implementation.
13953+
// It functionally does nothing for filtering.
13954+
if prepared != nil {
13955+
_, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
13956+
VariableConverter: regosql.AuditLogConverter(),
13957+
})
13958+
if err != nil {
13959+
return 0, err
13960+
}
13961+
}
13962+
13963+
q.mutex.RLock()
13964+
defer q.mutex.RUnlock()
13965+
13966+
var count int64
13967+
13968+
// q.auditLogs are already sorted by time DESC, so no need to sort after the fact.
13969+
for _, alog := range q.auditLogs {
13970+
if arg.RequestID != uuid.Nil && arg.RequestID != alog.RequestID {
13971+
continue
13972+
}
13973+
if arg.OrganizationID != uuid.Nil && arg.OrganizationID != alog.OrganizationID {
13974+
continue
13975+
}
13976+
if arg.Action != "" && string(alog.Action) != arg.Action {
13977+
continue
13978+
}
13979+
if arg.ResourceType != "" && !strings.Contains(string(alog.ResourceType), arg.ResourceType) {
13980+
continue
13981+
}
13982+
if arg.ResourceID != uuid.Nil && alog.ResourceID != arg.ResourceID {
13983+
continue
13984+
}
13985+
if arg.Username != "" {
13986+
user, err := q.getUserByIDNoLock(alog.UserID)
13987+
if err == nil && !strings.EqualFold(arg.Username, user.Username) {
13988+
continue
13989+
}
13990+
}
13991+
if arg.Email != "" {
13992+
user, err := q.getUserByIDNoLock(alog.UserID)
13993+
if err == nil && !strings.EqualFold(arg.Email, user.Email) {
13994+
continue
13995+
}
13996+
}
13997+
if !arg.DateFrom.IsZero() {
13998+
if alog.Time.Before(arg.DateFrom) {
13999+
continue
14000+
}
14001+
}
14002+
if !arg.DateTo.IsZero() {
14003+
if alog.Time.After(arg.DateTo) {
14004+
continue
14005+
}
14006+
}
14007+
if arg.BuildReason != "" {
14008+
workspaceBuild, err := q.getWorkspaceBuildByIDNoLock(context.Background(), alog.ResourceID)
14009+
if err == nil && !strings.EqualFold(arg.BuildReason, string(workspaceBuild.Reason)) {
14010+
continue
14011+
}
14012+
}
14013+
// If the filter exists, ensure the object is authorized.
14014+
if prepared != nil && prepared.Authorize(ctx, alog.RBACObject()) != nil {
14015+
continue
14016+
}
14017+
14018+
count++
14019+
}
14020+
14021+
return count, nil
1394714022
}

coderd/database/dbmetrics/querymetrics.go

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 30 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/modelqueries.go

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,7 @@ func (q *sqlQuerier) GetAuthorizedUsers(ctx context.Context, arg GetUsersParams,
478478

479479
type auditLogQuerier interface {
480480
GetAuthorizedAuditLogsOffset(ctx context.Context, arg GetAuditLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetAuditLogsOffsetRow, error)
481+
CountAuthorizedAuditLogs(ctx context.Context, arg CountAuditLogsParams, prepared rbac.PreparedAuthorized) (int64, error)
481482
}
482483

483484
func (q *sqlQuerier) GetAuthorizedAuditLogsOffset(ctx context.Context, arg GetAuditLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetAuditLogsOffsetRow, error) {
@@ -548,7 +549,6 @@ func (q *sqlQuerier) GetAuthorizedAuditLogsOffset(ctx context.Context, arg GetAu
548549
&i.OrganizationName,
549550
&i.OrganizationDisplayName,
550551
&i.OrganizationIcon,
551-
&i.Count,
552552
); err != nil {
553553
return nil, err
554554
}
@@ -563,6 +563,52 @@ func (q *sqlQuerier) GetAuthorizedAuditLogsOffset(ctx context.Context, arg GetAu
563563
return items, nil
564564
}
565565

566+
func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAuditLogsParams, prepared rbac.PreparedAuthorized) (int64, error) {
567+
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
568+
VariableConverter: regosql.AuditLogConverter(),
569+
})
570+
if err != nil {
571+
return 0, xerrors.Errorf("compile authorized filter: %w", err)
572+
}
573+
574+
filtered, err := insertAuthorizedFilter(countAuditLogs, fmt.Sprintf(" AND %s", authorizedFilter))
575+
if err != nil {
576+
return 0, xerrors.Errorf("insert authorized filter: %w", err)
577+
}
578+
579+
query := fmt.Sprintf("-- name: CountAuthorizedAuditLogs :one\n%s", filtered)
580+
581+
rows, err := q.db.QueryContext(ctx, query,
582+
arg.ResourceType,
583+
arg.ResourceID,
584+
arg.OrganizationID,
585+
arg.ResourceTarget,
586+
arg.Action,
587+
arg.UserID,
588+
arg.Username,
589+
arg.Email,
590+
arg.DateFrom,
591+
arg.DateTo,
592+
arg.BuildReason,
593+
arg.RequestID,
594+
)
595+
if err != nil {
596+
return 0, err
597+
}
598+
defer rows.Close()
599+
var count int64
600+
for rows.Next() {
601+
count++
602+
}
603+
if err := rows.Close(); err != nil {
604+
return 0, err
605+
}
606+
if err := rows.Err(); err != nil {
607+
return 0, err
608+
}
609+
return count, nil
610+
}
611+
566612
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
567613
if !strings.Contains(query, authorizedQueryPlaceholder) {
568614
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")

coderd/database/querier.go

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)