package main

import (
	"fmt"
	"github.com/janpfeifer/must"
	"os"
	"os/exec"
	"slices"
	"text/template"
)

const (
	simpleGoOpsFileName     = "gen_simple_ops.go"
	simpleGoOpsTestFileName = "gen_simple_ops_test.go"
)

var (
	simpleGoOpsTemplate = template.Must(template.New(simpleGoOpsFileName).Parse(`
/***** File generated by gopjrt/internal/cmd/xlabuilder_codegen, based on op_types.txt. Don't edit it directly. *****/

package xlabuilder

import (
	"github.com/pkg/errors"
)

{{range .}}{{if eq .Type "one"}}
{{range .Comments}}// {{.}}
{{end}}// The op is created on the same XlaBuilder as used for x.
func {{.Name}}(x *Op) (*Op, error) {
	builder := x.builder
	y := newOp({{.Name}}Op, x)
	err := builder.addOp(y)
	if err != nil {
		return nil, err
	}
	return y, nil
}
{{end}}{{if or (eq .Type "two") (eq .Type "two_cmp")}}
{{range .Comments}}// {{.}}
{{end}}// The op is created on the same XlaBuilder as used for x0 and x1.
func {{.Name}}(x0, x1 *Op) (*Op, error) {
	if x0.builder != x1.builder {
		return nil, errors.Errorf("arguments of {{.Name}}(x0, x1) come from different XlaBuilder objects (%q and %q)", x0.builder.Name(), x1.builder.Name())
	}
	if x0.Shape.DType != x1.Shape.DType {
		return nil, errors.Errorf("dtype of first (%s) and second (%s) operands don't match", x0.Shape.DType, x1.Shape.DType)
	}
	builder := x0.builder
	y := newOp({{.Name}}Op, x0, x1)
	err := builder.addOp(y)
	if err != nil {
		return nil, err
	}
	return y, nil
}
{{end}}{{end}}`))

	simpleGoOpsTestTemplate = template.Must(template.New(simpleGoOpsTestFileName).Parse(`
/***** File generated by gopjrt/internal/cmd/xlabuilder_codegen, based on op_types.txt. Don't edit it directly. *****/

package xlabuilder_test

import (
	"fmt"
	"github.com/stretchr/testify/require"
	"github.com/gomlx/gopjrt/dtypes"
	. "github.com/gomlx/gopjrt/xlabuilder"
	"testing"
)

// TestSimpleOps simply concatenate all unary and then binary ops in a nonsensical computation, just to check
// that the HLO proto is actually generated -- and that all simple ops run.
func TestSimpleOps(t *testing.T) {
	builder := New("simple_ops_test")
	x, err := Parameter(builder, "x", 0, MakeShape(dtypes.F32)) // Scalar float32.
	i, err := Parameter(builder, "i", 1, MakeShape(dtypes.S32)) // Scalar int32.
	v, err := Parameter(builder, "v", 2, MakeShape(dtypes.F32, 2))

	// Unary ops:{{range .}}{{if eq .Type "one"}}
	x, err = {{.Name}}(x)
	require.NoError(t, err, "Failed to build unary operation {{.Name}}"){{end}}{{end}}

	// Binary ops:{{range .}}{{if eq .Type "two"}}
	x, err = {{.Name}}(x, x)
	require.NoError(t, err, "Failed to build binary operation {{.Name}}"){{end}}{{end}}

	// Binary-comparison op.
	var result, cmp *Op{{range .}}{{if eq .Type "two_cmp"}}
	cmp, err = {{.Name}}(x, x)
	require.NoError(t, err, "Failed to build binary comparison operation {{.Name}}")	
	if result == nil {
		result = cmp
	} else {
		result, err = LogicalAnd(result, cmp)
		require.NoError(t, err, "Failed to build logical And operation when aggregating the result from {{.Name}}")	
	}{{end}}{{end}}

	// Other ops not tested above.
	x, err = Dot(v, v)
	require.NoError(t, err, "Failed to build Dot operation")
	var c *Op
	c, err = Complex(x, x)
	require.NoError(t, err, "Failed to build Complex operation")
	imgV, err := Imag(c)
	require.NoError(t, err, "Failed to build Imag operation")
	realV, err := Real(c)
	require.NoError(t, err, "Failed to build Real operation")
	same, err := Equal(imgV, realV)
	require.NoError(t, err, "Failed to build Equal operation")
	result, err = LogicalAnd(result, same)
	require.NoError(t, err, "Failed to build And operation")

	result, err = IsFinite(x)
	require.NoError(t, err, "Failed to build IsFinite operation")
	require.Equal(t, dtypes.Bool, result.Shape.DType, "IsFinite should return booleans") 

	require.NoError(t, err, "Failed to build And operation")
	i, err = Clz(i)
	require.NoError(t, err, "Failed to build Clz operation")

	i, err = PopulationCount(i)
	require.NoError(t, err, "Failed to build PopulationCount operation")
	same, err = Equal(i, i)
	require.NoError(t, err, "Failed to build Equal operation")

	result, err = BitwiseOr(result, same)
	require.NoError(t, err, "Failed to build Or operation")
	result, err = BitwiseAnd(result, same)
	require.NoError(t, err, "Failed to build And operation")
	result, err = BitwiseXor(result, same)
	require.NoError(t, err, "Failed to build Xor operation")
	i, err = BitwiseNot(i)
	require.NoError(t, err, "Failed to build unary operation BitwiseNot")

	result, err = LogicalOr(result, same)
	require.NoError(t, err, "Failed to build Or operation")
	result, err = LogicalAnd(result, same)
	require.NoError(t, err, "Failed to build And operation")
	result, err = LogicalXor(result, same)
	require.NoError(t, err, "Failed to build Xor operation")
	i, err = LogicalNot(i)
	require.NoError(t, err, "Failed to build unary operation LogicalNot")


	// Get computation created: result depends on all of them.
	comp, err := builder.Build(result)
	require.NoError(t, err, "Failed to build the computation after setting all the ops")
	fmt.Printf("HloModule proto:\n%s\n\n", comp.TextHLO())
}
`))
)

// GenerateSimpleGoOps will generate the code for the simple to implement Go ops. Currently type of ops supported
// (see op_types.txt):
//
// - one (one arg)
// - two (two arg)
// - two_cmp (two arg comparison ops)
func GenerateSimpleGoOps(opsInfo []OpInfo) {
	fileName := simpleGoOpsFileName
	f := must.M1(os.Create(fileName))
	must.M(simpleGoOpsTemplate.Execute(f, opsInfo))
	must.M(exec.Command("gofmt", "-w", fileName).Run())
	fmt.Printf("✅ Successfully generated %q based on %q\n", fileName, OpTypesFileName)

	// For testing we skip some that require special types.
	skip := []string{
		"LogicalNot", "LogicalAnd", "LogicalOr", "LogicalXor",
		"BitwiseNot", "BitwiseAnd", "BitwiseOr", "BitwiseXor",
		"Dot", "Clz", "Real", "Imag", "Conj", "Complex", "IsFinite", "PopulationCount"}
	filteredOps := slices.DeleteFunc(opsInfo, func(info OpInfo) bool {
		return slices.Index(skip, info.Name) != -1
	})
	fileName = simpleGoOpsTestFileName
	f = must.M1(os.Create(fileName))
	must.M(simpleGoOpsTestTemplate.Execute(f, filteredOps))
	must.M(exec.Command("gofmt", "-w", fileName).Run())
	fmt.Printf("✅ Successfully generated %q based on %q\n", fileName, OpTypesFileName)
}
