Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ea659b0

Browse files
committed
Refactor DBKeyCache with robust invalidation timing
Introduce a timer-based cache invalidation system in DBCache to enhance reliability. The new implementation minimizes cache invalidation race conditions during key fetching, ensuring consistent cache state management. Additions include a 'Close' method for releasing resources, such as timers, and test improvements for timer behavior validation.
1 parent 5803092 commit ea659b0

File tree

5 files changed

+188
-110
lines changed

5 files changed

+188
-110
lines changed

coderd/cryptokeys/dbkeycache.go

Lines changed: 46 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ import (
1414
"github.com/coder/quartz"
1515
)
1616

17+
// never represents the maximum value for a time.Duration.
18+
const never = 1<<63 - 1
19+
1720
// DBCache implements Keycache for callers with access to the database.
1821
type DBCache struct {
1922
db database.Store
@@ -25,7 +28,9 @@ type DBCache struct {
2528
keysMu sync.RWMutex
2629
keys map[int32]database.CryptoKey
2730
latestKey database.CryptoKey
28-
fetched chan struct{}
31+
timer *quartz.Timer
32+
// invalidateAt is the time at which the keys cache should be invalidated.
33+
invalidateAt time.Time
2934
}
3035

3136
type DBCacheOption func(*DBCache)
@@ -36,23 +41,22 @@ func WithDBCacheClock(clock quartz.Clock) DBCacheOption {
3641
}
3742
}
3843

39-
// NewDBCache creates a new DBCache. It starts a background
40-
// process that periodically refreshes the cache. The context should
41-
// be canceled to stop the background process.
42-
func NewDBCache(ctx context.Context, logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*DBCache)) *DBCache {
44+
// NewDBCache creates a new DBCache. Close should be called to
45+
// release resources associated with its internal timer.
46+
func NewDBCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*DBCache)) *DBCache {
4347
d := &DBCache{
4448
db: db,
4549
feature: feature,
4650
clock: quartz.NewReal(),
4751
logger: logger,
48-
fetched: make(chan struct{}),
4952
}
5053

5154
for _, opt := range opts {
5255
opt(d)
5356
}
5457

55-
go d.clear(ctx)
58+
d.timer = d.clock.AfterFunc(never, d.clear)
59+
5660
return d
5761
}
5862

@@ -76,11 +80,10 @@ func (d *DBCache) Verifying(ctx context.Context, sequence int32) (codersdk.Crypt
7680
return checkKey(key, now)
7781
}
7882

79-
cache, latest, err := d.fetch(ctx)
83+
err := d.fetch(ctx)
8084
if err != nil {
8185
return codersdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err)
8286
}
83-
d.keys, d.latestKey = cache, latest
8487

8588
key, ok = d.keys[sequence]
8689
if !ok {
@@ -110,47 +113,42 @@ func (d *DBCache) Signing(ctx context.Context) (codersdk.CryptoKey, error) {
110113
}
111114

112115
// Refetch all keys for this feature so we can find the latest valid key.
113-
cache, latest, err := d.fetch(ctx)
116+
err := d.fetch(ctx)
114117
if err != nil {
115118
return codersdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err)
116119
}
117-
d.keys, d.latestKey = cache, latest
118120

119121
return db2sdk.CryptoKey(d.latestKey), nil
120122
}
121123

122-
func (d *DBCache) clear(ctx context.Context) {
123-
for {
124-
fired := make(chan struct{})
125-
timer := d.clock.AfterFunc(time.Minute*10, func() {
126-
defer close(fired)
127-
128-
// There's a small window where the timer fires as we're fetching
129-
// keys that could result in us immediately invalidating the cache that we just populated.
130-
d.keysMu.Lock()
131-
defer d.keysMu.Unlock()
132-
d.keys = nil
133-
d.latestKey = database.CryptoKey{}
134-
})
135-
136-
select {
137-
case <-ctx.Done():
138-
return
139-
case <-d.fetched:
140-
timer.Stop()
141-
case <-fired:
142-
}
143-
}
124+
// clear invalidates the cache. This forces the subsequent call to fetch fresh keys.
125+
func (d *DBCache) clear() {
126+
now := d.clock.Now("DBCache", "clear")
127+
d.keysMu.Lock()
128+
defer d.keysMu.Unlock()
129+
// Check if we raced with a fetch. It's possible that the timer fired and we
130+
// lost the race to the mutex. We want to avoid invalidating
131+
// a cache that was just refetched.
132+
if now.Before(d.invalidateAt) {
133+
return
134+
}
135+
d.keys = nil
136+
d.latestKey = database.CryptoKey{}
144137
}
145138

146139
// fetch fetches all keys for the given feature and determines the latest key.
147-
func (d *DBCache) fetch(ctx context.Context) (map[int32]database.CryptoKey, database.CryptoKey, error) {
148-
now := d.clock.Now()
140+
// It must be called while holding the keysMu lock.
141+
func (d *DBCache) fetch(ctx context.Context) error {
149142
keys, err := d.db.GetCryptoKeysByFeature(ctx, d.feature)
150143
if err != nil {
151-
return nil, database.CryptoKey{}, xerrors.Errorf("get crypto keys by feature: %w", err)
144+
return xerrors.Errorf("get crypto keys by feature: %w", err)
152145
}
153146

147+
now := d.clock.Now()
148+
d.timer.Stop()
149+
d.timer = d.newTimer()
150+
d.invalidateAt = now.Add(time.Minute * 10)
151+
154152
cache := make(map[int32]database.CryptoKey)
155153
var latest database.CryptoKey
156154
for _, key := range keys {
@@ -161,20 +159,15 @@ func (d *DBCache) fetch(ctx context.Context) (map[int32]database.CryptoKey, data
161159
}
162160

163161
if len(cache) == 0 {
164-
return nil, database.CryptoKey{}, ErrKeyNotFound
162+
return ErrKeyNotFound
165163
}
166164

167165
if !latest.CanSign(now) {
168-
return nil, database.CryptoKey{}, ErrKeyInvalid
166+
return ErrKeyInvalid
169167
}
170168

171-
select {
172-
case <-ctx.Done():
173-
return nil, database.CryptoKey{}, ctx.Err()
174-
case d.fetched <- struct{}{}:
175-
}
176-
177-
return cache, latest, nil
169+
d.keys, d.latestKey = cache, latest
170+
return nil
178171
}
179172

180173
func checkKey(key database.CryptoKey, now time.Time) (codersdk.CryptoKey, error) {
@@ -184,3 +177,11 @@ func checkKey(key database.CryptoKey, now time.Time) (codersdk.CryptoKey, error)
184177

185178
return db2sdk.CryptoKey(key), nil
186179
}
180+
181+
func (d *DBCache) newTimer() *quartz.Timer {
182+
return d.clock.AfterFunc(time.Minute*10, d.clear)
183+
}
184+
185+
func (d *DBCache) Close() {
186+
d.timer.Stop()
187+
}

0 commit comments

Comments
 (0)