Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions rest_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package cas

import (
"fmt"
"github.com/golang/glog"
"github.com/patrickmn/go-cache"
"io/ioutil"
"net/http"
"net/url"
"path"
"time"

"github.com/golang/glog"
"github.com/patrickmn/go-cache"
)

// https://apereo.github.io/cas/4.2.x/protocol/REST-Protocol.html
Expand All @@ -28,6 +29,16 @@ type RestOptions struct {
ForwardUnauthenticatedRESTRequests bool
}

// RestAuthenticator handles the cas authentication via the rest protocol
type RestAuthenticator interface {
Handle(h http.Handler) http.Handler
RequestGrantingTicket(username string, password string) (TicketGrantingTicket, error)
RequestServiceTicket(tgt TicketGrantingTicket) (ServiceTicket, error)
ValidateServiceTicket(st ServiceTicket) (*AuthenticationResponse, error)
Logout(tgt TicketGrantingTicket) error
ShallForwardUnauthenticatedRESTRequests() bool
}

// RestClient uses the rest protocol provided by cas
type RestClient struct {
urlScheme URLScheme
Expand Down Expand Up @@ -182,3 +193,8 @@ func (c *RestClient) Logout(tgt TicketGrantingTicket) error {

return nil
}

// ShallForwardUnauthenticatedRESTRequests specifies if unauthenticated requests shall be forwarded
func (c *RestClient) ShallForwardUnauthenticatedRESTRequests() bool {
return c.forwardUnauthenticatedRESTRequests
}
75 changes: 47 additions & 28 deletions rest_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ import (

// restClientHandler handles CAS REST Protocol over HTTP Basic Authentication
type restClientHandler struct {
c *RestClient
c RestAuthenticator
h http.Handler
cache *cache.Cache
}

type reaction func(http.ResponseWriter, *http.Request)

// ServeHTTP handles HTTP requests, processes HTTP Basic Authentication over CAS Rest api
// and passes requests up to its child http.Handler.
func (ch *restClientHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Expand All @@ -23,45 +25,62 @@ func (ch *restClientHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

username, password, ok := r.BasicAuth()
if !ok {
w.Header().Set("WWW-Authenticate", "Basic realm=\"CAS Protected Area\"")
w.WriteHeader(401)
ch.handleUnauthenticatedRequest(w, r)
return
}

// cache to avoid hitting cas server on every request
// use the authorization header as key and the authenticationResponse as value
authorizationHeader := r.Header.Get("Authorization")
authenticationResponse, keyWasFound := ch.cache.Get(authorizationHeader)
_, keyWasFound := ch.cache.Get(authorizationHeader)
if !keyWasFound {
newAuthenticationResponse, err := ch.authenticate(username, password)
if err != nil {
if glog.V(1) {
glog.Infof("cas: rest authentication failed %v", err)
}
// TODO: Check which kind of error (timeout? 401? 50X?) occurred and act appropriately
if ch.c.forwardUnauthenticatedRESTRequests {
if glog.V(1) {
glog.Infof("unauthenticated request will be forwarded to application")
}
// forward REST request for potential local user authentication
ch.h.ServeHTTP(w, r)
} else {
// TODO: cache unauthenticated requests
w.Header().Set("WWW-Authenticate", "Basic realm=\"CAS Protected Area\"")
w.WriteHeader(401)
}
return
}
ch.cache.Set(authorizationHeader, newAuthenticationResponse, cache.DefaultExpiration)
setFirstAuthenticatedRequest(r, true)
reaction := ch.tryToAuthenticateAndCreateReaction(r, username, password)
ch.cache.Set(authorizationHeader, reaction, cache.DefaultExpiration)
}

authenticationResponse, keyWasFound = ch.cache.Get(authorizationHeader)
setAuthenticationResponse(r, authenticationResponse.(*AuthenticationResponse))
ch.h.ServeHTTP(w, r)
cachedReaction, keyWasFound := ch.cache.Get(authorizationHeader)
if f, ok := cachedReaction.(reaction); ok && keyWasFound {
f(w, r)
} else {
if glog.V(1) {
glog.Error("Unexpected behaviour: did not find a cached reaction for given authorizationHeader")
}
}
return
}

func (ch *restClientHandler) tryToAuthenticateAndCreateReaction(request *http.Request, username string, password string) reaction {
newAuthenticationResponse, err := ch.authenticate(username, password)
if err != nil {
if glog.V(1) {
glog.Infof("cas: rest authentication failed %v", err)
}
// TODO: Check which kind of error (timeout? 401? 50X?) occurred and act appropriately
return func(writer http.ResponseWriter, req *http.Request) {
ch.handleUnauthenticatedRequest(writer, req)
}
} else {
setFirstAuthenticatedRequest(request, true)
return func(writer http.ResponseWriter, req *http.Request) {
setAuthenticationResponse(req, newAuthenticationResponse)
ch.h.ServeHTTP(writer, req)
}
}
}

func (ch *restClientHandler) handleUnauthenticatedRequest(w http.ResponseWriter, r *http.Request) {
if ch.c.ShallForwardUnauthenticatedRESTRequests() {
if glog.V(1) {
glog.Info("unauthenticated request will be forwarded to application")
}
// forward REST request for potential local user authentication or anonymous user
ch.h.ServeHTTP(w, r)
} else {
w.Header().Set("WWW-Authenticate", "Basic realm=\"CAS Protected Area\"")
w.WriteHeader(401)
}
}

func (ch *restClientHandler) authenticate(username string, password string) (*AuthenticationResponse, error) {
tgt, err := ch.c.RequestGrantingTicket(username, password)
if err != nil {
Expand Down
177 changes: 177 additions & 0 deletions rest_handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package cas

import (
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/pkg/errors"

"github.com/patrickmn/go-cache"

"github.com/stretchr/testify/mock"

"github.com/stretchr/testify/assert"
)

type MockedRestClient struct {
mock.Mock
}

func (m *MockedRestClient) Handle(h http.Handler) http.Handler {
args := m.Called(h)
return args.Get(0).(http.Handler)
}

func (m *MockedRestClient) RequestGrantingTicket(username string, password string) (TicketGrantingTicket, error) {
args := m.Called(username, password)
return args.Get(0).(TicketGrantingTicket), args.Error(1)
}

func (m *MockedRestClient) RequestServiceTicket(tgt TicketGrantingTicket) (ServiceTicket, error) {
args := m.Called(tgt)
return args.Get(0).(ServiceTicket), args.Error(1)
}

func (m *MockedRestClient) ValidateServiceTicket(st ServiceTicket) (*AuthenticationResponse, error) {
args := m.Called(st)
return args.Get(0).(*AuthenticationResponse), args.Error(1)
}

func (m *MockedRestClient) Logout(tgt TicketGrantingTicket) error {
args := m.Called(tgt)
return args.Error(0)
}

func (m *MockedRestClient) ShallForwardUnauthenticatedRESTRequests() bool {
args := m.Called()
return args.Bool(0)
}

type serveCounter struct {
counter int
}

func (s *serveCounter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.counter++
}

func TestServeHTTP(t *testing.T) {
req, _ := http.NewRequest("GET", "/foo", nil)
req.SetBasicAuth("dirk", "gently")
m := new(MockedRestClient)
m.On("RequestGrantingTicket", "dirk", "gently").Return(TicketGrantingTicket("tgt"), nil)
m.On("RequestServiceTicket", mock.Anything).Return(ServiceTicket("st"), nil)
m.On("ValidateServiceTicket", mock.Anything).Return(&AuthenticationResponse{}, nil)
s := &serveCounter{0}
r := restClientHandler{c: m, h: s, cache: cache.New(time.Minute, time.Minute)}
w := httptest.NewRecorder()

r.ServeHTTP(w, req)

assert.Equal(t, s.counter, 1)
}

func TestServeHTTPCaching(t *testing.T) {
req, _ := http.NewRequest("GET", "/foo", nil)
req.SetBasicAuth("dirk", "gently")
m := new(MockedRestClient)
m.On("RequestGrantingTicket", "dirk", "gently").Return(TicketGrantingTicket("tgt"), nil)
m.On("RequestServiceTicket", mock.Anything).Return(ServiceTicket("st"), nil)
m.On("ValidateServiceTicket", mock.Anything).Return(&AuthenticationResponse{}, nil)
s := &serveCounter{0}
r := restClientHandler{c: m, h: s, cache: cache.New(time.Minute, time.Minute)}
w := httptest.NewRecorder()

r.ServeHTTP(w, req)
assert.Equal(t, s.counter, 1)

// this disables the authentication against cas so we can check if the cache is used
m.On("RequestGrantingTicket", "dirk", "gently").Return(nil, errors.New("failed"))

r.ServeHTTP(w, req)
assert.Equal(t, s.counter, 2)
}

func TestServeHTTPWithWrongCredentialsAndForward(t *testing.T) {
req, _ := http.NewRequest("GET", "/foo", nil)
req.SetBasicAuth("dirk", "gently")
m := new(MockedRestClient)
m.On("RequestGrantingTicket", "dirk", "gently").Return(TicketGrantingTicket("tgt"),
errors.New("wrong creds"))
m.On("ShallForwardUnauthenticatedRESTRequests").Return(true)
s := &serveCounter{0}
r := restClientHandler{c: m, h: s, cache: cache.New(time.Minute, time.Minute)}
w := httptest.NewRecorder()

r.ServeHTTP(w, req)

assert.Equal(t, s.counter, 1)
}

func TestServeHTTPWithWrongCredentialsAndForwardCaching(t *testing.T) {
req, _ := http.NewRequest("GET", "/foo", nil)
req.SetBasicAuth("dirk", "gently")
m := new(MockedRestClient)
m.On("RequestGrantingTicket", "dirk", "gently").Return(TicketGrantingTicket("tgt"),
errors.New("wrong creds"))
m.On("ShallForwardUnauthenticatedRESTRequests").Return(true)
s := &serveCounter{0}
r := restClientHandler{c: m, h: s, cache: cache.New(time.Minute, time.Minute)}
w := httptest.NewRecorder()

r.ServeHTTP(w, req)
assert.Equal(t, s.counter, 1)

m.On("RequestGrantingTicket", "dirk", "gently").Return(TicketGrantingTicket("tgt"),
nil)
r.ServeHTTP(w, req)
// without caching the request would now have an authentication response
assert.Nil(t, getAuthenticationResponse(req))
assert.Equal(t, s.counter, 2)
}

func TestServeHTTPWithWrongCredentialsAndWithoutForward(t *testing.T) {
req, _ := http.NewRequest("GET", "/foo", nil)
req.SetBasicAuth("dirk", "gently")
m := new(MockedRestClient)
m.On("RequestGrantingTicket", "dirk", "gently").Return(TicketGrantingTicket("tgt"),
errors.New("wrong creds"))
m.On("ShallForwardUnauthenticatedRESTRequests").Return(false)
s := &serveCounter{0}
r := restClientHandler{c: m, h: s, cache: cache.New(time.Minute, time.Minute)}
w := httptest.NewRecorder()

r.ServeHTTP(w, req)

assert.Equal(t, s.counter, 0)
assert.Equal(t, w.Header().Get("WWW-Authenticate"), "Basic realm=\"CAS Protected Area\"")
}

func TestServeHTTPWithoutBasicAuthAndForward(t *testing.T) {
req, _ := http.NewRequest("GET", "/foo", nil)
m := new(MockedRestClient)
m.On("ShallForwardUnauthenticatedRESTRequests").Return(true)
s := &serveCounter{0}
r := restClientHandler{c: m, h: s, cache: cache.New(time.Minute, time.Minute)}
w := httptest.NewRecorder()

r.ServeHTTP(w, req)

assert.Equal(t, s.counter, 1)
}

func TestServeHTTPWithoutBasicAuthAndWithoutForward(t *testing.T) {
req, _ := http.NewRequest("GET", "/foo", nil)
m := new(MockedRestClient)
m.On("ShallForwardUnauthenticatedRESTRequests").Return(false)
s := &serveCounter{0}
r := restClientHandler{c: m, h: s, cache: cache.New(time.Minute, time.Minute)}
w := httptest.NewRecorder()

r.ServeHTTP(w, req)

assert.Equal(t, s.counter, 0)
assert.Equal(t, w.Header().Get("WWW-Authenticate"), "Basic realm=\"CAS Protected Area\"")
}