Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 05cdee1

Browse files
feat(sse): add middleware to prevent proxy buffering of SSE connections (#70)
Co-authored-by: Cian Johnston <[email protected]>
1 parent 8c224f1 commit 05cdee1

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

lib/httpapi/server.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,19 @@ func hostAuthorizationMiddleware(allowedHosts []string, badHostHandler http.Hand
263263
}
264264
}
265265

266+
// sseMiddleware creates middleware that prevents proxy buffering for SSE endpoints
267+
func sseMiddleware(ctx huma.Context, next func(huma.Context)) {
268+
// Disable proxy buffering for SSE endpoints
269+
ctx.SetHeader("Cache-Control", "no-cache, no-store, must-revalidate")
270+
ctx.SetHeader("Pragma", "no-cache")
271+
ctx.SetHeader("Expires", "0")
272+
ctx.SetHeader("X-Accel-Buffering", "no") // nginx
273+
ctx.SetHeader("X-Proxy-Buffering", "no") // generic proxy
274+
ctx.SetHeader("Connection", "keep-alive")
275+
276+
next(ctx)
277+
}
278+
266279
func (s *Server) StartSnapshotLoop(ctx context.Context) {
267280
s.conversation.StartSnapshotLoop(ctx)
268281
go func() {
@@ -299,6 +312,7 @@ func (s *Server) registerRoutes() {
299312
Path: "/events",
300313
Summary: "Subscribe to events",
301314
Description: "The events are sent as Server-Sent Events (SSE). Initially, the endpoint returns a list of events needed to reconstruct the current state of the conversation and the agent's status. After that, it only returns events that have occurred since the last event was sent.\n\nNote: When an agent is running, the last message in the conversation history is updated frequently, and the endpoint sends a new message update event each time.",
315+
Middlewares: []func(huma.Context, func(huma.Context)){sseMiddleware},
302316
}, map[string]any{
303317
// Mapping of event type name to Go struct for that event.
304318
"message_update": MessageUpdateBody{},
@@ -311,6 +325,7 @@ func (s *Server) registerRoutes() {
311325
Path: "/internal/screen",
312326
Summary: "Subscribe to screen",
313327
Hidden: true,
328+
Middlewares: []func(huma.Context, func(huma.Context)){sseMiddleware},
314329
}, map[string]any{
315330
"screen": ScreenUpdateBody{},
316331
}, s.subscribeScreen)
@@ -390,6 +405,7 @@ func (s *Server) subscribeEvents(ctx context.Context, input *struct{}, send sse.
390405
return
391406
}
392407
}
408+
393409
for {
394410
select {
395411
case event, ok := <-ch:

lib/httpapi/server_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/coder/agentapi/lib/httpapi"
1616
"github.com/coder/agentapi/lib/logctx"
1717
"github.com/coder/agentapi/lib/msgfmt"
18+
"github.com/stretchr/testify/assert"
1819
"github.com/stretchr/testify/require"
1920
)
2021

@@ -631,3 +632,50 @@ func TestServer_CORSPreflightOrigins(t *testing.T) {
631632
})
632633
}
633634
}
635+
636+
func TestServer_SSEMiddleware_Events(t *testing.T) {
637+
t.Parallel()
638+
ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil)))
639+
srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{
640+
AgentType: msgfmt.AgentTypeClaude,
641+
Process: nil,
642+
Port: 0,
643+
ChatBasePath: "/chat",
644+
AllowedHosts: []string{"*"},
645+
AllowedOrigins: []string{"*"},
646+
})
647+
require.NoError(t, err)
648+
tsServer := httptest.NewServer(srv.Handler())
649+
t.Cleanup(tsServer.Close)
650+
651+
t.Run("events", func(t *testing.T) {
652+
t.Parallel()
653+
resp, err := tsServer.Client().Get(tsServer.URL + "/events")
654+
require.NoError(t, err)
655+
t.Cleanup(func() {
656+
_ = resp.Body.Close()
657+
})
658+
assertSSEHeaders(t, resp)
659+
})
660+
661+
t.Run("internal/screen", func(t *testing.T) {
662+
t.Parallel()
663+
664+
resp, err := tsServer.Client().Get(tsServer.URL + "/internal/screen")
665+
require.NoError(t, err)
666+
t.Cleanup(func() {
667+
_ = resp.Body.Close()
668+
})
669+
assertSSEHeaders(t, resp)
670+
})
671+
}
672+
673+
func assertSSEHeaders(t testing.TB, resp *http.Response) {
674+
t.Helper()
675+
assert.Equal(t, "no-cache, no-store, must-revalidate", resp.Header.Get("Cache-Control"))
676+
assert.Equal(t, "no-cache", resp.Header.Get("Pragma"))
677+
assert.Equal(t, "0", resp.Header.Get("Expires"))
678+
assert.Equal(t, "no", resp.Header.Get("X-Accel-Buffering"))
679+
assert.Equal(t, "no", resp.Header.Get("X-Proxy-Buffering"))
680+
assert.Equal(t, "keep-alive", resp.Header.Get("Connection"))
681+
}

0 commit comments

Comments
 (0)