|
| 1 | +// Package database connects to external services for stateful storage. |
| 2 | +// |
| 3 | +// Query functions are generated using sqlc. |
| 4 | +// |
| 5 | +// To modify the database schema: |
| 6 | +// 1. Add a new migration using "create_migration.sh" in database/migrations/ |
| 7 | +// 2. Run "make database/generate" in the root to generate models. |
| 8 | +// 3. Add/Edit queries in "query.sql" and run "make database/generate" to create Go code. |
| 9 | +package database |
| 10 | + |
| 11 | +import ( |
| 12 | + "context" |
| 13 | + "database/sql" |
| 14 | + "errors" |
| 15 | + |
| 16 | + "golang.org/x/xerrors" |
| 17 | +) |
| 18 | + |
| 19 | +// Store contains all queryable database functions. |
| 20 | +// It extends the generated interface to add transaction support. |
| 21 | +type Store interface { |
| 22 | + querier |
| 23 | + |
| 24 | + InTx(context.Context, func(Store) error) error |
| 25 | +} |
| 26 | + |
| 27 | +// DBTX represents a database connection or transaction. |
| 28 | +type DBTX interface { |
| 29 | + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) |
| 30 | + PrepareContext(context.Context, string) (*sql.Stmt, error) |
| 31 | + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) |
| 32 | + QueryRowContext(context.Context, string, ...interface{}) *sql.Row |
| 33 | +} |
| 34 | + |
| 35 | +// New creates a new database store using a SQL database connection. |
| 36 | +func New(sdb *sql.DB) Store { |
| 37 | + return &sqlQuerier{ |
| 38 | + db: sdb, |
| 39 | + sdb: sdb, |
| 40 | + } |
| 41 | +} |
| 42 | + |
| 43 | +type sqlQuerier struct { |
| 44 | + sdb *sql.DB |
| 45 | + db DBTX |
| 46 | +} |
| 47 | + |
| 48 | +// InTx performs database operations inside a transaction. |
| 49 | +func (q *sqlQuerier) InTx(ctx context.Context, fn func(Store) error) error { |
| 50 | + if q.sdb == nil { |
| 51 | + return nil |
| 52 | + } |
| 53 | + tx, err := q.sdb.Begin() |
| 54 | + if err != nil { |
| 55 | + return xerrors.Errorf("begin transaction: %w", err) |
| 56 | + } |
| 57 | + defer func() { |
| 58 | + rerr := tx.Rollback() |
| 59 | + if rerr == nil || errors.Is(rerr, sql.ErrTxDone) { |
| 60 | + // no need to do anything, tx committed successfully |
| 61 | + return |
| 62 | + } |
| 63 | + // couldn't roll back for some reason, extend returned error |
| 64 | + err = xerrors.Errorf("defer (%s): %w", rerr.Error(), err) |
| 65 | + }() |
| 66 | + err = fn(&sqlQuerier{db: tx}) |
| 67 | + if err != nil { |
| 68 | + return xerrors.Errorf("execute transaction: %w", err) |
| 69 | + } |
| 70 | + err = tx.Commit() |
| 71 | + if err != nil { |
| 72 | + return xerrors.Errorf("commit transaction: %w", err) |
| 73 | + } |
| 74 | + return nil |
| 75 | +} |
0 commit comments