Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 518300a

Browse files
authored
fix(coderd/database): improve pubsub closure and context cancellation (#7993)
1 parent aba5cb8 commit 518300a

File tree

2 files changed

+108
-9
lines changed

2 files changed

+108
-9
lines changed

coderd/database/pubsub.go

+21-7
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ func (q *msgQueue) dropped() {
163163
// Pubsub implementation using PostgreSQL.
164164
type pgPubsub struct {
165165
ctx context.Context
166+
cancel context.CancelFunc
167+
listenDone chan struct{}
166168
pgListener *pq.Listener
167169
db *sql.DB
168170
mut sync.Mutex
@@ -228,7 +230,7 @@ func (p *pgPubsub) Publish(event string, message []byte) error {
228230
// This is safe because we are calling pq.QuoteLiteral. pg_notify doesn't
229231
// support the first parameter being a prepared statement.
230232
//nolint:gosec
231-
_, err := p.db.ExecContext(context.Background(), `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
233+
_, err := p.db.ExecContext(p.ctx, `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
232234
if err != nil {
233235
return xerrors.Errorf("exec pg_notify: %w", err)
234236
}
@@ -237,19 +239,24 @@ func (p *pgPubsub) Publish(event string, message []byte) error {
237239

238240
// Close closes the pubsub instance.
239241
func (p *pgPubsub) Close() error {
240-
return p.pgListener.Close()
242+
p.cancel()
243+
err := p.pgListener.Close()
244+
<-p.listenDone
245+
return err
241246
}
242247

243248
// listen begins receiving messages on the pq listener.
244-
func (p *pgPubsub) listen(ctx context.Context) {
249+
func (p *pgPubsub) listen() {
250+
defer close(p.listenDone)
251+
defer p.pgListener.Close()
252+
245253
var (
246254
notif *pq.Notification
247255
ok bool
248256
)
249-
defer p.pgListener.Close()
250257
for {
251258
select {
252-
case <-ctx.Done():
259+
case <-p.ctx.Done():
253260
return
254261
case notif, ok = <-p.pgListener.Notify:
255262
if !ok {
@@ -292,7 +299,7 @@ func (p *pgPubsub) recordReconnect() {
292299
func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) {
293300
// Creates a new listener using pq.
294301
errCh := make(chan error)
295-
listener := pq.NewListener(connectURL, time.Second, time.Minute, func(event pq.ListenerEventType, err error) {
302+
listener := pq.NewListener(connectURL, time.Second, time.Minute, func(_ pq.ListenerEventType, err error) {
296303
// This callback gets events whenever the connection state changes.
297304
// Don't send if the errChannel has already been closed.
298305
select {
@@ -306,18 +313,25 @@ func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub
306313
select {
307314
case err := <-errCh:
308315
if err != nil {
316+
_ = listener.Close()
309317
return nil, xerrors.Errorf("create pq listener: %w", err)
310318
}
311319
case <-ctx.Done():
320+
_ = listener.Close()
312321
return nil, ctx.Err()
313322
}
323+
324+
// Start a new context that will be canceled when the pubsub is closed.
325+
ctx, cancel := context.WithCancel(context.Background())
314326
pgPubsub := &pgPubsub{
315327
ctx: ctx,
328+
cancel: cancel,
329+
listenDone: make(chan struct{}),
316330
db: database,
317331
pgListener: listener,
318332
queues: make(map[string]map[uuid.UUID]*msgQueue),
319333
}
320-
go pgPubsub.listen(ctx)
334+
go pgPubsub.listen()
321335

322336
return pgPubsub, nil
323337
}

coderd/database/pubsub_test.go

+87-2
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ func TestPubsub(t *testing.T) {
4545
event := "test"
4646
data := "testing"
4747
messageChannel := make(chan []byte)
48-
cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
48+
unsub, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
4949
messageChannel <- message
5050
})
5151
require.NoError(t, err)
52-
defer cancelFunc()
52+
defer unsub()
5353
go func() {
5454
err = pubsub.Publish(event, []byte(data))
5555
assert.NoError(t, err)
@@ -72,6 +72,91 @@ func TestPubsub(t *testing.T) {
7272
defer pubsub.Close()
7373
cancelFunc()
7474
})
75+
76+
t.Run("NotClosedOnCancelContext", func(t *testing.T) {
77+
ctx, cancel := context.WithCancel(context.Background())
78+
defer cancel()
79+
connectionURL, closePg, err := postgres.Open()
80+
require.NoError(t, err)
81+
defer closePg()
82+
db, err := sql.Open("postgres", connectionURL)
83+
require.NoError(t, err)
84+
defer db.Close()
85+
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
86+
require.NoError(t, err)
87+
defer pubsub.Close()
88+
89+
// Provided context must only be active during NewPubsub, not after.
90+
cancel()
91+
92+
event := "test"
93+
data := "testing"
94+
messageChannel := make(chan []byte)
95+
unsub, err := pubsub.Subscribe(event, func(_ context.Context, message []byte) {
96+
messageChannel <- message
97+
})
98+
require.NoError(t, err)
99+
defer unsub()
100+
go func() {
101+
err = pubsub.Publish(event, []byte(data))
102+
assert.NoError(t, err)
103+
}()
104+
message := <-messageChannel
105+
assert.Equal(t, string(message), data)
106+
})
107+
108+
t.Run("ClosePropagatesContextCancellationToSubscription", func(t *testing.T) {
109+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
110+
defer cancel()
111+
connectionURL, closePg, err := postgres.Open()
112+
require.NoError(t, err)
113+
defer closePg()
114+
db, err := sql.Open("postgres", connectionURL)
115+
require.NoError(t, err)
116+
defer db.Close()
117+
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
118+
require.NoError(t, err)
119+
defer pubsub.Close()
120+
121+
event := "test"
122+
done := make(chan struct{})
123+
called := make(chan struct{})
124+
unsub, err := pubsub.Subscribe(event, func(subCtx context.Context, _ []byte) {
125+
defer close(done)
126+
select {
127+
case <-subCtx.Done():
128+
assert.Fail(t, "context should not be canceled")
129+
default:
130+
}
131+
close(called)
132+
select {
133+
case <-subCtx.Done():
134+
case <-ctx.Done():
135+
assert.Fail(t, "timeout waiting for sub context to be canceled")
136+
}
137+
})
138+
require.NoError(t, err)
139+
defer unsub()
140+
141+
go func() {
142+
err := pubsub.Publish(event, nil)
143+
assert.NoError(t, err)
144+
}()
145+
146+
select {
147+
case <-called:
148+
case <-ctx.Done():
149+
require.Fail(t, "timeout waiting for handler to be called")
150+
}
151+
err = pubsub.Close()
152+
require.NoError(t, err)
153+
154+
select {
155+
case <-done:
156+
case <-ctx.Done():
157+
require.Fail(t, "timeout waiting for handler to finish")
158+
}
159+
})
75160
}
76161

77162
func TestPubsub_ordering(t *testing.T) {

0 commit comments

Comments
 (0)