From 34263b479b0bc8beed03ddf350ec4a524ceca6bc Mon Sep 17 00:00:00 2001 From: Lz Date: Sun, 9 Feb 2025 18:09:46 +0800 Subject: [PATCH 1/8] fix(context): exposed StatusCode (#38) --- compressor_deflate.go | 6 ++++-- compressor_gzip.go | 6 ++++-- context.go | 10 +++------- context_test.go | 2 +- ext/cookie/cookie_test.go | 14 +++++++------- response_writer.go | 1 + response_writer_deflate.go | 11 +++++------ response_writer_gzip.go | 11 +++++------ response_writer_std.go | 19 +++++++++++++++++++ response_writer_std_test.go | 22 ++++++++++++++++++++++ viewer_file.go | 4 +--- 11 files changed, 72 insertions(+), 34 deletions(-) create mode 100644 response_writer_std_test.go diff --git a/compressor_deflate.go b/compressor_deflate.go index 20e85fe..06cd8ed 100644 --- a/compressor_deflate.go +++ b/compressor_deflate.go @@ -23,7 +23,9 @@ func (c *DeflateCompressor) New(rw http.ResponseWriter) ResponseWriter { w, _ := flate.NewWriter(rw, flate.DefaultCompression) //nolint: errcheck because flate.DefaultCompression is a valid compression level return &deflateResponseWriter{ - w: w, - ResponseWriter: rw, + w: w, + stdResponseWriter: &stdResponseWriter{ + ResponseWriter: rw, + }, } } diff --git a/compressor_gzip.go b/compressor_gzip.go index 22e5dfa..d0ad295 100644 --- a/compressor_gzip.go +++ b/compressor_gzip.go @@ -21,8 +21,10 @@ func (c *GzipCompressor) New(rw http.ResponseWriter) ResponseWriter { rw.Header().Set("Content-Encoding", "gzip") return &gzipResponseWriter{ - w: gzip.NewWriter(rw), - ResponseWriter: rw, + w: gzip.NewWriter(rw), + stdResponseWriter: &stdResponseWriter{ + ResponseWriter: rw, + }, } } diff --git a/context.go b/context.go index cf45a1a..8a2b556 100644 --- a/context.go +++ b/context.go @@ -11,11 +11,10 @@ import ( type Context struct { Routing Routing app *App - Response http.ResponseWriter + Response ResponseWriter Request *http.Request - writtenStatus bool - values map[string]any + values map[string]any } // WriteStatus sets the HTTP status code for the response. @@ -23,10 +22,7 @@ type Context struct { // The status code will be sent to the client only once the response body is closed. // If a status code is not set, the default status code is 200 (OK). func (c *Context) WriteStatus(code int) { - if !c.writtenStatus { - c.Response.WriteHeader(code) - c.writtenStatus = true - } + c.Response.WriteHeader(code) } // WriteHeader sets a response header. diff --git a/context_test.go b/context_test.go index 49b206d..8380e88 100644 --- a/context_test.go +++ b/context_test.go @@ -252,7 +252,7 @@ func TestMixedViewers(t *testing.T) { func TestDeleteHeader(t *testing.T) { ctx := &Context{ - Response: httptest.NewRecorder(), + Response: NewResponseWriter(httptest.NewRecorder()), } ctx.WriteHeader("test", "value") diff --git a/ext/cookie/cookie_test.go b/ext/cookie/cookie_test.go index 0ce95ab..922df90 100644 --- a/ext/cookie/cookie_test.go +++ b/ext/cookie/cookie_test.go @@ -18,7 +18,7 @@ func TestCookie(t *testing.T) { t.Run("set", func(t *testing.T) { ctx := &xun.Context{ Request: httptest.NewRequest(http.MethodGet, "/", nil), - Response: httptest.NewRecorder(), + Response: xun.NewResponseWriter(httptest.NewRecorder()), } c := http.Cookie{Name: "test", Value: "value"} @@ -33,7 +33,7 @@ func TestCookie(t *testing.T) { t.Run("get", func(t *testing.T) { ctx := &xun.Context{ Request: httptest.NewRequest(http.MethodGet, "/", nil), - Response: httptest.NewRecorder(), + Response: xun.NewResponseWriter(httptest.NewRecorder()), } c := http.Cookie{Name: "test", Value: "dmFsdWU="} // base64 encoded "value" ctx.Request.Header.Set("Cookie", c.String()) @@ -48,7 +48,7 @@ func TestCookie(t *testing.T) { func TestDelete(t *testing.T) { ctx := &xun.Context{ Request: httptest.NewRequest(http.MethodGet, "/", nil), - Response: httptest.NewRecorder(), + Response: xun.NewResponseWriter(httptest.NewRecorder()), } c := http.Cookie{Name: "test", Value: "dmFsdWU="} // base64 encoded "value" Delete(ctx, c) @@ -62,7 +62,7 @@ func TestSignedCookie(t *testing.T) { cookie := http.Cookie{Name: "test", Value: "value"} ctx := &xun.Context{ Request: httptest.NewRequest(http.MethodGet, "/", nil), - Response: httptest.NewRecorder(), + Response: xun.NewResponseWriter(httptest.NewRecorder()), } ts, err := SetSigned(ctx, cookie, []byte("secret")) @@ -79,7 +79,7 @@ func TestSignedCookie(t *testing.T) { t.Run("get", func(t *testing.T) { ctx := &xun.Context{ Request: httptest.NewRequest(http.MethodGet, "/", nil), - Response: httptest.NewRecorder(), + Response: xun.NewResponseWriter(httptest.NewRecorder()), } ts := time.Now() @@ -102,7 +102,7 @@ func TestSignedCookie(t *testing.T) { func TestInvalidCookie(t *testing.T) { t.Run("too_long_value", func(t *testing.T) { ctx := &xun.Context{ - Response: httptest.NewRecorder(), + Response: xun.NewResponseWriter(httptest.NewRecorder()), } err := Set(ctx, http.Cookie{ @@ -143,7 +143,7 @@ func TestInvalidCookie(t *testing.T) { func TestInvalidSigned(t *testing.T) { t.Run("too_long_value", func(t *testing.T) { ctx := &xun.Context{ - Response: httptest.NewRecorder(), + Response: xun.NewResponseWriter(httptest.NewRecorder()), } _, err := SetSigned(ctx, http.Cookie{ diff --git a/response_writer.go b/response_writer.go index 4831c5a..f144383 100644 --- a/response_writer.go +++ b/response_writer.go @@ -10,5 +10,6 @@ import ( type ResponseWriter interface { http.ResponseWriter + StatusCode() int Close() } diff --git a/response_writer_deflate.go b/response_writer_deflate.go index db965e4..ae28516 100644 --- a/response_writer_deflate.go +++ b/response_writer_deflate.go @@ -2,24 +2,23 @@ package xun import ( "compress/flate" - "net/http" ) // deflateResponseWriter is a custom http.ResponseWriter that wraps the standard // ResponseWriter and compresses the response using the deflate algorithm. type deflateResponseWriter struct { + *stdResponseWriter w *flate.Writer - http.ResponseWriter } // Write writes the data to the underlying gzip writer. // It implements the io.Writer interface. -func (w *deflateResponseWriter) Write(p []byte) (int, error) { - return w.w.Write(p) +func (rw *deflateResponseWriter) Write(p []byte) (int, error) { + return rw.w.Write(p) } // Close closes the underlying writer, flushing any buffered data to the client. // It is important to call this method to ensure all data is properly sent. -func (w *deflateResponseWriter) Close() { - w.w.Close() +func (rw *deflateResponseWriter) Close() { + rw.w.Close() } diff --git a/response_writer_gzip.go b/response_writer_gzip.go index 27cb2bb..add65e6 100644 --- a/response_writer_gzip.go +++ b/response_writer_gzip.go @@ -2,23 +2,22 @@ package xun import ( "compress/gzip" - "net/http" ) // gzipResponseWriter is a custom http.ResponseWriter that wraps the standard // ResponseWriter and compresses the response using gzip. type gzipResponseWriter struct { + *stdResponseWriter w *gzip.Writer - http.ResponseWriter } // Write writes the data to the underlying gzip writer. // It implements the io.Writer interface. -func (w *gzipResponseWriter) Write(p []byte) (int, error) { - return w.w.Write(p) +func (rw *gzipResponseWriter) Write(p []byte) (int, error) { + return rw.w.Write(p) } // Close closes the gzipResponseWriter, ensuring that the underlying writer is also closed. -func (w *gzipResponseWriter) Close() { - w.w.Close() +func (rw *gzipResponseWriter) Close() { + rw.w.Close() } diff --git a/response_writer_std.go b/response_writer_std.go index 59919d9..b928a8b 100644 --- a/response_writer_std.go +++ b/response_writer_std.go @@ -5,9 +5,28 @@ import "net/http" // stdResponseWriter is a wrapper around http.ResponseWriter to implement the ResponseWriter interface. type stdResponseWriter struct { http.ResponseWriter + statusCode int } // Close implements the ResponseWriter interface Close method. // It is a no-op for the standard response writer. func (*stdResponseWriter) Close() { } + +func (rw *stdResponseWriter) WriteHeader(statusCode int) { + if rw.statusCode == 0 { + rw.statusCode = statusCode + rw.ResponseWriter.WriteHeader(statusCode) + } +} + +func (rw *stdResponseWriter) StatusCode() int { + if rw.statusCode == 0 { + return http.StatusOK + } + return rw.statusCode +} + +func NewResponseWriter(rw http.ResponseWriter) ResponseWriter { + return &stdResponseWriter{ResponseWriter: rw} +} diff --git a/response_writer_std_test.go b/response_writer_std_test.go new file mode 100644 index 0000000..0095be4 --- /dev/null +++ b/response_writer_std_test.go @@ -0,0 +1,22 @@ +package xun + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWriterStatus(t *testing.T) { + rw := NewResponseWriter(httptest.NewRecorder()) + + require.Equal(t, http.StatusOK, rw.StatusCode()) + + rw.WriteHeader(http.StatusNotFound) + require.Equal(t, http.StatusNotFound, rw.StatusCode()) + + rw.WriteHeader(http.StatusInternalServerError) + require.Equal(t, http.StatusNotFound, rw.StatusCode()) + +} diff --git a/viewer_file.go b/viewer_file.go index 2697cca..f5f3056 100644 --- a/viewer_file.go +++ b/viewer_file.go @@ -74,8 +74,6 @@ func (v *FileViewer) Render(w http.ResponseWriter, r *http.Request, data any) er if !v.isEmbed { return v.serveContent(w, r) } - - w.Header().Set("ETag", v.etag) if match := r.Header.Get("If-None-Match"); match != "" { for _, it := range strings.Split(match, ",") { if strings.TrimSpace(it) == v.etag { @@ -83,9 +81,9 @@ func (v *FileViewer) Render(w http.ResponseWriter, r *http.Request, data any) er return nil } } - } + w.Header().Set("ETag", v.etag) return v.serveContent(w, r) } From 7b7455da7f422f3edba2b544b31a7af3b87d419a Mon Sep 17 00:00:00 2001 From: Lz Date: Tue, 11 Feb 2025 20:01:54 +0800 Subject: [PATCH 2/8] feat(proxyproto): support serve http(s) server through proxy servers and load balancers (#40) --- CHANGELOG.md | 27 --- README.md | 106 ++++++++-- ext/proxyproto/conn.go | 83 ++++++++ ext/proxyproto/conn_test.go | 150 ++++++++++++++ ext/proxyproto/header.go | 306 +++++++++++++++++++++++++++++ ext/proxyproto/header_test.go | 359 ++++++++++++++++++++++++++++++++++ ext/proxyproto/helper.go | 15 ++ ext/proxyproto/helper_test.go | 15 ++ ext/proxyproto/listener.go | 25 +++ ext/proxyproto/serve.go | 38 ++++ ext/proxyproto/serve_test.go | 121 ++++++++++++ go.mod | 9 +- go.sum | 18 +- 13 files changed, 1220 insertions(+), 52 deletions(-) delete mode 100644 CHANGELOG.md create mode 100644 ext/proxyproto/conn.go create mode 100644 ext/proxyproto/conn_test.go create mode 100644 ext/proxyproto/header.go create mode 100644 ext/proxyproto/header_test.go create mode 100644 ext/proxyproto/helper.go create mode 100644 ext/proxyproto/helper_test.go create mode 100644 ext/proxyproto/listener.go create mode 100644 ext/proxyproto/serve.go create mode 100644 ext/proxyproto/serve_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 100644 index 388fd17..0000000 --- a/CHANGELOG.md +++ /dev/null @@ -1,27 +0,0 @@ -# Changelog - -All notable changes to this project will be documented in this file. - -The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), -and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - -## [Unreleased] -- use `html/template` to parse template files (#7) - -## [1.0.3] - 2025-01-01 -### Changed -- renamed package name with `xun` (#4) -- moved `htmx` helper to `ext/htmx` (#4) - -### Fixed -- fixed syntax issue on `htmx.WriteHeader` - -### Added -- added logging `app.routes` in `app.Start` - -## [1.0.1] - 2024-12-30 -### Added -- added htmx helper -- support setup custom FuncMap on HtmlTemplate - -## [1.0.0] - 2024-12-25 diff --git a/README.md b/README.md index c78b98a..38bd40b 100644 --- a/README.md +++ b/README.md @@ -493,30 +493,110 @@ app := xun.New(WithCompressor(&GzipCompressor{}, &DeflateCompressor{})) Use `autotls.Configure` to set up servers for automatic obtaining and renewing of TLS certificates from Let's Encrypt. ```go +mux := http.NewServeMux() - mux := http.NewServeMux() +app := xun.New(xun.WithMux(mux)) + +//... - app := xun.New(xun.WithMux(mux)) +httpServer := &http.Server{ + Addr: ":http", + //... +} +httpsServer := &http.Server{ + Addr: ":https", //... +} + +autotls. + New(autotls.WithCache(autocert.DirCache("./certs")), + autotls.WithHosts("abc.com", "123.com")). + Configure(httpServer, httpsServer) + +go httpServer.ListenAndServe() +go httpsServer.ListenAndServeTLS("", "") +``` + +#### Cookie +Cookies are a way to store information at the client end. see [more examples](./ext/cookie/cookie_test.go) +> Write cookie with base64(URL Encoding) to client +```go +cookie.Set(ctx, http.Cookie{Name: "test", Value: "value"}) // Set-Cookie: test=dmFsdWU= +``` + +> Read and decoded cookie from client's request +```go +v, err := cookie.Get(ctx,"test") + +fmt.Println(v) // value +``` + +When signed, the cookies can't be forged, because their values are validated using HMAC. +```go +ts, err := cookie.SetSigned(ctx,http.Cookie{Name: "test", Value: "value"},[]byte("secret")) // ts is current timestamp + +v, ts, err := cookie.GetSigned(ctx, "test",[]byte("secret")) // v is value, ts is the timestamp that was signed +``` + +> Delete a cookie +```go +cookie.Delete(ctx, http.Cookie{Name: "test", Value: "dmFsdWU="}) // Set-Cookie: test=; Expires=Thu, 01 Jan 1970 00:00:00 GMT; Max-Age=0 +``` + +#### HSTS +HTTP Strict Transport Security (HSTS) is a simple and widely supported standard to protect visitors by ensuring that their browsers always connect to a website over HTTPS. + + +> Redirect redirects plain HTTP requests to HTTPS. **DON'T use it if HTTPs is unsupported on your server.** +```go +app.Use(hsts.Redirect()) +``` + +> Write HSTS header if it is a HTTPs request. **It is only applied in HTTPs request.** +```go +app.Use(hsts.WriteHeader()) +``` + +#### Proxy Protocol +The PROXY protocol allows our application to receive client connection information that is passed through proxy servers and load balancers. Both PROXY protocol versions 1 and 2 are supported. + +[How to use the Proxy Protocol to preserve a client's ip address?](https://www.haproxy.com/blog/use-the-proxy-protocol-to-preserve-a-clients-ip-address) - httpServer := &http.Server{ - Addr: ":http", - //... +**Security Note: Do not enable the PROXY protocol on your servers unless they are located behind a proxy server or load balancer. If the PROXY protocol is enabled without such intermediaries, any client could potentially send fake IP addresses or other misleading information, posing a security risk.** + +> ListenAndServe + +```go + mux := http.NewServeMux() + + srv := &http.Server{ + Addr: ":80", + Handler: mux, } + app := xun.New(WithMux(mux)) + app.Start() + defer app.Close() + + // srv.ListenAndServe() + proxyproto.ListenAndServe(srv) +``` + +> ListenAndServeTLS + +```go httpsServer := &http.Server{ - Addr: ":https", - //... + Addr: ":443", + Handler: mux, } - autotls. - New(autotls.WithCache(autocert.DirCache("./certs")), - autotls.WithHosts("abc.com", "123.com")). - Configure(httpServer, httpsServer) + autotls.New(autotls.WithCache(autocert.DirCache("./certs")), + autotls.WithHosts("yaitoo.cn", "www.yaitoo.cn")). + Configure(srv, httpsServer) - go httpServer.ListenAndServe() - go httpsServer.ListenAndServeTLS("", "") + // httpsServer.ListenAndServeTLS( "", "") + proxyproto.ListenAndServeTLS(httpsServer, "", "") ``` ### Works with [tailwindcss](https://tailwindcss.com/docs/installation) diff --git a/ext/proxyproto/conn.go b/ext/proxyproto/conn.go new file mode 100644 index 0000000..dda346c --- /dev/null +++ b/ext/proxyproto/conn.go @@ -0,0 +1,83 @@ +package proxyproto + +import ( + "bufio" + "bytes" + "log" + "net" + "sync" +) + +var Logger = log.Default() + +type conn struct { + net.Conn + r *bufio.Reader + h *Header + + isLoaded bool + once sync.Once +} + +// NewConn wraps a net.Conn and returns a new proxyproto.Conn that reads the +// PROXY protocol header from the connection. If the connection is not a +// PROXY protocol connection, it returns the original connection. +func NewConn(nc net.Conn) net.Conn { + return &conn{Conn: nc, r: bufio.NewReader(nc)} +} + +// Read reads data from the connection. +// Read can be made to time out and return an error after a fixed +// time limit; see SetDeadline and SetReadDeadline. +func (c *conn) Read(b []byte) (n int, err error) { + if !c.isLoaded { + c.once.Do(c.tryUseProxy) + } + return c.r.Read(b) +} + +// LocalAddr returns the local network address, if known. +func (c *conn) LocalAddr() net.Addr { + if !c.isLoaded { + c.once.Do(c.tryUseProxy) + } + + if c.h != nil { + return c.h.LocalAddr + } + return c.Conn.LocalAddr() +} + +// RemoteAddr returns the remote network address, if known. +func (c *conn) RemoteAddr() net.Addr { + if !c.isLoaded { + c.once.Do(c.tryUseProxy) + } + if c.h != nil { + return c.h.RemoteAddr + } + return c.Conn.RemoteAddr() +} + +// tryUseProxy tries to read the PROXY protocol header from the connection. +// If the header is read successfully, it sets the Header field and marks the +// connection as proxied. If the header is invalid or not present, it does +// nothing. +func (c *conn) tryUseProxy() { + defer func() { + c.isLoaded = true + }() + // Read the first 13 bytes which should contain the identifier + buf, err := c.r.Peek(13) + if err != nil { + Logger.Println("proxyproto: peek", err) + return + } + + // v1 + if bytes.HasPrefix(buf[0:13], v1) { + c.h = readV1Header(c.r) + } else if bytes.HasPrefix(buf[0:13], v2) { + c.h = readV2Header(c.r) + } +} diff --git a/ext/proxyproto/conn_test.go b/ext/proxyproto/conn_test.go new file mode 100644 index 0000000..32e4b87 --- /dev/null +++ b/ext/proxyproto/conn_test.go @@ -0,0 +1,150 @@ +package proxyproto + +import ( + "io" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type mockConn struct { + bytes []byte + localAddr net.Addr + remoteAddr net.Addr +} + +func (m *mockConn) Read(p []byte) (n int, err error) { + + copy(p, m.bytes) + + if len(m.bytes) < len(p) { + return len(m.bytes), io.EOF + } + return len(p), nil +} +func (*mockConn) Write(_ []byte) (n int, err error) { panic("not implemented") } +func (*mockConn) Close() error { panic("not implemented") } +func (m *mockConn) LocalAddr() net.Addr { return m.localAddr } +func (m *mockConn) RemoteAddr() net.Addr { return m.remoteAddr } +func (*mockConn) SetDeadline(time.Time) error { panic("not implemented") } +func (*mockConn) SetReadDeadline(time.Time) error { panic("not implemented") } +func (*mockConn) SetWriteDeadline(time.Time) error { panic("not implemented") } + +func TestConn(t *testing.T) { + + tests := []struct { + name string + mc *mockConn + fireFunc func(c net.Conn) + remoteAddr net.Addr + localAddr net.Addr + err bool + }{ + { + name: "v1/read_first", + mc: &mockConn{ + bytes: []byte("PROXY TCP4 192.168.0.1 192.168.0.2 56324 443\r\n"), + localAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.2"), Port: 443}, + remoteAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 56324}, + }, + fireFunc: func(c net.Conn) { + c.Read(make([]byte, 1)) // nolint: errcheck + }, + remoteAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 56324}, + localAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.2"), Port: 443}, + }, + { + name: "v1/remote_first", + mc: &mockConn{ + bytes: []byte("PROXY TCP4 192.168.0.1 192.168.0.2 56324 443\r\n"), + localAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.2"), Port: 443}, + remoteAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 56324}, + }, + fireFunc: func(c net.Conn) { + c.RemoteAddr() + }, + remoteAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 56324}, + localAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.2"), Port: 443}, + }, + { + name: "v1/local_first", + mc: &mockConn{ + bytes: []byte("PROXY TCP4 192.168.0.1 192.168.0.2 56324 443\r\n"), + localAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.2"), Port: 443}, + remoteAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 56324}, + }, + fireFunc: func(c net.Conn) { + c.LocalAddr() + }, + remoteAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 56324}, + localAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.2"), Port: 443}, + }, + { + name: "v2/read_first", + mc: &mockConn{bytes: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x0C, + // IPV4 -------------| IPV4 ----------------| SRC PORT DEST PORT + 0x7F, 0x00, 0x00, 0x01, 0x7F, 0x00, 0x00, 0x02, 0xCA, 0x2B, 0x04, 0x01}, + localAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.2"), Port: 443}, + remoteAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 56324}, + }, + fireFunc: func(c net.Conn) { + c.Read(make([]byte, 1)) // nolint: errcheck + }, + remoteAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 51755}, + localAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.2"), Port: 1025}, + }, + { + name: "v2/remote_first", + mc: &mockConn{bytes: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x0C, + // IPV4 -------------| IPV4 ----------------| SRC PORT DEST PORT + 0x7F, 0x00, 0x00, 0x01, 0x7F, 0x00, 0x00, 0x02, 0xCA, 0x2B, 0x04, 0x01}, + localAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.2"), Port: 443}, + remoteAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 56324}, + }, + fireFunc: func(c net.Conn) { + c.RemoteAddr() + }, + remoteAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 51755}, + localAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.2"), Port: 1025}, + }, + { + name: "v2/local_first", + mc: &mockConn{bytes: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x0C, + // IPV4 -------------| IPV4 ----------------| SRC PORT DEST PORT + 0x7F, 0x00, 0x00, 0x01, 0x7F, 0x00, 0x00, 0x02, 0xCA, 0x2B, 0x04, 0x01}, + localAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.2"), Port: 443}, + remoteAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 56324}, + }, + fireFunc: func(c net.Conn) { + c.LocalAddr() + }, + remoteAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 51755}, + localAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.2"), Port: 1025}, + }, + { + name: "no_proxyproto", + mc: &mockConn{ + bytes: []byte("PROXY TCP4\r\n"), + localAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.2"), Port: 443}, + remoteAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 56324}, + }, + remoteAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 56324}, + localAddr: &net.TCPAddr{IP: net.ParseIP("192.168.0.2"), Port: 443}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mc := NewConn(tt.mc) + if tt.fireFunc != nil { + tt.fireFunc(mc) + } + + require.Equal(t, tt.remoteAddr.String(), mc.RemoteAddr().String()) + require.Equal(t, tt.localAddr.String(), mc.LocalAddr().String()) + + }) + } +} diff --git a/ext/proxyproto/header.go b/ext/proxyproto/header.go new file mode 100644 index 0000000..558a478 --- /dev/null +++ b/ext/proxyproto/header.go @@ -0,0 +1,306 @@ +package proxyproto + +import ( + "bufio" + "encoding/binary" + "io" + "net" + "strings" +) + +// Header represents the structure of the proxy protocol header. +// It holds information about the connection, such as local and remote addresses, +// the protocol version, and any additional TLVs (Type-Length-Value) in the v2 protocol. +type Header struct { + // LocalAddr is the ip address of the party that initiated the connection + LocalAddr net.Addr + // RemoteAddr is the ip address the remote party connected to; aka the address + // the proxy was listening for connections on. + RemoteAddr net.Addr + // The version of the proxy protocol parsed + Version int + + // V2 + Command Command + Protocol Protocol + // The unparsed TLVs (Type-Length-Value) that were appended to the end of + // the v2 proto proxy header. + RawTLVs []byte +} + +var ( + lengthUnspec = uint16(0) + lengthV4 = uint16(12) + lengthV6 = uint16(36) + lengthUnix = uint16(216) +) + +func (header *Header) validateLength(length uint16) bool { + if header.Protocol.IsIPv4() { + return length >= lengthV4 + } else if header.Protocol.IsIPv6() { + return length >= lengthV6 + } else if header.Protocol.IsUnix() { + return length >= lengthUnix + } else if header.Protocol.IsUnspec() { + return length >= lengthUnspec + } + return false +} + +var ( + v1 = []byte("PROXY ") + v2 = []byte("\r\n\r\n\x00\r\nQUIT\n") // 0D 0A 0D 0A 00 0D 0A 51 55 49 54 0A +) + +const ( + v1_TCP6 = "TCP6" + v1_TCP4 = "TCP4" +) + +// readV1Header reads the v1 header. +// +// The v1 header is always 108 bytes long and contains the +// following information: +// +// - PROXY +// - Protocol (TCP4 or TCP6) +// - src_ip +// - dest_ip +// - src_port +// - dest_port +// +// The header is followed by a \r\n. +// +// Example: +// PROXY TCP4 192.168.0.1 192.168.0.10 12345 80\r\n +// PROXY TCP6 2001:db8::1 2001:db8::100 12345 80\r\n +// PROXY UNKNOWN\r\n +func readV1Header(r *bufio.Reader) *Header { + proxyLine, err := r.ReadString('\n') + if err != nil { + Logger.Println("proxyproto: can't read v1 header", err) + return nil + } + + // PROXY + Protocol + src_ip + dest_ip + src_port + dest_port + // PROXY TCP4 192.168.0.1 192.168.0.10 12345 80\r\n + // PROXY TCP6 2001:db8::1 2001:db8::100 12345 80\r\n + // PROXY UNKNOWN\r\n + fields := strings.Fields(proxyLine) + + if len(fields) < 6 { + Logger.Println("proxyproto: insufficient v1 header fields, found", len(fields)) + return nil + } + + if fields[1] == v1_TCP4 { + h := &Header{} + h.Version = 1 + + var err error + h.RemoteAddr, err = net.ResolveTCPAddr("tcp4", fields[2]+":"+fields[4]) + if err != nil { + Logger.Println("proxyproto: invalid remote ipv4", err) + return nil + } + h.LocalAddr, err = net.ResolveTCPAddr("tcp4", fields[3]+":"+fields[5]) + if err != nil { + Logger.Println("proxyproto: invalid local ipv4", err) + return nil + } + return h + } else if fields[1] == v1_TCP6 { + h := &Header{} + h.Version = 1 + + var err error + h.RemoteAddr, err = net.ResolveTCPAddr("tcp6", "["+fields[2]+"]:"+fields[4]) // [::1]:80 + if err != nil { + Logger.Println("proxyproto: invalid remote ipv6", err) + return nil + } + h.LocalAddr, err = net.ResolveTCPAddr("tcp6", "["+fields[3]+"]:"+fields[5]) + if err != nil { + Logger.Println("proxyproto: invalid local ipv6", err) + return nil + } + return h + } + + Logger.Println("proxyproto: unknown protocol", fields[1]) + return nil +} + +type tcp4Addr struct { + Remote [4]byte + Local [4]byte + RemotePort uint16 + LocalPort uint16 +} + +type tcp6Addr struct { + Remote [16]byte + Local [16]byte + RemotePort uint16 + LocalPort uint16 +} + +// readV2Header reads the v2 header. +// +// The v2 header is always 16 bytes long and contains the +// following information: +// +// - 12 bytes signature ("PROXY ") +// - 2 bytes version (always 2) +// - 2 bytes command (LOCAL or PROXY) +// For v2 the header length is at most 52 bytes plus the length of the TLVs. +func readV2Header(reader *bufio.Reader) *Header { // skipcq: GO-R1005 + var err error + // Skip first 12 bytes (signature) + for i := 0; i < 12; i++ { + if _, err = reader.ReadByte(); err != nil { + Logger.Println("proxyproto: can't read v2 signature", err) + return nil + } + } + + header := new(Header) + header.Version = 2 + + // Read the 13th byte, protocol version and command + b13, err := reader.ReadByte() + if err != nil { + Logger.Println("proxyproto: can't read v2 command", err) + return nil + } + header.Command = Command(b13) + if _, ok := supportedCommand[header.Command]; !ok { + Logger.Println("proxyproto: invalid v2 command", header.Command) + return nil + } + + // Read the 14th byte, address family and protocol + b14, err := reader.ReadByte() + if err != nil { + Logger.Println("proxyproto: can't read v2 protocol", err) + return nil + } + header.Protocol = Protocol(b14) + + // Make sure there are bytes available as specified in length + var length uint16 + if err := binary.Read(io.LimitReader(reader, 2), binary.BigEndian, &length); err != nil { + Logger.Println("proxyproto: can't read v2 length", err) + return nil + } + if !header.validateLength(length) { + Logger.Println("proxyproto: invalid v2 length", length) + return nil + } + + // Return early if the length is zero, which means that + // there's no address information and TLVs present for UNSPEC. + if length == 0 { + return nil + } + + if _, err := reader.Peek(int(length)); err != nil { + Logger.Println("proxyproto: can't peek v2 TLVs", err) + return nil + } + + // Length-limited reader for payload section + payloadReader := io.LimitReader(reader, int64(length)).(*io.LimitedReader) + + if header.Protocol.IsIPv4() { + var addr tcp4Addr + if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { + Logger.Println("proxyproto: can't read v2 tcp4 addresses", err) + return nil + } + header.RemoteAddr = toAddr(header.Protocol, addr.Remote[:], addr.RemotePort) + header.LocalAddr = toAddr(header.Protocol, addr.Local[:], addr.LocalPort) + } else if header.Protocol.IsIPv6() { + var addr tcp6Addr + if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { + Logger.Println("proxyproto: can't read v2 tcp6 addresses", err) + return nil + } + header.RemoteAddr = toAddr(header.Protocol, addr.Remote[:], addr.RemotePort) + header.LocalAddr = toAddr(header.Protocol, addr.Local[:], addr.LocalPort) + } else { + Logger.Println("proxyproto: unsupported v2 protocol", header.Protocol) + return nil + } + + if payloadReader.N > 0 { + // Copy bytes for optional Type-Length-Value vector + header.RawTLVs = make([]byte, payloadReader.N) // Allocate minimum size slice + if _, err = io.ReadFull(payloadReader, header.RawTLVs); err != nil && err != io.EOF { + Logger.Println("proxyproto: read v2 TLVs", err) + return nil + } + } + + return header +} + +// Command represents the command in proxy protocol v2. +// Command doesn't exist in v1 but it should be set since other parts of +// this library may rely on it for determining connection details. +type Command byte + +const ( + // LOCAL represents the LOCAL command in v2, + // in which case no address information is expected. + LOCAL Command = '\x20' + // PROXY represents the PROXY command in v2, + // in which case valid local/remote address and port information is expected. + PROXY Command = '\x21' +) + +var supportedCommand = map[Command]bool{ + LOCAL: true, + PROXY: true, +} + +// Protocol represents address family and transport protocol. +type Protocol byte + +const ( + TCPv4 Protocol = '\x11' + UDPv4 Protocol = '\x12' + TCPv6 Protocol = '\x21' + UDPv6 Protocol = '\x22' +) + +// IsIPv4 returns true if the address family is IPv4 (AF_INET4), false otherwise. +func (ap Protocol) IsIPv4() bool { + return ap&0xF0 == 0x10 +} + +// IsIPv6 returns true if the address family is IPv6 (AF_INET6), false otherwise. +func (ap Protocol) IsIPv6() bool { + return ap&0xF0 == 0x20 +} + +// IsUnix returns true if the address family is UNIX (AF_UNIX), false otherwise. +func (ap Protocol) IsUnix() bool { + return ap&0xF0 == 0x30 +} + +// IsStream returns true if the transport protocol is TCP or STREAM (SOCK_STREAM), false otherwise. +func (ap Protocol) IsStream() bool { + return ap&0x0F == 0x01 +} + +// IsDatagram returns true if the transport protocol is UDP or DGRAM (SOCK_DGRAM), false otherwise. +func (ap Protocol) IsDatagram() bool { + return ap&0x0F == 0x02 +} + +// IsUnspec returns true if the transport protocol or address family is unspecified, false otherwise. +func (ap Protocol) IsUnspec() bool { + return (ap&0xF0 == 0x00) || (ap&0x0F == 0x00) +} diff --git a/ext/proxyproto/header_test.go b/ext/proxyproto/header_test.go new file mode 100644 index 0000000..5a4657c --- /dev/null +++ b/ext/proxyproto/header_test.go @@ -0,0 +1,359 @@ +package proxyproto + +import ( + "bufio" + "bytes" + "net" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestV1(t *testing.T) { + tests := []struct { + name string + header string + remote net.Addr + local net.Addr + err bool + }{ + { + name: "TCP4/minimal", + header: "PROXY TCP4 1.1.1.5 1.1.1.6 2 3\r\n", + remote: &net.TCPAddr{IP: net.ParseIP("1.1.1.5"), Port: 2}, + local: &net.TCPAddr{IP: net.ParseIP("1.1.1.6"), Port: 3}, + }, + { + name: "TCP4/maximal", + header: "PROXY TCP4 255.255.255.255 255.255.255.254 65535 65535\r\n", + remote: &net.TCPAddr{IP: net.ParseIP("255.255.255.255"), Port: 65535}, + local: &net.TCPAddr{IP: net.ParseIP("255.255.255.254"), Port: 65535}, + }, + { + name: "TCP6/minimal", + header: "PROXY TCP6 ::1 ::2 3 4\r\n", + remote: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 3}, + local: &net.TCPAddr{IP: net.ParseIP("::2"), Port: 4}, + }, + { + name: "TCP6/maximal", + header: "PROXY TCP6 0000:0000:0000:0000:0000:0000:0000:0002 0000:0000:0000:0000:0000:0000:0000:0001 65535 65535\r\n", + remote: &net.TCPAddr{IP: net.ParseIP("0000:0000:0000:0000:0000:0000:0000:0002"), Port: 65535}, + local: &net.TCPAddr{IP: net.ParseIP("0000:0000:0000:0000:0000:0000:0000:0001"), Port: 65535}, + }, + { + name: "UNKNOWN/minimal", + header: "PROXY UNKNOWN\r\n", + err: true, + }, + { + name: "UNKNOWN/maximal", + header: "PROXY UNKNOWN 0000:0000:0000:0000:0000:0000:0000:0002 0000:0000:0000:0000:0000:0000:0000:0001 65535 65535\r\n", + err: true, + }, + { + name: "TCP6/empty", + header: "PROXY TCP6\r\n", + remote: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 3}, + local: &net.TCPAddr{IP: net.ParseIP("::2"), Port: 4}, + err: true, + }, + { + name: "TCP6/cRLF_not_found", + header: "PROXY TCP6 0000:0000:0000:0000:0000:0000:0000:0001 0000:0000:0000:0000:0000:0000:0000:0001 65535 65535XXXX\r\n", + err: true, + }, + { + name: "UNKNOWN/cRLF_not_found", + header: "PROXY UNKNOWN 0000:0000:0000:0000:0000:0000:0000:0001 0000:0000:0000:0000:0000:0000:0000:0001 65535 65535X\r\n", + err: true, + }, + { + name: "UNKNOWN/no_cRLF", + header: "PROXY UNKNOWN", + err: true, + }, + { + name: "Header/only_cRLF", + header: "\r\n", + err: true, + }, + { + name: "Header/empty", + header: "", + err: true, + }, + { + name: "Header/garbage", + header: "ASDFASDGSAG@#!@#$!WDFGASDGASDFG#TAGASDFASDG@", + err: true, + }, + { + name: "TCP4/incomplete", + header: "PROXY TCP4 garbage\r\n", + err: true, + }, + { + name: "TCP6/incomplete", + header: "PROXY TCP6 garbage\r\n", + err: true, + }, + { + name: "PROTO/unrecognized", + header: "PROXY UNIX :1 :1 234 234\r\n", + err: true, + }, + { + name: "TCP4/invalid_src", + header: "PROXY TCP4 NOT-AN-IP 192.168.1.1 22 2345\r\n", + err: true, + }, + { + name: "TCP4/invalid_dest", + header: "PROXY TCP4 192.168.1.1 NOT-AN-IP 22 2345\r\n", + err: true, + }, + { + name: "TCP4/invalid_src_port", + header: "PROXY TCP4 192.168.1.1 192.168.1.1 NOT-A-PORT 2345\r\n", + err: true, + }, + { + name: "TCP4/invalid_dest_port", + header: "PROXY TCP4 192.168.1.1 192.168.1.1 22 NOT-A-PORT\r\n", + err: true, + }, + { + name: "TCP4/corrupted_address_line", + header: "PROXY TCP4 192.168.1.1 192.168.1.1 2345\r\n", + err: true, + }, + + { + name: "TCP6/invalid_src", + header: "PROXY TCP6 NOT-AN-IP ::1 22 2345\r\n", + err: true, + }, + { + name: "TCP6/invalid_dest", + header: "PROXY TCP6 NOT-AN-IP ::1 22 2345\r\n", + err: true, + }, + { + name: "TCP6/invalid_src_port", + header: "PROXY TCP6 ::1 ::2 NOT-A-PORT 2345\r\n", + err: true, + }, + { + name: "TCP6/invalid_dest_port", + header: "PROXY TCP6 ::1 ::2 22 NOT-A-PORT\r\n", + err: true, + }, + { + name: "TCP6/corrupted_address_line", + header: "PROXY TCP6 ::1 ::2 2345\r\n", + err: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := bufio.NewReader(strings.NewReader(tt.header)) + h := readV1Header(r) + + if tt.err { + require.Nil(t, h) + } else { + require.Equal(t, tt.local, h.LocalAddr) + require.Equal(t, tt.remote, h.RemoteAddr) + require.Equal(t, 1, h.Version) + } + }) + } +} + +func TestV2(t *testing.T) { + tests := []struct { + name string + header []byte + remote net.Addr + local net.Addr + rawTLVs []byte + err bool + }{ + { + name: "TCP4/127.0.0.1", + // VER IP/TCP LENGTH + header: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x0C, + // IPV4 -------------| IPV4 ----------------| SRC PORT DEST PORT + 0x7F, 0x00, 0x00, 0x01, 0x7F, 0x00, 0x00, 0x02, 0xCA, 0x2B, 0x04, 0x01}, + remote: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 51755}, + local: &net.TCPAddr{IP: net.ParseIP("127.0.0.2"), Port: 1025}, + }, + { + name: "UDP4/127.0.0.1", + // IP/UDP + header: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x12, 0x00, 0x0C, + 0x7F, 0x00, 0x00, 0x01, 0x7F, 0x00, 0x00, 0x01, 0xCA, 0x2B, 0x04, 0x01}, + remote: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 51755}, + local: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1025}, + }, + { + name: "TCP6/127.0.0.1", + // VER IP/TCP LENGTH + header: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x21, 0x00, 0x24, + // IPV6 -------------------------------------------------------------------------------------| + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0x7F, 0x00, 0x00, 0x01, + // IPV6 -------------------------------------------------------------------------------------| SRC PORT DEST PORT + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0x7F, 0x00, 0x00, 0x01, 0xCC, 0x4C, 0x04, 0x01}, + remote: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 52300}, + local: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1025}, + }, + { + name: "TCP6/maximal", + // VER IP/TCP LENGTH + header: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x21, 0x00, 0x24, + // IPV6 -------------------------------------------------------------------------------------| + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + // IPV6 -------------------------------------------------------------------------------------| SRC PORT DEST PORT + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, + remote: &net.TCPAddr{IP: net.ParseIP("FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF"), Port: 65535}, + local: &net.TCPAddr{IP: net.ParseIP("FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF"), Port: 65535}, + }, + { + name: "TCP6/::1", + // VER IP/TCP LENGTH + header: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x21, 0x00, 0x2B, + // IPV6 -------------------------------------------------------------------------------------| + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // IPV6 -------------------------------------------------------------------------------------| SRC PORT DEST PORT + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xCF, 0x8F, 0x04, 0x01, + // TLVs + 0x03, 0x00, 0x04, 0xFD, 0x16, 0xEE, 0x60}, + remote: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 53135}, + local: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 1025}, + rawTLVs: []byte{0x03, 0x00, 0x04, 0xFD, 0x16, 0xEE, 0x60}, + }, + { + name: "UDP6/::1", + // VER IP/TCP LENGTH + header: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x22, 0x00, 0x2B, + // IPV6 -------------------------------------------------------------------------------------| + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // IPV6 -------------------------------------------------------------------------------------| SRC PORT DEST PORT + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xCF, 0x8F, 0x04, 0x01, + // TLVs + 0x03, 0x00, 0x04, 0xFD, 0x16, 0xEE, 0x60}, + remote: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 53135}, + local: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 1025}, + rawTLVs: []byte{0x03, 0x00, 0x04, 0xFD, 0x16, 0xEE, 0x60}, + }, + { + name: "invalid/missing_proto_family_length", + header: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21}, + err: true, + }, + { + name: "invalid/version", + header: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x01, 0x21, 0x00, 0x2B}, + err: true, + }, + { + name: "invalid/length_too_long", + header: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x21, 0x08, 0x01}, + err: true, + }, + { + name: "invalid/too_less_bytes", + header: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x21, 0x08, 0x00}, + err: true, + }, + { + name: "invalid/local_with_no_trailing bytes", + header: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x20, 0x00, 0x00, 0x00}, + err: true, + }, + { + name: "invalid/local_with_trailing_bytes_TLVs", + header: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x20, 0xFF, 0x00, 0x07, + 0x03, 0x00, 0x04, 0xFD, 0x16, 0xEE, 0x60}, + err: true, + }, + { + name: "invalid/proxy_with_zero_byte_header", + header: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x00, 0x00, 0x00}, + err: true, + }, + { + name: "invalid/invalid-length_for_IPv4", + header: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x01, 0xFF}, + err: true, + }, + { + name: "invalid/invalid_length_for_IPv6", + header: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x21, 0x00, 0x01, 0xFF}, + err: true, + }, + { + name: "invalid/unix_socket_not_implemented", + header: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x31, 0x00, 0x01, 0xFF}, + err: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := bufio.NewReader(bytes.NewReader(tt.header)) + h := readV2Header(r) + + if tt.err { + require.Nil(t, h) + } else { + require.Equal(t, tt.local.String(), h.LocalAddr.String()) + require.Equal(t, tt.remote.String(), h.RemoteAddr.String()) + require.Equal(t, tt.rawTLVs, h.RawTLVs) + require.Equal(t, 2, h.Version) + } + + }) + } +} + +func TestBrokenReader(t *testing.T) { + + tests := []struct { + name string + bytes []byte + read func(i int, p []byte) (n int, err error) + }{ + { + name: "break_on_first_12_bytes", + }, + { + name: "break_on_13_byte", + + bytes: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}, + }, + { + name: "break_on_14_byte", + bytes: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21}, + }, + { + name: "break_on_16_byte", + bytes: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11}, + }, + { + name: "invalid_on_14_byte", + bytes: []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x01}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + h := readV2Header(bufio.NewReader(bytes.NewReader(test.bytes))) + + require.Nil(t, h) + }) + } +} diff --git a/ext/proxyproto/helper.go b/ext/proxyproto/helper.go new file mode 100644 index 0000000..daf7e40 --- /dev/null +++ b/ext/proxyproto/helper.go @@ -0,0 +1,15 @@ +package proxyproto + +import ( + "net" +) + +func toAddr(transport Protocol, ip net.IP, port uint16) net.Addr { + if transport.IsStream() { + return &net.TCPAddr{IP: ip, Port: int(port)} + } + if transport.IsDatagram() { + return &net.UDPAddr{IP: ip, Port: int(port)} + } + return nil +} diff --git a/ext/proxyproto/helper_test.go b/ext/proxyproto/helper_test.go new file mode 100644 index 0000000..56677a8 --- /dev/null +++ b/ext/proxyproto/helper_test.go @@ -0,0 +1,15 @@ +package proxyproto + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestToAddr(t *testing.T) { + t.Run("only_tcp_udp_work", func(t *testing.T) { + addr := toAddr(0, net.IPv4zero, 0) + require.Nil(t, addr) + }) +} diff --git a/ext/proxyproto/listener.go b/ext/proxyproto/listener.go new file mode 100644 index 0000000..a000c55 --- /dev/null +++ b/ext/proxyproto/listener.go @@ -0,0 +1,25 @@ +package proxyproto + +import ( + "net" +) + +type listener struct { + net.Listener +} + +func (l *listener) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + return NewConn(c), nil +} + +// NewListener wraps a net.Listener and returns a new net.Listener that returns +// a proxyproto.Conn when Accept is called. +// +// It is used to handle PROXY protocol connections. +func NewListener(l net.Listener) net.Listener { + return &listener{Listener: l} +} diff --git a/ext/proxyproto/serve.go b/ext/proxyproto/serve.go new file mode 100644 index 0000000..4ed895b --- /dev/null +++ b/ext/proxyproto/serve.go @@ -0,0 +1,38 @@ +package proxyproto + +import ( + "net" + "net/http" +) + +// ListenAndServe listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming connections. +func ListenAndServe(srv *http.Server) error { + addr := srv.Addr + if addr == "" { + addr = ":http" + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + return srv.Serve(NewListener(ln)) +} + +// ListenAndServeTLS listens on the TCP network address srv.Addr and then calls +// ServeTLS to handle requests on incoming TLS connections. +func ListenAndServeTLS(srv *http.Server, certFile, keyFile string) error { + addr := srv.Addr + if addr == "" { + addr = ":https" + } + + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + defer ln.Close() // skipcq: GO-S2307 + + return srv.ServeTLS(NewListener(ln), certFile, keyFile) +} diff --git a/ext/proxyproto/serve_test.go b/ext/proxyproto/serve_test.go new file mode 100644 index 0000000..24fdb86 --- /dev/null +++ b/ext/proxyproto/serve_test.go @@ -0,0 +1,121 @@ +package proxyproto + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestListenAndServe(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) // nolint: errcheck + })) + + srv := &http.Server{ // skipcq: GO-S2112 + Addr: strings.TrimPrefix(s.URL, "http://"), + Handler: s.Config.Handler, + } + + defer srv.Close() + + ln, err := net.Listen("tcp", ":http") // nolint:errcheck + if err == nil { + defer ln.Close() + } + + t.Run("http_80", func(t *testing.T) { + srv := &http.Server{ // skipcq: GO-S2112 + } + err := ListenAndServe(srv) + require.NotNil(t, err) + }) + + t.Run("fail_to_listen", func(t *testing.T) { + err := ListenAndServe(srv) + require.NotNil(t, err) + }) + + s.Close() + + t.Run("listen", func(t *testing.T) { + go ListenAndServe(srv) // nolint: errcheck + + time.Sleep(100 * time.Millisecond) + + resp, err := http.Get(s.URL) + require.NoError(t, err) + defer resp.Body.Close() // skipcq: GO-S2307 + require.Equal(t, http.StatusOK, resp.StatusCode) + + }) + + srv.Shutdown(context.TODO()) // nolint: errcheck + +} + +func TestListenAndServeTLS(t *testing.T) { + + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} // skipcq: GSC-G402, GO-S1020 + client := http.Client{ + Transport: tr, + } + + s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) // nolint: errcheck + })) + + time.Sleep(100 * time.Millisecond) + + srv := &http.Server{ // skipcq: GO-S2112 + Addr: strings.TrimPrefix(s.URL, "https://"), + Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) // nolint: errcheck + }), + TLSConfig: &tls.Config{Certificates: s.TLS.Certificates}, // skipcq: GSC-G402, GO-S1020 + } + + defer srv.Close() + + ln, err := net.Listen("tcp", ":https") // nolint:errcheck + if err == nil { + defer ln.Close() + } + + t.Run("https_443", func(t *testing.T) { + srv := &http.Server{ // skipcq: GO-S2112 + } + err := ListenAndServeTLS(srv, "", "") + require.NotNil(t, err) + }) + + t.Run("fail_to_listen", func(t *testing.T) { + err := ListenAndServeTLS(srv, "", "") + require.NotNil(t, err) + }) + + s.Close() + + t.Run("listen", func(t *testing.T) { + go ListenAndServeTLS(srv, "", "") // nolint: errcheck + + time.Sleep(100 * time.Millisecond) + + resp, err := client.Get(s.URL) + require.NoError(t, err) + defer resp.Body.Close() // skipcq: GO-S2307 + require.Equal(t, http.StatusOK, resp.StatusCode) + }) + + srv.Shutdown(context.TODO()) // nolint: errcheck +} diff --git a/go.mod b/go.mod index efa7620..e9a1efb 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/go-playground/validator/v10 v10.24.0 github.com/json-iterator/go v1.1.12 github.com/stretchr/testify v1.10.0 - golang.org/x/crypto v0.32.0 + golang.org/x/crypto v0.33.0 ) require ( @@ -19,8 +19,9 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - golang.org/x/net v0.34.0 // indirect - golang.org/x/sys v0.29.0 // indirect - golang.org/x/text v0.21.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect + golang.org/x/net v0.35.0 // indirect + golang.org/x/sys v0.30.0 // indirect + golang.org/x/text v0.22.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index bf19d15..a9fc18c 100644 --- a/go.sum +++ b/go.sum @@ -29,17 +29,19 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= -golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= -golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= -golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= -golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= -golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= -golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= +golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= +golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= +golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= From a4888e75f0b91e8e9dc529efe1e10c494f9d1ba0 Mon Sep 17 00:00:00 2001 From: Lz Date: Tue, 11 Feb 2025 20:44:02 +0800 Subject: [PATCH 3/8] !feat(form): moved out Binder/Validator to reduce 3rd dependencies in the core class library (#41) --- README.md | 8 ++--- context.go | 6 ++++ binder.go => ext/form/binder.go | 15 +++------ binder_test.go => ext/form/binder_test.go | 38 ++++++++++++++++++----- validate.go => ext/form/validate.go | 2 +- 5 files changed, 46 insertions(+), 23 deletions(-) rename binder.go => ext/form/binder.go (90%) rename binder_test.go => ext/form/binder_test.go (85%) rename validate.go => ext/form/validate.go (99%) diff --git a/README.md b/README.md index 38bd40b..2b42dee 100644 --- a/README.md +++ b/README.md @@ -415,7 +415,7 @@ type Login struct { #### BindQuery ```go app.Get("/login", func(c *Context) error { - it, err := xun.BindQuery[Login](c.Request) + it, err := form.BindQuery[Login](c.Request) if err != nil { c.WriteStatus(http.StatusBadRequest) return ErrCancelled @@ -432,7 +432,7 @@ type Login struct { #### BindForm ```go app.Post("/login", func(c *Context) error { - it, err := xun.BindForm[Login](c.Request) + it, err := form.BindForm[Login](c.Request) if err != nil { c.WriteStatus(http.StatusBadRequest) return ErrCancelled @@ -449,7 +449,7 @@ app.Post("/login", func(c *Context) error { #### BindJson ```go app.Post("/login", func(c *Context) error { - it, err := xun.BindJson[Login](c.Request) + it, err := form.BindJson[Login](c.Request) if err != nil { c.WriteStatus(http.StatusBadRequest) return ErrCancelled @@ -790,7 +790,7 @@ admin := app.Group("/admin") app.Post("/login", func(c *xun.Context) error { - it, err := xun.BindForm[Login](c.Request) + it, err := form.BindForm[Login](c.Request) if err != nil { c.WriteStatus(http.StatusBadRequest) diff --git a/context.go b/context.go index 8a2b556..35e4a67 100644 --- a/context.go +++ b/context.go @@ -3,6 +3,12 @@ package xun import ( "net/http" "strings" + + jsoniter "github.com/json-iterator/go" +) + +var ( + json = jsoniter.Config{UseNumber: false}.Froze() ) // Context is the primary structure for handling HTTP requests. diff --git a/binder.go b/ext/form/binder.go similarity index 90% rename from binder.go rename to ext/form/binder.go index aadc7e2..1b21265 100644 --- a/binder.go +++ b/ext/form/binder.go @@ -1,4 +1,4 @@ -package xun +package form import ( "net/http" @@ -20,10 +20,8 @@ func BindQuery[T any](req *http.Request) (*TEntity[T], error) { data := new(T) - err := formDecoder.Decode(data, req.URL.Query()) - if err != nil { - return nil, err - } + // new(T) always is a pointer + formDecoder.Decode(data, req.URL.Query()) // nolint: errcheck return &TEntity[T]{ Data: *data, @@ -45,11 +43,8 @@ func BindForm[T any](req *http.Request) (*TEntity[T], error) { return nil, err } - // r.PostForm is a map of our POST form values - err = formDecoder.Decode(data, req.PostForm) - if err != nil { - return nil, err - } + // new(T) always is a pointer + formDecoder.Decode(data, req.PostForm) // nolint: errcheck return &TEntity[T]{ Data: *data, diff --git a/binder_test.go b/ext/form/binder_test.go similarity index 85% rename from binder_test.go rename to ext/form/binder_test.go index 18a205a..a53aad6 100644 --- a/binder_test.go +++ b/ext/form/binder_test.go @@ -1,7 +1,8 @@ -package xun +package form import ( "bytes" + "crypto/tls" "net/http" "net/http/httptest" @@ -13,14 +14,21 @@ import ( ut "github.com/go-playground/universal-translator" trans "github.com/go-playground/validator/v10/translations/zh" "github.com/stretchr/testify/require" + "github.com/yaitoo/xun" ) func TestBinder(t *testing.T) { + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} // skipcq: GSC-G402, GO-S1020 + client := http.Client{ + Transport: tr, + } + mux := http.NewServeMux() srv := httptest.NewServer(mux) defer srv.Close() - app := New(WithMux(mux)) + app := xun.New(xun.WithMux(mux)) type Login struct { Email string `form:"email" json:"email" validate:"required,email"` @@ -29,11 +37,11 @@ func TestBinder(t *testing.T) { AddValidator(ut.New(zh.New()).GetFallback(), trans.RegisterDefaultTranslations) - app.Get("/login", func(c *Context) error { + app.Get("/login", func(c *xun.Context) error { it, err := BindQuery[Login](c.Request) if err != nil { c.WriteStatus(http.StatusBadRequest) - return ErrCancelled + return xun.ErrCancelled } if it.Validate(c.AcceptLanguage()...) && it.Data.Email == "xun@yaitoo.cn" && it.Data.Passwd == "123" { @@ -45,11 +53,11 @@ func TestBinder(t *testing.T) { return c.View(it) }) - app.Post("/login", func(c *Context) error { + app.Post("/login", func(c *xun.Context) error { it, err := BindForm[Login](c.Request) if err != nil { c.WriteStatus(http.StatusBadRequest) - return ErrCancelled + return xun.ErrCancelled } if it.Validate(c.AcceptLanguage()...) && it.Data.Email == "xun@yaitoo.cn" && it.Data.Passwd == "123" { @@ -61,11 +69,11 @@ func TestBinder(t *testing.T) { return c.View(it) }) - app.Put("/login", func(c *Context) error { + app.Put("/login", func(c *xun.Context) error { it, err := BindJson[Login](c.Request) if err != nil { c.WriteStatus(http.StatusBadRequest) - return ErrCancelled + return xun.ErrCancelled } if it.Validate(c.AcceptLanguage()...) && it.Data.Email == "xun@yaitoo.cn" && it.Data.Passwd == "123" { @@ -180,3 +188,17 @@ func TestBinder(t *testing.T) { } } + +func TestInvalid(t *testing.T) { + + t.Run("invalid_form", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Body = nil + _, err := BindForm[int](req) + require.NotNil(t, err) + }) + t.Run("invalid_json", func(t *testing.T) { + _, err := BindJson[int](httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`"`))) + require.NotNil(t, err) + }) +} diff --git a/validate.go b/ext/form/validate.go similarity index 99% rename from validate.go rename to ext/form/validate.go index a3193a7..850f81b 100644 --- a/validate.go +++ b/ext/form/validate.go @@ -1,4 +1,4 @@ -package xun +package form import ( "github.com/go-playground/locales/en" From 2b113183e280202559fabbb115f5867efd090dce Mon Sep 17 00:00:00 2001 From: Lz Date: Tue, 11 Feb 2025 21:24:31 +0800 Subject: [PATCH 4/8] fix(viewer): doesn't write body in HEAD request (#42) --- viewer_html.go | 21 +++++++++++---------- viewer_html_test.go | 29 +++++++++++++++++++++++++++++ viewer_json.go | 17 ++++++++++------- viewer_string.go | 8 ++++++-- viewer_string_test.go | 2 +- viewer_text.go | 18 +++++++++++------- viewer_text_test.go | 29 +++++++++++++++++++++++++++++ viewer_xml.go | 18 +++++++++++------- 8 files changed, 108 insertions(+), 34 deletions(-) create mode 100644 viewer_html_test.go create mode 100644 viewer_text_test.go diff --git a/viewer_html.go b/viewer_html.go index 56354a6..7965068 100644 --- a/viewer_html.go +++ b/viewer_html.go @@ -28,16 +28,17 @@ func (*HtmlViewer) MimeType() *MimeType { // This implementation uses the `HtmlTemplate.Execute` method to render the template. // The rendered result is written to the http.ResponseWriter. func (v *HtmlViewer) Render(w http.ResponseWriter, r *http.Request, data any) error { // skipcq: RVV-B0012 - - buf := BufPool.Get() - defer BufPool.Put(buf) - - err := v.template.Execute(buf, data) - if err != nil { - return err - } - + var err error w.Header().Set("Content-Type", "text/html; charset=utf-8") - _, err = buf.WriteTo(w) + if r.Method != http.MethodHead { + buf := BufPool.Get() + defer BufPool.Put(buf) + + err = v.template.Execute(buf, data) + if err != nil { + return err + } + _, err = buf.WriteTo(w) + } return err } diff --git a/viewer_html_test.go b/viewer_html_test.go new file mode 100644 index 0000000..0dfca2e --- /dev/null +++ b/viewer_html_test.go @@ -0,0 +1,29 @@ +package xun + +import ( + "html/template" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestInvalidHtmlTemplate(t *testing.T) { + + type Data struct { + Name string + } + + l, err := template.New("invalid").Parse(`

Hello, {{.Name}}!

Age: {{.Age}}

`) + require.NoError(t, err) + + v := &HtmlViewer{ + template: &HtmlTemplate{ + template: l, + }, + } + + err = v.Render(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil), Data{}) + require.NotNil(t, err) +} diff --git a/viewer_json.go b/viewer_json.go index 5307f58..be39c02 100644 --- a/viewer_json.go +++ b/viewer_json.go @@ -23,15 +23,18 @@ func (*JsonViewer) MimeType() *MimeType { // // It sets the Content-Type header to "application/json". func (*JsonViewer) Render(w http.ResponseWriter, r *http.Request, data any) error { // skipcq: RVV-B0012 - buf := BufPool.Get() - defer BufPool.Put(buf) + var err error + w.Header().Set("Content-Type", "application/json") + if r.Method != http.MethodHead { + buf := BufPool.Get() + defer BufPool.Put(buf) - err := json.NewEncoder(buf).Encode(data) - if err != nil { - return err + err = json.NewEncoder(buf).Encode(data) + if err != nil { + return err + } + _, err = buf.WriteTo(w) } - w.Header().Set("Content-Type", "application/json") - _, err = buf.WriteTo(w) return err } diff --git a/viewer_string.go b/viewer_string.go index 2e2853c..7207bcc 100644 --- a/viewer_string.go +++ b/viewer_string.go @@ -24,11 +24,15 @@ func (*StringViewer) MimeType() *MimeType { // // It sets the Content-Type header to "text/plain; charset=utf-8". func (*StringViewer) Render(w http.ResponseWriter, r *http.Request, data any) error { // skipcq: RVV-B0012 + var err error + w.Header().Set("Content-Type", "text/plain; charset=utf-8") if data == nil { return nil } - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - _, err := fmt.Fprint(w, data) + if r.Method != http.MethodHead { + _, err = fmt.Fprint(w, data) + } + return err } diff --git a/viewer_string_test.go b/viewer_string_test.go index 8691e03..7ba39df 100644 --- a/viewer_string_test.go +++ b/viewer_string_test.go @@ -22,7 +22,7 @@ func TestStringViewer(t *testing.T) { err := v.Render(rw, httptest.NewRequest(http.MethodGet, "/", nil), nil) require.NoError(t, err) require.Equal(t, -1, rw.Code) // error StatusCode should not be written by StringViewer - require.Empty(t, rw.Header().Get("Content-Type")) + require.Equal(t, "text/plain; charset=utf-8", rw.Header().Get("Content-Type")) buf, err := io.ReadAll(rw.Body) require.NoError(t, err) require.Empty(t, buf) diff --git a/viewer_text.go b/viewer_text.go index 62eefa7..94a2231 100644 --- a/viewer_text.go +++ b/viewer_text.go @@ -18,15 +18,19 @@ func (v *TextViewer) MimeType() *MimeType { // It sets the Content-Type header to "text/plain; charset=utf-8" and writes the rendered content to the response. // If there is an error executing the template, it is returned. func (v *TextViewer) Render(w http.ResponseWriter, r *http.Request, data any) error { // skipcq: RVV-B0012 - buf := BufPool.Get() - defer BufPool.Put(buf) + var err error + w.Header().Set("Content-Type", v.template.mime.String()+v.template.charset) + if r.Method != http.MethodHead { + buf := BufPool.Get() + defer BufPool.Put(buf) + + err = v.template.Execute(buf, data) + if err != nil { + return err + } - err := v.template.Execute(buf, data) - if err != nil { - return err + _, err = buf.WriteTo(w) } - w.Header().Set("Content-Type", v.template.mime.String()+v.template.charset) - _, err = buf.WriteTo(w) return err } diff --git a/viewer_text_test.go b/viewer_text_test.go new file mode 100644 index 0000000..a9e716d --- /dev/null +++ b/viewer_text_test.go @@ -0,0 +1,29 @@ +package xun + +import ( + "net/http" + "net/http/httptest" + "testing" + "text/template" + + "github.com/stretchr/testify/require" +) + +func TestInvalidTextTemplate(t *testing.T) { + + type Data struct { + Name string + } + + l, err := template.New("invalid").Parse(`

Hello, {{.Name}}!

Age: {{.Age}}

`) + require.NoError(t, err) + + v := &TextViewer{ + template: &TextTemplate{ + template: l, + }, + } + + err = v.Render(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil), Data{}) + require.NotNil(t, err) +} diff --git a/viewer_xml.go b/viewer_xml.go index fcdf733..234f004 100644 --- a/viewer_xml.go +++ b/viewer_xml.go @@ -24,14 +24,18 @@ func (*XmlViewer) MimeType() *MimeType { // // It sets the Content-Type header to "text/xml; charset=utf-8". func (*XmlViewer) Render(w http.ResponseWriter, r *http.Request, data any) error { // skipcq: RVV-B0012 - buf := BufPool.Get() - defer BufPool.Put(buf) + var err error + w.Header().Set("Content-Type", "text/xml; charset=utf-8") + if r.Method != http.MethodHead { + buf := BufPool.Get() + defer BufPool.Put(buf) - err := xml.NewEncoder(buf).Encode(data) - if err != nil { - return err + err = xml.NewEncoder(buf).Encode(data) + if err != nil { + return err + } + _, err = buf.WriteTo(w) } - w.Header().Set("Content-Type", "text/xml; charset=utf-8") - _, err = buf.WriteTo(w) + return err } From 2e8d6383b470b02dc0c51d688d1cd39ca1f0adc4 Mon Sep 17 00:00:00 2001 From: Lz Date: Tue, 11 Feb 2025 22:08:34 +0800 Subject: [PATCH 5/8] fix(hsts): added IgnoreRules on Redirect (#43) --- ext/hsts/hsts.go | 12 ++++++++---- ext/hsts/hsts_test.go | 34 ++++++++++++++++++++++++++++++++++ ext/hsts/option.go | 19 ++++++++++++++++++- 3 files changed, 60 insertions(+), 5 deletions(-) diff --git a/ext/hsts/hsts.go b/ext/hsts/hsts.go index 2012d96..4e03957 100644 --- a/ext/hsts/hsts.go +++ b/ext/hsts/hsts.go @@ -61,13 +61,17 @@ func WriteHeader(opts ...Option) xun.Middleware { } // Redirect is a middleware that redirects plain HTTP requests to HTTPS. -func Redirect() xun.Middleware { +func Redirect(rules ...IgnoreRule) xun.Middleware { return func(next xun.HandleFunc) xun.HandleFunc { return func(c *xun.Context) error { - r := c.Request + if c.Request.TLS == nil && (c.Request.Method == "GET" || c.Request.Method == "HEAD") { + for _, it := range rules { + if it(c.Request) { + return next(c) + } + } - if r.TLS == nil && (r.Method == "GET" || r.Method == "HEAD") { - target := "https://" + stripPort(r.Host) + r.URL.RequestURI() + target := "https://" + stripPort(c.Request.Host) + c.Request.URL.RequestURI() c.Redirect(target, http.StatusFound) return xun.ErrCancelled diff --git a/ext/hsts/hsts_test.go b/ext/hsts/hsts_test.go index 5f0fc7f..d7511a0 100644 --- a/ext/hsts/hsts_test.go +++ b/ext/hsts/hsts_test.go @@ -144,6 +144,40 @@ func TestRedirect(t *testing.T) { require.Equal(t, "", resp.Header.Get("Strict-Transport-Security")) }) + t.Run("ignore_should_not_be_redirected", func(t *testing.T) { + mux := http.NewServeMux() + srv := httptest.NewServer(mux) + defer srv.Close() + app := xun.New(xun.WithMux(mux)) + app.Use(Redirect(Ignore("/status"))) + + u, err := url.Parse(srv.URL) + require.NoError(t, err) + + l := "https://" + u.Hostname() + "/" + + app.Get("/", func(c *xun.Context) error { + return c.View(nil) + }) + + req, err := http.NewRequest(http.MethodGet, srv.URL, nil) + require.NoError(t, err) + resp, err := c.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusFound, resp.StatusCode) + require.Equal(t, l, resp.Header.Get("Location")) + require.Equal(t, "", resp.Header.Get("Strict-Transport-Security")) + + req, err = http.NewRequest(http.MethodGet, srv.URL+"/status", nil) + require.NoError(t, err) + resp, err = c.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, "", resp.Header.Get("Location")) + require.Equal(t, "", resp.Header.Get("Strict-Transport-Security")) + + }) + } func TestStripPort(t *testing.T) { diff --git a/ext/hsts/option.go b/ext/hsts/option.go index 85a0094..7ec5706 100644 --- a/ext/hsts/option.go +++ b/ext/hsts/option.go @@ -1,6 +1,10 @@ package hsts -import "time" +import ( + "net/http" + "strings" + "time" +) // Config represents the configuration options for HSTS (HTTP Strict Transport Security). // It includes various settings such as MaxAge, IncludeSubDomains, and Preload. @@ -39,3 +43,16 @@ func WithPreload() Option { c.Preload = true } } + +type IgnoreRule func(*http.Request) bool + +func Ignore(paths ...string) IgnoreRule { + return func(r *http.Request) bool { + for _, path := range paths { + if strings.EqualFold(r.URL.Path, path) { + return true + } + } + return false + } +} From 9c23fef46d26d4adf2e39d8e8582426a0e7596ef Mon Sep 17 00:00:00 2001 From: Lz Date: Tue, 11 Feb 2025 22:29:15 +0800 Subject: [PATCH 6/8] fix(hsts): added StartsWith rule (#44) --- ext/hsts/hsts_test.go | 10 +++++++++- ext/hsts/option.go | 28 +++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/ext/hsts/hsts_test.go b/ext/hsts/hsts_test.go index d7511a0..c92919f 100644 --- a/ext/hsts/hsts_test.go +++ b/ext/hsts/hsts_test.go @@ -149,7 +149,7 @@ func TestRedirect(t *testing.T) { srv := httptest.NewServer(mux) defer srv.Close() app := xun.New(xun.WithMux(mux)) - app.Use(Redirect(Ignore("/status"))) + app.Use(Redirect(Match("/status"), StartsWith("/images"))) u, err := url.Parse(srv.URL) require.NoError(t, err) @@ -176,6 +176,14 @@ func TestRedirect(t *testing.T) { require.Equal(t, "", resp.Header.Get("Location")) require.Equal(t, "", resp.Header.Get("Strict-Transport-Security")) + req, err = http.NewRequest(http.MethodGet, srv.URL+"/images/xxx", nil) + require.NoError(t, err) + resp, err = c.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, "", resp.Header.Get("Location")) + require.Equal(t, "", resp.Header.Get("Strict-Transport-Security")) + }) } diff --git a/ext/hsts/option.go b/ext/hsts/option.go index 7ec5706..e5b934a 100644 --- a/ext/hsts/option.go +++ b/ext/hsts/option.go @@ -44,9 +44,15 @@ func WithPreload() Option { } } +// IgnoreRule is a function that takes a pointer to an http.Request +// and returns a boolean indicating whether the request should be +// ignored by the HSTS middleware. type IgnoreRule func(*http.Request) bool -func Ignore(paths ...string) IgnoreRule { +// Match creates an IgnoreRule that matches the given paths to ignore requests. +// +// The paths are matched case-insensitively, so "/Doc" and "/doc" would be equivalent. +func Match(paths ...string) IgnoreRule { return func(r *http.Request) bool { for _, path := range paths { if strings.EqualFold(r.URL.Path, path) { @@ -56,3 +62,23 @@ func Ignore(paths ...string) IgnoreRule { return false } } + +// StartsWith creates an IgnoreRule that checks if the request path starts with any of the specified prefixes. +// +// The provided paths are automatically converted to lower-case for consistent matching. +func StartsWith(paths ...string) IgnoreRule { + // Convert provided paths to lower-case once for consistent matching. + lowerPaths := make([]string, len(paths)) + for i, p := range paths { + lowerPaths[i] = strings.ToLower(p) + } + return func(r *http.Request) bool { + l := strings.ToLower(r.URL.Path) + for _, path := range lowerPaths { + if strings.HasPrefix(l, path) { + return true + } + } + return false + } +} From d41aa1780cf36feff89d490e2d4df03244028b18 Mon Sep 17 00:00:00 2001 From: Lz Date: Wed, 12 Feb 2025 08:08:09 +0800 Subject: [PATCH 7/8] feat(reqlog): added a middleware to logging requests (#45) --- README.md | 62 +++++++++++++++++++++++++++ ext/reqlog/format.go | 49 +++++++++++++++++++++ ext/reqlog/log.go | 34 +++++++++++++++ ext/reqlog/log_test.go | 88 ++++++++++++++++++++++++++++++++++++++ ext/reqlog/option.go | 86 +++++++++++++++++++++++++++++++++++++ response_writer.go | 1 + response_writer_deflate.go | 4 +- response_writer_gzip.go | 4 +- response_writer_std.go | 16 ++++++- 9 files changed, 341 insertions(+), 3 deletions(-) create mode 100644 ext/reqlog/format.go create mode 100644 ext/reqlog/log.go create mode 100644 ext/reqlog/log_test.go create mode 100644 ext/reqlog/option.go diff --git a/README.md b/README.md index 2b42dee..b3af594 100644 --- a/README.md +++ b/README.md @@ -599,6 +599,68 @@ The PROXY protocol allows our application to receive client connection informati proxyproto.ListenAndServeTLS(httpsServer, "", "") ``` +#### Logging + +Logs each incoming request to the provided logger. The format of the log messages is customizable using the `Format` option. The default format is the combined log format (XLF/ELF). + +> Enable `reqlog` middleware + +```go +func main(){ + //.... + logger, _ := setupLogger() + + app.Use(reqlog.New(reqlog.WithLogger(logger), + reqlog.WithUser(getUserID), + reqlog.WithVisitor(getVisitorID), + reqlog.WithFormat(reqlog.Combined)))) + //... +} + +func setupLogger() (*log.Logger, error) { + logFile, err := os.OpenFile("./access.log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return nil, err + } + return log.New(logFile, "", 0), nil +} + +func getVisitorID(c *xun.Context) string { + v, err := c.Request.Cookie("visitor_id") // use fingerprintjs to generate visitor id in client's cookie + if err != nil { + return "" + } + + return v.Value +} + +func getUserID(c *xun.Context) string { + v, _, err := cookie.GetSigned(c, "session_id", secretKey) + if err != nil { + return "" + } + + return v +} + +``` + +> Install GoAccess to generate real-time analysis report + +[How to install GoAccess](https://goaccess.io/get-started) + +```bash +goaccess ./access.log --geoip-database=./GeoLite2-ASN.mmdb --geoip-database=./GeoLite2-City.mmdb -o ./realtime.html --log-format=COMBINED --real-time-html +``` + +> Serve the online real-time analysis report +```go + app.Get("/reports/realtime.html", func(c *xun.Context) error { + http.ServeFile(c.Response, c.Request, "./realtime.html") + return nil + }) +``` + ### Works with [tailwindcss](https://tailwindcss.com/docs/installation) #### Install Tailwind CSS Install tailwindcss via npm, and create your tailwind.config.js file. diff --git a/ext/reqlog/format.go b/ext/reqlog/format.go new file mode 100644 index 0000000..23a6cdb --- /dev/null +++ b/ext/reqlog/format.go @@ -0,0 +1,49 @@ +package reqlog + +import ( + "fmt" + "net" + "time" + + "github.com/yaitoo/xun" +) + +// Format is a function type that takes a Context pointer, an Options pointer, and a time.Time as arguments. +// It is used to format log messages. +type Format func(c *xun.Context, options *Options, starts time.Time) + +// Combined log request with Combined Log Format (XLF/ELF) +func Combined(c *xun.Context, options *Options, starts time.Time) { + requestLine := fmt.Sprintf(`"%s %s %s"`, c.Request.Method, c.Request.URL.Path, c.Request.Proto) + host, _, _ := net.SplitHostPort(c.Request.RemoteAddr) + + //COMBINED: remote、visitor、user、datetime、request line、status、body_bytes_sent、referer、user-agent + options.Logger.Printf("%s %s %s %s %s %d %d \"%s\" \"%s\"\n", + host, + options.GetVisitor(c), + options.GetUser(c), + starts.Format("[02/Jan/2006:15:04:05 -0700]"), + requestLine, + c.Response.StatusCode(), + c.Response.BodyBytesSent(), + c.Request.Referer(), + c.Request.UserAgent(), + ) +} + +// Common log request with Common Log Format (CLF) +func Common(c *xun.Context, options *Options, starts time.Time) { + requestLine := fmt.Sprintf(`"%s %s %s"`, c.Request.Method, c.Request.URL.Path, c.Request.Proto) + host, _, _ := net.SplitHostPort(c.Request.RemoteAddr) + + //Common: remote、visitor、user、datetime、request line、status、body_bytes_sent + options.Logger.Printf("%s %s %s %s %s %d %d\n", + host, + options.GetVisitor(c), + options.GetUser(c), + starts.Format("[02/Jan/2006:15:04:05 -0700]"), + requestLine, + c.Response.StatusCode(), + c.Response.BodyBytesSent(), + ) +} diff --git a/ext/reqlog/log.go b/ext/reqlog/log.go new file mode 100644 index 0000000..54f442d --- /dev/null +++ b/ext/reqlog/log.go @@ -0,0 +1,34 @@ +package reqlog + +import ( + "log" + "time" + + "github.com/yaitoo/xun" +) + +// New returns a middleware that logs each incoming request to the provided +// logger. The format of the log messages is customizable using the Format +// option. The default format is the combined log format (XLF/ELF). +func New(opts ...Option) xun.Middleware { + options := &Options{ + Logger: log.Default(), + GetVisitor: Miss, + GetUser: Miss, + Format: Combined, + } + + for _, opt := range opts { + opt(options) + } + + return func(next xun.HandleFunc) xun.HandleFunc { + return func(c *xun.Context) error { + now := time.Now() + defer func() { + options.Format(c, options, now) + }() + return next(c) + } + } +} diff --git a/ext/reqlog/log_test.go b/ext/reqlog/log_test.go new file mode 100644 index 0000000..b6b1d31 --- /dev/null +++ b/ext/reqlog/log_test.go @@ -0,0 +1,88 @@ +package reqlog + +import ( + "bytes" + "log" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/yaitoo/xun" +) + +func TestLogging(t *testing.T) { + + getVisitor := func(c *xun.Context) string { + return c.Request.Header.Get("X-Visitor-Id") + } + + getUser := func(c *xun.Context) string { + return c.Request.Header.Get("X-User-Id") + } + + t.Run("combined", func(t *testing.T) { + buf := bytes.Buffer{} + + logger := log.New(&buf, "", 0) + m := New(WithLogger(logger), + WithUser(getUser), + WithVisitor(getVisitor), + WithFormat(Combined)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Visitor-Id", "combined-vid") + req.Header.Set("X-User-Id", "combined-uid") + req.Header.Set("User-Agent", "combined-agent") + req.Header.Set("Referer", "combined-referer") + + ctx := &xun.Context{ + Request: req, + Response: xun.NewResponseWriter(httptest.NewRecorder()), + } + + err := m(func(c *xun.Context) error { + return nil + })(ctx) + + require.NoError(t, err) + + l := buf.String() + + require.True(t, strings.HasSuffix(l, "] \"GET / HTTP/1.1\" 200 0 \"combined-referer\" \"combined-agent\"\n")) + require.Contains(t, l, "combined-vid combined-uid [") + }) + + t.Run("common", func(t *testing.T) { + buf := bytes.Buffer{} + + logger := log.New(&buf, "", 0) + m := New(WithLogger(logger), + WithUser(getUser), + WithVisitor(getVisitor), + WithFormat(Common)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("User-Agent", "common-agent") + req.Header.Set("Referer", "common-referer") + + ctx := &xun.Context{ + Request: req, + Response: xun.NewResponseWriter(httptest.NewRecorder()), + } + + ctx.WriteStatus(http.StatusFound) + + err := m(func(c *xun.Context) error { + return nil + })(ctx) + + require.NoError(t, err) + + l := buf.String() + + require.True(t, strings.HasSuffix(l, "] \"GET / HTTP/1.1\" 302 0\n")) + require.Contains(t, l, "- - [") + }) +} diff --git a/ext/reqlog/option.go b/ext/reqlog/option.go new file mode 100644 index 0000000..42d1710 --- /dev/null +++ b/ext/reqlog/option.go @@ -0,0 +1,86 @@ +package reqlog + +import ( + "log" + + "github.com/yaitoo/xun" +) + +// Miss is a default function that returns an empty string. +// It is used as a default argument for the WithVisitor and WithUser functions. +var Miss = func(*xun.Context) string { return "-" } + +// Options represents the configuration options for the RequestLog middleware. +// It allows customizing the request log message format, the logger instance, +// and the functions to retrieve visitor and user information from the request context. +type Options struct { + Logger *log.Logger + GetVisitor func(c *xun.Context) string + GetUser func(c *xun.Context) string + Format Format +} + +// Option is a function that takes a pointer to Options and modifies it. +// It is used to customize the behavior of the RequestLog middleware. +type Option func(o *Options) + +// WithLogger sets the logger for the RequestLog middleware. If not set, +// it will use the package-level logger from the log package. +func WithLogger(l *log.Logger) Option { + return func(o *Options) { + if l != nil { + o.Logger = l + } + } +} + +// WithVisitor sets a custom function to retrieve visitor information from the request context. +// It will be used to populate the visitor field in the request log message. +// +// The function should take a pointer to the xun.Context and return a string. +// The empty string will be replaced with a dash in the log message. +func WithVisitor(get func(c *xun.Context) string) Option { + return func(o *Options) { + if get != nil { + o.GetVisitor = func(c *xun.Context) string { + v := get(c) + if v == "" { + return "-" + } + + return v + } + } + + } +} + +// WithUser sets a custom function to retrieve user information from the request context. +// It will be used to populate the user field in the request log message. +// +// The function should take a pointer to the xun.Context and return a string. +// The empty string will be replaced with a dash in the log message. +func WithUser(get func(c *xun.Context) string) Option { + return func(o *Options) { + if get != nil { + o.GetUser = func(c *xun.Context) string { + + u := get(c) + if u == "" { + return "-" + } + + return u + } + } + } +} + +// WithFormat sets a custom format for the request log message. +func WithFormat(f Format) Option { + return func(o *Options) { + if f != nil { + o.Format = f + } + } +} diff --git a/response_writer.go b/response_writer.go index f144383..4b79a21 100644 --- a/response_writer.go +++ b/response_writer.go @@ -10,6 +10,7 @@ import ( type ResponseWriter interface { http.ResponseWriter + BodyBytesSent() int StatusCode() int Close() } diff --git a/response_writer_deflate.go b/response_writer_deflate.go index ae28516..a42e1fa 100644 --- a/response_writer_deflate.go +++ b/response_writer_deflate.go @@ -14,7 +14,9 @@ type deflateResponseWriter struct { // Write writes the data to the underlying gzip writer. // It implements the io.Writer interface. func (rw *deflateResponseWriter) Write(p []byte) (int, error) { - return rw.w.Write(p) + n, err := rw.w.Write(p) + rw.bodySentBytes += n + return n, err } // Close closes the underlying writer, flushing any buffered data to the client. diff --git a/response_writer_gzip.go b/response_writer_gzip.go index add65e6..6bd2e50 100644 --- a/response_writer_gzip.go +++ b/response_writer_gzip.go @@ -14,7 +14,9 @@ type gzipResponseWriter struct { // Write writes the data to the underlying gzip writer. // It implements the io.Writer interface. func (rw *gzipResponseWriter) Write(p []byte) (int, error) { - return rw.w.Write(p) + n, err := rw.w.Write(p) + rw.bodySentBytes += n + return n, err } // Close closes the gzipResponseWriter, ensuring that the underlying writer is also closed. diff --git a/response_writer_std.go b/response_writer_std.go index b928a8b..26c31d3 100644 --- a/response_writer_std.go +++ b/response_writer_std.go @@ -5,7 +5,8 @@ import "net/http" // stdResponseWriter is a wrapper around http.ResponseWriter to implement the ResponseWriter interface. type stdResponseWriter struct { http.ResponseWriter - statusCode int + bodySentBytes int + statusCode int } // Close implements the ResponseWriter interface Close method. @@ -27,6 +28,19 @@ func (rw *stdResponseWriter) StatusCode() int { return rw.statusCode } +func (rw *stdResponseWriter) BodyBytesSent() int { + return rw.bodySentBytes +} + +func (rw *stdResponseWriter) Write(b []byte) (int, error) { + + n, err := rw.ResponseWriter.Write(b) + + rw.bodySentBytes = rw.bodySentBytes + n + + return n, err +} + func NewResponseWriter(rw http.ResponseWriter) ResponseWriter { return &stdResponseWriter{ResponseWriter: rw} } From 0e69a1d1c066108bfe5339583caa4271362ebf68 Mon Sep 17 00:00:00 2001 From: Lz Date: Thu, 13 Feb 2025 14:23:55 +0800 Subject: [PATCH 8/8] feat(ext): added `ext/csrf` (#46) --- README.md | 37 ++++++++ ext/csrf/csrf.go | 105 +++++++++++++++++++++ ext/csrf/csrf.js | 23 +++++ ext/csrf/csrf_test.go | 203 +++++++++++++++++++++++++++++++++++++++++ ext/csrf/option.go | 38 ++++++++ ext/csrf/token.go | 85 +++++++++++++++++ ext/csrf/token_test.go | 32 +++++++ 7 files changed, 523 insertions(+) create mode 100644 ext/csrf/csrf.go create mode 100644 ext/csrf/csrf.js create mode 100644 ext/csrf/csrf_test.go create mode 100644 ext/csrf/option.go create mode 100644 ext/csrf/token.go create mode 100644 ext/csrf/token_test.go diff --git a/README.md b/README.md index b3af594..4be72cf 100644 --- a/README.md +++ b/README.md @@ -661,6 +661,43 @@ goaccess ./access.log --geoip-database=./GeoLite2-ASN.mmdb --geoip-database=./Ge }) ``` +#### CSRF Token +A CSRF (Cross-Site Request Forgery) token is a unique security measure designed to protect web applications from unauthorized or malicious requests. see more [examples](./ext/csrf/csrf_test.go) + +> Enable `csrf` middleware +```go +func main(){ + //.... + secretKey := []byte("your-secret-key") + + app.Use(csrf.New(secretKey)) + //... +} +``` + +> Enable `JsToken` to prevent bot requests on POST/PUT/DELETE + +- enable `csrf` with JsToken +```go +func main(){ + //.... + secretKey := []byte("your-secret-key") + + app.Use(csrf.New(secretKey,csrf.WithJsToken())) + + app.Get("/assets/csrf.js",csrf.HandleFunc(secretKey)) + //... +} +``` + +- load `csrf.js` on html +```html + +``` + + + + ### Works with [tailwindcss](https://tailwindcss.com/docs/installation) #### Install Tailwind CSS Install tailwindcss via npm, and create your tailwind.config.js file. diff --git a/ext/csrf/csrf.go b/ext/csrf/csrf.go new file mode 100644 index 0000000..2a47e80 --- /dev/null +++ b/ext/csrf/csrf.go @@ -0,0 +1,105 @@ +package csrf + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "embed" + "encoding/base64" + "html/template" + "io" + "net/http" + "strings" + "time" + + "github.com/yaitoo/xun" +) + +// New returns a middleware that generates a CSRF token and validates it in the request. +// +// The opts parameter is a list of Option functions that can be used to customize the +// behavior of the CSRF middleware. See the Option type for more information. +func New(secretKey []byte, opts ...Option) xun.Middleware { + o := &Options{ + SecretKey: secretKey, + CookieName: DefaultCookieName, + } + + for _, opt := range opts { + opt(o) + } + + return func(next xun.HandleFunc) xun.HandleFunc { + return func(c *xun.Context) error { + + token, _ := c.Request.Cookie(o.CookieName) // nolint: errcheck + + if c.Request.Method == "GET" || c.Request.Method == "HEAD" || c.Request.Method == "OPTIONS" { + if token == nil { // csrf_token doesn't exists + setTokenCookie(c, o) + } + + return next(c) + } + + if !verifyToken(token, c.Request, o) { + c.WriteStatus(http.StatusTeapot) + return xun.ErrCancelled + } + + return next(c) + } + } +} + +//go:embed csrf.js +var fsys embed.FS + +var zeroTime time.Time + +// HandleFunc serves the JavaScript token required for the CSRF middleware. +// +// It takes the secret key and options to customize the behavior. See the Option +// type for more information. +func HandleFunc(secretKey []byte, opts ...Option) xun.HandleFunc { + o := &Options{ + SecretKey: secretKey, + CookieName: DefaultCookieName, + } + for _, opt := range opts { + opt(o) + } + + f, _ := fsys.Open("csrf.js") // nolint: errcheck + defer f.Close() + + buf, _ := io.ReadAll(f) // nolint: errcheck + + t, _ := template.New("token").Parse(string(buf)) // nolint: errcheck + + var processed bytes.Buffer + t.Execute(&processed, o) // nolint: errcheck + + mac := hmac.New(sha256.New, o.SecretKey) + mac.Write(processed.Bytes()) + + etag := base64.URLEncoding.EncodeToString(mac.Sum(nil)) + + return func(c *xun.Context) error { + if match := c.Request.Header.Get("If-None-Match"); match != "" { + for _, it := range strings.Split(match, ",") { + if strings.TrimSpace(it) == etag { + c.Response.WriteHeader(http.StatusNotModified) + return nil + } + } + } + + c.Response.Header().Set("ETag", etag) + + content := bytes.NewReader(processed.Bytes()) + c.Response.Header().Set("Content-Type", "application/javascript") + http.ServeContent(c.Response, c.Request, "csrf.js", zeroTime, content) + return nil + } +} diff --git a/ext/csrf/csrf.js b/ext/csrf/csrf.js new file mode 100644 index 0000000..5014880 --- /dev/null +++ b/ext/csrf/csrf.js @@ -0,0 +1,23 @@ + function getCsrfToken() { + const cookies = document.cookie.split(';'); + for (let cookie of cookies) { + cookie = cookie.trim(); + const [key, ...valueParts] = cookie.split('='); + const value = valueParts.join('='); + if (key === "{{ .CookieName }}") { + return value; + } + } + return null; +} + +function setCsrfToken() { + const token = getCsrfToken() + if (token === null) { + return + } + const name =`js_{{.CookieName}}`; + document.cookie = `${name}=${token};path=/;samesite=lax`; +} + +setCsrfToken(); \ No newline at end of file diff --git a/ext/csrf/csrf_test.go b/ext/csrf/csrf_test.go new file mode 100644 index 0000000..8140512 --- /dev/null +++ b/ext/csrf/csrf_test.go @@ -0,0 +1,203 @@ +package csrf + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "html/template" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "github.com/yaitoo/xun" +) + +var nop = func(c *xun.Context) error { + c.WriteStatus(http.StatusOK) + return nil +} + +func TestNew(t *testing.T) { + secretKey := []byte("test-secret-key-123") + + t.Run("set_cookie_when_missed", func(t *testing.T) { + m := New(secretKey) + + ctx := createContext(httptest.NewRequest("GET", "/", nil)) + + err := m(nop)(ctx) + require.NoError(t, err) + + require.NotNil(t, ctx.Response.Header().Get("Set-Cookie")) + }) + + t.Run("skip_set_cookie_when_present", func(t *testing.T) { + m := New(secretKey) + + ctx := createContext(httptest.NewRequest("GET", "/", nil)) + + ctx.Request.AddCookie(&http.Cookie{ + Name: DefaultCookieName, + Value: "test", + }) + + err := m(nop)(ctx) + require.NoError(t, err) + + require.Empty(t, ctx.Response.Header().Get("Set-Cookie")) + }) + + t.Run("verify_token", func(t *testing.T) { + m := New(secretKey) + + // fails + ctx := createContext(httptest.NewRequest("POST", "/", nil)) + err := m(nop)(ctx) + require.ErrorIs(t, err, xun.ErrCancelled) + require.Equal(t, http.StatusTeapot, ctx.Response.StatusCode()) + + // success + ctx = createContext(httptest.NewRequest("POST", "/", nil)) + cookie := generateToken(&Options{ + SecretKey: secretKey, + CookieName: DefaultCookieName, + }) + + ctx.Request.AddCookie(cookie) + + err = m(nop)(ctx) + require.NoError(t, err) + require.Equal(t, http.StatusOK, ctx.Response.StatusCode()) + + // skip + ctx = createContext(httptest.NewRequest("GET", "/", nil)) + err = m(nop)(ctx) + require.NoError(t, err) + require.Equal(t, http.StatusOK, ctx.Response.StatusCode()) + }) + + t.Run("options", func(t *testing.T) { + m := New(secretKey, WithCookie("test-cookie-name")) + + ctx := createContext(httptest.NewRequest("GET", "/", nil)) + + err := m(nop)(ctx) + require.NoError(t, err) + + require.NotNil(t, ctx.Response.Header().Get("Set-Cookie")) + require.Contains(t, ctx.Response.Header().Get("Set-Cookie"), "test-cookie-name=") + }) + + t.Run("verify_js_token", func(t *testing.T) { + m := New(secretKey, WithCookie("test_token"), WithJsToken()) + + cookie := generateToken(&Options{ + SecretKey: secretKey, + CookieName: "test_token", + }) + + // fails on js token + ctx := createContext(httptest.NewRequest("POST", "/", nil)) + ctx.Request.AddCookie(cookie) + + err := m(nop)(ctx) + require.ErrorIs(t, err, xun.ErrCancelled) + require.Equal(t, http.StatusTeapot, ctx.Response.StatusCode()) + + // fails on js token + ctx = createContext(httptest.NewRequest("POST", "/", nil)) + ctx.Request.AddCookie(cookie) + ctx.Request.AddCookie(&http.Cookie{ + Name: "js_test_token", + Value: "", + }) + err = m(nop)(ctx) + require.ErrorIs(t, err, xun.ErrCancelled) + require.Equal(t, http.StatusTeapot, ctx.Response.StatusCode()) + + // success + ctx = createContext(httptest.NewRequest("POST", "/", nil)) + + ctx.Request.AddCookie(cookie) + ctx.Request.AddCookie(&http.Cookie{ + Name: "js_test_token", + Value: cookie.Value, + }) + + err = m(nop)(ctx) + require.NoError(t, err) + require.Equal(t, http.StatusOK, ctx.Response.StatusCode()) + + // skip + ctx = createContext(httptest.NewRequest("GET", "/", nil)) + err = m(nop)(ctx) + require.NoError(t, err) + require.Equal(t, http.StatusOK, ctx.Response.StatusCode()) + }) + +} + +func TestHandleFunc(t *testing.T) { + + fn := HandleFunc([]byte("secret"), WithCookie("test_token")) + + t.Run("load", func(t *testing.T) { + w := httptest.NewRecorder() + ctx := &xun.Context{ + Request: httptest.NewRequest("GET", "/csrf.js", nil), + Response: xun.NewResponseWriter(w), + } + + err := fn(ctx) + require.NoError(t, err) + + require.Equal(t, http.StatusOK, w.Code) + require.Contains(t, w.Body.String(), `"test_token"`) + }) + + t.Run("etag", func(t *testing.T) { + f, _ := fsys.Open("csrf.js") // nolint: errcheck + defer f.Close() + + buf, _ := io.ReadAll(f) // nolint: errcheck + + p, _ := template.New("token").Parse(string(buf)) // nolint: errcheck + + var processed bytes.Buffer + // nolint: errcheck + p.Execute(&processed, &Options{ + SecretKey: []byte("secret"), + CookieName: "test_token", + }) + + mac := hmac.New(sha256.New, []byte("secret")) + mac.Write(processed.Bytes()) + + etag := base64.URLEncoding.EncodeToString(mac.Sum(nil)) + + w := httptest.NewRecorder() + ctx := &xun.Context{ + Request: httptest.NewRequest("GET", "/csrf.js", nil), + Response: xun.NewResponseWriter(w), + } + + ctx.Request.Header.Set("If-None-Match", etag) + + err := fn(ctx) + require.NoError(t, err) + + require.Equal(t, http.StatusNotModified, w.Code) + + }) + +} + +func createContext(r *http.Request) *xun.Context { + return &xun.Context{ + Request: r, + Response: xun.NewResponseWriter(httptest.NewRecorder()), + } +} diff --git a/ext/csrf/option.go b/ext/csrf/option.go new file mode 100644 index 0000000..cce0e62 --- /dev/null +++ b/ext/csrf/option.go @@ -0,0 +1,38 @@ +package csrf + +const ( + DefaultCookieName = "csrf_token" +) + +// Options represents the configuration for the CSRF middleware. +// It allows customizing the secret key, cookie name, maximum age, +// and an expiration function for the CSRF token. +type Options struct { + SecretKey []byte + CookieName string + JsToken bool +} + +// Option is a function type that takes a pointer to Options and modifies it. +// It is used to customize the behavior of the CSRF middleware. +type Option func(o *Options) + +// WithCookie sets the name of the cookie to use for storing the CSRF token. +// Defaults to "csrf_token". +func WithCookie(name string) Option { + return func(o *Options) { + if name != "" { + o.CookieName = name + } + } +} + +// WithJsToken enables the JavaScript token feature for CSRF protection. +// +// It sets the JsToken field in the Options struct to true, allowing the +// middleware to generate and handle CSRF tokens via JavaScript. +func WithJsToken() Option { + return func(o *Options) { + o.JsToken = true + } +} diff --git a/ext/csrf/token.go b/ext/csrf/token.go new file mode 100644 index 0000000..694e1fb --- /dev/null +++ b/ext/csrf/token.go @@ -0,0 +1,85 @@ +package csrf + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" + "net/http" + "strings" + + "github.com/yaitoo/xun" +) + +// verifyToken checks if the given cookie token is valid. +func verifyToken(token *http.Cookie, r *http.Request, o *Options) bool { + if token == nil { + return false + } + + if o.JsToken { + jsToken, err := r.Cookie("js_" + o.CookieName) + if err != nil { + return false + } + + if token.Value != jsToken.Value { + return false + } + } + + parts := strings.Split(token.Value, ".") + if len(parts) != 2 { + return false + } + + randomBytes, err := base64.URLEncoding.DecodeString(parts[0]) + if err != nil { + return false + } + + mac := hmac.New(sha256.New, o.SecretKey) + mac.Write(randomBytes) + expected := mac.Sum(nil) + + actual, err := base64.URLEncoding.DecodeString(parts[1]) + if err != nil { + return false + } + + return hmac.Equal(expected, actual) +} + +func generateToken(o *Options) *http.Cookie { + randomBytes := make([]byte, 32) + if _, err := rand.Read(randomBytes); err != nil { + return nil + } + + mac := hmac.New(sha256.New, o.SecretKey) + mac.Write(randomBytes) + signature := mac.Sum(nil) + + token := fmt.Sprintf( + "%s.%s", + base64.URLEncoding.EncodeToString(randomBytes), + base64.URLEncoding.EncodeToString(signature), + ) + + return &http.Cookie{ + Name: o.CookieName, + Value: token, + HttpOnly: false, + SameSite: http.SameSiteLaxMode, + Path: "/", + } +} + +func setTokenCookie(c *xun.Context, o *Options) { + token := generateToken(o) + + if token != nil { + http.SetCookie(c.Response, token) + } +} diff --git a/ext/csrf/token_test.go b/ext/csrf/token_test.go new file mode 100644 index 0000000..f4089d9 --- /dev/null +++ b/ext/csrf/token_test.go @@ -0,0 +1,32 @@ +package csrf + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestToken(t *testing.T) { + + ok := verifyToken(&http.Cookie{ + Value: "", + }, nil, &Options{}) + + require.Equal(t, false, ok) + + ok = verifyToken(&http.Cookie{ + Value: ".", + }, nil, &Options{}) + require.Equal(t, false, ok) + + ok = verifyToken(&http.Cookie{ + Value: "0.", + }, nil, &Options{}) + require.Equal(t, false, ok) + + ok = verifyToken(&http.Cookie{ + Value: "BVaCD9NBke1Oq2rtkU_bjRcWOEGrNmTGYd9ikcQ_5HM=.0", + }, nil, &Options{}) + require.Equal(t, false, ok) +}