|
1 | 1 | package httpmw_test
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "context" |
| 5 | + "crypto/sha256" |
| 6 | + "fmt" |
| 7 | + "math/rand" |
| 8 | + "net" |
4 | 9 | "net/http"
|
5 | 10 | "net/http/httptest"
|
6 | 11 | "testing"
|
| 12 | + "time" |
7 | 13 |
|
8 | 14 | "github.com/go-chi/chi/v5"
|
| 15 | + "github.com/google/uuid" |
9 | 16 | "github.com/stretchr/testify/require"
|
10 | 17 |
|
| 18 | + "github.com/coder/coder/coderd/database" |
| 19 | + "github.com/coder/coder/coderd/database/databasefake" |
11 | 20 | "github.com/coder/coder/coderd/httpmw"
|
| 21 | + "github.com/coder/coder/coderd/rbac" |
| 22 | + "github.com/coder/coder/codersdk" |
12 | 23 | "github.com/coder/coder/testutil"
|
13 | 24 | )
|
14 | 25 |
|
| 26 | +func insertAPIKey(ctx context.Context, t *testing.T, db database.Store, userID uuid.UUID) string { |
| 27 | + id, secret := randomAPIKeyParts() |
| 28 | + hashed := sha256.Sum256([]byte(secret)) |
| 29 | + |
| 30 | + _, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ |
| 31 | + ID: id, |
| 32 | + HashedSecret: hashed[:], |
| 33 | + LastUsed: database.Now().AddDate(0, 0, -1), |
| 34 | + ExpiresAt: database.Now().AddDate(0, 0, 1), |
| 35 | + UserID: userID, |
| 36 | + LoginType: database.LoginTypePassword, |
| 37 | + Scope: database.APIKeyScopeAll, |
| 38 | + }) |
| 39 | + require.NoError(t, err) |
| 40 | + |
| 41 | + return fmt.Sprintf("%s-%s", id, secret) |
| 42 | +} |
| 43 | + |
| 44 | +func randRemoteAddr() string { |
| 45 | + var b [4]byte |
| 46 | + // nolint:gosec |
| 47 | + rand.Read(b[:]) |
| 48 | + // nolint:gosec |
| 49 | + return fmt.Sprintf("%s:%v", net.IP(b[:]).String(), rand.Int31()%(1<<16)) |
| 50 | +} |
| 51 | + |
15 | 52 | func TestRateLimit(t *testing.T) {
|
16 | 53 | t.Parallel()
|
17 |
| - t.Run("NoUser", func(t *testing.T) { |
| 54 | + t.Run("NoUserSucceeds", func(t *testing.T) { |
| 55 | + t.Parallel() |
| 56 | + rtr := chi.NewRouter() |
| 57 | + rtr.Use(httpmw.RateLimit(5, time.Second)) |
| 58 | + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { |
| 59 | + rw.WriteHeader(http.StatusOK) |
| 60 | + }) |
| 61 | + |
| 62 | + require.Eventually(t, func() bool { |
| 63 | + req := httptest.NewRequest("GET", "/", nil) |
| 64 | + rec := httptest.NewRecorder() |
| 65 | + rtr.ServeHTTP(rec, req) |
| 66 | + resp := rec.Result() |
| 67 | + defer resp.Body.Close() |
| 68 | + return resp.StatusCode == http.StatusTooManyRequests |
| 69 | + }, testutil.WaitShort, testutil.IntervalFast) |
| 70 | + }) |
| 71 | + |
| 72 | + t.Run("RandomIPs", func(t *testing.T) { |
| 73 | + t.Parallel() |
| 74 | + rtr := chi.NewRouter() |
| 75 | + rtr.Use(httpmw.RateLimit(5, time.Second)) |
| 76 | + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { |
| 77 | + rw.WriteHeader(http.StatusOK) |
| 78 | + }) |
| 79 | + |
| 80 | + require.Never(t, func() bool { |
| 81 | + req := httptest.NewRequest("GET", "/", nil) |
| 82 | + rec := httptest.NewRecorder() |
| 83 | + req.RemoteAddr = randRemoteAddr() |
| 84 | + rtr.ServeHTTP(rec, req) |
| 85 | + resp := rec.Result() |
| 86 | + defer resp.Body.Close() |
| 87 | + return resp.StatusCode == http.StatusTooManyRequests |
| 88 | + }, testutil.WaitShort, testutil.IntervalFast) |
| 89 | + }) |
| 90 | + |
| 91 | + t.Run("RegularUser", func(t *testing.T) { |
18 | 92 | t.Parallel()
|
| 93 | + |
| 94 | + ctx := context.Background() |
| 95 | + |
| 96 | + db := databasefake.New() |
| 97 | + |
| 98 | + u := createUser(ctx, t, db) |
| 99 | + key := insertAPIKey(ctx, t, db, u.ID) |
| 100 | + |
19 | 101 | rtr := chi.NewRouter()
|
20 |
| - rtr.Use(httpmw.RateLimitPerMinute(5)) |
| 102 | + rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ |
| 103 | + DB: db, |
| 104 | + Optional: false, |
| 105 | + })) |
| 106 | + |
| 107 | + rtr.Use(httpmw.RateLimit(5, time.Second)) |
21 | 108 | rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
22 | 109 | rw.WriteHeader(http.StatusOK)
|
23 | 110 | })
|
24 | 111 |
|
| 112 | + // Bypass must fail |
| 113 | + req := httptest.NewRequest("GET", "/", nil) |
| 114 | + req.Header.Set(codersdk.SessionCustomHeader, key) |
| 115 | + req.Header.Set(codersdk.BypassRatelimitHeader, "true") |
| 116 | + rec := httptest.NewRecorder() |
| 117 | + // Assert we're not using IP address. |
| 118 | + req.RemoteAddr = randRemoteAddr() |
| 119 | + rtr.ServeHTTP(rec, req) |
| 120 | + resp := rec.Result() |
| 121 | + defer resp.Body.Close() |
| 122 | + require.Equal(t, http.StatusPreconditionRequired, resp.StatusCode) |
| 123 | + |
25 | 124 | require.Eventually(t, func() bool {
|
26 | 125 | req := httptest.NewRequest("GET", "/", nil)
|
| 126 | + req.Header.Set(codersdk.SessionCustomHeader, key) |
| 127 | + rec := httptest.NewRecorder() |
| 128 | + // Assert we're not using IP address. |
| 129 | + req.RemoteAddr = randRemoteAddr() |
| 130 | + rtr.ServeHTTP(rec, req) |
| 131 | + resp := rec.Result() |
| 132 | + defer resp.Body.Close() |
| 133 | + return resp.StatusCode == http.StatusTooManyRequests |
| 134 | + }, testutil.WaitShort, testutil.IntervalFast) |
| 135 | + }) |
| 136 | + |
| 137 | + t.Run("OwnerBypass", func(t *testing.T) { |
| 138 | + t.Parallel() |
| 139 | + |
| 140 | + ctx := context.Background() |
| 141 | + |
| 142 | + db := databasefake.New() |
| 143 | + |
| 144 | + u := createUser(ctx, t, db, func(u *database.InsertUserParams) { |
| 145 | + u.RBACRoles = []string{rbac.RoleOwner()} |
| 146 | + }) |
| 147 | + |
| 148 | + key := insertAPIKey(ctx, t, db, u.ID) |
| 149 | + |
| 150 | + rtr := chi.NewRouter() |
| 151 | + rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ |
| 152 | + DB: db, |
| 153 | + Optional: false, |
| 154 | + })) |
| 155 | + |
| 156 | + rtr.Use(httpmw.RateLimit(5, time.Second)) |
| 157 | + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { |
| 158 | + rw.WriteHeader(http.StatusOK) |
| 159 | + }) |
| 160 | + |
| 161 | + require.Never(t, func() bool { |
| 162 | + req := httptest.NewRequest("GET", "/", nil) |
| 163 | + req.Header.Set(codersdk.SessionCustomHeader, key) |
| 164 | + req.Header.Set(codersdk.BypassRatelimitHeader, "true") |
27 | 165 | rec := httptest.NewRecorder()
|
| 166 | + // Assert we're not using IP address. |
| 167 | + req.RemoteAddr = randRemoteAddr() |
28 | 168 | rtr.ServeHTTP(rec, req)
|
29 | 169 | resp := rec.Result()
|
30 | 170 | defer resp.Body.Close()
|
|
0 commit comments