From 91c4f660265151fc229392ca4c0e4ff31fd25836 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Tue, 20 Jul 2021 23:42:27 +0000 Subject: [PATCH 1/7] feat: Add DialCache for key-based connection caching --- wsnet/cache.go | 154 ++++++++++++++++++++++++++++++++++++++++++++ wsnet/cache_test.go | 67 +++++++++++++++++++ wsnet/dial.go | 6 ++ 3 files changed, 227 insertions(+) create mode 100644 wsnet/cache.go create mode 100644 wsnet/cache_test.go diff --git a/wsnet/cache.go b/wsnet/cache.go new file mode 100644 index 00000000..e62a2904 --- /dev/null +++ b/wsnet/cache.go @@ -0,0 +1,154 @@ +package wsnet + +import ( + "context" + "sync" + "time" + + "golang.org/x/sync/singleflight" +) + +// dialerFunc is used to reference a dialer returned for caching. +type dialerFunc func(ctx context.Context, key string, options *DialOptions) (*Dialer, error) + +// DialCache constructs a new DialerCache. +// The cache clears connections that: +// 1. Are older than the TTL and have no active user-created connections. +// 2. Have been closed. +func DialCache(ttl time.Duration, dialer dialerFunc) *DialerCache { + dc := &DialerCache{ + ttl: ttl, + dialerFunc: dialer, + closed: make(chan struct{}), + flightGroup: &singleflight.Group{}, + mut: sync.RWMutex{}, + dialers: make(map[string]*Dialer), + atime: make(map[string]time.Time), + } + go dc.init() + return dc +} + +type DialerCache struct { + dialerFunc dialerFunc + ttl time.Duration + flightGroup *singleflight.Group + + closed chan struct{} + mut sync.RWMutex + dialers map[string]*Dialer + atime map[string]time.Time +} + +// init starts the ticker for evicting connections. +func (d *DialerCache) init() { + ticker := time.NewTicker(time.Second * 30) + defer ticker.Stop() + for { + select { + case <-d.closed: + return + case <-ticker.C: + d.evict() + } + } +} + +// evict removes lost/broken/expired connections from the cache. +func (d *DialerCache) evict() { + var wg sync.WaitGroup + d.mut.RLock() + for key, dialer := range d.dialers { + wg.Add(1) + key := key + dialer := dialer + go func() { + defer wg.Done() + + evict := false + select { + case <-dialer.Closed(): + evict = true + default: + } + if dialer.ActiveConnections() == 0 && time.Since(d.atime[key]) >= d.ttl { + evict = true + } + if !evict { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + defer cancel() + err := dialer.Ping(ctx) + if err != nil { + evict = true + } + } + + if evict { + _ = dialer.Close() + d.mut.Lock() + delete(d.atime, key) + delete(d.dialers, key) + d.mut.Unlock() + } + }() + } + d.mut.RUnlock() + wg.Wait() +} + +// Dial returns a Dialer from the cache if one exists with the key provided, +// or dials a new connection using the dialerFunc. +func (d *DialerCache) Dial(ctx context.Context, key string, options *DialOptions) (*Dialer, bool, error) { + d.mut.RLock() + if dialer, ok := d.dialers[key]; ok { + closed := false + select { + case <-dialer.Closed(): + closed = true + default: + } + if !closed { + d.mut.RUnlock() + d.mut.Lock() + d.atime[key] = time.Now() + d.mut.Unlock() + + return dialer, true, nil + } + } + d.mut.RUnlock() + + dialer, err, _ := d.flightGroup.Do(key, func() (interface{}, error) { + dialer, err := d.dialerFunc(ctx, key, options) + if err != nil { + return nil, err + } + d.mut.Lock() + d.dialers[key] = dialer + d.atime[key] = time.Now() + d.mut.Unlock() + + return dialer, nil + }) + if err != nil { + return nil, false, err + } + return dialer.(*Dialer), false, nil +} + +// Close closes all cached dialers. +func (d *DialerCache) Close() error { + d.mut.Lock() + defer d.mut.Unlock() + + for key, dialer := range d.dialers { + d.flightGroup.Forget(key) + + err := dialer.Close() + if err != nil { + return err + } + } + close(d.closed) + return nil +} diff --git a/wsnet/cache_test.go b/wsnet/cache_test.go new file mode 100644 index 00000000..70921674 --- /dev/null +++ b/wsnet/cache_test.go @@ -0,0 +1,67 @@ +package wsnet + +import ( + "context" + "testing" + "time" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/stretchr/testify/require" +) + +func TestCache(t *testing.T) { + t.Run("Caches", func(t *testing.T) { + connectAddr, listenAddr := createDumbBroker(t) + l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") + require.NoError(t, err) + defer l.Close() + + cache := DialCache(time.Hour, func(ctx context.Context, key string, options *DialOptions) (*Dialer, error) { + return DialWebsocket(ctx, connectAddr, options) + }) + _, cached, err := cache.Dial(context.Background(), "example", nil) + require.NoError(t, err) + require.Equal(t, cached, false) + _, cached, err = cache.Dial(context.Background(), "example", nil) + require.NoError(t, err) + require.Equal(t, cached, true) + }) + + t.Run("Create If Closed", func(t *testing.T) { + connectAddr, listenAddr := createDumbBroker(t) + l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") + require.NoError(t, err) + defer l.Close() + + cache := DialCache(time.Hour, func(ctx context.Context, key string, options *DialOptions) (*Dialer, error) { + return DialWebsocket(ctx, connectAddr, options) + }) + + conn, cached, err := cache.Dial(context.Background(), "example", nil) + require.NoError(t, err) + require.Equal(t, cached, false) + require.NoError(t, conn.Close()) + _, cached, err = cache.Dial(context.Background(), "example", nil) + require.NoError(t, err) + require.Equal(t, cached, false) + }) + + t.Run("Evict No Connections", func(t *testing.T) { + connectAddr, listenAddr := createDumbBroker(t) + l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") + require.NoError(t, err) + defer l.Close() + + cache := DialCache(0, func(ctx context.Context, key string, options *DialOptions) (*Dialer, error) { + return DialWebsocket(ctx, connectAddr, options) + }) + + _, cached, err := cache.Dial(context.Background(), "example", nil) + require.NoError(t, err) + require.Equal(t, cached, false) + cache.evict() + _, cached, err = cache.Dial(context.Background(), "example", nil) + require.NoError(t, err) + require.Equal(t, cached, false) + }) +} diff --git a/wsnet/dial.go b/wsnet/dial.go index 050bc574..98de094e 100644 --- a/wsnet/dial.go +++ b/wsnet/dial.go @@ -246,6 +246,12 @@ func (d *Dialer) ActiveConnections() int { // Close closes the RTC connection. // All data channels dialed will be closed. func (d *Dialer) Close() error { + select { + case <-d.closedChan: + return nil + default: + } + close(d.closedChan) return d.rtc.Close() } From 857a743c14f57ef75d9705daa60473397cdfb5fe Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Tue, 20 Jul 2021 23:58:02 +0000 Subject: [PATCH 2/7] Remove DialOptions --- wsnet/cache.go | 6 +++--- wsnet/cache_test.go | 24 ++++++++++++------------ 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/wsnet/cache.go b/wsnet/cache.go index e62a2904..052ba415 100644 --- a/wsnet/cache.go +++ b/wsnet/cache.go @@ -9,7 +9,7 @@ import ( ) // dialerFunc is used to reference a dialer returned for caching. -type dialerFunc func(ctx context.Context, key string, options *DialOptions) (*Dialer, error) +type dialerFunc func(ctx context.Context, key string) (*Dialer, error) // DialCache constructs a new DialerCache. // The cache clears connections that: @@ -98,7 +98,7 @@ func (d *DialerCache) evict() { // Dial returns a Dialer from the cache if one exists with the key provided, // or dials a new connection using the dialerFunc. -func (d *DialerCache) Dial(ctx context.Context, key string, options *DialOptions) (*Dialer, bool, error) { +func (d *DialerCache) Dial(ctx context.Context, key string) (*Dialer, bool, error) { d.mut.RLock() if dialer, ok := d.dialers[key]; ok { closed := false @@ -119,7 +119,7 @@ func (d *DialerCache) Dial(ctx context.Context, key string, options *DialOptions d.mut.RUnlock() dialer, err, _ := d.flightGroup.Do(key, func() (interface{}, error) { - dialer, err := d.dialerFunc(ctx, key, options) + dialer, err := d.dialerFunc(ctx, key) if err != nil { return nil, err } diff --git a/wsnet/cache_test.go b/wsnet/cache_test.go index 70921674..4b607792 100644 --- a/wsnet/cache_test.go +++ b/wsnet/cache_test.go @@ -16,13 +16,13 @@ func TestCache(t *testing.T) { require.NoError(t, err) defer l.Close() - cache := DialCache(time.Hour, func(ctx context.Context, key string, options *DialOptions) (*Dialer, error) { - return DialWebsocket(ctx, connectAddr, options) + cache := DialCache(time.Hour, func(ctx context.Context, key string) (*Dialer, error) { + return DialWebsocket(ctx, connectAddr, nil) }) - _, cached, err := cache.Dial(context.Background(), "example", nil) + _, cached, err := cache.Dial(context.Background(), "example") require.NoError(t, err) require.Equal(t, cached, false) - _, cached, err = cache.Dial(context.Background(), "example", nil) + _, cached, err = cache.Dial(context.Background(), "example") require.NoError(t, err) require.Equal(t, cached, true) }) @@ -33,15 +33,15 @@ func TestCache(t *testing.T) { require.NoError(t, err) defer l.Close() - cache := DialCache(time.Hour, func(ctx context.Context, key string, options *DialOptions) (*Dialer, error) { - return DialWebsocket(ctx, connectAddr, options) + cache := DialCache(time.Hour, func(ctx context.Context, key string) (*Dialer, error) { + return DialWebsocket(ctx, connectAddr, nil) }) - conn, cached, err := cache.Dial(context.Background(), "example", nil) + conn, cached, err := cache.Dial(context.Background(), "example") require.NoError(t, err) require.Equal(t, cached, false) require.NoError(t, conn.Close()) - _, cached, err = cache.Dial(context.Background(), "example", nil) + _, cached, err = cache.Dial(context.Background(), "example") require.NoError(t, err) require.Equal(t, cached, false) }) @@ -52,15 +52,15 @@ func TestCache(t *testing.T) { require.NoError(t, err) defer l.Close() - cache := DialCache(0, func(ctx context.Context, key string, options *DialOptions) (*Dialer, error) { - return DialWebsocket(ctx, connectAddr, options) + cache := DialCache(0, func(ctx context.Context, key string) (*Dialer, error) { + return DialWebsocket(ctx, connectAddr, nil) }) - _, cached, err := cache.Dial(context.Background(), "example", nil) + _, cached, err := cache.Dial(context.Background(), "example") require.NoError(t, err) require.Equal(t, cached, false) cache.evict() - _, cached, err = cache.Dial(context.Background(), "example", nil) + _, cached, err = cache.Dial(context.Background(), "example") require.NoError(t, err) require.Equal(t, cached, false) }) From 0af778936b2642f0e6678110d4de69d8b14bccbe Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Wed, 21 Jul 2021 00:04:13 +0000 Subject: [PATCH 3/7] Move DialFunc to Dial --- wsnet/cache.go | 10 ++++------ wsnet/cache_test.go | 30 +++++++++++++++--------------- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/wsnet/cache.go b/wsnet/cache.go index 052ba415..fdbbc215 100644 --- a/wsnet/cache.go +++ b/wsnet/cache.go @@ -9,16 +9,15 @@ import ( ) // dialerFunc is used to reference a dialer returned for caching. -type dialerFunc func(ctx context.Context, key string) (*Dialer, error) +type dialerFunc func() (*Dialer, error) // DialCache constructs a new DialerCache. // The cache clears connections that: // 1. Are older than the TTL and have no active user-created connections. // 2. Have been closed. -func DialCache(ttl time.Duration, dialer dialerFunc) *DialerCache { +func DialCache(ttl time.Duration) *DialerCache { dc := &DialerCache{ ttl: ttl, - dialerFunc: dialer, closed: make(chan struct{}), flightGroup: &singleflight.Group{}, mut: sync.RWMutex{}, @@ -30,7 +29,6 @@ func DialCache(ttl time.Duration, dialer dialerFunc) *DialerCache { } type DialerCache struct { - dialerFunc dialerFunc ttl time.Duration flightGroup *singleflight.Group @@ -98,7 +96,7 @@ func (d *DialerCache) evict() { // Dial returns a Dialer from the cache if one exists with the key provided, // or dials a new connection using the dialerFunc. -func (d *DialerCache) Dial(ctx context.Context, key string) (*Dialer, bool, error) { +func (d *DialerCache) Dial(ctx context.Context, key string, dialerFunc func() (*Dialer, error)) (*Dialer, bool, error) { d.mut.RLock() if dialer, ok := d.dialers[key]; ok { closed := false @@ -119,7 +117,7 @@ func (d *DialerCache) Dial(ctx context.Context, key string) (*Dialer, bool, erro d.mut.RUnlock() dialer, err, _ := d.flightGroup.Do(key, func() (interface{}, error) { - dialer, err := d.dialerFunc(ctx, key) + dialer, err := dialerFunc() if err != nil { return nil, err } diff --git a/wsnet/cache_test.go b/wsnet/cache_test.go index 4b607792..b84d886c 100644 --- a/wsnet/cache_test.go +++ b/wsnet/cache_test.go @@ -10,19 +10,23 @@ import ( ) func TestCache(t *testing.T) { + dialFunc := func(connectAddr string) func() (*Dialer, error) { + return func() (*Dialer, error) { + return DialWebsocket(context.Background(), connectAddr, nil) + } + } + t.Run("Caches", func(t *testing.T) { connectAddr, listenAddr := createDumbBroker(t) l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") require.NoError(t, err) defer l.Close() - cache := DialCache(time.Hour, func(ctx context.Context, key string) (*Dialer, error) { - return DialWebsocket(ctx, connectAddr, nil) - }) - _, cached, err := cache.Dial(context.Background(), "example") + cache := DialCache(time.Hour) + _, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr)) require.NoError(t, err) require.Equal(t, cached, false) - _, cached, err = cache.Dial(context.Background(), "example") + _, cached, err = cache.Dial(context.Background(), "example", dialFunc(connectAddr)) require.NoError(t, err) require.Equal(t, cached, true) }) @@ -33,15 +37,13 @@ func TestCache(t *testing.T) { require.NoError(t, err) defer l.Close() - cache := DialCache(time.Hour, func(ctx context.Context, key string) (*Dialer, error) { - return DialWebsocket(ctx, connectAddr, nil) - }) + cache := DialCache(time.Hour) - conn, cached, err := cache.Dial(context.Background(), "example") + conn, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr)) require.NoError(t, err) require.Equal(t, cached, false) require.NoError(t, conn.Close()) - _, cached, err = cache.Dial(context.Background(), "example") + _, cached, err = cache.Dial(context.Background(), "example", dialFunc(connectAddr)) require.NoError(t, err) require.Equal(t, cached, false) }) @@ -52,15 +54,13 @@ func TestCache(t *testing.T) { require.NoError(t, err) defer l.Close() - cache := DialCache(0, func(ctx context.Context, key string) (*Dialer, error) { - return DialWebsocket(ctx, connectAddr, nil) - }) + cache := DialCache(0) - _, cached, err := cache.Dial(context.Background(), "example") + _, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr)) require.NoError(t, err) require.Equal(t, cached, false) cache.evict() - _, cached, err = cache.Dial(context.Background(), "example") + _, cached, err = cache.Dial(context.Background(), "example", dialFunc(connectAddr)) require.NoError(t, err) require.Equal(t, cached, false) }) From 6aa80482fcce34aba5c7281aa79417efc0fb5433 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Wed, 21 Jul 2021 00:08:26 +0000 Subject: [PATCH 4/7] Add WS options to dial --- internal/cmd/tunnel.go | 1 + wsnet/cache_test.go | 2 +- wsnet/dial.go | 6 +++--- wsnet/dial_test.go | 22 +++++++++++----------- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/internal/cmd/tunnel.go b/internal/cmd/tunnel.go index 9c12dd37..956e9fd2 100644 --- a/internal/cmd/tunnel.go +++ b/internal/cmd/tunnel.go @@ -112,6 +112,7 @@ func (c *tunnneler) start(ctx context.Context) error { TURNProxyURL: c.brokerAddr, ICEServers: []webrtc.ICEServer{wsnet.TURNProxyICECandidate()}, }, + nil, ) if err != nil { return xerrors.Errorf("creating workspace dialer: %w", err) diff --git a/wsnet/cache_test.go b/wsnet/cache_test.go index b84d886c..798920fc 100644 --- a/wsnet/cache_test.go +++ b/wsnet/cache_test.go @@ -12,7 +12,7 @@ import ( func TestCache(t *testing.T) { dialFunc := func(connectAddr string) func() (*Dialer, error) { return func() (*Dialer, error) { - return DialWebsocket(context.Background(), connectAddr, nil) + return DialWebsocket(context.Background(), connectAddr, nil, nil) } } diff --git a/wsnet/dial.go b/wsnet/dial.go index 98de094e..af4b422c 100644 --- a/wsnet/dial.go +++ b/wsnet/dial.go @@ -35,8 +35,8 @@ type DialOptions struct { } // DialWebsocket dials the broker with a WebSocket and negotiates a connection. -func DialWebsocket(ctx context.Context, broker string, options *DialOptions) (*Dialer, error) { - conn, resp, err := websocket.Dial(ctx, broker, nil) +func DialWebsocket(ctx context.Context, broker string, netOpts *DialOptions, wsOpts *websocket.DialOptions) (*Dialer, error) { + conn, resp, err := websocket.Dial(ctx, broker, wsOpts) if err != nil { if resp != nil { defer func() { @@ -52,7 +52,7 @@ func DialWebsocket(ctx context.Context, broker string, options *DialOptions) (*D // We should close the socket intentionally. _ = conn.Close(websocket.StatusInternalError, "an error occurred") }() - return Dial(nconn, options) + return Dial(nconn, netOpts) } // Dial negotiates a connection to a listener. diff --git a/wsnet/dial_test.go b/wsnet/dial_test.go index a5d33b96..8a6486ba 100644 --- a/wsnet/dial_test.go +++ b/wsnet/dial_test.go @@ -39,7 +39,7 @@ func ExampleDial_basic() { dialer, err := DialWebsocket(context.Background(), "wss://master.cdr.dev/agent/workspace/connect", &DialOptions{ ICEServers: servers, - }) + }, nil) if err != nil { // Do something... } @@ -60,7 +60,7 @@ func TestDial(t *testing.T) { require.NoError(t, err) defer l.Close() - dialer, err := DialWebsocket(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil) require.NoError(t, err) err = dialer.Ping(context.Background()) @@ -83,7 +83,7 @@ func TestDial(t *testing.T) { Credential: testPass, CredentialType: webrtc.ICECredentialTypePassword, }}, - }) + }, nil) require.NoError(t, err) _ = dialer.Ping(context.Background()) @@ -100,7 +100,7 @@ func TestDial(t *testing.T) { require.NoError(t, err) defer l.Close() - dialer, err := DialWebsocket(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil) require.NoError(t, err) _, err = dialer.DialContext(context.Background(), "tcp", "localhost:100") @@ -130,7 +130,7 @@ func TestDial(t *testing.T) { require.NoError(t, err) defer l.Close() - dialer, err := DialWebsocket(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil) require.NoError(t, err) conn, err := dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String()) @@ -158,7 +158,7 @@ func TestDial(t *testing.T) { require.NoError(t, err) defer l.Close() - dialer, err := DialWebsocket(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil) require.NoError(t, err) conn, err := dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String()) @@ -178,7 +178,7 @@ func TestDial(t *testing.T) { require.NoError(t, err) defer l.Close() - dialer, err := DialWebsocket(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil) require.NoError(t, err) err = dialer.Close() @@ -210,7 +210,7 @@ func TestDial(t *testing.T) { Credential: testPass, CredentialType: webrtc.ICECredentialTypePassword, }}, - }) + }, nil) require.NoError(t, err) conn, err := dialer.DialContext(context.Background(), "tcp", tcpListener.Addr().String()) @@ -231,7 +231,7 @@ func TestDial(t *testing.T) { require.NoError(t, err) defer l.Close() - dialer, err := DialWebsocket(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil) require.NoError(t, err) go func() { _ = dialer.Close() @@ -261,7 +261,7 @@ func TestDial(t *testing.T) { t.Error(err) return } - dialer, err := DialWebsocket(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil) if err != nil { t.Error(err) } @@ -314,7 +314,7 @@ func BenchmarkThroughput(b *testing.B) { } defer l.Close() - dialer, err := DialWebsocket(context.Background(), connectAddr, nil) + dialer, err := DialWebsocket(context.Background(), connectAddr, nil, nil) if err != nil { b.Error(err) return From b0c01672f2be2795853045358c520caa877d2230 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Wed, 21 Jul 2021 00:24:55 +0000 Subject: [PATCH 5/7] Requested changes --- wsnet/cache.go | 32 +++++++++++++++++++++----------- wsnet/cache_test.go | 13 ++++++++----- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/wsnet/cache.go b/wsnet/cache.go index fdbbc215..54fc80fb 100644 --- a/wsnet/cache.go +++ b/wsnet/cache.go @@ -2,15 +2,13 @@ package wsnet import ( "context" + "errors" "sync" "time" "golang.org/x/sync/singleflight" ) -// dialerFunc is used to reference a dialer returned for caching. -type dialerFunc func() (*Dialer, error) - // DialCache constructs a new DialerCache. // The cache clears connections that: // 1. Are older than the TTL and have no active user-created connections. @@ -72,6 +70,7 @@ func (d *DialerCache) evict() { if dialer.ActiveConnections() == 0 && time.Since(d.atime[key]) >= d.ttl { evict = true } + // If we're already evicting there's no point in trying to ping. if !evict { ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() @@ -96,9 +95,18 @@ func (d *DialerCache) evict() { // Dial returns a Dialer from the cache if one exists with the key provided, // or dials a new connection using the dialerFunc. +// The bool returns whether the connection was found in the cache or not. func (d *DialerCache) Dial(ctx context.Context, key string, dialerFunc func() (*Dialer, error)) (*Dialer, bool, error) { + select { + case <-d.closed: + return nil, false, errors.New("cache closed") + default: + } + d.mut.RLock() - if dialer, ok := d.dialers[key]; ok { + dialer, ok := d.dialers[key] + d.mut.RUnlock() + if ok { closed := false select { case <-dialer.Closed(): @@ -106,7 +114,6 @@ func (d *DialerCache) Dial(ctx context.Context, key string, dialerFunc func() (* default: } if !closed { - d.mut.RUnlock() d.mut.Lock() d.atime[key] = time.Now() d.mut.Unlock() @@ -114,9 +121,8 @@ func (d *DialerCache) Dial(ctx context.Context, key string, dialerFunc func() (* return dialer, true, nil } } - d.mut.RUnlock() - dialer, err, _ := d.flightGroup.Do(key, func() (interface{}, error) { + rawDialer, err, _ := d.flightGroup.Do(key, func() (interface{}, error) { dialer, err := dialerFunc() if err != nil { return nil, err @@ -131,7 +137,13 @@ func (d *DialerCache) Dial(ctx context.Context, key string, dialerFunc func() (* if err != nil { return nil, false, err } - return dialer.(*Dialer), false, nil + select { + case <-d.closed: + return nil, false, errors.New("cache closed") + default: + } + + return rawDialer.(*Dialer), false, nil } // Close closes all cached dialers. @@ -139,9 +151,7 @@ func (d *DialerCache) Close() error { d.mut.Lock() defer d.mut.Unlock() - for key, dialer := range d.dialers { - d.flightGroup.Forget(key) - + for _, dialer := range d.dialers { err := dialer.Close() if err != nil { return err diff --git a/wsnet/cache_test.go b/wsnet/cache_test.go index 798920fc..04319541 100644 --- a/wsnet/cache_test.go +++ b/wsnet/cache_test.go @@ -6,6 +6,7 @@ import ( "time" "cdr.dev/slog/sloggers/slogtest" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -23,12 +24,13 @@ func TestCache(t *testing.T) { defer l.Close() cache := DialCache(time.Hour) - _, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr)) + c1, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr)) require.NoError(t, err) require.Equal(t, cached, false) - _, cached, err = cache.Dial(context.Background(), "example", dialFunc(connectAddr)) + c2, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr)) require.NoError(t, err) require.Equal(t, cached, true) + assert.Same(t, c1, c2) }) t.Run("Create If Closed", func(t *testing.T) { @@ -39,13 +41,14 @@ func TestCache(t *testing.T) { cache := DialCache(time.Hour) - conn, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr)) + c1, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr)) require.NoError(t, err) require.Equal(t, cached, false) - require.NoError(t, conn.Close()) - _, cached, err = cache.Dial(context.Background(), "example", dialFunc(connectAddr)) + require.NoError(t, c1.Close()) + c2, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr)) require.NoError(t, err) require.Equal(t, cached, false) + assert.NotSame(t, c1, c2) }) t.Run("Evict No Connections", func(t *testing.T) { From 1199814fb6a49f63b88eba3787c3685399dce5fe Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Wed, 21 Jul 2021 00:25:48 +0000 Subject: [PATCH 6/7] Add comment --- wsnet/cache.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/wsnet/cache.go b/wsnet/cache.go index 54fc80fb..ef0da6f6 100644 --- a/wsnet/cache.go +++ b/wsnet/cache.go @@ -29,9 +29,10 @@ func DialCache(ttl time.Duration) *DialerCache { type DialerCache struct { ttl time.Duration flightGroup *singleflight.Group + closed chan struct{} + mut sync.RWMutex - closed chan struct{} - mut sync.RWMutex + // Key is the "key" of a dialer. dialers map[string]*Dialer atime map[string]time.Time } From 844385e02b8636b0b5d359541189f02211a06dc7 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Wed, 21 Jul 2021 01:53:05 +0000 Subject: [PATCH 7/7] Fixup --- wsnet/cache.go | 22 +++++++++++++++------- wsnet/cache_test.go | 5 +++-- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/wsnet/cache.go b/wsnet/cache.go index ef0da6f6..e62aa0a9 100644 --- a/wsnet/cache.go +++ b/wsnet/cache.go @@ -32,7 +32,7 @@ type DialerCache struct { closed chan struct{} mut sync.RWMutex - // Key is the "key" of a dialer. + // Key is the "key" of a dialer, which is usually the workspace ID. dialers map[string]*Dialer atime map[string]time.Time } @@ -81,13 +81,21 @@ func (d *DialerCache) evict() { } } - if evict { - _ = dialer.Close() - d.mut.Lock() - delete(d.atime, key) - delete(d.dialers, key) - d.mut.Unlock() + if !evict { + return + } + + _ = dialer.Close() + // Ensure after Ping and potential delays that we're still testing against + // the proper dialer. + if dialer != d.dialers[key] { + return } + + d.mut.Lock() + defer d.mut.Unlock() + delete(d.atime, key) + delete(d.dialers, key) }() } d.mut.RUnlock() diff --git a/wsnet/cache_test.go b/wsnet/cache_test.go index 04319541..44edb608 100644 --- a/wsnet/cache_test.go +++ b/wsnet/cache_test.go @@ -59,12 +59,13 @@ func TestCache(t *testing.T) { cache := DialCache(0) - _, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr)) + c1, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr)) require.NoError(t, err) require.Equal(t, cached, false) cache.evict() - _, cached, err = cache.Dial(context.Background(), "example", dialFunc(connectAddr)) + c2, cached, err := cache.Dial(context.Background(), "example", dialFunc(connectAddr)) require.NoError(t, err) require.Equal(t, cached, false) + assert.NotSame(t, c1, c2) }) }