diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d6b985d1..688090cb 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -10,8 +10,11 @@ jobs: golangci: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: stable - name: golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v7 with: - version: v1.54.2 + version: v2.0 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e6e91273..bca5ba69 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,10 +10,11 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: [ '1.19.x', '1.20.x', '1.21.x'] + go: [ '1.22.x', '1.23.x', '1.24.x'] steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} + - run: go version - run: go test -v ./... diff --git a/.golangci.yml b/.golangci.yml index 23f37cbf..46ed573a 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,72 +1,70 @@ -# Configuration file for golangci-lint -# -# https://github.com/golangci/golangci-lint -# -# fighting with false positives? -# https://github.com/golangci/golangci-lint#nolint - +version: "2" linters: enable: - - bodyclose # checks whether HTTP response body is closed successfully [fast: false, auto-fix: false] - - errcheck # Inspects source code for security problems [fast: true, auto-fix: false] - - gocritic # The most opinionated Go source code linter [fast: true, auto-fix: false] - - gocyclo # Computes and checks the cyclomatic complexity of functions [fast: true, auto-fix: false] - - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification [fast: true, auto-fix: true] - - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports [fast: true, auto-fix: true] - - gosec # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases [fast: true, auto-fix: false] - - gosimple # Linter for Go source code that specializes in simplifying a code [fast: false, auto-fix: false] - - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string [fast: false, auto-fix: false] - - ineffassign # Detects when assignments to existing variables are not used [fast: true, auto-fix: false] - - misspell # Finds commonly misspelled English words in comments [fast: true, auto-fix: true] - - nakedret # Finds naked returns in functions greater than a specified function length [fast: true, auto-fix: false] - - prealloc # Finds slice declarations that could potentially be preallocated [fast: true, auto-fix: false] - - revive # Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes [fast: true, auto-fix: false] - - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks [fast: false, auto-fix: false] - - stylecheck # Stylecheck is a replacement for golint [fast: false, auto-fix: false] - - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code [fast: true, auto-fix: false] - - unconvert # Remove unnecessary type conversions [fast: true, auto-fix: false] - - unparam # Reports unused function parameters [fast: false, auto-fix: false] - - unused # Checks Go code for unused constants, variables, functions and types [fast: false, auto-fix: false] - + - bodyclose + - gocritic + - gocyclo + - gosec + - misspell + - nakedret + - prealloc + - revive + - staticcheck + - unconvert + - unparam disable: # TODO(ross): fix errors reported by these checkers and enable them - - dupl # Tool for code clone detection [fast: true, auto-fix: false] - - gochecknoglobals # Checks that no globals are present in Go code [fast: true, auto-fix: false] - - gochecknoinits # Checks that no init functions are present in Go code [fast: true, auto-fix: false] - - goconst # Finds repeated strings that could be replaced by a constant [fast: true, auto-fix: false] - - lll # Reports long lines [fast: true, auto-fix: false] - - depguard # Go linter that checks if package imports are in a list of acceptable packages [fast: true, auto-fix: false] -linters-settings: - goimports: - local-prefixes: github.com/crewjam/saml - govet: - disable: - - shadow - enable: - - asmdecl - - assign - - atomic - - bools - - buildtag - - cgocall - - composites - - copylocks - - errorsas - - httpresponse - - loopclosure - - lostcancel - - nilfunc - - printf - - shift - - stdmethods - - structtag - - tests - - unmarshal - - unreachable - - unsafeptr - - unusedresult -issues: - exclude-use-default: false - exclude: - - G104 # 'Errors unhandled. (gosec) - + - depguard + - dupl + - gochecknoglobals + - gochecknoinits + - goconst + - lll + settings: + govet: + enable: + - asmdecl + - assign + - atomic + - bools + - buildtag + - cgocall + - composites + - copylocks + - errorsas + - httpresponse + - loopclosure + - lostcancel + - nilfunc + - printf + - shift + - stdmethods + - structtag + - tests + - unmarshal + - unreachable + - unsafeptr + - unusedresult + disable: + - shadow + exclusions: + generated: lax + rules: + - path: (.+)\.go$ + text: G104 # 'Errors unhandled. (gosec) + paths: + - example/.*\.go$ +formatters: + enable: + - gofmt + - goimports + settings: + goimports: + local-prefixes: + - github.com/clerk/saml + exclusions: + generated: lax + paths: + - third_party$ + - builtin$ + - examples$ diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 00000000..6da4fdeb --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +* @crewjam diff --git a/README.md b/README.md index cd1415a5..1adeccff 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [](http://godoc.org/github.com/crewjam/saml) - + Package saml contains a partial implementation of the SAML standard in golang. SAML is a standard for identity federation, i.e. either allowing a third party to authenticate your users or allowing third parties to rely on us to authenticate their users. @@ -130,7 +130,7 @@ The SAML standard is huge and complex with many dark corners and strange, unused This package supports the **Web SSO** profile. Message flows from the service provider to the IDP are supported using the **HTTP Redirect** binding and the **HTTP POST** binding. Message flows from the IDP to the service provider are supported via the **HTTP POST** binding. -The package can produce signed SAML assertions, and can validate both signed and encrypted SAML assertions. It does not support signed or encrypted requests. +The package can produce signed SAML assertions, and can validate both signed and encrypted SAML assertions. ## RelayState diff --git a/example/idp/idp.go b/example/idp/idp.go index 6ed039a5..e1e54de8 100644 --- a/example/idp/idp.go +++ b/example/idp/idp.go @@ -6,9 +6,9 @@ import ( "crypto/x509" "encoding/pem" "flag" + "net/http" "net/url" - "github.com/zenazn/goji" "golang.org/x/crypto/bcrypt" "github.com/clerk/saml/logger" @@ -118,6 +118,5 @@ func main() { logr.Fatalf("%s", err) } - goji.Handle("/*", idpServer) - goji.Serve() + http.ListenAndServe(":8080", idpServer) } diff --git a/example/service.go b/example/service.go index 1be285c8..35ae8a41 100644 --- a/example/service.go +++ b/example/service.go @@ -7,6 +7,7 @@ import ( "crypto/rsa" "crypto/tls" "crypto/x509" + "encoding/base64" "encoding/xml" "flag" "fmt" @@ -14,11 +15,6 @@ import ( "net/url" "strings" - "github.com/dchest/uniuri" - "github.com/kr/pretty" - "github.com/zenazn/goji" - "github.com/zenazn/goji/web" - "github.com/clerk/saml/samlsp" ) @@ -32,10 +28,16 @@ type Link struct { } // CreateLink handles requests to create links -func CreateLink(_ web.C, w http.ResponseWriter, r *http.Request) { +func CreateLink(w http.ResponseWriter, r *http.Request) { account := r.Header.Get("X-Remote-User") + + randomness := make([]byte, 8) + if _, err := r.Body.Read(randomness); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } l := Link{ - ShortLink: uniuri.New(), + ShortLink: base64.RawURLEncoding.EncodeToString(randomness), Target: r.FormValue("t"), Owner: account, } @@ -45,7 +47,7 @@ func CreateLink(_ web.C, w http.ResponseWriter, r *http.Request) { } // ServeLink handles requests to redirect to a link -func ServeLink(_ web.C, w http.ResponseWriter, r *http.Request) { +func ServeLink(w http.ResponseWriter, r *http.Request) { l, ok := links[strings.TrimPrefix(r.URL.Path, "/")] if !ok { http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) @@ -55,7 +57,7 @@ func ServeLink(_ web.C, w http.ResponseWriter, r *http.Request) { } // ListLinks returns a list of the current user's links -func ListLinks(_ web.C, w http.ResponseWriter, r *http.Request) { +func ListLinks(w http.ResponseWriter, r *http.Request) { account := r.Header.Get("X-Remote-User") for _, l := range links { if l.Owner == account { @@ -64,6 +66,27 @@ func ListLinks(_ web.C, w http.ResponseWriter, r *http.Request) { } } +// ServeWhoami serves the basic whoami endpoint +func ServeWhoami(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", "text/plain") + + session := samlsp.SessionFromContext(r.Context()) + if session == nil { + fmt.Fprintln(w, "not signed in") + return + } + fmt.Fprintln(w, "signed in") + sessionWithAttrs, ok := session.(samlsp.SessionWithAttributes) + if ok { + fmt.Fprintln(w, "attributes:") + for name, values := range sessionWithAttrs.GetAttributes() { + for _, value := range values { + fmt.Fprintf(w, "%s: %v\n", name, value) + } + } + } +} + var ( key = []byte(`-----BEGIN RSA PRIVATE KEY----- MIICXgIBAAKBgQDU8wdiaFmPfTyRYuFlVPi866WrH/2JubkHzp89bBQopDaLXYxi @@ -140,11 +163,9 @@ func main() { // register with the service provider spMetadataBuf, _ := xml.MarshalIndent(samlSP.ServiceProvider.Metadata(), "", " ") - spURL := *idpMetadataURL spURL.Path = "/services/sp" resp, err := http.Post(spURL.String(), "text/xml", bytes.NewReader(spMetadataBuf)) - if err != nil { panic(err) } @@ -153,20 +174,12 @@ func main() { panic(err) } - goji.Handle("/saml/*", samlSP) - - authMux := web.New() - authMux.Use(samlSP.RequireAccount) - authMux.Get("/whoami", func(w http.ResponseWriter, r *http.Request) { - if _, err := pretty.Fprintf(w, "%# v", r); err != nil { - panic(err) - } - }) - authMux.Post("/", CreateLink) - authMux.Get("/", ListLinks) - - goji.Handle("/*", authMux) - goji.Get("/:link", ServeLink) + mux := http.NewServeMux() + mux.Handle("GET /saml/", samlSP) + mux.HandleFunc("GET /{link}", ServeLink) + mux.Handle("GET /whoami", samlSP.RequireAccount(http.HandlerFunc(ServeWhoami))) + mux.Handle("POST /", samlSP.RequireAccount(http.HandlerFunc(CreateLink))) + mux.Handle("GET /", samlSP.RequireAccount(http.HandlerFunc(ListLinks))) - goji.Serve() + http.ListenAndServe(":8080", mux) } diff --git a/go.mod b/go.mod index e8a75b8a..aaacc53e 100644 --- a/go.mod +++ b/go.mod @@ -3,26 +3,17 @@ module github.com/clerk/saml go 1.23.0 require ( - github.com/beevik/etree v1.2.0 - github.com/crewjam/httperr v0.2.0 - github.com/dchest/uniuri v1.2.0 - github.com/golang-jwt/jwt/v4 v4.5.2 - github.com/google/go-cmp v0.6.0 - github.com/kr/pretty v0.3.1 + github.com/beevik/etree v1.5.0 + github.com/golang-jwt/jwt/v5 v5.2.2 + github.com/google/go-cmp v0.7.0 github.com/mattermost/xml-roundtrip-validator v0.1.0 github.com/russellhaering/goxmldsig v1.4.0 - github.com/stretchr/testify v1.8.4 - github.com/zenazn/goji v1.0.1 golang.org/x/crypto v0.35.0 gotest.tools v2.2.0+incompatible ) require ( - github.com/davecgh/go-spew v1.1.1 // indirect github.com/jonboulle/clockwork v0.2.2 // indirect - github.com/kr/text v0.2.0 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.9.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect + github.com/stretchr/testify v1.10.0 // indirect ) diff --git a/go.sum b/go.sum index 9395b045..c8fd7c98 100644 --- a/go.sum +++ b/go.sum @@ -1,58 +1,43 @@ github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A= -github.com/beevik/etree v1.2.0 h1:l7WETslUG/T+xOPs47dtd6jov2Ii/8/OjCldk5fYfQw= -github.com/beevik/etree v1.2.0/go.mod h1:aiPf89g/1k3AShMVAzriilpcE4R/Vuor90y83zVZWFc= +github.com/beevik/etree v1.5.0 h1:iaQZFSDS+3kYZiGoc9uKeOkUY3nYMXOKLl6KIJxiJWs= +github.com/beevik/etree v1.5.0/go.mod h1:gPNJNaBGVZ9AwsidazFZyygnd+0pAU38N4D+WemwKNs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/crewjam/httperr v0.2.0 h1:b2BfXR8U3AlIHwNeFFvZ+BV1LFvKLlzMjzaTnZMybNo= -github.com/crewjam/httperr v0.2.0/go.mod h1:Jlz+Sg/XqBQhyMjdDiC+GNNRzZTD7x39Gu3pglZ5oH4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dchest/uniuri v1.2.0 h1:koIcOUdrTIivZgSLhHQvKgqdWZq5d7KdMEWF1Ud6+5g= -github.com/dchest/uniuri v1.2.0/go.mod h1:fSzm4SLHzNZvWLvWJew423PhAzkpNQYq+uNLq4kxhkY= -github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= -github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ= github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU= github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/russellhaering/goxmldsig v1.4.0 h1:8UcDh/xGyQiyrW+Fq5t8f+l2DLB1+zlhYzkPUJ7Qhys= github.com/russellhaering/goxmldsig v1.4.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/zenazn/goji v1.0.1 h1:4lbD8Mx2h7IvloP7r2C0D6ltZP6Ufip8Hn0wmSK5LR8= -github.com/zenazn/goji v1.0.1/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/identity_provider.go b/identity_provider.go index dd754faa..f0ee6b74 100644 --- a/identity_provider.go +++ b/identity_provider.go @@ -8,13 +8,13 @@ import ( "encoding/base64" "encoding/xml" "fmt" + "html/template" "io" "net/http" "net/url" "os" "regexp" "strconv" - "text/template" "time" "github.com/beevik/etree" @@ -38,13 +38,14 @@ type Session struct { NameIDFormat string SubjectID string - Groups []string - UserName string - UserEmail string - UserCommonName string - UserSurname string - UserGivenName string - UserScopedAffiliation string + Groups []string + UserName string + UserEmail string + UserCommonName string + UserSurname string + UserGivenName string + UserScopedAffiliation string + EduPersonPrincipalName string `json:",omitempty"` CustomAttributes []Attribute } @@ -101,12 +102,14 @@ type IdentityProvider struct { Intermediates []*x509.Certificate MetadataURL url.URL SSOURL url.URL + LoginURL url.URL LogoutURL url.URL ServiceProviderProvider ServiceProviderProvider SessionProvider SessionProvider AssertionMaker AssertionMaker SignatureMethod string ValidDuration *time.Duration + ResponseFormTemplate *template.Template } // Metadata returns the metadata structure for this identity provider. @@ -175,7 +178,7 @@ func (idp *IdentityProvider) Metadata() *EntityDescriptor { } if idp.LogoutURL.String() != "" { - ed.IDPSSODescriptors[0].SSODescriptor.SingleLogoutServices = []Endpoint{ + ed.IDPSSODescriptors[0].SingleLogoutServices = []Endpoint{ { Binding: HTTPRedirectBinding, Location: idp.LogoutURL.String(), @@ -662,13 +665,33 @@ func (DefaultAssertionMaker) MakeAssertion(req *IdpAuthnRequest, session *Sessio } if session.UserEmail != "" { + attributes = append(attributes, Attribute{ + FriendlyName: "mail", + Name: "urn:oid:0.9.2342.19200300.100.1.3", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + Values: []AttributeValue{{ + Type: "xs:string", + Value: session.UserEmail, + }}, + }) + } + if session.EduPersonPrincipalName != "" || session.UserEmail != "" { + value := session.EduPersonPrincipalName + if value == "" { + // We used to set eduPersonPrincipalName (urn:oid:1.3.6.1.4.1.5923.1.1.1.6) + // to the value of session.UserEmail. It is more correct to set + // mail (urn:oid:0.9.2342.19200300.100.1.3). To avoid breaking things, + // we preserve the former behavior. + value = session.UserEmail + } + attributes = append(attributes, Attribute{ FriendlyName: "eduPersonPrincipalName", Name: "urn:oid:1.3.6.1.4.1.5923.1.1.1.6", NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", Values: []AttributeValue{{ Type: "xs:string", - Value: session.UserEmail, + Value: value, }}, }) } @@ -709,7 +732,7 @@ func (DefaultAssertionMaker) MakeAssertion(req *IdpAuthnRequest, session *Sessio if session.UserScopedAffiliation != "" { attributes = append(attributes, Attribute{ - FriendlyName: "uid", + FriendlyName: "scopedAffiliation", Name: "urn:oid:1.3.6.1.4.1.5923.1.1.1.9", NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", Values: []AttributeValue{{ @@ -921,6 +944,16 @@ func (req *IdpAuthnRequest) PostBinding() (IdpAuthnRequestForm, error) { return form, nil } +var defaultResponseFormTemplate = template.Must(template.New("saml-post-form").Parse(`` + + `
` + + `` + + `` + + ``)) + // WriteResponse writes the `Response` to the http.ResponseWriter. If // `Response` is not already set, it calls MakeResponse to produce it. func (req *IdpAuthnRequest) WriteResponse(w http.ResponseWriter) error { @@ -929,15 +962,10 @@ func (req *IdpAuthnRequest) WriteResponse(w http.ResponseWriter) error { return err } - tmpl := template.Must(template.New("saml-post-form").Parse(`` + - `` + - `` + - `` + - ``)) + tmpl := req.IDP.ResponseFormTemplate + if tmpl == nil { + tmpl = defaultResponseFormTemplate + } buf := bytes.NewBuffer(nil) if err := tmpl.Execute(buf, form); err != nil { diff --git a/identity_provider_test.go b/identity_provider_test.go index 8c7d8c00..fc578b17 100644 --- a/identity_provider_test.go +++ b/identity_provider_test.go @@ -25,7 +25,6 @@ import ( "gotest.tools/golden" "github.com/beevik/etree" - "github.com/golang-jwt/jwt/v4" dsig "github.com/russellhaering/goxmldsig" "github.com/clerk/saml/logger" @@ -104,7 +103,6 @@ func NewIdentityProviderTest(t *testing.T, opts ...idpTestOpts) *IdentityProvide rv, _ := time.Parse("Mon Jan 2 15:04:05 MST 2006", "Mon Dec 1 01:57:09 UTC 2015") return rv } - jwt.TimeFunc = TimeNow RandReader = &testRandomReader{} // TODO(ross): remove this and use the below generator xmlenc.RandReader = rand.New(rand.NewSource(0)) //nolint:gosec // deterministic random numbers for tests @@ -126,7 +124,7 @@ func NewIdentityProviderTest(t *testing.T, opts ...idpTestOpts) *IdentityProvide MetadataURL: mustParseURL("https://idp.example.com/saml/metadata"), SSOURL: mustParseURL("https://idp.example.com/saml/sso"), ServiceProviderProvider: &mockServiceProviderProvider{ - GetServiceProviderFunc: func(r *http.Request, serviceProviderID string) (*EntityDescriptor, error) { + GetServiceProviderFunc: func(_ *http.Request, serviceProviderID string) (*EntityDescriptor, error) { if serviceProviderID == test.SP.MetadataURL.String() { return test.SP.Metadata(), nil } @@ -134,7 +132,7 @@ func NewIdentityProviderTest(t *testing.T, opts ...idpTestOpts) *IdentityProvide }, }, SessionProvider: &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { + GetSessionFunc: func(_ http.ResponseWriter, _ *http.Request, _ *IdpAuthnRequest) *Session { return nil }, }, @@ -241,9 +239,10 @@ func TestIDPHTTPCanHandleMetadataRequest(t *testing.T) { func TestIDPCanHandleRequestWithNewSession(t *testing.T) { test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { - fmt.Fprintf(w, "RelayState: %s\nSAMLRequest: %s", + GetSessionFunc: func(w http.ResponseWriter, _ *http.Request, req *IdpAuthnRequest) *Session { + _, err := fmt.Fprintf(w, "RelayState: %s\nSAMLRequest: %s", req.RelayState, req.RequestBuffer) + assert.NilError(t, err) return nil }, } @@ -267,7 +266,7 @@ func TestIDPCanHandleRequestWithNewSession(t *testing.T) { func TestIDPCanHandleRequestWithExistingSession(t *testing.T) { test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { + GetSessionFunc: func(_ http.ResponseWriter, _ *http.Request, _ *IdpAuthnRequest) *Session { return &Session{ ID: "f00df00df00d", UserName: "alice", @@ -292,7 +291,7 @@ func TestIDPCanHandleRequestWithExistingSession(t *testing.T) { func TestIDPCanHandlePostRequestWithExistingSession(t *testing.T) { test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { + GetSessionFunc: func(_ http.ResponseWriter, _ *http.Request, _ *IdpAuthnRequest) *Session { return &Session{ ID: "f00df00df00d", UserName: "alice", @@ -321,7 +320,7 @@ func TestIDPCanHandlePostRequestWithExistingSession(t *testing.T) { func TestIDPRejectsInvalidRequest(t *testing.T) { test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { + GetSessionFunc: func(_ http.ResponseWriter, _ *http.Request, _ *IdpAuthnRequest) *Session { panic("not reached") }, } @@ -484,7 +483,6 @@ func TestIDPCanValidate(t *testing.T) { ""), } assert.Check(t, is.Error(req.Validate(), "cannot find assertion consumer service: file does not exist")) - } func TestIDPMakeAssertion(t *testing.T) { @@ -591,83 +589,93 @@ func TestIDPMakeAssertion(t *testing.T) { }) assert.Check(t, err) - expectedAttributes := - []Attribute{ - { - FriendlyName: "uid", - Name: "urn:oid:0.9.2342.19200300.100.1.1", - NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", - Values: []AttributeValue{ - { - Type: "xs:string", - Value: "alice", - }, + expectedAttributes := []Attribute{ + { + FriendlyName: "uid", + Name: "urn:oid:0.9.2342.19200300.100.1.1", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + Values: []AttributeValue{ + { + Type: "xs:string", + Value: "alice", }, }, - { - FriendlyName: "eduPersonPrincipalName", - Name: "urn:oid:1.3.6.1.4.1.5923.1.1.1.6", - NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", - Values: []AttributeValue{ - { - Type: "xs:string", - Value: "alice@example.com", - }, + }, + { + FriendlyName: "mail", + Name: "urn:oid:0.9.2342.19200300.100.1.3", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + Values: []AttributeValue{ + { + Type: "xs:string", + Value: "alice@example.com", }, }, - { - FriendlyName: "sn", - Name: "urn:oid:2.5.4.4", - NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", - Values: []AttributeValue{ - { - Type: "xs:string", - Value: "Smith", - }, + }, + { + FriendlyName: "eduPersonPrincipalName", + Name: "urn:oid:1.3.6.1.4.1.5923.1.1.1.6", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + Values: []AttributeValue{ + { + Type: "xs:string", + Value: "alice@example.com", }, }, - { - FriendlyName: "givenName", - Name: "urn:oid:2.5.4.42", - NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", - Values: []AttributeValue{ - { - Type: "xs:string", - Value: "Alice", - }, + }, + { + FriendlyName: "sn", + Name: "urn:oid:2.5.4.4", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + Values: []AttributeValue{ + { + Type: "xs:string", + Value: "Smith", }, }, - { - FriendlyName: "cn", - Name: "urn:oid:2.5.4.3", - NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", - Values: []AttributeValue{ - { - Type: "xs:string", - Value: "Alice Smith", - }, + }, + { + FriendlyName: "givenName", + Name: "urn:oid:2.5.4.42", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + Values: []AttributeValue{ + { + Type: "xs:string", + Value: "Alice", }, }, - { - FriendlyName: "eduPersonAffiliation", - Name: "urn:oid:1.3.6.1.4.1.5923.1.1.1.1", - NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", - Values: []AttributeValue{ - { - Type: "xs:string", - Value: "Users", - }, - { - Type: "xs:string", - Value: "Administrators", - }, - { - Type: "xs:string", - Value: "♀", - }, + }, + { + FriendlyName: "cn", + Name: "urn:oid:2.5.4.3", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + Values: []AttributeValue{ + { + Type: "xs:string", + Value: "Alice Smith", }, }, - } + }, + { + FriendlyName: "eduPersonAffiliation", + Name: "urn:oid:1.3.6.1.4.1.5923.1.1.1.1", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + Values: []AttributeValue{ + { + Type: "xs:string", + Value: "Users", + }, + { + Type: "xs:string", + Value: "Administrators", + }, + { + Type: "xs:string", + Value: "♀", + }, + }, + }, + } assert.Check(t, is.DeepEqual(expectedAttributes, req.Assertion.AttributeStatements[0].Attributes)) } @@ -798,8 +806,9 @@ func TestIDPWriteResponse(t *testing.T) { func TestIDPIDPInitiatedNewSession(t *testing.T) { test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { - fmt.Fprintf(w, "RelayState: %s", req.RelayState) + GetSessionFunc: func(w http.ResponseWriter, _ *http.Request, req *IdpAuthnRequest) *Session { + _, err := fmt.Fprintf(w, "RelayState: %s", req.RelayState) + assert.NilError(t, err) return nil }, } @@ -814,7 +823,7 @@ func TestIDPIDPInitiatedNewSession(t *testing.T) { func TestIDPIDPInitiatedExistingSession(t *testing.T) { test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { + GetSessionFunc: func(_ http.ResponseWriter, _ *http.Request, _ *IdpAuthnRequest) *Session { return &Session{ ID: "f00df00df00d", UserName: "alice", @@ -832,7 +841,7 @@ func TestIDPIDPInitiatedExistingSession(t *testing.T) { func TestIDPIDPInitiatedBadServiceProvider(t *testing.T) { test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { + GetSessionFunc: func(_ http.ResponseWriter, _ *http.Request, _ *IdpAuthnRequest) *Session { return &Session{ ID: "f00df00df00d", UserName: "alice", @@ -849,7 +858,7 @@ func TestIDPIDPInitiatedBadServiceProvider(t *testing.T) { func TestIDPCanHandleUnencryptedResponse(t *testing.T) { test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { + GetSessionFunc: func(_ http.ResponseWriter, _ *http.Request, _ *IdpAuthnRequest) *Session { return &Session{ID: "f00df00df00d", UserName: "alice"} }, } @@ -860,7 +869,7 @@ func TestIDPCanHandleUnencryptedResponse(t *testing.T) { &metadata) assert.Check(t, err) test.IDP.ServiceProviderProvider = &mockServiceProviderProvider{ - GetServiceProviderFunc: func(r *http.Request, serviceProviderID string) (*EntityDescriptor, error) { + GetServiceProviderFunc: func(_ *http.Request, serviceProviderID string) (*EntityDescriptor, error) { if serviceProviderID == "https://gitlab.example.com/users/saml/metadata" { return &metadata, nil } @@ -976,6 +985,17 @@ func TestIDPRequestedAttributes(t *testing.T) { }, }, }, + { + FriendlyName: "mail", + Name: "urn:oid:0.9.2342.19200300.100.1.3", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + Values: []AttributeValue{ + { + Type: "xs:string", + Value: "alice@example.com", + }, + }, + }, { FriendlyName: "eduPersonPrincipalName", Name: "urn:oid:1.3.6.1.4.1.5923.1.1.1.6", @@ -1020,14 +1040,15 @@ func TestIDPRequestedAttributes(t *testing.T) { }, }, }, - }}} + }, + }} assert.Check(t, is.DeepEqual(expectedAttributes, req.Assertion.AttributeStatements)) } func TestIDPNoDestination(t *testing.T) { test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { + GetSessionFunc: func(_ http.ResponseWriter, _ *http.Request, _ *IdpAuthnRequest) *Session { return &Session{ID: "f00df00df00d", UserName: "alice"} }, } @@ -1036,7 +1057,7 @@ func TestIDPNoDestination(t *testing.T) { err := xml.Unmarshal(golden.Get(t, "TestIDPNoDestination_idp_metadata.xml"), &metadata) assert.Check(t, err) test.IDP.ServiceProviderProvider = &mockServiceProviderProvider{ - GetServiceProviderFunc: func(r *http.Request, serviceProviderID string) (*EntityDescriptor, error) { + GetServiceProviderFunc: func(_ *http.Request, serviceProviderID string) (*EntityDescriptor, error) { if serviceProviderID == "https://gitlab.example.com/users/saml/metadata" { return &metadata, nil } @@ -1067,9 +1088,10 @@ func TestIDPNoDestination(t *testing.T) { func TestIDPRejectDecompressionBomb(t *testing.T) { test := NewIdentityProviderTest(t) test.IDP.SessionProvider = &mockSessionProvider{ - GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { - fmt.Fprintf(w, "RelayState: %s\nSAMLRequest: %s", + GetSessionFunc: func(w http.ResponseWriter, _ *http.Request, req *IdpAuthnRequest) *Session { + _, err := fmt.Fprintf(w, "RelayState: %s\nSAMLRequest: %s", req.RelayState, req.RequestBuffer) + assert.NilError(t, err) return nil }, } diff --git a/metadata.go b/metadata.go index 006a9e67..d160ccc3 100644 --- a/metadata.go +++ b/metadata.go @@ -38,6 +38,55 @@ type EntitiesDescriptor struct { EntityDescriptors []EntityDescriptor `xml:"urn:oasis:names:tc:SAML:2.0:metadata EntityDescriptor"` } +// MarshalXML implements xml.Marshaler +func (m EntitiesDescriptor) MarshalXML(e *xml.Encoder, _ xml.StartElement) error { + var validUntil *RelaxedTime + var cacheDuration *Duration + if m.ValidUntil != nil { + vu := RelaxedTime(*m.ValidUntil) + validUntil = &vu + } + if m.CacheDuration != nil { + cd := Duration(*m.CacheDuration) + cacheDuration = &cd + } + type Alias EntitiesDescriptor + aux := &struct { + ValidUntil *RelaxedTime `xml:"validUntil,attr,omitempty"` + CacheDuration *Duration `xml:"cacheDuration,attr,omitempty"` + *Alias + }{ + ValidUntil: validUntil, + CacheDuration: cacheDuration, + Alias: (*Alias)(&m), + } + return e.Encode(aux) +} + +// UnmarshalXML implements xml.Unmarshaler +func (m *EntitiesDescriptor) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { + type Alias EntitiesDescriptor + aux := &struct { + ValidUntil *RelaxedTime `xml:"validUntil,attr,omitempty"` + CacheDuration *Duration `xml:"cacheDuration,attr,omitempty"` + *Alias + }{ + Alias: (*Alias)(m), + } + if err := d.DecodeElement(aux, &start); err != nil { + return err + } + if aux.ValidUntil != nil { + t := time.Time(*aux.ValidUntil) + m.ValidUntil = &t + } + if aux.CacheDuration != nil { + d := time.Duration(*aux.CacheDuration) + m.CacheDuration = &d + } + return nil +} + // Metadata as been renamed to EntityDescriptor // // This change was made to be consistent with the rest of the API which uses names diff --git a/saml.go b/saml.go index f5d4f4e0..3638548a 100644 --- a/saml.go +++ b/saml.go @@ -149,7 +149,7 @@ // // This package supports the Web SSO profile. Message flows from the service provider to the IDP are supported using the HTTP Redirect binding and the HTTP POST binding. Message flows from the IDP to the service provider are supported via the HTTP POST binding. // -// The package can produce signed SAML assertions, and can validate both signed and encrypted SAML assertions. It does not support signed or encrypted requests. +// The package can produce signed SAML assertions, and can validate both signed and encrypted SAML assertions. // // # RelayState // diff --git a/samlidp/samlidp.go b/samlidp/samlidp.go index b4bda685..25933457 100644 --- a/samlidp/samlidp.go +++ b/samlidp/samlidp.go @@ -5,25 +5,25 @@ package samlidp import ( "crypto" "crypto/x509" + "html/template" "net/http" "net/url" - "regexp" + "strings" "sync" - "github.com/zenazn/goji/web" - "github.com/clerk/saml" "github.com/clerk/saml/logger" ) // Options represent the parameters to New() for creating a new IDP server type Options struct { - URL url.URL - Key crypto.PrivateKey - Signer crypto.Signer - Logger logger.Interface - Certificate *x509.Certificate - Store Store + URL url.URL + Key crypto.PrivateKey + Signer crypto.Signer + Logger logger.Interface + Certificate *x509.Certificate + Store Store + LoginFormTemplate *template.Template } // Server represents an IDP server. The server provides the following URLs: @@ -38,19 +38,24 @@ type Options struct { // /shortcuts - RESTful interface to Shortcut objects type Server struct { http.Handler - idpConfigMu sync.RWMutex // protects calls into the IDP - logger logger.Interface - serviceProviders map[string]*saml.EntityDescriptor - IDP saml.IdentityProvider // the underlying IDP - Store Store // the data store + idpConfigMu sync.RWMutex // protects calls into the IDP + logger logger.Interface + serviceProviders map[string]*saml.EntityDescriptor + IDP saml.IdentityProvider // the underlying IDP + Store Store // the data store + LoginFormTemplate *template.Template } // New returns a new Server func New(opts Options) (*Server, error) { + opts.URL.Path = strings.TrimSuffix(opts.URL.Path, "/") + metadataURL := opts.URL metadataURL.Path += "/metadata" ssoURL := opts.URL ssoURL.Path += "/sso" + loginURL := opts.URL + loginURL.Path += "/login" logr := opts.Logger if logr == nil { logr = logger.DefaultLogger @@ -65,9 +70,11 @@ func New(opts Options) (*Server, error) { Certificate: opts.Certificate, MetadataURL: metadataURL, SSOURL: ssoURL, + LoginURL: loginURL, }, - logger: logr, - Store: opts.Store, + logger: logr, + Store: opts.Store, + LoginFormTemplate: opts.LoginFormTemplate, } s.IDP.SessionProvider = s @@ -84,40 +91,41 @@ func New(opts Options) (*Server, error) { // is called automatically for you by New, but you may need to call it // yourself if you don't create the object using New.) func (s *Server) InitializeHTTP() { - mux := web.New() + mux := http.NewServeMux() s.Handler = mux - mux.Get("/metadata", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("GET /metadata", func(w http.ResponseWriter, r *http.Request) { s.idpConfigMu.RLock() defer s.idpConfigMu.RUnlock() s.IDP.ServeMetadata(w, r) }) - mux.Handle("/sso", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/sso", func(w http.ResponseWriter, r *http.Request) { + s.idpConfigMu.RLock() + defer s.idpConfigMu.RUnlock() s.IDP.ServeSSO(w, r) }) - mux.Handle("/login", s.HandleLogin) - mux.Handle("/login/:shortcut", s.HandleIDPInitiated) - mux.Handle("/login/:shortcut/*", s.HandleIDPInitiated) - - mux.Get("/services/", s.HandleListServices) - mux.Get("/services/:id", s.HandleGetService) - mux.Put("/services/:id", s.HandlePutService) - mux.Post("/services/:id", s.HandlePutService) - mux.Delete("/services/:id", s.HandleDeleteService) - - mux.Get("/users/", s.HandleListUsers) - mux.Get("/users/:id", s.HandleGetUser) - mux.Put("/users/:id", s.HandlePutUser) - mux.Delete("/users/:id", s.HandleDeleteUser) - - sessionPath := regexp.MustCompile("/sessions/(?P{{.Toast}}
` + - `` + - ``)) +func (s *Server) sendLoginForm(w http.ResponseWriter, req *saml.IdpAuthnRequest, toast string) { + tmpl := s.LoginFormTemplate + if tmpl == nil { + tmpl = defaultLoginFormTemplate + } data := struct { Toast string URL string @@ -122,7 +129,7 @@ func (s *Server) sendLoginForm(w http.ResponseWriter, _ *http.Request, req *saml RelayState string }{ Toast: toast, - URL: req.IDP.SSOURL.String(), + URL: req.IDP.LoginURL.String(), SAMLRequest: base64.StdEncoding.EncodeToString(req.RequestBuffer), RelayState: req.RelayState, } @@ -136,7 +143,7 @@ func (s *Server) sendLoginForm(w http.ResponseWriter, _ *http.Request, req *saml // in the request body, then they are validated. For valid credentials, the response is a // 200 OK and the JSON session object. For invalid credentials, the HTML login prompt form // is sent. -func (s *Server) HandleLogin(_ web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleLogin(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return @@ -153,7 +160,7 @@ func (s *Server) HandleLogin(_ web.C, w http.ResponseWriter, r *http.Request) { // HandleListSessions handles the `GET /sessions/` request and responds with a JSON formatted list // of session names. -func (s *Server) HandleListSessions(_ web.C, w http.ResponseWriter, _ *http.Request) { +func (s *Server) HandleListSessions(w http.ResponseWriter, _ *http.Request) { sessions, err := s.Store.List("/sessions/") if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -171,9 +178,9 @@ func (s *Server) HandleListSessions(_ web.C, w http.ResponseWriter, _ *http.Requ // HandleGetSession handles the `GET /sessions/:id` request and responds with the session // object in JSON format. -func (s *Server) HandleGetSession(c web.C, w http.ResponseWriter, _ *http.Request) { +func (s *Server) HandleGetSession(w http.ResponseWriter, r *http.Request) { session := saml.Session{} - err := s.Store.Get(fmt.Sprintf("/sessions/%s", c.URLParams["id"]), &session) + err := s.Store.Get(fmt.Sprintf("/sessions/%s", r.PathValue("id")), &session) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return @@ -186,8 +193,8 @@ func (s *Server) HandleGetSession(c web.C, w http.ResponseWriter, _ *http.Reques // HandleDeleteSession handles the `DELETE /sessions/:id` request. It invalidates the // specified session. -func (s *Server) HandleDeleteSession(c web.C, w http.ResponseWriter, _ *http.Request) { - err := s.Store.Delete(fmt.Sprintf("/sessions/%s", c.URLParams["id"])) +func (s *Server) HandleDeleteSession(w http.ResponseWriter, r *http.Request) { + err := s.Store.Delete(fmt.Sprintf("/sessions/%s", r.PathValue("id"))) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return diff --git a/samlidp/session_test.go b/samlidp/session_test.go index cb3d5c3d..6244f923 100644 --- a/samlidp/session_test.go +++ b/samlidp/session_test.go @@ -63,4 +63,27 @@ func TestSessionsCrud(t *testing.T) { assert.Check(t, is.Equal("{\"sessions\":[]}\n", w.Body.String())) + // user doesn't exists case + w = httptest.NewRecorder() + r, _ = http.NewRequest("POST", "https://idp.example.com/login", + strings.NewReader("user=unknown&password=dummypassword")) + r.Header.Set("Content-type", "application/x-www-form-urlencoded") + test.Server.ServeHTTP(w, r) + assert.Check(t, is.Equal(http.StatusOK, w.Code)) + assert.Check(t, is.Equal("text/html; charset=utf-8", + w.Header().Get("Content-type"))) + assert.Check(t, is.Equal(`Invalid username or password
`, + w.Body.String())) + + // invalid username/password exists case + w = httptest.NewRecorder() + r, _ = http.NewRequest("POST", "https://idp.example.com/login", + strings.NewReader("user=alice&password=dummypassword")) + r.Header.Set("Content-type", "application/x-www-form-urlencoded") + test.Server.ServeHTTP(w, r) + assert.Check(t, is.Equal(http.StatusOK, w.Code)) + assert.Check(t, is.Equal("text/html; charset=utf-8", + w.Header().Get("Content-type"))) + assert.Check(t, is.Equal(`Invalid username or password
`, + w.Body.String())) } diff --git a/samlidp/shortcut.go b/samlidp/shortcut.go index 192e5022..f08efe7d 100644 --- a/samlidp/shortcut.go +++ b/samlidp/shortcut.go @@ -4,8 +4,6 @@ import ( "encoding/json" "fmt" "net/http" - - "github.com/zenazn/goji/web" ) // Shortcut represents an IDP-initiated SAML flow. When a user @@ -31,7 +29,7 @@ type Shortcut struct { // HandleListShortcuts handles the `GET /shortcuts/` request and responds with a JSON formatted list // of shortcut names. -func (s *Server) HandleListShortcuts(_ web.C, w http.ResponseWriter, _ *http.Request) { +func (s *Server) HandleListShortcuts(w http.ResponseWriter, _ *http.Request) { shortcuts, err := s.Store.List("/shortcuts/") if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -49,9 +47,9 @@ func (s *Server) HandleListShortcuts(_ web.C, w http.ResponseWriter, _ *http.Req // HandleGetShortcut handles the `GET /shortcuts/:id` request and responds with the shortcut // object in JSON format. -func (s *Server) HandleGetShortcut(c web.C, w http.ResponseWriter, _ *http.Request) { +func (s *Server) HandleGetShortcut(w http.ResponseWriter, r *http.Request) { shortcut := Shortcut{} - err := s.Store.Get(fmt.Sprintf("/shortcuts/%s", c.URLParams["id"]), &shortcut) + err := s.Store.Get(fmt.Sprintf("/shortcuts/%s", r.PathValue("id")), &shortcut) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return @@ -64,15 +62,15 @@ func (s *Server) HandleGetShortcut(c web.C, w http.ResponseWriter, _ *http.Reque // HandlePutShortcut handles the `PUT /shortcuts/:id` request. It accepts a JSON formatted // shortcut object in the request body and stores it. -func (s *Server) HandlePutShortcut(c web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandlePutShortcut(w http.ResponseWriter, r *http.Request) { shortcut := Shortcut{} if err := json.NewDecoder(r.Body).Decode(&shortcut); err != nil { http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } - shortcut.Name = c.URLParams["id"] + shortcut.Name = r.PathValue("id") - err := s.Store.Put(fmt.Sprintf("/shortcuts/%s", c.URLParams["id"]), &shortcut) + err := s.Store.Put(fmt.Sprintf("/shortcuts/%s", r.PathValue("id")), &shortcut) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return @@ -81,8 +79,8 @@ func (s *Server) HandlePutShortcut(c web.C, w http.ResponseWriter, r *http.Reque } // HandleDeleteShortcut handles the `DELETE /shortcuts/:id` request. -func (s *Server) HandleDeleteShortcut(c web.C, w http.ResponseWriter, _ *http.Request) { - err := s.Store.Delete(fmt.Sprintf("/shortcuts/%s", c.URLParams["id"])) +func (s *Server) HandleDeleteShortcut(w http.ResponseWriter, r *http.Request) { + err := s.Store.Delete(fmt.Sprintf("/shortcuts/%s", r.PathValue("id"))) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return @@ -93,8 +91,8 @@ func (s *Server) HandleDeleteShortcut(c web.C, w http.ResponseWriter, _ *http.Re // HandleIDPInitiated handles a request for an IDP initiated login flow. It looks up // the specified shortcut, generates the appropriate SAML assertion and redirects the // user via the HTTP-POST binding to the service providers ACS URL. -func (s *Server) HandleIDPInitiated(c web.C, w http.ResponseWriter, r *http.Request) { - shortcutName := c.URLParams["shortcut"] +func (s *Server) HandleIDPInitiated(w http.ResponseWriter, r *http.Request) { + shortcutName := r.PathValue("shortcut") shortcut := Shortcut{} if err := s.Store.Get(fmt.Sprintf("/shortcuts/%s", shortcutName), &shortcut); err != nil { s.logger.Printf("ERROR: %s", err) @@ -107,7 +105,9 @@ func (s *Server) HandleIDPInitiated(c web.C, w http.ResponseWriter, r *http.Requ case shortcut.RelayState != nil: relayState = *shortcut.RelayState case shortcut.URISuffixAsRelayState: - relayState = c.URLParams["*"] + if suffix := r.PathValue("suffix"); suffix != "" { + relayState = "/" + suffix + } } s.IDP.ServeIDPInitiated(w, r, shortcut.ServiceProviderID, relayState) diff --git a/samlidp/testdata/http_sso_response.html b/samlidp/testdata/http_sso_response.html index ea0f60f6..47db428a 100644 --- a/samlidp/testdata/http_sso_response.html +++ b/samlidp/testdata/http_sso_response.html @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/samlidp/user.go b/samlidp/user.go index c8c412cb..de08ede8 100644 --- a/samlidp/user.go +++ b/samlidp/user.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" - "github.com/zenazn/goji/web" "golang.org/x/crypto/bcrypt" ) @@ -25,7 +24,7 @@ type User struct { // HandleListUsers handles the `GET /users/` request and responds with a JSON formatted list // of user names. -func (s *Server) HandleListUsers(_ web.C, w http.ResponseWriter, _ *http.Request) { +func (s *Server) HandleListUsers(w http.ResponseWriter, _ *http.Request) { users, err := s.Store.List("/users/") if err != nil { s.logger.Printf("ERROR: %s", err) @@ -45,9 +44,9 @@ func (s *Server) HandleListUsers(_ web.C, w http.ResponseWriter, _ *http.Request // HandleGetUser handles the `GET /users/:id` request and responds with the user object in JSON // format. The HashedPassword field is excluded. -func (s *Server) HandleGetUser(c web.C, w http.ResponseWriter, _ *http.Request) { +func (s *Server) HandleGetUser(w http.ResponseWriter, r *http.Request) { user := User{} - err := s.Store.Get(fmt.Sprintf("/users/%s", c.URLParams["id"]), &user) + err := s.Store.Get(fmt.Sprintf("/users/%s", r.PathValue("id")), &user) if err != nil { s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -65,14 +64,14 @@ func (s *Server) HandleGetUser(c web.C, w http.ResponseWriter, _ *http.Request) // the request body and stores it. If the PlaintextPassword field is present then it is hashed // and stored in HashedPassword. If the PlaintextPassword field is not present then // HashedPassword retains it's stored value. -func (s *Server) HandlePutUser(c web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandlePutUser(w http.ResponseWriter, r *http.Request) { user := User{} if err := json.NewDecoder(r.Body).Decode(&user); err != nil { s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } - user.Name = c.URLParams["id"] + user.Name = r.PathValue("id") if user.PlaintextPassword != nil { var err error @@ -84,11 +83,11 @@ func (s *Server) HandlePutUser(c web.C, w http.ResponseWriter, r *http.Request) } } else { existingUser := User{} - err := s.Store.Get(fmt.Sprintf("/users/%s", c.URLParams["id"]), &existingUser) - switch { - case err == nil: + err := s.Store.Get(fmt.Sprintf("/users/%s", r.PathValue("id")), &existingUser) + switch err { + case nil: user.HashedPassword = existingUser.HashedPassword - case err == ErrNotFound: + case ErrNotFound: // nop default: s.logger.Printf("ERROR: %s", err) @@ -98,7 +97,7 @@ func (s *Server) HandlePutUser(c web.C, w http.ResponseWriter, r *http.Request) } user.PlaintextPassword = nil - err := s.Store.Put(fmt.Sprintf("/users/%s", c.URLParams["id"]), &user) + err := s.Store.Put(fmt.Sprintf("/users/%s", r.PathValue("id")), &user) if err != nil { s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -108,8 +107,8 @@ func (s *Server) HandlePutUser(c web.C, w http.ResponseWriter, r *http.Request) } // HandleDeleteUser handles the `DELETE /users/:id` request. -func (s *Server) HandleDeleteUser(c web.C, w http.ResponseWriter, _ *http.Request) { - err := s.Store.Delete(fmt.Sprintf("/users/%s", c.URLParams["id"])) +func (s *Server) HandleDeleteUser(w http.ResponseWriter, r *http.Request) { + err := s.Store.Delete(fmt.Sprintf("/users/%s", r.PathValue("id"))) if err != nil { s.logger.Printf("ERROR: %s", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) diff --git a/samlsp/basic_assertion_handler.go b/samlsp/basic_assertion_handler.go new file mode 100644 index 00000000..a1b86958 --- /dev/null +++ b/samlsp/basic_assertion_handler.go @@ -0,0 +1,15 @@ +package samlsp + +import ( + "github.com/clerk/saml" +) + +var _ AssertionHandler = NopAssertionHandler{} + +// NopAssertionHandler is an implementation of AssertionHandler that does nothing. +type NopAssertionHandler struct{} + +// HandleAssertion is called and passed a SAML assertion. This implementation does nothing. +func (as NopAssertionHandler) HandleAssertion(_ *saml.Assertion) error { + return nil +} diff --git a/samlsp/fetch_metadata.go b/samlsp/fetch_metadata.go index d9b82591..0b6b6c31 100644 --- a/samlsp/fetch_metadata.go +++ b/samlsp/fetch_metadata.go @@ -5,11 +5,11 @@ import ( "context" "encoding/xml" "errors" + "fmt" "io" "net/http" "net/url" - "github.com/crewjam/httperr" xrv "github.com/mattermost/xml-roundtrip-validator" "github.com/clerk/saml/logger" @@ -69,7 +69,7 @@ func FetchMetadata(ctx context.Context, httpClient *http.Client, metadataURL url } }() if resp.StatusCode >= 400 { - return nil, httperr.Response(*resp) + return nil, fmt.Errorf("failed to fetch metadata: unexpected status code %d", resp.StatusCode) } data, err := io.ReadAll(resp.Body) diff --git a/samlsp/middleware.go b/samlsp/middleware.go index 4392fbf3..e9e490ce 100644 --- a/samlsp/middleware.go +++ b/samlsp/middleware.go @@ -40,12 +40,13 @@ import ( // SAML service provider already has a private key, we borrow that key // to sign the JWTs as well. type Middleware struct { - ServiceProvider saml.ServiceProvider - OnError func(w http.ResponseWriter, r *http.Request, err error) - Binding string // either saml.HTTPPostBinding or saml.HTTPRedirectBinding - ResponseBinding string // either saml.HTTPPostBinding or saml.HTTPArtifactBinding - RequestTracker RequestTracker - Session SessionProvider + ServiceProvider saml.ServiceProvider + OnError func(w http.ResponseWriter, r *http.Request, err error) + Binding string // either saml.HTTPPostBinding or saml.HTTPRedirectBinding + ResponseBinding string // either saml.HTTPPostBinding or saml.HTTPArtifactBinding + RequestTracker RequestTracker + Session SessionProvider + AssertionHandler AssertionHandler } // ServeHTTP implements http.Handler and serves the SAML-specific HTTP endpoints @@ -99,6 +100,11 @@ func (m *Middleware) ServeACS(w http.ResponseWriter, r *http.Request) { return } + if handlerErr := m.AssertionHandler.HandleAssertion(assertion); handlerErr != nil { + m.OnError(w, r, handlerErr) + return + } + m.CreateSessionFromAssertion(w, r, assertion, m.ServiceProvider.DefaultRedirectURI) } @@ -226,11 +232,6 @@ func (m *Middleware) CreateSessionFromAssertion(w http.ResponseWriter, r *http.R // SAML attribute `name` be set to `value`. This can be used to require // that a remote user be a member of a group. It relies on the Claims assigned // to to the context in RequireAccount. -// -// For example: -// -// goji.Use(m.RequireAccount) -// goji.Use(RequireAttributeMiddleware("eduPersonAffiliation", "Staff")) func RequireAttribute(name, value string) func(http.Handler) http.Handler { return func(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/samlsp/middleware_test.go b/samlsp/middleware_test.go index 14d32e45..6bc32fff 100644 --- a/samlsp/middleware_test.go +++ b/samlsp/middleware_test.go @@ -17,7 +17,6 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt/v4" dsig "github.com/russellhaering/goxmldsig" "gotest.tools/assert" is "gotest.tools/assert/cmp" @@ -55,7 +54,6 @@ func NewMiddlewareTest(t *testing.T) *MiddlewareTest { rv, _ := time.Parse("Mon Jan 2 15:04:05.999999999 MST 2006", "Mon Dec 1 01:57:09.123456789 UTC 2015") return rv } - jwt.TimeFunc = saml.TimeNow saml.Clock = dsig.NewFakeClockAt(saml.TimeNow()) saml.RandReader = &testRandomReader{} @@ -150,7 +148,7 @@ func TestMiddlewareRequireAccountNoCreds(t *testing.T) { test.Middleware.ServiceProvider.AcsURL.Scheme = "http" handler := test.Middleware.RequireAccount( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { panic("not reached") })) @@ -174,7 +172,7 @@ func TestMiddlewareRequireAccountNoCredsSecure(t *testing.T) { test := NewMiddlewareTest(t) handler := test.Middleware.RequireAccount( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { panic("not reached") })) @@ -200,7 +198,7 @@ func TestMiddlewareRequireAccountNoCredsPostBinding(t *testing.T) { test.Middleware.ServiceProvider.GetSSOBindingLocation(saml.HTTPRedirectBinding))) handler := test.Middleware.RequireAccount( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { panic("not reached") })) @@ -256,7 +254,7 @@ func TestMiddlewareRequireAccountCreds(t *testing.T) { func TestMiddlewareRequireAccountBadCreds(t *testing.T) { test := NewMiddlewareTest(t) handler := test.Middleware.RequireAccount( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { panic("not reached") })) @@ -281,13 +279,13 @@ func TestMiddlewareRequireAccountBadCreds(t *testing.T) { func TestMiddlewareRequireAccountExpiredCreds(t *testing.T) { test := NewMiddlewareTest(t) - jwt.TimeFunc = func() time.Time { + saml.TimeNow = func() time.Time { rv, _ := time.Parse("Mon Jan 2 15:04:05 UTC 2006", "Mon Dec 1 01:31:21 UTC 2115") return rv } handler := test.Middleware.RequireAccount( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { panic("not reached") })) @@ -306,13 +304,13 @@ func TestMiddlewareRequireAccountExpiredCreds(t *testing.T) { assert.Check(t, err) decodedRequest, err := testsaml.ParseRedirectRequest(redirectURL) assert.Check(t, err) - golden.Assert(t, string(decodedRequest), "expected_authn_request_secure.xml") + golden.Assert(t, strings.Replace(string(decodedRequest), `IssueInstant="2115-12-01T01:31:21Z"`, `IssueInstant="2015-12-01T01:57:09.123Z"`, 1), "expected_authn_request_secure.xml") } func TestMiddlewareRequireAccountPanicOnRequestToACS(t *testing.T) { test := NewMiddlewareTest(t) handler := test.Middleware.RequireAccount( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { panic("not reached") })) @@ -326,7 +324,7 @@ func TestMiddlewareRequireAttribute(t *testing.T) { test := NewMiddlewareTest(t) handler := test.Middleware.RequireAccount( RequireAttribute("eduPersonAffiliation", "Staff")( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusTeapot) }))) @@ -344,7 +342,7 @@ func TestMiddlewareRequireAttributeWrongValue(t *testing.T) { test := NewMiddlewareTest(t) handler := test.Middleware.RequireAccount( RequireAttribute("eduPersonAffiliation", "DomainAdmins")( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { panic("not reached") }))) @@ -362,7 +360,7 @@ func TestMiddlewareRequireAttributeNotPresent(t *testing.T) { test := NewMiddlewareTest(t) handler := test.Middleware.RequireAccount( RequireAttribute("valueThatDoesntExist", "doesntMatter")( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { panic("not reached") }))) @@ -379,7 +377,7 @@ func TestMiddlewareRequireAttributeNotPresent(t *testing.T) { func TestMiddlewareRequireAttributeMissingAccount(t *testing.T) { test := NewMiddlewareTest(t) handler := RequireAttribute("eduPersonAffiliation", "DomainAdmins")( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { panic("not reached") })) @@ -409,9 +407,10 @@ func TestMiddlewareCanParseResponse(t *testing.T) { assert.Check(t, is.Equal("/frob", resp.Header().Get("Location"))) assert.Check(t, is.DeepEqual([]string{ - "saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6=; Domain=15661444.ngrok.io; Expires=Thu, 01 Jan 1970 00:00:01 GMT", + "saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6=; Path=/saml2/acs; Domain=15661444.ngrok.io; Expires=Thu, 01 Jan 1970 00:00:01 GMT", "ttt=" + test.expectedSessionCookie + "; " + - "Path=/; Domain=15661444.ngrok.io; Max-Age=7200; HttpOnly; Secure"}, + "Path=/; Domain=15661444.ngrok.io; Max-Age=7200; HttpOnly; Secure", + }, resp.Header()["Set-Cookie"])) } @@ -455,7 +454,7 @@ func TestMiddlewareDefaultCookieDomainIPv6(t *testing.T) { func TestMiddlewareRejectsInvalidRelayState(t *testing.T) { test := NewMiddlewareTest(t) - test.Middleware.OnError = func(w http.ResponseWriter, r *http.Request, err error) { + test.Middleware.OnError = func(w http.ResponseWriter, _ *http.Request, err error) { assert.Check(t, is.Error(err, http.ErrNoCookie.Error())) http.Error(w, "forbidden", http.StatusTeapot) } @@ -478,7 +477,7 @@ func TestMiddlewareRejectsInvalidRelayState(t *testing.T) { func TestMiddlewareRejectsInvalidCookie(t *testing.T) { test := NewMiddlewareTest(t) - test.Middleware.OnError = func(w http.ResponseWriter, r *http.Request, err error) { + test.Middleware.OnError = func(w http.ResponseWriter, _ *http.Request, err error) { assert.Check(t, is.Error(err, "Authentication failed")) http.Error(w, "forbidden", http.StatusTeapot) } diff --git a/samlsp/new.go b/samlsp/new.go index ec8934c6..39827e0c 100644 --- a/samlsp/new.go +++ b/samlsp/new.go @@ -2,11 +2,15 @@ package samlsp import ( + "crypto" + "crypto/ecdsa" "crypto/rsa" "crypto/x509" + "fmt" "net/http" "net/url" + "github.com/golang-jwt/jwt/v5" dsig "github.com/russellhaering/goxmldsig" "github.com/clerk/saml" @@ -16,7 +20,7 @@ import ( type Options struct { EntityID string URL url.URL - Key *rsa.PrivateKey + Key crypto.Signer Certificate *x509.Certificate Intermediates []*x509.Certificate HTTPClient *http.Client @@ -33,11 +37,23 @@ type Options struct { LogoutBindings []string } +func getDefaultSigningMethod(signer crypto.Signer) jwt.SigningMethod { + if signer != nil { + switch signer.Public().(type) { + case *ecdsa.PublicKey: + return jwt.SigningMethodES256 + case *rsa.PublicKey: + return jwt.SigningMethodRS256 + } + } + return jwt.SigningMethodRS256 +} + // DefaultSessionCodec returns the default SessionCodec for the provided options, // a JWTSessionCodec configured to issue signed tokens. func DefaultSessionCodec(opts Options) JWTSessionCodec { return JWTSessionCodec{ - SigningMethod: defaultJWTSigningMethod, + SigningMethod: getDefaultSigningMethod(opts.Key), Audience: opts.URL.String(), Issuer: opts.URL.String(), MaxAge: defaultSessionMaxAge, @@ -67,7 +83,7 @@ func DefaultSessionProvider(opts Options) CookieSessionProvider { // options, a JWTTrackedRequestCodec that uses a JWT to encode TrackedRequests. func DefaultTrackedRequestCodec(opts Options) JWTTrackedRequestCodec { return JWTTrackedRequestCodec{ - SigningMethod: defaultJWTSigningMethod, + SigningMethod: getDefaultSigningMethod(opts.Key), Audience: opts.URL.String(), Issuer: opts.URL.String(), MaxAge: saml.MaxIssueDelay, @@ -99,7 +115,8 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider { if opts.ForceAuthn { forceAuthn = &opts.ForceAuthn } - signatureMethod := dsig.RSASHA1SignatureMethod + + signatureMethod := defaultSigningMethodForKey(opts.Key) if !opts.SignRequest { signatureMethod = "" } @@ -131,6 +148,25 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider { } } +func defaultSigningMethodForKey(key crypto.Signer) string { + switch key.(type) { + case *rsa.PrivateKey: + return dsig.RSASHA1SignatureMethod + case *ecdsa.PrivateKey: + return dsig.ECDSASHA256SignatureMethod + case nil: + return "" + default: + panic(fmt.Sprintf("programming error: unsupported key type %T", key)) + } +} + +// DefaultAssertionHandler returns the default AssertionHandler for the provided options, +// a NopAssertionHandler configured to do nothing. +func DefaultAssertionHandler(_ Options) NopAssertionHandler { + return NopAssertionHandler{} +} + // New creates a new Middleware with the default providers for the // given options. // @@ -139,11 +175,12 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider { // in the returned Middleware. func New(opts Options) (*Middleware, error) { m := &Middleware{ - ServiceProvider: DefaultServiceProvider(opts), - Binding: "", - ResponseBinding: saml.HTTPPostBinding, - OnError: DefaultOnError, - Session: DefaultSessionProvider(opts), + ServiceProvider: DefaultServiceProvider(opts), + Binding: "", + ResponseBinding: saml.HTTPPostBinding, + OnError: DefaultOnError, + Session: DefaultSessionProvider(opts), + AssertionHandler: DefaultAssertionHandler(opts), } m.RequestTracker = DefaultRequestTracker(opts, &m.ServiceProvider) if opts.UseArtifactResponse { diff --git a/samlsp/new_test.go b/samlsp/new_test.go index 7f0a760a..86eb49ee 100644 --- a/samlsp/new_test.go +++ b/samlsp/new_test.go @@ -3,7 +3,6 @@ package samlsp import ( "testing" - "github.com/stretchr/testify/require" "gotest.tools/assert" ) @@ -24,7 +23,7 @@ func TestNewCanAcceptCookieName(t *testing.T) { CookieName: tc.cookieName, } sp, err := New(opts) - require.Nil(t, err) + assert.Assert(t, err) cookieProvider := sp.Session.(CookieSessionProvider) assert.Equal(t, tc.expected, cookieProvider.Name) diff --git a/samlsp/request_tracker_cookie.go b/samlsp/request_tracker_cookie.go index 9a420c20..07dab43b 100644 --- a/samlsp/request_tracker_cookie.go +++ b/samlsp/request_tracker_cookie.go @@ -67,6 +67,7 @@ func (t CookieRequestTracker) StopTrackingRequest(w http.ResponseWriter, r *http cookie.Value = "" cookie.Domain = t.ServiceProvider.AcsURL.Hostname() cookie.Expires = time.Unix(1, 0) // past time as close to epoch as possible, but not zero time.Time{} + cookie.Path = t.ServiceProvider.AcsURL.Path http.SetCookie(w, cookie) return nil } diff --git a/samlsp/request_tracker_jwt.go b/samlsp/request_tracker_jwt.go index 906702c0..6ba4616a 100644 --- a/samlsp/request_tracker_jwt.go +++ b/samlsp/request_tracker_jwt.go @@ -1,24 +1,22 @@ package samlsp import ( - "crypto/rsa" + "crypto" "fmt" "time" - "github.com/golang-jwt/jwt/v4" + "github.com/golang-jwt/jwt/v5" "github.com/clerk/saml" ) -var defaultJWTSigningMethod = jwt.SigningMethodRS256 - // JWTTrackedRequestCodec encodes TrackedRequests as signed JWTs type JWTTrackedRequestCodec struct { SigningMethod jwt.SigningMethod Audience string Issuer string MaxAge time.Duration - Key *rsa.PrivateKey + Key crypto.Signer } var _ TrackedRequestCodec = JWTTrackedRequestCodec{} @@ -51,9 +49,12 @@ func (s JWTTrackedRequestCodec) Encode(value TrackedRequest) (string, error) { // Decode returns a Tracked request from an encoded string. func (s JWTTrackedRequestCodec) Decode(signed string) (*TrackedRequest, error) { - parser := jwt.Parser{ - ValidMethods: []string{s.SigningMethod.Alg()}, - } + parser := jwt.NewParser( + jwt.WithValidMethods([]string{s.SigningMethod.Alg()}), + jwt.WithTimeFunc(saml.TimeNow), + jwt.WithAudience(s.Audience), + jwt.WithIssuer(s.Issuer), + ) claims := JWTTrackedRequestClaims{} _, err := parser.ParseWithClaims(signed, &claims, func(*jwt.Token) (interface{}, error) { return s.Key.Public(), nil @@ -61,15 +62,9 @@ func (s JWTTrackedRequestCodec) Decode(signed string) (*TrackedRequest, error) { if err != nil { return nil, err } - if !claims.VerifyAudience(s.Audience, true) { - return nil, fmt.Errorf("expected audience %q, got %q", s.Audience, claims.Audience) - } - if !claims.VerifyIssuer(s.Issuer, true) { - return nil, fmt.Errorf("expected issuer %q, got %q", s.Issuer, claims.Issuer) - } if !claims.SAMLAuthnRequest { return nil, fmt.Errorf("expected saml-authn-request") } - claims.TrackedRequest.Index = claims.Subject + claims.Index = claims.Subject return &claims.TrackedRequest, nil } diff --git a/samlsp/saml_assertion_handler.go b/samlsp/saml_assertion_handler.go new file mode 100644 index 00000000..a18f00d1 --- /dev/null +++ b/samlsp/saml_assertion_handler.go @@ -0,0 +1,9 @@ +package samlsp + +import "github.com/clerk/saml" + +// AssertionHandler is an interface implemented by types that can handle +// assertions and add extra functionality +type AssertionHandler interface { + HandleAssertion(assertion *saml.Assertion) error +} diff --git a/samlsp/session_cookie.go b/samlsp/session_cookie.go index 17f9be6a..4e44fc67 100644 --- a/samlsp/session_cookie.go +++ b/samlsp/session_cookie.go @@ -21,6 +21,7 @@ type CookieSessionProvider struct { Secure bool SameSite http.SameSite MaxAge time.Duration + Path string Codec SessionCodec } @@ -43,6 +44,11 @@ func (c CookieSessionProvider) CreateSession(w http.ResponseWriter, r *http.Requ return err } + path := c.Path + if path == "" { + path = "/" + } + http.SetCookie(w, &http.Cookie{ Name: c.Name, Domain: c.Domain, @@ -51,7 +57,7 @@ func (c CookieSessionProvider) CreateSession(w http.ResponseWriter, r *http.Requ HttpOnly: c.HTTPOnly, Secure: c.Secure || r.URL.Scheme == "https", SameSite: c.SameSite, - Path: "/", + Path: path, }) return nil } diff --git a/samlsp/session_jwt.go b/samlsp/session_jwt.go index def76908..8760237d 100644 --- a/samlsp/session_jwt.go +++ b/samlsp/session_jwt.go @@ -1,12 +1,11 @@ package samlsp import ( - "crypto/rsa" + "crypto" "errors" - "fmt" "time" - "github.com/golang-jwt/jwt/v4" + "github.com/golang-jwt/jwt/v5" "github.com/clerk/saml" ) @@ -23,7 +22,7 @@ type JWTSessionCodec struct { Audience string Issuer string MaxAge time.Duration - Key *rsa.PrivateKey + Key crypto.Signer } var _ SessionCodec = JWTSessionCodec{} @@ -35,11 +34,11 @@ func (c JWTSessionCodec) New(assertion *saml.Assertion) (Session, error) { now := saml.TimeNow() claims := JWTSessionClaims{} claims.SAMLSession = true - claims.Audience = c.Audience + claims.Audience = jwt.ClaimStrings{c.Audience} claims.Issuer = c.Issuer - claims.IssuedAt = now.Unix() - claims.ExpiresAt = now.Add(c.MaxAge).Unix() - claims.NotBefore = now.Unix() + claims.IssuedAt = jwt.NewNumericDate(now) + claims.ExpiresAt = jwt.NewNumericDate(now.Add(c.MaxAge)) + claims.NotBefore = jwt.NewNumericDate(now) if sub := assertion.Subject; sub != nil { if nameID := sub.NameID; nameID != nil { @@ -89,9 +88,12 @@ func (c JWTSessionCodec) Encode(s Session) (string, error) { // Decode parses the serialized session that may have been returned by Encode // and returns a Session. func (c JWTSessionCodec) Decode(signed string) (Session, error) { - parser := jwt.Parser{ - ValidMethods: []string{c.SigningMethod.Alg()}, - } + parser := jwt.NewParser( + jwt.WithValidMethods([]string{c.SigningMethod.Alg()}), + jwt.WithTimeFunc(saml.TimeNow), + jwt.WithAudience(c.Audience), + jwt.WithIssuer(c.Issuer), + ) claims := JWTSessionClaims{} _, err := parser.ParseWithClaims(signed, &claims, func(*jwt.Token) (interface{}, error) { return c.Key.Public(), nil @@ -100,12 +102,6 @@ func (c JWTSessionCodec) Decode(signed string) (Session, error) { if err != nil { return nil, err } - if !claims.VerifyAudience(c.Audience, true) { - return nil, fmt.Errorf("expected audience %q, got %q", c.Audience, claims.Audience) - } - if !claims.VerifyIssuer(c.Issuer, true) { - return nil, fmt.Errorf("expected issuer %q, got %q", c.Issuer, claims.Issuer) - } if !claims.SAMLSession { return nil, errors.New("expected saml-session") } @@ -114,7 +110,7 @@ func (c JWTSessionCodec) Decode(signed string) (Session, error) { // JWTSessionClaims represents the JWT claims in the encoded session type JWTSessionClaims struct { - jwt.StandardClaims + jwt.RegisteredClaims Attributes Attributes `json:"attr"` SAMLSession bool `json:"saml-session"` } diff --git a/schema.go b/schema.go index 23cddbca..1a543de2 100644 --- a/schema.go +++ b/schema.go @@ -353,12 +353,15 @@ func (r *ArtifactResolve) Element() *etree.Element { if r.Issuer != nil { el.AddChild(r.Issuer.Element()) } - artifact := etree.NewElement("samlp:Artifact") - artifact.SetText(r.Artifact) - el.AddChild(artifact) if r.Signature != nil { + // ADFS requires that