7
7
"crypto/x509"
8
8
"encoding/json"
9
9
"encoding/pem"
10
+ "errors"
10
11
"fmt"
11
12
"io"
12
13
"net"
@@ -41,7 +42,7 @@ import (
41
42
type FakeIDP struct {
42
43
issuer string
43
44
key * rsa.PrivateKey
44
- provider providerJSON
45
+ provider ProviderJSON
45
46
handler http.Handler
46
47
cfg * oauth2.Config
47
48
@@ -66,7 +67,7 @@ type FakeIDP struct {
66
67
// IDP -> Application. Almost all IDPs have the concept of
67
68
// "Authorized Redirect URLs". This can be used to emulate that.
68
69
hookValidRedirectURL func (redirectURL string ) error
69
- hookUserInfo func (email string ) jwt.MapClaims
70
+ hookUserInfo func (email string ) ( jwt.MapClaims , error )
70
71
fakeCoderd func (req * http.Request ) (* http.Response , error )
71
72
hookOnRefresh func (email string ) error
72
73
// Custom authentication for the client. This is useful if you want
@@ -75,6 +76,26 @@ type FakeIDP struct {
75
76
serve bool
76
77
}
77
78
79
+ func StatusError (code int , err error ) error {
80
+ return statusHookError {
81
+ Err : err ,
82
+ HTTPStatusCode : code ,
83
+ }
84
+ }
85
+
86
+ // statusHookError allows a hook to change the returned http status code.
87
+ type statusHookError struct {
88
+ Err error
89
+ HTTPStatusCode int
90
+ }
91
+
92
+ func (s statusHookError ) Error () string {
93
+ if s .Err == nil {
94
+ return ""
95
+ }
96
+ return s .Err .Error ()
97
+ }
98
+
78
99
type FakeIDPOpt func (idp * FakeIDP )
79
100
80
101
func WithAuthorizedRedirectURL (hook func (redirectURL string ) error ) func (* FakeIDP ) {
@@ -83,9 +104,9 @@ func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeID
83
104
}
84
105
}
85
106
86
- // WithRefreshHook is called when a refresh token is used. The email is
107
+ // WithRefresh is called when a refresh token is used. The email is
87
108
// the email of the user that is being refreshed assuming the claims are correct.
88
- func WithRefreshHook (hook func (email string ) error ) func (* FakeIDP ) {
109
+ func WithRefresh (hook func (email string ) error ) func (* FakeIDP ) {
89
110
return func (f * FakeIDP ) {
90
111
f .hookOnRefresh = hook
91
112
}
@@ -108,13 +129,13 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
108
129
// every user on the /userinfo endpoint.
109
130
func WithStaticUserInfo (info jwt.MapClaims ) func (* FakeIDP ) {
110
131
return func (f * FakeIDP ) {
111
- f .hookUserInfo = func (_ string ) jwt.MapClaims {
112
- return info
132
+ f .hookUserInfo = func (_ string ) ( jwt.MapClaims , error ) {
133
+ return info , nil
113
134
}
114
135
}
115
136
}
116
137
117
- func WithDynamicUserInfo (userInfoFunc func (email string ) jwt.MapClaims ) func (* FakeIDP ) {
138
+ func WithDynamicUserInfo (userInfoFunc func (email string ) ( jwt.MapClaims , error ) ) func (* FakeIDP ) {
118
139
return func (f * FakeIDP ) {
119
140
f .hookUserInfo = userInfoFunc
120
141
}
@@ -160,7 +181,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
160
181
stateToIDTokenClaims : syncmap .New [string , jwt.MapClaims ](),
161
182
refreshIDTokenClaims : syncmap .New [string , jwt.MapClaims ](),
162
183
hookOnRefresh : func (_ string ) error { return nil },
163
- hookUserInfo : func (email string ) jwt.MapClaims { return jwt.MapClaims {} },
184
+ hookUserInfo : func (email string ) ( jwt.MapClaims , error ) { return jwt.MapClaims {}, nil },
164
185
hookValidRedirectURL : func (redirectURL string ) error { return nil },
165
186
}
166
187
@@ -181,16 +202,20 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
181
202
return idp
182
203
}
183
204
205
+ func (f * FakeIDP ) WellknownConfig () ProviderJSON {
206
+ return f .provider
207
+ }
208
+
184
209
func (f * FakeIDP ) updateIssuerURL (t testing.TB , issuer string ) {
185
210
t .Helper ()
186
211
187
212
u , err := url .Parse (issuer )
188
213
require .NoError (t , err , "invalid issuer URL" )
189
214
190
215
f .issuer = issuer
191
- // providerJSON is the JSON representation of the OpenID Connect provider
216
+ // ProviderJSON is the JSON representation of the OpenID Connect provider
192
217
// These are all the urls that the IDP will respond to.
193
- f .provider = providerJSON {
218
+ f .provider = ProviderJSON {
194
219
Issuer : issuer ,
195
220
AuthURL : u .ResolveReference (& url.URL {Path : authorizePath }).String (),
196
221
TokenURL : u .ResolveReference (& url.URL {Path : tokenPath }).String (),
@@ -220,6 +245,15 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
220
245
return srv
221
246
}
222
247
248
+ // GenerateAuthenticatedToken skips all oauth2 flows, and just generates a
249
+ // valid token for some given claims.
250
+ func (f * FakeIDP ) GenerateAuthenticatedToken (claims jwt.MapClaims ) (* oauth2.Token , error ) {
251
+ state := uuid .NewString ()
252
+ f .stateToIDTokenClaims .Store (state , claims )
253
+ code := f .newCode (state )
254
+ return f .cfg .Exchange (oidc .ClientContext (context .Background (), f .HTTPClient (nil )), code )
255
+ }
256
+
223
257
// Login does the full OIDC flow starting at the "LoginButton".
224
258
// The client argument is just to get the URL of the Coder instance.
225
259
//
@@ -333,7 +367,8 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map
333
367
return resp , nil
334
368
}
335
369
336
- type providerJSON struct {
370
+ // ProviderJSON is the .well-known/configuration JSON
371
+ type ProviderJSON struct {
337
372
Issuer string `json:"issuer"`
338
373
AuthURL string `json:"authorization_endpoint"`
339
374
TokenURL string `json:"token_endpoint"`
@@ -475,7 +510,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
475
510
err := f .hookValidRedirectURL (redirectURI )
476
511
if err != nil {
477
512
t .Errorf ("not authorized redirect_uri by custom hook %q: %s" , redirectURI , err .Error ())
478
- http .Error (rw , fmt .Sprintf ("invalid redirect_uri: %s" , err .Error ()), http .StatusBadRequest )
513
+ http .Error (rw , fmt .Sprintf ("invalid redirect_uri: %s" , err .Error ()), httpErrorCode ( http .StatusBadRequest , err ) )
479
514
return
480
515
}
481
516
@@ -501,7 +536,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
501
536
slog .F ("values" , values .Encode ()),
502
537
)
503
538
if err != nil {
504
- http .Error (rw , fmt .Sprintf ("invalid token request: %s" , err .Error ()), http .StatusBadRequest )
539
+ http .Error (rw , fmt .Sprintf ("invalid token request: %s" , err .Error ()), httpErrorCode ( http .StatusBadRequest , err ) )
505
540
return
506
541
}
507
542
getEmail := func (claims jwt.MapClaims ) string {
@@ -562,7 +597,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
562
597
claims = idTokenClaims
563
598
err := f .hookOnRefresh (getEmail (claims ))
564
599
if err != nil {
565
- http .Error (rw , fmt .Sprintf ("refresh hook blocked refresh: %s" , err .Error ()), http .StatusBadRequest )
600
+ http .Error (rw , fmt .Sprintf ("refresh hook blocked refresh: %s" , err .Error ()), httpErrorCode ( http .StatusBadRequest , err ) )
566
601
return
567
602
}
568
603
@@ -610,7 +645,12 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
610
645
http .Error (rw , "invalid access token, missing user info" , http .StatusBadRequest )
611
646
return
612
647
}
613
- _ = json .NewEncoder (rw ).Encode (f .hookUserInfo (email ))
648
+ claims , err := f .hookUserInfo (email )
649
+ if err != nil {
650
+ http .Error (rw , fmt .Sprintf ("user info hook returned error: %s" , err .Error ()), httpErrorCode (http .StatusBadRequest , err ))
651
+ return
652
+ }
653
+ _ = json .NewEncoder (rw ).Encode (claims )
614
654
}))
615
655
616
656
mux .Handle (keysPath , http .HandlerFunc (func (rw http.ResponseWriter , r * http.Request ) {
@@ -768,6 +808,15 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co
768
808
return cfg
769
809
}
770
810
811
+ func httpErrorCode (defaultCode int , err error ) int {
812
+ var stautsErr statusHookError
813
+ status := defaultCode
814
+ if errors .As (err , & stautsErr ) {
815
+ status = stautsErr .HTTPStatusCode
816
+ }
817
+ return status
818
+ }
819
+
771
820
type fakeRoundTripper struct {
772
821
roundTrip func (req * http.Request ) (* http.Response , error )
773
822
}
0 commit comments