package gentest_test

import (
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
	"slices"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

// TestCustomQueriesSynced makes sure the manual custom queries in modelqueries.go
// are synced with the autogenerated queries.sql.go. This should probably be
// autogenerated, but it's not atm and this is easy to throw in to elevate a better
// error message.
//
// If this breaks, and is hard to fix, you can t.Skip() it. It is not a critical
// test. Ping @Emyrk to fix it again.
func TestCustomQueriesSyncedRowScan(t *testing.T) {
	t.Parallel()

	funcsToTrack := map[string]string{
		"GetTemplatesWithFilter": "GetAuthorizedTemplates",
		"GetWorkspaces":          "GetAuthorizedWorkspaces",
		"GetUsers":               "GetAuthorizedUsers",
	}

	// Scan custom
	var custom []string
	for _, fn := range funcsToTrack {
		custom = append(custom, fn)
	}

	customFns := parseFile(t, "../modelqueries.go", func(name string) bool {
		return slices.Contains(custom, name)
	})
	generatedFns := parseFile(t, "../queries.sql.go", func(name string) bool {
		_, ok := funcsToTrack[name]
		return ok
	})
	merged := customFns
	for k, v := range generatedFns {
		merged[k] = v
	}

	for a, b := range funcsToTrack {
		a, b := a, b
		if !compareFns(t, a, b, merged[a], merged[b]) {
			//nolint:revive
			defer func() {
				// Run this at the end so the suggested fix is the last thing printed.
				t.Errorf("The functions %q and %q need to have identical 'rows.Scan()' "+
					"and 'db.QueryContext()' arguments in their function bodies. "+
					"Make sure to copy the function body from the autogenerated %q body. "+
					"Specifically the parameters for 'rows.Scan()' and 'db.QueryContext()'.", a, b, a)
			}()
		}
	}
}

type parsedFunc struct {
	RowScanArgs []ast.Expr
	QueryArgs   []ast.Expr
}

func parseFile(t *testing.T, filename string, trackFunc func(name string) bool) map[string]*parsedFunc {
	fset := token.NewFileSet()
	f, err := parser.ParseFile(fset, filename, nil, parser.SkipObjectResolution)
	require.NoErrorf(t, err, "failed to parse file %q", filename)

	parsed := make(map[string]*parsedFunc)
	for _, decl := range f.Decls {
		if fn, ok := decl.(*ast.FuncDecl); ok {
			if trackFunc(fn.Name.Name) {
				parsed[fn.Name.String()] = &parsedFunc{
					RowScanArgs: pullRowScanArgs(fn),
					QueryArgs:   pullQueryArgs(fn),
				}
			}
		}
	}

	return parsed
}

func compareFns(t *testing.T, aName, bName string, a, b *parsedFunc) bool {
	if a == nil {
		t.Errorf("The function %q is missing", aName)
		return false
	}
	if b == nil {
		t.Errorf("The function %q is missing", bName)
		return false
	}
	r := compareArgs(t, "rows.Scan() arguments", aName, bName, a.RowScanArgs, b.RowScanArgs)
	if len(a.QueryArgs) > 2 && len(b.QueryArgs) > 2 {
		// This is because the actual query param name is different. One uses the
		// const, the other uses a variable that is a mutation of the original query.
		a.QueryArgs[1] = b.QueryArgs[1]
	}
	q := compareArgs(t, "db.QueryContext() arguments", aName, bName, a.QueryArgs, b.QueryArgs)
	return r && q
}

func compareArgs(t *testing.T, argType string, aName, bName string, a, b []ast.Expr) bool {
	return assert.Equal(t, argList(t, a), argList(t, b), "mismatched %s for %s and %s", argType, aName, bName)
}

func argList(t *testing.T, args []ast.Expr) []string {
	defer func() {
		if r := recover(); r != nil {
			t.Errorf("Recovered in f reading arg names: %s", r)
		}
	}()

	var argNames []string
	for _, arg := range args {
		argname := "unknown"
		// This is "&i.Arg" style stuff
		if unary, ok := arg.(*ast.UnaryExpr); ok {
			argname = unary.X.(*ast.SelectorExpr).Sel.Name
		}
		if ident, ok := arg.(*ast.Ident); ok {
			argname = ident.Name
		}
		if sel, ok := arg.(*ast.SelectorExpr); ok {
			argname = sel.Sel.Name
		}
		if call, ok := arg.(*ast.CallExpr); ok {
			// Eh, this is pg.Array style stuff. Do a best effort.
			argname = fmt.Sprintf("call(%d)", len(call.Args))
			if fnCall, ok := call.Fun.(*ast.SelectorExpr); ok {
				argname = fmt.Sprintf("%s(%d)", fnCall.Sel.Name, len(call.Args))
			}
		}

		if argname == "unknown" {
			t.Errorf("Unknown arg, cannot parse: %T", arg)
		}
		argNames = append(argNames, argname)
	}
	return argNames
}

func pullQueryArgs(fn *ast.FuncDecl) []ast.Expr {
	for _, exp := range fn.Body.List {
		// find "rows, err :="
		if assign, ok := exp.(*ast.AssignStmt); ok {
			if len(assign.Lhs) == 2 {
				if id, ok := assign.Lhs[0].(*ast.Ident); ok && id.Name == "rows" {
					// This is rows, err :=
					query := assign.Rhs[0].(*ast.CallExpr)
					if qSel, ok := query.Fun.(*ast.SelectorExpr); ok && qSel.Sel.Name == "QueryContext" {
						return query.Args
					}
				}
			}
		}
	}
	return nil
}

func pullRowScanArgs(fn *ast.FuncDecl) []ast.Expr {
	for _, exp := range fn.Body.List {
		if forStmt, ok := exp.(*ast.ForStmt); ok {
			// This came from the debugger window and tracking it down.
			rowScan := (forStmt.Body.
				// Second statement in the for loop is the if statement
				// with rows.can
				List[1].(*ast.IfStmt).
				// This is the err := rows.Scan()
				Init.(*ast.AssignStmt).
				// Rhs is the row.Scan part
				Rhs)[0].(*ast.CallExpr)
			return rowScan.Args
		}
	}
	return nil
}
