Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7918e65

Browse files
johnstcnkylecarbs
andauthored
feat(coderd): add dbcrypt package (#9522)
- Adds package enterprise/dbcrypt to implement database encryption/decryption - Adds table dbcrypt_keys and associated queries - Adds columns oauth_access_token_key_id and oauth_refresh_token_key_id to tables git_auth_links and user_links Co-authored-by: Kyle Carberry <[email protected]>
1 parent 3bd0fd3 commit 7918e65

22 files changed

+1996
-72
lines changed

coderd/database/dbauthz/dbauthz.go

+35
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,13 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI
838838
return q.db.GetAuthorizationUserRoles(ctx, userID)
839839
}
840840

841+
func (q *querier) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) {
842+
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
843+
return nil, err
844+
}
845+
return q.db.GetDBCryptKeys(ctx)
846+
}
847+
841848
func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) {
842849
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
843850
return "", err
@@ -914,6 +921,13 @@ func (q *querier) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLin
914921
return fetch(q.log, q.auth, q.db.GetGitAuthLink)(ctx, arg)
915922
}
916923

924+
func (q *querier) GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) {
925+
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
926+
return nil, err
927+
}
928+
return q.db.GetGitAuthLinksByUserID(ctx, userID)
929+
}
930+
917931
func (q *querier) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
918932
return fetch(q.log, q.auth, q.db.GetGitSSHKey)(ctx, userID)
919933
}
@@ -1482,6 +1496,13 @@ func (q *querier) GetUserLinkByUserIDLoginType(ctx context.Context, arg database
14821496
return q.db.GetUserLinkByUserIDLoginType(ctx, arg)
14831497
}
14841498

1499+
func (q *querier) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserLink, error) {
1500+
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
1501+
return nil, err
1502+
}
1503+
return q.db.GetUserLinksByUserID(ctx, userID)
1504+
}
1505+
14851506
func (q *querier) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) {
14861507
// This does the filtering in SQL.
14871508
prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type)
@@ -1845,6 +1866,13 @@ func (q *querier) InsertAuditLog(ctx context.Context, arg database.InsertAuditLo
18451866
return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertAuditLog)(ctx, arg)
18461867
}
18471868

1869+
func (q *querier) InsertDBCryptKey(ctx context.Context, arg database.InsertDBCryptKeyParams) error {
1870+
if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil {
1871+
return err
1872+
}
1873+
return q.db.InsertDBCryptKey(ctx, arg)
1874+
}
1875+
18481876
func (q *querier) InsertDERPMeshKey(ctx context.Context, value string) error {
18491877
if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil {
18501878
return err
@@ -2144,6 +2172,13 @@ func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.Regis
21442172
return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg)
21452173
}
21462174

2175+
func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
2176+
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil {
2177+
return err
2178+
}
2179+
return q.db.RevokeDBCryptKey(ctx, activeKeyDigest)
2180+
}
2181+
21472182
func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
21482183
return q.db.TryAcquireLock(ctx, id)
21492184
}

coderd/database/dbfake/dbfake.go

+133-13
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ import (
3131

3232
var validProxyByHostnameRegex = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)
3333

34+
var errForeignKeyConstraint = &pq.Error{
35+
Code: "23503",
36+
Message: "update or delete on table violates foreign key constraint",
37+
}
38+
3439
var errDuplicateKey = &pq.Error{
3540
Code: "23505",
3641
Message: "duplicate key value violates unique constraint",
@@ -45,6 +50,7 @@ func New() database.Store {
4550
organizationMembers: make([]database.OrganizationMember, 0),
4651
organizations: make([]database.Organization, 0),
4752
users: make([]database.User, 0),
53+
dbcryptKeys: make([]database.DBCryptKey, 0),
4854
gitAuthLinks: make([]database.GitAuthLink, 0),
4955
groups: make([]database.Group, 0),
5056
groupMembers: make([]database.GroupMember, 0),
@@ -117,6 +123,7 @@ type data struct {
117123
// New tables
118124
workspaceAgentStats []database.WorkspaceAgentStat
119125
auditLogs []database.AuditLog
126+
dbcryptKeys []database.DBCryptKey
120127
files []database.File
121128
gitAuthLinks []database.GitAuthLink
122129
gitSSHKey []database.GitSSHKey
@@ -665,6 +672,19 @@ func (q *FakeQuerier) isEveryoneGroup(id uuid.UUID) bool {
665672
return false
666673
}
667674

675+
func (q *FakeQuerier) GetActiveDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) {
676+
q.mutex.RLock()
677+
defer q.mutex.RUnlock()
678+
ks := make([]database.DBCryptKey, 0, len(q.dbcryptKeys))
679+
for _, k := range q.dbcryptKeys {
680+
if !k.ActiveKeyDigest.Valid {
681+
continue
682+
}
683+
ks = append([]database.DBCryptKey{}, k)
684+
}
685+
return ks, nil
686+
}
687+
668688
func (*FakeQuerier) AcquireLock(_ context.Context, _ int64) error {
669689
return xerrors.New("AcquireLock must only be called within a transaction")
670690
}
@@ -1151,6 +1171,14 @@ func (q *FakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U
11511171
}, nil
11521172
}
11531173

1174+
func (q *FakeQuerier) GetDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) {
1175+
q.mutex.RLock()
1176+
defer q.mutex.RUnlock()
1177+
ks := make([]database.DBCryptKey, 0)
1178+
ks = append(ks, q.dbcryptKeys...)
1179+
return ks, nil
1180+
}
1181+
11541182
func (q *FakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) {
11551183
q.mutex.RLock()
11561184
defer q.mutex.RUnlock()
@@ -1393,6 +1421,18 @@ func (q *FakeQuerier) GetGitAuthLink(_ context.Context, arg database.GetGitAuthL
13931421
return database.GitAuthLink{}, sql.ErrNoRows
13941422
}
13951423

1424+
func (q *FakeQuerier) GetGitAuthLinksByUserID(_ context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) {
1425+
q.mutex.RLock()
1426+
defer q.mutex.RUnlock()
1427+
gals := make([]database.GitAuthLink, 0)
1428+
for _, gal := range q.gitAuthLinks {
1429+
if gal.UserID == userID {
1430+
gals = append(gals, gal)
1431+
}
1432+
}
1433+
return gals, nil
1434+
}
1435+
13961436
func (q *FakeQuerier) GetGitSSHKey(_ context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
13971437
q.mutex.RLock()
13981438
defer q.mutex.RUnlock()
@@ -2833,6 +2873,18 @@ func (q *FakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params dat
28332873
return database.UserLink{}, sql.ErrNoRows
28342874
}
28352875

2876+
func (q *FakeQuerier) GetUserLinksByUserID(_ context.Context, userID uuid.UUID) ([]database.UserLink, error) {
2877+
q.mutex.RLock()
2878+
defer q.mutex.RUnlock()
2879+
uls := make([]database.UserLink, 0)
2880+
for _, ul := range q.userLinks {
2881+
if ul.UserID == userID {
2882+
uls = append(uls, ul)
2883+
}
2884+
}
2885+
return uls, nil
2886+
}
2887+
28362888
func (q *FakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.GetUsersRow, error) {
28372889
if err := validateDatabaseType(params); err != nil {
28382890
return nil, err
@@ -3846,6 +3898,26 @@ func (q *FakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAudit
38463898
return alog, nil
38473899
}
38483900

3901+
func (q *FakeQuerier) InsertDBCryptKey(_ context.Context, arg database.InsertDBCryptKeyParams) error {
3902+
err := validateDatabaseType(arg)
3903+
if err != nil {
3904+
return err
3905+
}
3906+
3907+
for _, key := range q.dbcryptKeys {
3908+
if key.Number == arg.Number {
3909+
return errDuplicateKey
3910+
}
3911+
}
3912+
3913+
q.dbcryptKeys = append(q.dbcryptKeys, database.DBCryptKey{
3914+
Number: arg.Number,
3915+
ActiveKeyDigest: sql.NullString{String: arg.ActiveKeyDigest, Valid: true},
3916+
Test: arg.Test,
3917+
})
3918+
return nil
3919+
}
3920+
38493921
func (q *FakeQuerier) InsertDERPMeshKey(_ context.Context, id string) error {
38503922
q.mutex.Lock()
38513923
defer q.mutex.Unlock()
@@ -3892,13 +3964,15 @@ func (q *FakeQuerier) InsertGitAuthLink(_ context.Context, arg database.InsertGi
38923964
defer q.mutex.Unlock()
38933965
// nolint:gosimple
38943966
gitAuthLink := database.GitAuthLink{
3895-
ProviderID: arg.ProviderID,
3896-
UserID: arg.UserID,
3897-
CreatedAt: arg.CreatedAt,
3898-
UpdatedAt: arg.UpdatedAt,
3899-
OAuthAccessToken: arg.OAuthAccessToken,
3900-
OAuthRefreshToken: arg.OAuthRefreshToken,
3901-
OAuthExpiry: arg.OAuthExpiry,
3967+
ProviderID: arg.ProviderID,
3968+
UserID: arg.UserID,
3969+
CreatedAt: arg.CreatedAt,
3970+
UpdatedAt: arg.UpdatedAt,
3971+
OAuthAccessToken: arg.OAuthAccessToken,
3972+
OAuthAccessTokenKeyID: arg.OAuthAccessTokenKeyID,
3973+
OAuthRefreshToken: arg.OAuthRefreshToken,
3974+
OAuthRefreshTokenKeyID: arg.OAuthRefreshTokenKeyID,
3975+
OAuthExpiry: arg.OAuthExpiry,
39023976
}
39033977
q.gitAuthLinks = append(q.gitAuthLinks, gitAuthLink)
39043978
return gitAuthLink, nil
@@ -4362,12 +4436,14 @@ func (q *FakeQuerier) InsertUserLink(_ context.Context, args database.InsertUser
43624436

43634437
//nolint:gosimple
43644438
link := database.UserLink{
4365-
UserID: args.UserID,
4366-
LoginType: args.LoginType,
4367-
LinkedID: args.LinkedID,
4368-
OAuthAccessToken: args.OAuthAccessToken,
4369-
OAuthRefreshToken: args.OAuthRefreshToken,
4370-
OAuthExpiry: args.OAuthExpiry,
4439+
UserID: args.UserID,
4440+
LoginType: args.LoginType,
4441+
LinkedID: args.LinkedID,
4442+
OAuthAccessToken: args.OAuthAccessToken,
4443+
OAuthAccessTokenKeyID: args.OAuthAccessTokenKeyID,
4444+
OAuthRefreshToken: args.OAuthRefreshToken,
4445+
OAuthRefreshTokenKeyID: args.OAuthRefreshTokenKeyID,
4446+
OAuthExpiry: args.OAuthExpiry,
43714447
}
43724448

43734449
q.userLinks = append(q.userLinks, link)
@@ -4793,6 +4869,46 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg
47934869
return database.WorkspaceProxy{}, sql.ErrNoRows
47944870
}
47954871

4872+
func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error {
4873+
q.mutex.Lock()
4874+
defer q.mutex.Unlock()
4875+
4876+
for i := range q.dbcryptKeys {
4877+
key := q.dbcryptKeys[i]
4878+
4879+
// Is the key already revoked?
4880+
if !key.ActiveKeyDigest.Valid {
4881+
continue
4882+
}
4883+
4884+
if key.ActiveKeyDigest.String != activeKeyDigest {
4885+
continue
4886+
}
4887+
4888+
// Check for foreign key constraints.
4889+
for _, ul := range q.userLinks {
4890+
if (ul.OAuthAccessTokenKeyID.Valid && ul.OAuthAccessTokenKeyID.String == activeKeyDigest) ||
4891+
(ul.OAuthRefreshTokenKeyID.Valid && ul.OAuthRefreshTokenKeyID.String == activeKeyDigest) {
4892+
return errForeignKeyConstraint
4893+
}
4894+
}
4895+
for _, gal := range q.gitAuthLinks {
4896+
if (gal.OAuthAccessTokenKeyID.Valid && gal.OAuthAccessTokenKeyID.String == activeKeyDigest) ||
4897+
(gal.OAuthRefreshTokenKeyID.Valid && gal.OAuthRefreshTokenKeyID.String == activeKeyDigest) {
4898+
return errForeignKeyConstraint
4899+
}
4900+
}
4901+
4902+
// Revoke the key.
4903+
q.dbcryptKeys[i].RevokedAt = sql.NullTime{Time: dbtime.Now(), Valid: true}
4904+
q.dbcryptKeys[i].RevokedKeyDigest = sql.NullString{String: key.ActiveKeyDigest.String, Valid: true}
4905+
q.dbcryptKeys[i].ActiveKeyDigest = sql.NullString{}
4906+
return nil
4907+
}
4908+
4909+
return sql.ErrNoRows
4910+
}
4911+
47964912
func (*FakeQuerier) TryAcquireLock(_ context.Context, _ int64) (bool, error) {
47974913
return false, xerrors.New("TryAcquireLock must only be called within a transaction")
47984914
}
@@ -4834,7 +4950,9 @@ func (q *FakeQuerier) UpdateGitAuthLink(_ context.Context, arg database.UpdateGi
48344950
}
48354951
gitAuthLink.UpdatedAt = arg.UpdatedAt
48364952
gitAuthLink.OAuthAccessToken = arg.OAuthAccessToken
4953+
gitAuthLink.OAuthAccessTokenKeyID = arg.OAuthAccessTokenKeyID
48374954
gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken
4955+
gitAuthLink.OAuthRefreshTokenKeyID = arg.OAuthRefreshTokenKeyID
48384956
gitAuthLink.OAuthExpiry = arg.OAuthExpiry
48394957
q.gitAuthLinks[index] = gitAuthLink
48404958

@@ -5306,7 +5424,9 @@ func (q *FakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
53065424
for i, link := range q.userLinks {
53075425
if link.UserID == params.UserID && link.LoginType == params.LoginType {
53085426
link.OAuthAccessToken = params.OAuthAccessToken
5427+
link.OAuthAccessTokenKeyID = params.OAuthAccessTokenKeyID
53095428
link.OAuthRefreshToken = params.OAuthRefreshToken
5429+
link.OAuthRefreshTokenKeyID = params.OAuthRefreshTokenKeyID
53105430
link.OAuthExpiry = params.OAuthExpiry
53115431

53125432
q.userLinks[i] = link

coderd/database/dbgen/dbgen.go

+17-13
Original file line numberDiff line numberDiff line change
@@ -470,12 +470,14 @@ func File(t testing.TB, db database.Store, orig database.File) database.File {
470470

471471
func UserLink(t testing.TB, db database.Store, orig database.UserLink) database.UserLink {
472472
link, err := db.InsertUserLink(genCtx, database.InsertUserLinkParams{
473-
UserID: takeFirst(orig.UserID, uuid.New()),
474-
LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub),
475-
LinkedID: takeFirst(orig.LinkedID),
476-
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
477-
OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
478-
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
473+
UserID: takeFirst(orig.UserID, uuid.New()),
474+
LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub),
475+
LinkedID: takeFirst(orig.LinkedID),
476+
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
477+
OAuthAccessTokenKeyID: takeFirst(orig.OAuthAccessTokenKeyID, sql.NullString{}),
478+
OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()),
479+
OAuthRefreshTokenKeyID: takeFirst(orig.OAuthRefreshTokenKeyID, sql.NullString{}),
480+
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
479481
})
480482

481483
require.NoError(t, err, "insert link")
@@ -484,13 +486,15 @@ func UserLink(t testing.TB, db database.Store, orig database.UserLink) database.
484486

485487
func GitAuthLink(t testing.TB, db database.Store, orig database.GitAuthLink) database.GitAuthLink {
486488
link, err := db.InsertGitAuthLink(genCtx, database.InsertGitAuthLinkParams{
487-
ProviderID: takeFirst(orig.ProviderID, uuid.New().String()),
488-
UserID: takeFirst(orig.UserID, uuid.New()),
489-
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
490-
OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
491-
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
492-
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
493-
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
489+
ProviderID: takeFirst(orig.ProviderID, uuid.New().String()),
490+
UserID: takeFirst(orig.UserID, uuid.New()),
491+
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
492+
OAuthAccessTokenKeyID: takeFirst(orig.OAuthAccessTokenKeyID, sql.NullString{}),
493+
OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()),
494+
OAuthRefreshTokenKeyID: takeFirst(orig.OAuthRefreshTokenKeyID, sql.NullString{}),
495+
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
496+
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
497+
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
494498
})
495499

496500
require.NoError(t, err, "insert git auth link")

0 commit comments

Comments
 (0)