Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 9c74819

Browse files
committed
chore: switch from control flag to migration func
Add unit test for missing migration
1 parent 4241993 commit 9c74819

File tree

6 files changed

+66
-29
lines changed

6 files changed

+66
-29
lines changed

cli/resetpassword.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ func (*RootCmd) resetPassword() *serpent.Command {
4646
}
4747
}
4848

49-
sqlDB, err := ConnectToPostgres(inv.Context(), logger, false, sqlDriver, postgresURL)
49+
sqlDB, err := ConnectToPostgres(inv.Context(), logger, sqlDriver, postgresURL, nil)
5050
if err != nil {
5151
return xerrors.Errorf("dial postgres: %w", err)
5252
}

cli/server.go

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,15 +2090,18 @@ func IsLocalhost(host string) bool {
20902090
return host == "localhost" || host == "127.0.0.1" || host == "::1"
20912091
}
20922092

2093-
// ConnectToPostgres takes a control flag "migrate". If true, `migrations.Up` will be applied
2094-
// to the database, potentially making schema changes.
2095-
// If set to false, no database changes will be applied, however the migration version
2096-
// will be checked. If the database is not fully up to date with its migrations, then
2097-
// an error will be returned.
2098-
// nolint:revive // 'migrate' is a control flag.
2099-
func ConnectToPostgres(ctx context.Context, logger slog.Logger, migrate bool, driver string, dbURL string) (sqlDB *sql.DB, err error) {
2093+
// ConnectToPostgres takes in the migration command to run on the database once
2094+
// it connects. To avoid running migrations, pass in `nil` or a no-op function.
2095+
// Regardless of the passed in migration function, if the database is not fully
2096+
// migrated, an error will be returned. This can happen if the database is on a
2097+
// future or past migration version.
2098+
//
2099+
// If no error is returned, the database is fully migrated and up to date.
2100+
func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string, migrate func(db *sql.DB) error) (*sql.DB, error) {
21002101
logger.Debug(ctx, "connecting to postgresql")
21012102

2103+
var err error
2104+
var sqlDB *sql.DB
21022105
// Try to connect for 30 seconds.
21032106
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
21042107
defer cancel()
@@ -2161,13 +2164,16 @@ func ConnectToPostgres(ctx context.Context, logger slog.Logger, migrate bool, dr
21612164
}
21622165
logger.Debug(ctx, "connected to postgresql", slog.F("version", versionNum))
21632166

2164-
if migrate {
2165-
err = migrations.Up(sqlDB)
2166-
} else {
2167-
err = migrations.EnsureClean(sqlDB)
2167+
if migrate != nil {
2168+
err = migrate(sqlDB)
2169+
if err != nil {
2170+
return nil, xerrors.Errorf("migrate up: %w", err)
2171+
}
21682172
}
2173+
2174+
err = migrations.EnsureClean(sqlDB)
21692175
if err != nil {
2170-
return nil, xerrors.Errorf("migrate up: %w", err)
2176+
return nil, xerrors.Errorf("migrations in database: %w", err)
21712177
}
21722178
// The default is 0 but the request will fail with a 500 if the DB
21732179
// cannot accept new connections, so we try to limit that here.
@@ -2584,7 +2590,7 @@ func getAndMigratePostgresDB(ctx context.Context, logger slog.Logger, postgresUR
25842590
}
25852591
}
25862592

2587-
sqlDB, err := ConnectToPostgres(ctx, logger, true, sqlDriver, dbURL)
2593+
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL, migrations.Up)
25882594
if err != nil {
25892595
return nil, "", xerrors.Errorf("connect to postgres: %w", err)
25902596
}

cli/server_createadminuser.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func (r *RootCmd) newCreateAdminUserCommand() *serpent.Command {
7272
}
7373
}
7474

75-
sqlDB, err := ConnectToPostgres(ctx, logger, false, sqlDriver, newUserDBURL)
75+
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, newUserDBURL, nil)
7676
if err != nil {
7777
return xerrors.Errorf("connect to postgres: %w", err)
7878
}

cli/server_test.go

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ import (
3838
"tailscale.com/derp/derphttp"
3939
"tailscale.com/types/key"
4040

41+
"cdr.dev/slog/sloggers/slogtest"
4142
"github.com/coder/coder/v2/cli"
4243
"github.com/coder/coder/v2/cli/clitest"
4344
"github.com/coder/coder/v2/cli/config"
4445
"github.com/coder/coder/v2/coderd/coderdtest"
4546
"github.com/coder/coder/v2/coderd/database/dbtestutil"
47+
"github.com/coder/coder/v2/coderd/database/migrations"
4648
"github.com/coder/coder/v2/coderd/httpapi"
4749
"github.com/coder/coder/v2/coderd/telemetry"
4850
"github.com/coder/coder/v2/codersdk"
@@ -1828,20 +1830,48 @@ func TestConnectToPostgres(t *testing.T) {
18281830
if !dbtestutil.WillUsePostgres() {
18291831
t.Skip("this test does not make sense without postgres")
18301832
}
1831-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
1832-
t.Cleanup(cancel)
18331833

1834-
log := testutil.Logger(t)
1834+
t.Run("Migrate", func(t *testing.T) {
1835+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
1836+
t.Cleanup(cancel)
18351837

1836-
dbURL, err := dbtestutil.Open(t)
1837-
require.NoError(t, err)
1838+
log := testutil.Logger(t)
18381839

1839-
sqlDB, err := cli.ConnectToPostgres(ctx, log, true, "postgres", dbURL)
1840-
require.NoError(t, err)
1841-
t.Cleanup(func() {
1842-
_ = sqlDB.Close()
1840+
dbURL, err := dbtestutil.Open(t)
1841+
require.NoError(t, err)
1842+
1843+
sqlDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, migrations.Up)
1844+
require.NoError(t, err)
1845+
t.Cleanup(func() {
1846+
_ = sqlDB.Close()
1847+
})
1848+
require.NoError(t, sqlDB.PingContext(ctx))
18431849
})
1844-
require.NoError(t, sqlDB.PingContext(ctx))
1850+
1851+
t.Run("NoMigrate", func(t *testing.T) {
1852+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
1853+
t.Cleanup(cancel)
1854+
1855+
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
1856+
1857+
dbURL, err := dbtestutil.Open(t)
1858+
require.NoError(t, err)
1859+
1860+
okDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil)
1861+
require.NoError(t, err)
1862+
defer okDB.Close()
1863+
1864+
// Set the migration number forward
1865+
_, err = okDB.Exec(`UPDATE schema_migrations SET version = version + 1`)
1866+
require.NoError(t, err)
1867+
1868+
_, err = cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil)
1869+
require.Error(t, err)
1870+
require.ErrorContains(t, err, "database needs migration")
1871+
1872+
require.NoError(t, okDB.PingContext(ctx))
1873+
})
1874+
18451875
}
18461876

18471877
func TestServer_InvalidDERP(t *testing.T) {

coderd/database/awsiamrds/awsiamrds_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
"github.com/coder/coder/v2/cli"
1111
"github.com/coder/coder/v2/coderd/database/awsiamrds"
12+
"github.com/coder/coder/v2/coderd/database/migrations"
1213
"github.com/coder/coder/v2/coderd/database/pubsub"
1314
"github.com/coder/coder/v2/testutil"
1415
)
@@ -32,7 +33,7 @@ func TestDriver(t *testing.T) {
3233
sqlDriver, err := awsiamrds.Register(ctx, "postgres")
3334
require.NoError(t, err)
3435

35-
db, err := cli.ConnectToPostgres(ctx, testutil.Logger(t), true, sqlDriver, url)
36+
db, err := cli.ConnectToPostgres(ctx, testutil.Logger(t), sqlDriver, url, migrations.Up)
3637
require.NoError(t, err)
3738
defer func() {
3839
_ = db.Close()

enterprise/cli/server_dbcrypt.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ func (*RootCmd) dbcryptRotateCmd() *serpent.Command {
9898
}
9999
}
100100

101-
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, false, sqlDriver, flags.PostgresURL)
101+
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
102102
if err != nil {
103103
return xerrors.Errorf("connect to postgres: %w", err)
104104
}
@@ -163,7 +163,7 @@ func (*RootCmd) dbcryptDecryptCmd() *serpent.Command {
163163
}
164164
}
165165

166-
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, false, sqlDriver, flags.PostgresURL)
166+
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
167167
if err != nil {
168168
return xerrors.Errorf("connect to postgres: %w", err)
169169
}
@@ -219,7 +219,7 @@ Are you sure you want to continue?`
219219
}
220220
}
221221

222-
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, false, sqlDriver, flags.PostgresURL)
222+
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
223223
if err != nil {
224224
return xerrors.Errorf("connect to postgres: %w", err)
225225
}

0 commit comments

Comments
 (0)