Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/middleware/requestid.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ app.Use(requestid.New(requestid.Config{
}))
```

If the request already includes the configured header, that value is reused instead of generating a new one.
If the request already includes the configured header, that value is reused instead of generating a new one. The middleware
rejects IDs containing characters outside the visible ASCII range (for example, control characters or obs-text bytes) and
will regenerate the value using up to three attempts from the configured generator (or UUID when no generator is set). When a
custom generator fails to produce a valid ID, the middleware falls back to three UUID attempts to keep headers RFC-compliant
across transports.

Retrieve the request ID

Expand Down
56 changes: 53 additions & 3 deletions middleware/requestid/requestid.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package requestid

import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/utils/v2"
)

// The contextKey type is unexported to prevent collisions with context keys defined in
Expand All @@ -24,10 +25,9 @@ func New(config ...Config) fiber.Handler {
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Get id from request, else we generate one
rid := c.Get(cfg.Header)
rid := sanitizeRequestID(c.Get(cfg.Header), cfg.Generator)
if rid == "" {
rid = cfg.Generator()
rid = utils.UUID()
}

// Set new id to response header
Expand All @@ -41,6 +41,56 @@ func New(config ...Config) fiber.Handler {
}
}

// sanitizeRequestID returns the provided request ID when it is valid, otherwise
// it tries up to three values from the configured generator (or UUID when no
// generator is set), then three UUIDs if a custom generator failed, falling
// back to an empty string when no visible ASCII ID is produced.
func sanitizeRequestID(rid string, generator func() string) string {
if isValidRequestID(rid) {
return rid
}

generatorFn := generator
if generatorFn == nil {
generatorFn = utils.UUID
}

for range 3 {
rid = generatorFn()
if isValidRequestID(rid) {
return rid
}
}

if generator != nil {
for range 3 {
rid = utils.UUID()
if isValidRequestID(rid) {
return rid
}
}
}

return ""
}

// isValidRequestID reports whether the request ID contains only visible ASCII
// characters (0x20–0x7E) and is non-empty.
func isValidRequestID(rid string) bool {
if rid == "" {
return false
}

for i := 0; i < len(rid); i++ {
c := rid[i]
if c < 0x20 || c > 0x7e {
return false
}
}

return true
}

// FromContext returns the request ID from context.
// If there is no request ID, an empty string is returned.
func FromContext(c fiber.Ctx) string {
Expand Down
70 changes: 70 additions & 0 deletions middleware/requestid/requestid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,76 @@ func Test_RequestID(t *testing.T) {
require.Equal(t, reqid, resp.Header.Get(fiber.HeaderXRequestID))
}

func Test_RequestID_InvalidHeaderValue(t *testing.T) {
t.Parallel()

rid := sanitizeRequestID("bad\r\nid", func() string {
return "clean-generated-id"
})

require.Equal(t, "clean-generated-id", rid)
}

func Test_RequestID_InvalidGeneratedValue(t *testing.T) {
t.Parallel()

app := fiber.New()
app.Use(New(Config{
Generator: func() string {
return "bad\r\nid"
},
}))

app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)

rid := resp.Header.Get(fiber.HeaderXRequestID)
require.NotEmpty(t, rid)
require.NotContains(t, rid, "\r")
require.NotContains(t, rid, "\n")
require.Len(t, rid, 36, "Fallback should produce a UUID")
}

func Test_isValidRequestID_VisibleASCII(t *testing.T) {
t.Parallel()

require.True(t, isValidRequestID("request-id-09AZaz ~"))
}

func Test_isValidRequestID_Boundaries(t *testing.T) {
t.Parallel()

t.Run("allows space and tilde", func(t *testing.T) {
t.Parallel()

require.True(t, isValidRequestID(" ~"))
})

t.Run("rejects out of range", func(t *testing.T) {
t.Parallel()

require.False(t, isValidRequestID(string([]byte{0x1f})))
require.False(t, isValidRequestID(string([]byte{0x7f})))
})

t.Run("rejects empty", func(t *testing.T) {
t.Parallel()

require.False(t, isValidRequestID(""))
})
}

func Test_isValidRequestID_RejectsObsText(t *testing.T) {
t.Parallel()

require.False(t, isValidRequestID("valid\xff"))
}

// go test -run Test_RequestID_Next
func Test_RequestID_Next(t *testing.T) {
t.Parallel()
Expand Down
Loading