Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b3b5b75

Browse files
committed
added unit tests for the prometheus middleware
1 parent d9f3dfa commit b3b5b75

File tree

2 files changed

+96
-1
lines changed

2 files changed

+96
-1
lines changed

coderd/httpmw/prometheus.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
102102

103103
func getRoutePattern(r *http.Request) string {
104104
rctx := chi.RouteContext(r.Context())
105+
if rctx == nil {
106+
return ""
107+
}
108+
105109
if pattern := rctx.RoutePattern(); pattern != "" {
106110
// Pattern is already available
107111
return pattern
@@ -113,7 +117,8 @@ func getRoutePattern(r *http.Request) string {
113117
}
114118

115119
tctx := chi.NewRouteContext()
116-
if !rctx.Routes.Match(tctx, r.Method, routePath) {
120+
routes := rctx.Routes
121+
if routes != nil && routes.Match(tctx, r.Method, routePath) {
117122
// No matching pattern, so just return an empty string.
118123
// It is done to avoid returning a static path for frontend requests.
119124
return ""

coderd/httpmw/prometheus_test.go

+90
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@ import (
88

99
"github.com/go-chi/chi/v5"
1010
"github.com/prometheus/client_golang/prometheus"
11+
cm "github.com/prometheus/client_model/go"
12+
"github.com/stretchr/testify/assert"
1113
"github.com/stretchr/testify/require"
1214

1315
"github.com/coder/coder/v2/coderd/httpmw"
1416
"github.com/coder/coder/v2/coderd/tracing"
17+
"github.com/coder/coder/v2/testutil"
18+
"github.com/coder/websocket"
1519
)
1620

1721
func TestPrometheus(t *testing.T) {
@@ -30,3 +34,89 @@ func TestPrometheus(t *testing.T) {
3034
require.Greater(t, len(metrics), 0)
3135
})
3236
}
37+
38+
func TestPrometheus_Concurrent(t *testing.T) {
39+
t.Parallel()
40+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
41+
defer cancel()
42+
43+
reg := prometheus.NewRegistry()
44+
promMW := httpmw.Prometheus(reg)
45+
46+
// Create a test handler to simulate a WebSocket connection
47+
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
48+
conn, err := websocket.Accept(rw, r, nil)
49+
if !assert.NoError(t, err, "failed to accept websocket") {
50+
return
51+
}
52+
defer conn.Close(websocket.StatusGoingAway, "")
53+
})
54+
55+
wrappedHandler := promMW(testHandler)
56+
57+
r := chi.NewRouter()
58+
r.Use(tracing.StatusWriterMiddleware, promMW)
59+
r.Get("/api/v2/build/{build}/logs", func(rw http.ResponseWriter, r *http.Request) {
60+
wrappedHandler.ServeHTTP(rw, r)
61+
})
62+
63+
srv := httptest.NewServer(r)
64+
defer srv.Close()
65+
// nolint: bodyclose
66+
conn, _, err := websocket.Dial(ctx, srv.URL+"/api/v2/build/1/logs", nil)
67+
require.NoError(t, err, "failed to dial WebSocket")
68+
defer conn.Close(websocket.StatusNormalClosure, "")
69+
70+
metrics, err := reg.Gather()
71+
require.NoError(t, err)
72+
require.Greater(t, len(metrics), 0)
73+
metricLabels := getMetricLabels(metrics)
74+
75+
concurrentWebsockets, ok := metricLabels["coderd_api_concurrent_websockets"]
76+
require.True(t, ok, "coderd_api_concurrent_websockets metric not found")
77+
require.Equal(t, "/api/v2/build/{build}/logs", concurrentWebsockets["path"])
78+
}
79+
80+
func TestGetRoutePattern_UserRoute(t *testing.T) {
81+
t.Parallel()
82+
reg := prometheus.NewRegistry()
83+
promMW := httpmw.Prometheus(reg)
84+
85+
r := chi.NewRouter()
86+
r.With(promMW).Get("/api/v2/users/{user}", func(w http.ResponseWriter, r *http.Request) {})
87+
88+
req := httptest.NewRequest("GET", "/api/v2/users/john", nil)
89+
90+
sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
91+
92+
r.ServeHTTP(sw, req)
93+
94+
metrics, err := reg.Gather()
95+
require.NoError(t, err)
96+
require.Greater(t, len(metrics), 0)
97+
metricLabels := getMetricLabels(metrics)
98+
99+
reqProcessed, ok := metricLabels["coderd_api_requests_processed_total"]
100+
require.True(t, ok, "coderd_api_requests_processed_total metric not found")
101+
require.Equal(t, "/api/v2/users/{user}", reqProcessed["path"])
102+
require.Equal(t, "GET", reqProcessed["method"])
103+
104+
concurrentRequests, ok := metricLabels["coderd_api_concurrent_requests"]
105+
require.True(t, ok, "coderd_api_concurrent_requests metric not found")
106+
require.Equal(t, "/api/v2/users/{user}", concurrentRequests["path"])
107+
require.Equal(t, "GET", concurrentRequests["method"])
108+
}
109+
110+
func getMetricLabels(metrics []*cm.MetricFamily) map[string]map[string]string {
111+
metricLabels := map[string]map[string]string{}
112+
for _, metricFamily := range metrics {
113+
metricName := metricFamily.GetName()
114+
metricLabels[metricName] = map[string]string{}
115+
for _, metric := range metricFamily.GetMetric() {
116+
for _, labelPair := range metric.GetLabel() {
117+
metricLabels[metricName][labelPair.GetName()] = labelPair.GetValue()
118+
}
119+
}
120+
}
121+
return metricLabels
122+
}

0 commit comments

Comments
 (0)