diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index 586df2494..102b03de8 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -24,6 +24,8 @@ import ( "github.com/golang-migrate/migrate/v4/database" ) +var _ database.Driver = (*Mysql)(nil) // explicit compile time type check + func init() { database.Register("mysql", &Mysql{}) } @@ -54,20 +56,26 @@ type Mysql struct { config *Config } -// instance must have `multiStatements` set to true -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +// connection instance must have `multiStatements` set to true +func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Mysql, error) { if config == nil { return nil, ErrNilConfig } - if err := instance.Ping(); err != nil { + if err := conn.PingContext(ctx); err != nil { return nil, err } + mx := &Mysql{ + conn: conn, + db: nil, + config: config, + } + if config.DatabaseName == "" { query := `SELECT DATABASE()` var databaseName sql.NullString - if err := instance.QueryRow(query).Scan(&databaseName); err != nil { + if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil { return nil, &database.Error{OrigErr: err, Query: []byte(query)} } @@ -82,21 +90,33 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { config.MigrationsTable = DefaultMigrationsTable } - conn, err := instance.Conn(context.Background()) - if err != nil { + if err := mx.ensureVersionTable(); err != nil { return nil, err } - mx := &Mysql{ - conn: conn, - db: instance, - config: config, + return mx, nil +} + +// instance must have `multiStatements` set to true +func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { + ctx := context.Background() + + if err := instance.Ping(); err != nil { + return nil, err } - if err := mx.ensureVersionTable(); err != nil { + conn, err := instance.Conn(ctx) + if err != nil { return nil, err } + mx, err := WithConnection(ctx, conn, config) + if err != nil { + return nil, err + } + + mx.db = instance + return mx, nil } @@ -243,7 +263,11 @@ func (m *Mysql) Open(url string) (database.Driver, error) { func (m *Mysql) Close() error { connErr := m.conn.Close() - dbErr := m.db.Close() + var dbErr error + if m.db != nil { + dbErr = m.db.Close() + } + if connErr != nil || dbErr != nil { return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) }