diff --git a/context.go b/context.go index 21e429748..3511cf7ac 100644 --- a/context.go +++ b/context.go @@ -139,8 +139,8 @@ func (c *Context) Response() http.ResponseWriter { return c.response } -// SetResponse sets `*http.ResponseWriter`. Some middleware require that given ResponseWriter implements following -// method `Unwrap() http.ResponseWriter` which eventually should return echo.Response instance. +// SetResponse sets `*http.ResponseWriter`. Some context methods and/or middleware require that given ResponseWriter implements following +// method `Unwrap() http.ResponseWriter` which eventually should return *echo.Response instance. func (c *Context) SetResponse(r http.ResponseWriter) { c.response = r } @@ -415,6 +415,15 @@ func (c *Context) Render(code int, name string, data any) (err error) { if c.echo.Renderer == nil { return ErrRendererNotRegistered } + // as Renderer.Render can fail, and in that case we need to delay sending status code to the client until + // (global) error handler decides the correct status code for the error to be sent to the client, so we need to write + // the rendered template to the buffer first. + // + // html.Template.ExecuteTemplate() documentations writes: + // > If an error occurs executing the template or writing its output, + // > execution stops, but partial results may already have been written to + // > the output writer. + buf := new(bytes.Buffer) if err = c.echo.Renderer.Render(c, buf, name, data); err != nil { return @@ -454,7 +463,18 @@ func (c *Context) jsonPBlob(code int, callback string, i any) (err error) { func (c *Context) json(code int, i any, indent string) error { c.writeContentType(MIMEApplicationJSON) - c.response.WriteHeader(code) + + // as JSONSerializer.Serialize can fail, and in that case we need to delay sending status code to the client until + // (global) error handler decides correct status code for the error to be sent to the client. + // For that we need to use writer that can store the proposed status code until the first Write is called. + if r, err := UnwrapResponse(c.response); err == nil { + r.Status = code + } else { + resp := c.Response() + c.SetResponse(&delayedStatusWriter{ResponseWriter: resp, status: code}) + defer c.SetResponse(resp) + } + return c.echo.JSONSerializer.Serialize(c, i, indent) } diff --git a/context_test.go b/context_test.go index 6c8dd01af..5945c9ecc 100644 --- a/context_test.go +++ b/context_test.go @@ -12,6 +12,7 @@ import ( "io" "io/fs" "log/slog" + "math" "mime/multipart" "net/http" "net/http/httptest" @@ -138,6 +139,24 @@ func TestContextRenderTemplate(t *testing.T) { } } +func TestContextRenderTemplateError(t *testing.T) { + // we test that when template rendering fails, no response is sent to the client yet, so the global error handler can decide what to do + e := New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + tmpl := &Template{ + templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")), + } + c.Echo().Renderer = tmpl + err := c.Render(http.StatusOK, "not_existing", "Jon Snow") + + assert.EqualError(t, err, `template: no template "not_existing" associated with template "hello"`) + assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client + assert.Empty(t, rec.Body.String()) // body must not be sent to the client +} + func TestContextRenderErrorsOnNoRenderer(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) @@ -173,10 +192,9 @@ func TestContextStream(t *testing.T) { } func TestContextHTML(t *testing.T) { - e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec) + c := NewContext(req, rec) err := c.HTML(http.StatusOK, "Hi, Jon Snow") if assert.NoError(t, err) { @@ -187,10 +205,9 @@ func TestContextHTML(t *testing.T) { } func TestContextHTMLBlob(t *testing.T) { - e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec) + c := NewContext(req, rec) err := c.HTMLBlob(http.StatusOK, []byte("Hi, Jon Snow")) if assert.NoError(t, err) { @@ -222,6 +239,24 @@ func TestContextJSONErrorsOut(t *testing.T) { err := c.JSON(http.StatusOK, make(chan bool)) assert.EqualError(t, err, "json: unsupported type: chan bool") + + assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client + assert.Empty(t, rec.Body.String()) // body must not be sent to the client +} + +func TestContextJSONWithNotEchoResponse(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + c := e.NewContext(req, rec) + + c.SetResponse(rec) + + err := c.JSON(http.StatusCreated, map[string]float64{"foo": math.NaN()}) + assert.EqualError(t, err, "json: unsupported value: NaN") + + assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client + assert.Empty(t, rec.Body.String()) // body must not be sent to the client } func TestContextJSONPretty(t *testing.T) { diff --git a/response.go b/response.go index 5cb9a78a1..aa9046765 100644 --- a/response.go +++ b/response.go @@ -126,7 +126,47 @@ func UnwrapResponse(rw http.ResponseWriter) (*Response, error) { rw = t.Unwrap() continue default: - return nil, errors.New("ResponseWriter does not implement 'Unwrap() http.ResponseWriter' interface") + return nil, errors.New("ResponseWriter does not implement 'Unwrap() http.ResponseWriter' interface or unwrap to *echo.Response") } } } + +// delayedStatusWriter is a wrapper around http.ResponseWriter that delays writing the status code until first Write is called. +// This allows (global) error handler to decide correct status code to be sent to the client. +type delayedStatusWriter struct { + http.ResponseWriter + commited bool + status int +} + +func (w *delayedStatusWriter) WriteHeader(statusCode int) { + // in case something else writes status code explicitly before us we need mark response commited + w.commited = true + w.ResponseWriter.WriteHeader(statusCode) +} + +func (w *delayedStatusWriter) Write(data []byte) (int, error) { + if !w.commited { + w.commited = true + if w.status == 0 { + w.status = http.StatusOK + } + w.ResponseWriter.WriteHeader(w.status) + } + return w.ResponseWriter.Write(data) +} + +func (w *delayedStatusWriter) Flush() { + err := http.NewResponseController(w.ResponseWriter).Flush() + if err != nil && errors.Is(err, http.ErrNotSupported) { + panic(errors.New("response writer flushing is not supported")) + } +} + +func (w *delayedStatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return http.NewResponseController(w.ResponseWriter).Hijack() +} + +func (w *delayedStatusWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} diff --git a/response_test.go b/response_test.go index 667e0e34a..6f069a499 100644 --- a/response_test.go +++ b/response_test.go @@ -115,3 +115,19 @@ func TestResponse_FlushPanics(t *testing.T) { res.Flush() }) } + +func TestResponse_UnwrapResponse(t *testing.T) { + orgRes := NewResponse(httptest.NewRecorder(), nil) + res, err := UnwrapResponse(orgRes) + + assert.NotNil(t, res) + assert.NoError(t, err) +} + +func TestResponse_UnwrapResponse_error(t *testing.T) { + rw := new(testResponseWriter) + res, err := UnwrapResponse(rw) + + assert.Nil(t, res) + assert.EqualError(t, err, "ResponseWriter does not implement 'Unwrap() http.ResponseWriter' interface or unwrap to *echo.Response") +}