Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1f927a8

Browse files
authored
Ability to specify response HTTP status code for Throttle middleware (#571)
* Fix typo in doc comment for Throttle middleware * Add ability to specify response HTTP status code for Throttle middleware
1 parent 2c4d128 commit 1f927a8

File tree

2 files changed

+65
-10
lines changed

2 files changed

+65
-10
lines changed

middleware/throttle.go

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ type ThrottleOpts struct {
2222
Limit int
2323
BacklogLimit int
2424
BacklogTimeout time.Duration
25+
StatusCode int
2526
}
2627

2728
// Throttle is a middleware that limits number of currently processed requests
2829
// at a time across all users. Note: Throttle is not a rate-limiter per user,
29-
// instead it just puts a ceiling on the number of currently in-flight requests
30+
// instead it just puts a ceiling on the number of current in-flight requests
3031
// being processed from the point from where the Throttle middleware is mounted.
3132
func Throttle(limit int) func(http.Handler) http.Handler {
3233
return ThrottleWithOpts(ThrottleOpts{Limit: limit, BacklogTimeout: defaultBacklogTimeout})
@@ -49,10 +50,16 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
4950
panic("chi/middleware: Throttle expects backlogLimit to be positive")
5051
}
5152

53+
statusCode := opts.StatusCode
54+
if statusCode == 0 {
55+
statusCode = http.StatusTooManyRequests
56+
}
57+
5258
t := throttler{
5359
tokens: make(chan token, opts.Limit),
5460
backlogTokens: make(chan token, opts.Limit+opts.BacklogLimit),
5561
backlogTimeout: opts.BacklogTimeout,
62+
statusCode: statusCode,
5663
retryAfterFn: opts.RetryAfterFn,
5764
}
5865

@@ -72,7 +79,7 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
7279

7380
case <-ctx.Done():
7481
t.setRetryAfterHeaderIfNeeded(w, true)
75-
http.Error(w, errContextCanceled, http.StatusTooManyRequests)
82+
http.Error(w, errContextCanceled, t.statusCode)
7683
return
7784

7885
case btok := <-t.backlogTokens:
@@ -85,12 +92,12 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
8592
select {
8693
case <-timer.C:
8794
t.setRetryAfterHeaderIfNeeded(w, false)
88-
http.Error(w, errTimedOut, http.StatusTooManyRequests)
95+
http.Error(w, errTimedOut, t.statusCode)
8996
return
9097
case <-ctx.Done():
9198
timer.Stop()
9299
t.setRetryAfterHeaderIfNeeded(w, true)
93-
http.Error(w, errContextCanceled, http.StatusTooManyRequests)
100+
http.Error(w, errContextCanceled, t.statusCode)
94101
return
95102
case tok := <-t.tokens:
96103
defer func() {
@@ -103,7 +110,7 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
103110

104111
default:
105112
t.setRetryAfterHeaderIfNeeded(w, false)
106-
http.Error(w, errCapacityExceeded, http.StatusTooManyRequests)
113+
http.Error(w, errCapacityExceeded, t.statusCode)
107114
return
108115
}
109116
}
@@ -119,8 +126,9 @@ type token struct{}
119126
type throttler struct {
120127
tokens chan token
121128
backlogTokens chan token
122-
retryAfterFn func(ctxDone bool) time.Duration
123129
backlogTimeout time.Duration
130+
statusCode int
131+
retryAfterFn func(ctxDone bool) time.Duration
124132
}
125133

126134
// setRetryAfterHeaderIfNeeded sets Retry-After HTTP header if corresponding retryAfterFn option of throttler is initialized.

middleware/throttle_test.go

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ func TestThrottleTriggerGatewayTimeout(t *testing.T) {
116116
res, err := client.Get(server.URL)
117117
assertNoError(t, err)
118118
assertEqual(t, http.StatusOK, res.StatusCode)
119-
120119
}(i)
121120
}
122121

@@ -136,7 +135,6 @@ func TestThrottleTriggerGatewayTimeout(t *testing.T) {
136135
assertNoError(t, err)
137136
assertEqual(t, http.StatusTooManyRequests, res.StatusCode)
138137
assertEqual(t, errTimedOut, strings.TrimSpace(string(buf)))
139-
140138
}(i)
141139
}
142140

@@ -175,7 +173,6 @@ func TestThrottleMaximum(t *testing.T) {
175173
buf, err := ioutil.ReadAll(res.Body)
176174
assertNoError(t, err)
177175
assertEqual(t, testContent, buf)
178-
179176
}(i)
180177
}
181178

@@ -196,7 +193,6 @@ func TestThrottleMaximum(t *testing.T) {
196193
assertNoError(t, err)
197194
assertEqual(t, http.StatusTooManyRequests, res.StatusCode)
198195
assertEqual(t, errCapacityExceeded, strings.TrimSpace(string(buf)))
199-
200196
}(i)
201197
}
202198

@@ -252,3 +248,54 @@ func TestThrottleMaximum(t *testing.T) {
252248
253249
wg.Wait()
254250
}*/
251+
252+
func TestThrottleCustomStatusCode(t *testing.T) {
253+
const timeout = time.Second * 3
254+
255+
wait := make(chan struct{})
256+
257+
r := chi.NewRouter()
258+
r.Use(ThrottleWithOpts(ThrottleOpts{Limit: 1, StatusCode: http.StatusServiceUnavailable}))
259+
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
260+
select {
261+
case <-wait:
262+
case <-time.After(timeout):
263+
}
264+
w.WriteHeader(http.StatusOK)
265+
})
266+
server := httptest.NewServer(r)
267+
defer server.Close()
268+
269+
const totalRequestCount = 5
270+
271+
codes := make(chan int, totalRequestCount)
272+
errs := make(chan error, totalRequestCount)
273+
client := &http.Client{Timeout: timeout}
274+
for i := 0; i < totalRequestCount; i++ {
275+
go func() {
276+
resp, err := client.Get(server.URL)
277+
if err != nil {
278+
errs <- err
279+
return
280+
}
281+
codes <- resp.StatusCode
282+
}()
283+
}
284+
285+
waitResponse := func(wantCode int) {
286+
select {
287+
case err := <-errs:
288+
t.Fatal(err)
289+
case code := <-codes:
290+
assertEqual(t, wantCode, code)
291+
case <-time.After(timeout):
292+
t.Fatalf("waiting %d code, timeout exceeded", wantCode)
293+
}
294+
}
295+
296+
for i := 0; i < totalRequestCount-1; i++ {
297+
waitResponse(http.StatusServiceUnavailable)
298+
}
299+
close(wait) // Allow the last request to proceed.
300+
waitResponse(http.StatusOK)
301+
}

0 commit comments

Comments
 (0)