package main

import (
	"errors"
	"fmt"
	"io"
	"slices"
	"strings"

	routingspb "go.temporal.io/server/api/routing/v1"
	"go.temporal.io/server/cmd/tools/codegen"
	"google.golang.org/protobuf/compiler/protogen"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/types/descriptorpb"
	"google.golang.org/protobuf/types/pluginpb"
)

const generatedFilenameExtension = "_client.pb.go"

type writer struct {
	builder     strings.Builder
	indentation int
}

func (w *writer) print(f string, args ...any) {
	// Ignoring error as strings.Builder.WriteString never returns an error.
	for i := 0; i < w.indentation; i++ {
		_, _ = w.builder.WriteString("\t")
	}
	_, _ = fmt.Fprintf(&w.builder, f, args...)
}

func (w *writer) println(f string, args ...any) {
	w.print(f, args...)
	_, _ = w.builder.WriteString("\n")
}

func (w *writer) indent() {
	w.indentation++
}

func (w *writer) unindent() {
	if w.indentation <= 0 {
		// nolint: forbidigo
		panic("unmatched unindent")
	}
	w.indentation--
}

type Plugin struct {
	*protogen.Plugin
}

func New() *Plugin {
	p := &Plugin{}

	return p
}

func (p *Plugin) Run(plugin *protogen.Plugin) error {
	plugin.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_SUPPORTS_EDITIONS | pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)
	plugin.SupportedEditionsMinimum = descriptorpb.Edition_EDITION_PROTO3
	plugin.SupportedEditionsMaximum = descriptorpb.Edition_EDITION_2023
	p.Plugin = plugin

	for _, file := range plugin.Files {
		if !file.Generate {
			continue
		}
		if !strings.Contains(string(file.GoImportPath), "go.temporal.io/server/chasm/lib") {
			continue
		}
		if len(file.Services) == 0 {
			continue
		}
		// create the file
		generatedFile := p.NewGeneratedFile(file.GeneratedFilenamePrefix+generatedFilenameExtension, file.GoImportPath)

		w := &writer{}

		w.println("// Code generated by protoc-gen-go-chasm. DO NOT EDIT.")
		w.println("package %s", file.GoPackageName)
		w.println("")
		w.println("import (")
		w.indent()
		w.println(`"context"`)
		w.println(`"time"`)
		w.println("\n")
		w.println(`"go.temporal.io/server/client/history"`)
		w.println(`"go.temporal.io/server/common"`)
		w.println(`"go.temporal.io/server/common/backoff"`)
		w.println(`"go.temporal.io/server/common/config"`)
		w.println(`"go.temporal.io/server/common/dynamicconfig"`)
		w.println(`"go.temporal.io/server/common/headers"`)
		w.println(`"go.temporal.io/server/common/log"`)
		w.println(`"go.temporal.io/server/common/membership"`)
		w.println(`"go.temporal.io/server/common/metrics"`)
		w.println(`"google.golang.org/grpc"`)
		w.unindent()
		w.println(")")

		for _, svc := range file.Services {
			if err := p.genClient(w, svc); err != nil {
				return err
			}
		}

		if _, err := io.WriteString(generatedFile, w.builder.String()); err != nil {
			return err
		}
	}

	return nil
}

func genAssignShard(m *protogen.Method) (string, error) {
	opts, err := routingOptions(m)
	if err != nil {
		return "", err
	}
	if opts == nil {
		return "", fmt.Errorf("no routing directive specified on %s", m.Desc.FullName())
	}
	if opts.Random && (opts.NamespaceId != "" || opts.BusinessId != "") {
		return "", fmt.Errorf("random directive cannot be combined with namespace_id or business_id on %s", m.Desc.FullName())
	}
	if opts.Random {
		return "shardID := int32(rand.Intn(int(c.numShards)) + 1)", nil
	}
	if opts.BusinessId == "" {
		return "", fmt.Errorf("business_id directive empty on %s", m.Desc.FullName())
	}
	if opts.Random {
		return "", fmt.Errorf("random directive cannot be combined with namespace_id or business_id on %s", m.Desc.FullName())
	}

	namespaceIDField := opts.NamespaceId
	if namespaceIDField == "" {
		namespaceIDField = "namespace_id"
	}

	namespaceIDFieldGetter, err := goFieldPath(m, namespaceIDField)
	if err != nil {
		return "", fmt.Errorf("unable to resolve namespace_id field path %q: %w", namespaceIDField, err)
	}
	businessIDFieldGetter, err := goFieldPath(m, opts.BusinessId)
	if err != nil {
		return "", fmt.Errorf("unable to resolve business_id field path %q: %w", opts.BusinessId, err)
	}

	return fmt.Sprintf("shardID := common.WorkflowIDToHistoryShard(request%s, request%s, c.numShards)", namespaceIDFieldGetter, businessIDFieldGetter), nil
}

func goFieldPath(m *protogen.Method, path string) (string, error) {
	parts := strings.Split(path, ".")
	field := m.Input
	goPath := ""
	for _, part := range parts {
		fieldName := codegen.SnakeCaseToPascalCase(part)
		i := slices.IndexFunc(field.Fields, func(f *protogen.Field) bool {
			return f.GoName == fieldName
		})
		if i < 0 {
			return "", fmt.Errorf("field %s not found in %s", part, field.Desc.FullName())
		}
		field = field.Fields[i].Message
		// Convert to getter form
		goPath += "." + "Get" + fieldName + "()"
	}
	return goPath, nil
}

func routingOptions(m *protogen.Method) (*routingspb.RoutingOptions, error) {
	opts, ok := proto.GetExtension(m.Desc.Options(), routingspb.E_Routing).(*routingspb.RoutingOptions)
	if !ok {
		return nil, errors.New("no routing options extension found")
	}
	return opts, nil
}

func (p *Plugin) genClient(w *writer, svc *protogen.Service) error {
	structName := fmt.Sprintf("%sLayeredClient", svc.GoName)
	w.println("// %s is a client for %s.", structName, svc.GoName)
	w.println("type %s struct {", structName)
	w.indent()
	w.println("metricsHandler metrics.Handler")
	w.println("numShards      int32")
	w.println("redirector     history.Redirector[%sClient]", svc.GoName)
	w.println("retryPolicy    backoff.RetryPolicy")
	w.unindent()
	w.println("}")

	ctorName := fmt.Sprintf("New%s", structName)
	w.println("// %s initializes a new %s.", ctorName, structName)
	w.println("func %s(", ctorName)
	w.indent()
	w.println("dc *dynamicconfig.Collection,")
	w.println("rpcFactory     common.RPCFactory,")
	w.println("monitor        membership.Monitor,")
	w.println("config         *config.Persistence,")
	w.println("logger         log.Logger,")
	w.println("metricsHandler metrics.Handler,")
	w.unindent()
	w.println(") (%sClient, error) {", svc.GoName)
	w.indent() // start ctor body
	w.println("resolver, err := monitor.GetResolver(primitives.HistoryService)")
	w.println("if err != nil {")
	w.indent()
	w.println("return nil, err")
	w.unindent()
	w.println("}")
	w.println("connections := history.NewConnectionPool(resolver, rpcFactory, New%sClient)", svc.GoName)
	w.println("var redirector history.Redirector[%sClient]", svc.GoName)
	w.println("if dynamicconfig.HistoryClientOwnershipCachingEnabled.Get(dc)() {")
	w.indent() // start if
	w.println("redirector = history.NewCachingRedirector(")
	w.indent() // start args
	w.println("connections,")
	w.println("resolver,")
	w.println("logger,")
	w.println("dynamicconfig.HistoryClientOwnershipCachingStaleTTL.Get(dc),")
	w.unindent() // close args
	w.println(")")
	w.unindent() // close if
	w.println("} else {")
	w.indent() // start else
	w.println("redirector = history.NewBasicRedirector(connections, resolver)")
	w.unindent() // close else
	w.println("}")
	w.println("return &%s{", structName)
	w.indent() // start struct literal
	w.println("metricsHandler: metricsHandler,")
	w.println("redirector:     redirector,")
	w.println("numShards:      config.NumHistoryShards,")
	w.println("retryPolicy:    common.CreateHistoryClientRetryPolicy(),")
	w.unindent() // close struct literal
	w.println("}, nil")
	w.unindent() // close ctor body
	w.println("}")

	for _, method := range svc.Methods {
		w.println("func (c *%s) call%sNoRetry(", structName, method.GoName)
		w.indent()
		w.println("ctx context.Context,")
		w.println("request *%s,", method.Input.GoIdent.GoName)
		w.println("opts ...grpc.CallOption,")
		w.unindent()
		w.println(") (*%s, error) {", method.Output.GoIdent.GoName)
		w.indent()
		w.println("var response *%s", method.Output.GoIdent.GoName)
		w.println("var err error")
		w.println("startTime := time.Now().UTC()")
		w.println("// the caller is a namespace, hence the tag below.")
		w.println("caller := headers.GetCallerInfo(ctx).CallerName")
		w.println("metricsHandler := c.metricsHandler.WithTags(")
		w.indent() // start args
		w.println(`metrics.OperationTag("%s.%s"),`, svc.GoName, method.GoName)
		w.println("metrics.NamespaceTag(caller),")
		w.println("metrics.ServiceRoleTag(metrics.HistoryRoleTagValue),")
		w.unindent() // close args
		w.println(")")
		w.println("metrics.ClientRequests.With(metricsHandler).Record(1)")
		w.println("defer func() {")
		w.indent() // start defer
		w.println("if err != nil {")
		w.indent() // start if
		w.println("metrics.ClientFailures.With(metricsHandler).Record(1, metrics.ServiceErrorTypeTag(err))")
		w.unindent() // close if
		w.println("}")
		w.println("metrics.ClientLatency.With(metricsHandler).Record(time.Since(startTime))")
		w.unindent() // close defer
		w.println("}()")
		assignShard, err := genAssignShard(method)
		if err != nil {
			return err
		}
		w.println("%s", assignShard)
		w.println("op := func(ctx context.Context, client %sClient) error {", svc.GoName)
		w.indent()
		w.println("var err error")
		w.println("ctx, cancel := context.WithTimeout(ctx, history.DefaultTimeout)")
		w.println("defer cancel()")
		w.println("response, err = client.%s(ctx, request, opts...)", method.GoName)
		w.println("return err")
		w.unindent()
		w.println("}")
		w.println("err = c.redirector.Execute(ctx, shardID, op)")
		w.println("return response, err")
		w.unindent()
		w.println("}")

		w.println("func (c *%s) %s(", structName, method.GoName)
		w.indent()
		w.println("ctx context.Context,")
		w.println("request *%s,", method.Input.GoIdent.GoName)
		w.println("opts ...grpc.CallOption,")
		w.unindent()
		w.println(") (*%s, error) {", method.Output.GoIdent.GoName)
		w.indent()
		w.println("call := func(ctx context.Context) (*%s, error) {", method.Output.GoIdent.GoName)
		w.indent()
		w.println("return c.call%sNoRetry(ctx, request, opts...)", method.GoName)
		w.unindent()
		w.println("}")
		w.println("return backoff.ThrottleRetryContextWithReturn(ctx, call, c.retryPolicy, common.IsServiceClientTransientError)")
		w.unindent()
		w.println("}")
	}
	return nil
}

func main() {
	p := New()

	opts := protogen.Options{}

	opts.Run(p.Run)
}
