@@ -3,6 +3,7 @@ package pubsub
3
3
import (
4
4
"context"
5
5
"database/sql"
6
+ "database/sql/driver"
6
7
"errors"
7
8
"io"
8
9
"net"
@@ -15,6 +16,8 @@ import (
15
16
"github.com/prometheus/client_golang/prometheus"
16
17
"golang.org/x/xerrors"
17
18
19
+ "github.com/coder/coder/v2/coderd/database"
20
+
18
21
"cdr.dev/slog"
19
22
)
20
23
@@ -432,9 +435,35 @@ func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
432
435
// pq.defaultDialer uses a zero net.Dialer as well.
433
436
d : net.Dialer {},
434
437
}
438
+ connector driver.Connector
439
+ err error
435
440
)
441
+
442
+ // Create a custom connector if the database driver supports it.
443
+ connectorCreator , ok := p .db .Driver ().(database.ConnectorCreator )
444
+ if ok {
445
+ connector , err = connectorCreator .Connector (connectURL )
446
+ if err != nil {
447
+ return xerrors .Errorf ("create custom connector: %w" , err )
448
+ }
449
+ } else {
450
+ // use the default pq connector otherwise
451
+ connector , err = pq .NewConnector (connectURL )
452
+ if err != nil {
453
+ return xerrors .Errorf ("create pq connector: %w" , err )
454
+ }
455
+ }
456
+
457
+ // Set the dialer if the connector supports it.
458
+ dc , ok := connector .(database.DialerConnector )
459
+ if ! ok {
460
+ p .logger .Critical (ctx , "connector does not support setting log dialer, database connection debug logs will be missing" )
461
+ } else {
462
+ dc .Dialer (dialer )
463
+ }
464
+
436
465
p .pgListener = pqListenerShim {
437
- Listener : pq .NewDialListener ( dialer , connectURL , time .Second , time .Minute , func (t pq.ListenerEventType , err error ) {
466
+ Listener : pq .NewConnectorListener ( connector , connectURL , time .Second , time .Minute , func (t pq.ListenerEventType , err error ) {
438
467
switch t {
439
468
case pq .ListenerEventConnected :
440
469
p .logger .Info (ctx , "pubsub connected to postgres" )
@@ -583,8 +612,8 @@ func (p *PGPubsub) Collect(metrics chan<- prometheus.Metric) {
583
612
}
584
613
585
614
// New creates a new Pubsub implementation using a PostgreSQL connection.
586
- func New (startCtx context.Context , logger slog.Logger , database * sql.DB , connectURL string ) (* PGPubsub , error ) {
587
- p := newWithoutListener (logger , database )
615
+ func New (startCtx context.Context , logger slog.Logger , db * sql.DB , connectURL string ) (* PGPubsub , error ) {
616
+ p := newWithoutListener (logger , db )
588
617
if err := p .startListener (startCtx , connectURL ); err != nil {
589
618
return nil , err
590
619
}
@@ -594,11 +623,11 @@ func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connect
594
623
}
595
624
596
625
// newWithoutListener creates a new PGPubsub without creating the pqListener.
597
- func newWithoutListener (logger slog.Logger , database * sql.DB ) * PGPubsub {
626
+ func newWithoutListener (logger slog.Logger , db * sql.DB ) * PGPubsub {
598
627
return & PGPubsub {
599
628
logger : logger ,
600
629
listenDone : make (chan struct {}),
601
- db : database ,
630
+ db : db ,
602
631
queues : make (map [string ]map [uuid.UUID ]* msgQueue ),
603
632
latencyMeasurer : NewLatencyMeasurer (logger .Named ("latency-measurer" )),
604
633
0 commit comments