@@ -31,6 +31,11 @@ import (
31
31
32
32
var validProxyByHostnameRegex = regexp .MustCompile (`^[a-zA-Z0-9._-]+$` )
33
33
34
+ var errForeignKeyConstraint = & pq.Error {
35
+ Code : "23503" ,
36
+ Message : "update or delete on table violates foreign key constraint" ,
37
+ }
38
+
34
39
var errDuplicateKey = & pq.Error {
35
40
Code : "23505" ,
36
41
Message : "duplicate key value violates unique constraint" ,
@@ -45,6 +50,7 @@ func New() database.Store {
45
50
organizationMembers : make ([]database.OrganizationMember , 0 ),
46
51
organizations : make ([]database.Organization , 0 ),
47
52
users : make ([]database.User , 0 ),
53
+ dbcryptKeys : make ([]database.DBCryptKey , 0 ),
48
54
gitAuthLinks : make ([]database.GitAuthLink , 0 ),
49
55
groups : make ([]database.Group , 0 ),
50
56
groupMembers : make ([]database.GroupMember , 0 ),
@@ -117,6 +123,7 @@ type data struct {
117
123
// New tables
118
124
workspaceAgentStats []database.WorkspaceAgentStat
119
125
auditLogs []database.AuditLog
126
+ dbcryptKeys []database.DBCryptKey
120
127
files []database.File
121
128
gitAuthLinks []database.GitAuthLink
122
129
gitSSHKey []database.GitSSHKey
@@ -665,6 +672,19 @@ func (q *FakeQuerier) isEveryoneGroup(id uuid.UUID) bool {
665
672
return false
666
673
}
667
674
675
+ func (q * FakeQuerier ) GetActiveDBCryptKeys (_ context.Context ) ([]database.DBCryptKey , error ) {
676
+ q .mutex .RLock ()
677
+ defer q .mutex .RUnlock ()
678
+ ks := make ([]database.DBCryptKey , 0 , len (q .dbcryptKeys ))
679
+ for _ , k := range q .dbcryptKeys {
680
+ if ! k .ActiveKeyDigest .Valid {
681
+ continue
682
+ }
683
+ ks = append ([]database.DBCryptKey {}, k )
684
+ }
685
+ return ks , nil
686
+ }
687
+
668
688
func (* FakeQuerier ) AcquireLock (_ context.Context , _ int64 ) error {
669
689
return xerrors .New ("AcquireLock must only be called within a transaction" )
670
690
}
@@ -1151,6 +1171,14 @@ func (q *FakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U
1151
1171
}, nil
1152
1172
}
1153
1173
1174
+ func (q * FakeQuerier ) GetDBCryptKeys (_ context.Context ) ([]database.DBCryptKey , error ) {
1175
+ q .mutex .RLock ()
1176
+ defer q .mutex .RUnlock ()
1177
+ ks := make ([]database.DBCryptKey , 0 )
1178
+ ks = append (ks , q .dbcryptKeys ... )
1179
+ return ks , nil
1180
+ }
1181
+
1154
1182
func (q * FakeQuerier ) GetDERPMeshKey (_ context.Context ) (string , error ) {
1155
1183
q .mutex .RLock ()
1156
1184
defer q .mutex .RUnlock ()
@@ -1393,6 +1421,18 @@ func (q *FakeQuerier) GetGitAuthLink(_ context.Context, arg database.GetGitAuthL
1393
1421
return database.GitAuthLink {}, sql .ErrNoRows
1394
1422
}
1395
1423
1424
+ func (q * FakeQuerier ) GetGitAuthLinksByUserID (_ context.Context , userID uuid.UUID ) ([]database.GitAuthLink , error ) {
1425
+ q .mutex .RLock ()
1426
+ defer q .mutex .RUnlock ()
1427
+ gals := make ([]database.GitAuthLink , 0 )
1428
+ for _ , gal := range q .gitAuthLinks {
1429
+ if gal .UserID == userID {
1430
+ gals = append (gals , gal )
1431
+ }
1432
+ }
1433
+ return gals , nil
1434
+ }
1435
+
1396
1436
func (q * FakeQuerier ) GetGitSSHKey (_ context.Context , userID uuid.UUID ) (database.GitSSHKey , error ) {
1397
1437
q .mutex .RLock ()
1398
1438
defer q .mutex .RUnlock ()
@@ -2833,6 +2873,18 @@ func (q *FakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params dat
2833
2873
return database.UserLink {}, sql .ErrNoRows
2834
2874
}
2835
2875
2876
+ func (q * FakeQuerier ) GetUserLinksByUserID (_ context.Context , userID uuid.UUID ) ([]database.UserLink , error ) {
2877
+ q .mutex .RLock ()
2878
+ defer q .mutex .RUnlock ()
2879
+ uls := make ([]database.UserLink , 0 )
2880
+ for _ , ul := range q .userLinks {
2881
+ if ul .UserID == userID {
2882
+ uls = append (uls , ul )
2883
+ }
2884
+ }
2885
+ return uls , nil
2886
+ }
2887
+
2836
2888
func (q * FakeQuerier ) GetUsers (_ context.Context , params database.GetUsersParams ) ([]database.GetUsersRow , error ) {
2837
2889
if err := validateDatabaseType (params ); err != nil {
2838
2890
return nil , err
@@ -3846,6 +3898,26 @@ func (q *FakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAudit
3846
3898
return alog , nil
3847
3899
}
3848
3900
3901
+ func (q * FakeQuerier ) InsertDBCryptKey (_ context.Context , arg database.InsertDBCryptKeyParams ) error {
3902
+ err := validateDatabaseType (arg )
3903
+ if err != nil {
3904
+ return err
3905
+ }
3906
+
3907
+ for _ , key := range q .dbcryptKeys {
3908
+ if key .Number == arg .Number {
3909
+ return errDuplicateKey
3910
+ }
3911
+ }
3912
+
3913
+ q .dbcryptKeys = append (q .dbcryptKeys , database.DBCryptKey {
3914
+ Number : arg .Number ,
3915
+ ActiveKeyDigest : sql.NullString {String : arg .ActiveKeyDigest , Valid : true },
3916
+ Test : arg .Test ,
3917
+ })
3918
+ return nil
3919
+ }
3920
+
3849
3921
func (q * FakeQuerier ) InsertDERPMeshKey (_ context.Context , id string ) error {
3850
3922
q .mutex .Lock ()
3851
3923
defer q .mutex .Unlock ()
@@ -3892,13 +3964,15 @@ func (q *FakeQuerier) InsertGitAuthLink(_ context.Context, arg database.InsertGi
3892
3964
defer q .mutex .Unlock ()
3893
3965
// nolint:gosimple
3894
3966
gitAuthLink := database.GitAuthLink {
3895
- ProviderID : arg .ProviderID ,
3896
- UserID : arg .UserID ,
3897
- CreatedAt : arg .CreatedAt ,
3898
- UpdatedAt : arg .UpdatedAt ,
3899
- OAuthAccessToken : arg .OAuthAccessToken ,
3900
- OAuthRefreshToken : arg .OAuthRefreshToken ,
3901
- OAuthExpiry : arg .OAuthExpiry ,
3967
+ ProviderID : arg .ProviderID ,
3968
+ UserID : arg .UserID ,
3969
+ CreatedAt : arg .CreatedAt ,
3970
+ UpdatedAt : arg .UpdatedAt ,
3971
+ OAuthAccessToken : arg .OAuthAccessToken ,
3972
+ OAuthAccessTokenKeyID : arg .OAuthAccessTokenKeyID ,
3973
+ OAuthRefreshToken : arg .OAuthRefreshToken ,
3974
+ OAuthRefreshTokenKeyID : arg .OAuthRefreshTokenKeyID ,
3975
+ OAuthExpiry : arg .OAuthExpiry ,
3902
3976
}
3903
3977
q .gitAuthLinks = append (q .gitAuthLinks , gitAuthLink )
3904
3978
return gitAuthLink , nil
@@ -4362,12 +4436,14 @@ func (q *FakeQuerier) InsertUserLink(_ context.Context, args database.InsertUser
4362
4436
4363
4437
//nolint:gosimple
4364
4438
link := database.UserLink {
4365
- UserID : args .UserID ,
4366
- LoginType : args .LoginType ,
4367
- LinkedID : args .LinkedID ,
4368
- OAuthAccessToken : args .OAuthAccessToken ,
4369
- OAuthRefreshToken : args .OAuthRefreshToken ,
4370
- OAuthExpiry : args .OAuthExpiry ,
4439
+ UserID : args .UserID ,
4440
+ LoginType : args .LoginType ,
4441
+ LinkedID : args .LinkedID ,
4442
+ OAuthAccessToken : args .OAuthAccessToken ,
4443
+ OAuthAccessTokenKeyID : args .OAuthAccessTokenKeyID ,
4444
+ OAuthRefreshToken : args .OAuthRefreshToken ,
4445
+ OAuthRefreshTokenKeyID : args .OAuthRefreshTokenKeyID ,
4446
+ OAuthExpiry : args .OAuthExpiry ,
4371
4447
}
4372
4448
4373
4449
q .userLinks = append (q .userLinks , link )
@@ -4793,6 +4869,46 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg
4793
4869
return database.WorkspaceProxy {}, sql .ErrNoRows
4794
4870
}
4795
4871
4872
+ func (q * FakeQuerier ) RevokeDBCryptKey (_ context.Context , activeKeyDigest string ) error {
4873
+ q .mutex .Lock ()
4874
+ defer q .mutex .Unlock ()
4875
+
4876
+ for i := range q .dbcryptKeys {
4877
+ key := q .dbcryptKeys [i ]
4878
+
4879
+ // Is the key already revoked?
4880
+ if ! key .ActiveKeyDigest .Valid {
4881
+ continue
4882
+ }
4883
+
4884
+ if key .ActiveKeyDigest .String != activeKeyDigest {
4885
+ continue
4886
+ }
4887
+
4888
+ // Check for foreign key constraints.
4889
+ for _ , ul := range q .userLinks {
4890
+ if (ul .OAuthAccessTokenKeyID .Valid && ul .OAuthAccessTokenKeyID .String == activeKeyDigest ) ||
4891
+ (ul .OAuthRefreshTokenKeyID .Valid && ul .OAuthRefreshTokenKeyID .String == activeKeyDigest ) {
4892
+ return errForeignKeyConstraint
4893
+ }
4894
+ }
4895
+ for _ , gal := range q .gitAuthLinks {
4896
+ if (gal .OAuthAccessTokenKeyID .Valid && gal .OAuthAccessTokenKeyID .String == activeKeyDigest ) ||
4897
+ (gal .OAuthRefreshTokenKeyID .Valid && gal .OAuthRefreshTokenKeyID .String == activeKeyDigest ) {
4898
+ return errForeignKeyConstraint
4899
+ }
4900
+ }
4901
+
4902
+ // Revoke the key.
4903
+ q .dbcryptKeys [i ].RevokedAt = sql.NullTime {Time : dbtime .Now (), Valid : true }
4904
+ q .dbcryptKeys [i ].RevokedKeyDigest = sql.NullString {String : key .ActiveKeyDigest .String , Valid : true }
4905
+ q .dbcryptKeys [i ].ActiveKeyDigest = sql.NullString {}
4906
+ return nil
4907
+ }
4908
+
4909
+ return sql .ErrNoRows
4910
+ }
4911
+
4796
4912
func (* FakeQuerier ) TryAcquireLock (_ context.Context , _ int64 ) (bool , error ) {
4797
4913
return false , xerrors .New ("TryAcquireLock must only be called within a transaction" )
4798
4914
}
@@ -4834,7 +4950,9 @@ func (q *FakeQuerier) UpdateGitAuthLink(_ context.Context, arg database.UpdateGi
4834
4950
}
4835
4951
gitAuthLink .UpdatedAt = arg .UpdatedAt
4836
4952
gitAuthLink .OAuthAccessToken = arg .OAuthAccessToken
4953
+ gitAuthLink .OAuthAccessTokenKeyID = arg .OAuthAccessTokenKeyID
4837
4954
gitAuthLink .OAuthRefreshToken = arg .OAuthRefreshToken
4955
+ gitAuthLink .OAuthRefreshTokenKeyID = arg .OAuthRefreshTokenKeyID
4838
4956
gitAuthLink .OAuthExpiry = arg .OAuthExpiry
4839
4957
q .gitAuthLinks [index ] = gitAuthLink
4840
4958
@@ -5306,7 +5424,9 @@ func (q *FakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
5306
5424
for i , link := range q .userLinks {
5307
5425
if link .UserID == params .UserID && link .LoginType == params .LoginType {
5308
5426
link .OAuthAccessToken = params .OAuthAccessToken
5427
+ link .OAuthAccessTokenKeyID = params .OAuthAccessTokenKeyID
5309
5428
link .OAuthRefreshToken = params .OAuthRefreshToken
5429
+ link .OAuthRefreshTokenKeyID = params .OAuthRefreshTokenKeyID
5310
5430
link .OAuthExpiry = params .OAuthExpiry
5311
5431
5312
5432
q .userLinks [i ] = link
0 commit comments