Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 64 additions & 6 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"strings"
)

type Config struct {
APIBaseUrl string `json:"apiBaseUrl"`
UserSessionCookieName string `json:"userSessionCookieName"`
ResourceSessionRequestParam string `json:"resourceSessionRequestParam"`
APIBaseUrl string `json:"apiBaseUrl"`
UserSessionCookieName string `json:"userSessionCookieName"`
ResourceSessionRequestParam string `json:"resourceSessionRequestParam"`
ClientIPHeader *string `json:"clientIpHeader,omitempty"`
}

type Badger struct {
Expand All @@ -21,6 +23,7 @@ type Badger struct {
apiBaseUrl string
userSessionCookieName string
resourceSessionRequestParam string
clientIPHeader *string
}

type VerifyBody struct {
Expand Down Expand Up @@ -72,19 +75,20 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
apiBaseUrl: config.APIBaseUrl,
userSessionCookieName: config.UserSessionCookieName,
resourceSessionRequestParam: config.ResourceSessionRequestParam,
clientIPHeader: config.ClientIPHeader,
}, nil
}

func (p *Badger) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
cookies := p.extractCookies(req)

clientIP := p.getClientIP(req)
queryValues := req.URL.Query()

if sessionRequestValue := queryValues.Get(p.resourceSessionRequestParam); sessionRequestValue != "" {
body := ExchangeSessionBody{
RequestToken: &sessionRequestValue,
RequestHost: &req.Host,
RequestIP: &req.RemoteAddr,
RequestIP: clientIP,
}

jsonData, err := json.Marshal(body)
Expand Down Expand Up @@ -160,7 +164,7 @@ func (p *Badger) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
RequestPath: &req.URL.Path,
RequestMethod: &req.Method,
TLS: req.TLS != nil,
RequestIP: &req.RemoteAddr,
RequestIP: clientIP,
Headers: headers,
Query: queryParams,
}
Expand Down Expand Up @@ -250,3 +254,57 @@ func (p *Badger) getScheme(req *http.Request) string {
}
return "http"
}

func (p *Badger) getClientIP(req *http.Request) *string {
// If no specific header is configured, use remote address (safe default)
if p.clientIPHeader == nil || *p.clientIPHeader == "" {
remoteIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
// If SplitHostPort fails, assume req.RemoteAddr is just an IP
remoteIP = req.RemoteAddr
}
return &remoteIP
}

// Get the specified header value
headerValue := req.Header.Get(*p.clientIPHeader)
if headerValue == "" {
// Header not found, fallback to remote address
remoteIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
remoteIP = req.RemoteAddr
}
return &remoteIP
}

// Special handling for X-Forwarded-For header (contains comma-separated IPs)
if strings.ToLower(*p.clientIPHeader) == "x-forwarded-for" {
// X-Forwarded-For can contain multiple IPs: "client, proxy1, proxy2"
// The first one should be the original client IP
ips := strings.Split(headerValue, ",")
for _, ip := range ips {
ip = strings.TrimSpace(ip)
if parsedIP := net.ParseIP(ip); parsedIP != nil {
return &ip
}
}
// If no valid IP found in X-Forwarded-For, fallback to remote address
remoteIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
remoteIP = req.RemoteAddr
}
return &remoteIP
}

// For any other header, validate it's a valid IP and return it
if parsedIP := net.ParseIP(headerValue); parsedIP != nil {
return &headerValue
}

// Invalid IP in header, fallback to remote address
remoteIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
remoteIP = req.RemoteAddr
}
return &remoteIP
}