Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 384873a

Browse files
authored
feat: add wsproxy implementation for key fetching (#14917)
1 parent 5315656 commit 384873a

File tree

2 files changed

+709
-0
lines changed

2 files changed

+709
-0
lines changed

enterprise/wsproxy/keycache.go

+224
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
package wsproxy
2+
3+
import (
4+
"context"
5+
"sync"
6+
"time"
7+
8+
"golang.org/x/xerrors"
9+
10+
"cdr.dev/slog"
11+
12+
"github.com/coder/coder/v2/coderd/cryptokeys"
13+
"github.com/coder/coder/v2/codersdk"
14+
"github.com/coder/quartz"
15+
)
16+
17+
const (
18+
// latestSequence is a special sequence number that represents the latest key.
19+
latestSequence = -1
20+
// refreshInterval is the interval at which the key cache will refresh.
21+
refreshInterval = time.Minute * 10
22+
)
23+
24+
type Fetcher interface {
25+
Fetch(ctx context.Context) ([]codersdk.CryptoKey, error)
26+
}
27+
28+
type CryptoKeyCache struct {
29+
Clock quartz.Clock
30+
refreshCtx context.Context
31+
refreshCancel context.CancelFunc
32+
fetcher Fetcher
33+
logger slog.Logger
34+
35+
mu sync.Mutex
36+
keys map[int32]codersdk.CryptoKey
37+
lastFetch time.Time
38+
refresher *quartz.Timer
39+
fetching bool
40+
closed bool
41+
cond *sync.Cond
42+
}
43+
44+
func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client Fetcher, opts ...func(*CryptoKeyCache)) (*CryptoKeyCache, error) {
45+
cache := &CryptoKeyCache{
46+
Clock: quartz.NewReal(),
47+
logger: log,
48+
fetcher: client,
49+
}
50+
51+
for _, opt := range opts {
52+
opt(cache)
53+
}
54+
55+
cache.cond = sync.NewCond(&cache.mu)
56+
cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx)
57+
cache.refresher = cache.Clock.AfterFunc(refreshInterval, cache.refresh)
58+
59+
keys, err := cache.cryptoKeys(ctx)
60+
if err != nil {
61+
cache.refreshCancel()
62+
return nil, xerrors.Errorf("initial fetch: %w", err)
63+
}
64+
cache.keys = keys
65+
66+
return cache, nil
67+
}
68+
69+
func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error) {
70+
return k.cryptoKey(ctx, latestSequence)
71+
}
72+
73+
func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) {
74+
return k.cryptoKey(ctx, sequence)
75+
}
76+
77+
func (k *CryptoKeyCache) cryptoKey(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) {
78+
k.mu.Lock()
79+
defer k.mu.Unlock()
80+
81+
if k.closed {
82+
return codersdk.CryptoKey{}, cryptokeys.ErrClosed
83+
}
84+
85+
var key codersdk.CryptoKey
86+
var ok bool
87+
for key, ok = k.key(sequence); !ok && k.fetching && !k.closed; {
88+
k.cond.Wait()
89+
}
90+
91+
if k.closed {
92+
return codersdk.CryptoKey{}, cryptokeys.ErrClosed
93+
}
94+
95+
if ok {
96+
return checkKey(key, sequence, k.Clock.Now())
97+
}
98+
99+
k.fetching = true
100+
k.mu.Unlock()
101+
102+
keys, err := k.cryptoKeys(ctx)
103+
if err != nil {
104+
return codersdk.CryptoKey{}, xerrors.Errorf("get keys: %w", err)
105+
}
106+
107+
k.mu.Lock()
108+
k.lastFetch = k.Clock.Now()
109+
k.refresher.Reset(refreshInterval)
110+
k.keys = keys
111+
k.fetching = false
112+
k.cond.Broadcast()
113+
114+
key, ok = k.key(sequence)
115+
if !ok {
116+
return codersdk.CryptoKey{}, cryptokeys.ErrKeyNotFound
117+
}
118+
119+
return checkKey(key, sequence, k.Clock.Now())
120+
}
121+
122+
func (k *CryptoKeyCache) key(sequence int32) (codersdk.CryptoKey, bool) {
123+
if sequence == latestSequence {
124+
return k.keys[latestSequence], k.keys[latestSequence].CanSign(k.Clock.Now())
125+
}
126+
127+
key, ok := k.keys[sequence]
128+
return key, ok
129+
}
130+
131+
func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (codersdk.CryptoKey, error) {
132+
if sequence == latestSequence {
133+
if !key.CanSign(now) {
134+
return codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid
135+
}
136+
return key, nil
137+
}
138+
139+
if !key.CanVerify(now) {
140+
return codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid
141+
}
142+
143+
return key, nil
144+
}
145+
146+
// refresh fetches the keys from the control plane and updates the cache.
147+
func (k *CryptoKeyCache) refresh() {
148+
now := k.Clock.Now("CryptoKeyCache", "refresh")
149+
k.mu.Lock()
150+
151+
if k.closed {
152+
k.mu.Unlock()
153+
return
154+
}
155+
156+
// If something's already fetching, we don't need to do anything.
157+
if k.fetching {
158+
k.mu.Unlock()
159+
return
160+
}
161+
162+
// There's a window we must account for where the timer fires while a fetch
163+
// is ongoing but prior to the timer getting reset. In this case we want to
164+
// avoid double fetching.
165+
if now.Sub(k.lastFetch) < refreshInterval {
166+
k.mu.Unlock()
167+
return
168+
}
169+
170+
k.fetching = true
171+
172+
k.mu.Unlock()
173+
keys, err := k.cryptoKeys(k.refreshCtx)
174+
if err != nil {
175+
k.logger.Error(k.refreshCtx, "fetch crypto keys", slog.Error(err))
176+
return
177+
}
178+
179+
k.mu.Lock()
180+
defer k.mu.Unlock()
181+
182+
k.lastFetch = k.Clock.Now()
183+
k.refresher.Reset(refreshInterval)
184+
k.keys = keys
185+
k.fetching = false
186+
k.cond.Broadcast()
187+
}
188+
189+
// cryptoKeys queries the control plane for the crypto keys.
190+
// Outside of initialization, this should only be called by fetch.
191+
func (k *CryptoKeyCache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) {
192+
keys, err := k.fetcher.Fetch(ctx)
193+
if err != nil {
194+
return nil, xerrors.Errorf("crypto keys: %w", err)
195+
}
196+
cache := toKeyMap(keys, k.Clock.Now())
197+
return cache, nil
198+
}
199+
200+
func toKeyMap(keys []codersdk.CryptoKey, now time.Time) map[int32]codersdk.CryptoKey {
201+
m := make(map[int32]codersdk.CryptoKey)
202+
var latest codersdk.CryptoKey
203+
for _, key := range keys {
204+
m[key.Sequence] = key
205+
if key.Sequence > latest.Sequence && key.CanSign(now) {
206+
m[latestSequence] = key
207+
}
208+
}
209+
return m
210+
}
211+
212+
func (k *CryptoKeyCache) Close() {
213+
k.mu.Lock()
214+
defer k.mu.Unlock()
215+
216+
if k.closed {
217+
return
218+
}
219+
220+
k.closed = true
221+
k.refreshCancel()
222+
k.refresher.Stop()
223+
k.cond.Broadcast()
224+
}

0 commit comments

Comments
 (0)