diff --git a/README.md b/README.md index a10e1f7eacf..a68594df96b 100644 --- a/README.md +++ b/README.md @@ -76,10 +76,12 @@ golang.org/x/tools/txtar ## Contributing This repository uses Gerrit for code changes. -To learn how to submit changes, see https://golang.org/doc/contribute.html. +To learn how to submit changes, see https://go.dev/doc/contribute. + +The git repository is https://go.googlesource.com/tools. The main issue tracker for the tools repository is located at -https://github.com/golang/go/issues. Prefix your issue with "x/tools/(your +https://go.dev/issues. Prefix your issue with "x/tools/(your subdir):" in the subject line, so it is easy to find. ### JavaScript and CSS Formatting diff --git a/go.mod b/go.mod index 0fbd072dd43..d7b6f18ddc1 100644 --- a/go.mod +++ b/go.mod @@ -6,9 +6,9 @@ require ( github.com/google/go-cmp v0.6.0 github.com/yuin/goldmark v1.4.13 golang.org/x/mod v0.22.0 - golang.org/x/net v0.31.0 - golang.org/x/sync v0.9.0 + golang.org/x/net v0.32.0 + golang.org/x/sync v0.10.0 golang.org/x/telemetry v0.0.0-20240521205824-bda55230c457 ) -require golang.org/x/sys v0.27.0 // indirect +require golang.org/x/sys v0.28.0 // indirect diff --git a/go.sum b/go.sum index 1d4e510a2c0..9b25a309b97 100644 --- a/go.sum +++ b/go.sum @@ -4,11 +4,11 @@ github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= -golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= -golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= -golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= -golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= -golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/net v0.32.0 h1:ZqPmj8Kzc+Y6e0+skZsuACbx+wzMgo5MQsJh9Qd6aYI= +golang.org/x/net v0.32.0/go.mod h1:CwU0IoeOlnQQWJ6ioyFrfRuomB8GKF6KbYXZVyeXNfs= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/telemetry v0.0.0-20240521205824-bda55230c457 h1:zf5N6UOrA487eEFacMePxjXAJctxKmyjKUsjA11Uzuk= golang.org/x/telemetry v0.0.0-20240521205824-bda55230c457/go.mod h1:pRgIJT+bRLFKnoM1ldnzKoxTIn14Yxz928LQRYYgIN0= diff --git a/go/analysis/analysis.go b/go/analysis/analysis.go index aa02eeda680..d384aa89b8e 100644 --- a/go/analysis/analysis.go +++ b/go/analysis/analysis.go @@ -50,7 +50,7 @@ type Analyzer struct { // RunDespiteErrors allows the driver to invoke // the Run method of this analyzer even on a // package that contains parse or type errors. - // The Pass.TypeErrors field may consequently be non-empty. + // The [Pass.TypeErrors] field may consequently be non-empty. RunDespiteErrors bool // Requires is a set of analyzers that must run successfully diff --git a/go/analysis/analysistest/analysistest.go b/go/analysis/analysistest/analysistest.go index c1b2dd4fa1b..6aa04ed1502 100644 --- a/go/analysis/analysistest/analysistest.go +++ b/go/analysis/analysistest/analysistest.go @@ -23,7 +23,8 @@ import ( "text/scanner" "golang.org/x/tools/go/analysis" - "golang.org/x/tools/go/analysis/internal/checker" + "golang.org/x/tools/go/analysis/checker" + "golang.org/x/tools/go/analysis/internal" "golang.org/x/tools/go/packages" "golang.org/x/tools/internal/diff" "golang.org/x/tools/internal/testenv" @@ -137,7 +138,7 @@ type Testing interface { // analyzers that offer alternative fixes are advised to put each fix // in a separate .go file in the testdata. func RunWithSuggestedFixes(t Testing, dir string, a *analysis.Analyzer, patterns ...string) []*Result { - r := Run(t, dir, a, patterns...) + results := Run(t, dir, a, patterns...) // If the immediate caller of RunWithSuggestedFixes is in // x/tools, we apply stricter checks as required by gopls. @@ -162,7 +163,9 @@ func RunWithSuggestedFixes(t Testing, dir string, a *analysis.Analyzer, patterns // Validating the results separately means as long as the two analyses // don't produce conflicting suggestions for a single file, everything // should match up. - for _, act := range r { + for _, result := range results { + act := result.Action + // file -> message -> edits fileEdits := make(map[*token.File]map[string][]diff.Edit) fileContents := make(map[*token.File][]byte) @@ -185,14 +188,14 @@ func RunWithSuggestedFixes(t Testing, dir string, a *analysis.Analyzer, patterns if start > end { t.Errorf( "diagnostic for analysis %v contains Suggested Fix with malformed edit: pos (%v) > end (%v)", - act.Pass.Analyzer.Name, start, end) + act.Analyzer.Name, start, end) continue } - file, endfile := act.Pass.Fset.File(start), act.Pass.Fset.File(end) + file, endfile := act.Package.Fset.File(start), act.Package.Fset.File(end) if file == nil || endfile == nil || file != endfile { t.Errorf( "diagnostic for analysis %v contains Suggested Fix with malformed spanning files %v and %v", - act.Pass.Analyzer.Name, file.Name(), endfile.Name()) + act.Analyzer.Name, file.Name(), endfile.Name()) continue } if _, ok := fileContents[file]; !ok { @@ -275,7 +278,7 @@ func RunWithSuggestedFixes(t Testing, dir string, a *analysis.Analyzer, patterns } } } - return r + return results } // applyDiffsAndCompare applies edits to src and compares the results against @@ -355,24 +358,76 @@ func Run(t Testing, dir string, a *analysis.Analyzer, patterns ...string) []*Res return nil } - if err := analysis.Validate([]*analysis.Analyzer{a}); err != nil { - t.Errorf("Validate: %v", err) + // Print parse and type errors to the test log. + // (Do not print them to stderr, which would pollute + // the log in cases where the tests pass.) + if t, ok := t.(testing.TB); ok && !a.RunDespiteErrors { + packages.Visit(pkgs, nil, func(pkg *packages.Package) { + for _, err := range pkg.Errors { + t.Log(err) + } + }) + } + + res, err := checker.Analyze([]*analysis.Analyzer{a}, pkgs, nil) + if err != nil { + t.Errorf("Analyze: %v", err) return nil } - results := checker.TestAnalyzer(a, pkgs) - for _, result := range results { - if result.Err != nil { - t.Errorf("error analyzing %s: %v", result.Pass, result.Err) + var results []*Result + for _, act := range res.Roots { + if act.Err != nil { + t.Errorf("error analyzing %s: %v", act, act.Err) } else { - check(t, dir, result.Pass, result.Diagnostics, result.Facts) + check(t, dir, act) } + + // Compute legacy map of facts relating to this package. + facts := make(map[types.Object][]analysis.Fact) + for _, objFact := range act.AllObjectFacts() { + if obj := objFact.Object; obj.Pkg() == act.Package.Types { + facts[obj] = append(facts[obj], objFact.Fact) + } + } + for _, pkgFact := range act.AllPackageFacts() { + if pkgFact.Package == act.Package.Types { + facts[nil] = append(facts[nil], pkgFact.Fact) + } + } + + // Construct the legacy result. + results = append(results, &Result{ + Pass: internal.Pass(act), + Diagnostics: act.Diagnostics, + Facts: facts, + Result: act.Result, + Err: act.Err, + Action: act, + }) } return results } // A Result holds the result of applying an analyzer to a package. -type Result = checker.TestAnalyzerResult +// +// Facts contains only facts associated with the package and its objects. +// +// This internal type was inadvertently and regrettably exposed +// through a public type alias. It is essentially redundant with +// [checker.Action], but must be retained for compatibility. Clients may +// access the public fields of the Pass but must not invoke any of +// its "verbs", since the pass is already complete. +type Result struct { + Action *checker.Action + + // legacy fields + Facts map[types.Object][]analysis.Fact // nil key => package fact + Pass *analysis.Pass + Diagnostics []analysis.Diagnostic // see Action.Diagnostics + Result any // see Action.Result + Err error // see Action.Err +} // loadPackages uses go/packages to load a specified packages (from source, with // dependencies) from dir, which is the root of a GOPATH-style project tree. @@ -421,16 +476,6 @@ func loadPackages(a *analysis.Analyzer, dir string, patterns ...string) ([]*pack } } - // Do NOT print errors if the analyzer will continue running. - // It is incredibly confusing for tests to be printing to stderr - // willy-nilly instead of their test logs, especially when the - // errors are expected and are going to be fixed. - if !a.RunDespiteErrors { - if packages.PrintErrors(pkgs) > 0 { - return nil, fmt.Errorf("there were package loading errors (and RunDespiteErrors is false)") - } - } - if len(pkgs) == 0 { return nil, fmt.Errorf("no packages matched %s", patterns) } @@ -441,7 +486,7 @@ func loadPackages(a *analysis.Analyzer, dir string, patterns ...string) ([]*pack // been run, and verifies that all reported diagnostics and facts match // specified by the contents of "// want ..." comments in the package's // source files, which must have been parsed with comments enabled. -func check(t Testing, gopath string, pass *analysis.Pass, diagnostics []analysis.Diagnostic, facts map[types.Object][]analysis.Fact) { +func check(t Testing, gopath string, act *checker.Action) { type key struct { file string line int @@ -468,7 +513,7 @@ func check(t Testing, gopath string, pass *analysis.Pass, diagnostics []analysis } // Extract 'want' comments from parsed Go files. - for _, f := range pass.Files { + for _, f := range act.Package.Syntax { for _, cgroup := range f.Comments { for _, c := range cgroup.List { @@ -491,7 +536,7 @@ func check(t Testing, gopath string, pass *analysis.Pass, diagnostics []analysis // once outside the loop, but it's // incorrect because it can change due // to //line directives. - posn := pass.Fset.Position(c.Pos()) + posn := act.Package.Fset.Position(c.Pos()) filename := sanitize(gopath, posn.Filename) processComment(filename, posn.Line, text) } @@ -500,7 +545,17 @@ func check(t Testing, gopath string, pass *analysis.Pass, diagnostics []analysis // Extract 'want' comments from non-Go files. // TODO(adonovan): we may need to handle //line directives. - for _, filename := range pass.OtherFiles { + files := act.Package.OtherFiles + + // Hack: these two analyzers need to extract expectations from + // all configurations, so include the files are are usually + // ignored. (This was previously a hack in the respective + // analyzers' tests.) + if act.Analyzer.Name == "buildtag" || act.Analyzer.Name == "directive" { + files = append(files[:len(files):len(files)], act.Package.IgnoredFiles...) + } + + for _, filename := range files { data, err := os.ReadFile(filename) if err != nil { t.Errorf("can't read '// want' comments from %s: %v", filename, err) @@ -553,45 +608,38 @@ func check(t Testing, gopath string, pass *analysis.Pass, diagnostics []analysis } // Check the diagnostics match expectations. - for _, f := range diagnostics { + for _, f := range act.Diagnostics { // TODO(matloob): Support ranges in analysistest. - posn := pass.Fset.Position(f.Pos) + posn := act.Package.Fset.Position(f.Pos) checkMessage(posn, "diagnostic", "", f.Message) } // Check the facts match expectations. - // Report errors in lexical order for determinism. + // We check only facts relating to the current package. + // + // We report errors in lexical order for determinism. // (It's only deterministic within each file, not across files, // because go/packages does not guarantee file.Pos is ascending // across the files of a single compilation unit.) - var objects []types.Object - for obj := range facts { - objects = append(objects, obj) - } - sort.Slice(objects, func(i, j int) bool { - // Package facts compare less than object facts. - ip, jp := objects[i] == nil, objects[j] == nil // whether i, j is a package fact - if ip != jp { - return ip && !jp - } - return objects[i].Pos() < objects[j].Pos() - }) - for _, obj := range objects { - var posn token.Position - var name string - if obj != nil { - // Object facts are reported on the declaring line. - name = obj.Name() - posn = pass.Fset.Position(obj.Pos()) - } else { - // Package facts are reported at the start of the file. - name = "package" - posn = pass.Fset.Position(pass.Files[0].Pos()) - posn.Line = 1 + + // package facts: reported at start of first file + for _, pkgFact := range act.AllPackageFacts() { + if pkgFact.Package == act.Package.Types { + posn := act.Package.Fset.Position(act.Package.Syntax[0].Pos()) + posn.Line, posn.Column = 1, 1 + checkMessage(posn, "fact", "package", fmt.Sprint(pkgFact)) } + } - for _, fact := range facts[obj] { - checkMessage(posn, "fact", name, fmt.Sprint(fact)) + // object facts: reported at line of object declaration + objFacts := act.AllObjectFacts() + sort.Slice(objFacts, func(i, j int) bool { + return objFacts[i].Object.Pos() < objFacts[j].Object.Pos() + }) + for _, objFact := range objFacts { + if obj := objFact.Object; obj.Pkg() == act.Package.Types { + posn := act.Package.Fset.Position(obj.Pos()) + checkMessage(posn, "fact", obj.Name(), fmt.Sprint(objFact.Fact)) } } diff --git a/go/analysis/checker/checker.go b/go/analysis/checker/checker.go new file mode 100644 index 00000000000..5935a62abaf --- /dev/null +++ b/go/analysis/checker/checker.go @@ -0,0 +1,625 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package checker provides an analysis driver based on the +// [golang.org/x/tools/go/packages] representation of a set of +// packages and all their dependencies, as produced by +// [packages.Load]. +// +// It is the core of multichecker (the multi-analyzer driver), +// singlechecker (the single-analyzer driver often used to provide a +// convenient command alongside each analyzer), and analysistest, the +// test driver. +// +// By contrast, the 'go vet' command is based on unitchecker, an +// analysis driver that uses separate analysis--analogous to separate +// compilation--with file-based intermediate results. Like separate +// compilation, it is more scalable, especially for incremental +// analysis of large code bases. Commands based on multichecker and +// singlechecker are capable of detecting when they are being invoked +// by "go vet -vettool=exe" and instead dispatching to unitchecker. +// +// Programs built using this package will, in general, not be usable +// in that way. This package is intended only for use in applications +// that invoke the analysis driver as a subroutine, and need to insert +// additional steps before or after the analysis. +// +// See the Example of how to build a complete analysis driver program. +package checker + +import ( + "bytes" + "encoding/gob" + "fmt" + "go/types" + "io" + "log" + "reflect" + "sort" + "strings" + "sync" + "time" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/internal" + "golang.org/x/tools/go/analysis/internal/analysisflags" + "golang.org/x/tools/go/packages" + "golang.org/x/tools/internal/analysisinternal" +) + +// Options specifies options that control the analysis driver. +type Options struct { + // These options correspond to existing flags exposed by multichecker: + Sequential bool // disable parallelism + SanityCheck bool // check fact encoding is ok and deterministic + FactLog io.Writer // if non-nil, log each exported fact to it + + // TODO(adonovan): add ReadFile so that an Overlay specified + // in the [packages.Config] can be communicated via + // Pass.ReadFile to each Analyzer. +} + +// Graph holds the results of a round of analysis, including the graph +// of requested actions (analyzers applied to packages) plus any +// dependent actions that it was necessary to compute. +type Graph struct { + // Roots contains the roots of the action graph. + // Each node (a, p) in the action graph represents the + // application of one analyzer a to one package p. + // (A node thus corresponds to one analysis.Pass instance.) + // Roots holds one action per element of the product + // of the analyzers × packages arguments to Analyze, + // in unspecified order. + // + // Each element of Action.Deps represents an edge in the + // action graph: a dependency from one action to another. + // An edge of the form (a, p) -> (a, p2) indicates that the + // analysis of package p requires information ("facts") from + // the same analyzer applied to one of p's dependencies, p2. + // An edge of the form (a, p) -> (a2, p) indicates that the + // analysis of package p requires information ("results") + // from a different analyzer a2 applied to the same package. + // These two kind of edges are called "vertical" and "horizontal", + // respectively. + Roots []*Action +} + +// All returns an iterator over the action graph in depth-first postorder. +// +// Example: +// +// for act := range graph.All() { +// ... +// } +// +// Clients using go1.22 should iterate using the code below and may +// not assume anything else about the result: +// +// graph.All()(func (act *Action) bool { +// ... +// }) +func (g *Graph) All() actionSeq { + return func(yield func(*Action) bool) { + forEach(g.Roots, func(act *Action) error { + if !yield(act) { + return io.EOF // any error will do + } + return nil + }) + } +} + +// An Action represents one unit of analysis work by the driver: the +// application of one analysis to one package. It provides the inputs +// to and records the outputs of a single analysis.Pass. +// +// Actions form a DAG, both within a package (as different analyzers +// are applied, either in sequence or parallel), and across packages +// (as dependencies are analyzed). +type Action struct { + Analyzer *analysis.Analyzer + Package *packages.Package + IsRoot bool // whether this is a root node of the graph + Deps []*Action + Result any // computed result of Analyzer.run, if any (and if IsRoot) + Err error // error result of Analyzer.run + Diagnostics []analysis.Diagnostic + Duration time.Duration // execution time of this step + + opts *Options + once sync.Once + pass *analysis.Pass + objectFacts map[objectFactKey]analysis.Fact + packageFacts map[packageFactKey]analysis.Fact + inputs map[*analysis.Analyzer]any +} + +func (act *Action) String() string { + return fmt.Sprintf("%s@%s", act.Analyzer, act.Package) +} + +// Analyze runs the specified analyzers on the initial packages. +// +// The initial packages and all dependencies must have been loaded +// using the [packages.LoadAllSyntax] flag, Analyze may need to run +// some analyzer (those that consume and produce facts) on +// dependencies too. +// +// On success, it returns a Graph of actions whose Roots hold one +// item per (a, p) in the cross-product of analyzers and pkgs. +// +// If opts is nil, it is equivalent to new(Options). +func Analyze(analyzers []*analysis.Analyzer, pkgs []*packages.Package, opts *Options) (*Graph, error) { + if opts == nil { + opts = new(Options) + } + + if err := analysis.Validate(analyzers); err != nil { + return nil, err + } + + // Construct the action graph. + // + // Each graph node (action) is one unit of analysis. + // Edges express package-to-package (vertical) dependencies, + // and analysis-to-analysis (horizontal) dependencies. + type key struct { + a *analysis.Analyzer + pkg *packages.Package + } + actions := make(map[key]*Action) + + var mkAction func(a *analysis.Analyzer, pkg *packages.Package) *Action + mkAction = func(a *analysis.Analyzer, pkg *packages.Package) *Action { + k := key{a, pkg} + act, ok := actions[k] + if !ok { + act = &Action{Analyzer: a, Package: pkg, opts: opts} + + // Add a dependency on each required analyzers. + for _, req := range a.Requires { + act.Deps = append(act.Deps, mkAction(req, pkg)) + } + + // An analysis that consumes/produces facts + // must run on the package's dependencies too. + if len(a.FactTypes) > 0 { + paths := make([]string, 0, len(pkg.Imports)) + for path := range pkg.Imports { + paths = append(paths, path) + } + sort.Strings(paths) // for determinism + for _, path := range paths { + dep := mkAction(a, pkg.Imports[path]) + act.Deps = append(act.Deps, dep) + } + } + + actions[k] = act + } + return act + } + + // Build nodes for initial packages. + var roots []*Action + for _, a := range analyzers { + for _, pkg := range pkgs { + root := mkAction(a, pkg) + root.IsRoot = true + roots = append(roots, root) + } + } + + // Execute the graph in parallel. + execAll(roots) + + // Ensure that only root Results are visible to caller. + // (The others are considered temporary intermediaries.) + // TODO(adonovan): opt: clear them earlier, so we can + // release large data structures like SSA sooner. + for _, act := range actions { + if !act.IsRoot { + act.Result = nil + } + } + + return &Graph{Roots: roots}, nil +} + +func init() { + // Allow analysistest to access Action.pass, + // for its legacy Result data type. + internal.Pass = func(x any) *analysis.Pass { return x.(*Action).pass } +} + +type objectFactKey struct { + obj types.Object + typ reflect.Type +} + +type packageFactKey struct { + pkg *types.Package + typ reflect.Type +} + +func execAll(actions []*Action) { + var wg sync.WaitGroup + for _, act := range actions { + wg.Add(1) + work := func(act *Action) { + act.exec() + wg.Done() + } + if act.opts.Sequential { + work(act) + } else { + go work(act) + } + } + wg.Wait() +} + +func (act *Action) exec() { act.once.Do(act.execOnce) } + +func (act *Action) execOnce() { + // Analyze dependencies. + execAll(act.Deps) + + // Record time spent in this node but not its dependencies. + // In parallel mode, due to GC/scheduler contention, the + // time is 5x higher than in sequential mode, even with a + // semaphore limiting the number of threads here. + // So use -debug=tp. + t0 := time.Now() + defer func() { act.Duration = time.Since(t0) }() + + // Report an error if any dependency failed. + var failed []string + for _, dep := range act.Deps { + if dep.Err != nil { + failed = append(failed, dep.String()) + } + } + if failed != nil { + sort.Strings(failed) + act.Err = fmt.Errorf("failed prerequisites: %s", strings.Join(failed, ", ")) + return + } + + // Plumb the output values of the dependencies + // into the inputs of this action. Also facts. + inputs := make(map[*analysis.Analyzer]any) + act.objectFacts = make(map[objectFactKey]analysis.Fact) + act.packageFacts = make(map[packageFactKey]analysis.Fact) + for _, dep := range act.Deps { + if dep.Package == act.Package { + // Same package, different analysis (horizontal edge): + // in-memory outputs of prerequisite analyzers + // become inputs to this analysis pass. + inputs[dep.Analyzer] = dep.Result + + } else if dep.Analyzer == act.Analyzer { // (always true) + // Same analysis, different package (vertical edge): + // serialized facts produced by prerequisite analysis + // become available to this analysis pass. + inheritFacts(act, dep) + } + } + + // Quick (nonexhaustive) check that the correct go/packages mode bits were used. + // (If there were errors, all bets are off.) + if pkg := act.Package; pkg.Errors == nil { + if pkg.Name == "" || pkg.PkgPath == "" || pkg.Types == nil || pkg.Fset == nil || pkg.TypesSizes == nil { + panic("packages must be loaded with packages.LoadSyntax mode") + } + } + + module := &analysis.Module{} // possibly empty (non nil) in go/analysis drivers. + if mod := act.Package.Module; mod != nil { + module.Path = mod.Path + module.Version = mod.Version + module.GoVersion = mod.GoVersion + } + + // Run the analysis. + pass := &analysis.Pass{ + Analyzer: act.Analyzer, + Fset: act.Package.Fset, + Files: act.Package.Syntax, + OtherFiles: act.Package.OtherFiles, + IgnoredFiles: act.Package.IgnoredFiles, + Pkg: act.Package.Types, + TypesInfo: act.Package.TypesInfo, + TypesSizes: act.Package.TypesSizes, + TypeErrors: act.Package.TypeErrors, + Module: module, + + ResultOf: inputs, + Report: func(d analysis.Diagnostic) { act.Diagnostics = append(act.Diagnostics, d) }, + ImportObjectFact: act.ObjectFact, + ExportObjectFact: act.exportObjectFact, + ImportPackageFact: act.PackageFact, + ExportPackageFact: act.exportPackageFact, + AllObjectFacts: act.AllObjectFacts, + AllPackageFacts: act.AllPackageFacts, + } + pass.ReadFile = analysisinternal.MakeReadFile(pass) + act.pass = pass + + act.Result, act.Err = func() (any, error) { + if act.Package.IllTyped && !pass.Analyzer.RunDespiteErrors { + return nil, fmt.Errorf("analysis skipped due to errors in package") + } + + result, err := pass.Analyzer.Run(pass) + if err != nil { + return nil, err + } + + // correct result type? + if got, want := reflect.TypeOf(result), pass.Analyzer.ResultType; got != want { + return nil, fmt.Errorf( + "internal error: on package %s, analyzer %s returned a result of type %v, but declared ResultType %v", + pass.Pkg.Path(), pass.Analyzer, got, want) + } + + // resolve diagnostic URLs + for i := range act.Diagnostics { + url, err := analysisflags.ResolveURL(act.Analyzer, act.Diagnostics[i]) + if err != nil { + return nil, err + } + act.Diagnostics[i].URL = url + } + return result, nil + }() + + // Help detect (disallowed) calls after Run. + pass.ExportObjectFact = nil + pass.ExportPackageFact = nil +} + +// inheritFacts populates act.facts with +// those it obtains from its dependency, dep. +func inheritFacts(act, dep *Action) { + for key, fact := range dep.objectFacts { + // Filter out facts related to objects + // that are irrelevant downstream + // (equivalently: not in the compiler export data). + if !exportedFrom(key.obj, dep.Package.Types) { + if false { + log.Printf("%v: discarding %T fact from %s for %s: %s", act, fact, dep, key.obj, fact) + } + continue + } + + // Optionally serialize/deserialize fact + // to verify that it works across address spaces. + if act.opts.SanityCheck { + encodedFact, err := codeFact(fact) + if err != nil { + log.Panicf("internal error: encoding of %T fact failed in %v", fact, act) + } + fact = encodedFact + } + + if false { + log.Printf("%v: inherited %T fact for %s: %s", act, fact, key.obj, fact) + } + act.objectFacts[key] = fact + } + + for key, fact := range dep.packageFacts { + // TODO: filter out facts that belong to + // packages not mentioned in the export data + // to prevent side channels. + // + // The Pass.All{Object,Package}Facts accessors expose too much: + // all facts, of all types, for all dependencies in the action + // graph. Not only does the representation grow quadratically, + // but it violates the separate compilation paradigm, allowing + // analysis implementations to communicate with indirect + // dependencies that are not mentioned in the export data. + // + // It's not clear how to fix this short of a rather expensive + // filtering step after each action that enumerates all the + // objects that would appear in export data, and deletes + // facts associated with objects not in this set. + + // Optionally serialize/deserialize fact + // to verify that it works across address spaces + // and is deterministic. + if act.opts.SanityCheck { + encodedFact, err := codeFact(fact) + if err != nil { + log.Panicf("internal error: encoding of %T fact failed in %v", fact, act) + } + fact = encodedFact + } + + if false { + log.Printf("%v: inherited %T fact for %s: %s", act, fact, key.pkg.Path(), fact) + } + act.packageFacts[key] = fact + } +} + +// codeFact encodes then decodes a fact, +// just to exercise that logic. +func codeFact(fact analysis.Fact) (analysis.Fact, error) { + // We encode facts one at a time. + // A real modular driver would emit all facts + // into one encoder to improve gob efficiency. + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(fact); err != nil { + return nil, err + } + + // Encode it twice and assert that we get the same bits. + // This helps detect nondeterministic Gob encoding (e.g. of maps). + var buf2 bytes.Buffer + if err := gob.NewEncoder(&buf2).Encode(fact); err != nil { + return nil, err + } + if !bytes.Equal(buf.Bytes(), buf2.Bytes()) { + return nil, fmt.Errorf("encoding of %T fact is nondeterministic", fact) + } + + new := reflect.New(reflect.TypeOf(fact).Elem()).Interface().(analysis.Fact) + if err := gob.NewDecoder(&buf).Decode(new); err != nil { + return nil, err + } + return new, nil +} + +// exportedFrom reports whether obj may be visible to a package that imports pkg. +// This includes not just the exported members of pkg, but also unexported +// constants, types, fields, and methods, perhaps belonging to other packages, +// that find there way into the API. +// This is an overapproximation of the more accurate approach used by +// gc export data, which walks the type graph, but it's much simpler. +// +// TODO(adonovan): do more accurate filtering by walking the type graph. +func exportedFrom(obj types.Object, pkg *types.Package) bool { + switch obj := obj.(type) { + case *types.Func: + return obj.Exported() && obj.Pkg() == pkg || + obj.Type().(*types.Signature).Recv() != nil + case *types.Var: + if obj.IsField() { + return true + } + // we can't filter more aggressively than this because we need + // to consider function parameters exported, but have no way + // of telling apart function parameters from local variables. + return obj.Pkg() == pkg + case *types.TypeName, *types.Const: + return true + } + return false // Nil, Builtin, Label, or PkgName +} + +// ObjectFact retrieves a fact associated with obj, +// and returns true if one was found. +// Given a value ptr of type *T, where *T satisfies Fact, +// ObjectFact copies the value to *ptr. +// +// See documentation at ImportObjectFact field of [analysis.Pass]. +func (act *Action) ObjectFact(obj types.Object, ptr analysis.Fact) bool { + if obj == nil { + panic("nil object") + } + key := objectFactKey{obj, factType(ptr)} + if v, ok := act.objectFacts[key]; ok { + reflect.ValueOf(ptr).Elem().Set(reflect.ValueOf(v).Elem()) + return true + } + return false +} + +// exportObjectFact implements Pass.ExportObjectFact. +func (act *Action) exportObjectFact(obj types.Object, fact analysis.Fact) { + if act.pass.ExportObjectFact == nil { + log.Panicf("%s: Pass.ExportObjectFact(%s, %T) called after Run", act, obj, fact) + } + + if obj.Pkg() != act.Package.Types { + log.Panicf("internal error: in analysis %s of package %s: Fact.Set(%s, %T): can't set facts on objects belonging another package", + act.Analyzer, act.Package, obj, fact) + } + + key := objectFactKey{obj, factType(fact)} + act.objectFacts[key] = fact // clobber any existing entry + if log := act.opts.FactLog; log != nil { + objstr := types.ObjectString(obj, (*types.Package).Name) + fmt.Fprintf(log, "%s: object %s has fact %s\n", + act.Package.Fset.Position(obj.Pos()), objstr, fact) + } +} + +// AllObjectFacts returns a new slice containing all object facts of +// the analysis's FactTypes in unspecified order. +// +// See documentation at AllObjectFacts field of [analysis.Pass]. +func (act *Action) AllObjectFacts() []analysis.ObjectFact { + facts := make([]analysis.ObjectFact, 0, len(act.objectFacts)) + for k, fact := range act.objectFacts { + facts = append(facts, analysis.ObjectFact{Object: k.obj, Fact: fact}) + } + return facts +} + +// PackageFact retrieves a fact associated with package pkg, +// which must be this package or one of its dependencies. +// +// See documentation at ImportObjectFact field of [analysis.Pass]. +func (act *Action) PackageFact(pkg *types.Package, ptr analysis.Fact) bool { + if pkg == nil { + panic("nil package") + } + key := packageFactKey{pkg, factType(ptr)} + if v, ok := act.packageFacts[key]; ok { + reflect.ValueOf(ptr).Elem().Set(reflect.ValueOf(v).Elem()) + return true + } + return false +} + +// exportPackageFact implements Pass.ExportPackageFact. +func (act *Action) exportPackageFact(fact analysis.Fact) { + if act.pass.ExportPackageFact == nil { + log.Panicf("%s: Pass.ExportPackageFact(%T) called after Run", act, fact) + } + + key := packageFactKey{act.pass.Pkg, factType(fact)} + act.packageFacts[key] = fact // clobber any existing entry + if log := act.opts.FactLog; log != nil { + fmt.Fprintf(log, "%s: package %s has fact %s\n", + act.Package.Fset.Position(act.pass.Files[0].Pos()), act.pass.Pkg.Path(), fact) + } +} + +func factType(fact analysis.Fact) reflect.Type { + t := reflect.TypeOf(fact) + if t.Kind() != reflect.Ptr { + log.Fatalf("invalid Fact type: got %T, want pointer", fact) + } + return t +} + +// AllPackageFacts returns a new slice containing all package +// facts of the analysis's FactTypes in unspecified order. +// +// See documentation at AllPackageFacts field of [analysis.Pass]. +func (act *Action) AllPackageFacts() []analysis.PackageFact { + facts := make([]analysis.PackageFact, 0, len(act.packageFacts)) + for k, fact := range act.packageFacts { + facts = append(facts, analysis.PackageFact{Package: k.pkg, Fact: fact}) + } + return facts +} + +// forEach is a utility function for traversing the action graph. It +// applies function f to each action in the graph reachable from +// roots, in depth-first postorder. If any call to f returns an error, +// the traversal is aborted and ForEach returns the error. +func forEach(roots []*Action, f func(*Action) error) error { + seen := make(map[*Action]bool) + var visitAll func(actions []*Action) error + visitAll = func(actions []*Action) error { + for _, act := range actions { + if !seen[act] { + seen[act] = true + if err := visitAll(act.Deps); err != nil { + return err + } + if err := f(act); err != nil { + return err + } + } + } + return nil + } + return visitAll(roots) +} diff --git a/go/analysis/checker/example_test.go b/go/analysis/checker/example_test.go new file mode 100644 index 00000000000..91beeb1ed3f --- /dev/null +++ b/go/analysis/checker/example_test.go @@ -0,0 +1,104 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !wasm + +// The example command demonstrates a simple go/packages-based +// analysis driver program. +package checker_test + +import ( + "fmt" + "log" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/checker" + "golang.org/x/tools/go/packages" +) + +func Example() { + // Load packages: just this one. + // + // There may be parse or type errors among the + // initial packages or their dependencies, + // but the analysis driver can handle faulty inputs, + // as can some analyzers. + cfg := &packages.Config{Mode: packages.LoadAllSyntax} + initial, err := packages.Load(cfg, ".") + if err != nil { + log.Fatal(err) // failure to enumerate packages + } + if len(initial) == 0 { + log.Fatalf("no initial packages") + } + + // Run analyzers (just one) on packages. + analyzers := []*analysis.Analyzer{minmaxpkg} + graph, err := checker.Analyze(analyzers, initial, nil) + if err != nil { + log.Fatal(err) + } + + // Print information about the results of each + // analysis action, including all dependencies. + // + // Clients using Go 1.23 can say: + // for act := range graph.All() { ... } + graph.All()(func(act *checker.Action) bool { + // Print information about the Action, e.g. + // + // act.String() + // act.Result + // act.Err + // act.Diagnostics + // + // (We don't actually print anything here + // as the output would vary over time, + // which is unsuitable for a test.) + return true + }) + + // Print the minmaxpkg package fact computed for this package. + root := graph.Roots[0] + fact := new(minmaxpkgFact) + if root.PackageFact(root.Package.Types, fact) { + fmt.Printf("min=%s max=%s", fact.min, fact.max) + } + // Output: + // min=bufio max=unsafe +} + +// minmaxpkg is a trival example analyzer that uses package facts to +// compute information from the entire dependency graph. +var minmaxpkg = &analysis.Analyzer{ + Name: "minmaxpkg", + Doc: "Finds the min- and max-named packages among our dependencies.", + Run: run, + FactTypes: []analysis.Fact{(*minmaxpkgFact)(nil)}, +} + +// A package fact that records the alphabetically min and max-named +// packages among the dependencies of this package. +// (This property was chosen because it is relatively stable +// as the codebase evolves, avoiding frequent test breakage.) +type minmaxpkgFact struct{ min, max string } + +func (*minmaxpkgFact) AFact() {} + +func run(pass *analysis.Pass) (any, error) { + // Compute the min and max of the facts from our direct imports. + f := &minmaxpkgFact{min: pass.Pkg.Path(), max: pass.Pkg.Path()} + for _, imp := range pass.Pkg.Imports() { + if f2 := new(minmaxpkgFact); pass.ImportPackageFact(imp, f2) { + if f2.min < f.min { + f.min = f2.min + } + if f2.max > f.max { + f.max = f2.max + } + } + } + pass.ExportPackageFact(f) + return nil, nil +} diff --git a/go/analysis/checker/iter_go122.go b/go/analysis/checker/iter_go122.go new file mode 100644 index 00000000000..cd25cce035c --- /dev/null +++ b/go/analysis/checker/iter_go122.go @@ -0,0 +1,10 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !go1.23 + +package checker + +// This type is a placeholder for go1.23's iter.Seq[*Action]. +type actionSeq func(yield func(*Action) bool) diff --git a/internal/versions/constraint_go121.go b/go/analysis/checker/iter_go123.go similarity index 53% rename from internal/versions/constraint_go121.go rename to go/analysis/checker/iter_go123.go index 38011407d5f..e8278a9c1a4 100644 --- a/internal/versions/constraint_go121.go +++ b/go/analysis/checker/iter_go123.go @@ -2,13 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build go1.21 -// +build go1.21 +//go:build go1.23 -package versions +package checker -import "go/build/constraint" +import "iter" -func init() { - ConstraintGoVersion = constraint.GoVersion -} +type actionSeq = iter.Seq[*Action] diff --git a/go/analysis/checker/print.go b/go/analysis/checker/print.go new file mode 100644 index 00000000000..d7c0430117f --- /dev/null +++ b/go/analysis/checker/print.go @@ -0,0 +1,88 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package checker + +// This file defines helpers for printing analysis results. +// They should all be pure functions. + +import ( + "bytes" + "fmt" + "go/token" + "io" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/internal/analysisflags" +) + +// PrintText emits diagnostics as plain text to w. +// +// If contextLines is nonnegative, it also prints the +// offending line, plus that many lines of context +// before and after the line. +func (g *Graph) PrintText(w io.Writer, contextLines int) error { + return writeTextDiagnostics(w, g.Roots, contextLines) +} + +func writeTextDiagnostics(w io.Writer, roots []*Action, contextLines int) error { + // De-duplicate diagnostics by position (not token.Pos) to + // avoid double-reporting in source files that belong to + // multiple packages, such as foo and foo.test. + // (We cannot assume that such repeated files were parsed + // only once and use syntax nodes as the key.) + type key struct { + pos token.Position + end token.Position + *analysis.Analyzer + message string + } + seen := make(map[key]bool) + + // TODO(adonovan): opt: plumb errors back from PrintPlain and avoid buffer. + buf := new(bytes.Buffer) + forEach(roots, func(act *Action) error { + if act.Err != nil { + fmt.Fprintf(w, "%s: %v\n", act.Analyzer.Name, act.Err) + } else if act.IsRoot { + for _, diag := range act.Diagnostics { + // We don't display Analyzer.Name/diag.Category + // as most users don't care. + + posn := act.Package.Fset.Position(diag.Pos) + end := act.Package.Fset.Position(diag.End) + k := key{posn, end, act.Analyzer, diag.Message} + if seen[k] { + continue // duplicate + } + seen[k] = true + + analysisflags.PrintPlain(buf, act.Package.Fset, contextLines, diag) + } + } + return nil + }) + _, err := w.Write(buf.Bytes()) + return err +} + +// PrintJSON emits diagnostics in JSON form to w. +// Diagnostics are shown only for the root nodes, +// but errors (if any) are shown for all dependencies. +func (g *Graph) PrintJSON(w io.Writer) error { + return writeJSONDiagnostics(w, g.Roots) +} + +func writeJSONDiagnostics(w io.Writer, roots []*Action) error { + tree := make(analysisflags.JSONTree) + forEach(roots, func(act *Action) error { + var diags []analysis.Diagnostic + if act.IsRoot { + diags = act.Diagnostics + } + tree.Add(act.Package.Fset, act.Package.ID, act.Analyzer.Name, diags, act.Err) + return nil + }) + return tree.Print(w) +} diff --git a/go/analysis/internal/analysisflags/flags.go b/go/analysis/internal/analysisflags/flags.go index ff14ff58f9c..1282e70d41f 100644 --- a/go/analysis/internal/analysisflags/flags.go +++ b/go/analysis/internal/analysisflags/flags.go @@ -316,15 +316,22 @@ var vetLegacyFlags = map[string]string{ } // ---- output helpers common to all drivers ---- +// +// These functions should not depend on global state (flags)! +// Really they belong in a different package. + +// TODO(adonovan): don't accept an io.Writer if we don't report errors. +// Either accept a bytes.Buffer (infallible), or return a []byte. -// PrintPlain prints a diagnostic in plain text form, -// with context specified by the -c flag. -func PrintPlain(fset *token.FileSet, diag analysis.Diagnostic) { +// PrintPlain prints a diagnostic in plain text form. +// If contextLines is nonnegative, it also prints the +// offending line plus this many lines of context. +func PrintPlain(out io.Writer, fset *token.FileSet, contextLines int, diag analysis.Diagnostic) { posn := fset.Position(diag.Pos) - fmt.Fprintf(os.Stderr, "%s: %s\n", posn, diag.Message) + fmt.Fprintf(out, "%s: %s\n", posn, diag.Message) - // -c=N: show offending line plus N lines of context. - if Context >= 0 { + // show offending line plus N lines of context. + if contextLines >= 0 { posn := fset.Position(diag.Pos) end := fset.Position(diag.End) if !end.IsValid() { @@ -332,9 +339,9 @@ func PrintPlain(fset *token.FileSet, diag analysis.Diagnostic) { } data, _ := os.ReadFile(posn.Filename) lines := strings.Split(string(data), "\n") - for i := posn.Line - Context; i <= end.Line+Context; i++ { + for i := posn.Line - contextLines; i <= end.Line+contextLines; i++ { if 1 <= i && i <= len(lines) { - fmt.Fprintf(os.Stderr, "%d\t%s\n", i, lines[i-1]) + fmt.Fprintf(out, "%d\t%s\n", i, lines[i-1]) } } } @@ -438,10 +445,11 @@ func (tree JSONTree) Add(fset *token.FileSet, id, name string, diags []analysis. } } -func (tree JSONTree) Print() { +func (tree JSONTree) Print(out io.Writer) error { data, err := json.MarshalIndent(tree, "", "\t") if err != nil { log.Panicf("internal error: JSON marshaling failed: %v", err) } - fmt.Printf("%s\n", data) + _, err = fmt.Fprintf(out, "%s\n", data) + return err } diff --git a/go/analysis/internal/checker/checker.go b/go/analysis/internal/checker/checker.go index 8a802831c39..0c2fc5e59db 100644 --- a/go/analysis/internal/checker/checker.go +++ b/go/analysis/internal/checker/checker.go @@ -2,35 +2,37 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package checker defines the implementation of the checker commands. -// The same code drives the multi-analysis driver, the single-analysis -// driver that is conventionally provided for convenience along with -// each analysis package, and the test driver. +// Package internal/checker defines various implementation helpers for +// the singlechecker and multichecker packages, which provide the +// complete main function for an analysis driver executable +// based on go/packages. +// +// (Note: it is not used by the public 'checker' package, since the +// latter provides a set of pure functions for use as building blocks.) package checker +// TODO(adonovan): publish the JSON schema in go/analysis or analysisjson. + import ( - "bytes" - "encoding/gob" "flag" "fmt" "go/format" "go/token" - "go/types" + "io" + "io/ioutil" "log" "os" - "reflect" "runtime" "runtime/pprof" "runtime/trace" "sort" "strings" - "sync" "time" "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/checker" "golang.org/x/tools/go/analysis/internal/analysisflags" "golang.org/x/tools/go/packages" - "golang.org/x/tools/internal/analysisinternal" "golang.org/x/tools/internal/diff" "golang.org/x/tools/internal/robustio" ) @@ -78,7 +80,7 @@ func RegisterFlags() { // It provides most of the logic for the main functions of both the // singlechecker and the multi-analysis commands. // It returns the appropriate exit code. -func Run(args []string, analyzers []*analysis.Analyzer) (exitcode int) { +func Run(args []string, analyzers []*analysis.Analyzer) int { if CPUProfile != "" { f, err := os.Create(CPUProfile) if err != nil { @@ -144,15 +146,29 @@ func Run(args []string, analyzers []*analysis.Analyzer) (exitcode int) { pkgsExitCode = 1 } - // Run the analyzers. On each package with (transitive) - // errors, we run only the subset of analyzers that are - // marked (and whose transitive requirements are also - // marked) with RunDespiteErrors. - roots := analyze(initial, analyzers) + var factLog io.Writer + if dbg('f') { + factLog = os.Stderr + } + + // Run the analysis. + opts := &checker.Options{ + SanityCheck: dbg('s'), + Sequential: dbg('p'), + FactLog: factLog, + } + if dbg('v') { + log.Printf("building graph of analysis passes") + } + graph, err := checker.Analyze(analyzers, initial, opts) + if err != nil { + log.Print(err) + return 1 + } - // Apply fixes. + // Apply all fixes from the root actions. if Fix { - if err := applyFixes(roots); err != nil { + if err := applyFixes(graph.Roots); err != nil { // Fail when applying fixes failed. log.Print(err) return 1 @@ -163,12 +179,79 @@ func Run(args []string, analyzers []*analysis.Analyzer) (exitcode int) { // are errors in the packages, this will have 0 exit // code. Otherwise, we prefer to return exit code // indicating diagnostics. - if diagExitCode := printDiagnostics(roots); diagExitCode != 0 { + if diagExitCode := printDiagnostics(graph); diagExitCode != 0 { return diagExitCode // there were diagnostics } return pkgsExitCode // package errors but no diagnostics } +// printDiagnostics prints diagnostics in text or JSON form +// and returns the appropriate exit code. +func printDiagnostics(graph *checker.Graph) (exitcode int) { + // Print the results. + // With -json, the exit code is always zero. + if analysisflags.JSON { + if err := graph.PrintJSON(os.Stdout); err != nil { + return 1 + } + } else { + if err := graph.PrintText(os.Stderr, analysisflags.Context); err != nil { + return 1 + } + + // Compute the exit code. + var numErrors, rootDiags int + // TODO(adonovan): use "for act := range graph.All() { ... }" in go1.23. + graph.All()(func(act *checker.Action) bool { + if act.Err != nil { + numErrors++ + } else if act.IsRoot { + rootDiags += len(act.Diagnostics) + } + return true + }) + if numErrors > 0 { + exitcode = 1 // analysis failed, at least partially + } else if rootDiags > 0 { + exitcode = 3 // successfully produced diagnostics + } + } + + // Print timing info. + if dbg('t') { + if !dbg('p') { + log.Println("Warning: times are mostly GC/scheduler noise; use -debug=tp to disable parallelism") + } + + var list []*checker.Action + var total time.Duration + // TODO(adonovan): use "for act := range graph.All() { ... }" in go1.23. + graph.All()(func(act *checker.Action) bool { + list = append(list, act) + total += act.Duration + return true + }) + + // Print actions accounting for 90% of the total. + sort.Slice(list, func(i, j int) bool { + return list[i].Duration > list[j].Duration + }) + var sum time.Duration + for _, act := range list { + fmt.Fprintf(os.Stderr, "%s\t%s\n", act.Duration, act) + sum += act.Duration + if sum >= total*9/10 { + break + } + } + if total > sum { + fmt.Fprintf(os.Stderr, "%s\tall others\n", total-sum) + } + } + + return exitcode +} + // load loads the initial packages. Returns only top-level loading // errors. Does not consider errors in packages. func load(patterns []string, allSyntax bool) ([]*packages.Package, error) { @@ -188,149 +271,37 @@ func load(patterns []string, allSyntax bool) ([]*packages.Package, error) { return initial, err } -// TestAnalyzer applies an analyzer to a set of packages (and their -// dependencies if necessary) and returns the results. -// The analyzer must be valid according to [analysis.Validate]. -// -// Facts about pkg are returned in a map keyed by object; package facts -// have a nil key. -// -// This entry point is used only by analysistest. -func TestAnalyzer(a *analysis.Analyzer, pkgs []*packages.Package) []*TestAnalyzerResult { - var results []*TestAnalyzerResult - for _, act := range analyze(pkgs, []*analysis.Analyzer{a}) { - facts := make(map[types.Object][]analysis.Fact) - for key, fact := range act.objectFacts { - if key.obj.Pkg() == act.pass.Pkg { - facts[key.obj] = append(facts[key.obj], fact) - } - } - for key, fact := range act.packageFacts { - if key.pkg == act.pass.Pkg { - facts[nil] = append(facts[nil], fact) - } - } - - results = append(results, &TestAnalyzerResult{act.pass, act.diagnostics, facts, act.result, act.err}) - } - return results -} - -type TestAnalyzerResult struct { - Pass *analysis.Pass - Diagnostics []analysis.Diagnostic - Facts map[types.Object][]analysis.Fact - Result interface{} - Err error -} - -func analyze(pkgs []*packages.Package, analyzers []*analysis.Analyzer) []*action { - // Construct the action graph. - if dbg('v') { - log.Printf("building graph of analysis passes") - } - - // Each graph node (action) is one unit of analysis. - // Edges express package-to-package (vertical) dependencies, - // and analysis-to-analysis (horizontal) dependencies. - type key struct { - *analysis.Analyzer - *packages.Package - } - actions := make(map[key]*action) - - var mkAction func(a *analysis.Analyzer, pkg *packages.Package) *action - mkAction = func(a *analysis.Analyzer, pkg *packages.Package) *action { - k := key{a, pkg} - act, ok := actions[k] - if !ok { - act = &action{a: a, pkg: pkg} - - // Add a dependency on each required analyzers. - for _, req := range a.Requires { - act.deps = append(act.deps, mkAction(req, pkg)) - } - - // An analysis that consumes/produces facts - // must run on the package's dependencies too. - if len(a.FactTypes) > 0 { - paths := make([]string, 0, len(pkg.Imports)) - for path := range pkg.Imports { - paths = append(paths, path) - } - sort.Strings(paths) // for determinism - for _, path := range paths { - dep := mkAction(a, pkg.Imports[path]) - act.deps = append(act.deps, dep) - } - } - - actions[k] = act - } - return act - } - - // Build nodes for initial packages. - var roots []*action - for _, a := range analyzers { - for _, pkg := range pkgs { - root := mkAction(a, pkg) - root.isroot = true - roots = append(roots, root) - } - } - - // Execute the graph in parallel. - execAll(roots) - - return roots -} - -func applyFixes(roots []*action) error { - // visit all of the actions and accumulate the suggested edits. +// applyFixes applies suggested fixes associated with diagnostics +// reported by the specified actions. It verifies that edits do not +// conflict, even through file-system level aliases such as symbolic +// links, and then edits the files. +func applyFixes(actions []*checker.Action) error { + // Visit all of the actions and accumulate the suggested edits. paths := make(map[robustio.FileID]string) - editsByAction := make(map[robustio.FileID]map[*action][]diff.Edit) - visited := make(map[*action]bool) - var apply func(*action) error - var visitAll func(actions []*action) error - visitAll = func(actions []*action) error { - for _, act := range actions { - if !visited[act] { - visited[act] = true - if err := visitAll(act.deps); err != nil { - return err - } - if err := apply(act); err != nil { - return err - } - } - } - return nil - } - - apply = func(act *action) error { + editsByAction := make(map[robustio.FileID]map[*checker.Action][]diff.Edit) + for _, act := range actions { editsForTokenFile := make(map[*token.File][]diff.Edit) - for _, diag := range act.diagnostics { + for _, diag := range act.Diagnostics { for _, sf := range diag.SuggestedFixes { for _, edit := range sf.TextEdits { // Validate the edit. // Any error here indicates a bug in the analyzer. start, end := edit.Pos, edit.End - file := act.pkg.Fset.File(start) + file := act.Package.Fset.File(start) if file == nil { return fmt.Errorf("analysis %q suggests invalid fix: missing file info for pos (%v)", - act.a.Name, start) + act.Analyzer.Name, edit.Pos) } if !end.IsValid() { end = start } if start > end { return fmt.Errorf("analysis %q suggests invalid fix: pos (%v) > end (%v)", - act.a.Name, start, end) + act.Analyzer.Name, edit.Pos, edit.End) } if eof := token.Pos(file.Base() + file.Size()); end > eof { return fmt.Errorf("analysis %q suggests invalid fix: end (%v) past end of file (%v)", - act.a.Name, end, eof) + act.Analyzer.Name, edit.End, eof) } edit := diff.Edit{ Start: file.Offset(start), @@ -349,22 +320,17 @@ func applyFixes(roots []*action) error { } if _, hasId := paths[id]; !hasId { paths[id] = f.Name() - editsByAction[id] = make(map[*action][]diff.Edit) + editsByAction[id] = make(map[*checker.Action][]diff.Edit) } editsByAction[id][act] = edits } - return nil - } - - if err := visitAll(roots); err != nil { - return err } // Validate and group the edits to each actual file. editsByPath := make(map[string][]diff.Edit) for id, actToEdits := range editsByAction { path := paths[id] - actions := make([]*action, 0, len(actToEdits)) + actions := make([]*checker.Action, 0, len(actToEdits)) for act := range actToEdits { actions = append(actions, act) } @@ -373,7 +339,7 @@ func applyFixes(roots []*action) error { for _, act := range actions { edits := actToEdits[act] if _, invalid := validateEdits(edits); invalid > 0 { - name, x, y := act.a.Name, edits[invalid-1], edits[invalid] + name, x, y := act.Analyzer.Name, edits[invalid-1], edits[invalid] return diff3Conflict(path, name, name, []diff.Edit{x}, []diff.Edit{y}) } } @@ -382,7 +348,7 @@ func applyFixes(roots []*action) error { for j := range actions { for k := range actions[:j] { x, y := actions[j], actions[k] - if x.a.Name > y.a.Name { + if x.Analyzer.Name > y.Analyzer.Name { x, y = y, x } xedits, yedits := actToEdits[x], actToEdits[y] @@ -391,7 +357,7 @@ func applyFixes(roots []*action) error { // TODO: consider applying each action's consistent list of edits entirely, // and then using a three-way merge (such as GNU diff3) on the resulting // files to report more precisely the parts that actually conflict. - return diff3Conflict(path, x.a.Name, y.a.Name, xedits, yedits) + return diff3Conflict(path, x.Analyzer.Name, y.Analyzer.Name, xedits, yedits) } } } @@ -404,6 +370,7 @@ func applyFixes(roots []*action) error { } // Now we've got a set of valid edits for each file. Apply them. + // TODO(adonovan): don't abort the operation partway just because one file fails. for path, edits := range editsByPath { // TODO(adonovan): this should really work on the same // gulp from the file system that fed the analyzer (see #62292). @@ -462,7 +429,7 @@ func validateEdits(edits []diff.Edit) ([]diff.Edit, int) { // diff3Conflict returns an error describing two conflicting sets of // edits on a file at path. func diff3Conflict(path string, xlabel, ylabel string, xedits, yedits []diff.Edit) error { - contents, err := os.ReadFile(path) + contents, err := ioutil.ReadFile(path) if err != nil { return err } @@ -481,117 +448,6 @@ func diff3Conflict(path string, xlabel, ylabel string, xedits, yedits []diff.Edi xlabel, ylabel, path, xdiff, ydiff) } -// printDiagnostics prints the diagnostics for the root packages in either -// plain text or JSON format. JSON format also includes errors for any -// dependencies. -// -// It returns the exitcode: in plain mode, 0 for success, 1 for analysis -// errors, and 3 for diagnostics. We avoid 2 since the flag package uses -// it. JSON mode always succeeds at printing errors and diagnostics in a -// structured form to stdout. -func printDiagnostics(roots []*action) (exitcode int) { - // Print the output. - // - // Print diagnostics only for root packages, - // but errors for all packages. - printed := make(map[*action]bool) - var print func(*action) - var visitAll func(actions []*action) - visitAll = func(actions []*action) { - for _, act := range actions { - if !printed[act] { - printed[act] = true - visitAll(act.deps) - print(act) - } - } - } - - if analysisflags.JSON { - // JSON output - tree := make(analysisflags.JSONTree) - print = func(act *action) { - var diags []analysis.Diagnostic - if act.isroot { - diags = act.diagnostics - } - tree.Add(act.pkg.Fset, act.pkg.ID, act.a.Name, diags, act.err) - } - visitAll(roots) - tree.Print() - } else { - // plain text output - - // De-duplicate diagnostics by position (not token.Pos) to - // avoid double-reporting in source files that belong to - // multiple packages, such as foo and foo.test. - type key struct { - pos token.Position - end token.Position - *analysis.Analyzer - message string - } - seen := make(map[key]bool) - - print = func(act *action) { - if act.err != nil { - fmt.Fprintf(os.Stderr, "%s: %v\n", act.a.Name, act.err) - exitcode = 1 // analysis failed, at least partially - return - } - if act.isroot { - for _, diag := range act.diagnostics { - // We don't display a.Name/f.Category - // as most users don't care. - - posn := act.pkg.Fset.Position(diag.Pos) - end := act.pkg.Fset.Position(diag.End) - k := key{posn, end, act.a, diag.Message} - if seen[k] { - continue // duplicate - } - seen[k] = true - - analysisflags.PrintPlain(act.pkg.Fset, diag) - } - } - } - visitAll(roots) - - if exitcode == 0 && len(seen) > 0 { - exitcode = 3 // successfully produced diagnostics - } - } - - // Print timing info. - if dbg('t') { - if !dbg('p') { - log.Println("Warning: times are mostly GC/scheduler noise; use -debug=tp to disable parallelism") - } - var all []*action - var total time.Duration - for act := range printed { - all = append(all, act) - total += act.duration - } - sort.Slice(all, func(i, j int) bool { - return all[i].duration > all[j].duration - }) - - // Print actions accounting for 90% of the total. - var sum time.Duration - for _, act := range all { - fmt.Fprintf(os.Stderr, "%s\t%s\n", act.duration, act) - sum += act.duration - if sum >= total*9/10 { - break - } - } - } - - return exitcode -} - // needFacts reports whether any analysis required by the specified set // needs facts. If so, we must load the entire program from source. func needFacts(analyzers []*analysis.Analyzer) bool { @@ -612,373 +468,4 @@ func needFacts(analyzers []*analysis.Analyzer) bool { return false } -// An action represents one unit of analysis work: the application of -// one analysis to one package. Actions form a DAG, both within a -// package (as different analyzers are applied, either in sequence or -// parallel), and across packages (as dependencies are analyzed). -type action struct { - once sync.Once - a *analysis.Analyzer - pkg *packages.Package - pass *analysis.Pass - isroot bool - deps []*action - objectFacts map[objectFactKey]analysis.Fact - packageFacts map[packageFactKey]analysis.Fact - result interface{} - diagnostics []analysis.Diagnostic - err error - duration time.Duration -} - -type objectFactKey struct { - obj types.Object - typ reflect.Type -} - -type packageFactKey struct { - pkg *types.Package - typ reflect.Type -} - -func (act *action) String() string { - return fmt.Sprintf("%s@%s", act.a, act.pkg) -} - -func execAll(actions []*action) { - sequential := dbg('p') - var wg sync.WaitGroup - for _, act := range actions { - wg.Add(1) - work := func(act *action) { - act.exec() - wg.Done() - } - if sequential { - work(act) - } else { - go work(act) - } - } - wg.Wait() -} - -func (act *action) exec() { act.once.Do(act.execOnce) } - -func (act *action) execOnce() { - // Analyze dependencies. - execAll(act.deps) - - // TODO(adonovan): uncomment this during profiling. - // It won't build pre-go1.11 but conditional compilation - // using build tags isn't warranted. - // - // ctx, task := trace.NewTask(context.Background(), "exec") - // trace.Log(ctx, "pass", act.String()) - // defer task.End() - - // Record time spent in this node but not its dependencies. - // In parallel mode, due to GC/scheduler contention, the - // time is 5x higher than in sequential mode, even with a - // semaphore limiting the number of threads here. - // So use -debug=tp. - if dbg('t') { - t0 := time.Now() - defer func() { act.duration = time.Since(t0) }() - } - - // Report an error if any dependency failed. - var failed []string - for _, dep := range act.deps { - if dep.err != nil { - failed = append(failed, dep.String()) - } - } - if failed != nil { - sort.Strings(failed) - act.err = fmt.Errorf("failed prerequisites: %s", strings.Join(failed, ", ")) - return - } - - // Plumb the output values of the dependencies - // into the inputs of this action. Also facts. - inputs := make(map[*analysis.Analyzer]interface{}) - act.objectFacts = make(map[objectFactKey]analysis.Fact) - act.packageFacts = make(map[packageFactKey]analysis.Fact) - for _, dep := range act.deps { - if dep.pkg == act.pkg { - // Same package, different analysis (horizontal edge): - // in-memory outputs of prerequisite analyzers - // become inputs to this analysis pass. - inputs[dep.a] = dep.result - - } else if dep.a == act.a { // (always true) - // Same analysis, different package (vertical edge): - // serialized facts produced by prerequisite analysis - // become available to this analysis pass. - inheritFacts(act, dep) - } - } - - module := &analysis.Module{} // possibly empty (non nil) in go/analysis drivers. - if mod := act.pkg.Module; mod != nil { - module.Path = mod.Path - module.Version = mod.Version - module.GoVersion = mod.GoVersion - } - - // Run the analysis. - pass := &analysis.Pass{ - Analyzer: act.a, - Fset: act.pkg.Fset, - Files: act.pkg.Syntax, - OtherFiles: act.pkg.OtherFiles, - IgnoredFiles: act.pkg.IgnoredFiles, - Pkg: act.pkg.Types, - TypesInfo: act.pkg.TypesInfo, - TypesSizes: act.pkg.TypesSizes, - TypeErrors: act.pkg.TypeErrors, - Module: module, - - ResultOf: inputs, - Report: func(d analysis.Diagnostic) { act.diagnostics = append(act.diagnostics, d) }, - ImportObjectFact: act.importObjectFact, - ExportObjectFact: act.exportObjectFact, - ImportPackageFact: act.importPackageFact, - ExportPackageFact: act.exportPackageFact, - AllObjectFacts: act.allObjectFacts, - AllPackageFacts: act.allPackageFacts, - } - pass.ReadFile = analysisinternal.MakeReadFile(pass) - act.pass = pass - - var err error - if act.pkg.IllTyped && !pass.Analyzer.RunDespiteErrors { - err = fmt.Errorf("analysis skipped due to errors in package") - } else { - act.result, err = pass.Analyzer.Run(pass) - if err == nil { - if got, want := reflect.TypeOf(act.result), pass.Analyzer.ResultType; got != want { - err = fmt.Errorf( - "internal error: on package %s, analyzer %s returned a result of type %v, but declared ResultType %v", - pass.Pkg.Path(), pass.Analyzer, got, want) - } - } - } - if err == nil { // resolve diagnostic URLs - for i := range act.diagnostics { - if url, uerr := analysisflags.ResolveURL(act.a, act.diagnostics[i]); uerr == nil { - act.diagnostics[i].URL = url - } else { - err = uerr // keep the last error - } - } - } - act.err = err - - // disallow calls after Run - pass.ExportObjectFact = nil - pass.ExportPackageFact = nil -} - -// inheritFacts populates act.facts with -// those it obtains from its dependency, dep. -func inheritFacts(act, dep *action) { - serialize := dbg('s') - - for key, fact := range dep.objectFacts { - // Filter out facts related to objects - // that are irrelevant downstream - // (equivalently: not in the compiler export data). - if !exportedFrom(key.obj, dep.pkg.Types) { - if false { - log.Printf("%v: discarding %T fact from %s for %s: %s", act, fact, dep, key.obj, fact) - } - continue - } - - // Optionally serialize/deserialize fact - // to verify that it works across address spaces. - if serialize { - encodedFact, err := codeFact(fact) - if err != nil { - log.Panicf("internal error: encoding of %T fact failed in %v: %v", fact, act, err) - } - fact = encodedFact - } - - if false { - log.Printf("%v: inherited %T fact for %s: %s", act, fact, key.obj, fact) - } - act.objectFacts[key] = fact - } - - for key, fact := range dep.packageFacts { - // TODO: filter out facts that belong to - // packages not mentioned in the export data - // to prevent side channels. - - // Optionally serialize/deserialize fact - // to verify that it works across address spaces - // and is deterministic. - if serialize { - encodedFact, err := codeFact(fact) - if err != nil { - log.Panicf("internal error: encoding of %T fact failed in %v", fact, act) - } - fact = encodedFact - } - - if false { - log.Printf("%v: inherited %T fact for %s: %s", act, fact, key.pkg.Path(), fact) - } - act.packageFacts[key] = fact - } -} - -// codeFact encodes then decodes a fact, -// just to exercise that logic. -func codeFact(fact analysis.Fact) (analysis.Fact, error) { - // We encode facts one at a time. - // A real modular driver would emit all facts - // into one encoder to improve gob efficiency. - var buf bytes.Buffer - if err := gob.NewEncoder(&buf).Encode(fact); err != nil { - return nil, err - } - - // Encode it twice and assert that we get the same bits. - // This helps detect nondeterministic Gob encoding (e.g. of maps). - var buf2 bytes.Buffer - if err := gob.NewEncoder(&buf2).Encode(fact); err != nil { - return nil, err - } - if !bytes.Equal(buf.Bytes(), buf2.Bytes()) { - return nil, fmt.Errorf("encoding of %T fact is nondeterministic", fact) - } - - new := reflect.New(reflect.TypeOf(fact).Elem()).Interface().(analysis.Fact) - if err := gob.NewDecoder(&buf).Decode(new); err != nil { - return nil, err - } - return new, nil -} - -// exportedFrom reports whether obj may be visible to a package that imports pkg. -// This includes not just the exported members of pkg, but also unexported -// constants, types, fields, and methods, perhaps belonging to other packages, -// that find there way into the API. -// This is an overapproximation of the more accurate approach used by -// gc export data, which walks the type graph, but it's much simpler. -// -// TODO(adonovan): do more accurate filtering by walking the type graph. -func exportedFrom(obj types.Object, pkg *types.Package) bool { - switch obj := obj.(type) { - case *types.Func: - return obj.Exported() && obj.Pkg() == pkg || - obj.Type().(*types.Signature).Recv() != nil - case *types.Var: - if obj.IsField() { - return true - } - // we can't filter more aggressively than this because we need - // to consider function parameters exported, but have no way - // of telling apart function parameters from local variables. - return obj.Pkg() == pkg - case *types.TypeName, *types.Const: - return true - } - return false // Nil, Builtin, Label, or PkgName -} - -// importObjectFact implements Pass.ImportObjectFact. -// Given a non-nil pointer ptr of type *T, where *T satisfies Fact, -// importObjectFact copies the fact value to *ptr. -func (act *action) importObjectFact(obj types.Object, ptr analysis.Fact) bool { - if obj == nil { - panic("nil object") - } - key := objectFactKey{obj, factType(ptr)} - if v, ok := act.objectFacts[key]; ok { - reflect.ValueOf(ptr).Elem().Set(reflect.ValueOf(v).Elem()) - return true - } - return false -} - -// exportObjectFact implements Pass.ExportObjectFact. -func (act *action) exportObjectFact(obj types.Object, fact analysis.Fact) { - if act.pass.ExportObjectFact == nil { - log.Panicf("%s: Pass.ExportObjectFact(%s, %T) called after Run", act, obj, fact) - } - - if obj.Pkg() != act.pkg.Types { - log.Panicf("internal error: in analysis %s of package %s: Fact.Set(%s, %T): can't set facts on objects belonging another package", - act.a, act.pkg, obj, fact) - } - - key := objectFactKey{obj, factType(fact)} - act.objectFacts[key] = fact // clobber any existing entry - if dbg('f') { - objstr := types.ObjectString(obj, (*types.Package).Name) - fmt.Fprintf(os.Stderr, "%s: object %s has fact %s\n", - act.pkg.Fset.Position(obj.Pos()), objstr, fact) - } -} - -// allObjectFacts implements Pass.AllObjectFacts. -func (act *action) allObjectFacts() []analysis.ObjectFact { - facts := make([]analysis.ObjectFact, 0, len(act.objectFacts)) - for k := range act.objectFacts { - facts = append(facts, analysis.ObjectFact{Object: k.obj, Fact: act.objectFacts[k]}) - } - return facts -} - -// importPackageFact implements Pass.ImportPackageFact. -// Given a non-nil pointer ptr of type *T, where *T satisfies Fact, -// fact copies the fact value to *ptr. -func (act *action) importPackageFact(pkg *types.Package, ptr analysis.Fact) bool { - if pkg == nil { - panic("nil package") - } - key := packageFactKey{pkg, factType(ptr)} - if v, ok := act.packageFacts[key]; ok { - reflect.ValueOf(ptr).Elem().Set(reflect.ValueOf(v).Elem()) - return true - } - return false -} - -// exportPackageFact implements Pass.ExportPackageFact. -func (act *action) exportPackageFact(fact analysis.Fact) { - if act.pass.ExportPackageFact == nil { - log.Panicf("%s: Pass.ExportPackageFact(%T) called after Run", act, fact) - } - - key := packageFactKey{act.pass.Pkg, factType(fact)} - act.packageFacts[key] = fact // clobber any existing entry - if dbg('f') { - fmt.Fprintf(os.Stderr, "%s: package %s has fact %s\n", - act.pkg.Fset.Position(act.pass.Files[0].Pos()), act.pass.Pkg.Path(), fact) - } -} - -func factType(fact analysis.Fact) reflect.Type { - t := reflect.TypeOf(fact) - if t.Kind() != reflect.Ptr { - log.Fatalf("invalid Fact type: got %T, want pointer", fact) - } - return t -} - -// allPackageFacts implements Pass.AllPackageFacts. -func (act *action) allPackageFacts() []analysis.PackageFact { - facts := make([]analysis.PackageFact, 0, len(act.packageFacts)) - for k := range act.packageFacts { - facts = append(facts, analysis.PackageFact{Package: k.pkg, Fact: act.packageFacts[k]}) - } - return facts -} - func dbg(b byte) bool { return strings.IndexByte(Debug, b) >= 0 } diff --git a/go/analysis/internal/checker/checker_test.go b/go/analysis/internal/checker/checker_test.go index b0d711b4182..77a57f5119c 100644 --- a/go/analysis/internal/checker/checker_test.go +++ b/go/analysis/internal/checker/checker_test.go @@ -168,6 +168,7 @@ func NewT1() *T1 { return &T1{T} } // parse or type errors in the code. noop := &analysis.Analyzer{ Name: "noop", + Doc: "noop", Requires: []*analysis.Analyzer{inspect.Analyzer}, Run: func(pass *analysis.Pass) (interface{}, error) { return nil, nil @@ -179,6 +180,7 @@ func NewT1() *T1 { return &T1{T} } // regardless of parse or type errors in the code. noopWithFact := &analysis.Analyzer{ Name: "noopfact", + Doc: "noopfact", Requires: []*analysis.Analyzer{inspect.Analyzer}, Run: func(pass *analysis.Pass) (interface{}, error) { return nil, nil diff --git a/go/analysis/internal/checker/start_test.go b/go/analysis/internal/checker/start_test.go index 6b0df3033ed..af4dc42c85c 100644 --- a/go/analysis/internal/checker/start_test.go +++ b/go/analysis/internal/checker/start_test.go @@ -55,6 +55,7 @@ package comment var commentAnalyzer = &analysis.Analyzer{ Name: "comment", + Doc: "comment", Requires: []*analysis.Analyzer{inspect.Analyzer}, Run: commentRun, } diff --git a/go/analysis/internal/internal.go b/go/analysis/internal/internal.go new file mode 100644 index 00000000000..e7c8247fd33 --- /dev/null +++ b/go/analysis/internal/internal.go @@ -0,0 +1,12 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal + +import "golang.org/x/tools/go/analysis" + +// This function is set by the checker package to provide +// backdoor access to the private Pass field +// of the checker.Action type, for use by analysistest. +var Pass func(interface{}) *analysis.Pass diff --git a/go/analysis/passes/buildtag/buildtag.go b/go/analysis/passes/buildtag/buildtag.go index b5a2d2775f4..e7434e8fed2 100644 --- a/go/analysis/passes/buildtag/buildtag.go +++ b/go/analysis/passes/buildtag/buildtag.go @@ -15,7 +15,6 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/internal/analysisutil" - "golang.org/x/tools/internal/versions" ) const Doc = "check //go:build and // +build directives" @@ -371,11 +370,6 @@ func (check *checker) finish() { // tags reports issues in go versions in tags within the expression e. func (check *checker) tags(pos token.Pos, e constraint.Expr) { - // Check that constraint.GoVersion is meaningful (>= go1.21). - if versions.ConstraintGoVersion == nil { - return - } - // Use Eval to visit each tag. _ = e.Eval(func(tag string) bool { if malformedGoTag(tag) { @@ -393,10 +387,8 @@ func malformedGoTag(tag string) bool { // Check for close misspellings of the "go1." prefix. for _, pre := range []string{"go.", "g1.", "go"} { suffix := strings.TrimPrefix(tag, pre) - if suffix != tag { - if valid, ok := validTag("go1." + suffix); ok && valid { - return true - } + if suffix != tag && validGoVersion("go1."+suffix) { + return true } } return false @@ -404,15 +396,10 @@ func malformedGoTag(tag string) bool { // The tag starts with "go1" so it is almost certainly a GoVersion. // Report it if it is not a valid build constraint. - valid, ok := validTag(tag) - return ok && !valid + return !validGoVersion(tag) } -// validTag returns (valid, ok) where valid reports when a tag is valid, -// and ok reports determining if the tag is valid succeeded. -func validTag(tag string) (valid bool, ok bool) { - if versions.ConstraintGoVersion != nil { - return versions.ConstraintGoVersion(&constraint.TagExpr{Tag: tag}) != "", true - } - return false, false +// validGoVersion reports when a tag is a valid go version. +func validGoVersion(tag string) bool { + return constraint.GoVersion(&constraint.TagExpr{Tag: tag}) != "" } diff --git a/go/analysis/passes/buildtag/buildtag_test.go b/go/analysis/passes/buildtag/buildtag_test.go index 6109cba3ddd..9f0b9f5e957 100644 --- a/go/analysis/passes/buildtag/buildtag_test.go +++ b/go/analysis/passes/buildtag/buildtag_test.go @@ -7,33 +7,14 @@ package buildtag_test import ( "testing" - "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/analysistest" "golang.org/x/tools/go/analysis/passes/buildtag" - "golang.org/x/tools/internal/versions" ) func Test(t *testing.T) { - analyzer := *buildtag.Analyzer - analyzer.Run = func(pass *analysis.Pass) (interface{}, error) { - defer func() { - // The buildtag pass is unusual in that it checks the IgnoredFiles. - // After analysis, add IgnoredFiles to OtherFiles so that - // the test harness checks for expected diagnostics in those. - // (The test harness shouldn't do this by default because most - // passes can't do anything with the IgnoredFiles without type - // information, which is unavailable because they are ignored.) - var files []string - files = append(files, pass.OtherFiles...) - files = append(files, pass.IgnoredFiles...) - pass.OtherFiles = files - }() - - return buildtag.Analyzer.Run(pass) - } - patterns := []string{"a"} - if versions.ConstraintGoVersion != nil { - patterns = append(patterns, "b") - } - analysistest.Run(t, analysistest.TestData(), &analyzer, patterns...) + // This test has a dedicated hack in the analysistest package: + // Because it cares about IgnoredFiles, which most analyzers + // ignore, the test framework will consider expectations in + // ignore files too, but only for this analyzer. + analysistest.Run(t, analysistest.TestData(), buildtag.Analyzer, "a", "b") } diff --git a/go/analysis/passes/copylock/copylock_test.go b/go/analysis/passes/copylock/copylock_test.go index c22001ca3ea..ae249b1acad 100644 --- a/go/analysis/passes/copylock/copylock_test.go +++ b/go/analysis/passes/copylock/copylock_test.go @@ -16,7 +16,7 @@ import ( func Test(t *testing.T) { testdata := analysistest.TestData() - analysistest.Run(t, testdata, copylock.Analyzer, "a", "typeparams", "issue67787") + analysistest.Run(t, testdata, copylock.Analyzer, "a", "typeparams", "issue67787", "unfortunate") } func TestVersions22(t *testing.T) { diff --git a/go/analysis/passes/copylock/testdata/src/a/copylock_func.go b/go/analysis/passes/copylock/testdata/src/a/copylock_func.go index 0d3168f1ef1..c27862627b9 100644 --- a/go/analysis/passes/copylock/testdata/src/a/copylock_func.go +++ b/go/analysis/passes/copylock/testdata/src/a/copylock_func.go @@ -5,20 +5,25 @@ // This file contains tests for the copylock checker's // function declaration analysis. +// There are two cases missing from this file which +// are located in the "unfortunate" package in the +// testdata directory. Once the go.mod >= 1.26 for this +// repository, merge local_go124.go back into this file. + package a import "sync" func OkFunc(*sync.Mutex) {} func BadFunc(sync.Mutex) {} // want "BadFunc passes lock by value: sync.Mutex" -func BadFunc2(sync.Map) {} // want "BadFunc2 passes lock by value: sync.Map contains sync.Mutex" +func BadFunc2(sync.Map) {} // want "BadFunc2 passes lock by value: sync.Map contains sync.(Mutex|noCopy)" func OkRet() *sync.Mutex {} func BadRet() sync.Mutex {} // Don't warn about results var ( OkClosure = func(*sync.Mutex) {} BadClosure = func(sync.Mutex) {} // want "func passes lock by value: sync.Mutex" - BadClosure2 = func(sync.Map) {} // want "func passes lock by value: sync.Map contains sync.Mutex" + BadClosure2 = func(sync.Map) {} // want "func passes lock by value: sync.Map contains sync.(Mutex|noCopy)" ) type EmbeddedRWMutex struct { @@ -118,19 +123,3 @@ func AcceptedCases() { x = BadRet() // function call on RHS is OK (#16227) x = *OKRet() // indirection of function call on RHS is OK (#16227) } - -// TODO: Unfortunate cases - -// Non-ideal error message: -// Since we're looking for Lock methods, sync.Once's underlying -// sync.Mutex gets called out, but without any reference to the sync.Once. -type LocalOnce sync.Once - -func (LocalOnce) Bad() {} // want `Bad passes lock by value: a.LocalOnce contains sync.\b.*` - -// False negative: -// LocalMutex doesn't have a Lock method. -// Nevertheless, it is probably a bad idea to pass it by value. -type LocalMutex sync.Mutex - -func (LocalMutex) Bad() {} // WANTED: An error here :( diff --git a/go/analysis/passes/copylock/testdata/src/unfortunate/local_go123.go b/go/analysis/passes/copylock/testdata/src/unfortunate/local_go123.go new file mode 100644 index 00000000000..c6bc0256b02 --- /dev/null +++ b/go/analysis/passes/copylock/testdata/src/unfortunate/local_go123.go @@ -0,0 +1,25 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !go1.24 + +package unfortunate + +import "sync" + +// TODO: Unfortunate cases + +// Non-ideal error message: +// Since we're looking for Lock methods, sync.Once's underlying +// sync.Mutex gets called out, but without any reference to the sync.Once. +type LocalOnce sync.Once + +func (LocalOnce) Bad() {} // want `Bad passes lock by value: unfortunate.LocalOnce contains sync.\b.*` + +// False negative: +// LocalMutex doesn't have a Lock method. +// Nevertheless, it is probably a bad idea to pass it by value. +type LocalMutex sync.Mutex + +func (LocalMutex) Bad() {} // WANTED: An error here :( diff --git a/go/analysis/passes/copylock/testdata/src/unfortunate/local_go124.go b/go/analysis/passes/copylock/testdata/src/unfortunate/local_go124.go new file mode 100644 index 00000000000..5f45402f792 --- /dev/null +++ b/go/analysis/passes/copylock/testdata/src/unfortunate/local_go124.go @@ -0,0 +1,19 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.24 + +package unfortunate + +import "sync" + +// Cases where the interior sync.noCopy shows through. + +type LocalOnce sync.Once + +func (LocalOnce) Bad() {} // want "Bad passes lock by value: unfortunate.LocalOnce contains sync.noCopy" + +type LocalMutex sync.Mutex + +func (LocalMutex) Bad() {} // want "Bad passes lock by value: unfortunate.LocalMutex contains sync.noCopy" diff --git a/go/analysis/passes/ctrlflow/ctrlflow_test.go b/go/analysis/passes/ctrlflow/ctrlflow_test.go index 5afd01cc918..6fa764eb2d0 100644 --- a/go/analysis/passes/ctrlflow/ctrlflow_test.go +++ b/go/analysis/passes/ctrlflow/ctrlflow_test.go @@ -21,11 +21,11 @@ func Test(t *testing.T) { for _, result := range results { cfgs := result.Result.(*ctrlflow.CFGs) - for _, decl := range result.Pass.Files[0].Decls { + for _, decl := range result.Action.Package.Syntax[0].Decls { if decl, ok := decl.(*ast.FuncDecl); ok && decl.Body != nil { if cfgs.FuncDecl(decl) == nil { t.Errorf("%s: no CFG for func %s", - result.Pass.Fset.Position(decl.Pos()), decl.Name.Name) + result.Action.Package.Fset.Position(decl.Pos()), decl.Name.Name) } } } diff --git a/go/analysis/passes/directive/directive_test.go b/go/analysis/passes/directive/directive_test.go index 8f6ae0578e5..f9620473519 100644 --- a/go/analysis/passes/directive/directive_test.go +++ b/go/analysis/passes/directive/directive_test.go @@ -7,28 +7,14 @@ package directive_test import ( "testing" - "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/analysistest" "golang.org/x/tools/go/analysis/passes/directive" ) func Test(t *testing.T) { - analyzer := *directive.Analyzer - analyzer.Run = func(pass *analysis.Pass) (interface{}, error) { - defer func() { - // The directive pass is unusual in that it checks the IgnoredFiles. - // After analysis, add IgnoredFiles to OtherFiles so that - // the test harness checks for expected diagnostics in those. - // (The test harness shouldn't do this by default because most - // passes can't do anything with the IgnoredFiles without type - // information, which is unavailable because they are ignored.) - var files []string - files = append(files, pass.OtherFiles...) - files = append(files, pass.IgnoredFiles...) - pass.OtherFiles = files - }() - - return directive.Analyzer.Run(pass) - } - analysistest.Run(t, analysistest.TestData(), &analyzer, "a") + // This test has a dedicated hack in the analysistest package: + // Because it cares about IgnoredFiles, which most analyzers + // ignore, the test framework will consider expectations in + // ignore files too, but only for this analyzer. + analysistest.Run(t, analysistest.TestData(), directive.Analyzer, "a") } diff --git a/go/analysis/passes/printf/printf.go b/go/analysis/passes/printf/printf.go index 2d79d0b0334..171ad201372 100644 --- a/go/analysis/passes/printf/printf.go +++ b/go/analysis/passes/printf/printf.go @@ -433,6 +433,9 @@ func printfNameAndKind(pass *analysis.Pass, call *ast.CallExpr) (fn *types.Func, return nil, 0 } + // Facts are associated with generic declarations, not instantiations. + fn = fn.Origin() + _, ok := isPrint[fn.FullName()] if !ok { // Next look up just "printf", for use with -printf.funcs. diff --git a/go/analysis/passes/printf/printf_test.go b/go/analysis/passes/printf/printf_test.go index b27cef51983..198cf6ec549 100644 --- a/go/analysis/passes/printf/printf_test.go +++ b/go/analysis/passes/printf/printf_test.go @@ -16,6 +16,6 @@ func Test(t *testing.T) { printf.Analyzer.Flags.Set("funcs", "Warn,Warnf") analysistest.Run(t, testdata, printf.Analyzer, - "a", "b", "nofmt", "typeparams", "issue68744") + "a", "b", "nofmt", "typeparams", "issue68744", "issue70572") analysistest.RunWithSuggestedFixes(t, testdata, printf.Analyzer, "fix") } diff --git a/go/analysis/passes/printf/testdata/src/issue70572/issue70572.go b/go/analysis/passes/printf/testdata/src/issue70572/issue70572.go new file mode 100644 index 00000000000..b9959afeafd --- /dev/null +++ b/go/analysis/passes/printf/testdata/src/issue70572/issue70572.go @@ -0,0 +1,25 @@ +package issue70572 + +// Regression test for failure to detect that a call to B[bool].Printf +// was printf-like, because of a missing call to types.Func.Origin. + +import "fmt" + +type A struct{} + +func (v A) Printf(format string, values ...any) { // want Printf:"printfWrapper" + fmt.Printf(format, values...) +} + +type B[T any] struct{} + +func (v B[T]) Printf(format string, values ...any) { // want Printf:"printfWrapper" + fmt.Printf(format, values...) +} + +func main() { + var a A + var b B[bool] + a.Printf("x", 1) // want "arguments but no formatting directives" + b.Printf("x", 1) // want "arguments but no formatting directives" +} diff --git a/go/analysis/passes/waitgroup/doc.go b/go/analysis/passes/waitgroup/doc.go new file mode 100644 index 00000000000..207f7418307 --- /dev/null +++ b/go/analysis/passes/waitgroup/doc.go @@ -0,0 +1,34 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package waitgroup defines an Analyzer that detects simple misuses +// of sync.WaitGroup. +// +// # Analyzer waitgroup +// +// waitgroup: check for misuses of sync.WaitGroup +// +// This analyzer detects mistaken calls to the (*sync.WaitGroup).Add +// method from inside a new goroutine, causing Add to race with Wait: +// +// // WRONG +// var wg sync.WaitGroup +// go func() { +// wg.Add(1) // "WaitGroup.Add called from inside new goroutine" +// defer wg.Done() +// ... +// }() +// wg.Wait() // (may return prematurely before new goroutine starts) +// +// The correct code calls Add before starting the goroutine: +// +// // RIGHT +// var wg sync.WaitGroup +// wg.Add(1) +// go func() { +// defer wg.Done() +// ... +// }() +// wg.Wait() +package waitgroup diff --git a/go/analysis/passes/waitgroup/main.go b/go/analysis/passes/waitgroup/main.go new file mode 100644 index 00000000000..785eadd9fcc --- /dev/null +++ b/go/analysis/passes/waitgroup/main.go @@ -0,0 +1,16 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build ignore + +// The waitgroup command applies the golang.org/x/tools/go/analysis/passes/waitgroup +// analysis to the specified packages of Go source code. +package main + +import ( + "golang.org/x/tools/go/analysis/passes/waitgroup" + "golang.org/x/tools/go/analysis/singlechecker" +) + +func main() { singlechecker.Main(waitgroup.Analyzer) } diff --git a/go/analysis/passes/waitgroup/testdata/src/a/a.go b/go/analysis/passes/waitgroup/testdata/src/a/a.go new file mode 100644 index 00000000000..c1fecc2e121 --- /dev/null +++ b/go/analysis/passes/waitgroup/testdata/src/a/a.go @@ -0,0 +1,14 @@ +package a + +import "sync" + +func f() { + var wg sync.WaitGroup + wg.Add(1) // ok + go func() { + wg.Add(1) // want "WaitGroup.Add called from inside new goroutine" + // ... + wg.Add(1) // ok + }() + wg.Add(1) // ok +} diff --git a/go/analysis/passes/waitgroup/waitgroup.go b/go/analysis/passes/waitgroup/waitgroup.go new file mode 100644 index 00000000000..cbb0bfc9e6b --- /dev/null +++ b/go/analysis/passes/waitgroup/waitgroup.go @@ -0,0 +1,105 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package waitgroup defines an Analyzer that detects simple misuses +// of sync.WaitGroup. +package waitgroup + +import ( + _ "embed" + "go/ast" + "go/types" + "reflect" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/analysis/passes/internal/analysisutil" + "golang.org/x/tools/go/ast/inspector" + "golang.org/x/tools/go/types/typeutil" + "golang.org/x/tools/internal/typesinternal" +) + +//go:embed doc.go +var doc string + +var Analyzer = &analysis.Analyzer{ + Name: "waitgroup", + Doc: analysisutil.MustExtractDoc(doc, "waitgroup"), + URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/waitgroup", + Requires: []*analysis.Analyzer{inspect.Analyzer}, + Run: run, +} + +func run(pass *analysis.Pass) (any, error) { + if !analysisutil.Imports(pass.Pkg, "sync") { + return nil, nil // doesn't directly import sync + } + + inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + nodeFilter := []ast.Node{ + (*ast.CallExpr)(nil), + } + + inspect.WithStack(nodeFilter, func(n ast.Node, push bool, stack []ast.Node) (proceed bool) { + if push { + call := n.(*ast.CallExpr) + if fn, ok := typeutil.Callee(pass.TypesInfo, call).(*types.Func); ok && + isMethodNamed(fn, "sync", "WaitGroup", "Add") && + hasSuffix(stack, wantSuffix) && + backindex(stack, 1) == backindex(stack, 2).(*ast.BlockStmt).List[0] { // ExprStmt must be Block's first stmt + + pass.Reportf(call.Lparen, "WaitGroup.Add called from inside new goroutine") + } + } + return true + }) + + return nil, nil +} + +// go func() { +// wg.Add(1) +// ... +// }() +var wantSuffix = []ast.Node{ + (*ast.GoStmt)(nil), + (*ast.CallExpr)(nil), + (*ast.FuncLit)(nil), + (*ast.BlockStmt)(nil), + (*ast.ExprStmt)(nil), + (*ast.CallExpr)(nil), +} + +// hasSuffix reports whether stack has the matching suffix, +// considering only node types. +func hasSuffix(stack, suffix []ast.Node) bool { + // TODO(adonovan): the inspector could implement this for us. + if len(stack) < len(suffix) { + return false + } + for i := range len(suffix) { + if reflect.TypeOf(backindex(stack, i)) != reflect.TypeOf(backindex(suffix, i)) { + return false + } + } + return true +} + +// isMethodNamed reports whether f is a method with the specified +// package, receiver type, and method names. +func isMethodNamed(fn *types.Func, pkg, recv, name string) bool { + if fn.Pkg() != nil && fn.Pkg().Path() == pkg && fn.Name() == name { + if r := fn.Type().(*types.Signature).Recv(); r != nil { + if _, gotRecv := typesinternal.ReceiverNamed(r); gotRecv != nil { + return gotRecv.Obj().Name() == recv + } + } + } + return false +} + +// backindex is like [slices.Index] but from the back of the slice. +func backindex[T any](slice []T, i int) T { + return slice[len(slice)-1-i] +} diff --git a/go/analysis/passes/waitgroup/waitgroup_test.go b/go/analysis/passes/waitgroup/waitgroup_test.go new file mode 100644 index 00000000000..bd6443acd69 --- /dev/null +++ b/go/analysis/passes/waitgroup/waitgroup_test.go @@ -0,0 +1,16 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package waitgroup_test + +import ( + "testing" + + "golang.org/x/tools/go/analysis/analysistest" + "golang.org/x/tools/go/analysis/passes/waitgroup" +) + +func Test(t *testing.T) { + analysistest.Run(t, analysistest.TestData(), waitgroup.Analyzer, "a") +} diff --git a/go/analysis/unitchecker/unitchecker.go b/go/analysis/unitchecker/unitchecker.go index 2301ccfc0e4..1a9b3094e5e 100644 --- a/go/analysis/unitchecker/unitchecker.go +++ b/go/analysis/unitchecker/unitchecker.go @@ -144,7 +144,7 @@ func Run(configFile string, analyzers []*analysis.Analyzer) { for _, res := range results { tree.Add(fset, cfg.ID, res.a.Name, res.diagnostics, res.err) } - tree.Print() + tree.Print(os.Stdout) } else { // plain text exit := 0 @@ -156,7 +156,7 @@ func Run(configFile string, analyzers []*analysis.Analyzer) { } for _, res := range results { for _, diag := range res.diagnostics { - analysisflags.PrintPlain(fset, diag) + analysisflags.PrintPlain(os.Stderr, fset, analysisflags.Context, diag) exit = 1 } } diff --git a/go/callgraph/vta/testdata/src/callgraph_interfaces.go b/go/callgraph/vta/testdata/src/callgraph_interfaces.go index c272ff05701..8a9b51fb2ae 100644 --- a/go/callgraph/vta/testdata/src/callgraph_interfaces.go +++ b/go/callgraph/vta/testdata/src/callgraph_interfaces.go @@ -49,11 +49,11 @@ func Baz(b bool) { // func Do(b bool) I: // ... -// t1 = (C).Foo(struct{}{}:C) +// t1 = (C).Foo(C{}:C) // t2 = NewB() // t3 = make I <- B (t2) // return t3 // WANT: // Baz: Do(b) -> Do; invoke t0.Foo() -> A.Foo, B.Foo -// Do: (C).Foo(struct{}{}:C) -> C.Foo; NewB() -> NewB +// Do: (C).Foo(C{}:C) -> C.Foo; NewB() -> NewB diff --git a/go/callgraph/vta/testdata/src/callgraph_static.go b/go/callgraph/vta/testdata/src/callgraph_static.go index 147ed46435b..62e31a4f320 100644 --- a/go/callgraph/vta/testdata/src/callgraph_static.go +++ b/go/callgraph/vta/testdata/src/callgraph_static.go @@ -22,7 +22,7 @@ func Baz(a A) { // func Baz(a A): // t0 = (A).foo(a) // t1 = Bar() -// t2 = Baz(struct{}{}:A) +// t2 = Baz(A{}:A) // WANT: -// Baz: (A).foo(a) -> A.foo; Bar() -> Bar; Baz(struct{}{}:A) -> Baz +// Baz: (A).foo(a) -> A.foo; Bar() -> Bar; Baz(A{}:A) -> Baz diff --git a/go/callgraph/vta/testdata/src/callgraph_type_aliases.go b/go/callgraph/vta/testdata/src/callgraph_type_aliases.go index 9b32109a828..3624adfdb46 100644 --- a/go/callgraph/vta/testdata/src/callgraph_type_aliases.go +++ b/go/callgraph/vta/testdata/src/callgraph_type_aliases.go @@ -58,11 +58,11 @@ func Baz(b bool) { // func Do(b bool) I: // ... -// t1 = (C).Foo(struct{}{}:Z) +// t1 = (C).Foo(Z{}:Z) // t2 = NewY() // t3 = make I <- B (t2) // return t3 // WANT: // Baz: Do(b) -> Do; invoke t0.Foo() -> A.Foo, B.Foo -// Do: (C).Foo(struct{}{}:Z) -> C.Foo; NewY() -> NewY +// Do: (C).Foo(Z{}:Z) -> C.Foo; NewY() -> NewY diff --git a/go/callgraph/vta/vta_test.go b/go/callgraph/vta/vta_test.go index ce441eb7e1b..ea7d584d2d9 100644 --- a/go/callgraph/vta/vta_test.go +++ b/go/callgraph/vta/vta_test.go @@ -21,7 +21,7 @@ import ( ) func TestVTACallGraph(t *testing.T) { - errDiff := func(want, got, missing []string) { + errDiff := func(t *testing.T, want, got, missing []string) { t.Errorf("got:\n%s\n\nwant:\n%s\n\nmissing:\n%s\n\ndiff:\n%s", strings.Join(got, "\n"), strings.Join(want, "\n"), @@ -60,14 +60,14 @@ func TestVTACallGraph(t *testing.T) { g := CallGraph(ssautil.AllFunctions(prog), nil) got := callGraphStr(g) if missing := setdiff(want, got); len(missing) > 0 { - errDiff(want, got, missing) + errDiff(t, want, got, missing) } // Repeat the test with explicit CHA initial call graph. g = CallGraph(ssautil.AllFunctions(prog), cha.CallGraph(prog)) got = callGraphStr(g) if missing := setdiff(want, got); len(missing) > 0 { - errDiff(want, got, missing) + errDiff(t, want, got, missing) } }) } @@ -140,7 +140,7 @@ func TestVTAPanicMissingDefinitions(t *testing.T) { } for _, r := range res { if r.Err != nil { - t.Errorf("want no error for package %v; got %v", r.Pass.Pkg.Path(), r.Err) + t.Errorf("want no error for package %v; got %v", r.Action.Package.Types.Path(), r.Err) } } } diff --git a/go/gcexportdata/gcexportdata.go b/go/gcexportdata/gcexportdata.go index f3ab0a2e126..65fe2628e90 100644 --- a/go/gcexportdata/gcexportdata.go +++ b/go/gcexportdata/gcexportdata.go @@ -106,24 +106,18 @@ func Find(importPath, srcDir string) (filename, path string) { // additional trailing data beyond the end of the export data. func NewReader(r io.Reader) (io.Reader, error) { buf := bufio.NewReader(r) - _, size, err := gcimporter.FindExportData(buf) + size, err := gcimporter.FindExportData(buf) if err != nil { return nil, err } - if size >= 0 { - // We were given an archive and found the __.PKGDEF in it. - // This tells us the size of the export data, and we don't - // need to return the entire file. - return &io.LimitedReader{ - R: buf, - N: size, - }, nil - } else { - // We were given an object file. As such, we don't know how large - // the export data is and must return the entire file. - return buf, nil - } + // We were given an archive and found the __.PKGDEF in it. + // This tells us the size of the export data, and we don't + // need to return the entire file. + return &io.LimitedReader{ + R: buf, + N: size, + }, nil } // readAll works the same way as io.ReadAll, but avoids allocations and copies diff --git a/go/packages/external.go b/go/packages/external.go index 96db9daf314..91bd62e83b1 100644 --- a/go/packages/external.go +++ b/go/packages/external.go @@ -13,6 +13,7 @@ import ( "fmt" "os" "os/exec" + "slices" "strings" ) @@ -131,7 +132,7 @@ func findExternalDriver(cfg *Config) driver { // command. // // (See similar trick in Invocation.run in ../../internal/gocommand/invoke.go) - cmd.Env = append(slicesClip(cfg.Env), "PWD="+cfg.Dir) + cmd.Env = append(slices.Clip(cfg.Env), "PWD="+cfg.Dir) cmd.Stdin = bytes.NewReader(req) cmd.Stdout = buf cmd.Stderr = stderr @@ -150,7 +151,3 @@ func findExternalDriver(cfg *Config) driver { return &response, nil } } - -// slicesClip removes unused capacity from the slice, returning s[:len(s):len(s)]. -// TODO(adonovan): use go1.21 slices.Clip. -func slicesClip[S ~[]E, E any](s S) S { return s[:len(s):len(s)] } diff --git a/go/packages/golist.go b/go/packages/golist.go index 76f910ecec9..870271ed51f 100644 --- a/go/packages/golist.go +++ b/go/packages/golist.go @@ -505,13 +505,14 @@ func (state *golistState) createDriverResponse(words ...string) (*DriverResponse pkg := &Package{ Name: p.Name, ID: p.ImportPath, + Dir: p.Dir, GoFiles: absJoin(p.Dir, p.GoFiles, p.CgoFiles), CompiledGoFiles: absJoin(p.Dir, p.CompiledGoFiles), OtherFiles: absJoin(p.Dir, otherFiles(p)...), EmbedFiles: absJoin(p.Dir, p.EmbedFiles), EmbedPatterns: absJoin(p.Dir, p.EmbedPatterns), IgnoredFiles: absJoin(p.Dir, p.IgnoredGoFiles, p.IgnoredOtherFiles), - forTest: p.ForTest, + ForTest: p.ForTest, depsErrors: p.DepsErrors, Module: p.Module, } @@ -795,7 +796,7 @@ func jsonFlag(cfg *Config, goVersion int) string { // Request Dir in the unlikely case Export is not absolute. addFields("Dir", "Export") } - if cfg.Mode&needInternalForTest != 0 { + if cfg.Mode&NeedForTest != 0 { addFields("ForTest") } if cfg.Mode&needInternalDepsErrors != 0 { diff --git a/go/packages/loadmode_string.go b/go/packages/loadmode_string.go index 5fcad6ea6db..969da4c263c 100644 --- a/go/packages/loadmode_string.go +++ b/go/packages/loadmode_string.go @@ -23,6 +23,7 @@ var modes = [...]struct { {NeedSyntax, "NeedSyntax"}, {NeedTypesInfo, "NeedTypesInfo"}, {NeedTypesSizes, "NeedTypesSizes"}, + {NeedForTest, "NeedForTest"}, {NeedModule, "NeedModule"}, {NeedEmbedFiles, "NeedEmbedFiles"}, {NeedEmbedPatterns, "NeedEmbedPatterns"}, diff --git a/go/packages/packages.go b/go/packages/packages.go index 2ecc64238e8..9dedf9777dc 100644 --- a/go/packages/packages.go +++ b/go/packages/packages.go @@ -43,6 +43,20 @@ import ( // ID and Errors (if present) will always be filled. // [Load] may return more information than requested. // +// The Mode flag is a union of several bits named NeedName, +// NeedFiles, and so on, each of which determines whether +// a given field of Package (Name, Files, etc) should be +// populated. +// +// For convenience, we provide named constants for the most +// common combinations of Need flags: +// +// [LoadFiles] lists of files in each package +// [LoadImports] ... plus imports +// [LoadTypes] ... plus type information +// [LoadSyntax] ... plus type-annotated syntax +// [LoadAllSyntax] ... for all dependencies +// // Unfortunately there are a number of open bugs related to // interactions among the LoadMode bits: // - https://github.com/golang/go/issues/56633 @@ -55,7 +69,7 @@ const ( // NeedName adds Name and PkgPath. NeedName LoadMode = 1 << iota - // NeedFiles adds GoFiles, OtherFiles, and IgnoredFiles + // NeedFiles adds Dir, GoFiles, OtherFiles, and IgnoredFiles NeedFiles // NeedCompiledGoFiles adds CompiledGoFiles. @@ -86,9 +100,10 @@ const ( // needInternalDepsErrors adds the internal deps errors field for use by gopls. needInternalDepsErrors - // needInternalForTest adds the internal forTest field. + // NeedForTest adds ForTest. + // // Tests must also be set on the context for this field to be populated. - needInternalForTest + NeedForTest // typecheckCgo enables full support for type checking cgo. Requires Go 1.15+. // Modifies CompiledGoFiles and Types, and has no effect on its own. @@ -108,33 +123,18 @@ const ( const ( // LoadFiles loads the name and file names for the initial packages. - // - // Deprecated: LoadFiles exists for historical compatibility - // and should not be used. Please directly specify the needed fields using the Need values. LoadFiles = NeedName | NeedFiles | NeedCompiledGoFiles // LoadImports loads the name, file names, and import mapping for the initial packages. - // - // Deprecated: LoadImports exists for historical compatibility - // and should not be used. Please directly specify the needed fields using the Need values. LoadImports = LoadFiles | NeedImports // LoadTypes loads exported type information for the initial packages. - // - // Deprecated: LoadTypes exists for historical compatibility - // and should not be used. Please directly specify the needed fields using the Need values. LoadTypes = LoadImports | NeedTypes | NeedTypesSizes // LoadSyntax loads typed syntax for the initial packages. - // - // Deprecated: LoadSyntax exists for historical compatibility - // and should not be used. Please directly specify the needed fields using the Need values. LoadSyntax = LoadTypes | NeedSyntax | NeedTypesInfo // LoadAllSyntax loads typed syntax for the initial packages and all dependencies. - // - // Deprecated: LoadAllSyntax exists for historical compatibility - // and should not be used. Please directly specify the needed fields using the Need values. LoadAllSyntax = LoadSyntax | NeedDeps // Deprecated: NeedExportsFile is a historical misspelling of NeedExportFile. @@ -434,6 +434,12 @@ type Package struct { // PkgPath is the package path as used by the go/types package. PkgPath string + // Dir is the directory associated with the package, if it exists. + // + // For packages listed by the go command, this is the directory containing + // the package files. + Dir string + // Errors contains any errors encountered querying the metadata // of the package, or while parsing or type-checking its files. Errors []Error @@ -521,8 +527,8 @@ type Package struct { // -- internal -- - // forTest is the package under test, if any. - forTest string + // ForTest is the package under test, if any. + ForTest string // depsErrors is the DepsErrors field from the go list response, if any. depsErrors []*packagesinternal.PackageError @@ -551,9 +557,6 @@ type ModuleError struct { } func init() { - packagesinternal.GetForTest = func(p interface{}) string { - return p.(*Package).forTest - } packagesinternal.GetDepsErrors = func(p interface{}) []*packagesinternal.PackageError { return p.(*Package).depsErrors } @@ -565,7 +568,6 @@ func init() { } packagesinternal.TypecheckCgo = int(typecheckCgo) packagesinternal.DepsErrors = int(needInternalDepsErrors) - packagesinternal.ForTest = int(needInternalForTest) } // An Error describes a problem with a package's metadata, syntax, or types. diff --git a/go/packages/packages_test.go b/go/packages/packages_test.go index 939f2df2da4..11c4f77dce4 100644 --- a/go/packages/packages_test.go +++ b/go/packages/packages_test.go @@ -26,11 +26,13 @@ import ( "testing/fstest" "time" + "github.com/google/go-cmp/cmp" "golang.org/x/tools/go/packages" "golang.org/x/tools/internal/packagesinternal" "golang.org/x/tools/internal/packagestest" "golang.org/x/tools/internal/testenv" "golang.org/x/tools/internal/testfiles" + "golang.org/x/tools/txtar" ) // testCtx is canceled when the test binary is about to time out. @@ -509,11 +511,6 @@ func testConfigDir(t *testing.T, exporter packagestest.Exporter) { test.dir, test.pattern, got, test.want) } if fails != test.fails { - // TODO: remove when go#28023 is fixed - if test.fails && strings.HasPrefix(test.pattern, "./") && exporter == packagestest.Modules { - // Currently go list in module mode does not handle missing directories correctly. - continue - } t.Errorf("dir %q, pattern %q: error %v, want %v", test.dir, test.pattern, fails, test.fails) } @@ -2324,6 +2321,10 @@ func TestLoadModeStrings(t *testing.T) { packages.NeedName | packages.NeedExportFile, "(NeedName|NeedExportFile)", }, + { + packages.NeedForTest | packages.NeedEmbedFiles | packages.NeedEmbedPatterns, + "(NeedForTest|NeedEmbedFiles|NeedEmbedPatterns)", + }, { packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedExportFile | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo | packages.NeedTypesSizes, "(NeedName|NeedFiles|NeedCompiledGoFiles|NeedImports|NeedDeps|NeedExportFile|NeedTypes|NeedSyntax|NeedTypesInfo|NeedTypesSizes)", @@ -2425,8 +2426,7 @@ func testForTestField(t *testing.T, exporter packagestest.Exporter) { if !hasTestFile { continue } - got := packagesinternal.GetForTest(pkg) - if got != forTest { + if got := pkg.ForTest; got != forTest { t.Errorf("expected %q, got %q", forTest, got) } } @@ -3177,6 +3177,51 @@ func TestIssue69606b(t *testing.T) { } } +// TestIssue70394 tests materializing an alias type defined in a package (m/a) +// in another package (m/b) where the types for m/b are coming from the compiler, +// e.g. `go list -compiled=true ... m/b`. +func TestIssue70394(t *testing.T) { + // TODO(taking): backport https://go.dev/cl/604099 so that this works on 23. + testenv.NeedsGo1Point(t, 24) + testenv.NeedsTool(t, "go") // requires go list. + testenv.NeedsGoBuild(t) // requires the compiler for export data. + + t.Setenv("GODEBUG", "gotypesalias=1") + + dir := t.TempDir() + overlay := map[string][]byte{ + filepath.Join(dir, "go.mod"): []byte("module m"), // go version of the module does not matter. + filepath.Join(dir, "a/a.go"): []byte(`package a; type A = int32`), + filepath.Join(dir, "b/b.go"): []byte(`package b; import "m/a"; var V a.A`), + } + cfg := &packages.Config{ + Dir: dir, + Mode: packages.NeedTypes, // just NeedsTypes allows for loading export data. + Overlay: overlay, + Env: append(os.Environ(), "GOFLAGS=-mod=vendor", "GOWORK=off"), + } + pkgs, err := packages.Load(cfg, "m/b") + if err != nil { + t.Fatal(err) + } + if errs := packages.PrintErrors(pkgs); errs > 0 { + t.Fatalf("Got %d errors while loading packages.", errs) + } + if len(pkgs) != 1 { + t.Fatalf("Loaded %d packages. expected 1", len(pkgs)) + } + + pkg := pkgs[0] + scope := pkg.Types.Scope() + obj := scope.Lookup("V") + if obj == nil { + t.Fatalf("Failed to find object %q in package %q", "V", pkg) + } + if _, ok := obj.Type().(*types.Alias); !ok { + t.Errorf("Object %q has type %q. expected an alias", obj, obj.Type()) + } +} + // TestNeedTypesInfoOnly tests when NeedTypesInfo was set and NeedSyntax & NeedTypes were not, // Load should include the TypesInfo of packages properly func TestLoadTypesInfoWithoutSyntaxOrTypes(t *testing.T) { @@ -3209,3 +3254,102 @@ func foo() int { t.Errorf("expected types info to be present but got nil") } } + +// TestDirAndForTest tests the new fields added as part of golang/go#38445. +func TestDirAndForTest(t *testing.T) { + testenv.NeedsGoPackages(t) + + dir := writeTree(t, ` +-- go.mod -- +module example.com + +go 1.18 + +-- a/a.go -- +package a + +func Foo() int { return 1 } + +-- a/a_test.go -- +package a + +func Bar() int { return 2 } + +-- a/a_x_test.go -- +package a_test + +import ( + "example.com/a" + "example.com/b" +) + +func _() { + if got := a.Foo() + a.Bar() + b.Baz(); got != 6 { + panic("whoops") + } +} + +-- b/b.go -- +package b + +import "example.com/a" + +func Baz() int { return 3 } + +func Foo() int { return a.Foo() } +`) + + pkgs, err := packages.Load(&packages.Config{ + Mode: packages.NeedName | + packages.NeedFiles | + packages.NeedForTest | + packages.NeedImports, + Dir: dir, + Tests: true, + }, "./...") + if err != nil { + t.Fatal(err) + } + type result struct{ Dir, ForTest string } + got := make(map[string]result) + packages.Visit(pkgs, nil, func(pkg *packages.Package) { + if strings.Contains(pkg.PkgPath, ".") { // ignore std + rel, err := filepath.Rel(dir, pkg.Dir) + if err != nil { + t.Errorf("Rel(%q, %q) failed: %v", dir, pkg.Dir, err) + return + } + got[pkg.ID] = result{ + Dir: rel, + ForTest: pkg.ForTest, + } + } + }) + want := map[string]result{ + "example.com/a": {"a", ""}, + "example.com/a.test": {"a", ""}, + "example.com/a [example.com/a.test]": {"a", "example.com/a"}, // test variant + "example.com/a_test [example.com/a.test]": {"a", "example.com/a"}, // x_test + "example.com/b [example.com/a.test]": {"b", "example.com/a"}, // intermediate test variant + "example.com/b": {"b", ""}, + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("Load returned mismatching ForTest fields (ID->result -want +got):\n%s", diff) + } + t.Logf("Packages: %+v", pkgs) +} + +func writeTree(t *testing.T, archive string) string { + root := t.TempDir() + + for _, f := range txtar.Parse([]byte(archive)).Files { + filename := filepath.Join(root, f.Name) + if err := os.MkdirAll(filepath.Dir(filename), 0777); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filename, f.Data, 0666); err != nil { + t.Fatal(err) + } + } + return root +} diff --git a/go/ssa/const.go b/go/ssa/const.go index 865329bfd34..4dc53ef83cc 100644 --- a/go/ssa/const.go +++ b/go/ssa/const.go @@ -12,9 +12,9 @@ import ( "go/token" "go/types" "strconv" - "strings" "golang.org/x/tools/internal/typeparams" + "golang.org/x/tools/internal/typesinternal" ) // NewConst returns a new constant of the specified value and type. @@ -78,7 +78,7 @@ func zeroConst(t types.Type) *Const { func (c *Const) RelString(from *types.Package) string { var s string if c.Value == nil { - s = zeroString(c.typ, from) + s = typesinternal.ZeroString(c.typ, types.RelativeTo(from)) } else if c.Value.Kind() == constant.String { s = constant.StringVal(c.Value) const max = 20 @@ -93,44 +93,6 @@ func (c *Const) RelString(from *types.Package) string { return s + ":" + relType(c.Type(), from) } -// zeroString returns the string representation of the "zero" value of the type t. -func zeroString(t types.Type, from *types.Package) string { - switch t := t.(type) { - case *types.Basic: - switch { - case t.Info()&types.IsBoolean != 0: - return "false" - case t.Info()&types.IsNumeric != 0: - return "0" - case t.Info()&types.IsString != 0: - return `""` - case t.Kind() == types.UnsafePointer: - fallthrough - case t.Kind() == types.UntypedNil: - return "nil" - default: - panic(fmt.Sprint("zeroString for unexpected type:", t)) - } - case *types.Pointer, *types.Slice, *types.Interface, *types.Chan, *types.Map, *types.Signature: - return "nil" - case *types.Named, *types.Alias: - return zeroString(t.Underlying(), from) - case *types.Array, *types.Struct: - return relType(t, from) + "{}" - case *types.Tuple: - // Tuples are not normal values. - // We are currently format as "(t[0], ..., t[n])". Could be something else. - components := make([]string, t.Len()) - for i := 0; i < t.Len(); i++ { - components[i] = zeroString(t.At(i).Type(), from) - } - return "(" + strings.Join(components, ", ") + ")" - case *types.TypeParam: - return "*new(" + relType(t, from) + ")" - } - panic(fmt.Sprint("zeroString: unexpected ", t)) -} - func (c *Const) Name() string { return c.RelString(nil) } diff --git a/go/ssa/interp/interp.go b/go/ssa/interp/interp.go index 3ba78fbd89e..f80db0676c7 100644 --- a/go/ssa/interp/interp.go +++ b/go/ssa/interp/interp.go @@ -48,9 +48,11 @@ import ( "fmt" "go/token" "go/types" + "log" "os" "reflect" "runtime" + "slices" "sync/atomic" _ "unsafe" @@ -108,6 +110,7 @@ type frame struct { result value panicking bool panic interface{} + phitemps []value // temporaries for parallel phi assignment } func (fr *frame) get(key ssa.Value) value { @@ -379,12 +382,7 @@ func visitInstr(fr *frame, instr ssa.Instruction) continuation { fr.env[instr] = &closure{instr.Fn.(*ssa.Function), bindings} case *ssa.Phi: - for i, pred := range instr.Block().Preds { - if fr.prevBlock == pred { - fr.env[instr] = fr.get(instr.Edges[i]) - break - } - } + log.Fatal("unreachable") // phis are processed at block entry case *ssa.Select: var cases []reflect.SelectCase @@ -589,8 +587,9 @@ func runFrame(fr *frame) { if fr.i.mode&EnableTracing != 0 { fmt.Fprintf(os.Stderr, ".%s:\n", fr.block) } - block: - for _, instr := range fr.block.Instrs { + + nonPhis := executePhis(fr) + for _, instr := range nonPhis { if fr.i.mode&EnableTracing != 0 { if v, ok := instr.(ssa.Value); ok { fmt.Fprintln(os.Stderr, "\t", v.Name(), "=", instr) @@ -598,16 +597,47 @@ func runFrame(fr *frame) { fmt.Fprintln(os.Stderr, "\t", instr) } } - switch visitInstr(fr, instr) { - case kReturn: + if visitInstr(fr, instr) == kReturn { return - case kNext: - // no-op - case kJump: - break block } + // Inv: kNext (continue) or kJump (last instr) + } + } +} + +// executePhis executes the phi-nodes at the start of the current +// block and returns the non-phi instructions. +func executePhis(fr *frame) []ssa.Instruction { + firstNonPhi := -1 + for i, instr := range fr.block.Instrs { + if _, ok := instr.(*ssa.Phi); !ok { + firstNonPhi = i + break + } + } + // Inv: 0 <= firstNonPhi; every block contains a non-phi. + + nonPhis := fr.block.Instrs[firstNonPhi:] + if firstNonPhi > 0 { + phis := fr.block.Instrs[:firstNonPhi] + // Execute parallel assignment of phis. + // + // See "the swap problem" in Briggs et al's "Practical Improvements + // to the Construction and Destruction of SSA Form" for discussion. + predIndex := slices.Index(fr.block.Preds, fr.prevBlock) + fr.phitemps = fr.phitemps[:0] + for _, phi := range phis { + phi := phi.(*ssa.Phi) + if fr.i.mode&EnableTracing != 0 { + fmt.Fprintln(os.Stderr, "\t", phi.Name(), "=", phi) + } + fr.phitemps = append(fr.phitemps, fr.get(phi.Edges[predIndex])) + } + for i, phi := range phis { + fr.env[phi.(*ssa.Phi)] = fr.phitemps[i] } } + return nonPhis } // doRecover implements the recover() built-in. diff --git a/go/ssa/interp/interp_test.go b/go/ssa/interp/interp_test.go index f382c61f223..2aaecb850e7 100644 --- a/go/ssa/interp/interp_test.go +++ b/go/ssa/interp/interp_test.go @@ -21,7 +21,6 @@ import ( "go/build" "go/types" "io" - "log" "os" "path/filepath" "runtime" @@ -137,6 +136,7 @@ var testdataTests = []string{ "fixedbugs/issue52835.go", "fixedbugs/issue55086.go", "fixedbugs/issue66783.go", + "fixedbugs/issue69929.go", "typeassert.go", "zeros.go", "slice2array.go", @@ -152,7 +152,8 @@ func init() { os.Setenv("GOARCH", runtime.GOARCH) } -func run(t *testing.T, input string, goroot string) { +// run runs a single test. On success it returns the captured std{out,err}. +func run(t *testing.T, input string, goroot string) string { testenv.NeedsExec(t) // really we just need os.Pipe, but os/exec uses pipes t.Logf("Input: %s\n", input) @@ -182,14 +183,13 @@ func run(t *testing.T, input string, goroot string) { // Print a helpful hint if we don't make it to the end. var hint string defer func() { + t.Logf("Duration: %v", time.Since(start)) if hint != "" { t.Log("FAIL") t.Log(hint) } else { t.Log("PASS") } - - interp.CapturedOutput = nil }() hint = fmt.Sprintf("To dump SSA representation, run:\n%% go build golang.org/x/tools/cmd/ssadump && ./ssadump -test -build=CFP %s\n", input) @@ -209,8 +209,6 @@ func run(t *testing.T, input string, goroot string) { t.Fatalf("not a main package: %s", input) } - interp.CapturedOutput = new(bytes.Buffer) - sizes := types.SizesFor("gc", ctx.GOARCH) if sizes.Sizeof(types.Typ[types.Int]) < 4 { panic("bogus SizesFor") @@ -222,12 +220,8 @@ func run(t *testing.T, input string, goroot string) { // // While capturing is in effect, we must not write any // test-related stuff to stderr (including log.Print, t.Log, etc). - // - // Suppress capturing if we are the child process of TestRangeFunc. - // TODO(adonovan): simplify that test using this mechanism. - // Also eliminate the redundant interp.CapturedOutput mechanism. - restore := func() {} // restore files and log the mixed out/err. - if os.Getenv("INTERPTEST_CHILD") == "" { + var restore func() string // restore files and log+return the mixed out/err. + { // Connect std{out,err} to pipe. r, w, err := os.Pipe() if err != nil { @@ -239,7 +233,7 @@ func run(t *testing.T, input string, goroot string) { os.Stderr = w // Buffer what is written. - var buf bytes.Buffer + var buf strings.Builder done := make(chan struct{}) go func() { if _, err := io.Copy(&buf, r); err != nil { @@ -249,12 +243,14 @@ func run(t *testing.T, input string, goroot string) { }() // Finally, restore the files and log what was captured. - restore = func() { + restore = func() string { os.Stdout = savedStdout os.Stderr = savedStderr w.Close() <-done - t.Logf("Interpreter's stdout+stderr:\n%s", &buf) + captured := buf.String() + t.Logf("Interpreter's stdout+stderr:\n%s", captured) + return captured } } @@ -262,20 +258,18 @@ func run(t *testing.T, input string, goroot string) { // imode |= interp.DisableRecover // enable for debugging // imode |= interp.EnableTracing // enable for debugging exitCode := interp.Interpret(mainPkg, imode, sizes, input, []string{}) - restore() + capturedOutput := restore() if exitCode != 0 { t.Fatalf("interpreting %s: exit code was %d", input, exitCode) } // $GOROOT/test tests use this convention: - if strings.Contains(interp.CapturedOutput.String(), "BUG") { + if strings.Contains(capturedOutput, "BUG") { t.Fatalf("interpreting %s: exited zero but output contained 'BUG'", input) } hint = "" // call off the hounds - if false { - t.Log(input, time.Since(start)) // test profiling - } + return capturedOutput } // makeGoroot copies testdata/src into the "src" directory of a temporary @@ -327,13 +321,9 @@ const GOARCH = %q // TestTestdataFiles runs the interpreter on testdata/*.go. func TestTestdataFiles(t *testing.T) { goroot := makeGoroot(t) - cwd, err := os.Getwd() - if err != nil { - log.Fatal(err) - } for _, input := range testdataTests { t.Run(input, func(t *testing.T) { - run(t, filepath.Join(cwd, "testdata", input), goroot) + run(t, filepath.Join("testdata", input), goroot) }) } } diff --git a/go/ssa/interp/ops.go b/go/ssa/interp/ops.go index 7254676a4d0..d03aeace8f6 100644 --- a/go/ssa/interp/ops.go +++ b/go/ssa/interp/ops.go @@ -13,7 +13,6 @@ import ( "os" "reflect" "strings" - "sync" "unsafe" "golang.org/x/tools/go/ssa" @@ -950,27 +949,8 @@ func typeAssert(i *interpreter, instr *ssa.TypeAssert, itf iface) value { return v } -// If CapturedOutput is non-nil, all writes by the interpreted program -// to file descriptors 1 and 2 will also be written to CapturedOutput. -// -// (The $GOROOT/test system requires that the test be considered a -// failure if "BUG" appears in the combined stdout/stderr output, even -// if it exits zero. This is a global variable shared by all -// interpreters in the same process.) +// This variable is no longer used but remains to prevent build breakage. var CapturedOutput *bytes.Buffer -var capturedOutputMu sync.Mutex - -// write writes bytes b to the target program's standard output. -// The print/println built-ins and the write() system call funnel -// through here so they can be captured by the test driver. -func print(b []byte) (int, error) { - if CapturedOutput != nil { - capturedOutputMu.Lock() - CapturedOutput.Write(b) // ignore errors - capturedOutputMu.Unlock() - } - return os.Stdout.Write(b) -} // callBuiltin interprets a call to builtin fn with arguments args, // returning its result. @@ -1026,7 +1006,7 @@ func callBuiltin(caller *frame, callpos token.Pos, fn *ssa.Builtin, args []value if ln { buf.WriteRune('\n') } - print(buf.Bytes()) + os.Stderr.Write(buf.Bytes()) return nil case "len": diff --git a/go/ssa/interp/rangefunc_test.go b/go/ssa/interp/rangefunc_test.go index 58b7f43eca4..434468ff1f9 100644 --- a/go/ssa/interp/rangefunc_test.go +++ b/go/ssa/interp/rangefunc_test.go @@ -5,12 +5,9 @@ package interp_test import ( - "bytes" - "log" - "os" - "os/exec" "path/filepath" "reflect" + "strings" "testing" "golang.org/x/tools/internal/testenv" @@ -19,34 +16,15 @@ import ( func TestIssue69298(t *testing.T) { testenv.NeedsGo1Point(t, 23) - // TODO: Is cwd actually needed here? goroot := makeGoroot(t) - cwd, err := os.Getwd() - if err != nil { - log.Fatal(err) - } - run(t, filepath.Join(cwd, "testdata", "fixedbugs/issue69298.go"), goroot) + run(t, filepath.Join("testdata", "fixedbugs", "issue69298.go"), goroot) } -// TestRangeFunc tests range-over-func in a subprocess. func TestRangeFunc(t *testing.T) { testenv.NeedsGo1Point(t, 23) - // TODO(taking): Remove subprocess from the test and capture output another way. - if os.Getenv("INTERPTEST_CHILD") == "1" { - testRangeFunc(t) - return - } - - testenv.NeedsExec(t) - testenv.NeedsTool(t, "go") - - cmd := exec.Command(os.Args[0], "-test.run=TestRangeFunc") - cmd.Env = append(os.Environ(), "INTERPTEST_CHILD=1") - out, err := cmd.CombinedOutput() - if len(out) > 0 { - t.Logf("out=<<%s>>", out) - } + goroot := makeGoroot(t) + out := run(t, filepath.Join("testdata", "rangefunc.go"), goroot) // Check the output of the tests. const ( @@ -62,14 +40,14 @@ func TestRangeFunc(t *testing.T) { ) expected := map[string][]string{ // rangefunc.go - "TestCheck": []string{"i = 45", CERR_DONE}, - "TestCooperativeBadOfSliceIndex": []string{RERR_EXHAUSTED, "i = 36"}, - "TestCooperativeBadOfSliceIndexCheck": []string{CERR_EXHAUSTED, "i = 36"}, - "TestTrickyIterAll": []string{"i = 36", RERR_EXHAUSTED}, - "TestTrickyIterOne": []string{"i = 1", RERR_EXHAUSTED}, - "TestTrickyIterZero": []string{"i = 0", RERR_EXHAUSTED}, - "TestTrickyIterZeroCheck": []string{"i = 0", CERR_EXHAUSTED}, - "TestTrickyIterEcho": []string{ + "TestCheck": {"i = 45", CERR_DONE}, + "TestCooperativeBadOfSliceIndex": {RERR_EXHAUSTED, "i = 36"}, + "TestCooperativeBadOfSliceIndexCheck": {CERR_EXHAUSTED, "i = 36"}, + "TestTrickyIterAll": {"i = 36", RERR_EXHAUSTED}, + "TestTrickyIterOne": {"i = 1", RERR_EXHAUSTED}, + "TestTrickyIterZero": {"i = 0", RERR_EXHAUSTED}, + "TestTrickyIterZeroCheck": {"i = 0", CERR_EXHAUSTED}, + "TestTrickyIterEcho": { "first loop i=0", "first loop i=1", "first loop i=3", @@ -79,7 +57,7 @@ func TestRangeFunc(t *testing.T) { RERR_EXHAUSTED, "end i=0", }, - "TestTrickyIterEcho2": []string{ + "TestTrickyIterEcho2": { "k=0,x=1,i=0", "k=0,x=2,i=1", "k=0,x=3,i=3", @@ -89,37 +67,37 @@ func TestRangeFunc(t *testing.T) { RERR_EXHAUSTED, "end i=1", }, - "TestBreak1": []string{"[1 2 -1 1 2 -2 1 2 -3]"}, - "TestBreak2": []string{"[1 2 -1 1 2 -2 1 2 -3]"}, - "TestContinue": []string{"[-1 1 2 -2 1 2 -3 1 2 -4]"}, - "TestBreak3": []string{"[100 10 2 4 200 10 2 4 20 2 4 300 10 2 4 20 2 4 30]"}, - "TestBreak1BadA": []string{"[1 2 -1 1 2 -2 1 2 -3]", RERR_DONE}, - "TestBreak1BadB": []string{"[1 2]", RERR_DONE}, - "TestMultiCont0": []string{"[1000 10 2 4 2000]"}, - "TestMultiCont1": []string{"[1000 10 2 4]", RERR_DONE}, - "TestMultiCont2": []string{"[1000 10 2 4]", RERR_DONE}, - "TestMultiCont3": []string{"[1000 10 2 4]", RERR_DONE}, - "TestMultiBreak0": []string{"[1000 10 2 4]", RERR_DONE}, - "TestMultiBreak1": []string{"[1000 10 2 4]", RERR_DONE}, - "TestMultiBreak2": []string{"[1000 10 2 4]", RERR_DONE}, - "TestMultiBreak3": []string{"[1000 10 2 4]", RERR_DONE}, - "TestPanickyIterator1": []string{panickyIterMsg}, - "TestPanickyIterator1Check": []string{panickyIterMsg}, - "TestPanickyIterator2": []string{RERR_MISSING}, - "TestPanickyIterator2Check": []string{CERR_MISSING}, - "TestPanickyIterator3": []string{"[100 10 1 2 200 10 1 2]"}, - "TestPanickyIterator3Check": []string{"[100 10 1 2 200 10 1 2]"}, - "TestPanickyIterator4": []string{RERR_MISSING}, - "TestPanickyIterator4Check": []string{CERR_MISSING}, - "TestVeryBad1": []string{"[1 10]"}, - "TestVeryBad2": []string{"[1 10]"}, - "TestVeryBadCheck": []string{"[1 10]"}, - "TestOk": []string{"[1 10]"}, - "TestBreak1BadDefer": []string{RERR_DONE, "[1 2 -1 1 2 -2 1 2 -3 -30 -20 -10]"}, - "TestReturns": []string{"[-1 1 2 -10]", "[-1 1 2 -10]", RERR_DONE, "[-1 1 2 -10]", RERR_DONE}, - "TestGotoA": []string{"testGotoA1[-1 1 2 -2 1 2 -3 1 2 -4 -30 -20 -10]", "testGotoA2[-1 1 2 -2 1 2 -3 1 2 -4 -30 -20 -10]", RERR_DONE, "testGotoA3[-1 1 2 -10]", RERR_DONE}, - "TestGotoB": []string{"testGotoB1[-1 1 2 999 -10]", "testGotoB2[-1 1 2 -10]", RERR_DONE, "testGotoB3[-1 1 2 -10]", RERR_DONE}, - "TestPanicReturns": []string{ + "TestBreak1": {"[1 2 -1 1 2 -2 1 2 -3]"}, + "TestBreak2": {"[1 2 -1 1 2 -2 1 2 -3]"}, + "TestContinue": {"[-1 1 2 -2 1 2 -3 1 2 -4]"}, + "TestBreak3": {"[100 10 2 4 200 10 2 4 20 2 4 300 10 2 4 20 2 4 30]"}, + "TestBreak1BadA": {"[1 2 -1 1 2 -2 1 2 -3]", RERR_DONE}, + "TestBreak1BadB": {"[1 2]", RERR_DONE}, + "TestMultiCont0": {"[1000 10 2 4 2000]"}, + "TestMultiCont1": {"[1000 10 2 4]", RERR_DONE}, + "TestMultiCont2": {"[1000 10 2 4]", RERR_DONE}, + "TestMultiCont3": {"[1000 10 2 4]", RERR_DONE}, + "TestMultiBreak0": {"[1000 10 2 4]", RERR_DONE}, + "TestMultiBreak1": {"[1000 10 2 4]", RERR_DONE}, + "TestMultiBreak2": {"[1000 10 2 4]", RERR_DONE}, + "TestMultiBreak3": {"[1000 10 2 4]", RERR_DONE}, + "TestPanickyIterator1": {panickyIterMsg}, + "TestPanickyIterator1Check": {panickyIterMsg}, + "TestPanickyIterator2": {RERR_MISSING}, + "TestPanickyIterator2Check": {CERR_MISSING}, + "TestPanickyIterator3": {"[100 10 1 2 200 10 1 2]"}, + "TestPanickyIterator3Check": {"[100 10 1 2 200 10 1 2]"}, + "TestPanickyIterator4": {RERR_MISSING}, + "TestPanickyIterator4Check": {CERR_MISSING}, + "TestVeryBad1": {"[1 10]"}, + "TestVeryBad2": {"[1 10]"}, + "TestVeryBadCheck": {"[1 10]"}, + "TestOk": {"[1 10]"}, + "TestBreak1BadDefer": {RERR_DONE, "[1 2 -1 1 2 -2 1 2 -3 -30 -20 -10]"}, + "TestReturns": {"[-1 1 2 -10]", "[-1 1 2 -10]", RERR_DONE, "[-1 1 2 -10]", RERR_DONE}, + "TestGotoA": {"testGotoA1[-1 1 2 -2 1 2 -3 1 2 -4 -30 -20 -10]", "testGotoA2[-1 1 2 -2 1 2 -3 1 2 -4 -30 -20 -10]", RERR_DONE, "testGotoA3[-1 1 2 -10]", RERR_DONE}, + "TestGotoB": {"testGotoB1[-1 1 2 999 -10]", "testGotoB2[-1 1 2 -10]", RERR_DONE, "testGotoB3[-1 1 2 -10]", RERR_DONE}, + "TestPanicReturns": { "Got expected 'f return'", "Got expected 'g return'", "Got expected 'h return'", @@ -130,9 +108,9 @@ func TestRangeFunc(t *testing.T) { }, } got := make(map[string][]string) - for _, ln := range bytes.Split(out, []byte("\n")) { - if ind := bytes.Index(ln, []byte(" \t ")); ind >= 0 { - n, m := string(ln[:ind]), string(ln[ind+3:]) + for _, ln := range strings.Split(out, "\n") { + if ind := strings.Index(ln, " \t "); ind >= 0 { + n, m := ln[:ind], ln[ind+3:] got[n] = append(got[n], m) } } @@ -146,24 +124,4 @@ func TestRangeFunc(t *testing.T) { t.Errorf("No expected output for test %s. got %v", n, gs) } } - - var exitcode int - if err, ok := err.(*exec.ExitError); ok { - exitcode = err.ExitCode() - } - const want = 0 - if exitcode != want { - t.Errorf("exited %d, want %d", exitcode, want) - } -} - -func testRangeFunc(t *testing.T) { - goroot := makeGoroot(t) - cwd, err := os.Getwd() - if err != nil { - log.Fatal(err) - } - - input := "rangefunc.go" - run(t, filepath.Join(cwd, "testdata", input), goroot) } diff --git a/go/ssa/interp/testdata/fixedbugs/issue69929.go b/go/ssa/interp/testdata/fixedbugs/issue69929.go new file mode 100644 index 00000000000..8e91a89c640 --- /dev/null +++ b/go/ssa/interp/testdata/fixedbugs/issue69929.go @@ -0,0 +1,67 @@ +package main + +// This is a regression test for a bug (#69929) in +// the SSA interpreter in which it would not execute phis in parallel. +// +// The insert function below has interdependent phi nodes: +// +// entry: +// t0 = *root // t0 is x or y before loop +// jump test +// body: +// print(t5) // t5 is x at loop entry +// t3 = t5.Child // t3 is x after loop +// jump test +// test: +// t5 = phi(t0, t3) // t5 is x at loop entry +// t6 = phi(t0, t5) // t6 is y at loop entry +// if t5 != nil goto body else done +// done: +// print(t6) +// return +// +// The two phis: +// +// t5 = phi(t0, t3) +// t6 = phi(t0, t5) +// +// must be executed in parallel as if they were written in Go +// as: +// +// t5, t6 = phi(t0, t3), phi(t0, t5) +// +// with the second phi node observing the original, not +// updated, value of t5. (In more complex examples, the phi +// nodes may be mutually recursive, breaking partial solutions +// based on simple reordering of the phi instructions. See the +// Briggs paper for detail.) +// +// The correct behavior is print(1, root); print(2, root); print(3, root). +// The previous incorrect behavior had print(2, nil). + +func main() { + insert() + print(3, root) +} + +var root = new(node) + +type node struct{ child *node } + +func insert() { + x := root + y := x + for x != nil { + y = x + print(1, y) + x = x.child + } + print(2, y) +} + +func print(order int, ptr *node) { + println(order, ptr) + if ptr != root { + panic(ptr) + } +} diff --git a/gopls/doc/analyzers.md b/gopls/doc/analyzers.md index f7083bb0e89..38e246ecb47 100644 --- a/gopls/doc/analyzers.md +++ b/gopls/doc/analyzers.md @@ -842,26 +842,6 @@ Default: on. Package documentation: [timeformat](https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/timeformat) - -## `undeclaredname`: suggested fixes for "undeclared name: <>" - - -This checker provides suggested fixes for type errors of the -type "undeclared name: <>". It will either insert a new statement, -such as: - - <> := - -or a new function declaration, such as: - - func <>(inferred parameters) { - panic("implement me!") - } - -Default: on. - -Package documentation: [undeclaredname](https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/undeclaredname) - ## `unmarshal`: report passing non-pointer or non-interface values to unmarshal @@ -996,4 +976,70 @@ Default: off. Enable by setting `"analyses": {"useany": true}`. Package documentation: [useany](https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/useany) + +## `waitgroup`: check for misuses of sync.WaitGroup + + +This analyzer detects mistaken calls to the (*sync.WaitGroup).Add +method from inside a new goroutine, causing Add to race with Wait: + + // WRONG + var wg sync.WaitGroup + go func() { + wg.Add(1) // "WaitGroup.Add called from inside new goroutine" + defer wg.Done() + ... + }() + wg.Wait() // (may return prematurely before new goroutine starts) + +The correct code calls Add before starting the goroutine: + + // RIGHT + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + ... + }() + wg.Wait() + +Default: on. + +Package documentation: [waitgroup](https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/waitgroup) + + +## `yield`: report calls to yield where the result is ignored + + +After a yield function returns false, the caller should not call +the yield function again; generally the iterator should return +promptly. + +This example fails to check the result of the call to yield, +causing this analyzer to report a diagnostic: + + yield(1) // yield may be called again (on L2) after returning false + yield(2) + +The corrected code is either this: + + if yield(1) { yield(2) } + +or simply: + + _ = yield(1) && yield(2) + +It is not always a mistake to ignore the result of yield. +For example, this is a valid single-element iterator: + + yield(1) // ok to ignore result + return + +It is only a mistake when the yield call that returned false may be +followed by another call. + +Default: on. + +Package documentation: [yield](https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/yield) + diff --git a/gopls/doc/assets/add-test-for-func.png b/gopls/doc/assets/add-test-for-func.png new file mode 100644 index 00000000000..ddfe7c656d8 Binary files /dev/null and b/gopls/doc/assets/add-test-for-func.png differ diff --git a/gopls/doc/codelenses.md b/gopls/doc/codelenses.md index 0930076bec6..b7687bb3b30 100644 --- a/gopls/doc/codelenses.md +++ b/gopls/doc/codelenses.md @@ -97,16 +97,15 @@ Default: off File type: Go -## `run_govulncheck`: Run govulncheck +## `run_govulncheck`: Run govulncheck (legacy) -This codelens source annotates the `module` directive in a -go.mod file with a command to run Govulncheck. +This codelens source annotates the `module` directive in a go.mod file +with a command to run Govulncheck asynchronously. -[Govulncheck](https://go.dev/blog/vuln) is a static -analysis tool that computes the set of functions reachable -within your application, including dependencies; -queries a database of known security vulnerabilities; and +[Govulncheck](https://go.dev/blog/vuln) is a static analysis tool that +computes the set of functions reachable within your application, including +dependencies; queries a database of known security vulnerabilities; and reports any potential problems it finds. @@ -157,4 +156,20 @@ Default: on File type: go.mod +## `vulncheck`: Run govulncheck + + +This codelens source annotates the `module` directive in a go.mod file +with a command to run govulncheck synchronously. + +[Govulncheck](https://go.dev/blog/vuln) is a static analysis tool that +computes the set of functions reachable within your application, including +dependencies; queries a database of known security vulnerabilities; and +reports any potential problems it finds. + + +Default: off + +File type: go.mod + diff --git a/gopls/doc/contributing.md b/gopls/doc/contributing.md index 914794aee71..94752c5394d 100644 --- a/gopls/doc/contributing.md +++ b/gopls/doc/contributing.md @@ -3,9 +3,52 @@ This documentation augments the general documentation for contributing to the x/tools repository, described at the [repository root](../../CONTRIBUTING.md). -Contributions are welcome, but since development is so active, we request that -you file an issue and claim it before starting to work on something. Otherwise, -it is likely that we might already be working on a fix for your issue. +Contributions are welcome! However, development is fast moving, +and we are limited in our capacity to review contributions. +So, before sending a CL, please please please: + +- **file an issue** for a bug or feature request, if one does not + exist already. This allows us to identify redundant requests, or to + merge a specific problem into a more general one, and to assess the + importance of the problem. + +- **claim it for yourself** by commenting on the issue or, if you are + able, by assigning the issue to yourself. This helps us avoid two + people working on the same problem. + +- **propose an implementation plan** in the issue tracker for CLs of + any complexity. It is much more efficient to discuss the plan at a + high level before we start getting bogged down in the details of + a code review. + +When you send a CL, it should include: + +- a **CL description** that summarizes the change, + motivates why it is necessary, + explains it at a high level, + contrasts it with more obvious or simpler approaches, and + links to relevant issues; +- **tests** (integration tests or marker tests); +- **documentation**, for new or modified features; and +- **release notes**, for new features or significant changes. + +During code review, please address all reviewer comments. +Some comments result in straightforward code changes; +others demand a more complex response. +When a reviewer asks a question, the best response is +often not to respond to it directly, but to change the +code to avoid raising the question, +for example by making the code self-explanatory. +It's fine to disagree with a comment, +point out a reviewer's mistake, +or offer to address a comment in a follow-up change, +leaving a TODO comment in the current CL. +But please don't dismiss or quietly ignore a comment without action, +as it may lead reviewers to repeat themselves, +or to serious problems being neglected. + +For more detail, see the Go project's +[contribution guidelines](https://golang.org/doc/contribute.html). ## Finding issues diff --git a/gopls/doc/features/README.md b/gopls/doc/features/README.md index 92203ed677a..c78bb5c687d 100644 --- a/gopls/doc/features/README.md +++ b/gopls/doc/features/README.md @@ -50,6 +50,7 @@ when making significant changes to existing features or when adding new ones. - [Extract](transformation.md#refactor.extract): extract selection to a new file/function/variable - [Inline](transformation.md#refactor.inline.call): inline a call to a function or method - [Miscellaneous rewrites](transformation.md#refactor.rewrite): various Go-specific refactorings + - [Add test for func](transformation.md#source.addTest): create a test for the selected function - [Web-based queries](web.md): commands that open a browser page - [Package documentation](web.md#doc): browse documentation for current Go package - [Free symbols](web.md#freesymbols): show symbols used by a selected block of code diff --git a/gopls/doc/features/diagnostics.md b/gopls/doc/features/diagnostics.md index 5955a55d8b3..21015bcaa35 100644 --- a/gopls/doc/features/diagnostics.md +++ b/gopls/doc/features/diagnostics.md @@ -197,6 +197,57 @@ func (f Foo) bar(s string, i int) string { } ``` +### `CreateUndeclared`: Create missing declaration for "undeclared name: X" + +A Go compiler error "undeclared name: X" indicates that a variable or function is being used before +it has been declared in the current scope. In this scenario, gopls offers a quick fix to create the declaration. + +#### Declare a new variable + +When you reference a variable that hasn't been declared: + +```go +func main() { + x := 42 + min(x, y) // error: undefined: y +} +``` + +The quick fix would insert a declaration with a default +value inferring its type from the context: + +```go +func main() { + x := 42 + y := 0 + min(x, y) +} +``` + +#### Declare a new function + +Similarly, if you call a function that hasn't been declared: + +```go +func main() { + var s string + s = doSomething(42) // error: undefined: doSomething +} +``` + +Gopls will insert a new function declaration below, +inferring its type from the call: + +```go +func main() { + var s string + s = doSomething(42) +} + +func doSomething(i int) string { + panic("unimplemented") +} +``` +- [`source.addTest`](#source.addTest) - [`gopls.doc.features`](README.md), which opens gopls' index of features in a browser +- [`refactor.extract.constant`](#extract) - [`refactor.extract.function`](#extract) - [`refactor.extract.method`](#extract) - [`refactor.extract.toNewFile`](#extract.toNewFile) @@ -210,6 +212,45 @@ Client support: ``` - **CLI**: `gopls fix -a file.go:#offset source.organizeImports` + +## `source.addTest`: Add test for function or method + +If the selected chunk of code is part of a function or method declaration F, +gopls will offer the "Add test for F" code action, which adds a new test for the +selected function in the corresponding `_test.go` file. The generated test takes +into account its signature, including input parameters and results. + +**Test file**: if the `_test.go` file does not exist, gopls creates it, based on +the name of the current file (`a.go` -> `a_test.go`), copying any copyright and +build constraint comments from the original file. + +**Test package**: for new files that test code in package `p`, the test file +uses `p_test` package name whenever possible, to encourage testing only exported +functions. (If the test file already exists, the new test is added to that file.) + +**Parameters**: each of the function's non-blank parameters becomes an item in +the struct used for the table-driven test. (For each blank `_` parameter, the +value has no effect, so the test provides a zero-valued argument.) + +**Contexts**: If the first parameter is `context.Context`, the test passes +`context.Background()`. + +**Results**: the function's results are assigned to variables (`got`, `got2`, +and so on) and compared with expected values (`want`, `want2`, etc.`) defined in +the test case struct. The user should edit the logic to perform the appropriate +comparison. If the final result is an `error`, the test case defines a `wantErr` +boolean. + +**Method receivers**: When testing a method `T.F` or `(*T).F`, the test must +construct an instance of T to pass as the receiver. Gopls searches the package +for a suitable function that constructs a value of type T or *T, optionally with +an error, preferring a function named `NewT`. + +**Imports**: Gopls adds missing imports to the test file, using the last +corresponding import specifier from the original file. It avoids duplicate +imports, preserving any existing imports in the test file. + + ## Rename @@ -318,6 +359,9 @@ newly created declaration that contains the selected code: ![Before extracting a var](../assets/extract-var-before.png) ![After extracting a var](../assets/extract-var-after.png) +- **`refactor.extract.constant** does the same thing for a constant + expression, introducing a local const declaration. + If the default name for the new declaration is already in use, gopls generates a fresh name. @@ -340,11 +384,9 @@ number of cases where it falls short, including: The following Extract features are planned for 2024 but not yet supported: -- **Extract constant** is a variant of "Extract variable" to be - offered when the expression is constant; see golang/go#37170. - **Extract parameter struct** will replace two or more parameters of a function by a struct type with one field per parameter; see golang/go#65552. - + - **Extract interface for type** will create a declaration of an interface type with all the methods of the selected concrete type; diff --git a/gopls/doc/release/v0.17.0.md b/gopls/doc/release/v0.17.0.md index a3e8b1b34e0..1a278b013cb 100644 --- a/gopls/doc/release/v0.17.0.md +++ b/gopls/doc/release/v0.17.0.md @@ -13,6 +13,15 @@ # New features +## Change signature refactoring + +TODO(rfindley): document the state of change signature refactoring once the +feature set stabilizes. + +## Improvements to existing refactoring operations + +TODO(rfindley): document the full set of improvements to rename/extract/inline. + ## Extract declarations to new file Gopls now offers another code action, @@ -24,10 +33,19 @@ removed as needed. The user can invoke this code action by selecting a function name, the keywords `func`, `const`, `var`, `type`, or by placing the caret on them without selecting, -or by selecting a whole declaration or multiple declrations. +or by selecting a whole declaration or multiple declarations. In order to avoid ambiguity and surprise about what to extract, some kinds -of paritial selection of a declration cannot invoke this code action. +of paritial selection of a declaration cannot invoke this code action. + +## Extract constant + +When the selection is a constant expression, gopls now offers "Extract +constant" instead of "Extract variable", and generates a `const` +declaration instead of a local variable. + +Also, extraction of a constant or variable now works at top-level, +outside of any function. ## Pull diagnostics @@ -78,3 +96,30 @@ Gopls now offers a new code action, “Declare missing method of T.f”, where T is the concrete type and f is the undefined method. The stub method's signature is inferred from the context of the call. + +## `yield` analyzer + +The new `yield` analyzer detects mistakes using the `yield` function +in a Go 1.23 iterator, such as failure to check its boolean result and +break out of a loop. + +## `waitgroup` analyzer + +The new `waitgroup` analyzer detects calls to the `Add` method of +`sync.WaitGroup` that are (mistakenly) made within the new goroutine, +causing `Add` to race with `Wait`. +(This check is equivalent to +[staticcheck's SA2000](https://staticcheck.dev/docs/checks#SA2000), +but is enabled by default.) + +## Add test for function or method + +If the selected chunk of code is part of a function or method declaration F, +gopls will offer the "Add test for F" code action, which adds a new test for the +selected function in the corresponding `_test.go` file. The generated test takes +into account its signature, including input parameters and results. + +Since this feature is implemented by the server (gopls), it is compatible with +all LSP-compliant editors. VS Code users may continue to use the client-side +`Go: Generate Unit Tests For file/function/package` command which utilizes the +[gotests](https://github.com/cweill/gotests) tool. \ No newline at end of file diff --git a/gopls/go.mod b/gopls/go.mod index beea37161db..03f7956025d 100644 --- a/gopls/go.mod +++ b/gopls/go.mod @@ -8,10 +8,10 @@ require ( github.com/google/go-cmp v0.6.0 github.com/jba/templatecheck v0.7.0 golang.org/x/mod v0.22.0 - golang.org/x/sync v0.9.0 - golang.org/x/sys v0.27.0 + golang.org/x/sync v0.10.0 + golang.org/x/sys v0.28.0 golang.org/x/telemetry v0.0.0-20241106142447-58a1122356f5 - golang.org/x/text v0.20.0 + golang.org/x/text v0.21.0 golang.org/x/tools v0.21.1-0.20240531212143-b6235391adb3 golang.org/x/vuln v1.0.4 gopkg.in/yaml.v3 v3.0.1 diff --git a/gopls/go.sum b/gopls/go.sum index 3321a78e6f0..7785bbed7f6 100644 --- a/gopls/go.sum +++ b/gopls/go.sum @@ -16,7 +16,7 @@ github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= +golang.org/x/crypto v0.30.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp/typeparams v0.0.0-20231108232855-2478ac86f678 h1:1P7xPZEwZMoBoz0Yze5Nx2/4pxj6nw9ZqHWXqP0iRgQ= golang.org/x/exp/typeparams v0.0.0-20231108232855-2478ac86f678/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= @@ -25,27 +25,27 @@ golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= +golang.org/x/net v0.32.0/go.mod h1:CwU0IoeOlnQQWJ6ioyFrfRuomB8GKF6KbYXZVyeXNfs= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= -golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= -golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/telemetry v0.0.0-20240521205824-bda55230c457/go.mod h1:pRgIJT+bRLFKnoM1ldnzKoxTIn14Yxz928LQRYYgIN0= golang.org/x/telemetry v0.0.0-20241106142447-58a1122356f5 h1:TCDqnvbBsFapViksHcHySl/sW4+rTGNIAoJJesHRuMM= golang.org/x/telemetry v0.0.0-20241106142447-58a1122356f5/go.mod h1:8nZWdGp9pq73ZI//QJyckMQab3yq7hoWi7SI0UIusVI= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= -golang.org/x/term v0.26.0/go.mod h1:Si5m1o57C5nBNQo5z1iq+XDijt21BDBDp2bK0QI8e3E= +golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= -golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/vuln v1.0.4 h1:SP0mPeg2PmGCu03V+61EcQiOjmpri2XijexKdzv8Z1I= golang.org/x/vuln v1.0.4/go.mod h1:NbJdUQhX8jY++FtuhrXs2Eyx0yePo9pF7nPlIjo9aaQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/gopls/internal/analysis/fillreturns/fillreturns.go b/gopls/internal/analysis/fillreturns/fillreturns.go index 5ebfc2013bd..145b03f4a42 100644 --- a/gopls/internal/analysis/fillreturns/fillreturns.go +++ b/gopls/internal/analysis/fillreturns/fillreturns.go @@ -18,6 +18,7 @@ import ( "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/gopls/internal/fuzzy" "golang.org/x/tools/internal/analysisinternal" + "golang.org/x/tools/internal/typesinternal" ) //go:embed doc.go @@ -161,7 +162,7 @@ outer: if t := info.TypeOf(val); t == nil || !matchingTypes(t, retTyp) { continue } - if !analysisinternal.IsZeroValue(val) { + if !typesinternal.IsZeroExpr(val) { match, idx = val, j break } @@ -183,7 +184,7 @@ outer: // If no identifier matches the pattern, generate a zero value. if best := fuzzy.BestMatch(retTyp.String(), names); best != "" { fixed[i] = ast.NewIdent(best) - } else if zero := analysisinternal.ZeroValue(file, pass.Pkg, retTyp); zero != nil { + } else if zero := typesinternal.ZeroExpr(file, pass.Pkg, retTyp); zero != nil { fixed[i] = zero } else { return nil, nil @@ -194,7 +195,7 @@ outer: // Remove any non-matching "zero values" from the leftover values. var nonZeroRemaining []ast.Expr for _, expr := range remaining { - if !analysisinternal.IsZeroValue(expr) { + if !typesinternal.IsZeroExpr(expr) { nonZeroRemaining = append(nonZeroRemaining, expr) } } diff --git a/gopls/internal/analysis/fillreturns/testdata/src/a/a.go.golden b/gopls/internal/analysis/fillreturns/testdata/src/a/a.go.golden index 27353f5fbab..6d9e3e161dc 100644 --- a/gopls/internal/analysis/fillreturns/testdata/src/a/a.go.golden +++ b/gopls/internal/analysis/fillreturns/testdata/src/a/a.go.golden @@ -67,7 +67,7 @@ func basic() (uint8, uint16, uint32, uint64, int8, int16, int32, int64, float32, } func complex() (*int, []int, [2]int, map[int]int) { - return nil, nil, nil, nil // want "return values" + return nil, nil, [2]int{}, nil // want "return values" } func structsAndInterfaces() (T, url.URL, T1, I, I1, io.Reader, Client, ast2.Stmt) { diff --git a/gopls/internal/analysis/fillstruct/fillstruct.go b/gopls/internal/analysis/fillstruct/fillstruct.go index 629dfdfc797..6fa64182a07 100644 --- a/gopls/internal/analysis/fillstruct/fillstruct.go +++ b/gopls/internal/analysis/fillstruct/fillstruct.go @@ -349,8 +349,8 @@ func populateValue(f *ast.File, pkg *types.Package, typ types.Type) ast.Expr { } case *types.Map: - k := analysisinternal.TypeExpr(f, pkg, u.Key()) - v := analysisinternal.TypeExpr(f, pkg, u.Elem()) + k := typesinternal.TypeExpr(f, pkg, u.Key()) + v := typesinternal.TypeExpr(f, pkg, u.Elem()) if k == nil || v == nil { return nil } @@ -361,7 +361,7 @@ func populateValue(f *ast.File, pkg *types.Package, typ types.Type) ast.Expr { }, } case *types.Slice: - s := analysisinternal.TypeExpr(f, pkg, u.Elem()) + s := typesinternal.TypeExpr(f, pkg, u.Elem()) if s == nil { return nil } @@ -372,7 +372,7 @@ func populateValue(f *ast.File, pkg *types.Package, typ types.Type) ast.Expr { } case *types.Array: - a := analysisinternal.TypeExpr(f, pkg, u.Elem()) + a := typesinternal.TypeExpr(f, pkg, u.Elem()) if a == nil { return nil } @@ -386,7 +386,7 @@ func populateValue(f *ast.File, pkg *types.Package, typ types.Type) ast.Expr { } case *types.Chan: - v := analysisinternal.TypeExpr(f, pkg, u.Elem()) + v := typesinternal.TypeExpr(f, pkg, u.Elem()) if v == nil { return nil } @@ -405,7 +405,7 @@ func populateValue(f *ast.File, pkg *types.Package, typ types.Type) ast.Expr { } case *types.Struct: - s := analysisinternal.TypeExpr(f, pkg, typ) + s := typesinternal.TypeExpr(f, pkg, typ) if s == nil { return nil } @@ -416,7 +416,7 @@ func populateValue(f *ast.File, pkg *types.Package, typ types.Type) ast.Expr { case *types.Signature: var params []*ast.Field for i := 0; i < u.Params().Len(); i++ { - p := analysisinternal.TypeExpr(f, pkg, u.Params().At(i).Type()) + p := typesinternal.TypeExpr(f, pkg, u.Params().At(i).Type()) if p == nil { return nil } @@ -431,7 +431,7 @@ func populateValue(f *ast.File, pkg *types.Package, typ types.Type) ast.Expr { } var returns []*ast.Field for i := 0; i < u.Results().Len(); i++ { - r := analysisinternal.TypeExpr(f, pkg, u.Results().At(i).Type()) + r := typesinternal.TypeExpr(f, pkg, u.Results().At(i).Type()) if r == nil { return nil } diff --git a/gopls/internal/analysis/undeclaredname/doc.go b/gopls/internal/analysis/undeclaredname/doc.go deleted file mode 100644 index 02989c9d75b..00000000000 --- a/gopls/internal/analysis/undeclaredname/doc.go +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2023 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package undeclaredname defines an Analyzer that applies suggested fixes -// to errors of the type "undeclared name: %s". -// -// # Analyzer undeclaredname -// -// undeclaredname: suggested fixes for "undeclared name: <>" -// -// This checker provides suggested fixes for type errors of the -// type "undeclared name: <>". It will either insert a new statement, -// such as: -// -// <> := -// -// or a new function declaration, such as: -// -// func <>(inferred parameters) { -// panic("implement me!") -// } -package undeclaredname diff --git a/gopls/internal/analysis/undeclaredname/testdata/src/a/a.go b/gopls/internal/analysis/undeclaredname/testdata/src/a/a.go deleted file mode 100644 index c5d8a2d789c..00000000000 --- a/gopls/internal/analysis/undeclaredname/testdata/src/a/a.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2020 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package undeclared - -func x() int { - var z int - z = y // want "(undeclared name|undefined): y" - - if z == m { // want "(undeclared name|undefined): m" - z = 1 - } - - if z == 1 { - z = 1 - } else if z == n+1 { // want "(undeclared name|undefined): n" - z = 1 - } - - switch z { - case 10: - z = 1 - case a: // want "(undeclared name|undefined): a" - z = 1 - } - return z -} diff --git a/gopls/internal/analysis/undeclaredname/testdata/src/a/channels.go b/gopls/internal/analysis/undeclaredname/testdata/src/a/channels.go deleted file mode 100644 index 76c7ba685e1..00000000000 --- a/gopls/internal/analysis/undeclaredname/testdata/src/a/channels.go +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright 2021 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package undeclared - -func channels(s string) { - undefinedChannels(c()) // want "(undeclared name|undefined): undefinedChannels" -} - -func c() (<-chan string, chan string) { - return make(<-chan string), make(chan string) -} diff --git a/gopls/internal/analysis/undeclaredname/testdata/src/a/consecutive_params.go b/gopls/internal/analysis/undeclaredname/testdata/src/a/consecutive_params.go deleted file mode 100644 index 73beace102c..00000000000 --- a/gopls/internal/analysis/undeclaredname/testdata/src/a/consecutive_params.go +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright 2021 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package undeclared - -func consecutiveParams() { - var s string - undefinedConsecutiveParams(s, s) // want "(undeclared name|undefined): undefinedConsecutiveParams" -} diff --git a/gopls/internal/analysis/undeclaredname/testdata/src/a/error_param.go b/gopls/internal/analysis/undeclaredname/testdata/src/a/error_param.go deleted file mode 100644 index 5de9254112d..00000000000 --- a/gopls/internal/analysis/undeclaredname/testdata/src/a/error_param.go +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright 2021 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package undeclared - -func errorParam() { - var err error - undefinedErrorParam(err) // want "(undeclared name|undefined): undefinedErrorParam" -} diff --git a/gopls/internal/analysis/undeclaredname/testdata/src/a/literals.go b/gopls/internal/analysis/undeclaredname/testdata/src/a/literals.go deleted file mode 100644 index c62174ec947..00000000000 --- a/gopls/internal/analysis/undeclaredname/testdata/src/a/literals.go +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright 2021 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package undeclared - -type T struct{} - -func literals() { - undefinedLiterals("hey compiler", T{}, &T{}) // want "(undeclared name|undefined): undefinedLiterals" -} diff --git a/gopls/internal/analysis/undeclaredname/testdata/src/a/operation.go b/gopls/internal/analysis/undeclaredname/testdata/src/a/operation.go deleted file mode 100644 index 9396da4bd9d..00000000000 --- a/gopls/internal/analysis/undeclaredname/testdata/src/a/operation.go +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright 2021 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package undeclared - -import "time" - -func operation() { - undefinedOperation(10 * time.Second) // want "(undeclared name|undefined): undefinedOperation" -} diff --git a/gopls/internal/analysis/undeclaredname/testdata/src/a/selector.go b/gopls/internal/analysis/undeclaredname/testdata/src/a/selector.go deleted file mode 100644 index a4ed290d466..00000000000 --- a/gopls/internal/analysis/undeclaredname/testdata/src/a/selector.go +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright 2021 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package undeclared - -func selector() { - m := map[int]bool{} - undefinedSelector(m[1]) // want "(undeclared name|undefined): undefinedSelector" -} diff --git a/gopls/internal/analysis/undeclaredname/testdata/src/a/slice.go b/gopls/internal/analysis/undeclaredname/testdata/src/a/slice.go deleted file mode 100644 index 5cde299add3..00000000000 --- a/gopls/internal/analysis/undeclaredname/testdata/src/a/slice.go +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright 2021 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package undeclared - -func slice() { - undefinedSlice([]int{1, 2}) // want "(undeclared name|undefined): undefinedSlice" -} diff --git a/gopls/internal/analysis/undeclaredname/testdata/src/a/tuple.go b/gopls/internal/analysis/undeclaredname/testdata/src/a/tuple.go deleted file mode 100644 index 9e91c59c25e..00000000000 --- a/gopls/internal/analysis/undeclaredname/testdata/src/a/tuple.go +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright 2021 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package undeclared - -func tuple() { - undefinedTuple(b()) // want "(undeclared name|undefined): undefinedTuple" -} - -func b() (string, error) { - return "", nil -} diff --git a/gopls/internal/analysis/undeclaredname/testdata/src/a/unique_params.go b/gopls/internal/analysis/undeclaredname/testdata/src/a/unique_params.go deleted file mode 100644 index 5b4241425e5..00000000000 --- a/gopls/internal/analysis/undeclaredname/testdata/src/a/unique_params.go +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright 2021 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package undeclared - -func uniqueArguments() { - var s string - var i int - undefinedUniqueArguments(s, i, s) // want "(undeclared name|undefined): undefinedUniqueArguments" -} diff --git a/gopls/internal/analysis/yield/doc.go b/gopls/internal/analysis/yield/doc.go new file mode 100644 index 00000000000..e03d0520d06 --- /dev/null +++ b/gopls/internal/analysis/yield/doc.go @@ -0,0 +1,38 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package yield defines an Analyzer that checks for mistakes related +// to the yield function used in iterators. +// +// # Analyzer yield +// +// yield: report calls to yield where the result is ignored +// +// After a yield function returns false, the caller should not call +// the yield function again; generally the iterator should return +// promptly. +// +// This example fails to check the result of the call to yield, +// causing this analyzer to report a diagnostic: +// +// yield(1) // yield may be called again (on L2) after returning false +// yield(2) +// +// The corrected code is either this: +// +// if yield(1) { yield(2) } +// +// or simply: +// +// _ = yield(1) && yield(2) +// +// It is not always a mistake to ignore the result of yield. +// For example, this is a valid single-element iterator: +// +// yield(1) // ok to ignore result +// return +// +// It is only a mistake when the yield call that returned false may be +// followed by another call. +package yield diff --git a/gopls/internal/analysis/yield/main.go b/gopls/internal/analysis/yield/main.go new file mode 100644 index 00000000000..d0bb9613bf9 --- /dev/null +++ b/gopls/internal/analysis/yield/main.go @@ -0,0 +1,16 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build ignore + +// The yield command applies the yield analyzer to the specified +// packages of Go source code. +package main + +import ( + "golang.org/x/tools/go/analysis/singlechecker" + "golang.org/x/tools/gopls/internal/analysis/yield" +) + +func main() { singlechecker.Main(yield.Analyzer) } diff --git a/gopls/internal/analysis/yield/testdata/src/a/a.go b/gopls/internal/analysis/yield/testdata/src/a/a.go new file mode 100644 index 00000000000..9eb88b5ae69 --- /dev/null +++ b/gopls/internal/analysis/yield/testdata/src/a/a.go @@ -0,0 +1,120 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package yield + +import ( + "bufio" + "io" +) + +// +// +// Modify this block of comment lines as needed when changing imports +// to avoid perturbing subsequent line numbers (and thus error messages). +// +// This is L16. + +func goodIter(yield func(int) bool) { + _ = yield(1) && yield(2) && yield(3) // ok +} + +func badIterOR(yield func(int) bool) { + _ = yield(1) || // want `yield may be called again \(on L25\) after returning false` + yield(2) || // want `yield may be called again \(on L26\) after returning false` + yield(3) +} + +func badIterSeq(yield func(int) bool) { + yield(1) // want `yield may be called again \(on L31\) after returning false` + yield(2) // want `yield may be called again \(on L32\) after returning false` + yield(3) // ok +} + +func badIterLoop(yield func(int) bool) { + for { + yield(1) // want `yield may be called again after returning false` + } +} + +func goodIterLoop(yield func(int) bool) { + for { + if !yield(1) { + break + } + } +} + +func badIterIf(yield func(int) bool) { + ok := yield(1) // want `yield may be called again \(on L52\) after returning false` + if !ok { + yield(2) + } else { + yield(3) + } +} + +func singletonIter(yield func(int) bool) { + yield(1) // ok +} + +func twoArgumentYield(yield func(int, int) bool) { + _ = yield(1, 1) || // want `yield may be called again \(on L64\) after returning false` + yield(2, 2) +} + +func zeroArgumentYield(yield func() bool) { + _ = yield() || // want `yield may be called again \(on L69\) after returning false` + yield() +} + +func tricky(in io.ReadCloser) func(yield func(string, error) bool) { + return func(yield func(string, error) bool) { + scan := bufio.NewScanner(in) + for scan.Scan() { + if !yield(scan.Text(), nil) { // want `yield may be called again \(on L82\) after returning false` + _ = in.Close() + break + } + } + if err := scan.Err(); err != nil { + yield("", err) + } + } +} + +// Regression test for issue #70598. +func shortCircuitAND(yield func(int) bool) { + ok := yield(1) + ok = ok && yield(2) + ok = ok && yield(3) + ok = ok && yield(4) +} + +// This example has a bug because a false yield(2) may be followed by yield(3). +func tricky2(yield func(int) bool) { + cleanup := func() {} + ok := yield(1) // want "yield may be called again .on L104" + stop := !ok || yield(2) // want "yield may be called again .on L104" + if stop { + cleanup() + } else { + // dominated by !stop => !(!ok || yield(2)) => yield(1) && !yield(2): bad. + yield(3) + } +} + +// This example is sound, but the analyzer reports a false positive. +// TODO(adonovan): prune infeasible paths more carefully. +func tricky3(yield func(int) bool) { + cleanup := func() {} + ok := yield(1) // want "yield may be called again .on L118" + stop := !ok || !yield(2) // want "yield may be called again .on L118" + if stop { + cleanup() + } else { + // dominated by !stop => !(!ok || !yield(2)) => yield(1) && yield(2): good. + yield(3) + } +} diff --git a/gopls/internal/analysis/yield/yield.go b/gopls/internal/analysis/yield/yield.go new file mode 100644 index 00000000000..ccd30045f97 --- /dev/null +++ b/gopls/internal/analysis/yield/yield.go @@ -0,0 +1,193 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package yield + +// TODO(adonovan): also check for this pattern: +// +// for x := range seq { +// yield(x) +// } +// +// which should be entirely rewritten as +// +// seq(yield) +// +// to avoid unnecesary range desugaring and chains of dynamic calls. + +import ( + _ "embed" + "fmt" + "go/ast" + "go/constant" + "go/token" + "go/types" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/buildssa" + "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/inspector" + "golang.org/x/tools/go/ssa" + "golang.org/x/tools/gopls/internal/util/safetoken" + "golang.org/x/tools/internal/analysisinternal" +) + +//go:embed doc.go +var doc string + +var Analyzer = &analysis.Analyzer{ + Name: "yield", + Doc: analysisinternal.MustExtractDoc(doc, "yield"), + Requires: []*analysis.Analyzer{inspect.Analyzer, buildssa.Analyzer}, + Run: run, + URL: "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/yield", +} + +func run(pass *analysis.Pass) (interface{}, error) { + inspector := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + + // Find all calls to yield of the right type. + yieldCalls := make(map[token.Pos]*ast.CallExpr) // keyed by CallExpr.Lparen. + nodeFilter := []ast.Node{(*ast.CallExpr)(nil)} + inspector.Preorder(nodeFilter, func(n ast.Node) { + call := n.(*ast.CallExpr) + if id, ok := call.Fun.(*ast.Ident); ok && id.Name == "yield" { + if sig, ok := pass.TypesInfo.TypeOf(id).(*types.Signature); ok && + sig.Params().Len() < 3 && + sig.Results().Len() == 1 && + types.Identical(sig.Results().At(0).Type(), types.Typ[types.Bool]) { + yieldCalls[call.Lparen] = call + } + } + }) + + // Common case: nothing to do. + if len(yieldCalls) == 0 { + return nil, nil + } + + // Study the control flow using SSA. + buildssa := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA) + for _, fn := range buildssa.SrcFuncs { + // TODO(adonovan): opt: skip functions that don't contain any yield calls. + + // Find the yield calls in SSA. + type callInfo struct { + syntax *ast.CallExpr + index int // index of instruction within its block + reported bool + } + ssaYieldCalls := make(map[*ssa.Call]*callInfo) + for _, b := range fn.Blocks { + for i, instr := range b.Instrs { + if call, ok := instr.(*ssa.Call); ok { + if syntax, ok := yieldCalls[call.Pos()]; ok { + ssaYieldCalls[call] = &callInfo{syntax: syntax, index: i} + } + } + } + } + + // Now search for a control path from the instruction after a + // yield call to another yield call--possible the same one, + // following all block successors except "if yield() { ... }"; + // in such cases we know that yield returned true. + // + // Note that this is a "may" dataflow analysis: it + // reports when a yield function _may_ be called again + // without a positive intervening check, but it is + // possible that the check is beyond the ability of + // the representation to detect, perhaps involving + // sophisticated use of booleans, indirect state (not + // in SSA registers), or multiple flow paths some of + // which are infeasible. + // + // A "must" analysis (which would report when a second + // yield call can only be reached after failing the + // boolean check) would be too conservative. + // In particular, the most common mistake is to + // forget to check the boolean at all. + for call, info := range ssaYieldCalls { + visited := make([]bool, len(fn.Blocks)) // visited BasicBlock.Indexes + + // visit visits the instructions of a block (or a suffix if start > 0). + var visit func(b *ssa.BasicBlock, start int) + visit = func(b *ssa.BasicBlock, start int) { + if !visited[b.Index] { + if start == 0 { + visited[b.Index] = true + } + for _, instr := range b.Instrs[start:] { + switch instr := instr.(type) { + case *ssa.Call: + if !info.reported && ssaYieldCalls[instr] != nil { + info.reported = true + where := "" // "" => same yield call (a loop) + if instr != call { + otherLine := safetoken.StartPosition(pass.Fset, instr.Pos()).Line + where = fmt.Sprintf("(on L%d) ", otherLine) + } + pass.Reportf(call.Pos(), "yield may be called again %safter returning false", where) + } + case *ssa.If: + // Visit both successors, unless cond is yield() or its negation. + // In that case visit only the "if !yield()" block. + cond := instr.Cond + t, f := b.Succs[0], b.Succs[1] + + // Strip off any NOT operator. + cond, t, f = unnegate(cond, t, f) + + // As a peephole optimization for this special case: + // ok := yield() + // ok = ok && yield() + // ok = ok && yield() + // which in SSA becomes: + // yield() + // phi(false, yield()) + // phi(false, yield()) + // we reduce a cond of phi(false, x) to just x. + if phi, ok := cond.(*ssa.Phi); ok { + var nonFalse []ssa.Value + for _, v := range phi.Edges { + if c, ok := v.(*ssa.Const); ok && + !constant.BoolVal(c.Value) { + continue // constant false + } + nonFalse = append(nonFalse, v) + } + if len(nonFalse) == 1 { + cond = nonFalse[0] + cond, t, f = unnegate(cond, t, f) + } + } + + if cond, ok := cond.(*ssa.Call); ok && ssaYieldCalls[cond] != nil { + // Skip the successor reached by "if yield() { ... }". + } else { + visit(t, 0) + } + visit(f, 0) + + case *ssa.Jump: + visit(b.Succs[0], 0) + } + } + } + } + + // Start at the instruction after the yield call. + visit(call.Block(), info.index+1) + } + } + + return nil, nil +} + +func unnegate(cond ssa.Value, t, f *ssa.BasicBlock) (_ ssa.Value, _, _ *ssa.BasicBlock) { + if unop, ok := cond.(*ssa.UnOp); ok && unop.Op == token.NOT { + return unop.X, f, t + } + return cond, t, f +} diff --git a/gopls/internal/analysis/undeclaredname/undeclared_test.go b/gopls/internal/analysis/yield/yield_test.go similarity index 54% rename from gopls/internal/analysis/undeclaredname/undeclared_test.go rename to gopls/internal/analysis/yield/yield_test.go index ea3d724515b..af6784374e2 100644 --- a/gopls/internal/analysis/undeclaredname/undeclared_test.go +++ b/gopls/internal/analysis/yield/yield_test.go @@ -1,17 +1,17 @@ -// Copyright 2020 The Go Authors. All rights reserved. +// Copyright 2024 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package undeclaredname_test +package yield_test import ( "testing" "golang.org/x/tools/go/analysis/analysistest" - "golang.org/x/tools/gopls/internal/analysis/undeclaredname" + "golang.org/x/tools/gopls/internal/analysis/yield" ) func Test(t *testing.T) { testdata := analysistest.TestData() - analysistest.Run(t, testdata, undeclaredname.Analyzer, "a") + analysistest.Run(t, testdata, yield.Analyzer, "a") } diff --git a/gopls/internal/cache/check.go b/gopls/internal/cache/check.go index 0b9d4f8024d..dbae63b8529 100644 --- a/gopls/internal/cache/check.go +++ b/gopls/internal/cache/check.go @@ -644,7 +644,16 @@ func importLookup(mp *metadata.Package, source metadata.Source) func(PackagePath if prevID, ok := impMap[depPath]; ok { // debugging #63822 if prevID != depID { - bug.Reportf("inconsistent view of dependencies") + prev := source.Metadata(prevID) + curr := source.Metadata(depID) + switch { + case prev == nil || curr == nil: + bug.Reportf("inconsistent view of dependencies (missing dep)") + case prev.ForTest != curr.ForTest: + bug.Reportf("inconsistent view of dependencies (mismatching ForTest)") + default: + bug.Reportf("inconsistent view of dependencies") + } } continue } @@ -1782,7 +1791,7 @@ func depsErrors(ctx context.Context, snapshot *Snapshot, mp *metadata.Package) ( } } - modFile, err := nearestModFile(ctx, mp.CompiledGoFiles[0], snapshot) + modFile, err := findRootPattern(ctx, mp.CompiledGoFiles[0].Dir(), "go.mod", snapshot) if err != nil { return nil, err } diff --git a/gopls/internal/cache/diagnostics.go b/gopls/internal/cache/diagnostics.go index 95b1b9f1c18..0adbcb495db 100644 --- a/gopls/internal/cache/diagnostics.go +++ b/gopls/internal/cache/diagnostics.go @@ -191,13 +191,12 @@ func bundleLazyFixes(sd *Diagnostic) bool { // BundledLazyFixes extracts any bundled codeActions from the // diag.Data field. -func BundledLazyFixes(diag protocol.Diagnostic) []protocol.CodeAction { +func BundledLazyFixes(diag protocol.Diagnostic) ([]protocol.CodeAction, error) { var fix lazyFixesJSON if diag.Data != nil { err := protocol.UnmarshalJSON(*diag.Data, &fix) if err != nil { - bug.Reportf("unmarshalling lazy fix: %v", err) - return nil + return nil, fmt.Errorf("unmarshalling fix from diagnostic data: %v", err) } } @@ -205,8 +204,7 @@ func BundledLazyFixes(diag protocol.Diagnostic) []protocol.CodeAction { for _, action := range fix.Actions { // See bundleLazyFixes: for now we only support bundling commands. if action.Edit != nil { - bug.Reportf("bundled fix %q includes workspace edits", action.Title) - continue + return nil, fmt.Errorf("bundled fix %q includes workspace edits", action.Title) } // associate the action with the incoming diagnostic // (Note that this does not mutate the fix.Fixes slice). @@ -214,5 +212,5 @@ func BundledLazyFixes(diag protocol.Diagnostic) []protocol.CodeAction { actions = append(actions, action) } - return actions + return actions, nil } diff --git a/gopls/internal/cache/filemap.go b/gopls/internal/cache/filemap.go index ee64d7c32c3..c826141ed98 100644 --- a/gopls/internal/cache/filemap.go +++ b/gopls/internal/cache/filemap.go @@ -104,7 +104,7 @@ func (m *fileMap) set(key protocol.DocumentURI, fh file.Handle) { // addDirs adds all directories containing u to the dirs set. func (m *fileMap) addDirs(u protocol.DocumentURI) { - dir := filepath.Dir(u.Path()) + dir := u.DirPath() for dir != "" && !m.dirs.Contains(dir) { m.dirs.Add(dir) dir = filepath.Dir(dir) diff --git a/gopls/internal/cache/load.go b/gopls/internal/cache/load.go index 9987def6392..4868c0fa877 100644 --- a/gopls/internal/cache/load.go +++ b/gopls/internal/cache/load.go @@ -9,6 +9,7 @@ import ( "context" "errors" "fmt" + "go/types" "path/filepath" "slices" "sort" @@ -25,8 +26,8 @@ import ( "golang.org/x/tools/gopls/internal/util/immutable" "golang.org/x/tools/gopls/internal/util/pathutil" "golang.org/x/tools/internal/event" - "golang.org/x/tools/internal/gocommand" "golang.org/x/tools/internal/packagesinternal" + "golang.org/x/tools/internal/typesinternal" "golang.org/x/tools/internal/xcontext" ) @@ -42,7 +43,7 @@ var errNoPackages = errors.New("no packages returned") // errors associated with specific modules. // // If scopes contains a file scope there must be exactly one scope. -func (s *Snapshot) load(ctx context.Context, allowNetwork bool, scopes ...loadScope) (err error) { +func (s *Snapshot) load(ctx context.Context, allowNetwork AllowNetwork, scopes ...loadScope) (err error) { if ctx.Err() != nil { // Check context cancellation before incrementing id below: a load on a // cancelled context should be a no-op. @@ -57,6 +58,7 @@ func (s *Snapshot) load(ctx context.Context, allowNetwork bool, scopes ...loadSc // Keep track of module query -> module path so that we can later correlate query // errors with errors. moduleQueries := make(map[string]string) + for _, scope := range scopes { switch scope := scope.(type) { case packageLoadScope: @@ -118,21 +120,13 @@ func (s *Snapshot) load(ctx context.Context, allowNetwork bool, scopes ...loadSc startTime := time.Now() - inv, cleanupInvocation, err := s.GoCommandInvocation(allowNetwork, &gocommand.Invocation{ - WorkingDir: s.view.root.Path(), - }) - if err != nil { - return err - } - defer cleanupInvocation() - // Set a last resort deadline on packages.Load since it calls the go // command, which may hang indefinitely if it has a bug. golang/go#42132 // and golang/go#42255 have more context. ctx, cancel := context.WithTimeout(ctx, 10*time.Minute) defer cancel() - cfg := s.config(ctx, inv) + cfg := s.config(ctx, allowNetwork) pkgs, err := packages.Load(cfg, query...) // If the context was canceled, return early. Otherwise, we might be @@ -169,37 +163,21 @@ func (s *Snapshot) load(ctx context.Context, allowNetwork bool, scopes ...loadSc // package. We don't support this; theoretically we could, but it seems // unnecessarily complicated. // - // Prior to golang/go#64233 we just assumed that we'd get exactly one - // package here. The categorization of bug reports below may be a bit - // verbose, but anticipates that perhaps we don't fully understand - // possible failure modes. - errorf := bug.Errorf - if s.view.typ == GoPackagesDriverView { - errorf = fmt.Errorf // all bets are off - } - for _, pkg := range pkgs { - // Don't report bugs if any packages have errors. - // For example: given go list errors, go/packages may synthesize a - // package with ID equal to the query. - if len(pkg.Errors) > 0 { - errorf = fmt.Errorf - break - } - } - + // It's possible that we get no packages here, for example if the file is a + // cgo file and cgo is not enabled. var standalonePkg *packages.Package for _, pkg := range pkgs { if pkg.ID == "command-line-arguments" { if standalonePkg != nil { - return errorf("go/packages returned multiple standalone packages") + return fmt.Errorf("go/packages returned multiple standalone packages") } standalonePkg = pkg - } else if packagesinternal.GetForTest(pkg) == "" && !strings.HasSuffix(pkg.ID, ".test") { - return errorf("go/packages returned unexpected package %q for standalone file", pkg.ID) + } else if pkg.ForTest == "" && !strings.HasSuffix(pkg.ID, ".test") { + return fmt.Errorf("go/packages returned unexpected package %q for standalone file", pkg.ID) } } if standalonePkg == nil { - return errorf("go/packages failed to return non-test standalone package") + return fmt.Errorf("go/packages failed to return non-test standalone package") } if len(standalonePkg.CompiledGoFiles) > 0 { pkgs = []*packages.Package{standalonePkg} @@ -259,7 +237,7 @@ func (s *Snapshot) load(ctx context.Context, allowNetwork bool, scopes ...loadSc s.setBuiltin(pkg.GoFiles[0]) continue } - if packagesinternal.GetForTest(pkg) == "builtin" { + if pkg.ForTest == "builtin" { // We don't care about test variants of builtin. This caused test // failures in https://go.dev/cl/620196, when a test file was added to // builtin. @@ -278,7 +256,7 @@ func (s *Snapshot) load(ctx context.Context, allowNetwork bool, scopes ...loadSc if allFilesExcluded(pkg.GoFiles, filterFunc) { continue } - buildMetadata(newMetadata, pkg, cfg.Dir, standalone, s.view.typ != GoPackagesDriverView) + buildMetadata(newMetadata, cfg.Dir, standalone, pkg) } s.mu.Lock() @@ -362,6 +340,48 @@ func (m *moduleErrorMap) Error() string { return buf.String() } +// config returns the configuration used for the snapshot's interaction with +// the go/packages API. It uses the given working directory. +// +// TODO(rstambler): go/packages requires that we do not provide overlays for +// multiple modules in one config, so buildOverlay needs to filter overlays by +// module. +// TODO(rfindley): ^^ is this still true? +func (s *Snapshot) config(ctx context.Context, allowNetwork AllowNetwork) *packages.Config { + cfg := &packages.Config{ + Context: ctx, + Dir: s.view.root.Path(), + Env: s.view.Env(), + BuildFlags: slices.Clone(s.view.folder.Options.BuildFlags), + Mode: packages.NeedName | + packages.NeedFiles | + packages.NeedCompiledGoFiles | + packages.NeedImports | + packages.NeedDeps | + packages.NeedTypesSizes | + packages.NeedModule | + packages.NeedEmbedFiles | + packages.LoadMode(packagesinternal.DepsErrors) | + packages.NeedForTest, + Fset: nil, // we do our own parsing + Overlay: s.buildOverlays(), + Logf: func(format string, args ...interface{}) { + if s.view.folder.Options.VerboseOutput { + event.Log(ctx, fmt.Sprintf(format, args...)) + } + }, + Tests: true, + } + if !allowNetwork { + cfg.Env = append(cfg.Env, "GOPROXY=off") + } + // We want to type check cgo code if go/types supports it. + if typesinternal.SetUsesCgo(&types.Config{}) { + cfg.Mode |= packages.LoadMode(packagesinternal.TypecheckCgo) + } + return cfg +} + // buildMetadata populates the updates map with metadata updates to // apply, based on the given pkg. It recurs through pkg.Imports to ensure that // metadata exists for all dependencies. @@ -369,28 +389,30 @@ func (m *moduleErrorMap) Error() string { // Returns the metadata.Package that was built (or which was already present in // updates), or nil if the package could not be built. Notably, the resulting // metadata.Package may have an ID that differs from pkg.ID. -func buildMetadata(updates map[PackageID]*metadata.Package, pkg *packages.Package, loadDir string, standalone, goListView bool) *metadata.Package { +func buildMetadata(updates map[PackageID]*metadata.Package, loadDir string, standalone bool, pkg *packages.Package) *metadata.Package { // Allow for multiple ad-hoc packages in the workspace (see #47584). pkgPath := PackagePath(pkg.PkgPath) id := PackageID(pkg.ID) if metadata.IsCommandLineArguments(id) { var f string // file to use as disambiguating suffix - if len(pkg.CompiledGoFiles) > 0 { - f = pkg.CompiledGoFiles[0] - - // If there are multiple files, - // we can't use only the first. - // (Can this happen? #64557) - if len(pkg.CompiledGoFiles) > 1 { - bug.Reportf("unexpected files in command-line-arguments package: %v", pkg.CompiledGoFiles) + if len(pkg.GoFiles) > 0 { + f = pkg.GoFiles[0] + + // If there are multiple files, we can't use only the first. Note that we + // consider GoFiles, rather than CompiledGoFiles, as there can be + // multiple CompiledGoFiles in the presence of cgo processing, whereas a + // command-line-arguments package should always have exactly one nominal + // Go source file. (See golang/go#64557.) + if len(pkg.GoFiles) > 1 { + bug.Reportf("unexpected files in command-line-arguments package: %v", pkg.GoFiles) return nil } } else if len(pkg.IgnoredFiles) > 0 { // A file=empty.go query results in IgnoredFiles=[empty.go]. f = pkg.IgnoredFiles[0] } else { - bug.Reportf("command-line-arguments package has neither CompiledGoFiles nor IgnoredFiles") + bug.Reportf("command-line-arguments package has neither GoFiles nor IgnoredFiles") return nil } id = PackageID(pkg.ID + f) @@ -416,7 +438,7 @@ func buildMetadata(updates map[PackageID]*metadata.Package, pkg *packages.Packag ID: id, PkgPath: pkgPath, Name: PackageName(pkg.Name), - ForTest: PackagePath(packagesinternal.GetForTest(pkg)), + ForTest: PackagePath(pkg.ForTest), TypesSizes: pkg.TypesSizes, LoadDir: loadDir, Module: pkg.Module, @@ -522,7 +544,7 @@ func buildMetadata(updates map[PackageID]*metadata.Package, pkg *packages.Packag continue } - dep := buildMetadata(updates, imported, loadDir, false, goListView) // only top level packages can be standalone + dep := buildMetadata(updates, loadDir, false, imported) // only top level packages can be standalone // Don't record edges to packages with no name, as they cause trouble for // the importer (golang/go#60952). diff --git a/gopls/internal/cache/methodsets/methodsets.go b/gopls/internal/cache/methodsets/methodsets.go index 98b0563ceeb..d9173b3b4c3 100644 --- a/gopls/internal/cache/methodsets/methodsets.go +++ b/gopls/internal/cache/methodsets/methodsets.go @@ -52,8 +52,10 @@ import ( "strings" "golang.org/x/tools/go/types/objectpath" + "golang.org/x/tools/gopls/internal/util/bug" "golang.org/x/tools/gopls/internal/util/frob" "golang.org/x/tools/gopls/internal/util/safetoken" + "golang.org/x/tools/internal/typesinternal" ) // An Index records the non-empty method sets of all package-level @@ -223,16 +225,40 @@ func (b *indexBuilder) build(fset *token.FileSet, pkg *types.Package) *Index { return } - m.Posn = objectPos(method) - m.PkgPath = b.string(method.Pkg().Path()) - // Instantiations of generic methods don't have an // object path, so we use the generic. - if p, err := objectpathFor(method.Origin()); err != nil { - panic(err) // can't happen for a method of a package-level type - } else { - m.ObjectPath = b.string(string(p)) + p, err := objectpathFor(method.Origin()) + if err != nil { + // This should never happen for a method of a package-level type. + // ...but it does (golang/go#70418). + // Refine the crash into various bug reports. + report := func() { + bug.Reportf("missing object path for %s", method.FullName()) + } + sig := method.Signature() + if sig.Recv() == nil { + report() + return + } + _, named := typesinternal.ReceiverNamed(sig.Recv()) + switch { + case named == nil: + report() + case sig.TypeParams().Len() > 0: + report() + case method.Origin() != method: + report() // instantiated? + case sig.RecvTypeParams().Len() > 0: + report() // generic? + default: + report() + } + return } + + m.Posn = objectPos(method) + m.PkgPath = b.string(method.Pkg().Path()) + m.ObjectPath = b.string(string(p)) } // We ignore aliases, though in principle they could define a diff --git a/gopls/internal/cache/mod.go b/gopls/internal/cache/mod.go index 6837ec3257c..f16cfbfe1af 100644 --- a/gopls/internal/cache/mod.go +++ b/gopls/internal/cache/mod.go @@ -8,7 +8,6 @@ import ( "context" "errors" "fmt" - "path/filepath" "regexp" "strings" @@ -19,7 +18,6 @@ import ( "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/protocol/command" "golang.org/x/tools/internal/event" - "golang.org/x/tools/internal/gocommand" "golang.org/x/tools/internal/memoize" ) @@ -252,11 +250,7 @@ func modWhyImpl(ctx context.Context, snapshot *Snapshot, fh file.Handle) (map[st for _, req := range pm.File.Require { args = append(args, req.Mod.Path) } - inv, cleanupInvocation, err := snapshot.GoCommandInvocation(false, &gocommand.Invocation{ - Verb: "mod", - Args: args, - WorkingDir: filepath.Dir(fh.URI().Path()), - }) + inv, cleanupInvocation, err := snapshot.GoCommandInvocation(NoNetwork, fh.URI().DirPath(), "mod", args) if err != nil { return nil, err } diff --git a/gopls/internal/cache/mod_tidy.go b/gopls/internal/cache/mod_tidy.go index 67a3e9c7eb9..4d473d39b12 100644 --- a/gopls/internal/cache/mod_tidy.go +++ b/gopls/internal/cache/mod_tidy.go @@ -23,7 +23,6 @@ import ( "golang.org/x/tools/gopls/internal/protocol/command" "golang.org/x/tools/internal/diff" "golang.org/x/tools/internal/event" - "golang.org/x/tools/internal/gocommand" "golang.org/x/tools/internal/memoize" ) @@ -108,12 +107,8 @@ func modTidyImpl(ctx context.Context, snapshot *Snapshot, pm *ParsedModule) (*Ti } defer cleanup() - inv, cleanupInvocation, err := snapshot.GoCommandInvocation(false, &gocommand.Invocation{ - Verb: "mod", - Args: []string{"tidy", "-modfile=" + filepath.Join(tempDir, "go.mod")}, - Env: []string{"GOWORK=off"}, - WorkingDir: pm.URI.Dir().Path(), - }) + args := []string{"tidy", "-modfile=" + filepath.Join(tempDir, "go.mod")} + inv, cleanupInvocation, err := snapshot.GoCommandInvocation(NoNetwork, pm.URI.DirPath(), "mod", args, "GOWORK=off") if err != nil { return nil, err } diff --git a/gopls/internal/cache/parsego/file.go b/gopls/internal/cache/parsego/file.go index 1dc46da823a..ea8db19b4ff 100644 --- a/gopls/internal/cache/parsego/file.go +++ b/gopls/internal/cache/parsego/file.go @@ -12,6 +12,7 @@ import ( "sync" "golang.org/x/tools/gopls/internal/protocol" + "golang.org/x/tools/gopls/internal/util/bug" "golang.org/x/tools/gopls/internal/util/safetoken" ) @@ -116,6 +117,22 @@ func (pgf *File) RangePos(r protocol.Range) (token.Pos, token.Pos, error) { return pgf.Tok.Pos(start), pgf.Tok.Pos(end), nil } +// CheckNode asserts that the Node's positions are valid w.r.t. pgf.Tok. +func (pgf *File) CheckNode(node ast.Node) { + // Avoid safetoken.Offsets, and put each assertion on its own source line. + pgf.CheckPos(node.Pos()) + pgf.CheckPos(node.End()) +} + +// CheckPos asserts that the position is valid w.r.t. pgf.Tok. +func (pgf *File) CheckPos(pos token.Pos) { + if !pos.IsValid() { + bug.Report("invalid token.Pos") + } else if _, err := safetoken.Offset(pgf.Tok, pos); err != nil { + bug.Report("token.Pos out of range") + } +} + // Resolve lazily resolves ast.Ident.Objects in the enclosed syntax tree. // // Resolve must be called before accessing any of: diff --git a/gopls/internal/cache/parsego/parse.go b/gopls/internal/cache/parsego/parse.go index 08f9c6bbe85..52445a9fbbf 100644 --- a/gopls/internal/cache/parsego/parse.go +++ b/gopls/internal/cache/parsego/parse.go @@ -25,6 +25,7 @@ import ( "golang.org/x/tools/gopls/internal/label" "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/util/astutil" + "golang.org/x/tools/gopls/internal/util/bug" "golang.org/x/tools/gopls/internal/util/safetoken" "golang.org/x/tools/internal/diff" "golang.org/x/tools/internal/event" @@ -77,12 +78,21 @@ func Parse(ctx context.Context, fset *token.FileSet, uri protocol.DocumentURI, s tokenFile := func(file *ast.File) *token.File { tok := fset.File(file.FileStart) if tok == nil { + // Invalid File.FileStart (also File.{Package,Name.Pos}). + if file.Package.IsValid() { + bug.Report("ast.File has valid Package but no FileStart") + } + if file.Name.Pos().IsValid() { + bug.Report("ast.File has valid Name.Pos but no FileStart") + } tok = fset.AddFile(uri.Path(), -1, len(src)) tok.SetLinesForContent(src) - if file.FileStart.IsValid() { - file.FileStart = token.Pos(tok.Base()) - file.FileEnd = token.Pos(tok.Base() + tok.Size()) - } + // If the File contained any valid token.Pos values, + // they would all be invalid wrt the new token.File, + // but we have established that it lacks FileStart, + // Package, and Name.Pos. + file.FileStart = token.Pos(tok.Base()) + file.FileEnd = token.Pos(tok.Base() + tok.Size()) } return tok } diff --git a/gopls/internal/cache/session.go b/gopls/internal/cache/session.go index 5947b373b16..d8c01a17a01 100644 --- a/gopls/internal/cache/session.go +++ b/gopls/internal/cache/session.go @@ -191,7 +191,7 @@ func (s *Session) createView(ctx context.Context, def *viewDefinition) (*View, * } else { dirs = append(dirs, def.folder.Env.GOMODCACHE) for m := range def.workspaceModFiles { - dirs = append(dirs, filepath.Dir(m.Path())) + dirs = append(dirs, m.DirPath()) } } ignoreFilter = newIgnoreFilter(dirs) diff --git a/gopls/internal/cache/snapshot.go b/gopls/internal/cache/snapshot.go index 63aed7be2e6..46b0a6a1b5c 100644 --- a/gopls/internal/cache/snapshot.go +++ b/gopls/internal/cache/snapshot.go @@ -13,7 +13,6 @@ import ( "go/build/constraint" "go/parser" "go/token" - "go/types" "os" "path" "path/filepath" @@ -26,7 +25,6 @@ import ( "sync" "golang.org/x/sync/errgroup" - "golang.org/x/tools/go/packages" "golang.org/x/tools/go/types/objectpath" "golang.org/x/tools/gopls/internal/cache/metadata" "golang.org/x/tools/gopls/internal/cache/methodsets" @@ -49,8 +47,6 @@ import ( "golang.org/x/tools/internal/event/label" "golang.org/x/tools/internal/gocommand" "golang.org/x/tools/internal/memoize" - "golang.org/x/tools/internal/packagesinternal" - "golang.org/x/tools/internal/typesinternal" ) // A Snapshot represents the current state for a given view. @@ -363,50 +359,6 @@ func (s *Snapshot) Templates() map[protocol.DocumentURI]file.Handle { return tmpls } -// config returns the configuration used for the snapshot's interaction with -// the go/packages API. It uses the given working directory. -// -// TODO(rstambler): go/packages requires that we do not provide overlays for -// multiple modules in one config, so buildOverlay needs to filter overlays by -// module. -func (s *Snapshot) config(ctx context.Context, inv *gocommand.Invocation) *packages.Config { - - cfg := &packages.Config{ - Context: ctx, - Dir: inv.WorkingDir, - Env: inv.Env, - BuildFlags: inv.BuildFlags, - Mode: packages.NeedName | - packages.NeedFiles | - packages.NeedCompiledGoFiles | - packages.NeedImports | - packages.NeedDeps | - packages.NeedTypesSizes | - packages.NeedModule | - packages.NeedEmbedFiles | - packages.LoadMode(packagesinternal.DepsErrors) | - packages.LoadMode(packagesinternal.ForTest), - Fset: nil, // we do our own parsing - Overlay: s.buildOverlays(), - ParseFile: func(*token.FileSet, string, []byte) (*ast.File, error) { - panic("go/packages must not be used to parse files") - }, - Logf: func(format string, args ...interface{}) { - if s.Options().VerboseOutput { - event.Log(ctx, fmt.Sprintf(format, args...)) - } - }, - Tests: true, - } - packagesinternal.SetModFile(cfg, inv.ModFile) - packagesinternal.SetModFlag(cfg, inv.ModFlag) - // We want to type check cgo code if go/types supports it. - if typesinternal.SetUsesCgo(&types.Config{}) { - cfg.Mode |= packages.LoadMode(packagesinternal.TypecheckCgo) - } - return cfg -} - // RunGoModUpdateCommands runs a series of `go` commands that updates the go.mod // and go.sum file for wd, and returns their updated contents. // @@ -423,16 +375,14 @@ func (s *Snapshot) RunGoModUpdateCommands(ctx context.Context, modURI protocol.D // TODO(rfindley): we must use ModFlag and ModFile here (rather than simply // setting Args), because without knowing the verb, we can't know whether // ModFlag is appropriate. Refactor so that args can be set by the caller. - inv, cleanupInvocation, err := s.GoCommandInvocation(true, &gocommand.Invocation{ - WorkingDir: modURI.Dir().Path(), - ModFlag: "mod", - ModFile: filepath.Join(tempDir, "go.mod"), - Env: []string{"GOWORK=off"}, - }) + inv, cleanupInvocation, err := s.GoCommandInvocation(NetworkOK, modURI.DirPath(), "", nil, "GOWORK=off") if err != nil { return nil, nil, err } defer cleanupInvocation() + + inv.ModFlag = "mod" + inv.ModFile = filepath.Join(tempDir, "go.mod") invoke := func(args ...string) (*bytes.Buffer, error) { inv.Verb = args[0] inv.Args = args[1:] @@ -499,6 +449,15 @@ func TempModDir(ctx context.Context, fs file.Source, modURI protocol.DocumentURI return dir, cleanup, nil } +// AllowNetwork determines whether Go commands are permitted to use the +// network. (Controlled via GOPROXY=off.) +type AllowNetwork bool + +const ( + NoNetwork AllowNetwork = false + NetworkOK AllowNetwork = true +) + // GoCommandInvocation populates inv with configuration for running go commands // on the snapshot. // @@ -509,23 +468,15 @@ func TempModDir(ctx context.Context, fs file.Source, modURI protocol.DocumentURI // additional refactoring is still required: the responsibility for Env and // BuildFlags should be more clearly expressed in the API. // -// If allowNetwork is set, do not set GOPROXY=off. -func (s *Snapshot) GoCommandInvocation(allowNetwork bool, inv *gocommand.Invocation) (_ *gocommand.Invocation, cleanup func(), _ error) { - // TODO(rfindley): it's not clear that this is doing the right thing. - // Should inv.Env really overwrite view.options? Should s.view.envOverlay - // overwrite inv.Env? (Do we ever invoke this with a non-empty inv.Env?) - // - // We should survey existing uses and write down rules for how env is - // applied. - inv.Env = slices.Concat( - os.Environ(), - s.Options().EnvSlice(), - inv.Env, - []string{"GO111MODULE=" + s.view.adjustedGO111MODULE()}, - s.view.EnvOverlay(), - ) - inv.BuildFlags = slices.Clone(s.Options().BuildFlags) - +// If allowNetwork is NoNetwork, set GOPROXY=off. +func (s *Snapshot) GoCommandInvocation(allowNetwork AllowNetwork, dir, verb string, args []string, env ...string) (_ *gocommand.Invocation, cleanup func(), _ error) { + inv := &gocommand.Invocation{ + Verb: verb, + Args: args, + WorkingDir: dir, + Env: append(s.view.Env(), env...), + BuildFlags: slices.Clone(s.Options().BuildFlags), + } if !allowNetwork { inv.Env = append(inv.Env, "GOPROXY=off") } @@ -743,7 +694,7 @@ func (s *Snapshot) MetadataForFile(ctx context.Context, uri protocol.DocumentURI // - ...but uri is not unloadable if (shouldLoad || len(ids) == 0) && !unloadable { scope := fileLoadScope(uri) - err := s.load(ctx, false, scope) + err := s.load(ctx, NoNetwork, scope) // // Return the context error here as the current operation is no longer @@ -863,7 +814,7 @@ func (s *Snapshot) fileWatchingGlobPatterns() map[protocol.RelativePattern]unit var dirs []string if s.view.typ.usesModules() { if s.view.typ == GoWorkView { - workVendorDir := filepath.Join(s.view.gowork.Dir().Path(), "vendor") + workVendorDir := filepath.Join(s.view.gowork.DirPath(), "vendor") workVendorURI := protocol.URIFromPath(workVendorDir) patterns[protocol.RelativePattern{BaseURI: workVendorURI, Pattern: watchGoFiles}] = unit{} } @@ -874,8 +825,7 @@ func (s *Snapshot) fileWatchingGlobPatterns() map[protocol.RelativePattern]unit // The assumption is that the user is not actively editing non-workspace // modules, so don't pay the price of file watching. for modFile := range s.view.workspaceModFiles { - dir := filepath.Dir(modFile.Path()) - dirs = append(dirs, dir) + dirs = append(dirs, modFile.DirPath()) // TODO(golang/go#64724): thoroughly test these patterns, particularly on // on Windows. @@ -1115,15 +1065,6 @@ func moduleForURI(modFiles map[protocol.DocumentURI]struct{}, uri protocol.Docum return match } -// nearestModFile finds the nearest go.mod file contained in the directory -// containing uri, or a parent of that directory. -// -// The given uri must be a file, not a directory. -func nearestModFile(ctx context.Context, uri protocol.DocumentURI, fs file.Source) (protocol.DocumentURI, error) { - dir := filepath.Dir(uri.Path()) - return findRootPattern(ctx, protocol.URIFromPath(dir), "go.mod", fs) -} - // Metadata returns the metadata for the specified package, // or nil if it was not found. func (s *Snapshot) Metadata(id PackageID) *metadata.Package { @@ -1320,7 +1261,7 @@ func (s *Snapshot) reloadWorkspace(ctx context.Context) { scopes = []loadScope{viewLoadScope{}} } - err := s.load(ctx, false, scopes...) + err := s.load(ctx, NoNetwork, scopes...) // Unless the context was canceled, set "shouldLoad" to false for all // of the metadata we attempted to load. @@ -1406,7 +1347,7 @@ searchOverlays: ) if initialErr != nil { msg = fmt.Sprintf("initialization failed: %v", initialErr.MainError) - } else if goMod, err := nearestModFile(ctx, fh.URI(), s); err == nil && goMod != "" { + } else if goMod, err := findRootPattern(ctx, fh.URI().Dir(), "go.mod", file.Source(s)); err == nil && goMod != "" { // Check if the file's module should be loadable by considering both // loaded modules and workspace modules. The former covers cases where // the file is outside of a workspace folder. The latter covers cases @@ -1419,7 +1360,7 @@ searchOverlays: // prescriptive diagnostic in the case that there is no go.mod file, but // it is harder to be precise in that case, and less important. if !(loadedMod || workspaceMod) { - modDir := filepath.Dir(goMod.Path()) + modDir := goMod.DirPath() viewDir := s.view.folder.Dir.Path() // When the module is underneath the view dir, we offer @@ -1720,7 +1661,7 @@ func (s *Snapshot) clone(ctx, bgCtx context.Context, changed StateChange, done f continue // like with go.mod files, we only reinit when things change on disk } dir, base := filepath.Split(uri.Path()) - if base == "go.work.sum" && s.view.typ == GoWorkView && dir == filepath.Dir(s.view.gowork.Path()) { + if base == "go.work.sum" && s.view.typ == GoWorkView && dir == s.view.gowork.DirPath() { reinit = true } if base == "go.sum" { @@ -2003,7 +1944,7 @@ func deleteMostRelevantModFile(m *persistent.Map[protocol.DocumentURI, *memoize. m.Range(func(modURI protocol.DocumentURI, _ *memoize.Promise) { if len(modURI) > len(mostRelevant) { - if pathutil.InDir(filepath.Dir(modURI.Path()), changedFile) { + if pathutil.InDir(modURI.DirPath(), changedFile) { mostRelevant = modURI } } @@ -2055,12 +1996,12 @@ func invalidatedPackageIDs(uri protocol.DocumentURI, known map[protocol.Document }{fi, err} return fi, err } - dir := filepath.Dir(uri.Path()) + dir := uri.DirPath() fi, err := getInfo(dir) if err == nil { // Aggregate all possibly relevant package IDs. for knownURI, ids := range known { - knownDir := filepath.Dir(knownURI.Path()) + knownDir := knownURI.DirPath() knownFI, err := getInfo(knownDir) if err != nil { continue diff --git a/gopls/internal/cache/typerefs/pkgrefs_test.go b/gopls/internal/cache/typerefs/pkgrefs_test.go index 9d4b5c011d3..3f9a976ccf7 100644 --- a/gopls/internal/cache/typerefs/pkgrefs_test.go +++ b/gopls/internal/cache/typerefs/pkgrefs_test.go @@ -342,7 +342,7 @@ func loadPackages(query string, needExport bool) (map[PackageID]string, Metadata packages.NeedModule | packages.NeedEmbedFiles | packages.LoadMode(packagesinternal.DepsErrors) | - packages.LoadMode(packagesinternal.ForTest), + packages.NeedForTest, Tests: true, } if needExport { @@ -364,7 +364,7 @@ func loadPackages(query string, needExport bool) (map[PackageID]string, Metadata ID: id, PkgPath: PackagePath(pkg.PkgPath), Name: packageName(pkg.Name), - ForTest: PackagePath(packagesinternal.GetForTest(pkg)), + ForTest: PackagePath(pkg.ForTest), TypesSizes: pkg.TypesSizes, LoadDir: cfg.Dir, Module: pkg.Module, diff --git a/gopls/internal/cache/view.go b/gopls/internal/cache/view.go index 5c8f4faec9e..d2adc5de019 100644 --- a/gopls/internal/cache/view.go +++ b/gopls/internal/cache/view.go @@ -356,6 +356,16 @@ func (v *View) Folder() *Folder { return v.folder } +// Env returns the environment to use for running go commands in this view. +func (v *View) Env() []string { + return slices.Concat( + os.Environ(), + v.folder.Options.EnvSlice(), + []string{"GO111MODULE=" + v.adjustedGO111MODULE()}, + v.EnvOverlay(), + ) +} + // UpdateFolders updates the set of views for the new folders. // // Calling this causes each view to be reinitialized. @@ -663,11 +673,10 @@ func (s *Snapshot) initialize(ctx context.Context, firstAttempt bool) { addError(modURI, fmt.Errorf("no module path for %s", modURI)) continue } - moduleDir := filepath.Dir(modURI.Path()) // Previously, we loaded /... for each module path, but that // is actually incorrect when the pattern may match packages in more than // one module. See golang/go#59458 for more details. - scopes = append(scopes, moduleLoadScope{dir: moduleDir, modulePath: parsed.File.Module.Mod.Path}) + scopes = append(scopes, moduleLoadScope{dir: modURI.DirPath(), modulePath: parsed.File.Module.Mod.Path}) } } else { scopes = append(scopes, viewLoadScope{}) @@ -679,7 +688,7 @@ func (s *Snapshot) initialize(ctx context.Context, firstAttempt bool) { if len(scopes) > 0 { scopes = append(scopes, packageLoadScope("builtin")) } - loadErr := s.load(ctx, true, scopes...) + loadErr := s.load(ctx, NetworkOK, scopes...) // A failure is retryable if it may have been due to context cancellation, // and this is not the initial workspace load (firstAttempt==true). @@ -816,7 +825,7 @@ func defineView(ctx context.Context, fs file.Source, folder *Folder, forFile fil } dir := folder.Dir.Path() if forFile != nil { - dir = filepath.Dir(forFile.URI().Path()) + dir = forFile.URI().DirPath() } def := new(viewDefinition) diff --git a/gopls/internal/cache/xrefs/xrefs.go b/gopls/internal/cache/xrefs/xrefs.go index 4113e08716e..2115322bfdc 100644 --- a/gopls/internal/cache/xrefs/xrefs.go +++ b/gopls/internal/cache/xrefs/xrefs.go @@ -17,6 +17,7 @@ import ( "golang.org/x/tools/gopls/internal/cache/metadata" "golang.org/x/tools/gopls/internal/cache/parsego" "golang.org/x/tools/gopls/internal/protocol" + "golang.org/x/tools/gopls/internal/util/bug" "golang.org/x/tools/gopls/internal/util/frob" ) @@ -43,15 +44,6 @@ func Index(files []*parsego.File, pkg *types.Package, info *types.Info) []byte { objectpathFor := new(objectpath.Encoder).For for fileIndex, pgf := range files { - - nodeRange := func(n ast.Node) protocol.Range { - rng, err := pgf.PosRange(n.Pos(), n.End()) - if err != nil { - panic(err) // can't fail - } - return rng - } - ast.Inspect(pgf.File, func(n ast.Node) bool { switch n := n.(type) { case *ast.Ident: @@ -82,10 +74,15 @@ func Index(files []*parsego.File, pkg *types.Package, info *types.Info) []byte { objects[obj] = gobObj } - gobObj.Refs = append(gobObj.Refs, gobRef{ - FileIndex: fileIndex, - Range: nodeRange(n), - }) + // golang/go#66683: nodes can under/overflow the file. + // For example, "var _ = x." creates a SelectorExpr(Sel=Ident("_")) + // that is beyond EOF. (Arguably Ident.Name should be "".) + if rng, err := pgf.NodeRange(n); err == nil { + gobObj.Refs = append(gobObj.Refs, gobRef{ + FileIndex: fileIndex, + Range: rng, + }) + } } } @@ -102,10 +99,15 @@ func Index(files []*parsego.File, pkg *types.Package, info *types.Info) []byte { gobObj = &gobObject{Path: ""} objects[nil] = gobObj } - gobObj.Refs = append(gobObj.Refs, gobRef{ - FileIndex: fileIndex, - Range: nodeRange(n.Path), - }) + // golang/go#66683: nodes can under/overflow the file. + if rng, err := pgf.NodeRange(n.Path); err == nil { + gobObj.Refs = append(gobObj.Refs, gobRef{ + FileIndex: fileIndex, + Range: rng, + }) + } else { + bug.Reportf("out of bounds import spec %+v", n.Path) + } } return true }) diff --git a/gopls/internal/cmd/cmd.go b/gopls/internal/cmd/cmd.go index 91aca4683b5..d27542f79fb 100644 --- a/gopls/internal/cmd/cmd.go +++ b/gopls/internal/cmd/cmd.go @@ -425,6 +425,10 @@ func newConnection(server protocol.Server, client *cmdClient) *connection { } } +func (c *cmdClient) TextDocumentContentRefresh(context.Context, *protocol.TextDocumentContentRefreshParams) error { + return nil +} + func (c *cmdClient) CodeLensRefresh(context.Context) error { return nil } func (c *cmdClient) FoldingRangeRefresh(context.Context) error { return nil } diff --git a/gopls/internal/cmd/codeaction.go b/gopls/internal/cmd/codeaction.go index c349c7ab653..2096a153681 100644 --- a/gopls/internal/cmd/codeaction.go +++ b/gopls/internal/cmd/codeaction.go @@ -51,6 +51,7 @@ Valid kinds include: quickfix refactor refactor.extract + refactor.extract.constant refactor.extract.function refactor.extract.method refactor.extract.toNewFile diff --git a/gopls/internal/cmd/integration_test.go b/gopls/internal/cmd/integration_test.go index 15888b21f68..ad08119d397 100644 --- a/gopls/internal/cmd/integration_test.go +++ b/gopls/internal/cmd/integration_test.go @@ -818,9 +818,9 @@ const c = 0 got := res.stdout want := ` /*⇒7,keyword,[]*/package /*⇒1,namespace,[]*/a -/*⇒4,keyword,[]*/func /*⇒1,function,[definition]*/f() -/*⇒3,keyword,[]*/var /*⇒1,variable,[definition]*/v /*⇒3,type,[defaultLibrary number]*/int -/*⇒5,keyword,[]*/const /*⇒1,variable,[definition readonly]*/c = /*⇒1,number,[]*/0 +/*⇒4,keyword,[]*/func /*⇒1,function,[definition signature]*/f() +/*⇒3,keyword,[]*/var /*⇒1,variable,[definition number]*/v /*⇒3,type,[defaultLibrary number]*/int +/*⇒5,keyword,[]*/const /*⇒1,variable,[definition readonly number]*/c = /*⇒1,number,[]*/0 `[1:] if got != want { t.Errorf("semtok: got <<%s>>, want <<%s>>", got, want) diff --git a/gopls/internal/cmd/usage/codeaction.hlp b/gopls/internal/cmd/usage/codeaction.hlp index 6d6923ef458..d7bfe3ea99e 100644 --- a/gopls/internal/cmd/usage/codeaction.hlp +++ b/gopls/internal/cmd/usage/codeaction.hlp @@ -22,6 +22,7 @@ Valid kinds include: quickfix refactor refactor.extract + refactor.extract.constant refactor.extract.function refactor.extract.method refactor.extract.toNewFile diff --git a/gopls/internal/doc/api.json b/gopls/internal/doc/api.json index 298c3ab49e1..b64965ab863 100644 --- a/gopls/internal/doc/api.json +++ b/gopls/internal/doc/api.json @@ -569,11 +569,6 @@ "Doc": "check for calls of (time.Time).Format or time.Parse with 2006-02-01\n\nThe timeformat checker looks for time formats with the 2006-02-01 (yyyy-dd-mm)\nformat. Internationally, \"yyyy-dd-mm\" does not occur in common calendar date\nstandards, and so it is more likely that 2006-01-02 (yyyy-mm-dd) was intended.", "Default": "true" }, - { - "Name": "\"undeclaredname\"", - "Doc": "suggested fixes for \"undeclared name: \u003c\u003e\"\n\nThis checker provides suggested fixes for type errors of the\ntype \"undeclared name: \u003c\u003e\". It will either insert a new statement,\nsuch as:\n\n\t\u003c\u003e :=\n\nor a new function declaration, such as:\n\n\tfunc \u003c\u003e(inferred parameters) {\n\t\tpanic(\"implement me!\")\n\t}", - "Default": "true" - }, { "Name": "\"unmarshal\"", "Doc": "report passing non-pointer or non-interface values to unmarshal\n\nThe unmarshal analysis reports calls to functions such as json.Unmarshal\nin which the argument type is not a pointer or an interface.", @@ -613,6 +608,16 @@ "Name": "\"useany\"", "Doc": "check for constraints that could be simplified to \"any\"", "Default": "false" + }, + { + "Name": "\"waitgroup\"", + "Doc": "check for misuses of sync.WaitGroup\n\nThis analyzer detects mistaken calls to the (*sync.WaitGroup).Add\nmethod from inside a new goroutine, causing Add to race with Wait:\n\n\t// WRONG\n\tvar wg sync.WaitGroup\n\tgo func() {\n\t wg.Add(1) // \"WaitGroup.Add called from inside new goroutine\"\n\t defer wg.Done()\n\t ...\n\t}()\n\twg.Wait() // (may return prematurely before new goroutine starts)\n\nThe correct code calls Add before starting the goroutine:\n\n\t// RIGHT\n\tvar wg sync.WaitGroup\n\twg.Add(1)\n\tgo func() {\n\t\tdefer wg.Done()\n\t\t...\n\t}()\n\twg.Wait()", + "Default": "true" + }, + { + "Name": "\"yield\"", + "Doc": "report calls to yield where the result is ignored\n\nAfter a yield function returns false, the caller should not call\nthe yield function again; generally the iterator should return\npromptly.\n\nThis example fails to check the result of the call to yield,\ncausing this analyzer to report a diagnostic:\n\n\tyield(1) // yield may be called again (on L2) after returning false\n\tyield(2)\n\nThe corrected code is either this:\n\n\tif yield(1) { yield(2) }\n\nor simply:\n\n\t_ = yield(1) \u0026\u0026 yield(2)\n\nIt is not always a mistake to ignore the result of yield.\nFor example, this is a valid single-element iterator:\n\n\tyield(1) // ok to ignore result\n\treturn\n\nIt is only a mistake when the yield call that returned false may be\nfollowed by another call.", + "Default": "true" } ] }, @@ -811,7 +816,7 @@ }, { "Name": "\"run_govulncheck\"", - "Doc": "`\"run_govulncheck\"`: Run govulncheck\n\nThis codelens source annotates the `module` directive in a\ngo.mod file with a command to run Govulncheck.\n\n[Govulncheck](https://go.dev/blog/vuln) is a static\nanalysis tool that computes the set of functions reachable\nwithin your application, including dependencies;\nqueries a database of known security vulnerabilities; and\nreports any potential problems it finds.\n", + "Doc": "`\"run_govulncheck\"`: Run govulncheck (legacy)\n\nThis codelens source annotates the `module` directive in a go.mod file\nwith a command to run Govulncheck asynchronously.\n\n[Govulncheck](https://go.dev/blog/vuln) is a static analysis tool that\ncomputes the set of functions reachable within your application, including\ndependencies; queries a database of known security vulnerabilities; and\nreports any potential problems it finds.\n", "Default": "false" }, { @@ -833,6 +838,11 @@ "Name": "\"vendor\"", "Doc": "`\"vendor\"`: Update vendor directory\n\nThis codelens source annotates the `module` directive in a\ngo.mod file with a command to run [`go mod\nvendor`](https://go.dev/ref/mod#go-mod-vendor), which\ncreates or updates the directory named `vendor` in the\nmodule root so that it contains an up-to-date copy of all\nnecessary package dependencies.\n", "Default": "true" + }, + { + "Name": "\"vulncheck\"", + "Doc": "`\"vulncheck\"`: Run govulncheck\n\nThis codelens source annotates the `module` directive in a go.mod file\nwith a command to run govulncheck synchronously.\n\n[Govulncheck](https://go.dev/blog/vuln) is a static analysis tool that\ncomputes the set of functions reachable within your application, including\ndependencies; queries a database of known security vulnerabilities; and\nreports any potential problems it finds.\n", + "Default": "false" } ] }, @@ -953,8 +963,8 @@ { "FileType": "go.mod", "Lens": "run_govulncheck", - "Title": "Run govulncheck", - "Doc": "\nThis codelens source annotates the `module` directive in a\ngo.mod file with a command to run Govulncheck.\n\n[Govulncheck](https://go.dev/blog/vuln) is a static\nanalysis tool that computes the set of functions reachable\nwithin your application, including dependencies;\nqueries a database of known security vulnerabilities; and\nreports any potential problems it finds.\n", + "Title": "Run govulncheck (legacy)", + "Doc": "\nThis codelens source annotates the `module` directive in a go.mod file\nwith a command to run Govulncheck asynchronously.\n\n[Govulncheck](https://go.dev/blog/vuln) is a static analysis tool that\ncomputes the set of functions reachable within your application, including\ndependencies; queries a database of known security vulnerabilities; and\nreports any potential problems it finds.\n", "Default": false }, { @@ -977,6 +987,13 @@ "Title": "Update vendor directory", "Doc": "\nThis codelens source annotates the `module` directive in a\ngo.mod file with a command to run [`go mod\nvendor`](https://go.dev/ref/mod#go-mod-vendor), which\ncreates or updates the directory named `vendor` in the\nmodule root so that it contains an up-to-date copy of all\nnecessary package dependencies.\n", "Default": true + }, + { + "FileType": "go.mod", + "Lens": "vulncheck", + "Title": "Run govulncheck", + "Doc": "\nThis codelens source annotates the `module` directive in a go.mod file\nwith a command to run govulncheck synchronously.\n\n[Govulncheck](https://go.dev/blog/vuln) is a static analysis tool that\ncomputes the set of functions reachable within your application, including\ndependencies; queries a database of known security vulnerabilities; and\nreports any potential problems it finds.\n", + "Default": false } ], "Analyzers": [ @@ -1238,12 +1255,6 @@ "URL": "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/timeformat", "Default": true }, - { - "Name": "undeclaredname", - "Doc": "suggested fixes for \"undeclared name: \u003c\u003e\"\n\nThis checker provides suggested fixes for type errors of the\ntype \"undeclared name: \u003c\u003e\". It will either insert a new statement,\nsuch as:\n\n\t\u003c\u003e :=\n\nor a new function declaration, such as:\n\n\tfunc \u003c\u003e(inferred parameters) {\n\t\tpanic(\"implement me!\")\n\t}", - "URL": "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/undeclaredname", - "Default": true - }, { "Name": "unmarshal", "Doc": "report passing non-pointer or non-interface values to unmarshal\n\nThe unmarshal analysis reports calls to functions such as json.Unmarshal\nin which the argument type is not a pointer or an interface.", @@ -1291,6 +1302,18 @@ "Doc": "check for constraints that could be simplified to \"any\"", "URL": "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/useany", "Default": false + }, + { + "Name": "waitgroup", + "Doc": "check for misuses of sync.WaitGroup\n\nThis analyzer detects mistaken calls to the (*sync.WaitGroup).Add\nmethod from inside a new goroutine, causing Add to race with Wait:\n\n\t// WRONG\n\tvar wg sync.WaitGroup\n\tgo func() {\n\t wg.Add(1) // \"WaitGroup.Add called from inside new goroutine\"\n\t defer wg.Done()\n\t ...\n\t}()\n\twg.Wait() // (may return prematurely before new goroutine starts)\n\nThe correct code calls Add before starting the goroutine:\n\n\t// RIGHT\n\tvar wg sync.WaitGroup\n\twg.Add(1)\n\tgo func() {\n\t\tdefer wg.Done()\n\t\t...\n\t}()\n\twg.Wait()", + "URL": "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/waitgroup", + "Default": true + }, + { + "Name": "yield", + "Doc": "report calls to yield where the result is ignored\n\nAfter a yield function returns false, the caller should not call\nthe yield function again; generally the iterator should return\npromptly.\n\nThis example fails to check the result of the call to yield,\ncausing this analyzer to report a diagnostic:\n\n\tyield(1) // yield may be called again (on L2) after returning false\n\tyield(2)\n\nThe corrected code is either this:\n\n\tif yield(1) { yield(2) }\n\nor simply:\n\n\t_ = yield(1) \u0026\u0026 yield(2)\n\nIt is not always a mistake to ignore the result of yield.\nFor example, this is a valid single-element iterator:\n\n\tyield(1) // ok to ignore result\n\treturn\n\nIt is only a mistake when the yield call that returned false may be\nfollowed by another call.", + "URL": "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/yield", + "Default": true } ], "Hints": [ diff --git a/gopls/internal/golang/addtest.go b/gopls/internal/golang/addtest.go index bf4dfed0acf..8228faf0fc8 100644 --- a/gopls/internal/golang/addtest.go +++ b/gopls/internal/golang/addtest.go @@ -12,112 +12,195 @@ import ( "errors" "fmt" "go/ast" + "go/format" "go/token" "go/types" - "html/template" "os" "path/filepath" + "sort" "strconv" "strings" + "text/template" + "unicode" "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/gopls/internal/cache" + "golang.org/x/tools/gopls/internal/cache/metadata" "golang.org/x/tools/gopls/internal/cache/parsego" "golang.org/x/tools/gopls/internal/protocol" goplsastutil "golang.org/x/tools/gopls/internal/util/astutil" + "golang.org/x/tools/internal/imports" "golang.org/x/tools/internal/typesinternal" ) -const testTmplString = `func {{.TestFuncName}}(t *testing.T) { - {{- /* Functions/methods input parameters struct declaration. */}} - {{- if gt (len .Args) 1}} - type args struct { - {{- range .Args}} - {{.Name}} {{.Type}} - {{- end}} - } - {{- end}} - - {{- /* Test cases struct declaration and empty initialization. */}} - tests := []struct { - name string // description of this test case - {{- if gt (len .Args) 1}} - args args - {{- end}} - {{- if eq (len .Args) 1}} - arg {{(index .Args 0).Type}} - {{- end}} - {{- range $index, $res := .Results}} - {{- if eq $res.Name "gotErr"}} - wantErr bool - {{- else if eq $index 0}} - want {{$res.Type}} - {{- else}} - want{{add $index 1}} {{$res.Type}} - {{- end}} - {{- end}} - }{ - // TODO: Add test cases. - } - - {{- /* Loop over all the test cases. */}} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - {{/* Got variables. */}} - {{- if .Results}}{{fieldNames .Results ""}} := {{end}} - - {{- /* Call expression. In xtest package test, call function by PACKAGE.FUNC. */}} - {{- /* TODO(hxjiang): consider any renaming in existing xtest package imports. E.g. import renamedfoo "foo". */}} - {{- /* TODO(hxjiang): support add test for methods by calling the right constructor. */}} - {{- if .PackageName}}{{.PackageName}}.{{end}}{{.FuncName}} - - {{- /* Input parameters. */ -}} - ({{- if eq (len .Args) 1}}tt.arg{{end}}{{if gt (len .Args) 1}}{{fieldNames .Args "tt.args."}}{{end}}) - - {{- /* Handles the returned error before the rest of return value. */}} - {{- $last := index .Results (add (len .Results) -1)}} - {{- if eq $last.Name "gotErr"}} - if gotErr != nil { - if !tt.wantErr { - t.Errorf("{{$.FuncName}}() failed: %v", gotErr) - } - return - } - if tt.wantErr { - t.Fatal("{{$.FuncName}}() succeeded unexpectedly") - } - {{- end}} - - {{- /* Compare the returned values except for the last returned error. */}} - {{- if or (and .Results (ne $last.Name "gotErr")) (and (gt (len .Results) 1) (eq $last.Name "gotErr"))}} - // TODO: update the condition below to compare got with tt.want. - {{- range $index, $res := .Results}} - {{- if ne $res.Name "gotErr"}} - if true { - t.Errorf("{{$.FuncName}}() = %v, want %v", {{.Name}}, tt.{{if eq $index 0}}want{{else}}want{{add $index 1}}{{end}}) - } - {{- end}} - {{- end}} - {{- end}} - }) - } +const testTmplString = ` +func {{.TestFuncName}}(t *{{.TestingPackageName}}.T) { + {{- /* Test cases struct declaration and empty initialization. */}} + tests := []struct { + name string // description of this test case + + {{- $commentPrinted := false }} + {{- if and .Receiver .Receiver.Constructor}} + {{- range .Receiver.Constructor.Args}} + {{- if .Name}} + {{- if not $commentPrinted}} + // Named input parameters for receiver constructor. + {{- $commentPrinted = true }} + {{- end}} + {{.Name}} {{.Type}} + {{- end}} + {{- end}} + {{- end}} + + {{- $commentPrinted := false }} + {{- range .Func.Args}} + {{- if .Name}} + {{- if not $commentPrinted}} + // Named input parameters for target function. + {{- $commentPrinted = true }} + {{- end}} + {{.Name}} {{.Type}} + {{- end}} + {{- end}} + + {{- range $index, $res := .Func.Results}} + {{- if eq $res.Name "gotErr"}} + wantErr bool + {{- else if eq $index 0}} + want {{$res.Type}} + {{- else}} + want{{add $index 1}} {{$res.Type}} + {{- end}} + {{- end}} + }{ + // TODO: Add test cases. + } + + {{- /* Loop over all the test cases. */}} + for _, tt := range tests { + t.Run(tt.name, func(t *{{.TestingPackageName}}.T) { + {{- /* Constructor or empty initialization. */}} + {{- if .Receiver}} + {{- if .Receiver.Constructor}} + {{- /* Receiver variable by calling constructor. */}} + {{fieldNames .Receiver.Constructor.Results ""}} := {{if .PackageName}}{{.PackageName}}.{{end}} + {{- .Receiver.Constructor.Name}} + + {{- /* Constructor input parameters. */ -}} + ( + {{- range $index, $arg := .Receiver.Constructor.Args}} + {{- if ne $index 0}}, {{end}} + {{- if .Name}}tt.{{.Name}}{{else}}{{.Value}}{{end}} + {{- end -}} + ) + + {{- /* Handles the error return from constructor. */}} + {{- $last := last .Receiver.Constructor.Results}} + {{- if eq $last.Type "error"}} + if err != nil { + t.Fatalf("could not construct receiver type: %v", err) + } + {{- end}} + {{- else}} + {{- /* Receiver variable declaration. */}} + // TODO: construct the receiver type. + var {{.Receiver.Var.Name}} {{.Receiver.Var.Type}} + {{- end}} + {{- end}} + + {{- /* Got variables. */}} + {{if .Func.Results}}{{fieldNames .Func.Results ""}} := {{end}} + + {{- /* Call expression. */}} + {{- if .Receiver}}{{/* Call method by VAR.METHOD. */}} + {{- .Receiver.Var.Name}}. + {{- else if .PackageName}}{{/* Call function by PACKAGE.FUNC. */}} + {{- .PackageName}}. + {{- end}}{{.Func.Name}} + + {{- /* Input parameters. */ -}} + ( + {{- range $index, $arg := .Func.Args}} + {{- if ne $index 0}}, {{end}} + {{- if .Name}}tt.{{.Name}}{{else}}{{.Value}}{{end}} + {{- end -}} + ) + + {{- /* Handles the returned error before the rest of return value. */}} + {{- $last := last .Func.Results}} + {{- if eq $last.Type "error"}} + if gotErr != nil { + if !tt.wantErr { + t.Errorf("{{$.Func.Name}}() failed: %v", gotErr) + } + return + } + if tt.wantErr { + t.Fatal("{{$.Func.Name}}() succeeded unexpectedly") + } + {{- end}} + + {{- /* Compare the returned values except for the last returned error. */}} + {{- if or (and .Func.Results (ne $last.Type "error")) (and (gt (len .Func.Results) 1) (eq $last.Type "error"))}} + // TODO: update the condition below to compare got with tt.want. + {{- range $index, $res := .Func.Results}} + {{- if ne $res.Name "gotErr"}} + if true { + t.Errorf("{{$.Func.Name}}() = %v, want %v", {{.Name}}, tt.{{if eq $index 0}}want{{else}}want{{add $index 1}}{{end}}) + } + {{- end}} + {{- end}} + {{- end}} + }) + } } ` +// Name is the name of the field this input parameter should reference. +// Value is the expression this input parameter should accept. +// +// Exactly one of Name or Value must be set. type field struct { - Name, Type string + Name, Type, Value string +} + +type function struct { + Name string + Args []field + Results []field +} + +type receiver struct { + // Var is the name and type of the receiver variable. + Var field + // Constructor holds information about the constructor for the receiver type. + // If no qualified constructor is found, this field will be nil. + Constructor *function } type testInfo struct { + // TestingPackageName is the package name should be used when referencing + // package "testing" + TestingPackageName string + // PackageName is the package name the target function/method is delcared from. PackageName string - FuncName string TestFuncName string - Args []field - Results []field + // Func holds information about the function or method being tested. + Func function + // Receiver holds information about the receiver of the function or method + // being tested. + // This field is nil for functions and non-nil for methods. + Receiver *receiver } var testTmpl = template.Must(template.New("test").Funcs(template.FuncMap{ "add": func(a, b int) int { return a + b }, + "last": func(slice []field) field { + if len(slice) == 0 { + return field{} + } + return slice[len(slice)-1] + }, "fieldNames": func(fields []field, qualifier string) (res string) { var names []string for _, f := range fields { @@ -135,6 +218,10 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol. return nil, err } + if metadata.IsCommandLineArguments(pkg.Metadata().ID) { + return nil, fmt.Errorf("current file in command-line-arguments package") + } + if errors := pkg.ParseErrors(); len(errors) > 0 { return nil, fmt.Errorf("package has parse errors: %v", errors[0]) } @@ -142,35 +229,44 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol. return nil, fmt.Errorf("package has type errors: %v", errors[0]) } - // imports is a map from package path to local package name. - var imports = make(map[string]string) + // All three maps map the path of an imported package to + // the local name if explicit or "" otherwise. + var ( + fileImports map[string]string // imports in foo.go file + testImports map[string]string // imports in foo_test.go file + extraImports = make(map[string]string) // imports to add to test file + ) - var collectImports = func(file *ast.File) error { + var collectImports = func(file *ast.File) (map[string]string, error) { + imps := make(map[string]string) for _, spec := range file.Imports { // TODO(hxjiang): support dot imports. if spec.Name != nil && spec.Name.Name == "." { - return fmt.Errorf("\"add a test for FUNC\" does not support files containing dot imports") + return nil, fmt.Errorf("\"add test for func\" does not support files containing dot imports") } path, err := strconv.Unquote(spec.Path.Value) if err != nil { - return err + return nil, err } - if spec.Name != nil && spec.Name.Name != "_" { - imports[path] = spec.Name.Name + if spec.Name != nil { + if spec.Name.Name == "_" { + continue + } + imps[path] = spec.Name.Name } else { - imports[path] = filepath.Base(path) + imps[path] = "" } } - return nil + return imps, nil } // Collect all the imports from the x.go, keep track of the local package name. - if err := collectImports(pgf.File); err != nil { + if fileImports, err = collectImports(pgf.File); err != nil { return nil, err } testBase := strings.TrimSuffix(filepath.Base(loc.URI.Path()), ".go") + "_test.go" - goTestFileURI := protocol.URIFromPath(filepath.Join(loc.URI.Dir().Path(), testBase)) + goTestFileURI := protocol.URIFromPath(filepath.Join(loc.URI.DirPath(), testBase)) testFH, err := snapshot.ReadFile(ctx, goTestFileURI) if err != nil { @@ -185,14 +281,32 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol. // edits contains all the text edits to be applied to the test file. edits []protocol.TextEdit // xtest indicates whether the test file use package x or x_test. - // TODO(hxjiang): For now, we try to interpret the user's intention by - // reading the foo_test.go's package name. Instead, we can discuss the option - // to interpret the user's intention by which function they are selecting. - // Have one file for x_test package testing, one file for x package testing. + // TODO(hxjiang): We can discuss the option to interpret the user's + // intention by which function they are selecting. Have one file for + // x_test package testing, one file for x package testing. xtest = true ) - if testPGF, err := snapshot.ParseGo(ctx, testFH, parsego.Header); err != nil { + start, end, err := pgf.RangePos(loc.Range) + if err != nil { + return nil, err + } + + path, _ := astutil.PathEnclosingInterval(pgf.File, start, end) + if len(path) < 2 { + return nil, fmt.Errorf("no enclosing function") + } + + decl, ok := path[len(path)-2].(*ast.FuncDecl) + if !ok { + return nil, fmt.Errorf("no enclosing function") + } + + fn := pkg.TypesInfo().Defs[decl.Name].(*types.Func) + sig := fn.Signature() + + testPGF, err := snapshot.ParseGo(ctx, testFH, parsego.Header) + if err != nil { if !errors.Is(err, os.ErrNotExist) { return nil, err } @@ -205,24 +319,75 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol. // package decl based on the originating file. // Search for something that looks like a copyright header, to replicate // in the new file. - if groups := pgf.File.Comments; len(groups) > 0 { - // Copyright should appear before package decl and must be the first - // comment group. - // Avoid copying any other comment like package doc or directive comment. - if c := groups[0]; c.Pos() < pgf.File.Package && c != pgf.File.Doc && - !isDirective(c.List[0].Text) && - strings.Contains(strings.ToLower(c.List[0].Text), "copyright") { - start, end, err := pgf.NodeOffsets(c) - if err != nil { - return nil, err + if c := copyrightComment(pgf.File); c != nil { + start, end, err := pgf.NodeOffsets(c) + if err != nil { + return nil, err + } + header.Write(pgf.Src[start:end]) + // One empty line between copyright header and following. + header.WriteString("\n\n") + } + + // If this test file was created by gopls, add build constraints + // matching the non-test file. + if c := buildConstraintComment(pgf.File); c != nil { + start, end, err := pgf.NodeOffsets(c) + if err != nil { + return nil, err + } + header.Write(pgf.Src[start:end]) + // One empty line between build constraint and following. + header.WriteString("\n\n") + } + + // Determine if a new test file should use in-package test (package x) + // or external test (package x_test). If any of the function parameters + // reference an unexported object, we cannot write out test cases from + // an x_test package. + externalTestOK := func() bool { + if !fn.Exported() { + return false + } + if fn.Signature().Recv() != nil { + if _, ident, _ := goplsastutil.UnpackRecv(decl.Recv.List[0].Type); ident == nil || !ident.IsExported() { + return false } - header.Write(pgf.Src[start:end]) - // One empty line between copyright header and package decl. - header.WriteString("\n\n") } + refsUnexported := false + ast.Inspect(decl, func(n ast.Node) bool { + // The original function refs to an unexported object from the + // same package, so further inspection is unnecessary. + if refsUnexported { + return false + } + switch t := n.(type) { + case *ast.BlockStmt: + // Avoid inspect the function body. + return false + case *ast.Ident: + // Use test variant (package foo) if the function signature + // references any unexported objects (like types or + // constants) from the same package. + // Note: types.PkgName is excluded from this check as it's + // always defined in the same package. + if obj, ok := pkg.TypesInfo().Uses[t]; ok && !obj.Exported() && obj.Pkg() == pkg.Types() && !is[*types.PkgName](obj) { + refsUnexported = true + } + return false + default: + return true + } + }) + return !refsUnexported + } + + xtest = externalTestOK() + if xtest { + fmt.Fprintf(&header, "package %s_test\n", pkg.Types().Name()) + } else { + fmt.Fprintf(&header, "package %s\n", pkg.Types().Name()) } - // One empty line between package decl and rest of the file. - fmt.Fprintf(&header, "package %s_test\n\n", pkg.Types().Name()) // Write the copyright and package decl to the beginning of the file. edits = append(edits, protocol.TextEdit{ @@ -247,47 +412,43 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol. return nil, err } - // Collect all the imports from the x_test.go, overwrite the local pakcage - // name collected from x.go. - if err := collectImports(testPGF.File); err != nil { + // Collect all the imports from the foo_test.go. + if testImports, err = collectImports(testPGF.File); err != nil { return nil, err } } - // qf qualifier returns the local package name need to use in x_test.go by - // consulting the consolidated imports map. + // qf qualifier determines the correct package name to use for a type in + // foo_test.go. It does this by: + // - Consult imports map from test file foo_test.go. + // - If not found, consult imports map from original file foo.go. + // If the package is not imported in test file foo_test.go, it is added to + // extraImports map. qf := func(p *types.Package) string { - // When generating test in x packages, any type/function defined in the same - // x package can emit package name. + // References from an in-package test should not be qualified. if !xtest && p == pkg.Types() { return "" } - if local, ok := imports[p.Path()]; ok { + // Prefer using the package name if already defined in foo_test.go + if local, ok := testImports[p.Path()]; ok { + if local != "" { + return local + } else { + return p.Name() + } + } + // TODO(hxjiang): we should consult the scope of the test package to + // ensure these new imports do not shadow any package-level names. + // Prefer the local import name (if any) used in the package under test. + if local, ok := fileImports[p.Path()]; ok && local != "" { + extraImports[p.Path()] = local return local } + // Fall back to the package name since there is no renaming. + extraImports[p.Path()] = "" return p.Name() } - // TODO(hxjiang): modify existing imports or add new imports. - - start, end, err := pgf.RangePos(loc.Range) - if err != nil { - return nil, err - } - - path, _ := astutil.PathEnclosingInterval(pgf.File, start, end) - if len(path) < 2 { - return nil, fmt.Errorf("no enclosing function") - } - - decl, ok := path[len(path)-2].(*ast.FuncDecl) - if !ok { - return nil, fmt.Errorf("no enclosing function") - } - - fn := pkg.TypesInfo().Defs[decl.Name].(*types.Func) - sig := fn.Signature() - if xtest { // Reject if function/method is unexported. if !fn.Exported() { @@ -296,7 +457,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol. // Reject if receiver is unexported. if sig.Recv() != nil { - if _, ident, _ := goplsastutil.UnpackRecv(decl.Recv.List[0].Type); !ident.IsExported() { + if _, ident, _ := goplsastutil.UnpackRecv(decl.Recv.List[0].Type); ident == nil || !ident.IsExported() { return nil, fmt.Errorf("cannot add external test for method %s.%s as receiver type is not exported", ident.Name, decl.Name) } } @@ -309,40 +470,281 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol. if err != nil { return nil, err } + data := testInfo{ - FuncName: fn.Name(), - TestFuncName: testName, + TestingPackageName: qf(types.NewPackage("testing", "testing")), + PackageName: qf(pkg.Types()), + TestFuncName: testName, + Func: function{ + Name: fn.Name(), + }, } - if sig.Recv() == nil && xtest { - data.PackageName = qf(pkg.Types()) + errorType := types.Universe.Lookup("error").Type() + + var isContextType = func(t types.Type) bool { + named, ok := t.(*types.Named) + if !ok { + return false + } + return named.Obj().Pkg().Path() == "context" && named.Obj().Name() == "Context" } for i := range sig.Params().Len() { - if i == 0 { - data.Args = append(data.Args, field{ - Name: "in", - Type: types.TypeString(sig.Params().At(i).Type(), qf), - }) + param := sig.Params().At(i) + name, typ := param.Name(), param.Type() + f := field{Type: types.TypeString(typ, qf)} + if i == 0 && isContextType(typ) { + f.Value = qf(types.NewPackage("context", "context")) + ".Background()" + } else if name == "" || name == "_" { + f.Value = typesinternal.ZeroString(typ, qf) } else { - data.Args = append(data.Args, field{ - Name: fmt.Sprintf("in%d", i+1), - Type: types.TypeString(sig.Params().At(i).Type(), qf), - }) + f.Name = name } + data.Func.Args = append(data.Func.Args, f) } - errorType := types.Universe.Lookup("error").Type() for i := range sig.Results().Len() { - name := "got" - if i == sig.Results().Len()-1 && types.Identical(sig.Results().At(i).Type(), errorType) { + typ := sig.Results().At(i).Type() + var name string + if i == sig.Results().Len()-1 && types.Identical(typ, errorType) { name = "gotErr" - } else if i > 0 { + } else if i == 0 { + name = "got" + } else { name = fmt.Sprintf("got%d", i+1) } - data.Results = append(data.Results, field{ + data.Func.Results = append(data.Func.Results, field{ Name: name, - Type: types.TypeString(sig.Results().At(i).Type(), qf), + Type: types.TypeString(typ, qf), + }) + } + + if sig.Recv() != nil { + // Find the preferred type for the receiver. We don't use + // typesinternal.ReceiverNamed here as we want to preserve aliases. + recvType := sig.Recv().Type() + if ptr, ok := recvType.(*types.Pointer); ok { + recvType = ptr.Elem() + } + + t, ok := recvType.(typesinternal.NamedOrAlias) + if !ok { + return nil, fmt.Errorf("the receiver type is neither named type nor alias type") + } + + var varName string + { + var possibleNames []string // list of candidates, preferring earlier entries. + if len(sig.Recv().Name()) > 0 { + possibleNames = append(possibleNames, + sig.Recv().Name(), // receiver name. + string(sig.Recv().Name()[0]), // first character of receiver name. + ) + } + possibleNames = append(possibleNames, + string(t.Obj().Name()[0]), // first character of receiver type name. + ) + if len(t.Obj().Name()) >= 2 { + possibleNames = append(possibleNames, + string(t.Obj().Name()[:2]), // first two character of receiver type name. + ) + } + var camelCase []rune + for i, s := range t.Obj().Name() { + if i == 0 || unicode.IsUpper(s) { + camelCase = append(camelCase, s) + } + } + possibleNames = append(possibleNames, + string(camelCase), // captalized initials. + ) + for _, name := range possibleNames { + name = strings.ToLower(name) + if name == "" || name == "t" || name == "tt" { + continue + } + varName = name + break + } + if varName == "" { + varName = "r" // default as "r" for "receiver". + } + } + + data.Receiver = &receiver{ + Var: field{ + Name: varName, + Type: types.TypeString(recvType, qf), + }, + } + + // constructor is the selected constructor for type T. + var constructor *types.Func + + // When finding the qualified constructor, the function should return the + // any type whose named type is the same type as T's named type. + _, wantType := typesinternal.ReceiverNamed(sig.Recv()) + for _, name := range pkg.Types().Scope().Names() { + f, ok := pkg.Types().Scope().Lookup(name).(*types.Func) + if !ok { + continue + } + if f.Signature().Recv() != nil { + continue + } + // Unexported constructor is not visible in x_test package. + if xtest && !f.Exported() { + continue + } + // Only allow constructors returning T, T, (T, error), or (T, error). + if f.Signature().Results().Len() > 2 || f.Signature().Results().Len() == 0 { + continue + } + + _, gotType := typesinternal.ReceiverNamed(f.Signature().Results().At(0)) + if gotType == nil || !types.Identical(gotType, wantType) { + continue + } + + if f.Signature().Results().Len() == 2 && !types.Identical(f.Signature().Results().At(1).Type(), errorType) { + continue + } + + if constructor == nil { + constructor = f + } + + // Functions named NewType are prioritized as constructors over other + // functions that match only the signature criteria. + if strings.EqualFold(strings.ToLower(f.Name()), strings.ToLower("new"+t.Obj().Name())) { + constructor = f + } + } + + if constructor != nil { + data.Receiver.Constructor = &function{Name: constructor.Name()} + for i := range constructor.Signature().Params().Len() { + param := constructor.Signature().Params().At(i) + name, typ := param.Name(), param.Type() + f := field{Type: types.TypeString(typ, qf)} + if i == 0 && isContextType(typ) { + f.Value = qf(types.NewPackage("context", "context")) + ".Background()" + } else if name == "" || name == "_" { + f.Value = typesinternal.ZeroString(typ, qf) + } else { + f.Name = name + } + data.Receiver.Constructor.Args = append(data.Receiver.Constructor.Args, f) + } + for i := range constructor.Signature().Results().Len() { + typ := constructor.Signature().Results().At(i).Type() + var name string + if i == 0 { + // The first return value must be of type T, *T, or a type whose named + // type is the same as named type of T. + name = varName + } else if i == constructor.Signature().Results().Len()-1 && types.Identical(typ, errorType) { + name = "err" + } else { + // Drop any return values beyond the first and the last. + // e.g., "f, _, _, err := NewFoo()". + name = "_" + } + data.Receiver.Constructor.Results = append(data.Receiver.Constructor.Results, field{ + Name: name, + Type: types.TypeString(typ, qf), + }) + } + } + } + + // Resolves duplicate parameter names between the function and its + // receiver's constructor. It adds prefix to the constructor's parameters + // until no conflicts remain. + if data.Receiver != nil && data.Receiver.Constructor != nil { + seen := map[string]bool{} + for _, f := range data.Func.Args { + if f.Name == "" { + continue + } + seen[f.Name] = true + } + + // "" for no change, "c" for constructor, "i" for input. + for _, prefix := range []string{"", "c", "c_", "i", "i_"} { + conflict := false + for _, f := range data.Receiver.Constructor.Args { + if f.Name == "" { + continue + } + if seen[prefix+f.Name] { + conflict = true + break + } + } + if !conflict { + for i, f := range data.Receiver.Constructor.Args { + if f.Name == "" { + continue + } + data.Receiver.Constructor.Args[i].Name = prefix + data.Receiver.Constructor.Args[i].Name + } + break + } + } + } + + // Compute edits to update imports. + // + // If we're adding to an existing test file, we need to adjust existing + // imports. Otherwise, we can simply write out the imports to the new file. + if testPGF != nil { + var importFixes []*imports.ImportFix + for path, name := range extraImports { + importFixes = append(importFixes, &imports.ImportFix{ + StmtInfo: imports.ImportInfo{ + ImportPath: path, + Name: name, + }, + FixType: imports.AddImport, + }) + } + importEdits, err := ComputeImportFixEdits(snapshot.Options().Local, testPGF.Src, importFixes...) + if err != nil { + return nil, fmt.Errorf("could not compute the import fix edits: %w", err) + } + edits = append(edits, importEdits...) + } else { + var importsBuffer bytes.Buffer + if len(extraImports) == 1 { + importsBuffer.WriteString("\nimport ") + for path, name := range extraImports { + if name != "" { + importsBuffer.WriteString(name + " ") + } + importsBuffer.WriteString(fmt.Sprintf("\"%s\"\n", path)) + } + } else { + importsBuffer.WriteString("\nimport(") + // Loop over the map in sorted order ensures deterministic outcome. + paths := make([]string, 0, len(extraImports)) + for key := range extraImports { + paths = append(paths, key) + } + sort.Strings(paths) + for _, path := range paths { + importsBuffer.WriteString("\n\t") + if name := extraImports[path]; name != "" { + importsBuffer.WriteString(name + " ") + } + importsBuffer.WriteString(fmt.Sprintf("\"%s\"", path)) + } + importsBuffer.WriteString("\n)\n") + } + edits = append(edits, protocol.TextEdit{ + Range: protocol.Range{}, + NewText: importsBuffer.String(), }) } @@ -351,10 +753,16 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol. return nil, err } - edits = append(edits, protocol.TextEdit{ - Range: eofRange, - NewText: test.String(), - }) + formatted, err := format.Source(test.Bytes()) + if err != nil { + return nil, err + } + + edits = append(edits, + protocol.TextEdit{ + Range: eofRange, + NewText: string(formatted), + }) return append(changes, protocol.DocumentChangeEdit(testFH, edits)), nil } diff --git a/gopls/internal/golang/assembly.go b/gopls/internal/golang/assembly.go index 63d8e82d9fd..7f0ace4daf6 100644 --- a/gopls/internal/golang/assembly.go +++ b/gopls/internal/golang/assembly.go @@ -16,13 +16,11 @@ import ( "context" "fmt" "html" - "path/filepath" "regexp" "strconv" "strings" "golang.org/x/tools/gopls/internal/cache" - "golang.org/x/tools/internal/gocommand" ) // AssemblyHTML returns an HTML document containing an assembly listing of the selected function. @@ -32,11 +30,7 @@ import ( // - cross-link jumps and block labels, like github.com/aclements/objbrowse. func AssemblyHTML(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, symbol string, web Web) ([]byte, error) { // Compile the package with -S, and capture its stderr stream. - inv, cleanupInvocation, err := snapshot.GoCommandInvocation(false, &gocommand.Invocation{ - Verb: "build", - Args: []string{"-gcflags=-S", "."}, - WorkingDir: filepath.Dir(pkg.Metadata().CompiledGoFiles[0].Path()), - }) + inv, cleanupInvocation, err := snapshot.GoCommandInvocation(cache.NoNetwork, pkg.Metadata().CompiledGoFiles[0].DirPath(), "build", []string{"-gcflags=-S", "."}) if err != nil { return nil, err // e.g. failed to write overlays (rare) } diff --git a/gopls/internal/golang/change_signature.go b/gopls/internal/golang/change_signature.go index 41c56ba6c2c..8157c6d03fb 100644 --- a/gopls/internal/golang/change_signature.go +++ b/gopls/internal/golang/change_signature.go @@ -20,6 +20,7 @@ import ( "golang.org/x/tools/gopls/internal/cache/parsego" "golang.org/x/tools/gopls/internal/file" "golang.org/x/tools/gopls/internal/protocol" + goplsastutil "golang.org/x/tools/gopls/internal/util/astutil" "golang.org/x/tools/gopls/internal/util/bug" "golang.org/x/tools/gopls/internal/util/safetoken" "golang.org/x/tools/imports" @@ -30,21 +31,121 @@ import ( "golang.org/x/tools/internal/typesinternal" ) -// RemoveUnusedParameter computes a refactoring to remove the parameter -// indicated by the given range, which must be contained within an unused -// parameter name or field. +// Changing a signature works as follows, supposing we have the following +// original function declaration: // -// This operation is a work in progress. Remaining TODO: -// - Handle function assignment correctly. -// - Improve the extra newlines in output. -// - Stream type checking via ForEachPackage. -// - Avoid unnecessary additional type checking. -func RemoveUnusedParameter(ctx context.Context, fh file.Handle, rng protocol.Range, snapshot *cache.Snapshot) ([]protocol.DocumentChange, error) { +// func Foo(a, b, c int) +// +// Step 1: Write the declaration according to the given signature change. For +// example, given the parameter transformation [2, 0, 1], we construct a new +// ast.FuncDecl for the signature: +// +// func Foo0(c, a, b int) +// +// Step 2: Build a wrapper function that delegates to the new function. +// With this example, the wrapper would look like this: +// +// func Foo1(a, b, c int) { +// Foo0(c, a, b int) +// } +// +// Step 3: Swap in the wrapper for the original, and inline all calls. The +// trick here is to rename Foo1 to Foo, inline all calls (replacing them with +// a call to Foo0), and then rename Foo0 back to Foo, using a simple string +// replacement. +// +// For example, given a call +// +// func _() { +// Foo(1, 2, 3) +// } +// +// The inlining results in +// +// func _() { +// Foo0(3, 1, 2) +// } +// +// And then renaming results in +// +// func _() { +// Foo(3, 1, 2) +// } +// +// And the desired signature rewriting has occurred! Note: in practice, we +// don't use the names Foo0 and Foo1, as they are too likely to conflict with +// an existing declaration name. (Instead, we use the prefix G_o_ + p_l_s) +// +// The advantage of going through the inliner is that we get all of the +// semantic considerations for free: the inliner will check for side effects +// of arguments, check if the last use of a variable is being removed, check +// for unnecessary imports, etc. +// +// Furthermore, by running the change signature rewriting through the inliner, +// we ensure that the inliner gets better to the point that it can handle a +// change signature rewrite just as well as if we had implemented change +// signature as its own operation. For example, suppose we support reordering +// the results of a function. In that case, the wrapper would be: +// +// func Foo1() (int, int) { +// y, x := Foo0() +// return x, y +// } +// +// And a call would be rewritten from +// +// x, y := Foo() +// +// To +// +// r1, r2 := Foo() +// x, y := r2, r1 +// +// In order to make this idiomatic, we'd have to teach the inliner to rewrite +// this as y, x := Foo(). The simplest and most general way to achieve this is +// to teach the inliner to recognize when a variable is redundant (r1 and r2, +// in this case), lifting declarations. That's probably a very useful skill for +// the inliner to have. + +// removeParam computes a refactoring to remove the parameter indicated by the +// given range. +func removeParam(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, rng protocol.Range) ([]protocol.DocumentChange, error) { pkg, pgf, err := NarrowestPackageForFile(ctx, snapshot, fh.URI()) if err != nil { return nil, err } + // Find the unused parameter to remove. + info := findParam(pgf, rng) + if info == nil || info.paramIndex == -1 { + return nil, fmt.Errorf("no param found") + } + // Write a transformation to remove the param. + var newParams []int + for i := 0; i < info.decl.Type.Params.NumFields(); i++ { + if i != info.paramIndex { + newParams = append(newParams, i) + } + } + return ChangeSignature(ctx, snapshot, pkg, pgf, rng, newParams) +} +// ChangeSignature computes a refactoring to update the signature according to +// the provided parameter transformation, for the signature definition +// surrounding rng. +// +// newParams expresses the new parameters for the signature in terms of the old +// parameters. Each entry in newParams is the index of the new parameter in the +// original parameter list. For example, given func Foo(a, b, c int) and newParams +// [2, 0, 1], the resulting changed signature is Foo(c, a, b int). If newParams +// omits an index of the original signature, that parameter is removed. +// +// This operation is a work in progress. Remaining TODO: +// - Handle adding parameters. +// - Handle adding/removing/reordering results. +// - Improve the extra newlines in output. +// - Stream type checking via ForEachPackage. +// - Avoid unnecessary additional type checking. +func ChangeSignature(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, pgf *parsego.File, rng protocol.Range, newParams []int) ([]protocol.DocumentChange, error) { // Changes to our heuristics for whether we can remove a parameter must also // be reflected in the canRemoveParameter helper. if perrors, terrors := pkg.ParseErrors(), pkg.TypeErrors(); len(perrors) > 0 || len(terrors) > 0 { @@ -57,69 +158,135 @@ func RemoveUnusedParameter(ctx context.Context, fh file.Handle, rng protocol.Ran return nil, fmt.Errorf("can't change signatures for packages with parse or type errors: (e.g. %s)", sample) } - info, err := findParam(pgf, rng) - if err != nil { - return nil, err // e.g. invalid range + info := findParam(pgf, rng) + if info == nil || info.decl == nil { + return nil, fmt.Errorf("failed to find declaration") } - if info.field == nil { - return nil, fmt.Errorf("failed to find field") + + // Step 1: create the new declaration, which is a copy of the original decl + // with the rewritten signature. + + // Flatten, transform and regroup fields, using the flatField intermediate + // representation. A flatField is the result of flattening an *ast.FieldList + // along with type information. + type flatField struct { + name string // empty if the field is unnamed + typeExpr ast.Expr + typ types.Type } - // Create the new declaration, which is a copy of the original decl with the - // unnecessary parameter removed. - newDecl := internalastutil.CloneNode(info.decl) - if info.name != nil { - names := remove(newDecl.Type.Params.List[info.fieldIndex].Names, info.nameIndex) - newDecl.Type.Params.List[info.fieldIndex].Names = names + var newParamFields []flatField + for id, field := range goplsastutil.FlatFields(info.decl.Type.Params) { + typ := pkg.TypesInfo().TypeOf(field.Type) + if typ == nil { + return nil, fmt.Errorf("missing field type for field #%d", len(newParamFields)) + } + field := flatField{ + typeExpr: field.Type, + typ: typ, + } + if id != nil { + field.name = id.Name + } + newParamFields = append(newParamFields, field) + } + + // Select the new parameter fields. + newParamFields, ok := selectElements(newParamFields, newParams) + if !ok { + return nil, fmt.Errorf("failed to apply parameter transformation %v", newParams) } - if len(newDecl.Type.Params.List[info.fieldIndex].Names) == 0 { - // Unnamed, or final name was removed: in either case, remove the field. - newDecl.Type.Params.List = remove(newDecl.Type.Params.List, info.fieldIndex) + + // writeFields performs the regrouping of named fields. + writeFields := func(flatFields []flatField) *ast.FieldList { + list := new(ast.FieldList) + for i, f := range flatFields { + var field *ast.Field + if i > 0 && f.name != "" && flatFields[i-1].name != "" && types.Identical(f.typ, flatFields[i-1].typ) { + // Group named fields if they have the same type. + field = list.List[len(list.List)-1] + } else { + // Otherwise, create a new field. + field = &ast.Field{ + Type: internalastutil.CloneNode(f.typeExpr), + } + list.List = append(list.List, field) + } + if f.name != "" { + field.Names = append(field.Names, ast.NewIdent(f.name)) + } + } + return list } - // Compute inputs into building a wrapper function around the modified - // signature. + newDecl := internalastutil.CloneNode(info.decl) + newDecl.Type.Params = writeFields(newParamFields) + + // Step 2: build a wrapper function calling the new declaration. + var ( - params = internalastutil.CloneNode(info.decl.Type.Params) // "_" names will be modified - args []ast.Expr // arguments to delegate + params = internalastutil.CloneNode(info.decl.Type.Params) // parameters of wrapper func: "_" names must be modified + args = make([]ast.Expr, len(newParams)) // arguments to the delegated call variadic = false // whether the signature is variadic ) { - allNames := make(map[string]bool) // for renaming blanks + // Record names used by non-blank parameters, just in case the user had a + // parameter named 'blank0', which would conflict with the synthetic names + // we construct below. + // TODO(rfindley): add an integration test for this behavior. + nonBlankNames := make(map[string]bool) // for detecting conflicts with renamed blanks for _, fld := range params.List { for _, n := range fld.Names { if n.Name != "_" { - allNames[n.Name] = true + nonBlankNames[n.Name] = true } } + if len(fld.Names) == 0 { + // All parameters must have a non-blank name. For convenience, give + // this field a blank name. + fld.Names = append(fld.Names, ast.NewIdent("_")) // will be named below + } + } + // oldParams maps parameters to their argument in the delegated call. + // In other words, it is the inverse of newParams, but it is represented as + // a map rather than a slice, as not every old param need exist in + // newParams. + oldParams := make(map[int]int) + for new, old := range newParams { + oldParams[old] = new } blanks := 0 - for i, fld := range params.List { - for j, n := range fld.Names { - if i == info.fieldIndex && j == info.nameIndex { - continue - } - if n.Name == "_" { - // Create names for blank (_) parameters so the delegating wrapper - // can refer to them. - for { - newName := fmt.Sprintf("blank%d", blanks) - blanks++ - if !allNames[newName] { - n.Name = newName - break - } + paramIndex := 0 // global param index. + for id, field := range goplsastutil.FlatFields(params) { + argIndex, ok := oldParams[paramIndex] + paramIndex++ + if !ok { + continue // parameter is removed + } + if id.Name == "_" { // from above: every field has names + // Create names for blank (_) parameters so the delegating wrapper + // can refer to them. + for { + // These names will not be seen by the user, so give them an + // arbitrary name. + newName := fmt.Sprintf("blank%d", blanks) + blanks++ + if !nonBlankNames[newName] { + id.Name = newName + break } } - args = append(args, &ast.Ident{Name: n.Name}) - if i == len(params.List)-1 { - _, variadic = fld.Type.(*ast.Ellipsis) - } } + args[argIndex] = ast.NewIdent(id.Name) + // Record whether the call has an ellipsis. + // (Only the last loop iteration matters.) + _, variadic = field.Type.(*ast.Ellipsis) } } - // Rewrite all referring calls. + // Step 3: Rewrite all referring calls, by swapping in the wrapper and + // inlining all. + newContent, err := rewriteCalls(ctx, signatureRewrite{ snapshot: snapshot, pkg: pkg, @@ -239,24 +406,23 @@ func rewriteSignature(fset *token.FileSet, declIdx int, src0 []byte, newDecl *as // paramInfo records information about a param identified by a position. type paramInfo struct { decl *ast.FuncDecl // enclosing func decl (non-nil) - fieldIndex int // index of Field in Decl.Type.Params, or -1 + paramIndex int // index of param among all params, or -1 field *ast.Field // enclosing field of Decl, or nil if range not among parameters - nameIndex int // index of Name in Field.Names, or nil name *ast.Ident // indicated name (either enclosing, or Field.Names[0] if len(Field.Names) == 1) } // findParam finds the parameter information spanned by the given range. -func findParam(pgf *parsego.File, rng protocol.Range) (*paramInfo, error) { +func findParam(pgf *parsego.File, rng protocol.Range) *paramInfo { + info := paramInfo{paramIndex: -1} start, end, err := pgf.RangePos(rng) if err != nil { - return nil, err + return nil } path, _ := astutil.PathEnclosingInterval(pgf.File, start, end) var ( id *ast.Ident field *ast.Field - decl *ast.FuncDecl ) // Find the outermost enclosing node of each kind, whether or not they match // the semantics described in the docstring. @@ -267,37 +433,45 @@ func findParam(pgf *parsego.File, rng protocol.Range) (*paramInfo, error) { case *ast.Field: field = n case *ast.FuncDecl: - decl = n + info.decl = n } } - // Check the conditions described in the docstring. - if decl == nil { - return nil, fmt.Errorf("range is not within a function declaration") + if info.decl == nil { + return nil } - info := ¶mInfo{ - fieldIndex: -1, - nameIndex: -1, - decl: decl, + if field == nil { + return &info } - for fi, f := range decl.Type.Params.List { + pi := 0 + // Search for field and id among parameters of decl. + // This search may fail, even if one or both of id and field are non nil: + // field could be from a result or local declaration, and id could be part of + // the field type rather than names. + for _, f := range info.decl.Type.Params.List { if f == field { - info.fieldIndex = fi + info.paramIndex = pi // may be modified later info.field = f - for ni, n := range f.Names { + for _, n := range f.Names { if n == id { - info.nameIndex = ni + info.paramIndex = pi info.name = n break } + pi++ } if info.name == nil && len(info.field.Names) == 1 { - info.nameIndex = 0 info.name = info.field.Names[0] } break + } else { + m := len(f.Names) + if m == 0 { + m = 1 + } + pi += m } } - return info, nil + return &info } // signatureRewrite defines a rewritten function signature. @@ -467,7 +641,11 @@ func rewriteCalls(ctx context.Context, rw signatureRewrite) (map[protocol.Docume } post := func(got []byte) []byte { return bytes.ReplaceAll(got, []byte(tag), nil) } - return inlineAllCalls(ctx, logf, rw.snapshot, rw.pkg, rw.pgf, rw.origDecl, calleeInfo, post) + opts := &inline.Options{ + Logf: logf, + IgnoreEffects: true, + } + return inlineAllCalls(ctx, rw.snapshot, rw.pkg, rw.pgf, rw.origDecl, calleeInfo, post, opts) } // reTypeCheck re-type checks orig with new file contents defined by fileMask. @@ -581,6 +759,22 @@ func remove[T any](s []T, i int) []T { return append(s[:i], s[i+1:]...) } +// selectElements returns a new array of elements of s indicated by the +// provided list of indices. It returns false if any index was out of bounds. +// +// For example, given the slice []string{"a", "b", "c", "d"}, the +// indices []int{3, 0, 1} results in the slice []string{"d", "a", "b"}. +func selectElements[T any](s []T, indices []int) ([]T, bool) { + res := make([]T, len(indices)) + for i, index := range indices { + if index < 0 || index >= len(s) { + return nil, false + } + res[i] = s[index] + } + return res, true +} + // replaceFileDecl replaces old with new in the file described by pgf. // // TODO(rfindley): generalize, and combine with rewriteSignature. diff --git a/gopls/internal/golang/codeaction.go b/gopls/internal/golang/codeaction.go index 3e4f3113f9e..0a778ba758b 100644 --- a/gopls/internal/golang/codeaction.go +++ b/gopls/internal/golang/codeaction.go @@ -236,7 +236,8 @@ var codeActionProducers = [...]codeActionProducer{ {kind: settings.RefactorExtractFunction, fn: refactorExtractFunction}, {kind: settings.RefactorExtractMethod, fn: refactorExtractMethod}, {kind: settings.RefactorExtractToNewFile, fn: refactorExtractToNewFile}, - {kind: settings.RefactorExtractVariable, fn: refactorExtractVariable}, + {kind: settings.RefactorExtractConstant, fn: refactorExtractVariable, needPkg: true}, + {kind: settings.RefactorExtractVariable, fn: refactorExtractVariable, needPkg: true}, {kind: settings.RefactorInlineCall, fn: refactorInlineCall, needPkg: true}, {kind: settings.RefactorRewriteChangeQuote, fn: refactorRewriteChangeQuote}, {kind: settings.RefactorRewriteFillStruct, fn: refactorRewriteFillStruct, needPkg: true}, @@ -244,6 +245,8 @@ var codeActionProducers = [...]codeActionProducer{ {kind: settings.RefactorRewriteInvertIf, fn: refactorRewriteInvertIf}, {kind: settings.RefactorRewriteJoinLines, fn: refactorRewriteJoinLines, needPkg: true}, {kind: settings.RefactorRewriteRemoveUnusedParam, fn: refactorRewriteRemoveUnusedParam, needPkg: true}, + {kind: settings.RefactorRewriteMoveParamLeft, fn: refactorRewriteMoveParamLeft, needPkg: true}, + {kind: settings.RefactorRewriteMoveParamRight, fn: refactorRewriteMoveParamRight, needPkg: true}, {kind: settings.RefactorRewriteSplitLines, fn: refactorRewriteSplitLines, needPkg: true}, // Note: don't forget to update the allow-list in Server.CodeAction @@ -302,7 +305,7 @@ func quickFix(ctx context.Context, req *codeActionsRequest) error { continue } - msg := typeError.Error() + msg := typeError.Msg switch { // "Missing method" error? (stubmethods) // Offer a "Declare missing methods of INTERFACE" code action. @@ -329,6 +332,17 @@ func quickFix(ctx context.Context, req *codeActionsRequest) error { msg := fmt.Sprintf("Declare missing method %s.%s", si.Receiver.Obj().Name(), si.MethodName) req.addApplyFixAction(msg, fixMissingCalledFunction, req.loc) } + + // "undeclared name: x" or "undefined: x" compiler error. + // Offer a "Create variable/function x" code action. + // See [fixUndeclared] for command implementation. + case strings.HasPrefix(msg, "undeclared name: "), + strings.HasPrefix(msg, "undefined: "): + path, _ := astutil.PathEnclosingInterval(req.pgf.File, start, end) + title := undeclaredFixTitle(path, msg) + if title != "" { + req.addApplyFixAction(title, fixCreateUndeclared, req.loc) + } } } @@ -449,11 +463,25 @@ func refactorExtractMethod(ctx context.Context, req *codeActionsRequest) error { return nil } -// refactorExtractVariable produces "Extract variable" code actions. +// refactorExtractVariable produces "Extract variable|constant" code actions. // See [extractVariable] for command implementation. func refactorExtractVariable(ctx context.Context, req *codeActionsRequest) error { - if _, _, ok, _ := canExtractVariable(req.start, req.end, req.pgf.File); ok { - req.addApplyFixAction("Extract variable", fixExtractVariable, req.loc) + info := req.pkg.TypesInfo() + if expr, _, err := canExtractVariable(info, req.pgf.File, req.start, req.end); err == nil { + // Offer one of refactor.extract.{constant,variable} + // based on the constness of the expression; this is a + // limitation of the codeActionProducers mechanism. + // Beware that future evolutions of the refactorings + // may make them diverge to become non-complementary, + // for example because "if const x = ...; y {" is illegal. + constant := info.Types[expr].Value != nil + if (req.kind == settings.RefactorExtractConstant) == constant { + title := "Extract variable" + if constant { + title = "Extract constant" + } + req.addApplyFixAction(title, fixExtractVariable, req.loc) + } } return nil } @@ -468,14 +496,9 @@ func refactorExtractToNewFile(ctx context.Context, req *codeActionsRequest) erro return nil } -// addTest produces "Add a test for FUNC" code actions. +// addTest produces "Add test for FUNC" code actions. // See [server.commandHandler.AddTest] for command implementation. func addTest(ctx context.Context, req *codeActionsRequest) error { - // Reject if the feature is turned off. - if !req.snapshot.Options().AddTestSourceCodeAction { - return nil - } - // Reject test package. if req.pkg.Metadata().ForTest != "" { return nil @@ -496,26 +519,98 @@ func addTest(ctx context.Context, req *codeActionsRequest) error { return nil } - cmd := command.NewAddTestCommand("Add a test for "+decl.Name.String(), req.loc) + // TODO(hxjiang): support functions with type parameter. + if decl.Type.TypeParams != nil { + return nil + } + + cmd := command.NewAddTestCommand("Add test for "+decl.Name.String(), req.loc) req.addCommandAction(cmd, true) // TODO(hxjiang): add code action for generate test for package/file. return nil } +// identityTransform returns a change signature transformation that leaves the +// given fieldlist unmodified. +func identityTransform(fields *ast.FieldList) []command.ChangeSignatureParam { + var id []command.ChangeSignatureParam + for i := 0; i < fields.NumFields(); i++ { + id = append(id, command.ChangeSignatureParam{OldIndex: i}) + } + return id +} + // refactorRewriteRemoveUnusedParam produces "Remove unused parameter" code actions. // See [server.commandHandler.ChangeSignature] for command implementation. func refactorRewriteRemoveUnusedParam(ctx context.Context, req *codeActionsRequest) error { - if canRemoveParameter(req.pkg, req.pgf, req.loc.Range) { - cmd := command.NewChangeSignatureCommand("Refactor: remove unused parameter", command.ChangeSignatureArgs{ - RemoveParameter: req.loc, - ResolveEdits: req.resolveEdits(), + if info := removableParameter(req.pkg, req.pgf, req.loc.Range); info != nil { + var transform []command.ChangeSignatureParam + for i := 0; i < info.decl.Type.Params.NumFields(); i++ { + if i != info.paramIndex { + transform = append(transform, command.ChangeSignatureParam{OldIndex: i}) + } + } + cmd := command.NewChangeSignatureCommand("Remove unused parameter", command.ChangeSignatureArgs{ + Location: req.loc, + NewParams: transform, + NewResults: identityTransform(info.decl.Type.Results), + ResolveEdits: req.resolveEdits(), }) req.addCommandAction(cmd, true) } return nil } +func refactorRewriteMoveParamLeft(ctx context.Context, req *codeActionsRequest) error { + if info := findParam(req.pgf, req.loc.Range); info != nil && + info.paramIndex > 0 && + !is[*ast.Ellipsis](info.field.Type) { + + // ^^ we can't currently handle moving a variadic param. + // TODO(rfindley): implement. + + transform := identityTransform(info.decl.Type.Params) + transform[info.paramIndex] = command.ChangeSignatureParam{OldIndex: info.paramIndex - 1} + transform[info.paramIndex-1] = command.ChangeSignatureParam{OldIndex: info.paramIndex} + cmd := command.NewChangeSignatureCommand("Move parameter left", command.ChangeSignatureArgs{ + Location: req.loc, + NewParams: transform, + NewResults: identityTransform(info.decl.Type.Results), + ResolveEdits: req.resolveEdits(), + }) + + req.addCommandAction(cmd, true) + } + return nil +} + +func refactorRewriteMoveParamRight(ctx context.Context, req *codeActionsRequest) error { + if info := findParam(req.pgf, req.loc.Range); info != nil && info.paramIndex >= 0 { + params := info.decl.Type.Params + nparams := params.NumFields() + if info.paramIndex < nparams-1 { // not the last param + if info.paramIndex == nparams-2 && is[*ast.Ellipsis](params.List[len(params.List)-1].Type) { + // We can't currently handle moving a variadic param. + // TODO(rfindley): implement. + return nil + } + + transform := identityTransform(info.decl.Type.Params) + transform[info.paramIndex] = command.ChangeSignatureParam{OldIndex: info.paramIndex + 1} + transform[info.paramIndex+1] = command.ChangeSignatureParam{OldIndex: info.paramIndex} + cmd := command.NewChangeSignatureCommand("Move parameter right", command.ChangeSignatureArgs{ + Location: req.loc, + NewParams: transform, + NewResults: identityTransform(info.decl.Type.Results), + ResolveEdits: req.resolveEdits(), + }) + req.addCommandAction(cmd, true) + } + } + return nil +} + // refactorRewriteChangeQuote produces "Convert to {raw,interpreted} string literal" code actions. func refactorRewriteChangeQuote(ctx context.Context, req *codeActionsRequest) error { convertStringLiteral(req) @@ -582,10 +677,10 @@ func refactorRewriteFillSwitch(ctx context.Context, req *codeActionsRequest) err return nil } -// canRemoveParameter reports whether we can remove the function parameter -// indicated by the given [start, end) range. +// removableParameter returns paramInfo about a removable parameter indicated +// by the given [start, end) range, or nil if no such removal is available. // -// This is true if: +// Removing a parameter is possible if // - there are no parse or type errors, and // - [start, end) is contained within an unused field or parameter name // - ... of a non-method function declaration. @@ -594,33 +689,30 @@ func refactorRewriteFillSwitch(ctx context.Context, req *codeActionsRequest) err // much more precisely, allowing it to report its findings as diagnostics.) // // TODO(adonovan): inline into refactorRewriteRemoveUnusedParam. -func canRemoveParameter(pkg *cache.Package, pgf *parsego.File, rng protocol.Range) bool { +func removableParameter(pkg *cache.Package, pgf *parsego.File, rng protocol.Range) *paramInfo { if perrors, terrors := pkg.ParseErrors(), pkg.TypeErrors(); len(perrors) > 0 || len(terrors) > 0 { - return false // can't remove parameters from packages with errors + return nil // can't remove parameters from packages with errors } - info, err := findParam(pgf, rng) - if err != nil { - return false // e.g. invalid range - } - if info.field == nil { - return false // range does not span a parameter + info := findParam(pgf, rng) + if info == nil || info.field == nil { + return nil // range does not span a parameter } if info.decl.Body == nil { - return false // external function + return nil // external function } if len(info.field.Names) == 0 { - return true // no names => field is unused + return info // no names => field is unused } if info.name == nil { - return false // no name is indicated + return nil // no name is indicated } if info.name.Name == "_" { - return true // trivially unused + return info // trivially unused } obj := pkg.TypesInfo().Defs[info.name] if obj == nil { - return false // something went wrong + return nil // something went wrong } used := false @@ -630,7 +722,10 @@ func canRemoveParameter(pkg *cache.Package, pgf *parsego.File, rng protocol.Rang } return !used // keep going until we find a use }) - return !used + if used { + return nil + } + return info } // refactorInlineCall produces "Inline call to FUNC" code actions. diff --git a/gopls/internal/golang/comment.go b/gopls/internal/golang/comment.go index e1d154feac5..b7ff45037d8 100644 --- a/gopls/internal/golang/comment.go +++ b/gopls/internal/golang/comment.go @@ -12,6 +12,7 @@ import ( "go/doc/comment" "go/token" "go/types" + "slices" "strings" "golang.org/x/tools/gopls/internal/cache" @@ -19,6 +20,7 @@ import ( "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/settings" "golang.org/x/tools/gopls/internal/util/astutil" + "golang.org/x/tools/gopls/internal/util/bug" "golang.org/x/tools/gopls/internal/util/safetoken" ) @@ -154,6 +156,18 @@ func lookupDocLinkSymbol(pkg *cache.Package, pgf *parsego.File, name string) typ // Try treating the prefix as a package name, // allowing for non-renaming and renaming imports. fileScope := pkg.TypesInfo().Scopes[pgf.File] + if fileScope == nil { + // This is theoretically possible if pgf is a GoFile but not a + // CompiledGoFile. However, we do not know how to produce such a package + // without using an external GoPackagesDriver. + // See if this is the source of golang/go#70635 + if slices.Contains(pkg.CompiledGoFiles(), pgf) { + bug.Reportf("missing file scope for compiled file") + } else { + bug.Reportf("missing file scope for non-compiled file") + } + return nil + } pkgname, ok := fileScope.Lookup(prefix).(*types.PkgName) // ok => prefix is imported name if !ok { // Handle renaming import, e.g. diff --git a/gopls/internal/golang/completion/completion.go b/gopls/internal/golang/completion/completion.go index 6bf8ad8acde..7b4abe774a4 100644 --- a/gopls/internal/golang/completion/completion.go +++ b/gopls/internal/golang/completion/completion.go @@ -38,6 +38,7 @@ import ( "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/settings" goplsastutil "golang.org/x/tools/gopls/internal/util/astutil" + "golang.org/x/tools/gopls/internal/util/bug" "golang.org/x/tools/gopls/internal/util/safetoken" "golang.org/x/tools/gopls/internal/util/typesutil" "golang.org/x/tools/internal/event" @@ -134,6 +135,33 @@ func (i *CompletionItem) Snippet() string { return i.InsertText } +// addConversion wraps the existing completionItem in a conversion expression. +// Only affects the receiver's InsertText and snippet fields, not the Label. +// An empty conv argument has no effect. +func (i *CompletionItem) addConversion(c *completer, conv conversionEdits) error { + if conv.prefix != "" { + // If we are in a selector, add an edit to place prefix before selector. + if sel := enclosingSelector(c.path, c.pos); sel != nil { + edits, err := c.editText(sel.Pos(), sel.Pos(), conv.prefix) + if err != nil { + return err + } + i.AdditionalTextEdits = append(i.AdditionalTextEdits, edits...) + } else { + // If there is no selector, just stick the prefix at the start. + i.InsertText = conv.prefix + i.InsertText + i.snippet.PrependText(conv.prefix) + } + } + + if conv.suffix != "" { + i.InsertText += conv.suffix + i.snippet.WriteText(conv.suffix) + } + + return nil +} + // Scoring constants are used for weighting the relevance of different candidates. const ( // stdScore is the base score for all completion items. @@ -682,7 +710,7 @@ func (c *completer) collectCompletions(ctx context.Context) error { } // Struct literals are handled entirely separately. - if c.wantStructFieldCompletions() { + if wantStructFieldCompletions(c.enclosingCompositeLiteral) { // If we are definitely completing a struct field name, deep completions // don't make sense. if c.enclosingCompositeLiteral.inKey { @@ -1133,12 +1161,11 @@ func (c *completer) addFieldItems(fields *ast.FieldList) { } } -func (c *completer) wantStructFieldCompletions() bool { - clInfo := c.enclosingCompositeLiteral - if clInfo == nil { +func wantStructFieldCompletions(enclosingCl *compLitInfo) bool { + if enclosingCl == nil { return false } - return is[*types.Struct](clInfo.clType) && (clInfo.inKey || clInfo.maybeInFieldName) + return is[*types.Struct](enclosingCl.clType) && (enclosingCl.inKey || enclosingCl.maybeInFieldName) } func (c *completer) wantTypeName() bool { @@ -2039,8 +2066,7 @@ func enclosingFunction(path []ast.Node, info *types.Info) *funcInfo { return nil } -func (c *completer) expectedCompositeLiteralType() types.Type { - clInfo := c.enclosingCompositeLiteral +func expectedCompositeLiteralType(clInfo *compLitInfo, pos token.Pos) types.Type { switch t := clInfo.clType.(type) { case *types.Slice: if clInfo.inKey { @@ -2079,7 +2105,7 @@ func (c *completer) expectedCompositeLiteralType() types.Type { // The order of the literal fields must match the order in the struct definition. // Find the element that the position belongs to and suggest that field's type. - if i := exprAtPos(c.pos, clInfo.cl.Elts); i < t.NumFields() { + if i := exprAtPos(pos, clInfo.cl.Elts); i < t.NumFields() { return t.Field(i).Type() } } @@ -2164,6 +2190,25 @@ type candidateInference struct { // convertibleTo is a type our candidate type must be convertible to. convertibleTo types.Type + // needsExactType is true if the candidate type must be exactly the type of + // the objType, e.g. an interface rather than it's implementors. + // + // This is necessary when objType is derived using reverse type inference: + // any different (but assignable) type may lead to different type inference, + // which may no longer be valid. + // + // For example, consider the following scenario: + // + // func f[T any](x T) []T { return []T{x} } + // + // var s []any = f(_) + // + // Reverse type inference would infer that the type at _ must be 'any', but + // that does not mean that any object in the lexical scope is valid: the type of + // the object must be *exactly* any, otherwise type inference will cause the + // slice assignment to fail. + needsExactType bool + // typeName holds information about the expected type name at // position, if any. typeName typeNameInference @@ -2235,7 +2280,7 @@ func expectedCandidate(ctx context.Context, c *completer) (inf candidateInferenc inf.typeName = expectTypeName(c) if c.enclosingCompositeLiteral != nil { - inf.objType = c.expectedCompositeLiteralType() + inf.objType = expectedCompositeLiteralType(c.enclosingCompositeLiteral, c.pos) } Nodes: @@ -2259,34 +2304,21 @@ Nodes: break Nodes } case *ast.AssignStmt: - // Only rank completions if you are on the right side of the token. - if c.pos > node.TokPos { - i := exprAtPos(c.pos, node.Rhs) - if i >= len(node.Lhs) { - i = len(node.Lhs) - 1 - } - if tv, ok := c.pkg.TypesInfo().Types[node.Lhs[i]]; ok { - inf.objType = tv.Type - } - - // If we have a single expression on the RHS, record the LHS - // assignees so we can favor multi-return function calls with - // matching result values. - if len(node.Rhs) <= 1 { - for _, lhs := range node.Lhs { - inf.assignees = append(inf.assignees, c.pkg.TypesInfo().TypeOf(lhs)) - } - } else { - // Otherwise, record our single assignee, even if its type is - // not available. We use this info to downrank functions - // with the wrong number of result values. - inf.assignees = append(inf.assignees, c.pkg.TypesInfo().TypeOf(node.Lhs[i])) - } - } + objType, assignees := expectedAssignStmtTypes(c.pkg, node, c.pos) + inf.objType = objType + inf.assignees = assignees return inf case *ast.ValueSpec: - if node.Type != nil && c.pos > node.Type.End() { - inf.objType = c.pkg.TypesInfo().TypeOf(node.Type) + inf.objType = expectedValueSpecType(c.pkg, node, c.pos) + return + case *ast.ReturnStmt: + if c.enclosingFunc != nil { + inf.objType = expectedReturnStmtType(c.enclosingFunc.sig, node, c.pos) + } + return inf + case *ast.SendStmt: + if typ := expectedSendStmtType(c.pkg, node, c.pos); typ != nil { + inf.objType = typ } return inf case *ast.CallExpr: @@ -2299,22 +2331,27 @@ Nodes: break Nodes } - sig, _ := c.pkg.TypesInfo().Types[node.Fun].Type.(*types.Signature) - - if sig != nil && sig.TypeParams().Len() > 0 { - // If we are completing a generic func call, re-check the call expression. - // This allows type param inference to work in cases like: - // - // func foo[T any](T) {} - // foo[int](<>) // <- get "int" completions instead of "T" - // - // TODO: remove this after https://go.dev/issue/52503 - info := &types.Info{Types: make(map[ast.Expr]types.TypeAndValue)} - types.CheckExpr(c.pkg.FileSet(), c.pkg.Types(), node.Fun.Pos(), node.Fun, info) - sig, _ = info.Types[node.Fun].Type.(*types.Signature) - } + if sig, ok := c.pkg.TypesInfo().Types[node.Fun].Type.(*types.Signature); ok { + // Out of bounds arguments get no inference completion. + if !sig.Variadic() && exprAtPos(c.pos, node.Args) >= sig.Params().Len() { + return inf + } + + if sig.TypeParams().Len() > 0 { + targs := c.getTypeArgs(node) + res := inferExpectedResultTypes(c, i) + substs := reverseInferTypeArgs(sig, targs, res) + inst := instantiate(sig, substs) + if inst != nil { + // TODO(jacobz): If partial signature instantiation becomes possible, + // make needsExactType only true if necessary. + // Currently, ambigious cases always resolve to a conversion expression + // wrapping the completion, which is occassionally superfluous. + inf.needsExactType = true + sig = inst + } + } - if sig != nil { inf = c.expectedCallParamType(inf, node, sig) } @@ -2344,17 +2381,6 @@ Nodes: return inf } - case *ast.ReturnStmt: - if c.enclosingFunc != nil { - sig := c.enclosingFunc.sig - // Find signature result that corresponds to our return statement. - if resultIdx := exprAtPos(c.pos, node.Results); resultIdx < len(node.Results) { - if resultIdx < sig.Results().Len() { - inf.objType = sig.Results().At(resultIdx).Type() - } - } - } - return inf case *ast.CaseClause: if swtch, ok := findSwitchStmt(c.path[i+1:], c.pos, node).(*ast.SwitchStmt); ok { if tv, ok := c.pkg.TypesInfo().Types[swtch.Tag]; ok { @@ -2398,6 +2424,9 @@ Nodes: inf.objType = ct inf.typeName.wantTypeName = true inf.typeName.isTypeParam = true + if typ := c.inferExpectedTypeArg(i+1, 0); typ != nil { + inf.objType = typ + } } } } @@ -2405,20 +2434,14 @@ Nodes: case *ast.IndexListExpr: if node.Lbrack < c.pos && c.pos <= node.Rbrack { if tv, ok := c.pkg.TypesInfo().Types[node.X]; ok { - if ct := expectedConstraint(tv.Type, exprAtPos(c.pos, node.Indices)); ct != nil { + typeParamIdx := exprAtPos(c.pos, node.Indices) + if ct := expectedConstraint(tv.Type, typeParamIdx); ct != nil { inf.objType = ct inf.typeName.wantTypeName = true inf.typeName.isTypeParam = true - } - } - } - return inf - case *ast.SendStmt: - // Make sure we are on right side of arrow (e.g. "foo <- <>"). - if c.pos > node.Arrow+1 { - if tv, ok := c.pkg.TypesInfo().Types[node.Chan]; ok { - if ch, ok := tv.Type.Underlying().(*types.Chan); ok { - inf.objType = ch.Elem() + if typ := c.inferExpectedTypeArg(i+1, typeParamIdx); typ != nil { + inf.objType = typ + } } } } @@ -2457,6 +2480,267 @@ Nodes: return inf } +// inferExpectedResultTypes takes the index of a call expression within the completion +// path and uses its surroundings to infer the expected result tuple of the call's signature. +// Returns the signature result tuple as a slice, or nil if reverse type inference fails. +// +// # For example +// +// func generic[T any, U any](a T, b U) (T, U) { ... } +// +// var x TypeA +// var y TypeB +// x, y := generic(, ) +// +// inferExpectedResultTypes can determine that the expected result type of the function is (TypeA, TypeB) +func inferExpectedResultTypes(c *completer, callNodeIdx int) []types.Type { + callNode, ok := c.path[callNodeIdx].(*ast.CallExpr) + if !ok { + bug.Reportf("inferExpectedResultTypes given callNodeIndex: %v which is not a ast.CallExpr\n", callNodeIdx) + return nil + } + + if len(c.path) <= callNodeIdx+1 { + return nil + } + + var expectedResults []types.Type + + // Check the parents of the call node to extract the expected result types of the call signature. + // Currently reverse inferences are only supported with the the following parent expressions, + // however this list isn't exhaustive. + switch node := c.path[callNodeIdx+1].(type) { + case *ast.KeyValueExpr: + enclosingCompositeLiteral := enclosingCompositeLiteral(c.path[callNodeIdx:], callNode.Pos(), c.pkg.TypesInfo()) + if !wantStructFieldCompletions(enclosingCompositeLiteral) { + expectedResults = append(expectedResults, expectedCompositeLiteralType(enclosingCompositeLiteral, callNode.Pos())) + } + case *ast.AssignStmt: + objType, assignees := expectedAssignStmtTypes(c.pkg, node, c.pos) + if len(assignees) > 0 { + return assignees + } else if objType != nil { + expectedResults = append(expectedResults, objType) + } + case *ast.ValueSpec: + if resultType := expectedValueSpecType(c.pkg, node, c.pos); resultType != nil { + expectedResults = append(expectedResults, resultType) + } + case *ast.SendStmt: + if resultType := expectedSendStmtType(c.pkg, node, c.pos); resultType != nil { + expectedResults = append(expectedResults, resultType) + } + case *ast.ReturnStmt: + if c.enclosingFunc == nil { + return nil + } + + // As a special case for reverse call inference in + // + // return foo() + // + // Pull the result type from the enclosing function + if exprAtPos(c.pos, node.Results) == 0 { + if callSig := c.pkg.TypesInfo().Types[callNode.Fun].Type.(*types.Signature); callSig != nil { + enclosingResults := c.enclosingFunc.sig.Results() + if callSig.Results().Len() == enclosingResults.Len() { + expectedResults = make([]types.Type, enclosingResults.Len()) + for i := range enclosingResults.Len() { + expectedResults[i] = enclosingResults.At(i).Type() + } + return expectedResults + } + } + } + + if resultType := expectedReturnStmtType(c.enclosingFunc.sig, node, c.pos); resultType != nil { + expectedResults = append(expectedResults, resultType) + } + case *ast.CallExpr: + // TODO(jacobz): This is a difficult case because the normal CallExpr candidateInference + // leans on control flow which is inaccessible in this helper function. + // It would probably take a significant refactor to a recursive solution to make this case + // work cleanly. For now it's unimplemented. + } + return expectedResults +} + +// expectedSendStmtType return the expected type at the position. +// Returns nil if unknown. +func expectedSendStmtType(pkg *cache.Package, node *ast.SendStmt, pos token.Pos) types.Type { + // Make sure we are on right side of arrow (e.g. "foo <- <>"). + if pos > node.Arrow+1 { + if tv, ok := pkg.TypesInfo().Types[node.Chan]; ok { + if ch, ok := tv.Type.Underlying().(*types.Chan); ok { + return ch.Elem() + } + } + } + return nil +} + +// expectedValueSpecType returns the expected type of a ValueSpec at the query +// position. +func expectedValueSpecType(pkg *cache.Package, node *ast.ValueSpec, pos token.Pos) types.Type { + if node.Type != nil && pos > node.Type.End() { + return pkg.TypesInfo().TypeOf(node.Type) + } + return nil +} + +// expectedAssignStmtTypes analyzes the provided assignStmt, and checks +// to see if the provided pos is within a RHS expresison. If so, it report +// the expected type of that expression, and the LHS type(s) to which it +// is being assigned. +func expectedAssignStmtTypes(pkg *cache.Package, node *ast.AssignStmt, pos token.Pos) (objType types.Type, assignees []types.Type) { + // Only rank completions if you are on the right side of the token. + if pos > node.TokPos { + i := exprAtPos(pos, node.Rhs) + if i >= len(node.Lhs) { + i = len(node.Lhs) - 1 + } + if tv, ok := pkg.TypesInfo().Types[node.Lhs[i]]; ok { + objType = tv.Type + } + + // If we have a single expression on the RHS, record the LHS + // assignees so we can favor multi-return function calls with + // matching result values. + if len(node.Rhs) <= 1 { + for _, lhs := range node.Lhs { + assignees = append(assignees, pkg.TypesInfo().TypeOf(lhs)) + } + } else { + // Otherwise, record our single assignee, even if its type is + // not available. We use this info to downrank functions + // with the wrong number of result values. + assignees = append(assignees, pkg.TypesInfo().TypeOf(node.Lhs[i])) + } + } + return objType, assignees +} + +// expectedReturnStmtType returns the expected type of a return statement. +// Returns nil if enclosingSig is nil. +func expectedReturnStmtType(enclosingSig *types.Signature, node *ast.ReturnStmt, pos token.Pos) types.Type { + if enclosingSig != nil { + if resultIdx := exprAtPos(pos, node.Results); resultIdx < enclosingSig.Results().Len() { + return enclosingSig.Results().At(resultIdx).Type() + } + } + return nil +} + +// Returns the number of type arguments in a callExpr +func (c *completer) getTypeArgs(callExpr *ast.CallExpr) []types.Type { + var targs []types.Type + switch fun := callExpr.Fun.(type) { + case *ast.IndexListExpr: + for i := range fun.Indices { + if typ, ok := c.pkg.TypesInfo().Types[fun.Indices[i]]; ok && typeIsValid(typ.Type) { + targs = append(targs, typ.Type) + } + } + case *ast.IndexExpr: + if typ, ok := c.pkg.TypesInfo().Types[fun.Index]; ok && typeIsValid(typ.Type) { + targs = []types.Type{typ.Type} + } + } + return targs +} + +// reverseInferTypeArgs takes a generic signature, a list of passed type arguments, and the expected concrete return types +// inferred from the signature's call site. If possible, it returns a list of types that could be used as the type arguments +// to the signature. If not possible, it returns nil. +// +// Does not panic if any of the arguments are nil. +func reverseInferTypeArgs(sig *types.Signature, typeArgs []types.Type, expectedResults []types.Type) []types.Type { + if len(expectedResults) == 0 || sig == nil || sig.TypeParams().Len() == 0 || sig.Results().Len() != len(expectedResults) { + return nil + } + + tparams := make([]*types.TypeParam, sig.TypeParams().Len()) + for i := range sig.TypeParams().Len() { + tparams[i] = sig.TypeParams().At(i) + } + + for i := len(typeArgs); i < sig.TypeParams().Len(); i++ { + typeArgs = append(typeArgs, nil) + } + + u := newUnifier(tparams, typeArgs) + for i, assignee := range expectedResults { + // Unify does not check the constraints of the type parameters. + // Checks must be applied after. + if !u.unify(sig.Results().At(i).Type(), assignee, unifyModeExact) { + return nil + } + } + + substs := make([]types.Type, sig.TypeParams().Len()) + for i := 0; i < sig.TypeParams().Len(); i++ { + if sub := u.handles[sig.TypeParams().At(i)]; sub != nil && *sub != nil { + // Ensure the inferred subst is assignable to the type parameter's constraint. + if !assignableTo(*sub, sig.TypeParams().At(i).Constraint()) { + return nil + } + substs[i] = *sub + } + } + return substs +} + +// inferExpectedTypeArg gives a type param candidateInference based on the surroundings of it's call site. +// If successful, the inf parameter is returned with only it's objType field updated. +// +// callNodeIdx is the index within the completion path of the type parameter's parent call expression. +// typeParamIdx is the index of the type parameter at the completion pos. +func (c *completer) inferExpectedTypeArg(callNodeIdx int, typeParamIdx int) types.Type { + if len(c.path) <= callNodeIdx { + return nil + } + + callNode, ok := c.path[callNodeIdx].(*ast.CallExpr) + if !ok { + return nil + } + + // Infer the type parameters in a function call based on it's context + sig := c.pkg.TypesInfo().Types[callNode.Fun].Type.(*types.Signature) + expectedResults := inferExpectedResultTypes(c, callNodeIdx) + if typeParamIdx < 0 || typeParamIdx >= sig.TypeParams().Len() { + return nil + } + substs := reverseInferTypeArgs(sig, nil, expectedResults) + if substs == nil || substs[typeParamIdx] == nil { + return nil + } + + return substs[typeParamIdx] +} + +// Instantiates a signature with a set of type parameters. +// Wrapper around types.Instantiate but bad arguments won't cause a panic. +func instantiate(sig *types.Signature, substs []types.Type) *types.Signature { + if substs == nil || sig == nil || len(substs) != sig.TypeParams().Len() { + return nil + } + + for i := range substs { + if substs[i] == nil { + substs[i] = sig.TypeParams().At(i) + } + } + + if inst, err := types.Instantiate(nil, sig, substs, true); err == nil { + if inst, ok := inst.(*types.Signature); ok { + return inst + } + } + + return nil +} + func (c *completer) expectedCallParamType(inf candidateInference, node *ast.CallExpr, sig *types.Signature) candidateInference { numParams := sig.Params().Len() if numParams == 0 { @@ -2875,7 +3159,7 @@ func (c *completer) matchingCandidate(cand *candidate) bool { } // Bail out early if we are completing a field name in a composite literal. - if v, ok := cand.obj.(*types.Var); ok && v.IsField() && c.wantStructFieldCompletions() { + if v, ok := cand.obj.(*types.Var); ok && v.IsField() && wantStructFieldCompletions(c.enclosingCompositeLiteral) { return true } @@ -2972,6 +3256,14 @@ func (ci *candidateInference) candTypeMatches(cand *candidate) bool { cand.mods = append(cand.mods, takeDotDotDot) } + // Candidate matches, but isn't exactly identical to the expected type. + // Apply a conversion to allow it to match. + if ci.needsExactType && !types.Identical(candType, expType) { + cand.convertTo = expType + // Ranks barely lower if it needs a conversion, even though it's perfectly valid. + cand.score *= 0.95 + } + // Lower candidate score for untyped conversions. This avoids // ranking untyped constants above candidates with an exact type // match. Don't lower score of builtin constants, e.g. "true". @@ -3161,6 +3453,9 @@ func (c *completer) matchingTypeName(cand *candidate) bool { return false } + wantExactTypeParam := c.inference.typeName.isTypeParam && + c.inference.typeName.wantTypeName && c.inference.needsExactType + typeMatches := func(candType types.Type) bool { // Take into account any type name modifier prefixes. candType = c.inference.applyTypeNameModifiers(candType) @@ -3179,6 +3474,13 @@ func (c *completer) matchingTypeName(cand *candidate) bool { } } + // Suggest the exact type when performing reverse type inference. + // x = Foo[<>]() + // Where x is an interface kind, only suggest the interface type rather than its implementors + if wantExactTypeParam && types.Identical(candType, c.inference.objType) { + return true + } + if c.inference.typeName.wantComparable && !types.Comparable(candType) { return false } diff --git a/gopls/internal/golang/completion/format.go b/gopls/internal/golang/completion/format.go index baf0890497b..872025949fb 100644 --- a/gopls/internal/golang/completion/format.go +++ b/gopls/internal/golang/completion/format.go @@ -196,24 +196,9 @@ Suffixes: } if cand.convertTo != nil { - typeName := types.TypeString(cand.convertTo, c.qf) - - switch t := cand.convertTo.(type) { - // We need extra parens when casting to these types. For example, - // we need "(*int)(foo)", not "*int(foo)". - case *types.Pointer, *types.Signature: - typeName = "(" + typeName + ")" - case *types.Basic: - // If the types are incompatible (as determined by typeMatches), then we - // must need a conversion here. However, if the target type is untyped, - // don't suggest converting to e.g. "untyped float" (golang/go#62141). - if t.Info()&types.IsUntyped != 0 { - typeName = types.TypeString(types.Default(cand.convertTo), c.qf) - } - } - - prefix = typeName + "(" + prefix - suffix = ")" + conv := c.formatConversion(cand.convertTo) + prefix = conv.prefix + prefix + suffix = conv.suffix } if prefix != "" { @@ -288,6 +273,38 @@ Suffixes: return item, nil } +// conversionEdits represents the string edits needed to make a type conversion +// of an expression. +type conversionEdits struct { + prefix, suffix string +} + +// formatConversion returns the edits needed to make a type conversion +// expression, including parentheses if necessary. +// +// Returns empty conversionEdits if convertTo is nil. +func (c *completer) formatConversion(convertTo types.Type) conversionEdits { + if convertTo == nil { + return conversionEdits{} + } + + typeName := types.TypeString(convertTo, c.qf) + switch t := convertTo.(type) { + // We need extra parens when casting to these types. For example, + // we need "(*int)(foo)", not "*int(foo)". + case *types.Pointer, *types.Signature: + typeName = "(" + typeName + ")" + case *types.Basic: + // If the types are incompatible (as determined by typeMatches), then we + // must need a conversion here. However, if the target type is untyped, + // don't suggest converting to e.g. "untyped float" (golang/go#62141). + if t.Info()&types.IsUntyped != 0 { + typeName = types.TypeString(types.Default(convertTo), c.qf) + } + } + return conversionEdits{prefix: typeName + "(", suffix: ")"} +} + // importEdits produces the text edits necessary to add the given import to the current file. func (c *completer) importEdits(imp *importInfo) ([]protocol.TextEdit, error) { if imp == nil { diff --git a/gopls/internal/golang/completion/literal.go b/gopls/internal/golang/completion/literal.go index 7427d559e94..50ddb1fc26e 100644 --- a/gopls/internal/golang/completion/literal.go +++ b/gopls/internal/golang/completion/literal.go @@ -73,15 +73,21 @@ func (c *completer) literal(ctx context.Context, literalType types.Type, imp *im cand.addressable = true } - if !c.matchingCandidate(&cand) || cand.convertTo != nil { + // Only suggest a literal conversion if the exact type is known. + if !c.matchingCandidate(&cand) || (cand.convertTo != nil && !c.inference.needsExactType) { return } var ( - qf = c.qf - sel = enclosingSelector(c.path, c.pos) + qf = c.qf + sel = enclosingSelector(c.path, c.pos) + conversion conversionEdits ) + if cand.convertTo != nil { + conversion = c.formatConversion(cand.convertTo) + } + // Don't qualify the type name if we are in a selector expression // since the package name is already present. if sel != nil { @@ -129,13 +135,18 @@ func (c *completer) literal(ctx context.Context, literalType types.Type, imp *im switch t := literalType.Underlying().(type) { case *types.Struct, *types.Array, *types.Slice, *types.Map: - c.compositeLiteral(t, snip.Clone(), typeName, float64(score), addlEdits) + item := c.compositeLiteral(t, snip.Clone(), typeName, float64(score), addlEdits) + item.addConversion(c, conversion) + c.items = append(c.items, item) case *types.Signature: // Add a literal completion for a signature type that implements // an interface. For example, offer "http.HandlerFunc()" when // expected type is "http.Handler". if expType != nil && types.IsInterface(expType) { - c.basicLiteral(t, snip.Clone(), typeName, float64(score), addlEdits) + if item, ok := c.basicLiteral(t, snip.Clone(), typeName, float64(score), addlEdits); ok { + item.addConversion(c, conversion) + c.items = append(c.items, item) + } } case *types.Basic: // Add a literal completion for basic types that implement our @@ -143,7 +154,10 @@ func (c *completer) literal(ctx context.Context, literalType types.Type, imp *im // implements http.FileSystem), or are identical to our expected // type (i.e. yielding a type conversion such as "float64()"). if expType != nil && (types.IsInterface(expType) || types.Identical(expType, literalType)) { - c.basicLiteral(t, snip.Clone(), typeName, float64(score), addlEdits) + if item, ok := c.basicLiteral(t, snip.Clone(), typeName, float64(score), addlEdits); ok { + item.addConversion(c, conversion) + c.items = append(c.items, item) + } } } } @@ -155,11 +169,15 @@ func (c *completer) literal(ctx context.Context, literalType types.Type, imp *im switch literalType.Underlying().(type) { case *types.Slice: // The second argument to "make()" for slices is required, so default to "0". - c.makeCall(snip.Clone(), typeName, "0", float64(score), addlEdits) + item := c.makeCall(snip.Clone(), typeName, "0", float64(score), addlEdits) + item.addConversion(c, conversion) + c.items = append(c.items, item) case *types.Map, *types.Chan: // Maps and channels don't require the second argument, so omit // to keep things simple for now. - c.makeCall(snip.Clone(), typeName, "", float64(score), addlEdits) + item := c.makeCall(snip.Clone(), typeName, "", float64(score), addlEdits) + item.addConversion(c, conversion) + c.items = append(c.items, item) } } @@ -167,7 +185,10 @@ func (c *completer) literal(ctx context.Context, literalType types.Type, imp *im if score := c.matcher.Score("func"); !cand.hasMod(reference) && score > 0 && (expType == nil || !types.IsInterface(expType)) { switch t := literalType.Underlying().(type) { case *types.Signature: - c.functionLiteral(ctx, t, float64(score)) + if item, ok := c.functionLiteral(ctx, t, float64(score)); ok { + item.addConversion(c, conversion) + c.items = append(c.items, item) + } } } } @@ -178,9 +199,9 @@ func (c *completer) literal(ctx context.Context, literalType types.Type, imp *im // correct type, so scale down highScore. const literalCandidateScore = highScore / 2 -// functionLiteral adds a function literal completion item for the -// given signature. -func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, matchScore float64) { +// functionLiteral returns a function literal completion item for the +// given signature, if applicable. +func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, matchScore float64) (CompletionItem, bool) { snip := &snippet.Builder{} snip.WriteText("func(") @@ -216,7 +237,7 @@ func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, m if ctx.Err() == nil { event.Error(ctx, "formatting var type", err) } - return + return CompletionItem{}, false } name = abbreviateTypeName(typeName) } @@ -284,7 +305,7 @@ func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, m if ctx.Err() == nil { event.Error(ctx, "formatting var type", err) } - return + return CompletionItem{}, false } if sig.Variadic() && i == sig.Params().Len()-1 { typeStr = strings.Replace(typeStr, "[]", "...", 1) @@ -342,7 +363,7 @@ func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, m if ctx.Err() == nil { event.Error(ctx, "formatting var type", err) } - return + return CompletionItem{}, false } if tp, ok := types.Unalias(r.Type()).(*types.TypeParam); ok && !c.typeParamInScope(tp) { snip.WritePlaceholder(func(snip *snippet.Builder) { @@ -360,16 +381,18 @@ func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, m snip.WriteFinalTabstop() snip.WriteText("}") - c.items = append(c.items, CompletionItem{ + return CompletionItem{ Label: "func(...) {}", Score: matchScore * literalCandidateScore, Kind: protocol.VariableCompletion, snippet: snip, - }) + }, true } // conventionalAcronyms contains conventional acronyms for type names // in lower case. For example, "ctx" for "context" and "err" for "error". +// +// Keep this up to date with golang.conventionalVarNames. var conventionalAcronyms = map[string]string{ "context": "ctx", "error": "err", @@ -382,11 +405,6 @@ var conventionalAcronyms = map[string]string{ // non-identifier runes. For example, "[]int" becomes "i", and // "struct { i int }" becomes "s". func abbreviateTypeName(s string) string { - var ( - b strings.Builder - useNextUpper bool - ) - // Trim off leading non-letters. We trim everything between "[" and // "]" to handle array types like "[someConst]int". var inBracket bool @@ -407,32 +425,12 @@ func abbreviateTypeName(s string) string { return acr } - for i, r := range s { - // Stop if we encounter a non-identifier rune. - if !unicode.IsLetter(r) && !unicode.IsNumber(r) { - break - } - - if i == 0 { - b.WriteRune(unicode.ToLower(r)) - } - - if unicode.IsUpper(r) { - if useNextUpper { - b.WriteRune(unicode.ToLower(r)) - useNextUpper = false - } - } else { - useNextUpper = true - } - } - - return b.String() + return golang.AbbreviateVarName(s) } -// compositeLiteral adds a composite literal completion item for the given typeName. +// compositeLiteral returns a composite literal completion item for the given typeName. // T is an (unnamed, unaliased) struct, array, slice, or map type. -func (c *completer) compositeLiteral(T types.Type, snip *snippet.Builder, typeName string, matchScore float64, edits []protocol.TextEdit) { +func (c *completer) compositeLiteral(T types.Type, snip *snippet.Builder, typeName string, matchScore float64, edits []protocol.TextEdit) CompletionItem { snip.WriteText("{") // Don't put the tab stop inside the composite literal curlies "{}" // for structs that have no accessible fields. @@ -443,22 +441,24 @@ func (c *completer) compositeLiteral(T types.Type, snip *snippet.Builder, typeNa nonSnippet := typeName + "{}" - c.items = append(c.items, CompletionItem{ + return CompletionItem{ Label: nonSnippet, InsertText: nonSnippet, Score: matchScore * literalCandidateScore, Kind: protocol.VariableCompletion, AdditionalTextEdits: edits, snippet: snip, - }) + } } -// basicLiteral adds a literal completion item for the given basic +// basicLiteral returns a literal completion item for the given basic // type name typeName. -func (c *completer) basicLiteral(T types.Type, snip *snippet.Builder, typeName string, matchScore float64, edits []protocol.TextEdit) { +// +// If T is untyped, this function returns false. +func (c *completer) basicLiteral(T types.Type, snip *snippet.Builder, typeName string, matchScore float64, edits []protocol.TextEdit) (CompletionItem, bool) { // Never give type conversions like "untyped int()". if isUntyped(T) { - return + return CompletionItem{}, false } snip.WriteText("(") @@ -467,7 +467,7 @@ func (c *completer) basicLiteral(T types.Type, snip *snippet.Builder, typeName s nonSnippet := typeName + "()" - c.items = append(c.items, CompletionItem{ + return CompletionItem{ Label: nonSnippet, InsertText: nonSnippet, Detail: T.String(), @@ -475,11 +475,11 @@ func (c *completer) basicLiteral(T types.Type, snip *snippet.Builder, typeName s Kind: protocol.VariableCompletion, AdditionalTextEdits: edits, snippet: snip, - }) + }, true } -// makeCall adds a completion item for a "make()" call given a specific type. -func (c *completer) makeCall(snip *snippet.Builder, typeName string, secondArg string, matchScore float64, edits []protocol.TextEdit) { +// makeCall returns a completion item for a "make()" call given a specific type. +func (c *completer) makeCall(snip *snippet.Builder, typeName string, secondArg string, matchScore float64, edits []protocol.TextEdit) CompletionItem { // Keep it simple and don't add any placeholders for optional "make()" arguments. snip.PrependText("make(") @@ -501,14 +501,15 @@ func (c *completer) makeCall(snip *snippet.Builder, typeName string, secondArg s } nonSnippet.WriteByte(')') - c.items = append(c.items, CompletionItem{ - Label: nonSnippet.String(), - InsertText: nonSnippet.String(), - Score: matchScore * literalCandidateScore, + return CompletionItem{ + Label: nonSnippet.String(), + InsertText: nonSnippet.String(), + // make() should be just below other literal completions + Score: matchScore * literalCandidateScore * 0.99, Kind: protocol.FunctionCompletion, AdditionalTextEdits: edits, snippet: snip, - }) + } } // Create a snippet for a type name where type params become placeholders. @@ -516,28 +517,30 @@ func (c *completer) typeNameSnippet(literalType types.Type, qf types.Qualifier) var ( snip snippet.Builder typeName string - // TODO(adonovan): think more about aliases. - // They should probably be treated more like Named. - named, _ = types.Unalias(literalType).(*types.Named) + pnt, _ = literalType.(typesinternal.NamedOrAlias) // = *Named | *Alias ) - if named != nil && named.Obj() != nil && named.TypeParams().Len() > 0 && !c.fullyInstantiated(named) { + tparams := typesinternal.TypeParams(pnt) + if tparams.Len() > 0 && !c.fullyInstantiated(pnt) { + // tparams.Len() > 0 implies pnt != nil. + // Inv: pnt is not "error" or "unsafe.Pointer", so pnt.Obj() != nil and has a Pkg(). + // We are not "fully instantiated" meaning we have type params that must be specified. - if pkg := qf(named.Obj().Pkg()); pkg != "" { + if pkg := qf(pnt.Obj().Pkg()); pkg != "" { typeName = pkg + "." } // We do this to get "someType" instead of "someType[T]". - typeName += named.Obj().Name() + typeName += pnt.Obj().Name() snip.WriteText(typeName + "[") if c.opts.placeholders { - for i := 0; i < named.TypeParams().Len(); i++ { + for i := 0; i < tparams.Len(); i++ { if i > 0 { snip.WriteText(", ") } snip.WritePlaceholder(func(snip *snippet.Builder) { - snip.WriteText(types.TypeString(named.TypeParams().At(i), qf)) + snip.WriteText(types.TypeString(tparams.At(i), qf)) }) } } else { @@ -556,25 +559,35 @@ func (c *completer) typeNameSnippet(literalType types.Type, qf types.Qualifier) // fullyInstantiated reports whether all of t's type params have // specified type args. -func (c *completer) fullyInstantiated(t *types.Named) bool { - tps := t.TypeParams() - tas := t.TypeArgs() +func (c *completer) fullyInstantiated(t typesinternal.NamedOrAlias) bool { + targs := typesinternal.TypeArgs(t) + tparams := typesinternal.TypeParams(t) - if tps.Len() != tas.Len() { + if tparams.Len() != targs.Len() { return false } - for i := 0; i < tas.Len(); i++ { - // TODO(adonovan) think about generic aliases. - switch ta := types.Unalias(tas.At(i)).(type) { + for i := 0; i < targs.Len(); i++ { + targ := targs.At(i) + + // The expansion of an alias can have free type parameters, + // whether or not the alias itself has type parameters: + // + // func _[K comparable]() { + // type Set = map[K]bool // free(Set) = {K} + // type MapTo[V] = map[K]V // free(Map[foo]) = {V} + // } + // + // So, we must Unalias. + switch targ := types.Unalias(targ).(type) { case *types.TypeParam: // A *TypeParam only counts as specified if it is currently in // scope (i.e. we are in a generic definition). - if !c.typeParamInScope(ta) { + if !c.typeParamInScope(targ) { return false } case *types.Named: - if !c.fullyInstantiated(ta) { + if !c.fullyInstantiated(targ) { return false } } diff --git a/gopls/internal/golang/completion/package.go b/gopls/internal/golang/completion/package.go index e71f5c9dd02..5fd6c04144d 100644 --- a/gopls/internal/golang/completion/package.go +++ b/gopls/internal/golang/completion/package.go @@ -205,7 +205,7 @@ func packageSuggestions(ctx context.Context, snapshot *cache.Snapshot, fileURI p } }() - dirPath := filepath.Dir(fileURI.Path()) + dirPath := fileURI.DirPath() dirName := filepath.Base(dirPath) if !isValidDirName(dirName) { return packages, nil @@ -227,7 +227,7 @@ func packageSuggestions(ctx context.Context, snapshot *cache.Snapshot, fileURI p // Only add packages that are previously used in the current directory. var relevantPkg bool for _, uri := range mp.CompiledGoFiles { - if filepath.Dir(uri.Path()) == dirPath { + if uri.DirPath() == dirPath { relevantPkg = true break } diff --git a/gopls/internal/golang/completion/postfix_snippets.go b/gopls/internal/golang/completion/postfix_snippets.go index d322775cc7f..e0fc12cc9b5 100644 --- a/gopls/internal/golang/completion/postfix_snippets.go +++ b/gopls/internal/golang/completion/postfix_snippets.go @@ -442,7 +442,7 @@ func (a *postfixTmplArgs) TypeName(t types.Type) (string, error) { // Zero return the zero value representation of type t func (a *postfixTmplArgs) Zero(t types.Type) string { - return formatZeroValue(t, a.qf) + return typesinternal.ZeroString(t, a.qf) } func (a *postfixTmplArgs) IsIdent() bool { diff --git a/gopls/internal/golang/completion/snippet.go b/gopls/internal/golang/completion/snippet.go index 8df81f87672..fe346203120 100644 --- a/gopls/internal/golang/completion/snippet.go +++ b/gopls/internal/golang/completion/snippet.go @@ -13,7 +13,7 @@ import ( // structFieldSnippet calculates the snippet for struct literal field names. func (c *completer) structFieldSnippet(cand candidate, detail string, snip *snippet.Builder) { - if !c.wantStructFieldCompletions() { + if !wantStructFieldCompletions(c.enclosingCompositeLiteral) { return } diff --git a/gopls/internal/golang/completion/statements.go b/gopls/internal/golang/completion/statements.go index ce80cfb08ce..e187bf2bee0 100644 --- a/gopls/internal/golang/completion/statements.go +++ b/gopls/internal/golang/completion/statements.go @@ -15,6 +15,7 @@ import ( "golang.org/x/tools/gopls/internal/golang" "golang.org/x/tools/gopls/internal/golang/completion/snippet" "golang.org/x/tools/gopls/internal/protocol" + "golang.org/x/tools/internal/typesinternal" ) // addStatementCandidates adds full statement completion candidates @@ -294,7 +295,7 @@ func (c *completer) addErrCheck() { } else { snip.WriteText("return ") for i := 0; i < result.Len()-1; i++ { - snip.WriteText(formatZeroValue(result.At(i).Type(), c.qf)) + snip.WriteText(typesinternal.ZeroString(result.At(i).Type(), c.qf)) snip.WriteText(", ") } snip.WritePlaceholder(func(b *snippet.Builder) { @@ -404,7 +405,7 @@ func (c *completer) addReturnZeroValues() { fmt.Fprintf(&label, ", ") } - zero := formatZeroValue(result.At(i).Type(), c.qf) + zero := typesinternal.ZeroString(result.At(i).Type(), c.qf) snip.WritePlaceholder(func(b *snippet.Builder) { b.WriteText(zero) }) diff --git a/gopls/internal/golang/completion/unify.go b/gopls/internal/golang/completion/unify.go new file mode 100644 index 00000000000..8f4a1d3cbe0 --- /dev/null +++ b/gopls/internal/golang/completion/unify.go @@ -0,0 +1,710 @@ +// Below was copied from go/types/unify.go on September 24, 2024, +// and combined with snippets from other files as well. +// It is copied to implement unification for code completion inferences, +// in lieu of an official type unification API. +// +// TODO: When such an API is available, the code below should deleted. +// +// Due to complexity of extracting private types from the go/types package, +// the unifier does not fully implement interface unification. +// +// The code has been modified to compile without introducing any key functionality changes. +// + +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements type unification. +// +// Type unification attempts to make two types x and y structurally +// equivalent by determining the types for a given list of (bound) +// type parameters which may occur within x and y. If x and y are +// structurally different (say []T vs chan T), or conflicting +// types are determined for type parameters, unification fails. +// If unification succeeds, as a side-effect, the types of the +// bound type parameters may be determined. +// +// Unification typically requires multiple calls u.unify(x, y) to +// a given unifier u, with various combinations of types x and y. +// In each call, additional type parameter types may be determined +// as a side effect and recorded in u. +// If a call fails (returns false), unification fails. +// +// In the unification context, structural equivalence of two types +// ignores the difference between a defined type and its underlying +// type if one type is a defined type and the other one is not. +// It also ignores the difference between an (external, unbound) +// type parameter and its core type. +// If two types are not structurally equivalent, they cannot be Go +// identical types. On the other hand, if they are structurally +// equivalent, they may be Go identical or at least assignable, or +// they may be in the type set of a constraint. +// Whether they indeed are identical or assignable is determined +// upon instantiation and function argument passing. + +package completion + +import ( + "fmt" + "go/types" + "strings" +) + +const ( + // Upper limit for recursion depth. Used to catch infinite recursions + // due to implementation issues (e.g., see issues go.dev/issue/48619, go.dev/issue/48656). + unificationDepthLimit = 50 + + // Whether to panic when unificationDepthLimit is reached. + // If disabled, a recursion depth overflow results in a (quiet) + // unification failure. + panicAtUnificationDepthLimit = true + + // If enableCoreTypeUnification is set, unification will consider + // the core types, if any, of non-local (unbound) type parameters. + enableCoreTypeUnification = true +) + +// A unifier maintains a list of type parameters and +// corresponding types inferred for each type parameter. +// A unifier is created by calling newUnifier. +type unifier struct { + // handles maps each type parameter to its inferred type through + // an indirection *Type called (inferred type) "handle". + // Initially, each type parameter has its own, separate handle, + // with a nil (i.e., not yet inferred) type. + // After a type parameter P is unified with a type parameter Q, + // P and Q share the same handle (and thus type). This ensures + // that inferring the type for a given type parameter P will + // automatically infer the same type for all other parameters + // unified (joined) with P. + handles map[*types.TypeParam]*types.Type + depth int // recursion depth during unification +} + +// newUnifier returns a new unifier initialized with the given type parameter +// and corresponding type argument lists. The type argument list may be shorter +// than the type parameter list, and it may contain nil types. Matching type +// parameters and arguments must have the same index. +func newUnifier(tparams []*types.TypeParam, targs []types.Type) *unifier { + handles := make(map[*types.TypeParam]*types.Type, len(tparams)) + // Allocate all handles up-front: in a correct program, all type parameters + // must be resolved and thus eventually will get a handle. + // Also, sharing of handles caused by unified type parameters is rare and + // so it's ok to not optimize for that case (and delay handle allocation). + for i, x := range tparams { + var t types.Type + if i < len(targs) { + t = targs[i] + } + handles[x] = &t + } + return &unifier{handles, 0} +} + +// unifyMode controls the behavior of the unifier. +type unifyMode uint + +const ( + // If unifyModeAssign is set, we are unifying types involved in an assignment: + // they may match inexactly at the top, but element types must match + // exactly. + unifyModeAssign unifyMode = 1 << iota + + // If unifyModeExact is set, types unify if they are identical (or can be + // made identical with suitable arguments for type parameters). + // Otherwise, a named type and a type literal unify if their + // underlying types unify, channel directions are ignored, and + // if there is an interface, the other type must implement the + // interface. + unifyModeExact +) + +// This function was copied from go/types/unify.go +// +// unify attempts to unify x and y and reports whether it succeeded. +// As a side-effect, types may be inferred for type parameters. +// The mode parameter controls how types are compared. +func (u *unifier) unify(x, y types.Type, mode unifyMode) bool { + return u.nify(x, y, mode) +} + +type typeParamsById []*types.TypeParam + +// join unifies the given type parameters x and y. +// If both type parameters already have a type associated with them +// and they are not joined, join fails and returns false. +func (u *unifier) join(x, y *types.TypeParam) bool { + switch hx, hy := u.handles[x], u.handles[y]; { + case hx == hy: + // Both type parameters already share the same handle. Nothing to do. + case *hx != nil && *hy != nil: + // Both type parameters have (possibly different) inferred types. Cannot join. + return false + case *hx != nil: + // Only type parameter x has an inferred type. Use handle of x. + u.setHandle(y, hx) + // This case is treated like the default case. + // case *hy != nil: + // // Only type parameter y has an inferred type. Use handle of y. + // u.setHandle(x, hy) + default: + // Neither type parameter has an inferred type. Use handle of y. + u.setHandle(x, hy) + } + return true +} + +// asBoundTypeParam returns x.(*types.TypeParam) if x is a type parameter recorded with u. +// Otherwise, the result is nil. +func (u *unifier) asBoundTypeParam(x types.Type) *types.TypeParam { + if x, _ := types.Unalias(x).(*types.TypeParam); x != nil { + if _, found := u.handles[x]; found { + return x + } + } + return nil +} + +// setHandle sets the handle for type parameter x +// (and all its joined type parameters) to h. +func (u *unifier) setHandle(x *types.TypeParam, h *types.Type) { + hx := u.handles[x] + for y, hy := range u.handles { + if hy == hx { + u.handles[y] = h + } + } +} + +// at returns the (possibly nil) type for type parameter x. +func (u *unifier) at(x *types.TypeParam) types.Type { + return *u.handles[x] +} + +// set sets the type t for type parameter x; +// t must not be nil. +func (u *unifier) set(x *types.TypeParam, t types.Type) { + *u.handles[x] = t +} + +// unknowns returns the number of type parameters for which no type has been set yet. +func (u *unifier) unknowns() int { + n := 0 + for _, h := range u.handles { + if *h == nil { + n++ + } + } + return n +} + +// inferred returns the list of inferred types for the given type parameter list. +// The result is never nil and has the same length as tparams; result types that +// could not be inferred are nil. Corresponding type parameters and result types +// have identical indices. +func (u *unifier) inferred(tparams []*types.TypeParam) []types.Type { + list := make([]types.Type, len(tparams)) + for i, x := range tparams { + list[i] = u.at(x) + } + return list +} + +// asInterface returns the underlying type of x as an interface if +// it is a non-type parameter interface. Otherwise it returns nil. +func asInterface(x types.Type) (i *types.Interface) { + if _, ok := types.Unalias(x).(*types.TypeParam); !ok { + i, _ = x.Underlying().(*types.Interface) + } + return i +} + +func isTypeParam(t types.Type) bool { + _, ok := types.Unalias(t).(*types.TypeParam) + return ok +} + +func asNamed(t types.Type) *types.Named { + n, _ := types.Unalias(t).(*types.Named) + return n +} + +func isTypeLit(t types.Type) bool { + switch types.Unalias(t).(type) { + case *types.Named, *types.TypeParam: + return false + } + return true +} + +// identicalOrigin reports whether x and y originated in the same declaration. +func identicalOrigin(x, y *types.Named) bool { + // TODO(gri) is this correct? + return x.Origin().Obj() == y.Origin().Obj() +} + +func match(x, y types.Type) types.Type { + // Common case: we don't have channels. + if types.Identical(x, y) { + return x + } + + // We may have channels that differ in direction only. + if x, _ := x.(*types.Chan); x != nil { + if y, _ := y.(*types.Chan); y != nil && types.Identical(x.Elem(), y.Elem()) { + // We have channels that differ in direction only. + // If there's an unrestricted channel, select the restricted one. + switch { + case x.Dir() == types.SendRecv: + return y + case y.Dir() == types.SendRecv: + return x + } + } + } + + // types are different + return nil +} + +func coreType(t types.Type) types.Type { + t = types.Unalias(t) + tpar, _ := t.(*types.TypeParam) + if tpar == nil { + return t.Underlying() + } + + return nil +} + +func sameId(obj *types.Var, pkg *types.Package, name string, foldCase bool) bool { + // If we don't care about capitalization, we also ignore packages. + if foldCase && strings.EqualFold(obj.Name(), name) { + return true + } + // spec: + // "Two identifiers are different if they are spelled differently, + // or if they appear in different packages and are not exported. + // Otherwise, they are the same." + if obj.Name() != name { + return false + } + // obj.Name == name + if obj.Exported() { + return true + } + // not exported, so packages must be the same + if obj.Pkg() != nil && pkg != nil { + return obj.Pkg() == pkg + } + return obj.Pkg().Path() == pkg.Path() +} + +// nify implements the core unification algorithm which is an +// adapted version of Checker.identical. For changes to that +// code the corresponding changes should be made here. +// Must not be called directly from outside the unifier. +func (u *unifier) nify(x, y types.Type, mode unifyMode) (result bool) { + u.depth++ + defer func() { + u.depth-- + }() + + // nothing to do if x == y + if x == y || types.Unalias(x) == types.Unalias(y) { + return true + } + + // Stop gap for cases where unification fails. + if u.depth > unificationDepthLimit { + if panicAtUnificationDepthLimit { + panic("unification reached recursion depth limit") + } + return false + } + + // Unification is symmetric, so we can swap the operands. + // Ensure that if we have at least one + // - defined type, make sure one is in y + // - type parameter recorded with u, make sure one is in x + if asNamed(x) != nil || u.asBoundTypeParam(y) != nil { + x, y = y, x + } + + // Unification will fail if we match a defined type against a type literal. + // If we are matching types in an assignment, at the top-level, types with + // the same type structure are permitted as long as at least one of them + // is not a defined type. To accommodate for that possibility, we continue + // unification with the underlying type of a defined type if the other type + // is a type literal. This is controlled by the exact unification mode. + // We also continue if the other type is a basic type because basic types + // are valid underlying types and may appear as core types of type constraints. + // If we exclude them, inferred defined types for type parameters may not + // match against the core types of their constraints (even though they might + // correctly match against some of the types in the constraint's type set). + // Finally, if unification (incorrectly) succeeds by matching the underlying + // type of a defined type against a basic type (because we include basic types + // as type literals here), and if that leads to an incorrectly inferred type, + // we will fail at function instantiation or argument assignment time. + // + // If we have at least one defined type, there is one in y. + if ny := asNamed(y); mode&unifyModeExact == 0 && ny != nil && isTypeLit(x) { + y = ny.Underlying() + // Per the spec, a defined type cannot have an underlying type + // that is a type parameter. + // x and y may be identical now + if x == y || types.Unalias(x) == types.Unalias(y) { + return true + } + } + + // Cases where at least one of x or y is a type parameter recorded with u. + // If we have at least one type parameter, there is one in x. + // If we have exactly one type parameter, because it is in x, + // isTypeLit(x) is false and y was not changed above. In other + // words, if y was a defined type, it is still a defined type + // (relevant for the logic below). + switch px, py := u.asBoundTypeParam(x), u.asBoundTypeParam(y); { + case px != nil && py != nil: + // both x and y are type parameters + if u.join(px, py) { + return true + } + // both x and y have an inferred type - they must match + return u.nify(u.at(px), u.at(py), mode) + + case px != nil: + // x is a type parameter, y is not + if x := u.at(px); x != nil { + // x has an inferred type which must match y + if u.nify(x, y, mode) { + // We have a match, possibly through underlying types. + xi := asInterface(x) + yi := asInterface(y) + xn := asNamed(x) != nil + yn := asNamed(y) != nil + // If we have two interfaces, what to do depends on + // whether they are named and their method sets. + if xi != nil && yi != nil { + // Both types are interfaces. + // If both types are defined types, they must be identical + // because unification doesn't know which type has the "right" name. + if xn && yn { + return types.Identical(x, y) + } + return false + // Below is the original code for reference + + // In all other cases, the method sets must match. + // The types unified so we know that corresponding methods + // match and we can simply compare the number of methods. + // TODO(gri) We may be able to relax this rule and select + // the more general interface. But if one of them is a defined + // type, it's not clear how to choose and whether we introduce + // an order dependency or not. Requiring the same method set + // is conservative. + // if len(xi.typeSet().methods) != len(yi.typeSet().methods) { + // return false + // } + } else if xi != nil || yi != nil { + // One but not both of them are interfaces. + // In this case, either x or y could be viable matches for the corresponding + // type parameter, which means choosing either introduces an order dependence. + // Therefore, we must fail unification (go.dev/issue/60933). + return false + } + // If we have inexact unification and one of x or y is a defined type, select the + // defined type. This ensures that in a series of types, all matching against the + // same type parameter, we infer a defined type if there is one, independent of + // order. Type inference or assignment may fail, which is ok. + // Selecting a defined type, if any, ensures that we don't lose the type name; + // and since we have inexact unification, a value of equally named or matching + // undefined type remains assignable (go.dev/issue/43056). + // + // Similarly, if we have inexact unification and there are no defined types but + // channel types, select a directed channel, if any. This ensures that in a series + // of unnamed types, all matching against the same type parameter, we infer the + // directed channel if there is one, independent of order. + // Selecting a directional channel, if any, ensures that a value of another + // inexactly unifying channel type remains assignable (go.dev/issue/62157). + // + // If we have multiple defined channel types, they are either identical or we + // have assignment conflicts, so we can ignore directionality in this case. + // + // If we have defined and literal channel types, a defined type wins to avoid + // order dependencies. + if mode&unifyModeExact == 0 { + switch { + case xn: + // x is a defined type: nothing to do. + case yn: + // x is not a defined type and y is a defined type: select y. + u.set(px, y) + default: + // Neither x nor y are defined types. + if yc, _ := y.Underlying().(*types.Chan); yc != nil && yc.Dir() != types.SendRecv { + // y is a directed channel type: select y. + u.set(px, y) + } + } + } + return true + } + return false + } + // otherwise, infer type from y + u.set(px, y) + return true + } + + // If u.EnableInterfaceInference is set and we don't require exact unification, + // if both types are interfaces, one interface must have a subset of the + // methods of the other and corresponding method signatures must unify. + // If only one type is an interface, all its methods must be present in the + // other type and corresponding method signatures must unify. + + // Unless we have exact unification, neither x nor y are interfaces now. + // Except for unbound type parameters (see below), x and y must be structurally + // equivalent to unify. + + // If we get here and x or y is a type parameter, they are unbound + // (not recorded with the unifier). + // Ensure that if we have at least one type parameter, it is in x + // (the earlier swap checks for _recorded_ type parameters only). + // This ensures that the switch switches on the type parameter. + // + // TODO(gri) Factor out type parameter handling from the switch. + if isTypeParam(y) { + x, y = y, x + } + + // Type elements (array, slice, etc. elements) use emode for unification. + // Element types must match exactly if the types are used in an assignment. + emode := mode + if mode&unifyModeAssign != 0 { + emode |= unifyModeExact + } + + // Continue with unaliased types but don't lose original alias names, if any (go.dev/issue/67628). + xorig, x := x, types.Unalias(x) + yorig, y := y, types.Unalias(y) + + switch x := x.(type) { + case *types.Basic: + // Basic types are singletons except for the rune and byte + // aliases, thus we cannot solely rely on the x == y check + // above. See also comment in TypeName.IsAlias. + if y, ok := y.(*types.Basic); ok { + return x.Kind() == y.Kind() + } + + case *types.Array: + // Two array types unify if they have the same array length + // and their element types unify. + if y, ok := y.(*types.Array); ok { + // If one or both array lengths are unknown (< 0) due to some error, + // assume they are the same to avoid spurious follow-on errors. + return (x.Len() < 0 || y.Len() < 0 || x.Len() == y.Len()) && u.nify(x.Elem(), y.Elem(), emode) + } + + case *types.Slice: + // Two slice types unify if their element types unify. + if y, ok := y.(*types.Slice); ok { + return u.nify(x.Elem(), y.Elem(), emode) + } + + case *types.Struct: + // Two struct types unify if they have the same sequence of fields, + // and if corresponding fields have the same names, their (field) types unify, + // and they have identical tags. Two embedded fields are considered to have the same + // name. Lower-case field names from different packages are always different. + if y, ok := y.(*types.Struct); ok { + if x.NumFields() == y.NumFields() { + for i := range x.NumFields() { + f := x.Field(i) + g := y.Field(i) + if f.Embedded() != g.Embedded() || + x.Tag(i) != y.Tag(i) || + !sameId(f, g.Pkg(), g.Name(), false) || + !u.nify(f.Type(), g.Type(), emode) { + return false + } + } + return true + } + } + + case *types.Pointer: + // Two pointer types unify if their base types unify. + if y, ok := y.(*types.Pointer); ok { + return u.nify(x.Elem(), y.Elem(), emode) + } + + case *types.Tuple: + // Two tuples types unify if they have the same number of elements + // and the types of corresponding elements unify. + if y, ok := y.(*types.Tuple); ok { + if x.Len() == y.Len() { + if x != nil { + for i := range x.Len() { + v := x.At(i) + w := y.At(i) + if !u.nify(v.Type(), w.Type(), mode) { + return false + } + } + } + return true + } + } + + case *types.Signature: + // Two function types unify if they have the same number of parameters + // and result values, corresponding parameter and result types unify, + // and either both functions are variadic or neither is. + // Parameter and result names are not required to match. + // TODO(gri) handle type parameters or document why we can ignore them. + if y, ok := y.(*types.Signature); ok { + return x.Variadic() == y.Variadic() && + u.nify(x.Params(), y.Params(), emode) && + u.nify(x.Results(), y.Results(), emode) + } + + case *types.Interface: + return false + // Below is the original code + + // Two interface types unify if they have the same set of methods with + // the same names, and corresponding function types unify. + // Lower-case method names from different packages are always different. + // The order of the methods is irrelevant. + // xset := x.typeSet() + // yset := y.typeSet() + // if xset.comparable != yset.comparable { + // return false + // } + // if !xset.terms.equal(yset.terms) { + // return false + // } + // a := xset.methods + // b := yset.methods + // if len(a) == len(b) { + // // Interface types are the only types where cycles can occur + // // that are not "terminated" via named types; and such cycles + // // can only be created via method parameter types that are + // // anonymous interfaces (directly or indirectly) embedding + // // the current interface. Example: + // // + // // type T interface { + // // m() interface{T} + // // } + // // + // // If two such (differently named) interfaces are compared, + // // endless recursion occurs if the cycle is not detected. + // // + // // If x and y were compared before, they must be equal + // // (if they were not, the recursion would have stopped); + // // search the ifacePair stack for the same pair. + // // + // // This is a quadratic algorithm, but in practice these stacks + // // are extremely short (bounded by the nesting depth of interface + // // type declarations that recur via parameter types, an extremely + // // rare occurrence). An alternative implementation might use a + // // "visited" map, but that is probably less efficient overall. + // q := &ifacePair{x, y, p} + // for p != nil { + // if p.identical(q) { + // return true // same pair was compared before + // } + // p = p.prev + // } + // if debug { + // assertSortedMethods(a) + // assertSortedMethods(b) + // } + // for i, f := range a { + // g := b[i] + // if f.Id() != g.Id() || !u.nify(f.typ, g.typ, exact, q) { + // return false + // } + // } + // return true + // } + + case *types.Map: + // Two map types unify if their key and value types unify. + if y, ok := y.(*types.Map); ok { + return u.nify(x.Key(), y.Key(), emode) && u.nify(x.Elem(), y.Elem(), emode) + } + + case *types.Chan: + // Two channel types unify if their value types unify + // and if they have the same direction. + // The channel direction is ignored for inexact unification. + if y, ok := y.(*types.Chan); ok { + return (mode&unifyModeExact == 0 || x.Dir() == y.Dir()) && u.nify(x.Elem(), y.Elem(), emode) + } + + case *types.Named: + // Two named types unify if their type names originate in the same type declaration. + // If they are instantiated, their type argument lists must unify. + if y := asNamed(y); y != nil { + // Check type arguments before origins so they unify + // even if the origins don't match; for better error + // messages (see go.dev/issue/53692). + xargs := x.TypeArgs() + yargs := y.TypeArgs() + if xargs.Len() != yargs.Len() { + return false + } + for i := range xargs.Len() { + xarg := xargs.At(i) + yarg := yargs.At(i) + if !u.nify(xarg, yarg, mode) { + return false + } + } + return identicalOrigin(x, y) + } + + case *types.TypeParam: + // By definition, a valid type argument must be in the type set of + // the respective type constraint. Therefore, the type argument's + // underlying type must be in the set of underlying types of that + // constraint. If there is a single such underlying type, it's the + // constraint's core type. It must match the type argument's under- + // lying type, irrespective of whether the actual type argument, + // which may be a defined type, is actually in the type set (that + // will be determined at instantiation time). + // Thus, if we have the core type of an unbound type parameter, + // we know the structure of the possible types satisfying such + // parameters. Use that core type for further unification + // (see go.dev/issue/50755 for a test case). + if enableCoreTypeUnification { + // Because the core type is always an underlying type, + // unification will take care of matching against a + // defined or literal type automatically. + // If y is also an unbound type parameter, we will end + // up here again with x and y swapped, so we don't + // need to take care of that case separately. + if cx := coreType(x); cx != nil { + // If y is a defined type, it may not match against cx which + // is an underlying type (incl. int, string, etc.). Use assign + // mode here so that the unifier automatically takes under(y) + // if necessary. + return u.nify(cx, yorig, unifyModeAssign) + } + } + // x != y and there's nothing to do + + case nil: + // avoid a crash in case of nil type + + default: + panic(fmt.Sprintf("u.nify(%s, %s, %d)", xorig, yorig, mode)) + } + + return false +} diff --git a/gopls/internal/golang/completion/util.go b/gopls/internal/golang/completion/util.go index a13f5094839..766484e2fc8 100644 --- a/gopls/internal/golang/completion/util.go +++ b/gopls/internal/golang/completion/util.go @@ -277,28 +277,6 @@ func prevStmt(pos token.Pos, path []ast.Node) ast.Stmt { return nil } -// formatZeroValue produces Go code representing the zero value of T. It -// returns the empty string if T is invalid. -func formatZeroValue(T types.Type, qf types.Qualifier) string { - switch u := T.Underlying().(type) { - case *types.Basic: - switch { - case u.Info()&types.IsNumeric > 0: - return "0" - case u.Info()&types.IsString > 0: - return `""` - case u.Info()&types.IsBoolean > 0: - return "false" - default: - return "" - } - case *types.Pointer, *types.Interface, *types.Chan, *types.Map, *types.Slice, *types.Signature: - return "nil" - default: - return types.TypeString(T, qf) + "{}" - } -} - // isBasicKind returns whether t is a basic type of kind k. func isBasicKind(t types.Type, k types.BasicInfo) bool { b, _ := t.Underlying().(*types.Basic) diff --git a/gopls/internal/golang/completion/util_test.go b/gopls/internal/golang/completion/util_test.go deleted file mode 100644 index c94d279fbad..00000000000 --- a/gopls/internal/golang/completion/util_test.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2020 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package completion - -import ( - "go/types" - "testing" -) - -func TestFormatZeroValue(t *testing.T) { - tests := []struct { - typ types.Type - want string - }{ - {types.Typ[types.String], `""`}, - {types.Typ[types.Byte], "0"}, - {types.Typ[types.Invalid], ""}, - {types.Universe.Lookup("error").Type(), "nil"}, - } - - for _, test := range tests { - if got := formatZeroValue(test.typ, nil); got != test.want { - t.Errorf("formatZeroValue(%v) = %q, want %q", test.typ, got, test.want) - } - } -} diff --git a/gopls/internal/golang/embeddirective.go b/gopls/internal/golang/embeddirective.go index 3a35f907274..6dd542ddef8 100644 --- a/gopls/internal/golang/embeddirective.go +++ b/gopls/internal/golang/embeddirective.go @@ -35,7 +35,7 @@ func embedDefinition(m *protocol.Mapper, pos protocol.Position) ([]protocol.Loca // Find the first matching file. var match string - dir := filepath.Dir(m.URI.Path()) + dir := m.URI.DirPath() err := filepath.WalkDir(dir, func(abs string, d fs.DirEntry, e error) error { if e != nil { return e diff --git a/gopls/internal/golang/extract.go b/gopls/internal/golang/extract.go index 2edda76b6c5..72d718c2faf 100644 --- a/gopls/internal/golang/extract.go +++ b/gopls/internal/golang/extract.go @@ -10,8 +10,10 @@ import ( "go/ast" "go/format" "go/parser" + "go/printer" "go/token" "go/types" + "slices" "sort" "strings" "text/scanner" @@ -21,119 +23,204 @@ import ( "golang.org/x/tools/gopls/internal/util/bug" "golang.org/x/tools/gopls/internal/util/safetoken" "golang.org/x/tools/internal/analysisinternal" + "golang.org/x/tools/internal/typesinternal" ) +// extractVariable implements the refactor.extract.{variable,constant} CodeAction command. func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*token.FileSet, *analysis.SuggestedFix, error) { tokFile := fset.File(file.FileStart) - expr, path, ok, err := canExtractVariable(start, end, file) - if !ok { - return nil, nil, fmt.Errorf("extractVariable: cannot extract %s: %v", safetoken.StartPosition(fset, start), err) + expr, path, err := canExtractVariable(info, file, start, end) + if err != nil { + return nil, nil, fmt.Errorf("cannot extract %s: %v", safetoken.StartPosition(fset, start), err) } + constant := info.Types[expr].Value != nil - // Create new AST node for extracted code. + // Generate name(s) for new declaration. + baseName := cond(constant, "k", "x") var lhsNames []string switch expr := expr.(type) { - // TODO: stricter rules for selectorExpr. - case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.SliceExpr, - *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr: - lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", 0) - lhsNames = append(lhsNames, lhsName) case *ast.CallExpr: tup, ok := info.TypeOf(expr).(*types.Tuple) if !ok { - // If the call expression only has one return value, we can treat it the - // same as our standard extract variable case. - lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", 0) - lhsNames = append(lhsNames, lhsName) - break - } - idx := 0 - for i := 0; i < tup.Len(); i++ { - // Generate a unique variable for each return value. - var lhsName string - lhsName, idx = generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", idx) - lhsNames = append(lhsNames, lhsName) + // conversion or single-valued call: + // treat it the same as our standard extract variable case. + name, _ := freshName(info, file, expr.Pos(), baseName, 0) + lhsNames = append(lhsNames, name) + + } else { + // call with multiple results + idx := 0 + for range tup.Len() { + // Generate a unique variable for each result. + var name string + name, idx = freshName(info, file, expr.Pos(), baseName, idx) + lhsNames = append(lhsNames, name) + } } + default: - return nil, nil, fmt.Errorf("cannot extract %T", expr) + // TODO: stricter rules for selectorExpr. + name, _ := freshName(info, file, expr.Pos(), baseName, 0) + lhsNames = append(lhsNames, name) } // TODO: There is a bug here: for a variable declared in a labeled // switch/for statement it returns the for/switch statement itself - // which produces the below code which is a compiler error e.g. - // label: - // switch r1 := r() { ... break label ... } + // which produces the below code which is a compiler error. e.g. + // label: + // switch r1 := r() { ... break label ... } // On extracting "r()" to a variable - // label: - // x := r() - // switch r1 := x { ... break label ... } // compiler error - insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(path) - if insertBeforeStmt == nil { - return nil, nil, fmt.Errorf("cannot find location to insert extraction") - } - indent, err := calculateIndentation(src, tokFile, insertBeforeStmt) - if err != nil { - return nil, nil, err + // label: + // x := r() + // switch r1 := x { ... break label ... } // compiler error + // + // TODO(golang/go#70563): Another bug: extracting the + // expression to the recommended place may cause it to migrate + // across one or more declarations that it references. + // + // Before: + // if x := 1; cond { + // } else if y := «x + 2»; cond { + // } + // + // After: + // x1 := x + 2 // error: undefined x + // if x := 1; cond { + // } else if y := x1; cond { + // } + var ( + insertPos token.Pos + indentation string + stmtOK bool // ok to use ":=" instead of var/const decl? + ) + if before := analysisinternal.StmtToInsertVarBefore(path); before != nil { + // Within function: compute appropriate statement indentation. + indent, err := calculateIndentation(src, tokFile, before) + if err != nil { + return nil, nil, err + } + insertPos = before.Pos() + indentation = "\n" + indent + + // Currently, we always extract a constant expression + // to a const declaration (and logic in CodeAction + // assumes that we do so); this is conservative because + // it preserves its constant-ness. + // + // In future, constant expressions used only in + // contexts where constant-ness isn't important could + // be profitably extracted to a var declaration or := + // statement, especially if the latter is the Init of + // an {If,For,Switch}Stmt. + stmtOK = !constant + } else { + // Outside any statement: insert before the current + // declaration, without indentation. + currentDecl := path[len(path)-2] + insertPos = currentDecl.Pos() + indentation = "\n" } - newLineIndent := "\n" + indent - lhs := strings.Join(lhsNames, ", ") - assignStmt := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(lhs)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{expr}, + // Create statement to declare extracted var/const. + // + // TODO(adonovan): beware the const decls are not valid short + // statements, so if fixing #70563 causes + // StmtToInsertVarBefore to evolve to permit declarations in + // the "pre" part of an IfStmt, like so: + // Before: + // if cond { + // } else if «1 + 2» > 0 { + // } + // After: + // if x := 1 + 2; cond { + // } else if x > 0 { + // } + // then it will need to become aware that this is invalid + // for constants. + // + // Conversely, a short var decl stmt is not valid at top level, + // so when we fix #70665, we'll need to use a var decl. + var newNode ast.Node + if !stmtOK { + // var/const x1, ..., xn = expr + var names []*ast.Ident + for _, name := range lhsNames { + names = append(names, ast.NewIdent(name)) + } + newNode = &ast.GenDecl{ + Tok: cond(constant, token.CONST, token.VAR), + Specs: []ast.Spec{ + &ast.ValueSpec{ + Names: names, + Values: []ast.Expr{expr}, + }, + }, + } + + } else { + // var: x1, ... xn := expr + var lhs []ast.Expr + for _, name := range lhsNames { + lhs = append(lhs, ast.NewIdent(name)) + } + newNode = &ast.AssignStmt{ + Tok: token.DEFINE, + Lhs: lhs, + Rhs: []ast.Expr{expr}, + } } + + // Format and indent the declaration. var buf bytes.Buffer - if err := format.Node(&buf, fset, assignStmt); err != nil { + if err := format.Node(&buf, fset, newNode); err != nil { return nil, nil, err } - assignment := strings.ReplaceAll(buf.String(), "\n", newLineIndent) + newLineIndent + // TODO(adonovan): not sound for `...` string literals containing newlines. + assignment := strings.ReplaceAll(buf.String(), "\n", indentation) + indentation return fset, &analysis.SuggestedFix{ TextEdits: []analysis.TextEdit{ { - Pos: insertBeforeStmt.Pos(), - End: insertBeforeStmt.Pos(), + Pos: insertPos, + End: insertPos, NewText: []byte(assignment), }, { Pos: start, End: end, - NewText: []byte(lhs), + NewText: []byte(strings.Join(lhsNames, ", ")), }, }, }, nil } // canExtractVariable reports whether the code in the given range can be -// extracted to a variable. -func canExtractVariable(start, end token.Pos, file *ast.File) (ast.Expr, []ast.Node, bool, error) { +// extracted to a variable (or constant). +func canExtractVariable(info *types.Info, file *ast.File, start, end token.Pos) (ast.Expr, []ast.Node, error) { if start == end { - return nil, nil, false, fmt.Errorf("start and end are equal") + return nil, nil, fmt.Errorf("empty selection") + } + path, exact := astutil.PathEnclosingInterval(file, start, end) + if !exact { + return nil, nil, fmt.Errorf("selection is not an expression") } - path, _ := astutil.PathEnclosingInterval(file, start, end) if len(path) == 0 { - return nil, nil, false, fmt.Errorf("no path enclosing interval") + return nil, nil, bug.Errorf("no path enclosing interval") } for _, n := range path { if _, ok := n.(*ast.ImportSpec); ok { - return nil, nil, false, fmt.Errorf("cannot extract variable in an import block") + return nil, nil, fmt.Errorf("cannot extract variable or constant in an import block") } } - node := path[0] - if start != node.Pos() || end != node.End() { - return nil, nil, false, fmt.Errorf("range does not map to an AST node") - } - expr, ok := node.(ast.Expr) + expr, ok := path[0].(ast.Expr) if !ok { - return nil, nil, false, fmt.Errorf("node is not an expression") + return nil, nil, fmt.Errorf("selection is not an expression") // e.g. statement } - switch expr.(type) { - case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.CallExpr, - *ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr: - return expr, path, true, nil + if tv, ok := info.Types[expr]; !ok || !tv.IsValue() || tv.Type == nil || tv.HasOk() { + // e.g. type, builtin, x.(type), 2-valued m[k], or ill-typed + return nil, nil, fmt.Errorf("selection is not a single-valued expression") } - return nil, nil, false, fmt.Errorf("cannot extract an %T to a variable", expr) + return expr, path, nil } // Calculate indentation for insertion. @@ -149,22 +236,42 @@ func calculateIndentation(content []byte, tok *token.File, insertBeforeStmt ast. return string(content[lineOffset:stmtOffset]), nil } -// generateAvailableIdentifier adjusts the new function name until there are no collisions in scope. -// Possible collisions include other function and variable names. Returns the next index to check for prefix. -func generateAvailableIdentifier(pos token.Pos, path []ast.Node, pkg *types.Package, info *types.Info, prefix string, idx int) (string, int) { - scopes := CollectScopes(info, path, pos) - scopes = append(scopes, pkg.Scope()) - return generateIdentifier(idx, prefix, func(name string) bool { - for _, scope := range scopes { - if scope != nil && scope.Lookup(name) != nil { - return true +// freshName returns an identifier based on prefix (perhaps with a +// numeric suffix) that is not in scope at the specified position +// within the file. It returns the next numeric suffix to use. +func freshName(info *types.Info, file *ast.File, pos token.Pos, prefix string, idx int) (string, int) { + scope := info.Scopes[file].Innermost(pos) + return generateName(idx, prefix, func(name string) bool { + obj, _ := scope.LookupParent(name, pos) + return obj != nil + }) +} + +// freshNameOutsideRange is like [freshName], but ignores names +// declared between start and end for the purposes of detecting conflicts. +// +// This is used for function extraction, where [start, end) will be extracted +// to a new scope. +func freshNameOutsideRange(info *types.Info, file *ast.File, pos, start, end token.Pos, prefix string, idx int) (string, int) { + scope := info.Scopes[file].Innermost(pos) + return generateName(idx, prefix, func(name string) bool { + // Only report a collision if the object declaration + // was outside the extracted range. + for scope != nil { + obj, declScope := scope.LookupParent(name, pos) + if obj == nil { + return false // undeclared } + if !(start <= obj.Pos() && obj.Pos() < end) { + return true // declared outside ignored range + } + scope = declScope.Parent() } return false }) } -func generateIdentifier(idx int, prefix string, hasCollision func(string) bool) (string, int) { +func generateName(idx int, prefix string, hasCollision func(string) bool) (string, int) { name := prefix if idx != 0 { name += fmt.Sprintf("%d", idx) @@ -181,7 +288,7 @@ func generateIdentifier(idx int, prefix string, hasCollision func(string) bool) type returnVariable struct { // name is the identifier that is used on the left-hand side of the call to // the extracted function. - name ast.Expr + name *ast.Ident // decl is the declaration of the variable. It is used in the type signature of the // extracted function and for variable declarations. decl *ast.Field @@ -335,7 +442,7 @@ func extractFunctionMethod(fset *token.FileSet, start, end token.Pos, src []byte // The blank identifier is always a local variable continue } - typ := analysisinternal.TypeExpr(file, pkg, v.obj.Type()) + typ := typesinternal.TypeExpr(file, pkg, v.obj.Type()) if typ == nil { return nil, nil, fmt.Errorf("nil AST expression for type: %v", v.obj.Name()) } @@ -423,7 +530,8 @@ func extractFunctionMethod(fset *token.FileSet, start, end token.Pos, src []byte return nil, nil, err } selection := src[startOffset:endOffset] - extractedBlock, err := parseBlockStmt(fset, selection) + + extractedBlock, extractedComments, err := parseStmts(fset, selection) if err != nil { return nil, nil, err } @@ -516,7 +624,7 @@ func extractFunctionMethod(fset *token.FileSet, start, end token.Pos, src []byte // statements in the selection. Update the type signature of the extracted // function and construct the if statement that will be inserted in the enclosing // function. - retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, start, hasNonNestedReturn) + retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, start, end, hasNonNestedReturn) if err != nil { return nil, nil, err } @@ -544,26 +652,53 @@ func extractFunctionMethod(fset *token.FileSet, start, end token.Pos, src []byte if canDefine { sym = token.DEFINE } - var name, funName string + var funName string if isMethod { - name = "newMethod" // TODO(suzmue): generate a name that does not conflict for "newMethod". - funName = name + funName = "newMethod" } else { - name = "newFunction" - funName, _ = generateAvailableIdentifier(start, path, pkg, info, name, 0) + funName, _ = freshName(info, file, start, "newFunction", 0) } extractedFunCall := generateFuncCall(hasNonNestedReturn, hasReturnValues, params, append(returns, getNames(retVars)...), funName, sym, receiverName) - // Build the extracted function. + // Create variable declarations for any identifiers that need to be initialized prior to + // calling the extracted function. We do not manually initialize variables if every return + // value is uninitialized. We can use := to initialize the variables in this situation. + var declarations []ast.Stmt + if canDefineCount != len(returns) { + declarations = initializeVars(uninitialized, retVars, seenUninitialized, seenVars) + } + + var declBuf, replaceBuf, newFuncBuf, ifBuf, commentBuf bytes.Buffer + if err := format.Node(&declBuf, fset, declarations); err != nil { + return nil, nil, err + } + if err := format.Node(&replaceBuf, fset, extractedFunCall); err != nil { + return nil, nil, err + } + if ifReturn != nil { + if err := format.Node(&ifBuf, fset, ifReturn); err != nil { + return nil, nil, err + } + } + + // Build the extracted function. We format the function declaration and body + // separately, so that comments are printed relative to the extracted + // BlockStmt. + // + // In other words, extractedBlock and extractedComments were parsed from a + // synthetic function declaration of the form func _() { ... }. If we now + // print the real function declaration, the length of the signature will have + // grown, causing some comment positions to be computed as inside the + // signature itself. newFunc := &ast.FuncDecl{ Name: ast.NewIdent(funName), Type: &ast.FuncType{ Params: &ast.FieldList{List: paramTypes}, Results: &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)}, }, - Body: extractedBlock, + // Body handled separately -- see above. } if isMethod { var names []*ast.Ident @@ -577,39 +712,20 @@ func extractFunctionMethod(fset *token.FileSet, start, end token.Pos, src []byte }}, } } - - // Create variable declarations for any identifiers that need to be initialized prior to - // calling the extracted function. We do not manually initialize variables if every return - // value is uninitialized. We can use := to initialize the variables in this situation. - var declarations []ast.Stmt - if canDefineCount != len(returns) { - declarations = initializeVars(uninitialized, retVars, seenUninitialized, seenVars) - } - - var declBuf, replaceBuf, newFuncBuf, ifBuf, commentBuf bytes.Buffer - if err := format.Node(&declBuf, fset, declarations); err != nil { + if err := format.Node(&newFuncBuf, fset, newFunc); err != nil { return nil, nil, err } - if err := format.Node(&replaceBuf, fset, extractedFunCall); err != nil { + // Write a space between the end of the function signature and opening '{'. + if err := newFuncBuf.WriteByte(' '); err != nil { return nil, nil, err } - if ifReturn != nil { - if err := format.Node(&ifBuf, fset, ifReturn); err != nil { - return nil, nil, err - } + commentedNode := &printer.CommentedNode{ + Node: extractedBlock, + Comments: extractedComments, } - if err := format.Node(&newFuncBuf, fset, newFunc); err != nil { + if err := format.Node(&newFuncBuf, fset, commentedNode); err != nil { return nil, nil, err } - // Find all the comments within the range and print them to be put somewhere. - // TODO(suzmue): print these in the extracted function at the correct place. - for _, cg := range file.Comments { - if cg.Pos().IsValid() && cg.Pos() < end && cg.Pos() >= start { - for _, c := range cg.List { - fmt.Fprintln(&commentBuf, c.Text) - } - } - } // We're going to replace the whole enclosing function, // so preserve the text before and after the selected block. @@ -1161,37 +1277,37 @@ func varOverridden(info *types.Info, firstUse *ast.Ident, obj types.Object, isFr return isOverriden } -// parseBlockStmt generates an AST file from the given text. We then return the portion of the -// file that represents the text. -func parseBlockStmt(fset *token.FileSet, src []byte) (*ast.BlockStmt, error) { +// parseStmts parses the specified source (a list of statements) and +// returns them as a BlockStmt along with any associated comments. +func parseStmts(fset *token.FileSet, src []byte) (*ast.BlockStmt, []*ast.CommentGroup, error) { text := "package main\nfunc _() { " + string(src) + " }" - extract, err := parser.ParseFile(fset, "", text, parser.SkipObjectResolution) + file, err := parser.ParseFile(fset, "", text, parser.ParseComments|parser.SkipObjectResolution) if err != nil { - return nil, err + return nil, nil, err } - if len(extract.Decls) == 0 { - return nil, fmt.Errorf("parsed file does not contain any declarations") + if len(file.Decls) != 1 { + return nil, nil, fmt.Errorf("got %d declarations, want 1", len(file.Decls)) } - decl, ok := extract.Decls[0].(*ast.FuncDecl) + decl, ok := file.Decls[0].(*ast.FuncDecl) if !ok { - return nil, fmt.Errorf("parsed file does not contain expected function declaration") + return nil, nil, bug.Errorf("parsed file does not contain expected function declaration") } if decl.Body == nil { - return nil, fmt.Errorf("extracted function has no body") + return nil, nil, bug.Errorf("extracted function has no body") } - return decl.Body, nil + return decl.Body, file.Comments, nil } // generateReturnInfo generates the information we need to adjust the return statements and // signature of the extracted function. We prepare names, signatures, and "zero values" that // represent the new variables. We also use this information to construct the if statement that // is inserted below the call to the extracted function. -func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.Node, file *ast.File, info *types.Info, pos token.Pos, hasNonNestedReturns bool) ([]*returnVariable, *ast.IfStmt, error) { +func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.Node, file *ast.File, info *types.Info, start, end token.Pos, hasNonNestedReturns bool) ([]*returnVariable, *ast.IfStmt, error) { var retVars []*returnVariable var cond *ast.Ident if !hasNonNestedReturns { // Generate information for the added bool value. - name, _ := generateAvailableIdentifier(pos, path, pkg, info, "shouldReturn", 0) + name, _ := freshNameOutsideRange(info, file, path[0].Pos(), start, end, "shouldReturn", 0) cond = &ast.Ident{Name: name} retVars = append(retVars, &returnVariable{ name: cond, @@ -1201,24 +1317,43 @@ func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast. } // Generate information for the values in the return signature of the enclosing function. if enclosing.Results != nil { - idx := 0 + nameIdx := make(map[string]int) // last integral suffixes of generated names for _, field := range enclosing.Results.List { typ := info.TypeOf(field.Type) if typ == nil { return nil, nil, fmt.Errorf( "failed type conversion, AST expression: %T", field.Type) } - expr := analysisinternal.TypeExpr(file, pkg, typ) + expr := typesinternal.TypeExpr(file, pkg, typ) if expr == nil { return nil, nil, fmt.Errorf("nil AST expression") } - var name string - name, idx = generateAvailableIdentifier(pos, path, pkg, info, "returnValue", idx) - retVars = append(retVars, &returnVariable{ - name: ast.NewIdent(name), - decl: &ast.Field{Type: expr}, - zeroVal: analysisinternal.ZeroValue(file, pkg, typ), - }) + names := []string{""} + if len(field.Names) > 0 { + names = nil + for _, n := range field.Names { + names = append(names, n.Name) + } + } + for _, name := range names { + bestName := "result" + if name != "" && name != "_" { + bestName = name + } else if n, ok := varNameForType(typ); ok { + bestName = n + } + retName, idx := freshNameOutsideRange(info, file, path[0].Pos(), start, end, bestName, nameIdx[bestName]) + nameIdx[bestName] = idx + z := typesinternal.ZeroExpr(file, pkg, typ) + if z == nil { + return nil, nil, fmt.Errorf("can't generate zero value for %T", typ) + } + retVars = append(retVars, &returnVariable{ + name: ast.NewIdent(retName), + decl: &ast.Field{Type: expr}, + zeroVal: z, + }) + } } } var ifReturn *ast.IfStmt @@ -1235,6 +1370,48 @@ func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast. return retVars, ifReturn, nil } +type objKey struct{ pkg, name string } + +// conventionalVarNames specifies conventional names for variables with various +// standard library types. +// +// Keep this up to date with completion.conventionalAcronyms. +// +// TODO(rfindley): consider factoring out a "conventions" library. +var conventionalVarNames = map[objKey]string{ + {"", "error"}: "err", + {"context", "Context"}: "ctx", + {"sql", "Tx"}: "tx", + {"http", "ResponseWriter"}: "rw", // Note: same as [AbbreviateVarName]. +} + +// varNameForTypeName chooses a "good" name for a variable with the given type, +// if possible. Otherwise, it returns "", false. +// +// For special types, it uses known conventional names. +func varNameForType(t types.Type) (string, bool) { + var typeName string + if tn, ok := t.(interface{ Obj() *types.TypeName }); ok { + obj := tn.Obj() + k := objKey{name: obj.Name()} + if obj.Pkg() != nil { + k.pkg = obj.Pkg().Name() + } + if name, ok := conventionalVarNames[k]; ok { + return name, true + } + typeName = obj.Name() + } else if b, ok := t.(*types.Basic); ok { + typeName = b.Name() + } + + if typeName == "" { + return "", false + } + + return AbbreviateVarName(typeName), true +} + // adjustReturnStatements adds "zero values" of the given types to each return statement // in the given AST node. func adjustReturnStatements(returnTypes []*ast.Field, seenVars map[types.Object]ast.Expr, file *ast.File, pkg *types.Package, extractedBlock *ast.BlockStmt) error { @@ -1246,12 +1423,11 @@ func adjustReturnStatements(returnTypes []*ast.Field, seenVars map[types.Object] if typ != returnType.Type { continue } - val = analysisinternal.ZeroValue(file, pkg, obj.Type()) + val = typesinternal.ZeroExpr(file, pkg, obj.Type()) break } if val == nil { - return fmt.Errorf( - "could not find matching AST expression for %T", returnType.Type) + return fmt.Errorf("could not find matching AST expression for %T", returnType.Type) } zeroVals = append(zeroVals, val) } @@ -1266,7 +1442,7 @@ func adjustReturnStatements(returnTypes []*ast.Field, seenVars map[types.Object] return false } if n, ok := n.(*ast.ReturnStmt); ok { - n.Results = append(zeroVals, n.Results...) + n.Results = slices.Concat(zeroVals, n.Results) return false } return true @@ -1342,9 +1518,8 @@ func initializeVars(uninitialized []types.Object, retVars []*returnVariable, see // Each variable added from a return statement in the selection // must be initialized. for i, retVar := range retVars { - n := retVar.name.(*ast.Ident) valSpec := &ast.ValueSpec{ - Names: []*ast.Ident{n}, + Names: []*ast.Ident{retVar.name}, Type: retVars[i].decl.Type, } genDecl := &ast.GenDecl{ @@ -1382,3 +1557,11 @@ func getDecls(retVars []*returnVariable) []*ast.Field { } return decls } + +func cond[T any](cond bool, t, f T) T { + if cond { + return t + } else { + return f + } +} diff --git a/gopls/internal/golang/extracttofile.go b/gopls/internal/golang/extracttofile.go index ae26738a5c3..cda9cd51e6d 100644 --- a/gopls/internal/golang/extracttofile.go +++ b/gopls/internal/golang/extracttofile.go @@ -97,6 +97,8 @@ func ExtractToNewFile(ctx context.Context, snapshot *cache.Snapshot, fh file.Han if !ok { return nil, bug.Errorf("invalid selection") } + pgf.CheckPos(start) // #70553 + // Inv: start is valid wrt pgf.Tok. // select trailing empty lines offset, err := safetoken.Offset(pgf.Tok, end) @@ -104,7 +106,10 @@ func ExtractToNewFile(ctx context.Context, snapshot *cache.Snapshot, fh file.Han return nil, err } rest := pgf.Src[offset:] - end += token.Pos(len(rest) - len(bytes.TrimLeft(rest, " \t\n"))) + spaces := len(rest) - len(bytes.TrimLeft(rest, " \t\n")) + end += token.Pos(spaces) + pgf.CheckPos(end) // #70553 + // Inv: end is valid wrt pgf.Tok. replaceRange, err := pgf.PosRange(start, end) if err != nil { @@ -133,6 +138,26 @@ func ExtractToNewFile(ctx context.Context, snapshot *cache.Snapshot, fh file.Han } var buf bytes.Buffer + if c := copyrightComment(pgf.File); c != nil { + start, end, err := pgf.NodeOffsets(c) + if err != nil { + return nil, err + } + buf.Write(pgf.Src[start:end]) + // One empty line between copyright header and following. + buf.WriteString("\n\n") + } + + if c := buildConstraintComment(pgf.File); c != nil { + start, end, err := pgf.NodeOffsets(c) + if err != nil { + return nil, err + } + buf.Write(pgf.Src[start:end]) + // One empty line between build constraint and following. + buf.WriteString("\n\n") + } + fmt.Fprintf(&buf, "package %s\n", pgf.File.Name.Name) if len(adds) > 0 { buf.WriteString("import (") @@ -146,15 +171,15 @@ func ExtractToNewFile(ctx context.Context, snapshot *cache.Snapshot, fh file.Han buf.WriteString(")\n") } - newFile, err := chooseNewFile(ctx, snapshot, pgf.URI.Dir().Path(), firstSymbol) + newFile, err := chooseNewFile(ctx, snapshot, pgf.URI.DirPath(), firstSymbol) if err != nil { return nil, fmt.Errorf("%s: %w", errorPrefix, err) } fileStart := pgf.File.FileStart + pgf.CheckPos(fileStart) // #70553 buf.Write(pgf.Src[start-fileStart : end-fileStart]) - // TODO: attempt to duplicate the copyright header, if any. newFileContent, err := format.Source(buf.Bytes()) if err != nil { return nil, err @@ -202,31 +227,42 @@ func selectedToplevelDecls(pgf *parsego.File, start, end token.Pos) (token.Pos, firstName := "" for _, decl := range pgf.File.Decls { if posRangeIntersects(start, end, decl.Pos(), decl.End()) { - var id *ast.Ident - switch v := decl.(type) { + var ( + comment *ast.CommentGroup // (include comment preceding decl) + id *ast.Ident + ) + switch decl := decl.(type) { case *ast.BadDecl: return 0, 0, "", false + case *ast.FuncDecl: // if only selecting keyword "func" or function name, extend selection to the // whole function - if posRangeContains(v.Pos(), v.Name.End(), start, end) { - start, end = v.Pos(), v.End() + if posRangeContains(decl.Pos(), decl.Name.End(), start, end) { + pgf.CheckNode(decl) // #70553 + start, end = decl.Pos(), decl.End() + // Inv: start, end are valid wrt pgf.Tok. } - id = v.Name + comment = decl.Doc + id = decl.Name + case *ast.GenDecl: // selection cannot intersect an import declaration - if v.Tok == token.IMPORT { + if decl.Tok == token.IMPORT { return 0, 0, "", false } // if only selecting keyword "type", "const", or "var", extend selection to the // whole declaration - if v.Tok == token.TYPE && posRangeContains(v.Pos(), v.Pos()+token.Pos(len("type")), start, end) || - v.Tok == token.CONST && posRangeContains(v.Pos(), v.Pos()+token.Pos(len("const")), start, end) || - v.Tok == token.VAR && posRangeContains(v.Pos(), v.Pos()+token.Pos(len("var")), start, end) { - start, end = v.Pos(), v.End() + if decl.Tok == token.TYPE && posRangeContains(decl.Pos(), decl.Pos()+token.Pos(len("type")), start, end) || + decl.Tok == token.CONST && posRangeContains(decl.Pos(), decl.Pos()+token.Pos(len("const")), start, end) || + decl.Tok == token.VAR && posRangeContains(decl.Pos(), decl.Pos()+token.Pos(len("var")), start, end) { + pgf.CheckNode(decl) // #70553 + start, end = decl.Pos(), decl.End() + // Inv: start, end are valid wrt pgf.Tok. } - if len(v.Specs) > 0 { - switch spec := v.Specs[0].(type) { + comment = decl.Doc + if len(decl.Specs) > 0 { + switch spec := decl.Specs[0].(type) { case *ast.TypeSpec: id = spec.Name case *ast.ValueSpec: @@ -242,16 +278,10 @@ func selectedToplevelDecls(pgf *parsego.File, start, end token.Pos) (token.Pos, // may be "_" firstName = id.Name } - // extends selection to docs comments - var c *ast.CommentGroup - switch decl := decl.(type) { - case *ast.GenDecl: - c = decl.Doc - case *ast.FuncDecl: - c = decl.Doc - } - if c != nil && c.Pos() < start { - start = c.Pos() + if comment != nil && comment.Pos() < start { + pgf.CheckNode(comment) // #70553 + start = comment.Pos() + // Inv: start is valid wrt pgf.Tok. } } } diff --git a/gopls/internal/golang/fix.go b/gopls/internal/golang/fix.go index a20658fce7c..f88343f029c 100644 --- a/gopls/internal/golang/fix.go +++ b/gopls/internal/golang/fix.go @@ -14,7 +14,6 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/gopls/internal/analysis/embeddirective" "golang.org/x/tools/gopls/internal/analysis/fillstruct" - "golang.org/x/tools/gopls/internal/analysis/undeclaredname" "golang.org/x/tools/gopls/internal/analysis/unusedparams" "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/cache/parsego" @@ -36,7 +35,8 @@ import ( // The supplied token positions (start, end) must belong to // pkg.FileSet(), and the returned positions // (SuggestedFix.TextEdits[*].{Pos,End}) must belong to the returned -// FileSet. +// FileSet, which is not necessarily the same. +// (See [insertDeclsAfter] for explanation.) // // A fixer may return (nil, nil) if no fix is available. type fixer func(ctx context.Context, s *cache.Snapshot, pkg *cache.Package, pgf *parsego.File, start, end token.Pos) (*token.FileSet, *analysis.SuggestedFix, error) @@ -58,13 +58,14 @@ func singleFile(fixer1 singleFileFixer) fixer { // Names of ApplyFix.Fix created directly by the CodeAction handler. const ( - fixExtractVariable = "extract_variable" + fixExtractVariable = "extract_variable" // (or constant) fixExtractFunction = "extract_function" fixExtractMethod = "extract_method" fixInlineCall = "inline_call" fixInvertIfCondition = "invert_if_condition" fixSplitLines = "split_lines" fixJoinLines = "join_lines" + fixCreateUndeclared = "create_undeclared" fixMissingInterfaceMethods = "stub_missing_interface_method" fixMissingCalledFunction = "stub_missing_called_function" ) @@ -91,7 +92,7 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file // NarrowestPackageForFile/RangePos/suggestedFixToEdits // steps.) if fix == unusedparams.FixCategory { - return RemoveUnusedParameter(ctx, fh, rng, snapshot) + return removeParam(ctx, snapshot, fh, rng) } fixers := map[string]fixer{ @@ -99,7 +100,6 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file // These match the Diagnostic.Category. embeddirective.FixCategory: addEmbedImport, fillstruct.FixCategory: singleFile(fillstruct.SuggestedFix), - undeclaredname.FixCategory: singleFile(undeclaredname.SuggestedFix), // Ad-hoc fixers: these are used when the command is // constructed directly by logic in server/code_action. @@ -110,6 +110,7 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file fixInvertIfCondition: singleFile(invertIfCondition), fixSplitLines: singleFile(splitLines), fixJoinLines: singleFile(joinLines), + fixCreateUndeclared: singleFile(CreateUndeclared), fixMissingInterfaceMethods: stubMissingInterfaceMethodsFixer, fixMissingCalledFunction: stubMissingCalledFunctionFixer, } diff --git a/gopls/internal/golang/folding_range.go b/gopls/internal/golang/folding_range.go index 85faea5e31a..c61802d1b58 100644 --- a/gopls/internal/golang/folding_range.go +++ b/gopls/internal/golang/folding_range.go @@ -5,6 +5,7 @@ package golang import ( + "bytes" "context" "go/ast" "go/token" @@ -73,15 +74,12 @@ func foldingRangeFunc(pgf *parsego.File, n ast.Node, lineFoldingOnly bool) *Fold // TODO(suzmue): include trailing empty lines before the closing // parenthesis/brace. var kind protocol.FoldingRangeKind + // start and end define the range of content to fold away. var start, end token.Pos switch n := n.(type) { case *ast.BlockStmt: // Fold between positions of or lines between "{" and "}". - var startList, endList token.Pos - if num := len(n.List); num != 0 { - startList, endList = n.List[0].Pos(), n.List[num-1].End() - } - start, end = validLineFoldingRange(pgf.Tok, n.Lbrace, n.Rbrace, startList, endList, lineFoldingOnly) + start, end = getLineFoldingRange(pgf, n.Lbrace, n.Rbrace, lineFoldingOnly) case *ast.CaseClause: // Fold from position of ":" to end. start, end = n.Colon+1, n.End() @@ -89,26 +87,18 @@ func foldingRangeFunc(pgf *parsego.File, n ast.Node, lineFoldingOnly bool) *Fold // Fold from position of ":" to end. start, end = n.Colon+1, n.End() case *ast.CallExpr: - // Fold from position of "(" to position of ")". - start, end = n.Lparen+1, n.Rparen + // Fold between positions of or lines between "(" and ")". + start, end = getLineFoldingRange(pgf, n.Lparen, n.Rparen, lineFoldingOnly) case *ast.FieldList: // Fold between positions of or lines between opening parenthesis/brace and closing parenthesis/brace. - var startList, endList token.Pos - if num := len(n.List); num != 0 { - startList, endList = n.List[0].Pos(), n.List[num-1].End() - } - start, end = validLineFoldingRange(pgf.Tok, n.Opening, n.Closing, startList, endList, lineFoldingOnly) + start, end = getLineFoldingRange(pgf, n.Opening, n.Closing, lineFoldingOnly) case *ast.GenDecl: // If this is an import declaration, set the kind to be protocol.Imports. if n.Tok == token.IMPORT { kind = protocol.Imports } // Fold between positions of or lines between "(" and ")". - var startSpecs, endSpecs token.Pos - if num := len(n.Specs); num != 0 { - startSpecs, endSpecs = n.Specs[0].Pos(), n.Specs[num-1].End() - } - start, end = validLineFoldingRange(pgf.Tok, n.Lparen, n.Rparen, startSpecs, endSpecs, lineFoldingOnly) + start, end = getLineFoldingRange(pgf, n.Lparen, n.Rparen, lineFoldingOnly) case *ast.BasicLit: // Fold raw string literals from position of "`" to position of "`". if n.Kind == token.STRING && len(n.Value) >= 2 && n.Value[0] == '`' && n.Value[len(n.Value)-1] == '`' { @@ -116,24 +106,25 @@ func foldingRangeFunc(pgf *parsego.File, n ast.Node, lineFoldingOnly bool) *Fold } case *ast.CompositeLit: // Fold between positions of or lines between "{" and "}". - var startElts, endElts token.Pos - if num := len(n.Elts); num != 0 { - startElts, endElts = n.Elts[0].Pos(), n.Elts[num-1].End() - } - start, end = validLineFoldingRange(pgf.Tok, n.Lbrace, n.Rbrace, startElts, endElts, lineFoldingOnly) + start, end = getLineFoldingRange(pgf, n.Lbrace, n.Rbrace, lineFoldingOnly) } // Check that folding positions are valid. if !start.IsValid() || !end.IsValid() { return nil } + if start == end { + // Nothing to fold. + return nil + } // in line folding mode, do not fold if the start and end lines are the same. if lineFoldingOnly && safetoken.Line(pgf.Tok, start) == safetoken.Line(pgf.Tok, end) { return nil } mrng, err := pgf.PosMappedRange(start, end) if err != nil { - bug.Errorf("%w", err) // can't happen + bug.Reportf("failed to create mapped range: %s", err) // can't happen + return nil } return &FoldingRangeInfo{ MappedRange: mrng, @@ -141,26 +132,67 @@ func foldingRangeFunc(pgf *parsego.File, n ast.Node, lineFoldingOnly bool) *Fold } } -// validLineFoldingRange returns start and end token.Pos for folding range if the range is valid. -// returns token.NoPos otherwise, which fails token.IsValid check -func validLineFoldingRange(tokFile *token.File, open, close, start, end token.Pos, lineFoldingOnly bool) (token.Pos, token.Pos) { - if lineFoldingOnly { - if !open.IsValid() || !close.IsValid() { - return token.NoPos, token.NoPos - } +// getLineFoldingRange returns the folding range for nodes with parentheses/braces/brackets +// that potentially can take up multiple lines. +func getLineFoldingRange(pgf *parsego.File, open, close token.Pos, lineFoldingOnly bool) (token.Pos, token.Pos) { + if !open.IsValid() || !close.IsValid() { + return token.NoPos, token.NoPos + } + if open+1 == close { + // Nothing to fold: (), {} or []. + return token.NoPos, token.NoPos + } + + if !lineFoldingOnly { + // Can fold between opening and closing parenthesis/brace + // even if they are on the same line. + return open + 1, close + } - // Don't want to fold if the start/end is on the same line as the open/close - // as an example, the example below should *not* fold: - // var x = [2]string{"d", - // "e" } - if safetoken.Line(tokFile, open) == safetoken.Line(tokFile, start) || - safetoken.Line(tokFile, close) == safetoken.Line(tokFile, end) { - return token.NoPos, token.NoPos + // Clients with "LineFoldingOnly" set to true can fold only full lines. + // So, we return a folding range only when the closing parenthesis/brace + // and the end of the last argument/statement/element are on different lines. + // + // We could skip the check for the opening parenthesis/brace and start of + // the first argument/statement/element. For example, the following code + // + // var x = []string{"a", + // "b", + // "c" } + // + // can be folded to + // + // var x = []string{"a", ... + // "c" } + // + // However, this might look confusing. So, check the lines of "open" and + // "start" positions as well. + + // isOnlySpaceBetween returns true if there are only space characters between "from" and "to". + isOnlySpaceBetween := func(from token.Pos, to token.Pos) bool { + start, end, err := safetoken.Offsets(pgf.Tok, from, to) + if err != nil { + bug.Reportf("failed to get offsets: %s", err) // can't happen + return false } + return len(bytes.TrimSpace(pgf.Src[start:end])) == 0 + } - return open + 1, end + nextLine := safetoken.Line(pgf.Tok, open) + 1 + if nextLine > pgf.Tok.LineCount() { + return token.NoPos, token.NoPos } - return open + 1, close + nextLineStart := pgf.Tok.LineStart(nextLine) + if !isOnlySpaceBetween(open+1, nextLineStart) { + return token.NoPos, token.NoPos + } + + prevLineEnd := pgf.Tok.LineStart(safetoken.Line(pgf.Tok, close)) - 1 // there must be a previous line + if !isOnlySpaceBetween(prevLineEnd, close) { + return token.NoPos, token.NoPos + } + + return open + 1, prevLineEnd } // commentsFoldingRange returns the folding ranges for all comment blocks in file. @@ -185,7 +217,8 @@ func commentsFoldingRange(pgf *parsego.File) (comments []*FoldingRangeInfo) { } mrng, err := pgf.PosMappedRange(endLinePos, commentGrp.End()) if err != nil { - bug.Errorf("%w", err) // can't happen + bug.Reportf("failed to create mapped range: %s", err) // can't happen + continue } comments = append(comments, &FoldingRangeInfo{ // Fold from the end of the first line comment to the end of the comment block. diff --git a/gopls/internal/golang/gc_annotations.go b/gopls/internal/golang/gc_annotations.go index 03db9e74760..618216f6306 100644 --- a/gopls/internal/golang/gc_annotations.go +++ b/gopls/internal/golang/gc_annotations.go @@ -18,7 +18,6 @@ import ( "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/settings" "golang.org/x/tools/internal/event" - "golang.org/x/tools/internal/gocommand" ) // GCOptimizationDetails invokes the Go compiler on the specified @@ -33,7 +32,7 @@ func GCOptimizationDetails(ctx context.Context, snapshot *cache.Snapshot, mp *me if len(mp.CompiledGoFiles) == 0 { return nil, nil } - pkgDir := filepath.Dir(mp.CompiledGoFiles[0].Path()) + pkgDir := mp.CompiledGoFiles[0].DirPath() outDir, err := os.MkdirTemp("", fmt.Sprintf("gopls-%d.details", os.Getpid())) if err != nil { return nil, err @@ -57,14 +56,10 @@ func GCOptimizationDetails(ctx context.Context, snapshot *cache.Snapshot, mp *me if !strings.HasPrefix(outDir, "/") { outDirURI = protocol.DocumentURI(strings.Replace(string(outDirURI), "file:///", "file://", 1)) } - inv, cleanupInvocation, err := snapshot.GoCommandInvocation(false, &gocommand.Invocation{ - Verb: "build", - Args: []string{ - fmt.Sprintf("-gcflags=-json=0,%s", outDirURI), - fmt.Sprintf("-o=%s", tmpFile.Name()), - ".", - }, - WorkingDir: pkgDir, + inv, cleanupInvocation, err := snapshot.GoCommandInvocation(cache.NoNetwork, pkgDir, "build", []string{ + fmt.Sprintf("-gcflags=-json=0,%s", outDirURI), + fmt.Sprintf("-o=%s", tmpFile.Name()), + ".", }) if err != nil { return nil, err @@ -91,7 +86,7 @@ func GCOptimizationDetails(ctx context.Context, snapshot *cache.Snapshot, mp *me if fh == nil { continue } - if pkgDir != filepath.Dir(fh.URI().Path()) { + if pkgDir != fh.URI().DirPath() { // https://github.com/golang/go/issues/42198 // sometimes the detail diagnostics generated for files // outside the package can never be taken back. diff --git a/gopls/internal/golang/hover.go b/gopls/internal/golang/hover.go index 8e7febeaab3..3356a7db43a 100644 --- a/gopls/internal/golang/hover.go +++ b/gopls/internal/golang/hover.go @@ -72,9 +72,12 @@ type hoverJSON struct { // SymbolName is the human-readable name to use for the symbol in links. SymbolName string `json:"symbolName"` - // LinkPath is the pkg.go.dev link for the given symbol. - // For example, the "go/ast" part of "pkg.go.dev/go/ast#Node". - // It may have a module version suffix "@v1.2.3". + // LinkPath is the path of the package enclosing the given symbol, + // with the module portion (if any) replaced by "module@version". + // + // For example: "github.com/google/go-github/v48@v48.1.0/github". + // + // Use LinkTarget + "/" + LinkPath + "#" + LinkAnchor to form a pkgsite URL. LinkPath string `json:"linkPath"` // LinkAnchor is the pkg.go.dev link anchor for the given symbol. @@ -344,6 +347,13 @@ func hover(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, pp pro // if they embed platform-variant types. // var sizeOffset string // optional size/offset description + // debugging #69362: unexpected nil Defs[ident] value (?) + _ = ident.Pos() // (can't be nil due to check after referencedObject) + _ = pkg.TypesInfo() // (can't be nil due to check in call to inferredSignature) + _ = pkg.TypesInfo().Defs // (can't be nil due to nature of cache.Package) + if def, ok := pkg.TypesInfo().Defs[ident]; ok { + _ = def.Pos() // can't be nil due to reasoning in #69362. + } if def, ok := pkg.TypesInfo().Defs[ident]; ok && ident.Pos() == def.Pos() { // This is the declaring identifier. // (We can't simply use ident.Pos() == obj.Pos() because @@ -930,7 +940,7 @@ func hoverLit(pgf *parsego.File, lit *ast.BasicLit, pos token.Pos) (protocol.Ran func hoverEmbed(fh file.Handle, rng protocol.Range, pattern string) (protocol.Range, *hoverJSON, error) { s := &strings.Builder{} - dir := filepath.Dir(fh.URI().Path()) + dir := fh.URI().DirPath() var matches []string err := filepath.WalkDir(dir, func(abs string, d fs.DirEntry, e error) error { if e != nil { @@ -1360,7 +1370,16 @@ func formatLink(h *hoverJSON, options *settings.Options, pkgURL func(path Packag var url protocol.URI var caption string if pkgURL != nil { // LinksInHover == "gopls" - path, _, _ := strings.Cut(h.LinkPath, "@") // remove optional module version suffix + // Discard optional module version portion. + // (Ideally the hoverJSON would retain the structure...) + path := h.LinkPath + if module, versionDir, ok := strings.Cut(h.LinkPath, "@"); ok { + // "module@version/dir" + path = module + if _, dir, ok := strings.Cut(versionDir, "/"); ok { + path += "/" + dir + } + } url = pkgURL(PackagePath(path), h.LinkAnchor) caption = "in gopls doc viewer" } else { @@ -1620,7 +1639,12 @@ func computeSizeOffsetInfo(pkg *cache.Package, path []ast.Node, obj types.Object var tStruct *types.Struct for _, n := range path { if n, ok := n.(*ast.StructType); ok { - tStruct = pkg.TypesInfo().TypeOf(n).(*types.Struct) + t, ok := pkg.TypesInfo().TypeOf(n).(*types.Struct) + if ok { + // golang/go#69150: TypeOf(n) was observed not to be a Struct (likely + // nil) in some cases. + tStruct = t + } break } } diff --git a/gopls/internal/golang/inline_all.go b/gopls/internal/golang/inline_all.go index addfe2bc250..ec9a458d61a 100644 --- a/gopls/internal/golang/inline_all.go +++ b/gopls/internal/golang/inline_all.go @@ -44,7 +44,7 @@ import ( // // The code below notes where are assumptions are made that only hold true in // the case of parameter removal (annotated with 'Assumption:') -func inlineAllCalls(ctx context.Context, logf func(string, ...any), snapshot *cache.Snapshot, pkg *cache.Package, pgf *parsego.File, origDecl *ast.FuncDecl, callee *inline.Callee, post func([]byte) []byte) (map[protocol.DocumentURI][]byte, error) { +func inlineAllCalls(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, pgf *parsego.File, origDecl *ast.FuncDecl, callee *inline.Callee, post func([]byte) []byte, opts *inline.Options) (map[protocol.DocumentURI][]byte, error) { // Collect references. var refs []protocol.Location { @@ -112,6 +112,7 @@ func inlineAllCalls(ctx context.Context, logf func(string, ...any), snapshot *ca if err != nil { return nil, bug.Errorf("finding %s in %s: %v", ref.URI, refpkg.Metadata().ID, err) } + start, end, err := pgf.RangePos(ref.Range) if err != nil { return nil, err // e.g. invalid range @@ -124,6 +125,8 @@ func inlineAllCalls(ctx context.Context, logf func(string, ...any), snapshot *ca ) path, _ := astutil.PathEnclosingInterval(pgf.File, start, end) name, _ = path[0].(*ast.Ident) + + // TODO(rfindley): handle method expressions correctly. if _, ok := path[1].(*ast.SelectorExpr); ok { call, _ = path[2].(*ast.CallExpr) } else { @@ -137,11 +140,30 @@ func inlineAllCalls(ctx context.Context, logf func(string, ...any), snapshot *ca // use(func(...) { f(...) }) return nil, fmt.Errorf("cannot inline: found non-call function reference %v", ref) } + + // Heuristic: ignore references that overlap with type checker errors, as they may + // lead to invalid results (see golang/go#70268). + hasTypeErrors := false + for _, typeErr := range refpkg.TypeErrors() { + if call.Lparen <= typeErr.Pos && typeErr.Pos <= call.Rparen { + hasTypeErrors = true + } + } + + if hasTypeErrors { + continue + } + + if typeutil.StaticCallee(refpkg.TypesInfo(), call) == nil { + continue // dynamic call + } + // Sanity check. if obj := refpkg.TypesInfo().ObjectOf(name); obj == nil || obj.Name() != origDecl.Name.Name || obj.Pkg() == nil || obj.Pkg().Path() != string(pkg.Metadata().PkgPath) { + return nil, bug.Errorf("cannot inline: corrupted reference %v", ref) } @@ -193,7 +215,7 @@ func inlineAllCalls(ctx context.Context, logf func(string, ...any), snapshot *ca Call: calls[currentCall], Content: content, } - res, err := inline.Inline(caller, callee, &inline.Options{Logf: logf}) + res, err := inline.Inline(caller, callee, opts) if err != nil { return nil, fmt.Errorf("inlining failed: %v", err) } @@ -230,6 +252,10 @@ func inlineAllCalls(ctx context.Context, logf func(string, ...any), snapshot *ca // anything in the surrounding scope. // // TODO(rfindley): improve this. + logf := func(string, ...any) {} + if opts != nil { + logf = opts.Logf + } tpkg, tinfo, err = reTypeCheck(logf, callInfo.pkg, map[protocol.DocumentURI]*ast.File{uri: file}, true) if err != nil { return nil, bug.Errorf("type checking after inlining failed: %v", err) diff --git a/gopls/internal/golang/lines.go b/gopls/internal/golang/lines.go index 6a17e928b34..b6a9823957d 100644 --- a/gopls/internal/golang/lines.go +++ b/gopls/internal/golang/lines.go @@ -151,6 +151,15 @@ func processLines(fset *token.FileSet, items []ast.Node, comments []*ast.Comment } edits = append(edits, analysis.TextEdit{Pos: pos, End: end, NewText: []byte(sep + indent)}) + + // Print the Ellipsis if we synthesized one earlier. + if is[*ast.Ellipsis](nodes[i]) { + edits = append(edits, analysis.TextEdit{ + Pos: nodes[i].End(), + End: nodes[i].End(), + NewText: []byte("..."), + }) + } } return &analysis.SuggestedFix{TextEdits: edits} @@ -205,6 +214,18 @@ func findSplitJoinTarget(fset *token.FileSet, file *ast.File, src []byte, start, for _, arg := range node.Args { items = append(items, arg) } + + // Preserve "..." by wrapping the last + // argument in an Ellipsis node + // with the same Pos/End as the argument. + // See corresponding logic in processLines. + if node.Ellipsis.IsValid() { + last := &items[len(items)-1] + *last = &ast.Ellipsis{ + Ellipsis: (*last).Pos(), // determines Ellipsis.Pos() + Elt: (*last).(ast.Expr), // determines Ellipsis.End() + } + } case *ast.CompositeLit: for _, arg := range node.Elts { items = append(items, arg) diff --git a/gopls/internal/golang/rename.go b/gopls/internal/golang/rename.go index 7ff5857f186..914cd2b66ed 100644 --- a/gopls/internal/golang/rename.go +++ b/gopls/internal/golang/rename.go @@ -42,10 +42,13 @@ package golang // - FileID-based de-duplication of edits to different URIs for the same file. import ( + "bytes" "context" "errors" "fmt" "go/ast" + "go/parser" + "go/printer" "go/token" "go/types" "path" @@ -64,8 +67,10 @@ import ( "golang.org/x/tools/gopls/internal/cache/parsego" "golang.org/x/tools/gopls/internal/file" "golang.org/x/tools/gopls/internal/protocol" + goplsastutil "golang.org/x/tools/gopls/internal/util/astutil" "golang.org/x/tools/gopls/internal/util/bug" "golang.org/x/tools/gopls/internal/util/safetoken" + internalastutil "golang.org/x/tools/internal/astutil" "golang.org/x/tools/internal/diff" "golang.org/x/tools/internal/event" "golang.org/x/tools/internal/typesinternal" @@ -126,6 +131,15 @@ func PrepareRename(ctx context.Context, snapshot *cache.Snapshot, f file.Handle, if err != nil { return nil, nil, err } + + // Check if we're in a 'func' keyword. If so, we hijack the renaming to + // change the function signature. + if item, err := prepareRenameFuncSignature(pgf, pos); err != nil { + return nil, nil, err + } else if item != nil { + return item, nil, nil + } + targets, node, err := objectsAt(pkg.TypesInfo(), pgf.File, pos) if err != nil { return nil, nil, err @@ -193,6 +207,169 @@ func prepareRenamePackageName(ctx context.Context, snapshot *cache.Snapshot, pgf }, nil } +// prepareRenameFuncSignature prepares a change signature refactoring initiated +// through invoking a rename request at the 'func' keyword of a function +// declaration. +// +// The resulting text is the signature of the function, which may be edited to +// the new signature. +func prepareRenameFuncSignature(pgf *parsego.File, pos token.Pos) (*PrepareItem, error) { + fdecl := funcKeywordDecl(pgf, pos) + if fdecl == nil { + return nil, nil + } + ftyp := nameBlankParams(fdecl.Type) + var buf bytes.Buffer + if err := printer.Fprint(&buf, token.NewFileSet(), ftyp); err != nil { // use a new fileset so that the signature is formatted on a single line + return nil, err + } + rng, err := pgf.PosRange(ftyp.Func, ftyp.Func+token.Pos(len("func"))) + if err != nil { + return nil, err + } + text := buf.String() + return &PrepareItem{ + Range: rng, + Text: text, + }, nil +} + +// nameBlankParams returns a copy of ftype with blank or unnamed params +// assigned a unique name. +func nameBlankParams(ftype *ast.FuncType) *ast.FuncType { + ftype = internalastutil.CloneNode(ftype) + + // First, collect existing names. + scope := make(map[string]bool) + for name := range goplsastutil.FlatFields(ftype.Params) { + if name != nil { + scope[name.Name] = true + } + } + blanks := 0 + for name, field := range goplsastutil.FlatFields(ftype.Params) { + if name == nil { + name = ast.NewIdent("_") + field.Names = append(field.Names, name) // ok to append + } + if name.Name == "" || name.Name == "_" { + for { + newName := fmt.Sprintf("_%d", blanks) + blanks++ + if !scope[newName] { + name.Name = newName + break + } + } + } + } + return ftype +} + +// renameFuncSignature computes and applies the effective change signature +// operation resulting from a 'renamed' (=rewritten) signature. +func renameFuncSignature(ctx context.Context, snapshot *cache.Snapshot, f file.Handle, pp protocol.Position, newName string) (map[protocol.DocumentURI][]protocol.TextEdit, error) { + // Find the renamed signature. + pkg, pgf, err := NarrowestPackageForFile(ctx, snapshot, f.URI()) + if err != nil { + return nil, err + } + pos, err := pgf.PositionPos(pp) + if err != nil { + return nil, err + } + fdecl := funcKeywordDecl(pgf, pos) + if fdecl == nil { + return nil, nil + } + ftyp := nameBlankParams(fdecl.Type) + + // Parse the user's requested new signature. + parsed, err := parser.ParseExpr(newName) + if err != nil { + return nil, err + } + newType, _ := parsed.(*ast.FuncType) + if newType == nil { + return nil, fmt.Errorf("parsed signature is %T, not a function type", parsed) + } + + // Check results, before we get into handling permutations of parameters. + if got, want := newType.Results.NumFields(), ftyp.Results.NumFields(); got != want { + return nil, fmt.Errorf("changing results not yet supported (got %d results, want %d)", got, want) + } + var resultTypes []string + for _, field := range goplsastutil.FlatFields(ftyp.Results) { + resultTypes = append(resultTypes, FormatNode(token.NewFileSet(), field.Type)) + } + resultIndex := 0 + for _, field := range goplsastutil.FlatFields(newType.Results) { + if FormatNode(token.NewFileSet(), field.Type) != resultTypes[resultIndex] { + return nil, fmt.Errorf("changing results not yet supported") + } + resultIndex++ + } + + type paramInfo struct { + idx int + typ string + } + oldParams := make(map[string]paramInfo) + for name, field := range goplsastutil.FlatFields(ftyp.Params) { + oldParams[name.Name] = paramInfo{ + idx: len(oldParams), + typ: types.ExprString(field.Type), + } + } + + var newParams []int + for name, field := range goplsastutil.FlatFields(newType.Params) { + if name == nil { + return nil, fmt.Errorf("need named fields") + } + info, ok := oldParams[name.Name] + if !ok { + return nil, fmt.Errorf("couldn't find name %s: adding parameters not yet supported", name) + } + if newType := types.ExprString(field.Type); newType != info.typ { + return nil, fmt.Errorf("changing types (%s to %s) not yet supported", info.typ, newType) + } + newParams = append(newParams, info.idx) + } + + rng, err := pgf.PosRange(ftyp.Func, ftyp.Func) + if err != nil { + return nil, err + } + changes, err := ChangeSignature(ctx, snapshot, pkg, pgf, rng, newParams) + if err != nil { + return nil, err + } + transposed := make(map[protocol.DocumentURI][]protocol.TextEdit) + for _, change := range changes { + transposed[change.TextDocumentEdit.TextDocument.URI] = protocol.AsTextEdits(change.TextDocumentEdit.Edits) + } + return transposed, nil +} + +// funcKeywordDecl returns the FuncDecl for which pos is in the 'func' keyword, +// if any. +func funcKeywordDecl(pgf *parsego.File, pos token.Pos) *ast.FuncDecl { + path, _ := astutil.PathEnclosingInterval(pgf.File, pos, pos) + if len(path) < 1 { + return nil + } + fdecl, _ := path[0].(*ast.FuncDecl) + if fdecl == nil { + return nil + } + ftyp := fdecl.Type + if pos < ftyp.Func || pos > ftyp.Func+token.Pos(len("func")) { // tolerate renaming immediately after 'func' + return nil + } + return fdecl +} + func checkRenamable(obj types.Object) error { switch obj := obj.(type) { case *types.Var: @@ -219,6 +396,12 @@ func Rename(ctx context.Context, snapshot *cache.Snapshot, f file.Handle, pp pro ctx, done := event.Start(ctx, "golang.Rename") defer done() + if edits, err := renameFuncSignature(ctx, snapshot, f, pp, newName); err != nil { + return nil, false, err + } else if edits != nil { + return edits, false, nil + } + if !isValidIdentifier(newName) { return nil, false, fmt.Errorf("invalid identifier to rename: %q", newName) } @@ -605,7 +788,7 @@ func renamePackageName(ctx context.Context, s *cache.Snapshot, f file.Handle, ne } // Update the last component of the file's enclosing directory. - oldBase := filepath.Dir(f.URI().Path()) + oldBase := f.URI().DirPath() newPkgDir := filepath.Join(filepath.Dir(oldBase), string(newName)) // Update any affected replace directives in go.mod files. @@ -625,7 +808,7 @@ func renamePackageName(ctx context.Context, s *cache.Snapshot, f file.Handle, ne return nil, err } - modFileDir := filepath.Dir(pm.URI.Path()) + modFileDir := pm.URI.DirPath() affectedReplaces := []*modfile.Replace{} // Check if any replace directives need to be fixed diff --git a/gopls/internal/golang/semtok.go b/gopls/internal/golang/semtok.go index 4e24dafc23f..2043f9aaacc 100644 --- a/gopls/internal/golang/semtok.go +++ b/gopls/internal/golang/semtok.go @@ -109,9 +109,9 @@ type tokenVisitor struct { func (tv *tokenVisitor) visit() { f := tv.pgf.File // may not be in range, but harmless - tv.token(f.Package, len("package"), semtok.TokKeyword, nil) + tv.token(f.Package, len("package"), semtok.TokKeyword) if f.Name != nil { - tv.token(f.Name.NamePos, len(f.Name.Name), semtok.TokNamespace, nil) + tv.token(f.Name.NamePos, len(f.Name.Name), semtok.TokNamespace) } for _, decl := range f.Decls { // Only look at the decls that overlap the range. @@ -208,21 +208,6 @@ func (tv *tokenVisitor) comment(c *ast.Comment, importByName map[string]*types.P } } - tokenTypeByObject := func(obj types.Object) (semtok.TokenType, []string) { - switch obj.(type) { - case *types.PkgName: - return semtok.TokNamespace, nil - case *types.Func: - return semtok.TokFunction, nil - case *types.TypeName: - return semtok.TokType, appendTypeModifiers(nil, obj) - case *types.Const, *types.Var: - return semtok.TokVariable, nil - default: - return semtok.TokComment, nil - } - } - pos := c.Pos() for _, line := range strings.Split(c.Text, "\n") { last := 0 @@ -232,32 +217,32 @@ func (tv *tokenVisitor) comment(c *ast.Comment, importByName map[string]*types.P name := line[idx[2]:idx[3]] if objs := lookupObjects(name); len(objs) > 0 { if last < idx[2] { - tv.token(pos+token.Pos(last), idx[2]-last, semtok.TokComment, nil) + tv.token(pos+token.Pos(last), idx[2]-last, semtok.TokComment) } offset := pos + token.Pos(idx[2]) for i, obj := range objs { if i > 0 { - tv.token(offset, len("."), semtok.TokComment, nil) + tv.token(offset, len("."), semtok.TokComment) offset += token.Pos(len(".")) } id, rest, _ := strings.Cut(name, ".") name = rest - tok, mods := tokenTypeByObject(obj) - tv.token(offset, len(id), tok, mods) + tok, mods := tv.appendObjectModifiers(nil, obj) + tv.token(offset, len(id), tok, mods...) offset += token.Pos(len(id)) } last = idx[3] } } if last != len(c.Text) { - tv.token(pos+token.Pos(last), len(line)-last, semtok.TokComment, nil) + tv.token(pos+token.Pos(last), len(line)-last, semtok.TokComment) } pos += token.Pos(len(line) + 1) } } // token emits a token of the specified extent and semantics. -func (tv *tokenVisitor) token(start token.Pos, length int, typ semtok.TokenType, modifiers []string) { +func (tv *tokenVisitor) token(start token.Pos, length int, typ semtok.TokenType, modifiers ...semtok.Modifier) { if !start.IsValid() { return } @@ -338,7 +323,7 @@ func (tv *tokenVisitor) inspect(n ast.Node) (descend bool) { switch n := n.(type) { case *ast.ArrayType: case *ast.AssignStmt: - tv.token(n.TokPos, len(n.Tok.String()), semtok.TokOperator, nil) + tv.token(n.TokPos, len(n.Tok.String()), semtok.TokOperator) case *ast.BasicLit: if strings.Contains(n.Value, "\n") { // has to be a string. @@ -349,123 +334,119 @@ func (tv *tokenVisitor) inspect(n ast.Node) (descend bool) { if n.Kind == token.STRING { what = semtok.TokString } - tv.token(n.Pos(), len(n.Value), what, nil) + tv.token(n.Pos(), len(n.Value), what) case *ast.BinaryExpr: - tv.token(n.OpPos, len(n.Op.String()), semtok.TokOperator, nil) + tv.token(n.OpPos, len(n.Op.String()), semtok.TokOperator) case *ast.BlockStmt: case *ast.BranchStmt: - tv.token(n.TokPos, len(n.Tok.String()), semtok.TokKeyword, nil) - if n.Label != nil { - tv.token(n.Label.Pos(), len(n.Label.Name), semtok.TokLabel, nil) - } + tv.token(n.TokPos, len(n.Tok.String()), semtok.TokKeyword) case *ast.CallExpr: if n.Ellipsis.IsValid() { - tv.token(n.Ellipsis, len("..."), semtok.TokOperator, nil) + tv.token(n.Ellipsis, len("..."), semtok.TokOperator) } case *ast.CaseClause: iam := "case" if n.List == nil { iam = "default" } - tv.token(n.Case, len(iam), semtok.TokKeyword, nil) + tv.token(n.Case, len(iam), semtok.TokKeyword) case *ast.ChanType: // chan | chan <- | <- chan switch { case n.Arrow == token.NoPos: - tv.token(n.Begin, len("chan"), semtok.TokKeyword, nil) + tv.token(n.Begin, len("chan"), semtok.TokKeyword) case n.Arrow == n.Begin: - tv.token(n.Arrow, 2, semtok.TokOperator, nil) + tv.token(n.Arrow, 2, semtok.TokOperator) pos := tv.findKeyword("chan", n.Begin+2, n.Value.Pos()) - tv.token(pos, len("chan"), semtok.TokKeyword, nil) + tv.token(pos, len("chan"), semtok.TokKeyword) case n.Arrow != n.Begin: - tv.token(n.Begin, len("chan"), semtok.TokKeyword, nil) - tv.token(n.Arrow, 2, semtok.TokOperator, nil) + tv.token(n.Begin, len("chan"), semtok.TokKeyword) + tv.token(n.Arrow, 2, semtok.TokOperator) } case *ast.CommClause: length := len("case") if n.Comm == nil { length = len("default") } - tv.token(n.Case, length, semtok.TokKeyword, nil) + tv.token(n.Case, length, semtok.TokKeyword) case *ast.CompositeLit: case *ast.DeclStmt: case *ast.DeferStmt: - tv.token(n.Defer, len("defer"), semtok.TokKeyword, nil) + tv.token(n.Defer, len("defer"), semtok.TokKeyword) case *ast.Ellipsis: - tv.token(n.Ellipsis, len("..."), semtok.TokOperator, nil) + tv.token(n.Ellipsis, len("..."), semtok.TokOperator) case *ast.EmptyStmt: case *ast.ExprStmt: case *ast.Field: case *ast.FieldList: case *ast.ForStmt: - tv.token(n.For, len("for"), semtok.TokKeyword, nil) + tv.token(n.For, len("for"), semtok.TokKeyword) case *ast.FuncDecl: case *ast.FuncLit: case *ast.FuncType: if n.Func != token.NoPos { - tv.token(n.Func, len("func"), semtok.TokKeyword, nil) + tv.token(n.Func, len("func"), semtok.TokKeyword) } case *ast.GenDecl: - tv.token(n.TokPos, len(n.Tok.String()), semtok.TokKeyword, nil) + tv.token(n.TokPos, len(n.Tok.String()), semtok.TokKeyword) case *ast.GoStmt: - tv.token(n.Go, len("go"), semtok.TokKeyword, nil) + tv.token(n.Go, len("go"), semtok.TokKeyword) case *ast.Ident: tv.ident(n) case *ast.IfStmt: - tv.token(n.If, len("if"), semtok.TokKeyword, nil) + tv.token(n.If, len("if"), semtok.TokKeyword) if n.Else != nil { // x.Body.End() or x.Body.End()+1, not that it matters pos := tv.findKeyword("else", n.Body.End(), n.Else.Pos()) - tv.token(pos, len("else"), semtok.TokKeyword, nil) + tv.token(pos, len("else"), semtok.TokKeyword) } case *ast.ImportSpec: tv.importSpec(n) return false case *ast.IncDecStmt: - tv.token(n.TokPos, len(n.Tok.String()), semtok.TokOperator, nil) + tv.token(n.TokPos, len(n.Tok.String()), semtok.TokOperator) case *ast.IndexExpr: case *ast.IndexListExpr: case *ast.InterfaceType: - tv.token(n.Interface, len("interface"), semtok.TokKeyword, nil) + tv.token(n.Interface, len("interface"), semtok.TokKeyword) case *ast.KeyValueExpr: case *ast.LabeledStmt: - tv.token(n.Label.Pos(), len(n.Label.Name), semtok.TokLabel, []string{"definition"}) case *ast.MapType: - tv.token(n.Map, len("map"), semtok.TokKeyword, nil) + tv.token(n.Map, len("map"), semtok.TokKeyword) case *ast.ParenExpr: case *ast.RangeStmt: - tv.token(n.For, len("for"), semtok.TokKeyword, nil) + tv.token(n.For, len("for"), semtok.TokKeyword) // x.TokPos == token.NoPos is legal (for range foo {}) offset := n.TokPos if offset == token.NoPos { offset = n.For } pos := tv.findKeyword("range", offset, n.X.Pos()) - tv.token(pos, len("range"), semtok.TokKeyword, nil) + tv.token(pos, len("range"), semtok.TokKeyword) case *ast.ReturnStmt: - tv.token(n.Return, len("return"), semtok.TokKeyword, nil) + tv.token(n.Return, len("return"), semtok.TokKeyword) case *ast.SelectStmt: - tv.token(n.Select, len("select"), semtok.TokKeyword, nil) + tv.token(n.Select, len("select"), semtok.TokKeyword) case *ast.SelectorExpr: case *ast.SendStmt: - tv.token(n.Arrow, len("<-"), semtok.TokOperator, nil) + tv.token(n.Arrow, len("<-"), semtok.TokOperator) case *ast.SliceExpr: case *ast.StarExpr: - tv.token(n.Star, len("*"), semtok.TokOperator, nil) + tv.token(n.Star, len("*"), semtok.TokOperator) case *ast.StructType: - tv.token(n.Struct, len("struct"), semtok.TokKeyword, nil) + tv.token(n.Struct, len("struct"), semtok.TokKeyword) case *ast.SwitchStmt: - tv.token(n.Switch, len("switch"), semtok.TokKeyword, nil) + tv.token(n.Switch, len("switch"), semtok.TokKeyword) case *ast.TypeAssertExpr: if n.Type == nil { pos := tv.findKeyword("type", n.Lparen, n.Rparen) - tv.token(pos, len("type"), semtok.TokKeyword, nil) + tv.token(pos, len("type"), semtok.TokKeyword) } case *ast.TypeSpec: case *ast.TypeSwitchStmt: - tv.token(n.Switch, len("switch"), semtok.TokKeyword, nil) + tv.token(n.Switch, len("switch"), semtok.TokKeyword) case *ast.UnaryExpr: - tv.token(n.OpPos, len(n.Op.String()), semtok.TokOperator, nil) + tv.token(n.OpPos, len(n.Op.String()), semtok.TokOperator) case *ast.ValueSpec: // things only seen with parsing or type errors, so ignore them case *ast.BadDecl, *ast.BadExpr, *ast.BadStmt: @@ -482,40 +463,94 @@ func (tv *tokenVisitor) inspect(n ast.Node) (descend bool) { return true } +func (tv *tokenVisitor) appendObjectModifiers(mods []semtok.Modifier, obj types.Object) (semtok.TokenType, []semtok.Modifier) { + if obj.Pkg() == nil { + mods = append(mods, semtok.ModDefaultLibrary) + } + + // Note: PkgName, Builtin, Label have type Invalid, which adds no modifiers. + mods = appendTypeModifiers(mods, obj.Type()) + + switch obj := obj.(type) { + case *types.PkgName: + return semtok.TokNamespace, mods + + case *types.Builtin: + return semtok.TokFunction, mods + + case *types.Func: + if obj.Signature().Recv() != nil { + return semtok.TokMethod, mods + } else { + return semtok.TokFunction, mods + } + + case *types.TypeName: + if is[*types.TypeParam](types.Unalias(obj.Type())) { + return semtok.TokTypeParam, mods + } + return semtok.TokType, mods + + case *types.Const: + mods = append(mods, semtok.ModReadonly) + return semtok.TokVariable, mods + + case *types.Var: + if tv.isParam(obj.Pos()) { + return semtok.TokParameter, mods + } else { + return semtok.TokVariable, mods + } + + case *types.Label: + return semtok.TokLabel, mods + + case *types.Nil: + mods = append(mods, semtok.ModReadonly) + return semtok.TokVariable, mods + } + + panic(obj) +} + // appendTypeModifiers appends optional modifiers that describe the top-level -// type constructor of obj.Type(): "pointer", "map", etc. -func appendTypeModifiers(mods []string, obj types.Object) []string { - switch t := obj.Type().Underlying().(type) { +// type constructor of t: "pointer", "map", etc. +func appendTypeModifiers(mods []semtok.Modifier, t types.Type) []semtok.Modifier { + // For a type parameter, don't report "interface". + if is[*types.TypeParam](types.Unalias(t)) { + return mods + } + + switch t := t.Underlying().(type) { case *types.Interface: - mods = append(mods, "interface") + mods = append(mods, semtok.ModInterface) case *types.Struct: - mods = append(mods, "struct") + mods = append(mods, semtok.ModStruct) case *types.Signature: - mods = append(mods, "signature") + mods = append(mods, semtok.ModSignature) case *types.Pointer: - mods = append(mods, "pointer") + mods = append(mods, semtok.ModPointer) case *types.Array: - mods = append(mods, "array") + mods = append(mods, semtok.ModArray) case *types.Map: - mods = append(mods, "map") + mods = append(mods, semtok.ModMap) case *types.Slice: - mods = append(mods, "slice") + mods = append(mods, semtok.ModSlice) case *types.Chan: - mods = append(mods, "chan") + mods = append(mods, semtok.ModChan) case *types.Basic: - mods = append(mods, "defaultLibrary") switch t.Kind() { case types.Invalid: - mods = append(mods, "invalid") + // ignore (e.g. Builtin, PkgName, Label) case types.String: - mods = append(mods, "string") + mods = append(mods, semtok.ModString) case types.Bool: - mods = append(mods, "bool") + mods = append(mods, semtok.ModBool) case types.UnsafePointer: - mods = append(mods, "pointer") + mods = append(mods, semtok.ModPointer) default: if t.Info()&types.IsNumeric != 0 { - mods = append(mods, "number") + mods = append(mods, semtok.ModNumber) } } } @@ -523,76 +558,38 @@ func appendTypeModifiers(mods []string, obj types.Object) []string { } func (tv *tokenVisitor) ident(id *ast.Ident) { - var obj types.Object - - // emit emits a token for the identifier's extent. - emit := func(tok semtok.TokenType, modifiers ...string) { - tv.token(id.Pos(), len(id.Name), tok, modifiers) - if semDebug { - q := "nil" - if obj != nil { - q = fmt.Sprintf("%T", obj.Type()) // e.g. "*types.Map" - } - log.Printf(" use %s/%T/%s got %s %v (%s)", - id.Name, obj, q, tok, modifiers, tv.strStack()) - } - } + var ( + tok semtok.TokenType + mods []semtok.Modifier + obj types.Object + ok bool + ) + if obj, ok = tv.info.Defs[id]; obj != nil { + // definition + mods = append(mods, semtok.ModDefinition) + tok, mods = tv.appendObjectModifiers(mods, obj) + + } else if obj, ok = tv.info.Uses[id]; ok { + // use + tok, mods = tv.appendObjectModifiers(mods, obj) + + } else if tok, mods = tv.unkIdent(id); tok != "" { + // ok - // definition? - obj = tv.info.Defs[id] - if obj != nil { - if tok, modifiers := tv.definitionFor(id, obj); tok != "" { - emit(tok, modifiers...) - } else if semDebug { - log.Printf(" for %s/%T/%T got '' %v (%s)", - id.Name, obj, obj.Type(), modifiers, tv.strStack()) - } + } else { return } - // use? - obj = tv.info.Uses[id] - switch obj := obj.(type) { - case *types.Builtin: - emit(semtok.TokFunction, "defaultLibrary") - case *types.Const: - if is[*types.Basic](obj.Type()) && - (id.Name == "iota" || id.Name == "true" || id.Name == "false") { - emit(semtok.TokVariable, "readonly", "defaultLibrary") - } else { - emit(semtok.TokVariable, "readonly") - } - case *types.Func: - emit(semtok.TokFunction) - case *types.Label: - // Labels are reliably covered by the syntax traversal. - case *types.Nil: - // nil is a predeclared identifier - emit(semtok.TokVariable, "readonly", "defaultLibrary") - case *types.PkgName: - emit(semtok.TokNamespace) - case *types.TypeName: // could be a TypeParam - if is[*types.TypeParam](types.Unalias(obj.Type())) { - emit(semtok.TokTypeParam) - } else { - emit(semtok.TokType, appendTypeModifiers(nil, obj)...) - } - case *types.Var: - if is[*types.Signature](types.Unalias(obj.Type())) { - emit(semtok.TokFunction) - } else if tv.isParam(obj.Pos()) { - // variable, unless use.pos is the pos of a Field in an ancestor FuncDecl - // or FuncLit and then it's a parameter - emit(semtok.TokParameter) - } else { - emit(semtok.TokVariable) - } - case nil: - if tok, modifiers := tv.unkIdent(id); tok != "" { - emit(tok, modifiers...) + // Emit a token for the identifier's extent. + tv.token(id.Pos(), len(id.Name), tok, mods...) + + if semDebug { + q := "nil" + if obj != nil { + q = fmt.Sprintf("%T", obj.Type()) // e.g. "*types.Map" } - default: - panic(obj) + log.Printf(" use %s/%T/%s got %s %v (%s)", + id.Name, obj, q, tok, mods, tv.strStack()) } } @@ -626,8 +623,8 @@ func (tv *tokenVisitor) isParam(pos token.Pos) bool { // def), use the parse stack. // A lot of these only happen when the package doesn't compile, // but in that case it is all best-effort from the parse tree. -func (tv *tokenVisitor) unkIdent(id *ast.Ident) (semtok.TokenType, []string) { - def := []string{"definition"} +func (tv *tokenVisitor) unkIdent(id *ast.Ident) (semtok.TokenType, []semtok.Modifier) { + def := []semtok.Modifier{semtok.ModDefinition} n := len(tv.stack) - 2 // parent of Ident; stack is [File ... Ident] if n < 0 { tv.errorf("no stack") // can't happen @@ -748,115 +745,6 @@ func (tv *tokenVisitor) unkIdent(id *ast.Ident) (semtok.TokenType, []string) { return "", nil } -func isDeprecated(n *ast.CommentGroup) bool { - if n != nil { - for _, c := range n.List { - if strings.HasPrefix(c.Text, "// Deprecated") { - return true - } - } - } - return false -} - -// definitionFor handles a defining identifier. -func (tv *tokenVisitor) definitionFor(id *ast.Ident, obj types.Object) (semtok.TokenType, []string) { - // The definition of a types.Label cannot be found by - // ascending the syntax tree, and doing so will reach the - // FuncDecl, causing us to misinterpret the label as a - // parameter (#65494). - // - // However, labels are reliably covered by the syntax - // traversal, so we don't need to use type information. - if is[*types.Label](obj) { - return "", nil - } - - // PJW: look into replacing these syntactic tests with types more generally - modifiers := []string{"definition"} - for i := len(tv.stack) - 1; i >= 0; i-- { - switch ancestor := tv.stack[i].(type) { - case *ast.AssignStmt, *ast.RangeStmt: - if id.Name == "_" { - return "", nil // not really a variable - } - return semtok.TokVariable, modifiers - case *ast.GenDecl: - if isDeprecated(ancestor.Doc) { - modifiers = append(modifiers, "deprecated") - } - if ancestor.Tok == token.CONST { - modifiers = append(modifiers, "readonly") - } - return semtok.TokVariable, modifiers - case *ast.FuncDecl: - // If x is immediately under a FuncDecl, it is a function or method - if i == len(tv.stack)-2 { - if isDeprecated(ancestor.Doc) { - modifiers = append(modifiers, "deprecated") - } - if ancestor.Recv != nil { - return semtok.TokMethod, modifiers - } - return semtok.TokFunction, modifiers - } - // if x < ... < FieldList < FuncDecl, this is the receiver, a variable - // PJW: maybe not. it might be a typeparameter in the type of the receiver - if is[*ast.FieldList](tv.stack[i+1]) { - if is[*types.TypeName](obj) { - return semtok.TokTypeParam, modifiers - } - return semtok.TokVariable, nil - } - // if x < ... < FieldList < FuncType < FuncDecl, this is a param - return semtok.TokParameter, modifiers - case *ast.FuncType: - if isTypeParam(id, ancestor) { - return semtok.TokTypeParam, modifiers - } - return semtok.TokParameter, modifiers - case *ast.InterfaceType: - return semtok.TokMethod, modifiers - case *ast.TypeSpec: - // GenDecl/Typespec/FuncType/FieldList/Field/Ident - // (type A func(b uint64)) (err error) - // b and err should not be semtok.TokType, but semtok.TokVariable - // and in GenDecl/TpeSpec/StructType/FieldList/Field/Ident - // (type A struct{b uint64} - // but on type B struct{C}), C is a type, but is not being defined. - // GenDecl/TypeSpec/FieldList/Field/Ident is a typeParam - if is[*ast.FieldList](tv.stack[i+1]) { - return semtok.TokTypeParam, modifiers - } - fldm := tv.stack[len(tv.stack)-2] - if fld, ok := fldm.(*ast.Field); ok { - // if len(fld.names) == 0 this is a semtok.TokType, being used - if len(fld.Names) == 0 { - return semtok.TokType, appendTypeModifiers(nil, obj) - } - return semtok.TokVariable, modifiers - } - return semtok.TokType, appendTypeModifiers(modifiers, obj) - } - } - // can't happen - tv.errorf("failed to find the decl for %s", safetoken.Position(tv.pgf.Tok, id.Pos())) - return "", nil -} - -func isTypeParam(id *ast.Ident, t *ast.FuncType) bool { - if tp := t.TypeParams; tp != nil { - for _, p := range tp.List { - for _, n := range p.Names { - if id == n { - return true - } - } - } - } - return false -} - // multiline emits a multiline token (`string` or /*comment*/). func (tv *tokenVisitor) multiline(start, end token.Pos, tok semtok.TokenType) { // TODO(adonovan): test with non-ASCII. @@ -875,13 +763,13 @@ func (tv *tokenVisitor) multiline(start, end token.Pos, tok semtok.TokenType) { sline := spos.Line eline := epos.Line // first line is from spos.Column to end - tv.token(start, length(sline)-spos.Column, tok, nil) // leng(sline)-1 - (spos.Column-1) + tv.token(start, length(sline)-spos.Column, tok) // leng(sline)-1 - (spos.Column-1) for i := sline + 1; i < eline; i++ { // intermediate lines are from 1 to end - tv.token(f.LineStart(i), length(i)-1, tok, nil) // avoid the newline + tv.token(f.LineStart(i), length(i)-1, tok) // avoid the newline } // last line is from 1 to epos.Column - tv.token(f.LineStart(eline), epos.Column-1, tok, nil) // columns are 1-based + tv.token(f.LineStart(eline), epos.Column-1, tok) // columns are 1-based } // findKeyword returns the position of a keyword by searching within @@ -907,7 +795,7 @@ func (tv *tokenVisitor) importSpec(spec *ast.ImportSpec) { if spec.Name != nil { name := spec.Name.String() if name != "_" && name != "." { - tv.token(spec.Name.Pos(), len(name), semtok.TokNamespace, nil) + tv.token(spec.Name.Pos(), len(name), semtok.TokNamespace) } return // don't mark anything for . or _ } @@ -933,7 +821,7 @@ func (tv *tokenVisitor) importSpec(spec *ast.ImportSpec) { } // Report virtual declaration at the position of the substring. start := spec.Path.Pos() + token.Pos(j) - tv.token(start, len(depMD.Name), semtok.TokNamespace, nil) + tv.token(start, len(depMD.Name), semtok.TokNamespace) } // errorf logs an error and reports a bug. @@ -968,19 +856,19 @@ func (tv *tokenVisitor) godirective(c *ast.Comment) { kind, _ := stringsCutPrefix(directive, "//go:") if _, ok := godirectives[kind]; !ok { // Unknown 'go:' directive. - tv.token(c.Pos(), len(c.Text), semtok.TokComment, nil) + tv.token(c.Pos(), len(c.Text), semtok.TokComment) return } // Make the 'go:directive' part stand out, the rest is comments. - tv.token(c.Pos(), len("//"), semtok.TokComment, nil) + tv.token(c.Pos(), len("//"), semtok.TokComment) directiveStart := c.Pos() + token.Pos(len("//")) - tv.token(directiveStart, len(directive[len("//"):]), semtok.TokNamespace, nil) + tv.token(directiveStart, len(directive[len("//"):]), semtok.TokNamespace) if len(args) > 0 { tailStart := c.Pos() + token.Pos(len(directive)+len(" ")) - tv.token(tailStart, len(args), semtok.TokComment, nil) + tv.token(tailStart, len(args), semtok.TokComment) } } diff --git a/gopls/internal/golang/stubmethods/stubcalledfunc.go b/gopls/internal/golang/stubmethods/stubcalledfunc.go index 0b6c1052182..1b1b6aba7de 100644 --- a/gopls/internal/golang/stubmethods/stubcalledfunc.go +++ b/gopls/internal/golang/stubmethods/stubcalledfunc.go @@ -91,7 +91,7 @@ func GetCallStubInfo(fset *token.FileSet, info *types.Info, path []ast.Node, pos // Emit writes to out the missing method based on type info of si.Receiver and CallExpr. func (si *CallStubInfo) Emit(out *bytes.Buffer, qual types.Qualifier) error { params := si.collectParams() - rets := typesFromContext(si.info, si.path, si.path[0].Pos()) + rets := typesutil.TypesFromContext(si.info, si.path, si.path[0].Pos()) recv := si.Receiver.Obj() // Pointer receiver? var star string @@ -193,116 +193,6 @@ func (si *CallStubInfo) collectParams() []param { return params } -// typesFromContext returns the type (or perhaps zero or multiple types) -// of the "hole" into which the expression identified by path must fit. -// -// For example, given -// -// s, i := "", 0 -// s, i = EXPR -// -// the hole that must be filled by EXPR has type (string, int). -// -// It returns nil on failure. -func typesFromContext(info *types.Info, path []ast.Node, pos token.Pos) []types.Type { - var typs []types.Type - parent := parentNode(path) - if parent == nil { - return nil - } - switch parent := parent.(type) { - case *ast.AssignStmt: - // Append all lhs's type - if len(parent.Rhs) == 1 { - for _, lhs := range parent.Lhs { - t := info.TypeOf(lhs) - if t != nil && !containsInvalid(t) { - t = types.Default(t) - } else { - t = anyType - } - typs = append(typs, t) - } - break - } - - // Lhs and Rhs counts do not match, give up - if len(parent.Lhs) != len(parent.Rhs) { - break - } - - // Append corresponding index of lhs's type - for i, rhs := range parent.Rhs { - if rhs.Pos() <= pos && pos <= rhs.End() { - t := info.TypeOf(parent.Lhs[i]) - if t != nil && !containsInvalid(t) { - t = types.Default(t) - } else { - t = anyType - } - typs = append(typs, t) - break - } - } - case *ast.CallExpr: - // Find argument containing pos. - argIdx := -1 - for i, callArg := range parent.Args { - if callArg.Pos() <= pos && pos <= callArg.End() { - argIdx = i - break - } - } - if argIdx == -1 { - break - } - - t := info.TypeOf(parent.Fun) - if t == nil { - break - } - - if sig, ok := t.Underlying().(*types.Signature); ok { - var paramType types.Type - if sig.Variadic() && argIdx >= sig.Params().Len()-1 { - v := sig.Params().At(sig.Params().Len() - 1) - if s, _ := v.Type().(*types.Slice); s != nil { - paramType = s.Elem() - } - } else if argIdx < sig.Params().Len() { - paramType = sig.Params().At(argIdx).Type() - } else { - break - } - if paramType == nil || containsInvalid(paramType) { - paramType = anyType - } - typs = append(typs, paramType) - } - default: - // TODO: support other common kinds of "holes", e.g. - // x + EXPR => typeof(x) - // !EXPR => bool - // var x int = EXPR => int - // etc. - } - return typs -} - -// parentNode returns the nodes immediately enclosing path[0], -// ignoring parens. -func parentNode(path []ast.Node) ast.Node { - if len(path) <= 1 { - return nil - } - for _, n := range path[1:] { - if _, ok := n.(*ast.ParenExpr); !ok { - return n - } - } - return nil -} - // containsInvalid checks if the type name contains "invalid type", // which is not a valid syntax to generate. func containsInvalid(t types.Type) bool { diff --git a/gopls/internal/golang/stubmethods/stubmethods.go b/gopls/internal/golang/stubmethods/stubmethods.go index dbfcefd9e16..f380f5b984d 100644 --- a/gopls/internal/golang/stubmethods/stubmethods.go +++ b/gopls/internal/golang/stubmethods/stubmethods.go @@ -13,9 +13,11 @@ import ( "go/ast" "go/token" "go/types" - "golang.org/x/tools/internal/typesinternal" "strings" + "golang.org/x/tools/internal/typesinternal" + + "golang.org/x/tools/gopls/internal/util/bug" "golang.org/x/tools/gopls/internal/util/typesutil" ) @@ -272,23 +274,29 @@ func fromReturnStmt(fset *token.FileSet, info *types.Info, pos token.Pos, path [ concType, pointer := concreteType(ret.Results[returnIdx], info) if concType == nil || concType.Obj().Pkg() == nil { - return nil, nil + return nil, nil // result is not a named or *named or alias thereof } + // Inv: the return is not a spread return, + // such as "return f()" where f() has tuple type. conc := concType.Obj() if conc.Parent() != conc.Pkg().Scope() { return nil, fmt.Errorf("local type %q cannot be stubbed", conc.Name()) } - funcType := enclosingFunction(path, info) - if funcType == nil { - return nil, fmt.Errorf("could not find the enclosing function of the return statement") + sig := typesutil.EnclosingSignature(path, info) + if sig == nil { + // golang/go#70666: this bug may be reached in practice. + return nil, bug.Errorf("could not find the enclosing function of the return statement") } - if len(funcType.Results.List) != len(ret.Results) { + rets := sig.Results() + // The return operands and function results must match. + // (Spread returns were rejected earlier.) + if rets.Len() != len(ret.Results) { return nil, fmt.Errorf("%d-operand return statement in %d-result function", len(ret.Results), - len(funcType.Results.List)) + rets.Len()) } - iface := ifaceType(funcType.Results.List[returnIdx].Type, info) + iface := ifaceObjFromType(rets.At(returnIdx).Type()) if iface == nil { return nil, nil } @@ -442,21 +450,3 @@ func concreteType(e ast.Expr, info *types.Info) (*types.Named, bool) { } return named, isPtr } - -// enclosingFunction returns the signature and type of the function -// enclosing the given position. -func enclosingFunction(path []ast.Node, info *types.Info) *ast.FuncType { - for _, node := range path { - switch t := node.(type) { - case *ast.FuncDecl: - if _, ok := info.Defs[t.Name]; ok { - return t.Type - } - case *ast.FuncLit: - if _, ok := info.Types[t]; ok { - return t.Type - } - } - } - return nil -} diff --git a/gopls/internal/analysis/undeclaredname/undeclared.go b/gopls/internal/golang/undeclared.go similarity index 58% rename from gopls/internal/analysis/undeclaredname/undeclared.go rename to gopls/internal/golang/undeclared.go index 47027be07e4..3d9954639b4 100644 --- a/gopls/internal/analysis/undeclaredname/undeclared.go +++ b/gopls/internal/golang/undeclared.go @@ -2,11 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package undeclaredname +package golang import ( "bytes" - _ "embed" "fmt" "go/ast" "go/format" @@ -17,71 +16,32 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/ast/astutil" - "golang.org/x/tools/gopls/internal/util/safetoken" + "golang.org/x/tools/gopls/internal/util/typesutil" "golang.org/x/tools/internal/analysisinternal" + "golang.org/x/tools/internal/typesinternal" ) -//go:embed doc.go -var doc string - -var Analyzer = &analysis.Analyzer{ - Name: "undeclaredname", - Doc: analysisinternal.MustExtractDoc(doc, "undeclaredname"), - Requires: []*analysis.Analyzer{}, - Run: run, - RunDespiteErrors: true, - URL: "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/undeclaredname", -} - // The prefix for this error message changed in Go 1.20. var undeclaredNamePrefixes = []string{"undeclared name: ", "undefined: "} -func run(pass *analysis.Pass) (interface{}, error) { - for _, err := range pass.TypeErrors { - runForError(pass, err) - } - return nil, nil -} - -func runForError(pass *analysis.Pass, err types.Error) { +// undeclaredFixTitle generates a code action title for "undeclared name" errors, +// suggesting the creation of the missing variable or function if applicable. +func undeclaredFixTitle(path []ast.Node, errMsg string) string { // Extract symbol name from error. var name string for _, prefix := range undeclaredNamePrefixes { - if !strings.HasPrefix(err.Msg, prefix) { + if !strings.HasPrefix(errMsg, prefix) { continue } - name = strings.TrimPrefix(err.Msg, prefix) - } - if name == "" { - return - } - - // Find file enclosing error. - var file *ast.File - for _, f := range pass.Files { - if f.FileStart <= err.Pos && err.Pos < f.FileEnd { - file = f - break - } - } - if file == nil { - return - } - - // Find path to identifier in the error. - path, _ := astutil.PathEnclosingInterval(file, err.Pos, err.Pos) - if len(path) < 2 { - return + name = strings.TrimPrefix(errMsg, prefix) } ident, ok := path[0].(*ast.Ident) if !ok || ident.Name != name { - return + return "" } - - // Skip selector expressions because it might be too complex - // to try and provide a suggested fix for fields and methods. + // TODO: support create undeclared field if _, ok := path[1].(*ast.SelectorExpr); ok { - return + return "" } // Undeclared quick fixes only work in function bodies. @@ -89,16 +49,16 @@ func runForError(pass *analysis.Pass, err types.Error) { for i := range path { if _, inFunc = path[i].(*ast.FuncDecl); inFunc { if i == 0 { - return + return "" } if _, isBody := path[i-1].(*ast.BlockStmt); !isBody { - return + return "" } break } } if !inFunc { - return + return "" } // Offer a fix. @@ -106,22 +66,11 @@ func runForError(pass *analysis.Pass, err types.Error) { if isCallPosition(path) { noun = "function" } - pass.Report(analysis.Diagnostic{ - Pos: err.Pos, - End: err.Pos + token.Pos(len(name)), - Message: err.Msg, - Category: FixCategory, - SuggestedFixes: []analysis.SuggestedFix{{ - Message: fmt.Sprintf("Create %s %q", noun, name), - // No TextEdits => computed by a gopls command - }}, - }) + return fmt.Sprintf("Create %s %s", noun, name) } -const FixCategory = "undeclaredname" // recognized by gopls ApplyFix - -// SuggestedFix computes the edits for the lazy (no-edits) fix suggested by the analyzer. -func SuggestedFix(fset *token.FileSet, start, end token.Pos, content []byte, file *ast.File, pkg *types.Package, info *types.Info) (*token.FileSet, *analysis.SuggestedFix, error) { +// CreateUndeclared generates a suggested declaration for an undeclared variable or function. +func CreateUndeclared(fset *token.FileSet, start, end token.Pos, content []byte, file *ast.File, pkg *types.Package, info *types.Info) (*token.FileSet, *analysis.SuggestedFix, error) { pos := start // don't use the end path, _ := astutil.PathEnclosingInterval(file, pos, pos) if len(path) < 2 { @@ -137,34 +86,118 @@ func SuggestedFix(fset *token.FileSet, start, end token.Pos, content []byte, fil if isCallPosition(path) { return newFunctionDeclaration(path, file, pkg, info, fset) } + var ( + firstRef *ast.Ident // We should insert the new declaration before the first occurrence of the undefined ident. + assignTokPos token.Pos + funcDecl = path[len(path)-2].(*ast.FuncDecl) // This is already ensured by [undeclaredFixTitle]. + parent = ast.Node(funcDecl) + ) + // Search from enclosing FuncDecl to path[0], since we can not use := syntax outside function. + // Adds the missing colon after the first undefined symbol + // when it sits in lhs of an AssignStmt. + ast.Inspect(funcDecl, func(n ast.Node) bool { + if n == nil || firstRef != nil { + return false + } + if n, ok := n.(*ast.Ident); ok && n.Name == ident.Name && info.ObjectOf(n) == nil { + firstRef = n + // Only consider adding colon at the first occurrence. + if pos, ok := replaceableAssign(info, n, parent); ok { + assignTokPos = pos + return false + } + } + parent = n + return true + }) + if assignTokPos.IsValid() { + return fset, &analysis.SuggestedFix{ + TextEdits: []analysis.TextEdit{{ + Pos: assignTokPos, + End: assignTokPos, + NewText: []byte(":"), + }}, + }, nil + } - // Get the place to insert the new statement. - insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(path) + // firstRef should never be nil, at least one ident at cursor position should be found, + // but be defensive. + if firstRef == nil { + return nil, nil, fmt.Errorf("no identifier found") + } + p, _ := astutil.PathEnclosingInterval(file, firstRef.Pos(), firstRef.Pos()) + insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(p) if insertBeforeStmt == nil { return nil, nil, fmt.Errorf("could not locate insertion point") } - - insertBefore := safetoken.StartPosition(fset, insertBeforeStmt.Pos()).Offset - - // Get the indent to add on the line after the new statement. - // Since this will have a parse error, we can not use format.Source(). - contentBeforeStmt, indent := content[:insertBefore], "\n" - if nl := bytes.LastIndex(contentBeforeStmt, []byte("\n")); nl != -1 { - indent = string(contentBeforeStmt[nl:]) + indent, err := calculateIndentation(content, fset.File(file.FileStart), insertBeforeStmt) + if err != nil { + return nil, nil, err + } + typs := typesutil.TypesFromContext(info, path, start) + if typs == nil { + // Default to 0. + typs = []types.Type{types.Typ[types.Int]} + } + assignStmt := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(ident.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{typesinternal.ZeroExpr(file, pkg, typs[0])}, + } + var buf bytes.Buffer + if err := format.Node(&buf, fset, assignStmt); err != nil { + return nil, nil, err } + newLineIndent := "\n" + indent + assignment := strings.ReplaceAll(buf.String(), "\n", newLineIndent) + newLineIndent - // Create the new local variable statement. - newStmt := fmt.Sprintf("%s := %s", ident.Name, indent) return fset, &analysis.SuggestedFix{ - Message: fmt.Sprintf("Create variable %q", ident.Name), - TextEdits: []analysis.TextEdit{{ - Pos: insertBeforeStmt.Pos(), - End: insertBeforeStmt.Pos(), - NewText: []byte(newStmt), - }}, + TextEdits: []analysis.TextEdit{ + { + Pos: insertBeforeStmt.Pos(), + End: insertBeforeStmt.Pos(), + NewText: []byte(assignment), + }, + }, }, nil } +// replaceableAssign returns position of token.ASSIGN if ident meets the following conditions: +// 1) parent node must be an *ast.AssignStmt with Tok set to token.ASSIGN. +// 2) ident must not be self assignment. +// +// For example, we should not add a colon when +// a = a + 1 +// ^ ^ cursor here +func replaceableAssign(info *types.Info, ident *ast.Ident, parent ast.Node) (token.Pos, bool) { + var pos token.Pos + if assign, ok := parent.(*ast.AssignStmt); ok && assign.Tok == token.ASSIGN { + for _, rhs := range assign.Rhs { + if referencesIdent(info, rhs, ident) { + return pos, false + } + } + return assign.TokPos, true + } + return pos, false +} + +// referencesIdent checks whether the given undefined ident appears in the given expression. +func referencesIdent(info *types.Info, expr ast.Expr, ident *ast.Ident) bool { + var hasIdent bool + ast.Inspect(expr, func(n ast.Node) bool { + if n == nil { + return false + } + if i, ok := n.(*ast.Ident); ok && i.Name == ident.Name && info.ObjectOf(i) == nil { + hasIdent = true + return false + } + return true + }) + return hasIdent +} + func newFunctionDeclaration(path []ast.Node, file *ast.File, pkg *types.Package, info *types.Info, fset *token.FileSet) (*token.FileSet, *analysis.SuggestedFix, error) { if len(path) < 3 { return nil, nil, fmt.Errorf("unexpected set of enclosing nodes: %v", path) @@ -273,16 +306,23 @@ func newFunctionDeclaration(path []ast.Node, file *ast.File, pkg *types.Package, Names: []*ast.Ident{ ast.NewIdent(name), }, - Type: analysisinternal.TypeExpr(file, pkg, paramTypes[i]), + Type: typesinternal.TypeExpr(file, pkg, paramTypes[i]), + }) + } + + rets := &ast.FieldList{} + retTypes := typesutil.TypesFromContext(info, path[1:], path[1].Pos()) + for _, rt := range retTypes { + rets.List = append(rets.List, &ast.Field{ + Type: typesinternal.TypeExpr(file, pkg, rt), }) } decl := &ast.FuncDecl{ Name: ast.NewIdent(ident.Name), Type: &ast.FuncType{ - Params: params, - // TODO(golang/go#47558): Also handle result - // parameters here based on context of CallExpr. + Params: params, + Results: rets, }, Body: &ast.BlockStmt{ List: []ast.Stmt{ @@ -305,7 +345,6 @@ func newFunctionDeclaration(path []ast.Node, file *ast.File, pkg *types.Package, return nil, nil, err } return fset, &analysis.SuggestedFix{ - Message: fmt.Sprintf("Create function %q", ident.Name), TextEdits: []analysis.TextEdit{{ Pos: pos, End: pos, @@ -352,8 +391,3 @@ func isCallPosition(path []ast.Node) bool { is[*ast.CallExpr](path[1]) && path[1].(*ast.CallExpr).Fun == path[0] } - -func is[T any](x any) bool { - _, ok := x.(T) - return ok -} diff --git a/gopls/internal/golang/util.go b/gopls/internal/golang/util.go index 18f72421a64..be5c7c0a735 100644 --- a/gopls/internal/golang/util.go +++ b/gopls/internal/golang/util.go @@ -12,6 +12,7 @@ import ( "go/types" "regexp" "strings" + "unicode" "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/cache/metadata" @@ -131,6 +132,9 @@ func findFileInDeps(s metadata.Source, mp *metadata.Package, uri protocol.Docume // CollectScopes returns all scopes in an ast path, ordered as innermost scope // first. +// +// TODO(adonovan): move this to golang/completion and simplify to use +// Scopes.Innermost and LookupParent instead. func CollectScopes(info *types.Info, path []ast.Node, pos token.Pos) []*types.Scope { // scopes[i], where i file.Package { + return nil + } + + for _, c := range cg.List { + // TODO: use ast.ParseDirective when available (#68021). + if buildConstraintRe.MatchString(c.Text) { + return c + } + } + } + + return nil +} diff --git a/gopls/internal/mod/code_lens.go b/gopls/internal/mod/code_lens.go index f80063625ff..fcc474a575d 100644 --- a/gopls/internal/mod/code_lens.go +++ b/gopls/internal/mod/code_lens.go @@ -21,10 +21,11 @@ import ( // CodeLensSources returns the sources of code lenses for go.mod files. func CodeLensSources() map[settings.CodeLensSource]cache.CodeLensSourceFunc { return map[settings.CodeLensSource]cache.CodeLensSourceFunc{ - settings.CodeLensUpgradeDependency: upgradeLenses, // commands: CheckUpgrades, UpgradeDependency - settings.CodeLensTidy: tidyLens, // commands: Tidy - settings.CodeLensVendor: vendorLens, // commands: Vendor - settings.CodeLensRunGovulncheck: vulncheckLenses, // commands: RunGovulncheck + settings.CodeLensUpgradeDependency: upgradeLenses, // commands: CheckUpgrades, UpgradeDependency + settings.CodeLensTidy: tidyLens, // commands: Tidy + settings.CodeLensVendor: vendorLens, // commands: Vendor + settings.CodeLensVulncheck: vulncheckLenses, // commands: Vulncheck + settings.CodeLensRunGovulncheck: runGovulncheckLenses, // commands: RunGovulncheck } } @@ -112,7 +113,7 @@ func vendorLens(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle) ( cmd := command.NewVendorCommand(title, command.URIArg{URI: uri}) // Change the message depending on whether or not the module already has a // vendor directory. - vendorDir := filepath.Join(filepath.Dir(fh.URI().Path()), "vendor") + vendorDir := filepath.Join(fh.URI().DirPath(), "vendor") if info, _ := os.Stat(vendorDir); info != nil && info.IsDir() { title = "Sync vendor directory" } @@ -162,6 +163,29 @@ func vulncheckLenses(ctx context.Context, snapshot *cache.Snapshot, fh file.Hand return nil, err } + vulncheck := command.NewVulncheckCommand("Run govulncheck", command.VulncheckArgs{ + URI: uri, + Pattern: "./...", + }) + return []protocol.CodeLens{ + {Range: rng, Command: vulncheck}, + }, nil +} + +func runGovulncheckLenses(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle) ([]protocol.CodeLens, error) { + pm, err := snapshot.ParseMod(ctx, fh) + if err != nil || pm.File == nil { + return nil, err + } + // Place the codelenses near the module statement. + // A module may not have the require block, + // but vulnerabilities can exist in standard libraries. + uri := fh.URI() + rng, err := moduleStmtRange(fh, pm) + if err != nil { + return nil, err + } + vulncheck := command.NewRunGovulncheckCommand("Run govulncheck", command.VulncheckArgs{ URI: uri, Pattern: "./...", diff --git a/gopls/internal/protocol/command/command_gen.go b/gopls/internal/protocol/command/command_gen.go index 829a3824bc0..9991c95680e 100644 --- a/gopls/internal/protocol/command/command_gen.go +++ b/gopls/internal/protocol/command/command_gen.go @@ -65,6 +65,7 @@ const ( UpgradeDependency Command = "gopls.upgrade_dependency" Vendor Command = "gopls.vendor" Views Command = "gopls.views" + Vulncheck Command = "gopls.vulncheck" WorkspaceStats Command = "gopls.workspace_stats" ) @@ -110,6 +111,7 @@ var Commands = []Command{ UpgradeDependency, Vendor, Views, + Vulncheck, WorkspaceStats, } @@ -350,6 +352,12 @@ func Dispatch(ctx context.Context, params *protocol.ExecuteCommandParams, s Inte return nil, s.Vendor(ctx, a0) case Views: return s.Views(ctx) + case Vulncheck: + var a0 VulncheckArgs + if err := UnmarshalArgs(params.Arguments, &a0); err != nil { + return nil, err + } + return s.Vulncheck(ctx, a0) case WorkspaceStats: return s.WorkspaceStats(ctx) } @@ -684,6 +692,14 @@ func NewViewsCommand(title string) *protocol.Command { } } +func NewVulncheckCommand(title string, a0 VulncheckArgs) *protocol.Command { + return &protocol.Command{ + Title: title, + Command: Vulncheck.String(), + Arguments: MustMarshalArgs(a0), + } +} + func NewWorkspaceStatsCommand(title string) *protocol.Command { return &protocol.Command{ Title: title, diff --git a/gopls/internal/protocol/command/interface.go b/gopls/internal/protocol/command/interface.go index 258e1008395..0ce3af2aff9 100644 --- a/gopls/internal/protocol/command/interface.go +++ b/gopls/internal/protocol/command/interface.go @@ -16,6 +16,8 @@ package command import ( "context" + "encoding/json" + "fmt" "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/vulncheck" @@ -186,16 +188,30 @@ type Interface interface { // runner. StopProfile(context.Context, StopProfileArgs) (StopProfileResult, error) - // RunGovulncheck: Run vulncheck + // GoVulncheck: run vulncheck synchronously. // // Run vulnerability check (`govulncheck`). // - // This command is asynchronous; clients must wait for the 'end' progress notification. + // This command is synchronous, and returns the govulncheck result. + Vulncheck(context.Context, VulncheckArgs) (VulncheckResult, error) + + // RunGovulncheck: Run vulncheck asynchronously. + // + // Run vulnerability check (`govulncheck`). + // + // This command is asynchronous; clients must wait for the 'end' progress + // notification and then retrieve results using gopls.fetch_vulncheck_result. + // + // Deprecated: clients should call gopls.vulncheck instead, which returns the + // actual vulncheck result. RunGovulncheck(context.Context, VulncheckArgs) (RunVulncheckResult, error) // FetchVulncheckResult: Get known vulncheck result // // Fetch the result of latest vulnerability check (`govulncheck`). + // + // Deprecated: clients should call gopls.vulncheck instead, which returns the + // actual vulncheck result. FetchVulncheckResult(context.Context, URIArg) (map[protocol.DocumentURI]*vulncheck.Result, error) // MemStats: Fetch memory statistics @@ -224,7 +240,7 @@ type Interface interface { // to avoid conflicts with other counters gopls collects. AddTelemetryCounters(context.Context, AddTelemetryCountersArgs) error - // AddTest: add a test for the selected function + // AddTest: add test for the selected function AddTest(context.Context, protocol.Location) (*protocol.WorkspaceEdit, error) // MaybePromptForTelemetry: Prompt user to enable telemetry @@ -506,13 +522,12 @@ type VulncheckArgs struct { type RunVulncheckResult struct { // Token holds the progress token for LSP workDone reporting of the vulncheck // invocation. - // - // Deprecated: previously, this was used as a signal to retrieve the result - // using gopls.fetch_vulncheck_result. Clients should ignore this field: - // gopls.vulncheck now runs synchronously, and returns a result in the Result - // field. Token protocol.ProgressToken +} +// GovulncheckResult holds the result of synchronously running the vulncheck +// command. +type VulncheckResult struct { // Result holds the result of running vulncheck. Result *vulncheck.Result } @@ -569,12 +584,68 @@ type AddTelemetryCountersArgs struct { } // ChangeSignatureArgs specifies a "change signature" refactoring to perform. +// +// The new signature is expressed via the NewParams and NewResults fields. The +// elements of these lists each describe a new field of the signature, by +// either referencing a field in the old signature or by defining a new field: +// - If the element is an integer, it references a positional parameter in the +// old signature. +// - If the element is a string, it is parsed as a new field to add. +// +// Suppose we have a function `F(a, b int) (string, error)`. Here are some +// examples of refactoring this signature in practice, eliding the 'Location' +// and 'ResolveEdits' fields. +// - `{ "NewParams": [0], "NewResults": [0, 1] }` removes the second parameter +// - `{ "NewParams": [1, 0], "NewResults": [0, 1] }` flips the parameter order +// - `{ "NewParams": [0, 1, "a int"], "NewResults": [0, 1] }` adds a new field +// - `{ "NewParams": [1, 2], "NewResults": [1] }` drops the `error` result type ChangeSignatureArgs struct { - RemoveParameter protocol.Location + // Location is any range inside the function signature. By convention, this + // is the same location provided in the codeAction request. + Location protocol.Location // a range inside of the function signature, as passed to CodeAction + + // NewParams describes parameters of the new signature. + // An int value references a parameter in the old signature by index. + // A string value describes a new parameter field (e.g. "x int"). + NewParams []ChangeSignatureParam + + // NewResults describes results of the new signature (see above). + // An int value references a result in the old signature by index. + // A string value describes a new result field (e.g. "err error"). + NewResults []ChangeSignatureParam + // Whether to resolve and return the edits. ResolveEdits bool } +// ChangeSignatureParam implements the API described in the doc string of +// [ChangeSignatureArgs]: a union of JSON int | string. +type ChangeSignatureParam struct { + OldIndex int + NewField string +} + +func (a *ChangeSignatureParam) UnmarshalJSON(b []byte) error { + var s string + if err := json.Unmarshal(b, &s); err == nil { + a.NewField = s + return nil + } + var i int + if err := json.Unmarshal(b, &i); err == nil { + a.OldIndex = i + return nil + } + return fmt.Errorf("must be int or string") +} + +func (a ChangeSignatureParam) MarshalJSON() ([]byte, error) { + if a.NewField != "" { + return json.Marshal(a.NewField) + } + return json.Marshal(a.OldIndex) +} + // DiagnoseFilesArgs specifies a set of files for which diagnostics are wanted. type DiagnoseFilesArgs struct { Files []protocol.DocumentURI diff --git a/gopls/internal/protocol/generate/main.go b/gopls/internal/protocol/generate/main.go index de42540a054..ef9bf943606 100644 --- a/gopls/internal/protocol/generate/main.go +++ b/gopls/internal/protocol/generate/main.go @@ -31,7 +31,7 @@ const vscodeRepo = "https://github.com/microsoft/vscode-languageserver-node" // protocol version 3.17.0 (as declared by the metaData.version field). // (Point releases are reflected in the git tag version even when they are cosmetic // and don't change the protocol.) -var lspGitRef = "release/protocol/3.17.6-next.2" +var lspGitRef = "release/protocol/3.17.6-next.9" var ( repodir = flag.String("d", "", "directory containing clone of "+vscodeRepo) diff --git a/gopls/internal/protocol/generate/output.go b/gopls/internal/protocol/generate/output.go index 87d6f66cccd..c981bf9c383 100644 --- a/gopls/internal/protocol/generate/output.go +++ b/gopls/internal/protocol/generate/output.go @@ -86,7 +86,7 @@ func genDecl(model *Model, method string, param, result *Type, dir string) { } } -func genCase(model *Model, method string, param, result *Type, dir string) { +func genCase(_ *Model, method string, param, result *Type, dir string) { out := new(bytes.Buffer) fmt.Fprintf(out, "\tcase %q:\n", method) var p string @@ -128,7 +128,7 @@ func genCase(model *Model, method string, param, result *Type, dir string) { } } -func genFunc(model *Model, method string, param, result *Type, dir string, isnotify bool) { +func genFunc(_ *Model, method string, param, result *Type, dir string, isnotify bool) { out := new(bytes.Buffer) var p, r string var goResult string diff --git a/gopls/internal/protocol/generate/tables.go b/gopls/internal/protocol/generate/tables.go index 2036e701d48..c80337f187b 100644 --- a/gopls/internal/protocol/generate/tables.go +++ b/gopls/internal/protocol/generate/tables.go @@ -256,6 +256,8 @@ var methodNames = map[string]string{ "workspace/inlineValue/refresh": "InlineValueRefresh", "workspace/semanticTokens/refresh": "SemanticTokensRefresh", "workspace/symbol": "Symbol", + "workspace/textDocumentContent": "TextDocumentContent", + "workspace/textDocumentContent/refresh": "TextDocumentContentRefresh", "workspace/willCreateFiles": "WillCreateFiles", "workspace/willDeleteFiles": "WillDeleteFiles", "workspace/willRenameFiles": "WillRenameFiles", diff --git a/gopls/internal/protocol/semtok/semtok.go b/gopls/internal/protocol/semtok/semtok.go index 850e234a1b0..fc269c38759 100644 --- a/gopls/internal/protocol/semtok/semtok.go +++ b/gopls/internal/protocol/semtok/semtok.go @@ -12,33 +12,79 @@ type Token struct { Line, Start uint32 Len uint32 Type TokenType - Modifiers []string + Modifiers []Modifier } type TokenType string const ( - // These are the tokens defined by LSP 3.17, but a client is + // These are the tokens defined by LSP 3.18, but a client is // free to send its own set; any tokens that the server emits // that are not in this set are simply not encoded in the bitfield. - TokNamespace TokenType = "namespace" - TokType TokenType = "type" - TokInterface TokenType = "interface" - TokTypeParam TokenType = "typeParameter" - TokParameter TokenType = "parameter" - TokVariable TokenType = "variable" - TokMethod TokenType = "method" - TokFunction TokenType = "function" - TokKeyword TokenType = "keyword" - TokComment TokenType = "comment" - TokString TokenType = "string" - TokNumber TokenType = "number" - TokOperator TokenType = "operator" - TokMacro TokenType = "macro" // for templates + // + // If you add or uncomment a token type, document it in + // gopls/doc/features/passive.md#semantic-tokens. + TokComment TokenType = "comment" // for a comment + TokFunction TokenType = "function" // for a function + TokKeyword TokenType = "keyword" // for a keyword + TokLabel TokenType = "label" // for a control label (LSP 3.18) + TokMacro TokenType = "macro" // for text/template tokens + TokMethod TokenType = "method" // for a method + TokNamespace TokenType = "namespace" // for an imported package name + TokNumber TokenType = "number" // for a numeric literal + TokOperator TokenType = "operator" // for an operator + TokParameter TokenType = "parameter" // for a parameter variable + TokString TokenType = "string" // for a string literal + TokType TokenType = "type" // for a type name (plus other uses) + TokTypeParam TokenType = "typeParameter" // for a type parameter + TokVariable TokenType = "variable" // for a var or const + // TokClass TokenType = "class" + // TokDecorator TokenType = "decorator" + // TokEnum TokenType = "enum" + // TokEnumMember TokenType = "enumMember" + // TokEvent TokenType = "event" + // TokInterface TokenType = "interface" + // TokModifier TokenType = "modifier" + // TokProperty TokenType = "property" + // TokRegexp TokenType = "regexp" + // TokStruct TokenType = "struct" +) + +type Modifier string + +const ( + // LSP 3.18 standard modifiers + // As with TokenTypes, clients get only the modifiers they request. + // + // If you add or uncomment a modifier, document it in + // gopls/doc/features/passive.md#semantic-tokens. + ModDefaultLibrary Modifier = "defaultLibrary" // for predeclared symbols + ModDefinition Modifier = "definition" // for the declaring identifier of a symbol + ModReadonly Modifier = "readonly" // for constants (TokVariable) + // ModAbstract Modifier = "abstract" + // ModAsync Modifier = "async" + // ModDeclaration Modifier = "declaration" + // ModDeprecated Modifier = "deprecated" + // ModDocumentation Modifier = "documentation" + // ModModification Modifier = "modification" + // ModStatic Modifier = "static" - // not part of LSP 3.17 (even though JS has labels) - // https://github.com/microsoft/vscode-languageserver-node/issues/1422 - TokLabel TokenType = "label" + // non-standard modifiers + // + // Since the type of a symbol is orthogonal to its kind, + // (e.g. a variable can have function type), + // we use modifiers for the top-level type constructor. + ModArray Modifier = "array" + ModBool Modifier = "bool" + ModChan Modifier = "chan" + ModInterface Modifier = "interface" + ModMap Modifier = "map" + ModNumber Modifier = "number" + ModPointer Modifier = "pointer" + ModSignature Modifier = "signature" // for function types + ModSlice Modifier = "slice" + ModString Modifier = "string" + ModStruct Modifier = "struct" ) // Encode returns the LSP encoding of a sequence of tokens. @@ -62,9 +108,9 @@ func Encode( typeMap[TokenType(t)] = i } - modMap := make(map[string]int) + modMap := make(map[Modifier]int) for i, m := range modifiers { - modMap[m] = 1 << uint(i) // go 1.12 compatibility + modMap[Modifier(m)] = 1 << i } // each semantic token needs five values diff --git a/gopls/internal/protocol/tsclient.go b/gopls/internal/protocol/tsclient.go index 3f860d5351a..8fd322d424a 100644 --- a/gopls/internal/protocol/tsclient.go +++ b/gopls/internal/protocol/tsclient.go @@ -6,8 +6,8 @@ package protocol -// Code generated from protocol/metaModel.json at ref release/protocol/3.17.6-next.2 (hash 654dc9be6673c61476c28fda604406279c3258d7). -// https://github.com/microsoft/vscode-languageserver-node/blob/release/protocol/3.17.6-next.2/protocol/metaModel.json +// Code generated from protocol/metaModel.json at ref release/protocol/3.17.6-next.9 (hash c94395b5da53729e6dff931293b051009ccaaaa4). +// https://github.com/microsoft/vscode-languageserver-node/blob/release/protocol/3.17.6-next.9/protocol/metaModel.json // LSP metaData.version = 3.17.0. import ( @@ -55,6 +55,8 @@ type Client interface { InlineValueRefresh(context.Context) error // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#workspace_semanticTokens_refresh SemanticTokensRefresh(context.Context) error + // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#workspace_textDocumentContent_refresh + TextDocumentContentRefresh(context.Context, *TextDocumentContentRefreshParams) error // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#workspace_workspaceFolders WorkspaceFolders(context.Context) ([]WorkspaceFolder, error) } @@ -202,6 +204,14 @@ func clientDispatch(ctx context.Context, client Client, reply jsonrpc2.Replier, err := client.SemanticTokensRefresh(ctx) return true, reply(ctx, nil, err) + case "workspace/textDocumentContent/refresh": + var params TextDocumentContentRefreshParams + if err := UnmarshalJSON(r.Params(), ¶ms); err != nil { + return true, sendParseError(ctx, reply, err) + } + err := client.TextDocumentContentRefresh(ctx, ¶ms) + return true, reply(ctx, nil, err) + case "workspace/workspaceFolders": resp, err := client.WorkspaceFolders(ctx) if err != nil { @@ -287,6 +297,9 @@ func (s *clientDispatcher) InlineValueRefresh(ctx context.Context) error { func (s *clientDispatcher) SemanticTokensRefresh(ctx context.Context) error { return s.sender.Call(ctx, "workspace/semanticTokens/refresh", nil, nil) } +func (s *clientDispatcher) TextDocumentContentRefresh(ctx context.Context, params *TextDocumentContentRefreshParams) error { + return s.sender.Call(ctx, "workspace/textDocumentContent/refresh", params, nil) +} func (s *clientDispatcher) WorkspaceFolders(ctx context.Context) ([]WorkspaceFolder, error) { var result []WorkspaceFolder if err := s.sender.Call(ctx, "workspace/workspaceFolders", nil, &result); err != nil { diff --git a/gopls/internal/protocol/tsjson.go b/gopls/internal/protocol/tsjson.go index 7f77ffa999f..0ee4c464167 100644 --- a/gopls/internal/protocol/tsjson.go +++ b/gopls/internal/protocol/tsjson.go @@ -6,8 +6,8 @@ package protocol -// Code generated from protocol/metaModel.json at ref release/protocol/3.17.6-next.2 (hash 654dc9be6673c61476c28fda604406279c3258d7). -// https://github.com/microsoft/vscode-languageserver-node/blob/release/protocol/3.17.6-next.2/protocol/metaModel.json +// Code generated from protocol/metaModel.json at ref release/protocol/3.17.6-next.9 (hash c94395b5da53729e6dff931293b051009ccaaaa4). +// https://github.com/microsoft/vscode-languageserver-node/blob/release/protocol/3.17.6-next.9/protocol/metaModel.json // LSP metaData.version = 3.17.0. import "encoding/json" @@ -1932,12 +1932,14 @@ func (t Or_TextDocumentEdit_edits_Elem) MarshalJSON() ([]byte, error) { switch x := t.Value.(type) { case AnnotatedTextEdit: return json.Marshal(x) + case SnippetTextEdit: + return json.Marshal(x) case TextEdit: return json.Marshal(x) case nil: return []byte("null"), nil } - return nil, fmt.Errorf("type %T not one of [AnnotatedTextEdit TextEdit]", t) + return nil, fmt.Errorf("type %T not one of [AnnotatedTextEdit SnippetTextEdit TextEdit]", t) } func (t *Or_TextDocumentEdit_edits_Elem) UnmarshalJSON(x []byte) error { @@ -1950,12 +1952,17 @@ func (t *Or_TextDocumentEdit_edits_Elem) UnmarshalJSON(x []byte) error { t.Value = h0 return nil } - var h1 TextEdit + var h1 SnippetTextEdit if err := json.Unmarshal(x, &h1); err == nil { t.Value = h1 return nil } - return &UnmarshalError{"unmarshal failed to match one of [AnnotatedTextEdit TextEdit]"} + var h2 TextEdit + if err := json.Unmarshal(x, &h2); err == nil { + t.Value = h2 + return nil + } + return &UnmarshalError{"unmarshal failed to match one of [AnnotatedTextEdit SnippetTextEdit TextEdit]"} } func (t Or_TextDocumentFilter) MarshalJSON() ([]byte, error) { @@ -2099,6 +2106,36 @@ func (t *Or_WorkspaceEdit_documentChanges_Elem) UnmarshalJSON(x []byte) error { return &UnmarshalError{"unmarshal failed to match one of [CreateFile DeleteFile RenameFile TextDocumentEdit]"} } +func (t Or_WorkspaceOptions_textDocumentContent) MarshalJSON() ([]byte, error) { + switch x := t.Value.(type) { + case TextDocumentContentOptions: + return json.Marshal(x) + case TextDocumentContentRegistrationOptions: + return json.Marshal(x) + case nil: + return []byte("null"), nil + } + return nil, fmt.Errorf("type %T not one of [TextDocumentContentOptions TextDocumentContentRegistrationOptions]", t) +} + +func (t *Or_WorkspaceOptions_textDocumentContent) UnmarshalJSON(x []byte) error { + if string(x) == "null" { + t.Value = nil + return nil + } + var h0 TextDocumentContentOptions + if err := json.Unmarshal(x, &h0); err == nil { + t.Value = h0 + return nil + } + var h1 TextDocumentContentRegistrationOptions + if err := json.Unmarshal(x, &h1); err == nil { + t.Value = h1 + return nil + } + return &UnmarshalError{"unmarshal failed to match one of [TextDocumentContentOptions TextDocumentContentRegistrationOptions]"} +} + func (t Or_textDocument_declaration) MarshalJSON() ([]byte, error) { switch x := t.Value.(type) { case Declaration: diff --git a/gopls/internal/protocol/tsprotocol.go b/gopls/internal/protocol/tsprotocol.go index b0b01a4b69a..198aeae7d01 100644 --- a/gopls/internal/protocol/tsprotocol.go +++ b/gopls/internal/protocol/tsprotocol.go @@ -6,8 +6,8 @@ package protocol -// Code generated from protocol/metaModel.json at ref release/protocol/3.17.6-next.2 (hash 654dc9be6673c61476c28fda604406279c3258d7). -// https://github.com/microsoft/vscode-languageserver-node/blob/release/protocol/3.17.6-next.2/protocol/metaModel.json +// Code generated from protocol/metaModel.json at ref release/protocol/3.17.6-next.9 (hash c94395b5da53729e6dff931293b051009ccaaaa4). +// https://github.com/microsoft/vscode-languageserver-node/blob/release/protocol/3.17.6-next.9/protocol/metaModel.json // LSP metaData.version = 3.17.0. import "encoding/json" @@ -33,6 +33,11 @@ type ApplyWorkspaceEditParams struct { Label string `json:"label,omitempty"` // The edits to apply. Edit WorkspaceEdit `json:"edit"` + // Additional data about the edit. + // + // @since 3.18.0 + // @proposed + Metadata *WorkspaceEditMetadata `json:"metadata,omitempty"` } // The result returned from the apply workspace edit request. @@ -216,7 +221,6 @@ type ChangeAnnotation struct { // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#changeAnnotationIdentifier type ChangeAnnotationIdentifier = string // (alias) // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#changeAnnotationsSupportOptions type ChangeAnnotationsSupportOptions struct { @@ -249,7 +253,6 @@ type ClientCapabilities struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientCodeActionKindOptions type ClientCodeActionKindOptions struct { @@ -261,7 +264,6 @@ type ClientCodeActionKindOptions struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientCodeActionLiteralOptions type ClientCodeActionLiteralOptions struct { @@ -271,7 +273,6 @@ type ClientCodeActionLiteralOptions struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientCodeActionResolveOptions type ClientCodeActionResolveOptions struct { @@ -280,7 +281,14 @@ type ClientCodeActionResolveOptions struct { } // @since 3.18.0 -// @proposed +// +// See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientCodeLensResolveOptions +type ClientCodeLensResolveOptions struct { + // The properties that a client can resolve lazily. + Properties []string `json:"properties"` +} + +// @since 3.18.0 // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientCompletionItemInsertTextModeOptions type ClientCompletionItemInsertTextModeOptions struct { @@ -288,7 +296,6 @@ type ClientCompletionItemInsertTextModeOptions struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientCompletionItemOptions type ClientCompletionItemOptions struct { @@ -340,7 +347,6 @@ type ClientCompletionItemOptions struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientCompletionItemOptionsKind type ClientCompletionItemOptionsKind struct { @@ -356,7 +362,6 @@ type ClientCompletionItemOptionsKind struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientCompletionItemResolveOptions type ClientCompletionItemResolveOptions struct { @@ -365,7 +370,6 @@ type ClientCompletionItemResolveOptions struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientDiagnosticsTagOptions type ClientDiagnosticsTagOptions struct { @@ -374,7 +378,6 @@ type ClientDiagnosticsTagOptions struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientFoldingRangeKindOptions type ClientFoldingRangeKindOptions struct { @@ -386,7 +389,6 @@ type ClientFoldingRangeKindOptions struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientFoldingRangeOptions type ClientFoldingRangeOptions struct { @@ -401,7 +403,6 @@ type ClientFoldingRangeOptions struct { // // @since 3.15.0 // @since 3.18.0 ClientInfo type name added. -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientInfo type ClientInfo struct { @@ -412,7 +413,6 @@ type ClientInfo struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientInlayHintResolveOptions type ClientInlayHintResolveOptions struct { @@ -421,7 +421,6 @@ type ClientInlayHintResolveOptions struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientSemanticTokensRequestFullDelta type ClientSemanticTokensRequestFullDelta struct { @@ -431,7 +430,6 @@ type ClientSemanticTokensRequestFullDelta struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientSemanticTokensRequestOptions type ClientSemanticTokensRequestOptions struct { @@ -444,7 +442,6 @@ type ClientSemanticTokensRequestOptions struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientShowMessageActionItemOptions type ClientShowMessageActionItemOptions struct { @@ -455,7 +452,6 @@ type ClientShowMessageActionItemOptions struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientSignatureInformationOptions type ClientSignatureInformationOptions struct { @@ -479,7 +475,6 @@ type ClientSignatureInformationOptions struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientSignatureParameterInformationOptions type ClientSignatureParameterInformationOptions struct { @@ -491,7 +486,6 @@ type ClientSignatureParameterInformationOptions struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientSymbolKindOptions type ClientSymbolKindOptions struct { @@ -507,7 +501,6 @@ type ClientSymbolKindOptions struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientSymbolResolveOptions type ClientSymbolResolveOptions struct { @@ -517,7 +510,6 @@ type ClientSymbolResolveOptions struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#clientSymbolTagOptions type ClientSymbolTagOptions struct { @@ -649,7 +641,6 @@ type CodeActionContext struct { // Captures why the code action is currently disabled. // // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#codeActionDisabled type CodeActionDisabled struct { @@ -776,6 +767,11 @@ type CodeLens struct { type CodeLensClientCapabilities struct { // Whether code lens supports dynamic registration. DynamicRegistration bool `json:"dynamicRegistration,omitempty"` + // Whether the client supports resolving additional code lens + // properties via a separate `codeLens/resolve` request. + // + // @since 3.18.0 + ResolveSupport *ClientCodeLensResolveOptions `json:"resolveSupport,omitempty"` } // Code Lens provider options of a {@link CodeLensRequest}. @@ -1116,7 +1112,6 @@ type CompletionItemLabelDetails struct { type CompletionItemTag uint32 // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#completionItemTagOptions type CompletionItemTagOptions struct { @@ -1413,8 +1408,9 @@ type DeleteFilesParams struct { type Diagnostic struct { // The range at which the message applies Range Range `json:"range"` - // The diagnostic's severity. Can be omitted. If omitted it is up to the - // client to interpret diagnostics as error, warning, info or hint. + // The diagnostic's severity. To avoid interpretation mismatches when a + // server is used with different clients it is highly recommended that servers + // always provide a severity value. Severity DiagnosticSeverity `json:"severity,omitempty"` // The diagnostic's code, which usually appear in the user interface. Code interface{} `json:"code,omitempty"` @@ -1455,6 +1451,7 @@ type DiagnosticClientCapabilities struct { DynamicRegistration bool `json:"dynamicRegistration,omitempty"` // Whether the clients supports related documents for document diagnostic pulls. RelatedDocumentSupport bool `json:"relatedDocumentSupport,omitempty"` + DiagnosticsCapabilities } // Diagnostic options. @@ -1532,6 +1529,29 @@ type DiagnosticWorkspaceClientCapabilities struct { RefreshSupport bool `json:"refreshSupport,omitempty"` } +// General diagnostics capabilities for pull and push model. +// +// See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#diagnosticsCapabilities +type DiagnosticsCapabilities struct { + // Whether the clients accepts diagnostics with related information. + RelatedInformation bool `json:"relatedInformation,omitempty"` + // Client supports the tag property to provide meta data about a diagnostic. + // Clients supporting tags have to handle unknown tags gracefully. + // + // @since 3.15.0 + TagSupport *ClientDiagnosticsTagOptions `json:"tagSupport,omitempty"` + // Client supports a codeDescription property + // + // @since 3.16.0 + CodeDescriptionSupport bool `json:"codeDescriptionSupport,omitempty"` + // Whether code action supports the `data` property which is + // preserved between a `textDocument/publishDiagnostics` and + // `textDocument/codeAction` request. + // + // @since 3.16.0 + DataSupport bool `json:"dataSupport,omitempty"` +} + // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#didChangeConfigurationClientCapabilities type DidChangeConfigurationClientCapabilities struct { // Did change configuration notification supports dynamic registration. @@ -2126,7 +2146,6 @@ type DocumentSymbolRegistrationOptions struct { // Edit range variant that includes ranges for insert and replace operations. // // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#editRangeWithInsertReplace type EditRangeWithInsertReplace struct { @@ -3135,7 +3154,6 @@ type LocationLink struct { // Location with only uri and does not include range. // // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#locationUriOnly type LocationUriOnly struct { @@ -3191,7 +3209,6 @@ type MarkdownClientCapabilities struct { // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#markedString type MarkedString = Or_MarkedString // (alias) // @since 3.18.0 -// @proposed // @deprecated use MarkupContent instead. // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#markedStringWithLanguage @@ -3346,7 +3363,6 @@ type NotebookCellArrayChange struct { type NotebookCellKind uint32 // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#notebookCellLanguage type NotebookCellLanguage struct { @@ -3397,7 +3413,6 @@ type NotebookDocument struct { // Structural changes to cells in a notebook document. // // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#notebookDocumentCellChangeStructure type NotebookDocumentCellChangeStructure struct { @@ -3412,7 +3427,6 @@ type NotebookDocumentCellChangeStructure struct { // Cell changes to a notebook document. // // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#notebookDocumentCellChanges type NotebookDocumentCellChanges struct { @@ -3429,7 +3443,6 @@ type NotebookDocumentCellChanges struct { // Content changes to a cell in a notebook document. // // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#notebookDocumentCellContentChanges type NotebookDocumentCellContentChanges struct { @@ -3474,7 +3487,6 @@ type NotebookDocumentFilter = Or_NotebookDocumentFilter // (alias) // A notebook document filter where `notebookType` is required field. // // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#notebookDocumentFilterNotebookType type NotebookDocumentFilterNotebookType struct { @@ -3483,13 +3495,12 @@ type NotebookDocumentFilterNotebookType struct { // A Uri {@link Uri.scheme scheme}, like `file` or `untitled`. Scheme string `json:"scheme,omitempty"` // A glob pattern. - Pattern string `json:"pattern,omitempty"` + Pattern *GlobPattern `json:"pattern,omitempty"` } // A notebook document filter where `pattern` is required field. // // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#notebookDocumentFilterPattern type NotebookDocumentFilterPattern struct { @@ -3498,13 +3509,12 @@ type NotebookDocumentFilterPattern struct { // A Uri {@link Uri.scheme scheme}, like `file` or `untitled`. Scheme string `json:"scheme,omitempty"` // A glob pattern. - Pattern string `json:"pattern"` + Pattern GlobPattern `json:"pattern"` } // A notebook document filter where `scheme` is required field. // // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#notebookDocumentFilterScheme type NotebookDocumentFilterScheme struct { @@ -3513,11 +3523,10 @@ type NotebookDocumentFilterScheme struct { // A Uri {@link Uri.scheme scheme}, like `file` or `untitled`. Scheme string `json:"scheme"` // A glob pattern. - Pattern string `json:"pattern,omitempty"` + Pattern *GlobPattern `json:"pattern,omitempty"` } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#notebookDocumentFilterWithCells type NotebookDocumentFilterWithCells struct { @@ -3530,7 +3539,6 @@ type NotebookDocumentFilterWithCells struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#notebookDocumentFilterWithNotebook type NotebookDocumentFilterWithNotebook struct { @@ -3913,7 +3921,7 @@ type Or_SignatureInformation_documentation struct { Value interface{} `json:"value"` } -// created for Or [AnnotatedTextEdit TextEdit] +// created for Or [AnnotatedTextEdit SnippetTextEdit TextEdit] type Or_TextDocumentEdit_edits_Elem struct { Value interface{} `json:"value"` } @@ -3938,6 +3946,11 @@ type Or_WorkspaceEdit_documentChanges_Elem struct { Value interface{} `json:"value"` } +// created for Or [TextDocumentContentOptions TextDocumentContentRegistrationOptions] +type Or_WorkspaceOptions_textDocumentContent struct { + Value interface{} `json:"value"` +} + // created for Or [Declaration []DeclarationLink] type Or_textDocument_declaration struct { Value interface{} `json:"value"` @@ -4054,7 +4067,6 @@ type Position struct { type PositionEncodingKind string // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#prepareRenameDefaultBehavior type PrepareRenameDefaultBehavior struct { @@ -4068,7 +4080,6 @@ type PrepareRenameParams struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#prepareRenamePlaceholder type PrepareRenamePlaceholder struct { @@ -4120,28 +4131,12 @@ type ProgressToken = interface{} // (alias) // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#publishDiagnosticsClientCapabilities type PublishDiagnosticsClientCapabilities struct { - // Whether the clients accepts diagnostics with related information. - RelatedInformation bool `json:"relatedInformation,omitempty"` - // Client supports the tag property to provide meta data about a diagnostic. - // Clients supporting tags have to handle unknown tags gracefully. - // - // @since 3.15.0 - TagSupport *ClientDiagnosticsTagOptions `json:"tagSupport,omitempty"` // Whether the client interprets the version property of the // `textDocument/publishDiagnostics` notification's parameter. // // @since 3.15.0 VersionSupport bool `json:"versionSupport,omitempty"` - // Client supports a codeDescription property - // - // @since 3.16.0 - CodeDescriptionSupport bool `json:"codeDescriptionSupport,omitempty"` - // Whether code action supports the `data` property which is - // preserved between a `textDocument/publishDiagnostics` and - // `textDocument/codeAction` request. - // - // @since 3.16.0 - DataSupport bool `json:"dataSupport,omitempty"` + DiagnosticsCapabilities } // The publish diagnostic notification's parameters. @@ -4595,7 +4590,6 @@ type SemanticTokensEdit struct { // Semantic tokens options to support deltas for full documents // // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#semanticTokensFullDelta type SemanticTokensFullDelta struct { @@ -4794,7 +4788,6 @@ type ServerCapabilities struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#serverCompletionItemOptions type ServerCompletionItemOptions struct { @@ -4810,7 +4803,6 @@ type ServerCompletionItemOptions struct { // // @since 3.15.0 // @since 3.18.0 ServerInfo type name added. -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#serverInfo type ServerInfo struct { @@ -5047,9 +5039,23 @@ type SignatureInformation struct { ActiveParameter uint32 `json:"activeParameter,omitempty"` } +// An interactive text edit. +// // @since 3.18.0 // @proposed // +// See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#snippetTextEdit +type SnippetTextEdit struct { + // The range of the text document to be manipulated. + Range Range `json:"range"` + // The snippet to be inserted. + Snippet StringValue `json:"snippet"` + // The actual identifier of the snippet edit. + AnnotationID *ChangeAnnotationIdentifier `json:"annotationId,omitempty"` +} + +// @since 3.18.0 +// // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#staleRequestSupportOptions type StaleRequestSupportOptions struct { // The client will actively cancel the request. @@ -5247,7 +5253,6 @@ type TextDocumentClientCapabilities struct { // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#textDocumentContentChangeEvent type TextDocumentContentChangeEvent = TextDocumentContentChangePartial // (alias) // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#textDocumentContentChangePartial type TextDocumentContentChangePartial struct { @@ -5262,7 +5267,6 @@ type TextDocumentContentChangePartial struct { } // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#textDocumentContentChangeWholeDocument type TextDocumentContentChangeWholeDocument struct { @@ -5270,6 +5274,61 @@ type TextDocumentContentChangeWholeDocument struct { Text string `json:"text"` } +// Client capabilities for a text document content provider. +// +// @since 3.18.0 +// @proposed +// +// See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#textDocumentContentClientCapabilities +type TextDocumentContentClientCapabilities struct { + // Text document content provider supports dynamic registration. + DynamicRegistration bool `json:"dynamicRegistration,omitempty"` +} + +// Text document content provider options. +// +// @since 3.18.0 +// @proposed +// +// See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#textDocumentContentOptions +type TextDocumentContentOptions struct { + // The scheme for which the server provides content. + Scheme string `json:"scheme"` +} + +// Parameters for the `workspace/textDocumentContent` request. +// +// @since 3.18.0 +// @proposed +// +// See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#textDocumentContentParams +type TextDocumentContentParams struct { + // The uri of the text document. + URI DocumentURI `json:"uri"` +} + +// Parameters for the `workspace/textDocumentContent/refresh` request. +// +// @since 3.18.0 +// @proposed +// +// See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#textDocumentContentRefreshParams +type TextDocumentContentRefreshParams struct { + // The uri of the text document to refresh. + URI DocumentURI `json:"uri"` +} + +// Text document content provider registration options. +// +// @since 3.18.0 +// @proposed +// +// See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#textDocumentContentRegistrationOptions +type TextDocumentContentRegistrationOptions struct { + TextDocumentContentOptions + StaticRegistrationOptions +} + // Describes textual changes on a text document. A TextDocumentEdit describes all changes // on a document version Si and after they are applied move the document to version Si+1. // So the creator of a TextDocumentEdit doesn't need to sort the array of edits or do any @@ -5283,6 +5342,9 @@ type TextDocumentEdit struct { // // @since 3.16.0 - support for AnnotatedTextEdit. This is guarded using a // client capability. + // + // @since 3.18.0 - support for SnippetTextEdit. This is guarded using a + // client capability. Edits []Or_TextDocumentEdit_edits_Elem `json:"edits"` } @@ -5309,7 +5371,6 @@ type TextDocumentFilter = Or_TextDocumentFilter // (alias) // A document filter where `language` is required field. // // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#textDocumentFilterLanguage type TextDocumentFilterLanguage struct { @@ -5318,13 +5379,14 @@ type TextDocumentFilterLanguage struct { // A Uri {@link Uri.scheme scheme}, like `file` or `untitled`. Scheme string `json:"scheme,omitempty"` // A glob pattern, like **​/*.{ts,js}. See TextDocumentFilter for examples. - Pattern string `json:"pattern,omitempty"` + // + // @since 3.18.0 - support for relative patterns. + Pattern *GlobPattern `json:"pattern,omitempty"` } // A document filter where `pattern` is required field. // // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#textDocumentFilterPattern type TextDocumentFilterPattern struct { @@ -5333,13 +5395,14 @@ type TextDocumentFilterPattern struct { // A Uri {@link Uri.scheme scheme}, like `file` or `untitled`. Scheme string `json:"scheme,omitempty"` // A glob pattern, like **​/*.{ts,js}. See TextDocumentFilter for examples. - Pattern string `json:"pattern"` + // + // @since 3.18.0 - support for relative patterns. + Pattern GlobPattern `json:"pattern"` } // A document filter where `scheme` is required field. // // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#textDocumentFilterScheme type TextDocumentFilterScheme struct { @@ -5348,7 +5411,9 @@ type TextDocumentFilterScheme struct { // A Uri {@link Uri.scheme scheme}, like `file` or `untitled`. Scheme string `json:"scheme"` // A glob pattern, like **​/*.{ts,js}. See TextDocumentFilter for examples. - Pattern string `json:"pattern,omitempty"` + // + // @since 3.18.0 - support for relative patterns. + Pattern *GlobPattern `json:"pattern,omitempty"` } // A literal to identify a text document in the client. @@ -5822,6 +5887,11 @@ type WorkspaceClientCapabilities struct { // @since 3.18.0 // @proposed FoldingRange *FoldingRangeWorkspaceClientCapabilities `json:"foldingRange,omitempty"` + // Capabilities specific to the `workspace/textDocumentContent` request. + // + // @since 3.18.0 + // @proposed + TextDocumentContent *TextDocumentContentClientCapabilities `json:"textDocumentContent,omitempty"` } // Parameters of the workspace diagnostic request. @@ -5927,6 +5997,27 @@ type WorkspaceEditClientCapabilities struct { // // @since 3.16.0 ChangeAnnotationSupport *ChangeAnnotationsSupportOptions `json:"changeAnnotationSupport,omitempty"` + // Whether the client supports `WorkspaceEditMetadata` in `WorkspaceEdit`s. + // + // @since 3.18.0 + // @proposed + MetadataSupport bool `json:"metadataSupport,omitempty"` + // Whether the client supports snippets as text edits. + // + // @since 3.18.0 + // @proposed + SnippetEditSupport bool `json:"snippetEditSupport,omitempty"` +} + +// Additional data about a workspace edit. +// +// @since 3.18.0 +// @proposed +// +// See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#workspaceEditMetadata +type WorkspaceEditMetadata struct { + // Signal to the editor that this edit is a refactoring. + IsRefactoring bool `json:"isRefactoring,omitempty"` } // A workspace folder inside a client. @@ -6007,7 +6098,6 @@ type WorkspaceFullDocumentDiagnosticReport struct { // Defines workspace specific capabilities of the server. // // @since 3.18.0 -// @proposed // // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#workspaceOptions type WorkspaceOptions struct { @@ -6019,6 +6109,11 @@ type WorkspaceOptions struct { // // @since 3.16.0 FileOperations *FileOperationOptions `json:"fileOperations,omitempty"` + // The server supports the `workspace/textDocumentContent` request. + // + // @since 3.18.0 + // @proposed + TextDocumentContent *Or_WorkspaceOptions_textDocumentContent `json:"textDocumentContent,omitempty"` } // A special workspace symbol that supports locations without a range. @@ -6080,6 +6175,12 @@ type WorkspaceSymbolOptions struct { type WorkspaceSymbolParams struct { // A query string to filter symbols by. Clients may send an empty // string here to request all symbols. + // + // The `query`-parameter should be interpreted in a *relaxed way* as editors + // will apply their own highlighting and scoring on the results. A good rule + // of thumb is to match case-insensitive and to simply check that the + // characters of *query* appear in their order in a candidate symbol. + // Servers shouldn't use prefix, substring, or similar strict matching. Query string `json:"query"` WorkDoneProgressParams PartialResultParams @@ -6465,7 +6566,7 @@ const ( // If a client decides that a result is not of any use anymore // the client should cancel the request. ContentModified LSPErrorCodes = -32801 - // The client has canceled a request and a server as detected + // The client has canceled a request and a server has detected // the cancel. RequestCancelled LSPErrorCodes = -32800 // Predefined Language kinds @@ -6497,6 +6598,7 @@ const ( LangGo LanguageKind = "go" LangGroovy LanguageKind = "groovy" LangHandlebars LanguageKind = "handlebars" + LangHaskell LanguageKind = "haskell" LangHTML LanguageKind = "html" LangIni LanguageKind = "ini" LangJava LanguageKind = "java" @@ -6648,6 +6750,8 @@ const ( OperatorType SemanticTokenTypes = "operator" // @since 3.17.0 DecoratorType SemanticTokenTypes = "decorator" + // @since 3.18.0 + LabelType SemanticTokenTypes = "label" // How a signature help was triggered. // // @since 3.15.0 diff --git a/gopls/internal/protocol/tsserver.go b/gopls/internal/protocol/tsserver.go index 4e7df50cae1..51ddad9ec1f 100644 --- a/gopls/internal/protocol/tsserver.go +++ b/gopls/internal/protocol/tsserver.go @@ -6,8 +6,8 @@ package protocol -// Code generated from protocol/metaModel.json at ref release/protocol/3.17.6-next.2 (hash 654dc9be6673c61476c28fda604406279c3258d7). -// https://github.com/microsoft/vscode-languageserver-node/blob/release/protocol/3.17.6-next.2/protocol/metaModel.json +// Code generated from protocol/metaModel.json at ref release/protocol/3.17.6-next.9 (hash c94395b5da53729e6dff931293b051009ccaaaa4). +// https://github.com/microsoft/vscode-languageserver-node/blob/release/protocol/3.17.6-next.9/protocol/metaModel.json // LSP metaData.version = 3.17.0. import ( @@ -155,6 +155,8 @@ type Server interface { ExecuteCommand(context.Context, *ExecuteCommandParams) (interface{}, error) // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#workspace_symbol Symbol(context.Context, *WorkspaceSymbolParams) ([]SymbolInformation, error) + // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#workspace_textDocumentContent + TextDocumentContent(context.Context, *TextDocumentContentParams) (*string, error) // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#workspace_willCreateFiles WillCreateFiles(context.Context, *CreateFilesParams) (*WorkspaceEdit, error) // See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification#workspace_willDeleteFiles @@ -856,6 +858,17 @@ func serverDispatch(ctx context.Context, server Server, reply jsonrpc2.Replier, } return true, reply(ctx, resp, nil) + case "workspace/textDocumentContent": + var params TextDocumentContentParams + if err := UnmarshalJSON(r.Params(), ¶ms); err != nil { + return true, sendParseError(ctx, reply, err) + } + resp, err := server.TextDocumentContent(ctx, ¶ms) + if err != nil { + return true, reply(ctx, nil, err) + } + return true, reply(ctx, resp, nil) + case "workspace/willCreateFiles": var params CreateFilesParams if err := UnmarshalJSON(r.Params(), ¶ms); err != nil { @@ -1304,6 +1317,13 @@ func (s *serverDispatcher) Symbol(ctx context.Context, params *WorkspaceSymbolPa } return result, nil } +func (s *serverDispatcher) TextDocumentContent(ctx context.Context, params *TextDocumentContentParams) (*string, error) { + var result *string + if err := s.sender.Call(ctx, "workspace/textDocumentContent", params, &result); err != nil { + return nil, err + } + return result, nil +} func (s *serverDispatcher) WillCreateFiles(ctx context.Context, params *CreateFilesParams) (*WorkspaceEdit, error) { var result *WorkspaceEdit if err := s.sender.Call(ctx, "workspace/willCreateFiles", params, &result); err != nil { diff --git a/gopls/internal/protocol/uri.go b/gopls/internal/protocol/uri.go index 86775b065f5..e4252909835 100644 --- a/gopls/internal/protocol/uri.go +++ b/gopls/internal/protocol/uri.go @@ -90,7 +90,13 @@ func (uri DocumentURI) Path() string { func (uri DocumentURI) Dir() DocumentURI { // This function could be more efficiently implemented by avoiding any call // to Path(), but at least consolidates URI manipulation. - return URIFromPath(filepath.Dir(uri.Path())) + return URIFromPath(uri.DirPath()) +} + +// DirPath returns the file path to the directory containing this URI, which +// must be a file URI. +func (uri DocumentURI) DirPath() string { + return filepath.Dir(uri.Path()) } // Encloses reports whether uri's path, considered as a sequence of segments, diff --git a/gopls/internal/server/code_action.go b/gopls/internal/server/code_action.go index e00e343850d..2e1c83f407f 100644 --- a/gopls/internal/server/code_action.go +++ b/gopls/internal/server/code_action.go @@ -265,7 +265,10 @@ func (s *server) codeActionsMatchingDiagnostics(ctx context.Context, uri protoco var actions []protocol.CodeAction var unbundled []protocol.Diagnostic // diagnostics without bundled code actions in their Data field for _, pd := range pds { - bundled := cache.BundledLazyFixes(pd) + bundled, err := cache.BundledLazyFixes(pd) + if err != nil { + return nil, err + } if len(bundled) > 0 { for _, fix := range bundled { if enabled(fix.Kind) { diff --git a/gopls/internal/server/command.go b/gopls/internal/server/command.go index 403eadf0d2c..9995d02117e 100644 --- a/gopls/internal/server/command.go +++ b/gopls/internal/server/command.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io" + "log" "os" "path/filepath" "regexp" @@ -49,6 +50,15 @@ func (s *server) ExecuteCommand(ctx context.Context, params *protocol.ExecuteCom ctx, done := event.Start(ctx, "lsp.Server.executeCommand") defer done() + // For test synchronization, always create a progress notification. + // + // This may be in addition to user-facing progress notifications created in + // the course of command execution. + if s.Options().VerboseWorkDoneProgress { + work := s.progress.Start(ctx, params.Command, "Verbose: running command...", nil, nil) + defer work.End(ctx, "Done.") + } + var found bool for _, name := range s.Options().SupportedCommands { if name == params.Command { @@ -256,7 +266,11 @@ func (h *commandHandler) Packages(ctx context.Context, args command.PackagesArgs } func (h *commandHandler) MaybePromptForTelemetry(ctx context.Context) error { - go h.s.maybePromptForTelemetry(ctx, true) + // if the server's TelemetryPrompt is true, it's likely the server already + // handled prompting for it. Don't try to prompt again. + if !h.s.options.TelemetryPrompt { + go h.s.maybePromptForTelemetry(ctx, true) + } return nil } @@ -388,6 +402,22 @@ func (c *commandHandler) run(ctx context.Context, cfg commandConfig, run command return err } + // For legacy reasons, gopls.run_govulncheck must run asynchronously. + // TODO(golang/vscode-go#3572): remove this (along with the + // gopls.run_govulncheck command entirely) once VS Code only uses the new + // gopls.vulncheck command. + if c.params.Command == "gopls.run_govulncheck" { + if cfg.progress == "" { + log.Fatalf("asynchronous command gopls.run_govulncheck does not enable progress reporting") + } + go func() { + if err := runcmd(); err != nil { + showMessage(ctx, c.s.client, protocol.Error, err.Error()) + } + }() + return nil + } + return runcmd() } @@ -568,11 +598,7 @@ func (c *commandHandler) Vendor(ctx context.Context, args command.URIArg) error // modules.txt in-place. In that case we could theoretically allow this // command to run concurrently. stderr := new(bytes.Buffer) - inv, cleanupInvocation, err := deps.snapshot.GoCommandInvocation(true, &gocommand.Invocation{ - Verb: "mod", - Args: []string{"vendor"}, - WorkingDir: filepath.Dir(args.URI.Path()), - }) + inv, cleanupInvocation, err := deps.snapshot.GoCommandInvocation(cache.NetworkOK, args.URI.DirPath(), "mod", []string{"vendor"}) if err != nil { return err } @@ -698,7 +724,7 @@ func (c *commandHandler) Doc(ctx context.Context, args command.DocArgs) (protoco // Direct the client to open the /pkg page. result = web.PkgURL(deps.snapshot.View().ID(), pkgpath, fragment) if args.ShowDocument { - openClientBrowser(ctx, c.s.client, result) + openClientBrowser(ctx, c.s.client, "Doc", result, c.s.Options()) } return nil @@ -733,11 +759,8 @@ func (c *commandHandler) runTests(ctx context.Context, snapshot *cache.Snapshot, // Run `go test -run Func` on each test. var failedTests int for _, funcName := range tests { - inv, cleanupInvocation, err := snapshot.GoCommandInvocation(false, &gocommand.Invocation{ - Verb: "test", - Args: []string{pkgPath, "-v", "-count=1", fmt.Sprintf("-run=^%s$", regexp.QuoteMeta(funcName))}, - WorkingDir: filepath.Dir(uri.Path()), - }) + args := []string{pkgPath, "-v", "-count=1", fmt.Sprintf("-run=^%s$", regexp.QuoteMeta(funcName))} + inv, cleanupInvocation, err := snapshot.GoCommandInvocation(cache.NoNetwork, uri.DirPath(), "test", args) if err != nil { return err } @@ -753,10 +776,8 @@ func (c *commandHandler) runTests(ctx context.Context, snapshot *cache.Snapshot, // Run `go test -run=^$ -bench Func` on each test. var failedBenchmarks int for _, funcName := range benchmarks { - inv, cleanupInvocation, err := snapshot.GoCommandInvocation(false, &gocommand.Invocation{ - Verb: "test", - Args: []string{pkgPath, "-v", "-run=^$", fmt.Sprintf("-bench=^%s$", regexp.QuoteMeta(funcName))}, - WorkingDir: filepath.Dir(uri.Path()), + inv, cleanupInvocation, err := snapshot.GoCommandInvocation(cache.NoNetwork, uri.DirPath(), "test", []string{ + pkgPath, "-v", "-run=^$", fmt.Sprintf("-bench=^%s$", regexp.QuoteMeta(funcName)), }) if err != nil { return err @@ -816,11 +837,7 @@ func (c *commandHandler) Generate(ctx context.Context, args command.GenerateArgs if args.Recursive { pattern = "./..." } - inv, cleanupInvocation, err := deps.snapshot.GoCommandInvocation(true, &gocommand.Invocation{ - Verb: "generate", - Args: []string{"-x", pattern}, - WorkingDir: args.Dir.Path(), - }) + inv, cleanupInvocation, err := deps.snapshot.GoCommandInvocation(cache.NetworkOK, args.Dir.Path(), "generate", []string{"-x", pattern}) if err != nil { return err } @@ -849,12 +866,10 @@ func (c *commandHandler) GoGetPackage(ctx context.Context, args command.GoGetPac } defer cleanupModDir() - inv, cleanupInvocation, err := snapshot.GoCommandInvocation(true, &gocommand.Invocation{ - Verb: "list", - Args: []string{"-f", "{{.Module.Path}}@{{.Module.Version}}", "-mod=mod", "-modfile=" + filepath.Join(tempDir, "go.mod"), args.Pkg}, - Env: []string{"GOWORK=off"}, - WorkingDir: modURI.Dir().Path(), - }) + inv, cleanupInvocation, err := snapshot.GoCommandInvocation(cache.NetworkOK, modURI.DirPath(), "list", + []string{"-f", "{{.Module.Path}}@{{.Module.Version}}", "-mod=mod", "-modfile=" + filepath.Join(tempDir, "go.mod"), args.Pkg}, + "GOWORK=off", + ) if err != nil { return err } @@ -984,12 +999,8 @@ func addModuleRequire(invoke func(...string) (*bytes.Buffer, error), args []stri // TODO(rfindley): inline. func (s *server) getUpgrades(ctx context.Context, snapshot *cache.Snapshot, uri protocol.DocumentURI, modules []string) (map[string]string, error) { - inv, cleanup, err := snapshot.GoCommandInvocation(true, &gocommand.Invocation{ - Verb: "list", - // -mod=readonly is necessary when vendor is present (golang/go#66055) - Args: append([]string{"-mod=readonly", "-m", "-u", "-json"}, modules...), - WorkingDir: filepath.Dir(uri.Path()), - }) + args := append([]string{"-mod=readonly", "-m", "-u", "-json"}, modules...) + inv, cleanup, err := snapshot.GoCommandInvocation(cache.NetworkOK, uri.DirPath(), "list", args) if err != nil { return nil, err } @@ -1137,7 +1148,7 @@ func (c *commandHandler) StartDebugging(ctx context.Context, args command.Debugg return result, fmt.Errorf("starting debug server: %w", err) } result.URLs = []string{"http://" + listenedAddr} - openClientBrowser(ctx, c.s.client, result.URLs[0]) + openClientBrowser(ctx, c.s.client, "Debug", result.URLs[0], c.s.Options()) return result, nil } @@ -1206,23 +1217,91 @@ func (c *commandHandler) FetchVulncheckResult(ctx context.Context, arg command.U const GoVulncheckCommandTitle = "govulncheck" +func (c *commandHandler) Vulncheck(ctx context.Context, args command.VulncheckArgs) (command.VulncheckResult, error) { + if args.URI == "" { + return command.VulncheckResult{}, errors.New("VulncheckArgs is missing URI field") + } + + var commandResult command.VulncheckResult + err := c.run(ctx, commandConfig{ + progress: GoVulncheckCommandTitle, + requireSave: true, // govulncheck cannot honor overlays + forURI: args.URI, + }, func(ctx context.Context, deps commandDeps) error { + jsonrpc2.Async(ctx) // run this in parallel with other requests: vulncheck can be slow. + + workDoneWriter := progress.NewWorkDoneWriter(ctx, deps.work) + dir := args.URI.DirPath() + pattern := args.Pattern + + result, err := scan.RunGovulncheck(ctx, pattern, deps.snapshot, dir, workDoneWriter) + if err != nil { + return err + } + commandResult.Result = result + + snapshot, release, err := c.s.session.InvalidateView(ctx, deps.snapshot.View(), cache.StateChange{ + Vulns: map[protocol.DocumentURI]*vulncheck.Result{args.URI: result}, + }) + if err != nil { + return err + } + defer release() + + // Diagnosing with the background context ensures new snapshots are fully + // diagnosed. + c.s.diagnoseSnapshot(snapshot.BackgroundContext(), snapshot, nil, 0) + + affecting := make(map[string]bool, len(result.Entries)) + for _, finding := range result.Findings { + if len(finding.Trace) > 1 { // at least 2 frames if callstack exists (vulnerability, entry) + affecting[finding.OSV] = true + } + } + if len(affecting) == 0 { + showMessage(ctx, c.s.client, protocol.Info, "No vulnerabilities found") + return nil + } + affectingOSVs := make([]string, 0, len(affecting)) + for id := range affecting { + affectingOSVs = append(affectingOSVs, id) + } + sort.Strings(affectingOSVs) + + showMessage(ctx, c.s.client, protocol.Warning, fmt.Sprintf("Found %v", strings.Join(affectingOSVs, ", "))) + + return nil + }) + if err != nil { + return command.VulncheckResult{}, err + } + return commandResult, nil +} + +// RunGovulncheck is like Vulncheck (in fact, a copy), but is tweaked slightly +// to run asynchronously rather than return a result. +// +// This logic was copied, rather than factored out, as this implementation is +// slated for deletion. +// +// TODO(golang/vscode-go#3572) func (c *commandHandler) RunGovulncheck(ctx context.Context, args command.VulncheckArgs) (command.RunVulncheckResult, error) { if args.URI == "" { return command.RunVulncheckResult{}, errors.New("VulncheckArgs is missing URI field") } - var commandResult command.RunVulncheckResult + // Return the workdone token so that clients can identify when this + // vulncheck invocation is complete. + // + // Since the run function executes asynchronously, we use a channel to + // synchronize the start of the run and return the token. + tokenChan := make(chan protocol.ProgressToken, 1) err := c.run(ctx, commandConfig{ progress: GoVulncheckCommandTitle, requireSave: true, // govulncheck cannot honor overlays forURI: args.URI, }, func(ctx context.Context, deps commandDeps) error { - // For compatibility with the legacy asynchronous API, return the workdone - // token that clients used to use to identify when this vulncheck - // invocation is complete. - commandResult.Token = deps.work.Token() - - jsonrpc2.Async(ctx) // run this in parallel with other requests: vulncheck can be slow. + tokenChan <- deps.work.Token() workDoneWriter := progress.NewWorkDoneWriter(ctx, deps.work) dir := filepath.Dir(args.URI.Path()) @@ -1232,7 +1311,6 @@ func (c *commandHandler) RunGovulncheck(ctx context.Context, args command.Vulnch if err != nil { return err } - commandResult.Result = result snapshot, release, err := c.s.session.InvalidateView(ctx, deps.snapshot.View(), cache.StateChange{ Vulns: map[protocol.DocumentURI]*vulncheck.Result{args.URI: result}, @@ -1269,7 +1347,12 @@ func (c *commandHandler) RunGovulncheck(ctx context.Context, args command.Vulnch if err != nil { return command.RunVulncheckResult{}, err } - return commandResult, nil + select { + case <-ctx.Done(): + return command.RunVulncheckResult{}, ctx.Err() + case token := <-tokenChan: + return command.RunVulncheckResult{Token: token}, nil + } } // MemStats implements the MemStats command. It returns an error as a @@ -1465,8 +1548,21 @@ func showMessage(ctx context.Context, cli protocol.Client, typ protocol.MessageT // openClientBrowser causes the LSP client to open the specified URL // in an external browser. -func openClientBrowser(ctx context.Context, cli protocol.Client, url protocol.URI) { - showDocumentImpl(ctx, cli, url, nil) +// +// If the client does not support window/showDocument, a window/showMessage +// request is instead used, with the format "$title: open your browser to $url". +func openClientBrowser(ctx context.Context, cli protocol.Client, title string, url protocol.URI, opts *settings.Options) { + if opts.ShowDocumentSupported { + showDocumentImpl(ctx, cli, url, nil, opts) + } else { + params := &protocol.ShowMessageParams{ + Type: protocol.Info, + Message: fmt.Sprintf("%s: open your browser to %s", title, url), + } + if err := cli.ShowMessage(ctx, params); err != nil { + event.Error(ctx, "failed to show brower url", err) + } + } } // openClientEditor causes the LSP client to open the specified document @@ -1474,11 +1570,17 @@ func openClientBrowser(ctx context.Context, cli protocol.Client, url protocol.UR // // Note that VS Code 1.87.2 doesn't currently raise the window; this is // https://github.com/microsoft/vscode/issues/207634 -func openClientEditor(ctx context.Context, cli protocol.Client, loc protocol.Location) { - showDocumentImpl(ctx, cli, protocol.URI(loc.URI), &loc.Range) +func openClientEditor(ctx context.Context, cli protocol.Client, loc protocol.Location, opts *settings.Options) { + if !opts.ShowDocumentSupported { + return // no op + } + showDocumentImpl(ctx, cli, protocol.URI(loc.URI), &loc.Range, opts) } -func showDocumentImpl(ctx context.Context, cli protocol.Client, url protocol.URI, rangeOpt *protocol.Range) { +func showDocumentImpl(ctx context.Context, cli protocol.Client, url protocol.URI, rangeOpt *protocol.Range, opts *settings.Options) { + if !opts.ShowDocumentSupported { + return // no op + } // In principle we shouldn't send a showDocument request to a // client that doesn't support it, as reported by // ShowDocumentClientCapabilities. But even clients that do @@ -1513,10 +1615,23 @@ func showDocumentImpl(ctx context.Context, cli protocol.Client, url protocol.URI func (c *commandHandler) ChangeSignature(ctx context.Context, args command.ChangeSignatureArgs) (*protocol.WorkspaceEdit, error) { var result *protocol.WorkspaceEdit err := c.run(ctx, commandConfig{ - forURI: args.RemoveParameter.URI, + forURI: args.Location.URI, }, func(ctx context.Context, deps commandDeps) error { - // For now, gopls only supports removing unused parameters. - docedits, err := golang.RemoveUnusedParameter(ctx, deps.fh, args.RemoveParameter.Range, deps.snapshot) + pkg, pgf, err := golang.NarrowestPackageForFile(ctx, deps.snapshot, args.Location.URI) + if err != nil { + return err + } + + // For now, gopls only supports parameter permutation or removal. + var perm []int + for _, newParam := range args.NewParams { + if newParam.NewField != "" { + return fmt.Errorf("adding new parameters is currently unsupported") + } + perm = append(perm, newParam.OldIndex) + } + + docedits, err := golang.ChangeSignature(ctx, deps.snapshot, pkg, pgf, args.Location.Range, perm) if err != nil { return err } @@ -1595,7 +1710,7 @@ func (c *commandHandler) FreeSymbols(ctx context.Context, viewID string, loc pro return err } url := web.freesymbolsURL(viewID, loc) - openClientBrowser(ctx, c.s.client, url) + openClientBrowser(ctx, c.s.client, "Free symbols", url, c.s.Options()) return nil } @@ -1605,12 +1720,14 @@ func (c *commandHandler) Assembly(ctx context.Context, viewID, packageID, symbol return err } url := web.assemblyURL(viewID, packageID, symbol) - openClientBrowser(ctx, c.s.client, url) + openClientBrowser(ctx, c.s.client, "Assembly", url, c.s.Options()) return nil } func (c *commandHandler) ClientOpenURL(ctx context.Context, url string) error { - openClientBrowser(ctx, c.s.client, url) + // Fall back to "Gopls: open your browser..." if we must send a showMessage + // request, since we don't know the context of this command. + openClientBrowser(ctx, c.s.client, "Gopls", url, c.s.Options()) return nil } diff --git a/gopls/internal/server/prompt.go b/gopls/internal/server/prompt.go index 7eb400cfbe0..37f591487a6 100644 --- a/gopls/internal/server/prompt.go +++ b/gopls/internal/server/prompt.go @@ -162,6 +162,7 @@ func (s *server) maybePromptForTelemetry(ctx context.Context, enabled bool) { // v0.17 ~: must have all four fields. } else { state, attempts, creationTime, token = pUnknown, 0, 0, 0 + // TODO(hyangah): why do we want to present this as an error to user? errorf("malformed prompt result %q", string(content)) } } else if !os.IsNotExist(err) { diff --git a/gopls/internal/server/rename.go b/gopls/internal/server/rename.go index 93b2ac6f9c4..cdfb9c7a8fe 100644 --- a/gopls/internal/server/rename.go +++ b/gopls/internal/server/rename.go @@ -50,7 +50,7 @@ func (s *server) Rename(ctx context.Context, params *protocol.RenameParams) (*pr if isPkgRenaming { // Update the last component of the file's enclosing directory. - oldDir := filepath.Dir(fh.URI().Path()) + oldDir := fh.URI().DirPath() newDir := filepath.Join(filepath.Dir(oldDir), params.NewName) change := protocol.DocumentChangeRename( protocol.URIFromPath(oldDir), diff --git a/gopls/internal/server/server.go b/gopls/internal/server/server.go index 80e64bb996c..d9090250a66 100644 --- a/gopls/internal/server/server.go +++ b/gopls/internal/server/server.go @@ -303,7 +303,7 @@ func (s *server) initWeb() (*web, error) { openClientEditor(req.Context(), s.client, protocol.Location{ URI: uri, Range: protocol.Range{Start: posn, End: posn}, - }) + }, s.Options()) }) // The /pkg/PATH&view=... handler shows package documentation for PATH. diff --git a/gopls/internal/server/text_synchronization.go b/gopls/internal/server/text_synchronization.go index 257eadbbf41..6aef24691d6 100644 --- a/gopls/internal/server/text_synchronization.go +++ b/gopls/internal/server/text_synchronization.go @@ -105,7 +105,7 @@ func (s *server) DidOpen(ctx context.Context, params *protocol.DidOpenTextDocume // file is opened, and we can't do that inside didModifyFiles because we // don't want to request configuration while holding a lock. if len(s.session.Views()) == 0 { - dir := filepath.Dir(uri.Path()) + dir := uri.DirPath() s.addFolders(ctx, []protocol.WorkspaceFolder{{ URI: string(protocol.URIFromPath(dir)), Name: filepath.Base(dir), diff --git a/gopls/internal/server/unimplemented.go b/gopls/internal/server/unimplemented.go index 9347f42c42e..470a7cbb0ee 100644 --- a/gopls/internal/server/unimplemented.go +++ b/gopls/internal/server/unimplemented.go @@ -150,6 +150,10 @@ func (s *server) WillSaveWaitUntil(context.Context, *protocol.WillSaveTextDocume return nil, notImplemented("WillSaveWaitUntil") } +func (s *server) TextDocumentContent(context.Context, *protocol.TextDocumentContentParams) (*string, error) { + return nil, notImplemented("TextDocumentContent") +} + func notImplemented(method string) error { return fmt.Errorf("%w: %q not yet implemented", jsonrpc2.ErrMethodNotFound, method) } diff --git a/gopls/internal/settings/analysis.go b/gopls/internal/settings/analysis.go index 6bb85f1beca..d20526fc583 100644 --- a/gopls/internal/settings/analysis.go +++ b/gopls/internal/settings/analysis.go @@ -45,6 +45,7 @@ import ( "golang.org/x/tools/go/analysis/passes/unsafeptr" "golang.org/x/tools/go/analysis/passes/unusedresult" "golang.org/x/tools/go/analysis/passes/unusedwrite" + "golang.org/x/tools/go/analysis/passes/waitgroup" "golang.org/x/tools/gopls/internal/analysis/deprecated" "golang.org/x/tools/gopls/internal/analysis/embeddirective" "golang.org/x/tools/gopls/internal/analysis/fillreturns" @@ -54,10 +55,10 @@ import ( "golang.org/x/tools/gopls/internal/analysis/simplifycompositelit" "golang.org/x/tools/gopls/internal/analysis/simplifyrange" "golang.org/x/tools/gopls/internal/analysis/simplifyslice" - "golang.org/x/tools/gopls/internal/analysis/undeclaredname" "golang.org/x/tools/gopls/internal/analysis/unusedparams" "golang.org/x/tools/gopls/internal/analysis/unusedvariable" "golang.org/x/tools/gopls/internal/analysis/useany" + "golang.org/x/tools/gopls/internal/analysis/yield" "golang.org/x/tools/gopls/internal/protocol" ) @@ -145,14 +146,16 @@ func init() { {analyzer: unusedresult.Analyzer, enabled: true}, // not suitable for vet: - // - some (nilness) use go/ssa; see #59714. + // - some (nilness, yield) use go/ssa; see #59714. // - others don't meet the "frequency" criterion; // see GOROOT/src/cmd/vet/README. {analyzer: atomicalign.Analyzer, enabled: true}, {analyzer: deepequalerrors.Analyzer, enabled: true}, {analyzer: nilness.Analyzer, enabled: true}, // uses go/ssa + {analyzer: yield.Analyzer, enabled: true}, // uses go/ssa {analyzer: sortslice.Analyzer, enabled: true}, {analyzer: embeddirective.Analyzer, enabled: true}, + {analyzer: waitgroup.Analyzer, enabled: true}, // to appear in cmd/vet@go1.25 // disabled due to high false positives {analyzer: shadow.Analyzer, enabled: false}, // very noisy @@ -174,7 +177,6 @@ func init() { {analyzer: fillreturns.Analyzer, enabled: true}, {analyzer: nonewvars.Analyzer, enabled: true}, {analyzer: noresultvalues.Analyzer, enabled: true}, - {analyzer: undeclaredname.Analyzer, enabled: true}, // TODO(rfindley): why isn't the 'unusedvariable' analyzer enabled, if it // is only enhancing type errors with suggested fixes? // diff --git a/gopls/internal/settings/codeactionkind.go b/gopls/internal/settings/codeactionkind.go index 16a2eecb2cb..7bc4f4e4d66 100644 --- a/gopls/internal/settings/codeactionkind.go +++ b/gopls/internal/settings/codeactionkind.go @@ -91,12 +91,15 @@ const ( RefactorRewriteInvertIf protocol.CodeActionKind = "refactor.rewrite.invertIf" RefactorRewriteJoinLines protocol.CodeActionKind = "refactor.rewrite.joinLines" RefactorRewriteRemoveUnusedParam protocol.CodeActionKind = "refactor.rewrite.removeUnusedParam" + RefactorRewriteMoveParamLeft protocol.CodeActionKind = "refactor.rewrite.moveParamLeft" + RefactorRewriteMoveParamRight protocol.CodeActionKind = "refactor.rewrite.moveParamRight" RefactorRewriteSplitLines protocol.CodeActionKind = "refactor.rewrite.splitLines" // refactor.inline RefactorInlineCall protocol.CodeActionKind = "refactor.inline.call" // refactor.extract + RefactorExtractConstant protocol.CodeActionKind = "refactor.extract.constant" RefactorExtractFunction protocol.CodeActionKind = "refactor.extract.function" RefactorExtractMethod protocol.CodeActionKind = "refactor.extract.method" RefactorExtractVariable protocol.CodeActionKind = "refactor.extract.variable" diff --git a/gopls/internal/settings/default.go b/gopls/internal/settings/default.go index 2f637f3d16d..0354101f045 100644 --- a/gopls/internal/settings/default.go +++ b/gopls/internal/settings/default.go @@ -61,6 +61,7 @@ func DefaultOptions(overrides ...func(*Options)) *Options { RefactorRewriteRemoveUnusedParam: true, RefactorRewriteSplitLines: true, RefactorInlineCall: true, + RefactorExtractConstant: true, RefactorExtractFunction: true, RefactorExtractMethod: true, RefactorExtractVariable: true, @@ -136,7 +137,6 @@ func DefaultOptions(overrides ...func(*Options)) *Options { LinkifyShowMessage: false, IncludeReplaceInWorkspace: false, ZeroConfig: true, - AddTestSourceCodeAction: false, }, } }) diff --git a/gopls/internal/settings/settings.go b/gopls/internal/settings/settings.go index 02c59163609..5f1efef040d 100644 --- a/gopls/internal/settings/settings.go +++ b/gopls/internal/settings/settings.go @@ -76,6 +76,7 @@ type ClientOptions struct { CompletionDeprecated bool SupportedResourceOperations []protocol.ResourceOperationKind CodeActionResolveOptions []string + ShowDocumentSupported bool } // ServerOptions holds LSP-specific configuration that is provided by the @@ -250,13 +251,23 @@ const ( // Run govulncheck // - // This codelens source annotates the `module` directive in a - // go.mod file with a command to run Govulncheck. + // This codelens source annotates the `module` directive in a go.mod file + // with a command to run govulncheck synchronously. // - // [Govulncheck](https://go.dev/blog/vuln) is a static - // analysis tool that computes the set of functions reachable - // within your application, including dependencies; - // queries a database of known security vulnerabilities; and + // [Govulncheck](https://go.dev/blog/vuln) is a static analysis tool that + // computes the set of functions reachable within your application, including + // dependencies; queries a database of known security vulnerabilities; and + // reports any potential problems it finds. + CodeLensVulncheck CodeLensSource = "vulncheck" + + // Run govulncheck (legacy) + // + // This codelens source annotates the `module` directive in a go.mod file + // with a command to run Govulncheck asynchronously. + // + // [Govulncheck](https://go.dev/blog/vuln) is a static analysis tool that + // computes the set of functions reachable within your application, including + // dependencies; queries a database of known security vulnerabilities; and // reports any potential problems it finds. CodeLensRunGovulncheck CodeLensSource = "run_govulncheck" @@ -700,11 +711,6 @@ type InternalOptions struct { // TODO(rfindley): make pull diagnostics robust, and remove this option, // allowing pull diagnostics by default. PullDiagnostics bool - - // AddTestSourceCodeAction enables support for adding test as a source code - // action. - // TODO(hxjiang): remove this option once the feature is implemented. - AddTestSourceCodeAction bool } type SubdirWatchPatterns string @@ -862,6 +868,9 @@ func (o *Options) ForClientCapabilities(clientInfo *protocol.ClientInfo, caps pr o.InsertTextFormat = protocol.SnippetTextFormat } o.InsertReplaceSupported = caps.TextDocument.Completion.CompletionItem.InsertReplaceSupport + if caps.Window.ShowDocument != nil { + o.ShowDocumentSupported = caps.Window.ShowDocument.Support + } // Check if the client supports configuration messages. o.ConfigurationSupported = caps.Workspace.Configuration o.DynamicConfigurationSupported = caps.Workspace.DidChangeConfiguration.DynamicRegistration @@ -985,8 +994,6 @@ func (o *Options) setOne(name string, value any) error { return setBool(&o.DeepCompletion, value) case "completeUnimported": return setBool(&o.CompleteUnimported, value) - case "addTestSourceCodeAction": - return setBool(&o.AddTestSourceCodeAction, value) case "completionBudget": return setDuration(&o.CompletionBudget, value) case "matcher": diff --git a/gopls/internal/test/integration/bench/reload_test.go b/gopls/internal/test/integration/bench/reload_test.go index 332809ee1eb..b93b76f945d 100644 --- a/gopls/internal/test/integration/bench/reload_test.go +++ b/gopls/internal/test/integration/bench/reload_test.go @@ -4,6 +4,9 @@ package bench import ( + "fmt" + "path" + "regexp" "testing" . "golang.org/x/tools/gopls/internal/test/integration" @@ -14,39 +17,55 @@ import ( // This ensures we are able to diagnose a changed file without reloading all // invalidated packages. See also golang/go#61344 func BenchmarkReload(b *testing.B) { - // TODO(rfindley): add more tests, make this test table-driven - const ( - repo = "kubernetes" - // pkg/util/hash is transitively imported by a large number of packages. - // We should not need to reload those packages to get a diagnostic. - file = "pkg/util/hash/hash.go" - ) - b.Run(repo, func(b *testing.B) { - env := getRepo(b, repo).sharedEnv(b) - - env.OpenFile(file) - defer closeBuffer(b, env, file) - - env.AfterChange() - - if stopAndRecord := startProfileIfSupported(b, env, qualifiedName(repo, "reload")); stopAndRecord != nil { - defer stopAndRecord() - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - // Change the "hash" import. This may result in cache hits, but that's - // OK: the goal is to ensure that we don't reload more than just the - // current package. - env.RegexpReplace(file, `"hash"`, `"hashx"`) - // Note: don't use env.AfterChange() here: we only want to await the - // first diagnostic. - // - // Awaiting a full diagnosis would await diagnosing everything, which - // would require reloading everything. - env.Await(Diagnostics(ForFile(file))) - env.RegexpReplace(file, `"hashx"`, `"hash"`) - env.Await(NoDiagnostics(ForFile(file))) - } - }) + type replace map[string]string + tests := []struct { + repo string + file string + // replacements must be 'reversible', in the sense that the replacing + // string is unique. + replace replace + }{ + // pkg/util/hash is transitively imported by a large number of packages. We + // should not need to reload those packages to get a diagnostic. + {"kubernetes", "pkg/util/hash/hash.go", replace{`"hash"`: `"hashx"`}}, + {"kubernetes", "pkg/kubelet/kubelet.go", replace{ + `"k8s.io/kubernetes/pkg/kubelet/config"`: `"k8s.io/kubernetes/pkg/kubelet/configx"`, + }}, + } + + for _, test := range tests { + b.Run(fmt.Sprintf("%s/%s", test.repo, path.Base(test.file)), func(b *testing.B) { + env := getRepo(b, test.repo).sharedEnv(b) + + env.OpenFile(test.file) + defer closeBuffer(b, env, test.file) + + env.AfterChange() + + profileName := qualifiedName("reload", test.repo, path.Base(test.file)) + if stopAndRecord := startProfileIfSupported(b, env, profileName); stopAndRecord != nil { + defer stopAndRecord() + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Mutate the file. This may result in cache hits, but that's OK: the + // goal is to ensure that we don't reload more than just the current + // package. + for k, v := range test.replace { + env.RegexpReplace(test.file, regexp.QuoteMeta(k), v) + } + // Note: don't use env.AfterChange() here: we only want to await the + // first diagnostic. + // + // Awaiting a full diagnosis would await diagnosing everything, which + // would require reloading everything. + env.Await(Diagnostics(ForFile(test.file))) + for k, v := range test.replace { + env.RegexpReplace(test.file, regexp.QuoteMeta(v), k) + } + env.Await(NoDiagnostics(ForFile(test.file))) + } + }) + } } diff --git a/gopls/internal/test/integration/completion/completion_test.go b/gopls/internal/test/integration/completion/completion_test.go index c96e569f1ad..1f6eb2fe0fb 100644 --- a/gopls/internal/test/integration/completion/completion_test.go +++ b/gopls/internal/test/integration/completion/completion_test.go @@ -970,6 +970,275 @@ use ./missing/ }) } +const reverseInferenceSrcPrelude = ` +-- go.mod -- +module mod.com + +go 1.18 +-- a.go -- +package a + +type InterfaceA interface { + implA() +} + +type InterfaceB interface { + implB() +} + + +type TypeA struct{} + +func (TypeA) implA() {} + +type TypeX string + +func (TypeX) implB() {} + +type TypeB struct{} + +func (TypeB) implB() {} + +type TypeC struct{} // should have no impact + +type Wrap[T any] struct { + inner *T +} + +func NewWrap[T any](x T) Wrap[T] { + return Wrap[T]{inner: &x} +} + +func DoubleWrap[T any, U any](t T, u U) (Wrap[T], Wrap[U]) { + return Wrap[T]{inner: &t}, Wrap[U]{inner: &u} +} + +func IntWrap[T int32 | int64](x T) Wrap[T] { + return Wrap[T]{inner: &x} +} + +var ia InterfaceA +var ib InterfaceB + +var avar TypeA +var bvar TypeB + +var i int +var i32 int32 +var i64 int64 +` + +func TestReverseInferCompletion(t *testing.T) { + src := reverseInferenceSrcPrelude + ` + func main() { + var _ Wrap[int64] = IntWrap() + } + ` + Run(t, src, func(t *testing.T, env *Env) { + compl := env.RegexpSearch("a.go", `IntWrap\(()\)`) + + env.OpenFile("a.go") + result := env.Completion(compl) + + wantLabel := []string{"i64", "i", "i32", "int64()"} + + // only check the prefix due to formatting differences with escaped characters + wantText := []string{"i64", "int64(i", "int64(i32", "int64("} + + for i, item := range result.Items[:len(wantLabel)] { + if diff := cmp.Diff(wantLabel[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected label mismatch (-want +got):\n%s", diff) + } + + if insertText, ok := item.TextEdit.Value.(protocol.InsertReplaceEdit); ok { + if diff := cmp.Diff(wantText[i], insertText.NewText[:len(wantText[i])]); diff != "" { + t.Errorf("Completion: unexpected insertText mismatch (checks prefix only) (-want +got):\n%s", diff) + } + } + } + }) +} + +func TestInterfaceReverseInferCompletion(t *testing.T) { + src := reverseInferenceSrcPrelude + ` + func main() { + var wa Wrap[InterfaceA] + var wb Wrap[InterfaceB] + wb = NewWrap() // wb is of type Wrap[InterfaceB] + } + ` + + Run(t, src, func(t *testing.T, env *Env) { + compl := env.RegexpSearch("a.go", `NewWrap\(()\)`) + + env.OpenFile("a.go") + result := env.Completion(compl) + + wantLabel := []string{"ib", "bvar", "wb.inner", "TypeB{}", "TypeX()", "nil"} + + // only check the prefix due to formatting differences with escaped characters + wantText := []string{"ib", "InterfaceB(", "*wb.inner", "InterfaceB(", "InterfaceB(", "nil"} + + for i, item := range result.Items[:len(wantLabel)] { + if diff := cmp.Diff(wantLabel[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected label mismatch (-want +got):\n%s", diff) + } + + if insertText, ok := item.TextEdit.Value.(protocol.InsertReplaceEdit); ok { + if diff := cmp.Diff(wantText[i], insertText.NewText[:len(wantText[i])]); diff != "" { + t.Errorf("Completion: unexpected insertText mismatch (checks prefix only) (-want +got):\n%s", diff) + } + } + } + }) +} + +func TestInvalidReverseInferenceDefaultsToConstraintCompletion(t *testing.T) { + src := reverseInferenceSrcPrelude + ` + func main() { + var wa Wrap[InterfaceA] + // This is ambiguous, so default to the constraint rather the inference. + wa = IntWrap() + } + ` + Run(t, src, func(t *testing.T, env *Env) { + compl := env.RegexpSearch("a.go", `IntWrap\(()\)`) + + env.OpenFile("a.go") + result := env.Completion(compl) + + wantLabel := []string{"i32", "i64", "nil"} + + for i, item := range result.Items[:len(wantLabel)] { + if diff := cmp.Diff(wantLabel[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected label mismatch (-want +got):\n%s", diff) + } + } + }) +} + +func TestInterfaceReverseInferTypeParamCompletion(t *testing.T) { + src := reverseInferenceSrcPrelude + ` + func main() { + var wa Wrap[InterfaceA] + var wb Wrap[InterfaceB] + wb = NewWrap[]() + } + ` + + Run(t, src, func(t *testing.T, env *Env) { + compl := env.RegexpSearch("a.go", `NewWrap\[()\]\(\)`) + + env.OpenFile("a.go") + result := env.Completion(compl) + want := []string{"InterfaceB", "TypeB", "TypeX", "InterfaceA", "TypeA"} + for i, item := range result.Items[:len(want)] { + if diff := cmp.Diff(want[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected mismatch (-want +got):\n%s", diff) + } + } + }) +} + +func TestInvalidReverseInferenceTypeParamDefaultsToConstraintCompletion(t *testing.T) { + src := reverseInferenceSrcPrelude + ` + func main() { + var wa Wrap[InterfaceA] + // This is ambiguous, so default to the constraint rather the inference. + wb = IntWrap[]() + } + ` + + Run(t, src, func(t *testing.T, env *Env) { + compl := env.RegexpSearch("a.go", `IntWrap\[()\]\(\)`) + + env.OpenFile("a.go") + result := env.Completion(compl) + want := []string{"int32", "int64"} + for i, item := range result.Items[:len(want)] { + if diff := cmp.Diff(want[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected mismatch (-want +got):\n%s", diff) + } + } + }) +} + +func TestReverseInferDoubleTypeParamCompletion(t *testing.T) { + src := reverseInferenceSrcPrelude + ` + func main() { + var wa Wrap[InterfaceA] + var wb Wrap[InterfaceB] + + wa, wb = DoubleWrap[]() + // _ is necessary to trick the parser into an index list expression + wa, wb = DoubleWrap[InterfaceA, _]() + } + ` + Run(t, src, func(t *testing.T, env *Env) { + env.OpenFile("a.go") + + compl := env.RegexpSearch("a.go", `DoubleWrap\[()\]\(\)`) + result := env.Completion(compl) + + wantLabel := []string{"InterfaceA", "TypeA", "InterfaceB", "TypeB", "TypeC"} + + for i, item := range result.Items[:len(wantLabel)] { + if diff := cmp.Diff(wantLabel[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected label mismatch (-want +got):\n%s", diff) + } + } + + compl = env.RegexpSearch("a.go", `DoubleWrap\[InterfaceA, (_)\]\(\)`) + result = env.Completion(compl) + + wantLabel = []string{"InterfaceB", "TypeB", "TypeX", "InterfaceA", "TypeA"} + + for i, item := range result.Items[:len(wantLabel)] { + if diff := cmp.Diff(wantLabel[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected label mismatch (-want +got):\n%s", diff) + } + } + }) +} + +func TestDoubleParamReturnCompletion(t *testing.T) { + src := reverseInferenceSrcPrelude + ` + func concrete() (Wrap[InterfaceA], Wrap[InterfaceB]) { + return DoubleWrap[]() + } + + func concrete2() (Wrap[InterfaceA], Wrap[InterfaceB]) { + return DoubleWrap[InterfaceA, _]() + } + ` + + Run(t, src, func(t *testing.T, env *Env) { + env.OpenFile("a.go") + + compl := env.RegexpSearch("a.go", `DoubleWrap\[()\]\(\)`) + result := env.Completion(compl) + + wantLabel := []string{"InterfaceA", "TypeA", "InterfaceB", "TypeB", "TypeC"} + + for i, item := range result.Items[:len(wantLabel)] { + if diff := cmp.Diff(wantLabel[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected label mismatch (-want +got):\n%s", diff) + } + } + + compl = env.RegexpSearch("a.go", `DoubleWrap\[InterfaceA, (_)\]\(\)`) + result = env.Completion(compl) + + wantLabel = []string{"InterfaceB", "TypeB", "TypeX", "InterfaceA", "TypeA"} + + for i, item := range result.Items[:len(wantLabel)] { + if diff := cmp.Diff(wantLabel[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected label mismatch (-want +got):\n%s", diff) + } + } + }) +} + func TestBuiltinCompletion(t *testing.T) { const files = ` -- go.mod -- diff --git a/gopls/internal/test/integration/env.go b/gopls/internal/test/integration/env.go index 1a7ea70c89b..4acd4603827 100644 --- a/gopls/internal/test/integration/env.go +++ b/gopls/internal/test/integration/env.go @@ -9,6 +9,7 @@ import ( "fmt" "strings" "sync" + "sync/atomic" "testing" "golang.org/x/tools/gopls/internal/protocol" @@ -34,6 +35,10 @@ type Env struct { Awaiter *Awaiter } +// nextAwaiterRegistration is used to create unique IDs for various Awaiter +// registrations. +var nextAwaiterRegistration atomic.Uint64 + // An Awaiter keeps track of relevant LSP state, so that it may be asserted // upon with Expectations. // @@ -46,9 +51,13 @@ type Awaiter struct { mu sync.Mutex // For simplicity, each waiter gets a unique ID. - nextWaiterID int - state State - waiters map[int]*condition + state State + waiters map[uint64]*condition + + // collectors map a registration to the collection of messages that have been + // received since the registration was created. + docCollectors map[uint64][]*protocol.ShowDocumentParams + messageCollectors map[uint64][]*protocol.ShowMessageParams } func NewAwaiter(workdir *fake.Workdir) *Awaiter { @@ -60,7 +69,7 @@ func NewAwaiter(workdir *fake.Workdir) *Awaiter { startedWork: make(map[string]uint64), completedWork: make(map[string]uint64), }, - waiters: make(map[int]*condition), + waiters: make(map[uint64]*condition), } } @@ -79,9 +88,6 @@ func (a *Awaiter) Hooks() fake.ClientHooks { } } -// ResetShownDocuments resets the set of accumulated ShownDocuments seen so far. -func (a *Awaiter) ResetShownDocuments() { a.state.showDocument = nil } - // State encapsulates the server state TODO: explain more type State struct { // diagnostics are a map of relative path->diagnostics params @@ -171,20 +177,78 @@ func (a *Awaiter) onShowDocument(_ context.Context, params *protocol.ShowDocumen a.mu.Lock() defer a.mu.Unlock() + // Update any outstanding listeners. + for id, s := range a.docCollectors { + a.docCollectors[id] = append(s, params) + } + a.state.showDocument = append(a.state.showDocument, params) a.checkConditionsLocked() return nil } -func (a *Awaiter) onShowMessage(_ context.Context, m *protocol.ShowMessageParams) error { +// ListenToShownDocuments registers a listener to incoming showDocument +// notifications. Call the resulting func to deregister the listener and +// receive all notifications that have occurred since the listener was +// registered. +func (a *Awaiter) ListenToShownDocuments() func() []*protocol.ShowDocumentParams { + id := nextAwaiterRegistration.Add(1) + a.mu.Lock() defer a.mu.Unlock() - a.state.showMessage = append(a.state.showMessage, m) + if a.docCollectors == nil { + a.docCollectors = make(map[uint64][]*protocol.ShowDocumentParams) + } + a.docCollectors[id] = nil + + return func() []*protocol.ShowDocumentParams { + a.mu.Lock() + defer a.mu.Unlock() + params := a.docCollectors[id] + delete(a.docCollectors, id) + return params + } +} + +func (a *Awaiter) onShowMessage(_ context.Context, params *protocol.ShowMessageParams) error { + a.mu.Lock() + defer a.mu.Unlock() + + // Update any outstanding listeners. + for id, s := range a.messageCollectors { + a.messageCollectors[id] = append(s, params) + } + + a.state.showMessage = append(a.state.showMessage, params) a.checkConditionsLocked() return nil } +// ListenToShownDocuments registers a listener to incoming showDocument +// notifications. Call the resulting func to deregister the listener and +// receive all notifications that have occurred since the listener was +// registered. +func (a *Awaiter) ListenToShownMessages() func() []*protocol.ShowMessageParams { + id := nextAwaiterRegistration.Add(1) + + a.mu.Lock() + defer a.mu.Unlock() + + if a.messageCollectors == nil { + a.messageCollectors = make(map[uint64][]*protocol.ShowMessageParams) + } + a.messageCollectors[id] = nil + + return func() []*protocol.ShowMessageParams { + a.mu.Lock() + defer a.mu.Unlock() + params := a.messageCollectors[id] + delete(a.messageCollectors, id) + return params + } +} + func (a *Awaiter) onShowMessageRequest(_ context.Context, m *protocol.ShowMessageRequestParams) error { a.mu.Lock() defer a.mu.Unlock() @@ -332,8 +396,7 @@ func (a *Awaiter) Await(ctx context.Context, expectations ...Expectation) error expectations: expectations, verdict: make(chan Verdict), } - a.waiters[a.nextWaiterID] = cond - a.nextWaiterID++ + a.waiters[nextAwaiterRegistration.Add(1)] = cond a.mu.Unlock() var err error diff --git a/gopls/internal/test/integration/expectation.go b/gopls/internal/test/integration/expectation.go index f68f1de5e02..d5e6030bf20 100644 --- a/gopls/internal/test/integration/expectation.go +++ b/gopls/internal/test/integration/expectation.go @@ -452,6 +452,34 @@ type WorkStatus struct { EndMsg string } +// CompletedProgress expects that workDone progress is complete for the given +// progress token. When non-nil WorkStatus is provided, it will be filled +// when the expectation is met. +// +// If the token is not a progress token that the client has seen, this +// expectation is Unmeetable. +func CompletedProgressToken(token protocol.ProgressToken, into *WorkStatus) Expectation { + check := func(s State) Verdict { + work, ok := s.work[token] + if !ok { + return Unmeetable // TODO(rfindley): refactor to allow the verdict to explain this result + } + if work.complete { + if into != nil { + into.Msg = work.msg + into.EndMsg = work.endMsg + } + return Met + } + return Unmet + } + desc := fmt.Sprintf("completed work for token %v", token) + return Expectation{ + Check: check, + Description: desc, + } +} + // CompletedProgress expects that there is exactly one workDone progress with // the given title, and is satisfied when that progress completes. If it is // met, the corresponding status is written to the into argument. diff --git a/gopls/internal/test/integration/fake/client.go b/gopls/internal/test/integration/fake/client.go index 8fdddd92574..93eeab4a8af 100644 --- a/gopls/internal/test/integration/fake/client.go +++ b/gopls/internal/test/integration/fake/client.go @@ -73,6 +73,10 @@ func (c *Client) SemanticTokensRefresh(context.Context) error { return nil } func (c *Client) LogTrace(context.Context, *protocol.LogTraceParams) error { return nil } +func (c *Client) TextDocumentContentRefresh(context.Context, *protocol.TextDocumentContentRefreshParams) error { + return nil +} + func (c *Client) ShowMessage(ctx context.Context, params *protocol.ShowMessageParams) error { if c.hooks.OnShowMessage != nil { return c.hooks.OnShowMessage(ctx, params) diff --git a/gopls/internal/test/integration/fake/editor.go b/gopls/internal/test/integration/fake/editor.go index 466e833f269..1b1e0f170a2 100644 --- a/gopls/internal/test/integration/fake/editor.go +++ b/gopls/internal/test/integration/fake/editor.go @@ -10,6 +10,7 @@ import ( "encoding/json" "errors" "fmt" + "math/rand/v2" "os" "path" "path/filepath" @@ -17,6 +18,7 @@ import ( "slices" "strings" "sync" + "time" "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/protocol/command" @@ -136,6 +138,10 @@ type EditorConfig struct { // If non-nil, MessageResponder is used to respond to ShowMessageRequest // messages. MessageResponder func(params *protocol.ShowMessageRequestParams) (*protocol.MessageActionItem, error) + + // MaxMessageDelay is used for fuzzing message delivery to reproduce test + // flakes. + MaxMessageDelay time.Duration } // NewEditor creates a new Editor. @@ -162,10 +168,11 @@ func (e *Editor) Connect(ctx context.Context, connector servertest.Connector, ho e.serverConn = conn e.Server = protocol.ServerDispatcher(conn) e.client = &Client{editor: e, hooks: hooks} - conn.Go(bgCtx, - protocol.Handlers( - protocol.ClientHandler(e.client, - jsonrpc2.MethodNotFound))) + handler := protocol.ClientHandler(e.client, jsonrpc2.MethodNotFound) + if e.config.MaxMessageDelay > 0 { + handler = DelayedHandler(e.config.MaxMessageDelay, handler) + } + conn.Go(bgCtx, protocol.Handlers(handler)) if err := e.initialize(ctx); err != nil { return nil, err @@ -174,6 +181,18 @@ func (e *Editor) Connect(ctx context.Context, connector servertest.Connector, ho return e, nil } +// DelayedHandler waits [0, maxDelay) before handling each message. +func DelayedHandler(maxDelay time.Duration, handler jsonrpc2.Handler) jsonrpc2.Handler { + return func(ctx context.Context, reply jsonrpc2.Replier, req jsonrpc2.Request) error { + delay := time.Duration(rand.Int64N(int64(maxDelay))) + select { + case <-ctx.Done(): + case <-time.After(delay): + } + return handler(ctx, reply, req) + } +} + func (e *Editor) Stats() CallCounts { e.callsMu.Lock() defer e.callsMu.Unlock() @@ -342,7 +361,8 @@ func clientCapabilities(cfg EditorConfig) (protocol.ClientCapabilities, error) { capabilities.TextDocument.Completion.CompletionItem.SnippetSupport = true capabilities.TextDocument.Completion.CompletionItem.InsertReplaceSupport = true capabilities.TextDocument.SemanticTokens.Requests.Full = &protocol.Or_ClientSemanticTokensRequestOptions_full{Value: true} - capabilities.Window.WorkDoneProgress = true // support window/workDoneProgress + capabilities.Window.WorkDoneProgress = true // support window/workDoneProgress + capabilities.Window.ShowDocument = &protocol.ShowDocumentClientCapabilities{Support: true} // support window/showDocument capabilities.TextDocument.SemanticTokens.TokenTypes = []string{ "namespace", "type", "class", "enum", "interface", "struct", "typeParameter", "parameter", "variable", "property", "enumMember", diff --git a/gopls/internal/test/integration/misc/codeactions_test.go b/gopls/internal/test/integration/misc/codeactions_test.go index 7e5ac9aba62..a9d0ce8b149 100644 --- a/gopls/internal/test/integration/misc/codeactions_test.go +++ b/gopls/internal/test/integration/misc/codeactions_test.go @@ -64,11 +64,11 @@ func g() {} } check("src/a.go", + settings.AddTest, settings.GoAssembly, settings.GoDoc, settings.GoFreeSymbols, settings.GoplsDocFeatures, - settings.RefactorExtractVariable, settings.RefactorInlineCall) check("gen/a.go", settings.GoAssembly, diff --git a/gopls/internal/test/integration/misc/hover_test.go b/gopls/internal/test/integration/misc/hover_test.go index 47a1cb066f8..1592b899b1d 100644 --- a/gopls/internal/test/integration/misc/hover_test.go +++ b/gopls/internal/test/integration/misc/hover_test.go @@ -592,7 +592,7 @@ func main() { ).Run(t, mod, func(t *testing.T, env *Env) { env.OpenFile("main.go") got, _ := env.Hover(env.RegexpSearch("main.go", "F")) - const wantRE = "\\[`a.F` in gopls doc viewer\\]\\(http://127.0.0.1:[0-9]+/gopls/[^/]+/pkg/example.com\\?view=[0-9]+#F\\)" // no version + const wantRE = "\\[`a.F` in gopls doc viewer\\]\\(http://127.0.0.1:[0-9]+/gopls/[^/]+/pkg/example.com/a\\?view=[0-9]+#F\\)" // no version if m, err := regexp.MatchString(wantRE, got.Value); err != nil { t.Fatalf("bad regexp in test: %v", err) } else if !m { diff --git a/gopls/internal/test/integration/misc/semantictokens_test.go b/gopls/internal/test/integration/misc/semantictokens_test.go index b8d8729c63a..46f1df9b2c6 100644 --- a/gopls/internal/test/integration/misc/semantictokens_test.go +++ b/gopls/internal/test/integration/misc/semantictokens_test.go @@ -53,23 +53,23 @@ func TestSemantic_2527(t *testing.T) { want := []fake.SemanticToken{ {Token: "package", TokenType: "keyword"}, {Token: "foo", TokenType: "namespace"}, - {Token: "// Deprecated (for testing)", TokenType: "comment"}, + {Token: "// comment", TokenType: "comment"}, {Token: "func", TokenType: "keyword"}, - {Token: "Add", TokenType: "function", Mod: "definition deprecated"}, + {Token: "Add", TokenType: "function", Mod: "definition signature"}, {Token: "T", TokenType: "typeParameter", Mod: "definition"}, {Token: "int", TokenType: "type", Mod: "defaultLibrary number"}, {Token: "target", TokenType: "parameter", Mod: "definition"}, {Token: "T", TokenType: "typeParameter"}, - {Token: "l", TokenType: "parameter", Mod: "definition"}, + {Token: "l", TokenType: "parameter", Mod: "definition slice"}, {Token: "T", TokenType: "typeParameter"}, {Token: "T", TokenType: "typeParameter"}, {Token: "return", TokenType: "keyword"}, {Token: "append", TokenType: "function", Mod: "defaultLibrary"}, - {Token: "l", TokenType: "parameter"}, + {Token: "l", TokenType: "parameter", Mod: "slice"}, {Token: "target", TokenType: "parameter"}, {Token: "for", TokenType: "keyword"}, {Token: "range", TokenType: "keyword"}, - {Token: "l", TokenType: "parameter"}, + {Token: "l", TokenType: "parameter", Mod: "slice"}, {Token: "// test coverage", TokenType: "comment"}, {Token: "return", TokenType: "keyword"}, {Token: "nil", TokenType: "variable", Mod: "readonly defaultLibrary"}, @@ -81,7 +81,7 @@ module example.com go 1.19 -- main.go -- package foo -// Deprecated (for testing) +// comment func Add[T int](target T, l []T) []T { return append(l, target) for range l {} // test coverage @@ -167,18 +167,18 @@ func bar() {} {Token: "go:linkname", TokenType: "namespace"}, {Token: "now time.Now", TokenType: "comment"}, {Token: "func", TokenType: "keyword"}, - {Token: "now", TokenType: "function", Mod: "definition"}, + {Token: "now", TokenType: "function", Mod: "definition signature"}, {Token: "//", TokenType: "comment"}, {Token: "go:noinline", TokenType: "namespace"}, {Token: "func", TokenType: "keyword"}, - {Token: "foo", TokenType: "function", Mod: "definition"}, + {Token: "foo", TokenType: "function", Mod: "definition signature"}, {Token: "// Mentioning go:noinline should not tokenize.", TokenType: "comment"}, {Token: "//go:notadirective", TokenType: "comment"}, {Token: "func", TokenType: "keyword"}, - {Token: "bar", TokenType: "function", Mod: "definition"}, + {Token: "bar", TokenType: "function", Mod: "definition signature"}, } WithOptions( diff --git a/gopls/internal/test/integration/misc/vuln_test.go b/gopls/internal/test/integration/misc/vuln_test.go index 05cdbe8594f..9f6061c43d9 100644 --- a/gopls/internal/test/integration/misc/vuln_test.go +++ b/gopls/internal/test/integration/misc/vuln_test.go @@ -7,6 +7,7 @@ package misc import ( "context" "encoding/json" + "fmt" "sort" "strings" "testing" @@ -51,7 +52,10 @@ package foo }) } -func TestRunGovulncheckError2(t *testing.T) { +func TestVulncheckError(t *testing.T) { + // This test checks an error of the gopls.vulncheck command, which should be + // returned synchronously. + const files = ` -- go.mod -- module mod.com @@ -69,12 +73,13 @@ func F() { // build error incomplete Settings{ "codelenses": map[string]bool{ "run_govulncheck": true, + "vulncheck": true, }, }, ).Run(t, files, func(t *testing.T, env *Env) { env.OpenFile("go.mod") - var result command.RunVulncheckResult - err := env.Editor.ExecuteCodeLensCommand(env.Ctx, "go.mod", command.RunGovulncheck, &result) + var result command.VulncheckResult + err := env.Editor.ExecuteCodeLensCommand(env.Ctx, "go.mod", command.Vulncheck, &result) if err == nil { t.Fatalf("govulncheck succeeded unexpectedly: %v", result) } @@ -185,37 +190,54 @@ func main() { t.Fatal(err) } defer db.Clean() - WithOptions( - EnvVars{ - // Let the analyzer read vulnerabilities data from the testdata/vulndb. - "GOVULNDB": db.URI(), - // When fetchinging stdlib package vulnerability info, - // behave as if our go version is go1.19 for this testing. - // The default behavior is to run `go env GOVERSION` (which isn't mutable env var). - cache.GoVersionForVulnTest: "go1.19", - "_GOPLS_TEST_BINARY_RUN_AS_GOPLS": "true", // needed to run `gopls vulncheck`. - }, - Settings{ - "codelenses": map[string]bool{ - "run_govulncheck": true, - }, - }, - ).Run(t, files, func(t *testing.T, env *Env) { - env.OpenFile("go.mod") - // Run Command included in the codelens. - var result command.RunVulncheckResult - env.ExecuteCodeLensCommand("go.mod", command.RunGovulncheck, &result) + for _, legacy := range []bool{false, true} { + t.Run(fmt.Sprintf("legacy=%v", legacy), func(t *testing.T) { + WithOptions( + EnvVars{ + // Let the analyzer read vulnerabilities data from the testdata/vulndb. + "GOVULNDB": db.URI(), + // When fetchinging stdlib package vulnerability info, + // behave as if our go version is go1.19 for this testing. + // The default behavior is to run `go env GOVERSION` (which isn't mutable env var). + cache.GoVersionForVulnTest: "go1.19", + "_GOPLS_TEST_BINARY_RUN_AS_GOPLS": "true", // needed to run `gopls vulncheck`. + }, + Settings{ + "codelenses": map[string]bool{ + "run_govulncheck": true, + "vulncheck": true, + }, + }, + ).Run(t, files, func(t *testing.T, env *Env) { + env.OpenFile("go.mod") - env.OnceMet( - CompletedProgress(server.GoVulncheckCommandTitle, nil), - ShownMessage("Found GOSTDLIB"), - NoDiagnostics(ForFile("go.mod")), - ) - testFetchVulncheckResult(t, env, "go.mod", result.Result, map[string]fetchVulncheckResult{ - "go.mod": {IDs: []string{"GOSTDLIB"}, Mode: vulncheck.ModeGovulncheck}, + // Run Command included in the codelens. + + var result *vulncheck.Result + var expectation Expectation + if legacy { + var r command.RunVulncheckResult + env.ExecuteCodeLensCommand("go.mod", command.RunGovulncheck, &r) + expectation = CompletedProgressToken(r.Token, nil) + } else { + var r command.VulncheckResult + env.ExecuteCodeLensCommand("go.mod", command.Vulncheck, &r) + result = r.Result + expectation = CompletedProgress(server.GoVulncheckCommandTitle, nil) + } + + env.OnceMet( + expectation, + ShownMessage("Found GOSTDLIB"), + NoDiagnostics(ForFile("go.mod")), + ) + testFetchVulncheckResult(t, env, "go.mod", result, map[string]fetchVulncheckResult{ + "go.mod": {IDs: []string{"GOSTDLIB"}, Mode: vulncheck.ModeGovulncheck}, + }) + }) }) - }) + } } func TestFetchVulncheckResultStd(t *testing.T) { @@ -592,7 +614,7 @@ func TestRunVulncheckPackageDiagnostics(t *testing.T) { env.ExecuteCodeLensCommand("go.mod", command.RunGovulncheck, &result) gotDiagnostics := &protocol.PublishDiagnosticsParams{} env.OnceMet( - CompletedProgress(server.GoVulncheckCommandTitle, nil), + CompletedProgressToken(result.Token, nil), ShownMessage("Found"), ) env.OnceMet( @@ -640,7 +662,7 @@ func TestRunGovulncheck_Expiry(t *testing.T) { var result command.RunVulncheckResult env.ExecuteCodeLensCommand("go.mod", command.RunGovulncheck, &result) env.OnceMet( - CompletedProgress(server.GoVulncheckCommandTitle, nil), + CompletedProgressToken(result.Token, nil), ShownMessage("Found"), ) // Sleep long enough for the results to expire. @@ -671,7 +693,7 @@ func TestRunVulncheckWarning(t *testing.T) { env.ExecuteCodeLensCommand("go.mod", command.RunGovulncheck, &result) gotDiagnostics := &protocol.PublishDiagnosticsParams{} env.OnceMet( - CompletedProgress(server.GoVulncheckCommandTitle, nil), + CompletedProgressToken(result.Token, nil), ShownMessage("Found"), ) // Vulncheck diagnostics asynchronous to the vulncheck command. @@ -680,7 +702,7 @@ func TestRunVulncheckWarning(t *testing.T) { ReadDiagnostics("go.mod", gotDiagnostics), ) - testFetchVulncheckResult(t, env, "go.mod", result.Result, map[string]fetchVulncheckResult{ + testFetchVulncheckResult(t, env, "go.mod", nil, map[string]fetchVulncheckResult{ // All vulnerabilities (symbol-level, import-level, module-level) are reported. "go.mod": {IDs: []string{"GO-2022-01", "GO-2022-02", "GO-2022-03", "GO-2022-04"}, Mode: vulncheck.ModeGovulncheck}, }) @@ -826,7 +848,7 @@ func TestGovulncheckInfo(t *testing.T) { env.ExecuteCodeLensCommand("go.mod", command.RunGovulncheck, &result) gotDiagnostics := &protocol.PublishDiagnosticsParams{} env.OnceMet( - CompletedProgress(server.GoVulncheckCommandTitle, nil), + CompletedProgressToken(result.Token, nil), ShownMessage("No vulnerabilities found"), // only count affecting vulnerabilities. ) @@ -836,7 +858,7 @@ func TestGovulncheckInfo(t *testing.T) { ReadDiagnostics("go.mod", gotDiagnostics), ) - testFetchVulncheckResult(t, env, "go.mod", result.Result, map[string]fetchVulncheckResult{ + testFetchVulncheckResult(t, env, "go.mod", nil, map[string]fetchVulncheckResult{ "go.mod": {IDs: []string{"GO-2022-02", "GO-2022-04"}, Mode: vulncheck.ModeGovulncheck}, }) // wantDiagnostics maps a module path in the require diff --git a/gopls/internal/test/integration/misc/webserver_test.go b/gopls/internal/test/integration/misc/webserver_test.go index d5a051ea348..11cd56eef99 100644 --- a/gopls/internal/test/integration/misc/webserver_test.go +++ b/gopls/internal/test/integration/misc/webserver_test.go @@ -73,10 +73,11 @@ func (G[T]) F(int, int, int, int, int, int, int, ...int) {} // downcall, this time for a "file:" URL, causing the // client editor to navigate to the source file. t.Log("extracted /src URL", srcURL) + collectDocs := env.Awaiter.ListenToShownDocuments() get(t, srcURL) // Check that that shown location is that of NewFunc. - shownSource := shownDocument(t, env, "file:") + shownSource := shownDocument(t, collectDocs(), "file:") gotLoc := protocol.Location{ URI: protocol.DocumentURI(shownSource.URI), // fishy conversion Range: *shownSource.Selection, @@ -89,6 +90,75 @@ func (G[T]) F(int, int, int, int, int, int, int, ...int) {} }) } +func TestShowDocumentUnsupported(t *testing.T) { + const files = ` +-- go.mod -- +module example.com + +-- a.go -- +package a + +const A = 1 +` + + for _, supported := range []bool{false, true} { + t.Run(fmt.Sprintf("supported=%v", supported), func(t *testing.T) { + opts := []RunOption{Modes(Default)} + if !supported { + opts = append(opts, CapabilitiesJSON([]byte(` +{ + "window": { + "showDocument": { + "support": false + } + } +}`))) + } + WithOptions(opts...).Run(t, files, func(t *testing.T, env *Env) { + env.OpenFile("a.go") + // Invoke the "Browse package documentation" code + // action to start the server. + actions := env.CodeAction(env.Sandbox.Workdir.EntireFile("a.go"), nil, 0) + docAction, err := codeActionByKind(actions, settings.GoDoc) + if err != nil { + t.Fatal(err) + } + + // Execute the command. + // Its side effect should be a single showDocument request. + params := &protocol.ExecuteCommandParams{ + Command: docAction.Command.Command, + Arguments: docAction.Command.Arguments, + } + var result any + collectDocs := env.Awaiter.ListenToShownDocuments() + collectMessages := env.Awaiter.ListenToShownMessages() + env.ExecuteCommand(params, &result) + + // golang/go#70342: just because the command has finished does not mean + // that we will have received the necessary notifications. Synchronize + // using progress reports. + env.Await(CompletedWork(params.Command, 1, false)) + + wantDocs, wantMessages := 0, 1 + if supported { + wantDocs, wantMessages = 1, 0 + } + + docs := collectDocs() + messages := collectMessages() + + if gotDocs := len(docs); gotDocs != wantDocs { + t.Errorf("gopls.doc: got %d showDocument requests, want %d", gotDocs, wantDocs) + } + if gotMessages := len(messages); gotMessages != wantMessages { + t.Errorf("gopls.doc: got %d showMessage requests, want %d", gotMessages, wantMessages) + } + }) + }) + } +} + func TestPkgDocNoPanic66449(t *testing.T) { // This particular input triggered a latent bug in doc.New // that would corrupt the AST while filtering out unexported @@ -353,9 +423,10 @@ func viewPkgDoc(t *testing.T, env *Env, loc protocol.Location) protocol.URI { Arguments: docAction.Command.Arguments, } var result any + collectDocs := env.Awaiter.ListenToShownDocuments() env.ExecuteCommand(params, &result) - doc := shownDocument(t, env, "http:") + doc := shownDocument(t, collectDocs(), "http:") if doc == nil { t.Fatalf("no showDocument call had 'http:' prefix") } @@ -408,8 +479,9 @@ func f(buf bytes.Buffer, greeting string) { Arguments: action.Command.Arguments, } var result command.DebuggingResult + collectDocs := env.Awaiter.ListenToShownDocuments() env.ExecuteCommand(params, &result) - doc := shownDocument(t, env, "http:") + doc := shownDocument(t, collectDocs(), "http:") if doc == nil { t.Fatalf("no showDocument call had 'file:' prefix") } @@ -467,8 +539,9 @@ func g() { Arguments: action.Command.Arguments, } var result command.DebuggingResult + collectDocs := env.Awaiter.ListenToShownDocuments() env.ExecuteCommand(params, &result) - doc := shownDocument(t, env, "http:") + doc := shownDocument(t, collectDocs(), "http:") if doc == nil { t.Fatalf("no showDocument call had 'file:' prefix") } @@ -506,11 +579,8 @@ func g() { // shownDocument returns the first shown document matching the URI prefix. // It may be nil. // As a side effect, it clears the list of accumulated shown documents. -func shownDocument(t *testing.T, env *Env, prefix string) *protocol.ShowDocumentParams { +func shownDocument(t *testing.T, shown []*protocol.ShowDocumentParams, prefix string) *protocol.ShowDocumentParams { t.Helper() - var shown []*protocol.ShowDocumentParams - env.Await(ShownDocuments(&shown)) - env.Awaiter.ResetShownDocuments() // REVIEWERS: seems like a hack; better ideas? var first *protocol.ShowDocumentParams for _, sd := range shown { if strings.HasPrefix(sd.URI, prefix) { diff --git a/gopls/internal/test/integration/options.go b/gopls/internal/test/integration/options.go index 87be2114eaa..8090388e17d 100644 --- a/gopls/internal/test/integration/options.go +++ b/gopls/internal/test/integration/options.go @@ -7,6 +7,7 @@ package integration import ( "strings" "testing" + "time" "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/test/integration/fake" @@ -192,3 +193,14 @@ func MessageResponder(f func(*protocol.ShowMessageRequestParams) (*protocol.Mess opts.editor.MessageResponder = f }) } + +// DelayMessages can be used to fuzz message delivery delays for the purpose of +// reproducing test flakes. +// +// (Even though this option may be unused, keep it around to aid in debugging +// future flakes.) +func DelayMessages(upto time.Duration) RunOption { + return optionSetter(func(opts *runConfig) { + opts.editor.MaxMessageDelay = upto + }) +} diff --git a/gopls/internal/test/integration/workspace/quickfix_test.go b/gopls/internal/test/integration/workspace/quickfix_test.go index c39e5ca3542..3f6b8e8dc32 100644 --- a/gopls/internal/test/integration/workspace/quickfix_test.go +++ b/gopls/internal/test/integration/workspace/quickfix_test.go @@ -341,7 +341,7 @@ func main() {} } func TestStubMethods64087(t *testing.T) { - // We can't use the @fix or @quickfixerr or @codeactionerr + // We can't use the @fix or @quickfixerr or @codeaction // because the error now reported by the corrected logic // is internal and silently causes no fix to be offered. // @@ -404,7 +404,7 @@ type myerror struct{any} } func TestStubMethods64545(t *testing.T) { - // We can't use the @fix or @quickfixerr or @codeactionerr + // We can't use the @fix or @quickfixerr or @codeaction // because the error now reported by the corrected logic // is internal and silently causes no fix to be offered. // diff --git a/gopls/internal/test/marker/doc.go b/gopls/internal/test/marker/doc.go index 509791d509c..abddbddacd3 100644 --- a/gopls/internal/test/marker/doc.go +++ b/gopls/internal/test/marker/doc.go @@ -6,20 +6,30 @@ Package marker defines a framework for running "marker" tests, each defined by a file in the testdata subdirectory. -Use this command to run the tests: +Use this command to run the tests, from the gopls module: - $ go test ./gopls/internal/test/marker [-update] + $ go test ./internal/test/marker [-update] -A marker test uses the '//@' marker syntax of the x/tools/internal/expect package -to annotate source code with various information such as locations and -arguments of LSP operations to be executed by the test. The syntax following -'@' is parsed as a comma-separated list of ordinary Go function calls, for -example +A marker test uses the '//@' syntax of the x/tools/internal/expect package to +annotate source code with various information such as locations and arguments +of LSP operations to be executed by the test. The syntax following '@' is +parsed as a comma-separated list of Go-like function calls, which we refer to +as 'markers' (or sometimes 'marks'), for example - //@foo(a, "b", 3),bar(0) + //@ foo(a, "b", 3), bar(0) -and delegates to a corresponding function to perform LSP-related operations. -See the Marker types documentation below for a list of supported markers. +Unlike ordinary Go, the marker syntax also supports optional named arguments +using the syntax name=value. If provided, named arguments must appear after all +positional arguments, though their ordering with respect to other named +arguments does not matter. For example + + //@ foo(a, "b", d=4, c=3) + +Each marker causes a corresponding function to be called in the test. Some +markers are declarations; for example, @loc declares a name for a source +location. Others have effects, such as executing an LSP operation and asserting +that it behaved as expected. See the Marker types documentation below for the +list of all supported markers. Each call argument is converted to the type of the corresponding parameter of the designated function. The conversion logic may use the surrounding context, @@ -39,26 +49,30 @@ There are several types of file within the test archive that are given special treatment by the test runner: - "skip": the presence of this file causes the test to be skipped, with - the file content used as the skip message. + its content used as the skip message. - "flags": this file is treated as a whitespace-separated list of flags that configure the MarkerTest instance. Supported flags: - -{min,max}_go=go1.20 sets the {min,max}imum Go version for the test - (inclusive) - -cgo requires that CGO_ENABLED is set and the cgo tool is available + + -{min,max}_go=go1.20 sets the {min,max}imum Go runtime version for the test + (inclusive). + -{min,max}_go_command=go1.20 sets the {min,max}imum Go command version for + the test (inclusive). + -cgo requires that CGO_ENABLED is set and the cgo tool is available. -write_sumfile=a,b,c instructs the test runner to generate go.sum files in these directories before running the test. -skip_goos=a,b,c instructs the test runner to skip the test for the listed GOOS values. -skip_goarch=a,b,c does the same for GOARCH. - -ignore_extra_diags suppresses errors for unmatched diagnostics TODO(rfindley): using build constraint expressions for -skip_go{os,arch} would be clearer. + -ignore_extra_diags suppresses errors for unmatched diagnostics -filter_builtins=false disables the filtering of builtins from completion results. -filter_keywords=false disables the filtering of keywords from completion results. -errors_ok=true suppresses errors for Error level log entries. + TODO(rfindley): support flag values containing whitespace. - "settings.json": this file is parsed as JSON, and used as the @@ -88,33 +102,56 @@ treatment by the test runner: # Marker types -Markers are of two kinds. A few are "value markers" (e.g. @item), which are -processed in a first pass and each computes a value that may be referred to -by name later. Most are "action markers", which are processed in a second -pass and take some action such as testing an LSP operation; they may refer -to values computed by value markers. +Markers are of two kinds: "value markers" and "action markers". Value markers +are processed in a first pass, and define named values that may be referred to +as arguments to action markers. For example, the @loc marker defines a named +location that may be used wherever a location is expected. Value markers cannot +refer to names defined by other value markers. Action markers are processed in +a second pass and perform some action such as testing an LSP operation. + +Below, we list supported markers using function signatures, augmented with the +named argument support name=value, as described above. The types referred to in +the signatures below are described in the Argument conversion section. + +Here is the list of supported value markers: + + - loc(name, location): specifies the name for a location in the source. These + locations may be referenced by other markers. Naturally, the location + argument may be specified only as a string or regular expression in the + first pass. + + - defloc(name, location): performs a textDocument/defintiion request at the + src location, and binds the result to the given name. This may be used to + refer to positions in the standard library. + + - hiloc(name, location, kind): defines a documentHighlight value of the + given location and kind. Use its label in a @highlightall marker to + indicate the expected result of a highlight query. + + - item(name, details, kind): defines a completionItem with the provided + fields. This information is not positional, and therefore @item markers + may occur anywhere in the source. Use in conjunction with @complete, + @snippet, or @rank. + + TODO(rfindley): rethink whether floating @item annotations are the best + way to specify completion results. -The following markers are supported within marker tests: +Here is the list of supported action markers: - acceptcompletion(location, label, golden): specifies that accepting the completion candidate produced at the given location with provided label results in the given golden state. - - codeaction(start, end, kind, golden): specifies a code action - to request for the given range. To support multi-line ranges, the range - is defined to be between start.Start and end.End. The golden directory - contains changed file content after the code action is applied. + - codeaction(start location, kind string, end=location, edit=golden, result=golden, err=stringMatcher) - TODO(rfindley): now that 'location' supports multi-line matches, replace - uses of 'codeaction' with codeactionedit. + Specifies a code action to request at the location, with given kind. - - codeactionedit(location, kind, golden): a shorter form of - codeaction. Invokes a code action of the given kind for the given - in-line range, and compares the resulting formatted unified *edits* - (notably, not the full file content) with the golden directory. + If end is set, the location is defined to be between start.Start and end.End. - - codeactionerr(start, end, kind, wantError): specifies a codeaction that - fails with an error that matches the expectation. + Exactly one of edit, result, or err must be set. If edit is set, it is a + golden reference to the edits resulting from the code action. If result is + set, it is a golden reference to the full set of changed files resulting + from the code action. If err is set, it is the code action error. - codelens(location, title): specifies that a codelens is expected at the given location, with given title. Must be used in conjunction with @@ -135,8 +172,9 @@ The following markers are supported within marker tests: The specified location must match the start position of the diagnostic, but end positions are ignored unless exact=true. - TODO(adonovan): in the older marker framework, the annotation asserted - two additional fields (source="compiler", kind="error"). Restore them? + TODO(adonovan): in the older marker framework, the annotation asserted two + additional fields (source="compiler", kind="error"). Restore them using + optional named arguments. - def(src, dst location): performs a textDocument/definition request at the src location, and check the result points to the dst location. @@ -167,10 +205,6 @@ The following markers are supported within marker tests: textDocument/highlight request at the given src location, which should highlight the provided dst locations and kinds. - - hiloc(label, location, kind): defines a documentHighlight value of the - given location and kind. Use its label in a @highlightall marker to - indicate the expected result of a highlight query. - - hover(src, dst location, sm stringMatcher): performs a textDocument/hover at the src location, and checks that the result is the dst location, with matching hover content. @@ -188,36 +222,15 @@ The following markers are supported within marker tests: (These locations are the declarations of the functions enclosing the calls, not the calls themselves.) - - item(label, details, kind): defines a completionItem with the provided - fields. This information is not positional, and therefore @item markers - may occur anywhere in the source. Used in conjunction with @complete, - @snippet, or @rank. - - TODO(rfindley): rethink whether floating @item annotations are the best - way to specify completion results. - - - loc(name, location): specifies the name for a location in the source. These - locations may be referenced by other markers. - - outgoingcalls(src location, want ...location): makes a callHierarchy/outgoingCalls query at the src location, and checks that the set of call.To locations matches want. - - preparerename(src, spn, placeholder): asserts that a textDocument/prepareRename - request at the src location expands to the spn location, with given - placeholder. If placeholder is "", this is treated as a negative - assertion and prepareRename should return nil. - - - rename(location, new, golden): specifies a renaming of the - identifier at the specified location to the new name. - The golden directory contains the transformed files. - - - renameerr(location, new, wantError): specifies a renaming that - fails with an error that matches the expectation. - - - signature(location, label, active): specifies that - signatureHelp at the given location should match the provided string, with - the active parameter (an index) highlighted. + - preparerename(src location, placeholder string, span=location): asserts + that a textDocument/prepareRename request at the src location has the given + placeholder text. If present, the optional span argument is verified to be + the span of the prepareRename result. If placeholder is "", this is treated + as a negative assertion and prepareRename should return nil. - quickfix(location, regexp, golden): like diag, the location and regexp identify an expected diagnostic, which must have exactly one @@ -244,6 +257,17 @@ The following markers are supported within marker tests: 'want' locations. The first want location must be the declaration (assumedly unique). + - rename(location, new, golden): specifies a renaming of the + identifier at the specified location to the new name. + The golden directory contains the transformed files. + + - renameerr(location, new, wantError): specifies a renaming that + fails with an error that matches the expectation. + + - signature(location, label, active): specifies that + signatureHelp at the given location should match the provided string, with + the active parameter (an index) highlighted. + - snippet(location, string OR completionItem, snippet): executes a textDocument/completion request at the location, and searches for a result with label matching that its second argument, which may be a string literal @@ -288,20 +312,26 @@ the following tokens as defined by the Go spec: These values are passed as arguments to the corresponding parameter of the test function. Additional value conversions may occur for these argument -> parameter type pairs: + - string->regexp: the argument is parsed as a regular expressions. + - string->location: the argument is converted to the location of the first instance of the argument in the file content starting from the beginning of the line containing the note. Multi-line matches are permitted, but the match must begin before the note. + - regexp->location: the argument is converted to the location of the first match for the argument in the file content starting from the beginning of the line containing the note. Multi-line matches are permitted, but the match must begin before the note. If the regular expression contains exactly one subgroup, the position of the subgroup is used rather than the position of the submatch. + - name->location: the argument is replaced by the named location. + - name->Golden: the argument is used to look up golden content prefixed by @. + - {string,regexp,identifier}->stringMatcher: a stringMatcher type specifies an expected string, either in the form of a substring that must be present, a regular expression that it must match, or an @@ -331,7 +361,7 @@ Here is a complete example: In this example, the @hover annotation tells the test runner to run the hoverMarker function, which has parameters: - (mark marker, src, dsc protocol.Location, g *Golden). + (mark marker, src, dst protocol.Location, g *Golden). The first argument holds the test context, including fake editor with open files, and sandboxed directory. @@ -366,12 +396,6 @@ Note that -update does not cause missing @diag or @loc markers to be added. # TODO - Rename the files .txtar. - - Provide some means by which locations in the standard library - (or builtin.go) can be named, so that, for example, we can we - can assert that MyError implements the built-in error type. - - If possible, improve handling for optional arguments. Rather than have - multiple variations of a marker, it would be nice to support a more - flexible signature: can codeaction, codeactionedit, codeactionerr, and - quickfix be consolidated? + - Eliminate all *err markers, preferring named arguments. */ package marker diff --git a/gopls/internal/test/marker/marker_test.go b/gopls/internal/test/marker/marker_test.go index 272809c3384..654bca4ae5b 100644 --- a/gopls/internal/test/marker/marker_test.go +++ b/gopls/internal/test/marker/marker_test.go @@ -11,6 +11,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "flag" "fmt" "go/token" @@ -112,6 +113,7 @@ func Test(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { t.Parallel() + if test.skipReason != "" { t.Skip(test.skipReason) } @@ -146,8 +148,15 @@ func Test(t *testing.T) { testenv.SkipAfterGoCommand1Point(t, go1point) } if test.cgo { + if os.Getenv("CGO_ENABLED") == "0" { + // NeedsTool causes the test to fail if cgo is available but disabled + // on the current platform through the environment. I'm not sure why it + // behaves this way, but if CGO_ENABLED=0 is set, we want to skip. + t.Skip("skipping due to CGO_ENABLED=0") + } testenv.NeedsTool(t, "cgo") } + config := fake.EditorConfig{ Settings: test.settings, CapabilitiesJSON: test.capabilities, @@ -171,6 +180,7 @@ func Test(t *testing.T) { diags: make(map[protocol.Location][]protocol.Diagnostic), extraNotes: make(map[protocol.DocumentURI]map[string][]*expect.Note), } + // TODO(rfindley): make it easier to clean up the integration test environment. defer run.env.Editor.Shutdown(context.Background()) // ignore error defer run.env.Sandbox.Close() // ignore error @@ -340,7 +350,16 @@ func (mark marker) mapper() *protocol.Mapper { return mapper } -// errorf reports an error with a prefix indicating the position of the marker note. +// error reports an error with a prefix indicating the position of the marker +// note. +func (mark marker) error(args ...any) { + mark.T().Helper() + msg := fmt.Sprint(args...) + mark.T().Errorf("%s: %s", mark.run.fmtPos(mark.note.Pos), msg) +} + +// errorf reports a formatted error with a prefix indicating the position of +// the marker note. // // It formats the error message using mark.sprintf. func (mark marker) errorf(format string, args ...any) { @@ -396,7 +415,7 @@ func valueMarkerFunc(fn any) func(marker) { args := append([]any{mark}, mark.note.Args[1:]...) argValues, err := convertArgs(mark, ftype, args) if err != nil { - mark.errorf("converting args: %v", err) + mark.error(err) return } results := reflect.ValueOf(fn).Call(argValues) @@ -439,7 +458,7 @@ func actionMarkerFunc(fn any, allowedNames ...string) func(marker) { args := append([]any{mark}, mark.note.Args...) argValues, err := convertArgs(mark, ftype, args) if err != nil { - mark.errorf("converting args: %v", err) + mark.error(err) return } reflect.ValueOf(fn).Call(argValues) @@ -495,12 +514,42 @@ func namedArg[T any](mark marker, name string, dflt T) T { if e, ok := v.(T); ok { return e } else { - mark.errorf("invalid value for %q: %v", name, v) + v, err := convert(mark, v, reflect.TypeOf(dflt)) + if err != nil { + mark.errorf("invalid value for %q: could not convert %v (%T) to %T", name, v, v, dflt) + return dflt + } + return v.(T) + } + } + return dflt +} + +func namedArgFunc[T any](mark marker, name string, f func(marker, any) (T, error), dflt T) T { + if v, ok := mark.note.NamedArgs[name]; ok { + if v2, err := f(mark, v); err == nil { + return v2 + } else { + mark.errorf("invalid value for %q: %v: %v", name, v, err) } } return dflt } +func exactlyOneNamedArg(mark marker, names ...string) bool { + var found []string + for _, name := range names { + if _, ok := mark.note.NamedArgs[name]; ok { + found = append(found, name) + } + } + if len(found) != 1 { + mark.errorf("need exactly one of %v to be set, got %v", names, found) + return false + } + return true +} + // is reports whether arg is a T. func is[T any](arg any) bool { _, ok := arg.(T) @@ -509,17 +558,18 @@ func is[T any](arg any) bool { // Supported value marker functions. See [valueMarkerFunc] for more details. var valueMarkerFuncs = map[string]func(marker){ - "loc": valueMarkerFunc(locMarker), - "item": valueMarkerFunc(completionItemMarker), - "hiloc": valueMarkerFunc(highlightLocationMarker), + "loc": valueMarkerFunc(locMarker), + "item": valueMarkerFunc(completionItemMarker), + "hiloc": valueMarkerFunc(highlightLocationMarker), + "defloc": valueMarkerFunc(defLocMarker), } // Supported action marker functions. See [actionMarkerFunc] for more details. +// +// See doc.go for marker documentation. var actionMarkerFuncs = map[string]func(marker){ "acceptcompletion": actionMarkerFunc(acceptCompletionMarker), - "codeaction": actionMarkerFunc(codeActionMarker), - "codeactionedit": actionMarkerFunc(codeActionEditMarker), - "codeactionerr": actionMarkerFunc(codeActionErrMarker), + "codeaction": actionMarkerFunc(codeActionMarker, "end", "result", "edit", "err"), "codelenses": actionMarkerFunc(codeLensesMarker), "complete": actionMarkerFunc(completeMarker), "def": actionMarkerFunc(defMarker), @@ -535,7 +585,7 @@ var actionMarkerFuncs = map[string]func(marker){ "incomingcalls": actionMarkerFunc(incomingCallsMarker), "inlayhints": actionMarkerFunc(inlayhintsMarker), "outgoingcalls": actionMarkerFunc(outgoingCallsMarker), - "preparerename": actionMarkerFunc(prepareRenameMarker), + "preparerename": actionMarkerFunc(prepareRenameMarker, "span"), "rank": actionMarkerFunc(rankMarker), "refs": actionMarkerFunc(refsMarker), "rename": actionMarkerFunc(renameMarker), @@ -619,7 +669,8 @@ func (l stringListValue) String() string { return strings.Join([]string(l), ",") } -func (t *markerTest) getGolden(id expect.Identifier) *Golden { +func (mark *marker) getGolden(id expect.Identifier) *Golden { + t := mark.run.test golden, ok := t.golden[id] // If there was no golden content for this identifier, we must create one // to handle the case where -update is set: we need a place to store @@ -633,6 +684,9 @@ func (t *markerTest) getGolden(id expect.Identifier) *Golden { // markerTest during execution. Let's merge the two. t.golden[id] = golden } + if golden.firstReference == "" { + golden.firstReference = mark.path() + } return golden } @@ -641,9 +695,10 @@ func (t *markerTest) getGolden(id expect.Identifier) *Golden { // When -update is set, golden captures the updated golden contents for later // writing. type Golden struct { - id expect.Identifier - data map[string][]byte // key "" => @id itself - updated map[string][]byte + id expect.Identifier + firstReference string // file name first referencing this golden content + data map[string][]byte // key "" => @id itself + updated map[string][]byte } // Get returns golden content for the given name, which corresponds to the @@ -820,10 +875,12 @@ func formatTest(test *markerTest) ([]byte, error) { } updatedGolden := make(map[string][]byte) + firstReferences := make(map[string]string) for id, g := range test.golden { for name, data := range g.updated { filename := "@" + path.Join(string(id), name) // name may be "" updatedGolden[filename] = data + firstReferences[filename] = g.firstReference } } @@ -846,7 +903,7 @@ func formatTest(test *markerTest) ([]byte, error) { } } - // ...followed by any new golden files. + // ...but insert new golden files after their first reference. var newGoldenFiles []txtar.File for filename, data := range updatedGolden { // TODO(rfindley): it looks like this implicitly removes trailing newlines @@ -858,7 +915,25 @@ func formatTest(test *markerTest) ([]byte, error) { sort.Slice(newGoldenFiles, func(i, j int) bool { return newGoldenFiles[i].Name < newGoldenFiles[j].Name }) - arch.Files = append(arch.Files, newGoldenFiles...) + for _, g := range newGoldenFiles { + insertAt := len(arch.Files) + if firstRef := firstReferences[g.Name]; firstRef != "" { + for i, f := range arch.Files { + if f.Name == firstRef { + // Insert alphabetically among golden files following the test file. + for i++; i < len(arch.Files); i++ { + f := arch.Files[i] + if !strings.HasPrefix(f.Name, "@") || f.Name >= g.Name { + insertAt = i + break + } + } + break + } + } + } + arch.Files = slices.Insert(arch.Files, insertAt, g) + } return txtar.Format(arch), nil } @@ -973,22 +1048,10 @@ func (run *markerTestRun) fmtPos(pos token.Pos) string { // archive-relative paths for files and including the line number in the full // archive file. func (run *markerTestRun) fmtLoc(loc protocol.Location) string { - formatted := run.fmtLocDetails(loc, true) - if formatted == "" { + if loc == (protocol.Location{}) { run.env.T.Errorf("unable to find %s in test archive", loc) return "" } - return formatted -} - -// See fmtLoc. If includeTxtPos is not set, the position in the full archive -// file is omitted. -// -// If the location cannot be found within the archive, fmtLocDetails returns "". -func (run *markerTestRun) fmtLocDetails(loc protocol.Location, includeTxtPos bool) string { - if loc == (protocol.Location{}) { - return "" - } lines := bytes.Count(run.test.archive.Comment, []byte("\n")) var name string for _, f := range run.test.archive.Files { @@ -1001,39 +1064,74 @@ func (run *markerTestRun) fmtLocDetails(loc protocol.Location, includeTxtPos boo lines += bytes.Count(f.Data, []byte("\n")) } if name == "" { - return "" - } + // Fall back to formatting the "lsp" location. + // These will be in UTF-16, but we probably don't need to clarify that, + // since it will be implied by the file:// URI format. + return summarizeLoc(string(loc.URI), + int(loc.Range.Start.Line), int(loc.Range.Start.Character), + int(loc.Range.End.Line), int(loc.Range.End.Character)) + } + name, startLine, startCol, endLine, endCol := run.mapLocation(loc) + innerSpan := summarizeLoc(name, startLine, startCol, endLine, endCol) + outerSpan := summarizeLoc(run.test.name, lines+startLine, startCol, lines+endLine, endCol) + return fmt.Sprintf("%s (%s)", innerSpan, outerSpan) +} + +// mapLocation returns the relative path and utf8 span of the corresponding +// location, which must be a valid location in an archive file. +func (run *markerTestRun) mapLocation(loc protocol.Location) (name string, startLine, startCol, endLine, endCol int) { + // Note: Editor.Mapper fails if loc.URI is not open, but we always open all + // archive files, so this is probably OK. + // + // In the future, we may want to have the editor read contents from disk if + // the URI is not open. + name = run.env.Sandbox.Workdir.URIToPath(loc.URI) m, err := run.env.Editor.Mapper(name) if err != nil { run.env.T.Errorf("internal error: %v", err) - return "" + return } start, end, err := m.RangeOffsets(loc.Range) if err != nil { run.env.T.Errorf("error formatting location %s: %v", loc, err) + return + } + startLine, startCol = m.OffsetLineCol8(start) + endLine, endCol = m.OffsetLineCol8(end) + return name, startLine, startCol, endLine, endCol +} + +// fmtLocForGolden is like fmtLoc, but chooses more succinct and stable +// formatting, such as would be used for formatting locations in Golden +// content. +func (run *markerTestRun) fmtLocForGolden(loc protocol.Location) string { + if loc == (protocol.Location{}) { return "" } - var ( - startLine, startCol8 = m.OffsetLineCol8(start) - endLine, endCol8 = m.OffsetLineCol8(end) - ) - innerSpan := fmt.Sprintf("%d:%d", startLine, startCol8) // relative to the embedded file - outerSpan := fmt.Sprintf("%d:%d", lines+startLine, startCol8) // relative to the archive file - if start != end { - if endLine == startLine { - innerSpan += fmt.Sprintf("-%d", endCol8) - outerSpan += fmt.Sprintf("-%d", endCol8) - } else { - innerSpan += fmt.Sprintf("-%d:%d", endLine, endCol8) - outerSpan += fmt.Sprintf("-%d:%d", lines+endLine, endCol8) - } + name := run.env.Sandbox.Workdir.URIToPath(loc.URI) + // Note: we check IsAbs on filepaths rather than the slash-ified name for + // accurate handling of windows drive letters. + if filepath.IsAbs(filepath.FromSlash(name)) { + // Don't format any position information in this case, since it will be + // volatile. + return "" } + return summarizeLoc(run.mapLocation(loc)) +} - if includeTxtPos { - return fmt.Sprintf("%s:%s (%s:%s)", name, innerSpan, run.test.name, outerSpan) - } else { - return fmt.Sprintf("%s:%s", name, innerSpan) +// summarizeLoc formats a summary of the given location, in the form +// +// ::[-[:]endCol] +func summarizeLoc(name string, startLine, startCol, endLine, endCol int) string { + span := fmt.Sprintf("%s:%d:%d", name, startLine, startCol) + if startLine != endLine || startCol != endCol { + span += "-" + if endLine != startLine { + span += fmt.Sprintf("%d:", endLine) + } + span += fmt.Sprintf("%d", endCol) } + return span } // ---- converters ---- @@ -1068,6 +1166,8 @@ func convert(mark marker, arg any, paramType reflect.Type) (any, error) { // Handle stringMatcher and golden parameters before resolving identifiers, // because golden content lives in a separate namespace from other // identifiers. + // TODO(rfindley): simplify by flattening the namespace. This interacts + // poorly with named argument resolution. switch paramType { case stringMatcherType: return convertStringMatcher(mark, arg) @@ -1076,7 +1176,7 @@ func convert(mark marker, arg any, paramType reflect.Type) (any, error) { if !ok { return nil, fmt.Errorf("invalid input type %T: golden key must be an identifier", arg) } - return mark.run.test.getGolden(id), nil + return mark.getGolden(id), nil } if id, ok := arg.(expect.Identifier); ok { if arg2, ok := mark.run.values[id]; ok { @@ -1086,7 +1186,7 @@ func convert(mark marker, arg any, paramType reflect.Type) (any, error) { if converter, ok := customConverters[paramType]; ok { arg2, err := converter(mark, arg) if err != nil { - return nil, fmt.Errorf("converting for input type %T to %v: %v", arg, paramType, err) + return nil, err } arg = arg2 } @@ -1096,6 +1196,23 @@ func convert(mark marker, arg any, paramType reflect.Type) (any, error) { return nil, fmt.Errorf("cannot convert %v (%T) to %s", arg, arg, paramType) } +// convertNamedArgLocation is a workaround for converting locations referenced +// by a named argument. See the TODO in [convert]: this wouldn't be necessary +// if we flattened the namespace such that golden content lived in the same +// namespace as values. +func convertNamedArgLocation(mark marker, arg any) (protocol.Location, error) { + if id, ok := arg.(expect.Identifier); ok { + if v, ok := mark.run.values[id]; ok { + if loc, ok := v.(protocol.Location); ok { + return loc, nil + } else { + return protocol.Location{}, fmt.Errorf("invalid location value %v", v) + } + } + } + return convertLocation(mark, arg) +} + // convertLocation converts a string or regexp argument into the protocol // location corresponding to the first position of the string (or first match // of the regexp) in the line preceding the note. @@ -1196,7 +1313,7 @@ func convertStringMatcher(mark marker, arg any) (stringMatcher, error) { case *regexp.Regexp: return stringMatcher{pattern: arg}, nil case expect.Identifier: - golden := mark.run.test.getGolden(arg) + golden := mark.getGolden(arg) return stringMatcher{golden: golden}, nil default: return stringMatcher{}, fmt.Errorf("cannot convert %T to wantError (want: string, regexp, or identifier)", arg) @@ -1216,38 +1333,43 @@ type stringMatcher struct { substr string } -func (sc stringMatcher) String() string { - if sc.golden != nil { - return fmt.Sprintf("content from @%s entry", sc.golden.id) - } else if sc.pattern != nil { - return fmt.Sprintf("content matching %#q", sc.pattern) +// empty reports whether the receiver is an empty stringMatcher. +func (sm stringMatcher) empty() bool { + return sm.golden == nil && sm.pattern == nil && sm.substr == "" +} + +func (sm stringMatcher) String() string { + if sm.golden != nil { + return fmt.Sprintf("content from @%s entry", sm.golden.id) + } else if sm.pattern != nil { + return fmt.Sprintf("content matching %#q", sm.pattern) } else { - return fmt.Sprintf("content with substring %q", sc.substr) + return fmt.Sprintf("content with substring %q", sm.substr) } } // checkErr asserts that the given error matches the stringMatcher's expectations. -func (sc stringMatcher) checkErr(mark marker, err error) { +func (sm stringMatcher) checkErr(mark marker, err error) { if err == nil { - mark.errorf("@%s succeeded unexpectedly, want %v", mark.note.Name, sc) + mark.errorf("@%s succeeded unexpectedly, want %v", mark.note.Name, sm) return } - sc.check(mark, err.Error()) + sm.check(mark, err.Error()) } // check asserts that the given content matches the stringMatcher's expectations. -func (sc stringMatcher) check(mark marker, got string) { - if sc.golden != nil { - compareGolden(mark, []byte(got), sc.golden) - } else if sc.pattern != nil { +func (sm stringMatcher) check(mark marker, got string) { + if sm.golden != nil { + compareGolden(mark, []byte(got), sm.golden) + } else if sm.pattern != nil { // Content must match the regular expression pattern. - if !sc.pattern.MatchString(got) { - mark.errorf("got %q, does not match pattern %#q", got, sc.pattern) + if !sm.pattern.MatchString(got) { + mark.errorf("got %q, does not match pattern %#q", got, sm.pattern) } - } else if !strings.Contains(got, sc.substr) { + } else if !strings.Contains(got, sm.substr) { // Content must contain the expected substring. - mark.errorf("got %q, want substring %q", got, sc.substr) + mark.errorf("got %q, want substring %q", got, sm.substr) } } @@ -1683,10 +1805,15 @@ func hoverErrMarker(mark marker, src protocol.Location, em stringMatcher) { em.checkErr(mark, err) } -// locMarker implements the @loc marker. It is executed before other -// markers, so that locations are available. +// locMarker implements the @loc marker. func locMarker(mark marker, loc protocol.Location) protocol.Location { return loc } +// defLocMarker implements the @defloc marker, which binds a location to the +// (first) result of a jump-to-definition request. +func defLocMarker(mark marker, loc protocol.Location) protocol.Location { + return mark.run.env.GoToDefinition(loc) +} + // diagMarker implements the @diag marker. It eliminates diagnostics from // the observed set in mark.test. func diagMarker(mark marker, loc protocol.Location, re *regexp.Regexp) { @@ -1934,37 +2061,43 @@ func changedFiles(env *integration.Env, changes []protocol.DocumentChange) (map[ return result, nil } -func codeActionMarker(mark marker, start, end protocol.Location, actionKind string, g *Golden) { - // Request the range from start.Start to end.End. - loc := start - loc.Range.End = end.Range.End - - // Apply the fix it suggests. - changed, err := codeAction(mark.run.env, loc.URI, loc.Range, protocol.CodeActionKind(actionKind), nil) - if err != nil { - mark.errorf("codeAction failed: %v", err) +func codeActionMarker(mark marker, loc protocol.Location, kind string) { + if !exactlyOneNamedArg(mark, "edit", "result", "err") { return } - // Check the file state. - checkChangedFiles(mark, changed, g) -} + if end := namedArgFunc(mark, "end", convertNamedArgLocation, protocol.Location{}); end.URI != "" { + if end.URI != loc.URI { + panic("unreachable") + } + loc.Range.End = end.Range.End + } -func codeActionEditMarker(mark marker, loc protocol.Location, actionKind string, g *Golden) { - changed, err := codeAction(mark.run.env, loc.URI, loc.Range, protocol.CodeActionKind(actionKind), nil) - if err != nil { + var ( + edit = namedArg(mark, "edit", expect.Identifier("")) + result = namedArg(mark, "result", expect.Identifier("")) + wantErr = namedArgFunc(mark, "err", convertStringMatcher, stringMatcher{}) + ) + + changed, err := codeAction(mark.run.env, loc.URI, loc.Range, protocol.CodeActionKind(kind), nil) + if err != nil && wantErr.empty() { mark.errorf("codeAction failed: %v", err) return } - checkDiffs(mark, changed, g) -} - -func codeActionErrMarker(mark marker, start, end protocol.Location, actionKind string, wantErr stringMatcher) { - loc := start - loc.Range.End = end.Range.End - _, err := codeAction(mark.run.env, loc.URI, loc.Range, protocol.CodeActionKind(actionKind), nil) - wantErr.checkErr(mark, err) + switch { + case edit != "": + g := mark.getGolden(edit) + checkDiffs(mark, changed, g) + case result != "": + g := mark.getGolden(result) + // Check the file state. + checkChangedFiles(mark, changed, g) + case !wantErr.empty(): + wantErr.checkErr(mark, err) + default: + panic("unreachable") + } } // codeLensesMarker runs the @codelenses() marker, collecting @codelens marks @@ -2015,7 +2148,7 @@ func documentLinkMarker(mark marker, g *Golden) { continue } loc := protocol.Location{URI: mark.uri(), Range: l.Range} - fmt.Fprintln(&b, mark.run.fmtLocDetails(loc, false), *l.Target) + fmt.Fprintln(&b, mark.run.fmtLocForGolden(loc), *l.Target) } compareGolden(mark, b.Bytes(), g) @@ -2118,10 +2251,12 @@ func codeActionChanges(env *integration.Env, uri protocol.DocumentURI, rng proto } } if len(candidates) != 1 { + var msg bytes.Buffer + fmt.Fprintf(&msg, "found %d CodeActions of kind %s for this diagnostic, want 1", len(candidates), kind) for _, act := range actions { - env.T.Logf("found CodeAction Kind=%s Title=%q", act.Kind, act.Title) + fmt.Fprintf(&msg, "\n\tfound %q (%s)", act.Title, act.Kind) } - return nil, fmt.Errorf("found %d CodeActions of kind %s for this diagnostic, want 1", len(candidates), kind) + return nil, errors.New(msg.String()) } action := candidates[0] @@ -2347,7 +2482,7 @@ func inlayhintsMarker(mark marker, g *Golden) { compareGolden(mark, got, g) } -func prepareRenameMarker(mark marker, src, spn protocol.Location, placeholder string) { +func prepareRenameMarker(mark marker, src protocol.Location, placeholder string) { params := &protocol.PrepareRenameParams{ TextDocumentPositionParams: protocol.LocationTextDocumentPositionParams(src), } @@ -2361,7 +2496,15 @@ func prepareRenameMarker(mark marker, src, spn protocol.Location, placeholder st } return } - want := &protocol.PrepareRenameResult{Range: spn.Range, Placeholder: placeholder} + + want := &protocol.PrepareRenameResult{ + Placeholder: placeholder, + } + if span := namedArg(mark, "span", protocol.Location{}); span != (protocol.Location{}) { + want.Range = span.Range + } else { + got.Range = protocol.Range{} // ignore Range + } if diff := cmp.Diff(want, got); diff != "" { mark.errorf("mismatching PrepareRename result:\n%s", diff) } @@ -2468,9 +2611,7 @@ func workspaceSymbolMarker(mark marker, query string, golden *Golden) { for _, s := range gotSymbols { // Omit the txtar position of the symbol location; otherwise edits to the // txtar archive lead to unexpected failures. - loc := mark.run.fmtLocDetails(s.Location, false) - // TODO(rfindley): can we do better here, by detecting if the location is - // relative to GOROOT? + loc := mark.run.fmtLocForGolden(s.Location) if loc == "" { loc = "" } diff --git a/gopls/internal/test/marker/testdata/codeaction/addtest.txt b/gopls/internal/test/marker/testdata/codeaction/addtest.txt index 5d669ec7d01..82c8ee1b2a6 100644 --- a/gopls/internal/test/marker/testdata/codeaction/addtest.txt +++ b/gopls/internal/test/marker/testdata/codeaction/addtest.txt @@ -8,415 +8,681 @@ module golang.org/lsptests/addtest go 1.18 --- settings.json -- -{ - "addTestSourceCodeAction": true -} - --- withcopyright/copyright.go -- +-- copyrightandbuildconstraint/copyrightandbuildconstraint.go -- // Copyright 2020 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build go1.23 +//go:build go1.18 // Package main is for lsp test. package main -func Foo(in string) string {return in} //@codeactionedit("Foo", "source.addTest", with_copyright) +func Foo(in string) string {return in} //@codeaction("Foo", "source.addTest", edit=with_copyright_build_constraint) --- @with_copyright/withcopyright/copyright_test.go -- -@@ -0,0 +1,24 @@ +-- @with_copyright_build_constraint/copyrightandbuildconstraint/copyrightandbuildconstraint_test.go -- +@@ -0,0 +1,32 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + ++//go:build go1.18 ++ +package main_test + ++import( ++ "golang.org/lsptests/addtest/copyrightandbuildconstraint" ++ "testing" ++) ++ +func TestFoo(t *testing.T) { -+ tests := []struct { -+ name string // description of this test case -+ arg string -+ want string -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ got := main.Foo(tt.arg) -+ // TODO: update the condition below to compare got with tt.want. -+ if true { -+ t.Errorf("Foo() = %v, want %v", got, tt.want) -+ } -+ }) -+ } -+} --- withoutcopyright/copyright.go -- -//go:build go1.23 ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ got := main.Foo(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("Foo() = %v, want %v", got, tt.want) ++ } ++ }) ++ } ++} +-- buildconstraint/buildconstraint.go -- +//go:build go1.18 // Package copyright is for lsp test. package copyright -func Foo(in string) string {return in} //@codeactionedit("Foo", "source.addTest", without_copyright) +func Foo(in string) string {return in} //@codeaction("Foo", "source.addTest", edit=with_build_constraint) --- @without_copyright/withoutcopyright/copyright_test.go -- -@@ -0,0 +1,20 @@ +-- @with_build_constraint/buildconstraint/buildconstraint_test.go -- +@@ -0,0 +1,28 @@ ++//go:build go1.18 ++ +package copyright_test + ++import( ++ "golang.org/lsptests/addtest/buildconstraint" ++ "testing" ++) ++ +func TestFoo(t *testing.T) { -+ tests := []struct { -+ name string // description of this test case -+ arg string -+ want string -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ got := copyright.Foo(tt.arg) -+ // TODO: update the condition below to compare got with tt.want. -+ if true { -+ t.Errorf("Foo() = %v, want %v", got, tt.want) -+ } -+ }) -+ } ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ got := copyright.Foo(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("Foo() = %v, want %v", got, tt.want) ++ } ++ }) ++ } +} -- missingtestfile/missingtestfile.go -- package main -func ExportedFunction(in string) string {return in} //@codeactionedit("ExportedFunction", "source.addTest", missing_test_file_exported_function) - type Bar struct {} -func (*Bar) ExportedMethod(in string) string {return in} //@codeactionedit("ExportedMethod", "source.addTest", missing_test_file_exported_recv_exported_method) +type foo struct {} + +func ExportedFunction(in string) string {return in} //@codeaction("ExportedFunction", "source.addTest", edit=missing_test_file_exported_function) + +func UnexportedInputParam(in string, f foo) string {return in} //@codeaction("UnexportedInputParam", "source.addTest", edit=missing_test_file_function_unexported_input) + +func unexportedFunction(in string) string {return in} //@codeaction("unexportedFunction", "source.addTest", edit=missing_test_file_unexported_function) + +func (*Bar) ExportedMethod(in string) string {return in} //@codeaction("ExportedMethod", "source.addTest", edit=missing_test_file_exported_recv_exported_method) + +func (*Bar) UnexportedInputParam(in string, f foo) string {return in} //@codeaction("UnexportedInputParam", "source.addTest", edit=missing_test_file_method_unexported_input) + +func (*foo) ExportedMethod(in string) string {return in} //@codeaction("ExportedMethod", "source.addTest", edit=missing_test_file_unexported_recv) -- @missing_test_file_exported_function/missingtestfile/missingtestfile_test.go -- -@@ -0,0 +1,20 @@ +@@ -0,0 +1,26 @@ +package main_test + ++import( ++ "golang.org/lsptests/addtest/missingtestfile" ++ "testing" ++) ++ +func TestExportedFunction(t *testing.T) { -+ tests := []struct { -+ name string // description of this test case -+ arg string -+ want string -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ got := main.ExportedFunction(tt.arg) -+ // TODO: update the condition below to compare got with tt.want. -+ if true { -+ t.Errorf("ExportedFunction() = %v, want %v", got, tt.want) -+ } -+ }) -+ } ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ got := main.ExportedFunction(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("ExportedFunction() = %v, want %v", got, tt.want) ++ } ++ }) ++ } +} -- @missing_test_file_exported_recv_exported_method/missingtestfile/missingtestfile_test.go -- -@@ -0,0 +1,20 @@ +@@ -0,0 +1,28 @@ +package main_test + ++import( ++ "golang.org/lsptests/addtest/missingtestfile" ++ "testing" ++) ++ +func TestBar_ExportedMethod(t *testing.T) { -+ tests := []struct { -+ name string // description of this test case -+ arg string -+ want string -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ got := ExportedMethod(tt.arg) -+ // TODO: update the condition below to compare got with tt.want. -+ if true { -+ t.Errorf("ExportedMethod() = %v, want %v", got, tt.want) -+ } -+ }) -+ } ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ // TODO: construct the receiver type. ++ var b main.Bar ++ got := b.ExportedMethod(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("ExportedMethod() = %v, want %v", got, tt.want) ++ } ++ }) ++ } ++} +-- @missing_test_file_function_unexported_input/missingtestfile/missingtestfile_test.go -- +@@ -0,0 +1,24 @@ ++package main ++ ++import "testing" ++ ++func TestUnexportedInputParam(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ f foo ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ got := UnexportedInputParam(tt.in, tt.f) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("UnexportedInputParam() = %v, want %v", got, tt.want) ++ } ++ }) ++ } ++} +-- @missing_test_file_method_unexported_input/missingtestfile/missingtestfile_test.go -- +@@ -0,0 +1,26 @@ ++package main ++ ++import "testing" ++ ++func TestBar_UnexportedInputParam(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ f foo ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ // TODO: construct the receiver type. ++ var b Bar ++ got := b.UnexportedInputParam(tt.in, tt.f) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("UnexportedInputParam() = %v, want %v", got, tt.want) ++ } ++ }) ++ } ++} +-- @missing_test_file_unexported_function/missingtestfile/missingtestfile_test.go -- +@@ -0,0 +1,23 @@ ++package main ++ ++import "testing" ++ ++func Test_unexportedFunction(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ got := unexportedFunction(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("unexportedFunction() = %v, want %v", got, tt.want) ++ } ++ }) ++ } ++} +-- @missing_test_file_unexported_recv/missingtestfile/missingtestfile_test.go -- +@@ -0,0 +1,25 @@ ++package main ++ ++import "testing" ++ ++func Test_foo_ExportedMethod(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ // TODO: construct the receiver type. ++ var f foo ++ got := f.ExportedMethod(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("ExportedMethod() = %v, want %v", got, tt.want) ++ } ++ }) ++ } +} -- xpackagetestfile/xpackagetestfile.go -- package main -func ExportedFunction(in string) string {return in} //@codeactionedit("ExportedFunction", "source.addTest", xpackage_exported_function) -func unexportedFunction(in string) string {return in} //@codeactionedit("unexportedFunction", "source.addTest", xpackage_unexported_function) +func ExportedFunction(in string) string {return in} //@codeaction("ExportedFunction", "source.addTest", edit=xpackage_exported_function) +func unexportedFunction(in string) string {return in} //@codeaction("unexportedFunction", "source.addTest", edit=xpackage_unexported_function) type Bar struct {} -func (*Bar) ExportedMethod(in string) string {return in} //@codeactionedit("ExportedMethod", "source.addTest", xpackage_exported_recv_exported_method) -func (*Bar) unexportedMethod(in string) string {return in} //@codeactionedit("unexportedMethod", "source.addTest", xpackage_exported_recv_unexported_method) +func (*Bar) ExportedMethod(in string) string {return in} //@codeaction("ExportedMethod", "source.addTest", edit=xpackage_exported_recv_exported_method) +func (*Bar) unexportedMethod(in string) string {return in} //@codeaction("unexportedMethod", "source.addTest", edit=xpackage_exported_recv_unexported_method) type foo struct {} -func (*foo) ExportedMethod(in string) string {return in} //@codeactionedit("ExportedMethod", "source.addTest", xpackage_unexported_recv_exported_method) -func (*foo) unexportedMethod(in string) string {return in} //@codeactionedit("unexportedMethod", "source.addTest", xpackage_unexported_recv_unexported_method) +func (*foo) ExportedMethod(in string) string {return in} //@codeaction("ExportedMethod", "source.addTest", edit=xpackage_unexported_recv_exported_method) +func (*foo) unexportedMethod(in string) string {return in} //@codeaction("unexportedMethod", "source.addTest", edit=xpackage_unexported_recv_unexported_method) -- xpackagetestfile/xpackagetestfile_test.go -- package main -- @xpackage_exported_function/xpackagetestfile/xpackagetestfile_test.go -- -@@ -3 +3,18 @@ +@@ -3 +3,22 @@ ++import "testing" ++ ++ +func TestExportedFunction(t *testing.T) { -+ tests := []struct { -+ name string // description of this test case -+ arg string -+ want string -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ got := ExportedFunction(tt.arg) -+ // TODO: update the condition below to compare got with tt.want. -+ if true { -+ t.Errorf("ExportedFunction() = %v, want %v", got, tt.want) -+ } -+ }) -+ } ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ got := ExportedFunction(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("ExportedFunction() = %v, want %v", got, tt.want) ++ } ++ }) ++ } +} -- @xpackage_unexported_function/xpackagetestfile/xpackagetestfile_test.go -- -@@ -3 +3,18 @@ +@@ -3 +3,22 @@ ++import "testing" ++ ++ +func Test_unexportedFunction(t *testing.T) { -+ tests := []struct { -+ name string // description of this test case -+ arg string -+ want string -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ got := unexportedFunction(tt.arg) -+ // TODO: update the condition below to compare got with tt.want. -+ if true { -+ t.Errorf("unexportedFunction() = %v, want %v", got, tt.want) -+ } -+ }) -+ } ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ got := unexportedFunction(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("unexportedFunction() = %v, want %v", got, tt.want) ++ } ++ }) ++ } +} -- @xpackage_exported_recv_exported_method/xpackagetestfile/xpackagetestfile_test.go -- -@@ -3 +3,18 @@ +@@ -3 +3,24 @@ ++import "testing" ++ ++ +func TestBar_ExportedMethod(t *testing.T) { -+ tests := []struct { -+ name string // description of this test case -+ arg string -+ want string -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ got := ExportedMethod(tt.arg) -+ // TODO: update the condition below to compare got with tt.want. -+ if true { -+ t.Errorf("ExportedMethod() = %v, want %v", got, tt.want) -+ } -+ }) -+ } ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ // TODO: construct the receiver type. ++ var b Bar ++ got := b.ExportedMethod(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("ExportedMethod() = %v, want %v", got, tt.want) ++ } ++ }) ++ } +} -- @xpackage_exported_recv_unexported_method/xpackagetestfile/xpackagetestfile_test.go -- -@@ -3 +3,18 @@ +@@ -3 +3,24 @@ ++import "testing" ++ ++ +func TestBar_unexportedMethod(t *testing.T) { -+ tests := []struct { -+ name string // description of this test case -+ arg string -+ want string -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ got := unexportedMethod(tt.arg) -+ // TODO: update the condition below to compare got with tt.want. -+ if true { -+ t.Errorf("unexportedMethod() = %v, want %v", got, tt.want) -+ } -+ }) -+ } ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ // TODO: construct the receiver type. ++ var b Bar ++ got := b.unexportedMethod(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("unexportedMethod() = %v, want %v", got, tt.want) ++ } ++ }) ++ } +} -- @xpackage_unexported_recv_exported_method/xpackagetestfile/xpackagetestfile_test.go -- -@@ -3 +3,18 @@ +@@ -3 +3,24 @@ ++import "testing" ++ ++ +func Test_foo_ExportedMethod(t *testing.T) { -+ tests := []struct { -+ name string // description of this test case -+ arg string -+ want string -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ got := ExportedMethod(tt.arg) -+ // TODO: update the condition below to compare got with tt.want. -+ if true { -+ t.Errorf("ExportedMethod() = %v, want %v", got, tt.want) -+ } -+ }) -+ } ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ // TODO: construct the receiver type. ++ var f foo ++ got := f.ExportedMethod(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("ExportedMethod() = %v, want %v", got, tt.want) ++ } ++ }) ++ } +} -- @xpackage_unexported_recv_unexported_method/xpackagetestfile/xpackagetestfile_test.go -- -@@ -3 +3,18 @@ +@@ -3 +3,24 @@ ++import "testing" ++ ++ +func Test_foo_unexportedMethod(t *testing.T) { -+ tests := []struct { -+ name string // description of this test case -+ arg string -+ want string -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ got := unexportedMethod(tt.arg) -+ // TODO: update the condition below to compare got with tt.want. -+ if true { -+ t.Errorf("unexportedMethod() = %v, want %v", got, tt.want) -+ } -+ }) -+ } ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ // TODO: construct the receiver type. ++ var f foo ++ got := f.unexportedMethod(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("unexportedMethod() = %v, want %v", got, tt.want) ++ } ++ }) ++ } +} -- aliasreceiver/aliasreceiver.go -- package main -type bar struct {} -type middle1 = bar -type middle2 = middle1 -type middle3 = middle2 -type Bar = middle3 +type bar0 struct {} +type bar1 = bar0 +type Bar = bar1 + +func (*Bar) ExportedMethod(in string) string {return in} //@codeaction("ExportedMethod", "source.addTest", edit=pointer_receiver_exported_method) +func (*Bar) unexportedMethod(in string) string {return in} //@codeaction("unexportedMethod", "source.addTest", edit=pointer_receiver_unexported_method) + +type foo0 struct {} +type foo1 = foo0 +type foo = foo1 + +func (foo) ExportedMethod(in string) string {return in} //@codeaction("ExportedMethod", "source.addTest", edit=alias_receiver_exported_method) +func (foo) unexportedMethod(in string) string {return in} //@codeaction("unexportedMethod", "source.addTest", edit=alias_receiver_unexported_method) + +type baz0 struct{} +type baz1 = baz0 +type baz = baz1 + +func newBaz0() baz0 {return baz0{}} -func (*Bar) ExportedMethod(in string) string {return in} //@codeactionedit("ExportedMethod", "source.addTest", pointer_receiver_exported_method) -func (*Bar) unexportedMethod(in string) string {return in} //@codeactionedit("unexportedMethod", "source.addTest", pointer_receiver_unexported_method) +func (baz) method(in string) string {return in} //@codeaction("method", "source.addTest", edit=alias_constructor_on_underlying_type) -type bar2 struct {} -type middle4 = bar2 -type middle5 = middle4 -type middle6 = middle5 -type foo = *middle6 +type qux0 struct{} +type qux1 = qux0 +type qux2 = qux1 +type Qux = *qux2 -func (foo) ExportedMethod(in string) string {return in} //@codeactionedit("ExportedMethod", "source.addTest", alias_receiver_exported_method) -func (foo) unexportedMethod(in string) string {return in} //@codeactionedit("unexportedMethod", "source.addTest", alias_receiver_unexported_method) +func newQux1() (qux1, error) {return qux1{}, nil} + +func (Qux) method(in string) string {return in} //@codeaction("method", "source.addTest", edit=alias_constructor_on_different_alias_type) -- aliasreceiver/aliasreceiver_test.go -- package main -- @pointer_receiver_exported_method/aliasreceiver/aliasreceiver_test.go -- -@@ -3 +3,18 @@ +@@ -3 +3,24 @@ ++import "testing" ++ ++ +func TestBar_ExportedMethod(t *testing.T) { -+ tests := []struct { -+ name string // description of this test case -+ arg string -+ want string -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ got := ExportedMethod(tt.arg) -+ // TODO: update the condition below to compare got with tt.want. -+ if true { -+ t.Errorf("ExportedMethod() = %v, want %v", got, tt.want) -+ } -+ }) -+ } ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ // TODO: construct the receiver type. ++ var b Bar ++ got := b.ExportedMethod(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("ExportedMethod() = %v, want %v", got, tt.want) ++ } ++ }) ++ } +} -- @pointer_receiver_unexported_method/aliasreceiver/aliasreceiver_test.go -- -@@ -3 +3,18 @@ +@@ -3 +3,24 @@ ++import "testing" ++ ++ +func TestBar_unexportedMethod(t *testing.T) { -+ tests := []struct { -+ name string // description of this test case -+ arg string -+ want string -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ got := unexportedMethod(tt.arg) -+ // TODO: update the condition below to compare got with tt.want. -+ if true { -+ t.Errorf("unexportedMethod() = %v, want %v", got, tt.want) -+ } -+ }) -+ } ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ // TODO: construct the receiver type. ++ var b Bar ++ got := b.unexportedMethod(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("unexportedMethod() = %v, want %v", got, tt.want) ++ } ++ }) ++ } +} -- @alias_receiver_exported_method/aliasreceiver/aliasreceiver_test.go -- -@@ -3 +3,18 @@ +@@ -3 +3,24 @@ ++import "testing" ++ ++ +func Test_foo_ExportedMethod(t *testing.T) { -+ tests := []struct { -+ name string // description of this test case -+ arg string -+ want string -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ got := ExportedMethod(tt.arg) -+ // TODO: update the condition below to compare got with tt.want. -+ if true { -+ t.Errorf("ExportedMethod() = %v, want %v", got, tt.want) -+ } -+ }) -+ } ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ // TODO: construct the receiver type. ++ var f foo ++ got := f.ExportedMethod(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("ExportedMethod() = %v, want %v", got, tt.want) ++ } ++ }) ++ } +} -- @alias_receiver_unexported_method/aliasreceiver/aliasreceiver_test.go -- -@@ -3 +3,18 @@ +@@ -3 +3,24 @@ ++import "testing" ++ ++ +func Test_foo_unexportedMethod(t *testing.T) { -+ tests := []struct { -+ name string // description of this test case -+ arg string -+ want string -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ got := unexportedMethod(tt.arg) -+ // TODO: update the condition below to compare got with tt.want. -+ if true { -+ t.Errorf("unexportedMethod() = %v, want %v", got, tt.want) -+ } -+ }) -+ } ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ // TODO: construct the receiver type. ++ var f foo ++ got := f.unexportedMethod(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("unexportedMethod() = %v, want %v", got, tt.want) ++ } ++ }) ++ } ++} +-- @alias_constructor_on_underlying_type/aliasreceiver/aliasreceiver_test.go -- +@@ -3 +3,23 @@ ++import "testing" ++ ++ ++func Test_baz_method(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ b := newBaz0() ++ got := b.method(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("method() = %v, want %v", got, tt.want) ++ } ++ }) ++ } ++} +-- @alias_constructor_on_different_alias_type/aliasreceiver/aliasreceiver_test.go -- +@@ -3 +3,26 @@ ++import "testing" ++ ++ ++func TestQux_method(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ q, err := newQux1() ++ if err != nil { ++ t.Fatalf("could not construct receiver type: %v", err) ++ } ++ got := q.method(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("method() = %v, want %v", got, tt.want) ++ } ++ }) ++ } +} -- multiinputoutput/multiinputoutput.go -- package main -func Foo(in, in1, in2, in3 string) (out, out1, out2 string) {return in, in, in} //@codeactionedit("Foo", "source.addTest", multi_input_output) +func Foo(in, in2, in3, in4 string) (out, out1, out2 string) {return "", "", ""} //@codeaction("Foo", "source.addTest", edit=multi_input_output) -- @multi_input_output/multiinputoutput/multiinputoutput_test.go -- -@@ -0,0 +1,34 @@ +@@ -0,0 +1,37 @@ +package main_test + ++import( ++ "golang.org/lsptests/addtest/multiinputoutput" ++ "testing" ++) ++ +func TestFoo(t *testing.T) { -+ type args struct { -+ in string -+ in2 string -+ in3 string -+ in4 string -+ } -+ tests := []struct { -+ name string // description of this test case -+ args args -+ want string -+ want2 string -+ want3 string -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ got, got2, got3 := main.Foo(tt.args.in, tt.args.in2, tt.args.in3, tt.args.in4) -+ // TODO: update the condition below to compare got with tt.want. -+ if true { -+ t.Errorf("Foo() = %v, want %v", got, tt.want) -+ } -+ if true { -+ t.Errorf("Foo() = %v, want %v", got2, tt.want2) -+ } -+ if true { -+ t.Errorf("Foo() = %v, want %v", got3, tt.want3) -+ } -+ }) -+ } ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ in2 string ++ in3 string ++ in4 string ++ want string ++ want2 string ++ want3 string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ got, got2, got3 := main.Foo(tt.in, tt.in2, tt.in3, tt.in4) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("Foo() = %v, want %v", got, tt.want) ++ } ++ if true { ++ t.Errorf("Foo() = %v, want %v", got2, tt.want2) ++ } ++ if true { ++ t.Errorf("Foo() = %v, want %v", got3, tt.want3) ++ } ++ }) ++ } +} -- xpackagerename/xpackagerename.go -- package main @@ -424,39 +690,47 @@ package main import ( mytime "time" myast "go/ast" + mytest "testing" ) -func Foo(t mytime.Time, a *myast.Node) (mytime.Time, *myast.Node) {return t, a} //@codeactionedit("Foo", "source.addTest", xpackage_rename) +var local mytest.T + +func Foo(t mytime.Time, a *myast.Node) (mytime.Time, *myast.Node) {return t, a} //@codeaction("Foo", "source.addTest", edit=xpackage_rename) -- @xpackage_rename/xpackagerename/xpackagerename_test.go -- -@@ -0,0 +1,28 @@ +@@ -0,0 +1,33 @@ +package main_test + -+func TestFoo(t *testing.T) { -+ type args struct { -+ in mytime.Time -+ in2 *myast.Node -+ } -+ tests := []struct { -+ name string // description of this test case -+ args args -+ want mytime.Time -+ want2 *myast.Node -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ got, got2 := main.Foo(tt.args.in, tt.args.in2) -+ // TODO: update the condition below to compare got with tt.want. -+ if true { -+ t.Errorf("Foo() = %v, want %v", got, tt.want) -+ } -+ if true { -+ t.Errorf("Foo() = %v, want %v", got2, tt.want2) -+ } -+ }) -+ } ++import( ++ myast "go/ast" ++ "golang.org/lsptests/addtest/xpackagerename" ++ mytest "testing" ++ mytime "time" ++) ++ ++func TestFoo(t *mytest.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ t mytime.Time ++ a *myast.Node ++ want mytime.Time ++ want2 *myast.Node ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *mytest.T) { ++ got, got2 := main.Foo(tt.t, tt.a) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("Foo() = %v, want %v", got, tt.want) ++ } ++ if true { ++ t.Errorf("Foo() = %v, want %v", got2, tt.want2) ++ } ++ }) ++ } +} -- xtestpackagerename/xtestpackagerename.go -- package main @@ -464,149 +738,805 @@ package main import ( mytime "time" myast "go/ast" + mytest "testing" ) -func Foo(t mytime.Time, a *myast.Node) (mytime.Time, *myast.Node) {return t, a} //@codeactionedit("Foo", "source.addTest", xtest_package_rename) +var local mytest.T + +func Foo(t mytime.Time, a *myast.Node) (mytime.Time, *myast.Node) {return t, a} //@codeaction("Foo", "source.addTest", edit=xtest_package_rename) -- xtestpackagerename/xtestpackagerename_test.go -- package main_test import ( - yourtime "time" - yourast "go/ast" + yourast "go/ast" + yourtest "testing" + yourtime "time" ) var fooTime = yourtime.Time{} var fooNode = yourast.Node{} +var fooT yourtest.T -- @xtest_package_rename/xtestpackagerename/xtestpackagerename_test.go -- -@@ -11 +11,26 @@ -+func TestFoo(t *testing.T) { -+ type args struct { -+ in yourtime.Time -+ in2 *yourast.Node -+ } -+ tests := []struct { -+ name string // description of this test case -+ args args -+ want yourtime.Time -+ want2 *yourast.Node -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ got, got2 := main.Foo(tt.args.in, tt.args.in2) -+ // TODO: update the condition below to compare got with tt.want. -+ if true { -+ t.Errorf("Foo() = %v, want %v", got, tt.want) -+ } -+ if true { -+ t.Errorf("Foo() = %v, want %v", got2, tt.want2) -+ } -+ }) -+ } +@@ -7 +7,2 @@ ++ ++ "golang.org/lsptests/addtest/xtestpackagerename" +@@ -13 +15,25 @@ ++ ++func TestFoo(t *yourtest.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ t yourtime.Time ++ a *yourast.Node ++ want yourtime.Time ++ want2 *yourast.Node ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *yourtest.T) { ++ got, got2 := main.Foo(tt.t, tt.a) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("Foo() = %v, want %v", got, tt.want) ++ } ++ if true { ++ t.Errorf("Foo() = %v, want %v", got2, tt.want2) ++ } ++ }) ++ } +} -- returnwitherror/returnwitherror.go -- package main -func OnlyErr() error {return nil} //@codeactionedit("OnlyErr", "source.addTest", return_only_error) -func StringErr() (string, error) {return "", nil} //@codeactionedit("StringErr", "source.addTest", return_string_error) -func MultipleStringErr() (string, string, string, error) {return "", "", "", nil} //@codeactionedit("MultipleStringErr", "source.addTest", return_multiple_string_error) +func OnlyErr() error {return nil} //@codeaction("OnlyErr", "source.addTest", edit=return_only_error) +func StringErr() (string, error) {return "", nil} //@codeaction("StringErr", "source.addTest", edit=return_string_error) +func MultipleStringErr() (string, string, string, error) {return "", "", "", nil} //@codeaction("MultipleStringErr", "source.addTest", edit=return_multiple_string_error) -- @return_only_error/returnwitherror/returnwitherror_test.go -- -@@ -0,0 +1,24 @@ +@@ -0,0 +1,29 @@ +package main_test + ++import( ++ "golang.org/lsptests/addtest/returnwitherror" ++ "testing" ++) ++ +func TestOnlyErr(t *testing.T) { -+ tests := []struct { -+ name string // description of this test case -+ wantErr bool -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ gotErr := main.OnlyErr() -+ if gotErr != nil { -+ if !tt.wantErr { -+ t.Errorf("OnlyErr() failed: %v", gotErr) -+ } -+ return -+ } -+ if tt.wantErr { -+ t.Fatal("OnlyErr() succeeded unexpectedly") -+ } -+ }) -+ } ++ tests := []struct { ++ name string // description of this test case ++ wantErr bool ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ gotErr := main.OnlyErr() ++ if gotErr != nil { ++ if !tt.wantErr { ++ t.Errorf("OnlyErr() failed: %v", gotErr) ++ } ++ return ++ } ++ if tt.wantErr { ++ t.Fatal("OnlyErr() succeeded unexpectedly") ++ } ++ }) ++ } +} -- @return_string_error/returnwitherror/returnwitherror_test.go -- -@@ -0,0 +1,29 @@ +@@ -0,0 +1,34 @@ +package main_test + ++import( ++ "golang.org/lsptests/addtest/returnwitherror" ++ "testing" ++) ++ +func TestStringErr(t *testing.T) { -+ tests := []struct { -+ name string // description of this test case -+ want string -+ wantErr bool -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ got, gotErr := main.StringErr() -+ if gotErr != nil { -+ if !tt.wantErr { -+ t.Errorf("StringErr() failed: %v", gotErr) -+ } -+ return -+ } -+ if tt.wantErr { -+ t.Fatal("StringErr() succeeded unexpectedly") -+ } -+ // TODO: update the condition below to compare got with tt.want. -+ if true { -+ t.Errorf("StringErr() = %v, want %v", got, tt.want) -+ } -+ }) -+ } ++ tests := []struct { ++ name string // description of this test case ++ want string ++ wantErr bool ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ got, gotErr := main.StringErr() ++ if gotErr != nil { ++ if !tt.wantErr { ++ t.Errorf("StringErr() failed: %v", gotErr) ++ } ++ return ++ } ++ if tt.wantErr { ++ t.Fatal("StringErr() succeeded unexpectedly") ++ } ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("StringErr() = %v, want %v", got, tt.want) ++ } ++ }) ++ } +} -- @return_multiple_string_error/returnwitherror/returnwitherror_test.go -- -@@ -0,0 +1,37 @@ +@@ -0,0 +1,42 @@ +package main_test + ++import( ++ "golang.org/lsptests/addtest/returnwitherror" ++ "testing" ++) ++ +func TestMultipleStringErr(t *testing.T) { -+ tests := []struct { -+ name string // description of this test case -+ want string -+ want2 string -+ want3 string -+ wantErr bool -+ }{ -+ // TODO: Add test cases. -+ } -+ for _, tt := range tests { -+ t.Run(tt.name, func(t *testing.T) { -+ got, got2, got3, gotErr := main.MultipleStringErr() -+ if gotErr != nil { -+ if !tt.wantErr { -+ t.Errorf("MultipleStringErr() failed: %v", gotErr) -+ } -+ return -+ } -+ if tt.wantErr { -+ t.Fatal("MultipleStringErr() succeeded unexpectedly") -+ } -+ // TODO: update the condition below to compare got with tt.want. -+ if true { -+ t.Errorf("MultipleStringErr() = %v, want %v", got, tt.want) -+ } -+ if true { -+ t.Errorf("MultipleStringErr() = %v, want %v", got2, tt.want2) -+ } -+ if true { -+ t.Errorf("MultipleStringErr() = %v, want %v", got3, tt.want3) -+ } -+ }) -+ } ++ tests := []struct { ++ name string // description of this test case ++ want string ++ want2 string ++ want3 string ++ wantErr bool ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ got, got2, got3, gotErr := main.MultipleStringErr() ++ if gotErr != nil { ++ if !tt.wantErr { ++ t.Errorf("MultipleStringErr() failed: %v", gotErr) ++ } ++ return ++ } ++ if tt.wantErr { ++ t.Fatal("MultipleStringErr() succeeded unexpectedly") ++ } ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("MultipleStringErr() = %v, want %v", got, tt.want) ++ } ++ if true { ++ t.Errorf("MultipleStringErr() = %v, want %v", got2, tt.want2) ++ } ++ if true { ++ t.Errorf("MultipleStringErr() = %v, want %v", got3, tt.want3) ++ } ++ }) ++ } +} +-- constructor/constructor.go -- +package main + +// Constructor returns the type T. +func NewReturnType() ReturnType {return ReturnType{}} + +type ReturnType struct {} + +func (*ReturnType) Method(in string) string {return in} //@codeaction("Method", "source.addTest", edit=constructor_return_type) + +// Constructor returns the type T and an error. +func NewReturnTypeError() (ReturnTypeError, error) {return ReturnTypeError{}, nil} + +type ReturnTypeError struct {} + +func (*ReturnTypeError) Method(in string) string {return in} //@codeaction("Method", "source.addTest", edit=constructor_return_type_error) + +// Constructor returns the type *T. +func NewReturnPtr() *ReturnPtr {return nil} + +type ReturnPtr struct {} + +func (*ReturnPtr) Method(in string) string {return in} //@codeaction("Method", "source.addTest", edit=constructor_return_ptr) + +// Constructor returns the type *T and an error. +func NewReturnPtrError() (*ReturnPtrError, error) {return nil, nil} + +type ReturnPtrError struct {} + +func (*ReturnPtrError) Method(in string) string {return in} //@codeaction("Method", "source.addTest", edit=constructor_return_ptr_error) + +-- @constructor_return_type/constructor/constructor_test.go -- +@@ -0,0 +1,27 @@ ++package main_test ++ ++import( ++ "golang.org/lsptests/addtest/constructor" ++ "testing" ++) ++ ++func TestReturnType_Method(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ r := main.NewReturnType() ++ got := r.Method(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("Method() = %v, want %v", got, tt.want) ++ } ++ }) ++ } ++} +-- @constructor_return_type_error/constructor/constructor_test.go -- +@@ -0,0 +1,30 @@ ++package main_test ++ ++import( ++ "golang.org/lsptests/addtest/constructor" ++ "testing" ++) ++ ++func TestReturnTypeError_Method(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ r, err := main.NewReturnTypeError() ++ if err != nil { ++ t.Fatalf("could not construct receiver type: %v", err) ++ } ++ got := r.Method(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("Method() = %v, want %v", got, tt.want) ++ } ++ }) ++ } ++} +-- @constructor_return_ptr/constructor/constructor_test.go -- +@@ -0,0 +1,27 @@ ++package main_test ++ ++import( ++ "golang.org/lsptests/addtest/constructor" ++ "testing" ++) ++ ++func TestReturnPtr_Method(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ r := main.NewReturnPtr() ++ got := r.Method(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("Method() = %v, want %v", got, tt.want) ++ } ++ }) ++ } ++} +-- @constructor_return_ptr_error/constructor/constructor_test.go -- +@@ -0,0 +1,30 @@ ++package main_test ++ ++import( ++ "golang.org/lsptests/addtest/constructor" ++ "testing" ++) ++ ++func TestReturnPtrError_Method(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ r, err := main.NewReturnPtrError() ++ if err != nil { ++ t.Fatalf("could not construct receiver type: %v", err) ++ } ++ got := r.Method(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("Method() = %v, want %v", got, tt.want) ++ } ++ }) ++ } ++} +-- constructorcomparison/constructorcomparison.go -- +package main + +// Foo have two constructors. NewFoo is prefered over others. +func CreateAFoo() Foo {return Foo{}} +func NewFoo() Foo {return Foo{}} + +type Foo struct{} + +func (*Foo) Method(in string) string {return in} //@codeaction("Method", "source.addTest", edit=constructor_comparison_new) + +// Bar have two constructors. Bar is preferred due to alphabetical ordering. +func ABar() (Bar, error) {return Bar{}, nil} +// func CreateABar() Bar {return Bar{}} + +type Bar struct{} + +func (*Bar) Method(in string) string {return in} //@codeaction("Method", "source.addTest", edit=constructor_comparison_alphabetical) + +-- @constructor_comparison_new/constructorcomparison/constructorcomparison_test.go -- +@@ -0,0 +1,27 @@ ++package main_test ++ ++import( ++ "golang.org/lsptests/addtest/constructorcomparison" ++ "testing" ++) ++ ++func TestFoo_Method(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ f := main.NewFoo() ++ got := f.Method(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("Method() = %v, want %v", got, tt.want) ++ } ++ }) ++ } ++} +-- @constructor_comparison_alphabetical/constructorcomparison/constructorcomparison_test.go -- +@@ -0,0 +1,30 @@ ++package main_test ++ ++import( ++ "golang.org/lsptests/addtest/constructorcomparison" ++ "testing" ++) ++ ++func TestBar_Method(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ in string ++ want string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ b, err := main.ABar() ++ if err != nil { ++ t.Fatalf("could not construct receiver type: %v", err) ++ } ++ got := b.Method(tt.in) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("Method() = %v, want %v", got, tt.want) ++ } ++ }) ++ } ++} +-- unnamedparam/unnamedparam.go -- +package main + +import "time" + +func FooInputBasic(one, two, _ string, _ int) (out, out1, out2 string) {return "", "", ""} //@codeaction("Foo", "source.addTest", edit=function_basic_type) + +func FooInputStruct(one string, _ time.Time) (out, out1, out2 string) {return "", "", ""} //@codeaction("Foo", "source.addTest", edit=function_struct_type) + +func FooInputPtr(one string, _ *time.Time) (out, out1, out2 string) {return "", "", ""} //@codeaction("Foo", "source.addTest", edit=function_ptr_type) + +func FooInputFunc(one string, _ func(time.Time) *time.Time) (out, out1, out2 string) {return "", "", ""} //@codeaction("Foo", "source.addTest", edit=function_func_type) + +type BarInputBasic struct{} + +func NewBarInputBasic(one, two, _ string, _ int) *BarInputBasic {return nil} + +func (r *BarInputBasic) Method(one, two, _ string, _ int) {} //@codeaction("Method", "source.addTest", edit=constructor_basic_type) + +type BarInputStruct struct{} + +func NewBarInputStruct(one string, _ time.Time) *BarInputStruct {return nil} + +func (r *BarInputStruct) Method(one string, _ time.Time) {} //@codeaction("Method", "source.addTest", edit=constructor_struct_type) + +type BarInputPtr struct{} + +func NewBarInputPtr(one string, _ *time.Time) *BarInputPtr {return nil} + +func (r *BarInputPtr) Method(one string, _ *time.Time) {} //@codeaction("Method", "source.addTest", edit=constructor_ptr_type) + +type BarInputFunction struct{} + +func NewBarInputFunction(one string, _ func(time.Time) *time.Time) *BarInputFunction {return nil} + +func (r *BarInputFunction) Method(one string, _ func(time.Time) *time.Time) {} //@codeaction("Method", "source.addTest", edit=constructor_func_type) + +-- @function_basic_type/unnamedparam/unnamedparam_test.go -- +@@ -0,0 +1,35 @@ ++package main_test ++ ++import( ++ "golang.org/lsptests/addtest/unnamedparam" ++ "testing" ++) ++ ++func TestFooInputBasic(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ one string ++ two string ++ want string ++ want2 string ++ want3 string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ got, got2, got3 := main.FooInputBasic(tt.one, tt.two, "", 0) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("FooInputBasic() = %v, want %v", got, tt.want) ++ } ++ if true { ++ t.Errorf("FooInputBasic() = %v, want %v", got2, tt.want2) ++ } ++ if true { ++ t.Errorf("FooInputBasic() = %v, want %v", got3, tt.want3) ++ } ++ }) ++ } ++} +-- @function_func_type/unnamedparam/unnamedparam_test.go -- +@@ -0,0 +1,35 @@ ++package main_test ++ ++import( ++ "golang.org/lsptests/addtest/unnamedparam" ++ "testing" ++ "time" ++) ++ ++func TestFooInputFunc(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ one string ++ want string ++ want2 string ++ want3 string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ got, got2, got3 := main.FooInputFunc(tt.one, nil) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("FooInputFunc() = %v, want %v", got, tt.want) ++ } ++ if true { ++ t.Errorf("FooInputFunc() = %v, want %v", got2, tt.want2) ++ } ++ if true { ++ t.Errorf("FooInputFunc() = %v, want %v", got3, tt.want3) ++ } ++ }) ++ } ++} +-- @function_ptr_type/unnamedparam/unnamedparam_test.go -- +@@ -0,0 +1,35 @@ ++package main_test ++ ++import( ++ "golang.org/lsptests/addtest/unnamedparam" ++ "testing" ++ "time" ++) ++ ++func TestFooInputPtr(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ one string ++ want string ++ want2 string ++ want3 string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ got, got2, got3 := main.FooInputPtr(tt.one, nil) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("FooInputPtr() = %v, want %v", got, tt.want) ++ } ++ if true { ++ t.Errorf("FooInputPtr() = %v, want %v", got2, tt.want2) ++ } ++ if true { ++ t.Errorf("FooInputPtr() = %v, want %v", got3, tt.want3) ++ } ++ }) ++ } ++} +-- @function_struct_type/unnamedparam/unnamedparam_test.go -- +@@ -0,0 +1,35 @@ ++package main_test ++ ++import( ++ "golang.org/lsptests/addtest/unnamedparam" ++ "testing" ++ "time" ++) ++ ++func TestFooInputStruct(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for target function. ++ one string ++ want string ++ want2 string ++ want3 string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ got, got2, got3 := main.FooInputStruct(tt.one, time.Time{}) ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("FooInputStruct() = %v, want %v", got, tt.want) ++ } ++ if true { ++ t.Errorf("FooInputStruct() = %v, want %v", got2, tt.want2) ++ } ++ if true { ++ t.Errorf("FooInputStruct() = %v, want %v", got3, tt.want3) ++ } ++ }) ++ } ++} +-- @constructor_basic_type/unnamedparam/unnamedparam_test.go -- +@@ -0,0 +1,26 @@ ++package main_test ++ ++import( ++ "golang.org/lsptests/addtest/unnamedparam" ++ "testing" ++) ++ ++func TestBarInputBasic_Method(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for receiver constructor. ++ cone string ++ ctwo string ++ // Named input parameters for target function. ++ one string ++ two string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ r := main.NewBarInputBasic(tt.cone, tt.ctwo, "", 0) ++ r.Method(tt.one, tt.two, "", 0) ++ }) ++ } ++} +-- @constructor_func_type/unnamedparam/unnamedparam_test.go -- +@@ -0,0 +1,25 @@ ++package main_test ++ ++import( ++ "golang.org/lsptests/addtest/unnamedparam" ++ "testing" ++ "time" ++) ++ ++func TestBarInputFunction_Method(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for receiver constructor. ++ cone string ++ // Named input parameters for target function. ++ one string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ r := main.NewBarInputFunction(tt.cone, nil) ++ r.Method(tt.one, nil) ++ }) ++ } ++} +-- @constructor_ptr_type/unnamedparam/unnamedparam_test.go -- +@@ -0,0 +1,25 @@ ++package main_test ++ ++import( ++ "golang.org/lsptests/addtest/unnamedparam" ++ "testing" ++ "time" ++) ++ ++func TestBarInputPtr_Method(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for receiver constructor. ++ cone string ++ // Named input parameters for target function. ++ one string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ r := main.NewBarInputPtr(tt.cone, nil) ++ r.Method(tt.one, nil) ++ }) ++ } ++} +-- @constructor_struct_type/unnamedparam/unnamedparam_test.go -- +@@ -0,0 +1,25 @@ ++package main_test ++ ++import( ++ "golang.org/lsptests/addtest/unnamedparam" ++ "testing" ++ "time" ++) ++ ++func TestBarInputStruct_Method(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ // Named input parameters for receiver constructor. ++ cone string ++ // Named input parameters for target function. ++ one string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ r := main.NewBarInputStruct(tt.cone, time.Time{}) ++ r.Method(tt.one, time.Time{}) ++ }) ++ } ++} +-- contextinput/contextinput.go -- +package main + +import "context" + +func Function(ctx context.Context, _, _ string) (out, out1, out2 string) {return "", "", ""} //@codeaction("Function", "source.addTest", edit=function_context) + +type Foo struct {} + +func NewFoo(ctx context.Context) (*Foo, error) {return nil, nil} + +func (*Foo) Method(ctx context.Context, _, _ string) (out, out1, out2 string) {return "", "", ""} //@codeaction("Method", "source.addTest", edit=method_context) +-- contextinput/contextinput_test.go -- +package main_test + +import renamedctx "context" + +var local renamedctx.Context + +-- @function_context/contextinput/contextinput_test.go -- +@@ -3 +3,3 @@ +-import renamedctx "context" ++import ( ++ renamedctx "context" ++ "testing" +@@ -5 +7,3 @@ ++ "golang.org/lsptests/addtest/contextinput" ++) ++ +@@ -7 +12,26 @@ ++ ++func TestFunction(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ want string ++ want2 string ++ want3 string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ got, got2, got3 := main.Function(renamedctx.Background(), "", "") ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("Function() = %v, want %v", got, tt.want) ++ } ++ if true { ++ t.Errorf("Function() = %v, want %v", got2, tt.want2) ++ } ++ if true { ++ t.Errorf("Function() = %v, want %v", got3, tt.want3) ++ } ++ }) ++ } ++} +-- @method_context/contextinput/contextinput_test.go -- +@@ -3 +3,3 @@ +-import renamedctx "context" ++import ( ++ renamedctx "context" ++ "testing" +@@ -5 +7,3 @@ ++ "golang.org/lsptests/addtest/contextinput" ++) ++ +@@ -7 +12,30 @@ ++ ++func TestFoo_Method(t *testing.T) { ++ tests := []struct { ++ name string // description of this test case ++ want string ++ want2 string ++ want3 string ++ }{ ++ // TODO: Add test cases. ++ } ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ f, err := main.NewFoo(renamedctx.Background()) ++ if err != nil { ++ t.Fatalf("could not construct receiver type: %v", err) ++ } ++ got, got2, got3 := f.Method(renamedctx.Background(), "", "") ++ // TODO: update the condition below to compare got with tt.want. ++ if true { ++ t.Errorf("Method() = %v, want %v", got, tt.want) ++ } ++ if true { ++ t.Errorf("Method() = %v, want %v", got2, tt.want2) ++ } ++ if true { ++ t.Errorf("Method() = %v, want %v", got3, tt.want3) ++ } ++ }) ++ } ++} +-- typeparameter/typeparameter.go -- +package main + +func Function[T any] () {} // no suggested fix + +type Foo struct {} + +func NewFoo() + +func (*Foo) Method[T any]() {} // no suggested fix diff --git a/gopls/internal/test/marker/testdata/codeaction/change_quote.txt b/gopls/internal/test/marker/testdata/codeaction/change_quote.txt index a3b4f8d4c83..928ddc4d88e 100644 --- a/gopls/internal/test/marker/testdata/codeaction/change_quote.txt +++ b/gopls/internal/test/marker/testdata/codeaction/change_quote.txt @@ -17,53 +17,53 @@ import ( func foo() { var s string - s = "hello" //@codeactionedit(`"`, "refactor.rewrite.changeQuote", a1) - s = `hello` //@codeactionedit("`", "refactor.rewrite.changeQuote", a2) - s = "hello\tworld" //@codeactionedit(`"`, "refactor.rewrite.changeQuote", a3) - s = `hello world` //@codeactionedit("`", "refactor.rewrite.changeQuote", a4) - s = "hello\nworld" //@codeactionedit(`"`, "refactor.rewrite.changeQuote", a5) + s = "hello" //@codeaction(`"`, "refactor.rewrite.changeQuote", edit=a1) + s = `hello` //@codeaction("`", "refactor.rewrite.changeQuote", edit=a2) + s = "hello\tworld" //@codeaction(`"`, "refactor.rewrite.changeQuote", edit=a3) + s = `hello world` //@codeaction("`", "refactor.rewrite.changeQuote", edit=a4) + s = "hello\nworld" //@codeaction(`"`, "refactor.rewrite.changeQuote", edit=a5) // add a comment to avoid affect diff compute s = `hello -world` //@codeactionedit("`", "refactor.rewrite.changeQuote", a6) - s = "hello\"world" //@codeactionedit(`"`, "refactor.rewrite.changeQuote", a7) - s = `hello"world` //@codeactionedit("`", "refactor.rewrite.changeQuote", a8) - s = "hello\x1bworld" //@codeactionerr(`"`, "", "refactor.rewrite.changeQuote", re"found 0 CodeActions") - s = "hello`world" //@codeactionerr(`"`, "", "refactor.rewrite.changeQuote", re"found 0 CodeActions") - s = "hello\x7fworld" //@codeactionerr(`"`, "", "refactor.rewrite.changeQuote", re"found 0 CodeActions") +world` //@codeaction("`", "refactor.rewrite.changeQuote", edit=a6) + s = "hello\"world" //@codeaction(`"`, "refactor.rewrite.changeQuote", edit=a7) + s = `hello"world` //@codeaction("`", "refactor.rewrite.changeQuote", edit=a8) + s = "hello\x1bworld" //@codeaction(`"`, "refactor.rewrite.changeQuote", err=re"found 0 CodeActions") + s = "hello`world" //@codeaction(`"`, "refactor.rewrite.changeQuote", err=re"found 0 CodeActions") + s = "hello\x7fworld" //@codeaction(`"`, "refactor.rewrite.changeQuote", err=re"found 0 CodeActions") fmt.Println(s) } -- @a1/a.go -- @@ -9 +9 @@ -- s = "hello" //@codeactionedit(`"`, "refactor.rewrite.changeQuote", a1) -+ s = `hello` //@codeactionedit(`"`, "refactor.rewrite.changeQuote", a1) +- s = "hello" //@codeaction(`"`, "refactor.rewrite.changeQuote", edit=a1) ++ s = `hello` //@codeaction(`"`, "refactor.rewrite.changeQuote", edit=a1) -- @a2/a.go -- @@ -10 +10 @@ -- s = `hello` //@codeactionedit("`", "refactor.rewrite.changeQuote", a2) -+ s = "hello" //@codeactionedit("`", "refactor.rewrite.changeQuote", a2) +- s = `hello` //@codeaction("`", "refactor.rewrite.changeQuote", edit=a2) ++ s = "hello" //@codeaction("`", "refactor.rewrite.changeQuote", edit=a2) -- @a3/a.go -- @@ -11 +11 @@ -- s = "hello\tworld" //@codeactionedit(`"`, "refactor.rewrite.changeQuote", a3) -+ s = `hello world` //@codeactionedit(`"`, "refactor.rewrite.changeQuote", a3) +- s = "hello\tworld" //@codeaction(`"`, "refactor.rewrite.changeQuote", edit=a3) ++ s = `hello world` //@codeaction(`"`, "refactor.rewrite.changeQuote", edit=a3) -- @a4/a.go -- @@ -12 +12 @@ -- s = `hello world` //@codeactionedit("`", "refactor.rewrite.changeQuote", a4) -+ s = "hello\tworld" //@codeactionedit("`", "refactor.rewrite.changeQuote", a4) +- s = `hello world` //@codeaction("`", "refactor.rewrite.changeQuote", edit=a4) ++ s = "hello\tworld" //@codeaction("`", "refactor.rewrite.changeQuote", edit=a4) -- @a5/a.go -- @@ -13 +13,2 @@ -- s = "hello\nworld" //@codeactionedit(`"`, "refactor.rewrite.changeQuote", a5) +- s = "hello\nworld" //@codeaction(`"`, "refactor.rewrite.changeQuote", edit=a5) + s = `hello -+world` //@codeactionedit(`"`, "refactor.rewrite.changeQuote", a5) ++world` //@codeaction(`"`, "refactor.rewrite.changeQuote", edit=a5) -- @a6/a.go -- @@ -15,2 +15 @@ - s = `hello --world` //@codeactionedit("`", "refactor.rewrite.changeQuote", a6) -+ s = "hello\nworld" //@codeactionedit("`", "refactor.rewrite.changeQuote", a6) +-world` //@codeaction("`", "refactor.rewrite.changeQuote", edit=a6) ++ s = "hello\nworld" //@codeaction("`", "refactor.rewrite.changeQuote", edit=a6) -- @a7/a.go -- @@ -17 +17 @@ -- s = "hello\"world" //@codeactionedit(`"`, "refactor.rewrite.changeQuote", a7) -+ s = `hello"world` //@codeactionedit(`"`, "refactor.rewrite.changeQuote", a7) +- s = "hello\"world" //@codeaction(`"`, "refactor.rewrite.changeQuote", edit=a7) ++ s = `hello"world` //@codeaction(`"`, "refactor.rewrite.changeQuote", edit=a7) -- @a8/a.go -- @@ -18 +18 @@ -- s = `hello"world` //@codeactionedit("`", "refactor.rewrite.changeQuote", a8) -+ s = "hello\"world" //@codeactionedit("`", "refactor.rewrite.changeQuote", a8) +- s = `hello"world` //@codeaction("`", "refactor.rewrite.changeQuote", edit=a8) ++ s = "hello\"world" //@codeaction("`", "refactor.rewrite.changeQuote", edit=a8) diff --git a/gopls/internal/test/marker/testdata/codeaction/extract-variadic-63287.txt b/gopls/internal/test/marker/testdata/codeaction/extract-variadic-63287.txt index d035119bc3a..afabcf49f2a 100644 --- a/gopls/internal/test/marker/testdata/codeaction/extract-variadic-63287.txt +++ b/gopls/internal/test/marker/testdata/codeaction/extract-variadic-63287.txt @@ -9,7 +9,7 @@ go 1.18 -- a/a.go -- package a -//@codeactionedit(block, "refactor.extract.function", out) +//@codeaction(block, "refactor.extract.function", edit=out) func _() { var logf func(string, ...any) diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_method.txt b/gopls/internal/test/marker/testdata/codeaction/extract_method.txt index 7cb22d1577d..49388f5bcbc 100644 --- a/gopls/internal/test/marker/testdata/codeaction/extract_method.txt +++ b/gopls/internal/test/marker/testdata/codeaction/extract_method.txt @@ -6,18 +6,18 @@ This test exercises function and method extraction. -- basic.go -- package extract -//@codeactionedit(A_XLessThanYP, "refactor.extract.method", meth1) -//@codeactionedit(A_XLessThanYP, "refactor.extract.function", func1) -//@codeactionedit(A_AddP1, "refactor.extract.method", meth2) -//@codeactionedit(A_AddP1, "refactor.extract.function", func2) -//@codeactionedit(A_AddP2, "refactor.extract.method", meth3) -//@codeactionedit(A_AddP2, "refactor.extract.function", func3) -//@codeactionedit(A_XLessThanY, "refactor.extract.method", meth4) -//@codeactionedit(A_XLessThanY, "refactor.extract.function", func4) -//@codeactionedit(A_Add1, "refactor.extract.method", meth5) -//@codeactionedit(A_Add1, "refactor.extract.function", func5) -//@codeactionedit(A_Add2, "refactor.extract.method", meth6) -//@codeactionedit(A_Add2, "refactor.extract.function", func6) +//@codeaction(A_XLessThanYP, "refactor.extract.method", edit=meth1) +//@codeaction(A_XLessThanYP, "refactor.extract.function", edit=func1) +//@codeaction(A_AddP1, "refactor.extract.method", edit=meth2) +//@codeaction(A_AddP1, "refactor.extract.function", edit=func2) +//@codeaction(A_AddP2, "refactor.extract.method", edit=meth3) +//@codeaction(A_AddP2, "refactor.extract.function", edit=func3) +//@codeaction(A_XLessThanY, "refactor.extract.method", edit=meth4) +//@codeaction(A_XLessThanY, "refactor.extract.function", edit=func4) +//@codeaction(A_Add1, "refactor.extract.method", edit=meth5) +//@codeaction(A_Add1, "refactor.extract.function", edit=func5) +//@codeaction(A_Add2, "refactor.extract.method", edit=meth6) +//@codeaction(A_Add2, "refactor.extract.function", edit=func6) type A struct { x int @@ -162,12 +162,12 @@ import ( "testing" ) -//@codeactionedit(B_AddP, "refactor.extract.method", contextMeth1) -//@codeactionedit(B_AddP, "refactor.extract.function", contextFunc1) -//@codeactionedit(B_LongList, "refactor.extract.method", contextMeth2) -//@codeactionedit(B_LongList, "refactor.extract.function", contextFunc2) -//@codeactionedit(B_AddPWithB, "refactor.extract.function", contextFuncB) -//@codeactionedit(B_LongListWithT, "refactor.extract.function", contextFuncT) +//@codeaction(B_AddP, "refactor.extract.method", edit=contextMeth1) +//@codeaction(B_AddP, "refactor.extract.function", edit=contextFunc1) +//@codeaction(B_LongList, "refactor.extract.method", edit=contextMeth2) +//@codeaction(B_LongList, "refactor.extract.function", edit=contextFunc2) +//@codeaction(B_AddPWithB, "refactor.extract.function", edit=contextFuncB) +//@codeaction(B_LongListWithT, "refactor.extract.function", edit=contextFuncT) type B struct { x int @@ -237,20 +237,14 @@ func (b *B) LongListWithT(ctx context.Context, t *testing.T) (int, error) { +} + -- @contextFuncB/context.go -- -@@ -33 +33,6 @@ -- sum := b.x + b.y //@loc(B_AddPWithB, re`(?s:^.*?Err\(\))`) -+ //@loc(B_AddPWithB, re`(?s:^.*?Err\(\))`) +@@ -33 +33,4 @@ + return newFunction(ctx, tB, b) +} + +func newFunction(ctx context.Context, tB *testing.B, b *B) (int, error) { -+ sum := b.x + b.y -- @contextFuncT/context.go -- -@@ -42 +42,6 @@ -- p4 := p1 + p2 //@loc(B_LongListWithT, re`(?s:^.*?Err\(\))`) -+ //@loc(B_LongListWithT, re`(?s:^.*?Err\(\))`) +@@ -42 +42,4 @@ + return newFunction(ctx, t, p1, p2, p3) +} + +func newFunction(ctx context.Context, t *testing.T, p1 int, p2 int, p3 int) (int, error) { -+ p4 := p1 + p2 diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable-67905.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable-67905.txt index 259b84a09a3..fabbbee99d3 100644 --- a/gopls/internal/test/marker/testdata/codeaction/extract_variable-67905.txt +++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable-67905.txt @@ -16,7 +16,7 @@ import ( func f() io.Reader func main() { - switch r := f().(type) { //@codeactionedit("f()", "refactor.extract.variable", type_switch_func_call) + switch r := f().(type) { //@codeaction("f()", "refactor.extract.variable", edit=type_switch_func_call) default: _ = r } @@ -24,6 +24,6 @@ func main() { -- @type_switch_func_call/extract_switch.go -- @@ -10 +10,2 @@ -- switch r := f().(type) { //@codeactionedit("f()", "refactor.extract.variable", type_switch_func_call) +- switch r := f().(type) { //@codeaction("f()", "refactor.extract.variable", edit=type_switch_func_call) + x := f() -+ switch r := x.(type) { //@codeactionedit("f()", "refactor.extract.variable", type_switch_func_call) ++ switch r := x.(type) { //@codeaction("f()", "refactor.extract.variable", edit=type_switch_func_call) diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable-if.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable-if.txt new file mode 100644 index 00000000000..ab9d76b8602 --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable-if.txt @@ -0,0 +1,41 @@ +This test checks the behavior of the 'extract variable/constant' code actions +when the optimal place for the new declaration is within the "if" statement, +like so: + + if x := 1 + 2 or y + y ; true { + } else if x > 0 { + } + +A future refactor.variable implementation that does this should avoid +using a 'const' declaration, which is not legal at that location. + +-- flags -- +-ignore_extra_diags + +-- a.go -- +package a + +func constant() { + if true { + } else if 1 + 2 > 0 { //@ codeaction("1 + 2", "refactor.extract.constant", edit=constant) + } +} + +func variable(y int) { + if true { + } else if y + y > 0 { //@ codeaction("y + y", "refactor.extract.variable", edit=variable) + } +} + +-- @constant/a.go -- +@@ -4 +4 @@ ++ const k = 1 + 2 +@@ -5 +6 @@ +- } else if 1 + 2 > 0 { //@ codeaction("1 + 2", "refactor.extract.constant", edit=constant) ++ } else if k > 0 { //@ codeaction("1 + 2", "refactor.extract.constant", edit=constant) +-- @variable/a.go -- +@@ -10 +10 @@ ++ x := y + y +@@ -11 +12 @@ +- } else if y + y > 0 { //@ codeaction("y + y", "refactor.extract.variable", edit=variable) ++ } else if x > 0 { //@ codeaction("y + y", "refactor.extract.variable", edit=variable) diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable-inexact.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable-inexact.txt new file mode 100644 index 00000000000..1781b3ce6af --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable-inexact.txt @@ -0,0 +1,36 @@ +This test checks that extract variable/constant permits: +- extraneous whitespace in the selection +- function literals +- pointer dereference expressions +- parenthesized expressions + +-- a.go -- +package a + +func _(ptr *int) { + var _ = 1 + 2 + 3 //@codeaction("1 + 2 ", "refactor.extract.constant", edit=spaces) + var _ = func() {} //@codeaction("func() {}", "refactor.extract.variable", edit=funclit) + var _ = *ptr //@codeaction("*ptr", "refactor.extract.variable", edit=ptr) + var _ = (ptr) //@codeaction("(ptr)", "refactor.extract.variable", edit=paren) +} + +-- @spaces/a.go -- +@@ -4 +4,2 @@ +- var _ = 1 + 2 + 3 //@codeaction("1 + 2 ", "refactor.extract.constant", edit=spaces) ++ const k = 1 + 2 ++ var _ = k+ 3 //@codeaction("1 + 2 ", "refactor.extract.constant", edit=spaces) +-- @funclit/a.go -- +@@ -5 +5,2 @@ +- var _ = func() {} //@codeaction("func() {}", "refactor.extract.variable", edit=funclit) ++ x := func() {} ++ var _ = x //@codeaction("func() {}", "refactor.extract.variable", edit=funclit) +-- @ptr/a.go -- +@@ -6 +6,2 @@ +- var _ = *ptr //@codeaction("*ptr", "refactor.extract.variable", edit=ptr) ++ x := *ptr ++ var _ = x //@codeaction("*ptr", "refactor.extract.variable", edit=ptr) +-- @paren/a.go -- +@@ -7 +7,2 @@ +- var _ = (ptr) //@codeaction("(ptr)", "refactor.extract.variable", edit=paren) ++ x := (ptr) ++ var _ = x //@codeaction("(ptr)", "refactor.extract.variable", edit=paren) diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable-toplevel.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable-toplevel.txt new file mode 100644 index 00000000000..b9166c6299d --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable-toplevel.txt @@ -0,0 +1,51 @@ +This test checks the behavior of the 'extract variable/constant' code action +at top level (outside any function). See issue #70665. + +-- a.go -- +package a + +const length = len("hello") + 2 //@codeaction(`len("hello")`, "refactor.extract.constant", edit=lenhello) + +var slice = append([]int{}, 1, 2, 3) //@codeaction("[]int{}", "refactor.extract.variable", edit=sliceliteral) + +type SHA256 [32]byte //@codeaction("32", "refactor.extract.constant", edit=arraylen) + +func f([2]int) {} //@codeaction("2", "refactor.extract.constant", edit=paramtypearraylen) + +-- @lenhello/a.go -- +@@ -3 +3,2 @@ +-const length = len("hello") + 2 //@codeaction(`len("hello")`, "refactor.extract.constant", edit=lenhello) ++const k = len("hello") ++const length = k + 2 //@codeaction(`len("hello")`, "refactor.extract.constant", edit=lenhello) +-- @sliceliteral/a.go -- +@@ -5 +5,2 @@ +-var slice = append([]int{}, 1, 2, 3) //@codeaction("[]int{}", "refactor.extract.variable", edit=sliceliteral) ++var x = []int{} ++var slice = append(x, 1, 2, 3) //@codeaction("[]int{}", "refactor.extract.variable", edit=sliceliteral) +-- @arraylen/a.go -- +@@ -7 +7,2 @@ +-type SHA256 [32]byte //@codeaction("32", "refactor.extract.constant", edit=arraylen) ++const k = 32 ++type SHA256 [k]byte //@codeaction("32", "refactor.extract.constant", edit=arraylen) +-- @paramtypearraylen/a.go -- +@@ -9 +9,2 @@ +-func f([2]int) {} //@codeaction("2", "refactor.extract.constant", edit=paramtypearraylen) ++const k = 2 ++func f([k]int) {} //@codeaction("2", "refactor.extract.constant", edit=paramtypearraylen) +-- b/b.go -- +package b + +// Check that package- and file-level name collisions are avoided. + +import x3 "errors" + +var x, x1, x2 any // these names are taken already +var _ = x3.New("") +var a, b int +var c = a + b //@codeaction("a + b", "refactor.extract.variable", edit=fresh) + +-- @fresh/b/b.go -- +@@ -10 +10,2 @@ +-var c = a + b //@codeaction("a + b", "refactor.extract.variable", edit=fresh) ++var x4 = a + b ++var c = x4 //@codeaction("a + b", "refactor.extract.variable", edit=fresh) diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable.txt index 8c500d02c1e..c14fb732978 100644 --- a/gopls/internal/test/marker/testdata/codeaction/extract_variable.txt +++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable.txt @@ -1,4 +1,4 @@ -This test checks the behavior of the 'extract variable' code action. +This test checks the behavior of the 'extract variable/constant' code action. See extract_variable_resolve.txt for the same test with resolve support. -- flags -- @@ -8,41 +8,41 @@ See extract_variable_resolve.txt for the same test with resolve support. package extract func _() { - var _ = 1 + 2 //@codeactionedit("1", "refactor.extract.variable", basic_lit1) - var _ = 3 + 4 //@codeactionedit("3 + 4", "refactor.extract.variable", basic_lit2) + var _ = 1 + 2 //@codeaction("1", "refactor.extract.constant", edit=basic_lit1) + var _ = 3 + 4 //@codeaction("3 + 4", "refactor.extract.constant", edit=basic_lit2) } -- @basic_lit1/basic_lit.go -- @@ -4 +4,2 @@ -- var _ = 1 + 2 //@codeactionedit("1", "refactor.extract.variable", basic_lit1) -+ x := 1 -+ var _ = x + 2 //@codeactionedit("1", "refactor.extract.variable", basic_lit1) +- var _ = 1 + 2 //@codeaction("1", "refactor.extract.constant", edit=basic_lit1) ++ const k = 1 ++ var _ = k + 2 //@codeaction("1", "refactor.extract.constant", edit=basic_lit1) -- @basic_lit2/basic_lit.go -- @@ -5 +5,2 @@ -- var _ = 3 + 4 //@codeactionedit("3 + 4", "refactor.extract.variable", basic_lit2) -+ x := 3 + 4 -+ var _ = x //@codeactionedit("3 + 4", "refactor.extract.variable", basic_lit2) +- var _ = 3 + 4 //@codeaction("3 + 4", "refactor.extract.constant", edit=basic_lit2) ++ const k = 3 + 4 ++ var _ = k //@codeaction("3 + 4", "refactor.extract.constant", edit=basic_lit2) -- func_call.go -- package extract import "strconv" func _() { - x0 := append([]int{}, 1) //@codeactionedit("append([]int{}, 1)", "refactor.extract.variable", func_call1) + x0 := append([]int{}, 1) //@codeaction("append([]int{}, 1)", "refactor.extract.variable", edit=func_call1) str := "1" - b, err := strconv.Atoi(str) //@codeactionedit("strconv.Atoi(str)", "refactor.extract.variable", func_call2) + b, err := strconv.Atoi(str) //@codeaction("strconv.Atoi(str)", "refactor.extract.variable", edit=func_call2) } -- @func_call1/func_call.go -- @@ -6 +6,2 @@ -- x0 := append([]int{}, 1) //@codeactionedit("append([]int{}, 1)", "refactor.extract.variable", func_call1) +- x0 := append([]int{}, 1) //@codeaction("append([]int{}, 1)", "refactor.extract.variable", edit=func_call1) + x := append([]int{}, 1) -+ x0 := x //@codeactionedit("append([]int{}, 1)", "refactor.extract.variable", func_call1) ++ x0 := x //@codeaction("append([]int{}, 1)", "refactor.extract.variable", edit=func_call1) -- @func_call2/func_call.go -- @@ -8 +8,2 @@ -- b, err := strconv.Atoi(str) //@codeactionedit("strconv.Atoi(str)", "refactor.extract.variable", func_call2) +- b, err := strconv.Atoi(str) //@codeaction("strconv.Atoi(str)", "refactor.extract.variable", edit=func_call2) + x, x1 := strconv.Atoi(str) -+ b, err := x, x1 //@codeactionedit("strconv.Atoi(str)", "refactor.extract.variable", func_call2) ++ b, err := x, x1 //@codeaction("strconv.Atoi(str)", "refactor.extract.variable", edit=func_call2) -- scope.go -- package extract @@ -51,20 +51,20 @@ import "go/ast" func _() { x0 := 0 if true { - y := ast.CompositeLit{} //@codeactionedit("ast.CompositeLit{}", "refactor.extract.variable", scope1) + y := ast.CompositeLit{} //@codeaction("ast.CompositeLit{}", "refactor.extract.variable", edit=scope1) } if true { - x1 := !false //@codeactionedit("!false", "refactor.extract.variable", scope2) + x := !false //@codeaction("!false", "refactor.extract.constant", edit=scope2) } } -- @scope1/scope.go -- @@ -8 +8,2 @@ -- y := ast.CompositeLit{} //@codeactionedit("ast.CompositeLit{}", "refactor.extract.variable", scope1) +- y := ast.CompositeLit{} //@codeaction("ast.CompositeLit{}", "refactor.extract.variable", edit=scope1) + x := ast.CompositeLit{} -+ y := x //@codeactionedit("ast.CompositeLit{}", "refactor.extract.variable", scope1) ++ y := x //@codeaction("ast.CompositeLit{}", "refactor.extract.variable", edit=scope1) -- @scope2/scope.go -- @@ -11 +11,2 @@ -- x1 := !false //@codeactionedit("!false", "refactor.extract.variable", scope2) -+ x := !false -+ x1 := x //@codeactionedit("!false", "refactor.extract.variable", scope2) +- x := !false //@codeaction("!false", "refactor.extract.constant", edit=scope2) ++ const k = !false ++ x := k //@codeaction("!false", "refactor.extract.constant", edit=scope2) diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable_resolve.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable_resolve.txt index b3a9a67059f..2bf1803a7d8 100644 --- a/gopls/internal/test/marker/testdata/codeaction/extract_variable_resolve.txt +++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable_resolve.txt @@ -1,4 +1,4 @@ -This test checks the behavior of the 'extract variable' code action, with resolve support. +This test checks the behavior of the 'extract variable/constant' code action, with resolve support. See extract_variable.txt for the same test without resolve support. -- capabilities.json -- @@ -19,41 +19,41 @@ See extract_variable.txt for the same test without resolve support. package extract func _() { - var _ = 1 + 2 //@codeactionedit("1", "refactor.extract.variable", basic_lit1) - var _ = 3 + 4 //@codeactionedit("3 + 4", "refactor.extract.variable", basic_lit2) + var _ = 1 + 2 //@codeaction("1", "refactor.extract.constant", edit=basic_lit1) + var _ = 3 + 4 //@codeaction("3 + 4", "refactor.extract.constant", edit=basic_lit2) } -- @basic_lit1/basic_lit.go -- @@ -4 +4,2 @@ -- var _ = 1 + 2 //@codeactionedit("1", "refactor.extract.variable", basic_lit1) -+ x := 1 -+ var _ = x + 2 //@codeactionedit("1", "refactor.extract.variable", basic_lit1) +- var _ = 1 + 2 //@codeaction("1", "refactor.extract.constant", edit=basic_lit1) ++ const k = 1 ++ var _ = k + 2 //@codeaction("1", "refactor.extract.constant", edit=basic_lit1) -- @basic_lit2/basic_lit.go -- @@ -5 +5,2 @@ -- var _ = 3 + 4 //@codeactionedit("3 + 4", "refactor.extract.variable", basic_lit2) -+ x := 3 + 4 -+ var _ = x //@codeactionedit("3 + 4", "refactor.extract.variable", basic_lit2) +- var _ = 3 + 4 //@codeaction("3 + 4", "refactor.extract.constant", edit=basic_lit2) ++ const k = 3 + 4 ++ var _ = k //@codeaction("3 + 4", "refactor.extract.constant", edit=basic_lit2) -- func_call.go -- package extract import "strconv" func _() { - x0 := append([]int{}, 1) //@codeactionedit("append([]int{}, 1)", "refactor.extract.variable", func_call1) + x0 := append([]int{}, 1) //@codeaction("append([]int{}, 1)", "refactor.extract.variable", edit=func_call1) str := "1" - b, err := strconv.Atoi(str) //@codeactionedit("strconv.Atoi(str)", "refactor.extract.variable", func_call2) + b, err := strconv.Atoi(str) //@codeaction("strconv.Atoi(str)", "refactor.extract.variable", edit=func_call2) } -- @func_call1/func_call.go -- @@ -6 +6,2 @@ -- x0 := append([]int{}, 1) //@codeactionedit("append([]int{}, 1)", "refactor.extract.variable", func_call1) +- x0 := append([]int{}, 1) //@codeaction("append([]int{}, 1)", "refactor.extract.variable", edit=func_call1) + x := append([]int{}, 1) -+ x0 := x //@codeactionedit("append([]int{}, 1)", "refactor.extract.variable", func_call1) ++ x0 := x //@codeaction("append([]int{}, 1)", "refactor.extract.variable", edit=func_call1) -- @func_call2/func_call.go -- @@ -8 +8,2 @@ -- b, err := strconv.Atoi(str) //@codeactionedit("strconv.Atoi(str)", "refactor.extract.variable", func_call2) +- b, err := strconv.Atoi(str) //@codeaction("strconv.Atoi(str)", "refactor.extract.variable", edit=func_call2) + x, x1 := strconv.Atoi(str) -+ b, err := x, x1 //@codeactionedit("strconv.Atoi(str)", "refactor.extract.variable", func_call2) ++ b, err := x, x1 //@codeaction("strconv.Atoi(str)", "refactor.extract.variable", edit=func_call2) -- scope.go -- package extract @@ -62,20 +62,20 @@ import "go/ast" func _() { x0 := 0 if true { - y := ast.CompositeLit{} //@codeactionedit("ast.CompositeLit{}", "refactor.extract.variable", scope1) + y := ast.CompositeLit{} //@codeaction("ast.CompositeLit{}", "refactor.extract.variable", edit=scope1) } if true { - x1 := !false //@codeactionedit("!false", "refactor.extract.variable", scope2) + x := !false //@codeaction("!false", "refactor.extract.constant", edit=scope2) } } -- @scope1/scope.go -- @@ -8 +8,2 @@ -- y := ast.CompositeLit{} //@codeactionedit("ast.CompositeLit{}", "refactor.extract.variable", scope1) +- y := ast.CompositeLit{} //@codeaction("ast.CompositeLit{}", "refactor.extract.variable", edit=scope1) + x := ast.CompositeLit{} -+ y := x //@codeactionedit("ast.CompositeLit{}", "refactor.extract.variable", scope1) ++ y := x //@codeaction("ast.CompositeLit{}", "refactor.extract.variable", edit=scope1) -- @scope2/scope.go -- @@ -11 +11,2 @@ -- x1 := !false //@codeactionedit("!false", "refactor.extract.variable", scope2) -+ x := !false -+ x1 := x //@codeactionedit("!false", "refactor.extract.variable", scope2) +- x := !false //@codeaction("!false", "refactor.extract.constant", edit=scope2) ++ const k = !false ++ x := k //@codeaction("!false", "refactor.extract.constant", edit=scope2) diff --git a/gopls/internal/test/marker/testdata/codeaction/extracttofile.txt b/gopls/internal/test/marker/testdata/codeaction/extracttofile.txt index 158a9f9a22c..5577b5e9e26 100644 --- a/gopls/internal/test/marker/testdata/codeaction/extracttofile.txt +++ b/gopls/internal/test/marker/testdata/codeaction/extracttofile.txt @@ -12,46 +12,97 @@ go 1.18 package main // docs -func fn() {} //@codeactionedit("func", "refactor.extract.toNewFile", function_declaration) +func fn() {} //@codeaction("func", "refactor.extract.toNewFile", edit=function_declaration) -func fn2() {} //@codeactionedit("fn2", "refactor.extract.toNewFile", only_select_func_name) +func fn2() {} //@codeaction("fn2", "refactor.extract.toNewFile", edit=only_select_func_name) -func fn3() {} //@codeactionedit(re`()fn3`, "refactor.extract.toNewFile", zero_width_selection_on_func_name) +func fn3() {} //@codeaction(re`()fn3`, "refactor.extract.toNewFile", edit=zero_width_selection_on_func_name) // docs -type T int //@codeactionedit("type", "refactor.extract.toNewFile", type_declaration) +type T int //@codeaction("type", "refactor.extract.toNewFile", edit=type_declaration) // docs -var V int //@codeactionedit("var", "refactor.extract.toNewFile", var_declaration) +var V int //@codeaction("var", "refactor.extract.toNewFile", edit=var_declaration) // docs -const K = "" //@codeactionedit("const", "refactor.extract.toNewFile", const_declaration) +const K = "" //@codeaction("const", "refactor.extract.toNewFile", edit=const_declaration) -const ( //@codeactionedit("const", "refactor.extract.toNewFile", const_declaration_multiple_specs) +const ( //@codeaction("const", "refactor.extract.toNewFile", edit=const_declaration_multiple_specs) P = iota Q R ) -func fnA () {} //@codeaction("func", mdEnd, "refactor.extract.toNewFile", multiple_declarations) +func fnA () {} //@codeaction("func", "refactor.extract.toNewFile", end=mdEnd, result=multiple_declarations) // unattached comment func fnB () {} //@loc(mdEnd, "}") +-- @const_declaration_multiple_specs/p.go -- +@@ -0,0 +1,7 @@ ++package main ++ ++const ( //@codeaction("const", "refactor.extract.toNewFile", edit=const_declaration_multiple_specs) ++ P = iota ++ Q ++ R ++) +-- @multiple_declarations/fna.go -- +package main + +func fnA() {} //@codeaction("func", "refactor.extract.toNewFile", end=mdEnd, result=multiple_declarations) + +// unattached comment + +func fnB() {} +-- @multiple_declarations/a.go -- +package main + +// docs +func fn() {} //@codeaction("func", "refactor.extract.toNewFile", edit=function_declaration) + +func fn2() {} //@codeaction("fn2", "refactor.extract.toNewFile", edit=only_select_func_name) + +func fn3() {} //@codeaction(re`()fn3`, "refactor.extract.toNewFile", edit=zero_width_selection_on_func_name) + +// docs +type T int //@codeaction("type", "refactor.extract.toNewFile", edit=type_declaration) + +// docs +var V int //@codeaction("var", "refactor.extract.toNewFile", edit=var_declaration) + +// docs +const K = "" //@codeaction("const", "refactor.extract.toNewFile", edit=const_declaration) + +const ( //@codeaction("const", "refactor.extract.toNewFile", edit=const_declaration_multiple_specs) + P = iota + Q + R +) + +//@loc(mdEnd, "}") +-- @const_declaration_multiple_specs/a.go -- +@@ -19,6 +19 @@ +-const ( //@codeaction("const", "refactor.extract.toNewFile", edit=const_declaration_multiple_specs) +- P = iota +- Q +- R +-) +- -- existing.go -- -- existing2.go -- -- existing2.1.go -- -- b.go -- package main -func existing() {} //@codeactionedit("func", "refactor.extract.toNewFile", file_name_conflict) -func existing2() {} //@codeactionedit("func", "refactor.extract.toNewFile", file_name_conflict_again) +func existing() {} //@codeaction("func", "refactor.extract.toNewFile", edit=file_name_conflict) +func existing2() {} //@codeaction("func", "refactor.extract.toNewFile", edit=file_name_conflict_again) -- single_import.go -- package main import "fmt" -func F() { //@codeactionedit("func", "refactor.extract.toNewFile", single_import) +func F() { //@codeaction("func", "refactor.extract.toNewFile", edit=single_import) fmt.Println() } @@ -65,24 +116,24 @@ import ( func init(){ log.Println() } -func F() { //@codeactionedit("func", "refactor.extract.toNewFile", multiple_imports) +func F() { //@codeaction("func", "refactor.extract.toNewFile", edit=multiple_imports) fmt.Println() } -func g() string{ //@codeactionedit("func", "refactor.extract.toNewFile", renamed_import) +func g() string{ //@codeaction("func", "refactor.extract.toNewFile", edit=renamed_import) return time1.Now().string() } -- blank_import.go -- package main import _ "fmt" -func F() {} //@codeactionedit("func", "refactor.extract.toNewFile", blank_import) +func F() {} //@codeaction("func", "refactor.extract.toNewFile", edit=blank_import) -- @blank_import/blank_import.go -- @@ -3 +3 @@ --func F() {} //@codeactionedit("func", "refactor.extract.toNewFile", blank_import) -+//@codeactionedit("func", "refactor.extract.toNewFile", blank_import) +-func F() {} //@codeaction("func", "refactor.extract.toNewFile", edit=blank_import) ++//@codeaction("func", "refactor.extract.toNewFile", edit=blank_import) -- @blank_import/f.go -- @@ -0,0 +1,3 @@ +package main @@ -91,35 +142,18 @@ func F() {} //@codeactionedit("func", "refactor.extract.toNewFile", blank_import -- @const_declaration/a.go -- @@ -16,2 +16 @@ -// docs --const K = "" //@codeactionedit("const", "refactor.extract.toNewFile", const_declaration) -+//@codeactionedit("const", "refactor.extract.toNewFile", const_declaration) +-const K = "" //@codeaction("const", "refactor.extract.toNewFile", edit=const_declaration) ++//@codeaction("const", "refactor.extract.toNewFile", edit=const_declaration) -- @const_declaration/k.go -- @@ -0,0 +1,4 @@ +package main + +// docs +const K = "" --- @const_declaration_multiple_specs/a.go -- -@@ -19,6 +19 @@ --const ( //@codeactionedit("const", "refactor.extract.toNewFile", const_declaration_multiple_specs) -- P = iota -- Q -- R --) -- --- @const_declaration_multiple_specs/p.go -- -@@ -0,0 +1,7 @@ -+package main -+ -+const ( //@codeactionedit("const", "refactor.extract.toNewFile", const_declaration_multiple_specs) -+ P = iota -+ Q -+ R -+) -- @file_name_conflict/b.go -- @@ -2 +2 @@ --func existing() {} //@codeactionedit("func", "refactor.extract.toNewFile", file_name_conflict) -+//@codeactionedit("func", "refactor.extract.toNewFile", file_name_conflict) +-func existing() {} //@codeaction("func", "refactor.extract.toNewFile", edit=file_name_conflict) ++//@codeaction("func", "refactor.extract.toNewFile", edit=file_name_conflict) -- @file_name_conflict/existing.1.go -- @@ -0,0 +1,3 @@ +package main @@ -127,8 +161,8 @@ func F() {} //@codeactionedit("func", "refactor.extract.toNewFile", blank_import +func existing() {} -- @file_name_conflict_again/b.go -- @@ -3 +3 @@ --func existing2() {} //@codeactionedit("func", "refactor.extract.toNewFile", file_name_conflict_again) -+//@codeactionedit("func", "refactor.extract.toNewFile", file_name_conflict_again) +-func existing2() {} //@codeaction("func", "refactor.extract.toNewFile", edit=file_name_conflict_again) ++//@codeaction("func", "refactor.extract.toNewFile", edit=file_name_conflict_again) -- @file_name_conflict_again/existing2.2.go -- @@ -0,0 +1,3 @@ +package main @@ -137,50 +171,14 @@ func F() {} //@codeactionedit("func", "refactor.extract.toNewFile", blank_import -- @function_declaration/a.go -- @@ -3,2 +3 @@ -// docs --func fn() {} //@codeactionedit("func", "refactor.extract.toNewFile", function_declaration) -+//@codeactionedit("func", "refactor.extract.toNewFile", function_declaration) +-func fn() {} //@codeaction("func", "refactor.extract.toNewFile", edit=function_declaration) ++//@codeaction("func", "refactor.extract.toNewFile", edit=function_declaration) -- @function_declaration/fn.go -- @@ -0,0 +1,4 @@ +package main + +// docs +func fn() {} --- @multiple_declarations/a.go -- -package main - -// docs -func fn() {} //@codeactionedit("func", "refactor.extract.toNewFile", function_declaration) - -func fn2() {} //@codeactionedit("fn2", "refactor.extract.toNewFile", only_select_func_name) - -func fn3() {} //@codeactionedit(re`()fn3`, "refactor.extract.toNewFile", zero_width_selection_on_func_name) - -// docs -type T int //@codeactionedit("type", "refactor.extract.toNewFile", type_declaration) - -// docs -var V int //@codeactionedit("var", "refactor.extract.toNewFile", var_declaration) - -// docs -const K = "" //@codeactionedit("const", "refactor.extract.toNewFile", const_declaration) - -const ( //@codeactionedit("const", "refactor.extract.toNewFile", const_declaration_multiple_specs) - P = iota - Q - R -) - -//@loc(mdEnd, "}") - - --- @multiple_declarations/fna.go -- -package main - -func fnA() {} //@codeaction("func", mdEnd, "refactor.extract.toNewFile", multiple_declarations) - -// unattached comment - -func fnB() {} -- @multiple_imports/f.go -- @@ -0,0 +1,9 @@ +package main @@ -189,7 +187,7 @@ func fnB() {} + "fmt" +) + -+func F() { //@codeactionedit("func", "refactor.extract.toNewFile", multiple_imports) ++func F() { //@codeaction("func", "refactor.extract.toNewFile", edit=multiple_imports) + fmt.Println() +} -- @multiple_imports/multiple_imports.go -- @@ -197,13 +195,13 @@ func fnB() {} - "fmt" + @@ -10,3 +10 @@ --func F() { //@codeactionedit("func", "refactor.extract.toNewFile", multiple_imports) +-func F() { //@codeaction("func", "refactor.extract.toNewFile", edit=multiple_imports) - fmt.Println() -} -- @only_select_func_name/a.go -- @@ -6 +6 @@ --func fn2() {} //@codeactionedit("fn2", "refactor.extract.toNewFile", only_select_func_name) -+//@codeactionedit("fn2", "refactor.extract.toNewFile", only_select_func_name) +-func fn2() {} //@codeaction("fn2", "refactor.extract.toNewFile", edit=only_select_func_name) ++//@codeaction("fn2", "refactor.extract.toNewFile", edit=only_select_func_name) -- @only_select_func_name/fn2.go -- @@ -0,0 +1,3 @@ +package main @@ -217,20 +215,20 @@ func fnB() {} + "fmt" +) + -+func F() { //@codeactionedit("func", "refactor.extract.toNewFile", single_import) ++func F() { //@codeaction("func", "refactor.extract.toNewFile", edit=single_import) + fmt.Println() +} -- @single_import/single_import.go -- @@ -2,4 +2 @@ -import "fmt" --func F() { //@codeactionedit("func", "refactor.extract.toNewFile", single_import) +-func F() { //@codeaction("func", "refactor.extract.toNewFile", edit=single_import) - fmt.Println() -} -- @type_declaration/a.go -- @@ -10,2 +10 @@ -// docs --type T int //@codeactionedit("type", "refactor.extract.toNewFile", type_declaration) -+//@codeactionedit("type", "refactor.extract.toNewFile", type_declaration) +-type T int //@codeaction("type", "refactor.extract.toNewFile", edit=type_declaration) ++//@codeaction("type", "refactor.extract.toNewFile", edit=type_declaration) -- @type_declaration/t.go -- @@ -0,0 +1,4 @@ +package main @@ -240,8 +238,8 @@ func fnB() {} -- @var_declaration/a.go -- @@ -13,2 +13 @@ -// docs --var V int //@codeactionedit("var", "refactor.extract.toNewFile", var_declaration) -+//@codeactionedit("var", "refactor.extract.toNewFile", var_declaration) +-var V int //@codeaction("var", "refactor.extract.toNewFile", edit=var_declaration) ++//@codeaction("var", "refactor.extract.toNewFile", edit=var_declaration) -- @var_declaration/v.go -- @@ -0,0 +1,4 @@ +package main @@ -250,8 +248,8 @@ func fnB() {} +var V int -- @zero_width_selection_on_func_name/a.go -- @@ -8 +8 @@ --func fn3() {} //@codeactionedit(re`()fn3`, "refactor.extract.toNewFile", zero_width_selection_on_func_name) -+//@codeactionedit(re`()fn3`, "refactor.extract.toNewFile", zero_width_selection_on_func_name) +-func fn3() {} //@codeaction(re`()fn3`, "refactor.extract.toNewFile", edit=zero_width_selection_on_func_name) ++//@codeaction(re`()fn3`, "refactor.extract.toNewFile", edit=zero_width_selection_on_func_name) -- @zero_width_selection_on_func_name/fn3.go -- @@ -0,0 +1,3 @@ +package main @@ -265,7 +263,7 @@ func fnB() {} + time1 "time" +) + -+func g() string { //@codeactionedit("func", "refactor.extract.toNewFile", renamed_import) ++func g() string { //@codeaction("func", "refactor.extract.toNewFile", edit=renamed_import) + return time1.Now().string() +} -- @renamed_import/multiple_imports.go -- @@ -273,7 +271,81 @@ func fnB() {} - time1 "time" + @@ -13,4 +13 @@ --func g() string{ //@codeactionedit("func", "refactor.extract.toNewFile", renamed_import) +-func g() string{ //@codeaction("func", "refactor.extract.toNewFile", edit=renamed_import) - return time1.Now().string() -} - +-- copyright.go -- +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +// docs +const C = "" //@codeaction("const", "refactor.extract.toNewFile", edit=copyright) + +-- @copyright/c.go -- +@@ -0,0 +1,8 @@ ++// Copyright 2020 The Go Authors. All rights reserved. ++// Use of this source code is governed by a BSD-style ++// license that can be found in the LICENSE file. ++ ++package main ++ ++// docs ++const C = "" +-- @copyright/copyright.go -- +@@ -7,2 +7 @@ +-// docs +-const C = "" //@codeaction("const", "refactor.extract.toNewFile", edit=copyright) ++//@codeaction("const", "refactor.extract.toNewFile", edit=copyright) +-- buildconstraint.go -- +//go:build go1.18 + +package main + +// docs +const C = "" //@codeaction("const", "refactor.extract.toNewFile", edit=buildconstraint) + +-- @buildconstraint/buildconstraint.go -- +@@ -5,2 +5 @@ +-// docs +-const C = "" //@codeaction("const", "refactor.extract.toNewFile", edit=buildconstraint) ++//@codeaction("const", "refactor.extract.toNewFile", edit=buildconstraint) +-- @buildconstraint/c.go -- +@@ -0,0 +1,6 @@ ++//go:build go1.18 ++ ++package main ++ ++// docs ++const C = "" +-- copyrightandbuildconstraint.go -- +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.18 + +package main + +// docs +const C = "" //@codeaction("const", "refactor.extract.toNewFile", edit=copyrightandbuildconstraint) +-- @copyrightandbuildconstraint/c.go -- +@@ -0,0 +1,10 @@ ++// Copyright 2020 The Go Authors. All rights reserved. ++// Use of this source code is governed by a BSD-style ++// license that can be found in the LICENSE file. ++ ++//go:build go1.18 ++ ++package main ++ ++// docs ++const C = "" +-- @copyrightandbuildconstraint/copyrightandbuildconstraint.go -- +@@ -9,2 +9 @@ +-// docs +-const C = "" //@codeaction("const", "refactor.extract.toNewFile", edit=copyrightandbuildconstraint) ++//@codeaction("const", "refactor.extract.toNewFile", edit=copyrightandbuildconstraint) diff --git a/gopls/internal/test/marker/testdata/codeaction/fill_struct.txt b/gopls/internal/test/marker/testdata/codeaction/fill_struct.txt index 2b947bf8bbc..2cbd49cffe4 100644 --- a/gopls/internal/test/marker/testdata/codeaction/fill_struct.txt +++ b/gopls/internal/test/marker/testdata/codeaction/fill_struct.txt @@ -28,49 +28,49 @@ type basicStruct struct { foo int } -var _ = basicStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a1) +var _ = basicStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a1) type twoArgStruct struct { foo int bar string } -var _ = twoArgStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a2) +var _ = twoArgStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a2) type nestedStruct struct { bar string basic basicStruct } -var _ = nestedStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a3) +var _ = nestedStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a3) -var _ = data.B{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a4) +var _ = data.B{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a4) -- @a1/a.go -- @@ -11 +11,3 @@ --var _ = basicStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a1) +-var _ = basicStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a1) +var _ = basicStruct{ + foo: 0, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a1) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a1) -- @a2/a.go -- @@ -18 +18,4 @@ --var _ = twoArgStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a2) +-var _ = twoArgStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a2) +var _ = twoArgStruct{ + foo: 0, + bar: "", -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a2) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a2) -- @a3/a.go -- @@ -25 +25,4 @@ --var _ = nestedStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a3) +-var _ = nestedStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a3) +var _ = nestedStruct{ + bar: "", + basic: basicStruct{}, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a3) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a3) -- @a4/a.go -- @@ -27 +27,3 @@ --var _ = data.B{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a4) +-var _ = data.B{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a4) +var _ = data.B{ + ExportedInt: 0, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a4) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a4) -- a2.go -- package fillstruct @@ -82,57 +82,57 @@ type typedStruct struct { a [2]string } -var _ = typedStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a21) +var _ = typedStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a21) type funStruct struct { fn func(i int) int } -var _ = funStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a22) +var _ = funStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a22) type funStructComplex struct { fn func(i int, s string) (string, int) } -var _ = funStructComplex{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a23) +var _ = funStructComplex{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a23) type funStructEmpty struct { fn func() } -var _ = funStructEmpty{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a24) +var _ = funStructEmpty{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a24) -- @a21/a2.go -- @@ -11 +11,7 @@ --var _ = typedStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a21) +-var _ = typedStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a21) +var _ = typedStruct{ + m: map[string]int{}, + s: []int{}, + c: make(chan int), + c1: make(<-chan int), + a: [2]string{}, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a21) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a21) -- @a22/a2.go -- @@ -17 +17,4 @@ --var _ = funStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a22) +-var _ = funStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a22) +var _ = funStruct{ + fn: func(i int) int { + }, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a22) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a22) -- @a23/a2.go -- @@ -23 +23,4 @@ --var _ = funStructComplex{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a23) +-var _ = funStructComplex{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a23) +var _ = funStructComplex{ + fn: func(i int, s string) (string, int) { + }, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a23) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a23) -- @a24/a2.go -- @@ -29 +29,4 @@ --var _ = funStructEmpty{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a24) +-var _ = funStructEmpty{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a24) +var _ = funStructEmpty{ + fn: func() { + }, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a24) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a24) -- a3.go -- package fillstruct @@ -150,7 +150,7 @@ type Bar struct { Y *Foo } -var _ = Bar{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a31) +var _ = Bar{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a31) type importedStruct struct { m map[*ast.CompositeLit]ast.Field @@ -161,7 +161,7 @@ type importedStruct struct { st ast.CompositeLit } -var _ = importedStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a32) +var _ = importedStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a32) type pointerBuiltinStruct struct { b *bool @@ -169,23 +169,23 @@ type pointerBuiltinStruct struct { i *int } -var _ = pointerBuiltinStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a33) +var _ = pointerBuiltinStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a33) var _ = []ast.BasicLit{ - {}, //@codeactionedit("}", "refactor.rewrite.fillStruct", a34) + {}, //@codeaction("}", "refactor.rewrite.fillStruct", edit=a34) } -var _ = []ast.BasicLit{{}} //@codeactionedit("}", "refactor.rewrite.fillStruct", a35) +var _ = []ast.BasicLit{{}} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a35) -- @a31/a3.go -- @@ -17 +17,4 @@ --var _ = Bar{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a31) +-var _ = Bar{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a31) +var _ = Bar{ + X: &Foo{}, + Y: &Foo{}, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a31) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a31) -- @a32/a3.go -- @@ -28 +28,9 @@ --var _ = importedStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a32) +-var _ = importedStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a32) +var _ = importedStruct{ + m: map[*ast.CompositeLit]ast.Field{}, + s: []ast.BadExpr{}, @@ -194,31 +194,31 @@ var _ = []ast.BasicLit{{}} //@codeactionedit("}", "refactor.rewrite.fillStruct", + fn: func(ast_decl ast.DeclStmt) ast.Ellipsis { + }, + st: ast.CompositeLit{}, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a32) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a32) -- @a33/a3.go -- @@ -36 +36,5 @@ --var _ = pointerBuiltinStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a33) +-var _ = pointerBuiltinStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a33) +var _ = pointerBuiltinStruct{ + b: new(bool), + s: new(string), + i: new(int), -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a33) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a33) -- @a34/a3.go -- @@ -39 +39,5 @@ -- {}, //@codeactionedit("}", "refactor.rewrite.fillStruct", a34) +- {}, //@codeaction("}", "refactor.rewrite.fillStruct", edit=a34) + { + ValuePos: 0, + Kind: 0, + Value: "", -+ }, //@codeactionedit("}", "refactor.rewrite.fillStruct", a34) ++ }, //@codeaction("}", "refactor.rewrite.fillStruct", edit=a34) -- @a35/a3.go -- @@ -42 +42,5 @@ --var _ = []ast.BasicLit{{}} //@codeactionedit("}", "refactor.rewrite.fillStruct", a35) +-var _ = []ast.BasicLit{{}} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a35) +var _ = []ast.BasicLit{{ + ValuePos: 0, + Kind: 0, + Value: "", -+}} //@codeactionedit("}", "refactor.rewrite.fillStruct", a35) ++}} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a35) -- a4.go -- package fillstruct @@ -244,48 +244,48 @@ type assignStruct struct { func fill() { var x int - var _ = iStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a41) + var _ = iStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a41) var s string - var _ = sStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a42) + var _ = sStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a42) var n int _ = []int{} if true { arr := []int{1, 2} } - var _ = multiFill{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a43) + var _ = multiFill{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a43) var node *ast.CompositeLit - var _ = assignStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a45) + var _ = assignStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a45) } -- @a41/a4.go -- @@ -25 +25,3 @@ -- var _ = iStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a41) +- var _ = iStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a41) + var _ = iStruct{ + X: x, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", a41) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=a41) -- @a42/a4.go -- @@ -28 +28,3 @@ -- var _ = sStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a42) +- var _ = sStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a42) + var _ = sStruct{ + str: s, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", a42) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=a42) -- @a43/a4.go -- @@ -35 +35,5 @@ -- var _ = multiFill{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a43) +- var _ = multiFill{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a43) + var _ = multiFill{ + num: n, + strin: s, + arr: []int{}, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", a43) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=a43) -- @a45/a4.go -- @@ -38 +38,3 @@ -- var _ = assignStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a45) +- var _ = assignStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a45) + var _ = assignStruct{ + n: node, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", a45) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=a45) -- fillStruct.go -- package fillstruct @@ -306,42 +306,42 @@ type StructA3 struct { } func fill() { - a := StructA{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct1) - b := StructA2{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct2) - c := StructA3{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct3) + a := StructA{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct1) + b := StructA2{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct2) + c := StructA3{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct3) if true { - _ = StructA3{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct4) + _ = StructA3{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct4) } } -- @fillStruct1/fillStruct.go -- @@ -20 +20,7 @@ -- a := StructA{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct1) +- a := StructA{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct1) + a := StructA{ + unexportedIntField: 0, + ExportedIntField: 0, + MapA: map[int]string{}, + Array: []int{}, + StructB: StructB{}, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct1) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct1) -- @fillStruct2/fillStruct.go -- @@ -21 +21,3 @@ -- b := StructA2{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct2) +- b := StructA2{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct2) + b := StructA2{ + B: &StructB{}, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct2) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct2) -- @fillStruct3/fillStruct.go -- @@ -22 +22,3 @@ -- c := StructA3{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct3) +- c := StructA3{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct3) + c := StructA3{ + B: StructB{}, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct3) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct3) -- @fillStruct4/fillStruct.go -- @@ -24 +24,3 @@ -- _ = StructA3{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct4) +- _ = StructA3{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct4) + _ = StructA3{ + B: StructB{}, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct4) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct4) -- fillStruct_anon.go -- package fillstruct @@ -355,16 +355,16 @@ type StructAnon struct { } func fill() { - _ := StructAnon{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_anon) + _ := StructAnon{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_anon) } -- @fillStruct_anon/fillStruct_anon.go -- @@ -13 +13,5 @@ -- _ := StructAnon{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_anon) +- _ := StructAnon{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_anon) + _ := StructAnon{ + a: struct{}{}, + b: map[string]interface{}{}, + c: map[string]struct{d int; e bool}{}, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_anon) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_anon) -- fillStruct_nested.go -- package fillstruct @@ -378,16 +378,16 @@ type StructC struct { func nested() { c := StructB{ - StructC: StructC{}, //@codeactionedit("}", "refactor.rewrite.fillStruct", fill_nested) + StructC: StructC{}, //@codeaction("}", "refactor.rewrite.fillStruct", edit=fill_nested) } } -- @fill_nested/fillStruct_nested.go -- @@ -13 +13,3 @@ -- StructC: StructC{}, //@codeactionedit("}", "refactor.rewrite.fillStruct", fill_nested) +- StructC: StructC{}, //@codeaction("}", "refactor.rewrite.fillStruct", edit=fill_nested) + StructC: StructC{ + unexportedInt: 0, -+ }, //@codeactionedit("}", "refactor.rewrite.fillStruct", fill_nested) ++ }, //@codeaction("}", "refactor.rewrite.fillStruct", edit=fill_nested) -- fillStruct_package.go -- package fillstruct @@ -398,25 +398,25 @@ import ( ) func unexported() { - a := data.B{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_package1) - _ = h2.Client{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_package2) + a := data.B{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_package1) + _ = h2.Client{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_package2) } -- @fillStruct_package1/fillStruct_package.go -- @@ -10 +10,3 @@ -- a := data.B{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_package1) +- a := data.B{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_package1) + a := data.B{ + ExportedInt: 0, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_package1) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_package1) -- @fillStruct_package2/fillStruct_package.go -- @@ -11 +11,7 @@ -- _ = h2.Client{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_package2) +- _ = h2.Client{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_package2) + _ = h2.Client{ + Transport: nil, + CheckRedirect: func(req *h2.Request, via []*h2.Request) error { + }, + Jar: nil, + Timeout: 0, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_package2) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_package2) -- fillStruct_partial.go -- package fillstruct @@ -434,13 +434,13 @@ type StructPartialB struct { func fill() { a := StructPartialA{ PrefilledInt: 5, - } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_partial1) + } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_partial1) b := StructPartialB{ /* this comment should disappear */ PrefilledInt: 7, // This comment should be blown away. /* As should this one */ - } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_partial2) + } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_partial2) } -- @fillStruct_partial1/fillStruct_partial.go -- @@ -465,15 +465,15 @@ type StructD struct { } func spaces() { - d := StructD{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_spaces) + d := StructD{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_spaces) } -- @fillStruct_spaces/fillStruct_spaces.go -- @@ -8 +8,3 @@ -- d := StructD{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_spaces) +- d := StructD{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_spaces) + d := StructD{ + ExportedIntField: 0, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_spaces) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_spaces) -- fillStruct_unsafe.go -- package fillstruct @@ -485,16 +485,16 @@ type unsafeStruct struct { } func fill() { - _ := unsafeStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_unsafe) + _ := unsafeStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_unsafe) } -- @fillStruct_unsafe/fillStruct_unsafe.go -- @@ -11 +11,4 @@ -- _ := unsafeStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_unsafe) +- _ := unsafeStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_unsafe) + _ := unsafeStruct{ + x: 0, + p: nil, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_unsafe) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_unsafe) -- typeparams.go -- package fillstruct @@ -506,59 +506,59 @@ type basicStructWithTypeParams[T any] struct { foo T } -var _ = basicStructWithTypeParams[int]{} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams1) +var _ = basicStructWithTypeParams[int]{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams1) type twoArgStructWithTypeParams[F, B any] struct { foo F bar B } -var _ = twoArgStructWithTypeParams[string, int]{} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams2) +var _ = twoArgStructWithTypeParams[string, int]{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams2) var _ = twoArgStructWithTypeParams[int, string]{ bar: "bar", -} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams3) +} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams3) type nestedStructWithTypeParams struct { bar string basic basicStructWithTypeParams[int] } -var _ = nestedStructWithTypeParams{} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams4) +var _ = nestedStructWithTypeParams{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams4) func _[T any]() { type S struct{ t T } - _ = S{} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams5) + _ = S{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams5) } -- @typeparams1/typeparams.go -- @@ -11 +11,3 @@ --var _ = basicStructWithTypeParams[int]{} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams1) +-var _ = basicStructWithTypeParams[int]{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams1) +var _ = basicStructWithTypeParams[int]{ + foo: 0, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams1) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams1) -- @typeparams2/typeparams.go -- @@ -18 +18,4 @@ --var _ = twoArgStructWithTypeParams[string, int]{} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams2) +-var _ = twoArgStructWithTypeParams[string, int]{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams2) +var _ = twoArgStructWithTypeParams[string, int]{ + foo: "", + bar: 0, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams2) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams2) -- @typeparams3/typeparams.go -- @@ -21 +21 @@ + foo: 0, -- @typeparams4/typeparams.go -- @@ -29 +29,4 @@ --var _ = nestedStructWithTypeParams{} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams4) +-var _ = nestedStructWithTypeParams{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams4) +var _ = nestedStructWithTypeParams{ + bar: "", + basic: basicStructWithTypeParams{}, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams4) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams4) -- @typeparams5/typeparams.go -- @@ -33 +33,3 @@ -- _ = S{} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams5) +- _ = S{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams5) + _ = S{ + t: *new(T), -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams5) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams5) -- issue63921.go -- package fillstruct @@ -571,5 +571,5 @@ type invalidStruct struct { func _() { // Note: the golden content for issue63921 is empty: fillstruct produces no // edits, but does not panic. - invalidStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", issue63921) + invalidStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=issue63921) } diff --git a/gopls/internal/test/marker/testdata/codeaction/fill_struct_resolve.txt b/gopls/internal/test/marker/testdata/codeaction/fill_struct_resolve.txt index 24e7a9126e2..843bb20252d 100644 --- a/gopls/internal/test/marker/testdata/codeaction/fill_struct_resolve.txt +++ b/gopls/internal/test/marker/testdata/codeaction/fill_struct_resolve.txt @@ -39,49 +39,49 @@ type basicStruct struct { foo int } -var _ = basicStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a1) +var _ = basicStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a1) type twoArgStruct struct { foo int bar string } -var _ = twoArgStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a2) +var _ = twoArgStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a2) type nestedStruct struct { bar string basic basicStruct } -var _ = nestedStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a3) +var _ = nestedStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a3) -var _ = data.B{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a4) +var _ = data.B{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a4) -- @a1/a.go -- @@ -11 +11,3 @@ --var _ = basicStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a1) +-var _ = basicStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a1) +var _ = basicStruct{ + foo: 0, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a1) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a1) -- @a2/a.go -- @@ -18 +18,4 @@ --var _ = twoArgStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a2) +-var _ = twoArgStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a2) +var _ = twoArgStruct{ + foo: 0, + bar: "", -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a2) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a2) -- @a3/a.go -- @@ -25 +25,4 @@ --var _ = nestedStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a3) +-var _ = nestedStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a3) +var _ = nestedStruct{ + bar: "", + basic: basicStruct{}, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a3) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a3) -- @a4/a.go -- @@ -27 +27,3 @@ --var _ = data.B{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a4) +-var _ = data.B{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a4) +var _ = data.B{ + ExportedInt: 0, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a4) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a4) -- a2.go -- package fillstruct @@ -93,57 +93,57 @@ type typedStruct struct { a [2]string } -var _ = typedStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a21) +var _ = typedStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a21) type funStruct struct { fn func(i int) int } -var _ = funStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a22) +var _ = funStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a22) type funStructComplex struct { fn func(i int, s string) (string, int) } -var _ = funStructComplex{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a23) +var _ = funStructComplex{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a23) type funStructEmpty struct { fn func() } -var _ = funStructEmpty{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a24) +var _ = funStructEmpty{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a24) -- @a21/a2.go -- @@ -11 +11,7 @@ --var _ = typedStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a21) +-var _ = typedStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a21) +var _ = typedStruct{ + m: map[string]int{}, + s: []int{}, + c: make(chan int), + c1: make(<-chan int), + a: [2]string{}, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a21) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a21) -- @a22/a2.go -- @@ -17 +17,4 @@ --var _ = funStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a22) +-var _ = funStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a22) +var _ = funStruct{ + fn: func(i int) int { + }, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a22) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a22) -- @a23/a2.go -- @@ -23 +23,4 @@ --var _ = funStructComplex{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a23) +-var _ = funStructComplex{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a23) +var _ = funStructComplex{ + fn: func(i int, s string) (string, int) { + }, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a23) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a23) -- @a24/a2.go -- @@ -29 +29,4 @@ --var _ = funStructEmpty{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a24) +-var _ = funStructEmpty{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a24) +var _ = funStructEmpty{ + fn: func() { + }, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a24) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a24) -- a3.go -- package fillstruct @@ -161,7 +161,7 @@ type Bar struct { Y *Foo } -var _ = Bar{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a31) +var _ = Bar{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a31) type importedStruct struct { m map[*ast.CompositeLit]ast.Field @@ -172,7 +172,7 @@ type importedStruct struct { st ast.CompositeLit } -var _ = importedStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a32) +var _ = importedStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a32) type pointerBuiltinStruct struct { b *bool @@ -180,23 +180,23 @@ type pointerBuiltinStruct struct { i *int } -var _ = pointerBuiltinStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a33) +var _ = pointerBuiltinStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a33) var _ = []ast.BasicLit{ - {}, //@codeactionedit("}", "refactor.rewrite.fillStruct", a34) + {}, //@codeaction("}", "refactor.rewrite.fillStruct", edit=a34) } -var _ = []ast.BasicLit{{}} //@codeactionedit("}", "refactor.rewrite.fillStruct", a35) +var _ = []ast.BasicLit{{}} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a35) -- @a31/a3.go -- @@ -17 +17,4 @@ --var _ = Bar{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a31) +-var _ = Bar{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a31) +var _ = Bar{ + X: &Foo{}, + Y: &Foo{}, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a31) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a31) -- @a32/a3.go -- @@ -28 +28,9 @@ --var _ = importedStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a32) +-var _ = importedStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a32) +var _ = importedStruct{ + m: map[*ast.CompositeLit]ast.Field{}, + s: []ast.BadExpr{}, @@ -205,31 +205,31 @@ var _ = []ast.BasicLit{{}} //@codeactionedit("}", "refactor.rewrite.fillStruct", + fn: func(ast_decl ast.DeclStmt) ast.Ellipsis { + }, + st: ast.CompositeLit{}, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a32) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a32) -- @a33/a3.go -- @@ -36 +36,5 @@ --var _ = pointerBuiltinStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a33) +-var _ = pointerBuiltinStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a33) +var _ = pointerBuiltinStruct{ + b: new(bool), + s: new(string), + i: new(int), -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", a33) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a33) -- @a34/a3.go -- @@ -39 +39,5 @@ -- {}, //@codeactionedit("}", "refactor.rewrite.fillStruct", a34) +- {}, //@codeaction("}", "refactor.rewrite.fillStruct", edit=a34) + { + ValuePos: 0, + Kind: 0, + Value: "", -+ }, //@codeactionedit("}", "refactor.rewrite.fillStruct", a34) ++ }, //@codeaction("}", "refactor.rewrite.fillStruct", edit=a34) -- @a35/a3.go -- @@ -42 +42,5 @@ --var _ = []ast.BasicLit{{}} //@codeactionedit("}", "refactor.rewrite.fillStruct", a35) +-var _ = []ast.BasicLit{{}} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a35) +var _ = []ast.BasicLit{{ + ValuePos: 0, + Kind: 0, + Value: "", -+}} //@codeactionedit("}", "refactor.rewrite.fillStruct", a35) ++}} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a35) -- a4.go -- package fillstruct @@ -255,48 +255,48 @@ type assignStruct struct { func fill() { var x int - var _ = iStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a41) + var _ = iStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a41) var s string - var _ = sStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a42) + var _ = sStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a42) var n int _ = []int{} if true { arr := []int{1, 2} } - var _ = multiFill{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a43) + var _ = multiFill{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a43) var node *ast.CompositeLit - var _ = assignStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a45) + var _ = assignStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a45) } -- @a41/a4.go -- @@ -25 +25,3 @@ -- var _ = iStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a41) +- var _ = iStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a41) + var _ = iStruct{ + X: x, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", a41) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=a41) -- @a42/a4.go -- @@ -28 +28,3 @@ -- var _ = sStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a42) +- var _ = sStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a42) + var _ = sStruct{ + str: s, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", a42) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=a42) -- @a43/a4.go -- @@ -35 +35,5 @@ -- var _ = multiFill{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a43) +- var _ = multiFill{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a43) + var _ = multiFill{ + num: n, + strin: s, + arr: []int{}, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", a43) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=a43) -- @a45/a4.go -- @@ -38 +38,3 @@ -- var _ = assignStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", a45) +- var _ = assignStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=a45) + var _ = assignStruct{ + n: node, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", a45) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=a45) -- fillStruct.go -- package fillstruct @@ -317,42 +317,42 @@ type StructA3 struct { } func fill() { - a := StructA{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct1) - b := StructA2{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct2) - c := StructA3{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct3) + a := StructA{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct1) + b := StructA2{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct2) + c := StructA3{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct3) if true { - _ = StructA3{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct4) + _ = StructA3{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct4) } } -- @fillStruct1/fillStruct.go -- @@ -20 +20,7 @@ -- a := StructA{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct1) +- a := StructA{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct1) + a := StructA{ + unexportedIntField: 0, + ExportedIntField: 0, + MapA: map[int]string{}, + Array: []int{}, + StructB: StructB{}, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct1) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct1) -- @fillStruct2/fillStruct.go -- @@ -21 +21,3 @@ -- b := StructA2{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct2) +- b := StructA2{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct2) + b := StructA2{ + B: &StructB{}, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct2) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct2) -- @fillStruct3/fillStruct.go -- @@ -22 +22,3 @@ -- c := StructA3{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct3) +- c := StructA3{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct3) + c := StructA3{ + B: StructB{}, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct3) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct3) -- @fillStruct4/fillStruct.go -- @@ -24 +24,3 @@ -- _ = StructA3{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct4) +- _ = StructA3{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct4) + _ = StructA3{ + B: StructB{}, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct4) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct4) -- fillStruct_anon.go -- package fillstruct @@ -366,16 +366,16 @@ type StructAnon struct { } func fill() { - _ := StructAnon{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_anon) + _ := StructAnon{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_anon) } -- @fillStruct_anon/fillStruct_anon.go -- @@ -13 +13,5 @@ -- _ := StructAnon{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_anon) +- _ := StructAnon{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_anon) + _ := StructAnon{ + a: struct{}{}, + b: map[string]interface{}{}, + c: map[string]struct{d int; e bool}{}, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_anon) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_anon) -- fillStruct_nested.go -- package fillstruct @@ -389,16 +389,16 @@ type StructC struct { func nested() { c := StructB{ - StructC: StructC{}, //@codeactionedit("}", "refactor.rewrite.fillStruct", fill_nested) + StructC: StructC{}, //@codeaction("}", "refactor.rewrite.fillStruct", edit=fill_nested) } } -- @fill_nested/fillStruct_nested.go -- @@ -13 +13,3 @@ -- StructC: StructC{}, //@codeactionedit("}", "refactor.rewrite.fillStruct", fill_nested) +- StructC: StructC{}, //@codeaction("}", "refactor.rewrite.fillStruct", edit=fill_nested) + StructC: StructC{ + unexportedInt: 0, -+ }, //@codeactionedit("}", "refactor.rewrite.fillStruct", fill_nested) ++ }, //@codeaction("}", "refactor.rewrite.fillStruct", edit=fill_nested) -- fillStruct_package.go -- package fillstruct @@ -409,25 +409,25 @@ import ( ) func unexported() { - a := data.B{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_package1) - _ = h2.Client{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_package2) + a := data.B{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_package1) + _ = h2.Client{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_package2) } -- @fillStruct_package1/fillStruct_package.go -- @@ -10 +10,3 @@ -- a := data.B{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_package1) +- a := data.B{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_package1) + a := data.B{ + ExportedInt: 0, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_package1) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_package1) -- @fillStruct_package2/fillStruct_package.go -- @@ -11 +11,7 @@ -- _ = h2.Client{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_package2) +- _ = h2.Client{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_package2) + _ = h2.Client{ + Transport: nil, + CheckRedirect: func(req *h2.Request, via []*h2.Request) error { + }, + Jar: nil, + Timeout: 0, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_package2) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_package2) -- fillStruct_partial.go -- package fillstruct @@ -445,13 +445,13 @@ type StructPartialB struct { func fill() { a := StructPartialA{ PrefilledInt: 5, - } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_partial1) + } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_partial1) b := StructPartialB{ /* this comment should disappear */ PrefilledInt: 7, // This comment should be blown away. /* As should this one */ - } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_partial2) + } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_partial2) } -- @fillStruct_partial1/fillStruct_partial.go -- @@ -476,15 +476,15 @@ type StructD struct { } func spaces() { - d := StructD{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_spaces) + d := StructD{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_spaces) } -- @fillStruct_spaces/fillStruct_spaces.go -- @@ -8 +8,3 @@ -- d := StructD{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_spaces) +- d := StructD{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_spaces) + d := StructD{ + ExportedIntField: 0, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_spaces) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_spaces) -- fillStruct_unsafe.go -- package fillstruct @@ -496,16 +496,16 @@ type unsafeStruct struct { } func fill() { - _ := unsafeStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_unsafe) + _ := unsafeStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_unsafe) } -- @fillStruct_unsafe/fillStruct_unsafe.go -- @@ -11 +11,4 @@ -- _ := unsafeStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_unsafe) +- _ := unsafeStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_unsafe) + _ := unsafeStruct{ + x: 0, + p: nil, -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", fillStruct_unsafe) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=fillStruct_unsafe) -- typeparams.go -- package fillstruct @@ -517,59 +517,59 @@ type basicStructWithTypeParams[T any] struct { foo T } -var _ = basicStructWithTypeParams[int]{} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams1) +var _ = basicStructWithTypeParams[int]{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams1) type twoArgStructWithTypeParams[F, B any] struct { foo F bar B } -var _ = twoArgStructWithTypeParams[string, int]{} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams2) +var _ = twoArgStructWithTypeParams[string, int]{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams2) var _ = twoArgStructWithTypeParams[int, string]{ bar: "bar", -} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams3) +} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams3) type nestedStructWithTypeParams struct { bar string basic basicStructWithTypeParams[int] } -var _ = nestedStructWithTypeParams{} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams4) +var _ = nestedStructWithTypeParams{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams4) func _[T any]() { type S struct{ t T } - _ = S{} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams5) + _ = S{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams5) } -- @typeparams1/typeparams.go -- @@ -11 +11,3 @@ --var _ = basicStructWithTypeParams[int]{} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams1) +-var _ = basicStructWithTypeParams[int]{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams1) +var _ = basicStructWithTypeParams[int]{ + foo: 0, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams1) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams1) -- @typeparams2/typeparams.go -- @@ -18 +18,4 @@ --var _ = twoArgStructWithTypeParams[string, int]{} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams2) +-var _ = twoArgStructWithTypeParams[string, int]{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams2) +var _ = twoArgStructWithTypeParams[string, int]{ + foo: "", + bar: 0, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams2) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams2) -- @typeparams3/typeparams.go -- @@ -21 +21 @@ + foo: 0, -- @typeparams4/typeparams.go -- @@ -29 +29,4 @@ --var _ = nestedStructWithTypeParams{} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams4) +-var _ = nestedStructWithTypeParams{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams4) +var _ = nestedStructWithTypeParams{ + bar: "", + basic: basicStructWithTypeParams{}, -+} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams4) ++} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams4) -- @typeparams5/typeparams.go -- @@ -33 +33,3 @@ -- _ = S{} //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams5) +- _ = S{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams5) + _ = S{ + t: *new(T), -+ } //@codeactionedit("}", "refactor.rewrite.fillStruct", typeparams5) ++ } //@codeaction("}", "refactor.rewrite.fillStruct", edit=typeparams5) -- issue63921.go -- package fillstruct @@ -582,5 +582,5 @@ type invalidStruct struct { func _() { // Note: the golden content for issue63921 is empty: fillstruct produces no // edits, but does not panic. - invalidStruct{} //@codeactionedit("}", "refactor.rewrite.fillStruct", issue63921) + invalidStruct{} //@codeaction("}", "refactor.rewrite.fillStruct", edit=issue63921) } diff --git a/gopls/internal/test/marker/testdata/codeaction/fill_switch.txt b/gopls/internal/test/marker/testdata/codeaction/fill_switch.txt index 0d92b05fc41..1912c92c19a 100644 --- a/gopls/internal/test/marker/testdata/codeaction/fill_switch.txt +++ b/gopls/internal/test/marker/testdata/codeaction/fill_switch.txt @@ -50,19 +50,19 @@ func (notificationTwo) isNotification() {} func doSwitch() { var b data.TypeB switch b { - case data.TypeBOne: //@codeactionedit(":", "refactor.rewrite.fillSwitch", a1) + case data.TypeBOne: //@codeaction(":", "refactor.rewrite.fillSwitch", edit=a1) } var a typeA switch a { - case typeAThree: //@codeactionedit(":", "refactor.rewrite.fillSwitch", a2) + case typeAThree: //@codeaction(":", "refactor.rewrite.fillSwitch", edit=a2) } var n notification - switch n.(type) { //@codeactionedit("{", "refactor.rewrite.fillSwitch", a3) + switch n.(type) { //@codeaction("{", "refactor.rewrite.fillSwitch", edit=a3) } - switch nt := n.(type) { //@codeactionedit("{", "refactor.rewrite.fillSwitch", a4) + switch nt := n.(type) { //@codeaction("{", "refactor.rewrite.fillSwitch", edit=a4) } var s struct { @@ -70,7 +70,7 @@ func doSwitch() { } switch s.a { - case typeAThree: //@codeactionedit(":", "refactor.rewrite.fillSwitch", a5) + case typeAThree: //@codeaction(":", "refactor.rewrite.fillSwitch", edit=a5) } } -- @a1/a.go -- diff --git a/gopls/internal/test/marker/testdata/codeaction/fill_switch_resolve.txt b/gopls/internal/test/marker/testdata/codeaction/fill_switch_resolve.txt index 84464417b81..c8380a7d6d6 100644 --- a/gopls/internal/test/marker/testdata/codeaction/fill_switch_resolve.txt +++ b/gopls/internal/test/marker/testdata/codeaction/fill_switch_resolve.txt @@ -61,19 +61,19 @@ func (notificationTwo) isNotification() {} func doSwitch() { var b data.TypeB switch b { - case data.TypeBOne: //@codeactionedit(":", "refactor.rewrite.fillSwitch", a1) + case data.TypeBOne: //@codeaction(":", "refactor.rewrite.fillSwitch", edit=a1) } var a typeA switch a { - case typeAThree: //@codeactionedit(":", "refactor.rewrite.fillSwitch", a2) + case typeAThree: //@codeaction(":", "refactor.rewrite.fillSwitch", edit=a2) } var n notification - switch n.(type) { //@codeactionedit("{", "refactor.rewrite.fillSwitch", a3) + switch n.(type) { //@codeaction("{", "refactor.rewrite.fillSwitch", edit=a3) } - switch nt := n.(type) { //@codeactionedit("{", "refactor.rewrite.fillSwitch", a4) + switch nt := n.(type) { //@codeaction("{", "refactor.rewrite.fillSwitch", edit=a4) } var s struct { @@ -81,7 +81,7 @@ func doSwitch() { } switch s.a { - case typeAThree: //@codeactionedit(":", "refactor.rewrite.fillSwitch", a5) + case typeAThree: //@codeaction(":", "refactor.rewrite.fillSwitch", edit=a5) } } -- @a1/a.go -- diff --git a/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt b/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt index 1b9f487c49d..f84eeae7b4c 100644 --- a/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt +++ b/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt @@ -8,34 +8,32 @@ go 1.18 -- basic.go -- package extract -func _() { //@codeaction("{", closeBracket, "refactor.extract.function", outer) - a := 1 //@codeaction("a", end, "refactor.extract.function", inner) +func _() { //@codeaction("{", "refactor.extract.function", end=closeBracket, result=outer) + a := 1 //@codeaction("a", "refactor.extract.function", end=end, result=inner) _ = a + 4 //@loc(end, "4") } //@loc(closeBracket, "}") --- @inner/basic.go -- +-- @outer/basic.go -- package extract -func _() { //@codeaction("{", closeBracket, "refactor.extract.function", outer) - //@codeaction("a", end, "refactor.extract.function", inner) +func _() { //@codeaction("{", "refactor.extract.function", end=closeBracket, result=outer) newFunction() //@loc(end, "4") } func newFunction() { - a := 1 + a := 1 //@codeaction("a", "refactor.extract.function", end=end, result=inner) _ = a + 4 } //@loc(closeBracket, "}") --- @outer/basic.go -- +-- @inner/basic.go -- package extract -func _() { //@codeaction("{", closeBracket, "refactor.extract.function", outer) - //@codeaction("a", end, "refactor.extract.function", inner) +func _() { //@codeaction("{", "refactor.extract.function", end=closeBracket, result=outer) newFunction() //@loc(end, "4") } func newFunction() { - a := 1 + a := 1 //@codeaction("a", "refactor.extract.function", end=end, result=inner) _ = a + 4 } //@loc(closeBracket, "}") @@ -44,7 +42,7 @@ package extract func _() bool { x := 1 - if x == 0 { //@codeaction("if", ifend, "refactor.extract.function", return) + if x == 0 { //@codeaction("if", "refactor.extract.function", end=ifend, result=return) return true } //@loc(ifend, "}") return false @@ -55,16 +53,15 @@ package extract func _() bool { x := 1 - //@codeaction("if", ifend, "refactor.extract.function", return) - shouldReturn, returnValue := newFunction(x) + shouldReturn, b := newFunction(x) if shouldReturn { - return returnValue + return b } //@loc(ifend, "}") return false } func newFunction(x int) (bool, bool) { - if x == 0 { + if x == 0 { //@codeaction("if", "refactor.extract.function", end=ifend, result=return) return true, true } return false, false @@ -74,7 +71,7 @@ func newFunction(x int) (bool, bool) { package extract func _() bool { - x := 1 //@codeaction("x", rnnEnd, "refactor.extract.function", rnn) + x := 1 //@codeaction("x", "refactor.extract.function", end=rnnEnd, result=rnn) if x == 0 { return true } @@ -85,12 +82,11 @@ func _() bool { package extract func _() bool { - //@codeaction("x", rnnEnd, "refactor.extract.function", rnn) return newFunction() //@loc(rnnEnd, "false") } func newFunction() bool { - x := 1 + x := 1 //@codeaction("x", "refactor.extract.function", end=rnnEnd, result=rnn) if x == 0 { return true } @@ -105,7 +101,7 @@ import "fmt" func _() (int, string, error) { x := 1 y := "hello" - z := "bye" //@codeaction("z", rcEnd, "refactor.extract.function", rc) + z := "bye" //@codeaction("z", "refactor.extract.function", end=rcEnd, result=rc) if y == z { return x, y, fmt.Errorf("same") } else if false { @@ -123,16 +119,15 @@ import "fmt" func _() (int, string, error) { x := 1 y := "hello" - //@codeaction("z", rcEnd, "refactor.extract.function", rc) - z, shouldReturn, returnValue, returnValue1, returnValue2 := newFunction(y, x) + z, shouldReturn, i, s, err := newFunction(y, x) if shouldReturn { - return returnValue, returnValue1, returnValue2 + return i, s, err } //@loc(rcEnd, "}") return x, z, nil } func newFunction(y string, x int) (string, bool, int, string, error) { - z := "bye" + z := "bye" //@codeaction("z", "refactor.extract.function", end=rcEnd, result=rc) if y == z { return "", true, x, y, fmt.Errorf("same") } else if false { @@ -150,7 +145,7 @@ import "fmt" func _() (int, string, error) { x := 1 y := "hello" - z := "bye" //@codeaction("z", rcnnEnd, "refactor.extract.function", rcnn) + z := "bye" //@codeaction("z", "refactor.extract.function", end=rcnnEnd, result=rcnn) if y == z { return x, y, fmt.Errorf("same") } else if false { @@ -168,12 +163,11 @@ import "fmt" func _() (int, string, error) { x := 1 y := "hello" - //@codeaction("z", rcnnEnd, "refactor.extract.function", rcnn) return newFunction(y, x) //@loc(rcnnEnd, "nil") } func newFunction(y string, x int) (int, string, error) { - z := "bye" + z := "bye" //@codeaction("z", "refactor.extract.function", end=rcnnEnd, result=rcnn) if y == z { return x, y, fmt.Errorf("same") } else if false { @@ -190,7 +184,7 @@ import "go/ast" func _() { ast.Inspect(ast.NewIdent("a"), func(n ast.Node) bool { - if n == nil { //@codeaction("if", rflEnd, "refactor.extract.function", rfl) + if n == nil { //@codeaction("if", "refactor.extract.function", end=rflEnd, result=rfl) return true } //@loc(rflEnd, "}") return false @@ -204,17 +198,16 @@ import "go/ast" func _() { ast.Inspect(ast.NewIdent("a"), func(n ast.Node) bool { - //@codeaction("if", rflEnd, "refactor.extract.function", rfl) - shouldReturn, returnValue := newFunction(n) + shouldReturn, b := newFunction(n) if shouldReturn { - return returnValue + return b } //@loc(rflEnd, "}") return false }) } func newFunction(n ast.Node) (bool, bool) { - if n == nil { + if n == nil { //@codeaction("if", "refactor.extract.function", end=rflEnd, result=rfl) return true, true } return false, false @@ -227,7 +220,7 @@ import "go/ast" func _() { ast.Inspect(ast.NewIdent("a"), func(n ast.Node) bool { - if n == nil { //@codeaction("if", rflnnEnd, "refactor.extract.function", rflnn) + if n == nil { //@codeaction("if", "refactor.extract.function", end=rflnnEnd, result=rflnn) return true } return false //@loc(rflnnEnd, "false") @@ -241,13 +234,12 @@ import "go/ast" func _() { ast.Inspect(ast.NewIdent("a"), func(n ast.Node) bool { - //@codeaction("if", rflnnEnd, "refactor.extract.function", rflnn) return newFunction(n) //@loc(rflnnEnd, "false") }) } func newFunction(n ast.Node) bool { - if n == nil { + if n == nil { //@codeaction("if", "refactor.extract.function", end=rflnnEnd, result=rflnn) return true } return false @@ -258,7 +250,7 @@ package extract func _() string { x := 1 - if x == 0 { //@codeaction("if", riEnd, "refactor.extract.function", ri) + if x == 0 { //@codeaction("if", "refactor.extract.function", end=riEnd, result=ri) x = 3 return "a" } //@loc(riEnd, "}") @@ -271,17 +263,16 @@ package extract func _() string { x := 1 - //@codeaction("if", riEnd, "refactor.extract.function", ri) - shouldReturn, returnValue := newFunction(x) + shouldReturn, s := newFunction(x) if shouldReturn { - return returnValue + return s } //@loc(riEnd, "}") x = 2 return "b" } func newFunction(x int) (bool, string) { - if x == 0 { + if x == 0 { //@codeaction("if", "refactor.extract.function", end=riEnd, result=ri) x = 3 return true, "a" } @@ -293,7 +284,7 @@ package extract func _() string { x := 1 - if x == 0 { //@codeaction("if", rinnEnd, "refactor.extract.function", rinn) + if x == 0 { //@codeaction("if", "refactor.extract.function", end=rinnEnd, result=rinn) x = 3 return "a" } @@ -306,12 +297,11 @@ package extract func _() string { x := 1 - //@codeaction("if", rinnEnd, "refactor.extract.function", rinn) return newFunction(x) //@loc(rinnEnd, "\"b\"") } func newFunction(x int) string { - if x == 0 { + if x == 0 { //@codeaction("if", "refactor.extract.function", end=rinnEnd, result=rinn) x = 3 return "a" } @@ -324,10 +314,10 @@ package extract func _() { a := 1 - a = 5 //@codeaction("a", araend, "refactor.extract.function", ara) + a = 5 //@codeaction("a", "refactor.extract.function", end=araend, result=ara) a = a + 2 //@loc(araend, "2") - b := a * 2 //@codeaction("b", arbend, "refactor.extract.function", arb) + b := a * 2 //@codeaction("b", "refactor.extract.function", end=arbend, result=arb) _ = b + 4 //@loc(arbend, "4") } @@ -336,15 +326,14 @@ package extract func _() { a := 1 - //@codeaction("a", araend, "refactor.extract.function", ara) a = newFunction(a) //@loc(araend, "2") - b := a * 2 //@codeaction("b", arbend, "refactor.extract.function", arb) + b := a * 2 //@codeaction("b", "refactor.extract.function", end=arbend, result=arb) _ = b + 4 //@loc(arbend, "4") } func newFunction(a int) int { - a = 5 + a = 5 //@codeaction("a", "refactor.extract.function", end=araend, result=ara) a = a + 2 return a } @@ -354,15 +343,14 @@ package extract func _() { a := 1 - a = 5 //@codeaction("a", araend, "refactor.extract.function", ara) + a = 5 //@codeaction("a", "refactor.extract.function", end=araend, result=ara) a = a + 2 //@loc(araend, "2") - //@codeaction("b", arbend, "refactor.extract.function", arb) newFunction(a) //@loc(arbend, "4") } func newFunction(a int) { - b := a * 2 + b := a * 2 //@codeaction("b", "refactor.extract.function", end=arbend, result=arb) _ = b + 4 } @@ -371,7 +359,7 @@ package extract func _() { newFunction := 1 - a := newFunction //@codeaction("a", "newFunction", "refactor.extract.function", scope) + a := newFunction //@codeaction("a", "refactor.extract.function", end="newFunction", result=scope) _ = a // avoid diagnostic } @@ -384,7 +372,7 @@ package extract func _() { newFunction := 1 - a := newFunction2(newFunction) //@codeaction("a", "newFunction", "refactor.extract.function", scope) + a := newFunction2(newFunction) //@codeaction("a", "refactor.extract.function", end="newFunction", result=scope) _ = a // avoid diagnostic } @@ -402,7 +390,7 @@ package extract func _() { var a []int - a = append(a, 2) //@codeaction("a", siEnd, "refactor.extract.function", si) + a = append(a, 2) //@codeaction("a", "refactor.extract.function", end=siEnd, result=si) b := 4 //@loc(siEnd, "4") a = append(a, b) } @@ -412,13 +400,12 @@ package extract func _() { var a []int - //@codeaction("a", siEnd, "refactor.extract.function", si) a, b := newFunction(a) //@loc(siEnd, "4") a = append(a, b) } func newFunction(a []int) ([]int, int) { - a = append(a, 2) + a = append(a, 2) //@codeaction("a", "refactor.extract.function", end=siEnd, result=si) b := 4 return a, b } @@ -429,7 +416,7 @@ package extract func _() { var b []int var a int - a = 2 //@codeaction("a", srEnd, "refactor.extract.function", sr) + a = 2 //@codeaction("a", "refactor.extract.function", end=srEnd, result=sr) b = []int{} b = append(b, a) //@loc(srEnd, ")") b[0] = 1 @@ -441,13 +428,12 @@ package extract func _() { var b []int var a int - //@codeaction("a", srEnd, "refactor.extract.function", sr) b = newFunction(a, b) //@loc(srEnd, ")") b[0] = 1 } func newFunction(a int, b []int) []int { - a = 2 + a = 2 //@codeaction("a", "refactor.extract.function", end=srEnd, result=sr) b = []int{} b = append(b, a) return b @@ -458,7 +444,7 @@ package extract func _() { var b []int - a := 2 //@codeaction("a", upEnd, "refactor.extract.function", up) + a := 2 //@codeaction("a", "refactor.extract.function", end=upEnd, result=up) b = []int{} b = append(b, a) //@loc(upEnd, ")") b[0] = 1 @@ -472,7 +458,6 @@ package extract func _() { var b []int - //@codeaction("a", upEnd, "refactor.extract.function", up) a, b := newFunction(b) //@loc(upEnd, ")") b[0] = 1 if a == 2 { @@ -481,7 +466,7 @@ func _() { } func newFunction(b []int) (int, []int) { - a := 2 + a := 2 //@codeaction("a", "refactor.extract.function", end=upEnd, result=up) b = []int{} b = append(b, a) return a, b @@ -491,9 +476,9 @@ func newFunction(b []int) (int, []int) { package extract func _() { - a := /* comment in the middle of a line */ 1 //@codeaction("a", commentEnd, "refactor.extract.function", comment1) - // Comment on its own line //@codeaction("Comment", commentEnd, "refactor.extract.function", comment2) - _ = a + 4 //@loc(commentEnd, "4"),codeaction("_", lastComment, "refactor.extract.function", comment3) + a := /* comment in the middle of a line */ 1 //@codeaction("a", "refactor.extract.function", end=commentEnd, result=comment1) + // Comment on its own line //@codeaction("Comment", "refactor.extract.function", end=commentEnd, result=comment2) + _ = a + 4 //@loc(commentEnd, "4"),codeaction("_", "refactor.extract.function", end=lastComment, result=comment3) // Comment right after 3 + 4 // Comment after with space //@loc(lastComment, "Comment") @@ -503,18 +488,15 @@ func _() { package extract func _() { - /* comment in the middle of a line */ - //@codeaction("a", commentEnd, "refactor.extract.function", comment1) - // Comment on its own line //@codeaction("Comment", commentEnd, "refactor.extract.function", comment2) - newFunction() //@loc(commentEnd, "4"),codeaction("_", lastComment, "refactor.extract.function", comment3) + newFunction() //@loc(commentEnd, "4"),codeaction("_", "refactor.extract.function", end=lastComment, result=comment3) // Comment right after 3 + 4 // Comment after with space //@loc(lastComment, "Comment") } func newFunction() { - a := 1 - + a := /* comment in the middle of a line */ 1 //@codeaction("a", "refactor.extract.function", end=commentEnd, result=comment1) + // Comment on its own line //@codeaction("Comment", "refactor.extract.function", end=commentEnd, result=comment2) _ = a + 4 } @@ -522,9 +504,9 @@ func newFunction() { package extract func _() { - a := /* comment in the middle of a line */ 1 //@codeaction("a", commentEnd, "refactor.extract.function", comment1) - // Comment on its own line //@codeaction("Comment", commentEnd, "refactor.extract.function", comment2) - newFunction(a) //@loc(commentEnd, "4"),codeaction("_", lastComment, "refactor.extract.function", comment3) + a := /* comment in the middle of a line */ 1 //@codeaction("a", "refactor.extract.function", end=commentEnd, result=comment1) + // Comment on its own line //@codeaction("Comment", "refactor.extract.function", end=commentEnd, result=comment2) + newFunction(a) //@loc(commentEnd, "4"),codeaction("_", "refactor.extract.function", end=lastComment, result=comment3) // Comment right after 3 + 4 // Comment after with space //@loc(lastComment, "Comment") @@ -538,9 +520,9 @@ func newFunction(a int) { package extract func _() { - a := /* comment in the middle of a line */ 1 //@codeaction("a", commentEnd, "refactor.extract.function", comment1) - // Comment on its own line //@codeaction("Comment", commentEnd, "refactor.extract.function", comment2) - newFunction(a) //@loc(commentEnd, "4"),codeaction("_", lastComment, "refactor.extract.function", comment3) + a := /* comment in the middle of a line */ 1 //@codeaction("a", "refactor.extract.function", end=commentEnd, result=comment1) + // Comment on its own line //@codeaction("Comment", "refactor.extract.function", end=commentEnd, result=comment2) + newFunction(a) //@loc(commentEnd, "4"),codeaction("_", "refactor.extract.function", end=lastComment, result=comment3) // Comment right after 3 + 4 // Comment after with space //@loc(lastComment, "Comment") @@ -557,7 +539,7 @@ import "strconv" func _() { i, err := strconv.Atoi("1") - u, err := strconv.Atoi("2") //@codeaction("u", ")", "refactor.extract.function", redefine) + u, err := strconv.Atoi("2") //@codeaction(re`u.*\)`, "refactor.extract.function", result=redefine) if i == u || err == nil { return } @@ -570,7 +552,7 @@ import "strconv" func _() { i, err := strconv.Atoi("1") - u, err := newFunction() //@codeaction("u", ")", "refactor.extract.function", redefine) + u, err := newFunction() //@codeaction(re`u.*\)`, "refactor.extract.function", result=redefine) if i == u || err == nil { return } @@ -588,7 +570,7 @@ import "slices" // issue go#64821 func _() { - var s []string //@codeaction("var", anonEnd, "refactor.extract.function", anon1) + var s []string //@codeaction("var", "refactor.extract.function", end=anonEnd, result=anon1) slices.SortFunc(s, func(a, b string) int { return cmp.Compare(a, b) }) @@ -602,12 +584,11 @@ import "slices" // issue go#64821 func _() { - //@codeaction("var", anonEnd, "refactor.extract.function", anon1) newFunction() //@loc(anonEnd, ")") } func newFunction() { - var s []string + var s []string //@codeaction("var", "refactor.extract.function", end=anonEnd, result=anon1) slices.SortFunc(s, func(a, b string) int { return cmp.Compare(a, b) }) diff --git a/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue44813.txt b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue44813.txt index aaca44d6c7a..c1302b1bfef 100644 --- a/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue44813.txt +++ b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue44813.txt @@ -12,7 +12,7 @@ package extract import "fmt" func main() { - x := []rune{} //@codeaction("x", end, "refactor.extract.function", ext) + x := []rune{} //@codeaction("x", "refactor.extract.function", end=end, result=ext) s := "HELLO" for _, c := range s { x = append(x, c) @@ -26,13 +26,12 @@ package extract import "fmt" func main() { - //@codeaction("x", end, "refactor.extract.function", ext) x := newFunction() //@loc(end, "}") fmt.Printf("%x\n", x) } func newFunction() []rune { - x := []rune{} + x := []rune{} //@codeaction("x", "refactor.extract.function", end=end, result=ext) s := "HELLO" for _, c := range s { x = append(x, c) diff --git a/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue50851.txt b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue50851.txt new file mode 100644 index 00000000000..b085559cf2a --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue50851.txt @@ -0,0 +1,35 @@ +This test checks that function extraction moves comments along with the +extracted code. + +-- main.go -- +package main + +type F struct{} + +func (f *F) func1() { + println("a") + + println("b") //@ codeaction("print", "refactor.extract.function", end=end, result=result) + // This line prints the third letter of the alphabet. + println("c") //@loc(end, ")") + + println("d") +} +-- @result/main.go -- +package main + +type F struct{} + +func (f *F) func1() { + println("a") + + newFunction() //@loc(end, ")") + + println("d") +} + +func newFunction() { + println("b") //@ codeaction("print", "refactor.extract.function", end=end, result=result) + // This line prints the third letter of the alphabet. + println("c") +} diff --git a/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue66289.txt b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue66289.txt new file mode 100644 index 00000000000..30db2fb3ed0 --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue66289.txt @@ -0,0 +1,97 @@ + +-- a.go -- +package a + +import ( + "fmt" + "encoding/json" +) + +func F() error { + a, err := json.Marshal(0) //@codeaction("a", "refactor.extract.function", end=endF, result=F) + if err != nil { + return fmt.Errorf("1: %w", err) + } + b, err := json.Marshal(0) + if err != nil { + return fmt.Errorf("2: %w", err) + } //@loc(endF, "}") + fmt.Println(a, b) + return nil +} + +-- @F/a.go -- +package a + +import ( + "fmt" + "encoding/json" +) + +func F() error { + a, b, shouldReturn, err := newFunction() + if shouldReturn { + return err + } //@loc(endF, "}") + fmt.Println(a, b) + return nil +} + +func newFunction() ([]byte, []byte, bool, error) { + a, err := json.Marshal(0) //@codeaction("a", "refactor.extract.function", end=endF, result=F) + if err != nil { + return nil, nil, true, fmt.Errorf("1: %w", err) + } + b, err := json.Marshal(0) + if err != nil { + return nil, nil, true, fmt.Errorf("2: %w", err) + } + return a, b, false, nil +} + +-- b.go -- +package a + +import ( + "fmt" + "math/rand" +) + +func G() (x, y int) { + v := rand.Int() //@codeaction("v", "refactor.extract.function", end=endG, result=G) + if v < 0 { + return 1, 2 + } + if v > 0 { + return 3, 4 + } //@loc(endG, "}") + fmt.Println(v) + return 5, 6 +} +-- @G/b.go -- +package a + +import ( + "fmt" + "math/rand" +) + +func G() (x, y int) { + v, shouldReturn, x1, y1 := newFunction() + if shouldReturn { + return x1, y1 + } //@loc(endG, "}") + fmt.Println(v) + return 5, 6 +} + +func newFunction() (int, bool, int, int) { + v := rand.Int() //@codeaction("v", "refactor.extract.function", end=endG, result=G) + if v < 0 { + return 0, true, 1, 2 + } + if v > 0 { + return 0, true, 3, 4 + } + return v, false, 0, 0 +} diff --git a/gopls/internal/test/marker/testdata/codeaction/grouplines.txt b/gopls/internal/test/marker/testdata/codeaction/grouplines.txt index 1f14360d2e9..766b13b7f56 100644 --- a/gopls/internal/test/marker/testdata/codeaction/grouplines.txt +++ b/gopls/internal/test/marker/testdata/codeaction/grouplines.txt @@ -12,7 +12,7 @@ package func_arg func A( a string, b, c int64, - x int /*@codeaction("x", "x", "refactor.rewrite.joinLines", func_arg)*/, + x int /*@codeaction("x", "refactor.rewrite.joinLines", result=func_arg)*/, y int, ) (r1 string, r2, r3 int64, r4 int, r5 int) { return a, b, c, x, y @@ -21,7 +21,7 @@ func A( -- @func_arg/func_arg/func_arg.go -- package func_arg -func A(a string, b, c int64, x int /*@codeaction("x", "x", "refactor.rewrite.joinLines", func_arg)*/, y int) (r1 string, r2, r3 int64, r4 int, r5 int) { +func A(a string, b, c int64, x int /*@codeaction("x", "refactor.rewrite.joinLines", result=func_arg)*/, y int) (r1 string, r2, r3 int64, r4 int, r5 int) { return a, b, c, x, y } @@ -29,7 +29,7 @@ func A(a string, b, c int64, x int /*@codeaction("x", "x", "refactor.rewrite.joi package func_ret func A(a string, b, c int64, x int, y int) ( - r1 string /*@codeaction("r1", "r1", "refactor.rewrite.joinLines", func_ret)*/, + r1 string /*@codeaction("r1", "refactor.rewrite.joinLines", result=func_ret)*/, r2, r3 int64, r4 int, r5 int, @@ -40,7 +40,7 @@ func A(a string, b, c int64, x int, y int) ( -- @func_ret/func_ret/func_ret.go -- package func_ret -func A(a string, b, c int64, x int, y int) (r1 string /*@codeaction("r1", "r1", "refactor.rewrite.joinLines", func_ret)*/, r2, r3 int64, r4 int, r5 int) { +func A(a string, b, c int64, x int, y int) (r1 string /*@codeaction("r1", "refactor.rewrite.joinLines", result=func_ret)*/, r2, r3 int64, r4 int, r5 int) { return a, b, c, x, y } @@ -50,20 +50,20 @@ package functype_arg type A func( a string, b, c int64, - x int /*@codeaction("x", "x", "refactor.rewrite.joinLines", functype_arg)*/, + x int /*@codeaction("x", "refactor.rewrite.joinLines", result=functype_arg)*/, y int, ) (r1 string, r2, r3 int64, r4 int, r5 int) -- @functype_arg/functype_arg/functype_arg.go -- package functype_arg -type A func(a string, b, c int64, x int /*@codeaction("x", "x", "refactor.rewrite.joinLines", functype_arg)*/, y int) (r1 string, r2, r3 int64, r4 int, r5 int) +type A func(a string, b, c int64, x int /*@codeaction("x", "refactor.rewrite.joinLines", result=functype_arg)*/, y int) (r1 string, r2, r3 int64, r4 int, r5 int) -- functype_ret/functype_ret.go -- package functype_ret type A func(a string, b, c int64, x int, y int) ( - r1 string /*@codeaction("r1", "r1", "refactor.rewrite.joinLines", functype_ret)*/, + r1 string /*@codeaction("r1", "refactor.rewrite.joinLines", result=functype_ret)*/, r2, r3 int64, r4 int, r5 int, @@ -72,7 +72,7 @@ type A func(a string, b, c int64, x int, y int) ( -- @functype_ret/functype_ret/functype_ret.go -- package functype_ret -type A func(a string, b, c int64, x int, y int) (r1 string /*@codeaction("r1", "r1", "refactor.rewrite.joinLines", functype_ret)*/, r2, r3 int64, r4 int, r5 int) +type A func(a string, b, c int64, x int, y int) (r1 string /*@codeaction("r1", "refactor.rewrite.joinLines", result=functype_ret)*/, r2, r3 int64, r4 int, r5 int) -- func_call/func_call.go -- package func_call @@ -81,7 +81,7 @@ import "fmt" func a() { fmt.Println( - 1 /*@codeaction("1", "1", "refactor.rewrite.joinLines", func_call)*/, + 1 /*@codeaction("1", "refactor.rewrite.joinLines", result=func_call)*/, 2, 3, fmt.Sprintf("hello %d", 4), @@ -94,7 +94,7 @@ package func_call import "fmt" func a() { - fmt.Println(1 /*@codeaction("1", "1", "refactor.rewrite.joinLines", func_call)*/, 2, 3, fmt.Sprintf("hello %d", 4)) + fmt.Println(1 /*@codeaction("1", "refactor.rewrite.joinLines", result=func_call)*/, 2, 3, fmt.Sprintf("hello %d", 4)) } -- indent/indent.go -- @@ -108,7 +108,7 @@ func a() { 2, 3, fmt.Sprintf( - "hello %d" /*@codeaction("hello", "hello", "refactor.rewrite.joinLines", indent)*/, + "hello %d" /*@codeaction("hello", "refactor.rewrite.joinLines", result=indent)*/, 4, )) } @@ -123,7 +123,7 @@ func a() { 1, 2, 3, - fmt.Sprintf("hello %d" /*@codeaction("hello", "hello", "refactor.rewrite.joinLines", indent)*/, 4)) + fmt.Sprintf("hello %d" /*@codeaction("hello", "refactor.rewrite.joinLines", result=indent)*/, 4)) } -- structelts/structelts.go -- @@ -137,7 +137,7 @@ type A struct{ func a() { _ = A{ a: 1, - b: 2 /*@codeaction("b", "b", "refactor.rewrite.joinLines", structelts)*/, + b: 2 /*@codeaction("b", "refactor.rewrite.joinLines", result=structelts)*/, } } @@ -150,7 +150,7 @@ type A struct{ } func a() { - _ = A{a: 1, b: 2 /*@codeaction("b", "b", "refactor.rewrite.joinLines", structelts)*/} + _ = A{a: 1, b: 2 /*@codeaction("b", "refactor.rewrite.joinLines", result=structelts)*/} } -- sliceelts/sliceelts.go -- @@ -158,7 +158,7 @@ package sliceelts func a() { _ = []int{ - 1 /*@codeaction("1", "1", "refactor.rewrite.joinLines", sliceelts)*/, + 1 /*@codeaction("1", "refactor.rewrite.joinLines", result=sliceelts)*/, 2, } } @@ -167,7 +167,7 @@ func a() { package sliceelts func a() { - _ = []int{1 /*@codeaction("1", "1", "refactor.rewrite.joinLines", sliceelts)*/, 2} + _ = []int{1 /*@codeaction("1", "refactor.rewrite.joinLines", result=sliceelts)*/, 2} } -- mapelts/mapelts.go -- @@ -175,7 +175,7 @@ package mapelts func a() { _ = map[string]int{ - "a": 1 /*@codeaction("1", "1", "refactor.rewrite.joinLines", mapelts)*/, + "a": 1 /*@codeaction("1", "refactor.rewrite.joinLines", result=mapelts)*/, "b": 2, } } @@ -184,14 +184,14 @@ func a() { package mapelts func a() { - _ = map[string]int{"a": 1 /*@codeaction("1", "1", "refactor.rewrite.joinLines", mapelts)*/, "b": 2} + _ = map[string]int{"a": 1 /*@codeaction("1", "refactor.rewrite.joinLines", result=mapelts)*/, "b": 2} } -- starcomment/starcomment.go -- package starcomment func A( - /*1*/ x /*2*/ string /*3*/ /*@codeaction("x", "x", "refactor.rewrite.joinLines", starcomment)*/, + /*1*/ x /*2*/ string /*3*/ /*@codeaction("x", "refactor.rewrite.joinLines", result=starcomment)*/, /*4*/ y /*5*/ int /*6*/, ) (string, int) { return x, y @@ -200,7 +200,7 @@ func A( -- @starcomment/starcomment/starcomment.go -- package starcomment -func A(/*1*/ x /*2*/ string /*3*/ /*@codeaction("x", "x", "refactor.rewrite.joinLines", starcomment)*/, /*4*/ y /*5*/ int /*6*/) (string, int) { +func A(/*1*/ x /*2*/ string /*3*/ /*@codeaction("x", "refactor.rewrite.joinLines", result=starcomment)*/, /*4*/ y /*5*/ int /*6*/) (string, int) { return x, y } diff --git a/gopls/internal/test/marker/testdata/codeaction/import-shadows-builtin.txt b/gopls/internal/test/marker/testdata/codeaction/import-shadows-builtin.txt index aeb86a22686..da125d8a534 100644 --- a/gopls/internal/test/marker/testdata/codeaction/import-shadows-builtin.txt +++ b/gopls/internal/test/marker/testdata/codeaction/import-shadows-builtin.txt @@ -30,7 +30,7 @@ var V int -- main.go -- package main -import () //@codeaction("import", "", "source.organizeImports", out) +import () //@codeaction("import", "source.organizeImports", result=out) func main() { complex128.V() //@diag("V", re"type complex128 has no field") @@ -43,7 +43,7 @@ func _() { -- @out/main.go -- package main -import "example.com/complex127" //@codeaction("import", "", "source.organizeImports", out) +import "example.com/complex127" //@codeaction("import", "source.organizeImports", result=out) func main() { complex128.V() //@diag("V", re"type complex128 has no field") diff --git a/gopls/internal/test/marker/testdata/codeaction/imports.txt b/gopls/internal/test/marker/testdata/codeaction/imports.txt index 3d058fb36a1..ce365bd611f 100644 --- a/gopls/internal/test/marker/testdata/codeaction/imports.txt +++ b/gopls/internal/test/marker/testdata/codeaction/imports.txt @@ -6,7 +6,7 @@ module mod.test/imports go 1.18 -- add.go -- -package imports //@codeaction("imports", "", "source.organizeImports", add) +package imports //@codeaction("imports", "source.organizeImports", result=add) import ( "fmt" @@ -18,7 +18,7 @@ func _() { } -- @add/add.go -- -package imports //@codeaction("imports", "", "source.organizeImports", add) +package imports //@codeaction("imports", "source.organizeImports", result=add) import ( "bytes" @@ -31,7 +31,7 @@ func _() { } -- good.go -- -package imports //@codeactionerr("imports", "", "source.organizeImports", re"found 0 CodeActions") +package imports //@codeaction("imports", "source.organizeImports", err=re"found 0 CodeActions") import "fmt" @@ -46,7 +46,7 @@ fmt.Println("") // package doc -package imports //@codeaction("imports", "", "source.organizeImports", issue35458) +package imports //@codeaction("imports", "source.organizeImports", result=issue35458) @@ -66,7 +66,7 @@ func _() { -- @issue35458/issue35458.go -- // package doc -package imports //@codeaction("imports", "", "source.organizeImports", issue35458) +package imports //@codeaction("imports", "source.organizeImports", result=issue35458) @@ -85,7 +85,7 @@ func _() { -- multi.go -- -package imports //@codeaction("imports", "", "source.organizeImports", multi) +package imports //@codeaction("imports", "source.organizeImports", result=multi) import "fmt" @@ -96,7 +96,7 @@ func _() { } -- @multi/multi.go -- -package imports //@codeaction("imports", "", "source.organizeImports", multi) +package imports //@codeaction("imports", "source.organizeImports", result=multi) import "fmt" @@ -107,7 +107,7 @@ func _() { } -- needs.go -- -package imports //@codeaction("package", "", "source.organizeImports", needs) +package imports //@codeaction("package", "source.organizeImports", result=needs) func goodbye() { fmt.Printf("HI") //@diag("fmt", re"(undeclared|undefined)") @@ -115,7 +115,7 @@ func goodbye() { } -- @needs/needs.go -- -package imports //@codeaction("package", "", "source.organizeImports", needs) +package imports //@codeaction("package", "source.organizeImports", result=needs) import ( "fmt" @@ -128,7 +128,7 @@ func goodbye() { } -- remove.go -- -package imports //@codeaction("package", "", "source.organizeImports", remove) +package imports //@codeaction("package", "source.organizeImports", result=remove) import ( "bytes" //@diag("\"bytes\"", re"not used") @@ -140,7 +140,7 @@ func _() { } -- @remove/remove.go -- -package imports //@codeaction("package", "", "source.organizeImports", remove) +package imports //@codeaction("package", "source.organizeImports", result=remove) import ( "fmt" @@ -151,7 +151,7 @@ func _() { } -- removeall.go -- -package imports //@codeaction("package", "", "source.organizeImports", removeall) +package imports //@codeaction("package", "source.organizeImports", result=removeall) import ( "bytes" //@diag("\"bytes\"", re"not used") @@ -163,7 +163,7 @@ func _() { } -- @removeall/removeall.go -- -package imports //@codeaction("package", "", "source.organizeImports", removeall) +package imports //@codeaction("package", "source.organizeImports", result=removeall) //@diag("\"fmt\"", re"not used") @@ -172,4 +172,4 @@ func _() { -- twolines.go -- package imports -func main() {} //@codeactionerr("main", "", "source.organizeImports", re"found 0") +func main() {} //@codeaction("main", "source.organizeImports", err=re"found 0") diff --git a/gopls/internal/test/marker/testdata/codeaction/inline.txt b/gopls/internal/test/marker/testdata/codeaction/inline.txt index 050fe25b8ec..4c2bf15c207 100644 --- a/gopls/internal/test/marker/testdata/codeaction/inline.txt +++ b/gopls/internal/test/marker/testdata/codeaction/inline.txt @@ -9,7 +9,7 @@ go 1.18 package a func _() { - println(add(1, 2)) //@codeaction("add", ")", "refactor.inline.call", inline) + println(add(1, 2)) //@codeaction("add", "refactor.inline.call", end=")", result=inline) } func add(x, y int) int { return x + y } @@ -18,7 +18,7 @@ func add(x, y int) int { return x + y } package a func _() { - println(1 + 2) //@codeaction("add", ")", "refactor.inline.call", inline) + println(1 + 2) //@codeaction("add", "refactor.inline.call", end=")", result=inline) } func add(x, y int) int { return x + y } diff --git a/gopls/internal/test/marker/testdata/codeaction/inline_issue67336.txt b/gopls/internal/test/marker/testdata/codeaction/inline_issue67336.txt new file mode 100644 index 00000000000..daae6e41144 --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/inline_issue67336.txt @@ -0,0 +1,72 @@ +This is the test case from golang/go#67335, where the inlining resulted in bad +formatting. + +-- go.mod -- +module example.com + +go 1.20 + +-- define/my/typ/foo.go -- +package typ +type T int + +-- some/other/pkg/foo.go -- +package pkg +import "context" +import "example.com/define/my/typ" +func Foo(typ.T) context.Context{ return nil } + +-- one/more/pkg/foo.go -- +package pkg +func Bar() {} + +-- to/be/inlined/foo.go -- +package inlined + +import "context" +import "example.com/some/other/pkg" +import "example.com/define/my/typ" + +func Baz(ctx context.Context) context.Context { + return pkg.Foo(typ.T(5)) +} + +-- b/c/foo.go -- +package c +import ( + "context" + "example.com/to/be/inlined" + "example.com/one/more/pkg" +) + +const ( + // This is a variable + someConst = 5 +) + +func foo() { + inlined.Baz(context.TODO()) //@ codeaction("Baz", "refactor.inline.call", result=inline) + pkg.Bar() +} + +-- @inline/b/c/foo.go -- +package c + +import ( + "context" + + "example.com/define/my/typ" + "example.com/one/more/pkg" + pkg0 "example.com/some/other/pkg" +) + +const ( + // This is a variable + someConst = 5 +) + +func foo() { + var _ context.Context = context.TODO() + pkg0.Foo(typ.T(5)) //@ codeaction("Baz", "refactor.inline.call", result=inline) + pkg.Bar() +} diff --git a/gopls/internal/test/marker/testdata/codeaction/inline_issue68554.txt b/gopls/internal/test/marker/testdata/codeaction/inline_issue68554.txt new file mode 100644 index 00000000000..49b18b27935 --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/inline_issue68554.txt @@ -0,0 +1,38 @@ +This test checks that inlining removes unnecessary interface conversions. + +-- main.go -- +package main + +import ( + "fmt" + "io" +) + +func f(d discard) { + g(d) //@codeaction("g", "refactor.inline.call", result=out) +} + +func g(w io.Writer) { fmt.Println(w) } + +var d discard +type discard struct{} +func (discard) Write(p []byte) (int, error) { return len(p), nil } +-- @out/main.go -- +package main + +import ( + "fmt" + "io" +) + +func f(d discard) { + fmt.Println(d) //@codeaction("g", "refactor.inline.call", result=out) +} + +func g(w io.Writer) { fmt.Println(w) } + +var d discard + +type discard struct{} + +func (discard) Write(p []byte) (int, error) { return len(p), nil } diff --git a/gopls/internal/test/marker/testdata/codeaction/inline_resolve.txt b/gopls/internal/test/marker/testdata/codeaction/inline_resolve.txt index fa8476e91f6..c889ed8bba3 100644 --- a/gopls/internal/test/marker/testdata/codeaction/inline_resolve.txt +++ b/gopls/internal/test/marker/testdata/codeaction/inline_resolve.txt @@ -20,7 +20,7 @@ go 1.18 package a func _() { - println(add(1, 2)) //@codeaction("add", ")", "refactor.inline.call", inline) + println(add(1, 2)) //@codeaction("add", "refactor.inline.call", end=")", result=inline) } func add(x, y int) int { return x + y } @@ -29,7 +29,7 @@ func add(x, y int) int { return x + y } package a func _() { - println(1 + 2) //@codeaction("add", ")", "refactor.inline.call", inline) + println(1 + 2) //@codeaction("add", "refactor.inline.call", end=")", result=inline) } func add(x, y int) int { return x + y } diff --git a/gopls/internal/test/marker/testdata/codeaction/invertif.txt b/gopls/internal/test/marker/testdata/codeaction/invertif.txt index 02f856f6977..6838d94b333 100644 --- a/gopls/internal/test/marker/testdata/codeaction/invertif.txt +++ b/gopls/internal/test/marker/testdata/codeaction/invertif.txt @@ -10,7 +10,7 @@ import ( func Boolean() { b := true - if b { //@codeactionedit("if b", "refactor.rewrite.invertIf", boolean) + if b { //@codeaction("if b", "refactor.rewrite.invertIf", edit=boolean) fmt.Println("A") } else { fmt.Println("B") @@ -18,7 +18,7 @@ func Boolean() { } func BooleanFn() { - if os.IsPathSeparator('X') { //@codeactionedit("if os.IsPathSeparator('X')", "refactor.rewrite.invertIf", boolean_fn) + if os.IsPathSeparator('X') { //@codeaction("if os.IsPathSeparator('X')", "refactor.rewrite.invertIf", edit=boolean_fn) fmt.Println("A") } else { fmt.Println("B") @@ -30,7 +30,7 @@ func DontRemoveParens() { a := false b := true if !(a || - b) { //@codeactionedit("b", "refactor.rewrite.invertIf", dont_remove_parens) + b) { //@codeaction("b", "refactor.rewrite.invertIf", edit=dont_remove_parens) fmt.Println("A") } else { fmt.Println("B") @@ -46,7 +46,7 @@ func ElseIf() { // No inversion expected for else-if, that would become unreadable if len(os.Args) > 2 { fmt.Println("A") - } else if os.Args[0] == "X" { //@codeactionedit(re"if os.Args.0. == .X.", "refactor.rewrite.invertIf", else_if) + } else if os.Args[0] == "X" { //@codeaction(re"if os.Args.0. == .X.", "refactor.rewrite.invertIf", edit=else_if) fmt.Println("B") } else { fmt.Println("C") @@ -54,7 +54,7 @@ func ElseIf() { } func GreaterThan() { - if len(os.Args) > 2 { //@codeactionedit("i", "refactor.rewrite.invertIf", greater_than) + if len(os.Args) > 2 { //@codeaction("i", "refactor.rewrite.invertIf", edit=greater_than) fmt.Println("A") } else { fmt.Println("B") @@ -63,7 +63,7 @@ func GreaterThan() { func NotBoolean() { b := true - if !b { //@codeactionedit("if !b", "refactor.rewrite.invertIf", not_boolean) + if !b { //@codeaction("if !b", "refactor.rewrite.invertIf", edit=not_boolean) fmt.Println("A") } else { fmt.Println("B") @@ -71,7 +71,7 @@ func NotBoolean() { } func RemoveElse() { - if true { //@codeactionedit("if true", "refactor.rewrite.invertIf", remove_else) + if true { //@codeaction("if true", "refactor.rewrite.invertIf", edit=remove_else) fmt.Println("A") } else { fmt.Println("B") @@ -83,7 +83,7 @@ func RemoveElse() { func RemoveParens() { b := true - if !(b) { //@codeactionedit("if", "refactor.rewrite.invertIf", remove_parens) + if !(b) { //@codeaction("if", "refactor.rewrite.invertIf", edit=remove_parens) fmt.Println("A") } else { fmt.Println("B") @@ -91,7 +91,7 @@ func RemoveParens() { } func Semicolon() { - if _, err := fmt.Println("x"); err != nil { //@codeactionedit("if", "refactor.rewrite.invertIf", semicolon) + if _, err := fmt.Println("x"); err != nil { //@codeaction("if", "refactor.rewrite.invertIf", edit=semicolon) fmt.Println("A") } else { fmt.Println("B") @@ -99,7 +99,7 @@ func Semicolon() { } func SemicolonAnd() { - if n, err := fmt.Println("x"); err != nil && n > 0 { //@codeactionedit("f", "refactor.rewrite.invertIf", semicolon_and) + if n, err := fmt.Println("x"); err != nil && n > 0 { //@codeaction("f", "refactor.rewrite.invertIf", edit=semicolon_and) fmt.Println("A") } else { fmt.Println("B") @@ -107,7 +107,7 @@ func SemicolonAnd() { } func SemicolonOr() { - if n, err := fmt.Println("x"); err != nil || n < 5 { //@codeactionedit(re"if n, err := fmt.Println..x..; err != nil .. n < 5", "refactor.rewrite.invertIf", semicolon_or) + if n, err := fmt.Println("x"); err != nil || n < 5 { //@codeaction(re"if n, err := fmt.Println..x..; err != nil .. n < 5", "refactor.rewrite.invertIf", edit=semicolon_or) fmt.Println("A") } else { fmt.Println("B") @@ -116,103 +116,103 @@ func SemicolonOr() { -- @boolean/p.go -- @@ -10,3 +10 @@ -- if b { //@codeactionedit("if b", "refactor.rewrite.invertIf", boolean) +- if b { //@codeaction("if b", "refactor.rewrite.invertIf", edit=boolean) - fmt.Println("A") - } else { + if !b { @@ -14 +12,2 @@ -+ } else { //@codeactionedit("if b", "refactor.rewrite.invertIf", boolean) ++ } else { //@codeaction("if b", "refactor.rewrite.invertIf", edit=boolean) + fmt.Println("A") -- @boolean_fn/p.go -- @@ -18,3 +18 @@ -- if os.IsPathSeparator('X') { //@codeactionedit("if os.IsPathSeparator('X')", "refactor.rewrite.invertIf", boolean_fn) +- if os.IsPathSeparator('X') { //@codeaction("if os.IsPathSeparator('X')", "refactor.rewrite.invertIf", edit=boolean_fn) - fmt.Println("A") - } else { + if !os.IsPathSeparator('X') { @@ -22 +20,2 @@ -+ } else { //@codeactionedit("if os.IsPathSeparator('X')", "refactor.rewrite.invertIf", boolean_fn) ++ } else { //@codeaction("if os.IsPathSeparator('X')", "refactor.rewrite.invertIf", edit=boolean_fn) + fmt.Println("A") -- @dont_remove_parens/p.go -- @@ -29,4 +29,2 @@ - if !(a || -- b) { //@codeactionedit("b", "refactor.rewrite.invertIf", dont_remove_parens) +- b) { //@codeaction("b", "refactor.rewrite.invertIf", edit=dont_remove_parens) - fmt.Println("A") - } else { + if (a || + b) { @@ -34 +32,2 @@ -+ } else { //@codeactionedit("b", "refactor.rewrite.invertIf", dont_remove_parens) ++ } else { //@codeaction("b", "refactor.rewrite.invertIf", edit=dont_remove_parens) + fmt.Println("A") -- @else_if/p.go -- @@ -46,3 +46 @@ -- } else if os.Args[0] == "X" { //@codeactionedit(re"if os.Args.0. == .X.", "refactor.rewrite.invertIf", else_if) +- } else if os.Args[0] == "X" { //@codeaction(re"if os.Args.0. == .X.", "refactor.rewrite.invertIf", edit=else_if) - fmt.Println("B") - } else { + } else if os.Args[0] != "X" { @@ -50 +48,2 @@ -+ } else { //@codeactionedit(re"if os.Args.0. == .X.", "refactor.rewrite.invertIf", else_if) ++ } else { //@codeaction(re"if os.Args.0. == .X.", "refactor.rewrite.invertIf", edit=else_if) + fmt.Println("B") -- @greater_than/p.go -- @@ -54,3 +54 @@ -- if len(os.Args) > 2 { //@codeactionedit("i", "refactor.rewrite.invertIf", greater_than) +- if len(os.Args) > 2 { //@codeaction("i", "refactor.rewrite.invertIf", edit=greater_than) - fmt.Println("A") - } else { + if len(os.Args) <= 2 { @@ -58 +56,2 @@ -+ } else { //@codeactionedit("i", "refactor.rewrite.invertIf", greater_than) ++ } else { //@codeaction("i", "refactor.rewrite.invertIf", edit=greater_than) + fmt.Println("A") -- @not_boolean/p.go -- @@ -63,3 +63 @@ -- if !b { //@codeactionedit("if !b", "refactor.rewrite.invertIf", not_boolean) +- if !b { //@codeaction("if !b", "refactor.rewrite.invertIf", edit=not_boolean) - fmt.Println("A") - } else { + if b { @@ -67 +65,2 @@ -+ } else { //@codeactionedit("if !b", "refactor.rewrite.invertIf", not_boolean) ++ } else { //@codeaction("if !b", "refactor.rewrite.invertIf", edit=not_boolean) + fmt.Println("A") -- @remove_else/p.go -- @@ -71,3 +71 @@ -- if true { //@codeactionedit("if true", "refactor.rewrite.invertIf", remove_else) +- if true { //@codeaction("if true", "refactor.rewrite.invertIf", edit=remove_else) - fmt.Println("A") - } else { + if false { @@ -78 +76,3 @@ -+ //@codeactionedit("if true", "refactor.rewrite.invertIf", remove_else) ++ //@codeaction("if true", "refactor.rewrite.invertIf", edit=remove_else) + fmt.Println("A") + -- @remove_parens/p.go -- @@ -83,3 +83 @@ -- if !(b) { //@codeactionedit("if", "refactor.rewrite.invertIf", remove_parens) +- if !(b) { //@codeaction("if", "refactor.rewrite.invertIf", edit=remove_parens) - fmt.Println("A") - } else { + if b { @@ -87 +85,2 @@ -+ } else { //@codeactionedit("if", "refactor.rewrite.invertIf", remove_parens) ++ } else { //@codeaction("if", "refactor.rewrite.invertIf", edit=remove_parens) + fmt.Println("A") -- @semicolon/p.go -- @@ -91,3 +91 @@ -- if _, err := fmt.Println("x"); err != nil { //@codeactionedit("if", "refactor.rewrite.invertIf", semicolon) +- if _, err := fmt.Println("x"); err != nil { //@codeaction("if", "refactor.rewrite.invertIf", edit=semicolon) - fmt.Println("A") - } else { + if _, err := fmt.Println("x"); err == nil { @@ -95 +93,2 @@ -+ } else { //@codeactionedit("if", "refactor.rewrite.invertIf", semicolon) ++ } else { //@codeaction("if", "refactor.rewrite.invertIf", edit=semicolon) + fmt.Println("A") -- @semicolon_and/p.go -- @@ -99,3 +99 @@ -- if n, err := fmt.Println("x"); err != nil && n > 0 { //@codeactionedit("f", "refactor.rewrite.invertIf", semicolon_and) +- if n, err := fmt.Println("x"); err != nil && n > 0 { //@codeaction("f", "refactor.rewrite.invertIf", edit=semicolon_and) - fmt.Println("A") - } else { + if n, err := fmt.Println("x"); err == nil || n <= 0 { @@ -103 +101,2 @@ -+ } else { //@codeactionedit("f", "refactor.rewrite.invertIf", semicolon_and) ++ } else { //@codeaction("f", "refactor.rewrite.invertIf", edit=semicolon_and) + fmt.Println("A") -- @semicolon_or/p.go -- @@ -107,3 +107 @@ -- if n, err := fmt.Println("x"); err != nil || n < 5 { //@codeactionedit(re"if n, err := fmt.Println..x..; err != nil .. n < 5", "refactor.rewrite.invertIf", semicolon_or) +- if n, err := fmt.Println("x"); err != nil || n < 5 { //@codeaction(re"if n, err := fmt.Println..x..; err != nil .. n < 5", "refactor.rewrite.invertIf", edit=semicolon_or) - fmt.Println("A") - } else { + if n, err := fmt.Println("x"); err == nil && n >= 5 { @@ -111 +109,2 @@ -+ } else { //@codeactionedit(re"if n, err := fmt.Println..x..; err != nil .. n < 5", "refactor.rewrite.invertIf", semicolon_or) ++ } else { //@codeaction(re"if n, err := fmt.Println..x..; err != nil .. n < 5", "refactor.rewrite.invertIf", edit=semicolon_or) + fmt.Println("A") diff --git a/gopls/internal/test/marker/testdata/codeaction/issue64558.txt b/gopls/internal/test/marker/testdata/codeaction/issue64558.txt index 7ca661fbf00..a5a6594e74a 100644 --- a/gopls/internal/test/marker/testdata/codeaction/issue64558.txt +++ b/gopls/internal/test/marker/testdata/codeaction/issue64558.txt @@ -8,7 +8,7 @@ go 1.18 package a func _() { - f(1, 2) //@ diag("2", re"too many arguments"), codeactionerr("f", ")", "refactor.inline.call", re`inlining failed \("args/params mismatch"\), likely because inputs were ill-typed`) + f(1, 2) //@ diag("2", re"too many arguments"), codeaction("f", "refactor.inline.call", end=")", err=re`inlining failed \("too many arguments"\), likely because inputs were ill-typed`) } func f(int) {} diff --git a/gopls/internal/test/marker/testdata/codeaction/issue70268.txt b/gopls/internal/test/marker/testdata/codeaction/issue70268.txt new file mode 100644 index 00000000000..464f0eb01d8 --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/issue70268.txt @@ -0,0 +1,33 @@ +This test verifies the remove of unused parameters in case of syntax errors. +Issue golang/go#70268. + +-- go.mod -- +module unused.mod + +go 1.21 + +-- a/a.go -- +package a + +func A(x, unused int) int { //@codeaction("unused", "refactor.rewrite.removeUnusedParam", result=a) + return x +} + +-- @a/a/a.go -- +package a + +func A(x int) int { //@codeaction("unused", "refactor.rewrite.removeUnusedParam", result=a) + return x +} + +-- b/b.go -- +package b + +import "unused.mod/a" + +func main(){ + a.A/*dsdd*/(/*cccc*/ 1, + + + ) //@diag(")", re"not enough arguments") +} diff --git a/gopls/internal/test/marker/testdata/codeaction/moveparam.txt b/gopls/internal/test/marker/testdata/codeaction/moveparam.txt new file mode 100644 index 00000000000..2cc0cd8244f --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/moveparam.txt @@ -0,0 +1,178 @@ +This test checks basic functionality of the "move parameter left/right" code +action. + +Note that in many of these tests, a permutation can either be expressed as +a parameter move left or right. In these cases, the codeaction assertions +deliberately share the same golden data. + +-- go.mod -- +module example.com/moveparam + +go 1.19 + +-- basic/basic.go -- +package basic + +func Foo(a, b int) int { //@codeaction("a", "refactor.rewrite.moveParamRight", result=basic), codeaction("b", "refactor.rewrite.moveParamLeft", result=basic) + return a + b +} + +func _() { + x, y := 1, 2 + z := Foo(x, y) + _ = z +} + +-- basic/caller/caller.go -- +package caller + +import "example.com/moveparam/basic" + +func a() int { return 1 } +func b() int { return 2 } + +// Check that we can refactor a call in a toplevel var decl. +var _ = basic.Foo(1, 2) + +// Check that we can refactor a call with effects in a toplevel var decl. +var _ = basic.Foo(a(), b()) + +func _() { + // check various refactorings in a function body, and comment handling. + _ = basic.Foo(1, 2) // with comments + // another comment + _ = basic.Foo(3, 4) + x := 4 + x = basic.Foo(x /* this is an inline comment */, 5) +} + +-- @basic/basic/basic.go -- +package basic + +func Foo(b, a int) int { //@codeaction("a", "refactor.rewrite.moveParamRight", result=basic), codeaction("b", "refactor.rewrite.moveParamLeft", result=basic) + return a + b +} + +func _() { + x, y := 1, 2 + z := Foo(y, x) + _ = z +} +-- @basic/basic/caller/caller.go -- +package caller + +import "example.com/moveparam/basic" + +func a() int { return 1 } +func b() int { return 2 } + +// Check that we can refactor a call in a toplevel var decl. +var _ = basic.Foo(2, 1) + +// Check that we can refactor a call with effects in a toplevel var decl. +var _ = basic.Foo(b(), a()) + +func _() { + // check various refactorings in a function body, and comment handling. + _ = basic.Foo(2, 1) // with comments + // another comment + _ = basic.Foo(4, 3) + x := 4 + x = basic.Foo(5, x) +} +-- method/method.go -- +package method + +type T struct{} + +func (T) Foo(a, b int) {} //@codeaction("a", "refactor.rewrite.moveParamRight", result=method), codeaction("b", "refactor.rewrite.moveParamLeft", result=method) + +func _() { + var t T + t.Foo(1, 2) + // TODO(rfindley): test method expressions here, once they are handled. +} + +-- method/caller/caller.go -- +package caller + +import "example.com/moveparam/method" + +func _() { + var t method.T + t.Foo(1, 2) +} + +-- @method/method/caller/caller.go -- +package caller + +import "example.com/moveparam/method" + +func _() { + var t method.T + t.Foo(2, 1) +} +-- @method/method/method.go -- +package method + +type T struct{} + +func (T) Foo(b, a int) {} //@codeaction("a", "refactor.rewrite.moveParamRight", result=method), codeaction("b", "refactor.rewrite.moveParamLeft", result=method) + +func _() { + var t T + t.Foo(2, 1) + // TODO(rfindley): test method expressions here, once they are handled. +} +-- fieldlist/joinfield.go -- +package fieldlist + +func JoinField(a int, b string, c int) {} //@codeaction("a", "refactor.rewrite.moveParamRight", result=joinfield), codeaction("b", "refactor.rewrite.moveParamLeft", result=joinfield) + +func _() { + JoinField(1, "2", 3) +} + +-- @joinfield/fieldlist/joinfield.go -- +package fieldlist + +func JoinField(b string, a, c int) {} //@codeaction("a", "refactor.rewrite.moveParamRight", result=joinfield), codeaction("b", "refactor.rewrite.moveParamLeft", result=joinfield) + +func _() { + JoinField("2", 1, 3) +} +-- fieldlist/splitfield.go -- +package fieldlist + +func SplitField(a int, b, c string) {} //@codeaction("a", "refactor.rewrite.moveParamRight", result=splitfield), codeaction("b", "refactor.rewrite.moveParamLeft", result=splitfield) + +func _() { + SplitField(1, "2", "3") +} + +-- @splitfield/fieldlist/splitfield.go -- +package fieldlist + +func SplitField(b string, a int, c string) {} //@codeaction("a", "refactor.rewrite.moveParamRight", result=splitfield), codeaction("b", "refactor.rewrite.moveParamLeft", result=splitfield) + +func _() { + SplitField("2", 1, "3") +} +-- unnamed/unnamed.go -- +package unnamed + +func Unnamed(int, string) { //@codeaction("int", "refactor.rewrite.moveParamRight", result=unnamed) +} + +func _() { + Unnamed(1, "hi") +} +-- @unnamed/unnamed/unnamed.go -- +package unnamed + +func Unnamed(string, int) { //@codeaction("int", "refactor.rewrite.moveParamRight", result=unnamed) +} + +func _() { + Unnamed("hi", 1) +} diff --git a/gopls/internal/test/marker/testdata/codeaction/moveparam_issue70599.txt b/gopls/internal/test/marker/testdata/codeaction/moveparam_issue70599.txt new file mode 100644 index 00000000000..71510c7bb64 --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/moveparam_issue70599.txt @@ -0,0 +1,99 @@ +This test checks the fixes for bugs encountered while bug-bashing on the +movement refactoring. + +-- go.mod -- +module example.com + +go 1.21 + +-- unnecessaryconversion.go -- +package a + +// We should not add unnecessary conversions to concrete arguments to concrete +// parameters when the parameter use is in assignment context. + +type Hash [32]byte + +func Cache(key [32]byte, value any) { //@codeaction("key", "refactor.rewrite.moveParamRight", result=conversion) + // Not implemented. +} + +func _() { + var k Hash + Cache(k, 0) + Cache(Hash{}, 1) + Cache([32]byte{}, 2) +} + +-- @conversion/unnecessaryconversion.go -- +package a + +// We should not add unnecessary conversions to concrete arguments to concrete +// parameters when the parameter use is in assignment context. + +type Hash [32]byte + +func Cache(value any, key [32]byte) { //@codeaction("key", "refactor.rewrite.moveParamRight", result=conversion) + // Not implemented. +} + +func _() { + var k Hash + Cache(0, k) + Cache(1, Hash{}) + Cache(2, [32]byte{}) +} +-- shortvardecl.go -- +package a + +func Short(x, y int) (int, int) { //@codeaction("x", "refactor.rewrite.moveParamRight", result=short) + return x, y +} + +func _() { + x, y := Short(0, 1) + _, _ = x, y +} + +func _() { + var x, y int + x, y = Short(0, 1) + _, _ = x, y +} + +func _() { + _, _ = Short(0, 1) +} +-- @short/shortvardecl.go -- +package a + +func Short(y, x int) (int, int) { //@codeaction("x", "refactor.rewrite.moveParamRight", result=short) + return x, y +} + +func _() { + x, y := Short(1, 0) + _, _ = x, y +} + +func _() { + var x, y int + x, y = Short(1, 0) + _, _ = x, y +} + +func _() { + _, _ = Short(1, 0) +} +-- variadic.go -- +package a + +// We should not offer movement involving variadic parameters if it is not well +// supported. + +func Variadic(x int, y ...string) { //@codeaction("x", "refactor.rewrite.moveParamRight", err="0 CodeActions"), codeaction("y", "refactor.rewrite.moveParamLeft", err="0 CodeActions") +} + +func _() { + Variadic(1, "a", "b") +} diff --git a/gopls/internal/test/marker/testdata/codeaction/removeparam.txt b/gopls/internal/test/marker/testdata/codeaction/removeparam.txt index 2b78b882df6..c8fddb0fff7 100644 --- a/gopls/internal/test/marker/testdata/codeaction/removeparam.txt +++ b/gopls/internal/test/marker/testdata/codeaction/removeparam.txt @@ -9,14 +9,14 @@ go 1.18 -- a/a.go -- package a -func A(x, unused int) int { //@codeaction("unused", "unused", "refactor.rewrite.removeUnusedParam", a) +func A(x, unused int) int { //@codeaction("unused", "refactor.rewrite.removeUnusedParam", result=a) return x } -- @a/a/a.go -- package a -func A(x int) int { //@codeaction("unused", "unused", "refactor.rewrite.removeUnusedParam", a) +func A(x int) int { //@codeaction("unused", "refactor.rewrite.removeUnusedParam", result=a) return x } @@ -99,7 +99,7 @@ func _() { -- field/field.go -- package field -func Field(x int, field int) { //@codeaction("int", "int", "refactor.rewrite.removeUnusedParam", field) +func Field(x int, field int) { //@codeaction("int", "refactor.rewrite.removeUnusedParam", result=field) } func _() { @@ -108,7 +108,7 @@ func _() { -- @field/field/field.go -- package field -func Field(field int) { //@codeaction("int", "int", "refactor.rewrite.removeUnusedParam", field) +func Field(field int) { //@codeaction("int", "refactor.rewrite.removeUnusedParam", result=field) } func _() { @@ -117,7 +117,7 @@ func _() { -- ellipsis/ellipsis.go -- package ellipsis -func Ellipsis(...any) { //@codeaction("any", "any", "refactor.rewrite.removeUnusedParam", ellipsis) +func Ellipsis(...any) { //@codeaction("any", "refactor.rewrite.removeUnusedParam", result=ellipsis) } func _() { @@ -138,7 +138,7 @@ func i() []any -- @ellipsis/ellipsis/ellipsis.go -- package ellipsis -func Ellipsis() { //@codeaction("any", "any", "refactor.rewrite.removeUnusedParam", ellipsis) +func Ellipsis() { //@codeaction("any", "refactor.rewrite.removeUnusedParam", result=ellipsis) } func _() { @@ -146,12 +146,10 @@ func _() { Ellipsis() Ellipsis() Ellipsis() - var _ []any = []any{1, f(), g()} Ellipsis() func(_ ...any) { Ellipsis() }(h()) - var _ []any = i() Ellipsis() } @@ -162,7 +160,7 @@ func i() []any -- ellipsis2/ellipsis2.go -- package ellipsis2 -func Ellipsis2(_, _ int, rest ...int) { //@codeaction("_", "_", "refactor.rewrite.removeUnusedParam", ellipsis2) +func Ellipsis2(_, _ int, rest ...int) { //@codeaction("_", "refactor.rewrite.removeUnusedParam", result=ellipsis2) } func _() { @@ -176,11 +174,11 @@ func h() (int, int) -- @ellipsis2/ellipsis2/ellipsis2.go -- package ellipsis2 -func Ellipsis2(_ int, rest ...int) { //@codeaction("_", "_", "refactor.rewrite.removeUnusedParam", ellipsis2) +func Ellipsis2(_ int, rest ...int) { //@codeaction("_", "refactor.rewrite.removeUnusedParam", result=ellipsis2) } func _() { - Ellipsis2(2, []int{3}...) + Ellipsis2(2, 3) func(_, blank0 int, rest ...int) { Ellipsis2(blank0, rest...) }(h()) @@ -191,7 +189,7 @@ func h() (int, int) -- overlapping/overlapping.go -- package overlapping -func Overlapping(i int) int { //@codeactionerr(re"(i) int", re"(i) int", "refactor.rewrite.removeUnusedParam", re"overlapping") +func Overlapping(i int) int { //@codeaction(re"(i) int", "refactor.rewrite.removeUnusedParam", err=re"overlapping") return 0 } @@ -203,7 +201,7 @@ func _() { -- effects/effects.go -- package effects -func effects(x, y int) int { //@ diag("y", re"unused"), codeaction("y", "y", "refactor.rewrite.removeUnusedParam", effects) +func effects(x, y int) int { //@ diag("y", re"unused"), codeaction("y", "refactor.rewrite.removeUnusedParam", result=effects) return x } @@ -217,7 +215,7 @@ func _() { -- @effects/effects/effects.go -- package effects -func effects(x int) int { //@ diag("y", re"unused"), codeaction("y", "y", "refactor.rewrite.removeUnusedParam", effects) +func effects(x int) int { //@ diag("y", re"unused"), codeaction("y", "refactor.rewrite.removeUnusedParam", result=effects) return x } @@ -225,23 +223,19 @@ func f() int func g() int func _() { - var x, _ int = f(), g() - effects(x) - { - var x, _ int = f(), g() - effects(x) - } + effects(f()) + effects(f()) } -- recursive/recursive.go -- package recursive -func Recursive(x int) int { //@codeaction("x", "x", "refactor.rewrite.removeUnusedParam", recursive) +func Recursive(x int) int { //@codeaction("x", "refactor.rewrite.removeUnusedParam", result=recursive) return Recursive(1) } -- @recursive/recursive/recursive.go -- package recursive -func Recursive() int { //@codeaction("x", "x", "refactor.rewrite.removeUnusedParam", recursive) +func Recursive() int { //@codeaction("x", "refactor.rewrite.removeUnusedParam", result=recursive) return Recursive() } diff --git a/gopls/internal/test/marker/testdata/codeaction/removeparam_formatting.txt b/gopls/internal/test/marker/testdata/codeaction/removeparam_formatting.txt index b192d79b584..084797e1b33 100644 --- a/gopls/internal/test/marker/testdata/codeaction/removeparam_formatting.txt +++ b/gopls/internal/test/marker/testdata/codeaction/removeparam_formatting.txt @@ -14,7 +14,7 @@ go 1.18 package a // A doc comment. -func A(x /* used parameter */, unused int /* unused parameter */ ) int { //@codeaction("unused", "unused", "refactor.rewrite.removeUnusedParam", a) +func A(x /* used parameter */, unused int /* unused parameter */ ) int { //@codeaction("unused", "refactor.rewrite.removeUnusedParam", result=a) // about to return return x // returning // just returned @@ -36,7 +36,7 @@ func one() int { package a // A doc comment. -func A(x int) int { //@codeaction("unused", "unused", "refactor.rewrite.removeUnusedParam", a) +func A(x int) int { //@codeaction("unused", "refactor.rewrite.removeUnusedParam", result=a) // about to return return x // returning // just returned diff --git a/gopls/internal/test/marker/testdata/codeaction/removeparam_funcvalue.txt b/gopls/internal/test/marker/testdata/codeaction/removeparam_funcvalue.txt index ec8f63c34b3..19fbd69a6f5 100644 --- a/gopls/internal/test/marker/testdata/codeaction/removeparam_funcvalue.txt +++ b/gopls/internal/test/marker/testdata/codeaction/removeparam_funcvalue.txt @@ -10,7 +10,7 @@ go 1.18 -- a/a.go -- package a -func A(x, unused int) int { //@codeactionerr("unused", "unused", "refactor.rewrite.removeUnusedParam", re"non-call function reference") +func A(x, unused int) int { //@codeaction("unused", "refactor.rewrite.removeUnusedParam", err=re"non-call function reference") return x } diff --git a/gopls/internal/test/marker/testdata/codeaction/removeparam_imports.txt b/gopls/internal/test/marker/testdata/codeaction/removeparam_imports.txt index 9bad8232231..d9f4f22dc7e 100644 --- a/gopls/internal/test/marker/testdata/codeaction/removeparam_imports.txt +++ b/gopls/internal/test/marker/testdata/codeaction/removeparam_imports.txt @@ -59,37 +59,21 @@ import "mod.test/c" var Chan chan c.C -func B(x, y c.C) { //@codeaction("x", "x", "refactor.rewrite.removeUnusedParam", b) +func B(x, y c.C) { //@codeaction("x", "refactor.rewrite.removeUnusedParam", result=b) } --- c/c.go -- -package c - -type C int - --- d/d.go -- -package d - -// Removing the parameter should remove this import. -import "mod.test/c" - -func D(x c.C) { //@codeaction("x", "x", "refactor.rewrite.removeUnusedParam", d) -} - -func _() { - D(1) -} - --- @b/a/a1.go -- +-- @b/a/a3.go -- package a import ( "mod.test/b" - "mod.test/c" ) func _() { - var _ c.C = <-b.Chan + b.B(<-b.Chan) +} + +func _() { b.B(<-b.Chan) } -- @b/a/a2.go -- @@ -97,30 +81,20 @@ package a import ( "mod.test/b" - "mod.test/c" ) func _() { - var _ c.C = <-b.Chan b.B(<-b.Chan) - var _ c.C = <-b.Chan b.B(<-b.Chan) } --- @b/a/a3.go -- +-- @b/a/a1.go -- package a import ( "mod.test/b" - "mod.test/c" ) func _() { - var _ c.C = <-b.Chan - b.B(<-b.Chan) -} - -func _() { - var _ c.C = <-b.Chan b.B(<-b.Chan) } -- @b/a/a4.go -- @@ -131,11 +105,9 @@ package a import ( "mod.test/b" . "mod.test/b" - "mod.test/c" ) func _() { - var _ c.C = <-Chan b.B(<-Chan) } -- @b/b/b.go -- @@ -145,14 +117,32 @@ import "mod.test/c" var Chan chan c.C -func B(y c.C) { //@codeaction("x", "x", "refactor.rewrite.removeUnusedParam", b) +func B(y c.C) { //@codeaction("x", "refactor.rewrite.removeUnusedParam", result=b) } +-- c/c.go -- +package c + +type C int + +-- d/d.go -- +package d + +// Removing the parameter should remove this import. +import "mod.test/c" + +func D(x c.C) { //@codeaction("x", "refactor.rewrite.removeUnusedParam", result=d) +} + +func _() { + D(1) +} + -- @d/d/d.go -- package d // Removing the parameter should remove this import. -func D() { //@codeaction("x", "x", "refactor.rewrite.removeUnusedParam", d) +func D() { //@codeaction("x", "refactor.rewrite.removeUnusedParam", result=d) } func _() { diff --git a/gopls/internal/test/marker/testdata/codeaction/removeparam_issue65217.txt b/gopls/internal/test/marker/testdata/codeaction/removeparam_issue65217.txt index f2ecae4ad1c..93729577444 100644 --- a/gopls/internal/test/marker/testdata/codeaction/removeparam_issue65217.txt +++ b/gopls/internal/test/marker/testdata/codeaction/removeparam_issue65217.txt @@ -27,7 +27,7 @@ func _() { _ = i } -func f(unused S, i int) int { //@codeaction("unused", "unused", "refactor.rewrite.removeUnusedParam", rewrite), diag("unused", re`unused`) +func f(unused S, i int) int { //@codeaction("unused", "refactor.rewrite.removeUnusedParam", result=rewrite), diag("unused", re`unused`) return i } @@ -48,11 +48,10 @@ func _() { func _() { var s S - var _ S = s i := f(s.Int()) _ = i } -func f(i int) int { //@codeaction("unused", "unused", "refactor.rewrite.removeUnusedParam", rewrite), diag("unused", re`unused`) +func f(i int) int { //@codeaction("unused", "refactor.rewrite.removeUnusedParam", result=rewrite), diag("unused", re`unused`) return i } diff --git a/gopls/internal/test/marker/testdata/codeaction/removeparam_method.txt b/gopls/internal/test/marker/testdata/codeaction/removeparam_method.txt index 614c4d3147f..9b01edd5ae8 100644 --- a/gopls/internal/test/marker/testdata/codeaction/removeparam_method.txt +++ b/gopls/internal/test/marker/testdata/codeaction/removeparam_method.txt @@ -4,6 +4,7 @@ Specifically, check 1. basic removal of unused parameters, when the receiver is named, locally and across package boundaries 2. handling of unnamed receivers +3. no panics related to references through interface satisfaction -- go.mod -- module example.com/rm @@ -15,7 +16,7 @@ package rm type Basic int -func (t Basic) Foo(x int) { //@codeaction("x", "x", "refactor.rewrite.removeUnusedParam", basic) +func (t Basic) Foo(x int) { //@codeaction("x", "refactor.rewrite.removeUnusedParam", result=basic) } func _(b Basic) { @@ -37,12 +38,22 @@ func _() { func sideEffects() int +type Fooer interface { + Foo(int) +} + +// Dynamic calls aren't rewritten. +// Previously, this would cause a bug report or crash (golang/go#69896). +func _(f Fooer) { + f.Foo(1) +} + -- @basic/basic.go -- package rm type Basic int -func (t Basic) Foo() { //@codeaction("x", "x", "refactor.rewrite.removeUnusedParam", basic) +func (t Basic) Foo() { //@codeaction("x", "refactor.rewrite.removeUnusedParam", result=basic) } func _(b Basic) { @@ -57,15 +68,21 @@ import "example.com/rm" func _() { x := new(rm.Basic) - var ( - t rm.Basic = *x - _ int = sideEffects() - ) - t.Foo() + x.Foo() rm.Basic(1).Foo() } func sideEffects() int + +type Fooer interface { + Foo(int) +} + +// Dynamic calls aren't rewritten. +// Previously, this would cause a bug report or crash (golang/go#69896). +func _(f Fooer) { + f.Foo(1) +} -- missingrecv.go -- package rm @@ -73,7 +90,7 @@ type Missing struct{} var r2 int -func (Missing) M(a, b, c, r0 int) (r1 int) { //@codeaction("b", "b", "refactor.rewrite.removeUnusedParam", missingrecv) +func (Missing) M(a, b, c, r0 int) (r1 int) { //@codeaction("b", "refactor.rewrite.removeUnusedParam", result=missingrecv) return a + c } @@ -101,13 +118,13 @@ type Missing struct{} var r2 int -func (Missing) M(a, c, r0 int) (r1 int) { //@codeaction("b", "b", "refactor.rewrite.removeUnusedParam", missingrecv) +func (Missing) M(a, c, r0 int) (r1 int) { //@codeaction("b", "refactor.rewrite.removeUnusedParam", result=missingrecv) return a + c } func _() { m := &Missing{} - _ = (*m).M(1, 3, 4) + _ = m.M(1, 3, 4) } -- @missingrecv/missingrecvuse/p.go -- package missingrecvuse @@ -116,7 +133,6 @@ import "example.com/rm" func _() { x := rm.Missing{} - var _ int = sideEffects() x.M(1, 3, 4) } diff --git a/gopls/internal/test/marker/testdata/codeaction/removeparam_resolve.txt b/gopls/internal/test/marker/testdata/codeaction/removeparam_resolve.txt index 92f8d299272..b51dd6fb8cf 100644 --- a/gopls/internal/test/marker/testdata/codeaction/removeparam_resolve.txt +++ b/gopls/internal/test/marker/testdata/codeaction/removeparam_resolve.txt @@ -20,14 +20,14 @@ go 1.18 -- a/a.go -- package a -func A(x, unused int) int { //@codeaction("unused", "unused", "refactor.rewrite.removeUnusedParam", a) +func A(x, unused int) int { //@codeaction("unused", "refactor.rewrite.removeUnusedParam", result=a) return x } -- @a/a/a.go -- package a -func A(x int) int { //@codeaction("unused", "unused", "refactor.rewrite.removeUnusedParam", a) +func A(x int) int { //@codeaction("unused", "refactor.rewrite.removeUnusedParam", result=a) return x } @@ -110,7 +110,7 @@ func _() { -- field/field.go -- package field -func Field(x int, field int) { //@codeaction("int", "int", "refactor.rewrite.removeUnusedParam", field) +func Field(x int, field int) { //@codeaction("int", "refactor.rewrite.removeUnusedParam", result=field) } func _() { @@ -119,7 +119,7 @@ func _() { -- @field/field/field.go -- package field -func Field(field int) { //@codeaction("int", "int", "refactor.rewrite.removeUnusedParam", field) +func Field(field int) { //@codeaction("int", "refactor.rewrite.removeUnusedParam", result=field) } func _() { @@ -128,7 +128,7 @@ func _() { -- ellipsis/ellipsis.go -- package ellipsis -func Ellipsis(...any) { //@codeaction("any", "any", "refactor.rewrite.removeUnusedParam", ellipsis) +func Ellipsis(...any) { //@codeaction("any", "refactor.rewrite.removeUnusedParam", result=ellipsis) } func _() { @@ -149,7 +149,7 @@ func i() []any -- @ellipsis/ellipsis/ellipsis.go -- package ellipsis -func Ellipsis() { //@codeaction("any", "any", "refactor.rewrite.removeUnusedParam", ellipsis) +func Ellipsis() { //@codeaction("any", "refactor.rewrite.removeUnusedParam", result=ellipsis) } func _() { @@ -157,12 +157,10 @@ func _() { Ellipsis() Ellipsis() Ellipsis() - var _ []any = []any{1, f(), g()} Ellipsis() func(_ ...any) { Ellipsis() }(h()) - var _ []any = i() Ellipsis() } @@ -173,7 +171,7 @@ func i() []any -- ellipsis2/ellipsis2.go -- package ellipsis2 -func Ellipsis2(_, _ int, rest ...int) { //@codeaction("_", "_", "refactor.rewrite.removeUnusedParam", ellipsis2) +func Ellipsis2(_, _ int, rest ...int) { //@codeaction("_", "refactor.rewrite.removeUnusedParam", result=ellipsis2) } func _() { @@ -187,11 +185,11 @@ func h() (int, int) -- @ellipsis2/ellipsis2/ellipsis2.go -- package ellipsis2 -func Ellipsis2(_ int, rest ...int) { //@codeaction("_", "_", "refactor.rewrite.removeUnusedParam", ellipsis2) +func Ellipsis2(_ int, rest ...int) { //@codeaction("_", "refactor.rewrite.removeUnusedParam", result=ellipsis2) } func _() { - Ellipsis2(2, []int{3}...) + Ellipsis2(2, 3) func(_, blank0 int, rest ...int) { Ellipsis2(blank0, rest...) }(h()) @@ -202,7 +200,7 @@ func h() (int, int) -- overlapping/overlapping.go -- package overlapping -func Overlapping(i int) int { //@codeactionerr(re"(i) int", re"(i) int", "refactor.rewrite.removeUnusedParam", re"overlapping") +func Overlapping(i int) int { //@codeaction(re"(i) int", "refactor.rewrite.removeUnusedParam", err=re"overlapping") return 0 } @@ -214,7 +212,7 @@ func _() { -- effects/effects.go -- package effects -func effects(x, y int) int { //@codeaction("y", "y", "refactor.rewrite.removeUnusedParam", effects), diag("y", re"unused") +func effects(x, y int) int { //@codeaction("y", "refactor.rewrite.removeUnusedParam", result=effects), diag("y", re"unused") return x } @@ -228,7 +226,7 @@ func _() { -- @effects/effects/effects.go -- package effects -func effects(x int) int { //@codeaction("y", "y", "refactor.rewrite.removeUnusedParam", effects), diag("y", re"unused") +func effects(x int) int { //@codeaction("y", "refactor.rewrite.removeUnusedParam", result=effects), diag("y", re"unused") return x } @@ -236,23 +234,19 @@ func f() int func g() int func _() { - var x, _ int = f(), g() - effects(x) - { - var x, _ int = f(), g() - effects(x) - } + effects(f()) + effects(f()) } -- recursive/recursive.go -- package recursive -func Recursive(x int) int { //@codeaction("x", "x", "refactor.rewrite.removeUnusedParam", recursive) +func Recursive(x int) int { //@codeaction("x", "refactor.rewrite.removeUnusedParam", result=recursive) return Recursive(1) } -- @recursive/recursive/recursive.go -- package recursive -func Recursive() int { //@codeaction("x", "x", "refactor.rewrite.removeUnusedParam", recursive) +func Recursive() int { //@codeaction("x", "refactor.rewrite.removeUnusedParam", result=recursive) return Recursive() } diff --git a/gopls/internal/test/marker/testdata/codeaction/removeparam_satisfies.txt b/gopls/internal/test/marker/testdata/codeaction/removeparam_satisfies.txt index 3b6ba360d29..5bb93610131 100644 --- a/gopls/internal/test/marker/testdata/codeaction/removeparam_satisfies.txt +++ b/gopls/internal/test/marker/testdata/codeaction/removeparam_satisfies.txt @@ -18,15 +18,23 @@ package rm type T int -func (t T) Foo(x int) { //@codeaction("x", "x", "refactor.rewrite.removeUnusedParam", basic) +func (t T) Foo(x int) { //@codeaction("x", "refactor.rewrite.removeUnusedParam", result=basic) } --- use/use.go -- +-- @basic/p.go -- +package rm + +type T int + +func (t T) Foo() { //@codeaction("x", "refactor.rewrite.removeUnusedParam", result=basic) +} + +-- @basic/use/use.go -- package use import "example.com/rm" -type Fooer interface{ +type Fooer interface { Foo(int) } @@ -34,22 +42,14 @@ var _ Fooer = rm.T(0) func _() { var x rm.T - x.Foo(1) + x.Foo() } --- @basic/p.go -- -package rm - -type T int - -func (t T) Foo() { //@codeaction("x", "x", "refactor.rewrite.removeUnusedParam", basic) -} - --- @basic/use/use.go -- +-- use/use.go -- package use import "example.com/rm" -type Fooer interface { +type Fooer interface{ Foo(int) } @@ -57,6 +57,5 @@ var _ Fooer = rm.T(0) func _() { var x rm.T - var t rm.T = x - t.Foo() + x.Foo(1) } diff --git a/gopls/internal/test/marker/testdata/codeaction/removeparam_witherrs.txt b/gopls/internal/test/marker/testdata/codeaction/removeparam_witherrs.txt index 5b4cd37a51a..212a4a24765 100644 --- a/gopls/internal/test/marker/testdata/codeaction/removeparam_witherrs.txt +++ b/gopls/internal/test/marker/testdata/codeaction/removeparam_witherrs.txt @@ -3,7 +3,7 @@ This test checks that we can't remove parameters for packages with errors. -- p.go -- package p -func foo(unused int) { //@codeactionerr("unused", "unused", "refactor.rewrite.removeUnusedParam", re"found 0") +func foo(unused int) { //@codeaction("unused", "refactor.rewrite.removeUnusedParam", err=re"found 0") } func _() { diff --git a/gopls/internal/test/marker/testdata/codeaction/splitlines-variadic.txt b/gopls/internal/test/marker/testdata/codeaction/splitlines-variadic.txt new file mode 100644 index 00000000000..700f0d9b7e1 --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/splitlines-variadic.txt @@ -0,0 +1,55 @@ +This is a regression test for #70519, in which the ellipsis +of a variadic call would go missing after split/join lines. + +-- go.mod -- +module example.com +go 1.18 + +-- a/a.go -- +package a + +var a, b, c []any +func f(any, any, ...any) + +func _() { + f(a, b, c...) //@codeaction("a", "refactor.rewrite.splitLines", result=split) + + f( + a, + b, + c..., /*@codeaction("c", "refactor.rewrite.joinLines", result=joined)*/ + ) +} + +-- @split/a/a.go -- +package a + +var a, b, c []any +func f(any, any, ...any) + +func _() { + f( + a, + b, + c..., + ) //@codeaction("a", "refactor.rewrite.splitLines", result=split) + + f( + a, + b, + c..., /*@codeaction("c", "refactor.rewrite.joinLines", result=joined)*/ + ) +} + +-- @joined/a/a.go -- +package a + +var a, b, c []any +func f(any, any, ...any) + +func _() { + f(a, b, c...) //@codeaction("a", "refactor.rewrite.splitLines", result=split) + + f(a, b, c..., /*@codeaction("c", "refactor.rewrite.joinLines", result=joined)*/) +} + diff --git a/gopls/internal/test/marker/testdata/codeaction/splitlines.txt b/gopls/internal/test/marker/testdata/codeaction/splitlines.txt index 5600ccb777a..f0f6ef6091c 100644 --- a/gopls/internal/test/marker/testdata/codeaction/splitlines.txt +++ b/gopls/internal/test/marker/testdata/codeaction/splitlines.txt @@ -9,7 +9,7 @@ go 1.18 -- func_arg/func_arg.go -- package func_arg -func A(a string, b, c int64, x int, y int) (r1 string, r2, r3 int64, r4 int, r5 int) { //@codeaction("x", "x", "refactor.rewrite.splitLines", func_arg) +func A(a string, b, c int64, x int, y int) (r1 string, r2, r3 int64, r4 int, r5 int) { //@codeaction("x", "refactor.rewrite.splitLines", result=func_arg) return a, b, c, x, y } @@ -21,14 +21,14 @@ func A( b, c int64, x int, y int, -) (r1 string, r2, r3 int64, r4 int, r5 int) { //@codeaction("x", "x", "refactor.rewrite.splitLines", func_arg) +) (r1 string, r2, r3 int64, r4 int, r5 int) { //@codeaction("x", "refactor.rewrite.splitLines", result=func_arg) return a, b, c, x, y } -- func_ret/func_ret.go -- package func_ret -func A(a string, b, c int64, x int, y int) (r1 string, r2, r3 int64, r4 int, r5 int) { //@codeaction("r1", "r1", "refactor.rewrite.splitLines", func_ret) +func A(a string, b, c int64, x int, y int) (r1 string, r2, r3 int64, r4 int, r5 int) { //@codeaction("r1", "refactor.rewrite.splitLines", result=func_ret) return a, b, c, x, y } @@ -40,14 +40,14 @@ func A(a string, b, c int64, x int, y int) ( r2, r3 int64, r4 int, r5 int, -) { //@codeaction("r1", "r1", "refactor.rewrite.splitLines", func_ret) +) { //@codeaction("r1", "refactor.rewrite.splitLines", result=func_ret) return a, b, c, x, y } -- functype_arg/functype_arg.go -- package functype_arg -type A func(a string, b, c int64, x int, y int) (r1 string, r2, r3 int64, r4 int, r5 int) //@codeaction("x", "x", "refactor.rewrite.splitLines", functype_arg) +type A func(a string, b, c int64, x int, y int) (r1 string, r2, r3 int64, r4 int, r5 int) //@codeaction("x", "refactor.rewrite.splitLines", result=functype_arg) -- @functype_arg/functype_arg/functype_arg.go -- package functype_arg @@ -57,12 +57,12 @@ type A func( b, c int64, x int, y int, -) (r1 string, r2, r3 int64, r4 int, r5 int) //@codeaction("x", "x", "refactor.rewrite.splitLines", functype_arg) +) (r1 string, r2, r3 int64, r4 int, r5 int) //@codeaction("x", "refactor.rewrite.splitLines", result=functype_arg) -- functype_ret/functype_ret.go -- package functype_ret -type A func(a string, b, c int64, x int, y int) (r1 string, r2, r3 int64, r4 int, r5 int) //@codeaction("r1", "r1", "refactor.rewrite.splitLines", functype_ret) +type A func(a string, b, c int64, x int, y int) (r1 string, r2, r3 int64, r4 int, r5 int) //@codeaction("r1", "refactor.rewrite.splitLines", result=functype_ret) -- @functype_ret/functype_ret/functype_ret.go -- package functype_ret @@ -72,7 +72,7 @@ type A func(a string, b, c int64, x int, y int) ( r2, r3 int64, r4 int, r5 int, -) //@codeaction("r1", "r1", "refactor.rewrite.splitLines", functype_ret) +) //@codeaction("r1", "refactor.rewrite.splitLines", result=functype_ret) -- func_call/func_call.go -- package func_call @@ -80,7 +80,7 @@ package func_call import "fmt" func a() { - fmt.Println(1, 2, 3, fmt.Sprintf("hello %d", 4)) //@codeaction("1", "1", "refactor.rewrite.splitLines", func_call) + fmt.Println(1, 2, 3, fmt.Sprintf("hello %d", 4)) //@codeaction("1", "refactor.rewrite.splitLines", result=func_call) } -- @func_call/func_call/func_call.go -- @@ -94,7 +94,7 @@ func a() { 2, 3, fmt.Sprintf("hello %d", 4), - ) //@codeaction("1", "1", "refactor.rewrite.splitLines", func_call) + ) //@codeaction("1", "refactor.rewrite.splitLines", result=func_call) } -- indent/indent.go -- @@ -103,7 +103,7 @@ package indent import "fmt" func a() { - fmt.Println(1, 2, 3, fmt.Sprintf("hello %d", 4)) //@codeaction("hello", "hello", "refactor.rewrite.splitLines", indent) + fmt.Println(1, 2, 3, fmt.Sprintf("hello %d", 4)) //@codeaction("hello", "refactor.rewrite.splitLines", result=indent) } -- @indent/indent/indent.go -- @@ -115,7 +115,7 @@ func a() { fmt.Println(1, 2, 3, fmt.Sprintf( "hello %d", 4, - )) //@codeaction("hello", "hello", "refactor.rewrite.splitLines", indent) + )) //@codeaction("hello", "refactor.rewrite.splitLines", result=indent) } -- indent2/indent2.go -- @@ -125,7 +125,7 @@ import "fmt" func a() { fmt. - Println(1, 2, 3, fmt.Sprintf("hello %d", 4)) //@codeaction("1", "1", "refactor.rewrite.splitLines", indent2) + Println(1, 2, 3, fmt.Sprintf("hello %d", 4)) //@codeaction("1", "refactor.rewrite.splitLines", result=indent2) } -- @indent2/indent2/indent2.go -- @@ -140,7 +140,7 @@ func a() { 2, 3, fmt.Sprintf("hello %d", 4), - ) //@codeaction("1", "1", "refactor.rewrite.splitLines", indent2) + ) //@codeaction("1", "refactor.rewrite.splitLines", result=indent2) } -- structelts/structelts.go -- @@ -152,7 +152,7 @@ type A struct{ } func a() { - _ = A{a: 1, b: 2} //@codeaction("b", "b", "refactor.rewrite.splitLines", structelts) + _ = A{a: 1, b: 2} //@codeaction("b", "refactor.rewrite.splitLines", result=structelts) } -- @structelts/structelts/structelts.go -- @@ -167,14 +167,14 @@ func a() { _ = A{ a: 1, b: 2, - } //@codeaction("b", "b", "refactor.rewrite.splitLines", structelts) + } //@codeaction("b", "refactor.rewrite.splitLines", result=structelts) } -- sliceelts/sliceelts.go -- package sliceelts func a() { - _ = []int{1, 2} //@codeaction("1", "1", "refactor.rewrite.splitLines", sliceelts) + _ = []int{1, 2} //@codeaction("1", "refactor.rewrite.splitLines", result=sliceelts) } -- @sliceelts/sliceelts/sliceelts.go -- @@ -184,14 +184,14 @@ func a() { _ = []int{ 1, 2, - } //@codeaction("1", "1", "refactor.rewrite.splitLines", sliceelts) + } //@codeaction("1", "refactor.rewrite.splitLines", result=sliceelts) } -- mapelts/mapelts.go -- package mapelts func a() { - _ = map[string]int{"a": 1, "b": 2} //@codeaction("1", "1", "refactor.rewrite.splitLines", mapelts) + _ = map[string]int{"a": 1, "b": 2} //@codeaction("1", "refactor.rewrite.splitLines", result=mapelts) } -- @mapelts/mapelts/mapelts.go -- @@ -201,13 +201,13 @@ func a() { _ = map[string]int{ "a": 1, "b": 2, - } //@codeaction("1", "1", "refactor.rewrite.splitLines", mapelts) + } //@codeaction("1", "refactor.rewrite.splitLines", result=mapelts) } -- starcomment/starcomment.go -- package starcomment -func A(/*1*/ x /*2*/ string /*3*/, /*4*/ y /*5*/ int /*6*/) (string, int) { //@codeaction("x", "x", "refactor.rewrite.splitLines", starcomment) +func A(/*1*/ x /*2*/ string /*3*/, /*4*/ y /*5*/ int /*6*/) (string, int) { //@codeaction("x", "refactor.rewrite.splitLines", result=starcomment) return x, y } @@ -217,7 +217,7 @@ package starcomment func A( /*1*/ x /*2*/ string /*3*/, /*4*/ y /*5*/ int /*6*/, -) (string, int) { //@codeaction("x", "x", "refactor.rewrite.splitLines", starcomment) +) (string, int) { //@codeaction("x", "refactor.rewrite.splitLines", result=starcomment) return x, y } diff --git a/gopls/internal/test/marker/testdata/completion/alias.txt b/gopls/internal/test/marker/testdata/completion/alias.txt index e4c340e3f1f..6e5a92253d5 100644 --- a/gopls/internal/test/marker/testdata/completion/alias.txt +++ b/gopls/internal/test/marker/testdata/completion/alias.txt @@ -26,6 +26,15 @@ func takesGeneric[a int | string](s[a]) { takesGeneric() //@rank(")", tpInScopeLit),snippet(")", tpInScopeLit, "s[a]{\\}") } +func _() { + s[int]{} //@item(tpInstLit, "s[int]{}", "", "var") + takesGeneric[int]() //@rank(")", tpInstLit),snippet(")", tpInstLit, "s[int]{\\}") + + "s[...]{}" //@item(tpUninstLit, "s[...]{}", "", "var") + takesGeneric() //@rank(")", tpUninstLit),snippet(")", tpUninstLit, "s[${1:}]{\\}") +} + + type myType int //@item(flType, "myType", "int", "type") type myt[T int] myType //@item(aflType, "myt[T]", "int", "type") diff --git a/gopls/internal/test/marker/testdata/completion/issue70636.txt b/gopls/internal/test/marker/testdata/completion/issue70636.txt new file mode 100644 index 00000000000..a684ee905aa --- /dev/null +++ b/gopls/internal/test/marker/testdata/completion/issue70636.txt @@ -0,0 +1,23 @@ +This test reproduces the crash of golang/go#70636, an out of bounds error when +analyzing a return statement with more results than the signature expects. + +-- flags -- +-ignore_extra_diags + +-- go.mod -- +module example.com + +go 1.21 + +-- p.go -- +package p + +var xx int +var xy string + + +func _() { + return Foo(x) //@ rank(re"x()", "xx", "xy") +} + +func Foo[T any](t T) T {} diff --git a/gopls/internal/test/marker/testdata/definition/standalone.txt b/gopls/internal/test/marker/testdata/definition/standalone.txt index 6af1149184d..04a80f23614 100644 --- a/gopls/internal/test/marker/testdata/definition/standalone.txt +++ b/gopls/internal/test/marker/testdata/definition/standalone.txt @@ -40,3 +40,4 @@ import "golang.org/lsptests/a" func main() { a.F() //@hovererr("F", "no package") } + diff --git a/gopls/internal/test/marker/testdata/definition/standalone_issue64557.txt b/gopls/internal/test/marker/testdata/definition/standalone_issue64557.txt new file mode 100644 index 00000000000..42b920c1fc4 --- /dev/null +++ b/gopls/internal/test/marker/testdata/definition/standalone_issue64557.txt @@ -0,0 +1,30 @@ +This test checks that we can load standalone files that use cgo. + +-- flags -- +-cgo + +-- go.mod -- +module example.com + +-- main.go -- +//go:build ignore + +package main + +import ( + "C" + + "example.com/a" +) + +func F() {} //@loc(F, "F") + +func main() { + F() //@def("F", F) + println(a.A) //@def("A", A) +} + +-- a/a.go -- +package a + +const A = 0 //@loc(A, "A") diff --git a/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt b/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt index 312a0c57120..76f65a4ecd7 100644 --- a/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt +++ b/gopls/internal/test/marker/testdata/diagnostics/analyzers.txt @@ -66,6 +66,14 @@ func _() { slog.Info("msg", 1) //@diag("1", re`slog.Info arg "1" should be a string or a slog.Attr`) } +// waitgroup +func _() { + var wg sync.WaitGroup + go func() { + wg.Add(1) //@diag("(", re"WaitGroup.Add called from inside new goroutine") + }() +} + -- cgocall/cgocall.go -- package cgocall diff --git a/gopls/internal/test/marker/testdata/fixedbugs/issue59944.txt b/gopls/internal/test/marker/testdata/fixedbugs/issue59944.txt index 9e39d8f5fe9..c4cd4409bf0 100644 --- a/gopls/internal/test/marker/testdata/fixedbugs/issue59944.txt +++ b/gopls/internal/test/marker/testdata/fixedbugs/issue59944.txt @@ -4,8 +4,12 @@ the methodset of its receiver type. Adapted from the code in question from the issue. +The flag -ignore_extra_diags is included, as this bug was fixed in Go 1.24, so +that now the code below may produce a diagnostic. + -- flags -- -cgo +-ignore_extra_diags -- go.mod -- module example.com diff --git a/gopls/internal/test/marker/testdata/foldingrange/a.txt b/gopls/internal/test/marker/testdata/foldingrange/a.txt index 6210fc25251..2946767ec30 100644 --- a/gopls/internal/test/marker/testdata/foldingrange/a.txt +++ b/gopls/internal/test/marker/testdata/foldingrange/a.txt @@ -6,13 +6,17 @@ package folding //@foldingrange(raw) import ( "fmt" _ "log" + "sort" + "time" ) import _ "os" // bar is a function. // With a multiline doc comment. -func bar() string { +func bar() ( + string, +) { /* This is a single line comment */ switch { case true: @@ -76,19 +80,74 @@ func bar() string { this string is not indented` } + +func _() { + slice := []int{1, 2, 3} + sort.Slice(slice, func(i, j int) bool { + a, b := slice[i], slice[j] + return a < b + }) + + sort.Slice(slice, func(i, j int) bool { return slice[i] < slice[j] }) + + sort.Slice( + slice, + func(i, j int) bool { + return slice[i] < slice[j] + }, + ) + + fmt.Println( + 1, 2, 3, + 4, + ) + + fmt.Println(1, 2, 3, + 4, 5, 6, + 7, 8, 9, + 10) + + // Call with ellipsis. + _ = fmt.Errorf( + "test %d %d", + []any{1, 2, 3}..., + ) + + // Check multiline string. + fmt.Println( + `multi + line + string + `, + 1, 2, 3, + ) + + // Call without arguments. + _ = time.Now() +} + +func _( + a int, b int, + c int, +) { +} -- @raw -- package folding //@foldingrange(raw) import (<0 kind="imports"> "fmt" _ "log" + "sort" + "time" ) import _ "os" // bar is a function.<1 kind="comment"> // With a multiline doc comment. -func bar(<2 kind="">) string {<3 kind=""> +func bar() (<2 kind=""> + string, +) {<3 kind=""> /* This is a single line comment */ switch {<4 kind=""> case true:<5 kind=""> @@ -152,3 +211,54 @@ func bar(<2 kind="">) string {<3 kind=""> this string is not indented` } + +func _() {<35 kind=""> + slice := []int{<36 kind="">1, 2, 3} + sort.Slice(<37 kind="">slice, func(<38 kind="">i, j int) bool {<39 kind=""> + a, b := slice[i], slice[j] + return a < b + }) + + sort.Slice(<40 kind="">slice, func(<41 kind="">i, j int) bool {<42 kind=""> return slice[i] < slice[j] }) + + sort.Slice(<43 kind=""> + slice, + func(<44 kind="">i, j int) bool {<45 kind=""> + return slice[i] < slice[j] + }, + ) + + fmt.Println(<46 kind=""> + 1, 2, 3, + 4, + ) + + fmt.Println(<47 kind="">1, 2, 3, + 4, 5, 6, + 7, 8, 9, + 10) + + // Call with ellipsis. + _ = fmt.Errorf(<48 kind=""> + "test %d %d", + []any{<49 kind="">1, 2, 3}..., + ) + + // Check multiline string. + fmt.Println(<50 kind=""> + <51 kind="">`multi + line + string + `, + 1, 2, 3, + ) + + // Call without arguments. + _ = time.Now() +} + +func _(<52 kind=""> + a int, b int, + c int, +) {<53 kind=""> +} diff --git a/gopls/internal/test/marker/testdata/foldingrange/a_lineonly.txt b/gopls/internal/test/marker/testdata/foldingrange/a_lineonly.txt index 0c532e760f1..fde2fb29c27 100644 --- a/gopls/internal/test/marker/testdata/foldingrange/a_lineonly.txt +++ b/gopls/internal/test/marker/testdata/foldingrange/a_lineonly.txt @@ -15,6 +15,8 @@ package folding //@foldingrange(raw) import ( "fmt" _ "log" + "sort" + "time" ) import _ "os" @@ -85,12 +87,65 @@ func bar() string { this string is not indented` } + +func _() { + slice := []int{1, 2, 3} + sort.Slice(slice, func(i, j int) bool { + a, b := slice[i], slice[j] + return a < b + }) + + sort.Slice(slice, func(i, j int) bool { return slice[i] < slice[j] }) + + sort.Slice( + slice, + func(i, j int) bool { + return slice[i] < slice[j] + }, + ) + + fmt.Println( + 1, 2, 3, + 4, + ) + + fmt.Println(1, 2, 3, + 4, 5, 6, + 7, 8, 9, + 10) + + // Call with ellipsis. + _ = fmt.Errorf( + "test %d %d", + []any{1, 2, 3}..., + ) + + // Check multiline string. + fmt.Println( + `multi + line + string + `, + 1, 2, 3, + ) + + // Call without arguments. + _ = time.Now() +} + +func _( + a int, b int, + c int, +) { +} -- @raw -- package folding //@foldingrange(raw) import (<0 kind="imports"> "fmt" - _ "log" + _ "log" + "sort" + "time" ) import _ "os" @@ -122,7 +177,7 @@ func bar() string {<2 kind=""> _ = []int{<11 kind=""> 1, 2, - 3, + 3, } _ = [2]string{"d", "e", @@ -130,7 +185,7 @@ func bar() string {<2 kind=""> _ = map[string]int{<12 kind=""> "a": 1, "b": 2, - "c": 3, + "c": 3, } type T struct {<13 kind=""> f string @@ -140,7 +195,7 @@ func bar() string {<2 kind=""> _ = T{<14 kind=""> f: "j", g: 4, - h: "i", + h: "i", } x, y := make(chan bool), make(chan bool) select {<15 kind=""> @@ -161,3 +216,54 @@ func bar() string {<2 kind=""> this string is not indented` } + +func _() {<23 kind=""> + slice := []int{1, 2, 3} + sort.Slice(slice, func(i, j int) bool {<24 kind=""> + a, b := slice[i], slice[j] + return a < b + }) + + sort.Slice(slice, func(i, j int) bool { return slice[i] < slice[j] }) + + sort.Slice(<25 kind=""> + slice, + func(i, j int) bool {<26 kind=""> + return slice[i] < slice[j] + }, + ) + + fmt.Println(<27 kind=""> + 1, 2, 3, + 4, + ) + + fmt.Println(1, 2, 3, + 4, 5, 6, + 7, 8, 9, + 10) + + // Call with ellipsis. + _ = fmt.Errorf(<28 kind=""> + "test %d %d", + []any{1, 2, 3}..., + ) + + // Check multiline string. + fmt.Println(<29 kind=""> + <30 kind="">`multi + line + string + `, + 1, 2, 3, + ) + + // Call without arguments. + _ = time.Now() +} + +func _(<31 kind=""> + a int, b int, + c int, +) { +} diff --git a/gopls/internal/test/marker/testdata/foldingrange/bad.txt b/gopls/internal/test/marker/testdata/foldingrange/bad.txt index f9f14a4fa7d..14444e7aa44 100644 --- a/gopls/internal/test/marker/testdata/foldingrange/bad.txt +++ b/gopls/internal/test/marker/testdata/foldingrange/bad.txt @@ -31,11 +31,11 @@ import (<1 kind="imports"> _ "os" ) // badBar is a function. -func badBar(<2 kind="">) string {<3 kind=""> x := true - if x {<4 kind=""> +func badBar() string {<2 kind=""> x := true + if x {<3 kind=""> // This is the only foldable thing in this file when lineFoldingOnly - fmt.Println(<5 kind="">"true") - } else {<6 kind=""> - fmt.Println(<7 kind="">"false") } + fmt.Println(<4 kind="">"true") + } else {<5 kind=""> + fmt.Println(<6 kind="">"false") } return "" -} +} diff --git a/gopls/internal/test/marker/testdata/foldingrange/parse_errors.txt b/gopls/internal/test/marker/testdata/foldingrange/parse_errors.txt new file mode 100644 index 00000000000..ad98d549e7a --- /dev/null +++ b/gopls/internal/test/marker/testdata/foldingrange/parse_errors.txt @@ -0,0 +1,26 @@ +This test verifies that textDocument/foldingRange does not panic +and produces no folding ranges if a file contains errors. + +-- flags -- +-ignore_extra_diags + +-- a.go -- +package folding //@foldingrange(raw) + +// No comma. +func _( + a string +) {} + +// Extra brace. +func _() {}} +-- @raw -- +package folding //@foldingrange(raw) + +// No comma. +func _( + a string +) {} + +// Extra brace. +func _() {}} diff --git a/gopls/internal/test/marker/testdata/implementation/basic.txt b/gopls/internal/test/marker/testdata/implementation/basic.txt index 3f63a5d00c1..28522cb5bc4 100644 --- a/gopls/internal/test/marker/testdata/implementation/basic.txt +++ b/gopls/internal/test/marker/testdata/implementation/basic.txt @@ -43,6 +43,12 @@ type embedsImpP struct { //@loc(embedsImpP, "embedsImpP") ImpP //@implementation("ImpP", Laugher, OtherLaugher) } +var _ error //@defloc(StdError, "error") + +type MyError struct {} //@implementation("MyError", StdError) + +func (MyError) Error() string { return "bah" } + -- other/other.go -- package other diff --git a/gopls/internal/test/marker/testdata/quickfix/stub.txt b/gopls/internal/test/marker/testdata/quickfix/stub.txt index 6f0a0788679..385565e3eaf 100644 --- a/gopls/internal/test/marker/testdata/quickfix/stub.txt +++ b/gopls/internal/test/marker/testdata/quickfix/stub.txt @@ -170,6 +170,25 @@ type closer struct{} +func (c closer) Close() error { + panic("unimplemented") +} +-- successive_function_return.go -- +package stub + +import ( + "io" +) + +func _() (a, b int, c io.Closer) { + return 1, 2, closer2{} //@quickfix("c", re"does not implement", successive) +} + +type closer2 struct{} +-- @successive/successive_function_return.go -- +@@ -12 +12,5 @@ ++ ++// Close implements io.Closer. ++func (c closer2) Close() error { ++ panic("unimplemented") ++} -- generic_receiver.go -- package stub diff --git a/gopls/internal/test/marker/testdata/quickfix/stubmethods/fromcall_returns.txt b/gopls/internal/test/marker/testdata/quickfix/stubmethods/fromcall_returns.txt index ca10f628402..fc108eb9c74 100644 --- a/gopls/internal/test/marker/testdata/quickfix/stubmethods/fromcall_returns.txt +++ b/gopls/internal/test/marker/testdata/quickfix/stubmethods/fromcall_returns.txt @@ -104,3 +104,175 @@ func fn() func(i int) { + panic("unimplemented") +} + +-- if_stmt.go -- +package fromcallreturns + +type IfStruct struct{} + +func testIfStmt() { + i := IfStruct{} + if i.isValid() { //@quickfix("isValid", re"has no field or method", infer_if_stmt) + // do something + } +} +-- @infer_if_stmt/if_stmt.go -- +@@ -5 +5,4 @@ ++func (i IfStruct) isValid() bool { ++ panic("unimplemented") ++} ++ +-- for_stmt.go -- +package fromcallreturns + +type ForStruct struct{} + +func testForStmt() { + f := ForStruct{} + for f.hasNext() { //@quickfix("hasNext", re"has no field or method", infer_for_stmt1) + // do something + } + for i := 0; f.inside(); i++ { //@quickfix("inside", re"has no field or method", infer_for_stmt2) + // do something + } +} +-- @infer_for_stmt1/for_stmt.go -- +@@ -5 +5,4 @@ ++func (f ForStruct) hasNext() bool { ++ panic("unimplemented") ++} ++ +-- @infer_for_stmt2/for_stmt.go -- +@@ -5 +5,4 @@ ++func (f ForStruct) inside() bool { ++ panic("unimplemented") ++} ++ +-- unary.go -- +package fromcallreturns + +type Unary struct{} + +func testUnaryExpr() { + u := Unary{} + a, b, c, d := !u.Boolean(), -u.Minus(), +u.Plus(), ^u.Xor() //@quickfix("Boolean", re"has no field or method", infer_unary_expr1),quickfix("Minus", re"has no field or method", infer_unary_expr2),quickfix("Plus", re"has no field or method", infer_unary_expr3),quickfix("Xor", re"has no field or method", infer_unary_expr4) + _, _, _, _ = a, b, c, d +} +-- @infer_unary_expr1/unary.go -- +@@ -5 +5,4 @@ ++func (u Unary) Boolean() bool { ++ panic("unimplemented") ++} ++ +-- @infer_unary_expr2/unary.go -- +@@ -5 +5,4 @@ ++func (u Unary) Minus() int { ++ panic("unimplemented") ++} ++ +-- @infer_unary_expr3/unary.go -- +@@ -5 +5,4 @@ ++func (u Unary) Plus() int { ++ panic("unimplemented") ++} ++ +-- @infer_unary_expr4/unary.go -- +@@ -5 +5,4 @@ ++func (u Unary) Xor() int { ++ panic("unimplemented") ++} ++ +-- binary.go -- +package fromcallreturns + +type Binary struct{} + +func testBinaryExpr() { + b := Binary{} + _ = 1 + b.Num() //@quickfix("Num", re"has no field or method", infer_binary_expr1) + _ = "s" + b.Str() //@quickfix("Str", re"has no field or method", infer_binary_expr2) +} +-- @infer_binary_expr1/binary.go -- +@@ -5 +5,4 @@ ++func (b Binary) Num() int { ++ panic("unimplemented") ++} ++ +-- @infer_binary_expr2/binary.go -- +@@ -5 +5,4 @@ ++func (b Binary) Str() string { ++ panic("unimplemented") ++} ++ +-- value.go -- +package fromcallreturns + +type Value struct{} + +func v() { + v := Value{} + var a, b int = v.Multi() //@quickfix("Multi", re"has no field or method", infer_value_expr1) + var c, d int = 4, v.Single() //@quickfix("Single", re"has no field or method", infer_value_expr2) + _, _, _, _ = a, b, c, d +} +-- @infer_value_expr1/value.go -- +@@ -5 +5,4 @@ ++func (v Value) Multi() (int, int) { ++ panic("unimplemented") ++} ++ +-- @infer_value_expr2/value.go -- +@@ -5 +5,4 @@ ++func (v Value) Single() int { ++ panic("unimplemented") ++} ++ +-- return.go -- +package fromcallreturns + +type Return struct{} + +func r() { + r := Return{} + _ = func() (int, int) { + return r.Multi() //@quickfix("Multi", re"has no field or method", infer_retrun_expr1) + } + _ = func() string { + return r.Single() //@quickfix("Single", re"has no field or method", infer_retrun_expr2) + } +} +-- @infer_retrun_expr1/return.go -- +@@ -5 +5,4 @@ ++func (r Return) Multi() (int, int) { ++ panic("unimplemented") ++} ++ +-- @infer_retrun_expr2/return.go -- +@@ -5 +5,4 @@ ++func (r Return) Single() string { ++ panic("unimplemented") ++} ++ +-- successive_return.go -- +package fromcallreturns + +type R struct{} + +func _() (x int, y, z string, k int64) { + r := R{} + _ = func() (a, b float32, c int) { + return r.Multi() //@quickfix("Multi", re"has no field or method", successive1) + } + return 3, "", r.Single(), 6 //@quickfix("Single", re"has no field or method", successive2) +} +-- @successive1/successive_return.go -- +@@ -5 +5,4 @@ ++func (r R) Multi() (float32, float32, int) { ++ panic("unimplemented") ++} ++ +-- @successive2/successive_return.go -- +@@ -5 +5,4 @@ ++func (r R) Single() string { ++ panic("unimplemented") ++} ++ diff --git a/gopls/internal/test/marker/testdata/quickfix/undeclared.txt b/gopls/internal/test/marker/testdata/quickfix/undeclared.txt deleted file mode 100644 index 6b6e47e2765..00000000000 --- a/gopls/internal/test/marker/testdata/quickfix/undeclared.txt +++ /dev/null @@ -1,42 +0,0 @@ -Tests of suggested fixes for "undeclared name" diagnostics, -which are of ("compiler", "error") type. - --- go.mod -- -module example.com -go 1.12 - --- a.go -- -package p - -func a() { - z, _ := 1+y, 11 //@quickfix("y", re"(undeclared name|undefined): y", a) - _ = z -} - --- @a/a.go -- -@@ -4 +4 @@ -+ y := --- b.go -- -package p - -func b() { - if 100 < 90 { - } else if 100 > n+2 { //@quickfix("n", re"(undeclared name|undefined): n", b) - } -} - --- @b/b.go -- -@@ -4 +4 @@ -+ n := --- c.go -- -package p - -func c() { - for i < 200 { //@quickfix("i", re"(undeclared name|undefined): i", c) - } - r() //@diag("r", re"(undeclared name|undefined): r") -} - --- @c/c.go -- -@@ -4 +4 @@ -+ i := diff --git a/gopls/internal/test/marker/testdata/quickfix/undeclared/diag.txt b/gopls/internal/test/marker/testdata/quickfix/undeclared/diag.txt new file mode 100644 index 00000000000..88dbb88e8e6 --- /dev/null +++ b/gopls/internal/test/marker/testdata/quickfix/undeclared/diag.txt @@ -0,0 +1,97 @@ +This test checks @diag reports for undeclared variables and functions. + +-- x.go -- +package undeclared + +func x() int { + var z int + z = y //@diag("y", re"(undeclared name|undefined): y") + if z == m { //@diag("m", re"(undeclared name|undefined): m") + z = 1 + } + + if z == 1 { + z = 1 + } else if z == n+1 { //@diag("n", re"(undeclared name|undefined): n") + z = 1 + } + + switch z { + case 10: + z = 1 + case aa: //@diag("aa", re"(undeclared name|undefined): aa") + z = 1 + } + return z +} +-- channels.go -- +package undeclared + +func channels(s string) { + undefinedChannels(c()) //@diag("undefinedChannels", re"(undeclared name|undefined): undefinedChannels") +} + +func c() (<-chan string, chan string) { + return make(<-chan string), make(chan string) +} +-- consecutive_params.go -- +package undeclared + +func consecutiveParams() { + var s string + undefinedConsecutiveParams(s, s) //@diag("undefinedConsecutiveParams", re"(undeclared name|undefined): undefinedConsecutiveParams") +} +-- error_param.go -- +package undeclared + +func errorParam() { + var err error + undefinedErrorParam(err) //@diag("undefinedErrorParam", re"(undeclared name|undefined): undefinedErrorParam") +} +-- literals.go -- +package undeclared + +type T struct{} + +func literals() { + undefinedLiterals("hey compiler", T{}, &T{}) //@diag("undefinedLiterals", re"(undeclared name|undefined): undefinedLiterals") +} +-- operation.go -- +package undeclared + +import "time" + +func operation() { + undefinedOperation(10 * time.Second) //@diag("undefinedOperation", re"(undeclared name|undefined): undefinedOperation") +} +-- selector.go -- +package undeclared + +func selector() { + m := map[int]bool{} + undefinedSelector(m[1]) //@diag("undefinedSelector", re"(undeclared name|undefined): undefinedSelector") +} +-- slice.go -- +package undeclared + +func slice() { + undefinedSlice([]int{1, 2}) //@diag("undefinedSlice", re"(undeclared name|undefined): undefinedSlice") +} +-- tuple.go -- +package undeclared + +func tuple() { + undefinedTuple(b()) //@diag("undefinedTuple", re"(undeclared name|undefined): undefinedTuple") +} + +func b() (string, error) { + return "", nil +} +-- unique.go -- +package undeclared + +func uniqueArguments() { + var s string + var i int + undefinedUniqueArguments(s, i, s) //@diag("undefinedUniqueArguments", re"(undeclared name|undefined): undefinedUniqueArguments") +} diff --git a/gopls/internal/test/marker/testdata/quickfix/missingfunction.txt b/gopls/internal/test/marker/testdata/quickfix/undeclared/missingfunction.txt similarity index 85% rename from gopls/internal/test/marker/testdata/quickfix/missingfunction.txt rename to gopls/internal/test/marker/testdata/quickfix/undeclared/missingfunction.txt index a21ccca766f..3dd42a115d2 100644 --- a/gopls/internal/test/marker/testdata/quickfix/missingfunction.txt +++ b/gopls/internal/test/marker/testdata/quickfix/undeclared/missingfunction.txt @@ -125,3 +125,31 @@ func uniqueArguments() { +func undefinedUniqueArguments(s1 string, i int, s2 string) { + panic("unimplemented") +} +-- param.go -- +package missingfunction + +func inferFromParam() { + f(as_param()) //@quickfix("as_param", re"undefined", infer_param) +} + +func f(i int) {} +-- @infer_param/param.go -- +@@ -7 +7,4 @@ ++func as_param() int { ++ panic("unimplemented") ++} ++ +-- assign.go -- +package missingfunction + +func inferFromAssign() { + i := 42 + i = i + i = assign() //@quickfix("assign", re"undefined", infer_assign) +} +-- @infer_assign/assign.go -- +@@ -8 +8,4 @@ ++ ++func assign() int { ++ panic("unimplemented") ++} diff --git a/gopls/internal/test/marker/testdata/quickfix/undeclared/undeclared_variable.txt b/gopls/internal/test/marker/testdata/quickfix/undeclared/undeclared_variable.txt new file mode 100644 index 00000000000..a65f6b80f4b --- /dev/null +++ b/gopls/internal/test/marker/testdata/quickfix/undeclared/undeclared_variable.txt @@ -0,0 +1,108 @@ +Tests of suggested fixes for "undeclared name" diagnostics, +which are of ("compiler", "error") type. + +-- flags -- +-ignore_extra_diags + +-- go.mod -- +module example.com +go 1.12 + +-- a.go -- +package undeclared_var + +func a() { + z, _ := 1+y, 11 //@quickfix("y", re"(undeclared name|undefined): y", a) + _ = z +} + +-- @a/a.go -- +@@ -4 +4 @@ ++ y := 0 +-- b.go -- +package undeclared_var + +func b() { + if 100 < 90 { + } else if 100 > n+2 { //@quickfix("n", re"(undeclared name|undefined): n", b) + } +} + +-- @b/b.go -- +@@ -4 +4 @@ ++ n := 0 +-- c.go -- +package undeclared_var + +func c() { + for i < 200 { //@quickfix("i", re"(undeclared name|undefined): i", c) + } + r() //@diag("r", re"(undeclared name|undefined): r") +} + +-- @c/c.go -- +@@ -4 +4 @@ ++ i := 0 +-- add_colon.go -- +package undeclared_var + +func addColon() { + ac = 1 //@quickfix("ac", re"(undeclared name|undefined): ac", add_colon) +} + +-- @add_colon/add_colon.go -- +@@ -4 +4 @@ +- ac = 1 //@quickfix("ac", re"(undeclared name|undefined): ac", add_colon) ++ ac := 1 //@quickfix("ac", re"(undeclared name|undefined): ac", add_colon) +-- add_colon_first.go -- +package undeclared_var + +func addColonAtFirstStmt() { + ac = 1 + ac = 2 + ac = 3 + b := ac //@quickfix("ac", re"(undeclared name|undefined): ac", add_colon_first) +} + +-- @add_colon_first/add_colon_first.go -- +@@ -4 +4 @@ +- ac = 1 ++ ac := 1 +-- self_assign.go -- +package undeclared_var + +func selfAssign() { + ac = ac + 1 + ac = ac + 2 //@quickfix("ac", re"(undeclared name|undefined): ac", lhs) + ac = ac + 3 //@quickfix("ac + 3", re"(undeclared name|undefined): ac", rhs) +} + +-- @lhs/self_assign.go -- +@@ -4 +4 @@ ++ ac := nil +-- @rhs/self_assign.go -- +@@ -4 +4 @@ ++ ac := 0 +-- correct_type.go -- +package undeclared_var +import "fmt" +func selfAssign() { + fmt.Printf(ac) //@quickfix("ac", re"(undeclared name|undefined): ac", string) +} +-- @string/correct_type.go -- +@@ -4 +4 @@ ++ ac := "" +-- ignore.go -- +package undeclared_var +import "fmt" +type Foo struct { + bar int +} +func selfAssign() { + f := Foo{} + b = f.bar + c := bar //@quickfix("bar", re"(undeclared name|undefined): bar", ignore) +} +-- @ignore/ignore.go -- +@@ -9 +9 @@ ++ bar := nil diff --git a/gopls/internal/test/marker/testdata/quickfix/undeclaredfunc.txt b/gopls/internal/test/marker/testdata/quickfix/undeclared/undeclaredfunc.txt similarity index 80% rename from gopls/internal/test/marker/testdata/quickfix/undeclaredfunc.txt rename to gopls/internal/test/marker/testdata/quickfix/undeclared/undeclaredfunc.txt index 6a0f7be3870..68940ca858d 100644 --- a/gopls/internal/test/marker/testdata/quickfix/undeclaredfunc.txt +++ b/gopls/internal/test/marker/testdata/quickfix/undeclared/undeclaredfunc.txt @@ -1,8 +1,6 @@ This test checks the quick fix for "undeclared: f" that declares the missing function. See #47558. -TODO(adonovan): infer the result variables from the context (int, in this case). - -- a.go -- package a @@ -13,7 +11,7 @@ func _() int { return f(1, "") } //@quickfix(re"f.1", re"unde(fined|clared name) -func _() int { return f(1, "") } //@quickfix(re"f.1", re"unde(fined|clared name): f", x) +func _() int { return f(1, "") } @@ -5 +5,4 @@ -+func f(i int, s string) { ++func f(i int, s string) int { + panic("unimplemented") +} //@quickfix(re"f.1", re"unde(fined|clared name): f", x) + diff --git a/gopls/internal/test/marker/testdata/rename/func.txt b/gopls/internal/test/marker/testdata/rename/func.txt new file mode 100644 index 00000000000..04ad1e955d0 --- /dev/null +++ b/gopls/internal/test/marker/testdata/rename/func.txt @@ -0,0 +1,55 @@ +This test checks basic functionality for renaming (=changing) a function +signature. + +-- go.mod -- +module example.com + +go 1.20 + +-- a/a.go -- +package a + +//@rename(Foo, "func(i int, s string)", unchanged) +//@rename(Foo, "func(s string, i int)", reverse) +//@rename(Foo, "func(s string)", dropi) +//@rename(Foo, "func(i int)", drops) +//@rename(Foo, "func()", dropboth) +//@renameerr(Foo, "func(i int, s string, t bool)", "not yet supported") +//@renameerr(Foo, "func(i string)", "not yet supported") +//@renameerr(Foo, "func(i int, s string) int", "not yet supported") + +func Foo(i int, s string) { //@loc(Foo, "func") +} + +func _() { + Foo(0, "hi") +} +-- @dropboth/a/a.go -- +@@ -12 +12 @@ +-func Foo(i int, s string) { //@loc(Foo, "func") ++func Foo() { //@loc(Foo, "func") +@@ -16 +16 @@ +- Foo(0, "hi") ++ Foo() +-- @dropi/a/a.go -- +@@ -12 +12 @@ +-func Foo(i int, s string) { //@loc(Foo, "func") ++func Foo(s string) { //@loc(Foo, "func") +@@ -16 +16 @@ +- Foo(0, "hi") ++ Foo("hi") +-- @drops/a/a.go -- +@@ -12 +12 @@ +-func Foo(i int, s string) { //@loc(Foo, "func") ++func Foo(i int) { //@loc(Foo, "func") +@@ -16 +16 @@ +- Foo(0, "hi") ++ Foo(0) +-- @reverse/a/a.go -- +@@ -12 +12 @@ +-func Foo(i int, s string) { //@loc(Foo, "func") ++func Foo(s string, i int) { //@loc(Foo, "func") +@@ -16 +16 @@ +- Foo(0, "hi") ++ Foo("hi", 0) +-- @unchanged/a/a.go -- diff --git a/gopls/internal/test/marker/testdata/rename/issue43616.txt b/gopls/internal/test/marker/testdata/rename/issue43616.txt index 19cfac4a435..9ade79fb6be 100644 --- a/gopls/internal/test/marker/testdata/rename/issue43616.txt +++ b/gopls/internal/test/marker/testdata/rename/issue43616.txt @@ -4,15 +4,15 @@ fields. -- p.go -- package issue43616 -type foo int //@rename("foo", "bar", fooToBar),preparerename("oo","foo","foo") +type foo int //@rename("foo", "bar", fooToBar),preparerename("oo","foo",span="foo") var x struct{ foo } //@renameerr("foo", "baz", "rename the type directly") var _ = x.foo //@renameerr("foo", "quux", "rename the type directly") -- @fooToBar/p.go -- @@ -3 +3 @@ --type foo int //@rename("foo", "bar", fooToBar),preparerename("oo","foo","foo") -+type bar int //@rename("foo", "bar", fooToBar),preparerename("oo","foo","foo") +-type foo int //@rename("foo", "bar", fooToBar),preparerename("oo","foo",span="foo") ++type bar int //@rename("foo", "bar", fooToBar),preparerename("oo","foo",span="foo") @@ -5 +5 @@ -var x struct{ foo } //@renameerr("foo", "baz", "rename the type directly") +var x struct{ bar } //@renameerr("foo", "baz", "rename the type directly") diff --git a/gopls/internal/test/marker/testdata/rename/prepare.txt b/gopls/internal/test/marker/testdata/rename/prepare.txt index cd8439e41b3..7ac9581898e 100644 --- a/gopls/internal/test/marker/testdata/rename/prepare.txt +++ b/gopls/internal/test/marker/testdata/rename/prepare.txt @@ -33,9 +33,9 @@ func (*Y) Bobby() {} -- good/good0.go -- package good -func stuff() { //@item(good_stuff, "stuff", "func()", "func"),preparerename("stu", "stuff", "stuff") +func stuff() { //@item(good_stuff, "stuff", "func()", "func"),preparerename("stu", "stuff", span="stuff") x := 5 - random2(x) //@preparerename("dom", "random2", "random2") + random2(x) //@preparerename("dom", "random2", span="random2") } -- good/good1.go -- @@ -46,14 +46,14 @@ import ( ) func random() int { //@item(good_random, "random", "func() int", "func") - _ = "random() int" //@preparerename("random", "", "") - y := 6 + 7 //@preparerename("7", "", "") - return y //@preparerename("return", "","") + _ = "random() int" //@preparerename("random", "") + y := 6 + 7 //@preparerename("7", "") + return y //@preparerename("return", "", span="") } func random2(y int) int { //@item(good_random2, "random2", "func(y int) int", "func"),item(good_y_param, "y", "int", "var") //@complete("", good_y_param, types_import, good_random, good_random2, good_stuff) - var b types.Bob = &types.X{} //@preparerename("ypes","types", "types") + var b types.Bob = &types.X{} //@preparerename("ypes","types", span="types") if _, ok := b.(*types.X); ok { //@complete("X", X_struct, Y_struct, Bob_interface, CoolAlias) _ = 0 // suppress "empty branch" diagnostic } diff --git a/gopls/internal/test/marker/testdata/rename/prepare_func.txt b/gopls/internal/test/marker/testdata/rename/prepare_func.txt new file mode 100644 index 00000000000..2c73e69afe0 --- /dev/null +++ b/gopls/internal/test/marker/testdata/rename/prepare_func.txt @@ -0,0 +1,44 @@ +This test verifies the behavior of textDocument/prepareRename on function declarations. + +-- settings.json -- +{ + "deepCompletion": false +} + +-- go.mod -- +module golang.org/lsptests + +go 1.18 + +-- main.go -- +package main + +func _(i int) //@ preparerename("unc", "func(i int)", span="func") + +func _(i int) //@ preparerename("func", "func(i int)") + +func _(a, b int) //@ preparerename("func", "func(a, b int)") + +func _(a, _ int) //@ preparerename("func", "func(a, _0 int)") + +func _(a, _, _ int) //@ preparerename("func", "func(a, _0, _1 int)") + +func _(a, _, _, d int, _ string) //@ preparerename("func", "func(a, _0, _1, d int, _2 string)") + +func _(a int, b string) //@ preparerename("func", "func(a int, b string)") + +func _(a int, b ...string) //@ preparerename("func", "func(a int, b ...string)") + +func _(a int, b string) error //@ preparerename("func", "func(a int, b string) error") + +func _(a int, b string) (int, error) //@ preparerename("func", "func(a int, b string) (int, error)") + +func _( //@ preparerename("func", "func(a int, b string)") + a int, + b string, +) + +func _( //@ preparerename("func", "func(a int, b string) (int, error)") + a int, + b string, +) (int, error) diff --git a/gopls/internal/test/marker/testdata/token/builtin_constant.txt b/gopls/internal/test/marker/testdata/token/builtin_constant.txt index 8f0c021b3a9..79736d625b7 100644 --- a/gopls/internal/test/marker/testdata/token/builtin_constant.txt +++ b/gopls/internal/test/marker/testdata/token/builtin_constant.txt @@ -17,5 +17,5 @@ func _() { } const ( - c = iota //@ token("iota", "variable", "readonly defaultLibrary") + c = iota //@ token("iota", "variable", "readonly defaultLibrary number") ) diff --git a/gopls/internal/test/marker/testdata/token/comment.txt b/gopls/internal/test/marker/testdata/token/comment.txt index a5ce9139c4e..113ffa744dd 100644 --- a/gopls/internal/test/marker/testdata/token/comment.txt +++ b/gopls/internal/test/marker/testdata/token/comment.txt @@ -21,21 +21,21 @@ var B = 2 type Foo int -// [F] accept a [Foo], and print it. //@token("F", "function", ""),token("Foo", "type", "defaultLibrary number") +// [F] accept a [Foo], and print it. //@token("F", "function", "signature"),token("Foo", "type", "number") func F(v Foo) { println(v) } /* - [F1] print [A] and [B] //@token("F1", "function", ""),token("A", "variable", ""),token("B", "variable", "") + [F1] print [A] and [B] //@token("F1", "function", "signature"),token("A", "variable", "readonly number"),token("B", "variable", "number") */ func F1() { - // print [A] and [B]. //@token("A", "variable", ""),token("B", "variable", "") + // print [A] and [B]. //@token("A", "variable", "readonly number"),token("B", "variable", "number") println(A, B) } -// [F2] use [strconv.Atoi] convert s, then print it //@token("F2", "function", ""),token("strconv", "namespace", ""),token("Atoi", "function", "") +// [F2] use [strconv.Atoi] convert s, then print it //@token("F2", "function", "signature"),token("strconv", "namespace", ""),token("Atoi", "function", "signature") func F2(s string) { a, _ := strconv.Atoi("42") b, _ := strconv.Atoi("42") @@ -44,12 +44,12 @@ func F2(s string) { -- b.go -- package p -// [F3] accept [*Foo] //@token("F3", "function", ""),token("Foo", "type", "defaultLibrary number") +// [F3] accept [*Foo] //@token("F3", "function", "signature"),token("Foo", "type", "number") func F3(v *Foo) { println(*v) } -// [F4] equal [strconv.Atoi] //@token("F4", "function", ""),token("strconv", "namespace", ""),token("Atoi", "function", "") +// [F4] equal [strconv.Atoi] //@token("F4", "function", "signature"),token("strconv", "namespace", ""),token("Atoi", "function", "signature") func F4(s string) (int, error) { return 0, nil } diff --git a/gopls/internal/test/marker/testdata/token/issue66809.txt b/gopls/internal/test/marker/testdata/token/issue66809.txt new file mode 100644 index 00000000000..369c0b3dd07 --- /dev/null +++ b/gopls/internal/test/marker/testdata/token/issue66809.txt @@ -0,0 +1,16 @@ +This is a regression test for #66809 (missing modifiers for +declarations of function-type variables). + +-- settings.json -- +{ + "semanticTokens": true +} + +-- main.go -- +package main + +func main() { + foo := func(x string) string { return x } //@token("foo", "variable", "definition signature") + _ = foo //@token("foo", "variable", "signature") + foo("hello") //@token("foo", "variable", "signature") +} diff --git a/gopls/internal/test/marker/testdata/token/issue70251.txt b/gopls/internal/test/marker/testdata/token/issue70251.txt new file mode 100644 index 00000000000..25136d654ec --- /dev/null +++ b/gopls/internal/test/marker/testdata/token/issue70251.txt @@ -0,0 +1,13 @@ +This is a regression test for #70251 (missing modifiers for +predeclared interfaces). + +-- settings.json -- +{ + "semanticTokens": true +} + +-- a/a.go -- +package a + +var _ any //@token("any", "type", "defaultLibrary interface") +var _ error //@token("error", "type", "defaultLibrary interface") diff --git a/gopls/internal/test/marker/testdata/token/range.txt b/gopls/internal/test/marker/testdata/token/range.txt index 2f98c043d8e..b4a6065ec94 100644 --- a/gopls/internal/test/marker/testdata/token/range.txt +++ b/gopls/internal/test/marker/testdata/token/range.txt @@ -10,12 +10,12 @@ TODO: add more assertions. -- a.go -- package p //@token("package", "keyword", "") -const C = 42 //@token("C", "variable", "definition readonly") +const C = 42 //@token("C", "variable", "definition readonly number") -func F() { //@token("F", "function", "definition") - x := 2 + 3//@token("x", "variable", "definition"),token("2", "number", ""),token("+", "operator", "") - _ = x //@token("x", "variable", "") - _ = F //@token("F", "function", "") +func F() { //@token("F", "function", "definition signature") + x := 2 + 3//@token("x", "variable", "definition number"),token("2", "number", ""),token("+", "operator", "") + _ = x //@token("x", "variable", "number") + _ = F //@token("F", "function", "signature") } func _() { diff --git a/gopls/internal/test/marker/testdata/workspacesymbol/allscope.txt b/gopls/internal/test/marker/testdata/workspacesymbol/allscope.txt index 18fe4e5446f..645a9c967c9 100644 --- a/gopls/internal/test/marker/testdata/workspacesymbol/allscope.txt +++ b/gopls/internal/test/marker/testdata/workspacesymbol/allscope.txt @@ -27,4 +27,4 @@ func Println(s string) { } -- @println -- fmt/fmt.go:5:6-13 mod.test/symbols/fmt.Println Function - fmt.Println Function + fmt.Println Function diff --git a/gopls/internal/util/astutil/fields.go b/gopls/internal/util/astutil/fields.go new file mode 100644 index 00000000000..8b81ea47a49 --- /dev/null +++ b/gopls/internal/util/astutil/fields.go @@ -0,0 +1,35 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package astutil + +import ( + "go/ast" + "iter" +) + +// FlatFields 'flattens' an ast.FieldList, returning an iterator over each +// (name, field) combination in the list. For unnamed fields, the identifier is +// nil. +func FlatFields(list *ast.FieldList) iter.Seq2[*ast.Ident, *ast.Field] { + return func(yield func(*ast.Ident, *ast.Field) bool) { + if list == nil { + return + } + + for _, field := range list.List { + if len(field.Names) == 0 { + if !yield(nil, field) { + return + } + } else { + for _, name := range field.Names { + if !yield(name, field) { + return + } + } + } + } + } +} diff --git a/gopls/internal/util/astutil/fields_test.go b/gopls/internal/util/astutil/fields_test.go new file mode 100644 index 00000000000..7344d807fe3 --- /dev/null +++ b/gopls/internal/util/astutil/fields_test.go @@ -0,0 +1,55 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +package astutil_test + +import ( + "bytes" + "fmt" + "go/ast" + "go/parser" + "go/token" + "go/types" + "testing" + + "golang.org/x/tools/gopls/internal/util/astutil" +) + +func TestFlatFields(t *testing.T) { + tests := []struct { + params string + want string + }{ + {"", ""}, + {"a int", "a int"}, + {"int", "int"}, + {"a, b int", "a int, b int"}, + {"a, b, c int", "a int, b int, c int"}, + {"int, string", "int, string"}, + {"_ int, b string", "_ int, b string"}, + {"a, _ int, b string", "a int, _ int, b string"}, + } + + for _, test := range tests { + src := fmt.Sprintf("package p; func _(%s)", test.params) + f, err := parser.ParseFile(token.NewFileSet(), "", src, 0) + if err != nil { + t.Fatal(err) + } + params := f.Decls[0].(*ast.FuncDecl).Type.Params + var got bytes.Buffer + for name, field := range astutil.FlatFields(params) { + if got.Len() > 0 { + got.WriteString(", ") + } + if name != nil { + fmt.Fprintf(&got, "%s ", name.Name) + } + got.WriteString(types.ExprString(field.Type)) + } + if got := got.String(); got != test.want { + // align 'got' and 'want' for easier inspection + t.Errorf("FlatFields(%q):\n got: %q\nwant: %q", test.params, got, test.want) + } + } +} diff --git a/gopls/internal/util/typesutil/typesutil.go b/gopls/internal/util/typesutil/typesutil.go index 11233e80bd2..98f5605200e 100644 --- a/gopls/internal/util/typesutil/typesutil.go +++ b/gopls/internal/util/typesutil/typesutil.go @@ -7,7 +7,9 @@ package typesutil import ( "bytes" "go/ast" + "go/token" "go/types" + "strings" ) // FileQualifier returns a [types.Qualifier] function that qualifies @@ -54,3 +56,204 @@ func FormatTypeParams(tparams *types.TypeParamList) string { buf.WriteByte(']') return buf.String() } + +// TypesFromContext returns the type (or perhaps zero or multiple types) +// of the "hole" into which the expression identified by path must fit. +// +// For example, given +// +// s, i := "", 0 +// s, i = EXPR +// +// the hole that must be filled by EXPR has type (string, int). +// +// It returns nil on failure. +func TypesFromContext(info *types.Info, path []ast.Node, pos token.Pos) []types.Type { + anyType := types.Universe.Lookup("any").Type() + var typs []types.Type + parent := parentNode(path) + if parent == nil { + return nil + } + + validType := func(t types.Type) types.Type { + if t != nil && !containsInvalid(t) { + return types.Default(t) + } else { + return anyType + } + } + + switch parent := parent.(type) { + case *ast.AssignStmt: + // Append all lhs's type + if len(parent.Rhs) == 1 { + for _, lhs := range parent.Lhs { + t := info.TypeOf(lhs) + typs = append(typs, validType(t)) + } + break + } + // Lhs and Rhs counts do not match, give up + if len(parent.Lhs) != len(parent.Rhs) { + break + } + // Append corresponding index of lhs's type + for i, rhs := range parent.Rhs { + if rhs.Pos() <= pos && pos <= rhs.End() { + t := info.TypeOf(parent.Lhs[i]) + typs = append(typs, validType(t)) + break + } + } + case *ast.ValueSpec: + if len(parent.Values) == 1 { + for _, lhs := range parent.Names { + t := info.TypeOf(lhs) + typs = append(typs, validType(t)) + } + break + } + if len(parent.Values) != len(parent.Names) { + break + } + t := info.TypeOf(parent.Type) + typs = append(typs, validType(t)) + case *ast.ReturnStmt: + sig := EnclosingSignature(path, info) + if sig == nil || sig.Results() == nil { + break + } + rets := sig.Results() + // Append all return declarations' type + if len(parent.Results) == 1 { + for i := 0; i < rets.Len(); i++ { + t := rets.At(i).Type() + typs = append(typs, validType(t)) + } + break + } + // Return declaration and actual return counts do not match, give up + if rets.Len() != len(parent.Results) { + break + } + // Append corresponding index of return declaration's type + for i, ret := range parent.Results { + if ret.Pos() <= pos && pos <= ret.End() { + t := rets.At(i).Type() + typs = append(typs, validType(t)) + break + } + } + case *ast.CallExpr: + // Find argument containing pos. + argIdx := -1 + for i, callArg := range parent.Args { + if callArg.Pos() <= pos && pos <= callArg.End() { + argIdx = i + break + } + } + if argIdx == -1 { + break + } + + t := info.TypeOf(parent.Fun) + if t == nil { + break + } + + if sig, ok := t.Underlying().(*types.Signature); ok { + var paramType types.Type + if sig.Variadic() && argIdx >= sig.Params().Len()-1 { + v := sig.Params().At(sig.Params().Len() - 1) + if s, _ := v.Type().(*types.Slice); s != nil { + paramType = s.Elem() + } + } else if argIdx < sig.Params().Len() { + paramType = sig.Params().At(argIdx).Type() + } else { + break + } + if paramType == nil || containsInvalid(paramType) { + paramType = anyType + } + typs = append(typs, paramType) + } + case *ast.IfStmt: + if parent.Cond == path[0] { + typs = append(typs, types.Typ[types.Bool]) + } + case *ast.ForStmt: + if parent.Cond == path[0] { + typs = append(typs, types.Typ[types.Bool]) + } + case *ast.UnaryExpr: + if parent.X == path[0] { + var t types.Type + switch parent.Op { + case token.NOT: + t = types.Typ[types.Bool] + case token.ADD, token.SUB, token.XOR: + t = types.Typ[types.Int] + default: + t = anyType + } + typs = append(typs, t) + } + case *ast.BinaryExpr: + if parent.X == path[0] { + t := info.TypeOf(parent.Y) + typs = append(typs, validType(t)) + } else if parent.Y == path[0] { + t := info.TypeOf(parent.X) + typs = append(typs, validType(t)) + } + default: + // TODO: support other kinds of "holes" as the need arises. + } + return typs +} + +// parentNode returns the nodes immediately enclosing path[0], +// ignoring parens. +func parentNode(path []ast.Node) ast.Node { + if len(path) <= 1 { + return nil + } + for _, n := range path[1:] { + if _, ok := n.(*ast.ParenExpr); !ok { + return n + } + } + return nil +} + +// containsInvalid checks if the type name contains "invalid type", +// which is not a valid syntax to generate. +func containsInvalid(t types.Type) bool { + typeString := types.TypeString(t, nil) + return strings.Contains(typeString, types.Typ[types.Invalid].String()) +} + +// EnclosingSignature returns the signature of the innermost +// function enclosing the syntax node denoted by path +// (see [astutil.PathEnclosingInterval]), or nil if the node +// is not within a function. +func EnclosingSignature(path []ast.Node, info *types.Info) *types.Signature { + for _, n := range path { + switch n := n.(type) { + case *ast.FuncDecl: + if f, ok := info.Defs[n.Name]; ok { + return f.Type().(*types.Signature) + } + return nil + case *ast.FuncLit: + if f, ok := info.Types[n]; ok { + return f.Type.(*types.Signature) + } + return nil + } + } + return nil +} diff --git a/gopls/internal/work/completion.go b/gopls/internal/work/completion.go index 194721ef36d..870450bd32d 100644 --- a/gopls/internal/work/completion.go +++ b/gopls/internal/work/completion.go @@ -54,7 +54,7 @@ func Completion(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, p pathPrefixSlash := completingFrom pathPrefixAbs := filepath.FromSlash(pathPrefixSlash) if !filepath.IsAbs(pathPrefixAbs) { - pathPrefixAbs = filepath.Join(filepath.Dir(pw.URI.Path()), pathPrefixAbs) + pathPrefixAbs = filepath.Join(pw.URI.DirPath(), pathPrefixAbs) } // pathPrefixDir is the directory that will be walked to find matches. diff --git a/gopls/internal/work/diagnostics.go b/gopls/internal/work/diagnostics.go index f1acd4d27c7..06ca48eeab6 100644 --- a/gopls/internal/work/diagnostics.go +++ b/gopls/internal/work/diagnostics.go @@ -81,7 +81,7 @@ func diagnoseOne(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle) } func modFileURI(pw *cache.ParsedWorkFile, use *modfile.Use) protocol.DocumentURI { - workdir := filepath.Dir(pw.URI.Path()) + workdir := pw.URI.DirPath() modroot := filepath.FromSlash(use.Path) if !filepath.IsAbs(modroot) { diff --git a/internal/analysisinternal/analysis.go b/internal/analysisinternal/analysis.go index 4ccaa210af1..fe67b0fa27a 100644 --- a/internal/analysisinternal/analysis.go +++ b/internal/analysisinternal/analysis.go @@ -15,7 +15,6 @@ import ( "go/types" "os" pathpkg "path" - "strconv" "golang.org/x/tools/go/analysis" ) @@ -66,200 +65,23 @@ func TypeErrorEndPos(fset *token.FileSet, src []byte, start token.Pos) token.Pos return end } -func ZeroValue(f *ast.File, pkg *types.Package, typ types.Type) ast.Expr { - // TODO(adonovan): think about generics, and also generic aliases. - under := types.Unalias(typ) - // Don't call Underlying unconditionally: although it removes - // Named and Alias, it also removes TypeParam. - if n, ok := under.(*types.Named); ok { - under = n.Underlying() - } - switch under := under.(type) { - case *types.Basic: - switch { - case under.Info()&types.IsNumeric != 0: - return &ast.BasicLit{Kind: token.INT, Value: "0"} - case under.Info()&types.IsBoolean != 0: - return &ast.Ident{Name: "false"} - case under.Info()&types.IsString != 0: - return &ast.BasicLit{Kind: token.STRING, Value: `""`} - default: - panic(fmt.Sprintf("unknown basic type %v", under)) - } - case *types.Chan, *types.Interface, *types.Map, *types.Pointer, *types.Signature, *types.Slice, *types.Array: - return ast.NewIdent("nil") - case *types.Struct: - texpr := TypeExpr(f, pkg, typ) // typ because we want the name here. - if texpr == nil { - return nil - } - return &ast.CompositeLit{ - Type: texpr, - } - } - return nil -} - -// IsZeroValue checks whether the given expression is a 'zero value' (as determined by output of -// analysisinternal.ZeroValue) -func IsZeroValue(expr ast.Expr) bool { - switch e := expr.(type) { - case *ast.BasicLit: - return e.Value == "0" || e.Value == `""` - case *ast.Ident: - return e.Name == "nil" || e.Name == "false" - default: - return false - } -} - -// TypeExpr returns syntax for the specified type. References to -// named types from packages other than pkg are qualified by an appropriate -// package name, as defined by the import environment of file. -func TypeExpr(f *ast.File, pkg *types.Package, typ types.Type) ast.Expr { - switch t := typ.(type) { - case *types.Basic: - switch t.Kind() { - case types.UnsafePointer: - return &ast.SelectorExpr{X: ast.NewIdent("unsafe"), Sel: ast.NewIdent("Pointer")} - default: - return ast.NewIdent(t.Name()) - } - case *types.Pointer: - x := TypeExpr(f, pkg, t.Elem()) - if x == nil { - return nil - } - return &ast.UnaryExpr{ - Op: token.MUL, - X: x, - } - case *types.Array: - elt := TypeExpr(f, pkg, t.Elem()) - if elt == nil { - return nil - } - return &ast.ArrayType{ - Len: &ast.BasicLit{ - Kind: token.INT, - Value: fmt.Sprintf("%d", t.Len()), - }, - Elt: elt, - } - case *types.Slice: - elt := TypeExpr(f, pkg, t.Elem()) - if elt == nil { - return nil - } - return &ast.ArrayType{ - Elt: elt, - } - case *types.Map: - key := TypeExpr(f, pkg, t.Key()) - value := TypeExpr(f, pkg, t.Elem()) - if key == nil || value == nil { - return nil - } - return &ast.MapType{ - Key: key, - Value: value, - } - case *types.Chan: - dir := ast.ChanDir(t.Dir()) - if t.Dir() == types.SendRecv { - dir = ast.SEND | ast.RECV - } - value := TypeExpr(f, pkg, t.Elem()) - if value == nil { - return nil - } - return &ast.ChanType{ - Dir: dir, - Value: value, - } - case *types.Signature: - var params []*ast.Field - for i := 0; i < t.Params().Len(); i++ { - p := TypeExpr(f, pkg, t.Params().At(i).Type()) - if p == nil { - return nil - } - params = append(params, &ast.Field{ - Type: p, - Names: []*ast.Ident{ - { - Name: t.Params().At(i).Name(), - }, - }, - }) - } - if t.Variadic() { - last := params[len(params)-1] - last.Type = &ast.Ellipsis{Elt: last.Type.(*ast.ArrayType).Elt} - } - var returns []*ast.Field - for i := 0; i < t.Results().Len(); i++ { - r := TypeExpr(f, pkg, t.Results().At(i).Type()) - if r == nil { - return nil - } - returns = append(returns, &ast.Field{ - Type: r, - }) - } - return &ast.FuncType{ - Params: &ast.FieldList{ - List: params, - }, - Results: &ast.FieldList{ - List: returns, - }, - } - case interface{ Obj() *types.TypeName }: // *types.{Alias,Named,TypeParam} - if t.Obj().Pkg() == nil { - return ast.NewIdent(t.Obj().Name()) - } - if t.Obj().Pkg() == pkg { - return ast.NewIdent(t.Obj().Name()) - } - pkgName := t.Obj().Pkg().Name() - - // If the file already imports the package under another name, use that. - for _, cand := range f.Imports { - if path, _ := strconv.Unquote(cand.Path.Value); path == t.Obj().Pkg().Path() { - if cand.Name != nil && cand.Name.Name != "" { - pkgName = cand.Name.Name - } - } - } - if pkgName == "." { - return ast.NewIdent(t.Obj().Name()) - } - return &ast.SelectorExpr{ - X: ast.NewIdent(pkgName), - Sel: ast.NewIdent(t.Obj().Name()), - } - case *types.Struct: - return ast.NewIdent(t.String()) - case *types.Interface: - return ast.NewIdent(t.String()) - default: - return nil - } -} - -// StmtToInsertVarBefore returns the ast.Stmt before which we can safely insert a new variable. -// Some examples: +// StmtToInsertVarBefore returns the ast.Stmt before which we can +// safely insert a new var declaration, or nil if the path denotes a +// node outside any statement. // // Basic Example: -// z := 1 -// y := z + x +// +// z := 1 +// y := z + x +// // If x is undeclared, then this function would return `y := z + x`, so that we // can insert `x := ` on the line before `y := z + x`. // // If stmt example: -// if z == 1 { -// } else if z == y {} +// +// if z == 1 { +// } else if z == y {} +// // If y is undeclared, then this function would return `if z == 1 {`, because we cannot // insert a statement between an if and an else if statement. As a result, we need to find // the top of the if chain to insert `y := ` before. @@ -272,7 +94,7 @@ func StmtToInsertVarBefore(path []ast.Node) ast.Stmt { } } if enclosingIndex == -1 { - return nil + return nil // no enclosing statement: outside function } enclosingStmt := path[enclosingIndex] switch enclosingStmt.(type) { @@ -280,6 +102,9 @@ func StmtToInsertVarBefore(path []ast.Node) ast.Stmt { // The enclosingStmt is inside of the if declaration, // We need to check if we are in an else-if stmt and // get the base if statement. + // TODO(adonovan): for non-constants, it may be preferable + // to add the decl as the Init field of the innermost + // enclosing ast.IfStmt. return baseIfStmt(path, enclosingIndex) case *ast.CaseClause: // Get the enclosing switch stmt if the enclosingStmt is diff --git a/internal/drivertest/driver_test.go b/internal/drivertest/driver_test.go index c1e3729a2fb..e1b170e2e43 100644 --- a/internal/drivertest/driver_test.go +++ b/internal/drivertest/driver_test.go @@ -68,7 +68,7 @@ package lib packages.NeedModule | packages.NeedEmbedFiles | packages.LoadMode(packagesinternal.DepsErrors) | - packages.LoadMode(packagesinternal.ForTest), + packages.NeedForTest, } tests := []struct { diff --git a/internal/facts/facts_test.go b/internal/facts/facts_test.go index bb7d36a07ad..0143fc5a298 100644 --- a/internal/facts/facts_test.go +++ b/internal/facts/facts_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:debug gotypesalias=1 + package facts_test import ( @@ -18,8 +20,10 @@ import ( "golang.org/x/tools/go/analysis/analysistest" "golang.org/x/tools/go/packages" + "golang.org/x/tools/internal/aliases" "golang.org/x/tools/internal/facts" "golang.org/x/tools/internal/testenv" + "golang.org/x/tools/internal/typesinternal" ) type myFact struct { @@ -35,10 +39,9 @@ func init() { func TestEncodeDecode(t *testing.T) { tests := []struct { - name string - typeparams bool // requires typeparams to be enabled - files map[string]string - plookups []pkgLookups // see testEncodeDecode for details + name string + files map[string]string + plookups []pkgLookups // see testEncodeDecode for details }{ { name: "loading-order", @@ -184,8 +187,7 @@ func TestEncodeDecode(t *testing.T) { }, }, { - name: "typeparams", - typeparams: true, + name: "typeparams", files: map[string]string{ "a/a.go": `package a type T1 int @@ -202,9 +204,9 @@ func TestEncodeDecode(t *testing.T) { type N3[T a.T3] func() T type N4[T a.T4|int8] func() T type N5[T interface{Bar() a.T5} ] func() T - + type t5 struct{}; func (t5) Bar() a.T5 { return 0 } - + var G1 N1[a.T1] var G2 func() N2[a.T2] var G3 N3[a.T3] @@ -222,7 +224,7 @@ func TestEncodeDecode(t *testing.T) { v5 = b.G5 v6 = b.F6[t6] ) - + type t6 struct{}; func (t6) Foo() {} `, }, @@ -244,9 +246,7 @@ func TestEncodeDecode(t *testing.T) { }, }, } - - for i := range tests { - test := tests[i] + for _, test := range tests { t.Run(test.name, func(t *testing.T) { t.Parallel() testEncodeDecode(t, test.files, test.plookups) @@ -254,9 +254,36 @@ func TestEncodeDecode(t *testing.T) { } } +func TestEncodeDecodeAliases(t *testing.T) { + testenv.NeedsGo1Point(t, 24) + + files := map[string]string{ + "a/a.go": `package a + type A = int + `, + "b/b.go": `package b + import "a" + type B = a.A + `, + "c/c.go": `package c + import "b"; + type N1[T int|~string] = struct{} + + var V1 = N1[b.B]{} + `, + } + plookups := []pkgLookups{ + {"a", []lookup{}}, + {"b", []lookup{}}, + // fake objexpr for RHS of V1's type arg (see customFind hack) + {"c", []lookup{{"c.V1->c.N1->b.B->a.A", "myFact(a.A)"}}}, + } + testEncodeDecode(t, files, plookups) +} + type lookup struct { - objexpr string - want string + objexpr string // expression whose type is a named type + want string // printed form of fact associated with that type (or "no fact") } type pkgLookups struct { @@ -345,6 +372,19 @@ func testEncodeDecode(t *testing.T, files map[string]string, tests []pkgLookups) } } +// customFind allows for overriding how an object is looked up +// by find. This is necessary for objects that are accessible through +// the API but are not the type of any expression we can pass to types.CheckExpr. +var customFind = map[string]func(p *types.Package) types.Object{ + "c.V1->c.N1->b.B->a.A": func(p *types.Package) types.Object { + cV1 := p.Scope().Lookup("V1") + cN1 := cV1.Type().(*types.Alias) + aT1 := aliases.TypeArgs(cN1).At(0).(*types.Alias) + zZ1 := aliases.Rhs(aT1).(*types.Alias) + return zZ1.Obj() + }, +} + func find(p *types.Package, expr string) types.Object { // types.Eval only allows us to compute a TypeName object for an expression. // TODO(adonovan): support other expressions that denote an object: @@ -352,7 +392,9 @@ func find(p *types.Package, expr string) types.Object { // - new(T).f for a field or method // I've added CheckExpr in https://go-review.googlesource.com/c/go/+/144677. // If that becomes available, use it. - + if f := customFind[expr]; f != nil { + return f(p) + } // Choose an arbitrary position within the (single-file) package // so that we are within the scope of its import declarations. somepos := p.Scope().Lookup(p.Scope().Names()[0]).Pos() @@ -360,7 +402,7 @@ func find(p *types.Package, expr string) types.Object { if err != nil { return nil } - if n, ok := types.Unalias(tv.Type).(*types.Named); ok { + if n, ok := tv.Type.(typesinternal.NamedOrAlias); ok { return n.Obj() } return nil diff --git a/internal/facts/imports.go b/internal/facts/imports.go index c36f2a5af0c..ed5ec5fa131 100644 --- a/internal/facts/imports.go +++ b/internal/facts/imports.go @@ -6,6 +6,9 @@ package facts import ( "go/types" + + "golang.org/x/tools/internal/aliases" + "golang.org/x/tools/internal/typesinternal" ) // importMap computes the import map for a package by traversing the @@ -45,32 +48,41 @@ func importMap(imports []*types.Package) map[string]*types.Package { addType = func(T types.Type) { switch T := T.(type) { - case *types.Alias: - addType(types.Unalias(T)) case *types.Basic: // nop - case *types.Named: + case typesinternal.NamedOrAlias: // *types.{Named,Alias} + // Add the type arguments if this is an instance. + if targs := typesinternal.TypeArgs(T); targs.Len() > 0 { + for i := 0; i < targs.Len(); i++ { + addType(targs.At(i)) + } + } + // Remove infinite expansions of *types.Named by always looking at the origin. // Some named types with type parameters [that will not type check] have // infinite expansions: // type N[T any] struct { F *N[N[T]] } // importMap() is called on such types when Analyzer.RunDespiteErrors is true. - T = T.Origin() + T = typesinternal.Origin(T) if !typs[T] { typs[T] = true + + // common aspects addObj(T.Obj()) - addType(T.Underlying()) - for i := 0; i < T.NumMethods(); i++ { - addObj(T.Method(i)) - } - if tparams := T.TypeParams(); tparams != nil { + if tparams := typesinternal.TypeParams(T); tparams.Len() > 0 { for i := 0; i < tparams.Len(); i++ { addType(tparams.At(i)) } } - if targs := T.TypeArgs(); targs != nil { - for i := 0; i < targs.Len(); i++ { - addType(targs.At(i)) + + // variant aspects + switch T := T.(type) { + case *types.Alias: + addType(aliases.Rhs(T)) + case *types.Named: + addType(T.Underlying()) + for i := 0; i < T.NumMethods(); i++ { + addObj(T.Method(i)) } } } diff --git a/internal/gcimporter/exportdata.go b/internal/gcimporter/exportdata.go index f6437feb1cf..6f5d8a21391 100644 --- a/internal/gcimporter/exportdata.go +++ b/internal/gcimporter/exportdata.go @@ -39,12 +39,15 @@ func readGopackHeader(r *bufio.Reader) (name string, size int64, err error) { } // FindExportData positions the reader r at the beginning of the -// export data section of an underlying GC-created object/archive +// export data section of an underlying cmd/compile created archive // file by reading from it. The reader must be positioned at the -// start of the file before calling this function. The hdr result -// is the string before the export data, either "$$" or "$$B". -// The size result is the length of the export data in bytes, or -1 if not known. -func FindExportData(r *bufio.Reader) (hdr string, size int64, err error) { +// start of the file before calling this function. +// The size result is the length of the export data in bytes. +// +// This function is needed by [gcexportdata.Read], which must +// accept inputs produced by the last two releases of cmd/compile, +// plus tip. +func FindExportData(r *bufio.Reader) (size int64, err error) { // Read first line to make sure this is an object file. line, err := r.ReadSlice('\n') if err != nil { @@ -52,27 +55,32 @@ func FindExportData(r *bufio.Reader) (hdr string, size int64, err error) { return } - if string(line) == "!\n" { - // Archive file. Scan to __.PKGDEF. - var name string - if name, size, err = readGopackHeader(r); err != nil { - return - } + // Is the first line an archive file signature? + if string(line) != "!\n" { + err = fmt.Errorf("not the start of an archive file (%q)", line) + return + } - // First entry should be __.PKGDEF. - if name != "__.PKGDEF" { - err = fmt.Errorf("go archive is missing __.PKGDEF") - return - } + // Archive file. Scan to __.PKGDEF. + var name string + if name, size, err = readGopackHeader(r); err != nil { + return + } + arsize := size - // Read first line of __.PKGDEF data, so that line - // is once again the first line of the input. - if line, err = r.ReadSlice('\n'); err != nil { - err = fmt.Errorf("can't find export data (%v)", err) - return - } - size -= int64(len(line)) + // First entry should be __.PKGDEF. + if name != "__.PKGDEF" { + err = fmt.Errorf("go archive is missing __.PKGDEF") + return + } + + // Read first line of __.PKGDEF data, so that line + // is once again the first line of the input. + if line, err = r.ReadSlice('\n'); err != nil { + err = fmt.Errorf("can't find export data (%v)", err) + return } + size -= int64(len(line)) // Now at __.PKGDEF in archive or still at beginning of file. // Either way, line should begin with "go object ". @@ -81,8 +89,8 @@ func FindExportData(r *bufio.Reader) (hdr string, size int64, err error) { return } - // Skip over object header to export data. - // Begins after first line starting with $$. + // Skip over object headers to get to the export data section header "$$B\n". + // Object headers are lines that do not start with '$'. for line[0] != '$' { if line, err = r.ReadSlice('\n'); err != nil { err = fmt.Errorf("can't find export data (%v)", err) @@ -90,9 +98,18 @@ func FindExportData(r *bufio.Reader) (hdr string, size int64, err error) { } size -= int64(len(line)) } - hdr = string(line) + + // Check for the binary export data section header "$$B\n". + hdr := string(line) + if hdr != "$$B\n" { + err = fmt.Errorf("unknown export data header: %q", hdr) + return + } + // TODO(taking): Remove end-of-section marker "\n$$\n" from size. + if size < 0 { - size = -1 + err = fmt.Errorf("invalid size (%d) in the archive file: %d bytes remain without section headers (recompile package)", arsize, size) + return } return diff --git a/internal/gcimporter/gcimporter.go b/internal/gcimporter/gcimporter.go index e6c5d51f8e5..dbbca860432 100644 --- a/internal/gcimporter/gcimporter.go +++ b/internal/gcimporter/gcimporter.go @@ -161,6 +161,8 @@ func FindPkg(path, srcDir string) (filename, id string) { // Import imports a gc-generated package given its import path and srcDir, adds // the corresponding package object to the packages map, and returns the object. // The packages map must contain all packages already imported. +// +// TODO(taking): Import is only used in tests. Move to gcimporter_test. func Import(packages map[string]*types.Package, path, srcDir string, lookup func(path string) (io.ReadCloser, error)) (pkg *types.Package, err error) { var rc io.ReadCloser var filename, id string @@ -210,58 +212,50 @@ func Import(packages map[string]*types.Package, path, srcDir string, lookup func } defer rc.Close() - var hdr string var size int64 buf := bufio.NewReader(rc) - if hdr, size, err = FindExportData(buf); err != nil { + if size, err = FindExportData(buf); err != nil { return } - switch hdr { - case "$$B\n": - var data []byte - data, err = io.ReadAll(buf) - if err != nil { - break - } + var data []byte + data, err = io.ReadAll(buf) + if err != nil { + return + } + if len(data) == 0 { + return nil, fmt.Errorf("no data to load a package from for path %s", id) + } - // TODO(gri): allow clients of go/importer to provide a FileSet. - // Or, define a new standard go/types/gcexportdata package. - fset := token.NewFileSet() - - // Select appropriate importer. - if len(data) > 0 { - switch data[0] { - case 'v', 'c', 'd': - // binary: emitted by cmd/compile till go1.10; obsolete. - return nil, fmt.Errorf("binary (%c) import format is no longer supported", data[0]) - - case 'i': - // indexed: emitted by cmd/compile till go1.19; - // now used only for serializing go/types. - // See https://github.com/golang/go/issues/69491. - _, pkg, err := IImportData(fset, packages, data[1:], id) - return pkg, err - - case 'u': - // unified: emitted by cmd/compile since go1.20. - _, pkg, err := UImportData(fset, packages, data[1:size], id) - return pkg, err - - default: - l := len(data) - if l > 10 { - l = 10 - } - return nil, fmt.Errorf("unexpected export data with prefix %q for path %s", string(data[:l]), id) - } - } + // TODO(gri): allow clients of go/importer to provide a FileSet. + // Or, define a new standard go/types/gcexportdata package. + fset := token.NewFileSet() + + // Select appropriate importer. + switch data[0] { + case 'v', 'c', 'd': + // binary: emitted by cmd/compile till go1.10; obsolete. + return nil, fmt.Errorf("binary (%c) import format is no longer supported", data[0]) + + case 'i': + // indexed: emitted by cmd/compile till go1.19; + // now used only for serializing go/types. + // See https://github.com/golang/go/issues/69491. + _, pkg, err := IImportData(fset, packages, data[1:], id) + return pkg, err + + case 'u': + // unified: emitted by cmd/compile since go1.20. + _, pkg, err := UImportData(fset, packages, data[1:size], id) + return pkg, err default: - err = fmt.Errorf("unknown export data header: %q", hdr) + l := len(data) + if l > 10 { + l = 10 + } + return nil, fmt.Errorf("unexpected export data with prefix %q for path %s", string(data[:l]), id) } - - return } type byPath []*types.Package diff --git a/internal/imports/fix_test.go b/internal/imports/fix_test.go index 5409db0217f..02ddd480dfd 100644 --- a/internal/imports/fix_test.go +++ b/internal/imports/fix_test.go @@ -1652,9 +1652,9 @@ var _ = bytes.Buffer } func TestStdlibSelfImports(t *testing.T) { - const input = `package ecdsa + const input = `package rc4 -var _ = ecdsa.GenerateKey +var _ = rc4.NewCipher ` testConfig{ @@ -1663,7 +1663,7 @@ var _ = ecdsa.GenerateKey Files: fm{"x.go": "package x"}, }, }.test(t, func(t *goimportTest) { - got, err := t.processNonModule(filepath.Join(t.goroot, "src/crypto/ecdsa/foo.go"), []byte(input), nil) + got, err := t.processNonModule(filepath.Join(t.goroot, "src/crypto/rc4/foo.go"), []byte(input), nil) if err != nil { t.Fatalf("Process() = %v", err) } diff --git a/internal/imports/source.go b/internal/imports/source.go index 5d2aeeebc95..cbe4f3c5ba1 100644 --- a/internal/imports/source.go +++ b/internal/imports/source.go @@ -59,5 +59,5 @@ type Source interface { // candidates satisfy all missing references for that package name. It is up // to each data source to select the best result for each entry in the // missing map. - ResolveReferences(ctx context.Context, filename string, missing References) (map[PackageName]*Result, error) + ResolveReferences(ctx context.Context, filename string, missing References) ([]*Result, error) } diff --git a/internal/imports/source_env.go b/internal/imports/source_env.go index ff9555d2879..d14abaa3195 100644 --- a/internal/imports/source_env.go +++ b/internal/imports/source_env.go @@ -48,7 +48,7 @@ func (s *ProcessEnvSource) LoadPackageNames(ctx context.Context, srcDir string, return r.loadPackageNames(unknown, srcDir) } -func (s *ProcessEnvSource) ResolveReferences(ctx context.Context, filename string, refs map[string]map[string]bool) (map[string]*Result, error) { +func (s *ProcessEnvSource) ResolveReferences(ctx context.Context, filename string, refs map[string]map[string]bool) ([]*Result, error) { var mu sync.Mutex found := make(map[string][]pkgDistance) callback := &scanCallback{ @@ -121,5 +121,9 @@ func (s *ProcessEnvSource) ResolveReferences(ctx context.Context, filename strin if err := g.Wait(); err != nil { return nil, err } - return results, nil + var ans []*Result + for _, x := range results { + ans = append(ans, x) + } + return ans, nil } diff --git a/internal/imports/source_modindex.go b/internal/imports/source_modindex.go new file mode 100644 index 00000000000..05229f06ce6 --- /dev/null +++ b/internal/imports/source_modindex.go @@ -0,0 +1,103 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package imports + +import ( + "context" + "sync" + "time" + + "golang.org/x/tools/internal/modindex" +) + +// This code is here rather than in the modindex package +// to avoid import loops + +// implements Source using modindex, so only for module cache. +// +// this is perhaps over-engineered. A new Index is read at first use. +// And then Update is called after every 15 minutes, and a new Index +// is read if the index changed. It is not clear the Mutex is needed. +type IndexSource struct { + modcachedir string + mutex sync.Mutex + ix *modindex.Index + expires time.Time +} + +// create a new Source. Called from NewView in cache/session.go. +func NewIndexSource(cachedir string) *IndexSource { + return &IndexSource{modcachedir: cachedir} +} + +func (s *IndexSource) LoadPackageNames(ctx context.Context, srcDir string, paths []ImportPath) (map[ImportPath]PackageName, error) { + /// This is used by goimports to resolve the package names of imports of the + // current package, which is irrelevant for the module cache. + return nil, nil +} + +func (s *IndexSource) ResolveReferences(ctx context.Context, filename string, missing References) ([]*Result, error) { + if err := s.maybeReadIndex(); err != nil { + return nil, err + } + var cs []modindex.Candidate + for pkg, nms := range missing { + for nm := range nms { + x := s.ix.Lookup(pkg, nm, false) + cs = append(cs, x...) + } + } + found := make(map[string]*Result) + for _, c := range cs { + var x *Result + if x = found[c.ImportPath]; x == nil { + x = &Result{ + Import: &ImportInfo{ + ImportPath: c.ImportPath, + Name: "", + }, + Package: &PackageInfo{ + Name: c.PkgName, + Exports: make(map[string]bool), + }, + } + found[c.ImportPath] = x + } + x.Package.Exports[c.Name] = true + } + var ans []*Result + for _, x := range found { + ans = append(ans, x) + } + return ans, nil +} + +func (s *IndexSource) maybeReadIndex() error { + s.mutex.Lock() + defer s.mutex.Unlock() + + var readIndex bool + if time.Now().After(s.expires) { + ok, err := modindex.Update(s.modcachedir) + if err != nil { + return err + } + if ok { + readIndex = true + } + } + + if readIndex || s.ix == nil { + ix, err := modindex.ReadIndex(s.modcachedir) + if err != nil { + return err + } + s.ix = ix + // for now refresh every 15 minutes + s.expires = time.Now().Add(time.Minute * 15) + } + + return nil +} diff --git a/internal/imports/sourcex_test.go b/internal/imports/sourcex_test.go new file mode 100644 index 00000000000..e8a4d537f8f --- /dev/null +++ b/internal/imports/sourcex_test.go @@ -0,0 +1,107 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package imports_test + +import ( + "context" + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/google/go-cmp/cmp" + "golang.org/x/tools/internal/imports" + "golang.org/x/tools/internal/modindex" +) + +// There are two cached packages, both resolving foo.Foo, +// but only one resolving foo.Bar +var ( + foo = tpkg{ + repo: "foo.com", + dir: "foo@v1.0.0", + syms: []string{"Foo"}, + } + foobar = tpkg{ + repo: "bar.com", + dir: "foo@v1.0.0", + syms: []string{"Foo", "Bar"}, + } + + fx = `package main + var _ = foo.Foo + var _ = foo.Bar + ` +) + +type tpkg struct { + // all packages are named foo + repo string // e.g. foo.com + dir string // e.g., foo@v1.0.0 + syms []string // exported syms +} + +func newpkgs(cachedir string, pks ...*tpkg) error { + for _, p := range pks { + fname := filepath.Join(cachedir, p.repo, p.dir, "foo.go") + if err := os.MkdirAll(filepath.Dir(fname), 0755); err != nil { + return err + } + fd, err := os.Create(fname) + if err != nil { + return err + } + fmt.Fprintf(fd, "package foo\n") + for _, s := range p.syms { + fmt.Fprintf(fd, "func %s() {}\n", s) + } + fd.Close() + } + return nil +} + +func TestSource(t *testing.T) { + + dirs := testDirs(t) + if err := newpkgs(dirs.cachedir, &foo, &foobar); err != nil { + t.Fatal(err) + } + source := imports.NewIndexSource(dirs.cachedir) + ctx := context.Background() + fixes, err := imports.FixImports(ctx, "tfile.go", []byte(fx), "unused", nil, source) + if err != nil { + t.Fatal(err) + } + opts := imports.Options{} + // ApplyFixes needs a non-nil opts + got, err := imports.ApplyFixes(fixes, "tfile.go", []byte(fx), &opts, 0) + + fxwant := "package main\n\nimport \"bar.com/foo\"\n\nvar _ = foo.Foo\nvar _ = foo.Bar\n" + if diff := cmp.Diff(string(got), fxwant); diff != "" { + t.Errorf("FixImports got\n%q, wanted\n%q\ndiff is\n%s", string(got), fxwant, diff) + } +} + +type dirs struct { + tmpdir string + cachedir string + rootdir string // goroot if we need it, which we don't +} + +func testDirs(t *testing.T) dirs { + t.Helper() + dir := t.TempDir() + modindex.IndexDir = func() (string, error) { return dir, nil } + x := dirs{ + tmpdir: dir, + cachedir: filepath.Join(dir, "pkg", "mod"), + rootdir: filepath.Join(dir, "root"), + } + if err := os.MkdirAll(x.cachedir, 0755); err != nil { + t.Fatal(err) + } + os.MkdirAll(x.rootdir, 0755) + return x +} diff --git a/internal/modindex/dir_test.go b/internal/modindex/dir_test.go index cbdf194ddb4..6e76f825116 100644 --- a/internal/modindex/dir_test.go +++ b/internal/modindex/dir_test.go @@ -205,41 +205,42 @@ func TestDirsSinglePath(t *testing.T) { } } -/* more data for tests - -directories.go:169: WEIRD cloud.google.com/go/iam/admin/apiv1 -map[cloud.google.com/go:1 cloud.google.com/go/iam:5]: -[cloud.google.com/go/iam@v0.12.0/admin/apiv1 -cloud.google.com/go/iam@v0.13.0/admin/apiv1 -cloud.google.com/go/iam@v0.3.0/admin/apiv1 -cloud.google.com/go/iam@v0.7.0/admin/apiv1 -cloud.google.com/go/iam@v1.0.1/admin/apiv1 -cloud.google.com/go@v0.94.0/iam/admin/apiv1] -directories.go:169: WEIRD cloud.google.com/go/iam -map[cloud.google.com/go:1 cloud.google.com/go/iam:5]: -[cloud.google.com/go/iam@v0.12.0 cloud.google.com/go/iam@v0.13.0 -cloud.google.com/go/iam@v0.3.0 cloud.google.com/go/iam@v0.7.0 -cloud.google.com/go/iam@v1.0.1 cloud.google.com/go@v0.94.0/iam] -directories.go:169: WEIRD cloud.google.com/go/compute/apiv1 -map[cloud.google.com/go:1 cloud.google.com/go/compute:4]: -[cloud.google.com/go/compute@v1.12.1/apiv1 -cloud.google.com/go/compute@v1.18.0/apiv1 -cloud.google.com/go/compute@v1.19.0/apiv1 -cloud.google.com/go/compute@v1.7.0/apiv1 -cloud.google.com/go@v0.94.0/compute/apiv1] -directories.go:169: WEIRD cloud.google.com/go/longrunning/autogen -map[cloud.google.com/go:2 cloud.google.com/go/longrunning:2]: -[cloud.google.com/go/longrunning@v0.3.0/autogen -cloud.google.com/go/longrunning@v0.4.1/autogen -cloud.google.com/go@v0.104.0/longrunning/autogen -cloud.google.com/go@v0.94.0/longrunning/autogen] -directories.go:169: WEIRD cloud.google.com/go/iam/credentials/apiv1 -map[cloud.google.com/go:1 cloud.google.com/go/iam:5]: -[cloud.google.com/go/iam@v0.12.0/credentials/apiv1 -cloud.google.com/go/iam@v0.13.0/credentials/apiv1 -cloud.google.com/go/iam@v0.3.0/credentials/apiv1 -cloud.google.com/go/iam@v0.7.0/credentials/apiv1 -cloud.google.com/go/iam@v1.0.1/credentials/apiv1 -cloud.google.com/go@v0.94.0/iam/credentials/apiv1] +func TestMissingCachedir(t *testing.T) { + // behave properly if the cached dir is empty + dir := testModCache(t) + if err := Create(dir); err != nil { + t.Fatal(err) + } + ixd, err := IndexDir() + if err != nil { + t.Fatal(err) + } + des, err := os.ReadDir(ixd) + if err != nil { + t.Fatal(err) + } + if len(des) != 2 { + t.Errorf("got %d, butexpected two entries in index dir", len(des)) + } +} -*/ +func TestMissingIndex(t *testing.T) { + // behave properly if there is no existing index + dir := testModCache(t) + if ok, err := Update(dir); err != nil { + t.Fatal(err) + } else if !ok { + t.Error("Update returned !ok") + } + ixd, err := IndexDir() + if err != nil { + t.Fatal(err) + } + des, err := os.ReadDir(ixd) + if err != nil { + t.Fatal(err) + } + if len(des) != 2 { + t.Errorf("got %d, butexpected two entries in index dir", len(des)) + } +} diff --git a/internal/modindex/index.go b/internal/modindex/index.go index c2443db408a..27b6dd832d7 100644 --- a/internal/modindex/index.go +++ b/internal/modindex/index.go @@ -7,9 +7,11 @@ package modindex import ( "bufio" "encoding/csv" + "errors" "fmt" "hash/crc64" "io" + "io/fs" "log" "os" "path/filepath" @@ -85,7 +87,8 @@ type Entry struct { // ReadIndex reads the latest version of the on-disk index // for the cache directory cd. -// It returns nil if there is none, or if there is an error. +// It returns (nil, nil) if there is no index, but returns +// a non-nil error if the index exists but could not be read. func ReadIndex(cachedir string) (*Index, error) { cachedir, err := filepath.Abs(cachedir) if err != nil { @@ -100,10 +103,10 @@ func ReadIndex(cachedir string) (*Index, error) { iname := filepath.Join(dir, base) buf, err := os.ReadFile(iname) if err != nil { - if err == os.ErrNotExist { + if errors.Is(err, fs.ErrNotExist) { return nil, nil } - return nil, fmt.Errorf("reading %s: %s %T", iname, err, err) + return nil, fmt.Errorf("cannot read %s: %w", iname, err) } fname := filepath.Join(dir, string(buf)) fd, err := os.Open(fname) @@ -235,7 +238,6 @@ func writeIndexToFile(x *Index, fd *os.File) error { if err := w.Flush(); err != nil { return err } - log.Printf("%d Entries %d names", len(x.Entries), cnt) return nil } diff --git a/internal/modindex/modindex.go b/internal/modindex/modindex.go index 6d0b5f09d94..355a53e71aa 100644 --- a/internal/modindex/modindex.go +++ b/internal/modindex/modindex.go @@ -67,7 +67,7 @@ func modindexTimed(onlyBefore time.Time, cachedir Abspath, clear bool) (bool, er if clear && err != nil { return false, err } - // TODO(pjw): check that most of those directorie still exist + // TODO(pjw): check that most of those directories still exist } cfg := &work{ onlyBefore: onlyBefore, @@ -80,8 +80,8 @@ func modindexTimed(onlyBefore time.Time, cachedir Abspath, clear bool) (bool, er if err := cfg.buildIndex(); err != nil { return false, err } - if len(cfg.newIndex.Entries) == 0 { - // no changes, don't write a new index + if len(cfg.newIndex.Entries) == 0 && curIndex != nil { + // no changes from existing curIndex, don't write a new index return false, nil } if err := cfg.writeIndex(); err != nil { diff --git a/internal/packagesinternal/packages.go b/internal/packagesinternal/packages.go index 44719de173b..66e69b4389d 100644 --- a/internal/packagesinternal/packages.go +++ b/internal/packagesinternal/packages.go @@ -5,7 +5,6 @@ // Package packagesinternal exposes internal-only fields from go/packages. package packagesinternal -var GetForTest = func(p interface{}) string { return "" } var GetDepsErrors = func(p interface{}) []*PackageError { return nil } type PackageError struct { @@ -16,7 +15,6 @@ type PackageError struct { var TypecheckCgo int var DepsErrors int // must be set as a LoadMode to call GetDepsErrors -var ForTest int // must be set as a LoadMode to call GetForTest var SetModFlag = func(config interface{}, value string) {} var SetModFile = func(config interface{}, value string) {} diff --git a/internal/refactor/inline/callee.go b/internal/refactor/inline/callee.go index c72699cb772..ab1cbcb0070 100644 --- a/internal/refactor/inline/callee.go +++ b/internal/refactor/inline/callee.go @@ -18,6 +18,7 @@ import ( "golang.org/x/tools/go/types/typeutil" "golang.org/x/tools/internal/typeparams" + "golang.org/x/tools/internal/typesinternal" ) // A Callee holds information about an inlinable function. Gob-serializable. @@ -72,8 +73,8 @@ type object struct { PkgName string // name of object's package (or imported package if kind="pkgname") // TODO(rfindley): should we also track LocalPkgName here? Do we want to // preserve the local package name? - ValidPos bool // Object.Pos().IsValid() - Shadow map[string]bool // names shadowed at one of the object's refs + ValidPos bool // Object.Pos().IsValid() + Shadow shadowMap // shadowing info for the object's refs } // AnalyzeCallee analyzes a function that is a candidate for inlining @@ -124,6 +125,7 @@ func AnalyzeCallee(logf func(string, ...any), fset *token.FileSet, pkg *types.Pa // Record the location of all free references in the FuncDecl. // (Parameters are not free by this definition.) var ( + fieldObjs = fieldObjs(sig) freeObjIndex = make(map[types.Object]int) freeObjs []object freeRefs []freeRef // free refs that may need renaming @@ -202,25 +204,25 @@ func AnalyzeCallee(logf func(string, ...any), fset *token.FileSet, pkg *types.Pa objidx, ok := freeObjIndex[obj] if !ok { objidx = len(freeObjIndex) - var pkgpath, pkgname string + var pkgPath, pkgName string if pn, ok := obj.(*types.PkgName); ok { - pkgpath = pn.Imported().Path() - pkgname = pn.Imported().Name() + pkgPath = pn.Imported().Path() + pkgName = pn.Imported().Name() } else if obj.Pkg() != nil { - pkgpath = obj.Pkg().Path() - pkgname = obj.Pkg().Name() + pkgPath = obj.Pkg().Path() + pkgName = obj.Pkg().Name() } freeObjs = append(freeObjs, object{ Name: obj.Name(), Kind: objectKind(obj), - PkgName: pkgname, - PkgPath: pkgpath, + PkgName: pkgName, + PkgPath: pkgPath, ValidPos: obj.Pos().IsValid(), }) freeObjIndex[obj] = objidx } - freeObjs[objidx].Shadow = addShadows(freeObjs[objidx].Shadow, info, obj.Name(), stack) + freeObjs[objidx].Shadow = freeObjs[objidx].Shadow.add(info, fieldObjs, obj.Name(), stack) freeRefs = append(freeRefs, freeRef{ Offset: int(n.Pos() - decl.Pos()), @@ -382,14 +384,27 @@ func parseCompact(content []byte) (*token.FileSet, *ast.FuncDecl, error) { // A paramInfo records information about a callee receiver, parameter, or result variable. type paramInfo struct { - Name string // parameter name (may be blank, or even "") - Index int // index within signature - IsResult bool // false for receiver or parameter, true for result variable - Assigned bool // parameter appears on left side of an assignment statement - Escapes bool // parameter has its address taken - Refs []int // FuncDecl-relative byte offset of parameter ref within body - Shadow map[string]bool // names shadowed at one of the above refs - FalconType string // name of this parameter's type (if basic) in the falcon system + Name string // parameter name (may be blank, or even "") + Index int // index within signature + IsResult bool // false for receiver or parameter, true for result variable + IsInterface bool // parameter has a (non-type parameter) interface type + Assigned bool // parameter appears on left side of an assignment statement + Escapes bool // parameter has its address taken + Refs []refInfo // information about references to parameter within body + Shadow shadowMap // shadowing info for the above refs; see [shadowMap] + FalconType string // name of this parameter's type (if basic) in the falcon system +} + +type refInfo struct { + Offset int // FuncDecl-relative byte offset of parameter ref within body + Assignable bool // ref appears in context of assignment to known type + IfaceAssignment bool // ref is being assigned to an interface + AffectsInference bool // ref type may affect type inference + // IsSelectionOperand indicates whether the parameter reference is the + // operand of a selection (param.f). If so, and param's argument is itself + // a receiver parameter (a common case), we don't need to desugar (&v or *ptr) + // the selection: if param.Method is a valid selection, then so is param.fieldOrMethod. + IsSelectionOperand bool } // analyzeParams computes information about parameters of function fn, @@ -405,15 +420,16 @@ func analyzeParams(logf func(string, ...any), fset *token.FileSet, info *types.I panic(fmt.Sprintf("%s: no func object for %q", fset.PositionFor(decl.Name.Pos(), false), decl.Name)) // ill-typed? } + sig := fnobj.Type().(*types.Signature) paramInfos := make(map[*types.Var]*paramInfo) { - sig := fnobj.Type().(*types.Signature) newParamInfo := func(param *types.Var, isResult bool) *paramInfo { info := ¶mInfo{ - Name: param.Name(), - IsResult: isResult, - Index: len(paramInfos), + Name: param.Name(), + IsResult: isResult, + Index: len(paramInfos), + IsInterface: isNonTypeParamInterface(param.Type()), } paramInfos[param] = info return info @@ -446,6 +462,7 @@ func analyzeParams(logf func(string, ...any), fset *token.FileSet, info *types.I // // TODO(adonovan): combine this traversal with the one that computes // FreeRefs. The tricky part is that calleefx needs this one first. + fieldObjs := fieldObjs(sig) var stack []ast.Node stack = append(stack, decl.Type) // for scope of function itself ast.Inspect(decl.Body, func(n ast.Node) bool { @@ -458,11 +475,27 @@ func analyzeParams(logf func(string, ...any), fset *token.FileSet, info *types.I if id, ok := n.(*ast.Ident); ok { if v, ok := info.Uses[id].(*types.Var); ok { if pinfo, ok := paramInfos[v]; ok { - // Record location of ref to parameter/result - // and any intervening (shadowing) names. - offset := int(n.Pos() - decl.Pos()) - pinfo.Refs = append(pinfo.Refs, offset) - pinfo.Shadow = addShadows(pinfo.Shadow, info, pinfo.Name, stack) + // Record ref information, and any intervening (shadowing) names. + // + // If the parameter v has an interface type, and the reference id + // appears in a context where assignability rules apply, there may be + // an implicit interface-to-interface widening. In that case it is + // not necessary to insert an explicit conversion from the argument + // to the parameter's type. + // + // Contrapositively, if param is not an interface type, then the + // assignment may lose type information, for example in the case that + // the substituted expression is an untyped constant or unnamed type. + assignable, ifaceAssign, affectsInference := analyzeAssignment(info, stack) + ref := refInfo{ + Offset: int(n.Pos() - decl.Pos()), + Assignable: assignable, + IfaceAssignment: ifaceAssign, + AffectsInference: affectsInference, + IsSelectionOperand: isSelectionOperand(stack), + } + pinfo.Refs = append(pinfo.Refs, ref) + pinfo.Shadow = pinfo.Shadow.add(info, fieldObjs, pinfo.Name, stack) } } } @@ -481,27 +514,300 @@ func analyzeParams(logf func(string, ...any), fset *token.FileSet, info *types.I // -- callee helpers -- -// addShadows returns the shadows set augmented by the set of names +// analyzeAssignment looks at the the given stack, and analyzes certain +// attributes of the innermost expression. +// +// In all cases we 'fail closed' when we cannot detect (or for simplicity +// choose not to detect) the condition in question, meaning we err on the side +// of the more restrictive rule. This is noted for each result below. +// +// - assignable reports whether the expression is used in a position where +// assignability rules apply, such as in an actual assignment, as call +// argument, or in a send to a channel. Defaults to 'false'. If assignable +// is false, the other two results are irrelevant. +// - ifaceAssign reports whether that assignment is to an interface type. +// This is important as we want to preserve the concrete type in that +// assignment. Defaults to 'true'. Notably, if the assigned type is a type +// parameter, we assume that it could have interface type. +// - affectsInference is (somewhat vaguely) defined as whether or not the +// type of the operand may affect the type of the surrounding syntax, +// through type inference. It is infeasible to completely reverse engineer +// type inference, so we over approximate: if the expression is an argument +// to a call to a generic function (but not method!) that uses type +// parameters, assume that unification of that argument may affect the +// inferred types. +func analyzeAssignment(info *types.Info, stack []ast.Node) (assignable, ifaceAssign, affectsInference bool) { + remaining, parent, expr := exprContext(stack) + if parent == nil { + return false, false, false + } + + // TODO(golang/go#70638): simplify when types.Info records implicit conversions. + + // Types do not need to match for assignment to a variable. + if assign, ok := parent.(*ast.AssignStmt); ok { + for i, v := range assign.Rhs { + if v == expr { + if i >= len(assign.Lhs) { + return false, false, false // ill typed + } + // Check to see if the assignment is to an interface type. + if i < len(assign.Lhs) { + // TODO: We could handle spread calls here, but in current usage expr + // is an ident. + if id, _ := assign.Lhs[i].(*ast.Ident); id != nil && info.Defs[id] != nil { + // Types must match for a defining identifier in a short variable + // declaration. + return false, false, false + } + // In all other cases, types should be known. + typ := info.TypeOf(assign.Lhs[i]) + return true, typ == nil || types.IsInterface(typ), false + } + // Default: + return assign.Tok == token.ASSIGN, true, false + } + } + } + + // Types do not need to match for an initializer with known type. + if spec, ok := parent.(*ast.ValueSpec); ok && spec.Type != nil { + for _, v := range spec.Values { + if v == expr { + typ := info.TypeOf(spec.Type) + return true, typ == nil || types.IsInterface(typ), false + } + } + } + + // Types do not need to match for index expresions. + if ix, ok := parent.(*ast.IndexExpr); ok { + if ix.Index == expr { + typ := info.TypeOf(ix.X) + if typ == nil { + return true, true, false + } + m, _ := typeparams.CoreType(typ).(*types.Map) + return true, m == nil || types.IsInterface(m.Key()), false + } + } + + // Types do not need to match for composite literal keys, values, or + // fields. + if kv, ok := parent.(*ast.KeyValueExpr); ok { + var under types.Type + if len(remaining) > 0 { + if complit, ok := remaining[len(remaining)-1].(*ast.CompositeLit); ok { + if typ := info.TypeOf(complit); typ != nil { + // Unpointer to allow for pointers to slices or arrays, which are + // permitted as the types of nested composite literals without a type + // name. + under = typesinternal.Unpointer(typeparams.CoreType(typ)) + } + } + } + if kv.Key == expr { // M{expr: ...}: assign to map key + m, _ := under.(*types.Map) + return true, m == nil || types.IsInterface(m.Key()), false + } + if kv.Value == expr { + switch under := under.(type) { + case interface{ Elem() types.Type }: // T{...: expr}: assign to map/array/slice element + return true, types.IsInterface(under.Elem()), false + case *types.Struct: // Struct{k: expr} + if id, _ := kv.Key.(*ast.Ident); id != nil { + for fi := 0; fi < under.NumFields(); fi++ { + field := under.Field(fi) + if info.Uses[id] == field { + return true, types.IsInterface(field.Type()), false + } + } + } + default: + return true, true, false + } + } + } + if lit, ok := parent.(*ast.CompositeLit); ok { + for i, v := range lit.Elts { + if v == expr { + typ := info.TypeOf(lit) + if typ == nil { + return true, true, false + } + // As in the KeyValueExpr case above, unpointer to handle pointers to + // array/slice literals. + under := typesinternal.Unpointer(typeparams.CoreType(typ)) + switch under := under.(type) { + case interface{ Elem() types.Type }: // T{expr}: assign to map/array/slice element + return true, types.IsInterface(under.Elem()), false + case *types.Struct: // Struct{expr}: assign to unkeyed struct field + if i < under.NumFields() { + return true, types.IsInterface(under.Field(i).Type()), false + } + } + return true, true, false + } + } + } + + // Types do not need to match for values sent to a channel. + if send, ok := parent.(*ast.SendStmt); ok { + if send.Value == expr { + typ := info.TypeOf(send.Chan) + if typ == nil { + return true, true, false + } + ch, _ := typeparams.CoreType(typ).(*types.Chan) + return true, ch == nil || types.IsInterface(ch.Elem()), false + } + } + + // Types do not need to match for an argument to a call, unless the + // corresponding parameter has type parameters, as in that case the + // argument type may affect inference. + if call, ok := parent.(*ast.CallExpr); ok { + if _, ok := isConversion(info, call); ok { + return false, false, false // redundant conversions are handled at the call site + } + // Ordinary call. Could be a call of a func, builtin, or function value. + for i, arg := range call.Args { + if arg == expr { + typ := info.TypeOf(call.Fun) + if typ == nil { + return true, true, false + } + sig, _ := typeparams.CoreType(typ).(*types.Signature) + if sig != nil { + // Find the relevant parameter type, accounting for variadics. + paramType := paramTypeAtIndex(sig, call, i) + ifaceAssign := paramType == nil || types.IsInterface(paramType) + affectsInference := false + if fn := typeutil.StaticCallee(info, call); fn != nil { + if sig2 := fn.Type().(*types.Signature); sig2.Recv() == nil { + originParamType := paramTypeAtIndex(sig2, call, i) + affectsInference = originParamType == nil || new(typeparams.Free).Has(originParamType) + } + } + return true, ifaceAssign, affectsInference + } + } + } + } + + return false, false, false +} + +// paramTypeAtIndex returns the effective parameter type at the given argument +// index in call, if valid. +func paramTypeAtIndex(sig *types.Signature, call *ast.CallExpr, index int) types.Type { + if plen := sig.Params().Len(); sig.Variadic() && index >= plen-1 && !call.Ellipsis.IsValid() { + if s, ok := sig.Params().At(plen - 1).Type().(*types.Slice); ok { + return s.Elem() + } + } else if index < plen { + return sig.Params().At(index).Type() + } + return nil // ill typed +} + +// exprContext returns the innermost parent->child expression nodes for the +// given outer-to-inner stack, after stripping parentheses, along with the +// remaining stack up to the parent node. +// +// If no such context exists, returns (nil, nil). +func exprContext(stack []ast.Node) (remaining []ast.Node, parent ast.Node, expr ast.Expr) { + expr, _ = stack[len(stack)-1].(ast.Expr) + if expr == nil { + return nil, nil, nil + } + i := len(stack) - 2 + for ; i >= 0; i-- { + if pexpr, ok := stack[i].(*ast.ParenExpr); ok { + expr = pexpr + } else { + parent = stack[i] + break + } + } + if parent == nil { + return nil, nil, nil + } + // inv: i is the index of parent in the stack. + return stack[:i], parent, expr +} + +// isSelectionOperand reports whether the innermost node of stack is operand +// (x) of a selection x.f. +func isSelectionOperand(stack []ast.Node) bool { + _, parent, expr := exprContext(stack) + if parent == nil { + return false + } + sel, ok := parent.(*ast.SelectorExpr) + return ok && sel.X == expr +} + +// A shadowMap records information about shadowing at any of the parameter's +// references within the callee decl. +// +// For each name shadowed at a reference to the parameter within the callee +// body, shadow map records the 1-based index of the callee decl parameter +// causing the shadowing, or -1, if the shadowing is not due to a callee decl. +// A value of zero (or missing) indicates no shadowing. By convention, +// self-shadowing is excluded from the map. +// +// For example, in the following callee +// +// func f(a, b int) int { +// c := 2 + b +// return a + c +// } +// +// the shadow map of a is {b: 2, c: -1}, because b is shadowed by the 2nd +// parameter. The shadow map of b is {a: 1}, because c is not shadowed at the +// use of b. +type shadowMap map[string]int + +// addShadows returns the [shadowMap] augmented by the set of names // locally shadowed at the location of the reference in the callee // (identified by the stack). The name of the reference itself is // excluded. // // These shadowed names may not be used in a replacement expression // for the reference. -func addShadows(shadows map[string]bool, info *types.Info, exclude string, stack []ast.Node) map[string]bool { +func (s shadowMap) add(info *types.Info, paramIndexes map[types.Object]int, exclude string, stack []ast.Node) shadowMap { for _, n := range stack { if scope := scopeFor(info, n); scope != nil { for _, name := range scope.Names() { if name != exclude { - if shadows == nil { - shadows = make(map[string]bool) + if s == nil { + s = make(shadowMap) + } + obj := scope.Lookup(name) + if idx, ok := paramIndexes[obj]; ok { + s[name] = idx + 1 + } else { + s[name] = -1 } - shadows[name] = true } } } } - return shadows + return s +} + +// fieldObjs returns a map of each types.Object defined by the given signature +// to its index in the parameter list. Parameters with missing or blank name +// are skipped. +func fieldObjs(sig *types.Signature) map[types.Object]int { + m := make(map[types.Object]int) + for i := range sig.Params().Len() { + if p := sig.Params().At(i); p.Name() != "" && p.Name() != "_" { + m[p] = i + } + } + return m } func isField(obj types.Object) bool { diff --git a/internal/refactor/inline/falcon.go b/internal/refactor/inline/falcon.go index 9154c5093fb..b62a32e7430 100644 --- a/internal/refactor/inline/falcon.go +++ b/internal/refactor/inline/falcon.go @@ -446,7 +446,7 @@ func (st *falconState) expr(e ast.Expr) (res any) { // = types.TypeAndValue | as // - for an array or *array, use [n]int. // The last two entail progressively stronger index checks. var ct ast.Expr // type syntax for constraint - switch t := t.(type) { + switch t := typeparams.CoreType(t).(type) { case *types.Map: if types.IsInterface(t.Key()) { ct = &ast.MapType{ @@ -465,7 +465,7 @@ func (st *falconState) expr(e ast.Expr) (res any) { // = types.TypeAndValue | as Elt: makeIdent(st.int), } default: - panic(t) + panic(fmt.Sprintf("%T: %v", t, t)) } st.emitUnique(ct, uniques) } diff --git a/internal/refactor/inline/inline.go b/internal/refactor/inline/inline.go index 0fda1c579f9..c981599b5b0 100644 --- a/internal/refactor/inline/inline.go +++ b/internal/refactor/inline/inline.go @@ -11,6 +11,7 @@ import ( "go/constant" "go/format" "go/parser" + "go/printer" "go/token" "go/types" pathpkg "path" @@ -41,11 +42,13 @@ type Caller struct { enclosingFunc *ast.FuncDecl // top-level function/method enclosing the call, if any } +type logger = func(string, ...any) + // Options specifies parameters affecting the inliner algorithm. // All fields are optional. type Options struct { - Logf func(string, ...any) // log output function, records decision-making process - IgnoreEffects bool // ignore potential side effects of arguments (unsound) + Logf logger // log output function, records decision-making process + IgnoreEffects bool // ignore potential side effects of arguments (unsound) } // Result holds the result of code transformation. @@ -200,17 +203,37 @@ func (st *state) inline() (*Result, error) { } } + // File rewriting. This proceeds in multiple passes, in order to maximally + // preserve comment positioning. (This could be greatly simplified once + // comments are stored in the tree.) + // // Don't call replaceNode(caller.File, res.old, res.new) // as it mutates the caller's syntax tree. // Instead, splice the file, replacing the extent of the "old" // node by a formatting of the "new" node, and re-parse. // We'll fix up the imports on this new tree, and format again. - var f *ast.File + // + // Inv: f is the result of parsing content, using fset. + var ( + content = caller.Content + fset = caller.Fset + f *ast.File // parsed below + ) + reparse := func() error { + const mode = parser.ParseComments | parser.SkipObjectResolution | parser.AllErrors + f, err = parser.ParseFile(fset, "callee.go", content, mode) + if err != nil { + // Something has gone very wrong. + logf("failed to reparse <<%s>>: %v", string(content), err) // debugging + return err + } + return nil + } { - start := offsetOf(caller.Fset, res.old.Pos()) - end := offsetOf(caller.Fset, res.old.End()) + start := offsetOf(fset, res.old.Pos()) + end := offsetOf(fset, res.old.End()) var out bytes.Buffer - out.Write(caller.Content[:start]) + out.Write(content[:start]) // TODO(adonovan): might it make more sense to use // callee.Fset when formatting res.new? // The new tree is a mix of (cloned) caller nodes for @@ -230,21 +253,18 @@ func (st *state) inline() (*Result, error) { if i > 0 { out.WriteByte('\n') } - if err := format.Node(&out, caller.Fset, stmt); err != nil { + if err := format.Node(&out, fset, stmt); err != nil { return nil, err } } } else { - if err := format.Node(&out, caller.Fset, res.new); err != nil { + if err := format.Node(&out, fset, res.new); err != nil { return nil, err } } - out.Write(caller.Content[end:]) - const mode = parser.ParseComments | parser.SkipObjectResolution | parser.AllErrors - f, err = parser.ParseFile(caller.Fset, "callee.go", &out, mode) - if err != nil { - // Something has gone very wrong. - logf("failed to parse <<%s>>", &out) // debugging + out.Write(content[end:]) + content = out.Bytes() + if err := reparse(); err != nil { return nil, err } } @@ -255,15 +275,58 @@ func (st *state) inline() (*Result, error) { // to avoid migration of pre-import comments. // The imports will be organized below. if len(res.newImports) > 0 { - var importDecl *ast.GenDecl + // If we have imports to add, do so independent of the rest of the file. + // Otherwise, the length of the new imports may consume floating comments, + // causing them to be printed inside the imports block. + var ( + importDecl *ast.GenDecl + comments []*ast.CommentGroup // relevant comments. + before, after []byte // pre- and post-amble for the imports block. + ) if len(f.Imports) > 0 { // Append specs to existing import decl importDecl = f.Decls[0].(*ast.GenDecl) + for _, comment := range f.Comments { + // Filter comments. Don't use CommentMap.Filter here, because we don't + // want to include comments that document the import decl itself, for + // example: + // + // // We don't want this comment to be duplicated. + // import ( + // "something" + // ) + if importDecl.Pos() <= comment.Pos() && comment.Pos() < importDecl.End() { + comments = append(comments, comment) + } + } + before = content[:offsetOf(fset, importDecl.Pos())] + importDecl.Doc = nil // present in before + after = content[offsetOf(fset, importDecl.End()):] } else { // Insert new import decl. importDecl = &ast.GenDecl{Tok: token.IMPORT} f.Decls = prepend[ast.Decl](importDecl, f.Decls...) + + // Make room for the new declaration after the package declaration. + pkgEnd := f.Name.End() + file := fset.File(pkgEnd) + if file == nil { + logf("internal error: missing pkg file") + return nil, fmt.Errorf("missing pkg file for %s", f.Name.Name) + } + // Preserve any comments after the package declaration, by splicing in + // the new import block after the end of the package declaration line. + line := file.Line(pkgEnd) + if line < len(file.Lines()) { // line numbers are 1-based + nextLinePos := file.LineStart(line + 1) + nextLine := offsetOf(fset, nextLinePos) + before = slices.Concat(content[:nextLine], []byte("\n")) + after = slices.Concat([]byte("\n\n"), content[nextLine:]) + } else { + before = slices.Concat(content, []byte("\n\n")) + } } + // Add new imports. for _, imp := range res.newImports { // Check that the new imports are accessible. path, _ := strconv.Unquote(imp.spec.Path.Value) @@ -272,6 +335,21 @@ func (st *state) inline() (*Result, error) { } importDecl.Specs = append(importDecl.Specs, imp.spec) } + var out bytes.Buffer + out.Write(before) + commented := &printer.CommentedNode{ + Node: importDecl, + Comments: comments, + } + if err := format.Node(&out, fset, commented); err != nil { + logf("failed to format new importDecl: %v", err) // debugging + return nil, err + } + out.Write(after) + content = out.Bytes() + if err := reparse(); err != nil { + return nil, err + } } // Delete imports referenced only by caller.Call.Fun. @@ -279,7 +357,8 @@ func (st *state) inline() (*Result, error) { // (We can't let imports.Process take care of it as it may // mistake obsolete imports for missing new imports when the // names are similar, as is common during a package migration.) - for _, specToDelete := range res.oldImports { + for _, oldImport := range res.oldImports { + specToDelete := oldImport.spec for _, decl := range f.Decls { if decl, ok := decl.(*ast.GenDecl); ok && decl.Tok == token.IMPORT { decl.Specs = slicesDeleteFunc(decl.Specs, func(spec ast.Spec) bool { @@ -375,17 +454,23 @@ func (st *state) inline() (*Result, error) { Content: newSrc, Literalized: literalized, }, nil +} +// An oldImport is an import that will be deleted from the caller file. +type oldImport struct { + pkgName *types.PkgName + spec *ast.ImportSpec } +// A newImport is an import that will be added to the caller file. type newImport struct { pkgName string spec *ast.ImportSpec } type inlineCallResult struct { - newImports []newImport // to add - oldImports []*ast.ImportSpec // to remove + newImports []newImport // to add + oldImports []oldImport // to remove // If elideBraces is set, old is an ast.Stmt and new is an ast.BlockStmt to // be spliced in. This allows the inlining analysis to assert that inlining @@ -412,6 +497,9 @@ type inlineCallResult struct { // transformation replacing the call and adding new variable // declarations, for example, or replacing a call statement by zero or // many statements.) +// NOTE(rfindley): we've sort-of done this, with the 'elideBraces' flag that +// allows inlining a statement list. However, due to loss of comments, more +// sophisticated rewrites are challenging. // // TODO(adonovan): in earlier drafts, the transformation was expressed // by splicing substrings of the two source files because syntax @@ -421,6 +509,33 @@ type inlineCallResult struct { // candidate for evaluating an alternative fully self-contained tree // representation, such as any proposed solution to #20744, or even // dst or some private fork of go/ast.) +// TODO(rfindley): see if we can reduce the amount of comment lossiness by +// using printer.CommentedNode, which has been useful elsewhere. +// +// TODO(rfindley): inlineCall is getting very long, and very stateful, making +// it very hard to read. The following refactoring may improve readability and +// maintainability: +// - Rename 'state' to 'callsite', since that is what it encapsulates. +// - Add results of pre-processing analysis into the callsite struct, such as +// the effective importMap, new/old imports, arguments, etc. Essentially +// anything that resulted from initial analysis of the call site, and which +// may be useful to inlining strategies. +// - Delegate this call site analysis to a constructor or initializer, such +// as 'analyzeCallsite', so that it does not consume bandwidth in the +// 'inlineCall' logical flow. +// - Once analyzeCallsite returns, the callsite is immutable, much in the +// same way as the Callee and Caller are immutable. +// - Decide on a standard interface for strategies (and substrategies), such +// that they may be delegated to a separate method on callsite. +// +// In this way, the logical flow of inline call will clearly follow the +// following structure: +// 1. Analyze the call site. +// 2. Try strategies, in order, until one succeeds. +// 3. Process the results. +// +// If any expensive analysis may be avoided by earlier strategies, it can be +// encapsulated in its own type and passed to subsequent strategies. func (st *state) inlineCall() (*inlineCallResult, error) { logf, caller, callee := st.opts.Logf, st.caller, &st.callee.impl @@ -469,39 +584,83 @@ func (st *state) inlineCall() (*inlineCallResult, error) { assign1 = func(v *types.Var) bool { return !updatedLocals[v] } } - // import map, initially populated with caller imports. + // import map, initially populated with caller imports, and updated below + // with new imports necessary to reference free symbols in the callee. // - // For simplicity we ignore existing dot imports, so that a - // qualified identifier (QI) in the callee is always - // represented by a QI in the caller, allowing us to treat a - // QI like a selection on a package name. + // For simplicity we ignore existing dot imports, so that a qualified + // identifier (QI) in the callee is always represented by a QI in the caller, + // allowing us to treat a QI like a selection on a package name. importMap := make(map[string][]string) // maps package path to local name(s) + var oldImports []oldImport // imports referenced only by caller.Call.Fun + for _, imp := range caller.File.Imports { - if pkgname, ok := importedPkgName(caller.Info, imp); ok && - pkgname.Name() != "." && - pkgname.Name() != "_" { - path := pkgname.Imported().Path() - importMap[path] = append(importMap[path], pkgname.Name()) + if pkgName, ok := importedPkgName(caller.Info, imp); ok && + pkgName.Name() != "." && + pkgName.Name() != "_" { + + // If the import's sole use is in caller.Call.Fun of the form p.F(...), + // where p.F is a qualified identifier, the p import may not be + // necessary. + // + // Only the qualified identifier case matters, as other references to + // imported package names in the Call.Fun expression (e.g. + // x.after(3*time.Second).f() or time.Second.String()) will remain after + // inlining, as arguments. + // + // If that is the case, proactively check if any of the callee FreeObjs + // need this import. Doing so eagerly simplifies the resulting logic. + needed := true + sel, ok := ast.Unparen(caller.Call.Fun).(*ast.SelectorExpr) + if ok && soleUse(caller.Info, pkgName) == sel.X { + needed = false // no longer needed by caller + // Check to see if any of the inlined free objects need this package. + for _, obj := range callee.FreeObjs { + if obj.PkgPath == pkgName.Imported().Path() && obj.Shadow[pkgName.Name()] == 0 { + needed = true // needed by callee + break + } + } + } + + if needed { + path := pkgName.Imported().Path() + importMap[path] = append(importMap[path], pkgName.Name()) + } else { + oldImports = append(oldImports, oldImport{pkgName: pkgName, spec: imp}) + } } } - var oldImports []*ast.ImportSpec // imports referenced only caller.Call.Fun - - // localImportName returns the local name for a given imported package path. - var newImports []newImport - localImportName := func(obj *object) string { - // Does an import exist? - for _, name := range importMap[obj.PkgPath] { - // Check that either the import preexisted, - // or that it was newly added (no PkgName) but is not shadowed, - // either in the callee (shadows) or caller (caller.lookup). - if !obj.Shadow[name] { + // importName finds an existing import name to use in a particular shadowing + // context. It is used to determine the set of new imports in + // getOrMakeImportName, and is also used for writing out names in inlining + // strategies below. + importName := func(pkgPath string, shadow shadowMap) string { + for _, name := range importMap[pkgPath] { + // Check that either the import preexisted, or that it was newly added + // (no PkgName) but is not shadowed, either in the callee (shadows) or + // caller (caller.lookup). + if shadow[name] == 0 { found := caller.lookup(name) if is[*types.PkgName](found) || found == nil { return name } } } + return "" + } + + // keep track of new imports that are necessary to reference any free names + // in the callee. + var newImports []newImport + + // getOrMakeImportName returns the local name for a given imported package path, + // adding one if it doesn't exists. + getOrMakeImportName := func(pkgPath, pkgName string, shadow shadowMap) string { + // Does an import already exist that works in this shadowing context? + if name := importName(pkgPath, shadow); name != "" { + return name + } newlyAdded := func(name string) bool { for _, new := range newImports { @@ -515,33 +674,17 @@ func (st *state) inlineCall() (*inlineCallResult, error) { // shadowedInCaller reports whether a candidate package name // already refers to a declaration in the caller. shadowedInCaller := func(name string) bool { - existing := caller.lookup(name) - - // If the candidate refers to a PkgName p whose sole use is - // in caller.Call.Fun of the form p.F(...), where p.F is a - // qualified identifier, the p import will be deleted, - // so it's safe (and better) to recycle the name. - // - // Only the qualified identifier case matters, as other - // references to imported package names in the Call.Fun - // expression (e.g. x.after(3*time.Second).f() - // or time.Second.String()) will remain after - // inlining, as arguments. - if pkgName, ok := existing.(*types.PkgName); ok { - if sel, ok := ast.Unparen(caller.Call.Fun).(*ast.SelectorExpr); ok { - if sole := soleUse(caller.Info, pkgName); sole == sel.X { - for _, spec := range caller.File.Imports { - pkgName2, ok := importedPkgName(caller.Info, spec) - if ok && pkgName2 == pkgName { - oldImports = append(oldImports, spec) - return false - } - } - } + obj := caller.lookup(name) + if obj == nil { + return false + } + // If obj will be removed, the name is available. + for _, old := range oldImports { + if old.pkgName == obj { + return false } } - - return existing != nil + return true } // import added by callee @@ -555,29 +698,28 @@ func (st *state) inlineCall() (*inlineCallResult, error) { // TODO(rfindley): is it worth preserving local package names for callee // imports? Are they likely to be better or worse than the name we choose // here? - base := obj.PkgName + base := pkgName name := base - for n := 0; obj.Shadow[name] || shadowedInCaller(name) || newlyAdded(name) || name == "init"; n++ { + for n := 0; shadow[name] != 0 || shadowedInCaller(name) || newlyAdded(name) || name == "init"; n++ { name = fmt.Sprintf("%s%d", base, n) } - - logf("adding import %s %q", name, obj.PkgPath) + logf("adding import %s %q", name, pkgPath) spec := &ast.ImportSpec{ Path: &ast.BasicLit{ Kind: token.STRING, - Value: strconv.Quote(obj.PkgPath), + Value: strconv.Quote(pkgPath), }, } // Use explicit pkgname (out of necessity) when it differs from the declared name, // or (for good style) when it differs from base(pkgpath). - if name != obj.PkgName || name != pathpkg.Base(obj.PkgPath) { + if name != pkgName || name != pathpkg.Base(pkgPath) { spec.Name = makeIdent(name) } newImports = append(newImports, newImport{ pkgName: name, spec: spec, }) - importMap[obj.PkgPath] = append(importMap[obj.PkgPath], name) + importMap[pkgPath] = append(importMap[pkgPath], name) return name } @@ -607,7 +749,8 @@ func (st *state) inlineCall() (*inlineCallResult, error) { var newName ast.Expr if obj.Kind == "pkgname" { // Use locally appropriate import, creating as needed. - newName = makeIdent(localImportName(&obj)) // imported package + n := getOrMakeImportName(obj.PkgPath, obj.PkgName, obj.Shadow) + newName = makeIdent(n) // imported package } else if !obj.ValidPos { // Built-in function, type, or value (e.g. nil, zero): // check not shadowed at caller. @@ -651,7 +794,7 @@ func (st *state) inlineCall() (*inlineCallResult, error) { // Form a qualified identifier, pkg.Name. if qualify { - pkgName := localImportName(&obj) + pkgName := getOrMakeImportName(obj.PkgPath, obj.PkgName, obj.Shadow) newName = &ast.SelectorExpr{ X: makeIdent(pkgName), Sel: makeIdent(obj.Name), @@ -672,11 +815,22 @@ func (st *state) inlineCall() (*inlineCallResult, error) { return nil, err // "can't happen" } - // replaceCalleeID replaces an identifier in the callee. - // The replacement tree must not belong to the caller; use cloneNode as needed. - replaceCalleeID := func(offset int, repl ast.Expr) { - id := findIdent(calleeDecl, calleeDecl.Pos()+token.Pos(offset)) + // replaceCalleeID replaces an identifier in the callee. See [replacer] for + // more detailed semantics. + replaceCalleeID := func(offset int, repl ast.Expr, unpackVariadic bool) { + path, id := findIdent(calleeDecl, calleeDecl.Pos()+token.Pos(offset)) logf("- replace id %q @ #%d to %q", id.Name, offset, debugFormatNode(calleeFset, repl)) + // Replace f([]T{a, b, c}...) with f(a, b, c). + if lit, ok := repl.(*ast.CompositeLit); ok && unpackVariadic && len(path) > 0 { + if call, ok := last(path).(*ast.CallExpr); ok && + call.Ellipsis.IsValid() && + id == last(call.Args) { + + call.Args = append(call.Args[:len(call.Args)-1], lit.Elts...) + call.Ellipsis = token.NoPos + return + } + } replaceNode(calleeDecl, id, repl) } @@ -684,7 +838,7 @@ func (st *state) inlineCall() (*inlineCallResult, error) { // (The same tree may be spliced in multiple times, resulting in a DAG.) for _, ref := range callee.FreeRefs { if repl := objRenames[ref.Object]; repl != nil { - replaceCalleeID(ref.Offset, repl) + replaceCalleeID(ref.Offset, repl, false) } } @@ -760,14 +914,22 @@ func (st *state) inlineCall() (*inlineCallResult, error) { // nop } else { // ordinary call: f(a1, ... aN) -> f([]T{a1, ..., aN}) + // + // Substitution of []T{...} in the callee body may lead to + // g([]T{a1, ..., aN}...), which we simplify to g(a1, ..., an) + // later; see replaceCalleeID. n := len(params) - 1 ordinary, extra := args[:n], args[n:] var elts []ast.Expr + freevars := make(map[string]bool) pure, effects := true, false for _, arg := range extra { elts = append(elts, arg.expr) pure = pure && arg.pure effects = effects || arg.effects + for k, v := range arg.freevars { + freevars[k] = v + } } args = append(ordinary, &argument{ expr: &ast.CompositeLit{ @@ -779,7 +941,8 @@ func (st *state) inlineCall() (*inlineCallResult, error) { pure: pure, effects: effects, duplicable: false, - freevars: nil, // not needed + freevars: freevars, + variadic: true, }) } } @@ -988,8 +1151,9 @@ func (st *state) inlineCall() (*inlineCallResult, error) { (!needBindingDecl || (bindingDecl != nil && len(bindingDecl.names) == 0)) { // Reduces to: { var (bindings); lhs... := rhs... } - if newStmts, ok := st.assignStmts(stmt, results); ok { + if newStmts, ok := st.assignStmts(stmt, results, importName); ok { logf("strategy: reduce assign-context call to { return exprs }") + clearPositions(calleeDecl.Body) block := &ast.BlockStmt{ @@ -1190,6 +1354,10 @@ func (st *state) inlineCall() (*inlineCallResult, error) { Type: calleeDecl.Type, Body: calleeDecl.Body, } + // clear positions before prepending the binding decl below, since the + // binding decl contains syntax from the caller and we must not mutate the + // caller. (This was a prior bug.) + clearPositions(funcLit) // Literalization can still make use of a binding // decl as it gives a more natural reading order: @@ -1211,7 +1379,6 @@ func (st *state) inlineCall() (*inlineCallResult, error) { Ellipsis: token.NoPos, // f(slice...) is always simplified Args: remainingArgs, } - clearPositions(newCall.Fun) res.old = caller.Call res.new = newCall return res, nil @@ -1226,7 +1393,8 @@ type argument struct { effects bool // expr has effects (updates variables) duplicable bool // expr may be duplicated freevars map[string]bool // free names of expr - substitutable bool // is candidate for substitution + variadic bool // is explicit []T{...} for eliminated variadic + desugaredRecv bool // is *recv or &recv, where operator was elided } // arguments returns the effective arguments of the call. @@ -1320,12 +1488,14 @@ func (st *state) arguments(caller *Caller, calleeDecl *ast.FuncDecl, assign1 fun // &recv arg.expr = &ast.UnaryExpr{Op: token.AND, X: arg.expr} arg.typ = types.NewPointer(arg.typ) + arg.desugaredRecv = true } else if argIsPtr && !paramIsPtr { // *recv arg.expr = &ast.StarExpr{X: arg.expr} arg.typ = typeparams.Deref(arg.typ) arg.duplicable = false arg.pure = false + arg.desugaredRecv = true } } } @@ -1382,6 +1552,12 @@ type parameter struct { variadic bool // (final) parameter is unsimplified ...T } +// A replacer replaces an identifier at the given offset in the callee. +// The replacement tree must not belong to the caller; use cloneNode as needed. +// If unpackVariadic is set, the replacement is a composite resulting from +// variadic elimination, and may be unpackeded into variadic calls. +type replacer = func(offset int, repl ast.Expr, unpackVariadic bool) + // substitute implements parameter elimination by substitution. // // It considers each parameter and its corresponding argument in turn @@ -1401,7 +1577,7 @@ type parameter struct { // parameter, and is provided with its relative offset and replacement // expression (argument), and the corresponding elements of params and // args are replaced by nil. -func substitute(logf func(string, ...any), caller *Caller, params []*parameter, args []*argument, effects []int, falcon falconResult, replaceCalleeID func(offset int, repl ast.Expr)) { +func substitute(logf logger, caller *Caller, params []*parameter, args []*argument, effects []int, falcon falconResult, replace replacer) { // Inv: // in calls to variadic, len(args) >= len(params)-1 // in spread calls to non-variadic, len(args) < len(params) @@ -1409,9 +1585,24 @@ func substitute(logf func(string, ...any), caller *Caller, params []*parameter, // (In spread calls len(args) = 1, or 2 if call has receiver.) // Non-spread variadics have been simplified away already, // so the args[i] lookup is safe if we stop after the spread arg. + assert(len(args) <= len(params), "too many arguments") + + // Collect candidates for substitution. + // + // An argument is a candidate if it is not otherwise rejected, and any free + // variables that are shadowed only by other parameters. + // + // Therefore, substitution candidates are represented by a graph, where edges + // lead from each argument to the other arguments that, if substituted, would + // allow the argument to be substituted. We collect these edges in the + // [substGraph]. Any node that is known not to be elided from the graph. + // Arguments in this graph with no edges are substitutable independent of + // other nodes, though they may be removed due to falcon or effects analysis. + sg := make(substGraph) next: for i, param := range params { arg := args[i] + // Check argument against parameter. // // Beware: don't use types.Info on arg since @@ -1453,78 +1644,143 @@ next: // references among other arguments which have non-zero references // within the callee. if v, ok := caller.lookup(free).(*types.Var); ok && within(v.Pos(), caller.enclosingFunc.Body) && !isUsedOutsideCall(caller, v) { - logf("keeping param %q: arg contains perhaps the last reference to caller local %v @ %v", - param.info.Name, v, caller.Fset.PositionFor(v.Pos(), false)) - continue next + + // Check to see if the substituted var is used within other args + // whose corresponding params ARE used in the callee + usedElsewhere := func() bool { + for i, param := range params { + if i < len(args) && len(param.info.Refs) > 0 { // excludes original param + for name := range args[i].freevars { + if caller.lookup(name) == v { + return true + } + } + } + } + return false + } + if !usedElsewhere() { + logf("keeping param %q: arg contains perhaps the last reference to caller local %v @ %v", + param.info.Name, v, caller.Fset.PositionFor(v.Pos(), false)) + continue next + } } } } } - // Check for shadowing. + // Arg is a potential substition candidate: analyze its shadowing. // // Consider inlining a call f(z, 1) to - // func f(x, y int) int { z := y; return x + y + z }: + // + // func f(x, y int) int { z := y; return x + y + z } + // // we can't replace x in the body by z (or any - // expression that has z as a free identifier) - // because there's an intervening declaration of z - // that would shadow the caller's one. + // expression that has z as a free identifier) because there's an + // intervening declaration of z that would shadow the caller's one. + // + // However, we *could* replace x in the body by y, as long as the y + // parameter is also removed by substitution. + + sg[arg] = nil // Absent shadowing, the arg is substitutable. + for free := range arg.freevars { - if param.info.Shadow[free] { - logf("keeping param %q: cannot replace with argument as it has free ref to %s that is shadowed", param.info.Name, free) - continue next // shadowing conflict + switch s := param.info.Shadow[free]; { + case s < 0: + // Shadowed by a non-parameter symbol, so arg is not substitutable. + delete(sg, arg) + case s > 0: + // Shadowed by a parameter; arg may be substitutable, if only shadowed + // by other substitutable parameters. + if s > len(args) { + // Defensive: this should not happen in the current factoring, since + // spread arguments are already handled. + delete(sg, arg) + } + if edges, ok := sg[arg]; ok { + sg[arg] = append(edges, args[s-1]) + } } } - - arg.substitutable = true // may be substituted, if effects permit } - // Reject constant arguments as substitution candidates - // if they cause violation of falcon constraints. - checkFalconConstraints(logf, params, args, falcon) + // Process the initial state of the substitution graph. + sg.prune() + + // Now we check various conditions on the substituted argument set as a + // whole. These conditions reject substitution candidates, but since their + // analysis depends on the full set of candidates, we do not process side + // effects of their candidate rejection until after the analysis completes, + // in a call to prune. After pruning, we must re-run the analysis to check + // for additional rejections. + // + // Here's an example of that in practice: + // + // var a [3]int + // + // func falcon(x, y, z int) { + // _ = x + a[y+z] + // } + // + // func _() { + // var y int + // const x, z = 1, 2 + // falcon(y, x, z) + // } + // + // In this example, arguments 0 and 1 are shadowed by each other's + // corresponding parameter, and so each can be substituted only if they are + // both substituted. But the fallible constant analysis finds a violated + // constraint: x + z = 3, and so the constant array index would cause a + // compile-time error if argument 1 (x) were substituted. Therefore, + // following the falcon analysis, we must also prune argument 0. + // + // As far as I (rfindley) can tell, the falcon analysis should always succeed + // after the first pass, as it's not possible for additional bindings to + // cause new constraint failures. Nevertheless, we re-run it to be sure. + // + // However, the same cannot be said of the effects analysis, as demonstrated + // by this example: + // + // func effects(w, x, y, z int) { + // _ = x + w + y + z + // } + + // func _() { + // v := 0 + // w := func() int { v++; return 0 } + // x := func() int { v++; return 0 } + // y := func() int { v++; return 0 } + // effects(x(), w(), y(), x()) //@ inline(re"effects", effects) + // } + // + // In this example, arguments 0, 1, and 3 are related by the substitution + // graph. The first effects analysis implies that arguments 0 and 1 must be + // bound, and therefore argument 3 must be bound. But then a subsequent + // effects analysis forces argument 2 to also be bound. + + // Reject constant arguments as substitution candidates if they cause + // violation of falcon constraints. + // + // Keep redoing the analysis until we no longer reject additional arguments, + // as the set of substituted parameters affects the falcon package. + for checkFalconConstraints(logf, params, args, falcon, sg) { + sg.prune() + } // As a final step, introduce bindings to resolve any // evaluation order hazards. This must be done last, as // additional subsequent bindings could introduce new hazards. - resolveEffects(logf, args, effects) + // + // As with the falcon analysis, keep redoing the analysis until the no more + // arguments are rejected. + for resolveEffects(logf, args, effects, sg) { + sg.prune() + } // The remaining candidates are safe to substitute. for i, param := range params { - if arg := args[i]; arg.substitutable { - - // Wrap the argument in an explicit conversion if - // substitution might materially change its type. - // (We already did the necessary shadowing check - // on the parameter type syntax.) - // - // This is only needed for substituted arguments. All - // other arguments are given explicit types in either - // a binding decl or when using the literalization - // strategy. - - // If the types are identical, we can eliminate - // redundant type conversions such as this: - // - // Callee: - // func f(i int32) { print(i) } - // Caller: - // func g() { f(int32(1)) } - // Inlined as: - // func g() { print(int32(int32(1))) - // - // Recall that non-trivial does not imply non-identical - // for constant conversions; however, at this point state.arguments - // has already re-typechecked the constant and set arg.type to - // its (possibly "untyped") inherent type, so - // the conversion from untyped 1 to int32 is non-trivial even - // though both arg and param have identical types (int32). - if len(param.info.Refs) > 0 && - !types.Identical(arg.typ, param.obj.Type()) && - !trivialConversion(arg.constant, arg.typ, param.obj.Type()) { - arg.expr = convert(param.fieldType, arg.expr) - logf("param %q: adding explicit %s -> %s conversion around argument", - param.info.Name, arg.typ, param.obj.Type()) - } + if arg := args[i]; sg.has(arg) { // It is safe to substitute param and replace it with arg. // The formatter introduces parens as needed for precedence. @@ -1534,7 +1790,76 @@ next: logf("replacing parameter %q by argument %q", param.info.Name, debugFormatNode(caller.Fset, arg.expr)) for _, ref := range param.info.Refs { - replaceCalleeID(ref, internalastutil.CloneNode(arg.expr).(ast.Expr)) + // Apply any transformations necessary for this reference. + argExpr := arg.expr + + // If the reference itself is being selected, and we applied desugaring + // (an explicit &x or *x), we can undo that desugaring here as it is + // not necessary for a selector. We don't need to check addressability + // here because if we desugared, the receiver must have been + // addressable. + if ref.IsSelectionOperand && arg.desugaredRecv { + switch e := argExpr.(type) { + case *ast.UnaryExpr: + argExpr = e.X + case *ast.StarExpr: + argExpr = e.X + } + } + + // If the reference requires exact type agreement between parameter and + // argument, wrap the argument in an explicit conversion if + // substitution might materially change its type. (We already did the + // necessary shadowing check on the parameter type syntax.) + // + // The types must agree in any of these cases: + // - the argument affects type inference; + // - the reference's concrete type is assigned to an interface type; + // - the reference is not an assignment, nor a trivial conversion of an untyped constant. + // + // In all other cases, no explicit conversion is necessary as either + // the type does not matter, or must have already agreed for well-typed + // code. + // + // This is only needed for substituted arguments. All other arguments + // are given explicit types in either a binding decl or when using the + // literalization strategy. + // + // If the types are identical, we can eliminate + // redundant type conversions such as this: + // + // Callee: + // func f(i int32) { fmt.Println(i) } + // Caller: + // func g() { f(int32(1)) } + // Inlined as: + // func g() { fmt.Println(int32(int32(1))) + // + // Recall that non-trivial does not imply non-identical for constant + // conversions; however, at this point state.arguments has already + // re-typechecked the constant and set arg.type to its (possibly + // "untyped") inherent type, so the conversion from untyped 1 to int32 + // is non-trivial even though both arg and param have identical types + // (int32). + needType := ref.AffectsInference || + (ref.Assignable && ref.IfaceAssignment && !param.info.IsInterface) || + (!ref.Assignable && !trivialConversion(arg.constant, arg.typ, param.obj.Type())) + + if needType && + !types.Identical(types.Default(arg.typ), param.obj.Type()) { + + // If arg.expr is already an interface call, strip it. + if call, ok := argExpr.(*ast.CallExpr); ok && len(call.Args) == 1 { + if typ, ok := isConversion(caller.Info, call); ok && isNonTypeParamInterface(typ) { + argExpr = call.Args[0] + } + } + + argExpr = convert(param.fieldType, argExpr) + logf("param %q (offset %d): adding explicit %s -> %s conversion around argument", + param.info.Name, ref.Offset, arg.typ, param.obj.Type()) + } + replace(ref.Offset, internalastutil.CloneNode(argExpr).(ast.Expr), arg.variadic) } params[i] = nil // substituted args[i] = nil // substituted @@ -1542,6 +1867,23 @@ next: } } +// isConversion reports whether the given call is a type conversion, returning +// (operand, true) if so. +// +// If the call is not a conversion, it returns (nil, false). +func isConversion(info *types.Info, call *ast.CallExpr) (types.Type, bool) { + if tv, ok := info.Types[call.Fun]; ok && tv.IsType() { + return tv.Type, true + } + return nil, false +} + +// isNonTypeParamInterface reports whether t is a non-type parameter interface +// type. +func isNonTypeParamInterface(t types.Type) bool { + return !typeparams.IsTypeParam(t) && types.IsInterface(t) +} + // isUsedOutsideCall reports whether v is used outside of caller.Call, within // the body of caller.enclosingFunc. func isUsedOutsideCall(caller *Caller, v *types.Var) bool { @@ -1579,7 +1921,7 @@ func isUsedOutsideCall(caller *Caller, v *types.Var) bool { // TODO(adonovan): we could obtain a finer result rejecting only the // freevars of each failed constraint, and processing constraints in // order of increasing arity, but failures are quite rare. -func checkFalconConstraints(logf func(string, ...any), params []*parameter, args []*argument, falcon falconResult) { +func checkFalconConstraints(logf logger, params []*parameter, args []*argument, falcon falconResult, sg substGraph) bool { // Create a dummy package, as this is the only // way to create an environment for CheckExpr. pkg := types.NewPackage("falcon", "falcon") @@ -1598,7 +1940,7 @@ func checkFalconConstraints(logf func(string, ...any), params []*parameter, args continue // unreferenced } arg := args[i] - if arg.constant != nil && arg.substitutable && param.info.FalconType != "" { + if arg.constant != nil && sg.has(arg) && param.info.FalconType != "" { t := pkg.Scope().Lookup(param.info.FalconType).Type() pkg.Scope().Insert(types.NewConst(token.NoPos, pkg, name, t, arg.constant)) logf("falcon env: const %s %s = %v", name, param.info.FalconType, arg.constant) @@ -1609,11 +1951,12 @@ func checkFalconConstraints(logf func(string, ...any), params []*parameter, args } } if nconst == 0 { - return // nothing to do + return false // nothing to do } // Parse and evaluate the constraints in the environment. fset := token.NewFileSet() + removed := false for _, falcon := range falcon.Constraints { expr, err := parser.ParseExprFrom(fset, "falcon", falcon, 0) if err != nil { @@ -1622,15 +1965,16 @@ func checkFalconConstraints(logf func(string, ...any), params []*parameter, args if err := types.CheckExpr(fset, pkg, token.NoPos, expr, nil); err != nil { logf("falcon: constraint %s violated: %v", falcon, err) for j, arg := range args { - if arg.constant != nil && arg.substitutable { + if arg.constant != nil && sg.has(arg) { logf("keeping param %q due falcon violation", params[j].info.Name) - arg.substitutable = false + removed = sg.remove(arg) || removed } } break } logf("falcon: constraint %s satisfied", falcon) } + return removed } // resolveEffects marks arguments as non-substitutable to resolve @@ -1677,7 +2021,7 @@ func checkFalconConstraints(logf func(string, ...any), params []*parameter, args // current argument. Subsequent iterations cannot introduce hazards // with that argument because they can result only in additional // binding of lower-ordered arguments. -func resolveEffects(logf func(string, ...any), args []*argument, effects []int) { +func resolveEffects(logf logger, args []*argument, effects []int, sg substGraph) bool { effectStr := func(effects bool, idx int) string { i := fmt.Sprint(idx) if idx == len(args) { @@ -1685,9 +2029,10 @@ func resolveEffects(logf func(string, ...any), args []*argument, effects []int) } return string("RW"[btoi(effects)]) + i } + removed := false for i := len(args) - 1; i >= 0; i-- { argi := args[i] - if argi.substitutable && !argi.pure { + if sg.has(argi) && !argi.pure { // i is not bound: check whether it must be bound due to hazards. idx := index(effects, i) if idx >= 0 { @@ -1706,25 +2051,111 @@ func resolveEffects(logf func(string, ...any), args []*argument, effects []int) if ji > i && (jw || argi.effects) { // out of order evaluation logf("binding argument %s: preceded by %s", effectStr(argi.effects, i), effectStr(jw, ji)) - argi.substitutable = false + + removed = sg.remove(argi) || removed break } } } } - if !argi.substitutable { + if !sg.has(argi) { for j := 0; j < i; j++ { argj := args[j] if argj.pure { continue } - if (argi.effects || argj.effects) && argj.substitutable { + if (argi.effects || argj.effects) && sg.has(argj) { logf("binding argument %s: %s is bound", effectStr(argj.effects, j), effectStr(argi.effects, i)) - argj.substitutable = false + + removed = sg.remove(argj) || removed + } + } + } + } + return removed +} + +// A substGraph is a directed graph representing arguments that may be +// substituted, provided all of their related arguments (or "dependencies") are +// also substituted. The candidates arguments for substitution are the keys in +// this graph, and the edges represent shadowing of free variables of the key +// by parameters corresponding to the dependency arguments. +// +// Any argument not present as a map key is known not to be substitutable. Some +// arguments may have edges leading to other arguments that are not present in +// the graph. In this case, those arguments also cannot be substituted, because +// they have free variables that are shadowed by parameters that cannot be +// substituted. Calling [substGraph.prune] removes these arguments from the +// graph. +// +// The 'prune' operation is not built into the 'remove' step both because +// analyses (falcon, effects) need local information about each argument +// independent of dependencies, and for the efficiency of pruning once en masse +// after each analysis. +type substGraph map[*argument][]*argument + +// has reports whether arg is a candidate for substitution. +func (g substGraph) has(arg *argument) bool { + _, ok := g[arg] + return ok +} + +// remove marks arg as not substitutable, reporting whether the arg was +// previously substitutable. +// +// remove does not have side effects on other arguments that may be +// unsubstitutable as a result of their dependency being removed. +// Call [substGraph.prune] to propagate these side effects, removing dependent +// arguments. +func (g substGraph) remove(arg *argument) bool { + pre := len(g) + delete(g, arg) + return len(g) < pre +} + +// prune updates the graph to remove any keys that reach other arguments not +// present in the graph. +func (g substGraph) prune() { + // visit visits the forward transitive closure of arg and reports whether any + // missing argument was encountered, removing all nodes on the path to it + // from arg. + // + // The seen map is used for cycle breaking. In the presence of cycles, visit + // may report a false positive for an intermediate argument. For example, + // consider the following graph, where only a and b are candidates for + // substitution (meaning, only a and b are present in the graph). + // + // a ↔ b + // ↓ + // [c] + // + // In this case, starting a visit from a, visit(b, seen) may report 'true', + // because c has not yet been considered. For this reason, we must guarantee + // that visit is called with an empty seen map at least once for each node. + var visit func(*argument, map[*argument]unit) bool + visit = func(arg *argument, seen map[*argument]unit) bool { + deps, ok := g[arg] + if !ok { + return false + } + if _, ok := seen[arg]; !ok { + seen[arg] = unit{} + for _, dep := range deps { + if !visit(dep, seen) { + delete(g, arg) + return false } } } + return true + } + for arg := range g { + // Remove any argument that is, or transitively depends upon, + // an unsubstitutable argument. + // + // Each visitation gets a fresh cycle-breaking set. + visit(arg, make(map[*argument]unit)) } } @@ -1836,7 +2267,7 @@ type bindingDeclInfo struct { // // Strategies may impose additional checks on return // conversions, labels, defer, etc. -func createBindingDecl(logf func(string, ...any), caller *Caller, args []*argument, calleeDecl *ast.FuncDecl, results []*paramInfo) *bindingDeclInfo { +func createBindingDecl(logf logger, caller *Caller, args []*argument, calleeDecl *ast.FuncDecl, results []*paramInfo) *bindingDeclInfo { // Spread calls are tricky as they may not align with the // parameters' field groupings nor types. // For example, given @@ -1903,8 +2334,8 @@ func createBindingDecl(logf func(string, ...any), caller *Caller, args []*argume for _, field := range calleeDecl.Type.Params.List { // Each field (param group) becomes a ValueSpec. spec := &ast.ValueSpec{ - Names: field.Names, - Type: field.Type, + Names: cleanNodes(field.Names), + Type: cleanNode(field.Type), Values: values[:len(field.Names)], } values = values[len(field.Names):] @@ -1935,8 +2366,8 @@ func createBindingDecl(logf func(string, ...any), caller *Caller, args []*argume } if len(names) > 0 { spec := &ast.ValueSpec{ - Names: names, - Type: field.Type, + Names: cleanNodes(names), + Type: cleanNode(field.Type), } if shadow(spec) { return nil @@ -2619,6 +3050,24 @@ func replaceNode(root ast.Node, from, to ast.Node) { } } +// cleanNode returns a clone of node with positions cleared. +// +// It should be used for any callee nodes that are formatted using the caller +// file set. +func cleanNode[T ast.Node](node T) T { + clone := internalastutil.CloneNode(node) + clearPositions(clone) + return clone +} + +func cleanNodes[T ast.Node](nodes []T) []T { + var clean []T + for _, node := range nodes { + clean = append(clean, cleanNode(node)) + } + return clean +} + // clearPositions destroys token.Pos information within the tree rooted at root, // as positions in callee trees may cause caller comments to be emitted prematurely. // @@ -2658,26 +3107,38 @@ func clearPositions(root ast.Node) { }) } -// findIdent returns the Ident beneath root that has the given pos. -func findIdent(root ast.Node, pos token.Pos) *ast.Ident { +// findIdent finds the Ident beneath root that has the given pos. +// It returns the path to the ident (excluding the ident), and the ident +// itself, where the path is the sequence of ast.Nodes encountered in a +// depth-first search to find ident. +func findIdent(root ast.Node, pos token.Pos) ([]ast.Node, *ast.Ident) { // TODO(adonovan): opt: skip subtrees that don't contain pos. - var found *ast.Ident + var ( + path []ast.Node + found *ast.Ident + ) ast.Inspect(root, func(n ast.Node) bool { if found != nil { return false } + if n == nil { + path = path[:len(path)-1] + return false + } if id, ok := n.(*ast.Ident); ok { if id.Pos() == pos { found = id + return true } } + path = append(path, n) return true }) if found == nil { panic(fmt.Sprintf("findIdent %d not found in %s", pos, debugFormatNode(token.NewFileSet(), root))) } - return found + return path, found } func prepend[T any](elem T, slice ...T) []T { @@ -2869,6 +3330,15 @@ func declares(stmts []ast.Stmt) map[string]bool { return names } +// A importNameFunc is used to query local import names in the caller, in a +// particular shadowing context. +// +// The shadow map contains additional names shadowed in the inlined code, at +// the position the local import name is to be used. The shadow map only needs +// to contain newly introduced names in the inlined code; names shadowed at the +// caller are handled automatically. +type importNameFunc = func(pkgPath string, shadow shadowMap) string + // assignStmts rewrites a statement assigning the results of a call into zero // or more statements that assign its return operands, or (nil, false) if no // such rewrite is possible. The set of bindings created by the result of @@ -2911,7 +3381,7 @@ func declares(stmts []ast.Stmt) map[string]bool { // // Note: assignStmts may return (nil, true) if it determines that the rewritten // assignment consists only of _ = nil assignments. -func (st *state) assignStmts(callerStmt *ast.AssignStmt, returnOperands []ast.Expr) ([]ast.Stmt, bool) { +func (st *state) assignStmts(callerStmt *ast.AssignStmt, returnOperands []ast.Expr, importName importNameFunc) ([]ast.Stmt, bool) { logf, caller, callee := st.opts.Logf, st.caller, &st.callee.impl assert(len(callee.Returns) == 1, "unexpected multiple returns") @@ -2999,10 +3469,9 @@ func (st *state) assignStmts(callerStmt *ast.AssignStmt, returnOperands []ast.Ex // // This works as long as we don't need to write any additional type // information. - if callerStmt.Tok == token.ASSIGN && // LHS types already determined before call - len(nonTrivial) == 0 { // no non-trivial conversions to worry about + if len(nonTrivial) == 0 { // no non-trivial conversions to worry about - logf("substrategy: slice assignment") + logf("substrategy: splice assignment") return []ast.Stmt{&ast.AssignStmt{ Lhs: lhs, Tok: callerStmt.Tok, @@ -3014,18 +3483,23 @@ func (st *state) assignStmts(callerStmt *ast.AssignStmt, returnOperands []ast.Ex // Inlining techniques below will need to write type information in order to // preserve the correct types of LHS identifiers. // - // writeType is a simple helper to write out type expressions. + // typeExpr is a simple helper to write out type expressions. It currently + // handles (possibly qualified) type names. + // // TODO(rfindley): - // 1. handle qualified type names (potentially adding new imports) - // 2. expand this to handle more type expressions. - // 3. refactor to share logic with callee rewriting. + // 1. expand this to handle more type expressions. + // 2. refactor to share logic with callee rewriting. universeAny := types.Universe.Lookup("any") - typeExpr := func(typ types.Type, shadows ...map[string]bool) ast.Expr { - var typeName string + typeExpr := func(typ types.Type, shadow shadowMap) ast.Expr { + var ( + typeName string + obj *types.TypeName // nil for basic types + ) switch typ := typ.(type) { case *types.Basic: typeName = typ.Name() case interface{ Obj() *types.TypeName }: // Named, Alias, TypeParam + obj = typ.Obj() typeName = typ.Obj().Name() } @@ -3039,15 +3513,20 @@ func (st *state) assignStmts(callerStmt *ast.AssignStmt, returnOperands []ast.Ex return nil } - for _, shadow := range shadows { - if shadow[typeName] { + if obj == nil || obj.Pkg() == nil || obj.Pkg() == caller.Types { // local type or builtin + if shadow[typeName] != 0 { logf("cannot write shadowed type name %q", typeName) return nil } - } - obj, _ := caller.lookup(typeName).(*types.TypeName) - if obj != nil && types.Identical(obj.Type(), typ) { - return ast.NewIdent(typeName) + obj, _ := caller.lookup(typeName).(*types.TypeName) + if obj != nil && types.Identical(obj.Type(), typ) { + return ast.NewIdent(typeName) + } + } else if pkgName := importName(obj.Pkg().Path(), shadow); pkgName != "" { + return &ast.SelectorExpr{ + X: ast.NewIdent(pkgName), + Sel: ast.NewIdent(typeName), + } } return nil } @@ -3082,7 +3561,7 @@ func (st *state) assignStmts(callerStmt *ast.AssignStmt, returnOperands []ast.Ex var ( specs []ast.Spec specIdxs []int - shadow = make(map[string]bool) + shadow = make(shadowMap) ) failed := false byType.Iterate(func(typ types.Type, v any) { @@ -3166,7 +3645,7 @@ func (st *state) assignStmts(callerStmt *ast.AssignStmt, returnOperands []ast.Ex idx := origIdxs[i] if nonTrivial[idx] && defs[idx] != nil { typ := caller.Info.TypeOf(lhs[i]) - texpr := typeExpr(typ) + texpr := typeExpr(typ, nil) if texpr == nil { return nil, false } diff --git a/internal/refactor/inline/inline_test.go b/internal/refactor/inline/inline_test.go index 8da5fa98cd3..03fb5ccdb17 100644 --- a/internal/refactor/inline/inline_test.go +++ b/internal/refactor/inline/inline_test.go @@ -474,7 +474,7 @@ func TestDuplicable(t *testing.T) { }, { "Implicit conversions from underlying types are duplicable.", - `func f(i I) { print(i, i) }; type I int`, + `func f(i I) { print(i, i) }; type I int; func print(args ...any) {}`, `func _() { f(1) }`, `func _() { print(I(1), I(1)) }`, }, @@ -731,13 +731,25 @@ func TestSubstitution(t *testing.T) { `func _() { var local int; _ = local }`, }, { - "Arguments that are used are detected", + "Arguments that are used by other arguments are detected", `func f(x, y int) { print(x) }`, `func _() { var z int; f(z, z) }`, + `func _() { var z int; print(z) }`, + }, + { + "Arguments that are used by other variadic arguments are detected", + `func f(x int, ys ...int) { print(ys) }`, + `func _() { var z int; f(z, 1, 2, 3, z) }`, + `func _() { var z int; print([]int{1, 2, 3, z}) }`, + }, + { + "Arguments that are used by other variadic arguments are detected, 2", + `func f(x int, ys ...int) { print(ys) }`, + `func _() { var z int; f(z) }`, `func _() { var z int var _ int = z - print(z) + print([]int{}) }`, }, { @@ -1031,11 +1043,17 @@ func TestVariadic(t *testing.T) { `func _(slice []any) { f(slice...) }`, `func _(slice []any) { println(slice) }`, }, + { + "Undo variadic elimination", + `func f(args ...int) []int { return append([]int{1}, args...) }`, + `func _(a, b int) { f(a, b) }`, + `func _(a, b int) { _ = append([]int{1}, a, b) }`, + }, { "Variadic elimination (literalization).", `func f(x any, rest ...any) { defer println(x, rest) }`, // defer => literalization `func _() { f(1, 2, 3) }`, - `func _() { func() { defer println(any(1), []any{2, 3}) }() }`, + `func _() { func() { defer println(1, []any{2, 3}) }() }`, }, { "Variadic elimination (reduction).", @@ -1081,7 +1099,7 @@ func TestParameterBindingDecl(t *testing.T) { `func _() { f(g(0), g(1), g(2), g(3)) }`, `func _() { var w, _ any = g(0), g(1) - println(w, any(g(2)), g(3)) + println(w, g(2), g(3)) }`, }, { @@ -1207,6 +1225,60 @@ func TestEmbeddedFields(t *testing.T) { }) } +func TestSubstitutionGroups(t *testing.T) { + runTests(t, []testcase{ + { + // b -> a + "Basic", + `func f(a, b int) { print(a, b) }`, + `func _() { var a int; f(a, a) }`, + `func _() { var a int; print(a, a) }`, + }, + { + // a <-> b + "Cocycle", + `func f(a, b int) { print(a, b) }`, + `func _() { var a, b int; f(a+b, a+b) }`, + `func _() { var a, b int; print(a+b, a+b) }`, + }, + { + // a <-> b + // a -> c + // Don't compute b as substitutable due to bad cycle traversal. + "Middle cycle", + `func f(a, b, c int) { var d int; print(a, b, c, d) }`, + `func _() { var a, b, c, d int; f(a+b+c, a+b, d) }`, + `func _() { + var a, b, c, d int + { + var a, b, c int = a + b + c, a + b, d + var d int + print(a, b, c, d) + } +}`, + }, + { + // a -> b + // b -> c + // b -> d + // c + // + // Only c should be substitutable. + "Singleton", + `func f(a, b, c, d int) { var e int; print(a, b, c, d, e) }`, + `func _() { var a, b, c, d, e int; f(a+b, c+d, c, e) }`, + `func _() { + var a, b, c, d, e int + { + var a, b, d int = a + b, c + d, e + var e int + print(a, b, c, d, e) + } +}`, + }, + }) +} + func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { runTests(t, []testcase{ { @@ -1295,7 +1367,7 @@ func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { }, { // In this example, the set() call is rejected as a substitution - // candidate due to a shadowing conflict (x). This must entail that the + // candidate due to a shadowing conflict (z). This must entail that the // selection x.y (R) is also rejected, because it is lower numbered. // // Incidentally this program (which panics when executed) illustrates @@ -1303,12 +1375,13 @@ func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { // as x.y are not ordered wrt writes, depending on the compiler. // Changing x.y to identity(x).y forces the ordering and avoids the panic. "Hazards with args already rejected (e.g. due to shadowing) are detected too.", - `func f(x, y int) int { return x + y }; func set[T any](ptr *T, old, new T) int { println(old); *ptr = new; return 0; }`, - `func _() { x := new(struct{ y int }); f(x.y, set(&x, x, nil)) }`, + `func f(x, y int) (z int) { return x + y }; func set[T any](ptr *T, old, new T) int { println(old); *ptr = new; return 0; }`, + `func _() { x := new(struct{ y int }); z := x; f(x.y, set(&x, z, nil)) }`, `func _() { x := new(struct{ y int }) + z := x { - var x, y int = x.y, set(&x, x, nil) + var x, y int = x.y, set(&x, z, nil) _ = x + y } }`, @@ -1341,7 +1414,7 @@ func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { "Defer f() evaluates f() before unknown effects", `func f(int, y any, z int) { defer println(int, y, z) }; func g(int) int`, `func _() { f(g(1), g(2), g(3)) }`, - `func _() { func() { defer println(any(g(1)), any(g(2)), g(3)) }() }`, + `func _() { func() { defer println(g(1), g(2), g(3)) }() }`, }, { "Effects are ignored when IgnoreEffects", @@ -1468,6 +1541,24 @@ func TestSubstitutionPreservesParameterType(t *testing.T) { `func _() { T.f(1) }`, `func _() { T(1).g() }`, }, + { + "Implicit reference is made explicit outside of selector", + `type T int; func (x *T) f() bool { return x == x.id() }; func (x *T) id() *T { return x }`, + `func _() { var t T; _ = t.f() }`, + `func _() { var t T; _ = &t == t.id() }`, + }, + { + "Implicit parenthesized reference is not made explicit in selector", + `type T int; func (x *T) f() bool { return x == (x).id() }; func (x *T) id() *T { return x }`, + `func _() { var t T; _ = t.f() }`, + `func _() { var t T; _ = &t == (t).id() }`, + }, + { + "Implicit dereference is made explicit outside of selector", // TODO(rfindley): avoid unnecessary literalization here + `type T int; func (x T) f() bool { return x == x.id() }; func (x T) id() T { return x }`, + `func _() { var t *T; _ = t.f() }`, + `func _() { var t *T; _ = func() bool { var x T = *t; return x == x.id() }() }`, + }, { "Check for shadowing error on type used in the conversion.", `func f(x T) { _ = &x == (*T)(nil) }; type T int16`, @@ -1481,16 +1572,198 @@ func TestRedundantConversions(t *testing.T) { runTests(t, []testcase{ { "Type conversion must be added if the constant is untyped.", - `func f(i int32) { print(i) }`, + `func f(i int32) { print(i) }; func print(x any) {}`, `func _() { f(1) }`, `func _() { print(int32(1)) }`, }, { "Type conversion must not be added if the constant is typed.", - `func f(i int32) { print(i) }`, + `func f(i int32) { print(i) }; func print(x any) {}`, `func _() { f(int32(1)) }`, `func _() { print(int32(1)) }`, }, + { + "No type conversion for argument to interface parameter", + `type T int; func f(x any) { g(x) }; func g(any) {}`, + `func _() { f(T(1)) }`, + `func _() { g(T(1)) }`, + }, + { + "No type conversion for parenthesized argument to interface parameter", + `type T int; func f(x any) { g((x)) }; func g(any) {}`, + `func _() { f(T(1)) }`, + `func _() { g((T(1))) }`, + }, + { + "Type conversion for argument to type parameter", + `type T int; func f(x any) { g(x) }; func g[P any](P) {}`, + `func _() { f(T(1)) }`, + `func _() { g(any(T(1))) }`, + }, + { + "Strip redundant interface conversions", + `type T interface{ M() }; func f(x any) { g(x) }; func g[P any](P) {}`, + `func _() { f(T(nil)) }`, + `func _() { g(any(nil)) }`, + }, + { + "No type conversion for argument to variadic interface parameter", + `type T int; func f(x ...any) { g(x...) }; func g(...any) {}`, + `func _() { f(T(1)) }`, + `func _() { g(T(1)) }`, + }, + { + "Type conversion for variadic argument", + `type T int; func f(x ...any) { g(x...) }; func g(...any) {}`, + `func _() { f([]any{T(1)}...) }`, + `func _() { g([]any{T(1)}...) }`, + }, + { + "Type conversion for argument to interface channel", + `type T int; var c chan any; func f(x T) { c <- x }`, + `func _() { f(1) }`, + `func _() { c <- T(1) }`, + }, + { + "No type conversion for argument to concrete channel", + `type T int32; var c chan T; func f(x T) { c <- x }`, + `func _() { f(1) }`, + `func _() { c <- 1 }`, + }, + { + "Type conversion for interface map key", + `type T int; var m map[any]any; func f(x T) { m[x] = 1 }`, + `func _() { f(1) }`, + `func _() { m[T(1)] = 1 }`, + }, + { + "No type conversion for interface to interface map key", + `type T int; var m map[any]any; func f(x any) { m[x] = 1 }`, + `func _() { f(T(1)) }`, + `func _() { m[T(1)] = 1 }`, + }, + { + "No type conversion for concrete map key", + `type T int; var m map[T]any; func f(x T) { m[x] = 1 }`, + `func _() { f(1) }`, + `func _() { m[1] = 1 }`, + }, + { + "Type conversion for interface literal key/value", + `type T int; type m map[any]any; func f(x, y T) { _ = m{x: y} }`, + `func _() { f(1, 2) }`, + `func _() { _ = m{T(1): T(2)} }`, + }, + { + "No type conversion for concrete literal key/value", + `type T int; type m map[T]T; func f(x, y T) { _ = m{x: y} }`, + `func _() { f(1, 2) }`, + `func _() { _ = m{1: 2} }`, + }, + { + "Type conversion for interface literal element", + `type T int; type s []any; func f(x T) { _ = s{x} }`, + `func _() { f(1) }`, + `func _() { _ = s{T(1)} }`, + }, + { + "No type conversion for concrete literal element", + `type T int; type s []T; func f(x T) { _ = s{x} }`, + `func _() { f(1) }`, + `func _() { _ = s{1} }`, + }, + { + "Type conversion for interface unkeyed struct field", + `type T int; type s struct{any}; func f(x T) { _ = s{x} }`, + `func _() { f(1) }`, + `func _() { _ = s{T(1)} }`, + }, + { + "No type conversion for concrete unkeyed struct field", + `type T int; type s struct{T}; func f(x T) { _ = s{x} }`, + `func _() { f(1) }`, + `func _() { _ = s{1} }`, + }, + { + "Type conversion for interface field value", + `type T int; type S struct{ F any }; func f(x T) { _ = S{F: x} }`, + `func _() { f(1) }`, + `func _() { _ = S{F: T(1)} }`, + }, + { + "No type conversion for concrete field value", + `type T int; type S struct{ F T }; func f(x T) { _ = S{F: x} }`, + `func _() { f(1) }`, + `func _() { _ = S{F: 1} }`, + }, + { + "Type conversion for argument to interface channel", + `type T int; var c chan any; func f(x any) { c <- x }`, + `func _() { f(T(1)) }`, + `func _() { c <- T(1) }`, + }, + { + "No type conversion for argument to concrete channel", + `type T int32; var c chan T; func f(x T) { c <- x }`, + `func _() { f(1) }`, + `func _() { c <- 1 }`, + }, + { + "No type conversion for assignment to an explicit interface type", + `type T int; func f(x any) { var y any; y = x; _ = y }`, + `func _() { f(T(1)) }`, + `func _() { + var y any + y = T(1) + _ = y +}`, + }, + { + "No type conversion for short variable assignment to an explicit interface type", + `type T int; func f(e error) { var err any; i, err := 1, e; _, _ = i, err }`, + `func _() { f(nil) }`, + `func _() { + var err any + i, err := 1, nil + _, _ = i, err +}`, + }, + { + "No type conversion for initializer of an explicit interface type", + `type T int; func f(x any) { var y any = x; _ = y }`, + `func _() { f(T(1)) }`, + `func _() { + var y any = T(1) + _ = y +}`, + }, + { + "No type conversion for use as a composite literal key", + `type T int; func f(x any) { _ = map[any]any{x: 1} }`, + `func _() { f(T(1)) }`, + `func _() { _ = map[any]any{T(1): 1} }`, + }, + { + "No type conversion for use as a composite literal value", + `type T int; func f(x any) { _ = []any{x} }`, + `func _() { f(T(1)) }`, + `func _() { _ = []any{T(1)} }`, + }, + { + "No type conversion for use as a composite literal field", + `type T int; func f(x any) { _ = struct{ F any }{F: x} }`, + `func _() { f(T(1)) }`, + `func _() { _ = struct{ F any }{F: T(1)} }`, + }, + { + "No type conversion for use in a send statement", + `type T int; func f(x any) { var c chan any; c <- x }`, + `func _() { f(T(1)) }`, + `func _() { + var c chan any + c <- T(1) +}`, + }, }) } @@ -1745,14 +2018,22 @@ func deepHash(n ast.Node) any { visit(v.Elem()) } - case reflect.Array, reflect.Chan, reflect.Func, reflect.Map, reflect.UnsafePointer: - panic(v) // unreachable in AST + case reflect.String: + writeUint64(uint64(v.Len())) + hasher.Write([]byte(v.String())) - default: // bool, string, number - if v.Kind() == reflect.String { // proper framing - writeUint64(uint64(v.Len())) - } + case reflect.Int: + writeUint64(uint64(v.Int())) + + case reflect.Uint: + writeUint64(uint64(v.Uint())) + + case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + // Bools and fixed width numbers can be handled by binary.Write. binary.Write(hasher, le, v.Interface()) + + default: // reflect.Array, reflect.Chan, reflect.Func, reflect.Map, reflect.UnsafePointer, reflect.Uintptr + panic(v) // unreachable in AST } } visit(reflect.ValueOf(n)) @@ -1761,3 +2042,23 @@ func deepHash(n ast.Node) any { hasher.Sum(hash[:0]) return hash } + +func TestDeepHash(t *testing.T) { + // This test reproduces a bug in DeepHash that was encountered during work on + // the inliner. + // + // TODO(rfindley): consider replacing this with a fuzz test. + id := &ast.Ident{ + NamePos: 2, + Name: "t", + } + c := &ast.CallExpr{ + Fun: id, + } + h1 := deepHash(c) + id.NamePos = 1 + h2 := deepHash(c) + if h1 == h2 { + t.Fatal("bad") + } +} diff --git a/internal/refactor/inline/testdata/assignment-splice.txtar b/internal/refactor/inline/testdata/assignment-splice.txtar new file mode 100644 index 00000000000..f5a19c022f3 --- /dev/null +++ b/internal/refactor/inline/testdata/assignment-splice.txtar @@ -0,0 +1,62 @@ +This test checks the splice assignment substrategy. + +-- go.mod -- +module testdata + +go 1.20 + +-- a.go -- +package a + +func a() (int32, string) { + return b() +} + +func b() (int32, string) { + return 0, "a" +} + +func c() (int, chan<- int) { + return 0, make(chan int) // nontrivial conversion +} + +-- a1.go -- +package a + +func _() { + x, y := a() //@ inline(re"a", a1) +} +-- a1 -- +package a + +func _() { + x, y := b() //@ inline(re"a", a1) +} +-- a2.go -- +package a + +func _() { + var x, y any + x, y = a() //@ inline(re"a", a2) +} +-- a2 -- +package a + +func _() { + var x, y any + x, y = b() //@ inline(re"a", a2) +} +-- a3.go -- +package a + +func _() { + var y chan<- int + x, y := c() //@ inline(re"c", a3) +} +-- a3 -- +package a + +func _() { + var y chan<- int + x, y := 0, make(chan int) //@ inline(re"c", a3) +} diff --git a/internal/refactor/inline/testdata/assignment.txtar b/internal/refactor/inline/testdata/assignment.txtar new file mode 100644 index 00000000000..c79c1732934 --- /dev/null +++ b/internal/refactor/inline/testdata/assignment.txtar @@ -0,0 +1,138 @@ +Basic tests of inlining a call on the RHS of an assignment. + +-- go.mod -- +module testdata + +go 1.20 + +-- a/a1.go -- +package a + +import "testdata/b" + +func _() { + var y int + x, y := b.B1() //@ inline(re"B", b1) + _, _ = x, y +} + +-- a/a2.go -- +package a + +import "testdata/b" + +func _() { + var y int + x, y := b.B2() //@ inline(re"B", b2) + _, _ = x, y +} + +-- a/a3.go -- +package a + +import "testdata/b" + +func _() { + x, y := b.B3() //@ inline(re"B", b3) + _, _ = x, y +} + +-- a/a4.go -- +package a + +import "testdata/b" + +func _() { + x, y := b.B4() //@ inline(re"B", b4) + _, _ = x, y +} + +-- b/b.go -- +package b + +import ( + "testdata/c" +) + +func B1() (c.C, int) { + return 0, 1 +} + +func B2() (c.C, int) { + return B1() +} + +func B3() (c.C, c.C) { + return 0, 1 +} + +-- b/b4.go -- +package b + +import ( + c1 "testdata/c" + c2 "testdata/c2" +) + +func B4() (c1.C, c2.C) { + return 0, 1 +} + +-- c/c.go -- +package c + +type C int + +-- c2/c.go -- +package c + +type C int + +-- b1 -- +package a + +import ( + "testdata/c" +) + +func _() { + var y int + x, y := c.C(0), 1 //@ inline(re"B", b1) + _, _ = x, y +} +-- b2 -- +package a + +import ( + "testdata/b" +) + +func _() { + var y int + x, y := b.B1() //@ inline(re"B", b2) + _, _ = x, y +} +-- b3 -- +package a + +import ( + "testdata/c" +) + +func _() { + x, y := c.C(0), c.C(1) //@ inline(re"B", b3) + _, _ = x, y +} + +-- b4 -- +package a + +import ( + "testdata/c" + c0 "testdata/c2" +) + +func _() { + x, y := c.C(0), c0.C(1) //@ inline(re"B", b4) + _, _ = x, y +} diff --git a/internal/refactor/inline/testdata/import-comments.txtar b/internal/refactor/inline/testdata/import-comments.txtar new file mode 100644 index 00000000000..d4a4122c4d1 --- /dev/null +++ b/internal/refactor/inline/testdata/import-comments.txtar @@ -0,0 +1,113 @@ +This file checks various handling of comments when adding imports. + +-- go.mod -- +module testdata +go 1.12 + +-- a/empty.go -- +package a // This is package a. + +func _() { + a() //@ inline(re"a", empty) +} + +-- empty -- +package a // This is package a. + +import "testdata/b" + +func _() { + b.B() //@ inline(re"a", empty) +} +-- a/existing.go -- +package a // This is package a. + +// This is an import block. +import ( + // This is an import of io. + "io" + + // This is an import of c. + "testdata/c" +) + +var ( + // This is an io.Writer. + _ io.Writer + // This is c.C + _ c.C +) + +func _() { + a() //@ inline(re"a", existing) +} + +-- existing -- +package a // This is package a. + +// This is an import block. +import ( + // This is an import of io. + "io" + + // This is an import of c. + "testdata/b" + "testdata/c" +) + +var ( + // This is an io.Writer. + _ io.Writer + // This is c.C + _ c.C +) + +func _() { + b.B() //@ inline(re"a", existing) +} + +-- a/noparens.go -- +package a // This is package a. + +// This is an import of c. +import "testdata/c" + +func _() { + var _ c.C + a() //@ inline(re"a", noparens) +} + +-- noparens -- +package a // This is package a. + +// This is an import of c. +import ( + "testdata/b" + "testdata/c" +) + +func _() { + var _ c.C + b.B() //@ inline(re"a", noparens) +} + +-- a/a.go -- +package a + +// This is an import of b. +import "testdata/b" + +func a() { + // This is a call to B. + b.B() +} + +-- b/b.go -- +package b + +func B() {} + +-- c/c.go -- +package c + +type C int diff --git a/internal/refactor/inline/testdata/issue63298.txtar b/internal/refactor/inline/testdata/issue63298.txtar index cc556c90ecd..e7f36351219 100644 --- a/internal/refactor/inline/testdata/issue63298.txtar +++ b/internal/refactor/inline/testdata/issue63298.txtar @@ -38,13 +38,11 @@ func B() {} package a import ( - "testdata/b" b0 "testdata/another/b" - - //@ inline(re"a2", result) + "testdata/b" ) func _() { b.B() - b0.B() + b0.B() //@ inline(re"a2", result) } diff --git a/internal/refactor/inline/testdata/issue69441.txtar b/internal/refactor/inline/testdata/issue69441.txtar new file mode 100644 index 00000000000..259a2a2150a --- /dev/null +++ b/internal/refactor/inline/testdata/issue69441.txtar @@ -0,0 +1,44 @@ +This test checks that variadic elimination does not cause a semantic change due +to creation of a non-nil empty slice instead of a nil slice due to missing +variadic arguments. + +-- go.mod -- +module testdata +go 1.12 + +-- foo/foo.go -- +package foo +import "fmt" + +func F(is ...int) { + if is == nil { + fmt.Println("is is nil") + } else { + fmt.Println("is is not nil") + } +} + +func G(is ...int) { F(is...) } + +func main() { + G() //@ inline(re"G", G) +} + +-- G -- +package foo + +import "fmt" + +func F(is ...int) { + if is == nil { + fmt.Println("is is nil") + } else { + fmt.Println("is is not nil") + } +} + +func G(is ...int) { F(is...) } + +func main() { + F() //@ inline(re"G", G) +} diff --git a/internal/refactor/inline/testdata/issue69442.txtar b/internal/refactor/inline/testdata/issue69442.txtar new file mode 100644 index 00000000000..cf38bd8c9ec --- /dev/null +++ b/internal/refactor/inline/testdata/issue69442.txtar @@ -0,0 +1,34 @@ +This test checks that we don't introduce unnecessary (&v) or (*ptr) operations +when calling a method on an addressable receiver. + +-- go.mod -- +module testdata + +go 1.20 + +-- main.go -- +package foo +type T int + +func (*T) F() {} + +func (t *T) G() { t.F() } + +func main() { + var t T + t.G() //@ inline(re"G", inline) +} + +-- inline -- +package foo + +type T int + +func (*T) F() {} + +func (t *T) G() { t.F() } + +func main() { + var t T + t.F() //@ inline(re"G", inline) +} diff --git a/internal/refactor/inline/testdata/substgroups.txtar b/internal/refactor/inline/testdata/substgroups.txtar new file mode 100644 index 00000000000..37f8f7d8127 --- /dev/null +++ b/internal/refactor/inline/testdata/substgroups.txtar @@ -0,0 +1,113 @@ +This test checks that parameter shadowing is avoided for substitution groups, +as well as the examples of recursive pruning of these groups based on falcon +and effects analysis. + +-- go.mod -- +module testdata + +go 1.20 + +-- falcon.go -- +package a + +var a [3]int + +func falcon(x, y, z int) { + _ = x + a[y+z] +} + +func _() { + var y int + const x, z = 1, 2 + falcon(y, x, z) //@ inline(re"falcon", falcon) +} + +-- falcon -- +package a + +var a [3]int + +func falcon(x, y, z int) { + _ = x + a[y+z] +} + +func _() { + var y int + const x, z = 1, 2 + { + var x, y, z int = y, x, z + _ = x + a[y+z] + } //@ inline(re"falcon", falcon) +} + +-- effects.go -- +package a + +func effects(w, x, y, z int) { + _ = x + w + y + z +} + +func _() { + v := 0 + w := func() int { v++; return 0 } + x := func() int { v++; return 0 } + y := func() int { v++; return 0 } + effects(x(), w(), y(), x()) //@ inline(re"effects", effects) +} + +-- effects -- +package a + +func effects(w, x, y, z int) { + _ = x + w + y + z +} + +func _() { + v := 0 + w := func() int { v++; return 0 } + x := func() int { v++; return 0 } + y := func() int { v++; return 0 } + { + var w, x, y, z int = x(), w(), y(), x() + _ = x + w + y + z + } //@ inline(re"effects", effects) +} + +-- negative.go -- +package a + +func _() { + i := -1 + if negative1(i, i) { //@ inline(re"negative1", negative1) + i := 0 + _ = i + } +} + +func negative1(i, j int) bool { + return negative2(j, i) +} + +func negative2(i, j int) bool { + return i < 0 +} + +-- negative1 -- +package a + +func _() { + i := -1 + if negative2(i, i) { //@ inline(re"negative1", negative1) + i := 0 + _ = i + } +} + +func negative1(i, j int) bool { + return negative2(j, i) +} + +func negative2(i, j int) bool { + return i < 0 +} + diff --git a/internal/typesinternal/zerovalue.go b/internal/typesinternal/zerovalue.go new file mode 100644 index 00000000000..1066980649e --- /dev/null +++ b/internal/typesinternal/zerovalue.go @@ -0,0 +1,282 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package typesinternal + +import ( + "fmt" + "go/ast" + "go/token" + "go/types" + "strconv" + "strings" +) + +// ZeroString returns the string representation of the "zero" value of the type t. +// This string can be used on the right-hand side of an assignment where the +// left-hand side has that explicit type. +// Exception: This does not apply to tuples. Their string representation is +// informational only and cannot be used in an assignment. +// When assigning to a wider type (such as 'any'), it's the caller's +// responsibility to handle any necessary type conversions. +// See [ZeroExpr] for a variant that returns an [ast.Expr]. +func ZeroString(t types.Type, qf types.Qualifier) string { + switch t := t.(type) { + case *types.Basic: + switch { + case t.Info()&types.IsBoolean != 0: + return "false" + case t.Info()&types.IsNumeric != 0: + return "0" + case t.Info()&types.IsString != 0: + return `""` + case t.Kind() == types.UnsafePointer: + fallthrough + case t.Kind() == types.UntypedNil: + return "nil" + default: + panic(fmt.Sprint("ZeroString for unexpected type:", t)) + } + + case *types.Pointer, *types.Slice, *types.Interface, *types.Chan, *types.Map, *types.Signature: + return "nil" + + case *types.Named, *types.Alias: + switch under := t.Underlying().(type) { + case *types.Struct, *types.Array: + return types.TypeString(t, qf) + "{}" + default: + return ZeroString(under, qf) + } + + case *types.Array, *types.Struct: + return types.TypeString(t, qf) + "{}" + + case *types.TypeParam: + // Assumes func new is not shadowed. + return "*new(" + types.TypeString(t, qf) + ")" + + case *types.Tuple: + // Tuples are not normal values. + // We are currently format as "(t[0], ..., t[n])". Could be something else. + components := make([]string, t.Len()) + for i := 0; i < t.Len(); i++ { + components[i] = ZeroString(t.At(i).Type(), qf) + } + return "(" + strings.Join(components, ", ") + ")" + + case *types.Union: + // Variables of these types cannot be created, so it makes + // no sense to ask for their zero value. + panic(fmt.Sprintf("invalid type for a variable: %v", t)) + + default: + panic(t) // unreachable. + } +} + +// ZeroExpr returns the ast.Expr representation of the "zero" value of the type t. +// ZeroExpr is defined for types that are suitable for variables. +// It may panic for other types such as Tuple or Union. +// See [ZeroString] for a variant that returns a string. +func ZeroExpr(f *ast.File, pkg *types.Package, typ types.Type) ast.Expr { + switch t := typ.(type) { + case *types.Basic: + switch { + case t.Info()&types.IsBoolean != 0: + return &ast.Ident{Name: "false"} + case t.Info()&types.IsNumeric != 0: + return &ast.BasicLit{Kind: token.INT, Value: "0"} + case t.Info()&types.IsString != 0: + return &ast.BasicLit{Kind: token.STRING, Value: `""`} + case t.Kind() == types.UnsafePointer: + fallthrough + case t.Kind() == types.UntypedNil: + return ast.NewIdent("nil") + default: + panic(fmt.Sprint("ZeroExpr for unexpected type:", t)) + } + + case *types.Pointer, *types.Slice, *types.Interface, *types.Chan, *types.Map, *types.Signature: + return ast.NewIdent("nil") + + case *types.Named, *types.Alias: + switch under := t.Underlying().(type) { + case *types.Struct, *types.Array: + return &ast.CompositeLit{ + Type: TypeExpr(f, pkg, typ), + } + default: + return ZeroExpr(f, pkg, under) + } + + case *types.Array, *types.Struct: + return &ast.CompositeLit{ + Type: TypeExpr(f, pkg, typ), + } + + case *types.TypeParam: + return &ast.StarExpr{ // *new(T) + X: &ast.CallExpr{ + // Assumes func new is not shadowed. + Fun: ast.NewIdent("new"), + Args: []ast.Expr{ + ast.NewIdent(t.Obj().Name()), + }, + }, + } + + case *types.Tuple: + // Unlike ZeroString, there is no ast.Expr can express tuple by + // "(t[0], ..., t[n])". + panic(fmt.Sprintf("invalid type for a variable: %v", t)) + + case *types.Union: + // Variables of these types cannot be created, so it makes + // no sense to ask for their zero value. + panic(fmt.Sprintf("invalid type for a variable: %v", t)) + + default: + panic(t) // unreachable. + } +} + +// IsZeroExpr uses simple syntactic heuristics to report whether expr +// is a obvious zero value, such as 0, "", nil, or false. +// It cannot do better without type information. +func IsZeroExpr(expr ast.Expr) bool { + switch e := expr.(type) { + case *ast.BasicLit: + return e.Value == "0" || e.Value == `""` + case *ast.Ident: + return e.Name == "nil" || e.Name == "false" + default: + return false + } +} + +// TypeExpr returns syntax for the specified type. References to named types +// from packages other than pkg are qualified by an appropriate package name, as +// defined by the import environment of file. +// It may panic for types such as Tuple or Union. +func TypeExpr(f *ast.File, pkg *types.Package, typ types.Type) ast.Expr { + switch t := typ.(type) { + case *types.Basic: + switch t.Kind() { + case types.UnsafePointer: + // TODO(hxjiang): replace the implementation with types.Qualifier. + return &ast.SelectorExpr{X: ast.NewIdent("unsafe"), Sel: ast.NewIdent("Pointer")} + default: + return ast.NewIdent(t.Name()) + } + + case *types.Pointer: + return &ast.UnaryExpr{ + Op: token.MUL, + X: TypeExpr(f, pkg, t.Elem()), + } + + case *types.Array: + return &ast.ArrayType{ + Len: &ast.BasicLit{ + Kind: token.INT, + Value: fmt.Sprintf("%d", t.Len()), + }, + Elt: TypeExpr(f, pkg, t.Elem()), + } + + case *types.Slice: + return &ast.ArrayType{ + Elt: TypeExpr(f, pkg, t.Elem()), + } + + case *types.Map: + return &ast.MapType{ + Key: TypeExpr(f, pkg, t.Key()), + Value: TypeExpr(f, pkg, t.Elem()), + } + + case *types.Chan: + dir := ast.ChanDir(t.Dir()) + if t.Dir() == types.SendRecv { + dir = ast.SEND | ast.RECV + } + return &ast.ChanType{ + Dir: dir, + Value: TypeExpr(f, pkg, t.Elem()), + } + + case *types.Signature: + var params []*ast.Field + for i := 0; i < t.Params().Len(); i++ { + params = append(params, &ast.Field{ + Type: TypeExpr(f, pkg, t.Params().At(i).Type()), + Names: []*ast.Ident{ + { + Name: t.Params().At(i).Name(), + }, + }, + }) + } + if t.Variadic() { + last := params[len(params)-1] + last.Type = &ast.Ellipsis{Elt: last.Type.(*ast.ArrayType).Elt} + } + var returns []*ast.Field + for i := 0; i < t.Results().Len(); i++ { + returns = append(returns, &ast.Field{ + Type: TypeExpr(f, pkg, t.Results().At(i).Type()), + }) + } + return &ast.FuncType{ + Params: &ast.FieldList{ + List: params, + }, + Results: &ast.FieldList{ + List: returns, + }, + } + + case interface{ Obj() *types.TypeName }: // *types.{Alias,Named,TypeParam} + switch t.Obj().Pkg() { + case pkg, nil: + return ast.NewIdent(t.Obj().Name()) + } + pkgName := t.Obj().Pkg().Name() + + // TODO(hxjiang): replace the implementation with types.Qualifier. + // If the file already imports the package under another name, use that. + for _, cand := range f.Imports { + if path, _ := strconv.Unquote(cand.Path.Value); path == t.Obj().Pkg().Path() { + if cand.Name != nil && cand.Name.Name != "" { + pkgName = cand.Name.Name + } + } + } + if pkgName == "." { + return ast.NewIdent(t.Obj().Name()) + } + return &ast.SelectorExpr{ + X: ast.NewIdent(pkgName), + Sel: ast.NewIdent(t.Obj().Name()), + } + + case *types.Struct: + return ast.NewIdent(t.String()) + + case *types.Interface: + return ast.NewIdent(t.String()) + + case *types.Union: + // TODO(hxjiang): handle the union through syntax (~A | ... | ~Z). + // Remove nil check when calling typesinternal.TypeExpr. + return nil + + case *types.Tuple: + panic("invalid input type types.Tuple") + + default: + panic("unreachable") + } +} diff --git a/internal/typesinternal/zerovalue_test.go b/internal/typesinternal/zerovalue_test.go new file mode 100644 index 00000000000..6cb6ea672a5 --- /dev/null +++ b/internal/typesinternal/zerovalue_test.go @@ -0,0 +1,151 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:debug gotypesalias=1 + +package typesinternal_test + +import ( + "bytes" + "go/ast" + "go/parser" + "go/printer" + "go/token" + "go/types" + "strings" + "testing" + + "golang.org/x/tools/internal/typesinternal" +) + +func TestZeroValue(t *testing.T) { + // This test only refernece types/functions defined within the same package. + // We can safely drop the package name when encountered. + qf := types.Qualifier(func(p *types.Package) string { + return "" + }) + src := ` +package main + +type foo struct{ + bar string +} + +type namedInt int +type namedString string +type namedBool bool +type namedPointer *foo +type namedSlice []foo +type namedInterface interface{ Error() string } +type namedChan chan int +type namedMap map[string]foo +type namedSignature func(string) string +type namedStruct struct{ bar string } +type namedArray [3]foo + +type aliasInt = int +type aliasString = string +type aliasBool = bool +type aliasPointer = *foo +type aliasSlice = []foo +type aliasInterface = interface{ Error() string } +type aliasChan = chan int +type aliasMap = map[string]foo +type aliasSignature = func(string) string +type aliasStruct = struct{ bar string } +type aliasArray = [3]foo + +func _[T any]() { + var ( + _ int // 0 + _ bool // false + _ string // "" + + _ *foo // nil + _ []string // nil + _ []foo // nil + _ interface{ Error() string } // nil + _ chan foo // nil + _ map[string]foo // nil + _ func(string) string // nil + + _ namedInt // 0 + _ namedString // "" + _ namedBool // false + _ namedSlice // nil + _ namedInterface // nil + _ namedChan // nil + _ namedMap// nil + _ namedSignature // nil + _ namedStruct // namedStruct{} + _ namedArray // namedArray{} + + _ aliasInt // 0 + _ aliasString // "" + _ aliasBool // false + _ aliasSlice // nil + _ aliasInterface // nil + _ aliasChan // nil + _ aliasMap// nil + _ aliasSignature // nil + _ aliasStruct // aliasStruct{} + _ aliasArray // aliasArray{} + + _ [4]string // [4]string{} + _ [5]foo // [5]foo{} + _ foo // foo{} + _ struct{f foo} // struct{f foo}{} + + _ T // *new(T) + ) +} +` + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "p.go", src, parser.ParseComments) + if err != nil { + t.Fatalf("parse file error %v on file source:\n%s\n", err, src) + } + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + } + var conf types.Config + pkg, err := conf.Check("", fset, []*ast.File{f}, info) + if err != nil { + t.Fatalf("type check error %v on file source:\n%s\n", err, src) + } + + fun, ok := f.Decls[len(f.Decls)-1].(*ast.FuncDecl) + if !ok { + t.Fatalf("the last decl of the file is not FuncDecl") + } + + decl, ok := fun.Body.List[0].(*ast.DeclStmt).Decl.(*ast.GenDecl) + if !ok { + t.Fatalf("the first statement of the function is not GenDecl") + } + + for _, spec := range decl.Specs { + s, ok := spec.(*ast.ValueSpec) + if !ok { + t.Fatalf("%s: got %T, want ValueSpec", fset.Position(spec.Pos()), spec) + } + want := strings.TrimSpace(s.Comment.Text()) + + typ := info.TypeOf(s.Type) + got := typesinternal.ZeroString(typ, qf) + if got != want { + t.Errorf("%s: ZeroString() = %q, want zero value %q", fset.Position(spec.Pos()), got, want) + } + + zeroExpr := typesinternal.ZeroExpr(f, pkg, typ) + var bytes bytes.Buffer + printer.Fprint(&bytes, fset, zeroExpr) + got = bytes.String() + if got != want { + t.Errorf("%s: ZeroExpr() = %q, want zero value %q", fset.Position(spec.Pos()), got, want) + } + } +} diff --git a/internal/versions/constraint.go b/internal/versions/constraint.go deleted file mode 100644 index 179063d4848..00000000000 --- a/internal/versions/constraint.go +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright 2024 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package versions - -import "go/build/constraint" - -// ConstraintGoVersion is constraint.GoVersion (if built with go1.21+). -// Otherwise nil. -// -// Deprecate once x/tools is after go1.21. -var ConstraintGoVersion func(x constraint.Expr) string