From effcadc580dfe1f2250d1ce1c7bb234845b82950 Mon Sep 17 00:00:00 2001 From: sters Date: Mon, 5 Sep 2022 17:22:52 +0900 Subject: [PATCH] fix for BatchReadOnlyTransaction check --- passes/unclosetx/testdata/src/a/a.go | 19 +++++++++++++++++++ passes/unclosetx/unclosetx.go | 17 +++++++++++++---- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/passes/unclosetx/testdata/src/a/a.go b/passes/unclosetx/testdata/src/a/a.go index bd4d2a8..17efadf 100644 --- a/passes/unclosetx/testdata/src/a/a.go +++ b/passes/unclosetx/testdata/src/a/a.go @@ -73,3 +73,22 @@ func f6(ctx context.Context, client *spanner.Client) error { } return nil } + +func f7(ctx context.Context, client *spanner.Client) error { + ro, _ := client.BatchReadOnlyTransaction(ctx, spanner.StrongRead()) + defer ro.Close() + + stmt := spanner.Statement{SQL: `SELECT 1`} + + iter := ro.Query(ctx, stmt) + defer iter.Stop() + + for { + _, err := iter.Next() + if err != nil { + break + } + } + + return nil +} diff --git a/passes/unclosetx/unclosetx.go b/passes/unclosetx/unclosetx.go index 72b7d41..d915e72 100644 --- a/passes/unclosetx/unclosetx.go +++ b/passes/unclosetx/unclosetx.go @@ -33,15 +33,24 @@ func run(pass *analysis.Pass) (interface{}, error) { funcs := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA).SrcFuncs cmaps := pass.ResultOf[commentmap.Analyzer].(comment.Maps) - txTyp := zaganeutils.TypeOf(pass, "*ReadOnlyTransaction") - if txTyp == nil { + txTypeRo := zaganeutils.TypeOf(pass, "*ReadOnlyTransaction") + if txTypeRo == nil { + // skip checking + return nil, nil + } + + txTypeBatch := zaganeutils.TypeOf(pass, "*BatchReadOnlyTransaction") + if txTypeBatch == nil { // skip checking return nil, nil } var methods []*types.Func for _, s := range strings.Split(closeMethods, ",") { - if m := analysisutil.MethodOf(txTyp, s); m != nil { + if m := analysisutil.MethodOf(txTypeRo, s); m != nil { + methods = append(methods, m) + } + if m := analysisutil.MethodOf(txTypeBatch, s); m != nil { methods = append(methods, m) } } @@ -59,7 +68,7 @@ func run(pass *analysis.Pass) (interface{}, error) { // skip this continue } - instrs := analysisutil.NotCalledIn(f, txTyp, methods...) + instrs := analysisutil.NotCalledIn(f, txTypeRo, methods...) for _, instr := range instrs { pos := instr.Pos() if pos == token.NoPos {