@@ -46,6 +46,8 @@ type Config struct {
46
46
MigrationsTable string
47
47
DatabaseName string
48
48
SchemaName string
49
+ migrationsSchemaName string
50
+ migrationsTableName string
49
51
StatementTimeout time.Duration
50
52
MigrationsTableQuoted bool
51
53
MultiStatementEnabled bool
@@ -103,6 +105,19 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
103
105
config .MigrationsTable = DefaultMigrationsTable
104
106
}
105
107
108
+ config .migrationsSchemaName = config .SchemaName
109
+ config .migrationsTableName = config .MigrationsTable
110
+ if config .MigrationsTableQuoted {
111
+ re := regexp .MustCompile (`"(.*?)"` )
112
+ result := re .FindAllStringSubmatch (config .MigrationsTable , - 1 )
113
+ config .migrationsTableName = result [len (result )- 1 ][1 ]
114
+ if len (result ) == 2 {
115
+ config .migrationsSchemaName = result [0 ][1 ]
116
+ } else if len (result ) > 2 {
117
+ return nil , fmt .Errorf ("\" %s\" MigrationsTable contains too many dot characters" , config .MigrationsTable )
118
+ }
119
+ }
120
+
106
121
conn , err := instance .Conn (context .Background ())
107
122
108
123
if err != nil {
@@ -209,7 +224,7 @@ func (p *Postgres) Lock() error {
209
224
return database .ErrLocked
210
225
}
211
226
212
- aid , err := database .GenerateAdvisoryLockId (p .config .DatabaseName , p .config .SchemaName )
227
+ aid , err := database .GenerateAdvisoryLockId (p .config .DatabaseName , p .config .migrationsSchemaName , p . config . migrationsTableName )
213
228
if err != nil {
214
229
return err
215
230
}
@@ -229,7 +244,7 @@ func (p *Postgres) Unlock() error {
229
244
return nil
230
245
}
231
246
232
- aid , err := database .GenerateAdvisoryLockId (p .config .DatabaseName , p .config .SchemaName )
247
+ aid , err := database .GenerateAdvisoryLockId (p .config .DatabaseName , p .config .migrationsSchemaName , p . config . migrationsTableName )
233
248
if err != nil {
234
249
return err
235
250
}
@@ -335,7 +350,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
335
350
return & database.Error {OrigErr : err , Err : "transaction start failed" }
336
351
}
337
352
338
- query := `TRUNCATE ` + p . quoteIdentifier (p .config .MigrationsTable )
353
+ query := `TRUNCATE ` + quoteIdentifier ( p . config . migrationsSchemaName ) + `.` + quoteIdentifier (p .config .migrationsTableName )
339
354
if _ , err := tx .Exec (query ); err != nil {
340
355
if errRollback := tx .Rollback (); errRollback != nil {
341
356
err = multierror .Append (err , errRollback )
@@ -347,8 +362,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
347
362
// empty schema version for failed down migration on the first migration
348
363
// See: https://github.com/golang-migrate/migrate/issues/330
349
364
if version >= 0 || (version == database .NilVersion && dirty ) {
350
- query = `INSERT INTO ` + p .quoteIdentifier (p .config .MigrationsTable ) +
351
- ` (version, dirty) VALUES ($1, $2)`
365
+ query = `INSERT INTO ` + quoteIdentifier (p .config .migrationsSchemaName ) + `.` + quoteIdentifier (p .config .migrationsTableName ) + ` (version, dirty) VALUES ($1, $2)`
352
366
if _ , err := tx .Exec (query , version , dirty ); err != nil {
353
367
if errRollback := tx .Rollback (); errRollback != nil {
354
368
err = multierror .Append (err , errRollback )
@@ -365,7 +379,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
365
379
}
366
380
367
381
func (p * Postgres ) Version () (version int , dirty bool , err error ) {
368
- query := `SELECT version, dirty FROM ` + p . quoteIdentifier (p .config .MigrationsTable ) + ` LIMIT 1`
382
+ query := `SELECT version, dirty FROM ` + quoteIdentifier ( p . config . migrationsSchemaName ) + `.` + quoteIdentifier (p .config .migrationsTableName ) + ` LIMIT 1`
369
383
err = p .conn .QueryRowContext (context .Background (), query ).Scan (& version , & dirty )
370
384
switch {
371
385
case err == sql .ErrNoRows :
@@ -415,7 +429,7 @@ func (p *Postgres) Drop() (err error) {
415
429
if len (tableNames ) > 0 {
416
430
// delete one by one ...
417
431
for _ , t := range tableNames {
418
- query = `DROP TABLE IF EXISTS ` + p . quoteIdentifier (t ) + ` CASCADE`
432
+ query = `DROP TABLE IF EXISTS ` + quoteIdentifier (t ) + ` CASCADE`
419
433
if _ , err := p .conn .ExecContext (context .Background (), query ); err != nil {
420
434
return & database.Error {OrigErr : err , Query : []byte (query )}
421
435
}
@@ -447,27 +461,8 @@ func (p *Postgres) ensureVersionTable() (err error) {
447
461
// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
448
462
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
449
463
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
450
- var row * sql.Row
451
- tableName := p .config .MigrationsTable
452
- schemaName := ""
453
- if p .config .MigrationsTableQuoted {
454
- re := regexp .MustCompile (`"(.*?)"` )
455
- result := re .FindAllStringSubmatch (p .config .MigrationsTable , - 1 )
456
- tableName = result [len (result )- 1 ][1 ]
457
- if len (result ) == 2 {
458
- schemaName = result [0 ][1 ]
459
- } else if len (result ) > 2 {
460
- return fmt .Errorf ("\" %s\" MigrationsTable contains too many dot characters" , p .config .MigrationsTable )
461
- }
462
- }
463
- var query string
464
- if len (schemaName ) > 0 {
465
- query = `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = $2 LIMIT 1`
466
- row = p .conn .QueryRowContext (context .Background (), query , tableName , schemaName )
467
- } else {
468
- query = `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
469
- row = p .conn .QueryRowContext (context .Background (), query , tableName )
470
- }
464
+ query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
465
+ row := p .conn .QueryRowContext (context .Background (), query , p .config .migrationsSchemaName , p .config .migrationsTableName )
471
466
472
467
var count int
473
468
err = row .Scan (& count )
@@ -479,7 +474,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
479
474
return nil
480
475
}
481
476
482
- query = `CREATE TABLE IF NOT EXISTS ` + p . quoteIdentifier (p .config .MigrationsTable ) + ` (version bigint not null primary key, dirty boolean not null)`
477
+ query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier ( p . config . migrationsSchemaName ) + `.` + quoteIdentifier (p .config .migrationsTableName ) + ` (version bigint not null primary key, dirty boolean not null)`
483
478
if _ , err = p .conn .ExecContext (context .Background (), query ); err != nil {
484
479
return & database.Error {OrigErr : err , Query : []byte (query )}
485
480
}
@@ -488,10 +483,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
488
483
}
489
484
490
485
// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
491
- func (p * Postgres ) quoteIdentifier (name string ) string {
492
- if p .config .MigrationsTableQuoted {
493
- return name
494
- }
486
+ func quoteIdentifier (name string ) string {
495
487
end := strings .IndexRune (name , 0 )
496
488
if end > - 1 {
497
489
name = name [:end ]
0 commit comments