diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d6b985d1..688090cb 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -10,8 +10,11 @@ jobs: golangci: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: stable - name: golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v7 with: - version: v1.54.2 + version: v2.0 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e6e91273..bca5ba69 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,10 +10,11 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: [ '1.19.x', '1.20.x', '1.21.x'] + go: [ '1.22.x', '1.23.x', '1.24.x'] steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} + - run: go version - run: go test -v ./... diff --git a/.golangci.yml b/.golangci.yml index 23f37cbf..46ed573a 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,72 +1,70 @@ -# Configuration file for golangci-lint -# -# https://github.com/golangci/golangci-lint -# -# fighting with false positives? -# https://github.com/golangci/golangci-lint#nolint - +version: "2" linters: enable: - - bodyclose # checks whether HTTP response body is closed successfully [fast: false, auto-fix: false] - - errcheck # Inspects source code for security problems [fast: true, auto-fix: false] - - gocritic # The most opinionated Go source code linter [fast: true, auto-fix: false] - - gocyclo # Computes and checks the cyclomatic complexity of functions [fast: true, auto-fix: false] - - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification [fast: true, auto-fix: true] - - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports [fast: true, auto-fix: true] - - gosec # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases [fast: true, auto-fix: false] - - gosimple # Linter for Go source code that specializes in simplifying a code [fast: false, auto-fix: false] - - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string [fast: false, auto-fix: false] - - ineffassign # Detects when assignments to existing variables are not used [fast: true, auto-fix: false] - - misspell # Finds commonly misspelled English words in comments [fast: true, auto-fix: true] - - nakedret # Finds naked returns in functions greater than a specified function length [fast: true, auto-fix: false] - - prealloc # Finds slice declarations that could potentially be preallocated [fast: true, auto-fix: false] - - revive # Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes [fast: true, auto-fix: false] - - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks [fast: false, auto-fix: false] - - stylecheck # Stylecheck is a replacement for golint [fast: false, auto-fix: false] - - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code [fast: true, auto-fix: false] - - unconvert # Remove unnecessary type conversions [fast: true, auto-fix: false] - - unparam # Reports unused function parameters [fast: false, auto-fix: false] - - unused # Checks Go code for unused constants, variables, functions and types [fast: false, auto-fix: false] - + - bodyclose + - gocritic + - gocyclo + - gosec + - misspell + - nakedret + - prealloc + - revive + - staticcheck + - unconvert + - unparam disable: # TODO(ross): fix errors reported by these checkers and enable them - - dupl # Tool for code clone detection [fast: true, auto-fix: false] - - gochecknoglobals # Checks that no globals are present in Go code [fast: true, auto-fix: false] - - gochecknoinits # Checks that no init functions are present in Go code [fast: true, auto-fix: false] - - goconst # Finds repeated strings that could be replaced by a constant [fast: true, auto-fix: false] - - lll # Reports long lines [fast: true, auto-fix: false] - - depguard # Go linter that checks if package imports are in a list of acceptable packages [fast: true, auto-fix: false] -linters-settings: - goimports: - local-prefixes: github.com/crewjam/saml - govet: - disable: - - shadow - enable: - - asmdecl - - assign - - atomic - - bools - - buildtag - - cgocall - - composites - - copylocks - - errorsas - - httpresponse - - loopclosure - - lostcancel - - nilfunc - - printf - - shift - - stdmethods - - structtag - - tests - - unmarshal - - unreachable - - unsafeptr - - unusedresult -issues: - exclude-use-default: false - exclude: - - G104 # 'Errors unhandled. (gosec) - + - depguard + - dupl + - gochecknoglobals + - gochecknoinits + - goconst + - lll + settings: + govet: + enable: + - asmdecl + - assign + - atomic + - bools + - buildtag + - cgocall + - composites + - copylocks + - errorsas + - httpresponse + - loopclosure + - lostcancel + - nilfunc + - printf + - shift + - stdmethods + - structtag + - tests + - unmarshal + - unreachable + - unsafeptr + - unusedresult + disable: + - shadow + exclusions: + generated: lax + rules: + - path: (.+)\.go$ + text: G104 # 'Errors unhandled. (gosec) + paths: + - example/.*\.go$ +formatters: + enable: + - gofmt + - goimports + settings: + goimports: + local-prefixes: + - github.com/clerk/saml + exclusions: + generated: lax + paths: + - third_party$ + - builtin$ + - examples$ diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 00000000..6da4fdeb --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +* @crewjam diff --git a/README.md b/README.md index cd1415a5..1adeccff 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![](https://godoc.org/github.com/crewjam/saml?status.svg)](http://godoc.org/github.com/crewjam/saml) -![Build Status](https://github.com/crewjam/saml/workflows/Presubmit/badge.svg) +![Build Status](https://github.com/crewjam/saml/actions/workflows/test.yml/badge.svg) Package saml contains a partial implementation of the SAML standard in golang. SAML is a standard for identity federation, i.e. either allowing a third party to authenticate your users or allowing third parties to rely on us to authenticate their users. @@ -130,7 +130,7 @@ The SAML standard is huge and complex with many dark corners and strange, unused This package supports the **Web SSO** profile. Message flows from the service provider to the IDP are supported using the **HTTP Redirect** binding and the **HTTP POST** binding. Message flows from the IDP to the service provider are supported via the **HTTP POST** binding. -The package can produce signed SAML assertions, and can validate both signed and encrypted SAML assertions. It does not support signed or encrypted requests. +The package can produce signed SAML assertions, and can validate both signed and encrypted SAML assertions. ## RelayState diff --git a/example/idp/idp.go b/example/idp/idp.go index 6ed039a5..e1e54de8 100644 --- a/example/idp/idp.go +++ b/example/idp/idp.go @@ -6,9 +6,9 @@ import ( "crypto/x509" "encoding/pem" "flag" + "net/http" "net/url" - "github.com/zenazn/goji" "golang.org/x/crypto/bcrypt" "github.com/clerk/saml/logger" @@ -118,6 +118,5 @@ func main() { logr.Fatalf("%s", err) } - goji.Handle("/*", idpServer) - goji.Serve() + http.ListenAndServe(":8080", idpServer) } diff --git a/example/service.go b/example/service.go index 1be285c8..35ae8a41 100644 --- a/example/service.go +++ b/example/service.go @@ -7,6 +7,7 @@ import ( "crypto/rsa" "crypto/tls" "crypto/x509" + "encoding/base64" "encoding/xml" "flag" "fmt" @@ -14,11 +15,6 @@ import ( "net/url" "strings" - "github.com/dchest/uniuri" - "github.com/kr/pretty" - "github.com/zenazn/goji" - "github.com/zenazn/goji/web" - "github.com/clerk/saml/samlsp" ) @@ -32,10 +28,16 @@ type Link struct { } // CreateLink handles requests to create links -func CreateLink(_ web.C, w http.ResponseWriter, r *http.Request) { +func CreateLink(w http.ResponseWriter, r *http.Request) { account := r.Header.Get("X-Remote-User") + + randomness := make([]byte, 8) + if _, err := r.Body.Read(randomness); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } l := Link{ - ShortLink: uniuri.New(), + ShortLink: base64.RawURLEncoding.EncodeToString(randomness), Target: r.FormValue("t"), Owner: account, } @@ -45,7 +47,7 @@ func CreateLink(_ web.C, w http.ResponseWriter, r *http.Request) { } // ServeLink handles requests to redirect to a link -func ServeLink(_ web.C, w http.ResponseWriter, r *http.Request) { +func ServeLink(w http.ResponseWriter, r *http.Request) { l, ok := links[strings.TrimPrefix(r.URL.Path, "/")] if !ok { http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) @@ -55,7 +57,7 @@ func ServeLink(_ web.C, w http.ResponseWriter, r *http.Request) { } // ListLinks returns a list of the current user's links -func ListLinks(_ web.C, w http.ResponseWriter, r *http.Request) { +func ListLinks(w http.ResponseWriter, r *http.Request) { account := r.Header.Get("X-Remote-User") for _, l := range links { if l.Owner == account { @@ -64,6 +66,27 @@ func ListLinks(_ web.C, w http.ResponseWriter, r *http.Request) { } } +// ServeWhoami serves the basic whoami endpoint +func ServeWhoami(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", "text/plain") + + session := samlsp.SessionFromContext(r.Context()) + if session == nil { + fmt.Fprintln(w, "not signed in") + return + } + fmt.Fprintln(w, "signed in") + sessionWithAttrs, ok := session.(samlsp.SessionWithAttributes) + if ok { + fmt.Fprintln(w, "attributes:") + for name, values := range sessionWithAttrs.GetAttributes() { + for _, value := range values { + fmt.Fprintf(w, "%s: %v\n", name, value) + } + } + } +} + var ( key = []byte(`-----BEGIN RSA PRIVATE KEY----- MIICXgIBAAKBgQDU8wdiaFmPfTyRYuFlVPi866WrH/2JubkHzp89bBQopDaLXYxi @@ -140,11 +163,9 @@ func main() { // register with the service provider spMetadataBuf, _ := xml.MarshalIndent(samlSP.ServiceProvider.Metadata(), "", " ") - spURL := *idpMetadataURL spURL.Path = "/services/sp" resp, err := http.Post(spURL.String(), "text/xml", bytes.NewReader(spMetadataBuf)) - if err != nil { panic(err) } @@ -153,20 +174,12 @@ func main() { panic(err) } - goji.Handle("/saml/*", samlSP) - - authMux := web.New() - authMux.Use(samlSP.RequireAccount) - authMux.Get("/whoami", func(w http.ResponseWriter, r *http.Request) { - if _, err := pretty.Fprintf(w, "%# v", r); err != nil { - panic(err) - } - }) - authMux.Post("/", CreateLink) - authMux.Get("/", ListLinks) - - goji.Handle("/*", authMux) - goji.Get("/:link", ServeLink) + mux := http.NewServeMux() + mux.Handle("GET /saml/", samlSP) + mux.HandleFunc("GET /{link}", ServeLink) + mux.Handle("GET /whoami", samlSP.RequireAccount(http.HandlerFunc(ServeWhoami))) + mux.Handle("POST /", samlSP.RequireAccount(http.HandlerFunc(CreateLink))) + mux.Handle("GET /", samlSP.RequireAccount(http.HandlerFunc(ListLinks))) - goji.Serve() + http.ListenAndServe(":8080", mux) } diff --git a/go.mod b/go.mod index e8a75b8a..aaacc53e 100644 --- a/go.mod +++ b/go.mod @@ -3,26 +3,17 @@ module github.com/clerk/saml go 1.23.0 require ( - github.com/beevik/etree v1.2.0 - github.com/crewjam/httperr v0.2.0 - github.com/dchest/uniuri v1.2.0 - github.com/golang-jwt/jwt/v4 v4.5.2 - github.com/google/go-cmp v0.6.0 - github.com/kr/pretty v0.3.1 + github.com/beevik/etree v1.5.0 + github.com/golang-jwt/jwt/v5 v5.2.2 + github.com/google/go-cmp v0.7.0 github.com/mattermost/xml-roundtrip-validator v0.1.0 github.com/russellhaering/goxmldsig v1.4.0 - github.com/stretchr/testify v1.8.4 - github.com/zenazn/goji v1.0.1 golang.org/x/crypto v0.35.0 gotest.tools v2.2.0+incompatible ) require ( - github.com/davecgh/go-spew v1.1.1 // indirect github.com/jonboulle/clockwork v0.2.2 // indirect - github.com/kr/text v0.2.0 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.9.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect + github.com/stretchr/testify v1.10.0 // indirect ) diff --git a/go.sum b/go.sum index 9395b045..c8fd7c98 100644 --- a/go.sum +++ b/go.sum @@ -1,58 +1,43 @@ github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A= -github.com/beevik/etree v1.2.0 h1:l7WETslUG/T+xOPs47dtd6jov2Ii/8/OjCldk5fYfQw= -github.com/beevik/etree v1.2.0/go.mod h1:aiPf89g/1k3AShMVAzriilpcE4R/Vuor90y83zVZWFc= +github.com/beevik/etree v1.5.0 h1:iaQZFSDS+3kYZiGoc9uKeOkUY3nYMXOKLl6KIJxiJWs= +github.com/beevik/etree v1.5.0/go.mod h1:gPNJNaBGVZ9AwsidazFZyygnd+0pAU38N4D+WemwKNs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/crewjam/httperr v0.2.0 h1:b2BfXR8U3AlIHwNeFFvZ+BV1LFvKLlzMjzaTnZMybNo= -github.com/crewjam/httperr v0.2.0/go.mod h1:Jlz+Sg/XqBQhyMjdDiC+GNNRzZTD7x39Gu3pglZ5oH4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dchest/uniuri v1.2.0 h1:koIcOUdrTIivZgSLhHQvKgqdWZq5d7KdMEWF1Ud6+5g= -github.com/dchest/uniuri v1.2.0/go.mod h1:fSzm4SLHzNZvWLvWJew423PhAzkpNQYq+uNLq4kxhkY= -github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= -github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ= github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU= github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/russellhaering/goxmldsig v1.4.0 h1:8UcDh/xGyQiyrW+Fq5t8f+l2DLB1+zlhYzkPUJ7Qhys= github.com/russellhaering/goxmldsig v1.4.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/zenazn/goji v1.0.1 h1:4lbD8Mx2h7IvloP7r2C0D6ltZP6Ufip8Hn0wmSK5LR8= -github.com/zenazn/goji v1.0.1/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/identity_provider.go b/identity_provider.go index dd754faa..f0ee6b74 100644 --- a/identity_provider.go +++ b/identity_provider.go @@ -8,13 +8,13 @@ import ( "encoding/base64" "encoding/xml" "fmt" + "html/template" "io" "net/http" "net/url" "os" "regexp" "strconv" - "text/template" "time" "github.com/beevik/etree" @@ -38,13 +38,14 @@ type Session struct { NameIDFormat string SubjectID string - Groups []string - UserName string - UserEmail string - UserCommonName string - UserSurname string - UserGivenName string - UserScopedAffiliation string + Groups []string + UserName string + UserEmail string + UserCommonName string + UserSurname string + UserGivenName string + UserScopedAffiliation string + EduPersonPrincipalName string `json:",omitempty"` CustomAttributes []Attribute } @@ -101,12 +102,14 @@ type IdentityProvider struct { Intermediates []*x509.Certificate MetadataURL url.URL SSOURL url.URL + LoginURL url.URL LogoutURL url.URL ServiceProviderProvider ServiceProviderProvider SessionProvider SessionProvider AssertionMaker AssertionMaker SignatureMethod string ValidDuration *time.Duration + ResponseFormTemplate *template.Template } // Metadata returns the metadata structure for this identity provider. @@ -175,7 +178,7 @@ func (idp *IdentityProvider) Metadata() *EntityDescriptor { } if idp.LogoutURL.String() != "" { - ed.IDPSSODescriptors[0].SSODescriptor.SingleLogoutServices = []Endpoint{ + ed.IDPSSODescriptors[0].SingleLogoutServices = []Endpoint{ { Binding: HTTPRedirectBinding, Location: idp.LogoutURL.String(), @@ -662,13 +665,33 @@ func (DefaultAssertionMaker) MakeAssertion(req *IdpAuthnRequest, session *Sessio } if session.UserEmail != "" { + attributes = append(attributes, Attribute{ + FriendlyName: "mail", + Name: "urn:oid:0.9.2342.19200300.100.1.3", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + Values: []AttributeValue{{ + Type: "xs:string", + Value: session.UserEmail, + }}, + }) + } + if session.EduPersonPrincipalName != "" || session.UserEmail != "" { + value := session.EduPersonPrincipalName + if value == "" { + // We used to set eduPersonPrincipalName (urn:oid:1.3.6.1.4.1.5923.1.1.1.6) + // to the value of session.UserEmail. It is more correct to set + // mail (urn:oid:0.9.2342.19200300.100.1.3). To avoid breaking things, + // we preserve the former behavior. + value = session.UserEmail + } + attributes = append(attributes, Attribute{ FriendlyName: "eduPersonPrincipalName", Name: "urn:oid:1.3.6.1.4.1.5923.1.1.1.6", NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", Values: []AttributeValue{{ Type: "xs:string", - Value: session.UserEmail, + Value: value, }}, }) } @@ -709,7 +732,7 @@ func (DefaultAssertionMaker) MakeAssertion(req *IdpAuthnRequest, session *Sessio if session.UserScopedAffiliation != "" { attributes = append(attributes, Attribute{ - FriendlyName: "uid", + FriendlyName: "scopedAffiliation", Name: "urn:oid:1.3.6.1.4.1.5923.1.1.1.9", NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", Values: []AttributeValue{{ @@ -921,6 +944,16 @@ func (req *IdpAuthnRequest) PostBinding() (IdpAuthnRequestForm, error) { return form, nil } +var defaultResponseFormTemplate = template.Must(template.New("saml-post-form").Parse(`` + + `
` + + `` + + `` + + `` + + `
` + + `` + + `` + + ``)) + // WriteResponse writes the `Response` to the http.ResponseWriter. If // `Response` is not already set, it calls MakeResponse to produce it. func (req *IdpAuthnRequest) WriteResponse(w http.ResponseWriter) error { @@ -929,15 +962,10 @@ func (req *IdpAuthnRequest) WriteResponse(w http.ResponseWriter) error { return err } - tmpl := template.Must(template.New("saml-post-form").Parse(`` + - `
` + - `` + - `` + - `` + - `
` + - `` + - `` + - ``)) + tmpl := req.IDP.ResponseFormTemplate + if tmpl == nil { + tmpl = defaultResponseFormTemplate + } buf := bytes.NewBuffer(nil) if err := tmpl.Execute(buf, form); err != nil { diff --git a/identity_provider_test.go b/identity_provider_test.go index 8c7d8c00..fc578b17 100644 --- a/identity_provider_test.go +++ b/identity_provider_test.go @@ -25,7 +25,6 @@ import ( "gotest.tools/golden" "github.com/beevik/etree" - "github.com/golang-jwt/jwt/v4" dsig "github.com/russellhaering/goxmldsig" "github.com/clerk/saml/logger" @@ -104,7 +103,6 @@ func NewIdentityProviderTest(t *testing.T, opts ...idpTestOpts) *IdentityProvide rv, _ := time.Parse("Mon Jan 2 15:04:05 MST 2006", "Mon Dec 1 01:57:09 UTC 2015") return rv } - jwt.TimeFunc = TimeNow RandReader = &testRandomReader{} // TODO(ross): remove this and use the below generator xmlenc.RandReader = rand.New(rand.NewSource(0)) //nolint:gosec // deterministic random numbers for tests @@ -126,7 +124,7 @@ func NewIdentityProviderTest(t *testing.T, opts ...idpTestOpts) *IdentityProvide MetadataURL: mustParseURL("https://idp.example.com/saml/metadata"), SSOURL: mustParseURL("https://idp.example.com/saml/sso"), ServiceProviderProvider: &mockServiceProviderProvider{ - GetServiceProviderFunc: func(r *http.Request, serviceProviderID string) (*EntityDescriptor, error) { + GetServiceProviderFunc: func(_ *http.Request, serviceProviderID string) (*EntityDescriptor, error) { if serviceProviderID == test.SP.MetadataURL.String() { return test.SP.Metadata(), nil } @@ -134,7 +132,7 @@ func NewIdentityProviderTest(t *testing.T, opts ...idpTestOpts) *IdentityProvide }, }, SessionProvider: &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { + GetSessionFunc: func(_ http.ResponseWriter, _ *http.Request, _ *IdpAuthnRequest) *Session { return nil }, }, @@ -241,9 +239,10 @@ func TestIDPHTTPCanHandleMetadataRequest(t *testing.T) { func TestIDPCanHandleRequestWithNewSession(t *testing.T) { test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { - fmt.Fprintf(w, "RelayState: %s\nSAMLRequest: %s", + GetSessionFunc: func(w http.ResponseWriter, _ *http.Request, req *IdpAuthnRequest) *Session { + _, err := fmt.Fprintf(w, "RelayState: %s\nSAMLRequest: %s", req.RelayState, req.RequestBuffer) + assert.NilError(t, err) return nil }, } @@ -267,7 +266,7 @@ func TestIDPCanHandleRequestWithNewSession(t *testing.T) { func TestIDPCanHandleRequestWithExistingSession(t *testing.T) { test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { + GetSessionFunc: func(_ http.ResponseWriter, _ *http.Request, _ *IdpAuthnRequest) *Session { return &Session{ ID: "f00df00df00d", UserName: "alice", @@ -292,7 +291,7 @@ func TestIDPCanHandleRequestWithExistingSession(t *testing.T) { func TestIDPCanHandlePostRequestWithExistingSession(t *testing.T) { test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { + GetSessionFunc: func(_ http.ResponseWriter, _ *http.Request, _ *IdpAuthnRequest) *Session { return &Session{ ID: "f00df00df00d", UserName: "alice", @@ -321,7 +320,7 @@ func TestIDPCanHandlePostRequestWithExistingSession(t *testing.T) { func TestIDPRejectsInvalidRequest(t *testing.T) { test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { + GetSessionFunc: func(_ http.ResponseWriter, _ *http.Request, _ *IdpAuthnRequest) *Session { panic("not reached") }, } @@ -484,7 +483,6 @@ func TestIDPCanValidate(t *testing.T) { ""), } assert.Check(t, is.Error(req.Validate(), "cannot find assertion consumer service: file does not exist")) - } func TestIDPMakeAssertion(t *testing.T) { @@ -591,83 +589,93 @@ func TestIDPMakeAssertion(t *testing.T) { }) assert.Check(t, err) - expectedAttributes := - []Attribute{ - { - FriendlyName: "uid", - Name: "urn:oid:0.9.2342.19200300.100.1.1", - NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", - Values: []AttributeValue{ - { - Type: "xs:string", - Value: "alice", - }, + expectedAttributes := []Attribute{ + { + FriendlyName: "uid", + Name: "urn:oid:0.9.2342.19200300.100.1.1", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + Values: []AttributeValue{ + { + Type: "xs:string", + Value: "alice", }, }, - { - FriendlyName: "eduPersonPrincipalName", - Name: "urn:oid:1.3.6.1.4.1.5923.1.1.1.6", - NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", - Values: []AttributeValue{ - { - Type: "xs:string", - Value: "alice@example.com", - }, + }, + { + FriendlyName: "mail", + Name: "urn:oid:0.9.2342.19200300.100.1.3", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + Values: []AttributeValue{ + { + Type: "xs:string", + Value: "alice@example.com", }, }, - { - FriendlyName: "sn", - Name: "urn:oid:2.5.4.4", - NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", - Values: []AttributeValue{ - { - Type: "xs:string", - Value: "Smith", - }, + }, + { + FriendlyName: "eduPersonPrincipalName", + Name: "urn:oid:1.3.6.1.4.1.5923.1.1.1.6", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + Values: []AttributeValue{ + { + Type: "xs:string", + Value: "alice@example.com", }, }, - { - FriendlyName: "givenName", - Name: "urn:oid:2.5.4.42", - NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", - Values: []AttributeValue{ - { - Type: "xs:string", - Value: "Alice", - }, + }, + { + FriendlyName: "sn", + Name: "urn:oid:2.5.4.4", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + Values: []AttributeValue{ + { + Type: "xs:string", + Value: "Smith", }, }, - { - FriendlyName: "cn", - Name: "urn:oid:2.5.4.3", - NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", - Values: []AttributeValue{ - { - Type: "xs:string", - Value: "Alice Smith", - }, + }, + { + FriendlyName: "givenName", + Name: "urn:oid:2.5.4.42", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + Values: []AttributeValue{ + { + Type: "xs:string", + Value: "Alice", }, }, - { - FriendlyName: "eduPersonAffiliation", - Name: "urn:oid:1.3.6.1.4.1.5923.1.1.1.1", - NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", - Values: []AttributeValue{ - { - Type: "xs:string", - Value: "Users", - }, - { - Type: "xs:string", - Value: "Administrators", - }, - { - Type: "xs:string", - Value: "♀", - }, + }, + { + FriendlyName: "cn", + Name: "urn:oid:2.5.4.3", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + Values: []AttributeValue{ + { + Type: "xs:string", + Value: "Alice Smith", }, }, - } + }, + { + FriendlyName: "eduPersonAffiliation", + Name: "urn:oid:1.3.6.1.4.1.5923.1.1.1.1", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + Values: []AttributeValue{ + { + Type: "xs:string", + Value: "Users", + }, + { + Type: "xs:string", + Value: "Administrators", + }, + { + Type: "xs:string", + Value: "♀", + }, + }, + }, + } assert.Check(t, is.DeepEqual(expectedAttributes, req.Assertion.AttributeStatements[0].Attributes)) } @@ -798,8 +806,9 @@ func TestIDPWriteResponse(t *testing.T) { func TestIDPIDPInitiatedNewSession(t *testing.T) { test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { - fmt.Fprintf(w, "RelayState: %s", req.RelayState) + GetSessionFunc: func(w http.ResponseWriter, _ *http.Request, req *IdpAuthnRequest) *Session { + _, err := fmt.Fprintf(w, "RelayState: %s", req.RelayState) + assert.NilError(t, err) return nil }, } @@ -814,7 +823,7 @@ func TestIDPIDPInitiatedNewSession(t *testing.T) { func TestIDPIDPInitiatedExistingSession(t *testing.T) { test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { + GetSessionFunc: func(_ http.ResponseWriter, _ *http.Request, _ *IdpAuthnRequest) *Session { return &Session{ ID: "f00df00df00d", UserName: "alice", @@ -832,7 +841,7 @@ func TestIDPIDPInitiatedExistingSession(t *testing.T) { func TestIDPIDPInitiatedBadServiceProvider(t *testing.T) { test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { + GetSessionFunc: func(_ http.ResponseWriter, _ *http.Request, _ *IdpAuthnRequest) *Session { return &Session{ ID: "f00df00df00d", UserName: "alice", @@ -849,7 +858,7 @@ func TestIDPIDPInitiatedBadServiceProvider(t *testing.T) { func TestIDPCanHandleUnencryptedResponse(t *testing.T) { test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { + GetSessionFunc: func(_ http.ResponseWriter, _ *http.Request, _ *IdpAuthnRequest) *Session { return &Session{ID: "f00df00df00d", UserName: "alice"} }, } @@ -860,7 +869,7 @@ func TestIDPCanHandleUnencryptedResponse(t *testing.T) { &metadata) assert.Check(t, err) test.IDP.ServiceProviderProvider = &mockServiceProviderProvider{ - GetServiceProviderFunc: func(r *http.Request, serviceProviderID string) (*EntityDescriptor, error) { + GetServiceProviderFunc: func(_ *http.Request, serviceProviderID string) (*EntityDescriptor, error) { if serviceProviderID == "https://gitlab.example.com/users/saml/metadata" { return &metadata, nil } @@ -976,6 +985,17 @@ func TestIDPRequestedAttributes(t *testing.T) { }, }, }, + { + FriendlyName: "mail", + Name: "urn:oid:0.9.2342.19200300.100.1.3", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + Values: []AttributeValue{ + { + Type: "xs:string", + Value: "alice@example.com", + }, + }, + }, { FriendlyName: "eduPersonPrincipalName", Name: "urn:oid:1.3.6.1.4.1.5923.1.1.1.6", @@ -1020,14 +1040,15 @@ func TestIDPRequestedAttributes(t *testing.T) { }, }, }, - }}} + }, + }} assert.Check(t, is.DeepEqual(expectedAttributes, req.Assertion.AttributeStatements)) } func TestIDPNoDestination(t *testing.T) { test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { + GetSessionFunc: func(_ http.ResponseWriter, _ *http.Request, _ *IdpAuthnRequest) *Session { return &Session{ID: "f00df00df00d", UserName: "alice"} }, } @@ -1036,7 +1057,7 @@ func TestIDPNoDestination(t *testing.T) { err := xml.Unmarshal(golden.Get(t, "TestIDPNoDestination_idp_metadata.xml"), &metadata) assert.Check(t, err) test.IDP.ServiceProviderProvider = &mockServiceProviderProvider{ - GetServiceProviderFunc: func(r *http.Request, serviceProviderID string) (*EntityDescriptor, error) { + GetServiceProviderFunc: func(_ *http.Request, serviceProviderID string) (*EntityDescriptor, error) { if serviceProviderID == "https://gitlab.example.com/users/saml/metadata" { return &metadata, nil } @@ -1067,9 +1088,10 @@ func TestIDPNoDestination(t *testing.T) { func TestIDPRejectDecompressionBomb(t *testing.T) { test := NewIdentityProviderTest(t) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { - fmt.Fprintf(w, "RelayState: %s\nSAMLRequest: %s", + GetSessionFunc: func(w http.ResponseWriter, _ *http.Request, req *IdpAuthnRequest) *Session { + _, err := fmt.Fprintf(w, "RelayState: %s\nSAMLRequest: %s", req.RelayState, req.RequestBuffer) + assert.NilError(t, err) return nil }, } diff --git a/metadata.go b/metadata.go index 006a9e67..d160ccc3 100644 --- a/metadata.go +++ b/metadata.go @@ -38,6 +38,55 @@ type EntitiesDescriptor struct { EntityDescriptors []EntityDescriptor `xml:"urn:oasis:names:tc:SAML:2.0:metadata EntityDescriptor"` } +// MarshalXML implements xml.Marshaler +func (m EntitiesDescriptor) MarshalXML(e *xml.Encoder, _ xml.StartElement) error { + var validUntil *RelaxedTime + var cacheDuration *Duration + if m.ValidUntil != nil { + vu := RelaxedTime(*m.ValidUntil) + validUntil = &vu + } + if m.CacheDuration != nil { + cd := Duration(*m.CacheDuration) + cacheDuration = &cd + } + type Alias EntitiesDescriptor + aux := &struct { + ValidUntil *RelaxedTime `xml:"validUntil,attr,omitempty"` + CacheDuration *Duration `xml:"cacheDuration,attr,omitempty"` + *Alias + }{ + ValidUntil: validUntil, + CacheDuration: cacheDuration, + Alias: (*Alias)(&m), + } + return e.Encode(aux) +} + +// UnmarshalXML implements xml.Unmarshaler +func (m *EntitiesDescriptor) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { + type Alias EntitiesDescriptor + aux := &struct { + ValidUntil *RelaxedTime `xml:"validUntil,attr,omitempty"` + CacheDuration *Duration `xml:"cacheDuration,attr,omitempty"` + *Alias + }{ + Alias: (*Alias)(m), + } + if err := d.DecodeElement(aux, &start); err != nil { + return err + } + if aux.ValidUntil != nil { + t := time.Time(*aux.ValidUntil) + m.ValidUntil = &t + } + if aux.CacheDuration != nil { + d := time.Duration(*aux.CacheDuration) + m.CacheDuration = &d + } + return nil +} + // Metadata as been renamed to EntityDescriptor // // This change was made to be consistent with the rest of the API which uses names diff --git a/saml.go b/saml.go index f5d4f4e0..3638548a 100644 --- a/saml.go +++ b/saml.go @@ -149,7 +149,7 @@ // // This package supports the Web SSO profile. Message flows from the service provider to the IDP are supported using the HTTP Redirect binding and the HTTP POST binding. Message flows from the IDP to the service provider are supported via the HTTP POST binding. // -// The package can produce signed SAML assertions, and can validate both signed and encrypted SAML assertions. It does not support signed or encrypted requests. +// The package can produce signed SAML assertions, and can validate both signed and encrypted SAML assertions. // // # RelayState // diff --git a/samlidp/samlidp.go b/samlidp/samlidp.go index b4bda685..25933457 100644 --- a/samlidp/samlidp.go +++ b/samlidp/samlidp.go @@ -5,25 +5,25 @@ package samlidp import ( "crypto" "crypto/x509" + "html/template" "net/http" "net/url" - "regexp" + "strings" "sync" - "github.com/zenazn/goji/web" - "github.com/clerk/saml" "github.com/clerk/saml/logger" ) // Options represent the parameters to New() for creating a new IDP server type Options struct { - URL url.URL - Key crypto.PrivateKey - Signer crypto.Signer - Logger logger.Interface - Certificate *x509.Certificate - Store Store + URL url.URL + Key crypto.PrivateKey + Signer crypto.Signer + Logger logger.Interface + Certificate *x509.Certificate + Store Store + LoginFormTemplate *template.Template } // Server represents an IDP server. The server provides the following URLs: @@ -38,19 +38,24 @@ type Options struct { // /shortcuts - RESTful interface to Shortcut objects type Server struct { http.Handler - idpConfigMu sync.RWMutex // protects calls into the IDP - logger logger.Interface - serviceProviders map[string]*saml.EntityDescriptor - IDP saml.IdentityProvider // the underlying IDP - Store Store // the data store + idpConfigMu sync.RWMutex // protects calls into the IDP + logger logger.Interface + serviceProviders map[string]*saml.EntityDescriptor + IDP saml.IdentityProvider // the underlying IDP + Store Store // the data store + LoginFormTemplate *template.Template } // New returns a new Server func New(opts Options) (*Server, error) { + opts.URL.Path = strings.TrimSuffix(opts.URL.Path, "/") + metadataURL := opts.URL metadataURL.Path += "/metadata" ssoURL := opts.URL ssoURL.Path += "/sso" + loginURL := opts.URL + loginURL.Path += "/login" logr := opts.Logger if logr == nil { logr = logger.DefaultLogger @@ -65,9 +70,11 @@ func New(opts Options) (*Server, error) { Certificate: opts.Certificate, MetadataURL: metadataURL, SSOURL: ssoURL, + LoginURL: loginURL, }, - logger: logr, - Store: opts.Store, + logger: logr, + Store: opts.Store, + LoginFormTemplate: opts.LoginFormTemplate, } s.IDP.SessionProvider = s @@ -84,40 +91,41 @@ func New(opts Options) (*Server, error) { // is called automatically for you by New, but you may need to call it // yourself if you don't create the object using New.) func (s *Server) InitializeHTTP() { - mux := web.New() + mux := http.NewServeMux() s.Handler = mux - mux.Get("/metadata", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("GET /metadata", func(w http.ResponseWriter, r *http.Request) { s.idpConfigMu.RLock() defer s.idpConfigMu.RUnlock() s.IDP.ServeMetadata(w, r) }) - mux.Handle("/sso", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/sso", func(w http.ResponseWriter, r *http.Request) { + s.idpConfigMu.RLock() + defer s.idpConfigMu.RUnlock() s.IDP.ServeSSO(w, r) }) - mux.Handle("/login", s.HandleLogin) - mux.Handle("/login/:shortcut", s.HandleIDPInitiated) - mux.Handle("/login/:shortcut/*", s.HandleIDPInitiated) - - mux.Get("/services/", s.HandleListServices) - mux.Get("/services/:id", s.HandleGetService) - mux.Put("/services/:id", s.HandlePutService) - mux.Post("/services/:id", s.HandlePutService) - mux.Delete("/services/:id", s.HandleDeleteService) - - mux.Get("/users/", s.HandleListUsers) - mux.Get("/users/:id", s.HandleGetUser) - mux.Put("/users/:id", s.HandlePutUser) - mux.Delete("/users/:id", s.HandleDeleteUser) - - sessionPath := regexp.MustCompile("/sessions/(?P.*)") - mux.Get("/sessions/", s.HandleListSessions) - mux.Get(sessionPath, s.HandleGetSession) - mux.Delete(sessionPath, s.HandleDeleteSession) - - mux.Get("/shortcuts/", s.HandleListShortcuts) - mux.Get("/shortcuts/:id", s.HandleGetShortcut) - mux.Put("/shortcuts/:id", s.HandlePutShortcut) - mux.Delete("/shortcuts/:id", s.HandleDeleteShortcut) + mux.HandleFunc("/login", s.HandleLogin) + mux.HandleFunc("/login/{shortcut}", s.HandleIDPInitiated) + mux.HandleFunc("/login/{shortcut}/{suffix}", s.HandleIDPInitiated) + + mux.HandleFunc("GET /services/", s.HandleListServices) + mux.HandleFunc("GET /services/{id}", s.HandleGetService) + mux.HandleFunc("PUT /services/{id}", s.HandlePutService) + mux.HandleFunc("POST /services/{id}", s.HandlePutService) + mux.HandleFunc("DELETE /services/{id}", s.HandleDeleteService) + + mux.HandleFunc("GET /users/", s.HandleListUsers) + mux.HandleFunc("GET /users/{id}", s.HandleGetUser) + mux.HandleFunc("PUT /users/{id}", s.HandlePutUser) + mux.HandleFunc("DELETE /users/{id}", s.HandleDeleteUser) + + mux.HandleFunc("GET /sessions/", s.HandleListSessions) + mux.HandleFunc("GET /sessions/{id}", s.HandleGetSession) + mux.HandleFunc("DELETE /sessions/{id}", s.HandleDeleteSession) + + mux.HandleFunc("GET /shortcuts/", s.HandleListShortcuts) + mux.HandleFunc("GET /shortcuts/{id}", s.HandleGetShortcut) + mux.HandleFunc("PUT /shortcuts/{id}", s.HandlePutShortcut) + mux.HandleFunc("DELETE /shortcuts/{id}", s.HandleDeleteShortcut) } diff --git a/samlidp/samlidp_test.go b/samlidp/samlidp_test.go index 0cf573ea..7600c213 100644 --- a/samlidp/samlidp_test.go +++ b/samlidp/samlidp_test.go @@ -16,8 +16,6 @@ import ( is "gotest.tools/assert/cmp" "gotest.tools/golden" - "github.com/golang-jwt/jwt/v4" - "github.com/clerk/saml" "github.com/clerk/saml/logger" ) @@ -83,7 +81,6 @@ func NewServerTest(t *testing.T) *ServerTest { rv, _ := time.Parse("Mon Jan 2 15:04:05 MST 2006", "Mon Dec 1 01:57:09 UTC 2015") return rv } - jwt.TimeFunc = saml.TimeNow saml.RandReader = &testRandomReader{} test.SPKey = mustParsePrivateKey(golden.Get(t, "sp_key.pem")).(*rsa.PrivateKey) @@ -139,7 +136,7 @@ func TestHTTPCanSSORequest(t *testing.T) { test.Server.ServeHTTP(w, r) assert.Check(t, is.Equal(http.StatusOK, w.Code)) assert.Check(t, - strings.HasPrefix(w.Body.String(), "

"), + strings.HasPrefix(w.Body.String(), "

"), w.Body.String()) golden.Assert(t, w.Body.String(), "http_sso_response.html") } diff --git a/samlidp/service.go b/samlidp/service.go index b39dc1a1..1d513834 100644 --- a/samlidp/service.go +++ b/samlidp/service.go @@ -7,8 +7,6 @@ import ( "net/http" "os" - "github.com/zenazn/goji/web" - "github.com/clerk/saml" ) @@ -37,7 +35,7 @@ func (s *Server) GetServiceProvider(_ *http.Request, serviceProviderID string) ( // HandleListServices handles the `GET /services/` request and responds with a JSON formatted list // of service names. -func (s *Server) HandleListServices(_ web.C, w http.ResponseWriter, _ *http.Request) { +func (s *Server) HandleListServices(w http.ResponseWriter, _ *http.Request) { services, err := s.Store.List("/services/") if err != nil { s.logger.Printf("ERROR: %s", err) @@ -56,9 +54,9 @@ func (s *Server) HandleListServices(_ web.C, w http.ResponseWriter, _ *http.Requ // HandleGetService handles the `GET /services/:id` request and responds with the service // metadata in XML format. -func (s *Server) HandleGetService(c web.C, w http.ResponseWriter, _ *http.Request) { +func (s *Server) HandleGetService(w http.ResponseWriter, r *http.Request) { service := Service{} - err := s.Store.Get(fmt.Sprintf("/services/%s", c.URLParams["id"]), &service) + err := s.Store.Get(fmt.Sprintf("/services/%s", r.PathValue("id")), &service) if err != nil { s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -73,7 +71,7 @@ func (s *Server) HandleGetService(c web.C, w http.ResponseWriter, _ *http.Reques // HandlePutService handles the `PUT /shortcuts/:id` request. It accepts the XML-formatted // service metadata in the request body and stores it. -func (s *Server) HandlePutService(c web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandlePutService(w http.ResponseWriter, r *http.Request) { service := Service{} metadata, err := getSPMetadata(r.Body) @@ -85,7 +83,7 @@ func (s *Server) HandlePutService(c web.C, w http.ResponseWriter, r *http.Reques service.Metadata = *metadata - err = s.Store.Put(fmt.Sprintf("/services/%s", c.URLParams["id"]), &service) + err = s.Store.Put(fmt.Sprintf("/services/%s", r.PathValue("id")), &service) if err != nil { s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -100,16 +98,16 @@ func (s *Server) HandlePutService(c web.C, w http.ResponseWriter, r *http.Reques } // HandleDeleteService handles the `DELETE /services/:id` request. -func (s *Server) HandleDeleteService(c web.C, w http.ResponseWriter, _ *http.Request) { +func (s *Server) HandleDeleteService(w http.ResponseWriter, r *http.Request) { service := Service{} - err := s.Store.Get(fmt.Sprintf("/services/%s", c.URLParams["id"]), &service) + err := s.Store.Get(fmt.Sprintf("/services/%s", r.PathValue("id")), &service) if err != nil { s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } - if err := s.Store.Delete(fmt.Sprintf("/services/%s", c.URLParams["id"])); err != nil { + if err := s.Store.Delete(fmt.Sprintf("/services/%s", r.PathValue("id"))); err != nil { s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return diff --git a/samlidp/session.go b/samlidp/session.go index 9e9af828..61236906 100644 --- a/samlidp/session.go +++ b/samlidp/session.go @@ -5,14 +5,12 @@ import ( "encoding/hex" "encoding/json" "fmt" + "html/template" "net/http" - "text/template" "time" "golang.org/x/crypto/bcrypt" - "github.com/zenazn/goji/web" - "github.com/clerk/saml" ) @@ -38,12 +36,14 @@ func (s *Server) GetSession(w http.ResponseWriter, r *http.Request, req *saml.Id if r.Method == "POST" && r.PostForm.Get("user") != "" { user := User{} if err := s.Store.Get(fmt.Sprintf("/users/%s", r.PostForm.Get("user")), &user); err != nil { - s.sendLoginForm(w, r, req, "Invalid username or password") + s.logger.Printf("ERROR: User '%s' doesn't exists", r.PostForm.Get("user")) + s.sendLoginForm(w, req, "Invalid username or password") return nil } if err := bcrypt.CompareHashAndPassword(user.HashedPassword, []byte(r.PostForm.Get("password"))); err != nil { - s.sendLoginForm(w, r, req, "Invalid username or password") + s.logger.Printf("ERROR: Invalid password for user '%s'", r.PostForm.Get("user")) + s.sendLoginForm(w, req, "Invalid username or password") return nil } @@ -75,6 +75,8 @@ func (s *Server) GetSession(w http.ResponseWriter, r *http.Request, req *saml.Id Secure: r.URL.Scheme == "https", Path: "/", }) + + s.logger.Printf("User '%s' authenticated successfully", r.PostForm.Get("user")) return session } @@ -82,7 +84,7 @@ func (s *Server) GetSession(w http.ResponseWriter, r *http.Request, req *saml.Id session := &saml.Session{} if err := s.Store.Get(fmt.Sprintf("/sessions/%s", sessionCookie.Value), session); err != nil { if err == ErrNotFound { - s.sendLoginForm(w, r, req, "") + s.sendLoginForm(w, req, "") return nil } http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -90,31 +92,36 @@ func (s *Server) GetSession(w http.ResponseWriter, r *http.Request, req *saml.Id } if saml.TimeNow().After(session.ExpireTime) { - s.sendLoginForm(w, r, req, "") + s.sendLoginForm(w, req, "") return nil } return session } - s.sendLoginForm(w, r, req, "") + s.sendLoginForm(w, req, "") return nil } +var defaultLoginFormTemplate = template.Must(template.New("saml-post-form").Parse(`` + + `` + + `

{{.Toast}}

` + + `` + + `` + + `` + + `` + + `` + + `` + + `
` + + ``)) + // sendLoginForm produces a form which requests a username and password and directs the user // back to the IDP authorize URL to restart the SAML login flow, this time establishing a // session based on the credentials that were provided. -func (s *Server) sendLoginForm(w http.ResponseWriter, _ *http.Request, req *saml.IdpAuthnRequest, toast string) { - tmpl := template.Must(template.New("saml-post-form").Parse(`` + - `` + - `

{{.Toast}}

` + - `
` + - `` + - `` + - `` + - `` + - `` + - `
` + - ``)) +func (s *Server) sendLoginForm(w http.ResponseWriter, req *saml.IdpAuthnRequest, toast string) { + tmpl := s.LoginFormTemplate + if tmpl == nil { + tmpl = defaultLoginFormTemplate + } data := struct { Toast string URL string @@ -122,7 +129,7 @@ func (s *Server) sendLoginForm(w http.ResponseWriter, _ *http.Request, req *saml RelayState string }{ Toast: toast, - URL: req.IDP.SSOURL.String(), + URL: req.IDP.LoginURL.String(), SAMLRequest: base64.StdEncoding.EncodeToString(req.RequestBuffer), RelayState: req.RelayState, } @@ -136,7 +143,7 @@ func (s *Server) sendLoginForm(w http.ResponseWriter, _ *http.Request, req *saml // in the request body, then they are validated. For valid credentials, the response is a // 200 OK and the JSON session object. For invalid credentials, the HTML login prompt form // is sent. -func (s *Server) HandleLogin(_ web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleLogin(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return @@ -153,7 +160,7 @@ func (s *Server) HandleLogin(_ web.C, w http.ResponseWriter, r *http.Request) { // HandleListSessions handles the `GET /sessions/` request and responds with a JSON formatted list // of session names. -func (s *Server) HandleListSessions(_ web.C, w http.ResponseWriter, _ *http.Request) { +func (s *Server) HandleListSessions(w http.ResponseWriter, _ *http.Request) { sessions, err := s.Store.List("/sessions/") if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -171,9 +178,9 @@ func (s *Server) HandleListSessions(_ web.C, w http.ResponseWriter, _ *http.Requ // HandleGetSession handles the `GET /sessions/:id` request and responds with the session // object in JSON format. -func (s *Server) HandleGetSession(c web.C, w http.ResponseWriter, _ *http.Request) { +func (s *Server) HandleGetSession(w http.ResponseWriter, r *http.Request) { session := saml.Session{} - err := s.Store.Get(fmt.Sprintf("/sessions/%s", c.URLParams["id"]), &session) + err := s.Store.Get(fmt.Sprintf("/sessions/%s", r.PathValue("id")), &session) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return @@ -186,8 +193,8 @@ func (s *Server) HandleGetSession(c web.C, w http.ResponseWriter, _ *http.Reques // HandleDeleteSession handles the `DELETE /sessions/:id` request. It invalidates the // specified session. -func (s *Server) HandleDeleteSession(c web.C, w http.ResponseWriter, _ *http.Request) { - err := s.Store.Delete(fmt.Sprintf("/sessions/%s", c.URLParams["id"])) +func (s *Server) HandleDeleteSession(w http.ResponseWriter, r *http.Request) { + err := s.Store.Delete(fmt.Sprintf("/sessions/%s", r.PathValue("id"))) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return diff --git a/samlidp/session_test.go b/samlidp/session_test.go index cb3d5c3d..6244f923 100644 --- a/samlidp/session_test.go +++ b/samlidp/session_test.go @@ -63,4 +63,27 @@ func TestSessionsCrud(t *testing.T) { assert.Check(t, is.Equal("{\"sessions\":[]}\n", w.Body.String())) + // user doesn't exists case + w = httptest.NewRecorder() + r, _ = http.NewRequest("POST", "https://idp.example.com/login", + strings.NewReader("user=unknown&password=dummypassword")) + r.Header.Set("Content-type", "application/x-www-form-urlencoded") + test.Server.ServeHTTP(w, r) + assert.Check(t, is.Equal(http.StatusOK, w.Code)) + assert.Check(t, is.Equal("text/html; charset=utf-8", + w.Header().Get("Content-type"))) + assert.Check(t, is.Equal(`

Invalid username or password

`, + w.Body.String())) + + // invalid username/password exists case + w = httptest.NewRecorder() + r, _ = http.NewRequest("POST", "https://idp.example.com/login", + strings.NewReader("user=alice&password=dummypassword")) + r.Header.Set("Content-type", "application/x-www-form-urlencoded") + test.Server.ServeHTTP(w, r) + assert.Check(t, is.Equal(http.StatusOK, w.Code)) + assert.Check(t, is.Equal("text/html; charset=utf-8", + w.Header().Get("Content-type"))) + assert.Check(t, is.Equal(`

Invalid username or password

`, + w.Body.String())) } diff --git a/samlidp/shortcut.go b/samlidp/shortcut.go index 192e5022..f08efe7d 100644 --- a/samlidp/shortcut.go +++ b/samlidp/shortcut.go @@ -4,8 +4,6 @@ import ( "encoding/json" "fmt" "net/http" - - "github.com/zenazn/goji/web" ) // Shortcut represents an IDP-initiated SAML flow. When a user @@ -31,7 +29,7 @@ type Shortcut struct { // HandleListShortcuts handles the `GET /shortcuts/` request and responds with a JSON formatted list // of shortcut names. -func (s *Server) HandleListShortcuts(_ web.C, w http.ResponseWriter, _ *http.Request) { +func (s *Server) HandleListShortcuts(w http.ResponseWriter, _ *http.Request) { shortcuts, err := s.Store.List("/shortcuts/") if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -49,9 +47,9 @@ func (s *Server) HandleListShortcuts(_ web.C, w http.ResponseWriter, _ *http.Req // HandleGetShortcut handles the `GET /shortcuts/:id` request and responds with the shortcut // object in JSON format. -func (s *Server) HandleGetShortcut(c web.C, w http.ResponseWriter, _ *http.Request) { +func (s *Server) HandleGetShortcut(w http.ResponseWriter, r *http.Request) { shortcut := Shortcut{} - err := s.Store.Get(fmt.Sprintf("/shortcuts/%s", c.URLParams["id"]), &shortcut) + err := s.Store.Get(fmt.Sprintf("/shortcuts/%s", r.PathValue("id")), &shortcut) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return @@ -64,15 +62,15 @@ func (s *Server) HandleGetShortcut(c web.C, w http.ResponseWriter, _ *http.Reque // HandlePutShortcut handles the `PUT /shortcuts/:id` request. It accepts a JSON formatted // shortcut object in the request body and stores it. -func (s *Server) HandlePutShortcut(c web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandlePutShortcut(w http.ResponseWriter, r *http.Request) { shortcut := Shortcut{} if err := json.NewDecoder(r.Body).Decode(&shortcut); err != nil { http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } - shortcut.Name = c.URLParams["id"] + shortcut.Name = r.PathValue("id") - err := s.Store.Put(fmt.Sprintf("/shortcuts/%s", c.URLParams["id"]), &shortcut) + err := s.Store.Put(fmt.Sprintf("/shortcuts/%s", r.PathValue("id")), &shortcut) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return @@ -81,8 +79,8 @@ func (s *Server) HandlePutShortcut(c web.C, w http.ResponseWriter, r *http.Reque } // HandleDeleteShortcut handles the `DELETE /shortcuts/:id` request. -func (s *Server) HandleDeleteShortcut(c web.C, w http.ResponseWriter, _ *http.Request) { - err := s.Store.Delete(fmt.Sprintf("/shortcuts/%s", c.URLParams["id"])) +func (s *Server) HandleDeleteShortcut(w http.ResponseWriter, r *http.Request) { + err := s.Store.Delete(fmt.Sprintf("/shortcuts/%s", r.PathValue("id"))) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return @@ -93,8 +91,8 @@ func (s *Server) HandleDeleteShortcut(c web.C, w http.ResponseWriter, _ *http.Re // HandleIDPInitiated handles a request for an IDP initiated login flow. It looks up // the specified shortcut, generates the appropriate SAML assertion and redirects the // user via the HTTP-POST binding to the service providers ACS URL. -func (s *Server) HandleIDPInitiated(c web.C, w http.ResponseWriter, r *http.Request) { - shortcutName := c.URLParams["shortcut"] +func (s *Server) HandleIDPInitiated(w http.ResponseWriter, r *http.Request) { + shortcutName := r.PathValue("shortcut") shortcut := Shortcut{} if err := s.Store.Get(fmt.Sprintf("/shortcuts/%s", shortcutName), &shortcut); err != nil { s.logger.Printf("ERROR: %s", err) @@ -107,7 +105,9 @@ func (s *Server) HandleIDPInitiated(c web.C, w http.ResponseWriter, r *http.Requ case shortcut.RelayState != nil: relayState = *shortcut.RelayState case shortcut.URISuffixAsRelayState: - relayState = c.URLParams["*"] + if suffix := r.PathValue("suffix"); suffix != "" { + relayState = "/" + suffix + } } s.IDP.ServeIDPInitiated(w, r, shortcut.ServiceProviderID, relayState) diff --git a/samlidp/testdata/http_sso_response.html b/samlidp/testdata/http_sso_response.html index ea0f60f6..47db428a 100644 --- a/samlidp/testdata/http_sso_response.html +++ b/samlidp/testdata/http_sso_response.html @@ -1 +1 @@ -

\ No newline at end of file +

\ No newline at end of file diff --git a/samlidp/user.go b/samlidp/user.go index c8c412cb..de08ede8 100644 --- a/samlidp/user.go +++ b/samlidp/user.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" - "github.com/zenazn/goji/web" "golang.org/x/crypto/bcrypt" ) @@ -25,7 +24,7 @@ type User struct { // HandleListUsers handles the `GET /users/` request and responds with a JSON formatted list // of user names. -func (s *Server) HandleListUsers(_ web.C, w http.ResponseWriter, _ *http.Request) { +func (s *Server) HandleListUsers(w http.ResponseWriter, _ *http.Request) { users, err := s.Store.List("/users/") if err != nil { s.logger.Printf("ERROR: %s", err) @@ -45,9 +44,9 @@ func (s *Server) HandleListUsers(_ web.C, w http.ResponseWriter, _ *http.Request // HandleGetUser handles the `GET /users/:id` request and responds with the user object in JSON // format. The HashedPassword field is excluded. -func (s *Server) HandleGetUser(c web.C, w http.ResponseWriter, _ *http.Request) { +func (s *Server) HandleGetUser(w http.ResponseWriter, r *http.Request) { user := User{} - err := s.Store.Get(fmt.Sprintf("/users/%s", c.URLParams["id"]), &user) + err := s.Store.Get(fmt.Sprintf("/users/%s", r.PathValue("id")), &user) if err != nil { s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -65,14 +64,14 @@ func (s *Server) HandleGetUser(c web.C, w http.ResponseWriter, _ *http.Request) // the request body and stores it. If the PlaintextPassword field is present then it is hashed // and stored in HashedPassword. If the PlaintextPassword field is not present then // HashedPassword retains it's stored value. -func (s *Server) HandlePutUser(c web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandlePutUser(w http.ResponseWriter, r *http.Request) { user := User{} if err := json.NewDecoder(r.Body).Decode(&user); err != nil { s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } - user.Name = c.URLParams["id"] + user.Name = r.PathValue("id") if user.PlaintextPassword != nil { var err error @@ -84,11 +83,11 @@ func (s *Server) HandlePutUser(c web.C, w http.ResponseWriter, r *http.Request) } } else { existingUser := User{} - err := s.Store.Get(fmt.Sprintf("/users/%s", c.URLParams["id"]), &existingUser) - switch { - case err == nil: + err := s.Store.Get(fmt.Sprintf("/users/%s", r.PathValue("id")), &existingUser) + switch err { + case nil: user.HashedPassword = existingUser.HashedPassword - case err == ErrNotFound: + case ErrNotFound: // nop default: s.logger.Printf("ERROR: %s", err) @@ -98,7 +97,7 @@ func (s *Server) HandlePutUser(c web.C, w http.ResponseWriter, r *http.Request) } user.PlaintextPassword = nil - err := s.Store.Put(fmt.Sprintf("/users/%s", c.URLParams["id"]), &user) + err := s.Store.Put(fmt.Sprintf("/users/%s", r.PathValue("id")), &user) if err != nil { s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -108,8 +107,8 @@ func (s *Server) HandlePutUser(c web.C, w http.ResponseWriter, r *http.Request) } // HandleDeleteUser handles the `DELETE /users/:id` request. -func (s *Server) HandleDeleteUser(c web.C, w http.ResponseWriter, _ *http.Request) { - err := s.Store.Delete(fmt.Sprintf("/users/%s", c.URLParams["id"])) +func (s *Server) HandleDeleteUser(w http.ResponseWriter, r *http.Request) { + err := s.Store.Delete(fmt.Sprintf("/users/%s", r.PathValue("id"))) if err != nil { s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) diff --git a/samlsp/basic_assertion_handler.go b/samlsp/basic_assertion_handler.go new file mode 100644 index 00000000..a1b86958 --- /dev/null +++ b/samlsp/basic_assertion_handler.go @@ -0,0 +1,15 @@ +package samlsp + +import ( + "github.com/clerk/saml" +) + +var _ AssertionHandler = NopAssertionHandler{} + +// NopAssertionHandler is an implementation of AssertionHandler that does nothing. +type NopAssertionHandler struct{} + +// HandleAssertion is called and passed a SAML assertion. This implementation does nothing. +func (as NopAssertionHandler) HandleAssertion(_ *saml.Assertion) error { + return nil +} diff --git a/samlsp/fetch_metadata.go b/samlsp/fetch_metadata.go index d9b82591..0b6b6c31 100644 --- a/samlsp/fetch_metadata.go +++ b/samlsp/fetch_metadata.go @@ -5,11 +5,11 @@ import ( "context" "encoding/xml" "errors" + "fmt" "io" "net/http" "net/url" - "github.com/crewjam/httperr" xrv "github.com/mattermost/xml-roundtrip-validator" "github.com/clerk/saml/logger" @@ -69,7 +69,7 @@ func FetchMetadata(ctx context.Context, httpClient *http.Client, metadataURL url } }() if resp.StatusCode >= 400 { - return nil, httperr.Response(*resp) + return nil, fmt.Errorf("failed to fetch metadata: unexpected status code %d", resp.StatusCode) } data, err := io.ReadAll(resp.Body) diff --git a/samlsp/middleware.go b/samlsp/middleware.go index 4392fbf3..e9e490ce 100644 --- a/samlsp/middleware.go +++ b/samlsp/middleware.go @@ -40,12 +40,13 @@ import ( // SAML service provider already has a private key, we borrow that key // to sign the JWTs as well. type Middleware struct { - ServiceProvider saml.ServiceProvider - OnError func(w http.ResponseWriter, r *http.Request, err error) - Binding string // either saml.HTTPPostBinding or saml.HTTPRedirectBinding - ResponseBinding string // either saml.HTTPPostBinding or saml.HTTPArtifactBinding - RequestTracker RequestTracker - Session SessionProvider + ServiceProvider saml.ServiceProvider + OnError func(w http.ResponseWriter, r *http.Request, err error) + Binding string // either saml.HTTPPostBinding or saml.HTTPRedirectBinding + ResponseBinding string // either saml.HTTPPostBinding or saml.HTTPArtifactBinding + RequestTracker RequestTracker + Session SessionProvider + AssertionHandler AssertionHandler } // ServeHTTP implements http.Handler and serves the SAML-specific HTTP endpoints @@ -99,6 +100,11 @@ func (m *Middleware) ServeACS(w http.ResponseWriter, r *http.Request) { return } + if handlerErr := m.AssertionHandler.HandleAssertion(assertion); handlerErr != nil { + m.OnError(w, r, handlerErr) + return + } + m.CreateSessionFromAssertion(w, r, assertion, m.ServiceProvider.DefaultRedirectURI) } @@ -226,11 +232,6 @@ func (m *Middleware) CreateSessionFromAssertion(w http.ResponseWriter, r *http.R // SAML attribute `name` be set to `value`. This can be used to require // that a remote user be a member of a group. It relies on the Claims assigned // to to the context in RequireAccount. -// -// For example: -// -// goji.Use(m.RequireAccount) -// goji.Use(RequireAttributeMiddleware("eduPersonAffiliation", "Staff")) func RequireAttribute(name, value string) func(http.Handler) http.Handler { return func(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/samlsp/middleware_test.go b/samlsp/middleware_test.go index 14d32e45..6bc32fff 100644 --- a/samlsp/middleware_test.go +++ b/samlsp/middleware_test.go @@ -17,7 +17,6 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt/v4" dsig "github.com/russellhaering/goxmldsig" "gotest.tools/assert" is "gotest.tools/assert/cmp" @@ -55,7 +54,6 @@ func NewMiddlewareTest(t *testing.T) *MiddlewareTest { rv, _ := time.Parse("Mon Jan 2 15:04:05.999999999 MST 2006", "Mon Dec 1 01:57:09.123456789 UTC 2015") return rv } - jwt.TimeFunc = saml.TimeNow saml.Clock = dsig.NewFakeClockAt(saml.TimeNow()) saml.RandReader = &testRandomReader{} @@ -150,7 +148,7 @@ func TestMiddlewareRequireAccountNoCreds(t *testing.T) { test.Middleware.ServiceProvider.AcsURL.Scheme = "http" handler := test.Middleware.RequireAccount( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { panic("not reached") })) @@ -174,7 +172,7 @@ func TestMiddlewareRequireAccountNoCredsSecure(t *testing.T) { test := NewMiddlewareTest(t) handler := test.Middleware.RequireAccount( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { panic("not reached") })) @@ -200,7 +198,7 @@ func TestMiddlewareRequireAccountNoCredsPostBinding(t *testing.T) { test.Middleware.ServiceProvider.GetSSOBindingLocation(saml.HTTPRedirectBinding))) handler := test.Middleware.RequireAccount( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { panic("not reached") })) @@ -256,7 +254,7 @@ func TestMiddlewareRequireAccountCreds(t *testing.T) { func TestMiddlewareRequireAccountBadCreds(t *testing.T) { test := NewMiddlewareTest(t) handler := test.Middleware.RequireAccount( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { panic("not reached") })) @@ -281,13 +279,13 @@ func TestMiddlewareRequireAccountBadCreds(t *testing.T) { func TestMiddlewareRequireAccountExpiredCreds(t *testing.T) { test := NewMiddlewareTest(t) - jwt.TimeFunc = func() time.Time { + saml.TimeNow = func() time.Time { rv, _ := time.Parse("Mon Jan 2 15:04:05 UTC 2006", "Mon Dec 1 01:31:21 UTC 2115") return rv } handler := test.Middleware.RequireAccount( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { panic("not reached") })) @@ -306,13 +304,13 @@ func TestMiddlewareRequireAccountExpiredCreds(t *testing.T) { assert.Check(t, err) decodedRequest, err := testsaml.ParseRedirectRequest(redirectURL) assert.Check(t, err) - golden.Assert(t, string(decodedRequest), "expected_authn_request_secure.xml") + golden.Assert(t, strings.Replace(string(decodedRequest), `IssueInstant="2115-12-01T01:31:21Z"`, `IssueInstant="2015-12-01T01:57:09.123Z"`, 1), "expected_authn_request_secure.xml") } func TestMiddlewareRequireAccountPanicOnRequestToACS(t *testing.T) { test := NewMiddlewareTest(t) handler := test.Middleware.RequireAccount( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { panic("not reached") })) @@ -326,7 +324,7 @@ func TestMiddlewareRequireAttribute(t *testing.T) { test := NewMiddlewareTest(t) handler := test.Middleware.RequireAccount( RequireAttribute("eduPersonAffiliation", "Staff")( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusTeapot) }))) @@ -344,7 +342,7 @@ func TestMiddlewareRequireAttributeWrongValue(t *testing.T) { test := NewMiddlewareTest(t) handler := test.Middleware.RequireAccount( RequireAttribute("eduPersonAffiliation", "DomainAdmins")( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { panic("not reached") }))) @@ -362,7 +360,7 @@ func TestMiddlewareRequireAttributeNotPresent(t *testing.T) { test := NewMiddlewareTest(t) handler := test.Middleware.RequireAccount( RequireAttribute("valueThatDoesntExist", "doesntMatter")( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { panic("not reached") }))) @@ -379,7 +377,7 @@ func TestMiddlewareRequireAttributeNotPresent(t *testing.T) { func TestMiddlewareRequireAttributeMissingAccount(t *testing.T) { test := NewMiddlewareTest(t) handler := RequireAttribute("eduPersonAffiliation", "DomainAdmins")( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { panic("not reached") })) @@ -409,9 +407,10 @@ func TestMiddlewareCanParseResponse(t *testing.T) { assert.Check(t, is.Equal("/frob", resp.Header().Get("Location"))) assert.Check(t, is.DeepEqual([]string{ - "saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6=; Domain=15661444.ngrok.io; Expires=Thu, 01 Jan 1970 00:00:01 GMT", + "saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6=; Path=/saml2/acs; Domain=15661444.ngrok.io; Expires=Thu, 01 Jan 1970 00:00:01 GMT", "ttt=" + test.expectedSessionCookie + "; " + - "Path=/; Domain=15661444.ngrok.io; Max-Age=7200; HttpOnly; Secure"}, + "Path=/; Domain=15661444.ngrok.io; Max-Age=7200; HttpOnly; Secure", + }, resp.Header()["Set-Cookie"])) } @@ -455,7 +454,7 @@ func TestMiddlewareDefaultCookieDomainIPv6(t *testing.T) { func TestMiddlewareRejectsInvalidRelayState(t *testing.T) { test := NewMiddlewareTest(t) - test.Middleware.OnError = func(w http.ResponseWriter, r *http.Request, err error) { + test.Middleware.OnError = func(w http.ResponseWriter, _ *http.Request, err error) { assert.Check(t, is.Error(err, http.ErrNoCookie.Error())) http.Error(w, "forbidden", http.StatusTeapot) } @@ -478,7 +477,7 @@ func TestMiddlewareRejectsInvalidRelayState(t *testing.T) { func TestMiddlewareRejectsInvalidCookie(t *testing.T) { test := NewMiddlewareTest(t) - test.Middleware.OnError = func(w http.ResponseWriter, r *http.Request, err error) { + test.Middleware.OnError = func(w http.ResponseWriter, _ *http.Request, err error) { assert.Check(t, is.Error(err, "Authentication failed")) http.Error(w, "forbidden", http.StatusTeapot) } diff --git a/samlsp/new.go b/samlsp/new.go index ec8934c6..39827e0c 100644 --- a/samlsp/new.go +++ b/samlsp/new.go @@ -2,11 +2,15 @@ package samlsp import ( + "crypto" + "crypto/ecdsa" "crypto/rsa" "crypto/x509" + "fmt" "net/http" "net/url" + "github.com/golang-jwt/jwt/v5" dsig "github.com/russellhaering/goxmldsig" "github.com/clerk/saml" @@ -16,7 +20,7 @@ import ( type Options struct { EntityID string URL url.URL - Key *rsa.PrivateKey + Key crypto.Signer Certificate *x509.Certificate Intermediates []*x509.Certificate HTTPClient *http.Client @@ -33,11 +37,23 @@ type Options struct { LogoutBindings []string } +func getDefaultSigningMethod(signer crypto.Signer) jwt.SigningMethod { + if signer != nil { + switch signer.Public().(type) { + case *ecdsa.PublicKey: + return jwt.SigningMethodES256 + case *rsa.PublicKey: + return jwt.SigningMethodRS256 + } + } + return jwt.SigningMethodRS256 +} + // DefaultSessionCodec returns the default SessionCodec for the provided options, // a JWTSessionCodec configured to issue signed tokens. func DefaultSessionCodec(opts Options) JWTSessionCodec { return JWTSessionCodec{ - SigningMethod: defaultJWTSigningMethod, + SigningMethod: getDefaultSigningMethod(opts.Key), Audience: opts.URL.String(), Issuer: opts.URL.String(), MaxAge: defaultSessionMaxAge, @@ -67,7 +83,7 @@ func DefaultSessionProvider(opts Options) CookieSessionProvider { // options, a JWTTrackedRequestCodec that uses a JWT to encode TrackedRequests. func DefaultTrackedRequestCodec(opts Options) JWTTrackedRequestCodec { return JWTTrackedRequestCodec{ - SigningMethod: defaultJWTSigningMethod, + SigningMethod: getDefaultSigningMethod(opts.Key), Audience: opts.URL.String(), Issuer: opts.URL.String(), MaxAge: saml.MaxIssueDelay, @@ -99,7 +115,8 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider { if opts.ForceAuthn { forceAuthn = &opts.ForceAuthn } - signatureMethod := dsig.RSASHA1SignatureMethod + + signatureMethod := defaultSigningMethodForKey(opts.Key) if !opts.SignRequest { signatureMethod = "" } @@ -131,6 +148,25 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider { } } +func defaultSigningMethodForKey(key crypto.Signer) string { + switch key.(type) { + case *rsa.PrivateKey: + return dsig.RSASHA1SignatureMethod + case *ecdsa.PrivateKey: + return dsig.ECDSASHA256SignatureMethod + case nil: + return "" + default: + panic(fmt.Sprintf("programming error: unsupported key type %T", key)) + } +} + +// DefaultAssertionHandler returns the default AssertionHandler for the provided options, +// a NopAssertionHandler configured to do nothing. +func DefaultAssertionHandler(_ Options) NopAssertionHandler { + return NopAssertionHandler{} +} + // New creates a new Middleware with the default providers for the // given options. // @@ -139,11 +175,12 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider { // in the returned Middleware. func New(opts Options) (*Middleware, error) { m := &Middleware{ - ServiceProvider: DefaultServiceProvider(opts), - Binding: "", - ResponseBinding: saml.HTTPPostBinding, - OnError: DefaultOnError, - Session: DefaultSessionProvider(opts), + ServiceProvider: DefaultServiceProvider(opts), + Binding: "", + ResponseBinding: saml.HTTPPostBinding, + OnError: DefaultOnError, + Session: DefaultSessionProvider(opts), + AssertionHandler: DefaultAssertionHandler(opts), } m.RequestTracker = DefaultRequestTracker(opts, &m.ServiceProvider) if opts.UseArtifactResponse { diff --git a/samlsp/new_test.go b/samlsp/new_test.go index 7f0a760a..86eb49ee 100644 --- a/samlsp/new_test.go +++ b/samlsp/new_test.go @@ -3,7 +3,6 @@ package samlsp import ( "testing" - "github.com/stretchr/testify/require" "gotest.tools/assert" ) @@ -24,7 +23,7 @@ func TestNewCanAcceptCookieName(t *testing.T) { CookieName: tc.cookieName, } sp, err := New(opts) - require.Nil(t, err) + assert.Assert(t, err) cookieProvider := sp.Session.(CookieSessionProvider) assert.Equal(t, tc.expected, cookieProvider.Name) diff --git a/samlsp/request_tracker_cookie.go b/samlsp/request_tracker_cookie.go index 9a420c20..07dab43b 100644 --- a/samlsp/request_tracker_cookie.go +++ b/samlsp/request_tracker_cookie.go @@ -67,6 +67,7 @@ func (t CookieRequestTracker) StopTrackingRequest(w http.ResponseWriter, r *http cookie.Value = "" cookie.Domain = t.ServiceProvider.AcsURL.Hostname() cookie.Expires = time.Unix(1, 0) // past time as close to epoch as possible, but not zero time.Time{} + cookie.Path = t.ServiceProvider.AcsURL.Path http.SetCookie(w, cookie) return nil } diff --git a/samlsp/request_tracker_jwt.go b/samlsp/request_tracker_jwt.go index 906702c0..6ba4616a 100644 --- a/samlsp/request_tracker_jwt.go +++ b/samlsp/request_tracker_jwt.go @@ -1,24 +1,22 @@ package samlsp import ( - "crypto/rsa" + "crypto" "fmt" "time" - "github.com/golang-jwt/jwt/v4" + "github.com/golang-jwt/jwt/v5" "github.com/clerk/saml" ) -var defaultJWTSigningMethod = jwt.SigningMethodRS256 - // JWTTrackedRequestCodec encodes TrackedRequests as signed JWTs type JWTTrackedRequestCodec struct { SigningMethod jwt.SigningMethod Audience string Issuer string MaxAge time.Duration - Key *rsa.PrivateKey + Key crypto.Signer } var _ TrackedRequestCodec = JWTTrackedRequestCodec{} @@ -51,9 +49,12 @@ func (s JWTTrackedRequestCodec) Encode(value TrackedRequest) (string, error) { // Decode returns a Tracked request from an encoded string. func (s JWTTrackedRequestCodec) Decode(signed string) (*TrackedRequest, error) { - parser := jwt.Parser{ - ValidMethods: []string{s.SigningMethod.Alg()}, - } + parser := jwt.NewParser( + jwt.WithValidMethods([]string{s.SigningMethod.Alg()}), + jwt.WithTimeFunc(saml.TimeNow), + jwt.WithAudience(s.Audience), + jwt.WithIssuer(s.Issuer), + ) claims := JWTTrackedRequestClaims{} _, err := parser.ParseWithClaims(signed, &claims, func(*jwt.Token) (interface{}, error) { return s.Key.Public(), nil @@ -61,15 +62,9 @@ func (s JWTTrackedRequestCodec) Decode(signed string) (*TrackedRequest, error) { if err != nil { return nil, err } - if !claims.VerifyAudience(s.Audience, true) { - return nil, fmt.Errorf("expected audience %q, got %q", s.Audience, claims.Audience) - } - if !claims.VerifyIssuer(s.Issuer, true) { - return nil, fmt.Errorf("expected issuer %q, got %q", s.Issuer, claims.Issuer) - } if !claims.SAMLAuthnRequest { return nil, fmt.Errorf("expected saml-authn-request") } - claims.TrackedRequest.Index = claims.Subject + claims.Index = claims.Subject return &claims.TrackedRequest, nil } diff --git a/samlsp/saml_assertion_handler.go b/samlsp/saml_assertion_handler.go new file mode 100644 index 00000000..a18f00d1 --- /dev/null +++ b/samlsp/saml_assertion_handler.go @@ -0,0 +1,9 @@ +package samlsp + +import "github.com/clerk/saml" + +// AssertionHandler is an interface implemented by types that can handle +// assertions and add extra functionality +type AssertionHandler interface { + HandleAssertion(assertion *saml.Assertion) error +} diff --git a/samlsp/session_cookie.go b/samlsp/session_cookie.go index 17f9be6a..4e44fc67 100644 --- a/samlsp/session_cookie.go +++ b/samlsp/session_cookie.go @@ -21,6 +21,7 @@ type CookieSessionProvider struct { Secure bool SameSite http.SameSite MaxAge time.Duration + Path string Codec SessionCodec } @@ -43,6 +44,11 @@ func (c CookieSessionProvider) CreateSession(w http.ResponseWriter, r *http.Requ return err } + path := c.Path + if path == "" { + path = "/" + } + http.SetCookie(w, &http.Cookie{ Name: c.Name, Domain: c.Domain, @@ -51,7 +57,7 @@ func (c CookieSessionProvider) CreateSession(w http.ResponseWriter, r *http.Requ HttpOnly: c.HTTPOnly, Secure: c.Secure || r.URL.Scheme == "https", SameSite: c.SameSite, - Path: "/", + Path: path, }) return nil } diff --git a/samlsp/session_jwt.go b/samlsp/session_jwt.go index def76908..8760237d 100644 --- a/samlsp/session_jwt.go +++ b/samlsp/session_jwt.go @@ -1,12 +1,11 @@ package samlsp import ( - "crypto/rsa" + "crypto" "errors" - "fmt" "time" - "github.com/golang-jwt/jwt/v4" + "github.com/golang-jwt/jwt/v5" "github.com/clerk/saml" ) @@ -23,7 +22,7 @@ type JWTSessionCodec struct { Audience string Issuer string MaxAge time.Duration - Key *rsa.PrivateKey + Key crypto.Signer } var _ SessionCodec = JWTSessionCodec{} @@ -35,11 +34,11 @@ func (c JWTSessionCodec) New(assertion *saml.Assertion) (Session, error) { now := saml.TimeNow() claims := JWTSessionClaims{} claims.SAMLSession = true - claims.Audience = c.Audience + claims.Audience = jwt.ClaimStrings{c.Audience} claims.Issuer = c.Issuer - claims.IssuedAt = now.Unix() - claims.ExpiresAt = now.Add(c.MaxAge).Unix() - claims.NotBefore = now.Unix() + claims.IssuedAt = jwt.NewNumericDate(now) + claims.ExpiresAt = jwt.NewNumericDate(now.Add(c.MaxAge)) + claims.NotBefore = jwt.NewNumericDate(now) if sub := assertion.Subject; sub != nil { if nameID := sub.NameID; nameID != nil { @@ -89,9 +88,12 @@ func (c JWTSessionCodec) Encode(s Session) (string, error) { // Decode parses the serialized session that may have been returned by Encode // and returns a Session. func (c JWTSessionCodec) Decode(signed string) (Session, error) { - parser := jwt.Parser{ - ValidMethods: []string{c.SigningMethod.Alg()}, - } + parser := jwt.NewParser( + jwt.WithValidMethods([]string{c.SigningMethod.Alg()}), + jwt.WithTimeFunc(saml.TimeNow), + jwt.WithAudience(c.Audience), + jwt.WithIssuer(c.Issuer), + ) claims := JWTSessionClaims{} _, err := parser.ParseWithClaims(signed, &claims, func(*jwt.Token) (interface{}, error) { return c.Key.Public(), nil @@ -100,12 +102,6 @@ func (c JWTSessionCodec) Decode(signed string) (Session, error) { if err != nil { return nil, err } - if !claims.VerifyAudience(c.Audience, true) { - return nil, fmt.Errorf("expected audience %q, got %q", c.Audience, claims.Audience) - } - if !claims.VerifyIssuer(c.Issuer, true) { - return nil, fmt.Errorf("expected issuer %q, got %q", c.Issuer, claims.Issuer) - } if !claims.SAMLSession { return nil, errors.New("expected saml-session") } @@ -114,7 +110,7 @@ func (c JWTSessionCodec) Decode(signed string) (Session, error) { // JWTSessionClaims represents the JWT claims in the encoded session type JWTSessionClaims struct { - jwt.StandardClaims + jwt.RegisteredClaims Attributes Attributes `json:"attr"` SAMLSession bool `json:"saml-session"` } diff --git a/schema.go b/schema.go index 23cddbca..1a543de2 100644 --- a/schema.go +++ b/schema.go @@ -353,12 +353,15 @@ func (r *ArtifactResolve) Element() *etree.Element { if r.Issuer != nil { el.AddChild(r.Issuer.Element()) } - artifact := etree.NewElement("samlp:Artifact") - artifact.SetText(r.Artifact) - el.AddChild(artifact) if r.Signature != nil { + // ADFS requires that come before . + // ref: https://github.com/crewjam/saml/issues/535 + // ref: https://www.wiktorzychla.com/2017/09/adfs-and-saml2-artifact-binding-woes.html el.AddChild(r.Signature) } + artifact := etree.NewElement("samlp:Artifact") + artifact.SetText(r.Artifact) + el.AddChild(artifact) return el } diff --git a/service_provider.go b/service_provider.go index 6a56ae4c..3b1f4720 100644 --- a/service_provider.go +++ b/service_provider.go @@ -4,7 +4,11 @@ import ( "bytes" "compress/flate" "context" + "crypto" + "crypto/ecdsa" "crypto/rsa" + "crypto/sha256" + "crypto/sha512" "crypto/tls" "crypto/x509" "encoding/base64" @@ -16,6 +20,7 @@ import ( "net/http" "net/url" "regexp" + "strings" "time" "github.com/beevik/etree" @@ -66,8 +71,9 @@ type ServiceProvider struct { // Entity ID is optional - if not specified then MetadataURL will be used EntityID string - // Key is the RSA private key we use to sign requests. - Key *rsa.PrivateKey + // Key is private key we use to sign requests. It must be either an + // *rsa.PrivateKey or an *ecdsa.PrivateKey. + Key crypto.Signer // Certificate is the RSA public part of Key. Certificate *x509.Certificate @@ -91,6 +97,18 @@ type ServiceProvider struct { // IDPMetadata is the metadata from the identity provider. IDPMetadata *EntityDescriptor + // IDPCertificateFingerprint is fingerprint of the idp public certificate. If this field is specified, + // IDPCertificateFingerprintAlgorithm must also be specified, and IDPCertificate must not be specified. + IDPCertificateFingerprint *string + // IDPCertificateFingerprintAlgorithm is fingerprint algorithm used to obtain fingerprint of the idp public + // certificate. + // If this field is specified, IDPCertificateFingerprint must also be specified, and IDPCertificate must not be specified. + IDPCertificateFingerprintAlgorithm *string + + // IDPCertificate to use as idp public certificate. If this field is specified, IDPCertificateFingerprint and + // IDPCertificateFingerprintAlgorithm must not be specified. + IDPCertificate *string + // AuthnNameIDFormat is the format used in the NameIDPolicy for // authentication requests AuthnNameIDFormat NameIDFormat @@ -117,12 +135,30 @@ type ServiceProvider struct { // to verify signatures. SignatureVerifier SignatureVerifier - // SignatureMethod, if non-empty, authentication requests will be signed + // SignatureMethod, if non-empty, authentication requests will be signed. + // + // The method specified here must be consistent with the type of Key. + // + // If Key is *rsa.PrivateKey, then this must be one of dsig.RSASHA1SignatureMethod, + // dsig.RSASHA256SignatureMethod, dsig.RSASHA384SignatureMethod, or + // dsig.RSASHA512SignatureMethod: + // + // If Key is *ecdsa.PrivateKey, then this must be one of dsig.ECDSASHA1SignatureMethod, + // dsig.ECDSASHA256SignatureMethod, dsig.ECDSASHA384SignatureMethod, or + // dsig.ECDSASHA512SignatureMethod. SignatureMethod string // LogoutBindings specify the bindings available for SLO endpoint. If empty, // HTTP-POST binding is used. LogoutBindings []string + + // ValidateAudienceRestriction allows you to override the default audience validation + // for an assertion. If nil, the default audience validation is used. + ValidateAudienceRestriction func(assertion *Assertion) error + + // ValidateRequestID allows you to override the default request ID validation. + // If nil, the default request ID validation is used. + ValidateRequestID func(response Response, possibleRequestIDs []string) error } // MaxIssueDelay is the longest allowed time between when a SAML assertion is @@ -247,28 +283,32 @@ func (sp *ServiceProvider) MakeRedirectAuthenticationRequest(relayState string) // Redirect returns a URL suitable for using the redirect binding with the request func (r *AuthnRequest) Redirect(relayState string, sp *ServiceProvider) (*url.URL, error) { - w := &bytes.Buffer{} - w1 := base64.NewEncoder(base64.StdEncoding, w) - w2, _ := flate.NewWriter(w1, 9) + var requestStr strings.Builder + base64Writer := base64.NewEncoder(base64.StdEncoding, &requestStr) + compressedWriter, _ := flate.NewWriter(base64Writer, 9) doc := etree.NewDocument() doc.SetRoot(r.Element()) - if _, err := doc.WriteTo(w2); err != nil { - panic(err) + if _, err := doc.WriteTo(compressedWriter); err != nil { + return nil, err } - if err := w2.Close(); err != nil { - panic(err) + if err := compressedWriter.Close(); err != nil { + return nil, err } - if err := w1.Close(); err != nil { - panic(err) + if err := base64Writer.Close(); err != nil { + return nil, err + } + + rv, err := url.Parse(r.Destination) + if err != nil { + return nil, err } - rv, _ := url.Parse(r.Destination) // We can't depend on Query().set() as order matters for signing query := rv.RawQuery if len(query) > 0 { - query += "&SAMLRequest=" + url.QueryEscape(w.String()) + query += "&SAMLRequest=" + url.QueryEscape(requestStr.String()) } else { - query += "SAMLRequest=" + url.QueryEscape(w.String()) + query += "SAMLRequest=" + url.QueryEscape(requestStr.String()) } if relayState != "" { @@ -378,6 +418,85 @@ func (sp *ServiceProvider) getIDPSigningCerts() ([]*x509.Certificate, error) { return certs, nil } +func (sp *ServiceProvider) getCertBasedOnFingerprint(el *etree.Element) ([]*x509.Certificate, error) { + x509CertEl := el.FindElement("./Signature/KeyInfo/X509Data/X509Certificate") + if x509CertEl == nil { + return nil, fmt.Errorf("cannot validate signature on %s: no certificate present", el.Tag) + } + if len(x509CertEl.Child) != 1 { + return nil, fmt.Errorf("cannot validate signature on %s: x509 cert el child len != 1: %d", el.Tag, len(x509CertEl.Child)) + } + + x509CertElCharData, ok := x509CertEl.Child[0].(*etree.CharData) + if !ok { + return nil, fmt.Errorf("cannot validate signature on %s: x509 cert el first child not char data: %T", el.Tag, x509CertEl.Child[0]) + } + + cert, err := parseCert(x509CertElCharData.Data) + if err != nil { + return nil, fmt.Errorf("cannot validate signature on %s: %w", el.Tag, err) + } + + finP, err := fingerprint(cert, *sp.IDPCertificateFingerprintAlgorithm) + if err != nil { + return nil, fmt.Errorf("cannot validate signature on %s: %w", el.Tag, err) + } + + if *sp.IDPCertificateFingerprint != finP { + return nil, fmt.Errorf("cannot validate signature on %s: fingerprint mismatch", el.Tag) + } + + return []*x509.Certificate{cert}, nil + +} + +func parseCert(x509Data string) (*x509.Certificate, error) { + // cleanup whitespace + regex := regexp.MustCompile(`\s+`) + certStr := regex.ReplaceAllString(x509Data, "") + certBytes, err := base64.StdEncoding.DecodeString(certStr) + if err != nil { + return nil, fmt.Errorf("parse cert, cannot base64 decode cert string: %w", err) + } + + parsedCert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, fmt.Errorf("parse cert, cannot parse certificate: %w", err) + } + + return parsedCert, nil +} + +func fingerprint(cert *x509.Certificate, fingerprintAlgorithm string) (string, error) { + switch fingerprintAlgorithm { + case "http://www.w3.org/2001/04/xmlenc#sha256": + fp := sha256.Sum256(cert.Raw) + return fingerprintFormat(fp[:]) + case "http://www.w3.org/2001/04/xmlenc#sha512": + fp := sha512.Sum512(cert.Raw) + return fingerprintFormat(fp[:]) + default: + return "", fmt.Errorf("fingerprint, unknown algorithm: %s", fingerprintAlgorithm) + } +} + +func fingerprintFormat(fp []byte) (string, error) { + var buf bytes.Buffer + for i, f := range fp { + if i > 0 { + _, err := fmt.Fprintf(&buf, ":") + if err != nil { + return "", fmt.Errorf("fingerprint format, print ':': %w", err) + } + } + _, err := fmt.Fprintf(&buf, "%02X", f) + if err != nil { + return "", fmt.Errorf("fingerprint format, print bytes: %w", err) + } + } + return buf.String(), nil +} + // MakeArtifactResolveRequest produces a new ArtifactResolve object to send to the idp's Artifact resolver func (sp *ServiceProvider) MakeArtifactResolveRequest(artifactID string) (*ArtifactResolve, error) { req := ArtifactResolve{ @@ -447,17 +566,38 @@ func GetSigningContext(sp *ServiceProvider) (*dsig.SigningContext, error) { // for _, cert := range sp.Intermediates { // keyPair.Certificate = append(keyPair.Certificate, cert.Raw) // } - keyStore := dsig.TLSCertKeyStore(keyPair) - if sp.SignatureMethod != dsig.RSASHA1SignatureMethod && - sp.SignatureMethod != dsig.RSASHA256SignatureMethod && - sp.SignatureMethod != dsig.RSASHA512SignatureMethod { + switch sp.SignatureMethod { + case dsig.RSASHA1SignatureMethod, + dsig.RSASHA256SignatureMethod, + dsig.RSASHA384SignatureMethod, + dsig.RSASHA512SignatureMethod: + if _, ok := sp.Key.(*rsa.PrivateKey); !ok { + return nil, fmt.Errorf("signature method %s requires a key of type rsa.PrivateKey, not %T", sp.SignatureMethod, sp.Key) + } + + case dsig.ECDSASHA1SignatureMethod, + dsig.ECDSASHA256SignatureMethod, + dsig.ECDSASHA384SignatureMethod, + dsig.ECDSASHA512SignatureMethod: + if _, ok := sp.Key.(*ecdsa.PrivateKey); !ok { + return nil, fmt.Errorf("signature method %s requires a key of type ecdsa.PrivateKey, not %T", sp.SignatureMethod, sp.Key) + } + default: return nil, fmt.Errorf("invalid signing method %s", sp.SignatureMethod) } - signatureMethod := sp.SignatureMethod - signingContext := dsig.NewDefaultSigningContext(keyStore) + + keyStore := dsig.TLSCertKeyStore(keyPair) + chain, err := keyStore.GetChain() + if err != nil { + return nil, err + } + signingContext, err := dsig.NewSigningContext(sp.Key, chain) + if err != nil { + return nil, err + } signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList) - if err := signingContext.SetSignatureMethod(signatureMethod); err != nil { + if err := signingContext.SetSignatureMethod(sp.SignatureMethod); err != nil { return nil, err } @@ -652,7 +792,7 @@ func (sp *ServiceProvider) handleArtifactRequest(ctx context.Context, artifactID retErr.PrivateErr = fmt.Errorf("Error during artifact resolution: %s", err) return nil, retErr } - assertion, err := sp.ParseXMLArtifactResponse(responseBody, possibleRequestIDs, artifactResolveRequest.ID) + assertion, err := sp.ParseXMLArtifactResponse(responseBody, possibleRequestIDs, artifactResolveRequest.ID, *req.URL) if err != nil { return nil, err } @@ -670,7 +810,7 @@ func (sp *ServiceProvider) parseResponseHTTP(req *http.Request, possibleRequestI return nil, retErr } - assertion, err := sp.ParseXMLResponse(rawResponseBuf, possibleRequestIDs) + assertion, err := sp.ParseXMLResponse(rawResponseBuf, possibleRequestIDs, *req.URL) if err != nil { return nil, err } @@ -687,7 +827,7 @@ func (sp *ServiceProvider) parseResponseHTTP(req *http.Request, possibleRequestI // properties are useful in describing which part of the parsing process // failed. However, to discourage inadvertent disclosure the diagnostic // information, the Error() method returns a static string. -func (sp *ServiceProvider) ParseXMLArtifactResponse(soapResponseXML []byte, possibleRequestIDs []string, artifactRequestID string) (*Assertion, error) { +func (sp *ServiceProvider) ParseXMLArtifactResponse(soapResponseXML []byte, possibleRequestIDs []string, artifactRequestID string, currentURL url.URL) (*Assertion, error) { now := TimeNow() retErr := &InvalidResponseError{ Response: string(soapResponseXML), @@ -727,10 +867,10 @@ func (sp *ServiceProvider) ParseXMLArtifactResponse(soapResponseXML []byte, poss return nil, retErr } - return sp.parseArtifactResponse(artifactResponseEl, possibleRequestIDs, artifactRequestID, now) + return sp.parseArtifactResponse(artifactResponseEl, possibleRequestIDs, artifactRequestID, now, currentURL) } -func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Element, possibleRequestIDs []string, artifactRequestID string, now time.Time) (*Assertion, error) { +func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Element, possibleRequestIDs []string, artifactRequestID string, now time.Time, currentURL url.URL) (*Assertion, error) { retErr := &InvalidResponseError{ Now: now, Response: elementToString(artifactResponseEl), @@ -778,7 +918,7 @@ func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Eleme return nil, retErr } - assertion, err := sp.parseResponse(responseEl, possibleRequestIDs, now, signatureRequirement) + assertion, err := sp.parseResponse(responseEl, possibleRequestIDs, now, signatureRequirement, currentURL) if err != nil { retErr.PrivateErr = err return nil, retErr @@ -798,7 +938,7 @@ func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Eleme // properties are useful in describing which part of the parsing process // failed. However, to discourage inadvertent disclosure the diagnostic // information, the Error() method returns a static string. -func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleRequestIDs []string) (*Assertion, error) { +func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleRequestIDs []string, currentURL url.URL) (*Assertion, error) { now := TimeNow() var err error retErr := &InvalidResponseError{ @@ -822,7 +962,7 @@ func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleR return nil, retErr } - assertion, err := sp.parseResponse(doc.Root(), possibleRequestIDs, now, signatureRequired) + assertion, err := sp.parseResponse(doc.Root(), possibleRequestIDs, now, signatureRequired, currentURL) if err != nil { retErr.PrivateErr = err return nil, retErr @@ -844,7 +984,7 @@ const ( // This function handles decrypting the message, verifying the digital // signature on the assertion, and verifying that the specified conditions // and properties are met. -func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) { +func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement, currentURL url.URL) (*Assertion, error) { var responseSignatureErr error var responseHasSignature bool if signatureRequirement == signatureRequired { @@ -867,23 +1007,16 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ // If the response is *not* signed, the Destination may be omitted. if responseHasSignature || response.Destination != "" { - if response.Destination != sp.AcsURL.String() { - return nil, fmt.Errorf("`Destination` does not match AcsURL (expected %q, actual %q)", sp.AcsURL.String(), response.Destination) + // Per section 3.4.5.2 of the SAML spec, Destination must match the location at which the response was received, i.e. currentURL. + // Historically, we checked against the SP's ACS URL instead of currentURL, which is usually the same but may differ in query params. + // To mitigate the risk of switching to comparing against currentURL, we still allow it if the ACS URL matches, even if the current URL doesn't. + if response.Destination != currentURL.String() && response.Destination != sp.AcsURL.String() { + return nil, fmt.Errorf("`Destination` does not match requested URL or AcsURL (destination %q, requested %q, acs %q)", response.Destination, currentURL.String(), sp.AcsURL.String()) } } - requestIDvalid := false - if sp.AllowIDPInitiated { - requestIDvalid = true - } else { - for _, possibleRequestID := range possibleRequestIDs { - if response.InResponseTo == possibleRequestID { - requestIDvalid = true - } - } - } - if !requestIDvalid { - return nil, fmt.Errorf("`InResponseTo` does not match any of the possible request IDs (expected %v)", possibleRequestIDs) + if err := sp.validateRequestID(response, possibleRequestIDs); err != nil { + return nil, err } if response.IssueInstant.Add(MaxIssueDelay).Before(now) { @@ -959,6 +1092,27 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ return &assertions[0], nil } +func (sp *ServiceProvider) validateRequestID(response Response, possibleRequestIDs []string) error { + if sp.ValidateRequestID != nil { + return sp.ValidateRequestID(response, possibleRequestIDs) + } + + requestIDvalid := false + if sp.AllowIDPInitiated { + requestIDvalid = true + } else { + for _, possibleRequestID := range possibleRequestIDs { + if response.InResponseTo == possibleRequestID { + requestIDvalid = true + } + } + } + if !requestIDvalid { + return fmt.Errorf("`InResponseTo` does not match any of the possible request IDs (expected %v)", possibleRequestIDs) + } + return nil +} + func (sp *ServiceProvider) parseEncryptedAssertion(encryptedAssertionEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) { assertionEl, err := sp.decryptElement(encryptedAssertionEl) if err != nil { @@ -1076,6 +1230,20 @@ func (sp *ServiceProvider) validateAssertion(assertion *Assertion, possibleReque return fmt.Errorf("assertion Conditions is expired") } + if err := sp.validateAudienceRestriction(assertion); err != nil { + return err + } + return nil +} + +func (sp *ServiceProvider) validateAudienceRestriction(assertion *Assertion) error { + if sp.ValidateAudienceRestriction != nil { + if err := sp.ValidateAudienceRestriction(assertion); err != nil { + return fmt.Errorf("audience restriction validation failed: %w", err) + } + return nil + } + audienceRestrictionsValid := len(assertion.Conditions.AudienceRestrictions) == 0 audience := firstSet(sp.EntityID, sp.MetadataURL.String()) for _, audienceRestriction := range assertion.Conditions.AudienceRestrictions { @@ -1101,9 +1269,28 @@ func (sp *ServiceProvider) validateSignature(el *etree.Element) error { return errSignatureElementNotPresent } - certs, err := sp.getIDPSigningCerts() - if err != nil { - return fmt.Errorf("cannot validate signature on %s: %v", el.Tag, err) + var certs []*x509.Certificate + if sp.IDPMetadata != nil && sp.IDPCertificateFingerprint == nil && sp.IDPCertificateFingerprintAlgorithm == nil && sp.IDPCertificate == nil { + certs, err = sp.getIDPSigningCerts() + if err != nil { + return fmt.Errorf("cannot validate signature on %s: %v", el.Tag, err) + } + } + if sp.IDPMetadata != nil && sp.IDPCertificateFingerprint != nil && sp.IDPCertificateFingerprintAlgorithm != nil && sp.IDPCertificate == nil { + certs, err = sp.getCertBasedOnFingerprint(el) + if err != nil { + return fmt.Errorf("cannot validate signature on %s: %v", el.Tag, err) + } + } + if sp.IDPMetadata != nil && sp.IDPCertificateFingerprint == nil && sp.IDPCertificateFingerprintAlgorithm == nil && sp.IDPCertificate != nil { + cert, err := parseCert(*sp.IDPCertificate) + if err != nil { + return fmt.Errorf("cannot validate signature on %s: %w", el.Tag, err) + } + certs = append(certs, cert) + } + if len(certs) == 0 { + return fmt.Errorf("cannot validate signature on %s: saml config not set up properly, specify either idp metadata url, fingerprints or actual certificate", el.Tag) } certificateStore := dsig.MemoryX509CertificateStore{ @@ -1159,31 +1346,12 @@ func (sp *ServiceProvider) validateSignature(el *etree.Element) error { // SignLogoutRequest adds the `Signature` element to the `LogoutRequest`. func (sp *ServiceProvider) SignLogoutRequest(req *LogoutRequest) error { - keyPair := tls.Certificate{ - Certificate: [][]byte{sp.Certificate.Raw}, - PrivateKey: sp.Key, - Leaf: sp.Certificate, - } - // TODO: add intermediates for SP - // for _, cert := range sp.Intermediates { - // keyPair.Certificate = append(keyPair.Certificate, cert.Raw) - // } - keyStore := dsig.TLSCertKeyStore(keyPair) - - if sp.SignatureMethod != dsig.RSASHA1SignatureMethod && - sp.SignatureMethod != dsig.RSASHA256SignatureMethod && - sp.SignatureMethod != dsig.RSASHA512SignatureMethod { - return fmt.Errorf("invalid signing method %s", sp.SignatureMethod) - } - signatureMethod := sp.SignatureMethod - signingContext := dsig.NewDefaultSigningContext(keyStore) - signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList) - if err := signingContext.SetSignatureMethod(signatureMethod); err != nil { + signingContext, err := GetSigningContext(sp) + if err != nil { return err } assertionEl := req.Element() - signedRequestEl, err := signingContext.SignEnveloped(assertionEl) if err != nil { return err @@ -1213,7 +1381,7 @@ func (sp *ServiceProvider) MakeLogoutRequest(idpURL, nameID string) (*LogoutRequ SPNameQualifier: sp.Metadata().EntityID, }, } - if len(sp.SignatureMethod) > 0 { + if sp.SignatureMethod != "" { if err := sp.SignLogoutRequest(&req); err != nil { return nil, err } @@ -1327,7 +1495,7 @@ func (sp *ServiceProvider) MakeLogoutResponse(idpURL, logoutRequestID string) (* }, } - if len(sp.SignatureMethod) > 0 { + if sp.SignatureMethod != "" { if err := sp.SignLogoutResponse(&response); err != nil { return nil, err } @@ -1424,31 +1592,12 @@ func (r *LogoutResponse) Post(relayState string) []byte { // SignLogoutResponse adds the `Signature` element to the `LogoutResponse`. func (sp *ServiceProvider) SignLogoutResponse(resp *LogoutResponse) error { - keyPair := tls.Certificate{ - Certificate: [][]byte{sp.Certificate.Raw}, - PrivateKey: sp.Key, - Leaf: sp.Certificate, - } - // TODO: add intermediates for SP - // for _, cert := range sp.Intermediates { - // keyPair.Certificate = append(keyPair.Certificate, cert.Raw) - // } - keyStore := dsig.TLSCertKeyStore(keyPair) - - if sp.SignatureMethod != dsig.RSASHA1SignatureMethod && - sp.SignatureMethod != dsig.RSASHA256SignatureMethod && - sp.SignatureMethod != dsig.RSASHA512SignatureMethod { - return fmt.Errorf("invalid signing method %s", sp.SignatureMethod) - } - signatureMethod := sp.SignatureMethod - signingContext := dsig.NewDefaultSigningContext(keyStore) - signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList) - if err := signingContext.SetSignatureMethod(signatureMethod); err != nil { + signingContext, err := GetSigningContext(sp) + if err != nil { return err } assertionEl := resp.Element() - signedRequestEl, err := signingContext.SignEnveloped(assertionEl) if err != nil { return err @@ -1551,7 +1700,7 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData str } doc := etree.NewDocument() - if err := doc.ReadFromBytes(rawResponseBuf); err != nil { + if err := doc.ReadFromBytes(gr); err != nil { retErr.PrivateErr = err return retErr } @@ -1672,9 +1821,14 @@ func elementToBytes(el *etree.Element) ([]byte, error) { doc := etree.NewDocument() doc.SetRoot(el.Copy()) for space, uri := range namespaces { - doc.Root().CreateAttr("xmlns:"+space, uri) + if space == "" && len(doc.Root().SelectAttr("xmlns").Value) == 0 { + doc.Root().CreateAttr("xmlns", uri) + } else { + doc.Root().CreateAttr("xmlns:"+space, uri) + } } - + xmlstr, _ := doc.WriteToString() + fmt.Printf("%s", xmlstr) return doc.WriteToBytes() } diff --git a/service_provider_test.go b/service_provider_test.go index c75f370c..718e8704 100644 --- a/service_provider_test.go +++ b/service_provider_test.go @@ -267,39 +267,53 @@ func TestSPCanProducePostRequest(t *testing.T) { } func TestSPCanProduceSignedRequestRedirectBinding(t *testing.T) { - test := NewServiceProviderTest(t) - TimeNow = func() time.Time { - rv, _ := time.Parse("Mon Jan 2 15:04:05.999999999 UTC 2006", "Mon Dec 1 01:31:21.123456789 UTC 2015") - return rv - } - Clock = dsig.NewFakeClockAt(TimeNow()) - s := ServiceProvider{ - Key: test.Key, - Certificate: test.Certificate, - MetadataURL: mustParseURL("https://15661444.ngrok.io/saml2/metadata"), - AcsURL: mustParseURL("https://15661444.ngrok.io/saml2/acs"), - IDPMetadata: &EntityDescriptor{}, - SignatureMethod: dsig.RSASHA1SignatureMethod, + for _, alg := range []string{ + dsig.RSASHA1SignatureMethod, + dsig.RSASHA256SignatureMethod, + dsig.RSASHA384SignatureMethod, + dsig.RSASHA512SignatureMethod, + dsig.ECDSASHA1SignatureMethod, + dsig.ECDSASHA256SignatureMethod, + dsig.ECDSASHA384SignatureMethod, + dsig.ECDSASHA512SignatureMethod, + } { + testName := strings.Split(alg, "#")[1] + t.Run(testName, func(t *testing.T) { + test := NewServiceProviderTest(t) + TimeNow = func() time.Time { + rv, _ := time.Parse("Mon Jan 2 15:04:05.999999999 UTC 2006", "Mon Dec 1 01:31:21.123456789 UTC 2015") + return rv + } + Clock = dsig.NewFakeClockAt(TimeNow()) + s := ServiceProvider{ + Key: test.Key, + Certificate: test.Certificate, + MetadataURL: mustParseURL("https://15661444.ngrok.io/saml2/metadata"), + AcsURL: mustParseURL("https://15661444.ngrok.io/saml2/acs"), + IDPMetadata: &EntityDescriptor{}, + SignatureMethod: dsig.RSASHA1SignatureMethod, + } + err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) + assert.Check(t, err) + + redirectURL, err := s.MakeRedirectAuthenticationRequest("relayState") + assert.Assert(t, err) + // Signature we check against in the query string was validated with + // https://www.samltool.com/validate_authn_req.php . Once we add + // support for validating signed AuthN requests in the IDP implementation + // we can switch to testing using that. + golden.Assert(t, redirectURL.RawQuery, t.Name()+"_queryString") + + decodedRequest, err := testsaml.ParseRedirectRequest(redirectURL) + assert.Check(t, err) + assert.Check(t, is.Equal("idp.testshib.org", + redirectURL.Host)) + assert.Check(t, is.Equal("/idp/profile/SAML2/Redirect/SSO", + redirectURL.Path)) + // Contains no enveloped signature + golden.Assert(t, string(decodedRequest), t.Name()+"_decodedRequest") + }) } - err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) - assert.Check(t, err) - - redirectURL, err := s.MakeRedirectAuthenticationRequest("relayState") - assert.Check(t, err) - // Signature we check against in the query string was validated with - // https://www.samltool.com/validate_authn_req.php . Once we add - // support for validating signed AuthN requests in the IDP implementation - // we can switch to testing using that. - golden.Assert(t, redirectURL.RawQuery, t.Name()+"_queryString") - - decodedRequest, err := testsaml.ParseRedirectRequest(redirectURL) - assert.Check(t, err) - assert.Check(t, is.Equal("idp.testshib.org", - redirectURL.Host)) - assert.Check(t, is.Equal("/idp/profile/SAML2/Redirect/SSO", - redirectURL.Path)) - // Contains no enveloped signature - golden.Assert(t, string(decodedRequest), t.Name()+"_decodedRequest") } func TestSPCanProduceSignedRequestPostBinding(t *testing.T) { @@ -471,7 +485,7 @@ func TestSPCanHandleOneloginResponse(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(SamlResponse)) assertion, err := s.ParseResponse(&req, []string{"id-d40c15c104b52691eccf0a2a5c8a15595be75423"}) assert.Check(t, err) @@ -552,7 +566,7 @@ func TestSPCanHandleOktaSignedResponseEncryptedAssertion(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(SamlResponse)) assertion, err := s.ParseResponse(&req, []string{"id-a7364d1e4432aa9085a7a8bd824ea2fa8fa8f684"}) assert.Check(t, err) @@ -593,7 +607,7 @@ func TestSPCanHandleOktaResponseEncryptedSignedAssertion(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(SamlResponse)) assertion, err := s.ParseResponse(&req, []string{"id-6d976cdde8e76df5df0a8ff58148fc0b7ec6796d"}) assert.Check(t, err) @@ -634,7 +648,7 @@ func TestSPCanHandleOktaResponseEncryptedAssertionBothSigned(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(SamlResponse)) assertion, err := s.ParseResponse(&req, []string{"id-953d4cab69ff475c5901d12e585b0bb15a7b85fe"}) assert.Check(t, err) @@ -675,7 +689,7 @@ func TestSPCanHandlePlaintextResponse(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(SamlResponse)) assertion, err := s.ParseResponse(&req, []string{"id-fd419a5ab0472645427f8e07d87a3a5dd0b2e9a6"}) assert.Check(t, err) @@ -739,7 +753,7 @@ func TestSPRejectsInjectedComment(t *testing.T) { // this is a valid response { - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(SamlResponse)) assertion, err := s.ParseResponse(&req, []string{"id-fd419a5ab0472645427f8e07d87a3a5dd0b2e9a6"}) assert.Check(t, err) @@ -752,7 +766,7 @@ func TestSPRejectsInjectedComment(t *testing.T) { y := strings.Replace(string(x), "ross@octolabs.io", "ross@octolabs.io", 1) SamlResponse = []byte(base64.StdEncoding.EncodeToString([]byte(y))) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(SamlResponse)) assertion, err := s.ParseResponse(&req, []string{"id-fd419a5ab0472645427f8e07d87a3a5dd0b2e9a6"}) @@ -774,7 +788,7 @@ func TestSPRejectsInjectedComment(t *testing.T) { y := strings.Replace(string(x), "ross@octolabs.io", "ross@octolabs.io.example.com", 1) SamlResponse = []byte(base64.StdEncoding.EncodeToString([]byte(y))) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(SamlResponse)) _, err := s.ParseResponse(&req, []string{"id-fd419a5ab0472645427f8e07d87a3a5dd0b2e9a6"}) assert.Check(t, err != nil) @@ -797,7 +811,7 @@ func TestSPCanParseResponse(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) assertion, err := s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) assert.Check(t, err) @@ -944,7 +958,7 @@ func TestSPCanProcessResponseWithoutDestination(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} test.replaceDestination("") req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) @@ -972,6 +986,13 @@ func removeDestinationFromDocument(doc *etree.Document) *etree.Document { return doc } +func overrideDestinationFromDocument(doc *etree.Document, newDestination string) *etree.Document { + responseEl := doc.FindElement("//Response") + destAttr := responseEl.SelectAttr("Destination") + destAttr.Value = newDestination + return doc +} + func TestServiceProviderMismatchedDestinationsWithSignaturePresent(t *testing.T) { test := NewServiceProviderTest(t) s := ServiceProvider{ @@ -984,13 +1005,61 @@ func TestServiceProviderMismatchedDestinationsWithSignaturePresent(t *testing.T) err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} s.AcsURL = mustParseURL("https://wrong/saml2/acs") bytes, _ := test.responseDom(t).WriteToBytes() req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "`Destination` does not match AcsURL (expected \"https://wrong/saml2/acs\", actual \"https://15661444.ngrok.io/saml2/acs\")")) + "`Destination` does not match requested URL or AcsURL (destination \"https://15661444.ngrok.io/saml2/acs\", requested \"https://wrong/saml2/acs\", acs \"https://wrong/saml2/acs\")")) +} + +func TestDestinationMatchesCurrentUrlButNotAcsUrlWithSignaturePresent(t *testing.T) { + test := NewServiceProviderTest(t) + s := ServiceProvider{ + Key: test.Key, + Certificate: test.Certificate, + MetadataURL: mustParseURL("https://15661444.ngrok.io/saml2/metadata"), + AcsURL: mustParseURL("https://15661444.ngrok.io/saml2/acs"), + IDPMetadata: &EntityDescriptor{}, + } + err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) + assert.Check(t, err) + + currentURL := mustParseURL("https://15661444.ngrok.io/saml2/acs?current=true") + req := http.Request{PostForm: url.Values{}, URL: ¤tURL} + bytes, _ := overrideDestinationFromDocument(test.responseDom(t), "https://15661444.ngrok.io/saml2/acs?current=true").WriteToBytes() + req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) + assertion, err := s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) + if err != nil { + t.Logf("%s", err.(*InvalidResponseError).PrivateErr) + } + assert.Check(t, err) + assert.Check(t, is.Equal("_41bd295976dadd70e1480f318e772841", assertion.Subject.NameID.Value)) +} + +func TestDestinationMatchesAcsUrlButNotCurrentUrlWithSignaturePresent(t *testing.T) { + test := NewServiceProviderTest(t) + s := ServiceProvider{ + Key: test.Key, + Certificate: test.Certificate, + MetadataURL: mustParseURL("https://15661444.ngrok.io/saml2/metadata"), + AcsURL: mustParseURL("https://15661444.ngrok.io/saml2/acs"), + IDPMetadata: &EntityDescriptor{}, + } + err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) + assert.Check(t, err) + + currentURL := mustParseURL("https://15661444.ngrok.io/saml2/acs?query=param") + req := http.Request{PostForm: url.Values{}, URL: ¤tURL} + bytes, _ := test.responseDom(t).WriteToBytes() + req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) + assertion, err := s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) + if err != nil { + t.Logf("%s", err.(*InvalidResponseError).PrivateErr) + } + assert.Check(t, err) + assert.Check(t, is.Equal("_41bd295976dadd70e1480f318e772841", assertion.Subject.NameID.Value)) } func TestServiceProviderMissingDestinationWithSignaturePresent(t *testing.T) { @@ -1005,12 +1074,12 @@ func TestServiceProviderMissingDestinationWithSignaturePresent(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} bytes, _ := removeDestinationFromDocument(addSignatureToDocument(test.responseDom(t))).WriteToBytes() req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "`Destination` does not match AcsURL (expected \"https://15661444.ngrok.io/saml2/acs\", actual \"\")")) + "`Destination` does not match requested URL or AcsURL (destination \"\", requested \"https://15661444.ngrok.io/saml2/acs\", acs \"https://15661444.ngrok.io/saml2/acs\")")) } func TestSPMismatchedDestinationsWithSignaturePresent(t *testing.T) { @@ -1025,13 +1094,13 @@ func TestSPMismatchedDestinationsWithSignaturePresent(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} test.replaceDestination("https://wrong/saml2/acs") bytes, _ := addSignatureToDocument(test.responseDom(t)).WriteToBytes() req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "`Destination` does not match AcsURL (expected \"https://15661444.ngrok.io/saml2/acs\", actual \"https://wrong/saml2/acs\")")) + "`Destination` does not match requested URL or AcsURL (destination \"https://wrong/saml2/acs\", requested \"https://15661444.ngrok.io/saml2/acs\", acs \"https://15661444.ngrok.io/saml2/acs\")")) } func TestSPMismatchedDestinationsWithNoSignaturePresent(t *testing.T) { @@ -1046,13 +1115,13 @@ func TestSPMismatchedDestinationsWithNoSignaturePresent(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} test.replaceDestination("https://wrong/saml2/acs") bytes, _ := test.responseDom(t).WriteToBytes() req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "`Destination` does not match AcsURL (expected \"https://15661444.ngrok.io/saml2/acs\", actual \"https://wrong/saml2/acs\")")) + "`Destination` does not match requested URL or AcsURL (destination \"https://wrong/saml2/acs\", requested \"https://15661444.ngrok.io/saml2/acs\", acs \"https://15661444.ngrok.io/saml2/acs\")")) } func TestSPMissingDestinationWithSignaturePresent(t *testing.T) { @@ -1067,13 +1136,13 @@ func TestSPMissingDestinationWithSignaturePresent(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} test.replaceDestination("") bytes, _ := addSignatureToDocument(test.responseDom(t)).WriteToBytes() req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "`Destination` does not match AcsURL (expected \"https://15661444.ngrok.io/saml2/acs\", actual \"\")")) + "`Destination` does not match requested URL or AcsURL (destination \"\", requested \"https://15661444.ngrok.io/saml2/acs\", acs \"https://15661444.ngrok.io/saml2/acs\")")) } func TestSPInvalidAssertions(t *testing.T) { @@ -1187,7 +1256,7 @@ func TestXswPermutationOneIsRejected(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"id-d40c15c104b52691eccf0a2a5c8a15595be75423"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, @@ -1214,7 +1283,7 @@ func TestXswPermutationTwoIsRejected(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"id-d40c15c104b52691eccf0a2a5c8a15595be75423"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, @@ -1241,7 +1310,7 @@ func TestXswPermutationThreeIsRejected(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"}) @@ -1273,7 +1342,7 @@ func TestXswPermutationFourIsRejected(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"}) @@ -1303,7 +1372,7 @@ func TestXswPermutationFiveIsRejected(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, @@ -1330,7 +1399,7 @@ func TestXswPermutationSixIsRejected(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, @@ -1360,7 +1429,7 @@ func TestXswPermutationSevenIsRejected(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"}) // It's the assertion signature that can't be verified. The error message is generic and always mentions Response @@ -1391,7 +1460,7 @@ func TestXswPermutationEightIsRejected(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"}) // It's the assertion signature that can't be verified. The error message is generic and always mentions Response @@ -1422,7 +1491,7 @@ func TestXswPermutationNineIsRejected(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"}) // It's the assertion signature that can't be verified. The error message is generic and always mentions Response @@ -1449,7 +1518,7 @@ func TestSPRealWorldKeyInfoHasRSAPublicKeyNotX509Cert(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(respStr)) _, err = s.ParseResponse(&req, []string{"id-3992f74e652d89c3cf1efd6c7e472abaac9bc917"}) if err != nil { @@ -1480,7 +1549,7 @@ func TestSPRealWorldAssertionSignedNotResponse(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(respStr)) _, err = s.ParseResponse(&req, []string{"id-3992f74e652d89c3cf1efd6c7e472abaac9bc917"}) if err != nil { @@ -1519,7 +1588,7 @@ func TestServiceProviderCanHandleSignedAssertionsResponse(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(SamlResponse)) assertion, err := s.ParseResponse(&req, []string{"ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"}) if err != nil { @@ -1582,7 +1651,7 @@ func TestSPResponseWithNoIssuer(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} // Response with no (modified ServiceProviderTest.SamlResponse) samlResponse := golden.Get(t, "TestSPResponseWithNoIssuer_response") @@ -1695,7 +1764,7 @@ func TestParseXMLArtifactResponse(t *testing.T) { possibleReqIDs := []string{"id-f3c7bc7d626a4ededa6028b718e5252c6e770b94"} reqID := "id-218eb155248f7db7c85fe4e2709a3f17a70d09c7" - assertion, err := sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID) + assertion, err := sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID, sp.AcsURL) assert.Check(t, err) x, err := xml.Marshal(assertion) @@ -1727,7 +1796,7 @@ func TestParseBadXMLArtifactResponse(t *testing.T) { IDPMetadata: &EntityDescriptor{}, } - assertion, err := sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID) + assertion, err := sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID, sp.AcsURL) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "response Issuer does not match the IDP metadata (expected \"\")")) assert.Check(t, is.Nil(assertion)) @@ -1735,9 +1804,9 @@ func TestParseBadXMLArtifactResponse(t *testing.T) { err = xml.Unmarshal(test.IDPMetadata, &sp.IDPMetadata) assert.Check(t, err) - assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID) + assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID, sp.AcsURL) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "`Destination` does not match AcsURL (expected \"https://example.com/saml2/acs\", actual \"http://localhost:8000/saml/acs\")")) + "`Destination` does not match requested URL or AcsURL (destination \"http://localhost:8000/saml/acs\", requested \"https://example.com/saml2/acs\", acs \"https://example.com/saml2/acs\")")) assert.Check(t, is.Nil(assertion)) sp.AcsURL = mustParseURL("http://localhost:8000/saml/acs") @@ -1748,7 +1817,7 @@ func TestParseBadXMLArtifactResponse(t *testing.T) { return rv } - assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID) + assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID, sp.AcsURL) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "response IssueInstant expired at 2021-08-17 10:28:50.146 +0000 UTC")) assert.Check(t, is.Nil(assertion)) @@ -1763,38 +1832,38 @@ func TestParseBadXMLArtifactResponse(t *testing.T) { return rv } - assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID) + assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID, sp.AcsURL) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "cannot validate signature on ArtifactResponse: Cert is not valid at this time")) assert.Check(t, is.Nil(assertion)) Clock = dsig.NewFakeClockAt(TimeNow()) wrongReqID := "id-218eb155248f7db7c85fe4e2709a3f17a70d09c8" - assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, wrongReqID) + assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, wrongReqID, sp.AcsURL) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "`InResponseTo` does not match the artifact request ID (expected id-218eb155248f7db7c85fe4e2709a3f17a70d09c8)")) assert.Check(t, is.Nil(assertion)) wrongPossibleReqIDs := []string{"id-f3c7bc7d626a4ededa6028b718e5252c6e770b95"} - assertion, err = sp.ParseXMLArtifactResponse(samlResponse, wrongPossibleReqIDs, reqID) + assertion, err = sp.ParseXMLArtifactResponse(samlResponse, wrongPossibleReqIDs, reqID, sp.AcsURL) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "`InResponseTo` does not match any of the possible request IDs (expected [id-f3c7bc7d626a4ededa6028b718e5252c6e770b95])")) assert.Check(t, is.Nil(assertion)) // random other key sp.Key = mustParsePrivateKey(golden.Get(t, "key_2017.pem")).(*rsa.PrivateKey) - assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID) + assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID, sp.AcsURL) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "failed to decrypt EncryptedAssertion: certificate does not match provided key")) assert.Check(t, is.Nil(assertion)) // no input - assertion, err = sp.ParseXMLArtifactResponse([]byte(""), possibleReqIDs, reqID) + assertion, err = sp.ParseXMLArtifactResponse([]byte(""), possibleReqIDs, reqID, sp.AcsURL) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "invalid xml: no root")) assert.Check(t, is.Nil(assertion)) - assertion, err = sp.ParseXMLArtifactResponse([]byte(""), []string{}) + assertion, err := sp.ParseXMLResponse([]byte(""), []string{}, mustParseURL("http://test.com")) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "invalid xml: no root")) assert.Check(t, is.Nil(assertion)) - assertion, err = sp.ParseXMLResponse([]byte("
\ No newline at end of file +
\ No newline at end of file diff --git a/testdata/TestIDPCanHandleRequestWithExistingSession_http_response_body b/testdata/TestIDPCanHandleRequestWithExistingSession_http_response_body index 3af5e379..b63309c5 100644 --- a/testdata/TestIDPCanHandleRequestWithExistingSession_http_response_body +++ b/testdata/TestIDPCanHandleRequestWithExistingSession_http_response_body @@ -1 +1 @@ -
\ No newline at end of file +
\ No newline at end of file diff --git a/testdata/TestIDPIDPInitiatedExistingSession_response b/testdata/TestIDPIDPInitiatedExistingSession_response index afd22ef3..731f0fba 100644 --- a/testdata/TestIDPIDPInitiatedExistingSession_response +++ b/testdata/TestIDPIDPInitiatedExistingSession_response @@ -1 +1 @@ -
\ No newline at end of file +
\ No newline at end of file diff --git a/testdata/TestSPCanProduceSignedRequestRedirectBinding_decodedRequest b/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha1_decodedRequest similarity index 100% rename from testdata/TestSPCanProduceSignedRequestRedirectBinding_decodedRequest rename to testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha1_decodedRequest diff --git a/testdata/TestSPCanProduceSignedRequestRedirectBinding_queryString b/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha1_queryString similarity index 100% rename from testdata/TestSPCanProduceSignedRequestRedirectBinding_queryString rename to testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha1_queryString diff --git a/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha256_decodedRequest b/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha256_decodedRequest new file mode 100644 index 00000000..6cae3525 --- /dev/null +++ b/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha256_decodedRequest @@ -0,0 +1 @@ +https://15661444.ngrok.io/saml2/metadata \ No newline at end of file diff --git a/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha256_queryString b/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha256_queryString new file mode 100644 index 00000000..41ddd26c --- /dev/null +++ b/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha256_queryString @@ -0,0 +1 @@ +SAMLRequest=nFJNb9swDP0rhu62RNU1CqE2kDUYFqBbgzjbYTfVZhNituSJ9Lb%2B%2B8Fph2WXDOiV4uP70LtlPw6TW81yDDv8PiNL9mscArvloVZzCi56JnbBj8hOOteuPt47WxjnmTEJxaDOINNlzJSixC4OKtusa0V9boyxpjSVuTHedAbBgIUSKrgBDx2gNdba0lYq%2B4KJKYZa2cKobMM84yaw%2BCC1sgauc7C5gb0BdwXOQgH26qvK1shCwcsJeRSZ2GlN%2FVQIsvCRHouYDstATyk%2B0YB60Wr1DntK2Ilu2weVrf5YvYuB5xFTi%2BkHdfh5d%2F%2F3KlxXFZRlWYRDit8KinoJxGrfscq2r8bfUegpHC6n9PiyxO7Dfr%2FNtw%2FtXjWnn3In2yl7H9Po5fKRZUJ9%2FnRadRiE5Fk1%2FxM7ovjei7%2FVZ3zNa00%2B%2BRE3620cqHt%2BgwZJPjBhEJWthiH%2BvEvoBWslaUalmxfKf8vY%2FA4AAP%2F%2F&RelayState=relayState&SigAlg=http%3A%2F%2Fwww.w3.org%2F2000%2F09%2Fxmldsig%23rsa-sha1&Signature=WqMc7vKRJVNXwNHJmTemdfw5OML2XkLntYw%2FzwKoLMfavV%2FYy6fBP0GeGYlJVMweZBvbpjwoe%2BgpRkUCHKDUgixCG7hPi41p6MpQC%2Fp7ExTW5plvlS97iVAOvaF5V1MjvQCgBNKYnKNnvwAuxK%2Bu3N4rZjwGM%2F4JGgjJ5pannFQ%3D \ No newline at end of file diff --git a/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha384_decodedRequest b/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha384_decodedRequest new file mode 100644 index 00000000..6cae3525 --- /dev/null +++ b/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha384_decodedRequest @@ -0,0 +1 @@ +https://15661444.ngrok.io/saml2/metadata \ No newline at end of file diff --git a/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha384_queryString b/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha384_queryString new file mode 100644 index 00000000..41ddd26c --- /dev/null +++ b/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha384_queryString @@ -0,0 +1 @@ +SAMLRequest=nFJNb9swDP0rhu62RNU1CqE2kDUYFqBbgzjbYTfVZhNituSJ9Lb%2B%2B8Fph2WXDOiV4uP70LtlPw6TW81yDDv8PiNL9mscArvloVZzCi56JnbBj8hOOteuPt47WxjnmTEJxaDOINNlzJSixC4OKtusa0V9boyxpjSVuTHedAbBgIUSKrgBDx2gNdba0lYq%2B4KJKYZa2cKobMM84yaw%2BCC1sgauc7C5gb0BdwXOQgH26qvK1shCwcsJeRSZ2GlN%2FVQIsvCRHouYDstATyk%2B0YB60Wr1DntK2Ilu2weVrf5YvYuB5xFTi%2BkHdfh5d%2F%2F3KlxXFZRlWYRDit8KinoJxGrfscq2r8bfUegpHC6n9PiyxO7Dfr%2FNtw%2FtXjWnn3In2yl7H9Po5fKRZUJ9%2FnRadRiE5Fk1%2FxM7ovjei7%2FVZ3zNa00%2B%2BRE3620cqHt%2BgwZJPjBhEJWthiH%2BvEvoBWslaUalmxfKf8vY%2FA4AAP%2F%2F&RelayState=relayState&SigAlg=http%3A%2F%2Fwww.w3.org%2F2000%2F09%2Fxmldsig%23rsa-sha1&Signature=WqMc7vKRJVNXwNHJmTemdfw5OML2XkLntYw%2FzwKoLMfavV%2FYy6fBP0GeGYlJVMweZBvbpjwoe%2BgpRkUCHKDUgixCG7hPi41p6MpQC%2Fp7ExTW5plvlS97iVAOvaF5V1MjvQCgBNKYnKNnvwAuxK%2Bu3N4rZjwGM%2F4JGgjJ5pannFQ%3D \ No newline at end of file diff --git a/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha512_decodedRequest b/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha512_decodedRequest new file mode 100644 index 00000000..6cae3525 --- /dev/null +++ b/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha512_decodedRequest @@ -0,0 +1 @@ +https://15661444.ngrok.io/saml2/metadata \ No newline at end of file diff --git a/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha512_queryString b/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha512_queryString new file mode 100644 index 00000000..41ddd26c --- /dev/null +++ b/testdata/TestSPCanProduceSignedRequestRedirectBinding/ecdsa-sha512_queryString @@ -0,0 +1 @@ +SAMLRequest=nFJNb9swDP0rhu62RNU1CqE2kDUYFqBbgzjbYTfVZhNituSJ9Lb%2B%2B8Fph2WXDOiV4uP70LtlPw6TW81yDDv8PiNL9mscArvloVZzCi56JnbBj8hOOteuPt47WxjnmTEJxaDOINNlzJSixC4OKtusa0V9boyxpjSVuTHedAbBgIUSKrgBDx2gNdba0lYq%2B4KJKYZa2cKobMM84yaw%2BCC1sgauc7C5gb0BdwXOQgH26qvK1shCwcsJeRSZ2GlN%2FVQIsvCRHouYDstATyk%2B0YB60Wr1DntK2Ilu2weVrf5YvYuB5xFTi%2BkHdfh5d%2F%2F3KlxXFZRlWYRDit8KinoJxGrfscq2r8bfUegpHC6n9PiyxO7Dfr%2FNtw%2FtXjWnn3In2yl7H9Po5fKRZUJ9%2FnRadRiE5Fk1%2FxM7ovjei7%2FVZ3zNa00%2B%2BRE3620cqHt%2BgwZJPjBhEJWthiH%2BvEvoBWslaUalmxfKf8vY%2FA4AAP%2F%2F&RelayState=relayState&SigAlg=http%3A%2F%2Fwww.w3.org%2F2000%2F09%2Fxmldsig%23rsa-sha1&Signature=WqMc7vKRJVNXwNHJmTemdfw5OML2XkLntYw%2FzwKoLMfavV%2FYy6fBP0GeGYlJVMweZBvbpjwoe%2BgpRkUCHKDUgixCG7hPi41p6MpQC%2Fp7ExTW5plvlS97iVAOvaF5V1MjvQCgBNKYnKNnvwAuxK%2Bu3N4rZjwGM%2F4JGgjJ5pannFQ%3D \ No newline at end of file diff --git a/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha1_decodedRequest b/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha1_decodedRequest new file mode 100644 index 00000000..6cae3525 --- /dev/null +++ b/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha1_decodedRequest @@ -0,0 +1 @@ +https://15661444.ngrok.io/saml2/metadata \ No newline at end of file diff --git a/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha1_queryString b/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha1_queryString new file mode 100644 index 00000000..41ddd26c --- /dev/null +++ b/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha1_queryString @@ -0,0 +1 @@ +SAMLRequest=nFJNb9swDP0rhu62RNU1CqE2kDUYFqBbgzjbYTfVZhNituSJ9Lb%2B%2B8Fph2WXDOiV4uP70LtlPw6TW81yDDv8PiNL9mscArvloVZzCi56JnbBj8hOOteuPt47WxjnmTEJxaDOINNlzJSixC4OKtusa0V9boyxpjSVuTHedAbBgIUSKrgBDx2gNdba0lYq%2B4KJKYZa2cKobMM84yaw%2BCC1sgauc7C5gb0BdwXOQgH26qvK1shCwcsJeRSZ2GlN%2FVQIsvCRHouYDstATyk%2B0YB60Wr1DntK2Ilu2weVrf5YvYuB5xFTi%2BkHdfh5d%2F%2F3KlxXFZRlWYRDit8KinoJxGrfscq2r8bfUegpHC6n9PiyxO7Dfr%2FNtw%2FtXjWnn3In2yl7H9Po5fKRZUJ9%2FnRadRiE5Fk1%2FxM7ovjei7%2FVZ3zNa00%2B%2BRE3620cqHt%2BgwZJPjBhEJWthiH%2BvEvoBWslaUalmxfKf8vY%2FA4AAP%2F%2F&RelayState=relayState&SigAlg=http%3A%2F%2Fwww.w3.org%2F2000%2F09%2Fxmldsig%23rsa-sha1&Signature=WqMc7vKRJVNXwNHJmTemdfw5OML2XkLntYw%2FzwKoLMfavV%2FYy6fBP0GeGYlJVMweZBvbpjwoe%2BgpRkUCHKDUgixCG7hPi41p6MpQC%2Fp7ExTW5plvlS97iVAOvaF5V1MjvQCgBNKYnKNnvwAuxK%2Bu3N4rZjwGM%2F4JGgjJ5pannFQ%3D \ No newline at end of file diff --git a/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha256_decodedRequest b/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha256_decodedRequest new file mode 100644 index 00000000..6cae3525 --- /dev/null +++ b/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha256_decodedRequest @@ -0,0 +1 @@ +https://15661444.ngrok.io/saml2/metadata \ No newline at end of file diff --git a/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha256_queryString b/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha256_queryString new file mode 100644 index 00000000..41ddd26c --- /dev/null +++ b/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha256_queryString @@ -0,0 +1 @@ +SAMLRequest=nFJNb9swDP0rhu62RNU1CqE2kDUYFqBbgzjbYTfVZhNituSJ9Lb%2B%2B8Fph2WXDOiV4uP70LtlPw6TW81yDDv8PiNL9mscArvloVZzCi56JnbBj8hOOteuPt47WxjnmTEJxaDOINNlzJSixC4OKtusa0V9boyxpjSVuTHedAbBgIUSKrgBDx2gNdba0lYq%2B4KJKYZa2cKobMM84yaw%2BCC1sgauc7C5gb0BdwXOQgH26qvK1shCwcsJeRSZ2GlN%2FVQIsvCRHouYDstATyk%2B0YB60Wr1DntK2Ilu2weVrf5YvYuB5xFTi%2BkHdfh5d%2F%2F3KlxXFZRlWYRDit8KinoJxGrfscq2r8bfUegpHC6n9PiyxO7Dfr%2FNtw%2FtXjWnn3In2yl7H9Po5fKRZUJ9%2FnRadRiE5Fk1%2FxM7ovjei7%2FVZ3zNa00%2B%2BRE3620cqHt%2BgwZJPjBhEJWthiH%2BvEvoBWslaUalmxfKf8vY%2FA4AAP%2F%2F&RelayState=relayState&SigAlg=http%3A%2F%2Fwww.w3.org%2F2000%2F09%2Fxmldsig%23rsa-sha1&Signature=WqMc7vKRJVNXwNHJmTemdfw5OML2XkLntYw%2FzwKoLMfavV%2FYy6fBP0GeGYlJVMweZBvbpjwoe%2BgpRkUCHKDUgixCG7hPi41p6MpQC%2Fp7ExTW5plvlS97iVAOvaF5V1MjvQCgBNKYnKNnvwAuxK%2Bu3N4rZjwGM%2F4JGgjJ5pannFQ%3D \ No newline at end of file diff --git a/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha384_decodedRequest b/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha384_decodedRequest new file mode 100644 index 00000000..6cae3525 --- /dev/null +++ b/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha384_decodedRequest @@ -0,0 +1 @@ +https://15661444.ngrok.io/saml2/metadata \ No newline at end of file diff --git a/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha384_queryString b/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha384_queryString new file mode 100644 index 00000000..41ddd26c --- /dev/null +++ b/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha384_queryString @@ -0,0 +1 @@ +SAMLRequest=nFJNb9swDP0rhu62RNU1CqE2kDUYFqBbgzjbYTfVZhNituSJ9Lb%2B%2B8Fph2WXDOiV4uP70LtlPw6TW81yDDv8PiNL9mscArvloVZzCi56JnbBj8hOOteuPt47WxjnmTEJxaDOINNlzJSixC4OKtusa0V9boyxpjSVuTHedAbBgIUSKrgBDx2gNdba0lYq%2B4KJKYZa2cKobMM84yaw%2BCC1sgauc7C5gb0BdwXOQgH26qvK1shCwcsJeRSZ2GlN%2FVQIsvCRHouYDstATyk%2B0YB60Wr1DntK2Ilu2weVrf5YvYuB5xFTi%2BkHdfh5d%2F%2F3KlxXFZRlWYRDit8KinoJxGrfscq2r8bfUegpHC6n9PiyxO7Dfr%2FNtw%2FtXjWnn3In2yl7H9Po5fKRZUJ9%2FnRadRiE5Fk1%2FxM7ovjei7%2FVZ3zNa00%2B%2BRE3620cqHt%2BgwZJPjBhEJWthiH%2BvEvoBWslaUalmxfKf8vY%2FA4AAP%2F%2F&RelayState=relayState&SigAlg=http%3A%2F%2Fwww.w3.org%2F2000%2F09%2Fxmldsig%23rsa-sha1&Signature=WqMc7vKRJVNXwNHJmTemdfw5OML2XkLntYw%2FzwKoLMfavV%2FYy6fBP0GeGYlJVMweZBvbpjwoe%2BgpRkUCHKDUgixCG7hPi41p6MpQC%2Fp7ExTW5plvlS97iVAOvaF5V1MjvQCgBNKYnKNnvwAuxK%2Bu3N4rZjwGM%2F4JGgjJ5pannFQ%3D \ No newline at end of file diff --git a/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha512_decodedRequest b/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha512_decodedRequest new file mode 100644 index 00000000..6cae3525 --- /dev/null +++ b/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha512_decodedRequest @@ -0,0 +1 @@ +https://15661444.ngrok.io/saml2/metadata \ No newline at end of file diff --git a/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha512_queryString b/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha512_queryString new file mode 100644 index 00000000..41ddd26c --- /dev/null +++ b/testdata/TestSPCanProduceSignedRequestRedirectBinding/rsa-sha512_queryString @@ -0,0 +1 @@ +SAMLRequest=nFJNb9swDP0rhu62RNU1CqE2kDUYFqBbgzjbYTfVZhNituSJ9Lb%2B%2B8Fph2WXDOiV4uP70LtlPw6TW81yDDv8PiNL9mscArvloVZzCi56JnbBj8hOOteuPt47WxjnmTEJxaDOINNlzJSixC4OKtusa0V9boyxpjSVuTHedAbBgIUSKrgBDx2gNdba0lYq%2B4KJKYZa2cKobMM84yaw%2BCC1sgauc7C5gb0BdwXOQgH26qvK1shCwcsJeRSZ2GlN%2FVQIsvCRHouYDstATyk%2B0YB60Wr1DntK2Ilu2weVrf5YvYuB5xFTi%2BkHdfh5d%2F%2F3KlxXFZRlWYRDit8KinoJxGrfscq2r8bfUegpHC6n9PiyxO7Dfr%2FNtw%2FtXjWnn3In2yl7H9Po5fKRZUJ9%2FnRadRiE5Fk1%2FxM7ovjei7%2FVZ3zNa00%2B%2BRE3620cqHt%2BgwZJPjBhEJWthiH%2BvEvoBWslaUalmxfKf8vY%2FA4AAP%2F%2F&RelayState=relayState&SigAlg=http%3A%2F%2Fwww.w3.org%2F2000%2F09%2Fxmldsig%23rsa-sha1&Signature=WqMc7vKRJVNXwNHJmTemdfw5OML2XkLntYw%2FzwKoLMfavV%2FYy6fBP0GeGYlJVMweZBvbpjwoe%2BgpRkUCHKDUgixCG7hPi41p6MpQC%2Fp7ExTW5plvlS97iVAOvaF5V1MjvQCgBNKYnKNnvwAuxK%2Bu3N4rZjwGM%2F4JGgjJ5pannFQ%3D \ No newline at end of file diff --git a/xmlenc/cbc.go b/xmlenc/cbc.go index 991ba1eb..bb0e2882 100644 --- a/xmlenc/cbc.go +++ b/xmlenc/cbc.go @@ -3,7 +3,7 @@ package xmlenc import ( "crypto/aes" "crypto/cipher" - "crypto/des" // nolint: gas + "crypto/des" // nolint: gosec "encoding/base64" "errors" "fmt" diff --git a/xmlenc/decrypt.go b/xmlenc/decrypt.go index 98a575da..ea288b15 100644 --- a/xmlenc/decrypt.go +++ b/xmlenc/decrypt.go @@ -1,8 +1,6 @@ package xmlenc import ( - - // nolint: gas "crypto/rsa" "crypto/x509" "encoding/base64" diff --git a/xmlenc/decrypt_test.go b/xmlenc/decrypt_test.go index a8872f22..3670da50 100644 --- a/xmlenc/decrypt_test.go +++ b/xmlenc/decrypt_test.go @@ -18,6 +18,7 @@ func TestCanDecrypt(t *testing.T) { err := doc.ReadFromBytes(golden.Get(t, "input.xml")) assert.Check(t, err) + //nolint:gosec keyPEM := "-----BEGIN RSA PRIVATE KEY-----\nMIICXgIBAAKBgQDU8wdiaFmPfTyRYuFlVPi866WrH/2JubkHzp89bBQopDaLXYxi\n3PTu3O6Q/KaKxMOFBqrInwqpv/omOGZ4ycQ51O9I+Yc7ybVlW94lTo2gpGf+Y/8E\nPsVbnZaFutRctJ4dVIp9aQ2TpLiGT0xX1OzBO/JEgq9GzDRf+B+eqSuglwIDAQAB\nAoGBAMuy1eN6cgFiCOgBsB3gVDdTKpww87Qk5ivjqEt28SmXO13A1KNVPS6oQ8SJ\nCT5Azc6X/BIAoJCURVL+LHdqebogKljhH/3yIel1kH19vr4E2kTM/tYH+qj8afUS\nJEmArUzsmmK8ccuNqBcllqdwCZjxL4CHDUmyRudFcHVX9oyhAkEA/OV1OkjM3CLU\nN3sqELdMmHq5QZCUihBmk3/N5OvGdqAFGBlEeewlepEVxkh7JnaNXAXrKHRVu/f/\nfbCQxH+qrwJBANeQERF97b9Sibp9xgolb749UWNlAdqmEpmlvmS202TdcaaT1msU\n4rRLiQN3X9O9mq4LZMSVethrQAdX1whawpkCQQDk1yGf7xZpMJ8F4U5sN+F4rLyM\nRq8Sy8p2OBTwzCUXXK+fYeXjybsUUMr6VMYTRP2fQr/LKJIX+E5ZxvcIyFmDAkEA\nyfjNVUNVaIbQTzEbRlRvT6MqR+PTCefC072NF9aJWR93JimspGZMR7viY6IM4lrr\nvBkm0F5yXKaYtoiiDMzlOQJADqmEwXl0D72ZG/2KDg8b4QZEmC9i5gidpQwJXUc6\nhU+IVQoLxRq0fBib/36K9tcrrO5Ba4iEvDcNY+D8yGbUtA==\n-----END RSA PRIVATE KEY-----\n" b, _ := pem.Decode([]byte(keyPEM)) key, err := x509.ParsePKCS1PrivateKey(b.Bytes) diff --git a/xmlenc/digest.go b/xmlenc/digest.go index 3eaaf7bc..9a46450a 100644 --- a/xmlenc/digest.go +++ b/xmlenc/digest.go @@ -6,7 +6,7 @@ import ( "crypto/sha512" "hash" - //nolint:staticcheck // We should support this for legacy reasons. + //nolint:staticcheck,gosec // We should support this for legacy reasons. "golang.org/x/crypto/ripemd160" ) diff --git a/xmlenc/fuzz.go b/xmlenc/fuzz.go index c035d65f..ae23bdf3 100644 --- a/xmlenc/fuzz.go +++ b/xmlenc/fuzz.go @@ -9,6 +9,7 @@ import ( ) var testKey = func() *rsa.PrivateKey { + //nolint:gosec const keyStr = `-----BEGIN RSA PRIVATE KEY----- MIICXQIBAAKBgQDkXTUsWzRVpUHjbDpWCfYDfXmQ/q4LkaioZoTpu4ut1Q3eQC5t gD14agJhgT8yzeY5S/YNlwCyuVkjuFyoyTHFX2IOPpz7jnh4KnQ+B1IH9fY/+kmk diff --git a/xmlenc/pubkey.go b/xmlenc/pubkey.go index 13d4d9e7..f8eae9cb 100644 --- a/xmlenc/pubkey.go +++ b/xmlenc/pubkey.go @@ -125,6 +125,9 @@ func (e RSA) Decrypt(key interface{}, ciphertextEl *etree.Element) ([]byte, erro // the block cipher used is AES-256 CBC and the digest method is SHA-256. You can // specify other ciphers and digest methods by assigning to BlockCipher or // DigestMethod. +// +// OAEP implements the older RSA-OAEP (2001 spec) for backward compatibility, you might +// perfer OAEP_2009_256 over using this method. func OAEP() RSA { return RSA{ BlockCipher: AES256CBC, @@ -139,6 +142,44 @@ func OAEP() RSA { } } +// OAEP_SHA256 returns a version of RSA that implements RSA in OAEP mode. By default +// the block cipher used is AES-256 CBC and the digest method is SHA-256. You can +// specify other ciphers and digest methods by assigning to BlockCipher or +// DigestMethod. +func OAEP_SHA256() RSA { //nolint:revive + return RSA{ + BlockCipher: AES256CBC, + DigestMethod: SHA256, + algorithm: "http://www.w3.org/2009/xmlenc11#rsa-oaep", + + keyEncrypter: func(e RSA, pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error) { + return rsa.EncryptOAEP(e.DigestMethod.Hash(), RandReader, pubKey, plaintext, nil) + }, + keyDecrypter: func(e RSA, privKey *rsa.PrivateKey, ciphertext []byte) ([]byte, error) { + return rsa.DecryptOAEP(e.DigestMethod.Hash(), RandReader, privKey, ciphertext, nil) + }, + } +} + +// OAEP_SHA512 returns a version of RSA that implements RSA in OAEP mode. By default +// the block cipher used is AES-256 CBC and the digest method is SHA-512. You can +// specify other ciphers and digest methods by assigning to BlockCipher or +// DigestMethod. +func OAEP_SHA512() RSA { //nolint:revive + return RSA{ + BlockCipher: AES256CBC, + DigestMethod: SHA512, + algorithm: "http://www.w3.org/2009/xmlenc11#rsa-oaep", + + keyEncrypter: func(e RSA, pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error) { + return rsa.EncryptOAEP(e.DigestMethod.Hash(), RandReader, pubKey, plaintext, nil) + }, + keyDecrypter: func(e RSA, privKey *rsa.PrivateKey, ciphertext []byte) ([]byte, error) { + return rsa.DecryptOAEP(e.DigestMethod.Hash(), RandReader, privKey, ciphertext, nil) + }, + } +} + // PKCS1v15 returns a version of RSA that implements RSA in PKCS1v15 mode. By default // the block cipher used is AES-256 CBC. The DigestMethod field is ignored because PKCS1v15 // does not use a digest function. @@ -147,10 +188,10 @@ func PKCS1v15() RSA { BlockCipher: AES256CBC, DigestMethod: nil, algorithm: "http://www.w3.org/2001/04/xmlenc#rsa-1_5", - keyEncrypter: func(e RSA, pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error) { + keyEncrypter: func(_ RSA, pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error) { return rsa.EncryptPKCS1v15(RandReader, pubKey, plaintext) }, - keyDecrypter: func(e RSA, privKey *rsa.PrivateKey, ciphertext []byte) ([]byte, error) { + keyDecrypter: func(_ RSA, privKey *rsa.PrivateKey, ciphertext []byte) ([]byte, error) { return rsa.DecryptPKCS1v15(RandReader, privKey, ciphertext) }, }