Thanks to visit codestin.com
Credit goes to github.com

Skip to content

chore: prevent db migrations from running on all cli commands #15980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions cli/resetpassword.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,27 @@
package cli

import (
"database/sql"
"fmt"

"golang.org/x/xerrors"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/v2/coderd/database/awsiamrds"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/pretty"
"github.com/coder/serpent"

"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/migrations"
"github.com/coder/coder/v2/coderd/userpassword"
)

func (*RootCmd) resetPassword() *serpent.Command {
var postgresURL string
var (
postgresURL string
postgresAuth string
)

root := &serpent.Command{
Use: "reset-password <username>",
Expand All @@ -27,20 +32,26 @@ func (*RootCmd) resetPassword() *serpent.Command {
Handler: func(inv *serpent.Invocation) error {
username := inv.Args[0]

sqlDB, err := sql.Open("postgres", postgresURL)
if err != nil {
return xerrors.Errorf("dial postgres: %w", err)
logger := slog.Make(sloghuman.Sink(inv.Stdout))
if ok, _ := inv.ParsedFlags().GetBool("verbose"); ok {
logger = logger.Leveled(slog.LevelDebug)
}
defer sqlDB.Close()
err = sqlDB.Ping()
if err != nil {
return xerrors.Errorf("ping postgres: %w", err)

sqlDriver := "postgres"
if codersdk.PostgresAuth(postgresAuth) == codersdk.PostgresAuthAWSIAMRDS {
var err error
sqlDriver, err = awsiamrds.Register(inv.Context(), sqlDriver)
if err != nil {
return xerrors.Errorf("register aws rds iam auth: %w", err)
}
}

err = migrations.EnsureClean(sqlDB)
sqlDB, err := ConnectToPostgres(inv.Context(), logger, sqlDriver, postgresURL, nil)
if err != nil {
return xerrors.Errorf("database needs migration: %w", err)
return xerrors.Errorf("dial postgres: %w", err)
}
defer sqlDB.Close()

db := database.New(sqlDB)

user, err := db.GetUserByEmailOrUsername(inv.Context(), database.GetUserByEmailOrUsernameParams{
Expand Down Expand Up @@ -97,6 +108,14 @@ func (*RootCmd) resetPassword() *serpent.Command {
Env: "CODER_PG_CONNECTION_URL",
Value: serpent.StringOf(&postgresURL),
},
serpent.Option{
Name: "Postgres Connection Auth",
Description: "Type of auth to use when connecting to postgres.",
Flag: "postgres-connection-auth",
Env: "CODER_PG_CONNECTION_AUTH",
Default: "password",
Value: serpent.EnumOf(&postgresAuth, codersdk.PostgresAuthDrivers...),
},
}

return root
Expand Down
28 changes: 22 additions & 6 deletions cli/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
options.Database = dbmem.New()
options.Pubsub = pubsub.NewInMemory()
} else {
sqlDB, dbURL, err := getPostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver)
sqlDB, dbURL, err := getAndMigratePostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver)
if err != nil {
return xerrors.Errorf("connect to postgres: %w", err)
}
Expand Down Expand Up @@ -2090,9 +2090,18 @@ func IsLocalhost(host string) bool {
return host == "localhost" || host == "127.0.0.1" || host == "::1"
}

func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string) (sqlDB *sql.DB, err error) {
// ConnectToPostgres takes in the migration command to run on the database once
// it connects. To avoid running migrations, pass in `nil` or a no-op function.
// Regardless of the passed in migration function, if the database is not fully
// migrated, an error will be returned. This can happen if the database is on a
// future or past migration version.
//
// If no error is returned, the database is fully migrated and up to date.
func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string, migrate func(db *sql.DB) error) (*sql.DB, error) {
logger.Debug(ctx, "connecting to postgresql")

var err error
var sqlDB *sql.DB
// Try to connect for 30 seconds.
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
Expand Down Expand Up @@ -2155,9 +2164,16 @@ func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, d
}
logger.Debug(ctx, "connected to postgresql", slog.F("version", versionNum))

err = migrations.Up(sqlDB)
if migrate != nil {
err = migrate(sqlDB)
if err != nil {
return nil, xerrors.Errorf("migrate up: %w", err)
}
}

err = migrations.EnsureClean(sqlDB)
if err != nil {
return nil, xerrors.Errorf("migrate up: %w", err)
return nil, xerrors.Errorf("migrations in database: %w", err)
}
// The default is 0 but the request will fail with a 500 if the DB
// cannot accept new connections, so we try to limit that here.
Expand Down Expand Up @@ -2561,7 +2577,7 @@ func signalNotifyContext(ctx context.Context, inv *serpent.Invocation, sig ...os
return inv.SignalNotifyContext(ctx, sig...)
}

func getPostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, auth codersdk.PostgresAuth, sqlDriver string) (*sql.DB, string, error) {
func getAndMigratePostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, auth codersdk.PostgresAuth, sqlDriver string) (*sql.DB, string, error) {
dbURL, err := escapePostgresURLUserInfo(postgresURL)
if err != nil {
return nil, "", xerrors.Errorf("escaping postgres URL: %w", err)
Expand All @@ -2574,7 +2590,7 @@ func getPostgresDB(ctx context.Context, logger slog.Logger, postgresURL string,
}
}

sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL)
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL, migrations.Up)
if err != nil {
return nil, "", xerrors.Errorf("connect to postgres: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion cli/server_createadminuser.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (r *RootCmd) newCreateAdminUserCommand() *serpent.Command {
}
}

sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, newUserDBURL)
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, newUserDBURL, nil)
if err != nil {
return xerrors.Errorf("connect to postgres: %w", err)
}
Expand Down
53 changes: 43 additions & 10 deletions cli/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ import (
"tailscale.com/derp/derphttp"
"tailscale.com/types/key"

"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/cli"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/cli/config"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/migrations"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/telemetry"
"github.com/coder/coder/v2/codersdk"
Expand Down Expand Up @@ -1828,20 +1830,51 @@ func TestConnectToPostgres(t *testing.T) {
if !dbtestutil.WillUsePostgres() {
t.Skip("this test does not make sense without postgres")
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
t.Cleanup(cancel)

log := testutil.Logger(t)
t.Run("Migrate", func(t *testing.T) {
t.Parallel()

dbURL, err := dbtestutil.Open(t)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
t.Cleanup(cancel)

sqlDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL)
require.NoError(t, err)
t.Cleanup(func() {
_ = sqlDB.Close()
log := testutil.Logger(t)

dbURL, err := dbtestutil.Open(t)
require.NoError(t, err)

sqlDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, migrations.Up)
require.NoError(t, err)
t.Cleanup(func() {
_ = sqlDB.Close()
})
require.NoError(t, sqlDB.PingContext(ctx))
})

t.Run("NoMigrate", func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
t.Cleanup(cancel)

log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})

dbURL, err := dbtestutil.Open(t)
require.NoError(t, err)

okDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil)
require.NoError(t, err)
defer okDB.Close()

// Set the migration number forward
_, err = okDB.Exec(`UPDATE schema_migrations SET version = version + 1`)
require.NoError(t, err)

_, err = cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil)
require.Error(t, err)
require.ErrorContains(t, err, "database needs migration")

require.NoError(t, okDB.PingContext(ctx))
})
require.NoError(t, sqlDB.PingContext(ctx))
}

func TestServer_InvalidDERP(t *testing.T) {
Expand Down
3 changes: 3 additions & 0 deletions cli/testdata/coder_reset-password_--help.golden
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ USAGE:
Directly connect to the database to reset a user's password

OPTIONS:
--postgres-connection-auth password|awsiamrds, $CODER_PG_CONNECTION_AUTH (default: password)
Type of auth to use when connecting to postgres.

--postgres-url string, $CODER_PG_CONNECTION_URL
URL of a PostgreSQL database to connect to.

Expand Down
3 changes: 2 additions & 1 deletion coderd/database/awsiamrds/awsiamrds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/coder/coder/v2/cli"
"github.com/coder/coder/v2/coderd/database/awsiamrds"
"github.com/coder/coder/v2/coderd/database/migrations"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/testutil"
)
Expand All @@ -32,7 +33,7 @@ func TestDriver(t *testing.T) {
sqlDriver, err := awsiamrds.Register(ctx, "postgres")
require.NoError(t, err)

db, err := cli.ConnectToPostgres(ctx, testutil.Logger(t), sqlDriver, url)
db, err := cli.ConnectToPostgres(ctx, testutil.Logger(t), sqlDriver, url, migrations.Up)
require.NoError(t, err)
defer func() {
_ = db.Close()
Expand Down
10 changes: 10 additions & 0 deletions docs/reference/cli/reset-password.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions enterprise/cli/server_dbcrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (*RootCmd) dbcryptRotateCmd() *serpent.Command {
}
}

sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL)
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
if err != nil {
return xerrors.Errorf("connect to postgres: %w", err)
}
Expand Down Expand Up @@ -163,7 +163,7 @@ func (*RootCmd) dbcryptDecryptCmd() *serpent.Command {
}
}

sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL)
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
if err != nil {
return xerrors.Errorf("connect to postgres: %w", err)
}
Expand Down Expand Up @@ -219,7 +219,7 @@ Are you sure you want to continue?`
}
}

sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL)
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
if err != nil {
return xerrors.Errorf("connect to postgres: %w", err)
}
Expand Down
Loading