diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 5cc235fbdacb9..90c9c386628f1 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -775,7 +775,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { if f.hookWellKnown != nil { err := f.hookWellKnown(r, &cpy) if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) + httpError(rw, http.StatusInternalServerError, err) return } } @@ -792,7 +792,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { clientID := r.URL.Query().Get("client_id") if !assert.Equal(t, f.clientID, clientID, "unexpected client_id") { - http.Error(rw, "invalid client_id", http.StatusBadRequest) + httpError(rw, http.StatusBadRequest, xerrors.New("invalid client_id")) return } @@ -818,7 +818,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { err := f.hookValidRedirectURL(redirectURI) if err != nil { t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error()) - http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) + httpError(rw, http.StatusBadRequest, xerrors.Errorf("invalid redirect_uri: %w", err)) return } @@ -853,7 +853,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { )...) if err != nil { - http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) + httpError(rw, http.StatusBadRequest, err) return } getEmail := func(claims jwt.MapClaims) string { @@ -914,7 +914,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { claims = idTokenClaims err := f.hookOnRefresh(getEmail(claims)) if err != nil { - http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) + httpError(rw, http.StatusBadRequest, xerrors.Errorf("refresh hook blocked refresh: %w", err)) return } @@ -1036,7 +1036,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { claims, err := f.hookUserInfo(email) if err != nil { - http.Error(rw, fmt.Sprintf("user info hook returned error: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) + httpError(rw, http.StatusBadRequest, xerrors.Errorf("user info hook returned error: %w", err)) return } _ = json.NewEncoder(rw).Encode(claims) @@ -1499,13 +1499,33 @@ func slogRequestFields(r *http.Request) []any { } } -func httpErrorCode(defaultCode int, err error) int { - var statusErr statusHookError +// httpError handles better formatted custom errors. +func httpError(rw http.ResponseWriter, defaultCode int, err error) { status := defaultCode + + var statusErr statusHookError if errors.As(err, &statusErr) { status = statusErr.HTTPStatusCode } - return status + + var oauthErr *oauth2.RetrieveError + if errors.As(err, &oauthErr) { + if oauthErr.Response.StatusCode != 0 { + status = oauthErr.Response.StatusCode + } + + rw.Header().Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8") + form := url.Values{ + "error": {oauthErr.ErrorCode}, + "error_description": {oauthErr.ErrorDescription}, + "error_uri": {oauthErr.ErrorURI}, + } + rw.WriteHeader(status) + _, _ = rw.Write([]byte(form.Encode())) + return + } + + http.Error(rw, err.Error(), status) } type fakeRoundTripper struct { diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 4845ff22288fe..669ab42546777 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -3319,6 +3319,13 @@ func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.Regis return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg) } +func (q *querier) RemoveRefreshToken(ctx context.Context, arg database.RemoveRefreshTokenParams) error { + fetch := func(ctx context.Context, arg database.RemoveRefreshTokenParams) (database.ExternalAuthLink, error) { + return q.db.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{UserID: arg.UserID, ProviderID: arg.ProviderID}) + } + return fetchAndExec(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.RemoveRefreshToken)(ctx, arg) +} + func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error { // This is a system function to clear user groups in group sync. if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 2eb75f8b738c4..6570b51f263b8 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1269,6 +1269,14 @@ func (s *MethodTestSuite) TestUser() { UserID: u.ID, }).Asserts(u, policy.ActionUpdatePersonal) })) + s.Run("RemoveRefreshToken", s.Subtest(func(db database.Store, check *expects) { + link := dbgen.ExternalAuthLink(s.T(), db, database.ExternalAuthLink{}) + check.Args(database.RemoveRefreshTokenParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + UpdatedAt: link.UpdatedAt, + }).Asserts(rbac.ResourceUserObject(link.UserID), policy.ActionUpdatePersonal) + })) s.Run("UpdateExternalAuthLink", s.Subtest(func(db database.Store, check *expects) { link := dbgen.ExternalAuthLink(s.T(), db, database.ExternalAuthLink{}) check.Args(database.UpdateExternalAuthLinkParams{ diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index aed57e9284b3a..1891d62510351 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -8512,6 +8512,29 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg return database.WorkspaceProxy{}, sql.ErrNoRows } +func (q *FakeQuerier) RemoveRefreshToken(_ context.Context, arg database.RemoveRefreshTokenParams) error { + if err := validateDatabaseType(arg); err != nil { + return err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + for index, gitAuthLink := range q.externalAuthLinks { + if gitAuthLink.ProviderID != arg.ProviderID { + continue + } + if gitAuthLink.UserID != arg.UserID { + continue + } + gitAuthLink.UpdatedAt = arg.UpdatedAt + gitAuthLink.OAuthRefreshToken = "" + q.externalAuthLinks[index] = gitAuthLink + + return nil + } + return sql.ErrNoRows +} + func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUID) error { q.mutex.Lock() defer q.mutex.Unlock() diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 32d3cce658525..4b3431199ec48 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -2086,6 +2086,13 @@ func (m queryMetricsStore) RegisterWorkspaceProxy(ctx context.Context, arg datab return proxy, err } +func (m queryMetricsStore) RemoveRefreshToken(ctx context.Context, arg database.RemoveRefreshTokenParams) error { + start := time.Now() + r0 := m.s.RemoveRefreshToken(ctx, arg) + m.queryLatencies.WithLabelValues("RemoveRefreshToken").Observe(time.Since(start).Seconds()) + return r0 +} + func (m queryMetricsStore) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error { start := time.Now() r0 := m.s.RemoveUserFromAllGroups(ctx, userID) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index d6c34411f8208..28571a824e8c3 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -4448,6 +4448,20 @@ func (mr *MockStoreMockRecorder) RegisterWorkspaceProxy(arg0, arg1 any) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterWorkspaceProxy", reflect.TypeOf((*MockStore)(nil).RegisterWorkspaceProxy), arg0, arg1) } +// RemoveRefreshToken mocks base method. +func (m *MockStore) RemoveRefreshToken(arg0 context.Context, arg1 database.RemoveRefreshTokenParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveRefreshToken", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveRefreshToken indicates an expected call of RemoveRefreshToken. +func (mr *MockStoreMockRecorder) RemoveRefreshToken(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveRefreshToken", reflect.TypeOf((*MockStore)(nil).RemoveRefreshToken), arg0, arg1) +} + // RemoveUserFromAllGroups mocks base method. func (m *MockStore) RemoveUserFromAllGroups(arg0 context.Context, arg1 uuid.UUID) error { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 49ba6fbf8496a..c19851f5d067f 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -423,6 +423,10 @@ type sqlcQuerier interface { OrganizationMembers(ctx context.Context, arg OrganizationMembersParams) ([]OrganizationMembersRow, error) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error) + // Removing the refresh token disables the refresh behavior for a given + // auth token. If a refresh token is marked invalid, it is better to remove it + // then continually attempt to refresh the token. + RemoveRefreshToken(ctx context.Context, arg RemoveRefreshTokenParams) error RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 09dd4c1fbc488..2bec1b5188b36 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1194,6 +1194,29 @@ func (q *sqlQuerier) InsertExternalAuthLink(ctx context.Context, arg InsertExter return i, err } +const removeRefreshToken = `-- name: RemoveRefreshToken :exec +UPDATE + external_auth_links +SET + oauth_refresh_token = '', + updated_at = $1 +WHERE provider_id = $2 AND user_id = $3 +` + +type RemoveRefreshTokenParams struct { + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + ProviderID string `db:"provider_id" json:"provider_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` +} + +// Removing the refresh token disables the refresh behavior for a given +// auth token. If a refresh token is marked invalid, it is better to remove it +// then continually attempt to refresh the token. +func (q *sqlQuerier) RemoveRefreshToken(ctx context.Context, arg RemoveRefreshTokenParams) error { + _, err := q.db.ExecContext(ctx, removeRefreshToken, arg.UpdatedAt, arg.ProviderID, arg.UserID) + return err +} + const updateExternalAuthLink = `-- name: UpdateExternalAuthLink :one UPDATE external_auth_links SET updated_at = $3, diff --git a/coderd/database/queries/externalauth.sql b/coderd/database/queries/externalauth.sql index 8470c44ea9125..cd223bd792a2a 100644 --- a/coderd/database/queries/externalauth.sql +++ b/coderd/database/queries/externalauth.sql @@ -42,3 +42,14 @@ UPDATE external_auth_links SET oauth_expiry = $8, oauth_extra = $9 WHERE provider_id = $1 AND user_id = $2 RETURNING *; + +-- name: RemoveRefreshToken :exec +-- Removing the refresh token disables the refresh behavior for a given +-- auth token. If a refresh token is marked invalid, it is better to remove it +-- then continually attempt to refresh the token. +UPDATE + external_auth_links +SET + oauth_refresh_token = '', + updated_at = @updated_at +WHERE provider_id = @provider_id AND user_id = @user_id; diff --git a/coderd/externalauth/externalauth.go b/coderd/externalauth/externalauth.go index 2ad2761e80b46..1ce850c9cec03 100644 --- a/coderd/externalauth/externalauth.go +++ b/coderd/externalauth/externalauth.go @@ -118,7 +118,7 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu // This is true for github, which has no expiry. !externalAuthLink.OAuthExpiry.IsZero() && externalAuthLink.OAuthExpiry.Before(dbtime.Now()) { - return externalAuthLink, InvalidTokenError("token expired, refreshing is disabled") + return externalAuthLink, InvalidTokenError("token expired, refreshing is either disabled or refreshing failed and will not be retried") } // This is additional defensive programming. Because TokenSource is an interface, @@ -130,16 +130,41 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu refreshToken = "" } - token, err := c.TokenSource(ctx, &oauth2.Token{ + existingToken := &oauth2.Token{ AccessToken: externalAuthLink.OAuthAccessToken, RefreshToken: refreshToken, Expiry: externalAuthLink.OAuthExpiry, - }).Token() + } + + token, err := c.TokenSource(ctx, existingToken).Token() if err != nil { - // Even if the token fails to be obtained, do not return the error as an error. + // TokenSource can fail for numerous reasons. If it fails because of + // a bad refresh token, then the refresh token is invalid, and we should + // get rid of it. Keeping it around will cause additional refresh + // attempts that will fail and cost us api rate limits. + if isFailedRefresh(existingToken, err) { + dbExecErr := db.RemoveRefreshToken(ctx, database.RemoveRefreshTokenParams{ + UpdatedAt: dbtime.Now(), + ProviderID: externalAuthLink.ProviderID, + UserID: externalAuthLink.UserID, + }) + if dbExecErr != nil { + // This error should be rare. + return externalAuthLink, InvalidTokenError(fmt.Sprintf("refresh token failed: %q, then removing refresh token failed: %q", err.Error(), dbExecErr.Error())) + } + // The refresh token was cleared + externalAuthLink.OAuthRefreshToken = "" + } + + // Unfortunately have to match exactly on the error message string. + // Improve the error message to account refresh tokens are deleted if + // invalid on our end. + if err.Error() == "oauth2: token expired and refresh token is not set" { + return externalAuthLink, InvalidTokenError("token expired, refreshing is either disabled or refreshing failed and will not be retried") + } + // TokenSource(...).Token() will always return the current token if the token is not expired. - // If it is expired, it will attempt to refresh the token, and if it cannot, it will fail with - // an error. This error is a reason the token is invalid. + // So this error is only returned if a refresh of the token failed. return externalAuthLink, InvalidTokenError(fmt.Sprintf("refresh token: %s", err.Error())) } @@ -973,3 +998,50 @@ func IsGithubDotComURL(str string) bool { } return ghURL.Host == "github.com" } + +// isFailedRefresh returns true if the error returned by the TokenSource.Token() +// is due to a failed refresh. The failure being the refresh token itself. +// If this returns true, no amount of retries will fix the issue. +// +// Notes: Provider responses are not uniform. Here are some examples: +// Github +// - Returns a 200 with Code "bad_refresh_token" and Description "The refresh token passed is incorrect or expired." +// +// Gitea [TODO: get an expired refresh token] +// - [Bad JWT] Returns 400 with Code "unauthorized_client" and Description "unable to parse refresh token" +// +// Gitlab +// - Returns 400 with Code "invalid_grant" and Description "The provided authorization grant is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client." +func isFailedRefresh(existingToken *oauth2.Token, err error) bool { + if existingToken.RefreshToken == "" { + return false // No refresh token, so this cannot be refreshed + } + + if existingToken.Valid() { + return false // Valid tokens are not refreshed + } + + var oauthErr *oauth2.RetrieveError + if xerrors.As(err, &oauthErr) { + switch oauthErr.ErrorCode { + // Known error codes that indicate a failed refresh. + // 'Spec' means the code is defined in the spec. + case "bad_refresh_token", // Github + "invalid_grant", // Gitlab & Spec + "unauthorized_client", // Gitea & Spec + "unsupported_grant_type": // Spec, refresh not supported + return true + } + + switch oauthErr.Response.StatusCode { + case http.StatusBadRequest, http.StatusUnauthorized, http.StatusForbidden, http.StatusOK: + // Status codes that indicate the request was processed, and rejected. + return true + case http.StatusInternalServerError, http.StatusTooManyRequests: + // These do not indicate a failed refresh, but could be a temporary issue. + return false + } + } + + return false +} diff --git a/coderd/externalauth/externalauth_test.go b/coderd/externalauth/externalauth_test.go index fbc1cab4b7091..84bded9856572 100644 --- a/coderd/externalauth/externalauth_test.go +++ b/coderd/externalauth/externalauth_test.go @@ -17,6 +17,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "golang.org/x/oauth2" "golang.org/x/xerrors" @@ -25,6 +26,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/codersdk" @@ -62,7 +64,7 @@ func TestRefreshToken(t *testing.T) { _, err := config.RefreshToken(ctx, nil, link) require.Error(t, err) require.True(t, externalauth.IsInvalidTokenError(err)) - require.Contains(t, err.Error(), "refreshing is disabled") + require.Contains(t, err.Error(), "refreshing is either disabled or refreshing failed") }) // NoRefreshNoExpiry tests that an oauth token without an expiry is always valid. @@ -141,6 +143,73 @@ func TestRefreshToken(t *testing.T) { require.True(t, validated, "token should have been attempted to be validated") }) + // RefreshRetries tests that refresh token retry behavior works as expected. + // If a refresh token fails because the token itself is invalid, no more + // refresh attempts should ever happen. An invalid refresh token does + // not magically become valid at some point in the future. + t.Run("RefreshRetries", func(t *testing.T) { + t.Parallel() + + var refreshErr *oauth2.RetrieveError + + ctrl := gomock.NewController(t) + mDB := dbmock.NewMockStore(ctrl) + + refreshCount := 0 + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithRefresh(func(_ string) error { + refreshCount++ + return refreshErr + }), + // The IDP should not be contacted since the token is expired and + // refresh attempts will fail. + oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) { + t.Error("token was validated, but it was expired and this should never have happened.") + return nil, xerrors.New("should not be called") + }), + }, + ExternalAuthOpt: func(cfg *externalauth.Config) {}, + }) + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) + // Expire the link + link.OAuthExpiry = expired + + // Make the failure a server internal error. Not related to the token + refreshErr = &oauth2.RetrieveError{ + Response: &http.Response{ + StatusCode: http.StatusInternalServerError, + }, + ErrorCode: "internal_error", + } + _, err := config.RefreshToken(ctx, mDB, link) + require.Error(t, err) + require.True(t, externalauth.IsInvalidTokenError(err)) + require.Equal(t, refreshCount, 1) + + // Try again with a bad refresh token error + // Expect DB call to remove the refresh token + mDB.EXPECT().RemoveRefreshToken(gomock.Any(), gomock.Any()).Return(nil).Times(1) + refreshErr = &oauth2.RetrieveError{ // github error + Response: &http.Response{ + StatusCode: http.StatusOK, + }, + ErrorCode: "bad_refresh_token", + } + _, err = config.RefreshToken(ctx, mDB, link) + require.Error(t, err) + require.True(t, externalauth.IsInvalidTokenError(err)) + require.Equal(t, refreshCount, 2) + + // When the refresh token is empty, no api calls should be made + link.OAuthRefreshToken = "" // mock'd db, so manually set the token to '' + _, err = config.RefreshToken(ctx, mDB, link) + require.Error(t, err) + require.True(t, externalauth.IsInvalidTokenError(err)) + require.Equal(t, refreshCount, 2) + }) + // ValidateFailure tests if the token is no longer valid with a 401 response. t.Run("ValidateFailure", func(t *testing.T) { t.Parallel()