Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 61683f1

Browse files
authored
fix: allow for alternate usernames on conflict (#4614)
1 parent 3c40698 commit 61683f1

File tree

5 files changed

+192
-49
lines changed

5 files changed

+192
-49
lines changed

coderd/coderdtest/coderdtest.go

+92
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@ import (
2929
"time"
3030

3131
"cloud.google.com/go/compute/metadata"
32+
"github.com/coreos/go-oidc/v3/oidc"
3233
"github.com/fullsailor/pkcs7"
3334
"github.com/golang-jwt/jwt"
3435
"github.com/google/uuid"
3536
"github.com/moby/moby/pkg/namesgenerator"
3637
"github.com/spf13/afero"
3738
"github.com/stretchr/testify/assert"
3839
"github.com/stretchr/testify/require"
40+
"golang.org/x/oauth2"
3941
"golang.org/x/xerrors"
4042
"google.golang.org/api/idtoken"
4143
"google.golang.org/api/option"
@@ -725,6 +727,80 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif
725727
}
726728
}
727729

730+
type OIDCConfig struct {
731+
key *rsa.PrivateKey
732+
issuer string
733+
}
734+
735+
func NewOIDCConfig(t *testing.T, issuer string) *OIDCConfig {
736+
t.Helper()
737+
738+
block, _ := pem.Decode([]byte(testRSAPrivateKey))
739+
pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
740+
require.NoError(t, err)
741+
742+
if issuer == "" {
743+
issuer = "https://coder.com"
744+
}
745+
746+
return &OIDCConfig{
747+
key: pkey,
748+
issuer: issuer,
749+
}
750+
}
751+
752+
func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
753+
return "/?state=" + url.QueryEscape(state)
754+
}
755+
756+
func (*OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
757+
return nil
758+
}
759+
760+
func (*OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
761+
token, err := base64.StdEncoding.DecodeString(code)
762+
if err != nil {
763+
return nil, xerrors.Errorf("decode code: %w", err)
764+
}
765+
return (&oauth2.Token{
766+
AccessToken: "token",
767+
}).WithExtra(map[string]interface{}{
768+
"id_token": string(token),
769+
}), nil
770+
}
771+
772+
func (o *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
773+
t.Helper()
774+
775+
if _, ok := claims["exp"]; !ok {
776+
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
777+
}
778+
779+
if _, ok := claims["iss"]; !ok {
780+
claims["iss"] = o.issuer
781+
}
782+
783+
if _, ok := claims["sub"]; !ok {
784+
claims["sub"] = "testme"
785+
}
786+
787+
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(o.key)
788+
require.NoError(t, err)
789+
790+
return base64.StdEncoding.EncodeToString([]byte(signed))
791+
}
792+
793+
func (o *OIDCConfig) OIDCConfig() *coderd.OIDCConfig {
794+
return &coderd.OIDCConfig{
795+
OAuth2Config: o,
796+
Verifier: oidc.NewVerifier(o.issuer, &oidc.StaticKeySet{
797+
PublicKeys: []crypto.PublicKey{o.key.Public()},
798+
}, &oidc.Config{
799+
SkipClientIDCheck: true,
800+
}),
801+
}
802+
}
803+
728804
// NewAzureInstanceIdentity returns a metadata client and ID token validator for faking
729805
// instance authentication for Azure.
730806
func NewAzureInstanceIdentity(t *testing.T, instanceID string) (x509.VerifyOptions, *http.Client) {
@@ -805,3 +881,19 @@ func SDKError(t *testing.T, err error) *codersdk.Error {
805881
require.True(t, errors.As(err, &cerr))
806882
return cerr
807883
}
884+
885+
const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
886+
MIICXQIBAAKBgQDLets8+7M+iAQAqN/5BVyCIjhTQ4cmXulL+gm3v0oGMWzLupUS
887+
v8KPA+Tp7dgC/DZPfMLaNH1obBBhJ9DhS6RdS3AS3kzeFrdu8zFHLWF53DUBhS92
888+
5dCAEuJpDnNizdEhxTfoHrhuCmz8l2nt1pe5eUK2XWgd08Uc93h5ij098wIDAQAB
889+
AoGAHLaZeWGLSaen6O/rqxg2laZ+jEFbMO7zvOTruiIkL/uJfrY1kw+8RLIn+1q0
890+
wLcWcuEIHgKKL9IP/aXAtAoYh1FBvRPLkovF1NZB0Je/+CSGka6wvc3TGdvppZJe
891+
rKNcUvuOYLxkmLy4g9zuY5qrxFyhtIn2qZzXEtLaVOHzPQECQQDvN0mSajpU7dTB
892+
w4jwx7IRXGSSx65c+AsHSc1Rj++9qtPC6WsFgAfFN2CEmqhMbEUVGPv/aPjdyWk9
893+
pyLE9xR/AkEA2cGwyIunijE5v2rlZAD7C4vRgdcMyCf3uuPcgzFtsR6ZhyQSgLZ8
894+
YRPuvwm4cdPJMmO3YwBfxT6XGuSc2k8MjQJBAI0+b8prvpV2+DCQa8L/pjxp+VhR
895+
Xrq2GozrHrgR7NRokTB88hwFRJFF6U9iogy9wOx8HA7qxEbwLZuhm/4AhbECQC2a
896+
d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf
897+
sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u
898+
QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8
899+
-----END RSA PRIVATE KEY-----`

coderd/database/databasefake/databasefake.go

+6
Original file line numberDiff line numberDiff line change
@@ -2221,6 +2221,12 @@ func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam
22212221
q.mutex.Lock()
22222222
defer q.mutex.Unlock()
22232223

2224+
for _, user := range q.users {
2225+
if user.Username == arg.Username && !user.Deleted {
2226+
return database.User{}, errDuplicateKey
2227+
}
2228+
}
2229+
22242230
user := database.User{
22252231
ID: arg.ID,
22262232
Email: arg.Email,

coderd/userauth.go

+33
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/coreos/go-oidc/v3/oidc"
1414
"github.com/google/go-github/v43/github"
1515
"github.com/google/uuid"
16+
"github.com/moby/moby/pkg/namesgenerator"
1617
"golang.org/x/oauth2"
1718
"golang.org/x/xerrors"
1819

@@ -390,6 +391,38 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
390391
organizationID = organizations[0].ID
391392
}
392393

394+
_, err := tx.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
395+
Username: params.Username,
396+
})
397+
if err == nil {
398+
var (
399+
original = params.Username
400+
validUsername bool
401+
)
402+
for i := 0; i < 10; i++ {
403+
alternate := fmt.Sprintf("%s-%s", original, namesgenerator.GetRandomName(1))
404+
405+
params.Username = httpapi.UsernameFrom(alternate)
406+
407+
_, err := tx.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
408+
Username: params.Username,
409+
})
410+
if xerrors.Is(err, sql.ErrNoRows) {
411+
validUsername = true
412+
break
413+
}
414+
if err != nil {
415+
return xerrors.Errorf("get user by email/username: %w", err)
416+
}
417+
}
418+
if !validUsername {
419+
return httpError{
420+
code: http.StatusConflict,
421+
msg: fmt.Sprintf("exhausted alternatives for taken username %q", original),
422+
}
423+
}
424+
}
425+
393426
user, _, err = api.CreateUser(ctx, tx, CreateUserRequest{
394427
CreateUserRequest: codersdk.CreateUserRequest{
395428
Email: params.Email,

coderd/userauth_test.go

+55-47
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@ package coderd_test
33
import (
44
"context"
55
"crypto"
6-
"crypto/rand"
7-
"crypto/rsa"
6+
"fmt"
87
"io"
98
"net/http"
109
"net/url"
10+
"strings"
1111
"testing"
12-
"time"
1312

1413
"github.com/coreos/go-oidc/v3/oidc"
1514
"github.com/golang-jwt/jwt"
@@ -450,17 +449,19 @@ func TestUserOIDC(t *testing.T) {
450449
tc := tc
451450
t.Run(tc.Name, func(t *testing.T) {
452451
t.Parallel()
453-
config := createOIDCConfig(t, tc.Claims)
452+
conf := coderdtest.NewOIDCConfig(t, "")
453+
454+
config := conf.OIDCConfig()
454455
config.AllowSignups = tc.AllowSignups
455456
config.EmailDomain = tc.EmailDomain
457+
456458
client := coderdtest.New(t, &coderdtest.Options{
457459
OIDCConfig: config,
458460
})
459-
resp := oidcCallback(t, client)
461+
resp := oidcCallback(t, client, conf.EncodeClaims(t, tc.Claims))
460462
assert.Equal(t, tc.StatusCode, resp.StatusCode)
461463

462-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
463-
defer cancel()
464+
ctx, _ := testutil.Context(t)
464465

465466
if tc.Username != "" {
466467
client.SessionToken = authCookieValue(resp.Cookies())
@@ -478,10 +479,50 @@ func TestUserOIDC(t *testing.T) {
478479
})
479480
}
480481

482+
t.Run("AlternateUsername", func(t *testing.T) {
483+
t.Parallel()
484+
485+
conf := coderdtest.NewOIDCConfig(t, "")
486+
487+
config := conf.OIDCConfig()
488+
config.AllowSignups = true
489+
490+
client := coderdtest.New(t, &coderdtest.Options{
491+
OIDCConfig: config,
492+
})
493+
494+
code := conf.EncodeClaims(t, jwt.MapClaims{
495+
"email": "[email protected]",
496+
})
497+
resp := oidcCallback(t, client, code)
498+
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
499+
500+
ctx, _ := testutil.Context(t)
501+
502+
client.SessionToken = authCookieValue(resp.Cookies())
503+
user, err := client.User(ctx, "me")
504+
require.NoError(t, err)
505+
require.Equal(t, "jon", user.Username)
506+
507+
// Pass a different subject field so that we prompt creating a
508+
// new user.
509+
code = conf.EncodeClaims(t, jwt.MapClaims{
510+
"email": "[email protected]",
511+
"sub": "diff",
512+
})
513+
resp = oidcCallback(t, client, code)
514+
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
515+
516+
client.SessionToken = authCookieValue(resp.Cookies())
517+
user, err = client.User(ctx, "me")
518+
require.NoError(t, err)
519+
require.True(t, strings.HasPrefix(user.Username, "jon-"), "username %q should have prefix %q", user.Username, "jon-")
520+
})
521+
481522
t.Run("Disabled", func(t *testing.T) {
482523
t.Parallel()
483524
client := coderdtest.New(t, nil)
484-
resp := oidcCallback(t, client)
525+
resp := oidcCallback(t, client, "asdf")
485526
require.Equal(t, http.StatusPreconditionRequired, resp.StatusCode)
486527
})
487528

@@ -492,7 +533,7 @@ func TestUserOIDC(t *testing.T) {
492533
OAuth2Config: &oauth2Config{},
493534
},
494535
})
495-
resp := oidcCallback(t, client)
536+
resp := oidcCallback(t, client, "asdf")
496537
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
497538
})
498539

@@ -514,48 +555,16 @@ func TestUserOIDC(t *testing.T) {
514555
Verifier: verifier,
515556
},
516557
})
517-
resp := oidcCallback(t, client)
558+
resp := oidcCallback(t, client, "asdf")
518559
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
519560
})
520561
}
521562

522-
// createOIDCConfig generates a new OIDCConfig that returns a static token
523-
// with the claims provided.
524-
func createOIDCConfig(t *testing.T, claims jwt.MapClaims) *coderd.OIDCConfig {
525-
t.Helper()
526-
key, err := rsa.GenerateKey(rand.Reader, 2048)
527-
require.NoError(t, err)
528-
529-
// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1
530-
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
531-
claims["iss"] = "https://coder.com"
532-
claims["sub"] = "hello"
533-
534-
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(key)
535-
require.NoError(t, err)
536-
537-
verifier := oidc.NewVerifier("https://coder.com", &oidc.StaticKeySet{
538-
PublicKeys: []crypto.PublicKey{key.Public()},
539-
}, &oidc.Config{
540-
SkipClientIDCheck: true,
541-
})
542-
543-
return &coderd.OIDCConfig{
544-
OAuth2Config: &oauth2Config{
545-
token: (&oauth2.Token{
546-
AccessToken: "token",
547-
}).WithExtra(map[string]interface{}{
548-
"id_token": signed,
549-
}),
550-
},
551-
Verifier: verifier,
552-
}
553-
}
554-
555563
func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
556564
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
557565
return http.ErrUseLastResponse
558566
}
567+
559568
state := "somestate"
560569
oauthURL, err := client.URL.Parse("/api/v2/users/oauth2/github/callback?code=asd&state=" + state)
561570
require.NoError(t, err)
@@ -573,19 +582,18 @@ func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
573582
return res
574583
}
575584

576-
func oidcCallback(t *testing.T, client *codersdk.Client) *http.Response {
585+
func oidcCallback(t *testing.T, client *codersdk.Client, code string) *http.Response {
577586
t.Helper()
578587
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
579588
return http.ErrUseLastResponse
580589
}
581-
state := "somestate"
582-
oauthURL, err := client.URL.Parse("/api/v2/users/oidc/callback?code=asd&state=" + state)
590+
oauthURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/users/oidc/callback?code=%s&state=somestate", code))
583591
require.NoError(t, err)
584592
req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil)
585593
require.NoError(t, err)
586594
req.AddCookie(&http.Cookie{
587595
Name: codersdk.OAuth2StateKey,
588-
Value: state,
596+
Value: "somestate",
589597
})
590598
res, err := client.HTTPClient.Do(req)
591599
require.NoError(t, err)

enterprise/coderd/license/license_test.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,12 @@ func TestEntitlements(t *testing.T) {
140140
t.Run("TooManyUsers", func(t *testing.T) {
141141
t.Parallel()
142142
db := databasefake.New()
143-
db.InsertUser(context.Background(), database.InsertUserParams{})
144-
db.InsertUser(context.Background(), database.InsertUserParams{})
143+
db.InsertUser(context.Background(), database.InsertUserParams{
144+
Username: "test1",
145+
})
146+
db.InsertUser(context.Background(), database.InsertUserParams{
147+
Username: "test2",
148+
})
145149
db.InsertLicense(context.Background(), database.InsertLicenseParams{
146150
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
147151
UserLimit: 1,

0 commit comments

Comments
 (0)