Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 6204921

Browse files
committed
Add helper package to parse multi-statement migrations
Addresses: golang-migrate#406
1 parent 8eb1d30 commit 6204921

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

database/multistmt/parse.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Package multistmt provides methods for parsing multi-statement database migrations
2+
package multistmt
3+
4+
import (
5+
"bufio"
6+
"bytes"
7+
"io"
8+
)
9+
10+
// StartBufSize is the default starting size of the buffer used to scan and parse multi-statement migrations
11+
var StartBufSize = 4096
12+
13+
// Handler handles a single migration parsed from a multi-statement migration.
14+
// It's given the single migration to handle and returns whether or not further statements
15+
// from the multi-statement migration should be parsed and handled.
16+
type Handler func(migration []byte) bool
17+
18+
func splitWithDelimiter(delimiter []byte) func(d []byte, atEOF bool) (int, []byte, error) {
19+
return func(d []byte, atEOF bool) (int, []byte, error) {
20+
// SplitFunc inspired by bufio.ScanLines() implementation
21+
if atEOF {
22+
if len(d) == 0 {
23+
return 0, nil, nil
24+
}
25+
return len(d), d, nil
26+
}
27+
if i := bytes.Index(d, delimiter); i >= 0 {
28+
return i + len(delimiter), d[:i+len(delimiter)], nil
29+
}
30+
return 0, nil, nil
31+
}
32+
}
33+
34+
// Parse parses the given multi-statement migration
35+
func Parse(reader io.Reader, delimiter []byte, maxMigrationSize int, h Handler) error {
36+
scanner := bufio.NewScanner(reader)
37+
scanner.Buffer(make([]byte, 0, StartBufSize), maxMigrationSize)
38+
scanner.Split(splitWithDelimiter(delimiter))
39+
for scanner.Scan() {
40+
cont := h(scanner.Bytes())
41+
if !cont {
42+
break
43+
}
44+
}
45+
return scanner.Err()
46+
}

database/multistmt/parse_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package multistmt_test
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
9+
"github.com/golang-migrate/migrate/v4/database/multistmt"
10+
)
11+
12+
const maxMigrationSize = 1024
13+
14+
func TestParse(t *testing.T) {
15+
testCases := []struct {
16+
name string
17+
multiStmt string
18+
delimiter string
19+
expected []string
20+
expectedErr error
21+
}{
22+
{name: "single statement, no delimiter", multiStmt: "single statement, no delimiter", delimiter: ";",
23+
expected: []string{"single statement, no delimiter"}, expectedErr: nil},
24+
{name: "single statement, one delimiter", multiStmt: "single statement, one delimiter;", delimiter: ";",
25+
expected: []string{"single statement, one delimiter;"}, expectedErr: nil},
26+
{name: "two statements, no trailing delimiter", multiStmt: "statement one; statement two", delimiter: ";",
27+
expected: []string{"statement one;", " statement two"}, expectedErr: nil},
28+
{name: "two statements, with trailing delimiter", multiStmt: "statement one; statement two;", delimiter: ";",
29+
expected: []string{"statement one;", " statement two;"}, expectedErr: nil},
30+
}
31+
32+
for _, tc := range testCases {
33+
t.Run(tc.name, func(t *testing.T) {
34+
stmts := make([]string, 0, len(tc.expected))
35+
err := multistmt.Parse(strings.NewReader(tc.multiStmt), []byte(tc.delimiter), maxMigrationSize, func(b []byte) bool {
36+
stmts = append(stmts, string(b))
37+
return true
38+
})
39+
assert.Equal(t, tc.expectedErr, err)
40+
assert.Equal(t, tc.expected, stmts)
41+
})
42+
}
43+
}
44+
45+
func TestParseDiscontinue(t *testing.T) {
46+
multiStmt := "statement one; statement two"
47+
delimiter := ";"
48+
expected := []string{"statement one;"}
49+
50+
stmts := make([]string, 0, len(expected))
51+
err := multistmt.Parse(strings.NewReader(multiStmt), []byte(delimiter), maxMigrationSize, func(b []byte) bool {
52+
stmts = append(stmts, string(b))
53+
return false
54+
})
55+
assert.Nil(t, err)
56+
assert.Equal(t, expected, stmts)
57+
}

0 commit comments

Comments
 (0)