package gateway

import (
	"errors"
	"fmt"
	"net"
	"time"

	"go.sia.tech/core/types"
	"go.sia.tech/mux"
	"lukechampine.com/frand"
)

// A UniqueID is a randomly-generated nonce that helps prevent self-connections
// and double-connections.
type UniqueID [8]byte

// GenerateUniqueID returns a random UniqueID.
func GenerateUniqueID() (id UniqueID) {
	frand.Read(id[:])
	return
}

// A Header contains various peer metadata which is exchanged during the gateway
// handshake.
type Header struct {
	GenesisID  types.BlockID
	UniqueID   UniqueID
	NetAddress string
}

func validateHeader(ours, theirs Header) error {
	if theirs.GenesisID != ours.GenesisID {
		return errors.New("peer has different genesis block")
	} else if theirs.UniqueID == ours.UniqueID {
		return errors.New("peer has same unique ID as us")
	}
	return nil
}

func writeHeader(conn net.Conn, ourHeader Header) error {
	var accept string
	if err := withV1Encoder(conn, ourHeader.encodeTo); err != nil {
		return fmt.Errorf("could not write our header: %w", err)
	} else if err := withV1Decoder(conn, 128, func(d *types.Decoder) { accept = d.ReadString() }); err != nil {
		return fmt.Errorf("could not read peer header acceptance: %w", err)
	} else if accept != "accept" {
		return fmt.Errorf("peer rejected our header: %v", accept)
	}
	return nil
}

func readHeader(conn net.Conn, ourHeader Header, dialAddr *string, uniqueID *UniqueID) error {
	var peerHeader Header
	if err := withV1Decoder(conn, 32+8+128, peerHeader.decodeFrom); err != nil {
		return fmt.Errorf("could not read peer's header: %w", err)
	} else if err := validateHeader(ourHeader, peerHeader); err != nil {
		withV1Encoder(conn, func(e *types.Encoder) { e.WriteString(err.Error()) })
		return fmt.Errorf("unacceptable header: %w", err)
	} else if err := withV1Encoder(conn, func(e *types.Encoder) { e.WriteString("accept") }); err != nil {
		return fmt.Errorf("could not write accept: %w", err)
	} else if host, _, err := net.SplitHostPort(conn.RemoteAddr().String()); err != nil {
		return fmt.Errorf("invalid remote addr (%q): %w", conn.RemoteAddr(), err)
	} else if _, port, err := net.SplitHostPort(peerHeader.NetAddress); err != nil {
		return fmt.Errorf("peer provided invalid net address (%q): %w", peerHeader.NetAddress, err)
	} else {
		*dialAddr = net.JoinHostPort(host, port)
		*uniqueID = peerHeader.UniqueID
	}
	return nil
}

// A Transport provides a multiplexing transport for the Sia gateway protocol.
type Transport struct {
	UniqueID UniqueID
	Version  string
	Addr     string
	mux      *mux.Mux
}

// DialStream opens a new multiplexed stream.
func (t *Transport) DialStream() (*Stream, error) {
	return &Stream{mux: t.mux.DialStream()}, nil
}

// AcceptStream accepts an incoming multiplexed stream.
func (t *Transport) AcceptStream() (*Stream, error) {
	s, err := t.mux.AcceptStream()
	return &Stream{mux: s}, err
}

// Close closes the underlying connection.
func (t *Transport) Close() error {
	return t.mux.Close()
}

// A Stream provides a multiplexed stream for the Sia gateway protocol.
type Stream struct {
	mux *mux.Stream
}

func (s *Stream) withEncoder(fn func(*types.Encoder)) error {
	return withV2Encoder(s.mux, fn)
}

func (s *Stream) withDecoder(maxLen int, fn func(*types.Decoder)) error {
	return withV2Decoder(s.mux, maxLen, fn)
}

// WriteID writes the RPC ID of r to the stream.
func (s *Stream) WriteID(r Object) error {
	id := idForObject(r)
	return s.withEncoder(id.EncodeTo)
}

// ReadID reads an RPC ID from the stream.
func (s *Stream) ReadID() (id types.Specifier, err error) {
	err = s.withDecoder(16, id.DecodeFrom)
	return
}

// WriteRequest writes the request field of r to the stream.
func (s *Stream) WriteRequest(r Object) error {
	return s.withEncoder(r.encodeRequest)
}

// ReadRequest reads a request from the stream into r.
func (s *Stream) ReadRequest(r Object) error {
	if r.maxRequestLen() == 0 {
		return nil
	}
	return s.withDecoder(r.maxRequestLen(), r.decodeRequest)
}

// WriteResponse writes the response field of r to the stream.
func (s *Stream) WriteResponse(r Object) error {
	return s.withEncoder(r.encodeResponse)
}

// ReadResponse reads a response from the stream into r.
func (s *Stream) ReadResponse(r Object) error {
	if r.maxResponseLen() == 0 {
		return nil
	}
	return s.withDecoder(r.maxResponseLen(), r.decodeResponse)
}

// SetDeadline implements net.Conn.
func (s *Stream) SetDeadline(t time.Time) error {
	return s.mux.SetDeadline(t)
}

// Close closes the stream.
func (s *Stream) Close() error {
	return s.mux.Close()
}

// Dial initiates the gateway handshake with a peer.
func Dial(conn net.Conn, ourHeader Header) (*Transport, error) {
	p := &Transport{}

	// exchange versions
	const ourVersion = "2.0.0"
	if err := withV1Encoder(conn, func(e *types.Encoder) { e.WriteString(ourVersion) }); err != nil {
		return nil, fmt.Errorf("could not write our version: %w", err)
	} else if err := withV1Decoder(conn, 128, func(d *types.Decoder) { p.Version = d.ReadString() }); err != nil {
		return nil, fmt.Errorf("could not read peer version: %w", err)
	}
	// exchange headers
	if err := writeHeader(conn, ourHeader); err != nil {
		return nil, fmt.Errorf("could not write our header: %w", err)
	} else if err := readHeader(conn, ourHeader, &p.Addr, &p.UniqueID); err != nil {
		return nil, fmt.Errorf("could not read peer's header: %w", err)
	}
	// establish mux
	var err error
	p.mux, err = mux.DialAnonymous(conn)
	return p, err
}

// Accept reciprocates the gateway handshake with a peer.
func Accept(conn net.Conn, ourHeader Header) (*Transport, error) {
	p := &Transport{}

	// exchange versions
	const ourVersion = "2.0.0"
	if err := withV1Decoder(conn, 128, func(d *types.Decoder) { p.Version = d.ReadString() }); err != nil {
		return nil, fmt.Errorf("could not read peer version: %w", err)
	} else if err := withV1Encoder(conn, func(e *types.Encoder) { e.WriteString(ourVersion) }); err != nil {
		return nil, fmt.Errorf("could not write our version: %w", err)
	}
	// exchange headers
	if err := readHeader(conn, ourHeader, &p.Addr, &p.UniqueID); err != nil {
		return nil, fmt.Errorf("could not read peer's header: %w", err)
	} else if err := writeHeader(conn, ourHeader); err != nil {
		return nil, fmt.Errorf("could not write our header: %w", err)
	}
	// establish mux
	var err error
	p.mux, err = mux.AcceptAnonymous(conn)
	return p, err
}
