Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions lib/httpapi/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,19 @@ func hostAuthorizationMiddleware(allowedHosts []string, badHostHandler http.Hand
}
}

// sseMiddleware creates middleware that prevents proxy buffering for SSE endpoints
func sseMiddleware(ctx huma.Context, next func(huma.Context)) {
// Disable proxy buffering for SSE endpoints
ctx.SetHeader("Cache-Control", "no-cache, no-store, must-revalidate")
ctx.SetHeader("Pragma", "no-cache")
ctx.SetHeader("Expires", "0")
ctx.SetHeader("X-Accel-Buffering", "no") // nginx
ctx.SetHeader("X-Proxy-Buffering", "no") // generic proxy
ctx.SetHeader("Connection", "keep-alive")

next(ctx)
}

func (s *Server) StartSnapshotLoop(ctx context.Context) {
s.conversation.StartSnapshotLoop(ctx)
go func() {
Expand Down Expand Up @@ -299,6 +312,7 @@ func (s *Server) registerRoutes() {
Path: "/events",
Summary: "Subscribe to events",
Description: "The events are sent as Server-Sent Events (SSE). Initially, the endpoint returns a list of events needed to reconstruct the current state of the conversation and the agent's status. After that, it only returns events that have occurred since the last event was sent.\n\nNote: When an agent is running, the last message in the conversation history is updated frequently, and the endpoint sends a new message update event each time.",
Middlewares: []func(huma.Context, func(huma.Context)){sseMiddleware},
}, map[string]any{
// Mapping of event type name to Go struct for that event.
"message_update": MessageUpdateBody{},
Expand All @@ -311,6 +325,7 @@ func (s *Server) registerRoutes() {
Path: "/internal/screen",
Summary: "Subscribe to screen",
Hidden: true,
Middlewares: []func(huma.Context, func(huma.Context)){sseMiddleware},
}, map[string]any{
"screen": ScreenUpdateBody{},
}, s.subscribeScreen)
Expand Down Expand Up @@ -390,6 +405,7 @@ func (s *Server) subscribeEvents(ctx context.Context, input *struct{}, send sse.
return
}
}

for {
select {
case event, ok := <-ch:
Expand Down
48 changes: 48 additions & 0 deletions lib/httpapi/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/coder/agentapi/lib/httpapi"
"github.com/coder/agentapi/lib/logctx"
"github.com/coder/agentapi/lib/msgfmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -631,3 +632,50 @@ func TestServer_CORSPreflightOrigins(t *testing.T) {
})
}
}

func TestServer_SSEMiddleware_Events(t *testing.T) {
t.Parallel()
ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil)))
srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{
AgentType: msgfmt.AgentTypeClaude,
Process: nil,
Port: 0,
ChatBasePath: "/chat",
AllowedHosts: []string{"*"},
AllowedOrigins: []string{"*"},
})
require.NoError(t, err)
tsServer := httptest.NewServer(srv.Handler())
t.Cleanup(tsServer.Close)

t.Run("events", func(t *testing.T) {
t.Parallel()
resp, err := tsServer.Client().Get(tsServer.URL + "/events")
require.NoError(t, err)
t.Cleanup(func() {
_ = resp.Body.Close()
})
assertSSEHeaders(t, resp)
})

t.Run("internal/screen", func(t *testing.T) {
t.Parallel()

resp, err := tsServer.Client().Get(tsServer.URL + "/internal/screen")
require.NoError(t, err)
t.Cleanup(func() {
_ = resp.Body.Close()
})
assertSSEHeaders(t, resp)
})
}

func assertSSEHeaders(t testing.TB, resp *http.Response) {
t.Helper()
assert.Equal(t, "no-cache, no-store, must-revalidate", resp.Header.Get("Cache-Control"))
assert.Equal(t, "no-cache", resp.Header.Get("Pragma"))
assert.Equal(t, "0", resp.Header.Get("Expires"))
assert.Equal(t, "no", resp.Header.Get("X-Accel-Buffering"))
assert.Equal(t, "no", resp.Header.Get("X-Proxy-Buffering"))
assert.Equal(t, "keep-alive", resp.Header.Get("Connection"))
}