From e1bbb9a1b477ba6ab8c490bf2ac1c4a6a4fcec3e Mon Sep 17 00:00:00 2001 From: Lz Date: Thu, 20 Feb 2025 19:28:12 +0800 Subject: [PATCH 1/8] fix(htmx): added ext/htmx.js library (#52) --- README.md | 40 ++++++++------ etag.go | 120 ++++++++++++++++++++++++++++++++++++++++++ etag_test.go | 88 +++++++++++++++++++++++++++++++ ext/csrf/csrf.go | 45 ++++++++-------- ext/csrf/csrf_test.go | 10 ++-- ext/htmx/htmx.go | 38 +++++++++++++ ext/htmx/htmx.js | 27 ++++++++++ ext/htmx/htmx_test.go | 52 ++++++++++++++++++ go.mod | 3 +- go.sum | 6 +-- viewer_file.go | 21 ++------ 11 files changed, 382 insertions(+), 68 deletions(-) create mode 100644 etag.go create mode 100644 etag_test.go create mode 100644 ext/htmx/htmx.js diff --git a/README.md b/README.md index 4948ff7..be99ed4 100644 --- a/README.md +++ b/README.md @@ -833,13 +833,13 @@ func main() { ### Works with [tailwindcss](https://tailwindcss.com/docs/installation) -#### Install Tailwind CSS +#### 1. Install Tailwind CSS Install tailwindcss via npm, and create your tailwind.config.js file. ```bash npm install -D tailwindcss npx tailwindcss init ``` -#### Configure your template paths +#### 2. Configure your template paths Add the paths to all of your template files in your tailwind.config.js file. > tailwind.config.js @@ -854,7 +854,7 @@ module.exports = { } ``` -#### Add the Tailwind directives to your CSS +#### 3. Add the Tailwind directives to your CSS Add the @tailwind directives for each of Tailwind’s layers to your main CSS file. > app/tailwind.css ```css @@ -863,14 +863,14 @@ Add the @tailwind directives for each of Tailwind’s layers to your main CSS fi @tailwind utilities; ``` -#### Start the Tailwind CLI build process +#### 4. Start the Tailwind CLI build process Run the CLI tool to scan your template files for classes and build your CSS. ```bash npx tailwindcss -i ./app/tailwind.css -o ./app/public/theme.css --watch ``` -#### Start using Tailwind in your HTML +#### 5. Start using Tailwind in your HTML Add your compiled CSS file to the `assets.html` and start using Tailwind’s utility classes to style your content. > components/assets.html @@ -881,7 +881,7 @@ Add your compiled CSS file to the `assets.html` and start using Tailwind’s uti ``` ### Works with [htmx.js](https://htmx.org/docs/) -#### Add new pages +#### 1. Add new pages > `pages/admin/index.html` and `pages/login.html` ``` ├── app @@ -907,17 +907,25 @@ Add your compiled CSS file to the `assets.html` and start using Tailwind’s uti │   ├── tailwind.css ``` -#### Install htmx.js +#### 2. Serve [htmx-ext.js](./ext/htmx/htmx.js) library +The library to enable seamless integration between native JavaScript methods and htmx features, enhancing interactive capabilities without compromising core functionality. + +```go + app.Get("/htmx-ext.js", htmx.HandleFunc()) +``` + +#### 3. Install htmx.js and htmx-ext.js > components/assets.html ```html - + + ``` -#### Enabled `htmx` feature on pages +#### 4. Enabled `htmx` feature on pages > pages/index.html ```html @@ -982,25 +990,25 @@ Add your compiled CSS file to the `assets.html` and start using Tailwind’s uti {{ end }} ``` -#### Setup Hx-Trigger listener +#### 5. Setup Hx-Trigger listener > app.js ```js -window.addEventListener("DOMContentLoaded", (event) => { - document.body.addEventListener("showMessage", function(evt){ +xun.ready(function(evt) { + document.body.addEventListener("showMessage", function(evt){ alert(evt.detail.value); }) }); ``` -#### Apply `htmx` interceptor +#### 6. Apply `htmx` interceptor ```go app := xun.New(xun.WithInterceptor(htmx.New())) ``` -#### Create router handler to process request +#### 7. Create router handler to process request create an `admin` group router, and apply a middleware to check if it's logged. if not, redirect to /login. @@ -1062,7 +1070,9 @@ create an `admin` group router, and apply a middleware to check if it's logged. http.SetCookie(c.Response, &cookie) - c.Redirect(c.RequestReferer().Query().Get("return")) + u, _ := url.Parse(c.RequestReferer()) + + c.Redirect(u.Query().Get("return")) return nil }) ``` diff --git a/etag.go b/etag.go new file mode 100644 index 0000000..93f43aa --- /dev/null +++ b/etag.go @@ -0,0 +1,120 @@ +package xun + +import ( + "crypto/md5" // skipcq: GSC-G401, GSC-G501, GO-S1023 + "encoding/hex" + "hash" + "io" + "net/http" + "net/textproto" + "strings" +) + +// ComputeETag returns the ETag header value for the given reader content. +// +// The value is computed by taking the md5 of the content and encoding it +// as a hexadecimal string. +func ComputeETag(r io.Reader) string { + return ComputeETagWith(r, md5.New()) // skipcq: GSC-G401, GSC-G501, GO-S1023 +} + +// ComputeETagWith returns the ETag header value for the given reader content +// using the provided hash function. +func ComputeETagWith(r io.Reader, h hash.Hash) string { + if _, err := io.Copy(h, r); err != nil { + return "" + } + + return `"` + hex.EncodeToString(h.Sum(nil)) + `"` +} + +func WriteIfNoneMatch(w http.ResponseWriter, r *http.Request) bool { + if r.Method == "GET" || r.Method == "HEAD" { + if checkIfNoneMatch(w, r) { + writeNotModified(w) + return true + } + } + + return false +} + +func checkIfNoneMatch(w http.ResponseWriter, r *http.Request) bool { + inm := r.Header.Get("If-None-Match") + if inm == "" { + return false + } + buf := inm + for { + buf = textproto.TrimString(buf) + if len(buf) == 0 { + break + } + if buf[0] == ',' { + buf = buf[1:] + continue + } + if buf[0] == '*' { + return true + } + etag, remain := scanETag(buf) + if etag == "" { + break + } + if etagWeakMatch(etag, w.Header().Get("Etag")) { + return true + } + buf = remain + } + return false +} + +// scanETag determines if a syntactically valid ETag is present at s. If so, +// the ETag and remaining text after consuming ETag is returned. Otherwise, +// it returns "", "". +func scanETag(s string) (etag string, remain string) { + s = textproto.TrimString(s) + start := 0 + if strings.HasPrefix(s, "W/") { + start = 2 + } + if len(s[start:]) < 2 || s[start] != '"' { + return "", "" + } + // ETag is either W/"text" or "text". + // See RFC 7232 2.3. + for i := start + 1; i < len(s); i++ { + c := s[i] + switch { + // Character values allowed in ETags. + case c == 0x21 || c >= 0x23 && c <= 0x7E || c >= 0x80: + case c == '"': + return s[:i+1], s[i+1:] + default: + return "", "" + } + } + return "", "" +} + +// etagWeakMatch reports whether a and b match using weak ETag comparison. +// Assumes a and b are valid ETags. +func etagWeakMatch(a, b string) bool { + return strings.TrimPrefix(a, "W/") == strings.TrimPrefix(b, "W/") +} + +func writeNotModified(w http.ResponseWriter) { + // RFC 7232 section 4.1: + // a sender SHOULD NOT generate representation metadata other than the + // above listed fields unless said metadata exists for the purpose of + // guiding cache updates (e.g., Last-Modified might be useful if the + // response does not have an ETag field). + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + delete(h, "Content-Encoding") + if h.Get("Etag") != "" { + delete(h, "Last-Modified") + } + w.WriteHeader(http.StatusNotModified) +} diff --git a/etag_test.go b/etag_test.go new file mode 100644 index 0000000..e448a82 --- /dev/null +++ b/etag_test.go @@ -0,0 +1,88 @@ +package xun + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestETag(t *testing.T) { + + t.Run("tag", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + require.True(t, !WriteIfNoneMatch(w, req)) + + req.Header.Set("If-None-Match", "\"737060cd8c284d8af7ad3082f209582d\"") + require.True(t, !WriteIfNoneMatch(w, req)) + + w.Header().Set("ETag", "\"737060cd8c284d8af7ad3082f209582d\"") + require.True(t, WriteIfNoneMatch(w, req)) + }) + + t.Run("weak_tag", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + require.True(t, !WriteIfNoneMatch(w, req)) + + req.Header.Set("If-None-Match", "W/\"737060cd8c284d8af7ad3082f209582d\"") + require.True(t, !WriteIfNoneMatch(w, req)) + + w.Header().Set("ETag", "W/\"737060cd8c284d8af7ad3082f209582d\"") + require.True(t, WriteIfNoneMatch(w, req)) + }) + + t.Run("multi-etags", func(t *testing.T) { + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + + req.Header.Set("If-None-Match", `"etag1", "etag2", W/"weak-etag"`) + + w.Header().Set("ETag", `"etag1"`) + require.True(t, WriteIfNoneMatch(w, req)) + + w.Header().Set("ETag", `"etag2"`) + require.True(t, WriteIfNoneMatch(w, req)) + + w.Header().Set("ETag", `W/"weak-etag"`) + require.True(t, WriteIfNoneMatch(w, req)) + }) + + t.Run("any_etags", func(t *testing.T) { + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + + req.Header.Set("If-None-Match", `*`) + + w.Header().Set("ETag", `"etag1"`) + require.True(t, WriteIfNoneMatch(w, req)) + + w.Header().Set("ETag", `"etag2"`) + require.True(t, WriteIfNoneMatch(w, req)) + + w.Header().Set("ETag", `W/"weak-etag"`) + require.True(t, WriteIfNoneMatch(w, req)) + }) + + t.Run("invalid_etags", func(t *testing.T) { + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + + req.Header.Set("If-None-Match", `W\`) + + require.False(t, WriteIfNoneMatch(w, req)) + + req.Header.Set("If-None-Match", `""`) + require.False(t, WriteIfNoneMatch(w, req)) + + req.Header.Set("If-None-Match", `"etag "`) + require.False(t, WriteIfNoneMatch(w, req)) + + req.Header.Set("If-None-Match", `"etag`) + require.False(t, WriteIfNoneMatch(w, req)) + }) +} diff --git a/ext/csrf/csrf.go b/ext/csrf/csrf.go index 2a47e80..53f7ca4 100644 --- a/ext/csrf/csrf.go +++ b/ext/csrf/csrf.go @@ -5,11 +5,9 @@ import ( "crypto/hmac" "crypto/sha256" "embed" - "encoding/base64" "html/template" "io" "net/http" - "strings" "time" "github.com/yaitoo/xun" @@ -70,36 +68,35 @@ func HandleFunc(secretKey []byte, opts ...Option) xun.HandleFunc { opt(o) } - f, _ := fsys.Open("csrf.js") // nolint: errcheck - defer f.Close() - - buf, _ := io.ReadAll(f) // nolint: errcheck - - t, _ := template.New("token").Parse(string(buf)) // nolint: errcheck - - var processed bytes.Buffer - t.Execute(&processed, o) // nolint: errcheck + buf := loadJavaScript(o) mac := hmac.New(sha256.New, o.SecretKey) - mac.Write(processed.Bytes()) - - etag := base64.URLEncoding.EncodeToString(mac.Sum(nil)) + etag := xun.ComputeETagWith(bytes.NewReader(buf), mac) return func(c *xun.Context) error { - if match := c.Request.Header.Get("If-None-Match"); match != "" { - for _, it := range strings.Split(match, ",") { - if strings.TrimSpace(it) == etag { - c.Response.WriteHeader(http.StatusNotModified) - return nil - } - } - } - c.Response.Header().Set("ETag", etag) + if xun.WriteIfNoneMatch(c.Response, c.Request) { + return nil + } - content := bytes.NewReader(processed.Bytes()) + content := bytes.NewReader(buf) c.Response.Header().Set("Content-Type", "application/javascript") http.ServeContent(c.Response, c.Request, "csrf.js", zeroTime, content) + return nil } } + +func loadJavaScript(o *Options) []byte { + f, _ := fsys.Open("csrf.js") // nolint: errcheck + defer f.Close() + + buf, _ := io.ReadAll(f) // nolint: errcheck + + t, _ := template.New("token").Parse(string(buf)) // nolint: errcheck + + var body bytes.Buffer + t.Execute(&body, o) // nolint: errcheck + + return body.Bytes() +} diff --git a/ext/csrf/csrf_test.go b/ext/csrf/csrf_test.go index 8140512..bc60fe5 100644 --- a/ext/csrf/csrf_test.go +++ b/ext/csrf/csrf_test.go @@ -4,7 +4,6 @@ import ( "bytes" "crypto/hmac" "crypto/sha256" - "encoding/base64" "html/template" "io" "net/http" @@ -166,17 +165,14 @@ func TestHandleFunc(t *testing.T) { p, _ := template.New("token").Parse(string(buf)) // nolint: errcheck - var processed bytes.Buffer + var body bytes.Buffer // nolint: errcheck - p.Execute(&processed, &Options{ + p.Execute(&body, &Options{ SecretKey: []byte("secret"), CookieName: "test_token", }) - mac := hmac.New(sha256.New, []byte("secret")) - mac.Write(processed.Bytes()) - - etag := base64.URLEncoding.EncodeToString(mac.Sum(nil)) + etag := xun.ComputeETagWith(&body, hmac.New(sha256.New, []byte("secret"))) w := httptest.NewRecorder() ctx := &xun.Context{ diff --git a/ext/htmx/htmx.go b/ext/htmx/htmx.go index e9c9ba9..0916c0a 100644 --- a/ext/htmx/htmx.go +++ b/ext/htmx/htmx.go @@ -1,6 +1,12 @@ package htmx import ( + "bytes" + "embed" + "io" + "net/http" + "time" + jsoniter "github.com/json-iterator/go" "github.com/yaitoo/xun" ) @@ -79,3 +85,35 @@ func WriteHeader(c *xun.Context, key string, value any) { buf, _ := json.Marshal(value) c.WriteHeader(key, string(buf)) } + +//go:embed htmx.js +var fsys embed.FS + +var zeroTime time.Time + +// HandleFunc serves the htmx.js library for the htmx extension. +func HandleFunc() xun.HandleFunc { + buf := loadJavaScript() + etag := xun.ComputeETag(bytes.NewReader(buf)) + + return func(c *xun.Context) error { + c.Response.Header().Set("ETag", etag) + if xun.WriteIfNoneMatch(c.Response, c.Request) { + return nil + } + + content := bytes.NewReader(buf) + c.Response.Header().Set("Content-Type", "application/javascript") + http.ServeContent(c.Response, c.Request, "htmx.js", zeroTime, content) + return nil + } +} + +func loadJavaScript() []byte { + f, _ := fsys.Open("htmx.js") // nolint: errcheck + defer f.Close() + + buf, _ := io.ReadAll(f) // nolint: errcheck + + return buf +} diff --git a/ext/htmx/htmx.js b/ext/htmx/htmx.js new file mode 100644 index 0000000..3a2027f --- /dev/null +++ b/ext/htmx/htmx.js @@ -0,0 +1,27 @@ +window.xun = window.xun || { + /** + * A global object to manage custom events and callbacks. + * + * @property {Function} ready - Registers a callback function to be executed once when + * the DOM is fully loaded or when an `htmx:load` event occurs. + * + * @function ready + * @param {Function} fn - The callback function to be executed. + */ + ready:function(fn){ + let boosted = false; + document.addEventListener('DOMContentLoaded',function(){ + fn(); + }); + document.addEventListener('htmx:load', function(evt) { + if(boosted){ + fn(evt); + boosted = false; + } + }); + document.addEventListener('htmx:beforeOnLoad', function(evt) { + // trigger ready function again when a boosted request is done + boosted = evt.detail.boosted; + }); + } + } \ No newline at end of file diff --git a/ext/htmx/htmx_test.go b/ext/htmx/htmx_test.go index 7a65175..1b99cd2 100644 --- a/ext/htmx/htmx_test.go +++ b/ext/htmx/htmx_test.go @@ -1,6 +1,9 @@ package htmx import ( + "bytes" + "html/template" + "io" "net/http" "net/http/httptest" "strconv" @@ -68,3 +71,52 @@ func TestHtmxWriteHeader(t *testing.T) { require.Equal(t, "message", header["name"]) } + +func TestHandleFunc(t *testing.T) { + + fn := HandleFunc() + + t.Run("load", func(t *testing.T) { + w := httptest.NewRecorder() + ctx := &xun.Context{ + Request: httptest.NewRequest("GET", "/htmx.js", nil), + Response: xun.NewResponseWriter(w), + } + + err := fn(ctx) + require.NoError(t, err) + + require.Equal(t, http.StatusOK, w.Code) + require.Contains(t, w.Body.String(), `DOMContentLoaded`) + }) + + t.Run("etag", func(t *testing.T) { + f, _ := fsys.Open("htmx.js") // nolint: errcheck + defer f.Close() + + buf, _ := io.ReadAll(f) // nolint: errcheck + + p, _ := template.New("token").Parse(string(buf)) // nolint: errcheck + + var body bytes.Buffer + // nolint: errcheck + p.Execute(&body, nil) + + etag := xun.ComputeETag(&body) + + w := httptest.NewRecorder() + ctx := &xun.Context{ + Request: httptest.NewRequest("GET", "/htmx.js", nil), + Response: xun.NewResponseWriter(w), + } + + ctx.Request.Header.Set("If-None-Match", etag) + + err := fn(ctx) + require.NoError(t, err) + + require.Equal(t, http.StatusNotModified, w.Code) + + }) + +} diff --git a/go.mod b/go.mod index e9a1efb..bdb0ff0 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/go-playground/form/v4 v4.2.1 github.com/go-playground/locales v0.14.1 github.com/go-playground/universal-translator v0.18.1 - github.com/go-playground/validator/v10 v10.24.0 + github.com/go-playground/validator/v10 v10.25.0 github.com/json-iterator/go v1.1.12 github.com/stretchr/testify v1.10.0 golang.org/x/crypto v0.33.0 @@ -19,7 +19,6 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/stretchr/objx v0.5.2 // indirect golang.org/x/net v0.35.0 // indirect golang.org/x/sys v0.30.0 // indirect golang.org/x/text v0.22.0 // indirect diff --git a/go.sum b/go.sum index a9fc18c..ceddbff 100644 --- a/go.sum +++ b/go.sum @@ -13,8 +13,8 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.24.0 h1:KHQckvo8G6hlWnrPX4NJJ+aBfWNAE/HH+qdL2cBpCmg= -github.com/go-playground/validator/v10 v10.24.0/go.mod h1:GGzBIJMuE98Ic/kJsBXbz1x/7cByt++cQ+YOuDM5wus= +github.com/go-playground/validator/v10 v10.25.0 h1:5Dh7cjvzR7BRZadnsVOzPhWsrwUr0nmsZJxEAnFLNO8= +github.com/go-playground/validator/v10 v10.25.0/go.mod h1:GGzBIJMuE98Ic/kJsBXbz1x/7cByt++cQ+YOuDM5wus= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= @@ -29,8 +29,6 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= -github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= diff --git a/viewer_file.go b/viewer_file.go index ee6ff9a..94a4183 100644 --- a/viewer_file.go +++ b/viewer_file.go @@ -1,12 +1,9 @@ package xun import ( - "crypto/sha256" - "encoding/hex" "io" "io/fs" "net/http" - "strings" ) // NewFileViewer creates a new FileViewer instance. @@ -23,12 +20,8 @@ func NewFileViewer(fsys fs.FS, path string, isEmbed bool) *FileViewer { } defer f.Close() - hash := sha256.New() // skipcq: GSC-G401, GO-S1023 - if _, err := io.Copy(hash, f); err != nil { - return v - } v.isEmbed = true - v.etag = `"` + hex.EncodeToString(hash.Sum(nil)) + `"` + v.etag = ComputeETag(f) } return v @@ -74,16 +67,12 @@ func (v *FileViewer) Render(ctx *Context, data any) error { if !v.isEmbed { return v.serveContent(ctx.Response, ctx.Request) } - if match := ctx.Request.Header.Get("If-None-Match"); match != "" { - for _, it := range strings.Split(match, ",") { - if strings.TrimSpace(it) == v.etag { - ctx.Response.WriteHeader(http.StatusNotModified) - return nil - } - } - } ctx.Response.Header().Set("ETag", v.etag) + if WriteIfNoneMatch(ctx.Response, ctx.Request) { + return nil + } + return v.serveContent(ctx.Response, ctx.Request) } From d22caf2b24d59691ed5c0de588ea6eb78cd02ea4 Mon Sep 17 00:00:00 2001 From: Lz Date: Thu, 20 Feb 2025 20:58:38 +0800 Subject: [PATCH 2/8] fix(htmx): added `selector` in `xun.ready` to check if the callback should be executed (#53) --- ext/htmx/htmx.js | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/ext/htmx/htmx.js b/ext/htmx/htmx.js index 3a2027f..cded7a4 100644 --- a/ext/htmx/htmx.js +++ b/ext/htmx/htmx.js @@ -6,16 +6,26 @@ window.xun = window.xun || { * the DOM is fully loaded or when an `htmx:load` event occurs. * * @function ready - * @param {Function} fn - The callback function to be executed. + * @param {Function} callback - The callback function to be executed. + * @param {String} selector - The selector to be used to check if the callback should be executed. */ - ready:function(fn){ + ready:function(callback,selector){ + const f = function(evt){ + if(selector){ + if(document.querySelector(selector)){ + callback(evt); + } + }else{ + callback(evt); + } + } let boosted = false; - document.addEventListener('DOMContentLoaded',function(){ - fn(); + document.addEventListener('DOMContentLoaded',function(evt){ + f(evt); }); document.addEventListener('htmx:load', function(evt) { if(boosted){ - fn(evt); + f(evt); boosted = false; } }); From 8ede36cd0a9c7e712709b944dda1e1347db9a94f Mon Sep 17 00:00:00 2001 From: Lz Date: Fri, 21 Feb 2025 22:20:12 +0800 Subject: [PATCH 3/8] fix(json): use std json instead of json-iterator to support omitzero (#54) --- README.md | 4 ++-- app_test.go | 11 ++++++----- compressor_deflate_test.go | 2 +- compressor_gzip_test.go | 2 +- context.go | 6 ------ context_test.go | 3 ++- ext/form/binder.go | 6 +++--- ext/form/binder_test.go | 9 +++++---- ext/htmx/htmx.go | 10 ++++++---- ext/htmx/htmx_test.go | 1 + ext/htmx/interceptor_test.go | 8 ++++---- go.mod | 3 --- go.sum | 13 ------------- json.go | 34 ++++++++++++++++++++++++++++++++++ routing_option_test.go | 1 + viewengine_text_test.go | 2 +- viewer_json.go | 2 +- viewer_json_test.go | 2 +- 18 files changed, 69 insertions(+), 50 deletions(-) create mode 100644 json.go diff --git a/README.md b/README.md index be99ed4..412a614 100644 --- a/README.md +++ b/README.md @@ -994,10 +994,10 @@ The library to enable seamless integration between native JavaScript methods and > app.js ```js xun.ready(function(evt) { - document.body.addEventListener("showMessage", function(evt){ + document.addEventListener("showMessage", function(evt){ alert(evt.detail.value); }) -}); +},'body'); ``` diff --git a/app_test.go b/app_test.go index b7505a7..280d873 100644 --- a/app_test.go +++ b/app_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/tls" + "encoding/json" "errors" "io" "log/slog" @@ -80,7 +81,7 @@ func TestJsonViewer(t *testing.T) { resp, err := client.Do(req) require.NoError(t, err) - err = json.NewDecoder(resp.Body).Decode(data) + err = Json.NewDecoder(resp.Body).Decode(data) require.NoError(t, err) resp.Body.Close() @@ -92,7 +93,7 @@ func TestJsonViewer(t *testing.T) { resp, err = client.Do(req) require.NoError(t, err) - err = json.NewDecoder(resp.Body).Decode(&data) + err = Json.NewDecoder(resp.Body).Decode(&data) require.NoError(t, err) resp.Body.Close() require.Equal(t, "POST", data.Method) @@ -103,7 +104,7 @@ func TestJsonViewer(t *testing.T) { resp, err = client.Do(req) require.NoError(t, err) - err = json.NewDecoder(resp.Body).Decode(&data) + err = Json.NewDecoder(resp.Body).Decode(&data) require.NoError(t, err) resp.Body.Close() require.Equal(t, "PUT", data.Method) @@ -114,7 +115,7 @@ func TestJsonViewer(t *testing.T) { resp, err = client.Do(req) require.NoError(t, err) - err = json.NewDecoder(resp.Body).Decode(&data) + err = Json.NewDecoder(resp.Body).Decode(&data) require.NoError(t, err) resp.Body.Close() require.Equal(t, "DELETE", data.Method) @@ -125,7 +126,7 @@ func TestJsonViewer(t *testing.T) { resp, err = client.Do(req) require.NoError(t, err) - err = json.NewDecoder(resp.Body).Decode(&data) + err = Json.NewDecoder(resp.Body).Decode(&data) require.NoError(t, err) resp.Body.Close() require.Equal(t, "HandleFunc", data.Method) diff --git a/compressor_deflate_test.go b/compressor_deflate_test.go index 0c0104f..27d98d0 100644 --- a/compressor_deflate_test.go +++ b/compressor_deflate_test.go @@ -101,7 +101,7 @@ func TestDeflateCompressor(t *testing.T) { require.Equal(t, test.contentEncoding, resp.Header.Get("Content-Encoding")) data := make(map[string]string) - err = json.NewDecoder(test.createReader(resp.Body)).Decode(&data) + err = Json.NewDecoder(test.createReader(resp.Body)).Decode(&data) require.NoError(t, err) require.Equal(t, "hello", data["message"]) }) diff --git a/compressor_gzip_test.go b/compressor_gzip_test.go index a358e82..0d8dc96 100644 --- a/compressor_gzip_test.go +++ b/compressor_gzip_test.go @@ -112,7 +112,7 @@ func TestGzipCompressor(t *testing.T) { require.Equal(t, test.contentEncoding, resp.Header.Get("Content-Encoding")) data := make(map[string]string) - err = json.NewDecoder(test.createReader(resp.Body)).Decode(&data) + err = Json.NewDecoder(test.createReader(resp.Body)).Decode(&data) require.NoError(t, err) require.Equal(t, "hello", data["message"]) }) diff --git a/context.go b/context.go index ba20e24..d52e928 100644 --- a/context.go +++ b/context.go @@ -3,12 +3,6 @@ package xun import ( "net/http" "strings" - - jsoniter "github.com/json-iterator/go" -) - -var ( - json = jsoniter.Config{UseNumber: false}.Froze() ) type TempData map[string]any diff --git a/context_test.go b/context_test.go index a0b964d..1992b2d 100644 --- a/context_test.go +++ b/context_test.go @@ -1,6 +1,7 @@ package xun import ( + "encoding/json" "io" "net/http" "net/http/httptest" @@ -71,7 +72,7 @@ func TestTempData(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) var v string - err = json.NewDecoder(resp.Body).Decode(&v) + err = Json.NewDecoder(resp.Body).Decode(&v) require.NoError(t, err) require.Equal(t, "middleware", v) } diff --git a/ext/form/binder.go b/ext/form/binder.go index 1b21265..3beb955 100644 --- a/ext/form/binder.go +++ b/ext/form/binder.go @@ -5,11 +5,11 @@ import ( "github.com/go-playground/form/v4" "github.com/go-playground/validator/v10" - jsoniter "github.com/json-iterator/go" + "github.com/yaitoo/xun" ) var ( - json = jsoniter.Config{UseNumber: false}.Froze() + Json = xun.Json // use a single instance of Decoder, it caches struct info formDecoder = form.NewDecoder() @@ -61,7 +61,7 @@ func BindForm[T any](req *http.Request) (*TEntity[T], error) { func BindJson[T any](req *http.Request) (*TEntity[T], error) { data := new(T) - err := json.NewDecoder(req.Body).Decode(data) + err := Json.NewDecoder(req.Body).Decode(data) if err != nil { return nil, err } diff --git a/ext/form/binder_test.go b/ext/form/binder_test.go index a53aad6..5322a4e 100644 --- a/ext/form/binder_test.go +++ b/ext/form/binder_test.go @@ -3,6 +3,7 @@ package form import ( "bytes" "crypto/tls" + "encoding/json" "net/http" "net/http/httptest" @@ -135,7 +136,7 @@ func TestBinder(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) - err = json.NewDecoder(resp.Body).Decode(&result) + err = Json.NewDecoder(resp.Body).Decode(&result) require.NoError(t, err) resp.Body.Close() require.Equal(t, "xun@yaitoo.cn", result.Data.Email) @@ -148,7 +149,7 @@ func TestBinder(t *testing.T) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) - err = json.NewDecoder(resp.Body).Decode(&result) + err = Json.NewDecoder(resp.Body).Decode(&result) require.NoError(t, err) resp.Body.Close() require.Equal(t, "xun@yaitoo.cn", result.Data.Email) @@ -163,7 +164,7 @@ func TestBinder(t *testing.T) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) - err = json.NewDecoder(resp.Body).Decode(&result) + err = Json.NewDecoder(resp.Body).Decode(&result) require.NoError(t, err) resp.Body.Close() require.Len(t, result.Errors, 2) @@ -178,7 +179,7 @@ func TestBinder(t *testing.T) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) - err = json.NewDecoder(resp.Body).Decode(&result) + err = Json.NewDecoder(resp.Body).Decode(&result) require.NoError(t, err) resp.Body.Close() require.Len(t, result.Errors, 2) diff --git a/ext/htmx/htmx.go b/ext/htmx/htmx.go index 0916c0a..727a388 100644 --- a/ext/htmx/htmx.go +++ b/ext/htmx/htmx.go @@ -7,7 +7,6 @@ import ( "net/http" "time" - jsoniter "github.com/json-iterator/go" "github.com/yaitoo/xun" ) @@ -65,7 +64,7 @@ const ( ) var ( - json = jsoniter.Config{UseNumber: false}.Froze() + Json = xun.Json ) // HxHeader represents a map of string keys to values of any type. @@ -82,8 +81,11 @@ func WriteHeader(c *xun.Context, key string, value any) { return } - buf, _ := json.Marshal(value) - c.WriteHeader(key, string(buf)) + buf := bytes.NewBuffer(nil) + + Json.NewEncoder(buf).Encode(value) // nolint: errcheck + + c.WriteHeader(key, buf.String()) } //go:embed htmx.js diff --git a/ext/htmx/htmx_test.go b/ext/htmx/htmx_test.go index 1b99cd2..564c0ac 100644 --- a/ext/htmx/htmx_test.go +++ b/ext/htmx/htmx_test.go @@ -2,6 +2,7 @@ package htmx import ( "bytes" + "encoding/json" "html/template" "io" "net/http" diff --git a/ext/htmx/interceptor_test.go b/ext/htmx/interceptor_test.go index 535aba4..168aef3 100644 --- a/ext/htmx/interceptor_test.go +++ b/ext/htmx/interceptor_test.go @@ -62,7 +62,7 @@ func TestHtmxInterceptor(t *testing.T) { require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) - err = json.NewDecoder(resp.Body).Decode(&referer) + err = Json.NewDecoder(resp.Body).Decode(&referer) require.NoError(t, err) require.Equal(t, "/home", referer) @@ -76,7 +76,7 @@ func TestHtmxInterceptor(t *testing.T) { require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) - err = json.NewDecoder(resp.Body).Decode(&referer) + err = Json.NewDecoder(resp.Body).Decode(&referer) require.NoError(t, err) require.Equal(t, "/home", referer) @@ -89,7 +89,7 @@ func TestHtmxInterceptor(t *testing.T) { require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) - err = json.NewDecoder(resp.Body).Decode(&referer) + err = Json.NewDecoder(resp.Body).Decode(&referer) require.NoError(t, err) require.Empty(t, referer) // empty referer @@ -103,7 +103,7 @@ func TestHtmxInterceptor(t *testing.T) { require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) - err = json.NewDecoder(resp.Body).Decode(&referer) + err = Json.NewDecoder(resp.Body).Decode(&referer) require.NoError(t, err) require.Empty(t, referer) diff --git a/go.mod b/go.mod index bdb0ff0..cd97014 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ require ( github.com/go-playground/locales v0.14.1 github.com/go-playground/universal-translator v0.18.1 github.com/go-playground/validator/v10 v10.25.0 - github.com/json-iterator/go v1.1.12 github.com/stretchr/testify v1.10.0 golang.org/x/crypto v0.33.0 ) @@ -16,8 +15,6 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/gabriel-vasile/mimetype v1.4.8 // indirect github.com/leodido/go-urn v1.4.0 // indirect - github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect - github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect golang.org/x/net v0.35.0 // indirect golang.org/x/sys v0.30.0 // indirect diff --git a/go.sum b/go.sum index ceddbff..37f36c3 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= @@ -15,21 +13,10 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.25.0 h1:5Dh7cjvzR7BRZadnsVOzPhWsrwUr0nmsZJxEAnFLNO8= github.com/go-playground/validator/v10 v10.25.0/go.mod h1:GGzBIJMuE98Ic/kJsBXbz1x/7cByt++cQ+YOuDM5wus= -github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= -github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= -github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= diff --git a/json.go b/json.go new file mode 100644 index 0000000..4b42c64 --- /dev/null +++ b/json.go @@ -0,0 +1,34 @@ +package xun + +import ( + "encoding/json" + "io" +) + +var Json JsonEncoding = &stdJsonEncoding{} + +// JsonEncoding is the interface that defines the methods that the standard +// library encoding/json package provides. +type JsonEncoding interface { + NewEncoder(writer io.Writer) Encoder + NewDecoder(reader io.Reader) Decoder +} + +type Decoder interface { + Decode(obj interface{}) error +} + +type Encoder interface { + Encode(val interface{}) error +} + +type stdJsonEncoding struct { +} + +func (*stdJsonEncoding) NewEncoder(writer io.Writer) Encoder { + return json.NewEncoder(writer) +} + +func (*stdJsonEncoding) NewDecoder(reader io.Reader) Decoder { + return json.NewDecoder(reader) +} diff --git a/routing_option_test.go b/routing_option_test.go index 4edd636..ec17065 100644 --- a/routing_option_test.go +++ b/routing_option_test.go @@ -1,6 +1,7 @@ package xun import ( + "encoding/json" "encoding/xml" "io" "net/http" diff --git a/viewengine_text_test.go b/viewengine_text_test.go index d6b2096..015ef58 100644 --- a/viewengine_text_test.go +++ b/viewengine_text_test.go @@ -112,7 +112,7 @@ func TestWatchOnText(t *testing.T) { require.NoError(t, err) var content string - err = json.NewDecoder(resp.Body).Decode(&content) + err = Json.NewDecoder(resp.Body).Decode(&content) require.NoError(t, err) require.Empty(t, content) resp.Body.Close() diff --git a/viewer_json.go b/viewer_json.go index 84da626..8c42935 100644 --- a/viewer_json.go +++ b/viewer_json.go @@ -29,7 +29,7 @@ func (*JsonViewer) Render(ctx *Context, data any) error { // skipcq: RVV-B0012 buf := BufPool.Get() defer BufPool.Put(buf) - err = json.NewEncoder(buf).Encode(data) + err = Json.NewEncoder(buf).Encode(data) if err != nil { return err } diff --git a/viewer_json_test.go b/viewer_json_test.go index 9c8a68a..2a97192 100644 --- a/viewer_json_test.go +++ b/viewer_json_test.go @@ -25,7 +25,7 @@ func TestJsonViewerRenderError(t *testing.T) { // should get raw error when json.marshal fails, and StatusCode should be written err := v.Render(ctx, data) require.Error(t, err) - require.Equal(t, "chan int is unsupported type", err.Error()) + require.Equal(t, "json: unsupported type: chan int", err.Error()) require.Equal(t, -1, rw.Code) From aba6a8324fcea5d7bc4412f3e46405b9660e9d9d Mon Sep 17 00:00:00 2001 From: Lz Date: Sat, 22 Feb 2025 07:20:52 +0800 Subject: [PATCH 4/8] feat(acl): added `Whitelist` to allowing specific paths to bypass host checking (#55) --- ext/acl/acl.go | 20 +++++++++++++++++++- ext/acl/acl_test.go | 23 +++++++++++++++++++++++ ext/acl/config.go | 8 +++++++- ext/acl/config_test.go | 15 +++++++++++++++ ext/acl/option.go | 8 ++++++++ 5 files changed, 72 insertions(+), 2 deletions(-) diff --git a/ext/acl/acl.go b/ext/acl/acl.go index 5dc9f71..0e8431d 100644 --- a/ext/acl/acl.go +++ b/ext/acl/acl.go @@ -1,3 +1,10 @@ +// This package provides Access Control List (ACL) middleware for the Xun framework. +// It allows for configuring allowed and denied hosts, IP networks, and countries +// based on configuration files. The middleware supports dynamic reloading of rules +// when the configuration file changes, enabling real-time updates to access rules. +// It also offers functionality for host redirection and integrates with the Xun +// framework's context to apply these rules to incoming requests. + package acl import ( @@ -14,7 +21,9 @@ var ( v atomic.Value ) -func New(opts ...Option) xun.Middleware { +// New returns a new ACL middleware that applies access rules based on the provided options. +// It dynamically reloads rules if a configuration file is specified, enabling real-time updates. +func New(opts ...Option) xun.Middleware { // skipcq: GO-R1005 options := NewOptions() for _, opt := range opts { @@ -44,6 +53,15 @@ func New(opts ...Option) xun.Middleware { o := v.Load().(*Options) if len(o.AllowHosts) > 0 { _, allow := o.AllowHosts[m.Host] + if !allow { + for _, it := range o.HostWhitelist { + if strings.EqualFold(c.Request.URL.Path, it) { + allow = true + break + } + } + } + if !allow { if o.HostRedirectStatusCode > 0 && o.HostRedirectURL != "" { return redirect(c, o) diff --git a/ext/acl/acl_test.go b/ext/acl/acl_test.go index 7d2e156..7fd9599 100644 --- a/ext/acl/acl_test.go +++ b/ext/acl/acl_test.go @@ -49,6 +49,29 @@ func TestHosts(t *testing.T) { require.Equal(t, "http://127.0.0.2", w.Header().Get("Location")) }) + t.Run("host_whitelist", func(t *testing.T) { + m := New(AllowHosts("abc.com"), WithHostWhitelist("/status", "/Ping")) + + ctx := createContext(nil) + + ctx.Request = httptest.NewRequest(http.MethodGet, "http://123.com/status", nil) + err := m(nop)(ctx) + require.NoError(t, err) + + ctx.Request = httptest.NewRequest(http.MethodGet, "http://123.com/ping", nil) + err = m(nop)(ctx) + require.NoError(t, err) + + ctx.Request = httptest.NewRequest(http.MethodGet, "http://123.com/home", nil) + err = m(nop)(ctx) + require.ErrorIs(t, err, xun.ErrCancelled) + + ctx.Request = httptest.NewRequest(http.MethodGet, "http://abc.com/home", nil) + err = m(nop)(ctx) + require.NoError(t, err) + + }) + t.Run("redirect_with_invalid_url", func(t *testing.T) { m := New(AllowHosts("abc.com"), WithHostRedirect("", http.StatusFound)) diff --git a/ext/acl/config.go b/ext/acl/config.go index ccf0814..c9d31a6 100644 --- a/ext/acl/config.go +++ b/ext/acl/config.go @@ -19,6 +19,7 @@ const ( SectionDN // deny_ipnets SectionAC // allow_countries SectionDC // deny_countries + SectionWL // host_whitelist ) var openFile = func(file string) (fs.File, error) { @@ -26,7 +27,7 @@ var openFile = func(file string) (fs.File, error) { return os.OpenFile(file, os.O_RDONLY, 0600) } -func loadOptions(file string, o *Options) bool { +func loadOptions(file string, o *Options) bool { // skipcq: GO-R1005 f, err := openFile(file) if err != nil { Logger.Println("acl: can't read file", file, err) @@ -65,6 +66,9 @@ func loadOptions(file string, o *Options) bool { case "[deny_countries]": section = SectionDC continue + case "[host_whitelist]": + section = SectionWL + continue } switch section { @@ -78,6 +82,8 @@ func loadOptions(file string, o *Options) bool { AllowCountries(l)(o) case SectionDC: DenyCountries(l)(o) + case SectionWL: + WithHostWhitelist(l)(o) } } diff --git a/ext/acl/config_test.go b/ext/acl/config_test.go index 46143f3..bc53f60 100644 --- a/ext/acl/config_test.go +++ b/ext/acl/config_test.go @@ -20,9 +20,15 @@ func TestConfig(t *testing.T) { Data: []byte(` [allow_hosts] abc.com + +[host_whitelist] +/allow +/admin + [host_redirect] url=http://abc.com status_code=301 + [allow_ipnets] 172.0.0.0/24 192.0.0.1 @@ -72,6 +78,7 @@ us require.Len(t, o.AllowHosts, 0) require.Empty(t, o.HostRedirectURL) require.Equal(t, 302, o.HostRedirectStatusCode) + require.Len(t, o.HostWhitelist, 0) require.Len(t, o.AllowIPNets, 0) require.Len(t, o.DenyIPNets, 0) @@ -94,6 +101,9 @@ us require.Len(t, o.AllowHosts, 1) require.Equal(t, "http://abc.com", o.HostRedirectURL) require.Equal(t, 301, o.HostRedirectStatusCode) + require.Len(t, o.HostWhitelist, 2) + require.Equal(t, "/allow", o.HostWhitelist[0]) + require.Equal(t, "/admin", o.HostWhitelist[1]) require.Len(t, o.AllowIPNets, 2) require.Equal(t, ParseIPNet("172.0.0.0/24"), o.AllowIPNets[0]) @@ -128,6 +138,9 @@ cn [host_redirect] url=http://123.com status_code=302 + +[host_whitelist] +/status `) mu.Lock() @@ -144,6 +157,8 @@ status_code=302 require.Len(t, o.AllowHosts, 1) require.Equal(t, "http://123.com", o.HostRedirectURL) require.Equal(t, 302, o.HostRedirectStatusCode) + require.Len(t, o.HostWhitelist, 1) + require.Equal(t, "/status", o.HostWhitelist[0]) require.Len(t, o.AllowIPNets, 0) diff --git a/ext/acl/option.go b/ext/acl/option.go index cd7b1bb..6cecbff 100644 --- a/ext/acl/option.go +++ b/ext/acl/option.go @@ -20,6 +20,7 @@ type Options struct { HostRedirectURL string HostRedirectStatusCode int + HostWhitelist []string AllowIPNets []*net.IPNet DenyIPNets []*net.IPNet @@ -154,6 +155,13 @@ func WithHostRedirect(u string, code int) Option { } } +// WithHostWhitelist sets the whitelist for AllowHosts,allowing specific paths to bypass host checking. +func WithHostWhitelist(paths ...string) Option { + return func(o *Options) { + o.HostWhitelist = append(o.HostWhitelist, paths...) + } +} + // CountryRule represents a rule for managing country-based access control. // It includes a map of country codes and a flag indicating if any country is allowed. type CountryRule struct { From 8916c23f7ddfc8c32a8add241175e1d1ba8ccada Mon Sep 17 00:00:00 2001 From: Lz Date: Sat, 22 Feb 2025 23:25:12 +0800 Subject: [PATCH 5/8] fix(htmx): added `fetch` wrapper to support `Hx-Trigger` feature like htmx request (#56) --- README.md | 4 +-- ext/htmx/htmx.js | 89 +++++++++++++++++++++++++++++++++--------------- 2 files changed, 64 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 412a614..d4e299b 100644 --- a/README.md +++ b/README.md @@ -993,9 +993,9 @@ The library to enable seamless integration between native JavaScript methods and #### 5. Setup Hx-Trigger listener > app.js ```js -xun.ready(function(evt) { +$x.ready(function(evt) { document.addEventListener("showMessage", function(evt){ - alert(evt.detail.value); + alert(evt.detail); }) },'body'); diff --git a/ext/htmx/htmx.js b/ext/htmx/htmx.js index cded7a4..068243d 100644 --- a/ext/htmx/htmx.js +++ b/ext/htmx/htmx.js @@ -1,37 +1,72 @@ -window.xun = window.xun || { - /** - * A global object to manage custom events and callbacks. - * - * @property {Function} ready - Registers a callback function to be executed once when - * the DOM is fully loaded or when an `htmx:load` event occurs. - * - * @function ready - * @param {Function} callback - The callback function to be executed. - * @param {String} selector - The selector to be used to check if the callback should be executed. - */ - ready:function(callback,selector){ - const f = function(evt){ - if(selector){ - if(document.querySelector(selector)){ +(function () { + window.$x = window.$x || { + /** + * A global object to manage custom events and callbacks. + * + * @property {Function} ready - Registers a callback function to be executed + * once when the DOM is fully loaded or when an `htmx:load` event occurs. + * + * @function ready + * @param {Function} callback - The callback function to be executed. + * @param {String} selector - The selector to be used to check if the callback + * should be executed. + */ + ready: function (callback, selector) { + const f = function (evt) { + if (selector) { + if (document.querySelector(selector)) { + callback(evt); + } + } else { callback(evt); } - }else{ - callback(evt); - } - } + }; let boosted = false; - document.addEventListener('DOMContentLoaded',function(evt){ + document.addEventListener("DOMContentLoaded", function (evt) { f(evt); }); - document.addEventListener('htmx:load', function(evt) { - if(boosted){ + document.addEventListener("htmx:load", function (evt) { + if (boosted) { f(evt); - boosted = false; + boosted = false; } - }); - document.addEventListener('htmx:beforeOnLoad', function(evt) { + }); + document.addEventListener("htmx:beforeOnLoad", function (evt) { // trigger ready function again when a boosted request is done boosted = evt.detail.boosted; }); - } - } \ No newline at end of file + }, + /** + * The fetch function is a wrapper of native fetch with Hx-Trigger support + * like it in htmx requests. + * + * @function fetch + * @async + * @param {String|Request} input - The URL to be requested or the Request + * object. + * @param {Object} init - The options to be used for the request. See the + * {@link + * https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch|fetch} + * API. + * @returns {Promise} - The response of the request. + */ + fetch: async (...args) => { + const response = await fetch(...args); + if (!response.ok) { + const hx = response.headers.get("Hx-Trigger"); + if (hx) { + try{ + const d = JSON.parse(hx); + const keys = Object.keys(d); + for (const key of keys) { + window.dispatchEvent(new CustomEvent(key, { detail: d[key] })); + } + }catch(e){ + // prevent invalid JSON from breaking the application + } + } + } + return response; + }, + }; +})(); From 573f3f1aeb1dbed77c2d3f4365b12f339aa19189 Mon Sep 17 00:00:00 2001 From: Lz Date: Sat, 22 Feb 2025 23:54:46 +0800 Subject: [PATCH 6/8] chore(docs): updated title --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d4e299b..79be0ca 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # Xun -Xun is an HTTP web framework built on Go's built-in html/template and net/http package’s router. +Xun is a web framework built on Go's built-in html/template and net/http package’s router. It is designed to be lightweight, fast, and easy to use. Xun provides a simple and intuitive API for building web applications, while also offering advanced features such as middleware, routing, and template rendering. Xun [ʃʊn] (pronounced 'shoon'), derived from the Chinese character 迅, signifies being lightweight and fast. From ac40b61e3e1d5b5341e54146c5a1065d1e3568c9 Mon Sep 17 00:00:00 2001 From: Lz Date: Wed, 26 Feb 2025 23:51:09 +0800 Subject: [PATCH 7/8] feat(sse): added ext/sse (#59) --- .github/workflows/tests.yml | 6 +- ext/sse/README.md | 0 ext/sse/client.go | 50 +++++++++++++++++ ext/sse/event.go | 8 +++ ext/sse/server.go | 86 ++++++++++++++++++++++++++++ ext/sse/server_test.go | 108 ++++++++++++++++++++++++++++++++++++ ext/sse/streamer.go | 10 ++++ go.mod | 7 ++- go.sum | 6 +- 9 files changed, 274 insertions(+), 7 deletions(-) create mode 100644 ext/sse/README.md create mode 100644 ext/sse/client.go create mode 100644 ext/sse/event.go create mode 100644 ext/sse/server.go create mode 100644 ext/sse/server_test.go create mode 100644 ext/sse/streamer.go diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4d152ee..ea30f36 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,11 +23,11 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: ^1.21 + go-version: 1.24.0 - name: golangci-lint uses: golangci/golangci-lint-action@v6 with: - version: v1.61 + version: v1.64.5 unit-tests: name: Unit Tests runs-on: ubuntu-latest @@ -37,7 +37,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: ^1.21 + go-version: 1.24.0 - name: Unit Tests run: | make unit-tests diff --git a/ext/sse/README.md b/ext/sse/README.md new file mode 100644 index 0000000..e69de29 diff --git a/ext/sse/client.go b/ext/sse/client.go new file mode 100644 index 0000000..371c3e5 --- /dev/null +++ b/ext/sse/client.go @@ -0,0 +1,50 @@ +package sse + +import ( + "context" + "encoding/json" + "errors" + "fmt" +) + +var ErrClientClosed = errors.New("sse: client closed") + +// Client represents a WebSocket client that handles HTTP responses and supports +// flushing data to the client. It contains a response writer, a flusher for +// sending data immediately, and a channel for managing the client's lifecycle. +type Client struct { + s Streamer + + ctx context.Context +} + +// Connect establishes a connection for the Client using the provided Streamer. +// It assigns the Streamer to the Client's rw field and ensures that it implements +// the http.Flusher interface for flushing data. +func (c *Client) Connect(ctx context.Context, s Streamer) { + c.s = s + c.ctx = ctx +} + +// Send sends an event to the client by writing the event name and data to the response writer. +// It marshals the event data into JSON format and flushes the output to ensure the data is sent immediately. +// This method is part of the Client struct and is intended for use in server-sent events (SSE) communication. +func (c *Client) Send(event Event) error { + select { + case <-c.ctx.Done(): + return ErrClientClosed + default: + buf, err := json.Marshal(event.Data) + if err != nil { + return err + } + _, err = fmt.Fprintf(c.s, "event: %s\ndata: %s\n\n", event.Name, string(buf)) + if err != nil { + return err + } + + c.s.Flush() + } + + return nil +} diff --git a/ext/sse/event.go b/ext/sse/event.go new file mode 100644 index 0000000..5661a81 --- /dev/null +++ b/ext/sse/event.go @@ -0,0 +1,8 @@ +package sse + +// Event represents a server-sent event with a name and associated data. +// It can be used to transmit information from the server to the client in real-time. +type Event struct { + Name string + Data any +} diff --git a/ext/sse/server.go b/ext/sse/server.go new file mode 100644 index 0000000..2ed3e59 --- /dev/null +++ b/ext/sse/server.go @@ -0,0 +1,86 @@ +// Package sse provides a server implementation for Server-Sent Events (SSE). +// SSE is a technology enabling a client to receive automatic updates from a server via HTTP connection. +package sse + +import ( + "context" + "sync" + + "github.com/yaitoo/async" +) + +// Server represents a structure that manages connected clients +// in a concurrent environment. It uses a read-write mutex to +// ensure safe access to the clients map, which holds the +// active Client instances identified by their unique keys. +type Server struct { + sync.RWMutex + clients map[string]*Client +} + +// New creates and returns a new instance of the Server struct. +func New() *Server { + return &Server{ + clients: make(map[string]*Client), + } +} + +// Join adds a new client to the server or retrieves an existing one based on the provided ID. +// It establishes a connection with the specified Streamer and sets the appropriate headers +// for Server-Sent Events (SSE). If a client with the given ID already exists, it reuses that client. +func (s *Server) Join(ctx context.Context, id string, sm Streamer) *Client { + s.Lock() + defer s.Unlock() + c, ok := s.clients[id] + + if !ok { + c = &Client{} + s.clients[id] = c + } + + c.Connect(ctx, sm) + + sm.Header().Set("Content-Type", "text/event-stream") + sm.Header().Set("Cache-Control", "no-cache") + sm.Header().Set("Connection", "keep-alive") + + return c +} + +// Leave removes a client from the server's client list by its ID. +// This method is safe for concurrent use, as it locks the server +// before modifying the clients map and ensures that the lock is +// released afterward. +func (s *Server) Leave(id string) { + s.Lock() + defer s.Unlock() + + delete(s.clients, id) +} + +// Get retrieves the Client associated with the given id from the Server. +// It uses a read lock to ensure thread-safe access to the clients map. +// Returns nil if no Client is found for the specified id. +func (s *Server) Get(id string) *Client { + s.RLock() + defer s.RUnlock() + return s.clients[id] +} + +// Broadcast sends the specified event to all connected clients. +// It acquires a read lock to ensure thread-safe access to the clients slice, +// and spawns a goroutine for each client to handle the sending of the event. +func (s *Server) Broadcast(ctx context.Context, event Event) ([]error, error) { + s.RLock() + defer s.RUnlock() + + task := async.NewA() + + for _, c := range s.clients { + task.Add(func(ctx context.Context) error { + return c.Send(event) + }) + } + + return task.Wait(ctx) +} diff --git a/ext/sse/server_test.go b/ext/sse/server_test.go new file mode 100644 index 0000000..58af42d --- /dev/null +++ b/ext/sse/server_test.go @@ -0,0 +1,108 @@ +package sse + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestServer(t *testing.T) { + t.Run("join", func(t *testing.T) { + srv := New() + rw := httptest.NewRecorder() + + c1 := srv.Join(context.TODO(), "join", rw) + + c2 := srv.Join(context.TODO(), "join", rw) + + require.Equal(t, c1, c2) + + c3 := srv.Get("join") + + require.Equal(t, c1, c3) + + srv.Leave("join") + + c4 := srv.Get("join") + require.Nil(t, c4) + + }) + + t.Run("send", func(t *testing.T) { + srv := New() + rw := httptest.NewRecorder() + + c := srv.Join(context.TODO(), "send", rw) + + err := c.Send(Event{Name: "event1", Data: "data1"}) + require.NoError(t, err) + buf := rw.Body.Bytes() + require.Equal(t, "event: event1\ndata: \"data1\"\n\n", string(buf)) + + err = c.Send(Event{Name: "event2", Data: "data2"}) + require.NoError(t, err) + buf = rw.Body.Bytes() + require.Equal(t, "event: event1\ndata: \"data1\"\n\nevent: event2\ndata: \"data2\"\n\n", string(buf)) + }) + + t.Run("broadcast", func(t *testing.T) { + srv := New() + + rw1 := httptest.NewRecorder() + rw2 := httptest.NewRecorder() + + c1 := srv.Join(context.TODO(), "c1", rw1) + require.NotNil(t, c1) + + c2 := srv.Join(context.TODO(), "c2", rw2) + require.NotNil(t, c2) + + errs, err := srv.Broadcast(context.TODO(), Event{Name: "event1", Data: "data1"}) + require.NoError(t, err) + require.Nil(t, errs) + + buf1 := rw1.Body.Bytes() + buf2 := rw2.Body.Bytes() + + require.Equal(t, buf1, buf2) + require.Equal(t, "event: event1\ndata: \"data1\"\n\n", string(buf1)) + }) + + t.Run("invalid", func(t *testing.T) { + srv := New() + + rw := &streamerMock{ + ResponseWriter: httptest.NewRecorder(), + } + + ctx, cancel := context.WithCancel(context.TODO()) + + c := srv.Join(ctx, "invalid", rw) + + err := c.Send(Event{Name: "event1", Data: make(chan int)}) + require.Error(t, err) + + err = c.Send(Event{Name: "event1"}) + require.Error(t, err) + + cancel() + + err = c.Send(Event{Name: "event1"}) + require.ErrorIs(t, err, ErrClientClosed) + + }) +} + +type streamerMock struct { + http.ResponseWriter +} + +func (*streamerMock) Write([]byte) (int, error) { + return 0, errors.New("mock: invalid") +} + +func (*streamerMock) Flush() {} diff --git a/ext/sse/streamer.go b/ext/sse/streamer.go new file mode 100644 index 0000000..703968c --- /dev/null +++ b/ext/sse/streamer.go @@ -0,0 +1,10 @@ +package sse + +import ( + "net/http" +) + +type Streamer interface { + http.ResponseWriter + http.Flusher +} diff --git a/go.mod b/go.mod index cd97014..dda81e5 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module github.com/yaitoo/xun -go 1.22 +go 1.23.0 + +toolchain go1.24.0 require ( github.com/go-playground/form/v4 v4.2.1 @@ -8,7 +10,8 @@ require ( github.com/go-playground/universal-translator v0.18.1 github.com/go-playground/validator/v10 v10.25.0 github.com/stretchr/testify v1.10.0 - golang.org/x/crypto v0.33.0 + github.com/yaitoo/async v1.0.4 + golang.org/x/crypto v0.35.0 ) require ( diff --git a/go.sum b/go.sum index 37f36c3..7d982e4 100644 --- a/go.sum +++ b/go.sum @@ -19,8 +19,10 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= -golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= +github.com/yaitoo/async v1.0.4 h1:u+SWuJcSckgBOcMjMYz9IviojeCatDrdni3YNGLCiHY= +github.com/yaitoo/async v1.0.4/go.mod h1:IpSO7Ei7AxiqLxFqDjN4rJaVlt8wm4ZxMXyyQaWmM1g= +golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= +golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= From 990a2fa21309ec623552a88ca49c589648ec5a8c Mon Sep 17 00:00:00 2001 From: Lz Date: Thu, 27 Feb 2025 15:56:18 +0800 Subject: [PATCH 8/8] fix(sse): added JsonEvent/TextEvent and updated Join with http.ResponseWriter (#60) --- PULL_REQUEST_TEMPLATE.md | 2 +- README.md | 63 +++++++++++++++++++++++++- ext/sse/README.md | 0 ext/sse/client.go | 44 +++++++++++------- ext/sse/error.go | 17 +++++++ ext/sse/event.go | 48 ++++++++++++++++++-- ext/sse/server.go | 41 ++++++++++++++--- ext/sse/server_test.go | 98 ++++++++++++++++++++++++++++++++++------ ext/sse/streamer.go | 23 ++++++++++ response_writer_std.go | 23 +++++++++- 10 files changed, 314 insertions(+), 45 deletions(-) delete mode 100644 ext/sse/README.md create mode 100644 ext/sse/error.go diff --git a/PULL_REQUEST_TEMPLATE.md b/PULL_REQUEST_TEMPLATE.md index 685d699..aaa0e6d 100644 --- a/PULL_REQUEST_TEMPLATE.md +++ b/PULL_REQUEST_TEMPLATE.md @@ -9,6 +9,6 @@ ### Tests Tasks to complete before merging PR: -- [ ] Ensure unit tests are passing. If not run `make unit-test` to check for any regressions :clipboard: +- [ ] Ensure unit tests are passing. If not run `make unit-tests` to check for any regressions :clipboard: - [ ] Ensure lint tests are passing. if not run `make lint` to check for any issues - [ ] Ensure codecov/patch is passing for changes. \ No newline at end of file diff --git a/README.md b/README.md index 79be0ca..ade2da5 100644 --- a/README.md +++ b/README.md @@ -696,7 +696,7 @@ func main(){ ``` -#### Access Control List +#### Access Control List ([ACL](./ext/acl/)) The ACL filters and monitors HTTP traffic through granular rule sets, designed to protect web applications/APIs from malicious bots, exploit attempts, and unauthorized access. ##### Core Filtering Dimensions @@ -787,6 +787,67 @@ status_code=302 app.Use(acl.New(acl.WithConfig("./acl.ini"))) ``` +#### Server-Sent Events ([SSE](./ext/sse/)) +Server-Sent Events (SSE) is a server push technology enabling a client to receive automatic updates from a server via an HTTP connection. + +> use `sse` extension to handle SSE request +```go +ss := sse.New() + +app.Get("/chatroom/{id}", func(ctx *xun.Context)error { + id := c.Request.PathValue("id") + room, err := ss.Join(c.Request.Context(), id, c.Response) + if err != nil { + c.WriteStatus(http.StatusBadRequest) + return xun.ErrCancelled + } + + room.Wait() + + ss.Leave(id) + + return nil +}) + +``` + +> push an event to the chatroom +```go +r := ss.Get("room_id") +if r != nil { + r.Send(sse.TextEvent{ + Name:"showMessage", + Data:"Hello", + }) +} +``` + +> broadcast an event to all chatroom +```go +ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) +defer cancel() +ss.Broadcast(ctx, sse.TextEvent{ + Name:"shutdown", + Data:"Server is shutting down", +} +``` + +> shutdown server and close all chatrooms +```go +ss.Shutdown() +``` + +> use [htmx-ext-sse](https://htmx.org/extensions/sse/) extension to send SSE request +```html + + + +
+... +
+``` + ### Deploy your application Leveraging Go's built-in `//go:embed` directive and the standard library's `fs.FS` interface, we can compile all static assets and configuration files into a single self-contained binary. This dependency-free approach enables seamless deployment to any server environment. diff --git a/ext/sse/README.md b/ext/sse/README.md deleted file mode 100644 index e69de29..0000000 diff --git a/ext/sse/client.go b/ext/sse/client.go index 371c3e5..e60e790 100644 --- a/ext/sse/client.go +++ b/ext/sse/client.go @@ -2,20 +2,21 @@ package sse import ( "context" - "encoding/json" "errors" - "fmt" ) -var ErrClientClosed = errors.New("sse: client closed") +var ( + ErrClientClosed = errors.New("sse: client closed") +) -// Client represents a WebSocket client that handles HTTP responses and supports -// flushing data to the client. It contains a response writer, a flusher for -// sending data immediately, and a channel for managing the client's lifecycle. +// Client represents a connection to a streaming service. +// It holds the client's ID, a Streamer instance for managing the stream, +// a context for cancellation and timeout, and a channel for signaling closure. type Client struct { - s Streamer - - ctx context.Context + ID string + s Streamer + ctx context.Context + cancel context.CancelFunc } // Connect establishes a connection for the Client using the provided Streamer. @@ -23,7 +24,7 @@ type Client struct { // the http.Flusher interface for flushing data. func (c *Client) Connect(ctx context.Context, s Streamer) { c.s = s - c.ctx = ctx + c.ctx, c.cancel = context.WithCancel(ctx) } // Send sends an event to the client by writing the event name and data to the response writer. @@ -32,15 +33,11 @@ func (c *Client) Connect(ctx context.Context, s Streamer) { func (c *Client) Send(event Event) error { select { case <-c.ctx.Done(): - return ErrClientClosed + return NewError(c.ID, ErrClientClosed) default: - buf, err := json.Marshal(event.Data) - if err != nil { - return err - } - _, err = fmt.Fprintf(c.s, "event: %s\ndata: %s\n\n", event.Name, string(buf)) + err := event.Write(c.s) if err != nil { - return err + return NewError(c.ID, err) } c.s.Flush() @@ -48,3 +45,16 @@ func (c *Client) Send(event Event) error { return nil } + +// Wait blocks until the context is done or the client is closed. +// It listens for either the cancellation of the context or a signal +// to close the client, allowing for graceful shutdown. +func (c *Client) Wait() { + <-c.ctx.Done() +} + +// Close gracefully shuts down the Client by sending a signal to the close channel. +// This method should be called to ensure that any ongoing operations are properly terminated. +func (c *Client) Close() { + c.cancel() +} diff --git a/ext/sse/error.go b/ext/sse/error.go new file mode 100644 index 0000000..66ccdd1 --- /dev/null +++ b/ext/sse/error.go @@ -0,0 +1,17 @@ +package sse + +type Error struct { + ClientID string + error +} + +func (e *Error) Unwrap() error { + return e.error +} + +func NewError(clientID string, err error) *Error { + return &Error{ + ClientID: clientID, + error: err, + } +} diff --git a/ext/sse/event.go b/ext/sse/event.go index 5661a81..8bca6a4 100644 --- a/ext/sse/event.go +++ b/ext/sse/event.go @@ -1,8 +1,50 @@ package sse -// Event represents a server-sent event with a name and associated data. -// It can be used to transmit information from the server to the client in real-time. -type Event struct { +import ( + "encoding/json" + "fmt" + "io" +) + +// Event represents an interface for writing event data to an io.Writer. +// Implementations of this interface must provide the Write method, +// which takes an io.Writer and returns an error if the write operation fails. +type Event interface { + Write(r io.Writer) error +} + +// TextEvent represents a simple event structure with a name and associated data. +// It is used to encapsulate information for events in the SSE (Server-Sent Events) protocol. +type TextEvent struct { + Name string + Data string +} + +// Write formats the TextEvent as a string and writes it to the provided io.Writer. +// It outputs the event name and data in the SSE format, followed by two newlines. +// Returns an error if the write operation fails. +func (e *TextEvent) Write(w io.Writer) error { + _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", e.Name, e.Data) + return err +} + +// JsonEvent represents an event with a name and associated data. +// It can be used to structure events in a JSON format in the SSE (Server-Sent Events) protocol. +type JsonEvent struct { Name string Data any } + +// Write serializes the JsonEvent to the provided io.Writer in the SSE format. +// It writes the event name and the JSON-encoded data, followed by a double newline +// to indicate the end of the event. If an error occurs during marshaling or writing, +// it returns the error. +func (e *JsonEvent) Write(w io.Writer) error { + buf, err := json.Marshal(e.Data) + if err != nil { + return err + } + _, err = fmt.Fprintf(w, "event: %s\ndata: %s\n\n", e.Name, string(buf)) + + return err +} diff --git a/ext/sse/server.go b/ext/sse/server.go index 2ed3e59..b189b2b 100644 --- a/ext/sse/server.go +++ b/ext/sse/server.go @@ -4,6 +4,7 @@ package sse import ( "context" + "net/http" "sync" "github.com/yaitoo/async" @@ -28,13 +29,20 @@ func New() *Server { // Join adds a new client to the server or retrieves an existing one based on the provided ID. // It establishes a connection with the specified Streamer and sets the appropriate headers // for Server-Sent Events (SSE). If a client with the given ID already exists, it reuses that client. -func (s *Server) Join(ctx context.Context, id string, sm Streamer) *Client { +func (s *Server) Join(ctx context.Context, id string, rw http.ResponseWriter) (*Client, error) { + sm, err := NewStreamer(rw) + if err != nil { + return nil, err + } + s.Lock() defer s.Unlock() c, ok := s.clients[id] if !ok { - c = &Client{} + c = &Client{ + ID: id, + } s.clients[id] = c } @@ -44,7 +52,7 @@ func (s *Server) Join(ctx context.Context, id string, sm Streamer) *Client { sm.Header().Set("Cache-Control", "no-cache") sm.Header().Set("Connection", "keep-alive") - return c + return c, nil } // Leave removes a client from the server's client list by its ID. @@ -74,13 +82,32 @@ func (s *Server) Broadcast(ctx context.Context, event Event) ([]error, error) { s.RLock() defer s.RUnlock() - task := async.NewA() + tasks := async.NewA() for _, c := range s.clients { - task.Add(func(ctx context.Context) error { - return c.Send(event) + + tasks.Add(func(ctx context.Context) error { + if err := ctx.Err(); err != nil { + return NewError(c.ID, err) + } + if err := c.Send(event); err != nil { + return NewError(c.ID, err) + } + return nil }) } - return task.Wait(ctx) + return tasks.Wait(ctx) +} + +// Shutdown gracefully closes all active client connections and cleans up the client list. +// It locks the server to ensure thread safety during the shutdown process. +func (s *Server) Shutdown() { + s.Lock() + defer s.Unlock() + for _, c := range s.clients { + c.Close() + } + + s.clients = make(map[string]*Client) } diff --git a/ext/sse/server_test.go b/ext/sse/server_test.go index 58af42d..d00ab59 100644 --- a/ext/sse/server_test.go +++ b/ext/sse/server_test.go @@ -6,8 +6,10 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/require" + "github.com/yaitoo/async" ) func TestServer(t *testing.T) { @@ -15,16 +17,37 @@ func TestServer(t *testing.T) { srv := New() rw := httptest.NewRecorder() - c1 := srv.Join(context.TODO(), "join", rw) + c1, err := srv.Join(context.TODO(), "join", nil) + require.Nil(t, c1) + require.ErrorIs(t, err, ErrNotStreamer) - c2 := srv.Join(context.TODO(), "join", rw) + c1, err = srv.Join(context.TODO(), "join", ¬Streamer{}) + require.Nil(t, c1) + require.ErrorIs(t, err, ErrNotStreamer) + + c1, err = srv.Join(context.TODO(), "join", rw) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + c2, err := srv.Join(ctx, "join", rw) + require.NoError(t, err) require.Equal(t, c1, c2) + require.Equal(t, "join", c1.ID) c3 := srv.Get("join") require.Equal(t, c1, c3) + go func() { + time.Sleep(1 * time.Second) + c3.Close() + }() + + c2.Wait() + c3.Wait() + srv.Leave("join") c4 := srv.Get("join") @@ -36,17 +59,18 @@ func TestServer(t *testing.T) { srv := New() rw := httptest.NewRecorder() - c := srv.Join(context.TODO(), "send", rw) + c, err := srv.Join(context.TODO(), "send", rw) + require.NoError(t, err) - err := c.Send(Event{Name: "event1", Data: "data1"}) + err = c.Send(&TextEvent{Name: "event1", Data: "data1"}) require.NoError(t, err) buf := rw.Body.Bytes() - require.Equal(t, "event: event1\ndata: \"data1\"\n\n", string(buf)) + require.Equal(t, "event: event1\ndata: data1\n\n", string(buf)) - err = c.Send(Event{Name: "event2", Data: "data2"}) + err = c.Send(&JsonEvent{Name: "event2", Data: "data2"}) require.NoError(t, err) buf = rw.Body.Bytes() - require.Equal(t, "event: event1\ndata: \"data1\"\n\nevent: event2\ndata: \"data2\"\n\n", string(buf)) + require.Equal(t, "event: event1\ndata: data1\n\nevent: event2\ndata: \"data2\"\n\n", string(buf)) }) t.Run("broadcast", func(t *testing.T) { @@ -55,13 +79,15 @@ func TestServer(t *testing.T) { rw1 := httptest.NewRecorder() rw2 := httptest.NewRecorder() - c1 := srv.Join(context.TODO(), "c1", rw1) + c1, err := srv.Join(context.TODO(), "c1", rw1) require.NotNil(t, c1) + require.NoError(t, err) - c2 := srv.Join(context.TODO(), "c2", rw2) + c2, err := srv.Join(context.TODO(), "c2", rw2) require.NotNil(t, c2) + require.NoError(t, err) - errs, err := srv.Broadcast(context.TODO(), Event{Name: "event1", Data: "data1"}) + errs, err := srv.Broadcast(context.TODO(), &TextEvent{Name: "event1", Data: "data1"}) require.NoError(t, err) require.Nil(t, errs) @@ -69,7 +95,14 @@ func TestServer(t *testing.T) { buf2 := rw2.Body.Bytes() require.Equal(t, buf1, buf2) - require.Equal(t, "event: event1\ndata: \"data1\"\n\n", string(buf1)) + require.Equal(t, "event: event1\ndata: data1\n\n", string(buf1)) + + ctx, cancel := context.WithCancel(context.TODO()) + cancel() + + _, err = srv.Broadcast(ctx, &TextEvent{Name: "event1", Data: "data1"}) + require.ErrorIs(t, err, context.Canceled) + }) t.Run("invalid", func(t *testing.T) { @@ -81,22 +114,57 @@ func TestServer(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) - c := srv.Join(ctx, "invalid", rw) + c, err := srv.Join(ctx, "invalid", rw) + require.NoError(t, err) - err := c.Send(Event{Name: "event1", Data: make(chan int)}) + err = c.Send(&JsonEvent{Name: "event1", Data: make(chan int)}) require.Error(t, err) - err = c.Send(Event{Name: "event1"}) + err = c.Send(&TextEvent{Name: "event1"}) require.Error(t, err) cancel() - err = c.Send(Event{Name: "event1"}) + err = c.Send(&TextEvent{Name: "event1"}) require.ErrorIs(t, err, ErrClientClosed) + errs, err := srv.Broadcast(context.TODO(), &TextEvent{Name: "event1", Data: "data1"}) + require.ErrorIs(t, err, async.ErrTooLessDone) + require.Len(t, errs, 1) + + }) + + t.Run("shutdown", func(t *testing.T) { + srv := New() + + c1, err := srv.Join(context.TODO(), "c1", httptest.NewRecorder()) + require.NoError(t, err) + require.NotNil(t, c1) + + c2, err := srv.Join(context.TODO(), "c1", httptest.NewRecorder()) + require.NoError(t, err) + require.NotNil(t, c2) + srv.Shutdown() + c1.Wait() + c2.Wait() + + require.Len(t, srv.clients, 0) }) } +type notStreamer struct { +} + +func (s *notStreamer) Header() http.Header { + return http.Header{} +} + +func (s *notStreamer) Write([]byte) (int, error) { + return 0, errors.New("mock: invalid") +} + +func (s *notStreamer) WriteHeader(int) {} + type streamerMock struct { http.ResponseWriter } diff --git a/ext/sse/streamer.go b/ext/sse/streamer.go index 703968c..afb2fab 100644 --- a/ext/sse/streamer.go +++ b/ext/sse/streamer.go @@ -1,10 +1,33 @@ package sse import ( + "errors" "net/http" ) +var ErrNotStreamer = errors.New("sse: not streamer") + type Streamer interface { http.ResponseWriter http.Flusher } + +type stdStreamer struct { + http.ResponseWriter + http.Flusher +} + +// NewStreamer creates a new Streamer instance from the provided http.ResponseWriter. +// It returns an error if the ResponseWriter is nil or does not implement the http.Flusher interface. +// This function is intended for use in handling server-sent events (SSE). +func NewStreamer(rw http.ResponseWriter) (Streamer, error) { + if rw == nil { + return nil, ErrNotStreamer + } + + flusher, ok := rw.(http.Flusher) + if !ok { + return nil, ErrNotStreamer + } + return &stdStreamer{rw, flusher}, nil +} diff --git a/response_writer_std.go b/response_writer_std.go index 26c31d3..5ec2dcc 100644 --- a/response_writer_std.go +++ b/response_writer_std.go @@ -14,6 +14,10 @@ type stdResponseWriter struct { func (*stdResponseWriter) Close() { } +// WriteHeader sends an HTTP response header with the specified status code. +// It ensures that the header is only written once by checking if the statusCode +// has already been set. If the statusCode is zero, it updates the statusCode +// and calls the underlying ResponseWriter's WriteHeader method to send the header. func (rw *stdResponseWriter) WriteHeader(statusCode int) { if rw.statusCode == 0 { rw.statusCode = statusCode @@ -21,6 +25,8 @@ func (rw *stdResponseWriter) WriteHeader(statusCode int) { } } +// StatusCode returns the HTTP status code of the response writer. +// If the status code has not been set, it defaults to http.StatusOK. func (rw *stdResponseWriter) StatusCode() int { if rw.statusCode == 0 { return http.StatusOK @@ -28,12 +34,16 @@ func (rw *stdResponseWriter) StatusCode() int { return rw.statusCode } +// BodyBytesSent returns the number of bytes sent in the response body. +// It is a method of the stdResponseWriter type and provides access +// to the internal byte count for monitoring or logging purposes. func (rw *stdResponseWriter) BodyBytesSent() int { return rw.bodySentBytes } +// Write writes the data to the underlying ResponseWriter and tracks the number of bytes sent. +// It returns the number of bytes written and any error encountered during the write operation. func (rw *stdResponseWriter) Write(b []byte) (int, error) { - n, err := rw.ResponseWriter.Write(b) rw.bodySentBytes = rw.bodySentBytes + n @@ -41,6 +51,17 @@ func (rw *stdResponseWriter) Write(b []byte) (int, error) { return n, err } +// Flush sends any buffered data to the client. It implements the http.Flusher interface, +// allowing the response writer to flush the response immediately. +func (rw *stdResponseWriter) Flush() { + f, ok := rw.ResponseWriter.(http.Flusher) + if ok { + f.Flush() + } +} + +// NewResponseWriter creates a new instance of ResponseWriter that wraps the provided http.ResponseWriter. +// It returns a pointer to a stdResponseWriter, which implements the ResponseWriter interface. func NewResponseWriter(rw http.ResponseWriter) ResponseWriter { return &stdResponseWriter{ResponseWriter: rw} }