Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a63d27b

Browse files
committed
Initial oauth
1 parent 3304db0 commit a63d27b

File tree

16 files changed

+672
-154
lines changed

16 files changed

+672
-154
lines changed

coderd/coderd.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ type Options struct {
3434

3535
AWSCertificates awsidentity.Certificates
3636
GoogleTokenValidator *idtoken.Validator
37+
GithubOAuth2Provider GithubOAuth2Provider
3738

3839
SecureAuthCookie bool
3940
SSHKeygenAlgorithm gitsshkey.Algorithm
@@ -142,6 +143,12 @@ func New(options *Options) (http.Handler, func()) {
142143
r.Post("/first", api.postFirstUser)
143144
r.Post("/login", api.postLogin)
144145
r.Post("/logout", api.postLogout)
146+
r.Route("/auth", func(r chi.Router) {
147+
r.Route("/callback/github", func(r chi.Router) {
148+
r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Provider))
149+
r.Get("/", api.userAuthGithub)
150+
})
151+
})
145152
r.Group(func(r chi.Router) {
146153
r.Use(httpmw.ExtractAPIKey(options.Database, nil))
147154
r.Post("/", api.postUsers)

coderd/coderdtest/coderdtest.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ import (
4949

5050
type Options struct {
5151
AWSInstanceIdentity awsidentity.Certificates
52+
GithubOAuth2Provider coderd.GithubOAuth2Provider
5253
GoogleInstanceIdentity *idtoken.Validator
5354
SSHKeygenAlgorithm gitsshkey.Algorithm
5455
}
@@ -115,6 +116,7 @@ func New(t *testing.T, options *Options) *codersdk.Client {
115116
Pubsub: pubsub,
116117

117118
AWSCertificates: options.AWSInstanceIdentity,
119+
GithubOAuth2Provider: options.GithubOAuth2Provider,
118120
GoogleTokenValidator: options.GoogleInstanceIdentity,
119121
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
120122
})

coderd/database/databasefake/databasefake.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,16 @@ func (q *fakeQuerier) GetWorkspacesByUserID(_ context.Context, req database.GetW
365365
return workspaces, nil
366366
}
367367

368+
func (q *fakeQuerier) GetOrganizations(_ context.Context) ([]database.Organization, error) {
369+
q.mutex.RLock()
370+
defer q.mutex.RUnlock()
371+
372+
if len(q.organizations) == 0 {
373+
return nil, sql.ErrNoRows
374+
}
375+
return q.organizations, nil
376+
}
377+
368378
func (q *fakeQuerier) GetOrganizationByID(_ context.Context, id uuid.UUID) (database.Organization, error) {
369379
q.mutex.RLock()
370380
defer q.mutex.RUnlock()

coderd/database/querier.go

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

Lines changed: 36 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/organizations.sql

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
-- name: GetOrganizations :many
2+
SELECT
3+
*
4+
FROM
5+
organizations;
6+
17
-- name: GetOrganizationByID :one
28
SELECT
39
*

coderd/httpmw/oauth.go

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package httpmw
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/http"
7+
8+
"golang.org/x/oauth2"
9+
10+
"github.com/coder/coder/coderd/httpapi"
11+
"github.com/coder/coder/cryptorand"
12+
)
13+
14+
const (
15+
oauth2StateCookieName = "oauth_state"
16+
oauth2RedirectCookieName = "oauth_redirect"
17+
)
18+
19+
type oauth2StateKey struct{}
20+
21+
type OAuth2State struct {
22+
Token *oauth2.Token
23+
Redirect string
24+
}
25+
26+
// OAuth2Provider exposes a subset of *oauth2.Config functions for easier testing.
27+
// *oauth2.Config should be used instead of implementing this in production.
28+
type OAuth2Provider interface {
29+
AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
30+
Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
31+
}
32+
33+
// OAuth2 returns the state from an oauth request.
34+
func OAuth2(r *http.Request) OAuth2State {
35+
oauth, ok := r.Context().Value(oauth2StateKey{}).(OAuth2State)
36+
if !ok {
37+
panic("developer error: oauth middleware not provided")
38+
}
39+
return oauth
40+
}
41+
42+
// ExtractOAuth2 adds a middleware for handling OAuth2 callbacks.
43+
// Any route that does not have a "code" URL parameter will be redirected
44+
// to the handler configuration provided.
45+
func ExtractOAuth2(provider OAuth2Provider) func(http.Handler) http.Handler {
46+
return func(next http.Handler) http.Handler {
47+
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
48+
code := r.URL.Query().Get("code")
49+
state := r.URL.Query().Get("state")
50+
51+
if code == "" {
52+
// If the code isn't provided, we'll redirect!
53+
state, err := cryptorand.String(32)
54+
if err != nil {
55+
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
56+
Message: fmt.Sprintf("generate state string: %s", err),
57+
})
58+
return
59+
}
60+
61+
http.SetCookie(rw, &http.Cookie{
62+
Name: oauth2StateCookieName,
63+
Value: state,
64+
Path: "/",
65+
HttpOnly: true,
66+
SameSite: http.SameSiteStrictMode,
67+
})
68+
// Redirect must always be specified, otherwise
69+
// an old redirect could apply!
70+
http.SetCookie(rw, &http.Cookie{
71+
Name: oauth2RedirectCookieName,
72+
Value: r.URL.Query().Get("redirect"),
73+
Path: "/",
74+
HttpOnly: true,
75+
SameSite: http.SameSiteStrictMode,
76+
})
77+
78+
http.Redirect(rw, r, provider.AuthCodeURL(state, oauth2.AccessTypeOffline), http.StatusTemporaryRedirect)
79+
return
80+
}
81+
82+
if state == "" {
83+
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
84+
Message: "state must be provided",
85+
})
86+
return
87+
}
88+
89+
stateCookie, err := r.Cookie(oauth2StateCookieName)
90+
if err != nil {
91+
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
92+
Message: fmt.Sprintf("%q cookie must be provided", oauth2StateCookieName),
93+
})
94+
return
95+
}
96+
if stateCookie.Value != state {
97+
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
98+
Message: "state mismatched",
99+
})
100+
return
101+
}
102+
103+
var redirect string
104+
stateRedirect, err := r.Cookie(oauth2RedirectCookieName)
105+
if err == nil {
106+
redirect = stateRedirect.Value
107+
}
108+
109+
oauthToken, err := provider.Exchange(r.Context(), code)
110+
if err != nil {
111+
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
112+
Message: fmt.Sprintf("exchange oauth code: %s", err),
113+
})
114+
return
115+
}
116+
117+
ctx := context.WithValue(r.Context(), oauth2StateKey{}, OAuth2State{
118+
Token: oauthToken,
119+
Redirect: redirect,
120+
})
121+
next.ServeHTTP(rw, r.WithContext(ctx))
122+
})
123+
}
124+
}

coderd/httpmw/oauth_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package httpmw_test
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"net/url"
8+
"testing"
9+
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
"golang.org/x/oauth2"
13+
14+
"github.com/coder/coder/coderd/httpmw"
15+
)
16+
17+
type testOAuth2Provider struct {
18+
}
19+
20+
func (*testOAuth2Provider) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
21+
return "?state=" + url.QueryEscape(state)
22+
}
23+
24+
func (*testOAuth2Provider) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
25+
return &oauth2.Token{
26+
AccessToken: "hello",
27+
}, nil
28+
}
29+
30+
func TestOAuth2(t *testing.T) {
31+
t.Parallel()
32+
t.Run("RedirectWithoutCode", func(t *testing.T) {
33+
t.Parallel()
34+
req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil)
35+
res := httptest.NewRecorder()
36+
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
37+
location := res.Header().Get("Location")
38+
if !assert.NotEmpty(t, location) {
39+
return
40+
}
41+
require.Len(t, res.Result().Cookies(), 2)
42+
cookie := res.Result().Cookies()[1]
43+
require.Equal(t, "/dashboard", cookie.Value)
44+
})
45+
t.Run("NoState", func(t *testing.T) {
46+
t.Parallel()
47+
req := httptest.NewRequest("GET", "/?code=something", nil)
48+
res := httptest.NewRecorder()
49+
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
50+
require.Equal(t, http.StatusBadRequest, res.Result().StatusCode)
51+
})
52+
t.Run("NoStateCookie", func(t *testing.T) {
53+
t.Parallel()
54+
req := httptest.NewRequest("GET", "/?code=something&state=test", nil)
55+
res := httptest.NewRecorder()
56+
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
57+
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
58+
})
59+
t.Run("MismatchedState", func(t *testing.T) {
60+
t.Parallel()
61+
req := httptest.NewRequest("GET", "/?code=something&state=test", nil)
62+
req.AddCookie(&http.Cookie{
63+
Name: "oauth_state",
64+
Value: "mismatch",
65+
})
66+
res := httptest.NewRecorder()
67+
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
68+
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
69+
})
70+
t.Run("ExchangeCodeAndState", func(t *testing.T) {
71+
t.Parallel()
72+
req := httptest.NewRequest("GET", "/?code=test&state=something", nil)
73+
req.AddCookie(&http.Cookie{
74+
Name: "oauth_state",
75+
Value: "something",
76+
})
77+
req.AddCookie(&http.Cookie{
78+
Name: "oauth_redirect",
79+
Value: "/dashboard",
80+
})
81+
res := httptest.NewRecorder()
82+
httpmw.ExtractOAuth2(&testOAuth2Provider{})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
83+
state := httpmw.OAuth2(r)
84+
require.Equal(t, "/dashboard", state.Redirect)
85+
})).ServeHTTP(res, req)
86+
})
87+
88+
// t.Run("ExchangeCodeAndState", func(t *testing.T) {
89+
// t.Parallel()
90+
// req := httptest.NewRequest("GET", "/?code=test&state="+url.QueryEscape(state), nil)
91+
// res := httptest.NewRecorder()
92+
// ExtractOAuth(log, cipher, &testOAuthProvider{})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
93+
// rw.WriteHeader(http.StatusOK)
94+
// })).ServeHTTP(res, req)
95+
// assert.Equal(t, res.Result().StatusCode, http.StatusOK)
96+
// })
97+
}

0 commit comments

Comments
 (0)