Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ded612d

Browse files
authored
fix: use authenticated urls for pubsub (coder#14261)
1 parent 6914862 commit ded612d

File tree

9 files changed

+290
-14
lines changed

9 files changed

+290
-14
lines changed

coderd/database/awsiamrds/awsiamrds.go

+51-1
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,21 @@ import (
1010
"github.com/aws/aws-sdk-go-v2/aws"
1111
"github.com/aws/aws-sdk-go-v2/config"
1212
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
13+
"github.com/lib/pq"
1314
"golang.org/x/xerrors"
15+
16+
"github.com/coder/coder/v2/coderd/database"
1417
)
1518

1619
type awsIamRdsDriver struct {
1720
parent driver.Driver
1821
cfg aws.Config
1922
}
2023

21-
var _ driver.Driver = &awsIamRdsDriver{}
24+
var (
25+
_ driver.Driver = &awsIamRdsDriver{}
26+
_ database.ConnectorCreator = &awsIamRdsDriver{}
27+
)
2228

2329
// Register initializes and registers our aws iam rds wrapped database driver.
2430
func Register(ctx context.Context, parentName string) (string, error) {
@@ -65,6 +71,16 @@ func (d *awsIamRdsDriver) Open(name string) (driver.Conn, error) {
6571
return conn, nil
6672
}
6773

74+
// Connector returns a driver.Connector that fetches a new authentication token for each connection.
75+
func (d *awsIamRdsDriver) Connector(name string) (driver.Connector, error) {
76+
connector := &connector{
77+
url: name,
78+
cfg: d.cfg,
79+
}
80+
81+
return connector, nil
82+
}
83+
6884
func getAuthenticatedURL(cfg aws.Config, dbURL string) (string, error) {
6985
nURL, err := url.Parse(dbURL)
7086
if err != nil {
@@ -82,3 +98,37 @@ func getAuthenticatedURL(cfg aws.Config, dbURL string) (string, error) {
8298

8399
return nURL.String(), nil
84100
}
101+
102+
type connector struct {
103+
url string
104+
cfg aws.Config
105+
dialer pq.Dialer
106+
}
107+
108+
var _ database.DialerConnector = &connector{}
109+
110+
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
111+
nURL, err := getAuthenticatedURL(c.cfg, c.url)
112+
if err != nil {
113+
return nil, xerrors.Errorf("assigning authentication token to url: %w", err)
114+
}
115+
116+
nc, err := pq.NewConnector(nURL)
117+
if err != nil {
118+
return nil, xerrors.Errorf("creating new connector: %w", err)
119+
}
120+
121+
if c.dialer != nil {
122+
nc.Dialer(c.dialer)
123+
}
124+
125+
return nc.Connect(ctx)
126+
}
127+
128+
func (*connector) Driver() driver.Driver {
129+
return &pq.Driver{}
130+
}
131+
132+
func (c *connector) Dialer(dialer pq.Dialer) {
133+
c.dialer = dialer
134+
}

coderd/database/awsiamrds/awsiamrds_test.go

+25-3
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ import (
77

88
"github.com/stretchr/testify/require"
99

10+
"cdr.dev/slog"
1011
"cdr.dev/slog/sloggers/slogtest"
11-
1212
"github.com/coder/coder/v2/cli"
13-
awsrdsiam "github.com/coder/coder/v2/coderd/database/awsiamrds"
13+
"github.com/coder/coder/v2/coderd/database/awsiamrds"
14+
"github.com/coder/coder/v2/coderd/database/pubsub"
1415
"github.com/coder/coder/v2/testutil"
1516
)
1617

@@ -22,13 +23,15 @@ func TestDriver(t *testing.T) {
2223
// export DBAWSIAMRDS_TEST_URL="postgres://user@host:5432/dbname";
2324
url := os.Getenv("DBAWSIAMRDS_TEST_URL")
2425
if url == "" {
26+
t.Log("skipping test; no DBAWSIAMRDS_TEST_URL set")
2527
t.Skip()
2628
}
2729

30+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
2831
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
2932
defer cancel()
3033

31-
sqlDriver, err := awsrdsiam.Register(ctx, "postgres")
34+
sqlDriver, err := awsiamrds.Register(ctx, "postgres")
3235
require.NoError(t, err)
3336

3437
db, err := cli.ConnectToPostgres(ctx, slogtest.Make(t, nil), sqlDriver, url)
@@ -47,4 +50,23 @@ func TestDriver(t *testing.T) {
4750
var one int
4851
require.NoError(t, i.Scan(&one))
4952
require.Equal(t, 1, one)
53+
54+
ps, err := pubsub.New(ctx, logger, db, url)
55+
require.NoError(t, err)
56+
57+
gotChan := make(chan struct{})
58+
subCancel, err := ps.Subscribe("test", func(_ context.Context, _ []byte) {
59+
close(gotChan)
60+
})
61+
defer subCancel()
62+
require.NoError(t, err)
63+
64+
err = ps.Publish("test", []byte("hello"))
65+
require.NoError(t, err)
66+
67+
select {
68+
case <-gotChan:
69+
case <-ctx.Done():
70+
require.Fail(t, "timed out waiting for message")
71+
}
5072
}

coderd/database/connector.go

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package database
2+
3+
import (
4+
"database/sql/driver"
5+
6+
"github.com/lib/pq"
7+
)
8+
9+
// ConnectorCreator is a driver.Driver that can create a driver.Connector.
10+
type ConnectorCreator interface {
11+
driver.Driver
12+
Connector(name string) (driver.Connector, error)
13+
}
14+
15+
// DialerConnector is a driver.Connector that can set a pq.Dialer.
16+
type DialerConnector interface {
17+
driver.Connector
18+
Dialer(dialer pq.Dialer)
19+
}

coderd/database/dbtestutil/driver.go

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package dbtestutil
2+
3+
import (
4+
"context"
5+
"database/sql/driver"
6+
7+
"github.com/lib/pq"
8+
"golang.org/x/xerrors"
9+
10+
"github.com/coder/coder/v2/coderd/database"
11+
)
12+
13+
var _ database.DialerConnector = &Connector{}
14+
15+
type Connector struct {
16+
name string
17+
driver *Driver
18+
dialer pq.Dialer
19+
}
20+
21+
func (c *Connector) Connect(_ context.Context) (driver.Conn, error) {
22+
if c.dialer != nil {
23+
conn, err := pq.DialOpen(c.dialer, c.name)
24+
if err != nil {
25+
return nil, xerrors.Errorf("failed to dial open connection: %w", err)
26+
}
27+
28+
c.driver.Connections <- conn
29+
30+
return conn, nil
31+
}
32+
33+
conn, err := pq.Driver{}.Open(c.name)
34+
if err != nil {
35+
return nil, xerrors.Errorf("failed to open connection: %w", err)
36+
}
37+
38+
c.driver.Connections <- conn
39+
40+
return conn, nil
41+
}
42+
43+
func (c *Connector) Driver() driver.Driver {
44+
return c.driver
45+
}
46+
47+
func (c *Connector) Dialer(dialer pq.Dialer) {
48+
c.dialer = dialer
49+
}
50+
51+
type Driver struct {
52+
Connections chan driver.Conn
53+
}
54+
55+
func NewDriver() *Driver {
56+
return &Driver{
57+
Connections: make(chan driver.Conn, 1),
58+
}
59+
}
60+
61+
func (d *Driver) Connector(name string) (driver.Connector, error) {
62+
return &Connector{
63+
name: name,
64+
driver: d,
65+
}, nil
66+
}
67+
68+
func (d *Driver) Open(name string) (driver.Conn, error) {
69+
c, err := d.Connector(name)
70+
if err != nil {
71+
return nil, err
72+
}
73+
74+
return c.Connect(context.Background())
75+
}
76+
77+
func (d *Driver) Close() {
78+
close(d.Connections)
79+
}

coderd/database/pubsub/pubsub.go

+34-5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package pubsub
33
import (
44
"context"
55
"database/sql"
6+
"database/sql/driver"
67
"errors"
78
"io"
89
"net"
@@ -15,6 +16,8 @@ import (
1516
"github.com/prometheus/client_golang/prometheus"
1617
"golang.org/x/xerrors"
1718

19+
"github.com/coder/coder/v2/coderd/database"
20+
1821
"cdr.dev/slog"
1922
)
2023

@@ -432,9 +435,35 @@ func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
432435
// pq.defaultDialer uses a zero net.Dialer as well.
433436
d: net.Dialer{},
434437
}
438+
connector driver.Connector
439+
err error
435440
)
441+
442+
// Create a custom connector if the database driver supports it.
443+
connectorCreator, ok := p.db.Driver().(database.ConnectorCreator)
444+
if ok {
445+
connector, err = connectorCreator.Connector(connectURL)
446+
if err != nil {
447+
return xerrors.Errorf("create custom connector: %w", err)
448+
}
449+
} else {
450+
// use the default pq connector otherwise
451+
connector, err = pq.NewConnector(connectURL)
452+
if err != nil {
453+
return xerrors.Errorf("create pq connector: %w", err)
454+
}
455+
}
456+
457+
// Set the dialer if the connector supports it.
458+
dc, ok := connector.(database.DialerConnector)
459+
if !ok {
460+
p.logger.Critical(ctx, "connector does not support setting log dialer, database connection debug logs will be missing")
461+
} else {
462+
dc.Dialer(dialer)
463+
}
464+
436465
p.pgListener = pqListenerShim{
437-
Listener: pq.NewDialListener(dialer, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) {
466+
Listener: pq.NewConnectorListener(connector, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) {
438467
switch t {
439468
case pq.ListenerEventConnected:
440469
p.logger.Info(ctx, "pubsub connected to postgres")
@@ -583,8 +612,8 @@ func (p *PGPubsub) Collect(metrics chan<- prometheus.Metric) {
583612
}
584613

585614
// New creates a new Pubsub implementation using a PostgreSQL connection.
586-
func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connectURL string) (*PGPubsub, error) {
587-
p := newWithoutListener(logger, database)
615+
func New(startCtx context.Context, logger slog.Logger, db *sql.DB, connectURL string) (*PGPubsub, error) {
616+
p := newWithoutListener(logger, db)
588617
if err := p.startListener(startCtx, connectURL); err != nil {
589618
return nil, err
590619
}
@@ -594,11 +623,11 @@ func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connect
594623
}
595624

596625
// newWithoutListener creates a new PGPubsub without creating the pqListener.
597-
func newWithoutListener(logger slog.Logger, database *sql.DB) *PGPubsub {
626+
func newWithoutListener(logger slog.Logger, db *sql.DB) *PGPubsub {
598627
return &PGPubsub{
599628
logger: logger,
600629
listenDone: make(chan struct{}),
601-
db: database,
630+
db: db,
602631
queues: make(map[string]map[uuid.UUID]*msgQueue),
603632
latencyMeasurer: NewLatencyMeasurer(logger.Named("latency-measurer")),
604633

0 commit comments

Comments
 (0)