diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index ee6f33ef8..82220c0a1 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -1,23 +1,32 @@ ### Issue Description -### Checklist +### Working code to debug -- [ ] Dependencies installed -- [ ] No typos -- [ ] Searched existing issues and docs +```go +package main -### Expected behaviour +import ( + "github.com/labstack/echo/v4" + "net/http" + "net/http/httptest" + "testing" +) -### Actual behaviour +func TestExample(t *testing.T) { + e := echo.New() -### Steps to reproduce + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusOK, "Hello, World!") + }) -### Working code to debug + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() -```go -package main + e.ServeHTTP(rec, req) -func main() { + if rec.Code != http.StatusOK { + t.Errorf("got %d, want %d", rec.Code, http.StatusOK) + } } ``` diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index d2d3386c4..436254a63 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -14,17 +14,17 @@ permissions: env: # run static analysis only with the latest Go version - LATEST_GO_VERSION: "1.20" + LATEST_GO_VERSION: "1.25" jobs: check: runs-on: ubuntu-latest steps: - name: Checkout Code - uses: actions/checkout@v3 + uses: actions/checkout@v5 - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: go-version: ${{ env.LATEST_GO_VERSION }} check-latest: true diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index e06183d5e..c7780fd21 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -14,7 +14,7 @@ permissions: env: # run coverage and benchmarks only with the latest Go version - LATEST_GO_VERSION: "1.20" + LATEST_GO_VERSION: "1.25" jobs: test: @@ -25,15 +25,15 @@ jobs: # Echo tests with last four major releases (unless there are pressing vulnerabilities) # As we depend on `golang.org/x/` libraries which only support last 2 Go releases we could have situations when # we derive from last four major releases promise. - go: ["1.18", "1.19", "1.20"] + go: ["1.22", "1.23", "1.24", "1.25"] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: - name: Checkout Code - uses: actions/checkout@v3 + uses: actions/checkout@v5 - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} @@ -42,7 +42,7 @@ jobs: - name: Upload coverage to Codecov if: success() && matrix.go == env.LATEST_GO_VERSION && matrix.os == 'ubuntu-latest' - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v5 with: token: fail_ci_if_error: false @@ -53,18 +53,18 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout Code (Previous) - uses: actions/checkout@v3 + uses: actions/checkout@v5 with: ref: ${{ github.base_ref }} path: previous - name: Checkout Code (New) - uses: actions/checkout@v3 + uses: actions/checkout@v5 with: path: new - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: go-version: ${{ env.LATEST_GO_VERSION }} diff --git a/CHANGELOG.md b/CHANGELOG.md index fef7bb987..967fac2a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,144 @@ # Changelog +## v4.13.4 - 2025-05-22 + +**Enhancements** + +* chore: fix some typos in comment by @zhuhaicity in https://github.com/labstack/echo/pull/2735 +* CI: test with Go 1.24 by @aldas in https://github.com/labstack/echo/pull/2748 +* Add support for TLS WebSocket proxy by @t-ibayashi-safie in https://github.com/labstack/echo/pull/2762 + +**Security** + +* Update dependencies for [GO-2025-3487](https://pkg.go.dev/vuln/GO-2025-3487), [GO-2025-3503](https://pkg.go.dev/vuln/GO-2025-3503) and [GO-2025-3595](https://pkg.go.dev/vuln/GO-2025-3595) in https://github.com/labstack/echo/pull/2780 + + +## v4.13.3 - 2024-12-19 + +**Security** + +* Update golang.org/x/net dependency [GO-2024-3333](https://pkg.go.dev/vuln/GO-2024-3333) in https://github.com/labstack/echo/pull/2722 + + +## v4.13.2 - 2024-12-12 + +**Security** + +* Update dependencies (dependabot reports [GO-2024-3321](https://pkg.go.dev/vuln/GO-2024-3321)) in https://github.com/labstack/echo/pull/2721 + + +## v4.13.1 - 2024-12-11 + +**Fixes** + +* Fix BindBody ignoring `Transfer-Encoding: chunked` requests by @178inaba in https://github.com/labstack/echo/pull/2717 + + + +## v4.13.0 - 2024-12-04 + +**BREAKING CHANGE** JWT Middleware Removed from Core use [labstack/echo-jwt](https://github.com/labstack/echo-jwt) instead + +The JWT middleware has been **removed from Echo core** due to another security vulnerability, [CVE-2024-51744](https://nvd.nist.gov/vuln/detail/CVE-2024-51744). For more details, refer to issue [#2699](https://github.com/labstack/echo/issues/2699). A drop-in replacement is available in the [labstack/echo-jwt](https://github.com/labstack/echo-jwt) repository. + +**Important**: Direct assignments like `token := c.Get("user").(*jwt.Token)` will now cause a panic due to an invalid cast. Update your code accordingly. Replace the current imports from `"github.com/golang-jwt/jwt"` in your handlers to the new middleware version using `"github.com/golang-jwt/jwt/v5"`. + + +Background: + +The version of `golang-jwt/jwt` (v3.2.2) previously used in Echo core has been in an unmaintained state for some time. This is not the first vulnerability affecting this library; earlier issues were addressed in [PR #1946](https://github.com/labstack/echo/pull/1946). +JWT middleware was marked as deprecated in Echo core as of [v4.10.0](https://github.com/labstack/echo/releases/tag/v4.10.0) on 2022-12-27. If you did not notice that, consider leveraging tools like [Staticcheck](https://staticcheck.dev/) to catch such deprecations earlier in you dev/CI flow. For bonus points - check out [gosec](https://github.com/securego/gosec). + +We sincerely apologize for any inconvenience caused by this change. While we strive to maintain backward compatibility within Echo core, recurring security issues with third-party dependencies have forced this decision. + +**Enhancements** + +* remove jwt middleware by @stevenwhitehead in https://github.com/labstack/echo/pull/2701 +* optimization: struct alignment by @behnambm in https://github.com/labstack/echo/pull/2636 +* bind: Maintain backwards compatibility for map[string]interface{} binding by @thesaltree in https://github.com/labstack/echo/pull/2656 +* Add Go 1.23 to CI by @aldas in https://github.com/labstack/echo/pull/2675 +* improve `MultipartForm` test by @martinyonatann in https://github.com/labstack/echo/pull/2682 +* `bind` : add support of multipart multi files by @martinyonatann in https://github.com/labstack/echo/pull/2684 +* Add TemplateRenderer struct to ease creating renderers for `html/template` and `text/template` packages. by @aldas in https://github.com/labstack/echo/pull/2690 +* Refactor TestBasicAuth to utilize table-driven test format by @ErikOlson in https://github.com/labstack/echo/pull/2688 +* Remove broken header by @aldas in https://github.com/labstack/echo/pull/2705 +* fix(bind body): content-length can be -1 by @phamvinhdat in https://github.com/labstack/echo/pull/2710 +* CORS middleware should compile allowOrigin regexp at creation by @aldas in https://github.com/labstack/echo/pull/2709 +* Shorten Github issue template and add test example by @aldas in https://github.com/labstack/echo/pull/2711 + + +## v4.12.0 - 2024-04-15 + +**Security** + +* Update golang.org/x/net dep because of [GO-2024-2687](https://pkg.go.dev/vuln/GO-2024-2687) by @aldas in https://github.com/labstack/echo/pull/2625 + + +**Enhancements** + +* binder: make binding to Map work better with string destinations by @aldas in https://github.com/labstack/echo/pull/2554 +* README.md: add Encore as sponsor by @marcuskohlberg in https://github.com/labstack/echo/pull/2579 +* Reorder paragraphs in README.md by @aldas in https://github.com/labstack/echo/pull/2581 +* CI: upgrade actions/checkout to v4 by @aldas in https://github.com/labstack/echo/pull/2584 +* Remove default charset from 'application/json' Content-Type header by @doortts in https://github.com/labstack/echo/pull/2568 +* CI: Use Go 1.22 by @aldas in https://github.com/labstack/echo/pull/2588 +* binder: allow binding to a nil map by @georgmu in https://github.com/labstack/echo/pull/2574 +* Add Skipper Unit Test In BasicBasicAuthConfig and Add More Detail Explanation regarding BasicAuthValidator by @RyoKusnadi in https://github.com/labstack/echo/pull/2461 +* fix some typos by @teslaedison in https://github.com/labstack/echo/pull/2603 +* fix: some typos by @pomadev in https://github.com/labstack/echo/pull/2596 +* Allow ResponseWriters to unwrap writers when flushing/hijacking by @aldas in https://github.com/labstack/echo/pull/2595 +* Add SPDX licence comments to files. by @aldas in https://github.com/labstack/echo/pull/2604 +* Upgrade deps by @aldas in https://github.com/labstack/echo/pull/2605 +* Change type definition blocks to single declarations. This helps copy… by @aldas in https://github.com/labstack/echo/pull/2606 +* Fix Real IP logic by @cl-bvl in https://github.com/labstack/echo/pull/2550 +* Default binder can use `UnmarshalParams(params []string) error` inter… by @aldas in https://github.com/labstack/echo/pull/2607 +* Default binder can bind pointer to slice as struct field. For example `*[]string` by @aldas in https://github.com/labstack/echo/pull/2608 +* Remove maxparam dependence from Context by @aldas in https://github.com/labstack/echo/pull/2611 +* When route is registered with empty path it is normalized to `/`. by @aldas in https://github.com/labstack/echo/pull/2616 +* proxy middleware should use httputil.ReverseProxy for SSE requests by @aldas in https://github.com/labstack/echo/pull/2624 + + +## v4.11.4 - 2023-12-20 + +**Security** + +* Upgrade golang.org/x/crypto to v0.17.0 to fix vulnerability [issue](https://pkg.go.dev/vuln/GO-2023-2402) [#2562](https://github.com/labstack/echo/pull/2562) + +**Enhancements** + +* Update deps and mark Go version to 1.18 as this is what golang.org/x/* use [#2563](https://github.com/labstack/echo/pull/2563) +* Request logger: add example for Slog https://pkg.go.dev/log/slog [#2543](https://github.com/labstack/echo/pull/2543) + + +## v4.11.3 - 2023-11-07 + +**Security** + +* 'c.Attachment' and 'c.Inline' should escape filename in 'Content-Disposition' header to avoid 'Reflect File Download' vulnerability. [#2541](https://github.com/labstack/echo/pull/2541) + +**Enhancements** + +* Tests: refactor context tests to be separate functions [#2540](https://github.com/labstack/echo/pull/2540) +* Proxy middleware: reuse echo request context [#2537](https://github.com/labstack/echo/pull/2537) +* Mark unmarshallable yaml struct tags as ignored [#2536](https://github.com/labstack/echo/pull/2536) + + +## v4.11.2 - 2023-10-11 + +**Security** + +* Bump golang.org/x/net to prevent CVE-2023-39325 / CVE-2023-44487 HTTP/2 Rapid Reset Attack [#2527](https://github.com/labstack/echo/pull/2527) +* fix(sec): randomString bias introduced by #2490 [#2492](https://github.com/labstack/echo/pull/2492) +* CSRF/RequestID mw: switch math/random usage to crypto/random [#2490](https://github.com/labstack/echo/pull/2490) + +**Enhancements** + +* Delete unused context in body_limit.go [#2483](https://github.com/labstack/echo/pull/2483) +* Use Go 1.21 in CI [#2505](https://github.com/labstack/echo/pull/2505) +* Fix some typos [#2511](https://github.com/labstack/echo/pull/2511) +* Allow CORS middleware to send Access-Control-Max-Age: 0 [#2518](https://github.com/labstack/echo/pull/2518) +* Bump dependancies [#2522](https://github.com/labstack/echo/pull/2522) + ## v4.11.1 - 2023-07-16 **Fixes** diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..decbf0792 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,99 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## About This Project + +Echo is a high performance, minimalist Go web framework. This is the main repository for Echo v4, which is available as a Go module at `github.com/labstack/echo/v4`. + +## Development Commands + +The project uses a Makefile for common development tasks: + +- `make check` - Run linting, vetting, and race condition tests (default target) +- `make init` - Install required linting tools (golint, staticcheck) +- `make lint` - Run staticcheck and golint +- `make vet` - Run go vet +- `make test` - Run short tests +- `make race` - Run tests with race detector +- `make benchmark` - Run benchmarks + +Example commands for development: +```bash +# Setup development environment +make init + +# Run all checks (lint, vet, race) +make check + +# Run specific tests +go test ./middleware/... +go test -race ./... + +# Run benchmarks +make benchmark +``` + +## Code Architecture + +### Core Components + +**Echo Instance (`echo.go`)** +- The `Echo` struct is the top-level framework instance +- Contains router, middleware stacks, and server configuration +- Not goroutine-safe for mutations after server start + +**Context (`context.go`)** +- The `Context` interface represents HTTP request/response context +- Provides methods for request/response handling, path parameters, data binding +- Core abstraction for request processing + +**Router (`router.go`)** +- Radix tree-based HTTP router with smart route prioritization +- Supports static routes, parameterized routes (`/users/:id`), and wildcard routes (`/static/*`) +- Each HTTP method has its own routing tree + +**Middleware (`middleware/`)** +- Extensive middleware system with 50+ built-in middlewares +- Middleware can be applied at Echo, Group, or individual route level +- Common middleware: Logger, Recover, CORS, JWT, Rate Limiting, etc. + +### Key Patterns + +**Middleware Chain** +- Pre-middleware runs before routing +- Regular middleware runs after routing but before handlers +- Middleware functions have signature `func(next echo.HandlerFunc) echo.HandlerFunc` + +**Route Groups** +- Routes can be grouped with common prefixes and middleware +- Groups support nested sub-groups +- Defined in `group.go` + +**Data Binding** +- Automatic binding of request data (JSON, XML, form) to Go structs +- Implemented in `binder.go` with support for custom binders + +**Error Handling** +- Centralized error handling via `HTTPErrorHandler` +- Automatic panic recovery with stack traces + +## File Organization + +- Root directory: Core Echo functionality (echo.go, context.go, router.go, etc.) +- `middleware/`: All built-in middleware implementations +- `_test/`: Test fixtures and utilities +- `_fixture/`: Test data files + +## Code Style + +- Go code uses tabs for indentation (per .editorconfig) +- Follows standard Go conventions and formatting +- Uses gofmt, golint, and staticcheck for code quality + +## Testing + +- Standard Go testing with `testing` package +- Tests include unit tests, integration tests, and benchmarks +- Race condition testing is required (`make race`) +- Test files follow `*_test.go` naming convention \ No newline at end of file diff --git a/Makefile b/Makefile index 6aff6a89f..cbd78f1bf 100644 --- a/Makefile +++ b/Makefile @@ -31,6 +31,7 @@ benchmark: ## Run benchmarks help: ## Display this help screen @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' -goversion ?= "1.17" -test_version: ## Run tests inside Docker with given version (defaults to 1.17 oldest supported). Example: make test_version goversion=1.17 - @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" +goversion ?= "1.22" +docker_user ?= "1000" +test_version: ## Run tests inside Docker with given version (defaults to 1.22 oldest supported). Example: make test_version goversion=1.22 + @docker run --rm -it --user $(docker_user) -e HOME=/tmp -e GOCACHE=/tmp/go-cache -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "mkdir -p /tmp/go-cache /tmp/.cache && cd /project && make init check" diff --git a/README.md b/README.md index ea8f30f64..5a920e875 100644 --- a/README.md +++ b/README.md @@ -1,28 +1,24 @@ - - [![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge) [![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo/v4) [![Go Report Card](https://goreportcard.com/badge/github.com/labstack/echo?style=flat-square)](https://goreportcard.com/report/github.com/labstack/echo) -[![Build Status](http://img.shields.io/travis/labstack/echo.svg?style=flat-square)](https://travis-ci.org/labstack/echo) +[![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/labstack/echo/echo.yml?style=flat-square)](https://github.com/labstack/echo/actions) [![Codecov](https://img.shields.io/codecov/c/github/labstack/echo.svg?style=flat-square)](https://codecov.io/gh/labstack/echo) [![Forum](https://img.shields.io/badge/community-forum-00afd1.svg?style=flat-square)](https://github.com/labstack/echo/discussions) [![Twitter](https://img.shields.io/badge/twitter-@labstack-55acee.svg?style=flat-square)](https://twitter.com/labstack) [![License](http://img.shields.io/badge/license-mit-blue.svg?style=flat-square)](https://raw.githubusercontent.com/labstack/echo/master/LICENSE) -## Supported Go versions +## Echo -Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with -older versions. +High performance, extensible, minimalist Go web framework. -As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules). -Therefore a Go version capable of understanding /vN suffixed imports is required: +* [Official website](https://echo.labstack.com) +* [Quick start](https://echo.labstack.com/docs/quick-start) +* [Middlewares](https://echo.labstack.com/docs/category/middleware) -Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended -way of using Echo going forward. +Help and questions: [Github Discussions](https://github.com/labstack/echo/discussions) -For older versions, please use the latest v3 tag. -## Feature Overview +### Feature Overview - Optimized HTTP router which smartly prioritize routes - Build robust and scalable RESTful APIs @@ -38,16 +34,17 @@ For older versions, please use the latest v3 tag. - Automatic TLS via Let’s Encrypt - HTTP/2 support -## Benchmarks - -Date: 2020/11/11
-Source: https://github.com/vishr/web-framework-benchmark
-Lower is better! +## Sponsors - - +
+ + encore icon + Encore – the platform for building Go-based cloud backends + +
+
-The benchmarks above were run on an Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz +Click [here](https://github.com/sponsors/labstack) for more information on sponsorship. ## [Guide](https://echo.labstack.com/guide) @@ -57,6 +54,7 @@ The benchmarks above were run on an Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz // go get github.com/labstack/echo/{version} go get github.com/labstack/echo/v4 ``` +Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with older versions. ### Example @@ -66,6 +64,7 @@ package main import ( "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" + "log/slog" "net/http" ) @@ -81,7 +80,9 @@ func main() { e.GET("/", hello) // Start server - e.Logger.Fatal(e.Start(":1323")) + if err := e.Start(":8080"); err != nil && !errors.Is(err, http.ErrServerClosed) { + slog.Error("failed to start server", "error", err) + } } // Handler @@ -117,10 +118,6 @@ of middlewares in this list. Please send a PR to add your own library here. -## Help - -- [Forum](https://github.com/labstack/echo/discussions) - ## Contribute **Use issues for everything** diff --git a/bind.go b/bind.go index 374a2aec5..1d4fe6f0a 100644 --- a/bind.go +++ b/bind.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( @@ -5,31 +8,45 @@ import ( "encoding/xml" "errors" "fmt" + "mime/multipart" "net/http" "reflect" "strconv" "strings" + "time" ) -type ( - // Binder is the interface that wraps the Bind method. - Binder interface { - Bind(i interface{}, c Context) error - } +// Binder is the interface that wraps the Bind method. +type Binder interface { + Bind(i interface{}, c Context) error +} - // DefaultBinder is the default implementation of the Binder interface. - DefaultBinder struct{} +// DefaultBinder is the default implementation of the Binder interface. +type DefaultBinder struct{} - // BindUnmarshaler is the interface used to wrap the UnmarshalParam method. - // Types that don't implement this, but do implement encoding.TextUnmarshaler - // will use that interface instead. - BindUnmarshaler interface { - // UnmarshalParam decodes and assigns a value from an form or query param. - UnmarshalParam(param string) error - } -) +// BindUnmarshaler is the interface used to wrap the UnmarshalParam method. +// Types that don't implement this, but do implement encoding.TextUnmarshaler +// will use that interface instead. +type BindUnmarshaler interface { + // UnmarshalParam decodes and assigns a value from an form or query param. + UnmarshalParam(param string) error +} + +// bindMultipleUnmarshaler is used by binder to unmarshal multiple values from request at once to +// type implementing this interface. For example request could have multiple query fields `?a=1&a=2&b=test` in that case +// for `a` following slice `["1", "2"] will be passed to unmarshaller. +type bindMultipleUnmarshaler interface { + UnmarshalParams(params []string) error +} // BindPathParams binds path params to bindable object +// +// Time format support: time.Time fields can use `format` tags to specify custom parsing layouts. +// Example: `param:"created" format:"2006-01-02T15:04"` for datetime-local format +// Example: `param:"date" format:"2006-01-02"` for date format +// Uses Go's standard time format reference time: Mon Jan 2 15:04:05 MST 2006 +// Works with form data, query parameters, and path parameters (not JSON body) +// Falls back to default time.Time parsing if no format tag is specified func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error { names := c.ParamNames() values := c.ParamValues() @@ -37,7 +54,7 @@ func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error { for i, name := range names { params[name] = []string{values[i]} } - if err := b.bindData(i, params, "param"); err != nil { + if err := b.bindData(i, params, "param", nil); err != nil { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } return nil @@ -45,7 +62,7 @@ func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error { // BindQueryParams binds query params to bindable object func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error { - if err := b.bindData(i, c.QueryParams(), "query"); err != nil { + if err := b.bindData(i, c.QueryParams(), "query", nil); err != nil { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } return nil @@ -62,9 +79,12 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { return } - ctype := req.Header.Get(HeaderContentType) - switch { - case strings.HasPrefix(ctype, MIMEApplicationJSON): + // mediatype is found like `mime.ParseMediaType()` does it + base, _, _ := strings.Cut(req.Header.Get(HeaderContentType), ";") + mediatype := strings.TrimSpace(base) + + switch mediatype { + case MIMEApplicationJSON: if err = c.Echo().JSONSerializer.Deserialize(c, i); err != nil { switch err.(type) { case *HTTPError: @@ -73,7 +93,7 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } } - case strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML): + case MIMEApplicationXML, MIMETextXML: if err = xml.NewDecoder(req.Body).Decode(i); err != nil { if ute, ok := err.(*xml.UnsupportedTypeError); ok { return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err) @@ -82,12 +102,20 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { } return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } - case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): + case MIMEApplicationForm: params, err := c.FormParams() if err != nil { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } - if err = b.bindData(i, params, "form"); err != nil { + if err = b.bindData(i, params, "form", nil); err != nil { + return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + } + case MIMEMultipartForm: + params, err := c.MultipartForm() + if err != nil { + return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + } + if err = b.bindData(i, params.Value, "form", params.File); err != nil { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } default: @@ -98,7 +126,7 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { // BindHeaders binds HTTP headers to a bindable object func (b *DefaultBinder) BindHeaders(c Context, i interface{}) error { - if err := b.bindData(i, c.Request().Header, "header"); err != nil { + if err := b.bindData(i, c.Request().Header, "header", nil); err != nil { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } return nil @@ -124,17 +152,41 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { } // bindData will bind data ONLY fields in destination struct that have EXPLICIT tag -func (b *DefaultBinder) bindData(destination interface{}, data map[string][]string, tag string) error { - if destination == nil || len(data) == 0 { +func (b *DefaultBinder) bindData(destination interface{}, data map[string][]string, tag string, dataFiles map[string][]*multipart.FileHeader) error { + if destination == nil || (len(data) == 0 && len(dataFiles) == 0) { return nil } + hasFiles := len(dataFiles) > 0 typ := reflect.TypeOf(destination).Elem() val := reflect.ValueOf(destination).Elem() - // Map - if typ.Kind() == reflect.Map { + // Support binding to limited Map destinations: + // - map[string][]string, + // - map[string]string <-- (binds first value from data slice) + // - map[string]interface{} + // You are better off binding to struct but there are user who want this map feature. Source of data for these cases are: + // params,query,header,form as these sources produce string values, most of the time slice of strings, actually. + if typ.Kind() == reflect.Map && typ.Key().Kind() == reflect.String { + k := typ.Elem().Kind() + isElemInterface := k == reflect.Interface + isElemString := k == reflect.String + isElemSliceOfStrings := k == reflect.Slice && typ.Elem().Elem().Kind() == reflect.String + if !(isElemSliceOfStrings || isElemString || isElemInterface) { + return nil + } + if val.IsNil() { + val.Set(reflect.MakeMap(typ)) + } for k, v := range data { - val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0])) + if isElemString { + val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0])) + } else if isElemInterface { + // To maintain backward compatibility, we always bind to the first string value + // and not the slice of strings when dealing with map[string]interface{}{} + val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0])) + } else { + val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v)) + } } return nil } @@ -148,7 +200,7 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri return errors.New("binding element must be a struct") } - for i := 0; i < typ.NumField(); i++ { + for i := 0; i < typ.NumField(); i++ { // iterate over all destination fields typeField := typ.Field(i) structField := val.Field(i) if typeField.Anonymous { @@ -161,16 +213,16 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri } structFieldKind := structField.Kind() inputFieldName := typeField.Tag.Get(tag) - if typeField.Anonymous && structField.Kind() == reflect.Struct && inputFieldName != "" { + if typeField.Anonymous && structFieldKind == reflect.Struct && inputFieldName != "" { // if anonymous struct with query/param/form tags, report an error return errors.New("query/param/form tags are not allowed with anonymous struct field") } if inputFieldName == "" { - // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags). - // structs that implement BindUnmarshaler are binded only when they have explicit tag + // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contain fields with tags). + // structs that implement BindUnmarshaler are bound only when they have explicit tag if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct { - if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil { + if err := b.bindData(structField.Addr().Interface(), data, tag, dataFiles); err != nil { return err } } @@ -178,10 +230,20 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri continue } + if hasFiles { + if ok, err := isFieldMultipartFile(structField.Type()); err != nil { + return err + } else if ok { + if ok := setMultipartFileHeaderTypes(structField, inputFieldName, dataFiles); ok { + continue + } + } + } + inputValue, exists := data[inputFieldName] if !exists { - // Go json.Unmarshal supports case insensitive binding. However the - // url params are bound case sensitive which is inconsistent. To + // Go json.Unmarshal supports case-insensitive binding. However the + // url params are bound case-sensitive which is inconsistent. To // fix this we must check all of the map values in a // case-insensitive search. for k, v := range data { @@ -197,27 +259,47 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri continue } - // Call this first, in case we're dealing with an alias to an array type - if ok, err := unmarshalField(typeField.Type.Kind(), inputValue[0], structField); ok { + // NOTE: algorithm here is not particularly sophisticated. It probably does not work with absurd types like `**[]*int` + // but it is smart enough to handle niche cases like `*int`,`*[]string`,`[]*int` . + + // try unmarshalling first, in case we're dealing with an alias to an array type + if ok, err := unmarshalInputsToField(typeField.Type.Kind(), inputValue, structField); ok { if err != nil { return err } continue } - numElems := len(inputValue) - if structFieldKind == reflect.Slice && numElems > 0 { + formatTag := typeField.Tag.Get("format") + if ok, err := unmarshalInputToField(typeField.Type.Kind(), inputValue[0], structField, formatTag); ok { + if err != nil { + return err + } + continue + } + + // we could be dealing with pointer to slice `*[]string` so dereference it. There are weird OpenAPI generators + // that could create struct fields like that. + if structFieldKind == reflect.Pointer { + structFieldKind = structField.Elem().Kind() + structField = structField.Elem() + } + + if structFieldKind == reflect.Slice { sliceOf := structField.Type().Elem().Kind() + numElems := len(inputValue) slice := reflect.MakeSlice(structField.Type(), numElems, numElems) for j := 0; j < numElems; j++ { if err := setWithProperType(sliceOf, inputValue[j], slice.Index(j)); err != nil { return err } } - val.Field(i).Set(slice) - } else if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil { - return err + structField.Set(slice) + continue + } + if err := setWithProperType(structFieldKind, inputValue[0], structField); err != nil { + return err } } return nil @@ -225,7 +307,8 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error { // But also call it here, in case we're dealing with an array of BindUnmarshalers - if ok, err := unmarshalField(valueKind, val, structField); ok { + // Note: format tag not available in this context, so empty string is passed + if ok, err := unmarshalInputToField(valueKind, val, structField, ""); ok { return err } @@ -266,33 +349,52 @@ func setWithProperType(valueKind reflect.Kind, val string, structField reflect.V return nil } -func unmarshalField(valueKind reflect.Kind, val string, field reflect.Value) (bool, error) { - switch valueKind { - case reflect.Ptr: - return unmarshalFieldPtr(val, field) - default: - return unmarshalFieldNonPtr(val, field) +func unmarshalInputsToField(valueKind reflect.Kind, values []string, field reflect.Value) (bool, error) { + if valueKind == reflect.Ptr { + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + field = field.Elem() } -} -func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) { fieldIValue := field.Addr().Interface() - if unmarshaler, ok := fieldIValue.(BindUnmarshaler); ok { - return true, unmarshaler.UnmarshalParam(value) + unmarshaler, ok := fieldIValue.(bindMultipleUnmarshaler) + if !ok { + return false, nil } - if unmarshaler, ok := fieldIValue.(encoding.TextUnmarshaler); ok { - return true, unmarshaler.UnmarshalText([]byte(value)) + return true, unmarshaler.UnmarshalParams(values) +} + +func unmarshalInputToField(valueKind reflect.Kind, val string, field reflect.Value, formatTag string) (bool, error) { + if valueKind == reflect.Ptr { + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + field = field.Elem() } - return false, nil -} + fieldIValue := field.Addr().Interface() + + // Handle time.Time with custom format tag + if formatTag != "" { + if _, isTime := fieldIValue.(*time.Time); isTime { + t, err := time.Parse(formatTag, val) + if err != nil { + return true, err + } + field.Set(reflect.ValueOf(t)) + return true, nil + } + } -func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) { - if field.IsNil() { - // Initialize the pointer to a nil value - field.Set(reflect.New(field.Type().Elem())) + switch unmarshaler := fieldIValue.(type) { + case BindUnmarshaler: + return true, unmarshaler.UnmarshalParam(val) + case encoding.TextUnmarshaler: + return true, unmarshaler.UnmarshalText([]byte(val)) } - return unmarshalFieldNonPtr(value, field.Elem()) + + return false, nil } func setIntField(value string, bitSize int, field reflect.Value) error { @@ -338,3 +440,50 @@ func setFloatField(value string, bitSize int, field reflect.Value) error { } return err } + +var ( + // NOT supported by bind as you can NOT check easily empty struct being actual file or not + multipartFileHeaderType = reflect.TypeFor[multipart.FileHeader]() + // supported by bind as you can check by nil value if file existed or not + multipartFileHeaderPointerType = reflect.TypeFor[*multipart.FileHeader]() + multipartFileHeaderSliceType = reflect.TypeFor[[]multipart.FileHeader]() + multipartFileHeaderPointerSliceType = reflect.TypeFor[[]*multipart.FileHeader]() +) + +func isFieldMultipartFile(field reflect.Type) (bool, error) { + switch field { + case multipartFileHeaderPointerType, + multipartFileHeaderSliceType, + multipartFileHeaderPointerSliceType: + return true, nil + case multipartFileHeaderType: + return true, errors.New("binding to multipart.FileHeader struct is not supported, use pointer to struct") + default: + return false, nil + } +} + +func setMultipartFileHeaderTypes(structField reflect.Value, inputFieldName string, files map[string][]*multipart.FileHeader) bool { + fileHeaders := files[inputFieldName] + if len(fileHeaders) == 0 { + return false + } + + result := true + switch structField.Type() { + case multipartFileHeaderPointerSliceType: + structField.Set(reflect.ValueOf(fileHeaders)) + case multipartFileHeaderSliceType: + headers := make([]multipart.FileHeader, len(fileHeaders)) + for i, fileHeader := range fileHeaders { + headers[i] = *fileHeader + } + structField.Set(reflect.ValueOf(headers)) + case multipartFileHeaderPointerType: + structField.Set(reflect.ValueOf(fileHeaders[0])) + default: + result = false + } + + return result +} diff --git a/bind_test.go b/bind_test.go index c35283dcf..3e387ba19 100644 --- a/bind_test.go +++ b/bind_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( @@ -5,10 +8,12 @@ import ( "encoding/json" "encoding/xml" "errors" + "fmt" "io" "mime/multipart" "net/http" "net/http/httptest" + "net/http/httputil" "net/url" "reflect" "strconv" @@ -19,91 +24,91 @@ import ( "github.com/stretchr/testify/assert" ) -type ( - bindTestStruct struct { - I int - PtrI *int - I8 int8 - PtrI8 *int8 - I16 int16 - PtrI16 *int16 - I32 int32 - PtrI32 *int32 - I64 int64 - PtrI64 *int64 - UI uint - PtrUI *uint - UI8 uint8 - PtrUI8 *uint8 - UI16 uint16 - PtrUI16 *uint16 - UI32 uint32 - PtrUI32 *uint32 - UI64 uint64 - PtrUI64 *uint64 - B bool - PtrB *bool - F32 float32 - PtrF32 *float32 - F64 float64 - PtrF64 *float64 - S string - PtrS *string - cantSet string - DoesntExist string - GoT time.Time - GoTptr *time.Time - T Timestamp - Tptr *Timestamp - SA StringArray - } - bindTestStructWithTags struct { - I int `json:"I" form:"I"` - PtrI *int `json:"PtrI" form:"PtrI"` - I8 int8 `json:"I8" form:"I8"` - PtrI8 *int8 `json:"PtrI8" form:"PtrI8"` - I16 int16 `json:"I16" form:"I16"` - PtrI16 *int16 `json:"PtrI16" form:"PtrI16"` - I32 int32 `json:"I32" form:"I32"` - PtrI32 *int32 `json:"PtrI32" form:"PtrI32"` - I64 int64 `json:"I64" form:"I64"` - PtrI64 *int64 `json:"PtrI64" form:"PtrI64"` - UI uint `json:"UI" form:"UI"` - PtrUI *uint `json:"PtrUI" form:"PtrUI"` - UI8 uint8 `json:"UI8" form:"UI8"` - PtrUI8 *uint8 `json:"PtrUI8" form:"PtrUI8"` - UI16 uint16 `json:"UI16" form:"UI16"` - PtrUI16 *uint16 `json:"PtrUI16" form:"PtrUI16"` - UI32 uint32 `json:"UI32" form:"UI32"` - PtrUI32 *uint32 `json:"PtrUI32" form:"PtrUI32"` - UI64 uint64 `json:"UI64" form:"UI64"` - PtrUI64 *uint64 `json:"PtrUI64" form:"PtrUI64"` - B bool `json:"B" form:"B"` - PtrB *bool `json:"PtrB" form:"PtrB"` - F32 float32 `json:"F32" form:"F32"` - PtrF32 *float32 `json:"PtrF32" form:"PtrF32"` - F64 float64 `json:"F64" form:"F64"` - PtrF64 *float64 `json:"PtrF64" form:"PtrF64"` - S string `json:"S" form:"S"` - PtrS *string `json:"PtrS" form:"PtrS"` - cantSet string - DoesntExist string `json:"DoesntExist" form:"DoesntExist"` - GoT time.Time `json:"GoT" form:"GoT"` - GoTptr *time.Time `json:"GoTptr" form:"GoTptr"` - T Timestamp `json:"T" form:"T"` - Tptr *Timestamp `json:"Tptr" form:"Tptr"` - SA StringArray `json:"SA" form:"SA"` - } - Timestamp time.Time - TA []Timestamp - StringArray []string - Struct struct { - Foo string - } - Bar struct { - Baz int `json:"baz" query:"baz"` - } -) +type bindTestStruct struct { + I int + PtrI *int + I8 int8 + PtrI8 *int8 + I16 int16 + PtrI16 *int16 + I32 int32 + PtrI32 *int32 + I64 int64 + PtrI64 *int64 + UI uint + PtrUI *uint + UI8 uint8 + PtrUI8 *uint8 + UI16 uint16 + PtrUI16 *uint16 + UI32 uint32 + PtrUI32 *uint32 + UI64 uint64 + PtrUI64 *uint64 + B bool + PtrB *bool + F32 float32 + PtrF32 *float32 + F64 float64 + PtrF64 *float64 + S string + PtrS *string + cantSet string + DoesntExist string + GoT time.Time + GoTptr *time.Time + T Timestamp + Tptr *Timestamp + SA StringArray +} + +type bindTestStructWithTags struct { + I int `json:"I" form:"I"` + PtrI *int `json:"PtrI" form:"PtrI"` + I8 int8 `json:"I8" form:"I8"` + PtrI8 *int8 `json:"PtrI8" form:"PtrI8"` + I16 int16 `json:"I16" form:"I16"` + PtrI16 *int16 `json:"PtrI16" form:"PtrI16"` + I32 int32 `json:"I32" form:"I32"` + PtrI32 *int32 `json:"PtrI32" form:"PtrI32"` + I64 int64 `json:"I64" form:"I64"` + PtrI64 *int64 `json:"PtrI64" form:"PtrI64"` + UI uint `json:"UI" form:"UI"` + PtrUI *uint `json:"PtrUI" form:"PtrUI"` + UI8 uint8 `json:"UI8" form:"UI8"` + PtrUI8 *uint8 `json:"PtrUI8" form:"PtrUI8"` + UI16 uint16 `json:"UI16" form:"UI16"` + PtrUI16 *uint16 `json:"PtrUI16" form:"PtrUI16"` + UI32 uint32 `json:"UI32" form:"UI32"` + PtrUI32 *uint32 `json:"PtrUI32" form:"PtrUI32"` + UI64 uint64 `json:"UI64" form:"UI64"` + PtrUI64 *uint64 `json:"PtrUI64" form:"PtrUI64"` + B bool `json:"B" form:"B"` + PtrB *bool `json:"PtrB" form:"PtrB"` + F32 float32 `json:"F32" form:"F32"` + PtrF32 *float32 `json:"PtrF32" form:"PtrF32"` + F64 float64 `json:"F64" form:"F64"` + PtrF64 *float64 `json:"PtrF64" form:"PtrF64"` + S string `json:"S" form:"S"` + PtrS *string `json:"PtrS" form:"PtrS"` + cantSet string + DoesntExist string `json:"DoesntExist" form:"DoesntExist"` + GoT time.Time `json:"GoT" form:"GoT"` + GoTptr *time.Time `json:"GoTptr" form:"GoTptr"` + T Timestamp `json:"T" form:"T"` + Tptr *Timestamp `json:"Tptr" form:"Tptr"` + SA StringArray `json:"SA" form:"SA"` +} + +type Timestamp time.Time +type TA []Timestamp +type StringArray []string +type Struct struct { + Foo string +} +type Bar struct { + Baz int `json:"baz" query:"baz"` +} func (t *Timestamp) UnmarshalParam(src string) error { ts, err := time.Parse(time.RFC3339, src) @@ -164,6 +169,11 @@ var values = map[string][]string{ "ST": {"bar"}, } +// ptr return pointer to value. This is useful as `v := []*int8{&int8(1)}` will not compile +func ptr[T any](value T) *T { + return &value +} + func TestToMultipleFields(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/?id=1&ID=2", nil) @@ -429,10 +439,113 @@ func TestBindUnsupportedMediaType(t *testing.T) { testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) } +func TestDefaultBinder_bindDataToMap(t *testing.T) { + exampleData := map[string][]string{ + "multiple": {"1", "2"}, + "single": {"3"}, + } + + t.Run("ok, bind to map[string]string", func(t *testing.T) { + dest := map[string]string{} + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, + map[string]string{ + "multiple": "1", + "single": "3", + }, + dest, + ) + }) + + t.Run("ok, bind to map[string]string with nil map", func(t *testing.T) { + var dest map[string]string + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, + map[string]string{ + "multiple": "1", + "single": "3", + }, + dest, + ) + }) + + t.Run("ok, bind to map[string][]string", func(t *testing.T) { + dest := map[string][]string{} + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, + map[string][]string{ + "multiple": {"1", "2"}, + "single": {"3"}, + }, + dest, + ) + }) + + t.Run("ok, bind to map[string][]string with nil map", func(t *testing.T) { + var dest map[string][]string + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, + map[string][]string{ + "multiple": {"1", "2"}, + "single": {"3"}, + }, + dest, + ) + }) + + t.Run("ok, bind to map[string]interface", func(t *testing.T) { + dest := map[string]interface{}{} + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, + map[string]interface{}{ + "multiple": "1", + "single": "3", + }, + dest, + ) + }) + + t.Run("ok, bind to map[string]interface with nil map", func(t *testing.T) { + var dest map[string]interface{} + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, + map[string]interface{}{ + "multiple": "1", + "single": "3", + }, + dest, + ) + }) + + t.Run("ok, bind to map[string]int skips", func(t *testing.T) { + dest := map[string]int{} + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, map[string]int{}, dest) + }) + + t.Run("ok, bind to map[string]int skips with nil map", func(t *testing.T) { + var dest map[string]int + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, map[string]int(nil), dest) + }) + + t.Run("ok, bind to map[string][]int skips", func(t *testing.T) { + dest := map[string][]int{} + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, map[string][]int{}, dest) + }) + + t.Run("ok, bind to map[string][]int skips with nil map", func(t *testing.T) { + var dest map[string][]int + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, map[string][]int(nil), dest) + }) +} + func TestBindbindData(t *testing.T) { ts := new(bindTestStruct) b := new(DefaultBinder) - err := b.bindData(ts, values, "form") + err := b.bindData(ts, values, "form", nil) assert.NoError(t, err) assert.Equal(t, 0, ts.I) @@ -547,49 +660,6 @@ func TestBindSetWithProperType(t *testing.T) { assert.Error(t, setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0))) } -func TestBindSetFields(t *testing.T) { - - ts := new(bindTestStruct) - val := reflect.ValueOf(ts).Elem() - // Int - if assert.NoError(t, setIntField("5", 0, val.FieldByName("I"))) { - assert.Equal(t, 5, ts.I) - } - if assert.NoError(t, setIntField("", 0, val.FieldByName("I"))) { - assert.Equal(t, 0, ts.I) - } - - // Uint - if assert.NoError(t, setUintField("10", 0, val.FieldByName("UI"))) { - assert.Equal(t, uint(10), ts.UI) - } - if assert.NoError(t, setUintField("", 0, val.FieldByName("UI"))) { - assert.Equal(t, uint(0), ts.UI) - } - - // Float - if assert.NoError(t, setFloatField("15.5", 0, val.FieldByName("F32"))) { - assert.Equal(t, float32(15.5), ts.F32) - } - if assert.NoError(t, setFloatField("", 0, val.FieldByName("F32"))) { - assert.Equal(t, float32(0.0), ts.F32) - } - - // Bool - if assert.NoError(t, setBoolField("true", val.FieldByName("B"))) { - assert.Equal(t, true, ts.B) - } - if assert.NoError(t, setBoolField("", val.FieldByName("B"))) { - assert.Equal(t, false, ts.B) - } - - ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T")) - if assert.NoError(t, err) { - assert.Equal(t, ok, true) - assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T) - } -} - func BenchmarkBindbindDataWithTags(b *testing.B) { b.ReportAllocs() ts := new(bindTestStructWithTags) @@ -597,7 +667,7 @@ func BenchmarkBindbindDataWithTags(b *testing.B) { var err error b.ResetTimer() for i := 0; i < b.N; i++ { - err = binder.bindData(ts, values, "form") + err = binder.bindData(ts, values, "form", nil) } assert.NoError(b, err) assertBindTestStruct(b, (*bindTestStruct)(ts)) @@ -683,7 +753,7 @@ func testBindError(t *testing.T, r io.Reader, ctype string, expectedInternal err } func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { - // tests to check binding behaviour when multiple sources path params, query params and request body are in use + // tests to check binding behaviour when multiple sources (path params, query params and request body) are in use // binding is done in steps and one source could overwrite previous source binded data // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed @@ -853,7 +923,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { } func TestDefaultBinder_BindBody(t *testing.T) { - // tests to check binding behaviour when multiple sources path params, query params and request body are in use + // tests to check binding behaviour when multiple sources (path params, query params and request body) are in use // generally when binding from request body - URL and path params are ignored - unless form is being binded. // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed @@ -872,6 +942,7 @@ func TestDefaultBinder_BindBody(t *testing.T) { givenMethod string givenContentType string whenNoPathParams bool + whenChunkedBody bool whenBindTarget interface{} expect interface{} expectError string @@ -991,6 +1062,35 @@ func TestDefaultBinder_BindBody(t *testing.T) { expect: &Node{ID: 0, Node: ""}, expectError: "code=415, message=Unsupported Media Type", }, + // FIXME: REASON in Go 1.24 and earlier http.NoBody would result ContentLength=-1 + // but as of Go 1.25 http.NoBody would result ContentLength=0 + // I am too lazy to bother documenting this as 2 version specific tests. + //{ + // name: "nok, JSON POST with http.NoBody", + // givenURL: "/api/real_node/endpoint?node=xxx", + // givenMethod: http.MethodPost, + // givenContentType: MIMEApplicationJSON, + // givenContent: http.NoBody, + // expect: &Node{ID: 0, Node: ""}, + // expectError: "code=400, message=EOF, internal=EOF", + //}, + { + name: "ok, JSON POST with empty body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationJSON, + givenContent: strings.NewReader(""), + expect: &Node{ID: 0, Node: ""}, + }, + { + name: "ok, JSON POST bind to struct with: path + query + chunked body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationJSON, + givenContent: httputil.NewChunkedReader(strings.NewReader("18\r\n" + `{"id": 1, "node": "zzz"}` + "\r\n0\r\n")), + whenChunkedBody: true, + expect: &Node{ID: 1, Node: "zzz"}, + }, } for _, tc := range testCases { @@ -1006,6 +1106,10 @@ func TestDefaultBinder_BindBody(t *testing.T) { case MIMEApplicationJSON: req.Header.Set(HeaderContentType, MIMEApplicationJSON) } + if tc.whenChunkedBody { + req.ContentLength = -1 + req.TransferEncoding = append(req.TransferEncoding, "chunked") + } rec := httptest.NewRecorder() c := e.NewContext(req, rec) @@ -1032,3 +1136,551 @@ func TestDefaultBinder_BindBody(t *testing.T) { }) } } + +func testBindURL(queryString string, target any) error { + e := New() + req := httptest.NewRequest(http.MethodGet, queryString, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + return c.Bind(target) +} + +type unixTimestamp struct { + Time time.Time +} + +func (t *unixTimestamp) UnmarshalParam(param string) error { + n, err := strconv.ParseInt(param, 10, 64) + if err != nil { + return fmt.Errorf("'%s' is not an integer", param) + } + *t = unixTimestamp{Time: time.Unix(n, 0)} + return err +} + +type IntArrayA []int + +// UnmarshalParam converts value to *Int64Slice. This allows the API to accept +// a comma-separated list of integers as a query parameter. +func (i *IntArrayA) UnmarshalParam(value string) error { + var values = strings.Split(value, ",") + var numbers = make([]int, 0, len(values)) + + for _, v := range values { + n, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return fmt.Errorf("'%s' is not an integer", v) + } + + numbers = append(numbers, int(n)) + } + + *i = append(*i, numbers...) + return nil +} + +func TestBindUnmarshalParamExtras(t *testing.T) { + // this test documents how bind handles `BindUnmarshaler` interface: + // NOTE: BindUnmarshaler chooses first input value to be bound. + + t.Run("nok, unmarshalling fails", func(t *testing.T) { + result := struct { + V unixTimestamp `query:"t"` + }{} + err := testBindURL("/?t=xxxx", &result) + + assert.EqualError(t, err, "code=400, message='xxxx' is not an integer, internal='xxxx' is not an integer") + }) + + t.Run("ok, target is struct", func(t *testing.T) { + result := struct { + V unixTimestamp `query:"t"` + }{} + err := testBindURL("/?t=1710095540&t=1710095541", &result) + + assert.NoError(t, err) + expect := unixTimestamp{ + Time: time.Unix(1710095540, 0), + } + assert.Equal(t, expect, result.V) + }) + + t.Run("ok, target is an alias to slice and is nil, append only values from first", func(t *testing.T) { + result := struct { + V IntArrayA `query:"a"` + }{} + err := testBindURL("/?a=1,2,3&a=4,5,6", &result) + + assert.NoError(t, err) + assert.Equal(t, IntArrayA([]int{1, 2, 3}), result.V) + }) + + t.Run("ok, target is an alias to slice and is nil, single input", func(t *testing.T) { + result := struct { + V IntArrayA `query:"a"` + }{} + err := testBindURL("/?a=1,2", &result) + + assert.NoError(t, err) + assert.Equal(t, IntArrayA([]int{1, 2}), result.V) + }) + + t.Run("ok, target is pointer an alias to slice and is nil", func(t *testing.T) { + result := struct { + V *IntArrayA `query:"a"` + }{} + err := testBindURL("/?a=1&a=4,5,6", &result) + + assert.NoError(t, err) + var expected = IntArrayA([]int{1}) + assert.Equal(t, &expected, result.V) + }) + + t.Run("ok, target is pointer an alias to slice and is NOT nil", func(t *testing.T) { + result := struct { + V *IntArrayA `query:"a"` + }{} + result.V = new(IntArrayA) // NOT nil + + err := testBindURL("/?a=1&a=4,5,6", &result) + + assert.NoError(t, err) + var expected = IntArrayA([]int{1}) + assert.Equal(t, &expected, result.V) + }) +} + +type unixTimestampLast struct { + Time time.Time +} + +// this is silly example for `bindMultipleUnmarshaler` for type that uses last input value for unmarshalling +func (t *unixTimestampLast) UnmarshalParams(params []string) error { + lastInput := params[len(params)-1] + n, err := strconv.ParseInt(lastInput, 10, 64) + if err != nil { + return fmt.Errorf("'%s' is not an integer", lastInput) + } + *t = unixTimestampLast{Time: time.Unix(n, 0)} + return err +} + +type IntArrayB []int + +func (i *IntArrayB) UnmarshalParams(params []string) error { + var numbers = make([]int, 0, len(params)) + + for _, param := range params { + var values = strings.Split(param, ",") + for _, v := range values { + n, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return fmt.Errorf("'%s' is not an integer", v) + } + numbers = append(numbers, int(n)) + } + } + + *i = append(*i, numbers...) + return nil +} + +func TestBindUnmarshalParams(t *testing.T) { + // this test documents how bind handles `bindMultipleUnmarshaler` interface: + + t.Run("nok, unmarshalling fails", func(t *testing.T) { + result := struct { + V unixTimestampLast `query:"t"` + }{} + err := testBindURL("/?t=xxxx", &result) + + assert.EqualError(t, err, "code=400, message='xxxx' is not an integer, internal='xxxx' is not an integer") + }) + + t.Run("ok, target is struct", func(t *testing.T) { + result := struct { + V unixTimestampLast `query:"t"` + }{} + err := testBindURL("/?t=1710095540&t=1710095541", &result) + + assert.NoError(t, err) + expect := unixTimestampLast{ + Time: time.Unix(1710095541, 0), + } + assert.Equal(t, expect, result.V) + }) + + t.Run("ok, target is an alias to slice and is nil, append multiple inputs", func(t *testing.T) { + result := struct { + V IntArrayB `query:"a"` + }{} + err := testBindURL("/?a=1,2,3&a=4,5,6", &result) + + assert.NoError(t, err) + assert.Equal(t, IntArrayB([]int{1, 2, 3, 4, 5, 6}), result.V) + }) + + t.Run("ok, target is an alias to slice and is nil, single input", func(t *testing.T) { + result := struct { + V IntArrayB `query:"a"` + }{} + err := testBindURL("/?a=1,2", &result) + + assert.NoError(t, err) + assert.Equal(t, IntArrayB([]int{1, 2}), result.V) + }) + + t.Run("ok, target is pointer an alias to slice and is nil", func(t *testing.T) { + result := struct { + V *IntArrayB `query:"a"` + }{} + err := testBindURL("/?a=1&a=4,5,6", &result) + + assert.NoError(t, err) + var expected = IntArrayB([]int{1, 4, 5, 6}) + assert.Equal(t, &expected, result.V) + }) + + t.Run("ok, target is pointer an alias to slice and is NOT nil", func(t *testing.T) { + result := struct { + V *IntArrayB `query:"a"` + }{} + result.V = new(IntArrayB) // NOT nil + + err := testBindURL("/?a=1&a=4,5,6", &result) + assert.NoError(t, err) + var expected = IntArrayB([]int{1, 4, 5, 6}) + assert.Equal(t, &expected, result.V) + }) +} + +func TestBindInt8(t *testing.T) { + t.Run("nok, binding fails", func(t *testing.T) { + type target struct { + V int8 `query:"v"` + } + p := target{} + err := testBindURL("/?v=x&v=2", &p) + assert.EqualError(t, err, "code=400, message=strconv.ParseInt: parsing \"x\": invalid syntax, internal=strconv.ParseInt: parsing \"x\": invalid syntax") + }) + + t.Run("nok, int8 embedded in struct", func(t *testing.T) { + type target struct { + int8 `query:"v"` // embedded field is `Anonymous`. We can only set public fields + } + p := target{} + err := testBindURL("/?v=1&v=2", &p) + assert.NoError(t, err) + assert.Equal(t, target{0}, p) + }) + + t.Run("nok, pointer to int8 embedded in struct", func(t *testing.T) { + type target struct { + *int8 `query:"v"` // embedded field is `Anonymous`. We can only set public fields + } + p := target{} + err := testBindURL("/?v=1&v=2", &p) + assert.NoError(t, err) + + assert.Equal(t, target{int8: nil}, p) + }) + + t.Run("ok, bind int8 as struct field", func(t *testing.T) { + type target struct { + V int8 `query:"v"` + } + p := target{V: 127} + err := testBindURL("/?v=1&v=2", &p) + assert.NoError(t, err) + assert.Equal(t, target{V: 1}, p) + }) + + t.Run("ok, bind pointer to int8 as struct field, value is nil", func(t *testing.T) { + type target struct { + V *int8 `query:"v"` + } + p := target{} + err := testBindURL("/?v=1&v=2", &p) + assert.NoError(t, err) + assert.Equal(t, target{V: ptr(int8(1))}, p) + }) + + t.Run("ok, bind pointer to int8 as struct field, value is set", func(t *testing.T) { + type target struct { + V *int8 `query:"v"` + } + p := target{V: ptr(int8(127))} + err := testBindURL("/?v=1&v=2", &p) + assert.NoError(t, err) + assert.Equal(t, target{V: ptr(int8(1))}, p) + }) + + t.Run("ok, bind int8 slice as struct field, value is nil", func(t *testing.T) { + type target struct { + V []int8 `query:"v"` + } + p := target{} + err := testBindURL("/?v=1&v=2", &p) + assert.NoError(t, err) + assert.Equal(t, target{V: []int8{1, 2}}, p) + }) + + t.Run("ok, bind slice of int8 as struct field, value is set", func(t *testing.T) { + type target struct { + V []int8 `query:"v"` + } + p := target{V: []int8{111}} + err := testBindURL("/?v=1&v=2", &p) + assert.NoError(t, err) + assert.Equal(t, target{V: []int8{1, 2}}, p) + }) + + t.Run("ok, bind slice of pointer to int8 as struct field, value is set", func(t *testing.T) { + type target struct { + V []*int8 `query:"v"` + } + p := target{V: []*int8{ptr(int8(127))}} + err := testBindURL("/?v=1&v=2", &p) + assert.NoError(t, err) + assert.Equal(t, target{V: []*int8{ptr(int8(1)), ptr(int8(2))}}, p) + }) + + t.Run("ok, bind pointer to slice of int8 as struct field, value is set", func(t *testing.T) { + type target struct { + V *[]int8 `query:"v"` + } + p := target{V: &[]int8{111}} + err := testBindURL("/?v=1&v=2", &p) + assert.NoError(t, err) + assert.Equal(t, target{V: &[]int8{1, 2}}, p) + }) +} + +func TestBindMultipartFormFiles(t *testing.T) { + file1 := createTestFormFile("file", "file1.txt") + file11 := createTestFormFile("file", "file11.txt") + file2 := createTestFormFile("file2", "file2.txt") + filesA := createTestFormFile("files", "filesA.txt") + filesB := createTestFormFile("files", "filesB.txt") + + t.Run("nok, can not bind to multipart file struct", func(t *testing.T) { + var target struct { + File multipart.FileHeader `form:"file"` + } + err := bindMultipartFiles(t, &target, file1, file2) // file2 should be ignored + + assert.EqualError(t, err, "code=400, message=binding to multipart.FileHeader struct is not supported, use pointer to struct, internal=binding to multipart.FileHeader struct is not supported, use pointer to struct") + }) + + t.Run("ok, bind single multipart file to pointer to multipart file", func(t *testing.T) { + var target struct { + File *multipart.FileHeader `form:"file"` + } + err := bindMultipartFiles(t, &target, file1, file2) // file2 should be ignored + + assert.NoError(t, err) + assertMultipartFileHeader(t, target.File, file1) + }) + + t.Run("ok, bind multiple multipart files to pointer to multipart file", func(t *testing.T) { + var target struct { + File *multipart.FileHeader `form:"file"` + } + err := bindMultipartFiles(t, &target, file1, file11) + + assert.NoError(t, err) + assertMultipartFileHeader(t, target.File, file1) // should choose first one + }) + + t.Run("ok, bind multiple multipart files to slice of multipart file", func(t *testing.T) { + var target struct { + Files []multipart.FileHeader `form:"files"` + } + err := bindMultipartFiles(t, &target, filesA, filesB, file1) + + assert.NoError(t, err) + + assert.Len(t, target.Files, 2) + assertMultipartFileHeader(t, &target.Files[0], filesA) + assertMultipartFileHeader(t, &target.Files[1], filesB) + }) + + t.Run("ok, bind multiple multipart files to slice of pointer to multipart file", func(t *testing.T) { + var target struct { + Files []*multipart.FileHeader `form:"files"` + } + err := bindMultipartFiles(t, &target, filesA, filesB, file1) + + assert.NoError(t, err) + + assert.Len(t, target.Files, 2) + assertMultipartFileHeader(t, target.Files[0], filesA) + assertMultipartFileHeader(t, target.Files[1], filesB) + }) +} + +type testFormFile struct { + Fieldname string + Filename string + Content []byte +} + +func createTestFormFile(formFieldName string, filename string) testFormFile { + return testFormFile{ + Fieldname: formFieldName, + Filename: filename, + Content: []byte(strings.Repeat(filename, 10)), + } +} + +func bindMultipartFiles(t *testing.T, target any, files ...testFormFile) error { + var body bytes.Buffer + mw := multipart.NewWriter(&body) + + for _, file := range files { + fw, err := mw.CreateFormFile(file.Fieldname, file.Filename) + assert.NoError(t, err) + + n, err := fw.Write(file.Content) + assert.NoError(t, err) + assert.Equal(t, len(file.Content), n) + } + + err := mw.Close() + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, "/", &body) + assert.NoError(t, err) + req.Header.Set("Content-Type", mw.FormDataContentType()) + + rec := httptest.NewRecorder() + + e := New() + c := e.NewContext(req, rec) + return c.Bind(target) +} + +func assertMultipartFileHeader(t *testing.T, fh *multipart.FileHeader, file testFormFile) { + assert.Equal(t, file.Filename, fh.Filename) + assert.Equal(t, int64(len(file.Content)), fh.Size) + fl, err := fh.Open() + assert.NoError(t, err) + body, err := io.ReadAll(fl) + assert.NoError(t, err) + assert.Equal(t, string(file.Content), string(body)) + err = fl.Close() + assert.NoError(t, err) +} + +func TestTimeFormatBinding(t *testing.T) { + type TestStruct struct { + DateTimeLocal time.Time `form:"datetime_local" format:"2006-01-02T15:04"` + Date time.Time `query:"date" format:"2006-01-02"` + CustomFormat time.Time `form:"custom" format:"01/02/2006 15:04:05"` + DefaultTime time.Time `form:"default_time"` // No format tag - should use default parsing + PtrTime *time.Time `query:"ptr_time" format:"2006-01-02"` + } + + testCases := []struct { + name string + contentType string + data string + queryParams string + expect TestStruct + expectError bool + }{ + { + name: "ok, datetime-local format binding", + contentType: MIMEApplicationForm, + data: "datetime_local=2023-12-25T14:30&default_time=2023-12-25T14:30:45Z", + expect: TestStruct{ + DateTimeLocal: time.Date(2023, 12, 25, 14, 30, 0, 0, time.UTC), + DefaultTime: time.Date(2023, 12, 25, 14, 30, 45, 0, time.UTC), + }, + }, + { + name: "ok, date format binding via query params", + queryParams: "?date=2023-01-15&ptr_time=2023-02-20", + expect: TestStruct{ + Date: time.Date(2023, 1, 15, 0, 0, 0, 0, time.UTC), + PtrTime: &time.Time{}, + }, + }, + { + name: "ok, custom format via form data", + contentType: MIMEApplicationForm, + data: "custom=12/25/2023 14:30:45", + expect: TestStruct{ + CustomFormat: time.Date(2023, 12, 25, 14, 30, 45, 0, time.UTC), + }, + }, + { + name: "nok, invalid format should fail", + contentType: MIMEApplicationForm, + data: "datetime_local=invalid-date", + expectError: true, + }, + { + name: "nok, wrong format should fail", + contentType: MIMEApplicationForm, + data: "datetime_local=2023-12-25", // Missing time part + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + var req *http.Request + + if tc.contentType == MIMEApplicationJSON { + req = httptest.NewRequest(http.MethodPost, "/"+tc.queryParams, strings.NewReader(tc.data)) + req.Header.Set(HeaderContentType, tc.contentType) + } else if tc.contentType == MIMEApplicationForm { + req = httptest.NewRequest(http.MethodPost, "/"+tc.queryParams, strings.NewReader(tc.data)) + req.Header.Set(HeaderContentType, tc.contentType) + } else { + req = httptest.NewRequest(http.MethodGet, "/"+tc.queryParams, nil) + } + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + var result TestStruct + err := c.Bind(&result) + + if tc.expectError { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + + // Check individual fields since time comparison can be tricky + if !tc.expect.DateTimeLocal.IsZero() { + assert.True(t, tc.expect.DateTimeLocal.Equal(result.DateTimeLocal), + "DateTimeLocal: expected %v, got %v", tc.expect.DateTimeLocal, result.DateTimeLocal) + } + if !tc.expect.Date.IsZero() { + assert.True(t, tc.expect.Date.Equal(result.Date), + "Date: expected %v, got %v", tc.expect.Date, result.Date) + } + if !tc.expect.CustomFormat.IsZero() { + assert.True(t, tc.expect.CustomFormat.Equal(result.CustomFormat), + "CustomFormat: expected %v, got %v", tc.expect.CustomFormat, result.CustomFormat) + } + if !tc.expect.DefaultTime.IsZero() { + assert.True(t, tc.expect.DefaultTime.Equal(result.DefaultTime), + "DefaultTime: expected %v, got %v", tc.expect.DefaultTime, result.DefaultTime) + } + if tc.expect.PtrTime != nil { + assert.NotNil(t, result.PtrTime) + if result.PtrTime != nil { + expectedPtr := time.Date(2023, 2, 20, 0, 0, 0, 0, time.UTC) + assert.True(t, expectedPtr.Equal(*result.PtrTime), + "PtrTime: expected %v, got %v", expectedPtr, *result.PtrTime) + } + } + }) + } +} diff --git a/binder.go b/binder.go index 29cceca0b..da15ae82a 100644 --- a/binder.go +++ b/binder.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( @@ -66,9 +69,9 @@ import ( type BindingError struct { // Field is the field name where value binding failed Field string `json:"field"` + *HTTPError // Values of parameter that failed to bind. Values []string `json:"-"` - *HTTPError } // NewBindingError creates new instance of binding error @@ -91,16 +94,15 @@ func (be *BindingError) Error() string { // ValueBinder provides utility methods for binding query or path parameter to various Go built-in types type ValueBinder struct { - // failFast is flag for binding methods to return without attempting to bind when previous binding already failed - failFast bool - errors []error - // ValueFunc is used to get single parameter (first) value from request ValueFunc func(sourceParam string) string // ValuesFunc is used to get all values for parameter from request. i.e. `/api/search?ids=1&ids=2` ValuesFunc func(sourceParam string) []string // ErrorFunc is used to create errors. Allows you to use your own error type, that for example marshals to your specific json response ErrorFunc func(sourceParam string, values []string, message interface{}, internalError error) error + errors []error + // failFast is flag for binding methods to return without attempting to bind when previous binding already failed + failFast bool } // QueryParamsBinder creates query parameter value binder @@ -1323,7 +1325,7 @@ func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExi case time.Second: *dest = time.Unix(n, 0) case time.Millisecond: - *dest = time.Unix(n/1e3, (n%1e3)*1e6) // TODO: time.UnixMilli(n) exists since Go1.17 switch to that when min version allows + *dest = time.UnixMilli(n) case time.Nanosecond: *dest = time.Unix(0, n) } diff --git a/binder_external_test.go b/binder_external_test.go index f1aecb52b..e44055a23 100644 --- a/binder_external_test.go +++ b/binder_external_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + // run tests as external package to get real feel for API package echo_test diff --git a/binder_test.go b/binder_test.go index 0b27cae64..d552b604d 100644 --- a/binder_test.go +++ b/binder_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/context.go b/context.go index 27da28a9c..a70338d3c 100644 --- a/context.go +++ b/context.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( @@ -13,204 +16,216 @@ import ( "sync" ) -type ( - // Context represents the context of the current HTTP request. It holds request and - // response objects, path, path parameters, data and registered handler. - Context interface { - // Request returns `*http.Request`. - Request() *http.Request +// Context represents the context of the current HTTP request. It holds request and +// response objects, path, path parameters, data and registered handler. +type Context interface { + // Request returns `*http.Request`. + Request() *http.Request - // SetRequest sets `*http.Request`. - SetRequest(r *http.Request) + // SetRequest sets `*http.Request`. + SetRequest(r *http.Request) - // SetResponse sets `*Response`. - SetResponse(r *Response) + // SetResponse sets `*Response`. + SetResponse(r *Response) - // Response returns `*Response`. - Response() *Response + // Response returns `*Response`. + Response() *Response - // IsTLS returns true if HTTP connection is TLS otherwise false. - IsTLS() bool + // IsTLS returns true if HTTP connection is TLS otherwise false. + IsTLS() bool - // IsWebSocket returns true if HTTP connection is WebSocket otherwise false. - IsWebSocket() bool + // IsWebSocket returns true if HTTP connection is WebSocket otherwise false. + IsWebSocket() bool - // Scheme returns the HTTP protocol scheme, `http` or `https`. - Scheme() string + // Scheme returns the HTTP protocol scheme, `http` or `https`. + Scheme() string - // RealIP returns the client's network address based on `X-Forwarded-For` - // or `X-Real-IP` request header. - // The behavior can be configured using `Echo#IPExtractor`. - RealIP() string + // RealIP returns the client's network address based on `X-Forwarded-For` + // or `X-Real-IP` request header. + // The behavior can be configured using `Echo#IPExtractor`. + RealIP() string - // Path returns the registered path for the handler. - Path() string + // Path returns the registered path for the handler. + Path() string - // SetPath sets the registered path for the handler. - SetPath(p string) + // SetPath sets the registered path for the handler. + SetPath(p string) - // Param returns path parameter by name. - Param(name string) string + // Param returns path parameter by name. + Param(name string) string - // ParamNames returns path parameter names. - ParamNames() []string + // ParamNames returns path parameter names. + ParamNames() []string - // SetParamNames sets path parameter names. - SetParamNames(names ...string) + // SetParamNames sets path parameter names. + SetParamNames(names ...string) - // ParamValues returns path parameter values. - ParamValues() []string + // ParamValues returns path parameter values. + ParamValues() []string - // SetParamValues sets path parameter values. - SetParamValues(values ...string) + // SetParamValues sets path parameter values. + SetParamValues(values ...string) - // QueryParam returns the query param for the provided name. - QueryParam(name string) string + // QueryParam returns the query param for the provided name. + QueryParam(name string) string - // QueryParams returns the query parameters as `url.Values`. - QueryParams() url.Values + // QueryParams returns the query parameters as `url.Values`. + QueryParams() url.Values - // QueryString returns the URL query string. - QueryString() string + // QueryString returns the URL query string. + QueryString() string - // FormValue returns the form field value for the provided name. - FormValue(name string) string + // FormValue returns the form field value for the provided name. + FormValue(name string) string - // FormParams returns the form parameters as `url.Values`. - FormParams() (url.Values, error) + // FormParams returns the form parameters as `url.Values`. + FormParams() (url.Values, error) - // FormFile returns the multipart form file for the provided name. - FormFile(name string) (*multipart.FileHeader, error) + // FormFile returns the multipart form file for the provided name. + FormFile(name string) (*multipart.FileHeader, error) - // MultipartForm returns the multipart form. - MultipartForm() (*multipart.Form, error) + // MultipartForm returns the multipart form. + MultipartForm() (*multipart.Form, error) - // Cookie returns the named cookie provided in the request. - Cookie(name string) (*http.Cookie, error) + // Cookie returns the named cookie provided in the request. + Cookie(name string) (*http.Cookie, error) - // SetCookie adds a `Set-Cookie` header in HTTP response. - SetCookie(cookie *http.Cookie) + // SetCookie adds a `Set-Cookie` header in HTTP response. + SetCookie(cookie *http.Cookie) - // Cookies returns the HTTP cookies sent with the request. - Cookies() []*http.Cookie + // Cookies returns the HTTP cookies sent with the request. + Cookies() []*http.Cookie - // Get retrieves data from the context. - Get(key string) interface{} + // Get retrieves data from the context. + Get(key string) any - // Set saves data in the context. - Set(key string, val interface{}) + // Set saves data in the context. + Set(key string, val any) - // Bind binds path params, query params and the request body into provided type `i`. The default binder - // binds body based on Content-Type header. - Bind(i interface{}) error + // Bind binds path params, query params and the request body into provided type `i`. The default binder + // binds body based on Content-Type header. + Bind(i any) error - // Validate validates provided `i`. It is usually called after `Context#Bind()`. - // Validator must be registered using `Echo#Validator`. - Validate(i interface{}) error + // Validate validates provided `i`. It is usually called after `Context#Bind()`. + // Validator must be registered using `Echo#Validator`. + Validate(i any) error - // Render renders a template with data and sends a text/html response with status - // code. Renderer must be registered using `Echo.Renderer`. - Render(code int, name string, data interface{}) error + // Render renders a template with data and sends a text/html response with status + // code. Renderer must be registered using `Echo.Renderer`. + Render(code int, name string, data any) error - // HTML sends an HTTP response with status code. - HTML(code int, html string) error + // HTML sends an HTTP response with status code. + HTML(code int, html string) error - // HTMLBlob sends an HTTP blob response with status code. - HTMLBlob(code int, b []byte) error + // HTMLBlob sends an HTTP blob response with status code. + HTMLBlob(code int, b []byte) error - // String sends a string response with status code. - String(code int, s string) error + // String sends a string response with status code. + String(code int, s string) error - // JSON sends a JSON response with status code. - JSON(code int, i interface{}) error + // JSON sends a JSON response with status code. + JSON(code int, i any) error - // JSONPretty sends a pretty-print JSON with status code. - JSONPretty(code int, i interface{}, indent string) error + // JSONPretty sends a pretty-print JSON with status code. + JSONPretty(code int, i any, indent string) error - // JSONBlob sends a JSON blob response with status code. - JSONBlob(code int, b []byte) error + // JSONBlob sends a JSON blob response with status code. + JSONBlob(code int, b []byte) error - // JSONP sends a JSONP response with status code. It uses `callback` to construct - // the JSONP payload. - JSONP(code int, callback string, i interface{}) error + // JSONP sends a JSONP response with status code. It uses `callback` to construct + // the JSONP payload. + JSONP(code int, callback string, i any) error - // JSONPBlob sends a JSONP blob response with status code. It uses `callback` - // to construct the JSONP payload. - JSONPBlob(code int, callback string, b []byte) error + // JSONPBlob sends a JSONP blob response with status code. It uses `callback` + // to construct the JSONP payload. + JSONPBlob(code int, callback string, b []byte) error - // XML sends an XML response with status code. - XML(code int, i interface{}) error + // XML sends an XML response with status code. + XML(code int, i any) error - // XMLPretty sends a pretty-print XML with status code. - XMLPretty(code int, i interface{}, indent string) error + // XMLPretty sends a pretty-print XML with status code. + XMLPretty(code int, i any, indent string) error - // XMLBlob sends an XML blob response with status code. - XMLBlob(code int, b []byte) error + // XMLBlob sends an XML blob response with status code. + XMLBlob(code int, b []byte) error - // Blob sends a blob response with status code and content type. - Blob(code int, contentType string, b []byte) error + // Blob sends a blob response with status code and content type. + Blob(code int, contentType string, b []byte) error - // Stream sends a streaming response with status code and content type. - Stream(code int, contentType string, r io.Reader) error + // Stream sends a streaming response with status code and content type. + Stream(code int, contentType string, r io.Reader) error - // File sends a response with the content of the file. - File(file string) error + // File sends a response with the content of the file. + File(file string) error - // Attachment sends a response as attachment, prompting client to save the - // file. - Attachment(file string, name string) error + // Attachment sends a response as attachment, prompting client to save the + // file. + Attachment(file string, name string) error - // Inline sends a response as inline, opening the file in the browser. - Inline(file string, name string) error + // Inline sends a response as inline, opening the file in the browser. + Inline(file string, name string) error - // NoContent sends a response with no body and a status code. - NoContent(code int) error + // NoContent sends a response with no body and a status code. + NoContent(code int) error - // Redirect redirects the request to a provided URL with status code. - Redirect(code int, url string) error + // Redirect redirects the request to a provided URL with status code. + Redirect(code int, url string) error - // Error invokes the registered global HTTP error handler. Generally used by middleware. - // A side-effect of calling global error handler is that now Response has been committed (sent to the client) and - // middlewares up in chain can not change Response status code or Response body anymore. - // - // Avoid using this method in handlers as no middleware will be able to effectively handle errors after that. - Error(err error) + // Error invokes the registered global HTTP error handler. Generally used by middleware. + // A side-effect of calling global error handler is that now Response has been committed (sent to the client) and + // middlewares up in chain can not change Response status code or Response body anymore. + // + // Avoid using this method in handlers as no middleware will be able to effectively handle errors after that. + Error(err error) - // Handler returns the matched handler by router. - Handler() HandlerFunc + // Handler returns the matched handler by router. + Handler() HandlerFunc - // SetHandler sets the matched handler by router. - SetHandler(h HandlerFunc) + // SetHandler sets the matched handler by router. + SetHandler(h HandlerFunc) - // Logger returns the `Logger` instance. - Logger() Logger + // Logger returns the `Logger` instance. + Logger() Logger - // SetLogger Set the logger - SetLogger(l Logger) + // SetLogger Set the logger + SetLogger(l Logger) - // Echo returns the `Echo` instance. - Echo() *Echo + // Echo returns the `Echo` instance. + Echo() *Echo - // Reset resets the context after request completes. It must be called along - // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. - // See `Echo#ServeHTTP()` - Reset(r *http.Request, w http.ResponseWriter) - } + // Reset resets the context after request completes. It must be called along + // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. + // See `Echo#ServeHTTP()` + Reset(r *http.Request, w http.ResponseWriter) +} - context struct { - request *http.Request - response *Response - path string - pnames []string - pvalues []string - query url.Values - handler HandlerFunc - store Map - echo *Echo - logger Logger - lock sync.RWMutex - } -) +type context struct { + logger Logger + request *http.Request + response *Response + query url.Values + echo *Echo + + store Map + lock sync.RWMutex + + // following fields are set by Router + handler HandlerFunc + + // path is route path that Router matched. It is empty string where there is no route match. + // Route registered with RouteNotFound is considered as a match and path therefore is not empty. + path string + + // Usually echo.Echo is sizing pvalues but there could be user created middlewares that decide to + // overwrite parameter by calling SetParamNames + SetParamValues. + // When echo.Echo allocated that slice it length/capacity is tied to echo.Echo.maxParam value. + // + // It is important that pvalues size is always equal or bigger to pnames length. + pvalues []string + + // pnames length is tied to param count for the matched route + pnames []string +} const ( // ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain. @@ -329,13 +344,9 @@ func (c *context) SetParamNames(names ...string) { c.pnames = names l := len(names) - if *c.echo.maxParam < l { - *c.echo.maxParam = l - } - if len(c.pvalues) < l { // Keeping the old pvalues just for backward compatibility, but it sounds that doesn't make sense to keep them, - // probably those values will be overriden in a Context#SetParamValues + // probably those values will be overridden in a Context#SetParamValues newPvalues := make([]string, l) copy(newPvalues, c.pvalues) c.pvalues = newPvalues @@ -347,11 +358,11 @@ func (c *context) ParamValues() []string { } func (c *context) SetParamValues(values ...string) { - // NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam at all times + // NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam (or bigger) at all times // It will brake the Router#Find code limit := len(values) - if limit > *c.echo.maxParam { - limit = *c.echo.maxParam + if limit > len(c.pvalues) { + c.pvalues = make([]string, limit) } for i := 0; i < limit; i++ { c.pvalues[i] = values[i] @@ -419,13 +430,13 @@ func (c *context) Cookies() []*http.Cookie { return c.request.Cookies() } -func (c *context) Get(key string) interface{} { +func (c *context) Get(key string) any { c.lock.RLock() defer c.lock.RUnlock() return c.store[key] } -func (c *context) Set(key string, val interface{}) { +func (c *context) Set(key string, val any) { c.lock.Lock() defer c.lock.Unlock() @@ -435,18 +446,18 @@ func (c *context) Set(key string, val interface{}) { c.store[key] = val } -func (c *context) Bind(i interface{}) error { +func (c *context) Bind(i any) error { return c.echo.Binder.Bind(i, c) } -func (c *context) Validate(i interface{}) error { +func (c *context) Validate(i any) error { if c.echo.Validator == nil { return ErrValidatorNotRegistered } return c.echo.Validator.Validate(i) } -func (c *context) Render(code int, name string, data interface{}) (err error) { +func (c *context) Render(code int, name string, data any) (err error) { if c.echo.Renderer == nil { return ErrRendererNotRegistered } @@ -469,7 +480,7 @@ func (c *context) String(code int, s string) (err error) { return c.Blob(code, MIMETextPlainCharsetUTF8, []byte(s)) } -func (c *context) jsonPBlob(code int, callback string, i interface{}) (err error) { +func (c *context) jsonPBlob(code int, callback string, i any) (err error) { indent := "" if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { indent = defaultIndent @@ -488,13 +499,13 @@ func (c *context) jsonPBlob(code int, callback string, i interface{}) (err error return } -func (c *context) json(code int, i interface{}, indent string) error { - c.writeContentType(MIMEApplicationJSONCharsetUTF8) +func (c *context) json(code int, i any, indent string) error { + c.writeContentType(MIMEApplicationJSON) c.response.Status = code return c.echo.JSONSerializer.Serialize(c, i, indent) } -func (c *context) JSON(code int, i interface{}) (err error) { +func (c *context) JSON(code int, i any) (err error) { indent := "" if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { indent = defaultIndent @@ -502,15 +513,15 @@ func (c *context) JSON(code int, i interface{}) (err error) { return c.json(code, i, indent) } -func (c *context) JSONPretty(code int, i interface{}, indent string) (err error) { +func (c *context) JSONPretty(code int, i any, indent string) (err error) { return c.json(code, i, indent) } func (c *context) JSONBlob(code int, b []byte) (err error) { - return c.Blob(code, MIMEApplicationJSONCharsetUTF8, b) + return c.Blob(code, MIMEApplicationJSON, b) } -func (c *context) JSONP(code int, callback string, i interface{}) (err error) { +func (c *context) JSONP(code int, callback string, i any) (err error) { return c.jsonPBlob(code, callback, i) } @@ -527,7 +538,7 @@ func (c *context) JSONPBlob(code int, callback string, b []byte) (err error) { return } -func (c *context) xml(code int, i interface{}, indent string) (err error) { +func (c *context) xml(code int, i any, indent string) (err error) { c.writeContentType(MIMEApplicationXMLCharsetUTF8) c.response.WriteHeader(code) enc := xml.NewEncoder(c.response) @@ -540,7 +551,7 @@ func (c *context) xml(code int, i interface{}, indent string) (err error) { return enc.Encode(i) } -func (c *context) XML(code int, i interface{}) (err error) { +func (c *context) XML(code int, i any) (err error) { indent := "" if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { indent = defaultIndent @@ -548,7 +559,7 @@ func (c *context) XML(code int, i interface{}) (err error) { return c.xml(code, i, indent) } -func (c *context) XMLPretty(code int, i interface{}, indent string) (err error) { +func (c *context) XMLPretty(code int, i any, indent string) (err error) { return c.xml(code, i, indent) } @@ -584,8 +595,10 @@ func (c *context) Inline(file, name string) error { return c.contentDisposition(file, name, "inline") } +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + func (c *context) contentDisposition(file, name, dispositionType string) error { - c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf("%s; filename=%q", dispositionType, name)) + c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf(`%s; filename="%s"`, dispositionType, quoteEscaper.Replace(name))) return c.File(file) } @@ -640,8 +653,8 @@ func (c *context) Reset(r *http.Request, w http.ResponseWriter) { c.path = "" c.pnames = nil c.logger = nil - // NOTE: Don't reset because it has to have length c.echo.maxParam at all times - for i := 0; i < *c.echo.maxParam; i++ { + // NOTE: Don't reset because it has to have length c.echo.maxParam (or bigger) at all times + for i := 0; i < len(c.pvalues); i++ { c.pvalues[i] = "" } } diff --git a/context_fs.go b/context_fs.go index 1038f892e..1c25baf12 100644 --- a/context_fs.go +++ b/context_fs.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/context_fs_test.go b/context_fs_test.go index 51346c956..83232ea45 100644 --- a/context_fs_test.go +++ b/context_fs_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/context_test.go b/context_test.go index 11a63cfce..1fd89edb4 100644 --- a/context_test.go +++ b/context_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( @@ -19,14 +22,12 @@ import ( "time" "github.com/labstack/gommon/log" - testify "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/assert" ) -type ( - Template struct { - templates *template.Template - } -) +type Template struct { + templates *template.Template +} var testUser = user{1, "Jon Snow"} @@ -85,303 +86,443 @@ func (t *Template) Render(w io.Writer, name string, data interface{}, c Context) return t.templates.ExecuteTemplate(w, name, data) } -type responseWriterErr struct { -} - -func (responseWriterErr) Header() http.Header { - return http.Header{} -} +func TestContextEcho(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + rec := httptest.NewRecorder() -func (responseWriterErr) Write([]byte) (int, error) { - return 0, errors.New("err") -} + c := e.NewContext(req, rec).(*context) -func (responseWriterErr) WriteHeader(statusCode int) { + assert.Equal(t, e, c.Echo()) } -func TestContext(t *testing.T) { +func TestContextRequest(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() + c := e.NewContext(req, rec).(*context) - assert := testify.New(t) + assert.NotNil(t, c.Request()) + assert.Equal(t, req, c.Request()) +} - // Echo - assert.Equal(e, c.Echo()) +func TestContextResponse(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + rec := httptest.NewRecorder() - // Request - assert.NotNil(c.Request()) + c := e.NewContext(req, rec).(*context) - // Response - assert.NotNil(c.Response()) + assert.NotNil(t, c.Response()) +} - //-------- - // Render - //-------- +func TestContextRenderTemplate(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec).(*context) tmpl := &Template{ templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")), } c.echo.Renderer = tmpl err := c.Render(http.StatusOK, "hello", "Jon Snow") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("Hello, Jon Snow!", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "Hello, Jon Snow!", rec.Body.String()) } +} + +func TestContextRenderErrorsOnNoRenderer(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec).(*context) c.echo.Renderer = nil - err = c.Render(http.StatusOK, "hello", "Jon Snow") - assert.Error(err) - - // JSON - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON+"\n", rec.Body.String()) - } - - // JSON with "?pretty" - req = httptest.NewRequest(http.MethodGet, "/?pretty", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSONPretty+"\n", rec.Body.String()) - } - req = httptest.NewRequest(http.MethodGet, "/", nil) // reset - - // JSONPretty - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSONPretty+"\n", rec.Body.String()) - } - - // JSON (error) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.JSON(http.StatusOK, make(chan bool)) - assert.Error(err) - - // JSONP - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + assert.Error(t, c.Render(http.StatusOK, "hello", "Jon Snow")) +} + +func TestContextJSON(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + c := e.NewContext(req, rec).(*context) + + err := c.JSON(http.StatusOK, user{1, "Jon Snow"}) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON+"\n", rec.Body.String()) + } +} + +func TestContextJSONErrorsOut(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + c := e.NewContext(req, rec).(*context) + + err := c.JSON(http.StatusOK, make(chan bool)) + assert.EqualError(t, err, "json: unsupported type: chan bool") +} + +func TestContextJSONPrettyURL(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + + err := c.JSON(http.StatusOK, user{1, "Jon Snow"}) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) + } +} + +func TestContextJSONPretty(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + err := c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) + } +} + +func TestContextJSONWithEmptyIntent(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + u := user{1, "Jon Snow"} + emptyIndent := "" + buf := new(bytes.Buffer) + + enc := json.NewEncoder(buf) + enc.SetIndent(emptyIndent, emptyIndent) + _ = enc.Encode(u) + err := c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) + assert.Equal(t, buf.String(), rec.Body.String()) + } +} + +func TestContextJSONP(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + callback := "callback" - err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(callback+"("+userJSON+"\n);", rec.Body.String()) - } - - // XML - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.XML(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXML, rec.Body.String()) - } - - // XML with "?pretty" - req = httptest.NewRequest(http.MethodGet, "/?pretty", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.XML(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXMLPretty, rec.Body.String()) - } - req = httptest.NewRequest(http.MethodGet, "/", nil) - - // XML (error) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.XML(http.StatusOK, make(chan bool)) - assert.Error(err) - - // XML response write error - c = e.NewContext(req, rec).(*context) - c.response.Writer = responseWriterErr{} - err = c.XML(0, 0) - testify.Error(t, err) - - // XMLPretty - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXMLPretty, rec.Body.String()) - } - - t.Run("empty indent", func(t *testing.T) { - var ( - u = user{1, "Jon Snow"} - buf = new(bytes.Buffer) - emptyIndent = "" - ) - - t.Run("json", func(t *testing.T) { - buf.Reset() - assert := testify.New(t) - - // New JSONBlob with empty indent - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - enc := json.NewEncoder(buf) - enc.SetIndent(emptyIndent, emptyIndent) - err = enc.Encode(u) - err = c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(buf.String(), rec.Body.String()) - } - }) + err := c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"}) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, callback+"("+userJSON+"\n);", rec.Body.String()) + } +} - t.Run("xml", func(t *testing.T) { - buf.Reset() - assert := testify.New(t) - - // New XMLBlob with empty indent - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - enc := xml.NewEncoder(buf) - enc.Indent(emptyIndent, emptyIndent) - err = enc.Encode(u) - err = c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+buf.String(), rec.Body.String()) - } - }) - }) +func TestContextJSONBlob(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) - // Legacy JSONBlob - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) data, err := json.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) + assert.NoError(t, err) err = c.JSONBlob(http.StatusOK, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON, rec.Body.String()) - } - - // Legacy JSONPBlob - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - callback = "callback" - data, err = json.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON, rec.Body.String()) + } +} + +func TestContextJSONPBlob(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + callback := "callback" + data, err := json.Marshal(user{1, "Jon Snow"}) + assert.NoError(t, err) err = c.JSONPBlob(http.StatusOK, callback, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(callback+"("+userJSON+");", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, callback+"("+userJSON+");", rec.Body.String()) + } +} + +func TestContextXML(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + err := c.XML(http.StatusOK, user{1, "Jon Snow"}) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXML, rec.Body.String()) + } +} + +func TestContextXMLPrettyURL(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + + err := c.XML(http.StatusOK, user{1, "Jon Snow"}) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String()) } +} + +func TestContextXMLPretty(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + err := c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String()) + } +} + +func TestContextXMLBlob(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) - // Legacy XMLBlob - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - data, err = xml.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) + data, err := xml.Marshal(user{1, "Jon Snow"}) + assert.NoError(t, err) err = c.XMLBlob(http.StatusOK, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXML, rec.Body.String()) - } - - // String - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.String(http.StatusOK, "Hello, World!") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal("Hello, World!", rec.Body.String()) - } - - // HTML - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.HTML(http.StatusOK, "Hello, World!") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal("Hello, World!", rec.Body.String()) - } - - // Stream - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXML, rec.Body.String()) + } +} + +func TestContextXMLWithEmptyIntent(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + u := user{1, "Jon Snow"} + emptyIndent := "" + buf := new(bytes.Buffer) + + enc := xml.NewEncoder(buf) + enc.Indent(emptyIndent, emptyIndent) + _ = enc.Encode(u) + err := c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+buf.String(), rec.Body.String()) + } +} + +type responseWriterErr struct { +} + +func (responseWriterErr) Header() http.Header { + return http.Header{} +} + +func (responseWriterErr) Write([]byte) (int, error) { + return 0, errors.New("responseWriterErr") +} + +func (responseWriterErr) WriteHeader(statusCode int) { +} + +func TestContextXMLError(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + c.response.Writer = responseWriterErr{} + + err := c.XML(http.StatusOK, make(chan bool)) + assert.EqualError(t, err, "responseWriterErr") +} + +func TestContextString(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + + err := c.String(http.StatusOK, "Hello, World!") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, "Hello, World!", rec.Body.String()) + } +} + +func TestContextHTML(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + + err := c.HTML(http.StatusOK, "Hello, World!") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, "Hello, World!", rec.Body.String()) + } +} + +func TestContextStream(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + r := strings.NewReader("response from a stream") - err = c.Stream(http.StatusOK, "application/octet-stream", r) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("application/octet-stream", rec.Header().Get(HeaderContentType)) - assert.Equal("response from a stream", rec.Body.String()) - } - - // Attachment - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.Attachment("_fixture/images/walle.png", "walle.png") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) - assert.Equal(219885, rec.Body.Len()) - } - - // Inline - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.Inline("_fixture/images/walle.png", "walle.png") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) - assert.Equal(219885, rec.Body.Len()) - } - - // NoContent - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + err := c.Stream(http.StatusOK, "application/octet-stream", r) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "application/octet-stream", rec.Header().Get(HeaderContentType)) + assert.Equal(t, "response from a stream", rec.Body.String()) + } +} + +func TestContextAttachment(t *testing.T) { + var testCases = []struct { + name string + whenName string + expectHeader string + }{ + { + name: "ok", + whenName: "walle.png", + expectHeader: `attachment; filename="walle.png"`, + }, + { + name: "ok, escape quotes in malicious filename", + whenName: `malicious.sh"; \"; dummy=.txt`, + expectHeader: `attachment; filename="malicious.sh\"; \\\"; dummy=.txt"`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + err := c.Attachment("_fixture/images/walle.png", tc.whenName) + if assert.NoError(t, err) { + assert.Equal(t, tc.expectHeader, rec.Header().Get(HeaderContentDisposition)) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, 219885, rec.Body.Len()) + } + }) + } +} + +func TestContextInline(t *testing.T) { + var testCases = []struct { + name string + whenName string + expectHeader string + }{ + { + name: "ok", + whenName: "walle.png", + expectHeader: `inline; filename="walle.png"`, + }, + { + name: "ok, escape quotes in malicious filename", + whenName: `malicious.sh"; \"; dummy=.txt`, + expectHeader: `inline; filename="malicious.sh\"; \\\"; dummy=.txt"`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + err := c.Inline("_fixture/images/walle.png", tc.whenName) + if assert.NoError(t, err) { + assert.Equal(t, tc.expectHeader, rec.Header().Get(HeaderContentDisposition)) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, 219885, rec.Body.Len()) + } + }) + } +} + +func TestContextNoContent(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + c.NoContent(http.StatusOK) - assert.Equal(http.StatusOK, rec.Code) + assert.Equal(t, http.StatusOK, rec.Code) +} + +func TestContextError(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) - // Error - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) c.Error(errors.New("error")) - assert.Equal(http.StatusInternalServerError, rec.Code) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.True(t, c.Response().Committed) +} + +func TestContextReset(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) - // Reset c.SetParamNames("foo") c.SetParamValues("bar") c.Set("foe", "ban") c.query = url.Values(map[string][]string{"fon": {"baz"}}) + c.Reset(req, httptest.NewRecorder()) - assert.Equal(0, len(c.ParamValues())) - assert.Equal(0, len(c.ParamNames())) - assert.Equal(0, len(c.store)) - assert.Equal("", c.Path()) - assert.Equal(0, len(c.QueryParams())) + + assert.Len(t, c.ParamValues(), 0) + assert.Len(t, c.ParamNames(), 0) + assert.Len(t, c.Path(), 0) + assert.Len(t, c.QueryParams(), 0) + assert.Len(t, c.store, 0) } func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { @@ -391,11 +532,10 @@ func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { c := e.NewContext(req, rec).(*context) err := c.JSON(http.StatusCreated, user{1, "Jon Snow"}) - assert := testify.New(t) - if assert.NoError(err) { - assert.Equal(http.StatusCreated, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusCreated, rec.Code) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON+"\n", rec.Body.String()) } } @@ -406,9 +546,8 @@ func TestContext_JSON_DoesntCommitResponseCodePrematurely(t *testing.T) { c := e.NewContext(req, rec).(*context) err := c.JSON(http.StatusCreated, map[string]float64{"a": math.NaN()}) - assert := testify.New(t) - if assert.Error(err) { - assert.False(c.response.Committed) + if assert.Error(t, err) { + assert.False(t, c.response.Committed) } } @@ -422,22 +561,20 @@ func TestContextCookie(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*context) - assert := testify.New(t) - // Read single cookie, err := c.Cookie("theme") - if assert.NoError(err) { - assert.Equal("theme", cookie.Name) - assert.Equal("light", cookie.Value) + if assert.NoError(t, err) { + assert.Equal(t, "theme", cookie.Name) + assert.Equal(t, "light", cookie.Value) } // Read multiple for _, cookie := range c.Cookies() { switch cookie.Name { case "theme": - assert.Equal("light", cookie.Value) + assert.Equal(t, "light", cookie.Value) case "user": - assert.Equal("Jon Snow", cookie.Value) + assert.Equal(t, "Jon Snow", cookie.Value) } } @@ -452,11 +589,11 @@ func TestContextCookie(t *testing.T) { HttpOnly: true, } c.SetCookie(cookie) - assert.Contains(rec.Header().Get(HeaderSetCookie), "SSID") - assert.Contains(rec.Header().Get(HeaderSetCookie), "Ap4PGTEq") - assert.Contains(rec.Header().Get(HeaderSetCookie), "labstack.com") - assert.Contains(rec.Header().Get(HeaderSetCookie), "Secure") - assert.Contains(rec.Header().Get(HeaderSetCookie), "HttpOnly") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "labstack.com") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Secure") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly") } func TestContextPath(t *testing.T) { @@ -469,14 +606,12 @@ func TestContextPath(t *testing.T) { c := e.NewContext(nil, nil) r.Find(http.MethodGet, "/users/1", c) - assert := testify.New(t) - - assert.Equal("/users/:id", c.Path()) + assert.Equal(t, "/users/:id", c.Path()) r.Add(http.MethodGet, "/users/:uid/files/:fid", handler) c = e.NewContext(nil, nil) r.Find(http.MethodGet, "/users/1/files/1", c) - assert.Equal("/users/:uid/files/:fid", c.Path()) + assert.Equal(t, "/users/:uid/files/:fid", c.Path()) } func TestContextPathParam(t *testing.T) { @@ -486,15 +621,15 @@ func TestContextPathParam(t *testing.T) { // ParamNames c.SetParamNames("uid", "fid") - testify.EqualValues(t, []string{"uid", "fid"}, c.ParamNames()) + assert.EqualValues(t, []string{"uid", "fid"}, c.ParamNames()) // ParamValues c.SetParamValues("101", "501") - testify.EqualValues(t, []string{"101", "501"}, c.ParamValues()) + assert.EqualValues(t, []string{"101", "501"}, c.ParamValues()) // Param - testify.Equal(t, "501", c.Param("fid")) - testify.Equal(t, "", c.Param("undefined")) + assert.Equal(t, "501", c.Param("fid")) + assert.Equal(t, "", c.Param("undefined")) } func TestContextGetAndSetParam(t *testing.T) { @@ -507,49 +642,61 @@ func TestContextGetAndSetParam(t *testing.T) { // round-trip param values with modification paramVals := c.ParamValues() - testify.EqualValues(t, []string{""}, c.ParamValues()) + assert.EqualValues(t, []string{""}, c.ParamValues()) paramVals[0] = "bar" c.SetParamValues(paramVals...) - testify.EqualValues(t, []string{"bar"}, c.ParamValues()) + assert.EqualValues(t, []string{"bar"}, c.ParamValues()) // shouldn't explode during Reset() afterwards! - testify.NotPanics(t, func() { + assert.NotPanics(t, func() { c.Reset(nil, nil) }) } -// Issue #1655 -func TestContextSetParamNamesShouldUpdateEchoMaxParam(t *testing.T) { - assert := testify.New(t) - +func TestContextSetParamNamesEchoMaxParam(t *testing.T) { e := New() - assert.Equal(0, *e.maxParam) + assert.Equal(t, 0, *e.maxParam) expectedOneParam := []string{"one"} expectedTwoParams := []string{"one", "two"} expectedThreeParams := []string{"one", "two", ""} - expectedABCParams := []string{"A", "B", "C"} - c := e.NewContext(nil, nil) - c.SetParamNames("1", "2") - c.SetParamValues(expectedTwoParams...) - assert.Equal(2, *e.maxParam) - assert.EqualValues(expectedTwoParams, c.ParamValues()) + { + c := e.AcquireContext() + c.SetParamNames("1", "2") + c.SetParamValues(expectedTwoParams...) + assert.Equal(t, 0, *e.maxParam) // has not been changed + assert.EqualValues(t, expectedTwoParams, c.ParamValues()) + e.ReleaseContext(c) + } + + { + c := e.AcquireContext() + c.SetParamNames("1", "2", "3") + c.SetParamValues(expectedThreeParams...) + assert.Equal(t, 0, *e.maxParam) // has not been changed + assert.EqualValues(t, expectedThreeParams, c.ParamValues()) + e.ReleaseContext(c) + } - c.SetParamNames("1") - assert.Equal(2, *e.maxParam) - // Here for backward compatibility the ParamValues remains as they are - assert.EqualValues(expectedOneParam, c.ParamValues()) + { // values is always same size as names length + c := e.NewContext(nil, nil) + c.SetParamValues([]string{"one", "two"}...) // more values than names should be ok + c.SetParamNames("1") + assert.Equal(t, 0, *e.maxParam) // has not been changed + assert.EqualValues(t, expectedOneParam, c.ParamValues()) + } - c.SetParamNames("1", "2", "3") - assert.Equal(3, *e.maxParam) - // Here for backward compatibility the ParamValues remains as they are, but the len is extended to e.maxParam - assert.EqualValues(expectedThreeParams, c.ParamValues()) + e.GET("/:id", handlerFunc) + assert.Equal(t, 1, *e.maxParam) // has not been changed - c.SetParamValues("A", "B", "C", "D") - assert.Equal(3, *e.maxParam) - // Here D shouldn't be returned - assert.EqualValues(expectedABCParams, c.ParamValues()) + { + c := e.NewContext(nil, nil) + c.SetParamValues([]string{"one", "two"}...) + c.SetParamNames("1") + assert.Equal(t, 1, *e.maxParam) // has not been changed + assert.EqualValues(t, expectedOneParam, c.ParamValues()) + } } func TestContextFormValue(t *testing.T) { @@ -563,13 +710,13 @@ func TestContextFormValue(t *testing.T) { c := e.NewContext(req, nil) // FormValue - testify.Equal(t, "Jon Snow", c.FormValue("name")) - testify.Equal(t, "jon@labstack.com", c.FormValue("email")) + assert.Equal(t, "Jon Snow", c.FormValue("name")) + assert.Equal(t, "jon@labstack.com", c.FormValue("email")) // FormParams params, err := c.FormParams() - if testify.NoError(t, err) { - testify.Equal(t, url.Values{ + if assert.NoError(t, err) { + assert.Equal(t, url.Values{ "name": []string{"Jon Snow"}, "email": []string{"jon@labstack.com"}, }, params) @@ -580,8 +727,8 @@ func TestContextFormValue(t *testing.T) { req.Header.Add(HeaderContentType, MIMEMultipartForm) c = e.NewContext(req, nil) params, err = c.FormParams() - testify.Nil(t, params) - testify.Error(t, err) + assert.Nil(t, params) + assert.Error(t, err) } func TestContextQueryParam(t *testing.T) { @@ -593,11 +740,11 @@ func TestContextQueryParam(t *testing.T) { c := e.NewContext(req, nil) // QueryParam - testify.Equal(t, "Jon Snow", c.QueryParam("name")) - testify.Equal(t, "jon@labstack.com", c.QueryParam("email")) + assert.Equal(t, "Jon Snow", c.QueryParam("name")) + assert.Equal(t, "jon@labstack.com", c.QueryParam("email")) // QueryParams - testify.Equal(t, url.Values{ + assert.Equal(t, url.Values{ "name": []string{"Jon Snow"}, "email": []string{"jon@labstack.com"}, }, c.QueryParams()) @@ -608,7 +755,7 @@ func TestContextFormFile(t *testing.T) { buf := new(bytes.Buffer) mr := multipart.NewWriter(buf) w, err := mr.CreateFormFile("file", "test") - if testify.NoError(t, err) { + if assert.NoError(t, err) { w.Write([]byte("test")) } mr.Close() @@ -617,8 +764,8 @@ func TestContextFormFile(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) f, err := c.FormFile("file") - if testify.NoError(t, err) { - testify.Equal(t, "test", f.Filename) + if assert.NoError(t, err) { + assert.Equal(t, "test", f.Filename) } } @@ -627,14 +774,26 @@ func TestContextMultipartForm(t *testing.T) { buf := new(bytes.Buffer) mw := multipart.NewWriter(buf) mw.WriteField("name", "Jon Snow") + fileContent := "This is a test file" + w, err := mw.CreateFormFile("file", "test.txt") + if assert.NoError(t, err) { + w.Write([]byte(fileContent)) + } mw.Close() req := httptest.NewRequest(http.MethodPost, "/", buf) req.Header.Set(HeaderContentType, mw.FormDataContentType()) rec := httptest.NewRecorder() c := e.NewContext(req, rec) f, err := c.MultipartForm() - if testify.NoError(t, err) { - testify.NotNil(t, f) + if assert.NoError(t, err) { + assert.NotNil(t, f) + + files := f.File["file"] + if assert.Len(t, files, 1) { + file := files[0] + assert.Equal(t, "test.txt", file.Filename) + assert.Equal(t, int64(len(fileContent)), file.Size) + } } } @@ -643,16 +802,16 @@ func TestContextRedirect(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - testify.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) - testify.Equal(t, http.StatusMovedPermanently, rec.Code) - testify.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) - testify.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) + assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) + assert.Equal(t, http.StatusMovedPermanently, rec.Code) + assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) + assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) } func TestContextStore(t *testing.T) { var c Context = new(context) c.Set("name", "Jon Snow") - testify.Equal(t, "Jon Snow", c.Get("name")) + assert.Equal(t, "Jon Snow", c.Get("name")) } func BenchmarkContext_Store(b *testing.B) { @@ -682,19 +841,19 @@ func TestContextHandler(t *testing.T) { c := e.NewContext(nil, nil) r.Find(http.MethodGet, "/handler", c) err := c.Handler()(c) - testify.Equal(t, "handler", b.String()) - testify.NoError(t, err) + assert.Equal(t, "handler", b.String()) + assert.NoError(t, err) } func TestContext_SetHandler(t *testing.T) { var c Context = new(context) - testify.Nil(t, c.Handler()) + assert.Nil(t, c.Handler()) c.SetHandler(func(c Context) error { return nil }) - testify.NotNil(t, c.Handler()) + assert.NotNil(t, c.Handler()) } func TestContext_Path(t *testing.T) { @@ -703,7 +862,7 @@ func TestContext_Path(t *testing.T) { var c Context = new(context) c.SetPath(path) - testify.Equal(t, path, c.Path()) + assert.Equal(t, path, c.Path()) } type validator struct{} @@ -716,10 +875,10 @@ func TestContext_Validate(t *testing.T) { e := New() c := e.NewContext(nil, nil) - testify.Error(t, c.Validate(struct{}{})) + assert.Error(t, c.Validate(struct{}{})) e.Validator = &validator{} - testify.NoError(t, c.Validate(struct{}{})) + assert.NoError(t, c.Validate(struct{}{})) } func TestContext_QueryString(t *testing.T) { @@ -730,18 +889,18 @@ func TestContext_QueryString(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/?"+queryString, nil) c := e.NewContext(req, nil) - testify.Equal(t, queryString, c.QueryString()) + assert.Equal(t, queryString, c.QueryString()) } func TestContext_Request(t *testing.T) { var c Context = new(context) - testify.Nil(t, c.Request()) + assert.Nil(t, c.Request()) req := httptest.NewRequest(http.MethodGet, "/path", nil) c.SetRequest(req) - testify.Equal(t, req, c.Request()) + assert.Equal(t, req, c.Request()) } func TestContext_Scheme(t *testing.T) { @@ -798,14 +957,14 @@ func TestContext_Scheme(t *testing.T) { } for _, tt := range tests { - testify.Equal(t, tt.s, tt.c.Scheme()) + assert.Equal(t, tt.s, tt.c.Scheme()) } } func TestContext_IsWebSocket(t *testing.T) { tests := []struct { c Context - ws testify.BoolAssertionFunc + ws assert.BoolAssertionFunc }{ { &context{ @@ -813,7 +972,7 @@ func TestContext_IsWebSocket(t *testing.T) { Header: http.Header{HeaderUpgrade: []string{"websocket"}}, }, }, - testify.True, + assert.True, }, { &context{ @@ -821,13 +980,13 @@ func TestContext_IsWebSocket(t *testing.T) { Header: http.Header{HeaderUpgrade: []string{"Websocket"}}, }, }, - testify.True, + assert.True, }, { &context{ request: &http.Request{}, }, - testify.False, + assert.False, }, { &context{ @@ -835,7 +994,7 @@ func TestContext_IsWebSocket(t *testing.T) { Header: http.Header{HeaderUpgrade: []string{"other"}}, }, }, - testify.False, + assert.False, }, } @@ -854,8 +1013,8 @@ func TestContext_Bind(t *testing.T) { req.Header.Add(HeaderContentType, MIMEApplicationJSON) err := c.Bind(u) - testify.NoError(t, err) - testify.Equal(t, &user{1, "Jon Snow"}, u) + assert.NoError(t, err) + assert.Equal(t, &user{1, "Jon Snow"}, u) } func TestContext_Logger(t *testing.T) { @@ -863,15 +1022,15 @@ func TestContext_Logger(t *testing.T) { c := e.NewContext(nil, nil) log1 := c.Logger() - testify.NotNil(t, log1) + assert.NotNil(t, log1) log2 := log.New("echo2") c.SetLogger(log2) - testify.Equal(t, log2, c.Logger()) + assert.Equal(t, log2, c.Logger()) // Resetting the context returns the initial logger c.Reset(nil, nil) - testify.Equal(t, log1, c.Logger()) + assert.Equal(t, log1, c.Logger()) } func TestContext_RealIP(t *testing.T) { @@ -959,6 +1118,6 @@ func TestContext_RealIP(t *testing.T) { } for _, tt := range tests { - testify.Equal(t, tt.s, tt.c.RealIP()) + assert.Equal(t, tt.s, tt.c.RealIP()) } } diff --git a/echo.go b/echo.go index 22a5b7af9..ea6ba1619 100644 --- a/echo.go +++ b/echo.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + /* Package echo implements high performance, minimalist Go web framework. @@ -42,7 +45,6 @@ import ( "encoding/json" "errors" "fmt" - "io" stdLog "log" "net" "net/http" @@ -60,97 +62,90 @@ import ( "golang.org/x/net/http2/h2c" ) -type ( - // Echo is the top-level framework instance. - // - // Goroutine safety: Do not mutate Echo instance fields after server has started. Accessing these - // fields from handlers/middlewares and changing field values at the same time leads to data-races. - // Adding new routes after the server has been started is also not safe! - Echo struct { - filesystem - common - // startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get - // listener address info (on which interface/port was listener binded) without having data races. - startupMutex sync.RWMutex - colorer *color.Color - - // premiddleware are middlewares that are run before routing is done. In case a pre-middleware returns - // an error the router is not executed and the request will end up in the global error handler. - premiddleware []MiddlewareFunc - middleware []MiddlewareFunc - maxParam *int - router *Router - routers map[string]*Router - pool sync.Pool - - StdLogger *stdLog.Logger - Server *http.Server - TLSServer *http.Server - Listener net.Listener - TLSListener net.Listener - AutoTLSManager autocert.Manager - DisableHTTP2 bool - Debug bool - HideBanner bool - HidePort bool - HTTPErrorHandler HTTPErrorHandler - Binder Binder - JSONSerializer JSONSerializer - Validator Validator - Renderer Renderer - Logger Logger - IPExtractor IPExtractor - ListenerNetwork string - - // OnAddRouteHandler is called when Echo adds new route to specific host router. - OnAddRouteHandler func(host string, route Route, handler HandlerFunc, middleware []MiddlewareFunc) - } - - // Route contains a handler and information for matching against requests. - Route struct { - Method string `json:"method"` - Path string `json:"path"` - Name string `json:"name"` - } - - // HTTPError represents an error that occurred while handling a request. - HTTPError struct { - Code int `json:"-"` - Message interface{} `json:"message"` - Internal error `json:"-"` // Stores the error returned by an external dependency - } - - // MiddlewareFunc defines a function to process middleware. - MiddlewareFunc func(next HandlerFunc) HandlerFunc - - // HandlerFunc defines a function to serve HTTP requests. - HandlerFunc func(c Context) error - - // HTTPErrorHandler is a centralized HTTP error handler. - HTTPErrorHandler func(err error, c Context) - - // Validator is the interface that wraps the Validate function. - Validator interface { - Validate(i interface{}) error - } - - // JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. - JSONSerializer interface { - Serialize(c Context, i interface{}, indent string) error - Deserialize(c Context, i interface{}) error - } - - // Renderer is the interface that wraps the Render function. - Renderer interface { - Render(io.Writer, string, interface{}, Context) error - } - - // Map defines a generic map of type `map[string]interface{}`. - Map map[string]interface{} - - // Common struct for Echo & Group. - common struct{} -) +// Echo is the top-level framework instance. +// +// Goroutine safety: Do not mutate Echo instance fields after server has started. Accessing these +// fields from handlers/middlewares and changing field values at the same time leads to data-races. +// Adding new routes after the server has been started is also not safe! +type Echo struct { + filesystem + common + // startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get + // listener address info (on which interface/port was listener bound) without having data races. + startupMutex sync.RWMutex + colorer *color.Color + + // premiddleware are middlewares that are run before routing is done. In case a pre-middleware returns + // an error the router is not executed and the request will end up in the global error handler. + premiddleware []MiddlewareFunc + middleware []MiddlewareFunc + maxParam *int + router *Router + routers map[string]*Router + pool sync.Pool + + StdLogger *stdLog.Logger + Server *http.Server + TLSServer *http.Server + Listener net.Listener + TLSListener net.Listener + AutoTLSManager autocert.Manager + HTTPErrorHandler HTTPErrorHandler + Binder Binder + JSONSerializer JSONSerializer + Validator Validator + Renderer Renderer + Logger Logger + IPExtractor IPExtractor + ListenerNetwork string + + // OnAddRouteHandler is called when Echo adds new route to specific host router. + OnAddRouteHandler func(host string, route Route, handler HandlerFunc, middleware []MiddlewareFunc) + DisableHTTP2 bool + Debug bool + HideBanner bool + HidePort bool +} + +// Route contains a handler and information for matching against requests. +type Route struct { + Method string `json:"method"` + Path string `json:"path"` + Name string `json:"name"` +} + +// HTTPError represents an error that occurred while handling a request. +type HTTPError struct { + Internal error `json:"-"` // Stores the error returned by an external dependency + Message interface{} `json:"message"` + Code int `json:"-"` +} + +// MiddlewareFunc defines a function to process middleware. +type MiddlewareFunc func(next HandlerFunc) HandlerFunc + +// HandlerFunc defines a function to serve HTTP requests. +type HandlerFunc func(c Context) error + +// HTTPErrorHandler is a centralized HTTP error handler. +type HTTPErrorHandler func(err error, c Context) + +// Validator is the interface that wraps the Validate function. +type Validator interface { + Validate(i interface{}) error +} + +// JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. +type JSONSerializer interface { + Serialize(c Context, i interface{}, indent string) error + Deserialize(c Context, i interface{}) error +} + +// Map defines a generic map of type `map[string]interface{}`. +type Map map[string]interface{} + +// Common struct for Echo & Group. +type common struct{} // HTTP methods // NOTE: Deprecated, please use the stdlib constants directly instead. @@ -169,7 +164,12 @@ const ( // MIME types const ( - MIMEApplicationJSON = "application/json" + // MIMEApplicationJSON JavaScript Object Notation (JSON) https://www.rfc-editor.org/rfc/rfc8259 + MIMEApplicationJSON = "application/json" + // Deprecated: Please use MIMEApplicationJSON instead. JSON should be encoded using UTF-8 by default. + // No "charset" parameter is defined for this registration. + // Adding one really has no effect on compliant recipients. + // See RFC 8259, section 8.1. https://datatracker.ietf.org/doc/html/rfc8259#section-8.1 MIMEApplicationJSONCharsetUTF8 = MIMEApplicationJSON + "; " + charsetUTF8 MIMEApplicationJavaScript = "application/javascript" MIMEApplicationJavaScriptCharsetUTF8 = MIMEApplicationJavaScript + "; " + charsetUTF8 @@ -259,7 +259,7 @@ const ( const ( // Version of Echo - Version = "4.11.1" + Version = "4.13.4" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` @@ -274,21 +274,19 @@ ____________________________________O/_______ ` ) -var ( - methods = [...]string{ - http.MethodConnect, - http.MethodDelete, - http.MethodGet, - http.MethodHead, - http.MethodOptions, - http.MethodPatch, - http.MethodPost, - PROPFIND, - http.MethodPut, - http.MethodTrace, - REPORT, - } -) +var methods = [...]string{ + http.MethodConnect, + http.MethodDelete, + http.MethodGet, + http.MethodHead, + http.MethodOptions, + http.MethodPatch, + http.MethodPost, + PROPFIND, + http.MethodPut, + http.MethodTrace, + REPORT, +} // Errors var ( @@ -341,22 +339,23 @@ var ( ErrInvalidListenerNetwork = errors.New("invalid listener network") ) -// Error handlers -var ( - NotFoundHandler = func(c Context) error { - return ErrNotFound - } +// NotFoundHandler is the handler that router uses in case there was no matching route found. Returns an error that results +// HTTP 404 status code. +var NotFoundHandler = func(c Context) error { + return ErrNotFound +} - MethodNotAllowedHandler = func(c Context) error { - // See RFC 7231 section 7.4.1: An origin server MUST generate an Allow field in a 405 (Method Not Allowed) - // response and MAY do so in any other response. For disabled resources an empty Allow header may be returned - routerAllowMethods, ok := c.Get(ContextKeyHeaderAllow).(string) - if ok && routerAllowMethods != "" { - c.Response().Header().Set(HeaderAllow, routerAllowMethods) - } - return ErrMethodNotAllowed +// MethodNotAllowedHandler is the handler thar router uses in case there was no matching route found but there was +// another matching routes for that requested URL. Returns an error that results HTTP 405 Method Not Allowed status code. +var MethodNotAllowedHandler = func(c Context) error { + // See RFC 7231 section 7.4.1: An origin server MUST generate an Allow field in a 405 (Method Not Allowed) + // response and MAY do so in any other response. For disabled resources an empty Allow header may be returned + routerAllowMethods, ok := c.Get(ContextKeyHeaderAllow).(string) + if ok && routerAllowMethods != "" { + c.Response().Header().Set(HeaderAllow, routerAllowMethods) } -) + return ErrMethodNotAllowed +} // New creates an instance of Echo. func New() (e *Echo) { @@ -414,7 +413,7 @@ func (e *Echo) Routers() map[string]*Router { // // NOTE: In case errors happens in middleware call-chain that is returning from handler (which did not return an error). // When handler has already sent response (ala c.JSON()) and there is error in middleware that is returning from -// handler. Then the error that global error handler received will be ignored because we have already "commited" the +// handler. Then the error that global error handler received will be ignored because we have already "committed" the // response and status code header has been sent to the client. func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { diff --git a/echo_fs.go b/echo_fs.go index 9f83a0351..0ffc4b0bf 100644 --- a/echo_fs.go +++ b/echo_fs.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( @@ -99,8 +102,8 @@ func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc { // traverse up from current executable run path. // NB: private because you really should use fs.FS implementation instances type defaultFS struct { - prefix string fs fs.FS + prefix string } func newDefaultFS() *defaultFS { diff --git a/echo_fs_test.go b/echo_fs_test.go index eb072a28d..ab8faa7fa 100644 --- a/echo_fs_test.go +++ b/echo_fs_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( @@ -242,19 +245,16 @@ func TestEcho_FileFS(t *testing.T) { func TestEcho_StaticPanic(t *testing.T) { var testCases = []struct { - name string - givenRoot string - expectError string + name string + givenRoot string }{ { - name: "panics for ../", - givenRoot: "../assets", - expectError: "can not create sub FS, invalid root given, err: sub ../assets: invalid name", + name: "panics for ../", + givenRoot: "../assets", }, { - name: "panics for /", - givenRoot: "/assets", - expectError: "can not create sub FS, invalid root given, err: sub /assets: invalid name", + name: "panics for /", + givenRoot: "/assets", }, } @@ -263,7 +263,7 @@ func TestEcho_StaticPanic(t *testing.T) { e := New() e.Filesystem = os.DirFS("./") - assert.PanicsWithError(t, tc.expectError, func() { + assert.Panics(t, func() { e.Static("../assets", tc.givenRoot) }) }) diff --git a/echo_test.go b/echo_test.go index a352e4026..b7f32017a 100644 --- a/echo_test.go +++ b/echo_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( @@ -22,12 +25,10 @@ import ( "golang.org/x/net/http2" ) -type ( - user struct { - ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"` - Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"` - } -) +type user struct { + ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"` + Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"` +} const ( userJSON = `{"id":1,"name":"Jon Snow"}` @@ -1572,7 +1573,7 @@ func TestEcho_OnAddRouteHandler(t *testing.T) { }) } - e.GET("/static", NotFoundHandler) + e.GET("/static", dummyHandler) e.Host("domain.site").GET("/static/*", dummyHandler, func(next HandlerFunc) HandlerFunc { return func(c Context) error { return next(c) @@ -1582,7 +1583,7 @@ func TestEcho_OnAddRouteHandler(t *testing.T) { assert.Len(t, added, 2) assert.Equal(t, "", added[0].host) - assert.Equal(t, Route{Method: http.MethodGet, Path: "/static", Name: "github.com/labstack/echo/v4.glob..func1"}, added[0].route) + assert.Equal(t, Route{Method: http.MethodGet, Path: "/static", Name: "github.com/labstack/echo/v4.TestEcho_OnAddRouteHandler.func1"}, added[0].route) assert.Len(t, added[0].middleware, 0) assert.Equal(t, "domain.site", added[1].host) @@ -1597,6 +1598,11 @@ func TestEchoReverse(t *testing.T) { whenParams []interface{} expect string }{ + { + name: "ok, not existing path returns empty url", + whenRouteName: "not-existing", + expect: "", + }, { name: "ok,static with no params", whenRouteName: "/static", diff --git a/go.mod b/go.mod index fe2fd4e54..caaeec44b 100644 --- a/go.mod +++ b/go.mod @@ -1,24 +1,23 @@ module github.com/labstack/echo/v4 -go 1.17 +go 1.23.0 require ( - github.com/golang-jwt/jwt v3.2.2+incompatible - github.com/labstack/gommon v0.4.0 - github.com/stretchr/testify v1.8.1 + github.com/labstack/gommon v0.4.2 + github.com/stretchr/testify v1.10.0 github.com/valyala/fasttemplate v1.2.2 - golang.org/x/crypto v0.11.0 - golang.org/x/net v0.12.0 - golang.org/x/time v0.3.0 + golang.org/x/crypto v0.41.0 + golang.org/x/net v0.43.0 + golang.org/x/time v0.12.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/sys v0.10.0 // indirect - golang.org/x/text v0.11.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 41490b181..9306cb9e6 100644 --- a/go.sum +++ b/go.sum @@ -1,87 +1,31 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= -github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/labstack/gommon v0.4.0 h1:y7cvthEAEbU0yHOf4axH8ZG2NH8knB9iNSoTO8dyIk8= -github.com/labstack/gommon v0.4.0/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= -github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= -github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= +github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= -golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= -golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= +golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= -golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= -golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= -golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/group.go b/group.go index 749a5caab..cb37b123f 100644 --- a/group.go +++ b/group.go @@ -1,21 +1,22 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( "net/http" ) -type ( - // Group is a set of sub-routes for a specified route. It can be used for inner - // routes that share a common middleware or functionality that should be separate - // from the parent echo instance while still inheriting from it. - Group struct { - common - host string - prefix string - middleware []MiddlewareFunc - echo *Echo - } -) +// Group is a set of sub-routes for a specified route. It can be used for inner +// routes that share a common middleware or functionality that should be separate +// from the parent echo instance while still inheriting from it. +type Group struct { + common + host string + prefix string + echo *Echo + middleware []MiddlewareFunc +} // Use implements `Echo#Use()` for sub-routes within the Group. func (g *Group) Use(middleware ...MiddlewareFunc) { diff --git a/group_fs.go b/group_fs.go index aedc4c6a9..c1b7ec2d3 100644 --- a/group_fs.go +++ b/group_fs.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/group_fs_test.go b/group_fs_test.go index 958d9efb1..caa200940 100644 --- a/group_fs_test.go +++ b/group_fs_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( @@ -72,19 +75,16 @@ func TestGroup_FileFS(t *testing.T) { func TestGroup_StaticPanic(t *testing.T) { var testCases = []struct { - name string - givenRoot string - expectError string + name string + givenRoot string }{ { - name: "panics for ../", - givenRoot: "../images", - expectError: "can not create sub FS, invalid root given, err: sub ../images: invalid name", + name: "panics for ../", + givenRoot: "../images", }, { - name: "panics for /", - givenRoot: "/images", - expectError: "can not create sub FS, invalid root given, err: sub /images: invalid name", + name: "panics for /", + givenRoot: "/images", }, } @@ -95,7 +95,7 @@ func TestGroup_StaticPanic(t *testing.T) { g := e.Group("/assets") - assert.PanicsWithError(t, tc.expectError, func() { + assert.Panics(t, func() { g.Static("/images", tc.givenRoot) }) }) diff --git a/group_test.go b/group_test.go index d22f564b0..a97371418 100644 --- a/group_test.go +++ b/group_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/ip.go b/ip.go index 1bcd756ae..1fcd750ec 100644 --- a/ip.go +++ b/ip.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( @@ -21,7 +24,7 @@ To retrieve IP address reliably/securely, you must let your application be aware In Echo, this can be done by configuring `Echo#IPExtractor` appropriately. This guides show you why and how. -> Note: if you dont' set `Echo#IPExtractor` explicitly, Echo fallback to legacy behavior, which is not a good choice. +> Note: if you don't set `Echo#IPExtractor` explicitly, Echo fallback to legacy behavior, which is not a good choice. Let's start from two questions to know the right direction: @@ -64,7 +67,7 @@ XFF: "x" "x, a" "x, a, b" ``` In this case, use **first _untrustable_ IP reading from right**. Never use first one reading from left, as it is -configurable by client. Here "trustable" means "you are sure the IP address belongs to your infrastructre". +configurable by client. Here "trustable" means "you are sure the IP address belongs to your infrastructure". In above example, if `b` and `c` are trustable, the IP address of the client is `a` for both cases, never be `x`. In Echo, use `ExtractIPFromXFFHeader(...TrustOption)`. @@ -131,10 +134,10 @@ Private IPv6 address ranges: */ type ipChecker struct { + trustExtraRanges []*net.IPNet trustLoopback bool trustLinkLocal bool trustPrivateNet bool - trustExtraRanges []*net.IPNet } // TrustOption is config for which IP address to trust @@ -216,8 +219,14 @@ func ExtractIPDirect() IPExtractor { } func extractIP(req *http.Request) string { - ra, _, _ := net.SplitHostPort(req.RemoteAddr) - return ra + host, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + if net.ParseIP(req.RemoteAddr) != nil { + return req.RemoteAddr + } + return "" + } + return host } // ExtractIPFromRealIPHeader extracts IP address using x-real-ip header. @@ -225,15 +234,21 @@ func extractIP(req *http.Request) string { func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor { checker := newIPChecker(options) return func(req *http.Request) string { + directIP := extractIP(req) realIP := req.Header.Get(HeaderXRealIP) - if realIP != "" { + if realIP == "" { + return directIP + } + + if checker.trust(net.ParseIP(directIP)) { realIP = strings.TrimPrefix(realIP, "[") realIP = strings.TrimSuffix(realIP, "]") - if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) { + if rIP := net.ParseIP(realIP); rIP != nil { return realIP } } - return extractIP(req) + + return directIP } } diff --git a/ip_test.go b/ip_test.go index 38c4a1cac..e850b78cb 100644 --- a/ip_test.go +++ b/ip_test.go @@ -1,10 +1,14 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( - "github.com/stretchr/testify/assert" "net" "net/http" "testing" + + "github.com/stretchr/testify/assert" ) func mustParseCIDR(s string) *net.IPNet { @@ -375,6 +379,34 @@ func TestExtractIPDirect(t *testing.T) { }, expectIP: "203.0.113.1", }, + { + name: "remote addr is IP without port, extracts IP directly", + whenRequest: http.Request{ + RemoteAddr: "203.0.113.1", + }, + expectIP: "203.0.113.1", + }, + { + name: "remote addr is IPv6 without port, extracts IP directly", + whenRequest: http.Request{ + RemoteAddr: "2001:db8::1", + }, + expectIP: "2001:db8::1", + }, + { + name: "remote addr is IPv6 with port", + whenRequest: http.Request{ + RemoteAddr: "[2001:db8::1]:8080", + }, + expectIP: "2001:db8::1", + }, + { + name: "remote addr is invalid, returns empty string", + whenRequest: http.Request{ + RemoteAddr: "invalid-ip-format", + }, + expectIP: "", + }, { name: "request is from external IP has X-Real-Ip header, extractor still extracts IP from request remote addr", whenRequest: http.Request{ @@ -458,7 +490,7 @@ func TestExtractIPDirect(t *testing.T) { } func TestExtractIPFromRealIPHeader(t *testing.T) { - _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") + _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.0/24") _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64") var testCases = []struct { @@ -486,36 +518,42 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, { name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + }, whenRequest: http.Request{ Header: http.Header{ - HeaderXRealIP: []string{"203.0.113.199"}, // <-- this is untrusted + HeaderXRealIP: []string{"203.0.113.199"}, }, - RemoteAddr: "203.0.113.1:8080", + RemoteAddr: "8.8.8.8:8080", // <-- this is untrusted }, - expectIP: "203.0.113.1", + expectIP: "8.8.8.8", }, { name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + }, whenRequest: http.Request{ Header: http.Header{ - HeaderXRealIP: []string{"[2001:db8::113:199]"}, // <-- this is untrusted + HeaderXRealIP: []string{"[bc01:1010::9090:1888]"}, }, - RemoteAddr: "[2001:db8::113:1]:8080", + RemoteAddr: "[fe64:aa10::1]:8080", // <-- this is untrusted }, - expectIP: "2001:db8::113:1", + expectIP: "fe64:aa10::1", }, { name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" - TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.0/24" }, whenRequest: http.Request{ Header: http.Header{ - HeaderXRealIP: []string{"203.0.113.199"}, + HeaderXRealIP: []string{"8.8.8.8"}, }, RemoteAddr: "203.0.113.1:8080", }, - expectIP: "203.0.113.199", + expectIP: "8.8.8.8", }, { name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", @@ -524,11 +562,11 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, whenRequest: http.Request{ Header: http.Header{ - HeaderXRealIP: []string{"[2001:db8::113:199]"}, + HeaderXRealIP: []string{"[fe64:db8::113:199]"}, }, RemoteAddr: "[2001:db8::113:1]:8080", }, - expectIP: "2001:db8::113:199", + expectIP: "fe64:db8::113:199", }, { name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", @@ -537,12 +575,12 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, whenRequest: http.Request{ Header: http.Header{ - HeaderXRealIP: []string{"203.0.113.199"}, - HeaderXForwardedFor: []string{"203.0.113.198, 203.0.113.197"}, // <-- should not affect anything + HeaderXRealIP: []string{"8.8.8.8"}, + HeaderXForwardedFor: []string{"1.1.1.1 ,8.8.8.8"}, // <-- should not affect anything }, RemoteAddr: "203.0.113.1:8080", }, - expectIP: "203.0.113.199", + expectIP: "8.8.8.8", }, { name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", @@ -551,12 +589,12 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, whenRequest: http.Request{ Header: http.Header{ - HeaderXRealIP: []string{"[2001:db8::113:199]"}, - HeaderXForwardedFor: []string{"[2001:db8::113:198], [2001:db8::113:197]"}, // <-- should not affect anything + HeaderXRealIP: []string{"[fe64:db8::113:199]"}, + HeaderXForwardedFor: []string{"[feab:cde9::113:198], [fe64:db8::113:199]"}, // <-- should not affect anything }, RemoteAddr: "[2001:db8::113:1]:8080", }, - expectIP: "2001:db8::113:199", + expectIP: "fe64:db8::113:199", }, } diff --git a/json.go b/json.go index 16b2d0577..6da0aaf97 100644 --- a/json.go +++ b/json.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/json_test.go b/json_test.go index 27ee43e73..0b15ed1a1 100644 --- a/json_test.go +++ b/json_test.go @@ -1,7 +1,10 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( - testify "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" "strings" @@ -16,16 +19,14 @@ func TestDefaultJSONCodec_Encode(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*context) - assert := testify.New(t) - // Echo - assert.Equal(e, c.Echo()) + assert.Equal(t, e, c.Echo()) // Request - assert.NotNil(c.Request()) + assert.NotNil(t, c.Request()) // Response - assert.NotNil(c.Response()) + assert.NotNil(t, c.Response()) //-------- // Default JSON encoder @@ -34,16 +35,16 @@ func TestDefaultJSONCodec_Encode(t *testing.T) { enc := new(DefaultJSONSerializer) err := enc.Serialize(c, user{1, "Jon Snow"}, "") - if assert.NoError(err) { - assert.Equal(userJSON+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, userJSON+"\n", rec.Body.String()) } req = httptest.NewRequest(http.MethodPost, "/", nil) rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) err = enc.Serialize(c, user{1, "Jon Snow"}, " ") - if assert.NoError(err) { - assert.Equal(userJSONPretty+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) } } @@ -55,16 +56,14 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*context) - assert := testify.New(t) - // Echo - assert.Equal(e, c.Echo()) + assert.Equal(t, e, c.Echo()) // Request - assert.NotNil(c.Request()) + assert.NotNil(t, c.Request()) // Response - assert.NotNil(c.Response()) + assert.NotNil(t, c.Response()) //-------- // Default JSON encoder @@ -74,8 +73,8 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { var u = user{} err := enc.Deserialize(c, &u) - if assert.NoError(err) { - assert.Equal(u, user{ID: 1, Name: "Jon Snow"}) + if assert.NoError(t, err) { + assert.Equal(t, u, user{ID: 1, Name: "Jon Snow"}) } var userUnmarshalSyntaxError = user{} @@ -83,8 +82,8 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) err = enc.Deserialize(c, &userUnmarshalSyntaxError) - assert.IsType(&HTTPError{}, err) - assert.EqualError(err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value") + assert.IsType(t, &HTTPError{}, err) + assert.EqualError(t, err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value") var userUnmarshalTypeError = struct { ID string `json:"id"` @@ -95,7 +94,7 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) err = enc.Deserialize(c, &userUnmarshalTypeError) - assert.IsType(&HTTPError{}, err) - assert.EqualError(err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string") + assert.IsType(t, &HTTPError{}, err) + assert.EqualError(t, err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string") } diff --git a/log.go b/log.go index 3f8de5904..0acd9ff03 100644 --- a/log.go +++ b/log.go @@ -1,41 +1,41 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( - "io" - "github.com/labstack/gommon/log" + "io" ) -type ( - // Logger defines the logging interface. - Logger interface { - Output() io.Writer - SetOutput(w io.Writer) - Prefix() string - SetPrefix(p string) - Level() log.Lvl - SetLevel(v log.Lvl) - SetHeader(h string) - Print(i ...interface{}) - Printf(format string, args ...interface{}) - Printj(j log.JSON) - Debug(i ...interface{}) - Debugf(format string, args ...interface{}) - Debugj(j log.JSON) - Info(i ...interface{}) - Infof(format string, args ...interface{}) - Infoj(j log.JSON) - Warn(i ...interface{}) - Warnf(format string, args ...interface{}) - Warnj(j log.JSON) - Error(i ...interface{}) - Errorf(format string, args ...interface{}) - Errorj(j log.JSON) - Fatal(i ...interface{}) - Fatalj(j log.JSON) - Fatalf(format string, args ...interface{}) - Panic(i ...interface{}) - Panicj(j log.JSON) - Panicf(format string, args ...interface{}) - } -) +// Logger defines the logging interface. +type Logger interface { + Output() io.Writer + SetOutput(w io.Writer) + Prefix() string + SetPrefix(p string) + Level() log.Lvl + SetLevel(v log.Lvl) + SetHeader(h string) + Print(i ...interface{}) + Printf(format string, args ...interface{}) + Printj(j log.JSON) + Debug(i ...interface{}) + Debugf(format string, args ...interface{}) + Debugj(j log.JSON) + Info(i ...interface{}) + Infof(format string, args ...interface{}) + Infoj(j log.JSON) + Warn(i ...interface{}) + Warnf(format string, args ...interface{}) + Warnj(j log.JSON) + Error(i ...interface{}) + Errorf(format string, args ...interface{}) + Errorj(j log.JSON) + Fatal(i ...interface{}) + Fatalj(j log.JSON) + Fatalf(format string, args ...interface{}) + Panic(i ...interface{}) + Panicj(j log.JSON) + Panicf(format string, args ...interface{}) +} diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index f9e8caafe..9285f29fd 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -9,37 +12,35 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // BasicAuthConfig defines the config for BasicAuth middleware. - BasicAuthConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// BasicAuthConfig defines the config for BasicAuth middleware. +type BasicAuthConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Validator is a function to validate BasicAuth credentials. - // Required. - Validator BasicAuthValidator + // Validator is a function to validate BasicAuth credentials. + // Required. + Validator BasicAuthValidator - // Realm is a string to define realm attribute of BasicAuth. - // Default value "Restricted". - Realm string - } + // Realm is a string to define realm attribute of BasicAuth. + // Default value "Restricted". + Realm string +} - // BasicAuthValidator defines a function to validate BasicAuth credentials. - BasicAuthValidator func(string, string, echo.Context) (bool, error) -) +// BasicAuthValidator defines a function to validate BasicAuth credentials. +// The function should return a boolean indicating whether the credentials are valid, +// and an error if any error occurs during the validation process. +type BasicAuthValidator func(string, string, echo.Context) (bool, error) const ( basic = "basic" defaultRealm = "Restricted" ) -var ( - // DefaultBasicAuthConfig is the default BasicAuth middleware config. - DefaultBasicAuthConfig = BasicAuthConfig{ - Skipper: DefaultSkipper, - Realm: defaultRealm, - } -) +// DefaultBasicAuthConfig is the default BasicAuth middleware config. +var DefaultBasicAuthConfig = BasicAuthConfig{ + Skipper: DefaultSkipper, + Realm: defaultRealm, +} // BasicAuth returns an BasicAuth middleware. // diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 20e769214..b3abfa172 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -1,7 +1,11 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( "encoding/base64" + "errors" "net/http" "net/http/httptest" "strings" @@ -13,63 +17,103 @@ import ( func TestBasicAuth(t *testing.T) { e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - c := e.NewContext(req, res) - f := func(u, p string, c echo.Context) (bool, error) { + + mockValidator := func(u, p string, c echo.Context) (bool, error) { if u == "joe" && p == "secret" { return true, nil } return false, nil } - h := BasicAuth(f)(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) - - // Valid credentials - auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(t, h(c)) - h = BasicAuthWithConfig(BasicAuthConfig{ - Skipper: nil, - Validator: f, - Realm: "someRealm", - })(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) + tests := []struct { + name string + authHeader string + expectedCode int + expectedAuth string + skipperResult bool + expectedErr bool + expectedErrMsg string + }{ + { + name: "Valid credentials", + authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), + expectedCode: http.StatusOK, + }, + { + name: "Case-insensitive header scheme", + authHeader: strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), + expectedCode: http.StatusOK, + }, + { + name: "Invalid credentials", + authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")), + expectedCode: http.StatusUnauthorized, + expectedAuth: basic + ` realm="someRealm"`, + expectedErr: true, + expectedErrMsg: "Unauthorized", + }, + { + name: "Invalid base64 string", + authHeader: basic + " invalidString", + expectedCode: http.StatusBadRequest, + expectedErr: true, + expectedErrMsg: "Bad Request", + }, + { + name: "Missing Authorization header", + expectedCode: http.StatusUnauthorized, + expectedErr: true, + expectedErrMsg: "Unauthorized", + }, + { + name: "Invalid Authorization header", + authHeader: base64.StdEncoding.EncodeToString([]byte("invalid")), + expectedCode: http.StatusUnauthorized, + expectedErr: true, + expectedErrMsg: "Unauthorized", + }, + { + name: "Skipped Request", + authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip")), + expectedCode: http.StatusOK, + skipperResult: true, + }, + } - // Valid credentials - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(t, h(c)) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { - // Case-insensitive header scheme - auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(t, h(c)) + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + c := e.NewContext(req, res) - // Invalid credentials - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")) - req.Header.Set(echo.HeaderAuthorization, auth) - he := h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusUnauthorized, he.Code) - assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate)) + if tt.authHeader != "" { + req.Header.Set(echo.HeaderAuthorization, tt.authHeader) + } - // Invalid base64 string - auth = basic + " invalidString" - req.Header.Set(echo.HeaderAuthorization, auth) - he = h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusBadRequest, he.Code) + h := BasicAuthWithConfig(BasicAuthConfig{ + Validator: mockValidator, + Realm: "someRealm", + Skipper: func(c echo.Context) bool { + return tt.skipperResult + }, + })(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) - // Missing Authorization header - req.Header.Del(echo.HeaderAuthorization) - he = h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusUnauthorized, he.Code) + err := h(c) - // Invalid Authorization header - auth = base64.StdEncoding.EncodeToString([]byte("invalid")) - req.Header.Set(echo.HeaderAuthorization, auth) - he = h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusUnauthorized, he.Code) + if tt.expectedErr { + var he *echo.HTTPError + errors.As(err, &he) + assert.Equal(t, tt.expectedCode, he.Code) + if tt.expectedAuth != "" { + assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate)) + } + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedCode, res.Code) + } + }) + } } diff --git a/middleware/body_dump.go b/middleware/body_dump.go index fa7891b16..e4119ec1e 100644 --- a/middleware/body_dump.go +++ b/middleware/body_dump.go @@ -1,8 +1,12 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( "bufio" "bytes" + "errors" "io" "net" "net/http" @@ -10,32 +14,28 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // BodyDumpConfig defines the config for BodyDump middleware. - BodyDumpConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// BodyDumpConfig defines the config for BodyDump middleware. +type BodyDumpConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Handler receives request and response payload. - // Required. - Handler BodyDumpHandler - } + // Handler receives request and response payload. + // Required. + Handler BodyDumpHandler +} - // BodyDumpHandler receives the request and response payload. - BodyDumpHandler func(echo.Context, []byte, []byte) +// BodyDumpHandler receives the request and response payload. +type BodyDumpHandler func(echo.Context, []byte, []byte) - bodyDumpResponseWriter struct { - io.Writer - http.ResponseWriter - } -) +type bodyDumpResponseWriter struct { + io.Writer + http.ResponseWriter +} -var ( - // DefaultBodyDumpConfig is the default BodyDump middleware config. - DefaultBodyDumpConfig = BodyDumpConfig{ - Skipper: DefaultSkipper, - } -) +// DefaultBodyDumpConfig is the default BodyDump middleware config. +var DefaultBodyDumpConfig = BodyDumpConfig{ + Skipper: DefaultSkipper, +} // BodyDump returns a BodyDump middleware. // @@ -98,9 +98,16 @@ func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) { } func (w *bodyDumpResponseWriter) Flush() { - w.ResponseWriter.(http.Flusher).Flush() + err := http.NewResponseController(w.ResponseWriter).Flush() + if err != nil && errors.Is(err, http.ErrNotSupported) { + panic(errors.New("response writer flushing is not supported")) + } } func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return w.ResponseWriter.(http.Hijacker).Hijack() + return http.NewResponseController(w.ResponseWriter).Hijack() +} + +func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter } diff --git a/middleware/body_dump_test.go b/middleware/body_dump_test.go index de1de3356..e880af45b 100644 --- a/middleware/body_dump_test.go +++ b/middleware/body_dump_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -87,3 +90,53 @@ func TestBodyDumpFails(t *testing.T) { } }) } + +func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) { + bdrw := bodyDumpResponseWriter{ + ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush + } + + assert.PanicsWithError(t, "response writer flushing is not supported", func() { + bdrw.Flush() + }) +} + +func TestBodyDumpResponseWriter_CanFlush(t *testing.T) { + trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}} + bdrw := bodyDumpResponseWriter{ + ResponseWriter: &trwu, + } + + bdrw.Flush() + assert.Equal(t, 1, trwu.unwrapCalled) +} + +func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) { + trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()} + bdrw := bodyDumpResponseWriter{ + ResponseWriter: trwu, + } + + result := bdrw.Unwrap() + assert.Equal(t, trwu, result) +} + +func TestBodyDumpResponseWriter_CanHijack(t *testing.T) { + trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}} + bdrw := bodyDumpResponseWriter{ + ResponseWriter: &trwu, // this RW supports hijacking through unwrapping + } + + _, _, err := bdrw.Hijack() + assert.EqualError(t, err, "can hijack") +} + +func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) { + trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()} + bdrw := bodyDumpResponseWriter{ + ResponseWriter: &trwu, // this RW supports hijacking through unwrapping + } + + _, _, err := bdrw.Hijack() + assert.EqualError(t, err, "feature not supported") +} diff --git a/middleware/body_limit.go b/middleware/body_limit.go index b436bd595..7d3c665f2 100644 --- a/middleware/body_limit.go +++ b/middleware/body_limit.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -9,32 +12,27 @@ import ( "github.com/labstack/gommon/bytes" ) -type ( - // BodyLimitConfig defines the config for BodyLimit middleware. - BodyLimitConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// BodyLimitConfig defines the config for BodyLimit middleware. +type BodyLimitConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Maximum allowed size for a request body, it can be specified - // as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P. - Limit string `yaml:"limit"` - limit int64 - } + // Maximum allowed size for a request body, it can be specified + // as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P. + Limit string `yaml:"limit"` + limit int64 +} - limitedReader struct { - BodyLimitConfig - reader io.ReadCloser - read int64 - context echo.Context - } -) +type limitedReader struct { + BodyLimitConfig + reader io.ReadCloser + read int64 +} -var ( - // DefaultBodyLimitConfig is the default BodyLimit middleware config. - DefaultBodyLimitConfig = BodyLimitConfig{ - Skipper: DefaultSkipper, - } -) +// DefaultBodyLimitConfig is the default BodyLimit middleware config. +var DefaultBodyLimitConfig = BodyLimitConfig{ + Skipper: DefaultSkipper, +} // BodyLimit returns a BodyLimit middleware. // @@ -80,7 +78,7 @@ func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { // Based on content read r := pool.Get().(*limitedReader) - r.Reset(req.Body, c) + r.Reset(req.Body) defer pool.Put(r) req.Body = r @@ -102,9 +100,8 @@ func (r *limitedReader) Close() error { return r.reader.Close() } -func (r *limitedReader) Reset(reader io.ReadCloser, context echo.Context) { +func (r *limitedReader) Reset(reader io.ReadCloser) { r.reader = reader - r.context = context r.read = 0 } diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index 2bfce372a..d14c2b649 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -56,9 +59,6 @@ func TestBodyLimit(t *testing.T) { func TestBodyLimitReader(t *testing.T) { hw := []byte("Hello, World!") - e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) - rec := httptest.NewRecorder() config := BodyLimitConfig{ Skipper: DefaultSkipper, @@ -68,7 +68,6 @@ func TestBodyLimitReader(t *testing.T) { reader := &limitedReader{ BodyLimitConfig: config, reader: io.NopCloser(bytes.NewReader(hw)), - context: e.NewContext(req, rec), } // read all should return ErrStatusRequestEntityTooLarge @@ -78,7 +77,7 @@ func TestBodyLimitReader(t *testing.T) { // reset reader and read two bytes must succeed bt := make([]byte, 2) - reader.Reset(io.NopCloser(bytes.NewReader(hw)), e.NewContext(req, rec)) + reader.Reset(io.NopCloser(bytes.NewReader(hw))) n, err := reader.Read(bt) assert.Equal(t, 2, n) assert.Equal(t, nil, err) diff --git a/middleware/compress.go b/middleware/compress.go index 3e9bd3201..012b76b01 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -13,54 +16,50 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // GzipConfig defines the config for Gzip middleware. - GzipConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Gzip compression level. - // Optional. Default value -1. - Level int `yaml:"level"` - - // Length threshold before gzip compression is applied. - // Optional. Default value 0. - // - // Most of the time you will not need to change the default. Compressing - // a short response might increase the transmitted data because of the - // gzip format overhead. Compressing the response will also consume CPU - // and time on the server and the client (for decompressing). Depending on - // your use case such a threshold might be useful. - // - // See also: - // https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits - MinLength int - } +// GzipConfig defines the config for Gzip middleware. +type GzipConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Gzip compression level. + // Optional. Default value -1. + Level int `yaml:"level"` + + // Length threshold before gzip compression is applied. + // Optional. Default value 0. + // + // Most of the time you will not need to change the default. Compressing + // a short response might increase the transmitted data because of the + // gzip format overhead. Compressing the response will also consume CPU + // and time on the server and the client (for decompressing). Depending on + // your use case such a threshold might be useful. + // + // See also: + // https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits + MinLength int +} - gzipResponseWriter struct { - io.Writer - http.ResponseWriter - wroteHeader bool - wroteBody bool - minLength int - minLengthExceeded bool - buffer *bytes.Buffer - code int - } -) +type gzipResponseWriter struct { + io.Writer + http.ResponseWriter + wroteHeader bool + wroteBody bool + minLength int + minLengthExceeded bool + buffer *bytes.Buffer + code int +} const ( gzipScheme = "gzip" ) -var ( - // DefaultGzipConfig is the default Gzip middleware config. - DefaultGzipConfig = GzipConfig{ - Skipper: DefaultSkipper, - Level: -1, - MinLength: 0, - } -) +// DefaultGzipConfig is the default Gzip middleware config. +var DefaultGzipConfig = GzipConfig{ + Skipper: DefaultSkipper, + Level: -1, + MinLength: 0, +} // Gzip returns a middleware which compresses HTTP response using gzip compression // scheme. @@ -191,13 +190,15 @@ func (w *gzipResponseWriter) Flush() { } w.Writer.(*gzip.Writer).Flush() - if flusher, ok := w.ResponseWriter.(http.Flusher); ok { - flusher.Flush() - } + _ = http.NewResponseController(w.ResponseWriter).Flush() +} + +func (w *gzipResponseWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter } func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return w.ResponseWriter.(http.Hijacker).Hijack() + return http.NewResponseController(w.ResponseWriter).Hijack() } func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { diff --git a/middleware/compress_test.go b/middleware/compress_test.go index 0ed16c813..4bbdfdbc2 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -311,6 +314,36 @@ func TestGzipWithStatic(t *testing.T) { } } +func TestGzipResponseWriter_CanUnwrap(t *testing.T) { + trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()} + bdrw := gzipResponseWriter{ + ResponseWriter: trwu, + } + + result := bdrw.Unwrap() + assert.Equal(t, trwu, result) +} + +func TestGzipResponseWriter_CanHijack(t *testing.T) { + trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}} + bdrw := gzipResponseWriter{ + ResponseWriter: &trwu, // this RW supports hijacking through unwrapping + } + + _, _, err := bdrw.Hijack() + assert.EqualError(t, err, "can hijack") +} + +func TestGzipResponseWriter_CanNotHijack(t *testing.T) { + trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()} + bdrw := gzipResponseWriter{ + ResponseWriter: &trwu, // this RW supports hijacking through unwrapping + } + + _, _, err := bdrw.Hijack() + assert.EqualError(t, err, "feature not supported") +} + func BenchmarkGzip(b *testing.B) { e := echo.New() diff --git a/middleware/context_timeout.go b/middleware/context_timeout.go index be260e188..e67173f21 100644 --- a/middleware/context_timeout.go +++ b/middleware/context_timeout.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -13,7 +16,7 @@ type ContextTimeoutConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper - // ErrorHandler is a function when error aries in middeware execution. + // ErrorHandler is a function when error aries in middleware execution. ErrorHandler func(err error, c echo.Context) error // Timeout configures a timeout for the middleware, defaults to 0 for no timeout diff --git a/middleware/context_timeout_test.go b/middleware/context_timeout_test.go index 24c6203e7..e69bcd268 100644 --- a/middleware/context_timeout_test.go +++ b/middleware/context_timeout_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/cors.go b/middleware/cors.go index 6ddb540af..a1f445321 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -9,112 +12,109 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // CORSConfig defines the config for CORS middleware. - CORSConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // AllowOrigins determines the value of the Access-Control-Allow-Origin - // response header. This header defines a list of origins that may access the - // resource. The wildcard characters '*' and '?' are supported and are - // converted to regex fragments '.*' and '.' accordingly. - // - // Security: use extreme caution when handling the origin, and carefully - // validate any logic. Remember that attackers may register hostile domain names. - // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html - // - // Optional. Default value []string{"*"}. - // - // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin - AllowOrigins []string `yaml:"allow_origins"` - - // AllowOriginFunc is a custom function to validate the origin. It takes the - // origin as an argument and returns true if allowed or false otherwise. If - // an error is returned, it is returned by the handler. If this option is - // set, AllowOrigins is ignored. - // - // Security: use extreme caution when handling the origin, and carefully - // validate any logic. Remember that attackers may register hostile domain names. - // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html - // - // Optional. - AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"` - - // AllowMethods determines the value of the Access-Control-Allow-Methods - // response header. This header specified the list of methods allowed when - // accessing the resource. This is used in response to a preflight request. - // - // Optional. Default value DefaultCORSConfig.AllowMethods. - // If `allowMethods` is left empty, this middleware will fill for preflight - // request `Access-Control-Allow-Methods` header value - // from `Allow` header that echo.Router set into context. - // - // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods - AllowMethods []string `yaml:"allow_methods"` - - // AllowHeaders determines the value of the Access-Control-Allow-Headers - // response header. This header is used in response to a preflight request to - // indicate which HTTP headers can be used when making the actual request. - // - // Optional. Default value []string{}. - // - // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers - AllowHeaders []string `yaml:"allow_headers"` - - // AllowCredentials determines the value of the - // Access-Control-Allow-Credentials response header. This header indicates - // whether or not the response to the request can be exposed when the - // credentials mode (Request.credentials) is true. When used as part of a - // response to a preflight request, this indicates whether or not the actual - // request can be made using credentials. See also - // [MDN: Access-Control-Allow-Credentials]. - // - // Optional. Default value false, in which case the header is not set. - // - // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. - // See "Exploiting CORS misconfigurations for Bitcoins and bounties", - // https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html - // - // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials - AllowCredentials bool `yaml:"allow_credentials"` - - // UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials - // flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header. - // - // This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties) - // attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject. - // - // Optional. Default value is false. - UnsafeWildcardOriginWithAllowCredentials bool `yaml:"unsafe_wildcard_origin_with_allow_credentials"` - - // ExposeHeaders determines the value of Access-Control-Expose-Headers, which - // defines a list of headers that clients are allowed to access. - // - // Optional. Default value []string{}, in which case the header is not set. - // - // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header - ExposeHeaders []string `yaml:"expose_headers"` - - // MaxAge determines the value of the Access-Control-Max-Age response header. - // This header indicates how long (in seconds) the results of a preflight - // request can be cached. - // - // Optional. Default value 0. The header is set only if MaxAge > 0. - // - // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age - MaxAge int `yaml:"max_age"` - } -) +// CORSConfig defines the config for CORS middleware. +type CORSConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper -var ( - // DefaultCORSConfig is the default CORS middleware config. - DefaultCORSConfig = CORSConfig{ - Skipper: DefaultSkipper, - AllowOrigins: []string{"*"}, - AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, - } -) + // AllowOrigins determines the value of the Access-Control-Allow-Origin + // response header. This header defines a list of origins that may access the + // resource. The wildcard characters '*' and '?' are supported and are + // converted to regex fragments '.*' and '.' accordingly. + // + // Security: use extreme caution when handling the origin, and carefully + // validate any logic. Remember that attackers may register hostile domain names. + // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // + // Optional. Default value []string{"*"}. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin + AllowOrigins []string `yaml:"allow_origins"` + + // AllowOriginFunc is a custom function to validate the origin. It takes the + // origin as an argument and returns true if allowed or false otherwise. If + // an error is returned, it is returned by the handler. If this option is + // set, AllowOrigins is ignored. + // + // Security: use extreme caution when handling the origin, and carefully + // validate any logic. Remember that attackers may register hostile domain names. + // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // + // Optional. + AllowOriginFunc func(origin string) (bool, error) `yaml:"-"` + + // AllowMethods determines the value of the Access-Control-Allow-Methods + // response header. This header specified the list of methods allowed when + // accessing the resource. This is used in response to a preflight request. + // + // Optional. Default value DefaultCORSConfig.AllowMethods. + // If `allowMethods` is left empty, this middleware will fill for preflight + // request `Access-Control-Allow-Methods` header value + // from `Allow` header that echo.Router set into context. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods + AllowMethods []string `yaml:"allow_methods"` + + // AllowHeaders determines the value of the Access-Control-Allow-Headers + // response header. This header is used in response to a preflight request to + // indicate which HTTP headers can be used when making the actual request. + // + // Optional. Default value []string{}. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers + AllowHeaders []string `yaml:"allow_headers"` + + // AllowCredentials determines the value of the + // Access-Control-Allow-Credentials response header. This header indicates + // whether or not the response to the request can be exposed when the + // credentials mode (Request.credentials) is true. When used as part of a + // response to a preflight request, this indicates whether or not the actual + // request can be made using credentials. See also + // [MDN: Access-Control-Allow-Credentials]. + // + // Optional. Default value false, in which case the header is not set. + // + // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. + // See "Exploiting CORS misconfigurations for Bitcoins and bounties", + // https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials + AllowCredentials bool `yaml:"allow_credentials"` + + // UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials + // flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header. + // + // This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties) + // attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject. + // + // Optional. Default value is false. + UnsafeWildcardOriginWithAllowCredentials bool `yaml:"unsafe_wildcard_origin_with_allow_credentials"` + + // ExposeHeaders determines the value of Access-Control-Expose-Headers, which + // defines a list of headers that clients are allowed to access. + // + // Optional. Default value []string{}, in which case the header is not set. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header + ExposeHeaders []string `yaml:"expose_headers"` + + // MaxAge determines the value of the Access-Control-Max-Age response header. + // This header indicates how long (in seconds) the results of a preflight + // request can be cached. + // The header is set only if MaxAge != 0, negative value sends "0" which instructs browsers not to cache that response. + // + // Optional. Default value 0 - meaning header is not sent. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age + MaxAge int `yaml:"max_age"` +} + +// DefaultCORSConfig is the default CORS middleware config. +var DefaultCORSConfig = CORSConfig{ + Skipper: DefaultSkipper, + AllowOrigins: []string{"*"}, + AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, +} // CORS returns a Cross-Origin Resource Sharing (CORS) middleware. // See also [MDN: Cross-Origin Resource Sharing (CORS)]. @@ -147,19 +147,35 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { config.AllowMethods = DefaultCORSConfig.AllowMethods } - allowOriginPatterns := []string{} + allowOriginPatterns := make([]*regexp.Regexp, 0, len(config.AllowOrigins)) for _, origin := range config.AllowOrigins { + if origin == "*" { + continue // "*" is handled differently and does not need regexp + } pattern := regexp.QuoteMeta(origin) pattern = strings.ReplaceAll(pattern, "\\*", ".*") pattern = strings.ReplaceAll(pattern, "\\?", ".") pattern = "^" + pattern + "$" - allowOriginPatterns = append(allowOriginPatterns, pattern) + + re, err := regexp.Compile(pattern) + if err != nil { + // this is to preserve previous behaviour - invalid patterns were just ignored. + // If we would turn this to panic, users with invalid patterns + // would have applications crashing in production due unrecovered panic. + // TODO: this should be turned to error/panic in `v5` + continue + } + allowOriginPatterns = append(allowOriginPatterns, re) } allowMethods := strings.Join(config.AllowMethods, ",") allowHeaders := strings.Join(config.AllowHeaders, ",") exposeHeaders := strings.Join(config.ExposeHeaders, ",") - maxAge := strconv.Itoa(config.MaxAge) + + maxAge := "0" + if config.MaxAge > 0 { + maxAge = strconv.Itoa(config.MaxAge) + } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -235,7 +251,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } if checkPatterns { for _, re := range allowOriginPatterns { - if match, _ := regexp.MatchString(re, origin); match { + if match := re.MatchString(origin); match { allowOrigin = origin break } @@ -282,7 +298,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { res.Header().Set(echo.HeaderAccessControlAllowHeaders, h) } } - if config.MaxAge > 0 { + if config.MaxAge != 0 { res.Header().Set(echo.HeaderAccessControlMaxAge, maxAge) } return c.NoContent(http.StatusNoContent) diff --git a/middleware/cors_test.go b/middleware/cors_test.go index c1bb91eb3..5461e9362 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -28,6 +31,18 @@ func TestCORS(t *testing.T) { name: "ok, wildcard AllowedOrigin with no Origin header in request", notExpectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: ""}, }, + { + name: "ok, invalid pattern is ignored", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{ + "\xff", // Invalid UTF-8 makes regexp.Compile to error + "*.example.com", + }, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{echo.HeaderOrigin: "http://aaa.example.com"}, + expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://aaa.example.com"}, + }, { name: "ok, specific AllowOrigins and AllowCredentials", givenMW: CORSWithConfig(CORSConfig{ @@ -60,6 +75,59 @@ func TestCORS(t *testing.T) { echo.HeaderAccessControlMaxAge: "3600", }, }, + { + name: "ok, preflight request when `Access-Control-Max-Age` is set", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"localhost"}, + AllowCredentials: true, + MaxAge: 1, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + expectHeaders: map[string]string{ + echo.HeaderAccessControlMaxAge: "1", + }, + }, + { + name: "ok, preflight request when `Access-Control-Max-Age` is set to 0 - not to cache response", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"localhost"}, + AllowCredentials: true, + MaxAge: -1, // forces `Access-Control-Max-Age: 0` + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + expectHeaders: map[string]string{ + echo.HeaderAccessControlMaxAge: "0", + }, + }, + { + name: "ok, CORS check are skipped", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"localhost"}, + AllowCredentials: true, + Skipper: func(c echo.Context) bool { + return true + }, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + notExpectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "localhost", + echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + echo.HeaderAccessControlAllowCredentials: "true", + echo.HeaderAccessControlMaxAge: "3600", + }, + }, { name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` true", givenMW: CORSWithConfig(CORSConfig{ diff --git a/middleware/csrf.go b/middleware/csrf.go index 6899700c7..92f4019dc 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -6,85 +9,80 @@ import ( "time" "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" ) -type ( - // CSRFConfig defines the config for CSRF middleware. - CSRFConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // TokenLength is the length of the generated token. - TokenLength uint8 `yaml:"token_length"` - // Optional. Default value 32. - - // TokenLookup is a string in the form of ":" or ":,:" that is used - // to extract token from the request. - // Optional. Default value "header:X-CSRF-Token". - // Possible values: - // - "header:" or "header::" - // - "query:" - // - "form:" - // Multiple sources example: - // - "header:X-CSRF-Token,query:csrf" - TokenLookup string `yaml:"token_lookup"` - - // Context key to store generated CSRF token into context. - // Optional. Default value "csrf". - ContextKey string `yaml:"context_key"` - - // Name of the CSRF cookie. This cookie will store CSRF token. - // Optional. Default value "csrf". - CookieName string `yaml:"cookie_name"` - - // Domain of the CSRF cookie. - // Optional. Default value none. - CookieDomain string `yaml:"cookie_domain"` - - // Path of the CSRF cookie. - // Optional. Default value none. - CookiePath string `yaml:"cookie_path"` - - // Max age (in seconds) of the CSRF cookie. - // Optional. Default value 86400 (24hr). - CookieMaxAge int `yaml:"cookie_max_age"` - - // Indicates if CSRF cookie is secure. - // Optional. Default value false. - CookieSecure bool `yaml:"cookie_secure"` - - // Indicates if CSRF cookie is HTTP only. - // Optional. Default value false. - CookieHTTPOnly bool `yaml:"cookie_http_only"` - - // Indicates SameSite mode of the CSRF cookie. - // Optional. Default value SameSiteDefaultMode. - CookieSameSite http.SameSite `yaml:"cookie_same_site"` - - // ErrorHandler defines a function which is executed for returning custom errors. - ErrorHandler CSRFErrorHandler - } +// CSRFConfig defines the config for CSRF middleware. +type CSRFConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // TokenLength is the length of the generated token. + TokenLength uint8 `yaml:"token_length"` + // Optional. Default value 32. + + // TokenLookup is a string in the form of ":" or ":,:" that is used + // to extract token from the request. + // Optional. Default value "header:X-CSRF-Token". + // Possible values: + // - "header:" or "header::" + // - "query:" + // - "form:" + // Multiple sources example: + // - "header:X-CSRF-Token,query:csrf" + TokenLookup string `yaml:"token_lookup"` + + // Context key to store generated CSRF token into context. + // Optional. Default value "csrf". + ContextKey string `yaml:"context_key"` + + // Name of the CSRF cookie. This cookie will store CSRF token. + // Optional. Default value "csrf". + CookieName string `yaml:"cookie_name"` + + // Domain of the CSRF cookie. + // Optional. Default value none. + CookieDomain string `yaml:"cookie_domain"` + + // Path of the CSRF cookie. + // Optional. Default value none. + CookiePath string `yaml:"cookie_path"` + + // Max age (in seconds) of the CSRF cookie. + // Optional. Default value 86400 (24hr). + CookieMaxAge int `yaml:"cookie_max_age"` + + // Indicates if CSRF cookie is secure. + // Optional. Default value false. + CookieSecure bool `yaml:"cookie_secure"` + + // Indicates if CSRF cookie is HTTP only. + // Optional. Default value false. + CookieHTTPOnly bool `yaml:"cookie_http_only"` + + // Indicates SameSite mode of the CSRF cookie. + // Optional. Default value SameSiteDefaultMode. + CookieSameSite http.SameSite `yaml:"cookie_same_site"` + + // ErrorHandler defines a function which is executed for returning custom errors. + ErrorHandler CSRFErrorHandler +} - // CSRFErrorHandler is a function which is executed for creating custom errors. - CSRFErrorHandler func(err error, c echo.Context) error -) +// CSRFErrorHandler is a function which is executed for creating custom errors. +type CSRFErrorHandler func(err error, c echo.Context) error // ErrCSRFInvalid is returned when CSRF check fails var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") -var ( - // DefaultCSRFConfig is the default CSRF middleware config. - DefaultCSRFConfig = CSRFConfig{ - Skipper: DefaultSkipper, - TokenLength: 32, - TokenLookup: "header:" + echo.HeaderXCSRFToken, - ContextKey: "csrf", - CookieName: "_csrf", - CookieMaxAge: 86400, - CookieSameSite: http.SameSiteDefaultMode, - } -) +// DefaultCSRFConfig is the default CSRF middleware config. +var DefaultCSRFConfig = CSRFConfig{ + Skipper: DefaultSkipper, + TokenLength: 32, + TokenLookup: "header:" + echo.HeaderXCSRFToken, + ContextKey: "csrf", + CookieName: "_csrf", + CookieMaxAge: 86400, + CookieSameSite: http.SameSiteDefaultMode, +} // CSRF returns a Cross-Site Request Forgery (CSRF) middleware. // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery @@ -103,6 +101,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if config.TokenLength == 0 { config.TokenLength = DefaultCSRFConfig.TokenLength } + if config.TokenLookup == "" { config.TokenLookup = DefaultCSRFConfig.TokenLookup } @@ -132,7 +131,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { token := "" if k, err := c.Cookie(config.CookieName); err != nil { - token = random.String(config.TokenLength) // Generate token + token = randomString(config.TokenLength) } else { token = k.Value // Reuse token } diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 6bccdbe4d..98e5d04f6 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -8,7 +11,6 @@ import ( "testing" "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" "github.com/stretchr/testify/assert" ) @@ -233,7 +235,7 @@ func TestCSRF(t *testing.T) { assert.Error(t, h(c)) // Valid CSRF token - token := random.String(32) + token := randomString(32) req.Header.Set(echo.HeaderCookie, "_csrf="+token) req.Header.Set(echo.HeaderXCSRFToken, token) if assert.NoError(t, h(c)) { diff --git a/middleware/decompress.go b/middleware/decompress.go index a73c9738b..0c56176ee 100644 --- a/middleware/decompress.go +++ b/middleware/decompress.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -9,16 +12,14 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // DecompressConfig defines the config for Decompress middleware. - DecompressConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// DecompressConfig defines the config for Decompress middleware. +type DecompressConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers - GzipDecompressPool Decompressor - } -) + // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers + GzipDecompressPool Decompressor +} // GZIPEncoding content-encoding header if set to "gzip", decompress body contents. const GZIPEncoding string = "gzip" @@ -28,13 +29,11 @@ type Decompressor interface { gzipDecompressPool() sync.Pool } -var ( - //DefaultDecompressConfig defines the config for decompress middleware - DefaultDecompressConfig = DecompressConfig{ - Skipper: DefaultSkipper, - GzipDecompressPool: &DefaultGzipDecompressPool{}, - } -) +// DefaultDecompressConfig defines the config for decompress middleware +var DefaultDecompressConfig = DecompressConfig{ + Skipper: DefaultSkipper, + GzipDecompressPool: &DefaultGzipDecompressPool{}, +} // DefaultGzipDecompressPool is the default implementation of Decompressor interface type DefaultGzipDecompressPool struct { diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go index 2e73ba80e..63b1a68f5 100644 --- a/middleware/decompress_test.go +++ b/middleware/decompress_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -131,7 +134,7 @@ func TestDecompressSkipper(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) e.ServeHTTP(rec, req) - assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8) + assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSON) reqBody, err := io.ReadAll(c.Request().Body) assert.NoError(t, err) assert.Equal(t, body, string(reqBody)) diff --git a/middleware/extractor.go b/middleware/extractor.go index 5d9cee6d0..3f2741407 100644 --- a/middleware/extractor.go +++ b/middleware/extractor.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go index 428c5563e..42cbcfeab 100644 --- a/middleware/extractor_test.go +++ b/middleware/extractor_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/jwt.go b/middleware/jwt.go deleted file mode 100644 index bc318c976..000000000 --- a/middleware/jwt.go +++ /dev/null @@ -1,304 +0,0 @@ -//go:build go1.15 -// +build go1.15 - -package middleware - -import ( - "errors" - "fmt" - "github.com/golang-jwt/jwt" - "github.com/labstack/echo/v4" - "net/http" - "reflect" -) - -type ( - // JWTConfig defines the config for JWT middleware. - JWTConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // BeforeFunc defines a function which is executed just before the middleware. - BeforeFunc BeforeFunc - - // SuccessHandler defines a function which is executed for a valid token before middleware chain continues with next - // middleware or handler. - SuccessHandler JWTSuccessHandler - - // ErrorHandler defines a function which is executed for an invalid token. - // It may be used to define a custom JWT error. - ErrorHandler JWTErrorHandler - - // ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context. - ErrorHandlerWithContext JWTErrorHandlerWithContext - - // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandlerWithContext decides to - // ignore the error (by returning `nil`). - // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. - // In that case you can use ErrorHandlerWithContext to set a default public JWT token value in the request context - // and continue. Some logic down the remaining execution chain needs to check that (public) token value then. - ContinueOnIgnoredError bool - - // Signing key to validate token. - // This is one of the three options to provide a token validation key. - // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. - // Required if neither user-defined KeyFunc nor SigningKeys is provided. - SigningKey interface{} - - // Map of signing keys to validate token with kid field usage. - // This is one of the three options to provide a token validation key. - // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. - // Required if neither user-defined KeyFunc nor SigningKey is provided. - SigningKeys map[string]interface{} - - // Signing method used to check the token's signing algorithm. - // Optional. Default value HS256. - SigningMethod string - - // Context key to store user information from the token into context. - // Optional. Default value "user". - ContextKey string - - // Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation. - // Not used if custom ParseTokenFunc is set. - // Optional. Default value jwt.MapClaims - Claims jwt.Claims - - // TokenLookup is a string in the form of ":" or ":,:" that is used - // to extract token from the request. - // Optional. Default value "header:Authorization". - // Possible values: - // - "header:" or "header::" - // `` is argument value to cut/trim prefix of the extracted value. This is useful if header - // value has static prefix like `Authorization: ` where part that we - // want to cut is ` ` note the space at the end. - // In case of JWT tokens `Authorization: Bearer ` prefix we cut is `Bearer `. - // If prefix is left empty the whole value is returned. - // - "query:" - // - "param:" - // - "cookie:" - // - "form:" - // Multiple sources example: - // - "header:Authorization,cookie:myowncookie" - TokenLookup string - - // TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context. - // This is one of the two options to provide a token extractor. - // The order of precedence is user-defined TokenLookupFuncs, and TokenLookup. - // You can also provide both if you want. - TokenLookupFuncs []ValuesExtractor - - // AuthScheme to be used in the Authorization header. - // Optional. Default value "Bearer". - AuthScheme string - - // KeyFunc defines a user-defined function that supplies the public key for a token validation. - // The function shall take care of verifying the signing algorithm and selecting the proper key. - // A user-defined KeyFunc can be useful if tokens are issued by an external party. - // Used by default ParseTokenFunc implementation. - // - // When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored. - // This is one of the three options to provide a token validation key. - // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. - // Required if neither SigningKeys nor SigningKey is provided. - // Not used if custom ParseTokenFunc is set. - // Default to an internal implementation verifying the signing algorithm and selecting the proper key. - KeyFunc jwt.Keyfunc - - // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token - // parsing fails or parsed token is invalid. - // Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library - ParseTokenFunc func(auth string, c echo.Context) (interface{}, error) - } - - // JWTSuccessHandler defines a function which is executed for a valid token. - JWTSuccessHandler func(c echo.Context) - - // JWTErrorHandler defines a function which is executed for an invalid token. - JWTErrorHandler func(err error) error - - // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context. - JWTErrorHandlerWithContext func(err error, c echo.Context) error -) - -// Algorithms -const ( - AlgorithmHS256 = "HS256" -) - -// Errors -var ( - ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt") - ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt") -) - -var ( - // DefaultJWTConfig is the default JWT auth middleware config. - DefaultJWTConfig = JWTConfig{ - Skipper: DefaultSkipper, - SigningMethod: AlgorithmHS256, - ContextKey: "user", - TokenLookup: "header:" + echo.HeaderAuthorization, - TokenLookupFuncs: nil, - AuthScheme: "Bearer", - Claims: jwt.MapClaims{}, - KeyFunc: nil, - } -) - -// JWT returns a JSON Web Token (JWT) auth middleware. -// -// For valid token, it sets the user in context and calls next handler. -// For invalid token, it returns "401 - Unauthorized" error. -// For missing token, it returns "400 - Bad Request" error. -// -// See: https://jwt.io/introduction -// See `JWTConfig.TokenLookup` -// -// Deprecated: Please use https://github.com/labstack/echo-jwt instead -func JWT(key interface{}) echo.MiddlewareFunc { - c := DefaultJWTConfig - c.SigningKey = key - return JWTWithConfig(c) -} - -// JWTWithConfig returns a JWT auth middleware with config. -// See: `JWT()`. -// -// Deprecated: Please use https://github.com/labstack/echo-jwt instead -func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { - // Defaults - if config.Skipper == nil { - config.Skipper = DefaultJWTConfig.Skipper - } - if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil && config.ParseTokenFunc == nil { - panic("echo: jwt middleware requires signing key") - } - if config.SigningMethod == "" { - config.SigningMethod = DefaultJWTConfig.SigningMethod - } - if config.ContextKey == "" { - config.ContextKey = DefaultJWTConfig.ContextKey - } - if config.Claims == nil { - config.Claims = DefaultJWTConfig.Claims - } - if config.TokenLookup == "" && len(config.TokenLookupFuncs) == 0 { - config.TokenLookup = DefaultJWTConfig.TokenLookup - } - if config.AuthScheme == "" { - config.AuthScheme = DefaultJWTConfig.AuthScheme - } - if config.KeyFunc == nil { - config.KeyFunc = config.defaultKeyFunc - } - if config.ParseTokenFunc == nil { - config.ParseTokenFunc = config.defaultParseToken - } - - extractors, cErr := createExtractors(config.TokenLookup, config.AuthScheme) - if cErr != nil { - panic(cErr) - } - if len(config.TokenLookupFuncs) > 0 { - extractors = append(config.TokenLookupFuncs, extractors...) - } - - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - if config.Skipper(c) { - return next(c) - } - - if config.BeforeFunc != nil { - config.BeforeFunc(c) - } - - var lastExtractorErr error - var lastTokenErr error - for _, extractor := range extractors { - auths, err := extractor(c) - if err != nil { - lastExtractorErr = ErrJWTMissing // backwards compatibility: all extraction errors are same (unlike KeyAuth) - continue - } - for _, auth := range auths { - token, err := config.ParseTokenFunc(auth, c) - if err != nil { - lastTokenErr = err - continue - } - // Store user information from token into context. - c.Set(config.ContextKey, token) - if config.SuccessHandler != nil { - config.SuccessHandler(c) - } - return next(c) - } - } - // we are here only when we did not successfully extract or parse any of the tokens - err := lastTokenErr - if err == nil { // prioritize token errors over extracting errors - err = lastExtractorErr - } - if config.ErrorHandler != nil { - return config.ErrorHandler(err) - } - if config.ErrorHandlerWithContext != nil { - tmpErr := config.ErrorHandlerWithContext(err, c) - if config.ContinueOnIgnoredError && tmpErr == nil { - return next(c) - } - return tmpErr - } - - // backwards compatible errors codes - if lastTokenErr != nil { - return &echo.HTTPError{ - Code: ErrJWTInvalid.Code, - Message: ErrJWTInvalid.Message, - Internal: err, - } - } - return err // this is lastExtractorErr value - } - } -} - -func (config *JWTConfig) defaultParseToken(auth string, c echo.Context) (interface{}, error) { - var token *jwt.Token - var err error - // Issue #647, #656 - if _, ok := config.Claims.(jwt.MapClaims); ok { - token, err = jwt.Parse(auth, config.KeyFunc) - } else { - t := reflect.ValueOf(config.Claims).Type().Elem() - claims := reflect.New(t).Interface().(jwt.Claims) - token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc) - } - if err != nil { - return nil, err - } - if !token.Valid { - return nil, errors.New("invalid token") - } - return token, nil -} - -// defaultKeyFunc returns a signing key of the given token. -func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) { - // Check the signing method - if t.Method.Alg() != config.SigningMethod { - return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) - } - if len(config.SigningKeys) > 0 { - if kid, ok := t.Header["kid"].(string); ok { - if key, ok := config.SigningKeys[kid]; ok { - return key, nil - } - } - return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"]) - } - - return config.SigningKey, nil -} diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go deleted file mode 100644 index 90e8cad81..000000000 --- a/middleware/jwt_test.go +++ /dev/null @@ -1,777 +0,0 @@ -//go:build go1.15 -// +build go1.15 - -package middleware - -import ( - "errors" - "fmt" - "net/http" - "net/http/httptest" - "net/url" - "strings" - "testing" - - "github.com/golang-jwt/jwt" - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" -) - -// jwtCustomInfo defines some custom types we're going to use within our tokens. -type jwtCustomInfo struct { - Name string `json:"name"` - Admin bool `json:"admin"` -} - -// jwtCustomClaims are custom claims expanding default ones. -type jwtCustomClaims struct { - *jwt.StandardClaims - jwtCustomInfo -} - -func TestJWT(t *testing.T) { - e := echo.New() - - e.GET("/", func(c echo.Context) error { - token := c.Get("user").(*jwt.Token) - return c.JSON(http.StatusOK, token.Claims) - }) - - e.Use(JWT([]byte("secret"))) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, "bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") - res := httptest.NewRecorder() - - e.ServeHTTP(res, req) - - assert.Equal(t, http.StatusOK, res.Code) - assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String()) -} - -func TestJWTRace(t *testing.T) { - e := echo.New() - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - initialToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" - raceToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IlJhY2UgQ29uZGl0aW9uIiwiYWRtaW4iOmZhbHNlfQ.Xzkx9mcgGqYMTkuxSCbJ67lsDyk5J2aB7hu65cEE-Ss" - validKey := []byte("secret") - - h := JWTWithConfig(JWTConfig{ - Claims: &jwtCustomClaims{}, - SigningKey: validKey, - })(handler) - - makeReq := func(token string) echo.Context { - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" "+token) - c := e.NewContext(req, res) - assert.NoError(t, h(c)) - return c - } - - c := makeReq(initialToken) - user := c.Get("user").(*jwt.Token) - claims := user.Claims.(*jwtCustomClaims) - assert.Equal(t, claims.Name, "John Doe") - - makeReq(raceToken) - user = c.Get("user").(*jwt.Token) - claims = user.Claims.(*jwtCustomClaims) - // Initial context should still be "John Doe", not "Race Condition" - assert.Equal(t, claims.Name, "John Doe") - assert.Equal(t, claims.Admin, true) -} - -func TestJWTConfig(t *testing.T) { - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" - validKey := []byte("secret") - invalidKey := []byte("invalid-key") - validAuth := DefaultJWTConfig.AuthScheme + " " + token - - testCases := []struct { - name string - expPanic bool - expErrCode int // 0 for Success - config JWTConfig - reqURL string // "/" if empty - hdrAuth string - hdrCookie string // test.Request doesn't provide SetCookie(); use name=val - formValues map[string]string - }{ - { - name: "No signing key provided", - expPanic: true, - }, - { - name: "Unexpected signing method", - expErrCode: http.StatusBadRequest, - config: JWTConfig{ - SigningKey: validKey, - SigningMethod: "RS256", - }, - }, - { - name: "Invalid key", - expErrCode: http.StatusUnauthorized, - hdrAuth: validAuth, - config: JWTConfig{SigningKey: invalidKey}, - }, - { - name: "Valid JWT", - hdrAuth: validAuth, - config: JWTConfig{SigningKey: validKey}, - }, - { - name: "Valid JWT with custom AuthScheme", - hdrAuth: "Token" + " " + token, - config: JWTConfig{AuthScheme: "Token", SigningKey: validKey}, - }, - { - name: "Valid JWT with custom claims", - hdrAuth: validAuth, - config: JWTConfig{ - Claims: &jwtCustomClaims{}, - SigningKey: []byte("secret"), - }, - }, - { - name: "Invalid Authorization header", - hdrAuth: "invalid-auth", - expErrCode: http.StatusBadRequest, - config: JWTConfig{SigningKey: validKey}, - }, - { - name: "Empty header auth field", - config: JWTConfig{SigningKey: validKey}, - expErrCode: http.StatusBadRequest, - }, - { - name: "Valid query method", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", - }, - reqURL: "/?a=b&jwt=" + token, - }, - { - name: "Invalid query param name", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", - }, - reqURL: "/?a=b&jwtxyz=" + token, - expErrCode: http.StatusBadRequest, - }, - { - name: "Invalid query param value", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", - }, - reqURL: "/?a=b&jwt=invalid-token", - expErrCode: http.StatusUnauthorized, - }, - { - name: "Empty query", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", - }, - reqURL: "/?a=b", - expErrCode: http.StatusBadRequest, - }, - { - name: "Valid param method", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "param:jwt", - }, - reqURL: "/" + token, - }, - { - name: "Valid cookie method", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "cookie:jwt", - }, - hdrCookie: "jwt=" + token, - }, - { - name: "Multiple jwt lookuop", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt,cookie:jwt", - }, - hdrCookie: "jwt=" + token, - }, - { - name: "Invalid token with cookie method", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "cookie:jwt", - }, - expErrCode: http.StatusUnauthorized, - hdrCookie: "jwt=invalid", - }, - { - name: "Empty cookie", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "cookie:jwt", - }, - expErrCode: http.StatusBadRequest, - }, - { - name: "Valid form method", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "form:jwt", - }, - formValues: map[string]string{"jwt": token}, - }, - { - name: "Invalid token with form method", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "form:jwt", - }, - expErrCode: http.StatusUnauthorized, - formValues: map[string]string{"jwt": "invalid"}, - }, - { - name: "Empty form field", - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "form:jwt", - }, - expErrCode: http.StatusBadRequest, - }, - { - name: "Valid JWT with a valid key using a user-defined KeyFunc", - hdrAuth: validAuth, - config: JWTConfig{ - KeyFunc: func(*jwt.Token) (interface{}, error) { - return validKey, nil - }, - }, - }, - { - name: "Valid JWT with an invalid key using a user-defined KeyFunc", - hdrAuth: validAuth, - config: JWTConfig{ - KeyFunc: func(*jwt.Token) (interface{}, error) { - return invalidKey, nil - }, - }, - expErrCode: http.StatusUnauthorized, - }, - { - name: "Token verification does not pass using a user-defined KeyFunc", - hdrAuth: validAuth, - config: JWTConfig{ - KeyFunc: func(*jwt.Token) (interface{}, error) { - return nil, errors.New("faulty KeyFunc") - }, - }, - expErrCode: http.StatusUnauthorized, - }, - { - name: "Valid JWT with lower case AuthScheme", - hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token, - config: JWTConfig{SigningKey: validKey}, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - if tc.reqURL == "" { - tc.reqURL = "/" - } - - var req *http.Request - if len(tc.formValues) > 0 { - form := url.Values{} - for k, v := range tc.formValues { - form.Set(k, v) - } - req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode())) - req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded") - req.ParseForm() - } else { - req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil) - } - res := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) - req.Header.Set(echo.HeaderCookie, tc.hdrCookie) - c := e.NewContext(req, res) - - if tc.reqURL == "/"+token { - c.SetParamNames("jwt") - c.SetParamValues(token) - } - - if tc.expPanic { - assert.Panics(t, func() { - JWTWithConfig(tc.config) - }, tc.name) - return - } - - if tc.expErrCode != 0 { - h := JWTWithConfig(tc.config)(handler) - he := h(c).(*echo.HTTPError) - assert.Equal(t, tc.expErrCode, he.Code, tc.name) - return - } - - h := JWTWithConfig(tc.config)(handler) - if assert.NoError(t, h(c), tc.name) { - user := c.Get("user").(*jwt.Token) - switch claims := user.Claims.(type) { - case jwt.MapClaims: - assert.Equal(t, claims["name"], "John Doe", tc.name) - case *jwtCustomClaims: - assert.Equal(t, claims.Name, "John Doe", tc.name) - assert.Equal(t, claims.Admin, true, tc.name) - default: - panic("unexpected type of claims") - } - } - }) - } -} - -func TestJWTwithKID(t *testing.T) { - e := echo.New() - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - firstToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6ImZpcnN0T25lIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.w5VGpHOe0jlNgf7jMVLHzIYH_XULmpUlreJnilwSkWk" - secondToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.sdghDYQ85jdh0hgQ6bKbMguLI_NSPYWjkhVJkee-yZM" - wrongToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.RyhLybtVLpoewF6nz9YN79oXo32kAtgUxp8FNwTkb90" - staticToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.1_-XFYUPpJfgsaGwYhgZEt7hfySMg-a3GN-nfZmbW7o" - validKeys := map[string]interface{}{"firstOne": []byte("first_secret"), "secondOne": []byte("second_secret")} - invalidKeys := map[string]interface{}{"thirdOne": []byte("third_secret")} - staticSecret := []byte("static_secret") - invalidStaticSecret := []byte("invalid_secret") - - for _, tc := range []struct { - expErrCode int // 0 for Success - config JWTConfig - hdrAuth string - info string - }{ - { - hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken, - config: JWTConfig{SigningKeys: validKeys}, - info: "First token valid", - }, - { - hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken, - config: JWTConfig{SigningKeys: validKeys}, - info: "Second token valid", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + wrongToken, - config: JWTConfig{SigningKeys: validKeys}, - info: "Wrong key id token", - }, - { - hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken, - config: JWTConfig{SigningKey: staticSecret}, - info: "Valid static secret token", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken, - config: JWTConfig{SigningKey: invalidStaticSecret}, - info: "Invalid static secret", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken, - config: JWTConfig{SigningKeys: invalidKeys}, - info: "Invalid keys first token", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken, - config: JWTConfig{SigningKeys: invalidKeys}, - info: "Invalid keys second token", - }, - } { - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) - c := e.NewContext(req, res) - - if tc.expErrCode != 0 { - h := JWTWithConfig(tc.config)(handler) - he := h(c).(*echo.HTTPError) - assert.Equal(t, tc.expErrCode, he.Code, tc.info) - continue - } - - h := JWTWithConfig(tc.config)(handler) - if assert.NoError(t, h(c), tc.info) { - user := c.Get("user").(*jwt.Token) - switch claims := user.Claims.(type) { - case jwt.MapClaims: - assert.Equal(t, claims["name"], "John Doe", tc.info) - case *jwtCustomClaims: - assert.Equal(t, claims.Name, "John Doe", tc.info) - assert.Equal(t, claims.Admin, true, tc.info) - default: - panic("unexpected type of claims") - } - } - } -} - -func TestJWTConfig_skipper(t *testing.T) { - e := echo.New() - - e.Use(JWTWithConfig(JWTConfig{ - Skipper: func(context echo.Context) bool { - return true // skip everything - }, - SigningKey: []byte("secret"), - })) - - isCalled := false - e.GET("/", func(c echo.Context) error { - isCalled = true - return c.String(http.StatusTeapot, "test") - }) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - e.ServeHTTP(res, req) - - assert.Equal(t, http.StatusTeapot, res.Code) - assert.True(t, isCalled) -} - -func TestJWTConfig_BeforeFunc(t *testing.T) { - e := echo.New() - e.GET("/", func(c echo.Context) error { - return c.String(http.StatusTeapot, "test") - }) - - isCalled := false - e.Use(JWTWithConfig(JWTConfig{ - BeforeFunc: func(context echo.Context) { - isCalled = true - }, - SigningKey: []byte("secret"), - })) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") - res := httptest.NewRecorder() - e.ServeHTTP(res, req) - - assert.Equal(t, http.StatusTeapot, res.Code) - assert.True(t, isCalled) -} - -func TestJWTConfig_extractorErrorHandling(t *testing.T) { - var testCases = []struct { - name string - given JWTConfig - expectStatusCode int - }{ - { - name: "ok, ErrorHandler is executed", - given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandler: func(err error) error { - return echo.NewHTTPError(http.StatusTeapot, "custom_error") - }, - }, - expectStatusCode: http.StatusTeapot, - }, - { - name: "ok, ErrorHandlerWithContext is executed", - given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandlerWithContext: func(err error, context echo.Context) error { - return echo.NewHTTPError(http.StatusTeapot, "custom_error") - }, - }, - expectStatusCode: http.StatusTeapot, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - e.GET("/", func(c echo.Context) error { - return c.String(http.StatusNotImplemented, "should not end up here") - }) - - e.Use(JWTWithConfig(tc.given)) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - e.ServeHTTP(res, req) - - assert.Equal(t, tc.expectStatusCode, res.Code) - }) - } -} - -func TestJWTConfig_parseTokenErrorHandling(t *testing.T) { - var testCases = []struct { - name string - given JWTConfig - expectErr string - }{ - { - name: "ok, ErrorHandler is executed", - given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandler: func(err error) error { - return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error()) - }, - }, - expectErr: "{\"message\":\"ErrorHandler: parsing failed\"}\n", - }, - { - name: "ok, ErrorHandlerWithContext is executed", - given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandlerWithContext: func(err error, context echo.Context) error { - return echo.NewHTTPError(http.StatusTeapot, "ErrorHandlerWithContext: "+err.Error()) - }, - }, - expectErr: "{\"message\":\"ErrorHandlerWithContext: parsing failed\"}\n", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - //e.Debug = true - e.GET("/", func(c echo.Context) error { - return c.String(http.StatusNotImplemented, "should not end up here") - }) - - config := tc.given - parseTokenCalled := false - config.ParseTokenFunc = func(auth string, c echo.Context) (interface{}, error) { - parseTokenCalled = true - return nil, errors.New("parsing failed") - } - e.Use(JWTWithConfig(config)) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") - res := httptest.NewRecorder() - - e.ServeHTTP(res, req) - - assert.Equal(t, http.StatusTeapot, res.Code) - assert.Equal(t, tc.expectErr, res.Body.String()) - assert.True(t, parseTokenCalled) - }) - } -} - -func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) { - e := echo.New() - e.GET("/", func(c echo.Context) error { - return c.String(http.StatusTeapot, "test") - }) - - // example of minimal custom ParseTokenFunc implementation. Allows you to use different versions of `github.com/golang-jwt/jwt` - // with current JWT middleware - signingKey := []byte("secret") - - config := JWTConfig{ - ParseTokenFunc: func(auth string, c echo.Context) (interface{}, error) { - keyFunc := func(t *jwt.Token) (interface{}, error) { - if t.Method.Alg() != "HS256" { - return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) - } - return signingKey, nil - } - - // claims are of type `jwt.MapClaims` when token is created with `jwt.Parse` - token, err := jwt.Parse(auth, keyFunc) - if err != nil { - return nil, err - } - if !token.Valid { - return nil, errors.New("invalid token") - } - return token, nil - }, - } - - e.Use(JWTWithConfig(config)) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") - res := httptest.NewRecorder() - e.ServeHTTP(res, req) - - assert.Equal(t, http.StatusTeapot, res.Code) -} - -func TestJWTConfig_TokenLookupFuncs(t *testing.T) { - e := echo.New() - - e.GET("/", func(c echo.Context) error { - token := c.Get("user").(*jwt.Token) - return c.JSON(http.StatusOK, token.Claims) - }) - - e.Use(JWTWithConfig(JWTConfig{ - TokenLookupFuncs: []ValuesExtractor{ - func(c echo.Context) ([]string, error) { - return []string{c.Request().Header.Get("X-API-Key")}, nil - }, - }, - SigningKey: []byte("secret"), - })) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set("X-API-Key", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") - res := httptest.NewRecorder() - e.ServeHTTP(res, req) - - assert.Equal(t, http.StatusOK, res.Code) - assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String()) -} - -func TestJWTConfig_SuccessHandler(t *testing.T) { - var testCases = []struct { - name string - givenToken string - expectCalled bool - expectStatus int - }{ - { - name: "ok, success handler is called", - givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ", - expectCalled: true, - expectStatus: http.StatusOK, - }, - { - name: "nok, success handler is not called", - givenToken: "x.x.x", - expectCalled: false, - expectStatus: http.StatusUnauthorized, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - - e.GET("/", func(c echo.Context) error { - token := c.Get("user").(*jwt.Token) - return c.JSON(http.StatusOK, token.Claims) - }) - - wasCalled := false - e.Use(JWTWithConfig(JWTConfig{ - SuccessHandler: func(c echo.Context) { - wasCalled = true - }, - SigningKey: []byte("secret"), - })) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken) - res := httptest.NewRecorder() - - e.ServeHTTP(res, req) - - assert.Equal(t, tc.expectCalled, wasCalled) - assert.Equal(t, tc.expectStatus, res.Code) - }) - } -} - -func TestJWTConfig_ContinueOnIgnoredError(t *testing.T) { - var testCases = []struct { - name string - whenContinueOnIgnoredError bool - givenToken string - expectStatus int - expectBody string - }{ - { - name: "no error handler is called", - whenContinueOnIgnoredError: true, - givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ", - expectStatus: http.StatusTeapot, - expectBody: "", - }, - { - name: "ContinueOnIgnoredError is false and error handler is called for missing token", - whenContinueOnIgnoredError: false, - givenToken: "", - // empty response with 200. This emulates previous behaviour when error handler swallowed the error - expectStatus: http.StatusOK, - expectBody: "", - }, - { - name: "error handler is called for missing token", - whenContinueOnIgnoredError: true, - givenToken: "", - expectStatus: http.StatusTeapot, - expectBody: "public-token", - }, - { - name: "error handler is called for invalid token", - whenContinueOnIgnoredError: true, - givenToken: "x.x.x", - expectStatus: http.StatusUnauthorized, - expectBody: "{\"message\":\"Unauthorized\"}\n", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - - e.GET("/", func(c echo.Context) error { - testValue, _ := c.Get("test").(string) - return c.String(http.StatusTeapot, testValue) - }) - - e.Use(JWTWithConfig(JWTConfig{ - ContinueOnIgnoredError: tc.whenContinueOnIgnoredError, - SigningKey: []byte("secret"), - ErrorHandlerWithContext: func(err error, c echo.Context) error { - if err == ErrJWTMissing { - c.Set("test", "public-token") - return nil - } - return echo.ErrUnauthorized - }, - })) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - if tc.givenToken != "" { - req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken) - } - res := httptest.NewRecorder() - - e.ServeHTTP(res, req) - - assert.Equal(t, tc.expectStatus, res.Code) - assert.Equal(t, tc.expectBody, res.Body.String()) - }) - } -} diff --git a/middleware/key_auth.go b/middleware/key_auth.go index f6fcc5d69..79bee207c 100644 --- a/middleware/key_auth.go +++ b/middleware/key_auth.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -6,69 +9,65 @@ import ( "net/http" ) -type ( - // KeyAuthConfig defines the config for KeyAuth middleware. - KeyAuthConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // KeyLookup is a string in the form of ":" or ":,:" that is used - // to extract key from the request. - // Optional. Default value "header:Authorization". - // Possible values: - // - "header:" or "header::" - // `` is argument value to cut/trim prefix of the extracted value. This is useful if header - // value has static prefix like `Authorization: ` where part that we - // want to cut is ` ` note the space at the end. - // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. - // - "query:" - // - "form:" - // - "cookie:" - // Multiple sources example: - // - "header:Authorization,header:X-Api-Key" - KeyLookup string - - // AuthScheme to be used in the Authorization header. - // Optional. Default value "Bearer". - AuthScheme string - - // Validator is a function to validate key. - // Required. - Validator KeyAuthValidator - - // ErrorHandler defines a function which is executed for an invalid key. - // It may be used to define a custom error. - ErrorHandler KeyAuthErrorHandler - - // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to - // ignore the error (by returning `nil`). - // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. - // In that case you can use ErrorHandler to set a default public key auth value in the request context - // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then. - ContinueOnIgnoredError bool - } +// KeyAuthConfig defines the config for KeyAuth middleware. +type KeyAuthConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // KeyLookup is a string in the form of ":" or ":,:" that is used + // to extract key from the request. + // Optional. Default value "header:Authorization". + // Possible values: + // - "header:" or "header::" + // `` is argument value to cut/trim prefix of the extracted value. This is useful if header + // value has static prefix like `Authorization: ` where part that we + // want to cut is ` ` note the space at the end. + // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. + // - "query:" + // - "form:" + // - "cookie:" + // Multiple sources example: + // - "header:Authorization,header:X-Api-Key" + KeyLookup string + + // AuthScheme to be used in the Authorization header. + // Optional. Default value "Bearer". + AuthScheme string + + // Validator is a function to validate key. + // Required. + Validator KeyAuthValidator + + // ErrorHandler defines a function which is executed for an invalid key. + // It may be used to define a custom error. + ErrorHandler KeyAuthErrorHandler + + // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to + // ignore the error (by returning `nil`). + // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. + // In that case you can use ErrorHandler to set a default public key auth value in the request context + // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then. + ContinueOnIgnoredError bool +} - // KeyAuthValidator defines a function to validate KeyAuth credentials. - KeyAuthValidator func(auth string, c echo.Context) (bool, error) +// KeyAuthValidator defines a function to validate KeyAuth credentials. +type KeyAuthValidator func(auth string, c echo.Context) (bool, error) - // KeyAuthErrorHandler defines a function which is executed for an invalid key. - KeyAuthErrorHandler func(err error, c echo.Context) error -) - -var ( - // DefaultKeyAuthConfig is the default KeyAuth middleware config. - DefaultKeyAuthConfig = KeyAuthConfig{ - Skipper: DefaultSkipper, - KeyLookup: "header:" + echo.HeaderAuthorization, - AuthScheme: "Bearer", - } -) +// KeyAuthErrorHandler defines a function which is executed for an invalid key. +type KeyAuthErrorHandler func(err error, c echo.Context) error // ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups type ErrKeyAuthMissing struct { Err error } +// DefaultKeyAuthConfig is the default KeyAuth middleware config. +var DefaultKeyAuthConfig = KeyAuthConfig{ + Skipper: DefaultSkipper, + KeyLookup: "header:" + echo.HeaderAuthorization, + AuthScheme: "Bearer", +} + // Error returns errors text func (e *ErrKeyAuthMissing) Error() string { return e.Err.Error() diff --git a/middleware/key_auth_test.go b/middleware/key_auth_test.go index ff8968c38..447f0bee8 100644 --- a/middleware/key_auth_test.go +++ b/middleware/key_auth_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/logger.go b/middleware/logger.go index 7958d873b..5d9d29e1b 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -14,85 +17,248 @@ import ( "github.com/valyala/fasttemplate" ) -type ( - // LoggerConfig defines the config for Logger middleware. - LoggerConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Tags to construct the logger format. - // - // - time_unix - // - time_unix_milli - // - time_unix_micro - // - time_unix_nano - // - time_rfc3339 - // - time_rfc3339_nano - // - time_custom - // - id (Request ID) - // - remote_ip - // - uri - // - host - // - method - // - path - // - route - // - protocol - // - referer - // - user_agent - // - status - // - error - // - latency (In nanoseconds) - // - latency_human (Human readable) - // - bytes_in (Bytes received) - // - bytes_out (Bytes sent) - // - header: - // - query: - // - form: - // - custom (see CustomTagFunc field) - // - // Example "${remote_ip} ${status}" - // - // Optional. Default value DefaultLoggerConfig.Format. - Format string `yaml:"format"` - - // Optional. Default value DefaultLoggerConfig.CustomTimeFormat. - CustomTimeFormat string `yaml:"custom_time_format"` - - // CustomTagFunc is function called for `${custom}` tag to output user implemented text by writing it to buf. - // Make sure that outputted text creates valid JSON string with other logged tags. - // Optional. - CustomTagFunc func(c echo.Context, buf *bytes.Buffer) (int, error) - - // Output is a writer where logs in JSON format are written. - // Optional. Default value os.Stdout. - Output io.Writer - - template *fasttemplate.Template - colorer *color.Color - pool *sync.Pool - } -) +// LoggerConfig defines the config for Logger middleware. +// +// # Configuration Examples +// +// ## Basic Usage with Default Settings +// +// e.Use(middleware.Logger()) +// +// This uses the default JSON format that logs all common request/response details. +// +// ## Custom Simple Format +// +// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ +// Format: "${time_rfc3339_nano} ${status} ${method} ${uri} ${latency_human}\n", +// })) +// +// ## JSON Format with Custom Fields +// +// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ +// Format: `{"timestamp":"${time_rfc3339_nano}","level":"info","remote_ip":"${remote_ip}",` + +// `"method":"${method}","uri":"${uri}","status":${status},"latency":"${latency_human}",` + +// `"user_agent":"${user_agent}","error":"${error}"}` + "\n", +// })) +// +// ## Custom Time Format +// +// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ +// Format: "${time_custom} ${method} ${uri} ${status}\n", +// CustomTimeFormat: "2006-01-02 15:04:05", +// })) +// +// ## Logging Headers and Parameters +// +// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ +// Format: `{"time":"${time_rfc3339_nano}","method":"${method}","uri":"${uri}",` + +// `"status":${status},"auth":"${header:Authorization}","user":"${query:user}",` + +// `"form_data":"${form:action}","session":"${cookie:session_id}"}` + "\n", +// })) +// +// ## Custom Output (File Logging) +// +// file, err := os.OpenFile("app.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) +// if err != nil { +// log.Fatal(err) +// } +// defer file.Close() +// +// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ +// Output: file, +// })) +// +// ## Custom Tag Function +// +// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ +// Format: `{"time":"${time_rfc3339_nano}","user_id":"${custom}","method":"${method}"}` + "\n", +// CustomTagFunc: func(c echo.Context, buf *bytes.Buffer) (int, error) { +// userID := getUserIDFromContext(c) // Your custom logic +// return buf.WriteString(strconv.Itoa(userID)) +// }, +// })) +// +// ## Conditional Logging (Skip Certain Requests) +// +// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ +// Skipper: func(c echo.Context) bool { +// // Skip logging for health check endpoints +// return c.Request().URL.Path == "/health" || c.Request().URL.Path == "/metrics" +// }, +// })) +// +// ## Integration with External Logging Service +// +// logBuffer := &SyncBuffer{} // Thread-safe buffer for external service +// +// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ +// Format: `{"timestamp":"${time_rfc3339_nano}","service":"my-api","level":"info",` + +// `"method":"${method}","uri":"${uri}","status":${status},"latency_ms":${latency},` + +// `"remote_ip":"${remote_ip}","user_agent":"${user_agent}","error":"${error}"}` + "\n", +// Output: logBuffer, +// })) +// +// # Available Tags +// +// ## Time Tags +// - time_unix: Unix timestamp (seconds) +// - time_unix_milli: Unix timestamp (milliseconds) +// - time_unix_micro: Unix timestamp (microseconds) +// - time_unix_nano: Unix timestamp (nanoseconds) +// - time_rfc3339: RFC3339 format (2006-01-02T15:04:05Z07:00) +// - time_rfc3339_nano: RFC3339 with nanoseconds +// - time_custom: Uses CustomTimeFormat field +// +// ## Request Information +// - id: Request ID from X-Request-ID header +// - remote_ip: Client IP address (respects proxy headers) +// - uri: Full request URI with query parameters +// - host: Host header value +// - method: HTTP method (GET, POST, etc.) +// - path: URL path without query parameters +// - route: Echo route pattern (e.g., /users/:id) +// - protocol: HTTP protocol version +// - referer: Referer header value +// - user_agent: User-Agent header value +// +// ## Response Information +// - status: HTTP status code +// - error: Error message if request failed +// - latency: Request processing time in nanoseconds +// - latency_human: Human-readable processing time +// - bytes_in: Request body size in bytes +// - bytes_out: Response body size in bytes +// +// ## Dynamic Tags +// - header:: Value of specific header (e.g., header:Authorization) +// - query:: Value of specific query parameter (e.g., query:user_id) +// - form:: Value of specific form field (e.g., form:username) +// - cookie:: Value of specific cookie (e.g., cookie:session_id) +// - custom: Output from CustomTagFunc +// +// # Troubleshooting +// +// ## Common Issues +// +// 1. **Missing logs**: Check if Skipper function is filtering out requests +// 2. **Invalid JSON**: Ensure CustomTagFunc outputs valid JSON content +// 3. **Performance issues**: Consider using a buffered writer for high-traffic applications +// 4. **File permission errors**: Ensure write permissions when logging to files +// +// ## Performance Tips +// +// - Use time_unix formats for better performance than time_rfc3339 +// - Minimize the number of dynamic tags (header:, query:, form:, cookie:) +// - Use Skipper to exclude high-frequency, low-value requests (health checks, etc.) +// - Consider async logging for very high-traffic applications +type LoggerConfig struct { + // Skipper defines a function to skip middleware. + // Use this to exclude certain requests from logging (e.g., health checks). + // + // Example: + // Skipper: func(c echo.Context) bool { + // return c.Request().URL.Path == "/health" + // }, + Skipper Skipper -var ( - // DefaultLoggerConfig is the default Logger middleware config. - DefaultLoggerConfig = LoggerConfig{ - Skipper: DefaultSkipper, - Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` + - `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` + - `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` + - `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", - CustomTimeFormat: "2006-01-02 15:04:05.00000", - colorer: color.New(), - } -) + // Format defines the logging format using template tags. + // Tags are enclosed in ${} and replaced with actual values. + // See the detailed tag documentation above for all available options. + // + // Default: JSON format with common fields + // Example: "${time_rfc3339_nano} ${status} ${method} ${uri} ${latency_human}\n" + Format string `yaml:"format"` + + // CustomTimeFormat specifies the time format used by ${time_custom} tag. + // Uses Go's reference time: Mon Jan 2 15:04:05 MST 2006 + // + // Default: "2006-01-02 15:04:05.00000" + // Example: "2006-01-02 15:04:05" or "15:04:05.000" + CustomTimeFormat string `yaml:"custom_time_format"` + + // CustomTagFunc is called when ${custom} tag is encountered. + // Use this to add application-specific information to logs. + // The function should write valid content for your log format. + // + // Example: + // CustomTagFunc: func(c echo.Context, buf *bytes.Buffer) (int, error) { + // userID := getUserFromContext(c) + // return buf.WriteString(`"user_id":"` + userID + `"`) + // }, + CustomTagFunc func(c echo.Context, buf *bytes.Buffer) (int, error) + + // Output specifies where logs are written. + // Can be any io.Writer: files, buffers, network connections, etc. + // + // Default: os.Stdout + // Example: Custom file, syslog, or external logging service + Output io.Writer + + template *fasttemplate.Template + colorer *color.Color + pool *sync.Pool +} + +// DefaultLoggerConfig is the default Logger middleware config. +var DefaultLoggerConfig = LoggerConfig{ + Skipper: DefaultSkipper, + Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` + + `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` + + `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` + + `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", + CustomTimeFormat: "2006-01-02 15:04:05.00000", + colorer: color.New(), +} -// Logger returns a middleware that logs HTTP requests. +// Logger returns a middleware that logs HTTP requests using the default configuration. +// +// The default format logs requests as JSON with the following fields: +// - time: RFC3339 nano timestamp +// - id: Request ID from X-Request-ID header +// - remote_ip: Client IP address +// - host: Host header +// - method: HTTP method +// - uri: Request URI +// - user_agent: User-Agent header +// - status: HTTP status code +// - error: Error message (if any) +// - latency: Processing time in nanoseconds +// - latency_human: Human-readable processing time +// - bytes_in: Request body size +// - bytes_out: Response body size +// +// Example output: +// +// {"time":"2023-01-15T10:30:45.123456789Z","id":"","remote_ip":"127.0.0.1", +// "host":"localhost:8080","method":"GET","uri":"/users/123","user_agent":"curl/7.81.0", +// "status":200,"error":"","latency":1234567,"latency_human":"1.234567ms", +// "bytes_in":0,"bytes_out":42} +// +// For custom configurations, use LoggerWithConfig instead. func Logger() echo.MiddlewareFunc { return LoggerWithConfig(DefaultLoggerConfig) } -// LoggerWithConfig returns a Logger middleware with config. -// See: `Logger()`. +// LoggerWithConfig returns a Logger middleware with custom configuration. +// +// This function allows you to customize all aspects of request logging including: +// - Log format and fields +// - Output destination +// - Time formatting +// - Custom tags and logic +// - Request filtering +// +// See LoggerConfig documentation for detailed configuration examples and options. +// +// Example: +// +// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ +// Format: "${time_rfc3339} ${status} ${method} ${uri} ${latency_human}\n", +// Output: customLogWriter, +// Skipper: func(c echo.Context) bool { +// return c.Request().URL.Path == "/health" +// }, +// })) func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { // Defaults if config.Skipper == nil { diff --git a/middleware/logger_test.go b/middleware/logger_test.go index 9f35a70bc..d5236e1ac 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/method_override.go b/middleware/method_override.go index 92b14d2ed..3991e1029 100644 --- a/middleware/method_override.go +++ b/middleware/method_override.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -6,28 +9,24 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // MethodOverrideConfig defines the config for MethodOverride middleware. - MethodOverrideConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// MethodOverrideConfig defines the config for MethodOverride middleware. +type MethodOverrideConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Getter is a function that gets overridden method from the request. - // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride). - Getter MethodOverrideGetter - } + // Getter is a function that gets overridden method from the request. + // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride). + Getter MethodOverrideGetter +} - // MethodOverrideGetter is a function that gets overridden method from the request - MethodOverrideGetter func(echo.Context) string -) +// MethodOverrideGetter is a function that gets overridden method from the request +type MethodOverrideGetter func(echo.Context) string -var ( - // DefaultMethodOverrideConfig is the default MethodOverride middleware config. - DefaultMethodOverrideConfig = MethodOverrideConfig{ - Skipper: DefaultSkipper, - Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), - } -) +// DefaultMethodOverrideConfig is the default MethodOverride middleware config. +var DefaultMethodOverrideConfig = MethodOverrideConfig{ + Skipper: DefaultSkipper, + Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), +} // MethodOverride returns a MethodOverride middleware. // MethodOverride middleware checks for the overridden method from the request and diff --git a/middleware/method_override_test.go b/middleware/method_override_test.go index 5760b1581..0000d1d80 100644 --- a/middleware/method_override_test.go +++ b/middleware/method_override_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/middleware.go b/middleware/middleware.go index 664f71f45..6f33cc5c1 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -9,14 +12,12 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // Skipper defines a function to skip middleware. Returning true skips processing - // the middleware. - Skipper func(c echo.Context) bool +// Skipper defines a function to skip middleware. Returning true skips processing +// the middleware. +type Skipper func(c echo.Context) bool - // BeforeFunc defines a function which is executed just before the middleware. - BeforeFunc func(c echo.Context) -) +// BeforeFunc defines a function which is executed just before the middleware. +type BeforeFunc func(c echo.Context) func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { groups := pattern.FindAllStringSubmatch(input, -1) @@ -53,7 +54,7 @@ func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error return nil } - // Depending how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path. + // Depending on how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path. // We only want to use path part for rewriting and therefore trim prefix if it exists rawURI := req.RequestURI if rawURI != "" && rawURI[0] != '/' { diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 44f44142c..7f3dc3866 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -1,7 +1,13 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( + "bufio" + "errors" "github.com/stretchr/testify/assert" + "net" "net/http" "net/http/httptest" "regexp" @@ -90,3 +96,46 @@ func TestRewriteURL(t *testing.T) { }) } } + +type testResponseWriterNoFlushHijack struct { +} + +func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) { +} + +func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) { + return 0, nil +} + +func (w *testResponseWriterNoFlushHijack) Header() http.Header { + return nil +} + +type testResponseWriterUnwrapper struct { + unwrapCalled int + rw http.ResponseWriter +} + +func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) { +} + +func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) { + return 0, nil +} + +func (w *testResponseWriterUnwrapper) Header() http.Header { + return nil +} + +func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter { + w.unwrapCalled++ + return w.rw +} + +type testResponseWriterUnwrapperHijack struct { + testResponseWriterUnwrapper +} + +func (w *testResponseWriterUnwrapperHijack) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, errors.New("can hijack") +} diff --git a/middleware/proxy.go b/middleware/proxy.go index e4f98d9ed..2744bc4a8 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -1,7 +1,11 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( "context" + "crypto/tls" "fmt" "io" "math/rand" @@ -19,119 +23,129 @@ import ( // TODO: Handle TLS proxy -type ( - // ProxyConfig defines the config for Proxy middleware. - ProxyConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Balancer defines a load balancing technique. - // Required. - Balancer ProxyBalancer - - // RetryCount defines the number of times a failed proxied request should be retried - // using the next available ProxyTarget. Defaults to 0, meaning requests are never retried. - RetryCount int - - // RetryFilter defines a function used to determine if a failed request to a - // ProxyTarget should be retried. The RetryFilter will only be called when the number - // of previous retries is less than RetryCount. If the function returns true, the - // request will be retried. The provided error indicates the reason for the request - // failure. When the ProxyTarget is unavailable, the error will be an instance of - // echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error - // will indicate an internal error in the Proxy middleware. When a RetryFilter is not - // specified, all requests that fail with http.StatusBadGateway will be retried. A custom - // RetryFilter can be provided to only retry specific requests. Note that RetryFilter is - // only called when the request to the target fails, or an internal error in the Proxy - // middleware has occurred. Successful requests that return a non-200 response code cannot - // be retried. - RetryFilter func(c echo.Context, e error) bool - - // ErrorHandler defines a function which can be used to return custom errors from - // the Proxy middleware. ErrorHandler is only invoked when there has been - // either an internal error in the Proxy middleware or the ProxyTarget is - // unavailable. Due to the way requests are proxied, ErrorHandler is not invoked - // when a ProxyTarget returns a non-200 response. In these cases, the response - // is already written so errors cannot be modified. ErrorHandler is only - // invoked after all retry attempts have been exhausted. - ErrorHandler func(c echo.Context, err error) error - - // Rewrite defines URL path rewrite rules. The values captured in asterisk can be - // retrieved by index e.g. $1, $2 and so on. - // Examples: - // "/old": "/new", - // "/api/*": "/$1", - // "/js/*": "/public/javascripts/$1", - // "/users/*/orders/*": "/user/$1/order/$2", - Rewrite map[string]string - - // RegexRewrite defines rewrite rules using regexp.Rexexp with captures - // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. - // Example: - // "^/old/[0.9]+/": "/new", - // "^/api/.+?/(.*)": "/v2/$1", - RegexRewrite map[*regexp.Regexp]string - - // Context key to store selected ProxyTarget into context. - // Optional. Default value "target". - ContextKey string - - // To customize the transport to remote. - // Examples: If custom TLS certificates are required. - Transport http.RoundTripper - - // ModifyResponse defines function to modify response from ProxyTarget. - ModifyResponse func(*http.Response) error - } +// ProxyConfig defines the config for Proxy middleware. +type ProxyConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Balancer defines a load balancing technique. + // Required. + Balancer ProxyBalancer + + // RetryCount defines the number of times a failed proxied request should be retried + // using the next available ProxyTarget. Defaults to 0, meaning requests are never retried. + RetryCount int + + // RetryFilter defines a function used to determine if a failed request to a + // ProxyTarget should be retried. The RetryFilter will only be called when the number + // of previous retries is less than RetryCount. If the function returns true, the + // request will be retried. The provided error indicates the reason for the request + // failure. When the ProxyTarget is unavailable, the error will be an instance of + // echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error + // will indicate an internal error in the Proxy middleware. When a RetryFilter is not + // specified, all requests that fail with http.StatusBadGateway will be retried. A custom + // RetryFilter can be provided to only retry specific requests. Note that RetryFilter is + // only called when the request to the target fails, or an internal error in the Proxy + // middleware has occurred. Successful requests that return a non-200 response code cannot + // be retried. + RetryFilter func(c echo.Context, e error) bool + + // ErrorHandler defines a function which can be used to return custom errors from + // the Proxy middleware. ErrorHandler is only invoked when there has been + // either an internal error in the Proxy middleware or the ProxyTarget is + // unavailable. Due to the way requests are proxied, ErrorHandler is not invoked + // when a ProxyTarget returns a non-200 response. In these cases, the response + // is already written so errors cannot be modified. ErrorHandler is only + // invoked after all retry attempts have been exhausted. + ErrorHandler func(c echo.Context, err error) error + + // Rewrite defines URL path rewrite rules. The values captured in asterisk can be + // retrieved by index e.g. $1, $2 and so on. + // Examples: + // "/old": "/new", + // "/api/*": "/$1", + // "/js/*": "/public/javascripts/$1", + // "/users/*/orders/*": "/user/$1/order/$2", + Rewrite map[string]string + + // RegexRewrite defines rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRewrite map[*regexp.Regexp]string + + // Context key to store selected ProxyTarget into context. + // Optional. Default value "target". + ContextKey string + + // To customize the transport to remote. + // Examples: If custom TLS certificates are required. + Transport http.RoundTripper + + // ModifyResponse defines function to modify response from ProxyTarget. + ModifyResponse func(*http.Response) error +} - // ProxyTarget defines the upstream target. - ProxyTarget struct { - Name string - URL *url.URL - Meta echo.Map - } +// ProxyTarget defines the upstream target. +type ProxyTarget struct { + Name string + URL *url.URL + Meta echo.Map +} - // ProxyBalancer defines an interface to implement a load balancing technique. - ProxyBalancer interface { - AddTarget(*ProxyTarget) bool - RemoveTarget(string) bool - Next(echo.Context) *ProxyTarget - } +// ProxyBalancer defines an interface to implement a load balancing technique. +type ProxyBalancer interface { + AddTarget(*ProxyTarget) bool + RemoveTarget(string) bool + Next(echo.Context) *ProxyTarget +} - // TargetProvider defines an interface that gives the opportunity for balancer - // to return custom errors when selecting target. - TargetProvider interface { - NextTarget(echo.Context) (*ProxyTarget, error) - } +// TargetProvider defines an interface that gives the opportunity for balancer +// to return custom errors when selecting target. +type TargetProvider interface { + NextTarget(echo.Context) (*ProxyTarget, error) +} - commonBalancer struct { - targets []*ProxyTarget - mutex sync.Mutex - } +type commonBalancer struct { + targets []*ProxyTarget + mutex sync.Mutex +} - // RandomBalancer implements a random load balancing technique. - randomBalancer struct { - commonBalancer - random *rand.Rand - } +// RandomBalancer implements a random load balancing technique. +type randomBalancer struct { + commonBalancer + random *rand.Rand +} - // RoundRobinBalancer implements a round-robin load balancing technique. - roundRobinBalancer struct { - commonBalancer - // tracking the index on `targets` slice for the next `*ProxyTarget` to be used - i int - } -) +// RoundRobinBalancer implements a round-robin load balancing technique. +type roundRobinBalancer struct { + commonBalancer + // tracking the index on `targets` slice for the next `*ProxyTarget` to be used + i int +} + +// DefaultProxyConfig is the default Proxy middleware config. +var DefaultProxyConfig = ProxyConfig{ + Skipper: DefaultSkipper, + ContextKey: "target", +} -var ( - // DefaultProxyConfig is the default Proxy middleware config. - DefaultProxyConfig = ProxyConfig{ - Skipper: DefaultSkipper, - ContextKey: "target", +func proxyRaw(t *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { + var dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) + if transport, ok := config.Transport.(*http.Transport); ok { + if transport.TLSClientConfig != nil { + d := tls.Dialer{ + Config: transport.TLSClientConfig, + } + dialFunc = d.DialContext + } + } + if dialFunc == nil { + var d net.Dialer + dialFunc = d.DialContext } -) -func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { in, _, err := c.Response().Hijack() if err != nil { @@ -139,13 +153,11 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { return } defer in.Close() - - out, err := net.Dial("tcp", t.URL.Host) + out, err := dialFunc(c.Request().Context(), "tcp", t.URL.Host) if err != nil { c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL))) return } - defer out.Close() // Write header err = r.Write(out) @@ -359,12 +371,15 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { c.Set("_error", nil) } + // This is needed for ProxyConfig.ModifyResponse and/or ProxyConfig.Transport to be able to process the Request + // that Balancer may have replaced with c.SetRequest. + req = c.Request() + // Proxy switch { case c.IsWebSocket(): - proxyRaw(tgt, c).ServeHTTP(res, req) - case req.Header.Get(echo.HeaderAccept) == "text/event-stream": - default: + proxyRaw(tgt, c, config).ServeHTTP(res, req) + default: // even SSE requests proxyHTTP(tgt, c, config).ServeHTTP(res, req) } diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 1b5ba6cbe..dbf07648b 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -1,8 +1,12 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( "bytes" "context" + "crypto/tls" "errors" "fmt" "io" @@ -17,6 +21,7 @@ import ( "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" + "golang.org/x/net/websocket" ) // Assert expected with url.EscapedPath method to obtain the path. @@ -188,7 +193,7 @@ func TestProxyRealIPHeader(t *testing.T) { tests := []*struct { hasRealIPheader bool hasIPExtractor bool - extectedXRealIP string + expectedXRealIP string }{ {false, false, remoteAddrIP}, {false, true, extractedRealIP}, @@ -210,7 +215,7 @@ func TestProxyRealIPHeader(t *testing.T) { e.IPExtractor = nil } e.ServeHTTP(rec, req) - assert.Equal(t, tt.extectedXRealIP, req.Header.Get(echo.HeaderXRealIP), "hasRealIPheader: %t / hasIPExtractor: %t", tt.hasRealIPheader, tt.hasIPExtractor) + assert.Equal(t, tt.expectedXRealIP, req.Header.Get(echo.HeaderXRealIP), "hasRealIPheader: %t / hasIPExtractor: %t", tt.hasRealIPheader, tt.hasIPExtractor) } } @@ -747,3 +752,291 @@ func TestProxyBalancerWithNoTargets(t *testing.T) { rrb := NewRoundRobinBalancer([]*ProxyTarget{}) assert.Nil(t, rrb.Next(nil)) } + +type testContextKey string + +type customBalancer struct { + target *ProxyTarget +} + +func (b *customBalancer) AddTarget(target *ProxyTarget) bool { + return false +} + +func (b *customBalancer) RemoveTarget(name string) bool { + return false +} + +func (b *customBalancer) Next(c echo.Context) *ProxyTarget { + ctx := context.WithValue(c.Request().Context(), testContextKey("FROM_BALANCER"), "CUSTOM_BALANCER") + c.SetRequest(c.Request().WithContext(ctx)) + return b.target +} + +func TestModifyResponseUseContext(t *testing.T) { + server := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }), + ) + defer server.Close() + + targetURL, _ := url.Parse(server.URL) + e := echo.New() + e.Use(ProxyWithConfig( + ProxyConfig{ + Balancer: &customBalancer{ + target: &ProxyTarget{ + Name: "tst", + URL: targetURL, + }, + }, + RetryCount: 1, + ModifyResponse: func(res *http.Response) error { + val := res.Request.Context().Value(testContextKey("FROM_BALANCER")) + if valStr, ok := val.(string); ok { + res.Header.Set("FROM_BALANCER", valStr) + } + return nil + }, + }, + )) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "OK", rec.Body.String()) + assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER")) +} + +func createSimpleWebSocketServer(serveTLS bool) *httptest.Server { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsHandler := func(conn *websocket.Conn) { + defer conn.Close() + for { + var msg string + err := websocket.Message.Receive(conn, &msg) + if err != nil { + return + } + // message back to the client + websocket.Message.Send(conn, msg) + } + } + websocket.Server{Handler: wsHandler}.ServeHTTP(w, r) + }) + if serveTLS { + return httptest.NewTLSServer(handler) + } + return httptest.NewServer(handler) +} + +func createSimpleProxyServer(t *testing.T, srv *httptest.Server, serveTLS bool, toTLS bool) *httptest.Server { + e := echo.New() + + if toTLS { + // proxy to tls target + tgtURL, _ := url.Parse(srv.URL) + tgtURL.Scheme = "wss" + balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}}) + + defaultTransport, ok := http.DefaultTransport.(*http.Transport) + if !ok { + t.Fatal("Default transport is not of type *http.Transport") + } + transport := defaultTransport.Clone() + transport.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer, Transport: transport})) + } else { + // proxy to non-TLS target + tgtURL, _ := url.Parse(srv.URL) + balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}}) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer})) + } + + if serveTLS { + // serve proxy server with TLS + ts := httptest.NewTLSServer(e) + return ts + } + // serve proxy server without TLS + ts := httptest.NewServer(e) + return ts +} + +// TestProxyWithConfigWebSocketNonTLS2NonTLS tests the proxy with non-TLS to non-TLS WebSocket connection. +func TestProxyWithConfigWebSocketNonTLS2NonTLS(t *testing.T) { + /* + Arrange + */ + // Create a WebSocket test server (non-TLS) + srv := createSimpleWebSocketServer(false) + defer srv.Close() + + // create proxy server (non-TLS to non-TLS) + ts := createSimpleProxyServer(t, srv, false, false) + defer ts.Close() + + tsURL, _ := url.Parse(ts.URL) + tsURL.Scheme = "ws" + tsURL.Path = "/" + + /* + Act + */ + + // Connect to the proxy WebSocket + wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/") + assert.NoError(t, err) + defer wsConn.Close() + + // Send message + sendMsg := "Hello, Non TLS WebSocket!" + err = websocket.Message.Send(wsConn, sendMsg) + assert.NoError(t, err) + + /* + Assert + */ + // Read response + var recvMsg string + err = websocket.Message.Receive(wsConn, &recvMsg) + assert.NoError(t, err) + assert.Equal(t, sendMsg, recvMsg) +} + +// TestProxyWithConfigWebSocketTLS2TLS tests the proxy with TLS to TLS WebSocket connection. +func TestProxyWithConfigWebSocketTLS2TLS(t *testing.T) { + /* + Arrange + */ + // Create a WebSocket test server (TLS) + srv := createSimpleWebSocketServer(true) + defer srv.Close() + + // create proxy server (TLS to TLS) + ts := createSimpleProxyServer(t, srv, true, true) + defer ts.Close() + + tsURL, _ := url.Parse(ts.URL) + tsURL.Scheme = "wss" + tsURL.Path = "/" + + /* + Act + */ + origin, err := url.Parse(ts.URL) + assert.NoError(t, err) + config := &websocket.Config{ + Location: tsURL, + Origin: origin, + TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing + Version: websocket.ProtocolVersionHybi13, + } + wsConn, err := websocket.DialConfig(config) + assert.NoError(t, err) + defer wsConn.Close() + + // Send message + sendMsg := "Hello, TLS to TLS WebSocket!" + err = websocket.Message.Send(wsConn, sendMsg) + assert.NoError(t, err) + + // Read response + var recvMsg string + err = websocket.Message.Receive(wsConn, &recvMsg) + assert.NoError(t, err) + assert.Equal(t, sendMsg, recvMsg) +} + +// TestProxyWithConfigWebSocketNonTLS2TLS tests the proxy with non-TLS to TLS WebSocket connection. +func TestProxyWithConfigWebSocketNonTLS2TLS(t *testing.T) { + /* + Arrange + */ + + // Create a WebSocket test server (TLS) + srv := createSimpleWebSocketServer(true) + defer srv.Close() + + // create proxy server (Non-TLS to TLS) + ts := createSimpleProxyServer(t, srv, false, true) + defer ts.Close() + + tsURL, _ := url.Parse(ts.URL) + tsURL.Scheme = "ws" + tsURL.Path = "/" + + /* + Act + */ + // Connect to the proxy WebSocket + wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/") + assert.NoError(t, err) + defer wsConn.Close() + + // Send message + sendMsg := "Hello, Non TLS to TLS WebSocket!" + err = websocket.Message.Send(wsConn, sendMsg) + assert.NoError(t, err) + + /* + Assert + */ + // Read response + var recvMsg string + err = websocket.Message.Receive(wsConn, &recvMsg) + assert.NoError(t, err) + assert.Equal(t, sendMsg, recvMsg) +} + +// TestProxyWithConfigWebSocketTLSToNoneTLS tests the proxy with TLS to non-TLS WebSocket connection. (TLS termination) +func TestProxyWithConfigWebSocketTLS2NonTLS(t *testing.T) { + /* + Arrange + */ + + // Create a WebSocket test server (non-TLS) + srv := createSimpleWebSocketServer(false) + defer srv.Close() + + // create proxy server (TLS to non-TLS) + ts := createSimpleProxyServer(t, srv, true, false) + defer ts.Close() + + tsURL, _ := url.Parse(ts.URL) + tsURL.Scheme = "wss" + tsURL.Path = "/" + + /* + Act + */ + origin, err := url.Parse(ts.URL) + assert.NoError(t, err) + config := &websocket.Config{ + Location: tsURL, + Origin: origin, + TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing + Version: websocket.ProtocolVersionHybi13, + } + wsConn, err := websocket.DialConfig(config) + assert.NoError(t, err) + defer wsConn.Close() + + // Send message + sendMsg := "Hello, TLS to NoneTLS WebSocket!" + err = websocket.Message.Send(wsConn, sendMsg) + assert.NoError(t, err) + + // Read response + var recvMsg string + err = websocket.Message.Receive(wsConn, &recvMsg) + assert.NoError(t, err) + assert.Equal(t, sendMsg, recvMsg) +} diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index 1d24df52a..70b89b0e2 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -9,39 +12,34 @@ import ( "golang.org/x/time/rate" ) -type ( - // RateLimiterStore is the interface to be implemented by custom stores. - RateLimiterStore interface { - // Stores for the rate limiter have to implement the Allow method - Allow(identifier string) (bool, error) - } -) +// RateLimiterStore is the interface to be implemented by custom stores. +type RateLimiterStore interface { + // Stores for the rate limiter have to implement the Allow method + Allow(identifier string) (bool, error) +} -type ( - // RateLimiterConfig defines the configuration for the rate limiter - RateLimiterConfig struct { - Skipper Skipper - BeforeFunc BeforeFunc - // IdentifierExtractor uses echo.Context to extract the identifier for a visitor - IdentifierExtractor Extractor - // Store defines a store for the rate limiter - Store RateLimiterStore - // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error - ErrorHandler func(context echo.Context, err error) error - // DenyHandler provides a handler to be called when RateLimiter denies access - DenyHandler func(context echo.Context, identifier string, err error) error - } - // Extractor is used to extract data from echo.Context - Extractor func(context echo.Context) (string, error) -) +// RateLimiterConfig defines the configuration for the rate limiter +type RateLimiterConfig struct { + Skipper Skipper + BeforeFunc BeforeFunc + // IdentifierExtractor uses echo.Context to extract the identifier for a visitor + IdentifierExtractor Extractor + // Store defines a store for the rate limiter + Store RateLimiterStore + // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error + ErrorHandler func(context echo.Context, err error) error + // DenyHandler provides a handler to be called when RateLimiter denies access + DenyHandler func(context echo.Context, identifier string, err error) error +} -// errors -var ( - // ErrRateLimitExceeded denotes an error raised when rate limit is exceeded - ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") - // ErrExtractorError denotes an error raised when extractor function is unsuccessful - ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") -) +// Extractor is used to extract data from echo.Context +type Extractor func(context echo.Context) (string, error) + +// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded +var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") + +// ErrExtractorError denotes an error raised when extractor function is unsuccessful +var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") // DefaultRateLimiterConfig defines default values for RateLimiterConfig var DefaultRateLimiterConfig = RateLimiterConfig{ @@ -150,25 +148,24 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { } } -type ( - // RateLimiterMemoryStore is the built-in store implementation for RateLimiter - RateLimiterMemoryStore struct { - visitors map[string]*Visitor - mutex sync.Mutex - rate rate.Limit // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. +// RateLimiterMemoryStore is the built-in store implementation for RateLimiter +type RateLimiterMemoryStore struct { + visitors map[string]*Visitor + mutex sync.Mutex + rate rate.Limit // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. - burst int - expiresIn time.Duration - lastCleanup time.Time + burst int + expiresIn time.Duration + lastCleanup time.Time - timeNow func() time.Time - } - // Visitor signifies a unique user's limiter details - Visitor struct { - *rate.Limiter - lastSeen time.Time - } -) + timeNow func() time.Time +} + +// Visitor signifies a unique user's limiter details +type Visitor struct { + *rate.Limiter + lastSeen time.Time +} /* NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with @@ -194,7 +191,7 @@ NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStor with the provided configuration. Rate must be provided. Burst will be set to the rounded down value of the configured rate if not provided or set to 0. -The build-in memory store is usually capable for modest loads. For higher loads other +The built-in memory store is usually capable for modest loads. For higher loads other store implementations should be considered. Characteristics: diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go index 0f7c9141d..1de7b63e5 100644 --- a/middleware/rate_limiter_test.go +++ b/middleware/rate_limiter_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -10,7 +13,6 @@ import ( "time" "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" "github.com/stretchr/testify/assert" "golang.org/x/time/rate" ) @@ -410,7 +412,7 @@ func TestNewRateLimiterMemoryStore(t *testing.T) { func generateAddressList(count int) []string { addrs := make([]string, count) for i := 0; i < count; i++ { - addrs[i] = random.String(15) + addrs[i] = randomString(15) } return addrs } diff --git a/middleware/recover.go b/middleware/recover.go index 0466cfe56..e6a5940e4 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -9,56 +12,52 @@ import ( "github.com/labstack/gommon/log" ) -type ( +// LogErrorFunc defines a function for custom logging in the middleware. +type LogErrorFunc func(c echo.Context, err error, stack []byte) error + +// RecoverConfig defines the config for Recover middleware. +type RecoverConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Size of the stack to be printed. + // Optional. Default value 4KB. + StackSize int `yaml:"stack_size"` + + // DisableStackAll disables formatting stack traces of all other goroutines + // into buffer after the trace for the current goroutine. + // Optional. Default value false. + DisableStackAll bool `yaml:"disable_stack_all"` + + // DisablePrintStack disables printing stack trace. + // Optional. Default value as false. + DisablePrintStack bool `yaml:"disable_print_stack"` + + // LogLevel is log level to printing stack trace. + // Optional. Default value 0 (Print). + LogLevel log.Lvl + // LogErrorFunc defines a function for custom logging in the middleware. - LogErrorFunc func(c echo.Context, err error, stack []byte) error - - // RecoverConfig defines the config for Recover middleware. - RecoverConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Size of the stack to be printed. - // Optional. Default value 4KB. - StackSize int `yaml:"stack_size"` - - // DisableStackAll disables formatting stack traces of all other goroutines - // into buffer after the trace for the current goroutine. - // Optional. Default value false. - DisableStackAll bool `yaml:"disable_stack_all"` - - // DisablePrintStack disables printing stack trace. - // Optional. Default value as false. - DisablePrintStack bool `yaml:"disable_print_stack"` - - // LogLevel is log level to printing stack trace. - // Optional. Default value 0 (Print). - LogLevel log.Lvl - - // LogErrorFunc defines a function for custom logging in the middleware. - // If it's set you don't need to provide LogLevel for config. - // If this function returns nil, the centralized HTTPErrorHandler will not be called. - LogErrorFunc LogErrorFunc - - // DisableErrorHandler disables the call to centralized HTTPErrorHandler. - // The recovered error is then passed back to upstream middleware, instead of swallowing the error. - // Optional. Default value false. - DisableErrorHandler bool `yaml:"disable_error_handler"` - } -) + // If it's set you don't need to provide LogLevel for config. + // If this function returns nil, the centralized HTTPErrorHandler will not be called. + LogErrorFunc LogErrorFunc + + // DisableErrorHandler disables the call to centralized HTTPErrorHandler. + // The recovered error is then passed back to upstream middleware, instead of swallowing the error. + // Optional. Default value false. + DisableErrorHandler bool `yaml:"disable_error_handler"` +} -var ( - // DefaultRecoverConfig is the default Recover middleware config. - DefaultRecoverConfig = RecoverConfig{ - Skipper: DefaultSkipper, - StackSize: 4 << 10, // 4 KB - DisableStackAll: false, - DisablePrintStack: false, - LogLevel: 0, - LogErrorFunc: nil, - DisableErrorHandler: false, - } -) +// DefaultRecoverConfig is the default Recover middleware config. +var DefaultRecoverConfig = RecoverConfig{ + Skipper: DefaultSkipper, + StackSize: 4 << 10, // 4 KB + DisableStackAll: false, + DisablePrintStack: false, + LogLevel: 0, + LogErrorFunc: nil, + DisableErrorHandler: false, +} // Recover returns a middleware which recovers from panics anywhere in the chain // and handles the control to the centralized HTTPErrorHandler. diff --git a/middleware/recover_test.go b/middleware/recover_test.go index 3e0d35d79..8fa34fa5c 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/redirect.go b/middleware/redirect.go index 13877db38..b772ac131 100644 --- a/middleware/redirect.go +++ b/middleware/redirect.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/redirect_test.go b/middleware/redirect_test.go index 9d1b56205..88068ea2e 100644 --- a/middleware/redirect_test.go +++ b/middleware/redirect_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/request_id.go b/middleware/request_id.go index 8c5ff6605..14bd4fd15 100644 --- a/middleware/request_id.go +++ b/middleware/request_id.go @@ -1,36 +1,34 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" ) -type ( - // RequestIDConfig defines the config for RequestID middleware. - RequestIDConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// RequestIDConfig defines the config for RequestID middleware. +type RequestIDConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Generator defines a function to generate an ID. - // Optional. Default value random.String(32). - Generator func() string + // Generator defines a function to generate an ID. + // Optional. Defaults to generator for random string of length 32. + Generator func() string - // RequestIDHandler defines a function which is executed for a request id. - RequestIDHandler func(echo.Context, string) + // RequestIDHandler defines a function which is executed for a request id. + RequestIDHandler func(echo.Context, string) - // TargetHeader defines what header to look for to populate the id - TargetHeader string - } -) + // TargetHeader defines what header to look for to populate the id + TargetHeader string +} -var ( - // DefaultRequestIDConfig is the default RequestID middleware config. - DefaultRequestIDConfig = RequestIDConfig{ - Skipper: DefaultSkipper, - Generator: generator, - TargetHeader: echo.HeaderXRequestID, - } -) +// DefaultRequestIDConfig is the default RequestID middleware config. +var DefaultRequestIDConfig = RequestIDConfig{ + Skipper: DefaultSkipper, + Generator: generator, + TargetHeader: echo.HeaderXRequestID, +} // RequestID returns a X-Request-ID middleware. func RequestID() echo.MiddlewareFunc { @@ -73,5 +71,5 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { } func generator() string { - return random.String(32) + return randomString(32) } diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go index 21b777826..4e68b126a 100644 --- a/middleware/request_id_test.go +++ b/middleware/request_id_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/request_logger.go b/middleware/request_logger.go index ce76230c7..7c18200b0 100644 --- a/middleware/request_logger.go +++ b/middleware/request_logger.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -8,6 +11,30 @@ import ( "github.com/labstack/echo/v4" ) +// Example for `slog` https://pkg.go.dev/log/slog +// logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) +// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ +// LogStatus: true, +// LogURI: true, +// LogError: true, +// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code +// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { +// if v.Error == nil { +// logger.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST", +// slog.String("uri", v.URI), +// slog.Int("status", v.Status), +// ) +// } else { +// logger.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR", +// slog.String("uri", v.URI), +// slog.Int("status", v.Status), +// slog.String("err", v.Error.Error()), +// ) +// } +// return nil +// }, +// })) +// // Example for `fmt.Printf` // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // LogStatus: true, diff --git a/middleware/request_logger_test.go b/middleware/request_logger_test.go index 51d617abb..c612f5c22 100644 --- a/middleware/request_logger_test.go +++ b/middleware/request_logger_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -194,7 +197,7 @@ func TestRequestLogger_LogValuesFuncError(t *testing.T) { e.ServeHTTP(rec, req) // NOTE: when global error handler received error returned from middleware the status has already - // been written to the client and response has been "commited" therefore global error handler does not do anything + // been written to the client and response has been "committed" therefore global error handler does not do anything // and error that bubbled up in middleware chain will not be reflected in response code. assert.Equal(t, http.StatusTeapot, rec.Code) assert.Equal(t, http.StatusTeapot, expect.Status) diff --git a/middleware/rewrite.go b/middleware/rewrite.go index e5b0a6b56..4c19cc1cc 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -6,37 +9,33 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // RewriteConfig defines the config for Rewrite middleware. - RewriteConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// RewriteConfig defines the config for Rewrite middleware. +type RewriteConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Rules defines the URL path rewrite rules. The values captured in asterisk can be - // retrieved by index e.g. $1, $2 and so on. - // Example: - // "/old": "/new", - // "/api/*": "/$1", - // "/js/*": "/public/javascripts/$1", - // "/users/*/orders/*": "/user/$1/order/$2", - // Required. - Rules map[string]string `yaml:"rules"` + // Rules defines the URL path rewrite rules. The values captured in asterisk can be + // retrieved by index e.g. $1, $2 and so on. + // Example: + // "/old": "/new", + // "/api/*": "/$1", + // "/js/*": "/public/javascripts/$1", + // "/users/*/orders/*": "/user/$1/order/$2", + // Required. + Rules map[string]string `yaml:"rules"` - // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures - // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. - // Example: - // "^/old/[0.9]+/": "/new", - // "^/api/.+?/(.*)": "/v2/$1", - RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"` - } -) + // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRules map[*regexp.Regexp]string `yaml:"-"` +} -var ( - // DefaultRewriteConfig is the default Rewrite middleware config. - DefaultRewriteConfig = RewriteConfig{ - Skipper: DefaultSkipper, - } -) +// DefaultRewriteConfig is the default Rewrite middleware config. +var DefaultRewriteConfig = RewriteConfig{ + Skipper: DefaultSkipper, +} // Rewrite returns a Rewrite middleware. // diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index 47d707c30..d137b2d13 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/secure.go b/middleware/secure.go index 6c4051723..c904abf1a 100644 --- a/middleware/secure.go +++ b/middleware/secure.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -6,84 +9,80 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // SecureConfig defines the config for Secure middleware. - SecureConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // XSSProtection provides protection against cross-site scripting attack (XSS) - // by setting the `X-XSS-Protection` header. - // Optional. Default value "1; mode=block". - XSSProtection string `yaml:"xss_protection"` - - // ContentTypeNosniff provides protection against overriding Content-Type - // header by setting the `X-Content-Type-Options` header. - // Optional. Default value "nosniff". - ContentTypeNosniff string `yaml:"content_type_nosniff"` - - // XFrameOptions can be used to indicate whether or not a browser should - // be allowed to render a page in a ,