Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 58f7071

Browse files
authored
fix: make 'NoRefresh' honor unlimited tokens in gitauth (#9472)
* chore: fix NoRefresh to honor unlimited tokens * improve testing coverage of gitauth * refactor rest of gitauth tests
1 parent da0ef92 commit 58f7071

File tree

5 files changed

+354
-112
lines changed

5 files changed

+354
-112
lines changed

coderd/coderdtest/oidctest/idp.go

+64-15
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"crypto/x509"
88
"encoding/json"
99
"encoding/pem"
10+
"errors"
1011
"fmt"
1112
"io"
1213
"net"
@@ -41,7 +42,7 @@ import (
4142
type FakeIDP struct {
4243
issuer string
4344
key *rsa.PrivateKey
44-
provider providerJSON
45+
provider ProviderJSON
4546
handler http.Handler
4647
cfg *oauth2.Config
4748

@@ -66,7 +67,7 @@ type FakeIDP struct {
6667
// IDP -> Application. Almost all IDPs have the concept of
6768
// "Authorized Redirect URLs". This can be used to emulate that.
6869
hookValidRedirectURL func(redirectURL string) error
69-
hookUserInfo func(email string) jwt.MapClaims
70+
hookUserInfo func(email string) (jwt.MapClaims, error)
7071
fakeCoderd func(req *http.Request) (*http.Response, error)
7172
hookOnRefresh func(email string) error
7273
// Custom authentication for the client. This is useful if you want
@@ -75,6 +76,26 @@ type FakeIDP struct {
7576
serve bool
7677
}
7778

79+
func StatusError(code int, err error) error {
80+
return statusHookError{
81+
Err: err,
82+
HTTPStatusCode: code,
83+
}
84+
}
85+
86+
// statusHookError allows a hook to change the returned http status code.
87+
type statusHookError struct {
88+
Err error
89+
HTTPStatusCode int
90+
}
91+
92+
func (s statusHookError) Error() string {
93+
if s.Err == nil {
94+
return ""
95+
}
96+
return s.Err.Error()
97+
}
98+
7899
type FakeIDPOpt func(idp *FakeIDP)
79100

80101
func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeIDP) {
@@ -83,9 +104,9 @@ func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeID
83104
}
84105
}
85106

86-
// WithRefreshHook is called when a refresh token is used. The email is
107+
// WithRefresh is called when a refresh token is used. The email is
87108
// the email of the user that is being refreshed assuming the claims are correct.
88-
func WithRefreshHook(hook func(email string) error) func(*FakeIDP) {
109+
func WithRefresh(hook func(email string) error) func(*FakeIDP) {
89110
return func(f *FakeIDP) {
90111
f.hookOnRefresh = hook
91112
}
@@ -108,13 +129,13 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
108129
// every user on the /userinfo endpoint.
109130
func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) {
110131
return func(f *FakeIDP) {
111-
f.hookUserInfo = func(_ string) jwt.MapClaims {
112-
return info
132+
f.hookUserInfo = func(_ string) (jwt.MapClaims, error) {
133+
return info, nil
113134
}
114135
}
115136
}
116137

117-
func WithDynamicUserInfo(userInfoFunc func(email string) jwt.MapClaims) func(*FakeIDP) {
138+
func WithDynamicUserInfo(userInfoFunc func(email string) (jwt.MapClaims, error)) func(*FakeIDP) {
118139
return func(f *FakeIDP) {
119140
f.hookUserInfo = userInfoFunc
120141
}
@@ -160,7 +181,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
160181
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
161182
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
162183
hookOnRefresh: func(_ string) error { return nil },
163-
hookUserInfo: func(email string) jwt.MapClaims { return jwt.MapClaims{} },
184+
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
164185
hookValidRedirectURL: func(redirectURL string) error { return nil },
165186
}
166187

@@ -181,16 +202,20 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
181202
return idp
182203
}
183204

205+
func (f *FakeIDP) WellknownConfig() ProviderJSON {
206+
return f.provider
207+
}
208+
184209
func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
185210
t.Helper()
186211

187212
u, err := url.Parse(issuer)
188213
require.NoError(t, err, "invalid issuer URL")
189214

190215
f.issuer = issuer
191-
// providerJSON is the JSON representation of the OpenID Connect provider
216+
// ProviderJSON is the JSON representation of the OpenID Connect provider
192217
// These are all the urls that the IDP will respond to.
193-
f.provider = providerJSON{
218+
f.provider = ProviderJSON{
194219
Issuer: issuer,
195220
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
196221
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
@@ -220,6 +245,15 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
220245
return srv
221246
}
222247

248+
// GenerateAuthenticatedToken skips all oauth2 flows, and just generates a
249+
// valid token for some given claims.
250+
func (f *FakeIDP) GenerateAuthenticatedToken(claims jwt.MapClaims) (*oauth2.Token, error) {
251+
state := uuid.NewString()
252+
f.stateToIDTokenClaims.Store(state, claims)
253+
code := f.newCode(state)
254+
return f.cfg.Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code)
255+
}
256+
223257
// Login does the full OIDC flow starting at the "LoginButton".
224258
// The client argument is just to get the URL of the Coder instance.
225259
//
@@ -333,7 +367,8 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map
333367
return resp, nil
334368
}
335369

336-
type providerJSON struct {
370+
// ProviderJSON is the .well-known/configuration JSON
371+
type ProviderJSON struct {
337372
Issuer string `json:"issuer"`
338373
AuthURL string `json:"authorization_endpoint"`
339374
TokenURL string `json:"token_endpoint"`
@@ -475,7 +510,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
475510
err := f.hookValidRedirectURL(redirectURI)
476511
if err != nil {
477512
t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error())
478-
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest)
513+
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
479514
return
480515
}
481516

@@ -501,7 +536,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
501536
slog.F("values", values.Encode()),
502537
)
503538
if err != nil {
504-
http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), http.StatusBadRequest)
539+
http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
505540
return
506541
}
507542
getEmail := func(claims jwt.MapClaims) string {
@@ -562,7 +597,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
562597
claims = idTokenClaims
563598
err := f.hookOnRefresh(getEmail(claims))
564599
if err != nil {
565-
http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), http.StatusBadRequest)
600+
http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
566601
return
567602
}
568603

@@ -610,7 +645,12 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
610645
http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest)
611646
return
612647
}
613-
_ = json.NewEncoder(rw).Encode(f.hookUserInfo(email))
648+
claims, err := f.hookUserInfo(email)
649+
if err != nil {
650+
http.Error(rw, fmt.Sprintf("user info hook returned error: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
651+
return
652+
}
653+
_ = json.NewEncoder(rw).Encode(claims)
614654
}))
615655

616656
mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@@ -768,6 +808,15 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co
768808
return cfg
769809
}
770810

811+
func httpErrorCode(defaultCode int, err error) int {
812+
var stautsErr statusHookError
813+
status := defaultCode
814+
if errors.As(err, &stautsErr) {
815+
status = stautsErr.HTTPStatusCode
816+
}
817+
return status
818+
}
819+
771820
type fakeRoundTripper struct {
772821
roundTrip func(req *http.Request) (*http.Response, error)
773822
}

coderd/gitauth/config.go

+22-4
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,30 @@ type Config struct {
6060
}
6161

6262
// RefreshToken automatically refreshes the token if expired and permitted.
63-
// It returns the token and a bool indicating if the token was refreshed.
63+
// It returns the token and a bool indicating if the token is valid.
6464
func (c *Config) RefreshToken(ctx context.Context, db database.Store, gitAuthLink database.GitAuthLink) (database.GitAuthLink, bool, error) {
6565
// If the token is expired and refresh is disabled, we prompt
6666
// the user to authenticate again.
67-
if c.NoRefresh && gitAuthLink.OAuthExpiry.Before(dbtime.Now()) {
67+
if c.NoRefresh &&
68+
// If the time is set to 0, then it should never expire.
69+
// This is true for github, which has no expiry.
70+
!gitAuthLink.OAuthExpiry.IsZero() &&
71+
gitAuthLink.OAuthExpiry.Before(dbtime.Now()) {
6872
return gitAuthLink, false, nil
6973
}
7074

75+
// This is additional defensive programming. Because TokenSource is an interface,
76+
// we cannot be sure that the implementation will treat an 'IsZero' time
77+
// as "not-expired". The default implementation does, but a custom implementation
78+
// might not. Removing the refreshToken will guarantee a refresh will fail.
79+
refreshToken := gitAuthLink.OAuthRefreshToken
80+
if c.NoRefresh {
81+
refreshToken = ""
82+
}
83+
7184
token, err := c.TokenSource(ctx, &oauth2.Token{
7285
AccessToken: gitAuthLink.OAuthAccessToken,
73-
RefreshToken: gitAuthLink.OAuthRefreshToken,
86+
RefreshToken: refreshToken,
7487
Expiry: gitAuthLink.OAuthExpiry,
7588
}).Token()
7689
if err != nil {
@@ -130,8 +143,13 @@ func (c *Config) ValidateToken(ctx context.Context, token string) (bool, *coders
130143
if err != nil {
131144
return false, nil, err
132145
}
146+
147+
cli := http.DefaultClient
148+
if v, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
149+
cli = v
150+
}
133151
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
134-
res, err := http.DefaultClient.Do(req)
152+
res, err := cli.Do(req)
135153
if err != nil {
136154
return false, nil, err
137155
}

0 commit comments

Comments
 (0)