package main

import (
	"fmt"
	"github.com/janpfeifer/must"
	"os"
	"regexp"
	"text/template"
)

const (
	apiCallsCFileName = "gen_api_calls.c"
	apiCallsHFileName = "gen_api_calls.h"
)

var (
	// reApiMethods separates all structs using the PJRT_DEFINE_STRUCT_TRAITS macro.
	// Exceptionally, it doesn't capture PJRT_Api, because it is defined differently. But we don't need to create
	// that structure.
	reApiMethods = regexp.MustCompile(
		`(?m)(((^//.*$\n)*)` + // Match preceding comments.
			`^typedef\s+((const\s+)?(struct\s+)?\w+(\s*\*)?)` + // Return type.
			`\s+(\w+)\(` + // Function name.
			`\s*(.*?)\s*\);$` + // Arguments..
			`)`) // Close definition.
	// \s+(\w+?)\(\);)$

	// Notice all methods take one argument, and the name of it is "args" -- while the type varies.
	apiCallsCTemplate = template.Must(template.New(apiCallsCFileName).Parse(`
/***** File generated by ./cmd/codegen, don't edit it directly. *****/

#include <stdlib.h>
#include "pjrt_c_api.h"
#include "gen_api_calls.h"

{{range .}}
// call_{{.Name}} calls the corresponding PJRT API method.
{{.Comments}}{{.Return}} call_{{.Name}}(const PJRT_Api *api, {{.Args}}) {
	return api->{{.Name}}(args);
}
{{end}}
`))

	apiCallsHTemplate = template.Must(template.New(apiCallsHFileName).Parse(`
/***** File generated by ./cmd/codegen, don't edit it directly. *****/

#ifndef GOMLX_GOPJRT_GEN_API_CALLS
#define GOMLX_GOPJRT_GEN_API_CALLS
#include "pjrt_c_api.h"

#ifdef __cplusplus
extern "C" {
#endif
{{range .}}
// call_{{.Name}} calls the corresponding PJRT API method.
{{.Comments}}extern {{.Return}} call_{{.Name}}(const PJRT_Api *api, {{.Args}});
{{end}}

#ifdef __cplusplus
}
#endif
#endif
`))
)

type apiCallInfo struct {
	Name, Comments string
	Return, Args   string
}

func generateAPICalls(contents string) {
	var allInfo []apiCallInfo
	for _, matches := range reApiMethods.FindAllStringSubmatch(contents, -1) {
		info := apiCallInfo{
			Name:     matches[8],
			Comments: matches[2],
			Return:   matches[4],
			Args:     matches[9],
		}
		allInfo = append(allInfo, info)
	}

	f := must.M1(os.Create(apiCallsCFileName))
	must.M(apiCallsCTemplate.Execute(f, allInfo))
	fmt.Printf("✅ Successfully generated %q based on pjrt_c_api.h\n", apiCallsCFileName)

	f = must.M1(os.Create(apiCallsHFileName))
	must.M(apiCallsHTemplate.Execute(f, allInfo))
	fmt.Printf("✅ Successfully generated %q based on pjrt_c_api.h\n", apiCallsHFileName)
}
