Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ad0dd1b

Browse files
authored
fix: Add client certs to OAuth HTTPClient context (#5126)
1 parent 663f7a3 commit ad0dd1b

File tree

4 files changed

+55
-29
lines changed

4 files changed

+55
-29
lines changed

cli/server.go

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,16 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co
207207
listener = tls.NewListener(listener, tlsConfig)
208208
}
209209

210+
ctx, httpClient, err := configureHTTPClient(
211+
ctx,
212+
cfg.TLS.ClientCertFile.Value,
213+
cfg.TLS.ClientKeyFile.Value,
214+
cfg.TLS.ClientCAFile.Value,
215+
)
216+
if err != nil {
217+
return xerrors.Errorf("configure http client: %w", err)
218+
}
219+
210220
tcpAddr, valid := listener.Addr().(*net.TCPAddr)
211221
if !valid {
212222
return xerrors.New("must be listening on tcp")
@@ -377,6 +387,7 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co
377387
DeploymentConfig: cfg,
378388
PrometheusRegistry: prometheus.NewRegistry(),
379389
APIRateLimit: cfg.APIRateLimit.Value,
390+
HTTPClient: httpClient,
380391
}
381392
if tlsConfig != nil {
382393
options.TLSCertificates = tlsConfig.Certificates
@@ -424,11 +435,6 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co
424435
return xerrors.Errorf("OIDC issuer URL must be set!")
425436
}
426437

427-
ctx, err := handleOauth2ClientCertificates(ctx, cfg)
428-
if err != nil {
429-
return xerrors.Errorf("configure oidc client certificates: %w", err)
430-
}
431-
432438
if cfg.OIDC.IgnoreEmailVerified.Value {
433439
logger.Warn(ctx, "coder will not check email_verified for OIDC logins")
434440
}
@@ -1088,19 +1094,27 @@ func configureTLS(tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles
10881094
return nil, nil //nolint:nilnil
10891095
}
10901096

1097+
err = configureCAPool(tlsClientCAFile, tlsConfig)
1098+
if err != nil {
1099+
return nil, err
1100+
}
1101+
1102+
return tlsConfig, nil
1103+
}
1104+
1105+
func configureCAPool(tlsClientCAFile string, tlsConfig *tls.Config) error {
10911106
if tlsClientCAFile != "" {
10921107
caPool := x509.NewCertPool()
10931108
data, err := os.ReadFile(tlsClientCAFile)
10941109
if err != nil {
1095-
return nil, xerrors.Errorf("read %q: %w", tlsClientCAFile, err)
1110+
return xerrors.Errorf("read %q: %w", tlsClientCAFile, err)
10961111
}
10971112
if !caPool.AppendCertsFromPEM(data) {
1098-
return nil, xerrors.Errorf("failed to parse CA certificate in tls-client-ca-file")
1113+
return xerrors.Errorf("failed to parse CA certificate in tls-client-ca-file")
10991114
}
11001115
tlsConfig.ClientCAs = caPool
11011116
}
1102-
1103-
return tlsConfig, nil
1117+
return nil
11041118
}
11051119

11061120
//nolint:revive // Ignore flag-parameter: parameter 'allowEveryone' seems to be a control flag, avoid control coupling (revive)
@@ -1319,20 +1333,27 @@ func startBuiltinPostgres(ctx context.Context, cfg config.Root, logger slog.Logg
13191333
return connectionURL, ep.Stop, nil
13201334
}
13211335

1322-
func handleOauth2ClientCertificates(ctx context.Context, cfg *codersdk.DeploymentConfig) (context.Context, error) {
1323-
if cfg.TLS.ClientCertFile.Value != "" && cfg.TLS.ClientKeyFile.Value != "" {
1324-
certificates, err := loadCertificates([]string{cfg.TLS.ClientCertFile.Value}, []string{cfg.TLS.ClientKeyFile.Value})
1336+
func configureHTTPClient(ctx context.Context, clientCertFile, clientKeyFile string, tlsClientCAFile string) (context.Context, *http.Client, error) {
1337+
if clientCertFile != "" && clientKeyFile != "" {
1338+
certificates, err := loadCertificates([]string{clientCertFile}, []string{clientKeyFile})
13251339
if err != nil {
1326-
return nil, err
1340+
return ctx, nil, err
13271341
}
13281342

1329-
return context.WithValue(ctx, oauth2.HTTPClient, &http.Client{
1343+
tlsClientConfig := &tls.Config{ //nolint:gosec
1344+
Certificates: certificates,
1345+
}
1346+
err = configureCAPool(tlsClientCAFile, tlsClientConfig)
1347+
if err != nil {
1348+
return nil, nil, err
1349+
}
1350+
1351+
httpClient := &http.Client{
13301352
Transport: &http.Transport{
1331-
TLSClientConfig: &tls.Config{ //nolint:gosec
1332-
Certificates: certificates,
1333-
},
1353+
TLSClientConfig: tlsClientConfig,
13341354
},
1335-
}), nil
1355+
}
1356+
return context.WithValue(ctx, oauth2.HTTPClient, httpClient), httpClient, nil
13361357
}
1337-
return ctx, nil
1358+
return ctx, &http.Client{}, nil
13381359
}

coderd/coderd.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ type Options struct {
107107
Experimental bool
108108
DeploymentConfig *codersdk.DeploymentConfig
109109
UpdateCheckOptions *updatecheck.Options // Set non-nil to enable update checking.
110+
HTTPClient *http.Client
110111
}
111112

112113
// New constructs a Coder API handler.
@@ -279,7 +280,7 @@ func New(options *Options) *API {
279280
for _, gitAuthConfig := range options.GitAuthConfigs {
280281
r.Route(fmt.Sprintf("/%s", gitAuthConfig.ID), func(r chi.Router) {
281282
r.Use(
282-
httpmw.ExtractOAuth2(gitAuthConfig),
283+
httpmw.ExtractOAuth2(gitAuthConfig, options.HTTPClient),
283284
apiKeyMiddleware,
284285
)
285286
r.Get("/callback", api.gitAuthCallback(gitAuthConfig))
@@ -428,12 +429,12 @@ func New(options *Options) *API {
428429
r.Get("/authmethods", api.userAuthMethods)
429430
r.Route("/oauth2", func(r chi.Router) {
430431
r.Route("/github", func(r chi.Router) {
431-
r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config))
432+
r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.HTTPClient))
432433
r.Get("/callback", api.userOAuth2Github)
433434
})
434435
})
435436
r.Route("/oidc/callback", func(r chi.Router) {
436-
r.Use(httpmw.ExtractOAuth2(options.OIDCConfig))
437+
r.Use(httpmw.ExtractOAuth2(options.OIDCConfig, options.HTTPClient))
437438
r.Get("/", api.userOIDC)
438439
})
439440
r.Group(func(r chi.Router) {

coderd/httpmw/oauth2.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,14 @@ func OAuth2(r *http.Request) OAuth2State {
4040
// ExtractOAuth2 is a middleware for automatically redirecting to OAuth
4141
// URLs, and handling the exchange inbound. Any route that does not have
4242
// a "code" URL parameter will be redirected.
43-
func ExtractOAuth2(config OAuth2Config) func(http.Handler) http.Handler {
43+
func ExtractOAuth2(config OAuth2Config, client *http.Client) func(http.Handler) http.Handler {
4444
return func(next http.Handler) http.Handler {
4545
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
4646
ctx := r.Context()
47+
if client != nil {
48+
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
49+
}
50+
4751
// Interfaces can hold a nil value
4852
if config == nil || reflect.ValueOf(config).IsNil() {
4953
httpapi.Write(ctx, rw, http.StatusPreconditionRequired, codersdk.Response{

coderd/httpmw/oauth2_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ func TestOAuth2(t *testing.T) {
3939
t.Parallel()
4040
req := httptest.NewRequest("GET", "/", nil)
4141
res := httptest.NewRecorder()
42-
httpmw.ExtractOAuth2(nil)(nil).ServeHTTP(res, req)
42+
httpmw.ExtractOAuth2(nil, nil)(nil).ServeHTTP(res, req)
4343
require.Equal(t, http.StatusPreconditionRequired, res.Result().StatusCode)
4444
})
4545
t.Run("RedirectWithoutCode", func(t *testing.T) {
4646
t.Parallel()
4747
req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil)
4848
res := httptest.NewRecorder()
49-
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
49+
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
5050
location := res.Header().Get("Location")
5151
if !assert.NotEmpty(t, location) {
5252
return
@@ -59,14 +59,14 @@ func TestOAuth2(t *testing.T) {
5959
t.Parallel()
6060
req := httptest.NewRequest("GET", "/?code=something", nil)
6161
res := httptest.NewRecorder()
62-
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
62+
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
6363
require.Equal(t, http.StatusBadRequest, res.Result().StatusCode)
6464
})
6565
t.Run("NoStateCookie", func(t *testing.T) {
6666
t.Parallel()
6767
req := httptest.NewRequest("GET", "/?code=something&state=test", nil)
6868
res := httptest.NewRecorder()
69-
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
69+
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
7070
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
7171
})
7272
t.Run("MismatchedState", func(t *testing.T) {
@@ -77,7 +77,7 @@ func TestOAuth2(t *testing.T) {
7777
Value: "mismatch",
7878
})
7979
res := httptest.NewRecorder()
80-
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
80+
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
8181
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
8282
})
8383
t.Run("ExchangeCodeAndState", func(t *testing.T) {
@@ -92,7 +92,7 @@ func TestOAuth2(t *testing.T) {
9292
Value: "/dashboard",
9393
})
9494
res := httptest.NewRecorder()
95-
httpmw.ExtractOAuth2(&testOAuth2Provider{})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
95+
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
9696
state := httpmw.OAuth2(r)
9797
require.Equal(t, "/dashboard", state.Redirect)
9898
})).ServeHTTP(res, req)

0 commit comments

Comments
 (0)