diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index b72b4c3..90fd48f 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -263,6 +263,19 @@ func hostAuthorizationMiddleware(allowedHosts []string, badHostHandler http.Hand } } +// sseMiddleware creates middleware that prevents proxy buffering for SSE endpoints +func sseMiddleware(ctx huma.Context, next func(huma.Context)) { + // Disable proxy buffering for SSE endpoints + ctx.SetHeader("Cache-Control", "no-cache, no-store, must-revalidate") + ctx.SetHeader("Pragma", "no-cache") + ctx.SetHeader("Expires", "0") + ctx.SetHeader("X-Accel-Buffering", "no") // nginx + ctx.SetHeader("X-Proxy-Buffering", "no") // generic proxy + ctx.SetHeader("Connection", "keep-alive") + + next(ctx) +} + func (s *Server) StartSnapshotLoop(ctx context.Context) { s.conversation.StartSnapshotLoop(ctx) go func() { @@ -299,6 +312,7 @@ func (s *Server) registerRoutes() { Path: "/events", Summary: "Subscribe to events", Description: "The events are sent as Server-Sent Events (SSE). Initially, the endpoint returns a list of events needed to reconstruct the current state of the conversation and the agent's status. After that, it only returns events that have occurred since the last event was sent.\n\nNote: When an agent is running, the last message in the conversation history is updated frequently, and the endpoint sends a new message update event each time.", + Middlewares: []func(huma.Context, func(huma.Context)){sseMiddleware}, }, map[string]any{ // Mapping of event type name to Go struct for that event. "message_update": MessageUpdateBody{}, @@ -311,6 +325,7 @@ func (s *Server) registerRoutes() { Path: "/internal/screen", Summary: "Subscribe to screen", Hidden: true, + Middlewares: []func(huma.Context, func(huma.Context)){sseMiddleware}, }, map[string]any{ "screen": ScreenUpdateBody{}, }, s.subscribeScreen) @@ -390,6 +405,7 @@ func (s *Server) subscribeEvents(ctx context.Context, input *struct{}, send sse. return } } + for { select { case event, ok := <-ch: diff --git a/lib/httpapi/server_test.go b/lib/httpapi/server_test.go index bc50d3e..3778fc7 100644 --- a/lib/httpapi/server_test.go +++ b/lib/httpapi/server_test.go @@ -15,6 +15,7 @@ import ( "github.com/coder/agentapi/lib/httpapi" "github.com/coder/agentapi/lib/logctx" "github.com/coder/agentapi/lib/msgfmt" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -631,3 +632,50 @@ func TestServer_CORSPreflightOrigins(t *testing.T) { }) } } + +func TestServer_SSEMiddleware_Events(t *testing.T) { + t.Parallel() + ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) + srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + AgentType: msgfmt.AgentTypeClaude, + Process: nil, + Port: 0, + ChatBasePath: "/chat", + AllowedHosts: []string{"*"}, + AllowedOrigins: []string{"*"}, + }) + require.NoError(t, err) + tsServer := httptest.NewServer(srv.Handler()) + t.Cleanup(tsServer.Close) + + t.Run("events", func(t *testing.T) { + t.Parallel() + resp, err := tsServer.Client().Get(tsServer.URL + "/events") + require.NoError(t, err) + t.Cleanup(func() { + _ = resp.Body.Close() + }) + assertSSEHeaders(t, resp) + }) + + t.Run("internal/screen", func(t *testing.T) { + t.Parallel() + + resp, err := tsServer.Client().Get(tsServer.URL + "/internal/screen") + require.NoError(t, err) + t.Cleanup(func() { + _ = resp.Body.Close() + }) + assertSSEHeaders(t, resp) + }) +} + +func assertSSEHeaders(t testing.TB, resp *http.Response) { + t.Helper() + assert.Equal(t, "no-cache, no-store, must-revalidate", resp.Header.Get("Cache-Control")) + assert.Equal(t, "no-cache", resp.Header.Get("Pragma")) + assert.Equal(t, "0", resp.Header.Get("Expires")) + assert.Equal(t, "no", resp.Header.Get("X-Accel-Buffering")) + assert.Equal(t, "no", resp.Header.Get("X-Proxy-Buffering")) + assert.Equal(t, "keep-alive", resp.Header.Get("Connection")) +}