From 41dc49c8217753afce218f50f37d59aa5920f8a3 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 30 Aug 2024 11:11:33 +0400 Subject: [PATCH] fix: allow posting licenses that will be valid in future --- .../coderd/coderdenttest/coderdenttest.go | 9 +++- enterprise/coderd/license/license.go | 41 +++++++++++++++- enterprise/coderd/licenses.go | 32 +++++-------- enterprise/coderd/licenses_test.go | 48 +++++++++++++++++++ 4 files changed, 106 insertions(+), 24 deletions(-) diff --git a/enterprise/coderd/coderdenttest/coderdenttest.go b/enterprise/coderd/coderdenttest/coderdenttest.go index f5bfd05529fdd..1248781d483e4 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest.go +++ b/enterprise/coderd/coderdenttest/coderdenttest.go @@ -174,6 +174,10 @@ type LicenseOptions struct { // ExpiresAt is the time at which the license will hard expire. // ExpiresAt should always be greater then GraceAt. ExpiresAt time.Time + // NotBefore is the time at which the license becomes valid. If set to the + // zero value, the `nbf` claim on the license is set to 1 minute in the + // past. + NotBefore time.Time Features license.Features } @@ -233,13 +237,16 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string { if options.GraceAt.IsZero() { options.GraceAt = time.Now().Add(time.Hour) } + if options.NotBefore.IsZero() { + options.NotBefore = time.Now().Add(-time.Minute) + } c := &license.Claims{ RegisteredClaims: jwt.RegisteredClaims{ ID: uuid.NewString(), Issuer: "test@testing.test", ExpiresAt: jwt.NewNumericDate(options.ExpiresAt), - NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Minute)), + NotBefore: jwt.NewNumericDate(options.NotBefore), IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), }, LicenseExpires: jwt.NewNumericDate(options.GraceAt), diff --git a/enterprise/coderd/license/license.go b/enterprise/coderd/license/license.go index fdb177d753eae..f81606afd66fd 100644 --- a/enterprise/coderd/license/license.go +++ b/enterprise/coderd/license/license.go @@ -287,6 +287,8 @@ var ( ErrInvalidVersion = xerrors.New("license must be version 3") ErrMissingKeyID = xerrors.Errorf("JOSE header must contain %s", HeaderKeyID) ErrMissingLicenseExpires = xerrors.New("license missing license_expires") + ErrMissingExp = xerrors.New("exp claim missing or not parsable") + ErrMultipleIssues = xerrors.New("license has multiple issues; contact support") ) type Features map[codersdk.FeatureName]int64 @@ -336,7 +338,7 @@ func ParseRaw(l string, keys map[string]ed25519.PublicKey) (jwt.MapClaims, error return nil, xerrors.New("unable to parse Claims") } -// ParseClaims validates a database.License record, and if valid, returns the claims. If +// ParseClaims validates a raw JWT, and if valid, returns the claims. If // unparsable or invalid, it returns an error func ParseClaims(rawJWT string, keys map[string]ed25519.PublicKey) (*Claims, error) { tok, err := jwt.ParseWithClaims( @@ -348,18 +350,53 @@ func ParseClaims(rawJWT string, keys map[string]ed25519.PublicKey) (*Claims, err if err != nil { return nil, err } - if claims, ok := tok.Claims.(*Claims); ok && tok.Valid { + return validateClaims(tok) +} + +func validateClaims(tok *jwt.Token) (*Claims, error) { + if claims, ok := tok.Claims.(*Claims); ok { if claims.Version != uint64(CurrentVersion) { return nil, ErrInvalidVersion } if claims.LicenseExpires == nil { return nil, ErrMissingLicenseExpires } + if claims.ExpiresAt == nil { + return nil, ErrMissingExp + } return claims, nil } return nil, xerrors.New("unable to parse Claims") } +// ParseClaimsIgnoreNbf validates a raw JWT, but ignores `nbf` claim. If otherwise valid, it returns +// the claims. If unparsable or invalid, it returns an error. Ignoring the `nbf` (not before) is +// useful to determine if a JWT _will_ become valid at any point now or in the future. +func ParseClaimsIgnoreNbf(rawJWT string, keys map[string]ed25519.PublicKey) (*Claims, error) { + tok, err := jwt.ParseWithClaims( + rawJWT, + &Claims{}, + keyFunc(keys), + jwt.WithValidMethods(ValidMethods), + ) + var vErr *jwt.ValidationError + if xerrors.As(err, &vErr) { + // zero out the NotValidYet error to check if there were other problems + vErr.Errors = vErr.Errors & (^jwt.ValidationErrorNotValidYet) + if vErr.Errors != 0 { + // There are other errors besides not being valid yet. We _could_ go + // through all the jwt.ValidationError bits and try to work out the + // correct error, but if we get here something very strange is + // going on so let's just return a generic error that says to get in + // touch with our support team. + return nil, ErrMultipleIssues + } + } else if err != nil { + return nil, err + } + return validateClaims(tok) +} + func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, error) { return func(j *jwt.Token) (interface{}, error) { keyID, ok := j.Header[HeaderKeyID].(string) diff --git a/enterprise/coderd/licenses.go b/enterprise/coderd/licenses.go index b3f38a8ca5f8d..8e713886555a5 100644 --- a/enterprise/coderd/licenses.go +++ b/enterprise/coderd/licenses.go @@ -86,25 +86,7 @@ func (api *API) postLicense(rw http.ResponseWriter, r *http.Request) { return } - rawClaims, err := license.ParseRaw(addLicense.License, api.LicenseKeys) - if err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid license", - Detail: err.Error(), - }) - return - } - exp, ok := rawClaims["exp"].(float64) - if !ok { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid license", - Detail: "exp claim missing or not parsable", - }) - return - } - expTime := time.Unix(int64(exp), 0) - - claims, err := license.ParseClaims(addLicense.License, api.LicenseKeys) + claims, err := license.ParseClaimsIgnoreNbf(addLicense.License, api.LicenseKeys) if err != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid license", @@ -134,7 +116,7 @@ func (api *API) postLicense(rw http.ResponseWriter, r *http.Request) { dl, err := api.Database.InsertLicense(ctx, database.InsertLicenseParams{ UploadedAt: dbtime.Now(), JWT: addLicense.License, - Exp: expTime, + Exp: claims.ExpiresAt.Time, UUID: id, }) if err != nil { @@ -160,7 +142,15 @@ func (api *API) postLicense(rw http.ResponseWriter, r *http.Request) { // don't fail the HTTP request, since we did write it successfully to the database } - httpapi.Write(ctx, rw, http.StatusCreated, convertLicense(dl, rawClaims)) + c, err := decodeClaims(dl) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to decode database response", + Detail: err.Error(), + }) + return + } + httpapi.Write(ctx, rw, http.StatusCreated, convertLicense(dl, c)) } // postRefreshEntitlements forces an `updateEntitlements` call and publishes diff --git a/enterprise/coderd/licenses_test.go b/enterprise/coderd/licenses_test.go index c2f7d83fbbd6b..bbd6ef717fe8e 100644 --- a/enterprise/coderd/licenses_test.go +++ b/enterprise/coderd/licenses_test.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "testing" + "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -82,6 +83,53 @@ func TestPostLicense(t *testing.T) { t.Error("expected to get error status 400") } }) + + // Test a license that isn't yet valid, but will be in the future. We should allow this so that + // operators can upload a license ahead of time. + t.Run("NotYet", func(t *testing.T) { + t.Parallel() + client, _ := coderdenttest.New(t, &coderdenttest.Options{DontAddLicense: true}) + respLic := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ + AccountType: license.AccountTypeSalesforce, + AccountID: "testing", + Features: license.Features{ + codersdk.FeatureAuditLog: 1, + }, + NotBefore: time.Now().Add(time.Hour), + GraceAt: time.Now().Add(2 * time.Hour), + ExpiresAt: time.Now().Add(3 * time.Hour), + }) + assert.GreaterOrEqual(t, respLic.ID, int32(0)) + // just a couple spot checks for sanity + assert.Equal(t, "testing", respLic.Claims["account_id"]) + features, err := respLic.FeaturesClaims() + require.NoError(t, err) + assert.EqualValues(t, 1, features[codersdk.FeatureAuditLog]) + }) + + // Test we still reject a license that isn't valid yet, but has other issues (e.g. expired + // before it starts). + t.Run("NotEver", func(t *testing.T) { + t.Parallel() + client, _ := coderdenttest.New(t, &coderdenttest.Options{DontAddLicense: true}) + lic := coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ + AccountType: license.AccountTypeSalesforce, + AccountID: "testing", + Features: license.Features{ + codersdk.FeatureAuditLog: 1, + }, + NotBefore: time.Now().Add(time.Hour), + GraceAt: time.Now().Add(2 * time.Hour), + ExpiresAt: time.Now().Add(-time.Hour), + }) + _, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{ + License: lic, + }) + errResp := &codersdk.Error{} + require.ErrorAs(t, err, &errResp) + require.Equal(t, http.StatusBadRequest, errResp.StatusCode()) + require.Contains(t, errResp.Detail, license.ErrMultipleIssues.Error()) + }) } func TestGetLicense(t *testing.T) {