@@ -508,6 +508,102 @@ func TestAPIKey(t *testing.T) {
508
508
require .Equal (t , sentAPIKey .ExpiresAt , gotAPIKey .ExpiresAt )
509
509
})
510
510
511
+ t .Run ("APIKeyExpiredOAuthExpired" , func (t * testing.T ) {
512
+ t .Parallel ()
513
+ var (
514
+ db = dbmem .New ()
515
+ user = dbgen .User (t , db , database.User {})
516
+ sentAPIKey , token = dbgen .APIKey (t , db , database.APIKey {
517
+ UserID : user .ID ,
518
+ LastUsed : dbtime .Now ().AddDate (0 , 0 , - 1 ),
519
+ ExpiresAt : dbtime .Now ().AddDate (0 , 0 , - 1 ),
520
+ LoginType : database .LoginTypeOIDC ,
521
+ })
522
+ _ = dbgen .UserLink (t , db , database.UserLink {
523
+ UserID : user .ID ,
524
+ LoginType : database .LoginTypeOIDC ,
525
+ OAuthExpiry : dbtime .Now ().AddDate (0 , 0 , - 1 ),
526
+ })
527
+
528
+ r = httptest .NewRequest ("GET" , "/" , nil )
529
+ rw = httptest .NewRecorder ()
530
+ )
531
+ r .Header .Set (codersdk .SessionTokenHeader , token )
532
+
533
+ // Include a valid oauth token for refreshing. If this token is invalid,
534
+ // it is difficult to tell an auth failure from an expired api key, or
535
+ // an expired oauth key.
536
+ oauthToken := & oauth2.Token {
537
+ AccessToken : "wow" ,
538
+ RefreshToken : "moo" ,
539
+ Expiry : dbtime .Now ().AddDate (0 , 0 , 1 ),
540
+ }
541
+ httpmw .ExtractAPIKeyMW (httpmw.ExtractAPIKeyConfig {
542
+ DB : db ,
543
+ OAuth2Configs : & httpmw.OAuth2Configs {
544
+ OIDC : & testutil.OAuth2Config {
545
+ Token : oauthToken ,
546
+ },
547
+ },
548
+ RedirectToLogin : false ,
549
+ })(successHandler ).ServeHTTP (rw , r )
550
+ res := rw .Result ()
551
+ defer res .Body .Close ()
552
+ require .Equal (t , http .StatusUnauthorized , res .StatusCode )
553
+
554
+ gotAPIKey , err := db .GetAPIKeyByID (r .Context (), sentAPIKey .ID )
555
+ require .NoError (t , err )
556
+
557
+ require .Equal (t , sentAPIKey .LastUsed , gotAPIKey .LastUsed )
558
+ require .Equal (t , sentAPIKey .ExpiresAt , gotAPIKey .ExpiresAt )
559
+ })
560
+
561
+ t .Run ("APIKeyExpiredOAuthNotExpired" , func (t * testing.T ) {
562
+ t .Parallel ()
563
+ var (
564
+ db = dbmem .New ()
565
+ user = dbgen .User (t , db , database.User {})
566
+ sentAPIKey , token = dbgen .APIKey (t , db , database.APIKey {
567
+ UserID : user .ID ,
568
+ LastUsed : dbtime .Now ().AddDate (0 , 0 , - 1 ),
569
+ ExpiresAt : dbtime .Now ().AddDate (0 , 0 , - 1 ),
570
+ LoginType : database .LoginTypeOIDC ,
571
+ })
572
+ _ = dbgen .UserLink (t , db , database.UserLink {
573
+ UserID : user .ID ,
574
+ LoginType : database .LoginTypeOIDC ,
575
+ })
576
+
577
+ r = httptest .NewRequest ("GET" , "/" , nil )
578
+ rw = httptest .NewRecorder ()
579
+ )
580
+ r .Header .Set (codersdk .SessionTokenHeader , token )
581
+
582
+ oauthToken := & oauth2.Token {
583
+ AccessToken : "wow" ,
584
+ RefreshToken : "moo" ,
585
+ Expiry : dbtime .Now ().AddDate (0 , 0 , 1 ),
586
+ }
587
+ httpmw .ExtractAPIKeyMW (httpmw.ExtractAPIKeyConfig {
588
+ DB : db ,
589
+ OAuth2Configs : & httpmw.OAuth2Configs {
590
+ OIDC : & testutil.OAuth2Config {
591
+ Token : oauthToken ,
592
+ },
593
+ },
594
+ RedirectToLogin : false ,
595
+ })(successHandler ).ServeHTTP (rw , r )
596
+ res := rw .Result ()
597
+ defer res .Body .Close ()
598
+ require .Equal (t , http .StatusUnauthorized , res .StatusCode )
599
+
600
+ gotAPIKey , err := db .GetAPIKeyByID (r .Context (), sentAPIKey .ID )
601
+ require .NoError (t , err )
602
+
603
+ require .Equal (t , sentAPIKey .LastUsed , gotAPIKey .LastUsed )
604
+ require .Equal (t , sentAPIKey .ExpiresAt , gotAPIKey .ExpiresAt )
605
+ })
606
+
511
607
t .Run ("OAuthRefresh" , func (t * testing.T ) {
512
608
t .Parallel ()
513
609
var (
@@ -553,7 +649,67 @@ func TestAPIKey(t *testing.T) {
553
649
require .NoError (t , err )
554
650
555
651
require .Equal (t , sentAPIKey .LastUsed , gotAPIKey .LastUsed )
556
- require .Equal (t , oauthToken .Expiry , gotAPIKey .ExpiresAt )
652
+ // Note that OAuth expiry is independent of APIKey expiry, so an OIDC refresh DOES NOT affect the expiry of the
653
+ // APIKey
654
+ require .Equal (t , sentAPIKey .ExpiresAt , gotAPIKey .ExpiresAt )
655
+
656
+ gotLink , err := db .GetUserLinkByUserIDLoginType (r .Context (), database.GetUserLinkByUserIDLoginTypeParams {
657
+ UserID : user .ID ,
658
+ LoginType : database .LoginTypeGithub ,
659
+ })
660
+ require .NoError (t , err )
661
+ require .Equal (t , gotLink .OAuthRefreshToken , "moo" )
662
+ })
663
+
664
+ t .Run ("OAuthExpiredNoRefresh" , func (t * testing.T ) {
665
+ t .Parallel ()
666
+ var (
667
+ ctx = testutil .Context (t , testutil .WaitShort )
668
+ db = dbmem .New ()
669
+ user = dbgen .User (t , db , database.User {})
670
+ sentAPIKey , token = dbgen .APIKey (t , db , database.APIKey {
671
+ UserID : user .ID ,
672
+ LastUsed : dbtime .Now (),
673
+ ExpiresAt : dbtime .Now ().AddDate (0 , 0 , 1 ),
674
+ LoginType : database .LoginTypeGithub ,
675
+ })
676
+
677
+ r = httptest .NewRequest ("GET" , "/" , nil )
678
+ rw = httptest .NewRecorder ()
679
+ )
680
+ _ , err := db .InsertUserLink (ctx , database.InsertUserLinkParams {
681
+ UserID : user .ID ,
682
+ LoginType : database .LoginTypeGithub ,
683
+ OAuthExpiry : dbtime .Now ().AddDate (0 , 0 , - 1 ),
684
+ OAuthAccessToken : "letmein" ,
685
+ })
686
+ require .NoError (t , err )
687
+
688
+ r .Header .Set (codersdk .SessionTokenHeader , token )
689
+
690
+ oauthToken := & oauth2.Token {
691
+ AccessToken : "wow" ,
692
+ RefreshToken : "moo" ,
693
+ Expiry : dbtime .Now ().AddDate (0 , 0 , 1 ),
694
+ }
695
+ httpmw .ExtractAPIKeyMW (httpmw.ExtractAPIKeyConfig {
696
+ DB : db ,
697
+ OAuth2Configs : & httpmw.OAuth2Configs {
698
+ Github : & testutil.OAuth2Config {
699
+ Token : oauthToken ,
700
+ },
701
+ },
702
+ RedirectToLogin : false ,
703
+ })(successHandler ).ServeHTTP (rw , r )
704
+ res := rw .Result ()
705
+ defer res .Body .Close ()
706
+ require .Equal (t , http .StatusUnauthorized , res .StatusCode )
707
+
708
+ gotAPIKey , err := db .GetAPIKeyByID (r .Context (), sentAPIKey .ID )
709
+ require .NoError (t , err )
710
+
711
+ require .Equal (t , sentAPIKey .LastUsed , gotAPIKey .LastUsed )
712
+ require .Equal (t , sentAPIKey .ExpiresAt , gotAPIKey .ExpiresAt )
557
713
})
558
714
559
715
t .Run ("RemoteIPUpdates" , func (t * testing.T ) {
0 commit comments