diff --git a/coderd/database/dbtestutil/dbpool/main.go b/coderd/database/dbtestutil/dbpool/main.go new file mode 100644 index 0000000000000..1e264273e911e --- /dev/null +++ b/coderd/database/dbtestutil/dbpool/main.go @@ -0,0 +1,31 @@ +package dbpool + +import "net/rpc" + +type Client struct { + rpcClient *rpc.Client +} + +func NewClient(addr string) (*Client, error) { + rpcClient, err := rpc.DialHTTP("tcp", addr) + if err != nil { + return nil, err + } + return &Client{rpcClient: rpcClient}, nil +} + +func (c *Client) GetDB() (string, error) { + var arg int + var reply string + err := c.rpcClient.Call("DBPool.GetDB", &arg, &reply) + return reply, err +} + +func (c *Client) DisposeDB(dbURL string) error { + var reply int + return c.rpcClient.Call("DBPool.DisposeDB", &dbURL, &reply) +} + +func (c *Client) Close() error { + return c.rpcClient.Close() +} diff --git a/coderd/database/dbtestutil/postgres.go b/coderd/database/dbtestutil/postgres.go index c0b35a03529ca..169a21668d315 100644 --- a/coderd/database/dbtestutil/postgres.go +++ b/coderd/database/dbtestutil/postgres.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "net" + "net/url" "os" "path/filepath" "strconv" @@ -21,6 +22,7 @@ import ( "github.com/ory/dockertest/v3/docker" "golang.org/x/xerrors" + "github.com/coder/coder/v2/coderd/database/dbtestutil/dbpool" "github.com/coder/coder/v2/coderd/database/migrations" "github.com/coder/coder/v2/cryptorand" "github.com/coder/retry" @@ -38,6 +40,39 @@ func (p ConnectionParams) DSN() string { return fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", p.Username, p.Password, p.Host, p.Port, p.DBName) } +func ParseDSN(dsn string) (ConnectionParams, error) { + u, err := url.Parse(dsn) + if err != nil { + return ConnectionParams{}, xerrors.Errorf("parse dsn: %w", err) + } + + if u.Scheme != "postgres" { + return ConnectionParams{}, xerrors.Errorf("invalid dsn scheme: %s", u.Scheme) + } + + var params ConnectionParams + if u.User != nil { + params.Username = u.User.Username() + params.Password, _ = u.User.Password() + } + + params.Host = u.Hostname() + params.Port = u.Port() + if params.Port == "" { + // Default PostgreSQL port + params.Port = "5432" + } + + // The path includes a leading slash, remove it. + if len(u.Path) > 1 { + params.DBName = u.Path[1:] + } else { + return ConnectionParams{}, xerrors.New("database name missing in dsn") + } + + return params, nil +} + // These variables are global because all tests share them. var ( connectionParamsInitOnce sync.Once @@ -138,12 +173,68 @@ type TBSubset interface { Logf(format string, args ...any) } +func RemoveDB(t TBSubset, dbName string) error { + cleanupDbURL := defaultConnectionParams.DSN() + cleanupConn, err := sql.Open("postgres", cleanupDbURL) + if err != nil { + return xerrors.Errorf("cleanup database %q: failed to connect to postgres: %w", dbName, err) + } + defer func() { + if err := cleanupConn.Close(); err != nil { + t.Logf("cleanup database %q: failed to close connection: %s\n", dbName, err.Error()) + } + }() + _, err = cleanupConn.Exec("DROP DATABASE " + dbName + ";") + if err != nil { + return xerrors.Errorf("cleanup database %q: failed to drop database: %w", dbName, err) + } + return nil +} + +func getDBPoolClient() (*dbpool.Client, error) { + dbpoolURL := os.Getenv("DBPOOL") + if dbpoolURL == "" { + return nil, nil //nolint:nilnil + } + client, err := dbpool.NewClient(dbpoolURL) + if err != nil { + return nil, xerrors.Errorf("create db pool client: %w", err) + } + return client, nil +} + // Open creates a new PostgreSQL database instance. // If there's a database running at localhost:5432, it will use that. // Otherwise, it will start a new postgres container. func Open(t TBSubset, opts ...OpenOption) (string, error) { t.Helper() + openOptions := OpenOptions{} + for _, opt := range opts { + opt(&openOptions) + } + + if openOptions.DBFrom == nil { + dbPoolClient, err := getDBPoolClient() + if err != nil { + return "", xerrors.Errorf("get db pool client: %w", err) + } + if dbPoolClient != nil { + dbURL, err := dbPoolClient.GetDB() + if err != nil { + return "", xerrors.Errorf("get db from pool: %w", err) + } + t.Cleanup(func() { + defer dbPoolClient.Close() + err := dbPoolClient.DisposeDB(dbURL) + if err != nil { + t.Logf("cleanup database %s: failed to dispose db: %+v\n", dbURL, err) + } + }) + return dbURL, nil + } + } + connectionParamsInitOnce.Do(func() { errDefaultConnectionParamsInit = initDefaultConnection(t) }) @@ -151,11 +242,6 @@ func Open(t TBSubset, opts ...OpenOption) (string, error) { return "", xerrors.Errorf("init default connection params: %w", errDefaultConnectionParamsInit) } - openOptions := OpenOptions{} - for _, opt := range opts { - opt(&openOptions) - } - var ( username = defaultConnectionParams.Username password = defaultConnectionParams.Password @@ -182,22 +268,7 @@ func Open(t TBSubset, opts ...OpenOption) (string, error) { } t.Cleanup(func() { - cleanupDbURL := defaultConnectionParams.DSN() - cleanupConn, err := sql.Open("postgres", cleanupDbURL) - if err != nil { - t.Logf("cleanup database %q: failed to connect to postgres: %s\n", dbName, err.Error()) - return - } - defer func() { - if err := cleanupConn.Close(); err != nil { - t.Logf("cleanup database %q: failed to close connection: %s\n", dbName, err.Error()) - } - }() - _, err = cleanupConn.Exec("DROP DATABASE " + dbName + ";") - if err != nil { - t.Logf("failed to clean up database %q: %s\n", dbName, err.Error()) - return - } + RemoveDB(t, dbName) }) dsn := ConnectionParams{ diff --git a/scripts/dbpool/main.go b/scripts/dbpool/main.go new file mode 100644 index 0000000000000..5a9661d544fb5 --- /dev/null +++ b/scripts/dbpool/main.go @@ -0,0 +1,305 @@ +package main + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "net/rpc" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogjson" + + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database/dbtestutil" +) + +type mockTB struct{} + +func (*mockTB) Cleanup(_ func()) { + // noop, we won't be running cleanup +} + +func (*mockTB) Helper() { + // noop +} + +func (*mockTB) Logf(format string, args ...any) { + _, _ = fmt.Printf(format, args...) +} + +type DBPool struct { + numCleanupWorkers int + + availableDBs chan string + garbageDBs chan string + dbRequests chan struct{} + ctx context.Context + Cancel context.CancelFunc + logger *slog.Logger +} + +type DBPoolArgs struct { + PoolSize int + NumCleanupWorkers int + Logger *slog.Logger +} + +func NewDBPool(args DBPoolArgs) *DBPool { + ctx, cancel := context.WithCancel(context.Background()) + dbRequests := make(chan struct{}, args.PoolSize) + for i := 0; i < args.PoolSize; i++ { + dbRequests <- struct{}{} + } + args.Logger.Info(ctx, "starting db pool", slog.F("size", args.PoolSize), slog.F("action", "start")) + return &DBPool{ + availableDBs: make(chan string, args.PoolSize), + garbageDBs: make(chan string, args.PoolSize), + dbRequests: dbRequests, + ctx: ctx, + Cancel: cancel, + logger: args.Logger, + numCleanupWorkers: args.NumCleanupWorkers, + } +} + +func (m *DBPool) GetDB(_ *int, reply *string) error { + select { + case dbURL := <-m.availableDBs: + *reply = dbURL + m.logger.Info(m.ctx, "db lease started", slog.F("action", "GetDB"), slog.F("db_url", dbURL)) + return nil + case <-m.ctx.Done(): + return xerrors.Errorf("server context canceled while waiting for DB: %w", m.ctx.Err()) + } +} + +func (m *DBPool) DisposeDB(dbURL *string, _ *int) error { + select { + case m.garbageDBs <- *dbURL: + m.logger.Info(m.ctx, "db returned to pool for disposal", slog.F("action", "DisposeDB"), slog.F("db_url", *dbURL)) + return nil + case <-m.ctx.Done(): + return xerrors.Errorf("could not dispose DB %s, server context canceled: %w", *dbURL, m.ctx.Err()) + } +} + +func (m *DBPool) createDB() error { + t := &mockTB{} + dbURL, err := dbtestutil.Open(t) + if err != nil { + return xerrors.Errorf("open db: %w", err) + } + m.availableDBs <- dbURL + m.logger.Info(m.ctx, "created db and added to pool", slog.F("action", "createDB"), slog.F("db_url", dbURL)) + return nil +} + +func (m *DBPool) destroyDB(dbURL string) error { + t := &mockTB{} + connParams, err := dbtestutil.ParseDSN(dbURL) + if err != nil { + return xerrors.Errorf("parse dsn: %w", err) + } + if err := dbtestutil.RemoveDB(t, connParams.DBName); err != nil { + return xerrors.Errorf("remove db: %w", err) + } + m.dbRequests <- struct{}{} + m.logger.Info(m.ctx, "removed db from pool", slog.F("action", "destroyDB"), slog.F("db_url", dbURL)) + return nil +} + +func (m *DBPool) cleanup() { + wg := sync.WaitGroup{} + for range m.numCleanupWorkers { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case dbURL := <-m.availableDBs: + err := m.destroyDB(dbURL) + if err != nil { + m.logger.Error(m.ctx, "error destroying db in cleanup", slog.Error(err)) + } + default: + return + } + } + }() + } + wg.Wait() +} + +func (m *DBPool) Start(numCreateWorkers int, numDestroyWorkers int) { + wg := sync.WaitGroup{} + errChan := make(chan error, 1) + for range numCreateWorkers { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-m.ctx.Done(): + return + case <-m.dbRequests: + if err := m.createDB(); err != nil { + // we only care about the first error + select { + case errChan <- xerrors.Errorf("create db: %w", err): + default: + } + } + } + } + }() + } + for range numDestroyWorkers { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-m.ctx.Done(): + return + case dbURL := <-m.garbageDBs: + if err := m.destroyDB(dbURL); err != nil { + // we only care about the first error + select { + case errChan <- xerrors.Errorf("destroy db: %w", err): + default: + } + } + } + } + }() + } + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-m.ctx.Done(): + return + case err := <-errChan: + m.logger.Error(m.ctx, "received error over channel", slog.Error(err)) + m.Cancel() + return + } + } + }() + wg.Wait() + + m.cleanup() +} + +var errAlreadyPrinted = xerrors.New("error already printed") + +func inner(logger *slog.Logger) error { + dbPool := NewDBPool(DBPoolArgs{ + PoolSize: 250, + NumCleanupWorkers: 16, + Logger: logger, + }) + + osSignalChan := make(chan os.Signal, 1) + signal.Notify(osSignalChan, syscall.SIGINT) + + // for both errChan and shutdownSignalChan, we buffer 16 to avoid deadlocks + errChan := make(chan error, 16) + shutdownSignalChan := make(chan struct{}, 16) + dbPoolStoppedChan := make(chan struct{}) + shutdownTimeoutChan := make(chan struct{}) + + go func() { + <-osSignalChan + shutdownSignalChan <- struct{}{} + }() + go func() { + defer func() { + shutdownSignalChan <- struct{}{} + }() + l, err := net.Listen("tcp", "localhost:8080") + if err != nil { + select { + case errChan <- xerrors.Errorf("listen: %w", err): + default: + } + return + } + if err := rpc.Register(dbPool); err != nil { + select { + case errChan <- xerrors.Errorf("register db manager: %w", err): + default: + } + return + } + rpc.HandleHTTP() + server := &http.Server{ + Addr: l.Addr().String(), + ReadTimeout: 15 * time.Second, + WriteTimeout: 15 * time.Second, + } + dbPool.logger.Info(dbPool.ctx, "serving on port 8080") + if err := server.Serve(l); err != nil && err != http.ErrServerClosed { + select { + case errChan <- xerrors.Errorf("serve: %w", err): + default: + } + } + }() + go func() { + <-shutdownSignalChan + dbPool.Cancel() + time.Sleep(15 * time.Second) + close(shutdownTimeoutChan) + }() + go func() { + dbPool.Start(10, 10) + close(dbPoolStoppedChan) + }() + + select { + case <-dbPoolStoppedChan: + dbPool.logger.Info(dbPool.ctx, "cleaned up, exiting gracefully") + case <-shutdownTimeoutChan: + select { + case errChan <- xerrors.Errorf("timed out waiting for server to clean up"): + default: + } + } + + errorPrinted := false + for { + select { + case err := <-errChan: + dbPool.logger.Error(dbPool.ctx, "an error occurred", slog.Error(err)) + errorPrinted = true + default: + goto finishLine + } + } + +finishLine: + if errorPrinted { + return errAlreadyPrinted + } + return nil +} + +func main() { + logger := slog.Make(slogjson.Sink(os.Stdout)) + if err := inner(&logger); err != nil { + if !errors.Is(err, errAlreadyPrinted) { + logger.Error(context.Background(), "an error occurred, exiting", slog.Error(err)) + } + os.Exit(1) + } +} diff --git a/scripts/dbpoolclient/main.go b/scripts/dbpoolclient/main.go new file mode 100644 index 0000000000000..631eb8665fd03 --- /dev/null +++ b/scripts/dbpoolclient/main.go @@ -0,0 +1,54 @@ +package main + +import ( + "fmt" + "os" + + "github.com/coder/coder/v2/coderd/database/dbtestutil/dbpool" +) + +func main() { + if len(os.Args) < 2 { + fmt.Println("Usage: dbpoolclient [args]") + fmt.Println("Commands:") + fmt.Println(" getdb") + fmt.Println(" dispose ") + os.Exit(1) + } + + client, err := dbpool.NewClient("localhost:8080") + if err != nil { + panic(err) + } + + command := os.Args[1] + + switch command { + case "getdb": + if len(os.Args) != 2 { + fmt.Println("Usage: dbpoolclient getdb") + os.Exit(1) + } + fmt.Println("getting db") + dbURL, err := client.GetDB() + if err != nil { + panic(err) + } + fmt.Println(dbURL) + case "dispose": + if len(os.Args) != 3 { + fmt.Println("Usage: dbpoolclient dispose ") + os.Exit(1) + } + dbURL := os.Args[2] + fmt.Printf("disposing db: %s\n", dbURL) + err := client.DisposeDB(dbURL) + if err != nil { + panic(err) + } + fmt.Println("db disposed successfully") + default: + fmt.Printf("Unknown command: %s\n", command) + os.Exit(1) + } +}