Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Migrate from lib/pq to pgx to fix context cancellation data race #18492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions coderd/database/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package database

import (
"database/sql/driver"

"github.com/lib/pq"
)

// ConnectorCreator is a driver.Driver that can create a driver.Connector.
Expand All @@ -12,8 +10,10 @@ type ConnectorCreator interface {
Connector(name string) (driver.Connector, error)
}

// DialerConnector is a driver.Connector that can set a pq.Dialer.
// DialerConnector is a driver.Connector that can set a dialer.
// Note: pgx uses a different approach for custom dialers via config
type DialerConnector interface {
driver.Connector
Dialer(dialer pq.Dialer)
// Dialer functionality is handled differently in pgx
// Use stdlib.RegisterConnConfig for custom connection configuration
}
24 changes: 7 additions & 17 deletions coderd/database/dbtestutil/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"context"
"database/sql/driver"

"github.com/lib/pq"
"github.com/jackc/pgx/v5/stdlib"
"golang.org/x/xerrors"

"github.com/coder/coder/v2/coderd/database"
Expand All @@ -15,22 +15,11 @@ var _ database.DialerConnector = &Connector{}
type Connector struct {
name string
driver *Driver
dialer pq.Dialer
// Note: pgx handles dialing differently via config
}

func (c *Connector) Connect(_ context.Context) (driver.Conn, error) {
if c.dialer != nil {
conn, err := pq.DialOpen(c.dialer, c.name)
if err != nil {
return nil, xerrors.Errorf("failed to dial open connection: %w", err)
}

c.driver.Connections <- conn

return conn, nil
}

conn, err := pq.Driver{}.Open(c.name)
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
conn, err := stdlib.GetDefaultDriver().Open(c.name)
if err != nil {
return nil, xerrors.Errorf("failed to open connection: %w", err)
}
Expand All @@ -44,8 +33,9 @@ func (c *Connector) Driver() driver.Driver {
return c.driver
}

func (c *Connector) Dialer(dialer pq.Dialer) {
c.dialer = dialer
func (c *Connector) Dialer(dialer interface{}) {
// Note: pgx handles dialing differently via config
// This method is kept for interface compatibility but is a no-op
}

type Driver struct {
Expand Down
2 changes: 1 addition & 1 deletion coderd/database/dbtestutil/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"database/sql"
"testing"

_ "github.com/lib/pq"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"

Expand Down
36 changes: 18 additions & 18 deletions coderd/database/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ import (
"context"
"errors"

"github.com/lib/pq"
"github.com/jackc/pgx/v5/pgconn"
)

func IsSerializedError(err error) bool {
var pqErr *pq.Error
if errors.As(err, &pqErr) {
return pqErr.Code.Name() == "serialization_failure"
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
return pgErr.Code == "40001" // serialization_failure
}
return false
}
Expand All @@ -20,14 +20,14 @@ func IsSerializedError(err error) bool {
// the error must be caused by one of them. If no constraints are given,
// this function returns true for any unique violation.
func IsUniqueViolation(err error, uniqueConstraints ...UniqueConstraint) bool {
var pqErr *pq.Error
if errors.As(err, &pqErr) {
if pqErr.Code.Name() == "unique_violation" {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
if pgErr.Code == "23505" { // unique_violation
if len(uniqueConstraints) == 0 {
return true
}
for _, uc := range uniqueConstraints {
if pqErr.Constraint == string(uc) {
if pgErr.ConstraintName == string(uc) {
return true
}
}
Expand All @@ -42,14 +42,14 @@ func IsUniqueViolation(err error, uniqueConstraints ...UniqueConstraint) bool {
// the error must be caused by one of them. If no constraints are given,
// this function returns true for any foreign key violation.
func IsForeignKeyViolation(err error, foreignKeyConstraints ...ForeignKeyConstraint) bool {
var pqErr *pq.Error
if errors.As(err, &pqErr) {
if pqErr.Code.Name() == "foreign_key_violation" {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
if pgErr.Code == "23503" { // foreign_key_violation
if len(foreignKeyConstraints) == 0 {
return true
}
for _, fc := range foreignKeyConstraints {
if pqErr.Constraint == string(fc) {
if pgErr.ConstraintName == string(fc) {
return true
}
}
Expand All @@ -61,9 +61,9 @@ func IsForeignKeyViolation(err error, foreignKeyConstraints ...ForeignKeyConstra

// IsQueryCanceledError checks if the error is due to a query being canceled.
func IsQueryCanceledError(err error) bool {
var pqErr *pq.Error
if errors.As(err, &pqErr) {
return pqErr.Code == "57014" // query_canceled
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
return pgErr.Code == "57014" // query_canceled
} else if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return true
}
Expand All @@ -72,9 +72,9 @@ func IsQueryCanceledError(err error) bool {
}

func IsWorkspaceAgentLogsLimitError(err error) bool {
var pqErr *pq.Error
if errors.As(err, &pqErr) {
return pqErr.Constraint == "max_logs_length" && pqErr.Table == "workspace_agents"
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
return pgErr.ConstraintName == "max_logs_length" && pgErr.TableName == "workspace_agents"
}

return false
Expand Down
8 changes: 4 additions & 4 deletions coderd/database/migrations/txnmigrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"strings"

"github.com/golang-migrate/migrate/v4/database"
"github.com/lib/pq"
"github.com/jackc/pgx/v5/pgconn"
"golang.org/x/xerrors"
)

Expand Down Expand Up @@ -81,7 +81,7 @@ func (d *pgTxnDriver) runStatement(statement []byte) error {
return nil
}
if _, err := d.tx.ExecContext(ctx, query); err != nil {
var pgErr *pq.Error
var pgErr *pgconn.PgError
if xerrors.As(err, &pgErr) {
var line uint
message := fmt.Sprintf("migration failed: %s", pgErr.Message)
Expand Down Expand Up @@ -131,9 +131,9 @@ func (d *pgTxnDriver) Version() (version int, dirty bool, err error) {
return database.NilVersion, false, nil

case err != nil:
var pgErr *pq.Error
var pgErr *pgconn.PgError
if xerrors.As(err, &pgErr) {
if pgErr.Code.Name() == "undefined_table" {
if pgErr.Code == "42P01" { // undefined_table
return database.NilVersion, false, nil
}
}
Expand Down
21 changes: 10 additions & 11 deletions coderd/database/modelqueries.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"strings"

"github.com/google/uuid"
"github.com/lib/pq"
"golang.org/x/xerrors"

"github.com/coder/coder/v2/coderd/rbac"
Expand Down Expand Up @@ -78,7 +77,7 @@ func (q *sqlQuerier) GetAuthorizedTemplates(ctx context.Context, arg GetTemplate
arg.OrganizationID,
arg.ExactName,
arg.FuzzyName,
pq.Array(arg.IDs),
arg.IDs,
arg.Deprecated,
arg.HasAITask,
)
Expand Down Expand Up @@ -247,17 +246,17 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
// The name comment is for metric tracking
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s", filtered)
rows, err := q.db.QueryContext(ctx, query,
pq.Array(arg.ParamNames),
pq.Array(arg.ParamValues),
arg.ParamNames,
arg.ParamValues,
arg.Deleted,
arg.Status,
arg.OwnerID,
arg.OrganizationID,
pq.Array(arg.HasParam),
arg.HasParam,
arg.OwnerUsername,
arg.TemplateName,
pq.Array(arg.TemplateIDs),
pq.Array(arg.WorkspaceIds),
arg.TemplateIDs,
arg.WorkspaceIds,
arg.Name,
arg.HasAgent,
arg.AgentInactiveDisconnectTimeoutSeconds,
Expand Down Expand Up @@ -357,7 +356,7 @@ func (q *sqlQuerier) GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx context.Conte
&i.Name,
&i.JobStatus,
&i.Transition,
pq.Array(&i.Agents),
&i.Agents,
); err != nil {
return nil, err
}
Expand Down Expand Up @@ -393,15 +392,15 @@ func (q *sqlQuerier) GetAuthorizedUsers(ctx context.Context, arg GetUsersParams,
rows, err := q.db.QueryContext(ctx, query,
arg.AfterID,
arg.Search,
pq.Array(arg.Status),
pq.Array(arg.RbacRole),
arg.Status,
arg.RbacRole,
arg.LastSeenBefore,
arg.LastSeenAfter,
arg.CreatedBefore,
arg.CreatedAfter,
arg.IncludeSystem,
arg.GithubComUserID,
pq.Array(arg.LoginType),
arg.LoginType,
arg.OffsetOpt,
arg.LimitOpt,
)
Expand Down
Loading
Loading