Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 6b3da1b

Browse files
committed
Add ability to specify response HTTP status code for Throttle middleware
1 parent ff1d3c6 commit 6b3da1b

File tree

2 files changed

+49
-9
lines changed

2 files changed

+49
-9
lines changed

middleware/throttle.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ type ThrottleOpts struct {
2222
Limit int
2323
BacklogLimit int
2424
BacklogTimeout time.Duration
25+
StatusCode int
2526
}
2627

2728
// Throttle is a middleware that limits number of currently processed requests
@@ -49,10 +50,16 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
4950
panic("chi/middleware: Throttle expects backlogLimit to be positive")
5051
}
5152

53+
statusCode := opts.StatusCode
54+
if statusCode == 0 {
55+
statusCode = http.StatusTooManyRequests
56+
}
57+
5258
t := throttler{
5359
tokens: make(chan token, opts.Limit),
5460
backlogTokens: make(chan token, opts.Limit+opts.BacklogLimit),
5561
backlogTimeout: opts.BacklogTimeout,
62+
statusCode: statusCode,
5663
retryAfterFn: opts.RetryAfterFn,
5764
}
5865

@@ -72,7 +79,7 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
7279

7380
case <-ctx.Done():
7481
t.setRetryAfterHeaderIfNeeded(w, true)
75-
http.Error(w, errContextCanceled, http.StatusTooManyRequests)
82+
http.Error(w, errContextCanceled, t.statusCode)
7683
return
7784

7885
case btok := <-t.backlogTokens:
@@ -85,12 +92,12 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
8592
select {
8693
case <-timer.C:
8794
t.setRetryAfterHeaderIfNeeded(w, false)
88-
http.Error(w, errTimedOut, http.StatusTooManyRequests)
95+
http.Error(w, errTimedOut, t.statusCode)
8996
return
9097
case <-ctx.Done():
9198
timer.Stop()
9299
t.setRetryAfterHeaderIfNeeded(w, true)
93-
http.Error(w, errContextCanceled, http.StatusTooManyRequests)
100+
http.Error(w, errContextCanceled, t.statusCode)
94101
return
95102
case tok := <-t.tokens:
96103
defer func() {
@@ -103,7 +110,7 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
103110

104111
default:
105112
t.setRetryAfterHeaderIfNeeded(w, false)
106-
http.Error(w, errCapacityExceeded, http.StatusTooManyRequests)
113+
http.Error(w, errCapacityExceeded, t.statusCode)
107114
return
108115
}
109116
}
@@ -119,8 +126,9 @@ type token struct{}
119126
type throttler struct {
120127
tokens chan token
121128
backlogTokens chan token
122-
retryAfterFn func(ctxDone bool) time.Duration
123129
backlogTimeout time.Duration
130+
statusCode int
131+
retryAfterFn func(ctxDone bool) time.Duration
124132
}
125133

126134
// setRetryAfterHeaderIfNeeded sets Retry-After HTTP header if corresponding retryAfterFn option of throttler is initialized.

middleware/throttle_test.go

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ func TestThrottleTriggerGatewayTimeout(t *testing.T) {
116116
res, err := client.Get(server.URL)
117117
assertNoError(t, err)
118118
assertEqual(t, http.StatusOK, res.StatusCode)
119-
120119
}(i)
121120
}
122121

@@ -136,7 +135,6 @@ func TestThrottleTriggerGatewayTimeout(t *testing.T) {
136135
assertNoError(t, err)
137136
assertEqual(t, http.StatusTooManyRequests, res.StatusCode)
138137
assertEqual(t, errTimedOut, strings.TrimSpace(string(buf)))
139-
140138
}(i)
141139
}
142140

@@ -175,7 +173,6 @@ func TestThrottleMaximum(t *testing.T) {
175173
buf, err := ioutil.ReadAll(res.Body)
176174
assertNoError(t, err)
177175
assertEqual(t, testContent, buf)
178-
179176
}(i)
180177
}
181178

@@ -196,7 +193,6 @@ func TestThrottleMaximum(t *testing.T) {
196193
assertNoError(t, err)
197194
assertEqual(t, http.StatusTooManyRequests, res.StatusCode)
198195
assertEqual(t, errCapacityExceeded, strings.TrimSpace(string(buf)))
199-
200196
}(i)
201197
}
202198

@@ -252,3 +248,39 @@ func TestThrottleMaximum(t *testing.T) {
252248
253249
wg.Wait()
254250
}*/
251+
252+
func TestThrottleCustomStatusCode(t *testing.T) {
253+
block := make(chan struct{})
254+
255+
r := chi.NewRouter()
256+
r.Use(ThrottleWithOpts(ThrottleOpts{Limit: 1, StatusCode: http.StatusServiceUnavailable}))
257+
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
258+
w.WriteHeader(http.StatusOK)
259+
block <- struct{}{}
260+
block <- struct{}{}
261+
w.Write(testContent)
262+
})
263+
server := httptest.NewServer(r)
264+
defer server.Close()
265+
266+
client := http.Client{
267+
Timeout: time.Second * 60, // Maximum waiting time.
268+
}
269+
270+
done := make(chan struct{})
271+
272+
go func() {
273+
res, err := client.Get(server.URL)
274+
assertNoError(t, err)
275+
assertEqual(t, http.StatusOK, res.StatusCode)
276+
done <- struct{}{}
277+
}()
278+
279+
<-block
280+
res, err := client.Get(server.URL)
281+
assertNoError(t, err)
282+
assertEqual(t, http.StatusServiceUnavailable, res.StatusCode)
283+
<-block
284+
285+
<-done
286+
}

0 commit comments

Comments
 (0)