package main

import (
	"fmt"
	"os"
	"regexp"
	"slices"
	"text/template"

	"github.com/janpfeifer/must"
)

const (
	NewStructCFileName = "gen_new_struct.c"
	NewStructHFileName = "gen_new_struct.h"
)

type cStructInfo struct {
	Name, Comments string
	HasStructSize  bool
}

var (
	// reStructs separates all structs using the PJRT_DEFINE_STRUCT_TRAITS macro.
	// Exceptionally, it doesn't capture PJRT_Api, because it is defined differently. But we don't need to create
	// that structure.
	reStructs = regexp.MustCompile(
		`(?m)(((^//.*$\n)*)` + // Match preceding comments.
			`^struct (\w+) \{$\n` + // Start of the struct.
			`(^.*$\n)*?^};$\n` +
			`^PJRT_DEFINE_STRUCT_TRAITS\((\w+),\s+(\w+)\);$\n)`) // End of the struct.

	newStructCTemplate = template.Must(template.New(NewStructCFileName).Parse(`
/***** File generated by ./cmd/codegen, don't edit it directly. *****/

#include <stdlib.h>
#include <string.h>
#include "pjrt_c_api.h"
#include "gen_new_struct.h"

{{range .}}
// new_{{.Name}} allocates a zero-initialized C.{{.Name}} structure, sets its .struct_size, and returns it.
{{.Comments}}{{.Name}}* new_{{.Name}}() {
	{{.Name}}* p = malloc(sizeof({{.Name}}));
{{- if .HasStructSize}}
	memset(p, 0, sizeof({{.Name}}));
	p->struct_size = {{.Name}}_STRUCT_SIZE;
{{- end}}
	return p;
}
{{end}}
`))

	newStructHTemplate = template.Must(template.New(NewStructHFileName).Parse(`
/***** File generated by ./cmd/codegen, don't edit it directly. *****/

#ifndef GOMLX_GOPJRT_GEN_NEW_STRUCT
#define GOMLX_GOPJRT_GEN_NEW_STRUCT
#include "pjrt_c_api.h"

#ifdef __cplusplus
extern "C" {
#endif
{{range .}}
// new_{{.Name}} allocates a zero-initialized C.{{.Name}} structure, sets its .struct_size, and returns it.
{{.Comments}}extern {{.Name}}* new_{{.Name}}();
{{end}}

#ifdef __cplusplus
}
#endif
#endif
`))
)

func generateNewStruct(contents string) {
	var allInfo []cStructInfo
	for _, cStructMatches := range reStructs.FindAllStringSubmatch(contents, -1) {
		info := cStructInfo{
			Name:     cStructMatches[4],
			Comments: cStructMatches[2],
		}
		info.HasStructSize = slices.Index([]string{"PJRT_SendCallbackInfo", "PJRT_RecvCallbackInfo", "PJRT_ProcessInfo"}, info.Name) == -1
		allInfo = append(allInfo, info)
	}

	f := must.M1(os.Create(NewStructCFileName))
	must.M(newStructCTemplate.Execute(f, allInfo))
	fmt.Printf("✅ Successfully generated %q based on pjrt_c_api.h\n", NewStructCFileName)

	f = must.M1(os.Create(NewStructHFileName))
	must.M(newStructHTemplate.Execute(f, allInfo))
	fmt.Printf("✅ Successfully generated %q based on pjrt_c_api.h\n", NewStructHFileName)
}
