package httpapi

import (
	"context"
	"errors"
	"time"

	"golang.org/x/xerrors"

	"cdr.dev/slog/v3"
	"github.com/coder/websocket"
)

const HeartbeatInterval time.Duration = 15 * time.Second

// HeartbeatClose loops to ping a WebSocket to keep it alive. It calls `exit` on ping
// failure.
func HeartbeatClose(ctx context.Context, logger slog.Logger, exit func(), conn *websocket.Conn) {
	ticker := time.NewTicker(HeartbeatInterval)
	defer ticker.Stop()

	for {
		select {
		case <-ctx.Done():
			return
		case <-ticker.C:
		}
		err := pingWithTimeout(ctx, conn, HeartbeatInterval)
		if err != nil {
			// context.DeadlineExceeded is expected when the client disconnects without sending a close frame
			if !errors.Is(err, context.DeadlineExceeded) {
				logger.Error(ctx, "failed to heartbeat ping", slog.Error(err))
			}
			_ = conn.Close(websocket.StatusGoingAway, "Ping failed")
			exit()
			return
		}
	}
}

func pingWithTimeout(ctx context.Context, conn *websocket.Conn, timeout time.Duration) error {
	ctx, cancel := context.WithTimeout(ctx, timeout)
	defer cancel()
	err := conn.Ping(ctx)
	if err != nil {
		return xerrors.Errorf("failed to ping: %w", err)
	}

	return nil
}
