Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a7ed977

Browse files
authored
chore: prevent db migrations from running on all cli commands (#15980)
1 parent 813270d commit a7ed977

File tree

8 files changed

+115
-33
lines changed

8 files changed

+115
-33
lines changed

cli/resetpassword.go

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,27 @@
33
package cli
44

55
import (
6-
"database/sql"
76
"fmt"
87

98
"golang.org/x/xerrors"
109

10+
"cdr.dev/slog"
11+
"cdr.dev/slog/sloggers/sloghuman"
12+
"github.com/coder/coder/v2/coderd/database/awsiamrds"
13+
"github.com/coder/coder/v2/codersdk"
1114
"github.com/coder/pretty"
1215
"github.com/coder/serpent"
1316

1417
"github.com/coder/coder/v2/cli/cliui"
1518
"github.com/coder/coder/v2/coderd/database"
16-
"github.com/coder/coder/v2/coderd/database/migrations"
1719
"github.com/coder/coder/v2/coderd/userpassword"
1820
)
1921

2022
func (*RootCmd) resetPassword() *serpent.Command {
21-
var postgresURL string
23+
var (
24+
postgresURL string
25+
postgresAuth string
26+
)
2227

2328
root := &serpent.Command{
2429
Use: "reset-password <username>",
@@ -27,20 +32,26 @@ func (*RootCmd) resetPassword() *serpent.Command {
2732
Handler: func(inv *serpent.Invocation) error {
2833
username := inv.Args[0]
2934

30-
sqlDB, err := sql.Open("postgres", postgresURL)
31-
if err != nil {
32-
return xerrors.Errorf("dial postgres: %w", err)
35+
logger := slog.Make(sloghuman.Sink(inv.Stdout))
36+
if ok, _ := inv.ParsedFlags().GetBool("verbose"); ok {
37+
logger = logger.Leveled(slog.LevelDebug)
3338
}
34-
defer sqlDB.Close()
35-
err = sqlDB.Ping()
36-
if err != nil {
37-
return xerrors.Errorf("ping postgres: %w", err)
39+
40+
sqlDriver := "postgres"
41+
if codersdk.PostgresAuth(postgresAuth) == codersdk.PostgresAuthAWSIAMRDS {
42+
var err error
43+
sqlDriver, err = awsiamrds.Register(inv.Context(), sqlDriver)
44+
if err != nil {
45+
return xerrors.Errorf("register aws rds iam auth: %w", err)
46+
}
3847
}
3948

40-
err = migrations.EnsureClean(sqlDB)
49+
sqlDB, err := ConnectToPostgres(inv.Context(), logger, sqlDriver, postgresURL, nil)
4150
if err != nil {
42-
return xerrors.Errorf("database needs migration: %w", err)
51+
return xerrors.Errorf("dial postgres: %w", err)
4352
}
53+
defer sqlDB.Close()
54+
4455
db := database.New(sqlDB)
4556

4657
user, err := db.GetUserByEmailOrUsername(inv.Context(), database.GetUserByEmailOrUsernameParams{
@@ -97,6 +108,14 @@ func (*RootCmd) resetPassword() *serpent.Command {
97108
Env: "CODER_PG_CONNECTION_URL",
98109
Value: serpent.StringOf(&postgresURL),
99110
},
111+
serpent.Option{
112+
Name: "Postgres Connection Auth",
113+
Description: "Type of auth to use when connecting to postgres.",
114+
Flag: "postgres-connection-auth",
115+
Env: "CODER_PG_CONNECTION_AUTH",
116+
Default: "password",
117+
Value: serpent.EnumOf(&postgresAuth, codersdk.PostgresAuthDrivers...),
118+
},
100119
}
101120

102121
return root

cli/server.go

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
697697
options.Database = dbmem.New()
698698
options.Pubsub = pubsub.NewInMemory()
699699
} else {
700-
sqlDB, dbURL, err := getPostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver)
700+
sqlDB, dbURL, err := getAndMigratePostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver)
701701
if err != nil {
702702
return xerrors.Errorf("connect to postgres: %w", err)
703703
}
@@ -2090,9 +2090,18 @@ func IsLocalhost(host string) bool {
20902090
return host == "localhost" || host == "127.0.0.1" || host == "::1"
20912091
}
20922092

2093-
func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string) (sqlDB *sql.DB, err error) {
2093+
// ConnectToPostgres takes in the migration command to run on the database once
2094+
// it connects. To avoid running migrations, pass in `nil` or a no-op function.
2095+
// Regardless of the passed in migration function, if the database is not fully
2096+
// migrated, an error will be returned. This can happen if the database is on a
2097+
// future or past migration version.
2098+
//
2099+
// If no error is returned, the database is fully migrated and up to date.
2100+
func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string, migrate func(db *sql.DB) error) (*sql.DB, error) {
20942101
logger.Debug(ctx, "connecting to postgresql")
20952102

2103+
var err error
2104+
var sqlDB *sql.DB
20962105
// Try to connect for 30 seconds.
20972106
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
20982107
defer cancel()
@@ -2155,9 +2164,16 @@ func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, d
21552164
}
21562165
logger.Debug(ctx, "connected to postgresql", slog.F("version", versionNum))
21572166

2158-
err = migrations.Up(sqlDB)
2167+
if migrate != nil {
2168+
err = migrate(sqlDB)
2169+
if err != nil {
2170+
return nil, xerrors.Errorf("migrate up: %w", err)
2171+
}
2172+
}
2173+
2174+
err = migrations.EnsureClean(sqlDB)
21592175
if err != nil {
2160-
return nil, xerrors.Errorf("migrate up: %w", err)
2176+
return nil, xerrors.Errorf("migrations in database: %w", err)
21612177
}
21622178
// The default is 0 but the request will fail with a 500 if the DB
21632179
// cannot accept new connections, so we try to limit that here.
@@ -2561,7 +2577,7 @@ func signalNotifyContext(ctx context.Context, inv *serpent.Invocation, sig ...os
25612577
return inv.SignalNotifyContext(ctx, sig...)
25622578
}
25632579

2564-
func getPostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, auth codersdk.PostgresAuth, sqlDriver string) (*sql.DB, string, error) {
2580+
func getAndMigratePostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, auth codersdk.PostgresAuth, sqlDriver string) (*sql.DB, string, error) {
25652581
dbURL, err := escapePostgresURLUserInfo(postgresURL)
25662582
if err != nil {
25672583
return nil, "", xerrors.Errorf("escaping postgres URL: %w", err)
@@ -2574,7 +2590,7 @@ func getPostgresDB(ctx context.Context, logger slog.Logger, postgresURL string,
25742590
}
25752591
}
25762592

2577-
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL)
2593+
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL, migrations.Up)
25782594
if err != nil {
25792595
return nil, "", xerrors.Errorf("connect to postgres: %w", err)
25802596
}

cli/server_createadminuser.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func (r *RootCmd) newCreateAdminUserCommand() *serpent.Command {
7272
}
7373
}
7474

75-
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, newUserDBURL)
75+
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, newUserDBURL, nil)
7676
if err != nil {
7777
return xerrors.Errorf("connect to postgres: %w", err)
7878
}

cli/server_test.go

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ import (
3838
"tailscale.com/derp/derphttp"
3939
"tailscale.com/types/key"
4040

41+
"cdr.dev/slog/sloggers/slogtest"
4142
"github.com/coder/coder/v2/cli"
4243
"github.com/coder/coder/v2/cli/clitest"
4344
"github.com/coder/coder/v2/cli/config"
4445
"github.com/coder/coder/v2/coderd/coderdtest"
4546
"github.com/coder/coder/v2/coderd/database/dbtestutil"
47+
"github.com/coder/coder/v2/coderd/database/migrations"
4648
"github.com/coder/coder/v2/coderd/httpapi"
4749
"github.com/coder/coder/v2/coderd/telemetry"
4850
"github.com/coder/coder/v2/codersdk"
@@ -1828,20 +1830,51 @@ func TestConnectToPostgres(t *testing.T) {
18281830
if !dbtestutil.WillUsePostgres() {
18291831
t.Skip("this test does not make sense without postgres")
18301832
}
1831-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
1832-
t.Cleanup(cancel)
18331833

1834-
log := testutil.Logger(t)
1834+
t.Run("Migrate", func(t *testing.T) {
1835+
t.Parallel()
18351836

1836-
dbURL, err := dbtestutil.Open(t)
1837-
require.NoError(t, err)
1837+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
1838+
t.Cleanup(cancel)
18381839

1839-
sqlDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL)
1840-
require.NoError(t, err)
1841-
t.Cleanup(func() {
1842-
_ = sqlDB.Close()
1840+
log := testutil.Logger(t)
1841+
1842+
dbURL, err := dbtestutil.Open(t)
1843+
require.NoError(t, err)
1844+
1845+
sqlDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, migrations.Up)
1846+
require.NoError(t, err)
1847+
t.Cleanup(func() {
1848+
_ = sqlDB.Close()
1849+
})
1850+
require.NoError(t, sqlDB.PingContext(ctx))
1851+
})
1852+
1853+
t.Run("NoMigrate", func(t *testing.T) {
1854+
t.Parallel()
1855+
1856+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
1857+
t.Cleanup(cancel)
1858+
1859+
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
1860+
1861+
dbURL, err := dbtestutil.Open(t)
1862+
require.NoError(t, err)
1863+
1864+
okDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil)
1865+
require.NoError(t, err)
1866+
defer okDB.Close()
1867+
1868+
// Set the migration number forward
1869+
_, err = okDB.Exec(`UPDATE schema_migrations SET version = version + 1`)
1870+
require.NoError(t, err)
1871+
1872+
_, err = cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil)
1873+
require.Error(t, err)
1874+
require.ErrorContains(t, err, "database needs migration")
1875+
1876+
require.NoError(t, okDB.PingContext(ctx))
18431877
})
1844-
require.NoError(t, sqlDB.PingContext(ctx))
18451878
}
18461879

18471880
func TestServer_InvalidDERP(t *testing.T) {

cli/testdata/coder_reset-password_--help.golden

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ USAGE:
66
Directly connect to the database to reset a user's password
77

88
OPTIONS:
9+
--postgres-connection-auth password|awsiamrds, $CODER_PG_CONNECTION_AUTH (default: password)
10+
Type of auth to use when connecting to postgres.
11+
912
--postgres-url string, $CODER_PG_CONNECTION_URL
1013
URL of a PostgreSQL database to connect to.
1114

coderd/database/awsiamrds/awsiamrds_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
"github.com/coder/coder/v2/cli"
1111
"github.com/coder/coder/v2/coderd/database/awsiamrds"
12+
"github.com/coder/coder/v2/coderd/database/migrations"
1213
"github.com/coder/coder/v2/coderd/database/pubsub"
1314
"github.com/coder/coder/v2/testutil"
1415
)
@@ -32,7 +33,7 @@ func TestDriver(t *testing.T) {
3233
sqlDriver, err := awsiamrds.Register(ctx, "postgres")
3334
require.NoError(t, err)
3435

35-
db, err := cli.ConnectToPostgres(ctx, testutil.Logger(t), sqlDriver, url)
36+
db, err := cli.ConnectToPostgres(ctx, testutil.Logger(t), sqlDriver, url, migrations.Up)
3637
require.NoError(t, err)
3738
defer func() {
3839
_ = db.Close()

docs/reference/cli/reset-password.md

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

enterprise/cli/server_dbcrypt.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ func (*RootCmd) dbcryptRotateCmd() *serpent.Command {
9898
}
9999
}
100100

101-
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL)
101+
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
102102
if err != nil {
103103
return xerrors.Errorf("connect to postgres: %w", err)
104104
}
@@ -163,7 +163,7 @@ func (*RootCmd) dbcryptDecryptCmd() *serpent.Command {
163163
}
164164
}
165165

166-
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL)
166+
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
167167
if err != nil {
168168
return xerrors.Errorf("connect to postgres: %w", err)
169169
}
@@ -219,7 +219,7 @@ Are you sure you want to continue?`
219219
}
220220
}
221221

222-
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL)
222+
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
223223
if err != nil {
224224
return xerrors.Errorf("connect to postgres: %w", err)
225225
}

0 commit comments

Comments
 (0)