diff --git a/enterprise/coderd/scim.go b/enterprise/coderd/scim.go index 01d04626a6948..a7bb502a300eb 100644 --- a/enterprise/coderd/scim.go +++ b/enterprise/coderd/scim.go @@ -1,6 +1,7 @@ package coderd import ( + "bytes" "crypto/subtle" "database/sql" "encoding/json" @@ -26,16 +27,21 @@ import ( ) func (api *API) scimVerifyAuthHeader(r *http.Request) bool { - bearer := []byte("Bearer ") + bearer := []byte("bearer ") hdr := []byte(r.Header.Get("Authorization")) - if len(hdr) >= len(bearer) && subtle.ConstantTimeCompare(hdr[:len(bearer)], bearer) == 1 { + // Use toLower to make the comparison case-insensitive. + if len(hdr) >= len(bearer) && subtle.ConstantTimeCompare(bytes.ToLower(hdr[:len(bearer)]), bearer) == 1 { hdr = hdr[len(bearer):] } return len(api.SCIMAPIKey) != 0 && subtle.ConstantTimeCompare(hdr, api.SCIMAPIKey) == 1 } +func scimUnauthorized(rw http.ResponseWriter) { + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusUnauthorized, "invalidAuthorization", xerrors.New("invalid authorization"))) +} + // scimServiceProviderConfig returns a static SCIM service provider configuration. // // @Summary SCIM 2.0: Service Provider Config @@ -114,7 +120,7 @@ func (api *API) scimServiceProviderConfig(rw http.ResponseWriter, _ *http.Reques //nolint:revive func (api *API) scimGetUsers(rw http.ResponseWriter, r *http.Request) { if !api.scimVerifyAuthHeader(r) { - _ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"}) + scimUnauthorized(rw) return } @@ -142,11 +148,11 @@ func (api *API) scimGetUsers(rw http.ResponseWriter, r *http.Request) { //nolint:revive func (api *API) scimGetUser(rw http.ResponseWriter, r *http.Request) { if !api.scimVerifyAuthHeader(r) { - _ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"}) + scimUnauthorized(rw) return } - _ = handlerutil.WriteError(rw, spec.ErrNotFound) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, xerrors.New("endpoint will always return 404"))) } // We currently use our own struct instead of using the SCIM package. This was @@ -192,7 +198,7 @@ var SCIMAuditAdditionalFields = map[string]string{ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() if !api.scimVerifyAuthHeader(r) { - _ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"}) + scimUnauthorized(rw) return } @@ -209,7 +215,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { var sUser SCIMUser err := json.NewDecoder(r.Body).Decode(&sUser) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", err)) return } @@ -222,7 +228,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { } if email == "" { - _ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusBadRequest, Type: "invalidEmail"}) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidEmail", xerrors.New("no primary email provided"))) return } @@ -232,7 +238,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { Username: sUser.UserName, }) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, err) // internal error return } if err == nil { @@ -248,7 +254,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { UpdatedAt: dbtime.Now(), }) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, err) // internal error return } aReq.New = newUser @@ -284,14 +290,14 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { //nolint:gocritic // SCIM operations are a system user orgSync, err := api.IDPSync.OrganizationSyncSettings(dbauthz.AsSystemRestricted(ctx), api.Database) if err != nil { - _ = handlerutil.WriteError(rw, xerrors.Errorf("failed to get organization sync settings: %w", err)) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to get organization sync settings: %w", err))) return } if orgSync.AssignDefault { //nolint:gocritic // SCIM operations are a system user defaultOrganization, err := api.Database.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx)) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to get default organization: %w", err))) return } organizations = append(organizations, defaultOrganization.ID) @@ -309,7 +315,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { SkipNotifications: true, }) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to create user: %w", err))) return } aReq.New = dbUser @@ -335,7 +341,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() if !api.scimVerifyAuthHeader(r) { - _ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"}) + scimUnauthorized(rw) return } @@ -354,21 +360,21 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) { var sUser SCIMUser err := json.NewDecoder(r.Body).Decode(&sUser) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", err)) return } sUser.ID = id uid, err := uuid.Parse(id) if err != nil { - _ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusBadRequest, Type: "invalidId"}) + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidId", xerrors.Errorf("id must be a uuid: %w", err))) return } //nolint:gocritic // needed for SCIM dbUser, err := api.Database.GetUserByID(dbauthz.AsSystemRestricted(ctx), uid) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, err) // internal error return } aReq.Old = dbUser @@ -400,7 +406,7 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) { UpdatedAt: dbtime.Now(), }) if err != nil { - _ = handlerutil.WriteError(rw, err) + _ = handlerutil.WriteError(rw, err) // internal error return } dbUser = userNew diff --git a/enterprise/coderd/scim/scimtypes.go b/enterprise/coderd/scim/scimtypes.go index e78b70b3e9f3f..39e022aa24e05 100644 --- a/enterprise/coderd/scim/scimtypes.go +++ b/enterprise/coderd/scim/scimtypes.go @@ -1,6 +1,11 @@ package scim -import "time" +import ( + "encoding/json" + "time" + + "github.com/imulab/go-scim/pkg/v2/spec" +) type ServiceProviderConfig struct { Schemas []string `json:"schemas"` @@ -44,3 +49,37 @@ type AuthenticationScheme struct { SpecURI string `json:"specUri"` DocURI string `json:"documentationUri"` } + +// HTTPError wraps a *spec.Error for correct usage with +// 'handlerutil.WriteError'. This error type is cursed to be +// absolutely strange and specific to the SCIM library we use. +// +// The library expects *spec.Error to be returned on unwrap, and the +// internal error description to be returned by a json.Marshal of the +// top level error. +type HTTPError struct { + scim *spec.Error + internal error +} + +func NewHTTPError(status int, eType string, err error) *HTTPError { + return &HTTPError{ + scim: &spec.Error{ + Status: status, + Type: eType, + }, + internal: err, + } +} + +func (e HTTPError) Error() string { + return e.internal.Error() +} + +func (e HTTPError) MarshalJSON() ([]byte, error) { + return json.Marshal(e.internal) +} + +func (e HTTPError) Unwrap() error { + return e.scim +} diff --git a/enterprise/coderd/scim_test.go b/enterprise/coderd/scim_test.go index 3e5c22f7e9461..1f9d230bf7f2d 100644 --- a/enterprise/coderd/scim_test.go +++ b/enterprise/coderd/scim_test.go @@ -6,11 +6,15 @@ import ( "fmt" "io" "net/http" + "net/http/httptest" "testing" "github.com/golang-jwt/jwt/v4" + "github.com/imulab/go-scim/pkg/v2/handlerutil" + "github.com/imulab/go-scim/pkg/v2/spec" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/coderdtest" @@ -22,6 +26,7 @@ import ( "github.com/coder/coder/v2/enterprise/coderd" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" "github.com/coder/coder/v2/enterprise/coderd/license" + "github.com/coder/coder/v2/enterprise/coderd/scim" "github.com/coder/coder/v2/testutil" ) @@ -59,7 +64,8 @@ func setScimAuth(key []byte) func(*http.Request) { func setScimAuthBearer(key []byte) func(*http.Request) { return func(r *http.Request) { - r.Header.Set("Authorization", "Bearer "+string(key)) + // Do strange casing to ensure it's case-insensitive + r.Header.Set("Authorization", "beAreR "+string(key)) } } @@ -111,7 +117,7 @@ func TestScim(t *testing.T) { res, err := client.Request(ctx, "POST", "/scim/v2/Users", struct{}{}) require.NoError(t, err) defer res.Body.Close() - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + assert.Equal(t, http.StatusUnauthorized, res.StatusCode) }) t.Run("OK", func(t *testing.T) { @@ -454,7 +460,7 @@ func TestScim(t *testing.T) { require.NoError(t, err) _, _ = io.Copy(io.Discard, res.Body) _ = res.Body.Close() - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + assert.Equal(t, http.StatusUnauthorized, res.StatusCode) }) t.Run("OK", func(t *testing.T) { @@ -585,3 +591,21 @@ func TestScim(t *testing.T) { }) }) } + +func TestScimError(t *testing.T) { + t.Parallel() + + // Demonstrates that we cannot use the standard errors + rw := httptest.NewRecorder() + _ = handlerutil.WriteError(rw, spec.ErrNotFound) + resp := rw.Result() + defer resp.Body.Close() + require.Equal(t, http.StatusInternalServerError, resp.StatusCode) + + // Our error wrapper works + rw = httptest.NewRecorder() + _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, xerrors.New("not found"))) + resp = rw.Result() + defer resp.Body.Close() + require.Equal(t, http.StatusNotFound, resp.StatusCode) +}