Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d0af582

Browse files
postgres,pgx: fix SchemaName parameter is ignored #golang-migrate#547
1 parent 4720914 commit d0af582

File tree

4 files changed

+50
-150
lines changed

4 files changed

+50
-150
lines changed

database/pgx/pgx.go

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ type Config struct {
4646
MigrationsTable string
4747
DatabaseName string
4848
SchemaName string
49+
migrationsSchemaName string
50+
migrationsTableName string
4951
StatementTimeout time.Duration
5052
MigrationsTableQuoted bool
5153
MultiStatementEnabled bool
@@ -103,6 +105,19 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
103105
config.MigrationsTable = DefaultMigrationsTable
104106
}
105107

108+
config.migrationsSchemaName = config.SchemaName
109+
config.migrationsTableName = config.MigrationsTable
110+
if config.MigrationsTableQuoted {
111+
re := regexp.MustCompile(`"(.*?)"`)
112+
result := re.FindAllStringSubmatch(config.MigrationsTable, -1)
113+
config.migrationsTableName = result[len(result)-1][1]
114+
if len(result) == 2 {
115+
config.migrationsSchemaName = result[0][1]
116+
} else if len(result) > 2 {
117+
return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable)
118+
}
119+
}
120+
106121
conn, err := instance.Conn(context.Background())
107122

108123
if err != nil {
@@ -209,7 +224,7 @@ func (p *Postgres) Lock() error {
209224
return database.ErrLocked
210225
}
211226

212-
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName)
227+
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
213228
if err != nil {
214229
return err
215230
}
@@ -229,7 +244,7 @@ func (p *Postgres) Unlock() error {
229244
return nil
230245
}
231246

232-
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName)
247+
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
233248
if err != nil {
234249
return err
235250
}
@@ -335,7 +350,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
335350
return &database.Error{OrigErr: err, Err: "transaction start failed"}
336351
}
337352

338-
query := `TRUNCATE ` + p.quoteIdentifier(p.config.MigrationsTable)
353+
query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName)
339354
if _, err := tx.Exec(query); err != nil {
340355
if errRollback := tx.Rollback(); errRollback != nil {
341356
err = multierror.Append(err, errRollback)
@@ -347,8 +362,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
347362
// empty schema version for failed down migration on the first migration
348363
// See: https://github.com/golang-migrate/migrate/issues/330
349364
if version >= 0 || (version == database.NilVersion && dirty) {
350-
query = `INSERT INTO ` + p.quoteIdentifier(p.config.MigrationsTable) +
351-
` (version, dirty) VALUES ($1, $2)`
365+
query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
352366
if _, err := tx.Exec(query, version, dirty); err != nil {
353367
if errRollback := tx.Rollback(); errRollback != nil {
354368
err = multierror.Append(err, errRollback)
@@ -365,7 +379,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
365379
}
366380

367381
func (p *Postgres) Version() (version int, dirty bool, err error) {
368-
query := `SELECT version, dirty FROM ` + p.quoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1`
382+
query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
369383
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
370384
switch {
371385
case err == sql.ErrNoRows:
@@ -415,7 +429,7 @@ func (p *Postgres) Drop() (err error) {
415429
if len(tableNames) > 0 {
416430
// delete one by one ...
417431
for _, t := range tableNames {
418-
query = `DROP TABLE IF EXISTS ` + p.quoteIdentifier(t) + ` CASCADE`
432+
query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE`
419433
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
420434
return &database.Error{OrigErr: err, Query: []byte(query)}
421435
}
@@ -447,27 +461,8 @@ func (p *Postgres) ensureVersionTable() (err error) {
447461
// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
448462
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
449463
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
450-
var row *sql.Row
451-
tableName := p.config.MigrationsTable
452-
schemaName := ""
453-
if p.config.MigrationsTableQuoted {
454-
re := regexp.MustCompile(`"(.*?)"`)
455-
result := re.FindAllStringSubmatch(p.config.MigrationsTable, -1)
456-
tableName = result[len(result)-1][1]
457-
if len(result) == 2 {
458-
schemaName = result[0][1]
459-
} else if len(result) > 2 {
460-
return fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", p.config.MigrationsTable)
461-
}
462-
}
463-
var query string
464-
if len(schemaName) > 0 {
465-
query = `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = $2 LIMIT 1`
466-
row = p.conn.QueryRowContext(context.Background(), query, tableName, schemaName)
467-
} else {
468-
query = `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
469-
row = p.conn.QueryRowContext(context.Background(), query, tableName)
470-
}
464+
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
465+
row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
471466

472467
var count int
473468
err = row.Scan(&count)
@@ -479,7 +474,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
479474
return nil
480475
}
481476

482-
query = `CREATE TABLE IF NOT EXISTS ` + p.quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
477+
query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
483478
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
484479
return &database.Error{OrigErr: err, Query: []byte(query)}
485480
}
@@ -488,10 +483,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
488483
}
489484

490485
// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
491-
func (p *Postgres) quoteIdentifier(name string) string {
492-
if p.config.MigrationsTableQuoted {
493-
return name
494-
}
486+
func quoteIdentifier(name string) string {
495487
end := strings.IndexRune(name, 0)
496488
if end > -1 {
497489
name = name[:end]

database/pgx/pgx_test.go

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -760,43 +760,3 @@ func Test_computeLineFromPos(t *testing.T) {
760760
})
761761
}
762762
}
763-
764-
func Test_quoteIdentifier(t *testing.T) {
765-
testcases := []struct {
766-
migrationsTableQuoted bool
767-
migrationsTable string
768-
expected string
769-
}{
770-
{
771-
false,
772-
"schema_name.table_name",
773-
"\"schema_name.table_name\"",
774-
},
775-
{
776-
false,
777-
"table_name",
778-
"\"table_name\"",
779-
},
780-
{
781-
true,
782-
"\"schema_name\".\"table_name\"",
783-
"\"schema_name\".\"table_name\"",
784-
},
785-
{
786-
true,
787-
"\"table_name\"",
788-
"\"table_name\"",
789-
},
790-
}
791-
p := &Postgres{
792-
config: &Config{},
793-
}
794-
795-
for _, tc := range testcases {
796-
p.config.MigrationsTableQuoted = tc.migrationsTableQuoted
797-
got := p.quoteIdentifier(tc.migrationsTable)
798-
if tc.expected != got {
799-
t.Fatalf("expected %s but got %s", tc.expected, got)
800-
}
801-
}
802-
}

database/postgres/postgres.go

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ type Config struct {
4747
MultiStatementEnabled bool
4848
DatabaseName string
4949
SchemaName string
50+
migrationsSchemaName string
51+
migrationsTableName string
5052
StatementTimeout time.Duration
5153
MultiStatementMaxSize int
5254
}
@@ -101,6 +103,20 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
101103
if len(config.MigrationsTable) == 0 {
102104
config.MigrationsTable = DefaultMigrationsTable
103105
}
106+
107+
config.migrationsSchemaName = config.SchemaName
108+
config.migrationsTableName = config.MigrationsTable
109+
if config.MigrationsTableQuoted {
110+
re := regexp.MustCompile(`"(.*?)"`)
111+
result := re.FindAllStringSubmatch(config.MigrationsTable, -1)
112+
config.migrationsTableName = result[len(result)-1][1]
113+
if len(result) == 2 {
114+
config.migrationsSchemaName = result[0][1]
115+
} else if len(result) > 2 {
116+
return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable)
117+
}
118+
}
119+
104120
conn, err := instance.Conn(context.Background())
105121

106122
if err != nil {
@@ -202,7 +218,7 @@ func (p *Postgres) Lock() error {
202218
return database.ErrLocked
203219
}
204220

205-
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName)
221+
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
206222
if err != nil {
207223
return err
208224
}
@@ -222,7 +238,7 @@ func (p *Postgres) Unlock() error {
222238
return nil
223239
}
224240

225-
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName)
241+
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
226242
if err != nil {
227243
return err
228244
}
@@ -325,20 +341,13 @@ func runesLastIndex(input []rune, target rune) int {
325341
return -1
326342
}
327343

328-
func (p *Postgres) quoteIdentifier(name string) string {
329-
if p.config.MigrationsTableQuoted {
330-
return name
331-
}
332-
return pq.QuoteIdentifier(name)
333-
}
334-
335344
func (p *Postgres) SetVersion(version int, dirty bool) error {
336345
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
337346
if err != nil {
338347
return &database.Error{OrigErr: err, Err: "transaction start failed"}
339348
}
340349

341-
query := `TRUNCATE ` + p.quoteIdentifier(p.config.MigrationsTable)
350+
query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName)
342351
if _, err := tx.Exec(query); err != nil {
343352
if errRollback := tx.Rollback(); errRollback != nil {
344353
err = multierror.Append(err, errRollback)
@@ -350,8 +359,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
350359
// empty schema version for failed down migration on the first migration
351360
// See: https://github.com/golang-migrate/migrate/issues/330
352361
if version >= 0 || (version == database.NilVersion && dirty) {
353-
query = `INSERT INTO ` + p.quoteIdentifier(p.config.MigrationsTable) +
354-
` (version, dirty) VALUES ($1, $2)`
362+
query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
355363
if _, err := tx.Exec(query, version, dirty); err != nil {
356364
if errRollback := tx.Rollback(); errRollback != nil {
357365
err = multierror.Append(err, errRollback)
@@ -368,7 +376,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
368376
}
369377

370378
func (p *Postgres) Version() (version int, dirty bool, err error) {
371-
query := `SELECT version, dirty FROM ` + p.quoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1`
379+
query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
372380
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
373381
switch {
374382
case err == sql.ErrNoRows:
@@ -418,7 +426,7 @@ func (p *Postgres) Drop() (err error) {
418426
if len(tableNames) > 0 {
419427
// delete one by one ...
420428
for _, t := range tableNames {
421-
query = `DROP TABLE IF EXISTS ` + p.quoteIdentifier(t) + ` CASCADE`
429+
query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE`
422430
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
423431
return &database.Error{OrigErr: err, Query: []byte(query)}
424432
}
@@ -450,27 +458,8 @@ func (p *Postgres) ensureVersionTable() (err error) {
450458
// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
451459
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
452460
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
453-
var row *sql.Row
454-
tableName := p.config.MigrationsTable
455-
schemaName := ""
456-
if p.config.MigrationsTableQuoted {
457-
re := regexp.MustCompile(`"(.*?)"`)
458-
result := re.FindAllStringSubmatch(p.config.MigrationsTable, -1)
459-
tableName = result[len(result)-1][1]
460-
if len(result) == 2 {
461-
schemaName = result[0][1]
462-
} else if len(result) > 2 {
463-
return fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", p.config.MigrationsTable)
464-
}
465-
}
466-
var query string
467-
if len(schemaName) > 0 {
468-
query = `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = $2 LIMIT 1`
469-
row = p.conn.QueryRowContext(context.Background(), query, tableName, schemaName)
470-
} else {
471-
query = `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
472-
row = p.conn.QueryRowContext(context.Background(), query, tableName)
473-
}
461+
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
462+
row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
474463

475464
var count int
476465
err = row.Scan(&count)
@@ -482,7 +471,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
482471
return nil
483472
}
484473

485-
query = `CREATE TABLE IF NOT EXISTS ` + p.quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
474+
query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
486475
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
487476
return &database.Error{OrigErr: err, Query: []byte(query)}
488477
}

database/postgres/postgres_test.go

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -762,45 +762,4 @@ func Test_computeLineFromPos(t *testing.T) {
762762
run(true, true)
763763
})
764764
}
765-
766-
}
767-
768-
func Test_quoteIdentifier(t *testing.T) {
769-
testcases := []struct {
770-
migrationsTableQuoted bool
771-
migrationsTable string
772-
expected string
773-
}{
774-
{
775-
false,
776-
"schema_name.table_name",
777-
"\"schema_name.table_name\"",
778-
},
779-
{
780-
false,
781-
"table_name",
782-
"\"table_name\"",
783-
},
784-
{
785-
true,
786-
"\"schema_name\".\"table_name\"",
787-
"\"schema_name\".\"table_name\"",
788-
},
789-
{
790-
true,
791-
"\"table_name\"",
792-
"\"table_name\"",
793-
},
794-
}
795-
p := &Postgres{
796-
config: &Config{},
797-
}
798-
799-
for _, tc := range testcases {
800-
p.config.MigrationsTableQuoted = tc.migrationsTableQuoted
801-
got := p.quoteIdentifier(tc.migrationsTable)
802-
if tc.expected != got {
803-
t.Fatalf("expected %s but got %s", tc.expected, got)
804-
}
805-
}
806765
}

0 commit comments

Comments
 (0)