77 "crypto/x509"
88 "encoding/json"
99 "encoding/pem"
10+ "errors"
1011 "fmt"
1112 "io"
1213 "net"
@@ -41,7 +42,7 @@ import (
4142type FakeIDP struct {
4243 issuer string
4344 key * rsa.PrivateKey
44- provider providerJSON
45+ provider ProviderJSON
4546 handler http.Handler
4647 cfg * oauth2.Config
4748
@@ -66,7 +67,7 @@ type FakeIDP struct {
6667 // IDP -> Application. Almost all IDPs have the concept of
6768 // "Authorized Redirect URLs". This can be used to emulate that.
6869 hookValidRedirectURL func (redirectURL string ) error
69- hookUserInfo func (email string ) jwt.MapClaims
70+ hookUserInfo func (email string ) ( jwt.MapClaims , error )
7071 fakeCoderd func (req * http.Request ) (* http.Response , error )
7172 hookOnRefresh func (email string ) error
7273 // Custom authentication for the client. This is useful if you want
@@ -75,6 +76,26 @@ type FakeIDP struct {
7576 serve bool
7677}
7778
79+ func StatusError (code int , err error ) error {
80+ return statusHookError {
81+ Err : err ,
82+ HTTPStatusCode : code ,
83+ }
84+ }
85+
86+ // statusHookError allows a hook to change the returned http status code.
87+ type statusHookError struct {
88+ Err error
89+ HTTPStatusCode int
90+ }
91+
92+ func (s statusHookError ) Error () string {
93+ if s .Err == nil {
94+ return ""
95+ }
96+ return s .Err .Error ()
97+ }
98+
7899type FakeIDPOpt func (idp * FakeIDP )
79100
80101func WithAuthorizedRedirectURL (hook func (redirectURL string ) error ) func (* FakeIDP ) {
@@ -83,9 +104,9 @@ func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeID
83104 }
84105}
85106
86- // WithRefreshHook is called when a refresh token is used. The email is
107+ // WithRefresh is called when a refresh token is used. The email is
87108// the email of the user that is being refreshed assuming the claims are correct.
88- func WithRefreshHook (hook func (email string ) error ) func (* FakeIDP ) {
109+ func WithRefresh (hook func (email string ) error ) func (* FakeIDP ) {
89110 return func (f * FakeIDP ) {
90111 f .hookOnRefresh = hook
91112 }
@@ -108,13 +129,13 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
108129// every user on the /userinfo endpoint.
109130func WithStaticUserInfo (info jwt.MapClaims ) func (* FakeIDP ) {
110131 return func (f * FakeIDP ) {
111- f .hookUserInfo = func (_ string ) jwt.MapClaims {
112- return info
132+ f .hookUserInfo = func (_ string ) ( jwt.MapClaims , error ) {
133+ return info , nil
113134 }
114135 }
115136}
116137
117- func WithDynamicUserInfo (userInfoFunc func (email string ) jwt.MapClaims ) func (* FakeIDP ) {
138+ func WithDynamicUserInfo (userInfoFunc func (email string ) ( jwt.MapClaims , error ) ) func (* FakeIDP ) {
118139 return func (f * FakeIDP ) {
119140 f .hookUserInfo = userInfoFunc
120141 }
@@ -160,7 +181,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
160181 stateToIDTokenClaims : syncmap .New [string , jwt.MapClaims ](),
161182 refreshIDTokenClaims : syncmap .New [string , jwt.MapClaims ](),
162183 hookOnRefresh : func (_ string ) error { return nil },
163- hookUserInfo : func (email string ) jwt.MapClaims { return jwt.MapClaims {} },
184+ hookUserInfo : func (email string ) ( jwt.MapClaims , error ) { return jwt.MapClaims {}, nil },
164185 hookValidRedirectURL : func (redirectURL string ) error { return nil },
165186 }
166187
@@ -181,16 +202,20 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
181202 return idp
182203}
183204
205+ func (f * FakeIDP ) WellknownConfig () ProviderJSON {
206+ return f .provider
207+ }
208+
184209func (f * FakeIDP ) updateIssuerURL (t testing.TB , issuer string ) {
185210 t .Helper ()
186211
187212 u , err := url .Parse (issuer )
188213 require .NoError (t , err , "invalid issuer URL" )
189214
190215 f .issuer = issuer
191- // providerJSON is the JSON representation of the OpenID Connect provider
216+ // ProviderJSON is the JSON representation of the OpenID Connect provider
192217 // These are all the urls that the IDP will respond to.
193- f .provider = providerJSON {
218+ f .provider = ProviderJSON {
194219 Issuer : issuer ,
195220 AuthURL : u .ResolveReference (& url.URL {Path : authorizePath }).String (),
196221 TokenURL : u .ResolveReference (& url.URL {Path : tokenPath }).String (),
@@ -220,6 +245,15 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
220245 return srv
221246}
222247
248+ // GenerateAuthenticatedToken skips all oauth2 flows, and just generates a
249+ // valid token for some given claims.
250+ func (f * FakeIDP ) GenerateAuthenticatedToken (claims jwt.MapClaims ) (* oauth2.Token , error ) {
251+ state := uuid .NewString ()
252+ f .stateToIDTokenClaims .Store (state , claims )
253+ code := f .newCode (state )
254+ return f .cfg .Exchange (oidc .ClientContext (context .Background (), f .HTTPClient (nil )), code )
255+ }
256+
223257// Login does the full OIDC flow starting at the "LoginButton".
224258// The client argument is just to get the URL of the Coder instance.
225259//
@@ -333,7 +367,8 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map
333367 return resp , nil
334368}
335369
336- type providerJSON struct {
370+ // ProviderJSON is the .well-known/configuration JSON
371+ type ProviderJSON struct {
337372 Issuer string `json:"issuer"`
338373 AuthURL string `json:"authorization_endpoint"`
339374 TokenURL string `json:"token_endpoint"`
@@ -475,7 +510,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
475510 err := f .hookValidRedirectURL (redirectURI )
476511 if err != nil {
477512 t .Errorf ("not authorized redirect_uri by custom hook %q: %s" , redirectURI , err .Error ())
478- http .Error (rw , fmt .Sprintf ("invalid redirect_uri: %s" , err .Error ()), http .StatusBadRequest )
513+ http .Error (rw , fmt .Sprintf ("invalid redirect_uri: %s" , err .Error ()), httpErrorCode ( http .StatusBadRequest , err ) )
479514 return
480515 }
481516
@@ -501,7 +536,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
501536 slog .F ("values" , values .Encode ()),
502537 )
503538 if err != nil {
504- http .Error (rw , fmt .Sprintf ("invalid token request: %s" , err .Error ()), http .StatusBadRequest )
539+ http .Error (rw , fmt .Sprintf ("invalid token request: %s" , err .Error ()), httpErrorCode ( http .StatusBadRequest , err ) )
505540 return
506541 }
507542 getEmail := func (claims jwt.MapClaims ) string {
@@ -562,7 +597,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
562597 claims = idTokenClaims
563598 err := f .hookOnRefresh (getEmail (claims ))
564599 if err != nil {
565- http .Error (rw , fmt .Sprintf ("refresh hook blocked refresh: %s" , err .Error ()), http .StatusBadRequest )
600+ http .Error (rw , fmt .Sprintf ("refresh hook blocked refresh: %s" , err .Error ()), httpErrorCode ( http .StatusBadRequest , err ) )
566601 return
567602 }
568603
@@ -610,7 +645,12 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
610645 http .Error (rw , "invalid access token, missing user info" , http .StatusBadRequest )
611646 return
612647 }
613- _ = json .NewEncoder (rw ).Encode (f .hookUserInfo (email ))
648+ claims , err := f .hookUserInfo (email )
649+ if err != nil {
650+ http .Error (rw , fmt .Sprintf ("user info hook returned error: %s" , err .Error ()), httpErrorCode (http .StatusBadRequest , err ))
651+ return
652+ }
653+ _ = json .NewEncoder (rw ).Encode (claims )
614654 }))
615655
616656 mux .Handle (keysPath , http .HandlerFunc (func (rw http.ResponseWriter , r * http.Request ) {
@@ -768,6 +808,15 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co
768808 return cfg
769809}
770810
811+ func httpErrorCode (defaultCode int , err error ) int {
812+ var stautsErr statusHookError
813+ status := defaultCode
814+ if errors .As (err , & stautsErr ) {
815+ status = stautsErr .HTTPStatusCode
816+ }
817+ return status
818+ }
819+
771820type fakeRoundTripper struct {
772821 roundTrip func (req * http.Request ) (* http.Response , error )
773822}
0 commit comments