From 5a1a1966da959ef488eeb4ccc42bd6ce8d6c5114 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 24 Aug 2022 20:35:30 +0000 Subject: [PATCH 1/5] feat: add panic recovery middleware --- coderd/coderd.go | 19 +-------- coderd/httpapi/httpapi.go | 12 ++++++ coderd/httpapi/request.go | 30 +++++++++++++ coderd/httpapi/status_writer.go | 56 +++++++++++++++++++++++++ coderd/httpmw/logger.go | 58 +++++++++++++++++++++++++ coderd/httpmw/prometheus.go | 21 +++------- coderd/httpmw/prometheus_test.go | 4 +- coderd/httpmw/recover.go | 40 ++++++++++++++++++ coderd/httpmw/recover_test.go | 72 ++++++++++++++++++++++++++++++++ coderd/tracing/httpmw.go | 10 +++-- 10 files changed, 285 insertions(+), 37 deletions(-) create mode 100644 coderd/httpapi/request.go create mode 100644 coderd/httpapi/status_writer.go create mode 100644 coderd/httpmw/logger.go create mode 100644 coderd/httpmw/recover.go create mode 100644 coderd/httpmw/recover_test.go diff --git a/coderd/coderd.go b/coderd/coderd.go index a2062b736227a..19b98bbc44ee9 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1,9 +1,7 @@ package coderd import ( - "context" "crypto/x509" - "fmt" "io" "net/http" "net/url" @@ -124,11 +122,8 @@ func New(options *Options) *API { apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, oauthConfigs, false) r.Use( - func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - next.ServeHTTP(middleware.NewWrapResponseWriter(w, r.ProtoMajor), r) - }) - }, + httpmw.Recover(api.Logger), + httpmw.Logger(api.Logger), httpmw.Prometheus(options.PrometheusRegistry), tracing.HTTPMW(api.TracerProvider, "coderd.http"), ) @@ -156,7 +151,6 @@ func New(options *Options) *API { r.Use( // Specific routes can specify smaller limits. httpmw.RateLimitPerMinute(options.APIRateLimit), - debugLogRequest(api.Logger), ) r.Get("/", func(w http.ResponseWriter, r *http.Request) { httpapi.Write(w, http.StatusOK, codersdk.Response{ @@ -433,15 +427,6 @@ func (api *API) Close() error { return api.workspaceAgentCache.Close() } -func debugLogRequest(log slog.Logger) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - log.Debug(context.Background(), fmt.Sprintf("%s %s", r.Method, r.URL.Path)) - next.ServeHTTP(rw, r) - }) - } -} - func compressHandler(h http.Handler) http.Handler { cmp := middleware.NewCompressor(5, "text/*", diff --git a/coderd/httpapi/httpapi.go b/coderd/httpapi/httpapi.go index b42d2257b45b5..5393a79bfc06c 100644 --- a/coderd/httpapi/httpapi.go +++ b/coderd/httpapi/httpapi.go @@ -59,6 +59,18 @@ func Forbidden(rw http.ResponseWriter) { }) } +func InternalServerError(rw http.ResponseWriter, err error) { + var details string + if err != nil { + details = err.Error() + } + + Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "An internal server error occurred.", + Detail: details, + }) +} + // Write outputs a standardized format to an HTTP response body. func Write(rw http.ResponseWriter, status int, response interface{}) { buf := &bytes.Buffer{} diff --git a/coderd/httpapi/request.go b/coderd/httpapi/request.go new file mode 100644 index 0000000000000..6a07ede6dce19 --- /dev/null +++ b/coderd/httpapi/request.go @@ -0,0 +1,30 @@ +package httpapi + +import "net/http" + +const ( + // XForwardedHostHeader is a header used by proxies to indicate the + // original host of the request. + XForwardedHostHeader = "X-Forwarded-Host" +) + +// RequestHost returns the name of the host from the request. It prioritizes +// 'X-Forwarded-Host' over r.Host since most requests are being proxied. +func RequestHost(r *http.Request) string { + host := r.Header.Get(XForwardedHostHeader) + if host != "" { + return host + } + + return r.Host +} + +func IsWebsocketUpgrade(r *http.Request) bool { + vs := r.Header.Values("Upgrade") + for _, v := range vs { + if v == "websocket" { + return true + } + } + return false +} diff --git a/coderd/httpapi/status_writer.go b/coderd/httpapi/status_writer.go new file mode 100644 index 0000000000000..a57d15f264d80 --- /dev/null +++ b/coderd/httpapi/status_writer.go @@ -0,0 +1,56 @@ +package httpapi + +import ( + "bufio" + "net" + "net/http" +) + +var _ http.ResponseWriter = (*StatusWriter)(nil) +var _ http.Hijacker = (*StatusWriter)(nil) + +// StatusWriter intercepts the status of the request and the response body up +// to maxBodySize if Status >= 400. It is guaranteed to be the ResponseWriter +// directly downstream from Middleware. +type StatusWriter struct { + http.ResponseWriter + Status int + Hijacked bool + ResponseBody []byte +} + +func (w *StatusWriter) WriteHeader(status int) { + w.Status = status + w.ResponseWriter.WriteHeader(status) +} + +func (w *StatusWriter) Write(b []byte) (int, error) { + const maxBodySize = 4096 + + if w.Status == 0 { + w.Status = http.StatusOK + } + + if w.Status >= http.StatusBadRequest { + // Instantiate the recorded response body to be at most + // maxBodySize length. + w.ResponseBody = make([]byte, minInt(len(b), maxBodySize)) + copy(w.ResponseBody, b) + } + + return w.ResponseWriter.Write(b) +} + +// minInt returns the smaller of a or b. This is helpful because math.Min only +// works with float64s. +func minInt(a, b int) int { + if a < b { + return a + } + return b +} + +func (w *StatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + w.Hijacked = true + return w.ResponseWriter.(http.Hijacker).Hijack() +} diff --git a/coderd/httpmw/logger.go b/coderd/httpmw/logger.go new file mode 100644 index 0000000000000..714d491980a93 --- /dev/null +++ b/coderd/httpmw/logger.go @@ -0,0 +1,58 @@ +package httpmw + +import ( + "net/http" + "time" + + "cdr.dev/slog" + "github.com/coder/coder/coderd/httpapi" +) + +func Logger(log slog.Logger) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + sw := &httpapi.StatusWriter{ResponseWriter: w} + + httplog := log.With( + slog.F("host", httpapi.RequestHost(r)), + slog.F("path", r.URL.Path), + slog.F("proto", r.Proto), + slog.F("remote_addr", r.RemoteAddr), + ) + + next.ServeHTTP(sw, r) + + // Don't log successful health check requests. + if r.URL.Path == "/api/v2" && sw.Status == 200 { + return + } + + httplog = httplog.With( + slog.F("took", time.Since(start)), + slog.F("status_code", sw.Status), + slog.F("latency_ms", float64(time.Since(start)/time.Millisecond)), + ) + + // For status codes 400 and higher we + // want to log the response body. + if sw.Status >= 400 { + httplog = httplog.With( + slog.F("response_body", string(sw.ResponseBody)), + ) + } + + logLevelFn := httplog.Debug + if sw.Status >= 400 { + logLevelFn = httplog.Warn + } + if sw.Status >= 500 { + // Server errors should be treated as an ERROR + // log level. + logLevelFn = httplog.Error + } + + logLevelFn(r.Context(), r.Method) + }) + } +} diff --git a/coderd/httpmw/prometheus.go b/coderd/httpmw/prometheus.go index acc57071f0ef0..d954adc2ccd2a 100644 --- a/coderd/httpmw/prometheus.go +++ b/coderd/httpmw/prometheus.go @@ -6,7 +6,8 @@ import ( "time" "github.com/go-chi/chi/v5" - chimw "github.com/go-chi/chi/v5/middleware" + + "github.com/coder/coder/coderd/httpapi" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -66,9 +67,9 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler rctx = chi.RouteContext(r.Context()) ) - sw, ok := w.(chimw.WrapResponseWriter) + sw, ok := w.(*httpapi.StatusWriter) if !ok { - panic("dev error: http.ResponseWriter is not chimw.WrapResponseWriter") + panic("dev error: http.ResponseWriter is not *httpapi.StatusWriter") } var ( @@ -76,7 +77,7 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler distOpts []string ) // We want to count WebSockets separately. - if isWebsocketUpgrade(r) { + if httpapi.IsWebsocketUpgrade(r) { websocketsConcurrent.Inc() defer websocketsConcurrent.Dec() @@ -93,20 +94,10 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler path := rctx.RoutePattern() distOpts = append(distOpts, path) - statusStr := strconv.Itoa(sw.Status()) + statusStr := strconv.Itoa(sw.Status) requestsProcessed.WithLabelValues(statusStr, method, path).Inc() dist.WithLabelValues(distOpts...).Observe(float64(time.Since(start)) / 1e6) }) } } - -func isWebsocketUpgrade(r *http.Request) bool { - vs := r.Header.Values("Upgrade") - for _, v := range vs { - if v == "websocket" { - return true - } - } - return false -} diff --git a/coderd/httpmw/prometheus_test.go b/coderd/httpmw/prometheus_test.go index 97c557540674a..95141834716c2 100644 --- a/coderd/httpmw/prometheus_test.go +++ b/coderd/httpmw/prometheus_test.go @@ -7,10 +7,10 @@ import ( "testing" "github.com/go-chi/chi/v5" - chimw "github.com/go-chi/chi/v5/middleware" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" + "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" ) @@ -20,7 +20,7 @@ func TestPrometheus(t *testing.T) { t.Parallel() req := httptest.NewRequest("GET", "/", nil) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, chi.NewRouteContext())) - res := chimw.NewWrapResponseWriter(httptest.NewRecorder(), 0) + res := &httpapi.StatusWriter{ResponseWriter: httptest.NewRecorder()} reg := prometheus.NewRegistry() httpmw.Prometheus(reg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) diff --git a/coderd/httpmw/recover.go b/coderd/httpmw/recover.go new file mode 100644 index 0000000000000..509d364552ca3 --- /dev/null +++ b/coderd/httpmw/recover.go @@ -0,0 +1,40 @@ +package httpmw + +import ( + "context" + "net/http" + "runtime/debug" + + "cdr.dev/slog" + "github.com/coder/coder/coderd/httpapi" +) + +func Recover(log slog.Logger) func(h http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + err := recover() + if err != nil { + log.Warn(context.Background(), + "panic serving http request (recovered)", + slog.F("err", err), + slog.F("stack", string(debug.Stack())), + ) + + var hijacked bool + if sw, ok := w.(*httpapi.StatusWriter); ok { + hijacked = sw.Hijacked + } + + // Only try to write errors on + // non-hijacked responses. + if !hijacked { + httpapi.InternalServerError(w, nil) + } + } + }() + + h.ServeHTTP(w, r) + }) + } +} diff --git a/coderd/httpmw/recover_test.go b/coderd/httpmw/recover_test.go new file mode 100644 index 0000000000000..ec1bcde04aba2 --- /dev/null +++ b/coderd/httpmw/recover_test.go @@ -0,0 +1,72 @@ +package httpmw_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/httpmw" +) + +func TestRecover(t *testing.T) { + handler := func(isPanic, hijack bool) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if isPanic { + panic("Oh no!") + } + + w.WriteHeader(http.StatusOK) + }) + } + + cases := []struct { + Name string + Code int + Panic bool + Hijack bool + }{ + { + Name: "OK", + Code: http.StatusOK, + Panic: false, + Hijack: false, + }, + { + Name: "Panic", + Code: http.StatusInternalServerError, + Panic: true, + Hijack: false, + }, + { + Name: "Hijack", + Code: 0, + Panic: true, + Hijack: true, + }, + } + + for _, c := range cases { + c := c + + t.Run(c.Name, func(t *testing.T) { + t.Parallel() + + var ( + log = slogtest.Make(t, nil) + r = httptest.NewRequest("GET", "/", nil) + w = &httpapi.StatusWriter{ + ResponseWriter: httptest.NewRecorder(), + Hijacked: c.Hijack, + } + ) + + httpmw.Recover(log)(handler(c.Panic, c.Hijack)).ServeHTTP(w, r) + + require.Equal(t, c.Code, w.Status) + }) + } +} diff --git a/coderd/tracing/httpmw.go b/coderd/tracing/httpmw.go index 6e22e68e970f6..dd3f2c8297240 100644 --- a/coderd/tracing/httpmw.go +++ b/coderd/tracing/httpmw.go @@ -4,10 +4,11 @@ import ( "fmt" "net/http" - "github.com/go-chi/chi/middleware" "github.com/go-chi/chi/v5" sdktrace "go.opentelemetry.io/otel/sdk/trace" semconv "go.opentelemetry.io/otel/semconv/v1.10.0" + + "github.com/coder/coder/coderd/httpapi" ) // HTTPMW adds tracing to http routes. @@ -24,7 +25,10 @@ func HTTPMW(tracerProvider *sdktrace.TracerProvider, name string) func(http.Hand defer span.End() r = r.WithContext(ctx) - wrw := middleware.NewWrapResponseWriter(rw, r.ProtoMajor) + sw, ok := rw.(*httpapi.StatusWriter) + if !ok { + panic("ResponseWriter not a *httpapi.StatusWriter?") + } // pass the span through the request context and serve the request to the next middleware next.ServeHTTP(rw, r) @@ -41,7 +45,7 @@ func HTTPMW(tracerProvider *sdktrace.TracerProvider, name string) func(http.Hand span.SetAttributes(semconv.HTTPRouteKey.String(route)) // set the status code - status := wrw.Status() + status := sw.Status // 0 status means one has not yet been sent in which case net/http library will write StatusOK if status == 0 { status = http.StatusOK From 5d37630c71ff56cae658f4997af750fb84caad07 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 24 Aug 2022 21:56:52 +0000 Subject: [PATCH 2/5] add some more tests --- coderd/httpapi/status_writer.go | 29 ++++-- coderd/httpapi/status_writer_test.go | 129 +++++++++++++++++++++++++++ coderd/httpmw/recover_test.go | 2 + 3 files changed, 152 insertions(+), 8 deletions(-) create mode 100644 coderd/httpapi/status_writer_test.go diff --git a/coderd/httpapi/status_writer.go b/coderd/httpapi/status_writer.go index a57d15f264d80..8643087ef52df 100644 --- a/coderd/httpapi/status_writer.go +++ b/coderd/httpapi/status_writer.go @@ -4,6 +4,8 @@ import ( "bufio" "net" "net/http" + + "golang.org/x/xerrors" ) var _ http.ResponseWriter = (*StatusWriter)(nil) @@ -17,23 +19,31 @@ type StatusWriter struct { Status int Hijacked bool ResponseBody []byte + + wroteHeader bool } func (w *StatusWriter) WriteHeader(status int) { - w.Status = status - w.ResponseWriter.WriteHeader(status) + if !w.wroteHeader { + w.Status = status + w.wroteHeader = true + w.ResponseWriter.WriteHeader(status) + } } func (w *StatusWriter) Write(b []byte) (int, error) { const maxBodySize = 4096 - if w.Status == 0 { + if !w.wroteHeader { w.Status = http.StatusOK } if w.Status >= http.StatusBadRequest { - // Instantiate the recorded response body to be at most - // maxBodySize length. + // This is technically wrong as multiple calls to write + // will simply overwrite w.ResponseBody but given that + // we typically only write to the response body once + // and this field is only used for logging I'm leaving + // this as-is. w.ResponseBody = make([]byte, minInt(len(b), maxBodySize)) copy(w.ResponseBody, b) } @@ -41,8 +51,6 @@ func (w *StatusWriter) Write(b []byte) (int, error) { return w.ResponseWriter.Write(b) } -// minInt returns the smaller of a or b. This is helpful because math.Min only -// works with float64s. func minInt(a, b int) int { if a < b { return a @@ -52,5 +60,10 @@ func minInt(a, b int) int { func (w *StatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { w.Hijacked = true - return w.ResponseWriter.(http.Hijacker).Hijack() + hijacker, ok := w.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, xerrors.Errorf("%T is not a http.Hijacker", w.ResponseWriter) + } + + return hijacker.Hijack() } diff --git a/coderd/httpapi/status_writer_test.go b/coderd/httpapi/status_writer_test.go new file mode 100644 index 0000000000000..57ded1de3b558 --- /dev/null +++ b/coderd/httpapi/status_writer_test.go @@ -0,0 +1,129 @@ +package httpapi_test + +import ( + "bufio" + "crypto/rand" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/coderd/httpapi" +) + +func TestStatusWriter(t *testing.T) { + t.Parallel() + + t.Run("WriteHeader", func(t *testing.T) { + t.Parallel() + + var ( + rec = httptest.NewRecorder() + w = &httpapi.StatusWriter{ResponseWriter: rec} + ) + + w.WriteHeader(http.StatusOK) + require.Equal(t, http.StatusOK, w.Status) + // Validate that the code is written to the underlying Response. + require.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("WriteHeaderTwice", func(t *testing.T) { + t.Parallel() + + var ( + rec = httptest.NewRecorder() + w = &httpapi.StatusWriter{ResponseWriter: rec} + code = http.StatusNotFound + ) + + w.WriteHeader(code) + w.WriteHeader(http.StatusOK) + // Validate that we only record the first status code. + require.Equal(t, code, w.Status) + // Validate that the code is written to the underlying Response. + require.Equal(t, code, rec.Code) + }) + + t.Run("WriteNoHeader", func(t *testing.T) { + t.Parallel() + var ( + rec = httptest.NewRecorder() + w = &httpapi.StatusWriter{ResponseWriter: rec} + body = []byte("hello") + ) + + _, err := w.Write(body) + require.NoError(t, err) + + // Should set the status to OK. + require.Equal(t, http.StatusOK, w.Status) + // We don't record the body for codes <400. + require.Equal(t, []byte(nil), w.ResponseBody) + require.Equal(t, body, rec.Body.Bytes()) + }) + + t.Run("WriteAfterHeader", func(t *testing.T) { + t.Parallel() + var ( + rec = httptest.NewRecorder() + w = &httpapi.StatusWriter{ResponseWriter: rec} + body = []byte("hello") + code = http.StatusInternalServerError + ) + + w.WriteHeader(code) + _, err := w.Write(body) + require.NoError(t, err) + + require.Equal(t, code, w.Status) + require.Equal(t, body, w.ResponseBody) + require.Equal(t, body, rec.Body.Bytes()) + }) + + t.Run("WriteMaxBody", func(t *testing.T) { + t.Parallel() + var ( + rec = httptest.NewRecorder() + w = &httpapi.StatusWriter{ResponseWriter: rec} + // 8kb body. + body = make([]byte, 8<<10) + code = http.StatusInternalServerError + ) + + _, err := rand.Read(body) + require.NoError(t, err) + + w.WriteHeader(code) + _, err = w.Write(body) + require.NoError(t, err) + + require.Equal(t, code, w.Status) + require.Equal(t, body, rec.Body.Bytes()) + require.Equal(t, body[:4096], w.ResponseBody) + }) + + t.Run("Hijack", func(t *testing.T) { + t.Parallel() + var ( + rec = httptest.NewRecorder() + ) + + w := &httpapi.StatusWriter{ResponseWriter: hijacker{rec}} + + _, _, err := w.Hijack() + require.Error(t, err) + require.Equal(t, "hijacked", err.Error()) + }) +} + +type hijacker struct { + http.ResponseWriter +} + +func (hijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, xerrors.New("hijacked") +} diff --git a/coderd/httpmw/recover_test.go b/coderd/httpmw/recover_test.go index ec1bcde04aba2..f4b043f0baf6b 100644 --- a/coderd/httpmw/recover_test.go +++ b/coderd/httpmw/recover_test.go @@ -13,6 +13,8 @@ import ( ) func TestRecover(t *testing.T) { + t.Parallel() + handler := func(isPanic, hijack bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if isPanic { From f9d48b39db0320f5e23284e2b7eb45d9d959cd66 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 25 Aug 2022 22:30:58 +0000 Subject: [PATCH 3/5] pr comments --- coderd/httpapi/httpapi_test.go | 35 +++++++++++++++++++++++++++++++++ coderd/httpapi/status_writer.go | 5 +++-- coderd/httpmw/recover.go | 6 +++--- coderd/tracing/httpmw.go | 2 +- 4 files changed, 42 insertions(+), 6 deletions(-) diff --git a/coderd/httpapi/httpapi_test.go b/coderd/httpapi/httpapi_test.go index 35ed403ba48da..79a26d54a25b4 100644 --- a/coderd/httpapi/httpapi_test.go +++ b/coderd/httpapi/httpapi_test.go @@ -10,11 +10,46 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/codersdk" ) +func TestInternalServerError(t *testing.T) { + t.Parallel() + + t.Run("NoError", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + httpapi.InternalServerError(w, nil) + + var resp codersdk.Response + err := json.NewDecoder(w.Body).Decode(&resp) + require.NoError(t, err) + require.Equal(t, http.StatusInternalServerError, w.Code) + require.NotEmpty(t, resp.Message) + require.Empty(t, resp.Detail) + }) + + t.Run("WithError", func(t *testing.T) { + t.Parallel() + var ( + w = httptest.NewRecorder() + httpErr = xerrors.New("error!") + ) + + httpapi.InternalServerError(w, httpErr) + + var resp codersdk.Response + err := json.NewDecoder(w.Body).Decode(&resp) + require.NoError(t, err) + require.Equal(t, http.StatusInternalServerError, w.Code) + require.NotEmpty(t, resp.Message) + require.Equal(t, httpErr.Error(), resp.Detail) + }) +} + func TestWrite(t *testing.T) { t.Parallel() t.Run("NoErrors", func(t *testing.T) { diff --git a/coderd/httpapi/status_writer.go b/coderd/httpapi/status_writer.go index 8643087ef52df..e8ee5711f3995 100644 --- a/coderd/httpapi/status_writer.go +++ b/coderd/httpapi/status_writer.go @@ -27,8 +27,8 @@ func (w *StatusWriter) WriteHeader(status int) { if !w.wroteHeader { w.Status = status w.wroteHeader = true - w.ResponseWriter.WriteHeader(status) } + w.ResponseWriter.WriteHeader(status) } func (w *StatusWriter) Write(b []byte) (int, error) { @@ -36,6 +36,7 @@ func (w *StatusWriter) Write(b []byte) (int, error) { if !w.wroteHeader { w.Status = http.StatusOK + w.wroteHeader = true } if w.Status >= http.StatusBadRequest { @@ -59,11 +60,11 @@ func minInt(a, b int) int { } func (w *StatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - w.Hijacked = true hijacker, ok := w.ResponseWriter.(http.Hijacker) if !ok { return nil, nil, xerrors.Errorf("%T is not a http.Hijacker", w.ResponseWriter) } + w.Hijacked = true return hijacker.Hijack() } diff --git a/coderd/httpmw/recover.go b/coderd/httpmw/recover.go index 509d364552ca3..a25c063c5f85f 100644 --- a/coderd/httpmw/recover.go +++ b/coderd/httpmw/recover.go @@ -13,11 +13,11 @@ func Recover(log slog.Logger) func(h http.Handler) http.Handler { return func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { - err := recover() - if err != nil { + r := recover() + if r != nil { log.Warn(context.Background(), "panic serving http request (recovered)", - slog.F("err", err), + slog.F("panic", r), slog.F("stack", string(debug.Stack())), ) diff --git a/coderd/tracing/httpmw.go b/coderd/tracing/httpmw.go index dd3f2c8297240..fae4f4a5058de 100644 --- a/coderd/tracing/httpmw.go +++ b/coderd/tracing/httpmw.go @@ -27,7 +27,7 @@ func HTTPMW(tracerProvider *sdktrace.TracerProvider, name string) func(http.Hand sw, ok := rw.(*httpapi.StatusWriter) if !ok { - panic("ResponseWriter not a *httpapi.StatusWriter?") + panic(fmt.Sprintf("ResponseWriter not a *httpapi.StatusWriter; got %T", rw)) } // pass the span through the request context and serve the request to the next middleware From 1086ff905ec46a083a791a423749b8a60dac86c5 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Mon, 29 Aug 2022 21:18:56 +0000 Subject: [PATCH 4/5] unexport ResponseBody --- coderd/httpapi/status_writer.go | 10 +++++++--- coderd/httpapi/status_writer_test.go | 6 +++--- coderd/httpmw/logger.go | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/coderd/httpapi/status_writer.go b/coderd/httpapi/status_writer.go index e8ee5711f3995..dcdb3345c6002 100644 --- a/coderd/httpapi/status_writer.go +++ b/coderd/httpapi/status_writer.go @@ -18,7 +18,7 @@ type StatusWriter struct { http.ResponseWriter Status int Hijacked bool - ResponseBody []byte + responseBody []byte wroteHeader bool } @@ -45,8 +45,8 @@ func (w *StatusWriter) Write(b []byte) (int, error) { // we typically only write to the response body once // and this field is only used for logging I'm leaving // this as-is. - w.ResponseBody = make([]byte, minInt(len(b), maxBodySize)) - copy(w.ResponseBody, b) + w.responseBody = make([]byte, minInt(len(b), maxBodySize)) + copy(w.responseBody, b) } return w.ResponseWriter.Write(b) @@ -68,3 +68,7 @@ func (w *StatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return hijacker.Hijack() } + +func (w *StatusWriter) ResponseBody() []byte { + return w.responseBody +} diff --git a/coderd/httpapi/status_writer_test.go b/coderd/httpapi/status_writer_test.go index 57ded1de3b558..ee713ac555220 100644 --- a/coderd/httpapi/status_writer_test.go +++ b/coderd/httpapi/status_writer_test.go @@ -62,7 +62,7 @@ func TestStatusWriter(t *testing.T) { // Should set the status to OK. require.Equal(t, http.StatusOK, w.Status) // We don't record the body for codes <400. - require.Equal(t, []byte(nil), w.ResponseBody) + require.Equal(t, []byte(nil), w.ResponseBody()) require.Equal(t, body, rec.Body.Bytes()) }) @@ -80,7 +80,7 @@ func TestStatusWriter(t *testing.T) { require.NoError(t, err) require.Equal(t, code, w.Status) - require.Equal(t, body, w.ResponseBody) + require.Equal(t, body, w.ResponseBody()) require.Equal(t, body, rec.Body.Bytes()) }) @@ -103,7 +103,7 @@ func TestStatusWriter(t *testing.T) { require.Equal(t, code, w.Status) require.Equal(t, body, rec.Body.Bytes()) - require.Equal(t, body[:4096], w.ResponseBody) + require.Equal(t, body[:4096], w.ResponseBody()) }) t.Run("Hijack", func(t *testing.T) { diff --git a/coderd/httpmw/logger.go b/coderd/httpmw/logger.go index 714d491980a93..6f3a700bc56bf 100644 --- a/coderd/httpmw/logger.go +++ b/coderd/httpmw/logger.go @@ -38,7 +38,7 @@ func Logger(log slog.Logger) func(next http.Handler) http.Handler { // want to log the response body. if sw.Status >= 400 { httplog = httplog.With( - slog.F("response_body", string(sw.ResponseBody)), + slog.F("response_body", string(sw.ResponseBody())), ) } From 28149492636cd8c198c072488bc3420ff7f05506 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Mon, 29 Aug 2022 21:50:01 +0000 Subject: [PATCH 5/5] fix unique constraint --- coderd/workspaces.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/workspaces.go b/coderd/workspaces.go index af8c4eddec618..d524dc634027f 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -512,7 +512,7 @@ func (api *API) patchWorkspace(rw http.ResponseWriter, r *http.Request) { return } // Check if the name was already in use. - if database.IsUniqueViolation(err, database.UniqueWorkspacesOwnerIDLowerIndex) { + if database.IsUniqueViolation(err) { httpapi.Write(rw, http.StatusConflict, codersdk.Response{ Message: fmt.Sprintf("Workspace %q already exists.", req.Name), Validations: []codersdk.ValidationError{{