diff --git a/cli/resetpassword.go b/cli/resetpassword.go index 2aacc8a6e6c44..f77ed81d14db4 100644 --- a/cli/resetpassword.go +++ b/cli/resetpassword.go @@ -3,22 +3,27 @@ package cli import ( - "database/sql" "fmt" "golang.org/x/xerrors" + "cdr.dev/slog" + "cdr.dev/slog/sloggers/sloghuman" + "github.com/coder/coder/v2/coderd/database/awsiamrds" + "github.com/coder/coder/v2/codersdk" "github.com/coder/pretty" "github.com/coder/serpent" "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/migrations" "github.com/coder/coder/v2/coderd/userpassword" ) func (*RootCmd) resetPassword() *serpent.Command { - var postgresURL string + var ( + postgresURL string + postgresAuth string + ) root := &serpent.Command{ Use: "reset-password ", @@ -27,20 +32,26 @@ func (*RootCmd) resetPassword() *serpent.Command { Handler: func(inv *serpent.Invocation) error { username := inv.Args[0] - sqlDB, err := sql.Open("postgres", postgresURL) - if err != nil { - return xerrors.Errorf("dial postgres: %w", err) + logger := slog.Make(sloghuman.Sink(inv.Stdout)) + if ok, _ := inv.ParsedFlags().GetBool("verbose"); ok { + logger = logger.Leveled(slog.LevelDebug) } - defer sqlDB.Close() - err = sqlDB.Ping() - if err != nil { - return xerrors.Errorf("ping postgres: %w", err) + + sqlDriver := "postgres" + if codersdk.PostgresAuth(postgresAuth) == codersdk.PostgresAuthAWSIAMRDS { + var err error + sqlDriver, err = awsiamrds.Register(inv.Context(), sqlDriver) + if err != nil { + return xerrors.Errorf("register aws rds iam auth: %w", err) + } } - err = migrations.EnsureClean(sqlDB) + sqlDB, err := ConnectToPostgres(inv.Context(), logger, sqlDriver, postgresURL, nil) if err != nil { - return xerrors.Errorf("database needs migration: %w", err) + return xerrors.Errorf("dial postgres: %w", err) } + defer sqlDB.Close() + db := database.New(sqlDB) user, err := db.GetUserByEmailOrUsername(inv.Context(), database.GetUserByEmailOrUsernameParams{ @@ -97,6 +108,14 @@ func (*RootCmd) resetPassword() *serpent.Command { Env: "CODER_PG_CONNECTION_URL", Value: serpent.StringOf(&postgresURL), }, + serpent.Option{ + Name: "Postgres Connection Auth", + Description: "Type of auth to use when connecting to postgres.", + Flag: "postgres-connection-auth", + Env: "CODER_PG_CONNECTION_AUTH", + Default: "password", + Value: serpent.EnumOf(&postgresAuth, codersdk.PostgresAuthDrivers...), + }, } return root diff --git a/cli/server.go b/cli/server.go index ff8b2963e0eb4..9bb4cfb0a72f2 100644 --- a/cli/server.go +++ b/cli/server.go @@ -697,7 +697,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. options.Database = dbmem.New() options.Pubsub = pubsub.NewInMemory() } else { - sqlDB, dbURL, err := getPostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver) + sqlDB, dbURL, err := getAndMigratePostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver) if err != nil { return xerrors.Errorf("connect to postgres: %w", err) } @@ -2090,9 +2090,18 @@ func IsLocalhost(host string) bool { return host == "localhost" || host == "127.0.0.1" || host == "::1" } -func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string) (sqlDB *sql.DB, err error) { +// ConnectToPostgres takes in the migration command to run on the database once +// it connects. To avoid running migrations, pass in `nil` or a no-op function. +// Regardless of the passed in migration function, if the database is not fully +// migrated, an error will be returned. This can happen if the database is on a +// future or past migration version. +// +// If no error is returned, the database is fully migrated and up to date. +func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string, migrate func(db *sql.DB) error) (*sql.DB, error) { logger.Debug(ctx, "connecting to postgresql") + var err error + var sqlDB *sql.DB // Try to connect for 30 seconds. ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() @@ -2155,9 +2164,16 @@ func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, d } logger.Debug(ctx, "connected to postgresql", slog.F("version", versionNum)) - err = migrations.Up(sqlDB) + if migrate != nil { + err = migrate(sqlDB) + if err != nil { + return nil, xerrors.Errorf("migrate up: %w", err) + } + } + + err = migrations.EnsureClean(sqlDB) if err != nil { - return nil, xerrors.Errorf("migrate up: %w", err) + return nil, xerrors.Errorf("migrations in database: %w", err) } // The default is 0 but the request will fail with a 500 if the DB // cannot accept new connections, so we try to limit that here. @@ -2561,7 +2577,7 @@ func signalNotifyContext(ctx context.Context, inv *serpent.Invocation, sig ...os return inv.SignalNotifyContext(ctx, sig...) } -func getPostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, auth codersdk.PostgresAuth, sqlDriver string) (*sql.DB, string, error) { +func getAndMigratePostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, auth codersdk.PostgresAuth, sqlDriver string) (*sql.DB, string, error) { dbURL, err := escapePostgresURLUserInfo(postgresURL) if err != nil { return nil, "", xerrors.Errorf("escaping postgres URL: %w", err) @@ -2574,7 +2590,7 @@ func getPostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, } } - sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL) + sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL, migrations.Up) if err != nil { return nil, "", xerrors.Errorf("connect to postgres: %w", err) } diff --git a/cli/server_createadminuser.go b/cli/server_createadminuser.go index 7ef95e7e093e6..ed9c7b9bcc921 100644 --- a/cli/server_createadminuser.go +++ b/cli/server_createadminuser.go @@ -72,7 +72,7 @@ func (r *RootCmd) newCreateAdminUserCommand() *serpent.Command { } } - sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, newUserDBURL) + sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, newUserDBURL, nil) if err != nil { return xerrors.Errorf("connect to postgres: %w", err) } diff --git a/cli/server_test.go b/cli/server_test.go index 9ba963d484548..0dba63e7c2fe3 100644 --- a/cli/server_test.go +++ b/cli/server_test.go @@ -38,11 +38,13 @@ import ( "tailscale.com/derp/derphttp" "tailscale.com/types/key" + "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/cli" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/cli/config" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/migrations" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/telemetry" "github.com/coder/coder/v2/codersdk" @@ -1828,20 +1830,51 @@ func TestConnectToPostgres(t *testing.T) { if !dbtestutil.WillUsePostgres() { t.Skip("this test does not make sense without postgres") } - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - t.Cleanup(cancel) - log := testutil.Logger(t) + t.Run("Migrate", func(t *testing.T) { + t.Parallel() - dbURL, err := dbtestutil.Open(t) - require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + t.Cleanup(cancel) - sqlDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL) - require.NoError(t, err) - t.Cleanup(func() { - _ = sqlDB.Close() + log := testutil.Logger(t) + + dbURL, err := dbtestutil.Open(t) + require.NoError(t, err) + + sqlDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, migrations.Up) + require.NoError(t, err) + t.Cleanup(func() { + _ = sqlDB.Close() + }) + require.NoError(t, sqlDB.PingContext(ctx)) + }) + + t.Run("NoMigrate", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + t.Cleanup(cancel) + + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + + dbURL, err := dbtestutil.Open(t) + require.NoError(t, err) + + okDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil) + require.NoError(t, err) + defer okDB.Close() + + // Set the migration number forward + _, err = okDB.Exec(`UPDATE schema_migrations SET version = version + 1`) + require.NoError(t, err) + + _, err = cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil) + require.Error(t, err) + require.ErrorContains(t, err, "database needs migration") + + require.NoError(t, okDB.PingContext(ctx)) }) - require.NoError(t, sqlDB.PingContext(ctx)) } func TestServer_InvalidDERP(t *testing.T) { diff --git a/cli/testdata/coder_reset-password_--help.golden b/cli/testdata/coder_reset-password_--help.golden index a7d53df12ad90..ccefb412d8fb7 100644 --- a/cli/testdata/coder_reset-password_--help.golden +++ b/cli/testdata/coder_reset-password_--help.golden @@ -6,6 +6,9 @@ USAGE: Directly connect to the database to reset a user's password OPTIONS: + --postgres-connection-auth password|awsiamrds, $CODER_PG_CONNECTION_AUTH (default: password) + Type of auth to use when connecting to postgres. + --postgres-url string, $CODER_PG_CONNECTION_URL URL of a PostgreSQL database to connect to. diff --git a/coderd/database/awsiamrds/awsiamrds_test.go b/coderd/database/awsiamrds/awsiamrds_test.go index 844b85b119850..d52da4aab7bfe 100644 --- a/coderd/database/awsiamrds/awsiamrds_test.go +++ b/coderd/database/awsiamrds/awsiamrds_test.go @@ -9,6 +9,7 @@ import ( "github.com/coder/coder/v2/cli" "github.com/coder/coder/v2/coderd/database/awsiamrds" + "github.com/coder/coder/v2/coderd/database/migrations" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/testutil" ) @@ -32,7 +33,7 @@ func TestDriver(t *testing.T) { sqlDriver, err := awsiamrds.Register(ctx, "postgres") require.NoError(t, err) - db, err := cli.ConnectToPostgres(ctx, testutil.Logger(t), sqlDriver, url) + db, err := cli.ConnectToPostgres(ctx, testutil.Logger(t), sqlDriver, url, migrations.Up) require.NoError(t, err) defer func() { _ = db.Close() diff --git a/docs/reference/cli/reset-password.md b/docs/reference/cli/reset-password.md index 75e94821cdb31..ada9ad7e7db3e 100644 --- a/docs/reference/cli/reset-password.md +++ b/docs/reference/cli/reset-password.md @@ -19,3 +19,13 @@ coder reset-password [flags] | Environment | $CODER_PG_CONNECTION_URL | URL of a PostgreSQL database to connect to. + +### --postgres-connection-auth + +| | | +|-------------|----------------------------------------| +| Type | password\|awsiamrds | +| Environment | $CODER_PG_CONNECTION_AUTH | +| Default | password | + +Type of auth to use when connecting to postgres. diff --git a/enterprise/cli/server_dbcrypt.go b/enterprise/cli/server_dbcrypt.go index 148303f85402d..72ac6cc6e82b0 100644 --- a/enterprise/cli/server_dbcrypt.go +++ b/enterprise/cli/server_dbcrypt.go @@ -98,7 +98,7 @@ func (*RootCmd) dbcryptRotateCmd() *serpent.Command { } } - sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL) + sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil) if err != nil { return xerrors.Errorf("connect to postgres: %w", err) } @@ -163,7 +163,7 @@ func (*RootCmd) dbcryptDecryptCmd() *serpent.Command { } } - sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL) + sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil) if err != nil { return xerrors.Errorf("connect to postgres: %w", err) } @@ -219,7 +219,7 @@ Are you sure you want to continue?` } } - sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL) + sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil) if err != nil { return xerrors.Errorf("connect to postgres: %w", err) }