From 17263fa5e2ab11d17bdc089852f2f29a0f246fd1 Mon Sep 17 00:00:00 2001 From: Grant Nelson Date: Mon, 18 Dec 2023 12:16:33 -0700 Subject: [PATCH] Broke up parseAndAugment --- build/build.go | 257 +++++++++++++++++++++++---------------- build/build_test.go | 290 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 443 insertions(+), 104 deletions(-) diff --git a/build/build.go b/build/build.go index 070d05df1..08c96f4d9 100644 --- a/build/build.go +++ b/build/build.go @@ -117,6 +117,19 @@ func ImportDir(dir string, mode build.ImportMode, installSuffix string, buildTag return pkg, nil } +// overrideInfo is used by parseAndAugment methods to manage +// directives and how the overlay and original are merged. +type overrideInfo struct { + // KeepOriginal indicates that the original code should be kept + // but the identifier will be prefixed by `_gopherjs_original_foo`. + // If false the original code is removed. + keepOriginal bool + + // pruneMethodBody indicates that the body of the methods should be + // removed because they contain something that is invalid to GopherJS. + pruneMethodBody bool +} + // parseAndAugment parses and returns all .go files of given pkg. // Standard Go library packages are augmented with files in compiler/natives folder. // If isTest is true and pkg.ImportPath has no _test suffix, package is built for running internal tests. @@ -132,84 +145,86 @@ func ImportDir(dir string, mode build.ImportMode, installSuffix string, buildTag // the original identifier gets replaced by `_`. New identifiers that don't exist in original // package get added. func parseAndAugment(xctx XContext, pkg *PackageData, isTest bool, fileSet *token.FileSet) ([]*ast.File, []JSFile, error) { - var files []*ast.File + jsFiles, overlayFiles := parseOverlayFiles(xctx, pkg, isTest, fileSet) + + originalFiles, err := parserOriginalFiles(pkg, fileSet) + if err != nil { + return nil, nil, err + } + + overrides := make(map[string]overrideInfo) + for _, file := range overlayFiles { + augmentOverlayFile(file, overrides) + } + delete(overrides, "init") - type overrideInfo struct { - keepOriginal bool - pruneOriginal bool + for _, file := range originalFiles { + augmentOriginalImports(pkg.ImportPath, file) + augmentOriginalFile(file, overrides) } - replacedDeclNames := make(map[string]overrideInfo) + return append(overlayFiles, originalFiles...), jsFiles, nil +} + +// parseOverlayFiles loads and parses overlay files +// to augment the original files with. +func parseOverlayFiles(xctx XContext, pkg *PackageData, isTest bool, fileSet *token.FileSet) ([]JSFile, []*ast.File) { isXTest := strings.HasSuffix(pkg.ImportPath, "_test") importPath := pkg.ImportPath if isXTest { importPath = importPath[:len(importPath)-5] } - jsFiles := []JSFile{} - nativesContext := overlayCtx(xctx.Env()) + nativesPkg, err := nativesContext.Import(importPath, "", 0) + if err != nil { + return nil, nil + } - if nativesPkg, err := nativesContext.Import(importPath, "", 0); err == nil { - jsFiles = nativesPkg.JSFiles - names := nativesPkg.GoFiles - if isTest { - names = append(names, nativesPkg.TestGoFiles...) - } - if isXTest { - names = nativesPkg.XTestGoFiles + jsFiles := nativesPkg.JSFiles + var files []*ast.File + names := nativesPkg.GoFiles + if isTest { + names = append(names, nativesPkg.TestGoFiles...) + } + if isXTest { + names = nativesPkg.XTestGoFiles + } + + for _, name := range names { + fullPath := path.Join(nativesPkg.Dir, name) + r, err := nativesContext.bctx.OpenFile(fullPath) + if err != nil { + panic(err) } - for _, name := range names { - fullPath := path.Join(nativesPkg.Dir, name) - r, err := nativesContext.bctx.OpenFile(fullPath) - if err != nil { - panic(err) - } - // Files should be uniquely named and in the original package directory in order to be - // ordered correctly - newPath := path.Join(pkg.Dir, "gopherjs__"+name) - file, err := parser.ParseFile(fileSet, newPath, r, parser.ParseComments) - if err != nil { - panic(err) - } - r.Close() - for _, decl := range file.Decls { - switch d := decl.(type) { - case *ast.FuncDecl: - k := astutil.FuncKey(d) - replacedDeclNames[k] = overrideInfo{ - keepOriginal: astutil.KeepOriginal(d), - pruneOriginal: astutil.PruneOriginal(d), - } - case *ast.GenDecl: - switch d.Tok { - case token.TYPE: - for _, spec := range d.Specs { - replacedDeclNames[spec.(*ast.TypeSpec).Name.Name] = overrideInfo{} - } - case token.VAR, token.CONST: - for _, spec := range d.Specs { - for _, name := range spec.(*ast.ValueSpec).Names { - replacedDeclNames[name.Name] = overrideInfo{} - } - } - } - } - } - files = append(files, file) + // Files should be uniquely named and in the original package directory in order to be + // ordered correctly + newPath := path.Join(pkg.Dir, "gopherjs__"+name) + file, err := parser.ParseFile(fileSet, newPath, r, parser.ParseComments) + if err != nil { + panic(err) } + r.Close() + + files = append(files, file) } - delete(replacedDeclNames, "init") + return jsFiles, files +} +// parserOriginalFiles loads and parses the original files to augment. +func parserOriginalFiles(pkg *PackageData, fileSet *token.FileSet) ([]*ast.File, error) { + var files []*ast.File var errList compiler.ErrorList for _, name := range pkg.GoFiles { if !filepath.IsAbs(name) { // name might be absolute if specified directly. E.g., `gopherjs build /abs/file.go`. name = filepath.Join(pkg.Dir, name) } + r, err := buildutil.OpenFile(pkg.bctx, name) if err != nil { - return nil, nil, err + return nil, err } + file, err := parser.ParseFile(fileSet, name, r, parser.ParseComments) r.Close() if err != nil { @@ -226,68 +241,102 @@ func parseAndAugment(xctx XContext, pkg *PackageData, isTest bool, fileSet *toke continue } - switch pkg.ImportPath { - case "crypto/rand", "encoding/gob", "encoding/json", "expvar", "go/token", "log", "math/big", "math/rand", "regexp", "time": - for _, spec := range file.Imports { - path, _ := strconv.Unquote(spec.Path.Value) - if path == "sync" { - if spec.Name == nil { - spec.Name = ast.NewIdent("sync") + files = append(files, file) + } + + if errList != nil { + return nil, errList + } + return files, nil +} + +// augmentOverlayFile is the part of parseAndAugment that processes +// an overlay file AST to collect information such as compiler directives +// and perform any initial augmentation needed to the overlay. +func augmentOverlayFile(file *ast.File, overrides map[string]overrideInfo) { + for _, decl := range file.Decls { + switch d := decl.(type) { + case *ast.FuncDecl: + k := astutil.FuncKey(d) + overrides[k] = overrideInfo{ + keepOriginal: astutil.KeepOriginal(d), + pruneMethodBody: astutil.PruneOriginal(d), + } + case *ast.GenDecl: + for _, spec := range d.Specs { + switch s := spec.(type) { + case *ast.TypeSpec: + overrides[s.Name.Name] = overrideInfo{} + case *ast.ValueSpec: + for _, name := range s.Names { + overrides[name.Name] = overrideInfo{} } - spec.Path.Value = `"github.com/gopherjs/gopherjs/nosync"` } } } + } +} - for _, decl := range file.Decls { - switch d := decl.(type) { - case *ast.FuncDecl: - k := astutil.FuncKey(d) - if info, ok := replacedDeclNames[k]; ok { - if info.pruneOriginal { - // Prune function bodies, since it may contain code invalid for - // GopherJS and pin unwanted imports. - d.Body = nil - } - if info.keepOriginal { - // Allow overridden function calls - // The standard library implementation of foo() becomes _gopherjs_original_foo() - d.Name.Name = "_gopherjs_original_" + d.Name.Name - } else { - d.Name = ast.NewIdent("_") - } +// augmentOriginalImports is the part of parseAndAugment that processes +// an original file AST to modify the imports for that file. +func augmentOriginalImports(importPath string, file *ast.File) { + switch importPath { + case "crypto/rand", "encoding/gob", "encoding/json", "expvar", "go/token", "log", "math/big", "math/rand", "regexp", "time": + for _, spec := range file.Imports { + path, _ := strconv.Unquote(spec.Path.Value) + if path == "sync" { + if spec.Name == nil { + spec.Name = ast.NewIdent("sync") } - case *ast.GenDecl: - switch d.Tok { - case token.TYPE: - for _, spec := range d.Specs { - s := spec.(*ast.TypeSpec) - if _, ok := replacedDeclNames[s.Name.Name]; ok { - s.Name = ast.NewIdent("_") - s.Type = &ast.StructType{Struct: s.Pos(), Fields: &ast.FieldList{}} - s.TypeParams = nil - } + spec.Path.Value = `"github.com/gopherjs/gopherjs/nosync"` + } + } + } +} + +// augmentOriginalFile is the part of parseAndAugment that processes an +// original file AST to augment the source code using the overrides from +// the overlay files. +func augmentOriginalFile(file *ast.File, overrides map[string]overrideInfo) { + for _, decl := range file.Decls { + switch d := decl.(type) { + case *ast.FuncDecl: + if info, ok := overrides[astutil.FuncKey(d)]; ok { + if info.pruneMethodBody { + // Prune function bodies, since it may contain code invalid for + // GopherJS and pin unwanted imports. + d.Body = nil + } + if info.keepOriginal { + // Allow overridden function calls + // The standard library implementation of foo() becomes _gopherjs_original_foo() + d.Name.Name = "_gopherjs_original_" + d.Name.Name + } else { + // By setting the name to an underscore, the method will + // not be outputted. Doing this will keep the dependencies the same. + d.Name = ast.NewIdent("_") + } + } + case *ast.GenDecl: + for _, spec := range d.Specs { + switch s := spec.(type) { + case *ast.TypeSpec: + if _, ok := overrides[s.Name.Name]; ok { + s.Name = ast.NewIdent("_") + // Change to struct type with no type body and not type parameters. + s.Type = &ast.StructType{Struct: s.Pos(), Fields: &ast.FieldList{}} + s.TypeParams = nil } - case token.VAR, token.CONST: - for _, spec := range d.Specs { - s := spec.(*ast.ValueSpec) - for i, name := range s.Names { - if _, ok := replacedDeclNames[name.Name]; ok { - s.Names[i] = ast.NewIdent("_") - } + case *ast.ValueSpec: + for i, name := range s.Names { + if _, ok := overrides[name.Name]; ok { + s.Names[i] = ast.NewIdent("_") } } } } } - - files = append(files, file) - } - - if errList != nil { - return nil, nil, errList } - return files, jsFiles, nil } // Options controls build process behavior. diff --git a/build/build_test.go b/build/build_test.go index 2fa17e2c5..8364052d7 100644 --- a/build/build_test.go +++ b/build/build_test.go @@ -1,12 +1,15 @@ package build import ( + "bytes" "fmt" gobuild "go/build" + "go/printer" "go/token" "strconv" "testing" + "github.com/gopherjs/gopherjs/internal/srctesting" "github.com/shurcooL/go/importgraphutil" ) @@ -127,3 +130,290 @@ func (m stringSet) String() string { } return fmt.Sprintf("%q", s) } + +func TestOverlayAugmentation(t *testing.T) { + tests := []struct { + desc string + src string + expInfo map[string]overrideInfo + }{ + { + desc: `remove function`, + src: `func Foo(a, b int) int { + return a + b + }`, + expInfo: map[string]overrideInfo{ + `Foo`: {}, + }, + }, { + desc: `keep function`, + src: `//gopherjs:keep-original + func Foo(a, b int) int { + return a + b + }`, + expInfo: map[string]overrideInfo{ + `Foo`: {keepOriginal: true}, + }, + }, { + desc: `prune function body`, + src: `//gopherjs:prune-original + func Foo(a, b int) int { + return a + b + }`, + expInfo: map[string]overrideInfo{ + `Foo`: {pruneMethodBody: true}, + }, + }, { + desc: `remove constants and values`, + src: `import "time" + + const ( + foo = 42 + bar = "gopherjs" + ) + + var now = time.Now`, + expInfo: map[string]overrideInfo{ + `foo`: {}, + `bar`: {}, + `now`: {}, + }, + }, { + desc: `remove types`, + src: `import "time" + + type ( + foo struct {} + bar int + ) + + type bob interface {}`, + expInfo: map[string]overrideInfo{ + `foo`: {}, + `bar`: {}, + `bob`: {}, + }, + }, { + desc: `remove methods`, + src: `import "cmp" + + type Foo struct { + bar int + } + + func (x *Foo) GetBar() int { return x.bar } + func (x *Foo) SetBar(bar int) { x.bar = bar }`, + expInfo: map[string]overrideInfo{ + `Foo`: {}, + `Foo.GetBar`: {}, + `Foo.SetBar`: {}, + }, + }, { + desc: `remove generics`, + src: `import "cmp" + + type Pointer[T any] struct {} + + func Sort[S ~[]E, E cmp.Ordered](x S) {} + + // this is a stub for "func Equal[S ~[]E, E any](s1, s2 S) bool {}" + func Equal[S ~[]E, E any](s1, s2 S) bool {}`, + expInfo: map[string]overrideInfo{ + `Pointer`: {}, + `Sort`: {}, + `Equal`: {}, + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + pkgName := "package testpackage\n\n" + fsetSrc := token.NewFileSet() + fileSrc := srctesting.Parse(t, fsetSrc, pkgName+test.src) + + overrides := map[string]overrideInfo{} + augmentOverlayFile(fileSrc, overrides) + + for key, expInfo := range test.expInfo { + if gotInfo, ok := overrides[key]; !ok { + t.Errorf(`%q was expected but not gotten`, key) + } else if expInfo != gotInfo { + t.Errorf(`%q had wrong info, got %+v`, key, gotInfo) + } + } + for key, gotInfo := range overrides { + if _, ok := test.expInfo[key]; !ok { + t.Errorf(`%q with %+v was not expected`, key, gotInfo) + } + } + }) + } +} + +func TestOriginalAugmentation(t *testing.T) { + tests := []struct { + desc string + info map[string]overrideInfo + src string + want string + }{ + { + desc: `do not affect function`, + info: map[string]overrideInfo{}, + src: `func Foo(a, b int) int { + return a + b + }`, + want: `func Foo(a, b int) int { + return a + b + }`, + }, { + desc: `change unnamed sync import`, + info: map[string]overrideInfo{}, + src: `import "sync" + + var _ = &sync.Mutex{}`, + want: `import sync "github.com/gopherjs/gopherjs/nosync" + + var _ = &sync.Mutex{}`, + }, { + desc: `change named sync import`, + info: map[string]overrideInfo{}, + src: `import foo "sync" + + var _ = &foo.Mutex{}`, + want: `import foo "github.com/gopherjs/gopherjs/nosync" + + var _ = &foo.Mutex{}`, + }, { + desc: `remove function`, + info: map[string]overrideInfo{ + `Foo`: {}, + }, + src: `func Foo(a, b int) int { + return a + b + }`, + want: `func _(a, b int) int { + return a + b + }`, + }, { + desc: `keep original function`, + info: map[string]overrideInfo{ + `Foo`: {keepOriginal: true}, + }, + src: `func Foo(a, b int) int { + return a + b + }`, + want: `func _gopherjs_original_Foo(a, b int) int { + return a + b + }`, + }, { + desc: `remove types and values`, + info: map[string]overrideInfo{ + `Foo`: {}, + `now`: {}, + `bar1`: {}, + }, + src: `import "time" + + type Foo interface{ + bob(a, b string) string + } + + var now = time.Now + const bar1, bar2 = 21, 42`, + want: `import "time" + + type _ struct { + } + + var _ = time.Now + const _, bar2 = 21, 42`, + }, { + desc: `remove in multi-value context`, + info: map[string]overrideInfo{ + `bar`: {}, + }, + src: `const foo, bar = func() (int, int) { + return 24, 12 + }()`, + want: `const foo, _ = func() (int, int) { + return 24, 12 + }()`, + }, { + desc: `remove methods`, + info: map[string]overrideInfo{ + `Foo`: {}, + `Foo.GetBar`: {}, + `Foo.SetBar`: {}, + }, + src: `import "cmp" + + type Foo struct { + bar int + } + + func (x *Foo) GetBar() int { return x.bar } + func (x *Foo) SetBar(bar int) { x.bar = bar }`, + want: `import "cmp" + + type _ struct { + } + + func (x *Foo) _() int { return x.bar } + func (x *Foo) _(bar int) { x.bar = bar }`, + }, { + desc: `remove generics`, + info: map[string]overrideInfo{ + `Pointer`: {}, + `Sort`: {}, + `Equal`: {}, + }, + src: `import "cmp" + + type Pointer[T any] struct {} + + func Sort[S ~[]E, E cmp.Ordered](x S) {} + + // overlay had stub "func Equal() {}" + func Equal[S ~[]E, E any](s1, s2 S) bool {}`, + want: `import "cmp" + + type _ struct { + } + + func _[S ~[]E, E cmp.Ordered](x S) {} + + // overlay had stub "func Equal() {}" + func _[S ~[]E, E any](s1, s2 S) bool {}`, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + pkgName := "package testpackage\n\n" + importPath := `math/rand` + fsetSrc := token.NewFileSet() + fileSrc := srctesting.Parse(t, fsetSrc, pkgName+test.src) + + augmentOriginalImports(importPath, fileSrc) + augmentOriginalFile(fileSrc, test.info) + + buf := &bytes.Buffer{} + _ = printer.Fprint(buf, fsetSrc, fileSrc) + got := buf.String() + + fsetWant := token.NewFileSet() + fileWant := srctesting.Parse(t, fsetWant, pkgName+test.want) + + buf.Reset() + _ = printer.Fprint(buf, fsetWant, fileWant) + want := buf.String() + + if got != want { + t.Errorf("augmentOriginalImports, augmentOriginalFile, and pruneImports got unexpected code:\n"+ + "returned:\n\t%q\nwant:\n\t%q", got, want) + } + }) + } +}