diff --git a/coderd/httpmw/provisionerdaemon.go b/coderd/httpmw/provisionerdaemon.go index b2b4e2c04088e..e8a50ae0fc3b3 100644 --- a/coderd/httpmw/provisionerdaemon.go +++ b/coderd/httpmw/provisionerdaemon.go @@ -25,6 +25,9 @@ type ExtractProvisionerAuthConfig struct { PSK string } +// ExtractProvisionerDaemonAuthenticated authenticates a request as a provisioner daemon. +// If the request is not authenticated, the next handler is called unless Optional is true. +// This function currently is tested inside the enterprise package. func ExtractProvisionerDaemonAuthenticated(opts ExtractProvisionerAuthConfig) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/enterprise/coderd/httpmw/doc.go b/enterprise/coderd/httpmw/doc.go new file mode 100644 index 0000000000000..ef48f0f6e0498 --- /dev/null +++ b/enterprise/coderd/httpmw/doc.go @@ -0,0 +1,5 @@ +// Package httpmw contains middleware for HTTP handlers. +// Currently, the tested middleware is inside the AGPL package. +// As the middleware also contains enterprise-only logic, tests had to be +// moved here. +package httpmw diff --git a/enterprise/coderd/httpmw/provisionerdaemon_test.go b/enterprise/coderd/httpmw/provisionerdaemon_test.go new file mode 100644 index 0000000000000..84da7f546fa35 --- /dev/null +++ b/enterprise/coderd/httpmw/provisionerdaemon_test.go @@ -0,0 +1,290 @@ +package httpmw_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" + "github.com/coder/coder/v2/enterprise/coderd/license" + "github.com/coder/coder/v2/testutil" +) + +func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { + const ( + //nolint:gosec // test key generated by test + functionalKey = "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4" + ) + t.Parallel() + + tests := []struct { + name string + opts httpmw.ExtractProvisionerAuthConfig + expectedStatusCode int + expectedResponseMessage string + provisionerKey string + provisionerPSK string + }{ + { + name: "NoKeyProvided_Optional", + opts: httpmw.ExtractProvisionerAuthConfig{ + DB: nil, + Optional: true, + }, + expectedStatusCode: http.StatusOK, + }, + { + name: "NoKeyProvided_NotOptional", + opts: httpmw.ExtractProvisionerAuthConfig{ + DB: nil, + Optional: false, + }, + expectedStatusCode: http.StatusUnauthorized, + expectedResponseMessage: "provisioner daemon key required", + }, + { + name: "ProvisionerKeyAndPSKProvided_NotOptional", + opts: httpmw.ExtractProvisionerAuthConfig{ + DB: nil, + Optional: false, + }, + provisionerKey: "key", + provisionerPSK: "psk", + expectedStatusCode: http.StatusBadRequest, + expectedResponseMessage: "provisioner daemon key and psk provided, but only one is allowed", + }, + { + name: "ProvisionerKeyAndPSKProvided_Optional", + opts: httpmw.ExtractProvisionerAuthConfig{ + DB: nil, + Optional: true, + }, + provisionerKey: "key", + expectedStatusCode: http.StatusOK, + }, + { + name: "InvalidProvisionerKey_NotOptional", + opts: httpmw.ExtractProvisionerAuthConfig{ + DB: nil, + Optional: false, + }, + provisionerKey: "invalid", + expectedStatusCode: http.StatusBadRequest, + expectedResponseMessage: "provisioner daemon key invalid", + }, + { + name: "InvalidProvisionerKey_Optional", + opts: httpmw.ExtractProvisionerAuthConfig{ + DB: nil, + Optional: true, + }, + provisionerKey: "invalid", + expectedStatusCode: http.StatusOK, + }, + { + name: "InvalidProvisionerPSK_NotOptional", + opts: httpmw.ExtractProvisionerAuthConfig{ + DB: nil, + Optional: false, + PSK: "psk", + }, + provisionerPSK: "invalid", + expectedStatusCode: http.StatusUnauthorized, + expectedResponseMessage: "provisioner daemon psk invalid", + }, + { + name: "InvalidProvisionerPSK_Optional", + opts: httpmw.ExtractProvisionerAuthConfig{ + DB: nil, + Optional: true, + PSK: "psk", + }, + provisionerPSK: "invalid", + expectedStatusCode: http.StatusOK, + }, + { + name: "ValidProvisionerPSK_NotOptional", + opts: httpmw.ExtractProvisionerAuthConfig{ + DB: nil, + Optional: false, + PSK: "ThisIsAValidPSK", + }, + provisionerPSK: "ThisIsAValidPSK", + expectedStatusCode: http.StatusOK, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + routeCtx := chi.NewRouteContext() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) + res := httptest.NewRecorder() + + if test.provisionerKey != "" { + r.Header.Set(codersdk.ProvisionerDaemonKey, test.provisionerKey) + } + if test.provisionerPSK != "" { + r.Header.Set(codersdk.ProvisionerDaemonPSK, test.provisionerPSK) + } + + httpmw.ExtractProvisionerDaemonAuthenticated(test.opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(res, r) + + //nolint:bodyclose + require.Equal(t, test.expectedStatusCode, res.Result().StatusCode) + if test.expectedResponseMessage != "" { + require.Contains(t, res.Body.String(), test.expectedResponseMessage) + } + }) + } + + t.Run("ProvisionerKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureExternalProvisionerDaemons: 1, + }, + }, + }) + // nolint:gocritic // test + key, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{ + Name: "dont-TEST-me", + }) + require.NoError(t, err) + + routeCtx := chi.NewRouteContext() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) + res := httptest.NewRecorder() + + r.Header.Set(codersdk.ProvisionerDaemonKey, key.Key) + + httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ + DB: db, + Optional: false, + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(res, r) + + //nolint:bodyclose + require.Equal(t, http.StatusOK, res.Result().StatusCode) + }) + + t.Run("ProvisionerKey_NotFound", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureExternalProvisionerDaemons: 1, + }, + }, + }) + // nolint:gocritic // test + _, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{ + Name: "dont-TEST-me", + }) + require.NoError(t, err) + + routeCtx := chi.NewRouteContext() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) + res := httptest.NewRecorder() + + //nolint:gosec // test key generated by test + pkey := "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4" + r.Header.Set(codersdk.ProvisionerDaemonKey, pkey) + + httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ + DB: db, + Optional: false, + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(res, r) + + //nolint:bodyclose + require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) + require.Contains(t, res.Body.String(), "provisioner daemon key invalid") + }) + + t.Run("ProvisionerKey_CompareFail", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockDB := dbmock.NewMockStore(ctrl) + + gomock.InOrder( + mockDB.EXPECT().GetProvisionerKeyByHashedSecret(gomock.Any(), gomock.Any()).Times(1).Return(database.ProvisionerKey{ + ID: uuid.New(), + HashedSecret: []byte("hashedSecret"), + }, nil), + ) + + routeCtx := chi.NewRouteContext() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) + res := httptest.NewRecorder() + + r.Header.Set(codersdk.ProvisionerDaemonKey, functionalKey) + + httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ + DB: mockDB, + Optional: false, + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(res, r) + + //nolint:bodyclose + require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) + require.Contains(t, res.Body.String(), "provisioner daemon key invalid") + }) + + t.Run("ProvisionerKey_DBError", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockDB := dbmock.NewMockStore(ctrl) + + gomock.InOrder( + mockDB.EXPECT().GetProvisionerKeyByHashedSecret(gomock.Any(), gomock.Any()).Times(1).Return(database.ProvisionerKey{}, xerrors.New("error")), + ) + + routeCtx := chi.NewRouteContext() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) + res := httptest.NewRecorder() + + //nolint:gosec // test key generated by test + r.Header.Set(codersdk.ProvisionerDaemonKey, functionalKey) + + httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ + DB: mockDB, + Optional: false, + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(res, r) + + //nolint:bodyclose + require.Equal(t, http.StatusInternalServerError, res.Result().StatusCode) + require.Contains(t, res.Body.String(), "get provisioner daemon key") + }) +}