package websocket

import (
	"bytes"
	"crypto/tls"
	"crypto/x509"
	"encoding/json"
	"fmt"
	"net/http"
	"net/url"
	"os"
	"software.sslmate.com/src/go-pkcs12"
	"strings"
	"sync"
	"time"

	"github.com/fosrl/newt/logger"
	"github.com/gorilla/websocket"
)

type Client struct {
	conn              *websocket.Conn
	config            *Config
	baseURL           string
	handlers          map[string]MessageHandler
	done              chan struct{}
	handlersMux       sync.RWMutex
	reconnectInterval time.Duration
	isConnected       bool
	reconnectMux      sync.RWMutex

	onConnect func() error
}

type ClientOption func(*Client)

type MessageHandler func(message WSMessage)

// WithBaseURL sets the base URL for the client
func WithBaseURL(url string) ClientOption {
	return func(c *Client) {
		c.baseURL = url
	}
}

func WithTLSConfig(tlsClientCertPath string) ClientOption {
	return func(c *Client) {
		c.config.TlsClientCert = tlsClientCertPath
	}
}

func (c *Client) OnConnect(callback func() error) {
	c.onConnect = callback
}

// NewClient creates a new Newt client
func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*Client, error) {
	config := &Config{
		NewtID:   newtID,
		Secret:   secret,
		Endpoint: endpoint,
	}

	client := &Client{
		config:            config,
		baseURL:           endpoint, // default value
		handlers:          make(map[string]MessageHandler),
		done:              make(chan struct{}),
		reconnectInterval: 10 * time.Second,
		isConnected:       false,
	}

	// Apply options before loading config
	if opts != nil {
		for _, opt := range opts {
			if opt == nil {
				continue
			}
			opt(client)
		}
	}

	// Load existing config if available
	if err := client.loadConfig(); err != nil {
		return nil, fmt.Errorf("failed to load config: %w", err)
	}

	return client, nil
}

// Connect establishes the WebSocket connection
func (c *Client) Connect() error {
	go c.connectWithRetry()
	return nil
}

// Close closes the WebSocket connection
func (c *Client) Close() error {
	close(c.done)
	if c.conn != nil {
		return c.conn.Close()
	}

	// stop the ping monitor
	c.setConnected(false)

	return nil
}

// SendMessage sends a message through the WebSocket connection
func (c *Client) SendMessage(messageType string, data interface{}) error {
	if c.conn == nil {
		return fmt.Errorf("not connected")
	}

	msg := WSMessage{
		Type: messageType,
		Data: data,
	}

	return c.conn.WriteJSON(msg)
}

// RegisterHandler registers a handler for a specific message type
func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
	c.handlersMux.Lock()
	defer c.handlersMux.Unlock()
	c.handlers[messageType] = handler
}

// readPump pumps messages from the WebSocket connection
func (c *Client) readPump() {
	defer c.conn.Close()

	for {
		select {
		case <-c.done:
			return
		default:
			var msg WSMessage
			err := c.conn.ReadJSON(&msg)
			if err != nil {
				return
			}

			c.handlersMux.RLock()
			if handler, ok := c.handlers[msg.Type]; ok {
				handler(msg)
			}
			c.handlersMux.RUnlock()
		}
	}
}

func (c *Client) getToken() (string, error) {
	// Parse the base URL to ensure we have the correct hostname
	baseURL, err := url.Parse(c.baseURL)
	if err != nil {
		return "", fmt.Errorf("failed to parse base URL: %w", err)
	}

	// Ensure we have the base URL without trailing slashes
	baseEndpoint := strings.TrimRight(baseURL.String(), "/")

	var tlsConfig *tls.Config = nil
	if c.config.TlsClientCert != "" {
		tlsConfig, err = loadClientCertificate(c.config.TlsClientCert)
		if err != nil {
			return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
		}
	}

	// If we already have a token, try to use it
	if c.config.Token != "" {
		tokenCheckData := map[string]interface{}{
			"newtId": c.config.NewtID,
			"secret": c.config.Secret,
			"token":  c.config.Token,
		}
		jsonData, err := json.Marshal(tokenCheckData)
		if err != nil {
			return "", fmt.Errorf("failed to marshal token check data: %w", err)
		}

		// Create a new request
		req, err := http.NewRequest(
			"POST",
			baseEndpoint+"/api/v1/auth/newt/get-token",
			bytes.NewBuffer(jsonData),
		)
		if err != nil {
			return "", fmt.Errorf("failed to create request: %w", err)
		}

		// Set headers
		req.Header.Set("Content-Type", "application/json")
		req.Header.Set("X-CSRF-Token", "x-csrf-protection")

		// Make the request
		client := &http.Client{}
		if tlsConfig != nil {
			client.Transport = &http.Transport{
				TLSClientConfig: tlsConfig,
			}
		}
		resp, err := client.Do(req)
		if err != nil {
			return "", fmt.Errorf("failed to check token validity: %w", err)
		}
		defer resp.Body.Close()

		var tokenResp TokenResponse
		if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
			return "", fmt.Errorf("failed to decode token check response: %w", err)
		}

		// If token is still valid, return it
		if tokenResp.Success && tokenResp.Message == "Token session already valid" {
			return c.config.Token, nil
		}
	}

	// Get a new token
	tokenData := map[string]interface{}{
		"newtId": c.config.NewtID,
		"secret": c.config.Secret,
	}
	jsonData, err := json.Marshal(tokenData)
	if err != nil {
		return "", fmt.Errorf("failed to marshal token request data: %w", err)
	}

	// Create a new request
	req, err := http.NewRequest(
		"POST",
		baseEndpoint+"/api/v1/auth/newt/get-token",
		bytes.NewBuffer(jsonData),
	)
	if err != nil {
		return "", fmt.Errorf("failed to create request: %w", err)
	}

	// Set headers
	req.Header.Set("Content-Type", "application/json")
	req.Header.Set("X-CSRF-Token", "x-csrf-protection")

	// Make the request
	client := &http.Client{}
	if tlsConfig != nil {
		client.Transport = &http.Transport{
			TLSClientConfig: tlsConfig,
		}
	}
	resp, err := client.Do(req)
	if err != nil {
		return "", fmt.Errorf("failed to request new token: %w", err)
	}
	defer resp.Body.Close()

	var tokenResp TokenResponse
	if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
		// print out the token response for debugging
		buf := new(bytes.Buffer)
		buf.ReadFrom(resp.Body)
		logger.Info("Token response: %s", buf.String())
		return "", fmt.Errorf("failed to decode token response: %w", err)
	}

	if !tokenResp.Success {
		return "", fmt.Errorf("failed to get token: %s", tokenResp.Message)
	}

	if tokenResp.Data.Token == "" {
		return "", fmt.Errorf("received empty token from server")
	}

	return tokenResp.Data.Token, nil
}

func (c *Client) connectWithRetry() {
	for {
		select {
		case <-c.done:
			return
		default:
			err := c.establishConnection()
			if err != nil {
				logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
				time.Sleep(c.reconnectInterval)
				continue
			}
			return
		}
	}
}

func (c *Client) establishConnection() error {
	// Get token for authentication
	token, err := c.getToken()
	if err != nil {
		return fmt.Errorf("failed to get token: %w", err)
	}

	// Parse the base URL to determine protocol and hostname
	baseURL, err := url.Parse(c.baseURL)
	if err != nil {
		return fmt.Errorf("failed to parse base URL: %w", err)
	}

	// Determine WebSocket protocol based on HTTP protocol
	wsProtocol := "wss"
	if baseURL.Scheme == "http" {
		wsProtocol = "ws"
	}

	// Create WebSocket URL
	wsURL := fmt.Sprintf("%s://%s/api/v1/ws", wsProtocol, baseURL.Host)
	u, err := url.Parse(wsURL)
	if err != nil {
		return fmt.Errorf("failed to parse WebSocket URL: %w", err)
	}

	// Add token to query parameters
	q := u.Query()
	q.Set("token", token)
	u.RawQuery = q.Encode()

	// Connect to WebSocket
	dialer := websocket.DefaultDialer
	if c.config.TlsClientCert != "" {
		logger.Info("Adding tls to req")
		tlsConfig, err := loadClientCertificate(c.config.TlsClientCert)
		if err != nil {
			return fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
		}
		dialer.TLSClientConfig = tlsConfig
	}
	conn, _, err := dialer.Dial(u.String(), nil)
	if err != nil {
		return fmt.Errorf("failed to connect to WebSocket: %w", err)
	}

	c.conn = conn
	c.setConnected(true)

	// Start the ping monitor
	go c.pingMonitor()
	// Start the read pump
	go c.readPump()

	if c.onConnect != nil {
		err := c.saveConfig()
		if err != nil {
			logger.Error("Failed to save config: %v", err)
		}
		if err := c.onConnect(); err != nil {
			logger.Error("OnConnect callback failed: %v", err)
		}
	}

	return nil
}

func (c *Client) pingMonitor() {
	ticker := time.NewTicker(30 * time.Second)
	defer ticker.Stop()

	for {
		select {
		case <-c.done:
			return
		case <-ticker.C:
			if err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil {
				logger.Error("Ping failed: %v", err)
				c.reconnect()
				return
			}
		}
	}
}

func (c *Client) reconnect() {
	c.setConnected(false)
	if c.conn != nil {
		c.conn.Close()
	}

	go c.connectWithRetry()
}

func (c *Client) setConnected(status bool) {
	c.reconnectMux.Lock()
	defer c.reconnectMux.Unlock()
	c.isConnected = status
}

// LoadClientCertificate Helper method to load client certificates
func loadClientCertificate(p12Path string) (*tls.Config, error) {
	logger.Info("Loading tls-client-cert %s", p12Path)
	// Read the PKCS12 file
	p12Data, err := os.ReadFile(p12Path)
	if err != nil {
		return nil, fmt.Errorf("failed to read PKCS12 file: %w", err)
	}

	// Parse PKCS12 with empty password for non-encrypted files
	privateKey, certificate, caCerts, err := pkcs12.DecodeChain(p12Data, "")
	if err != nil {
		return nil, fmt.Errorf("failed to decode PKCS12: %w", err)
	}

	// Create certificate
	cert := tls.Certificate{
		Certificate: [][]byte{certificate.Raw},
		PrivateKey:  privateKey,
	}

	// Optional: Add CA certificates if present
	rootCAs, err := x509.SystemCertPool()
	if err != nil {
		return nil, fmt.Errorf("failed to load system cert pool: %w", err)
	}
	if len(caCerts) > 0 {
		for _, caCert := range caCerts {
			rootCAs.AddCert(caCert)
		}
	}

	// Create TLS configuration
	return &tls.Config{
		Certificates: []tls.Certificate{cert},
		RootCAs:      rootCAs,
	}, nil
}
