From f98ad3a72a666f838c57382ad748dcd8566f083b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20CAPARROS?= Date: Thu, 17 Jun 2021 16:06:35 +0900 Subject: [PATCH 1/4] Added a method to create a mysql database from a connection object --- database/mysql/mysql.go | 67 +++++++++++++++++++++++++++++++---------- 1 file changed, 51 insertions(+), 16 deletions(-) diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index 586df2494..102cf2cd8 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -54,32 +54,37 @@ type Mysql struct { config *Config } -// instance must have `multiStatements` set to true -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +// connection instance must have `multiStatements` set to true +func WithConnection(conn *sql.Conn, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig } - if err := instance.Ping(); err != nil { + mx := &Mysql{ + conn: conn, + db: nil, + config: config, + } + + if err := mx.setupDefaultConfig(); err != nil { return nil, err } - if config.DatabaseName == "" { - query := `SELECT DATABASE()` - var databaseName sql.NullString - if err := instance.QueryRow(query).Scan(&databaseName); err != nil { - return nil, &database.Error{OrigErr: err, Query: []byte(query)} - } + if err := mx.ensureVersionTable(); err != nil { + return nil, err + } - if len(databaseName.String) == 0 { - return nil, ErrNoDatabaseName - } + return mx, nil +} - config.DatabaseName = databaseName.String +// instance must have `multiStatements` set to true +func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { + if config == nil { + return nil, ErrNilConfig } - if len(config.MigrationsTable) == 0 { - config.MigrationsTable = DefaultMigrationsTable + if err := instance.Ping(); err != nil { + return nil, err } conn, err := instance.Conn(context.Background()) @@ -93,6 +98,10 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { config: config, } + if err := mx.setupDefaultConfig(); err != nil { + return nil, err + } + if err := mx.ensureVersionTable(); err != nil { return nil, err } @@ -100,6 +109,28 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return mx, nil } +func (m *Mysql) setupDefaultConfig() error { + if m.config.DatabaseName == "" { + query := `SELECT DATABASE()` + var databaseName sql.NullString + if err := m.conn.QueryRowContext(context.Background(), query).Scan(&databaseName); err != nil { + return &database.Error{OrigErr: err, Query: []byte(query)} + } + + if len(databaseName.String) == 0 { + return ErrNoDatabaseName + } + + m.config.DatabaseName = databaseName.String + } + + if len(m.config.MigrationsTable) == 0 { + m.config.MigrationsTable = DefaultMigrationsTable + } + + return nil +} + // extractCustomQueryParams extracts the custom query params (ones that start with "x-") from // mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) { @@ -243,7 +274,11 @@ func (m *Mysql) Open(url string) (database.Driver, error) { func (m *Mysql) Close() error { connErr := m.conn.Close() - dbErr := m.db.Close() + var dbErr error + if m.db != nil { + dbErr = m.db.Close() + } + if connErr != nil || dbErr != nil { return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) } From 394279a90dfbc532d47af90bd26e5a9e5629941f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20CAPARROS?= Date: Tue, 29 Jun 2021 16:28:28 +0900 Subject: [PATCH 2/4] Calling WithConnection from WithInstance to de-duplicate code --- database/mysql/mysql.go | 59 ++++++++++++++--------------------------- 1 file changed, 20 insertions(+), 39 deletions(-) diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index 102cf2cd8..1a47a6f05 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -55,7 +55,7 @@ type Mysql struct { } // connection instance must have `multiStatements` set to true -func WithConnection(conn *sql.Conn, config *Config) (database.Driver, error) { +func WithConnection(conn *sql.Conn, config *Config) (*Mysql, error) { if config == nil { return nil, ErrNilConfig } @@ -66,8 +66,22 @@ func WithConnection(conn *sql.Conn, config *Config) (database.Driver, error) { config: config, } - if err := mx.setupDefaultConfig(); err != nil { - return nil, err + if config.DatabaseName == "" { + query := `SELECT DATABASE()` + var databaseName sql.NullString + if err := conn.QueryRowContext(context.Background(), query).Scan(&databaseName); err != nil { + return nil, &database.Error{OrigErr: err, Query: []byte(query)} + } + + if len(databaseName.String) == 0 { + return nil, ErrNoDatabaseName + } + + config.DatabaseName = databaseName.String + } + + if len(config.MigrationsTable) == 0 { + config.MigrationsTable = DefaultMigrationsTable } if err := mx.ensureVersionTable(); err != nil { @@ -79,10 +93,6 @@ func WithConnection(conn *sql.Conn, config *Config) (database.Driver, error) { // instance must have `multiStatements` set to true func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { - if config == nil { - return nil, ErrNilConfig - } - if err := instance.Ping(); err != nil { return nil, err } @@ -92,45 +102,16 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return nil, err } - mx := &Mysql{ - conn: conn, - db: instance, - config: config, - } - - if err := mx.setupDefaultConfig(); err != nil { + mx, err := WithConnection(conn, config) + if err != nil { return nil, err } - if err := mx.ensureVersionTable(); err != nil { - return nil, err - } + mx.db = instance return mx, nil } -func (m *Mysql) setupDefaultConfig() error { - if m.config.DatabaseName == "" { - query := `SELECT DATABASE()` - var databaseName sql.NullString - if err := m.conn.QueryRowContext(context.Background(), query).Scan(&databaseName); err != nil { - return &database.Error{OrigErr: err, Query: []byte(query)} - } - - if len(databaseName.String) == 0 { - return ErrNoDatabaseName - } - - m.config.DatabaseName = databaseName.String - } - - if len(m.config.MigrationsTable) == 0 { - m.config.MigrationsTable = DefaultMigrationsTable - } - - return nil -} - // extractCustomQueryParams extracts the custom query params (ones that start with "x-") from // mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) { From 73081e2d41a5045e34ba0b043daf3d16535d9aaa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20CAPARROS?= Date: Tue, 6 Jul 2021 10:16:34 +0900 Subject: [PATCH 3/4] Adding context and ping to mysql.WithConnection --- database/mysql/mysql.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index 1a47a6f05..8316bb7cc 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -55,11 +55,15 @@ type Mysql struct { } // connection instance must have `multiStatements` set to true -func WithConnection(conn *sql.Conn, config *Config) (*Mysql, error) { +func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Mysql, error) { if config == nil { return nil, ErrNilConfig } + if err := conn.PingContext(ctx); err != nil { + return nil, err + } + mx := &Mysql{ conn: conn, db: nil, @@ -69,7 +73,7 @@ func WithConnection(conn *sql.Conn, config *Config) (*Mysql, error) { if config.DatabaseName == "" { query := `SELECT DATABASE()` var databaseName sql.NullString - if err := conn.QueryRowContext(context.Background(), query).Scan(&databaseName); err != nil { + if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil { return nil, &database.Error{OrigErr: err, Query: []byte(query)} } @@ -93,16 +97,18 @@ func WithConnection(conn *sql.Conn, config *Config) (*Mysql, error) { // instance must have `multiStatements` set to true func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { + ctx := context.Background() + if err := instance.Ping(); err != nil { return nil, err } - conn, err := instance.Conn(context.Background()) + conn, err := instance.Conn(ctx) if err != nil { return nil, err } - mx, err := WithConnection(conn, config) + mx, err := WithConnection(ctx, conn, config) if err != nil { return nil, err } From 4bfe1920aae0ca6d26b46c6de4f0afe03c43c67a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20CAPARROS?= Date: Tue, 6 Jul 2021 10:18:21 +0900 Subject: [PATCH 4/4] Interface type check at compile time --- database/mysql/mysql.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index 8316bb7cc..102b03de8 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -24,6 +24,8 @@ import ( "github.com/golang-migrate/migrate/v4/database" ) +var _ database.Driver = (*Mysql)(nil) // explicit compile time type check + func init() { database.Register("mysql", &Mysql{}) }