From ad4f901d145576b83f040cf4b49e1068902ad5a0 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 26 Aug 2024 10:32:20 -0500 Subject: [PATCH 1/2] chore: refactor entitlements to keep it in just the options Duplicating the reference did not feel valuable, just confusing --- enterprise/coderd/coderd.go | 26 ++++++++++++-------------- enterprise/coderd/coderd_test.go | 6 +++++- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 066bea50b2758..92225f3e44ecb 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -75,6 +75,9 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { // from when an additional replica was started. options.ReplicaErrorGracePeriod = time.Minute } + if options.Entitlements == nil { + options.Entitlements = entitlements.New() + } ctx, cancelFunc := context.WithCancel(ctx) @@ -105,25 +108,22 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { return nil, xerrors.Errorf("init database encryption: %w", err) } - entitlementsSet := entitlements.New() options.Database = cryptDB api := &API{ - ctx: ctx, - cancel: cancelFunc, - Options: options, - entitlements: entitlementsSet, + ctx: ctx, + cancel: cancelFunc, + Options: options, provisionerDaemonAuth: &provisionerDaemonAuth{ psk: options.ProvisionerDaemonPSK, authorizer: options.Authorizer, db: options.Database, }, licenseMetricsCollector: &license.MetricsCollector{ - Entitlements: entitlementsSet, + Entitlements: options.Entitlements, }, } // This must happen before coderd initialization! options.PostAuthAdditionalHeadersFunc = api.writeEntitlementWarningsHeader - options.Options.Entitlements = api.entitlements api.AGPL = coderd.New(options.Options) defer func() { if err != nil { @@ -561,8 +561,6 @@ type API struct { // ProxyHealth checks the reachability of all workspace proxies. ProxyHealth *proxyhealth.ProxyHealth - entitlements *entitlements.Set - provisionerDaemonAuth *provisionerDaemonAuth licenseMetricsCollector *license.MetricsCollector @@ -595,7 +593,7 @@ func (api *API) writeEntitlementWarningsHeader(a rbac.Subject, header http.Heade return } - api.entitlements.WriteEntitlementWarningHeaders(header) + api.Entitlements.WriteEntitlementWarningHeaders(header) } func (api *API) Close() error { @@ -658,7 +656,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { // // We don't simply append to entitlement.Errors since we don't want any // enterprise features enabled. - api.entitlements.Update(func(entitlements *codersdk.Entitlements) { + api.Entitlements.Update(func(entitlements *codersdk.Entitlements) { entitlements.Errors = []string{ "License requires telemetry but telemetry is disabled", } @@ -669,7 +667,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { } featureChanged := func(featureName codersdk.FeatureName) (initial, changed, enabled bool) { - return api.entitlements.FeatureChanged(featureName, reloadedEntitlements.Features[featureName]) + return api.Entitlements.FeatureChanged(featureName, reloadedEntitlements.Features[featureName]) } shouldUpdate := func(initial, changed, enabled bool) bool { @@ -835,7 +833,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { } reloadedEntitlements.Features[codersdk.FeatureExternalTokenEncryption] = featureExternalTokenEncryption - api.entitlements.Replace(reloadedEntitlements) + api.Entitlements.Replace(reloadedEntitlements) return nil } @@ -1015,7 +1013,7 @@ func derpMapper(logger slog.Logger, proxyHealth *proxyhealth.ProxyHealth) func(* // @Router /entitlements [get] func (api *API) serveEntitlements(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - httpapi.Write(ctx, rw, http.StatusOK, api.entitlements.AsJSON()) + httpapi.Write(ctx, rw, http.StatusOK, api.Entitlements.AsJSON()) } func (api *API) runEntitlementsLoop(ctx context.Context) { diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index 5183a1d4f6a21..985b264979134 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -3,6 +3,7 @@ package coderd_test import ( "bytes" "context" + "fmt" "net/http" "reflect" "strings" @@ -46,7 +47,7 @@ func TestEntitlements(t *testing.T) { t.Parallel() t.Run("NoLicense", func(t *testing.T) { t.Parallel() - adminClient, adminUser := coderdenttest.New(t, &coderdenttest.Options{ + adminClient, _, api, adminUser := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ DontAddLicense: true, }) anotherClient, _ := coderdtest.CreateAnotherUser(t, adminClient, adminUser.OrganizationID) @@ -54,6 +55,9 @@ func TestEntitlements(t *testing.T) { require.NoError(t, err) require.False(t, res.HasLicense) require.Empty(t, res.Warnings) + + // Ensure the entitlements are the same reference + require.Equal(t, fmt.Sprintf("%p", api.Entitlements), fmt.Sprintf("%p", api.AGPL.Entitlements)) }) t.Run("FullLicense", func(t *testing.T) { // PGCoordinator requires a real postgres From b968ea3b84e8579297327b734ef73df4c1335353 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 26 Aug 2024 11:00:18 -0500 Subject: [PATCH 2/2] fix compile issues --- enterprise/coderd/jfrog.go | 2 +- enterprise/coderd/licenses.go | 2 +- enterprise/coderd/provisionerdaemons.go | 2 +- enterprise/coderd/scim.go | 2 +- enterprise/coderd/templates.go | 2 +- enterprise/coderd/userauth.go | 4 ++-- enterprise/coderd/users.go | 2 +- enterprise/coderd/workspaceagents.go | 2 +- enterprise/coderd/workspacequota.go | 2 +- 9 files changed, 10 insertions(+), 10 deletions(-) diff --git a/enterprise/coderd/jfrog.go b/enterprise/coderd/jfrog.go index e1afe473c2367..f176f48960c0e 100644 --- a/enterprise/coderd/jfrog.go +++ b/enterprise/coderd/jfrog.go @@ -107,7 +107,7 @@ func (api *API) jfrogEnabledMW(next http.Handler) http.Handler { // This doesn't actually use the external auth feature but we want // to lock this behind an enterprise license and it's somewhat // related to external auth (in that it is JFrog integration). - if !api.entitlements.Enabled(codersdk.FeatureMultipleExternalAuth) { + if !api.Entitlements.Enabled(codersdk.FeatureMultipleExternalAuth) { httpapi.RouteNotFound(rw) return } diff --git a/enterprise/coderd/licenses.go b/enterprise/coderd/licenses.go index 7db217234c25b..b3f38a8ca5f8d 100644 --- a/enterprise/coderd/licenses.go +++ b/enterprise/coderd/licenses.go @@ -189,7 +189,7 @@ func (api *API) postRefreshEntitlements(rw http.ResponseWriter, r *http.Request) // Prevent abuse by limiting how often we allow a forced refresh. now := time.Now() - if ok, wait := api.entitlements.AllowRefresh(now); !ok { + if ok, wait := api.Entitlements.AllowRefresh(now); !ok { rw.Header().Set("Retry-After", strconv.Itoa(int(wait.Seconds()))) httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: fmt.Sprintf("Entitlements already recently refreshed, please wait %d seconds to force a new refresh", int(wait.Seconds())), diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index 52836da237e23..10387eaf99b0c 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -39,7 +39,7 @@ import ( func (api *API) provisionerDaemonsEnabledMW(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - if !api.entitlements.Enabled(codersdk.FeatureExternalProvisionerDaemons) { + if !api.Entitlements.Enabled(codersdk.FeatureExternalProvisionerDaemons) { httpapi.Write(r.Context(), rw, http.StatusForbidden, codersdk.Response{ Message: "External provisioner daemons is an Enterprise feature. Contact sales!", }) diff --git a/enterprise/coderd/scim.go b/enterprise/coderd/scim.go index fdf478571f6f2..0e777111819b9 100644 --- a/enterprise/coderd/scim.go +++ b/enterprise/coderd/scim.go @@ -25,7 +25,7 @@ import ( func (api *API) scimEnabledMW(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - if !api.entitlements.Enabled(codersdk.FeatureSCIM) { + if !api.Entitlements.Enabled(codersdk.FeatureSCIM) { httpapi.RouteNotFound(rw) return } diff --git a/enterprise/coderd/templates.go b/enterprise/coderd/templates.go index bd0b803cb9c97..e9dc5ea638fff 100644 --- a/enterprise/coderd/templates.go +++ b/enterprise/coderd/templates.go @@ -349,7 +349,7 @@ func (api *API) RequireFeatureMW(feat codersdk.FeatureName) func(http.Handler) h return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { // Entitlement must be enabled. - if !api.entitlements.Enabled(feat) { + if !api.Entitlements.Enabled(feat) { licenseType := "a Premium" if feat.Enterprise() { licenseType = "an Enterprise" diff --git a/enterprise/coderd/userauth.go b/enterprise/coderd/userauth.go index 5c972515b789c..65c4a3473f3f7 100644 --- a/enterprise/coderd/userauth.go +++ b/enterprise/coderd/userauth.go @@ -14,7 +14,7 @@ import ( // nolint: revive func (api *API) setUserGroups(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error { - if !api.entitlements.Enabled(codersdk.FeatureTemplateRBAC) { + if !api.Entitlements.Enabled(codersdk.FeatureTemplateRBAC) { return nil } @@ -78,7 +78,7 @@ func (api *API) setUserGroups(ctx context.Context, logger slog.Logger, db databa } func (api *API) setUserSiteRoles(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, roles []string) error { - if !api.entitlements.Enabled(codersdk.FeatureUserRoleManagement) { + if !api.Entitlements.Enabled(codersdk.FeatureUserRoleManagement) { logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise entitlement, roles left unchanged", slog.F("user_id", userID), slog.F("roles", roles), ) diff --git a/enterprise/coderd/users.go b/enterprise/coderd/users.go index 808f91140f176..246dfde93368b 100644 --- a/enterprise/coderd/users.go +++ b/enterprise/coderd/users.go @@ -18,7 +18,7 @@ const TimeFormatHHMM = "15:04" func (api *API) autostopRequirementEnabledMW(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - feature, ok := api.entitlements.Feature(codersdk.FeatureAdvancedTemplateScheduling) + feature, ok := api.Entitlements.Feature(codersdk.FeatureAdvancedTemplateScheduling) if !ok || !feature.Entitlement.Entitled() { httpapi.Write(r.Context(), rw, http.StatusForbidden, codersdk.Response{ Message: "Advanced template scheduling (and user quiet hours schedule) is an Enterprise feature. Contact sales!", diff --git a/enterprise/coderd/workspaceagents.go b/enterprise/coderd/workspaceagents.go index baf4a9f4a9340..3223151425630 100644 --- a/enterprise/coderd/workspaceagents.go +++ b/enterprise/coderd/workspaceagents.go @@ -9,7 +9,7 @@ import ( ) func (api *API) shouldBlockNonBrowserConnections(rw http.ResponseWriter) bool { - if api.entitlements.Enabled(codersdk.FeatureBrowserOnly) { + if api.Entitlements.Enabled(codersdk.FeatureBrowserOnly) { httpapi.Write(context.Background(), rw, http.StatusConflict, codersdk.Response{ Message: "Non-browser connections are disabled for your deployment.", }) diff --git a/enterprise/coderd/workspacequota.go b/enterprise/coderd/workspacequota.go index da6546687d84c..8178f6304a947 100644 --- a/enterprise/coderd/workspacequota.go +++ b/enterprise/coderd/workspacequota.go @@ -155,7 +155,7 @@ func (api *API) workspaceQuota(rw http.ResponseWriter, r *http.Request) { user = httpmw.UserParam(r) ) - licensed := api.entitlements.Enabled(codersdk.FeatureTemplateRBAC) + licensed := api.Entitlements.Enabled(codersdk.FeatureTemplateRBAC) // There are no groups and thus no allowance if RBAC isn't licensed. var quotaAllowance int64 = -1