Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit cffcd71

Browse files
committed
fix: Use correct db inside InTx
Add rules.go for catching this
1 parent 303feb1 commit cffcd71

File tree

5 files changed

+60
-9
lines changed

5 files changed

+60
-9
lines changed

coderd/organizations.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ func (api *API) postOrganizations(rw http.ResponseWriter, r *http.Request) {
5757
}
5858

5959
var organization database.Organization
60-
err = api.Database.InTx(func(db database.Store) error {
61-
organization, err = api.Database.InsertOrganization(r.Context(), database.InsertOrganizationParams{
60+
err = api.Database.InTx(func(store database.Store) error {
61+
organization, err = store.InsertOrganization(r.Context(), database.InsertOrganizationParams{
6262
ID: uuid.New(),
6363
Name: req.Name,
6464
CreatedAt: database.Now(),
@@ -67,7 +67,7 @@ func (api *API) postOrganizations(rw http.ResponseWriter, r *http.Request) {
6767
if err != nil {
6868
return xerrors.Errorf("create organization: %w", err)
6969
}
70-
_, err = api.Database.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
70+
_, err = store.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
7171
OrganizationID: organization.ID,
7272
UserID: apiKey.UserID,
7373
CreatedAt: database.Now(),

coderd/templateversions.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ func (api *API) templateVersionsByTemplate(rw http.ResponseWriter, r *http.Reque
391391
if paginationParams.AfterID != uuid.Nil {
392392
// See if the record exists first. If the record does not exist, the pagination
393393
// query will not work.
394-
_, err := api.Database.GetTemplateVersionByID(r.Context(), paginationParams.AfterID)
394+
_, err := store.GetTemplateVersionByID(r.Context(), paginationParams.AfterID)
395395
if err != nil && xerrors.Is(err, sql.ErrNoRows) {
396396
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
397397
Message: fmt.Sprintf("record at \"after_id\" (%q) does not exists", paginationParams.AfterID.String()),
@@ -405,7 +405,7 @@ func (api *API) templateVersionsByTemplate(rw http.ResponseWriter, r *http.Reque
405405
}
406406
}
407407

408-
versions, err := api.Database.GetTemplateVersionsByTemplateID(r.Context(), database.GetTemplateVersionsByTemplateIDParams{
408+
versions, err := store.GetTemplateVersionsByTemplateID(r.Context(), database.GetTemplateVersionsByTemplateIDParams{
409409
TemplateID: template.ID,
410410
AfterID: paginationParams.AfterID,
411411
LimitOpt: int32(paginationParams.Limit),
@@ -426,7 +426,7 @@ func (api *API) templateVersionsByTemplate(rw http.ResponseWriter, r *http.Reque
426426
for _, version := range versions {
427427
jobIDs = append(jobIDs, version.JobID)
428428
}
429-
jobs, err := api.Database.GetProvisionerJobsByIDs(r.Context(), jobIDs)
429+
jobs, err := store.GetProvisionerJobsByIDs(r.Context(), jobIDs)
430430
if err != nil {
431431
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
432432
Message: fmt.Sprintf("get jobs: %s", err),
@@ -608,7 +608,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
608608
}
609609
}
610610

611-
provisionerJob, err = api.Database.InsertProvisionerJob(r.Context(), database.InsertProvisionerJobParams{
611+
provisionerJob, err = db.InsertProvisionerJob(r.Context(), database.InsertProvisionerJobParams{
612612
ID: jobID,
613613
CreatedAt: database.Now(),
614614
UpdatedAt: database.Now(),
@@ -632,7 +632,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
632632
}
633633
}
634634

635-
templateVersion, err = api.Database.InsertTemplateVersion(r.Context(), database.InsertTemplateVersionParams{
635+
templateVersion, err = db.InsertTemplateVersion(r.Context(), database.InsertTemplateVersionParams{
636636
ID: uuid.New(),
637637
TemplateID: templateID,
638638
OrganizationID: organization.ID,

coderd/workspacebuilds.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func (api *API) workspaceBuilds(rw http.ResponseWriter, r *http.Request) {
8585
OffsetOpt: int32(paginationParams.Offset),
8686
LimitOpt: int32(paginationParams.Limit),
8787
}
88-
builds, err = api.Database.GetWorkspaceBuildByWorkspaceID(r.Context(), req)
88+
builds, err = store.GetWorkspaceBuildByWorkspaceID(r.Context(), req)
8989
if xerrors.Is(err, sql.ErrNoRows) {
9090
err = nil
9191
}

docker-compose.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ services:
1818
condition: service_healthy
1919
database:
2020
image: "postgres:14.2"
21+
ports:
22+
- "5432:5432"
2123
environment:
2224
POSTGRES_USER: ${POSTGRES_USER:-username} # The PostgreSQL user (useful to connect to the database)
2325
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-password} # The PostgreSQL password (useful to connect to the database)

scripts/rules.go

+49
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,52 @@ func doNotCallTFailNowInsideGoroutine(m dsl.Matcher) {
6969
Where(m["t"].Type.Implements("testing.TB") && m["fail"].Text.Matches("^(FailNow|Fatal|Fatalf)$")).
7070
Report("Do not call functions that may call t.FailNow in a goroutine, as this can cause data races (see testing.go:834)")
7171
}
72+
73+
// InTx checks to ensure the database used inside the transaction closure is the transaction
74+
// database, and not the original database that creates the tx.
75+
func InTx(m dsl.Matcher) {
76+
m.Import("github.com/coder/coder/coderd/database")
77+
78+
// ':=' and '=' are different matches. Really...
79+
m.Match(`
80+
$x.InTx(func($y database.Store) error {
81+
$*_
82+
$*_ := $x.$f($*_)
83+
$*_
84+
})
85+
`, `
86+
$x.InTx(func($y database.Store) error {
87+
$*_
88+
$*_ = $x.$f($*_)
89+
$*_
90+
})
91+
`).Where(m["x"].Text != m["y"].Text && m["x"].Type.Implements("database.Store")).
92+
At(m["f"]).
93+
Report("Do not use the database directly within the InTx closure. Use '$y' instead of '$x'.")
94+
95+
// When using a tx closure, ensure that if you pass the db to another
96+
// function inside the closure, it is the tx.
97+
// This will miss more complex cases such as passing the db as apart
98+
// of another struct.
99+
m.Match(
100+
`
101+
$x.InTx(func($y database.Store) error {
102+
$*_
103+
$*_ := $f($*_, $x, $*_)
104+
$*_
105+
})
106+
`, `
107+
$x.InTx(func($y database.Store) error {
108+
$*_
109+
$*_ = $f($*_, $x, $*_)
110+
$*_
111+
})
112+
`, `
113+
$x.InTx(func($y database.Store) error {
114+
$*_
115+
$f($*_, $x, $*_)
116+
$*_
117+
})
118+
`).Where(m["x"].Text != m["y"].Text && m["x"].Type.Implements("database.Store")).
119+
At(m["f"]).Report("Pass the tx database into the '$f' function inside the closure. Use '$y' over $x'")
120+
}

0 commit comments

Comments
 (0)