Thanks to visit codestin.com
Credit goes to github.com

Skip to content

chore: make scim auth header case insensitive for 'bearer' #15538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 24 additions & 18 deletions enterprise/coderd/scim.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package coderd

import (
"bytes"
"crypto/subtle"
"database/sql"
"encoding/json"
Expand All @@ -26,16 +27,21 @@ import (
)

func (api *API) scimVerifyAuthHeader(r *http.Request) bool {
bearer := []byte("Bearer ")
bearer := []byte("bearer ")
hdr := []byte(r.Header.Get("Authorization"))

if len(hdr) >= len(bearer) && subtle.ConstantTimeCompare(hdr[:len(bearer)], bearer) == 1 {
// Use toLower to make the comparison case-insensitive.
if len(hdr) >= len(bearer) && subtle.ConstantTimeCompare(bytes.ToLower(hdr[:len(bearer)]), bearer) == 1 {
hdr = hdr[len(bearer):]
}

return len(api.SCIMAPIKey) != 0 && subtle.ConstantTimeCompare(hdr, api.SCIMAPIKey) == 1
}

func scimUnauthorized(rw http.ResponseWriter) {
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusUnauthorized, "invalidAuthorization", xerrors.New("invalid authorization")))
}

// scimServiceProviderConfig returns a static SCIM service provider configuration.
//
// @Summary SCIM 2.0: Service Provider Config
Expand Down Expand Up @@ -114,7 +120,7 @@ func (api *API) scimServiceProviderConfig(rw http.ResponseWriter, _ *http.Reques
//nolint:revive
func (api *API) scimGetUsers(rw http.ResponseWriter, r *http.Request) {
if !api.scimVerifyAuthHeader(r) {
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"})
scimUnauthorized(rw)
return
}

Expand Down Expand Up @@ -142,11 +148,11 @@ func (api *API) scimGetUsers(rw http.ResponseWriter, r *http.Request) {
//nolint:revive
func (api *API) scimGetUser(rw http.ResponseWriter, r *http.Request) {
if !api.scimVerifyAuthHeader(r) {
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"})
scimUnauthorized(rw)
return
}

_ = handlerutil.WriteError(rw, spec.ErrNotFound)
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, xerrors.New("endpoint will always return 404")))
}

// We currently use our own struct instead of using the SCIM package. This was
Expand Down Expand Up @@ -192,7 +198,7 @@ var SCIMAuditAdditionalFields = map[string]string{
func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.scimVerifyAuthHeader(r) {
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"})
scimUnauthorized(rw)
return
}

Expand All @@ -209,7 +215,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
var sUser SCIMUser
err := json.NewDecoder(r.Body).Decode(&sUser)
if err != nil {
_ = handlerutil.WriteError(rw, err)
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", err))
return
}

Expand All @@ -222,7 +228,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
}

if email == "" {
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusBadRequest, Type: "invalidEmail"})
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidEmail", xerrors.New("no primary email provided")))
return
}

Expand All @@ -232,7 +238,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
Username: sUser.UserName,
})
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
_ = handlerutil.WriteError(rw, err)
_ = handlerutil.WriteError(rw, err) // internal error
return
}
if err == nil {
Expand All @@ -248,7 +254,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
UpdatedAt: dbtime.Now(),
})
if err != nil {
_ = handlerutil.WriteError(rw, err)
_ = handlerutil.WriteError(rw, err) // internal error
return
}
aReq.New = newUser
Expand Down Expand Up @@ -284,14 +290,14 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
//nolint:gocritic // SCIM operations are a system user
orgSync, err := api.IDPSync.OrganizationSyncSettings(dbauthz.AsSystemRestricted(ctx), api.Database)
if err != nil {
_ = handlerutil.WriteError(rw, xerrors.Errorf("failed to get organization sync settings: %w", err))
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to get organization sync settings: %w", err)))
return
}
if orgSync.AssignDefault {
//nolint:gocritic // SCIM operations are a system user
defaultOrganization, err := api.Database.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx))
if err != nil {
_ = handlerutil.WriteError(rw, err)
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to get default organization: %w", err)))
return
}
organizations = append(organizations, defaultOrganization.ID)
Expand All @@ -309,7 +315,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
SkipNotifications: true,
})
if err != nil {
_ = handlerutil.WriteError(rw, err)
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to create user: %w", err)))
return
}
aReq.New = dbUser
Expand All @@ -335,7 +341,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.scimVerifyAuthHeader(r) {
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"})
scimUnauthorized(rw)
return
}

Expand All @@ -354,21 +360,21 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) {
var sUser SCIMUser
err := json.NewDecoder(r.Body).Decode(&sUser)
if err != nil {
_ = handlerutil.WriteError(rw, err)
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", err))
return
}
sUser.ID = id

uid, err := uuid.Parse(id)
if err != nil {
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusBadRequest, Type: "invalidId"})
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidId", xerrors.Errorf("id must be a uuid: %w", err)))
return
}

//nolint:gocritic // needed for SCIM
dbUser, err := api.Database.GetUserByID(dbauthz.AsSystemRestricted(ctx), uid)
if err != nil {
_ = handlerutil.WriteError(rw, err)
_ = handlerutil.WriteError(rw, err) // internal error
return
}
aReq.Old = dbUser
Expand Down Expand Up @@ -400,7 +406,7 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) {
UpdatedAt: dbtime.Now(),
})
if err != nil {
_ = handlerutil.WriteError(rw, err)
_ = handlerutil.WriteError(rw, err) // internal error
return
}
dbUser = userNew
Expand Down
41 changes: 40 additions & 1 deletion enterprise/coderd/scim/scimtypes.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package scim

import "time"
import (
"encoding/json"
"time"

"github.com/imulab/go-scim/pkg/v2/spec"
)

type ServiceProviderConfig struct {
Schemas []string `json:"schemas"`
Expand Down Expand Up @@ -44,3 +49,37 @@ type AuthenticationScheme struct {
SpecURI string `json:"specUri"`
DocURI string `json:"documentationUri"`
}

// HTTPError wraps a *spec.Error for correct usage with
// 'handlerutil.WriteError'. This error type is cursed to be
// absolutely strange and specific to the SCIM library we use.
//
// The library expects *spec.Error to be returned on unwrap, and the
// internal error description to be returned by a json.Marshal of the
// top level error.
type HTTPError struct {
scim *spec.Error
internal error
}

func NewHTTPError(status int, eType string, err error) *HTTPError {
return &HTTPError{
scim: &spec.Error{
Status: status,
Type: eType,
},
internal: err,
}
}

func (e HTTPError) Error() string {
return e.internal.Error()
}

func (e HTTPError) MarshalJSON() ([]byte, error) {
return json.Marshal(e.internal)
}

func (e HTTPError) Unwrap() error {
return e.scim
}
30 changes: 27 additions & 3 deletions enterprise/coderd/scim_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@ import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/golang-jwt/jwt/v4"
"github.com/imulab/go-scim/pkg/v2/handlerutil"
"github.com/imulab/go-scim/pkg/v2/spec"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"

"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/coderdtest"
Expand All @@ -22,6 +26,7 @@ import (
"github.com/coder/coder/v2/enterprise/coderd"
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
"github.com/coder/coder/v2/enterprise/coderd/license"
"github.com/coder/coder/v2/enterprise/coderd/scim"
"github.com/coder/coder/v2/testutil"
)

Expand Down Expand Up @@ -59,7 +64,8 @@ func setScimAuth(key []byte) func(*http.Request) {

func setScimAuthBearer(key []byte) func(*http.Request) {
return func(r *http.Request) {
r.Header.Set("Authorization", "Bearer "+string(key))
// Do strange casing to ensure it's case-insensitive
r.Header.Set("Authorization", "beAreR "+string(key))
}
}

Expand Down Expand Up @@ -111,7 +117,7 @@ func TestScim(t *testing.T) {
res, err := client.Request(ctx, "POST", "/scim/v2/Users", struct{}{})
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
})

t.Run("OK", func(t *testing.T) {
Expand Down Expand Up @@ -454,7 +460,7 @@ func TestScim(t *testing.T) {
require.NoError(t, err)
_, _ = io.Copy(io.Discard, res.Body)
_ = res.Body.Close()
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
})

t.Run("OK", func(t *testing.T) {
Expand Down Expand Up @@ -585,3 +591,21 @@ func TestScim(t *testing.T) {
})
})
}

func TestScimError(t *testing.T) {
t.Parallel()

// Demonstrates that we cannot use the standard errors
rw := httptest.NewRecorder()
_ = handlerutil.WriteError(rw, spec.ErrNotFound)
resp := rw.Result()
defer resp.Body.Close()
require.Equal(t, http.StatusInternalServerError, resp.StatusCode)

// Our error wrapper works
rw = httptest.NewRecorder()
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, xerrors.New("not found")))
resp = rw.Result()
defer resp.Body.Close()
require.Equal(t, http.StatusNotFound, resp.StatusCode)
}
Loading