Repository: go-pkgz/routegroup Branch: master Commit: 3ab4ae10eff2 Files: 17 Total size: 177.5 KB Directory structure: gitextract_nc4z17bg/ ├── .github/ │ ├── CODEOWNERS │ ├── FUNDING.yml │ └── workflows/ │ └── ci.yml ├── .gitignore ├── .golangci.yml ├── LICENSE ├── README.md ├── fileserver_test.go ├── go.mod ├── go.sum ├── group.go ├── group_test.go ├── middleware_test.go ├── mount_test.go ├── notfound_test.go ├── pathparams_test.go └── routing_test.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/CODEOWNERS ================================================ # These owners will be the default owners for everything in the repo. # Unless a later match takes precedence, @umputun will be requested for # review when someone opens a pull request. * @umputun ================================================ FILE: .github/FUNDING.yml ================================================ github: [umputun] ================================================ FILE: .github/workflows/ci.yml ================================================ name: build on: push: branches: tags: pull_request: permissions: contents: read # to fetch code (actions/checkout) jobs: build: runs-on: ubuntu-latest steps: - name: checkout uses: actions/checkout@v4 - name: set up go uses: actions/setup-go@v5 with: go-version: "1.23" id: go - name: build and test run: | go get -v go test -timeout=60s -race -covermode=atomic -coverprofile=$GITHUB_WORKSPACE/profile.cov_tmp cat $GITHUB_WORKSPACE/profile.cov_tmp | grep -v "_mock.go" > $GITHUB_WORKSPACE/profile.cov go build -race env: GO111MODULE: "on" TZ: "America/Chicago" - name: golangci-lint uses: golangci/golangci-lint-action@v7 with: version: v2.6 - name: install goveralls run: | go install github.com/mattn/goveralls@latest - name: submit coverage run: $(go env GOPATH)/bin/goveralls -service="github" -coverprofile=$GITHUB_WORKSPACE/profile.cov env: COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} ================================================ FILE: .gitignore ================================================ # Coverage files coverage.out coverage.html *.cover # Test binaries *.test # Go workspace files go.work go.work.sum ================================================ FILE: .golangci.yml ================================================ version: "2" run: concurrency: 4 linters: default: none enable: - copyloopvar - gochecknoinits - gocritic - gosec - govet - ineffassign - misspell - nakedret - prealloc - revive - staticcheck - unconvert - unparam - unused settings: goconst: min-len: 2 min-occurrences: 2 gocritic: disabled-checks: - wrapperFunc enabled-tags: - performance - style - experimental gocyclo: min-complexity: 15 govet: enable: - shadow lll: line-length: 140 misspell: locale: US exclusions: generated: lax rules: - linters: - gosec text: 'G114: Use of net/http serve function that has no support for setting timeouts' - linters: - revive - unparam path: _test\.go$ text: unused-parameter paths: - third_party$ - builtin$ - examples$ formatters: exclusions: generated: lax paths: - third_party$ - builtin$ - examples$ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2024 Umputun Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ ## routegroup [![Build Status](https://github.com/go-pkgz/routegroup/workflows/build/badge.svg)](https://github.com/go-pkgz/routegroup/actions) [![Go Report Card](https://goreportcard.com/badge/github.com/go-pkgz/routegroup)](https://goreportcard.com/report/github.com/go-pkgz/routegroup) [![Coverage Status](https://coveralls.io/repos/github/go-pkgz/routegroup/badge.svg?branch=master)](https://coveralls.io/github/go-pkgz/routegroup?branch=master) [![godoc](https://godoc.org/github.com/go-pkgz/routegroup?status.svg)](https://godoc.org/github.com/go-pkgz/routegroup) `routegroup` is a tiny Go package providing a lightweight wrapper for efficient route grouping and middleware integration with the standard `http.ServeMux`. ## Features - Simple and intuitive API for route grouping and route mounting. - Lightweight, just about 100 LOC - Easy middleware integration for individual routes or groups of routes. - Seamless integration with Go's standard `http.ServeMux`. - Fully compatible with the `http.Handler` interface and can be used as a drop-in replacement for `http.ServeMux`. - No external dependencies. ## Requirements - Go 1.23 or higher *(This library uses `http.Request.Pattern` to make route patterns available to global middlewares and relies on the enhanced `http.ServeMux` routing behavior introduced in Go 1.22/1.23)* ## Install and update `go get -u github.com/go-pkgz/routegroup` ## Usage **Creating a New Route Group** To start, create a new route group without a base path: ```go func main() { mux := http.NewServeMux() group := routegroup.New(mux) } ``` **Adding Routes with Middleware** Add routes to your group, optionally with middleware: ```go group.Use(loggingMiddleware, corsMiddleware) group.Handle("/hello", helloHandler) group.Handle("/bye", byeHandler) ``` **Creating a Nested Route Group** For routes under a specific path prefix `Mount` method can be used to create a nested group: ```go apiGroup := routegroup.Mount(mux, "/api") apiGroup.Use(loggingMiddleware, corsMiddleware) apiGroup.Handle("/v1", apiV1Handler) apiGroup.Handle("/v2", apiV2Handler) ``` **Complete Example** Here's a complete example demonstrating route grouping and middleware usage: ```go package main import ( "net/http" "github.com/go-pkgz/routegroup" ) func main() { router := routegroup.New(http.NewServeMux()) router.Use(loggingMiddleware) // handle the /hello route router.Handle("GET /hello", helloHandler) // create a new group for the /api path apiRouter := router.Mount("/api") // add middleware apiRouter.Use(loggingMiddleware, corsMiddleware) // route handling apiRouter.HandleFunc("GET /hello", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hello, API!")) }) // add another group with its own set of middlewares protectedGroup := router.Group() protectedGroup.Use(authMiddleware) protectedGroup.HandleFunc("GET /protected", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Protected API!")) }) http.ListenAndServe(":8080", router) } ``` **Applying Middleware to Specific Routes** You can also apply middleware to specific routes inside the group without modifying the group's middleware stack: ```go apiGroup.With(corsMiddleware, apiMiddleware).Handle("GET /hello", helloHandler) ``` **Alternative Usage with `Route`** You can also use the `Route` method to add routes and middleware in a single function call: ```go router := routegroup.New(http.NewServeMux()) router.Route(func(b *routegroup.Bundle) { b.Use(loggingMiddleware, corsMiddleware) b.Handle("GET /hello", helloHandler) b.Handle("GET /bye", byeHandler) }) http.ListenAndServe(":8080", router) ``` When called on the root bundle, `Route` automatically creates a new group to avoid accidentally modifying the root bundle's middleware stack. This means the middleware and routes defined inside the `Route` function are isolated from other routes on the root bundle. The `Route` method can also be chained after `Mount` or `Group` for a more functional style: ```go router := routegroup.New(http.NewServeMux()) router.Group().Route(func(b *routegroup.Bundle) { b.Use(loggingMiddleware, corsMiddleware) b.Handle("GET /hello", helloHandler) b.Handle("GET /bye", byeHandler) }) ``` **Setting optional `NotFoundHandler`** It is possible to set a custom `NotFoundHandler` for the group. This handler will be called when no other route matches the request: ```go group.NotFoundHandler(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "404 page not found, something is wrong!", http.StatusNotFound) } ``` If a custom `NotFoundHandler` is not configured, `routegroup` will default to using the standard library behavior. Note on 405: In the current design, `routegroup` applies root-level middlewares to all requests at the top level without installing a catch‑all route. This preserves native `405 Method Not Allowed` responses from `http.ServeMux` when a path exists but a wrong method is used. A configured `NotFoundHandler` is only invoked when no route matches; it does not interfere with 405 handling. The custom `NotFoundHandler` will have the root bundle's global middlewares applied to it. Legacy note: `DisableNotFoundHandler()` is now a no‑op and preserved only for API compatibility. ### Middleware Ordering - Call `Use(...)` before registering routes on the same bundle. Calling `Use` after any handler has been registered on that bundle will panic with a descriptive error. - Root bundle middlewares (added via `router.Use(...)`) are applied globally to all requests at serve time. - Group/bundle middlewares (added via `group.Use(...)`) apply to the routes registered on that bundle and its descendants, provided they are added before those routes. - `With(...)` returns a new bundle; you can add middlewares there first, then register routes. This is the preferred way to add scoped middlewares without affecting previously defined routes. **Important**: Route registration (HandleFunc, Handle, HandleFiles, etc.) should be done during initialization and not performed concurrently. The library is designed for typical usage where routes are registered at startup time in a single goroutine. Examples Incorrect: calling `Use` after routes on the same bundle (will panic) ```go mux := http.NewServeMux() router := routegroup.New(mux) router.HandleFunc("/r", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) // This will panic: Use called after routes were registered on this bundle router.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // global header w.Header().Set("X-Global", "true") next.ServeHTTP(w, r) }) }) ``` Allowed: parent/root `Use` after child bundle routes ```go mux := http.NewServeMux() router := routegroup.New(mux) child := router.Group() child.HandleFunc("/child", func(w http.ResponseWriter, _ *http.Request) { w.Write([]byte("ok")) }) // Parent has not registered its own routes yet; this is allowed and will apply globally router.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Parent", "true") next.ServeHTTP(w, r) }) }) ``` Preferred: use `With` (or `Group`+`Use`) to attach scoped middleware before routes ```go mux := http.NewServeMux() router := routegroup.New(mux) // Global middleware (optional), add before any root routes router.Use(loggingMiddleware) // Scoped middleware using With: returns a new bundle on which we can add routes api := router.With(authMiddleware) api.HandleFunc("GET /items", itemsHandler) api.HandleFunc("POST /items", createItem) // Or using Group + Use before routes admin := router.Group() admin.Use(adminOnly) admin.HandleFunc("GET /dashboard", dashboardHandler) ``` **Handling Root Paths Without Trailing Slashes** When working with mounted groups, you often need to handle requests to the group's root path without a trailing slash. For this purpose, `routegroup` provides the `HandleRoot` or `HandleRootFunc` methods: ```go // Create mounted groups apiGroup := router.Mount("/api") v1Group := apiGroup.Mount("/v1") usersGroup := v1Group.Mount("/users") // Handle the root paths (no trailing slashes) apiGroup.HandleRoot("GET", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // This handles requests to "/api" (without trailing slash) w.Write([]byte("API Documentation")) })) usersGroup.HandleRoot("GET", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // This handles requests to "/api/v1/users" (without trailing slash) w.Write([]byte("List users")) })) // Different HTTP methods can be handled separately usersGroup.HandleRoot("POST", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // This handles POST requests to "/api/v1/users" w.Write([]byte("Create user")) })) ``` While it's also possible to handle such paths using a trailing slash pattern (`"/"`) with the regular `Handle` or `HandleFunc` methods, that approach results in a redirect from non-trailing slash URLs (e.g., `/api`) to the trailing slash version (e.g., `/api/`). The `HandleRoot` method avoids this redirect, providing a more direct response and avoiding an extra round-trip, which is especially important for non-GET requests or when clients don't automatically follow redirects. ### Using derived groups In some instances, it's practical to create an initial group that includes a set of middlewares, and then derive all other groups from it. This approach guarantees that every group incorporates a common set of middlewares as a foundation, allowing each to add its specific middlewares. To facilitate this scenario, `routegroup` offers both `Bundle.Group` and `Bundle.Mount` methods, and it also implements the `http.Handler` interface. The following example illustrates how to use derived groups: ```go // create a new bundle with a base set of middlewares // note: the bundle is also http.Handler and can be passed to http.ListenAndServe router := routegroup.New(http.NewServeMux()) router.Use(loggingMiddleware, corsMiddleware) // add a new, derived group with its own set of middlewares // this group will inherit the middlewares from the base group apiGroup := router.Group() apiGroup.Use(apiMiddleware) apiGroup.Handle("GET /hello", helloHandler) apiGroup.Handle("GET /bye", byeHandler) // mount another group for the /admin path with its own set of middlewares, // using `Route` method to show the alternative usage. // this group will inherit the middlewares from the base group as well router.Mount("/admin").Route(func(b *routegroup.Bundle) { b.Use(adminMiddleware) b.Handle("POST /do", doHandler) }) // start the server, passing the wrapped mux as the handler http.ListenAndServe(":8080", router) ``` ### Wrap Function Sometimes route's group is not necessary, and all you need is to apply middleware(s) directly to a single route. In this case, `routegroup` provides a `Wrap` function that can be used to wrap a single `http.Handler` with one or more middlewares. Here's an example: ```go mux := http.NewServeMux() mux.HandleFunc("/hello", routegroup.Wrap(helloHandler, loggingMiddleware, corsMiddleware)) http.ListenAndServe(":8080", mux) ``` ### 404 and 405 behavior `routegroup` applies the root bundle's middlewares to all requests at the top level. This keeps the standard library's matching logic intact: - Wrong method on an existing path returns `405 Method Not Allowed` (with an `Allow` header). - Unknown path returns `404 Not Found`. You can optionally configure a custom 404 handler with `NotFoundHandler(fn)`. It will run only when no route matches and does not affect 405 handling. The custom handler will have global middlewares applied to it. The legacy `DisableNotFoundHandler()` is now a no‑op and kept only for compatibility. ### HandleFiles helper `routegroup` provides a helper function `HandleFiles` that can be used to serve static files from a directory. The function is a thin wrapper around the standard `http.FileServer` and can be used to serve files from a specific directory. Here's an example: ```go // serve static files from the "assets/static" directory router.HandleFiles("/static/", http.Dir("assets/static")) ``` ## Real-world example Here's an example of how `routegroup` can be used in a real-world application. The following code snippet is taken from a web service that provides a set of routes for user authentication, session management, and user management. The service also serves static files from the "assets/static" embedded file system. ```go // Routes returns http.Handler that handles all the routes for the Service. // It also serves static files from the "assets/static" directory. // The rootURL option sets prefix for the routes. func (s *Service) Routes() http.Handler { router := routegroup.Mount(http.NewServeMux(), s.rootURL) // make a bundle with the rootURL base path // add common middlewares router.Use(rest.Maybe(handlers.CompressHandler, func(*http.Request) bool { return !s.skipGZ })) router.Use(rest.Throttle(s.limitActiveReqs)) router.Use(s.middleware.securityHeaders(s.skipSecurityHeaders)) // prepare csrf middleware csrfMiddleware := s.middleware.csrf(s.skipCSRFCheck) // add open routes router.HandleFunc("GET /login", s.loginPageHandler) router.HandleFunc("POST /login", s.loginCheckHandler) router.HandleFunc("GET /logout", s.logoutHandler) // add routes with auth middleware router.Group().Route(func(auth *routegroup.Bundle) { auth.Use(s.middleware.Auth()) auth.HandleFunc("GET /update", s.pwdUpdateHandler) auth.With(csrfMiddleware).HandleFunc("PUT /update", s.pwdUpdateHandler) }) // add admin routes router.Mount("/admin").Route(func(admin *routegroup.Bundle) { admin.Use(s.middleware.Auth("admin")) admin.Use(s.middleware.AdminOnly) admin.HandleFunc("GET /", s.admin.renderHandler) admin.With(csrfMiddleware).Route(func(csrf *routegroup.Bundle) { csrf.HandleFunc("DELETE /sessions", s.admin.deleteSessionsHandler) csrf.HandleFunc("POST /user", s.admin.addUserHandler) csrf.HandleFunc("DELETE /user", s.admin.deleteUserHandler) }) }) router.HandleFunc("GET /static/*", s.fileServerHandlerFunc()) // serve static files return router } // fileServerHandlerFunc returns http.HandlerFunc that serves static files from the "assets/static" directory. // prefix is set by the rootURL option. func (s *Service) fileServerHandlerFunc() http.HandlerFunc { staticFS, err := fs.Sub(assets, "assets/static") // error is always nil if err != nil { panic(err) // should never happen we load from embedded FS } return func(w http.ResponseWriter, r *http.Request) { webFS := http.StripPrefix(s.rootURL+"/static/", http.FileServer(http.FS(staticFS))) webFS.ServeHTTP(w, r) } } ``` ## Contributing Contributions to `routegroup` are welcome! Please submit a pull request or open an issue for any bugs or feature requests. ## License `routegroup` is available under the MIT license. See the [LICENSE](https://github.com/go-pkgz/routegroup/blob/master/LICENSE) file for more info. ================================================ FILE: fileserver_test.go ================================================ package routegroup_test import ( "bytes" "fmt" "io" "net/http" "net/http/httptest" "os" "path/filepath" "testing" "github.com/go-pkgz/routegroup" ) func TestStaticFileServer(t *testing.T) { dir := t.TempDir() // create test file structure content := []byte("static file content") err := os.WriteFile(filepath.Join(dir, "test.txt"), content, 0o600) if err != nil { t.Fatal(err) } err = os.WriteFile(filepath.Join(dir, "index.html"), []byte("index content"), 0o600) if err != nil { t.Fatal(err) } // create subdirectory subDir := filepath.Join(dir, "sub") if err = os.Mkdir(subDir, 0o750); err != nil { t.Fatal(err) } subContent := []byte("sub file content") err = os.WriteFile(filepath.Join(subDir, "sub.txt"), subContent, 0o600) if err != nil { t.Fatal(err) } t.Run("serve files from root path with HEAD", func(t *testing.T) { router := routegroup.New(http.NewServeMux()) router.HandleFiles("/", http.Dir(dir)) srv := httptest.NewServer(router) defer srv.Close() // test GET request resp, err := http.Get(srv.URL + "/test.txt") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("GET - got status %d, want %d", resp.StatusCode, http.StatusOK) } if !bytes.Equal(body, content) { t.Errorf("GET - got body %q, want %q", body, content) } // test HEAD request req, err := http.NewRequest(http.MethodHead, srv.URL+"/test.txt", http.NoBody) if err != nil { t.Fatal(err) } resp, err = http.DefaultClient.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err = io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("HEAD - got status %d, want %d", resp.StatusCode, http.StatusOK) } if len(body) != 0 { t.Errorf("HEAD - should have no body, got %d bytes", len(body)) } if cl := resp.Header.Get("Content-Length"); cl != fmt.Sprint(len(content)) { t.Errorf("HEAD - got Content-Length %s, want %d", cl, len(content)) } }) t.Run("serve files from /files/ prefix", func(t *testing.T) { router := routegroup.New(http.NewServeMux()) router.HandleFiles("/files", http.Dir(dir)) srv := httptest.NewServer(router) defer srv.Close() // test GET request resp, err := http.Get(srv.URL + "/files/test.txt") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("GET - got status %d, want %d", resp.StatusCode, http.StatusOK) } if !bytes.Equal(body, content) { t.Errorf("GET - got body %q, want %q", body, content) } // test HEAD request req, err := http.NewRequest(http.MethodHead, srv.URL+"/files/test.txt", http.NoBody) if err != nil { t.Fatal(err) } resp, err = http.DefaultClient.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err = io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("HEAD - got status %d, want %d", resp.StatusCode, http.StatusOK) } if len(body) != 0 { t.Errorf("HEAD - should have no body, got %d bytes", len(body)) } if cl := resp.Header.Get("Content-Length"); cl != fmt.Sprint(len(content)) { t.Errorf("HEAD - got Content-Length %s, want %d", cl, len(content)) } }) t.Run("serve files from mounted group", func(t *testing.T) { router := routegroup.New(http.NewServeMux()) assets := router.Mount("/assets") assets.HandleFiles("/", http.Dir(dir)) srv := httptest.NewServer(router) defer srv.Close() // test GET request resp, err := http.Get(srv.URL + "/assets/test.txt") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("GET - got status %d, want %d", resp.StatusCode, http.StatusOK) } if !bytes.Equal(body, content) { t.Errorf("GET - got body %q, want %q", body, content) } // test HEAD request req, err := http.NewRequest(http.MethodHead, srv.URL+"/assets/test.txt", http.NoBody) if err != nil { t.Fatal(err) } resp, err = http.DefaultClient.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err = io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("HEAD - got status %d, want %d", resp.StatusCode, http.StatusOK) } if len(body) != 0 { t.Errorf("HEAD - should have no body, got %d bytes", len(body)) } if cl := resp.Header.Get("Content-Length"); cl != fmt.Sprint(len(content)) { t.Errorf("HEAD - got Content-Length %s, want %d", cl, len(content)) } }) } func TestDirectFileServerHandle(t *testing.T) { dir := t.TempDir() content := []byte("static file content") err := os.WriteFile(filepath.Join(dir, "test.txt"), content, 0o600) if err != nil { t.Fatal(err) } t.Run("raw Handle without strip", func(t *testing.T) { router := routegroup.New(http.NewServeMux()) router.Handle("/files/", http.FileServer(http.Dir(dir))) // without StripPrefix! srv := httptest.NewServer(router) defer srv.Close() // test GET request - should fail as we need StripPrefix resp, err := http.Get(srv.URL + "/files/test.txt") if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.StatusCode != http.StatusNotFound { t.Errorf("expected 404 without StripPrefix, got %d", resp.StatusCode) } // test HEAD request - should also fail req, err := http.NewRequest(http.MethodHead, srv.URL+"/files/test.txt", http.NoBody) if err != nil { t.Fatal(err) } resp, err = http.DefaultClient.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.StatusCode != http.StatusNotFound { t.Errorf("HEAD - expected 404 without StripPrefix, got %d", resp.StatusCode) } }) t.Run("Handle with strip prefix", func(t *testing.T) { router := routegroup.New(http.NewServeMux()) router.Handle("/files/", http.StripPrefix("/files/", http.FileServer(http.Dir(dir)))) srv := httptest.NewServer(router) defer srv.Close() // test GET request resp, err := http.Get(srv.URL + "/files/test.txt") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("GET - got status %d, want %d", resp.StatusCode, http.StatusOK) } if !bytes.Equal(body, content) { t.Errorf("GET - got body %q, want %q", body, content) } // test HEAD request req, err := http.NewRequest(http.MethodHead, srv.URL+"/files/test.txt", http.NoBody) if err != nil { t.Fatal(err) } resp, err = http.DefaultClient.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err = io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("HEAD - got status %d, want %d", resp.StatusCode, http.StatusOK) } if len(body) != 0 { t.Errorf("HEAD - should have no body, got %d bytes", len(body)) } if cl := resp.Header.Get("Content-Length"); cl != fmt.Sprint(len(content)) { t.Errorf("HEAD - got Content-Length %s, want %d", cl, len(content)) } }) t.Run("Handle with mounted group", func(t *testing.T) { router := routegroup.New(http.NewServeMux()) api := router.Mount("/api") api.Handle("/static/", http.StripPrefix("/api/static/", http.FileServer(http.Dir(dir)))) srv := httptest.NewServer(router) defer srv.Close() // test GET request resp, err := http.Get(srv.URL + "/api/static/test.txt") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("GET - got status %d, want %d", resp.StatusCode, http.StatusOK) } if !bytes.Equal(body, content) { t.Errorf("GET - got body %q, want %q", body, content) } // test HEAD request req, err := http.NewRequest(http.MethodHead, srv.URL+"/api/static/test.txt", http.NoBody) if err != nil { t.Fatal(err) } resp, err = http.DefaultClient.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err = io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("HEAD - got status %d, want %d", resp.StatusCode, http.StatusOK) } if len(body) != 0 { t.Errorf("HEAD - should have no body, got %d bytes", len(body)) } if cl := resp.Header.Get("Content-Length"); cl != fmt.Sprint(len(content)) { t.Errorf("HEAD - got Content-Length %s, want %d", cl, len(content)) } }) } func TestFileServerWithMiddleware(t *testing.T) { dir := t.TempDir() content := []byte("static file content") err := os.WriteFile(filepath.Join(dir, "test.txt"), content, 0o600) if err != nil { t.Fatal(err) } t.Run("root path with middleware", func(t *testing.T) { router := routegroup.New(http.NewServeMux()) router.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Root-MW", "called") next.ServeHTTP(w, r) }) }) router.HandleFiles("/", http.Dir(dir)) srv := httptest.NewServer(router) defer srv.Close() // test GET request resp, err := http.Get(srv.URL + "/test.txt") if err != nil { t.Fatal(err) } defer resp.Body.Close() if mw := resp.Header.Get("X-Root-MW"); mw != "called" { t.Errorf("middleware not called, got header %q", mw) } // test HEAD request req, err := http.NewRequest(http.MethodHead, srv.URL+"/test.txt", http.NoBody) if err != nil { t.Fatal(err) } resp, err = http.DefaultClient.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if mw := resp.Header.Get("X-Root-MW"); mw != "called" { t.Errorf("middleware not called for HEAD, got header %q", mw) } }) t.Run("prefixed path with middleware", func(t *testing.T) { router := routegroup.New(http.NewServeMux()) router.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Prefix-MW", "called") next.ServeHTTP(w, r) }) }) router.HandleFiles("/files", http.Dir(dir)) srv := httptest.NewServer(router) defer srv.Close() resp, err := http.Get(srv.URL + "/files/test.txt") if err != nil { t.Fatal(err) } defer resp.Body.Close() if mw := resp.Header.Get("X-Prefix-MW"); mw != "called" { t.Errorf("middleware not called, got header %q", mw) } }) t.Run("mounted path with chained middleware", func(t *testing.T) { router := routegroup.New(http.NewServeMux()) router.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Root-MW", "called") next.ServeHTTP(w, r) }) }) assets := router.Mount("/assets") assets.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Assets-MW", "called") next.ServeHTTP(w, r) }) }) assets.HandleFiles("/", http.Dir(dir)) srv := httptest.NewServer(router) defer srv.Close() // test both middleware being called resp, err := http.Get(srv.URL + "/assets/test.txt") if err != nil { t.Fatal(err) } defer resp.Body.Close() if mw := resp.Header.Get("X-Root-MW"); mw != "called" { t.Errorf("root middleware not called, got header %q", mw) } if mw := resp.Header.Get("X-Assets-MW"); mw != "called" { t.Errorf("assets middleware not called, got header %q", mw) } // test 404 path still triggers middleware resp, err = http.Get(srv.URL + "/assets/notfound.txt") if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.StatusCode != http.StatusNotFound { t.Errorf("got status %d, want %d", resp.StatusCode, http.StatusNotFound) } if mw := resp.Header.Get("X-Root-MW"); mw != "called" { t.Errorf("root middleware not called for 404, got header %q", mw) } if mw := resp.Header.Get("X-Assets-MW"); mw != "called" { t.Errorf("assets middleware not called for 404, got header %q", mw) } }) t.Run("direct Handle with middleware", func(t *testing.T) { router := routegroup.New(http.NewServeMux()) router.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Direct-MW", "called") next.ServeHTTP(w, r) }) }) router.Handle("/files/", http.StripPrefix("/files/", http.FileServer(http.Dir(dir)))) srv := httptest.NewServer(router) defer srv.Close() resp, err := http.Get(srv.URL + "/files/test.txt") if err != nil { t.Fatal(err) } defer resp.Body.Close() if mw := resp.Header.Get("X-Direct-MW"); mw != "called" { t.Errorf("middleware not called, got header %q", mw) } }) } func TestMixedHandlers(t *testing.T) { dir := t.TempDir() // create static files content := []byte("static file content") err := os.WriteFile(filepath.Join(dir, "test.txt"), content, 0o600) if err != nil { t.Fatal(err) } router := routegroup.New(http.NewServeMux()) // setup regular and file handlers in various combinations router.HandleFunc("GET /api/info", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("api info")) }) router.HandleFiles("/public", http.Dir(dir)) // setup api group with mixed handlers api := router.Mount("/v1") api.HandleFunc("GET /data", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("api data")) }) api.HandleFiles("/static", http.Dir(dir)) // setup admin group with both types and middleware admin := router.Mount("/admin") admin.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Admin", "true") next.ServeHTTP(w, r) }) }) admin.HandleFunc("GET /users", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("admin users")) }) admin.HandleFiles("/assets", http.Dir(dir)) srv := httptest.NewServer(router) defer srv.Close() tests := []struct { name string path string expectedStatus int expectedBody string expectedHeader string // for middleware check }{ {"api info endpoint", "/api/info", http.StatusOK, "api info", ""}, {"public static file", "/public/test.txt", http.StatusOK, "static file content", ""}, {"v1 api endpoint", "/v1/data", http.StatusOK, "api data", ""}, {"v1 static file", "/v1/static/test.txt", http.StatusOK, "static file content", ""}, {"admin endpoint", "/admin/users", http.StatusOK, "admin users", "true"}, {"admin static file", "/admin/assets/test.txt", http.StatusOK, "static file content", "true"}, {"non-existent api path", "/api/notfound", http.StatusNotFound, "404 page not found\n", ""}, {"non-existent static file", "/public/notfound.txt", http.StatusNotFound, "404 page not found\n", ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { resp, err := http.Get(srv.URL + tt.path) if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != tt.expectedStatus { t.Errorf("got status %d, want %d", resp.StatusCode, tt.expectedStatus) } if string(body) != tt.expectedBody { t.Errorf("got body %q, want %q", string(body), tt.expectedBody) } if tt.expectedHeader != "" { if h := resp.Header.Get("X-Admin"); h != tt.expectedHeader { t.Errorf("got X-Admin header %q, want %q", h, tt.expectedHeader) } } }) } } func TestIssue12StaticAndIndex(t *testing.T) { dir := t.TempDir() err := os.WriteFile(filepath.Join(dir, "test.txt"), []byte("static content"), 0o600) if err != nil { t.Fatal(err) } router := routegroup.New(http.NewServeMux()) router.Route(func(base *routegroup.Bundle) { base.Handle("/", http.FileServer(http.Dir(dir))) base.HandleFunc("GET /{$}", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("index page")) }) base.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("login page")) }) }) srv := httptest.NewServer(router) defer srv.Close() t.Run("serve static file", func(t *testing.T) { resp, err := http.Get(srv.URL + "/test.txt") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if got := string(body); got != "static content" { t.Errorf("got %q, want static content", got) } }) t.Run("serve index", func(t *testing.T) { resp, err := http.Get(srv.URL + "/") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if got := string(body); got != "index page" { t.Errorf("got %q, want index page", got) } }) t.Run("serve login page", func(t *testing.T) { resp, err := http.Get(srv.URL + "/login") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if got := string(body); got != "login page" { t.Errorf("got %q, want login page", got) } }) } ================================================ FILE: go.mod ================================================ module github.com/go-pkgz/routegroup go 1.23 ================================================ FILE: go.sum ================================================ ================================================ FILE: group.go ================================================ // Package routegroup provides a way to group routes and applies middleware to them. // Works with the standard library's http.ServeMux. package routegroup import ( "net/http" "regexp" "strings" ) // Bundle represents a group of routes with associated middleware. type Bundle struct { mux *http.ServeMux // the underlying mux to register the routes to basePath string // base path for the group middlewares []func(http.Handler) http.Handler // middlewares stack // optional custom 404 handler notFound http.HandlerFunc // root points to the root bundle for global middleware application. // for the root bundle, root == nil. root *Bundle // routesLocked indicates that routes have been registered on the root bundle // and no further root-level middlewares may be added. routesLocked bool // rootCount captures how many root middlewares were present when this bundle // was created. Used to avoid double-applying root middlewares for per-route wrapping. rootCount int } // New creates a new Group. func New(mux *http.ServeMux) *Bundle { return &Bundle{mux: mux} } // Mount creates a new group with a specified base path. func Mount(mux *http.ServeMux, basePath string) *Bundle { return &Bundle{mux: mux, basePath: basePath} } // ServeHTTP implements the http.Handler interface func (b *Bundle) ServeHTTP(w http.ResponseWriter, r *http.Request) { // resolve the root bundle (where global middlewares live). root := b if b.root != nil { root = b.root } // get the handler and pattern for this request _, pattern := b.mux.Handler(r) // if a pattern was found, create a shallow copy of the request with the pattern set // this allows global middlewares to see the pattern before mux.ServeHTTP is called if pattern != "" { r2 := *r r2.Pattern = pattern r = &r2 } // create a handler that will let the mux do its routing (including setting path parameters) // but intercept 404s to use custom handler if provided muxHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if pattern == "" && root.notFound != nil { // no route matched, need to check if it's a true 404 or a 405 // probe the mux to see what status it would return probe := &statusRecorder{status: http.StatusOK} b.mux.ServeHTTP(probe, r) // if mux wants to return 405 (Method Not Allowed), let it handle the request // to preserve the proper 405 response and Allow header if probe.status == http.StatusMethodNotAllowed { b.mux.ServeHTTP(w, r) return } // it's a true 404, use custom handler root.notFound.ServeHTTP(w, r) return } // let the mux handle the request normally (this sets path parameters) b.mux.ServeHTTP(w, r) }) // apply root (global) middlewares around the mux handler and serve the request. root.wrapGlobal(muxHandler).ServeHTTP(w, r) } // Group creates a new group with the same middleware stack as the original on top of the existing bundle. func (b *Bundle) Group() *Bundle { return b.clone() // copy the middlewares to avoid modifying the original } // Mount creates a new group with a specified base path on top of the existing bundle. func (b *Bundle) Mount(basePath string) *Bundle { g := b.clone() // copy the middlewares to avoid modifying the original g.basePath += basePath return g } // Use adds middleware(s) to the Group. // Middlewares are executed in the order they are added. // Note: Root-level middlewares (added to the root bundle) have access to the matched // route pattern via r.Pattern, but execute before path parameters are parsed. // Therefore, r.PathValue() will return empty strings in root middlewares. // Middlewares on mounted groups execute after routing and have full access to path values. func (b *Bundle) Use(middleware func(http.Handler) http.Handler, more ...func(http.Handler) http.Handler) { // disallow adding middlewares after any routes have been registered on this bundle. if b.routesLocked { panic("routegroup: Use called after routes were registered on this bundle; add middlewares before registering routes or use Group/With for scoped middleware") } b.middlewares = append(b.middlewares, middleware) b.middlewares = append(b.middlewares, more...) } // With adds new middleware(s) to the Group and returns a new Group with the updated middleware stack. // The With method is similar to Use, but instead of modifying the current Group, // it returns a new Group instance with the added middleware(s). // This allows for creating chain of middleware without affecting the original Group. func (b *Bundle) With(middleware func(http.Handler) http.Handler, more ...func(http.Handler) http.Handler) *Bundle { newMiddlewares := make([]func(http.Handler) http.Handler, len(b.middlewares), len(b.middlewares)+len(more)+1) copy(newMiddlewares, b.middlewares) newMiddlewares = append(newMiddlewares, middleware) newMiddlewares = append(newMiddlewares, more...) // preserve root pointer and rootCount nb := &Bundle{mux: b.mux, basePath: b.basePath, middlewares: newMiddlewares, root: b.root, rootCount: b.rootCount} if nb.root == nil { // b is the root, so all b's middlewares are root middlewares nb.root = b nb.rootCount = len(b.middlewares) } return nb } // Handle adds a new route to the Group's mux, applying all middlewares to the handler. func (b *Bundle) Handle(pattern string, handler http.Handler) { b.lockRoot() // lock root on first route registration // for file server paths (ending with /), preserve the pattern as-is if strings.HasSuffix(pattern, "/") { fullPath := b.basePath + pattern b.mux.Handle(fullPath, b.wrapMiddleware(handler)) return } b.register(pattern, handler.ServeHTTP) } // HandleFiles is a helper to serve static files from a directory func (b *Bundle) HandleFiles(pattern string, root http.FileSystem) { b.lockRoot() // lock root on first route registration // normalize pattern to always have trailing slash if !strings.HasSuffix(pattern, "/") { pattern += "/" } // build the full path for registration fullPath := b.basePath + pattern if pattern == "/" && b.basePath == "" { // root case - serve directly without stripping b.mux.Handle("/", b.wrapMiddleware(http.FileServer(root))) return } // for both mounted groups and prefixed paths, strip the fullPath handler := http.StripPrefix(strings.TrimSuffix(fullPath, "/"), http.FileServer(root)) b.mux.Handle(fullPath, b.wrapMiddleware(handler)) } // HandleFunc registers the handler function for the given pattern to the Group's mux. // The handler is wrapped with the Group's middlewares. func (b *Bundle) HandleFunc(pattern string, handler http.HandlerFunc) { b.register(pattern, handler) } // Handler returns the handler and the pattern that matches the request. // It always returns a non-nil handler, see http.ServeMux.Handler documentation for details. func (b *Bundle) Handler(r *http.Request) (h http.Handler, pattern string) { return b.mux.Handler(r) } // DisableNotFoundHandler used to disable auto-registration of a catch-all 404. // // Deprecated: now a no-op retained for API compatibility. func (b *Bundle) DisableNotFoundHandler() {} // NotFoundHandler sets a custom handler for any unmatched routes (404 responses). // Note: This handler is only used for true 404s. Requests to valid paths with // incorrect HTTP methods will still return 405 Method Not Allowed with Allow header. func (b *Bundle) NotFoundHandler(handler http.HandlerFunc) { // always set on the root bundle so custom 404 works regardless of which bundle serves. if b.root != nil { b.root.notFound = handler return } b.notFound = handler } // matches non-space characters, spaces, then anything, i.e. "GET /path/to/resource" var reGo122 = regexp.MustCompile(`^(\S+)\s+(.+)$`) func (b *Bundle) register(pattern string, handler http.HandlerFunc) { b.lockRoot() // lock root on first route registration matches := reGo122.FindStringSubmatch(pattern) var path, method string if len(matches) > 2 { // path in the form "GET /path/to/resource" method = matches[1] path = matches[2] pattern = method + " " + b.basePath + path } else { // path is just "/path/to/resource" path = pattern pattern = b.basePath + pattern // method is not set intentionally here, the request pattern had no method part } // if the pattern is the root path on / change it to /{$} // this keeps handling the root request without becoming a catch-all if pattern == "/" || path == "/" { if method != "" { // preserve the method part if it was set pattern = method + " " + b.basePath + "/{$}" } else { pattern = b.basePath + "/{$}" // no method part, just the path } } b.mux.HandleFunc(pattern, b.wrapMiddleware(handler).ServeHTTP) } // Route allows for configuring the Group inside the configureFn function. // When called on the root bundle, it automatically creates a new group to avoid // accidentally modifying the root bundle's middleware stack. func (b *Bundle) Route(configureFn func(*Bundle)) { // if called on root bundle, auto-create a group for better UX if b.root == nil { child := b.Group() configureFn(child) // if child registered routes, lock root too to prevent Use() after routes if child.routesLocked { b.routesLocked = true } return } configureFn(b) } // HandleRoot adds a handler for the group's root path without trailing slash. // This avoids the 301 redirect that would occur with a "/" pattern. // Method parameter can be empty to register for all HTTP methods. func (b *Bundle) HandleRoot(method string, handler http.Handler) { b.lockRoot() // lock root on first route registration // for empty base path, use "/" to match the root pattern := b.basePath if pattern == "" { pattern = "/" } // add method if specified if method != "" { pattern = method + " " + pattern } b.mux.Handle(pattern, b.wrapMiddleware(handler)) } // HandleRootFunc is like HandleRoot but takes a handler function. func (b *Bundle) HandleRootFunc(method string, handler http.HandlerFunc) { b.lockRoot() // lock root on first route registration // for empty base path, use "/" to match the root pattern := b.basePath if pattern == "" { pattern = "/" } // add method if specified if method != "" { pattern = method + " " + pattern } b.mux.HandleFunc(pattern, b.wrapMiddleware(handler).ServeHTTP) } // wrapMiddleware applies the registered middlewares to a handler. func (b *Bundle) wrapMiddleware(handler http.Handler) http.Handler { // root bundle: don't apply middlewares here, they're applied globally in ServeHTTP if b.root == nil { return handler } // child bundle: apply only middlewares added after mounting (exclude inherited root middlewares) start := b.rootCount if start > len(b.middlewares) { start = len(b.middlewares) // safety: ensure start doesn't exceed bounds } for i := len(b.middlewares) - 1; i >= start; i-- { handler = b.middlewares[i](handler) } return handler } func (b *Bundle) clone() *Bundle { middlewares := make([]func(http.Handler) http.Handler, len(b.middlewares)) copy(middlewares, b.middlewares) // preserve root pointer and rootCount nb := &Bundle{mux: b.mux, basePath: b.basePath, middlewares: middlewares, root: b.root, rootCount: b.rootCount} if nb.root == nil { // b is the root, so all b's middlewares are root middlewares nb.root = b nb.rootCount = len(b.middlewares) } return nb } // Wrap directly wraps the handler with the provided middleware(s). func Wrap(handler http.Handler, mw1 func(http.Handler) http.Handler, mws ...func(http.Handler) http.Handler) http.Handler { for i := len(mws) - 1; i >= 0; i-- { handler = mws[i](handler) } return mw1(handler) // apply the first middleware } // wrapGlobal applies only the root bundle's middlewares to the provided handler. func (b *Bundle) wrapGlobal(handler http.Handler) http.Handler { // resolve root bundle root := b if b.root != nil { root = b.root } for i := len(root.middlewares) - 1; i >= 0; i-- { handler = root.middlewares[i](handler) } return handler } // lockRoot marks this bundle as having registered routes. func (b *Bundle) lockRoot() { b.routesLocked = true } // statusRecorder is a minimal ResponseWriter that only records the status code. // Used to probe what status the mux would return without actually writing a response. type statusRecorder struct { status int } func (r *statusRecorder) Header() http.Header { return make(http.Header) } func (r *statusRecorder) Write([]byte) (int, error) { return 0, nil } func (r *statusRecorder) WriteHeader(status int) { r.status = status } ================================================ FILE: group_test.go ================================================ package routegroup_test import ( "fmt" "io" "net/http" "net/http/httptest" "reflect" "strings" "sync" "testing" "github.com/go-pkgz/routegroup" ) // testMiddleware is simple middleware for testing purposes. func testMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Println("Test middleware") w.Header().Add("X-Test-Middleware", "true") next.ServeHTTP(w, r) }) } func TestGroupMiddleware(t *testing.T) { group := routegroup.New(http.NewServeMux()) group.Use(testMiddleware) group.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodGet, "/test", http.NoBody) if err != nil { t.Fatal(err) } group.ServeHTTP(recorder, request) if recorder.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, recorder.Code) } if header := recorder.Header().Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } } func TestMountedBundleServeHTTP(t *testing.T) { // test ServeHTTP when called directly on a mounted bundle root := routegroup.New(http.NewServeMux()) root.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Root-Middleware", "true") next.ServeHTTP(w, r) }) }) mounted := root.Mount("/api") mounted.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("mounted handler")) }) // serve directly from the mounted bundle (not typical usage but should work) recorder := httptest.NewRecorder() request, _ := http.NewRequest(http.MethodGet, "/api/test", http.NoBody) mounted.ServeHTTP(recorder, request) if recorder.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, recorder.Code) } if body := recorder.Body.String(); body != "mounted handler" { t.Errorf("Expected body 'mounted handler', got '%s'", body) } // should still apply root middleware when serving from mounted if header := recorder.Header().Get("X-Root-Middleware"); header != "true" { t.Errorf("Expected X-Root-Middleware header to be 'true', got '%s'", header) } } func TestGroupHandle(t *testing.T) { group := routegroup.New(http.NewServeMux()) group.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("test handler")) }) group.Handle("GET /test2", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("test2 handler")) })) t.Run("handler function", func(t *testing.T) { recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodGet, "/test", http.NoBody) if err != nil { t.Fatal(err) } group.ServeHTTP(recorder, request) if recorder.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, recorder.Code) } if body := recorder.Body.String(); body != "test handler" { t.Errorf("Expected body 'test handler', got '%s'", body) } }) t.Run("handle, wrong method -> 405", func(t *testing.T) { recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodPost, "/test2", http.NoBody) if err != nil { t.Fatal(err) } group.ServeHTTP(recorder, request) if recorder.Code != http.StatusMethodNotAllowed { t.Errorf("Expected status code %d, got %d", http.StatusMethodNotAllowed, recorder.Code) } if allow := recorder.Header().Get("Allow"); !strings.Contains(allow, http.MethodGet) { t.Errorf("expected Allow header to contain GET, got %q", allow) } }) t.Run("handler", func(t *testing.T) { recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodGet, "/test2", http.NoBody) if err != nil { t.Fatal(err) } group.ServeHTTP(recorder, request) if recorder.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, recorder.Code) } if body := recorder.Body.String(); body != "test2 handler" { t.Errorf("Expected body 'test2 handler', got '%s'", body) } }) } func TestBundleHandler(t *testing.T) { group := routegroup.New(http.NewServeMux()) group.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) t.Run("handler returns correct pattern and handler", func(t *testing.T) { request, err := http.NewRequest(http.MethodGet, "/test", http.NoBody) if err != nil { t.Fatal(err) } handler, pattern := group.Handler(request) if handler == nil { t.Error("Expected handler to be not nil") } if pattern != "/test" { t.Errorf("Expected pattern '/test', got '%s'", pattern) } }) t.Run("handler returns not-nil and empty pattern for non-existing route", func(t *testing.T) { request, err := http.NewRequest(http.MethodGet, "/non-existing", http.NoBody) if err != nil { t.Fatal(err) } handler, pattern := group.Handler(request) if handler == nil { t.Error("Expected handler to be not nil") } if pattern != "" { t.Errorf("Expected empty pattern, got '%s'", pattern) } }) } func TestGroupRoute(t *testing.T) { group := routegroup.New(http.NewServeMux()) group.Route(func(g *routegroup.Bundle) { g.Use(testMiddleware) g.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) g.HandleFunc("POST /test2", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) }) t.Run("GET /test", func(t *testing.T) { recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodGet, "/test", http.NoBody) if err != nil { t.Fatal(err) } group.ServeHTTP(recorder, request) if recorder.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, recorder.Code) } if header := recorder.Header().Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) t.Run("POST /test2", func(t *testing.T) { recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodPost, "/test2", http.NoBody) if err != nil { t.Fatal(err) } group.ServeHTTP(recorder, request) if recorder.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, recorder.Code) } if header := recorder.Header().Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) t.Run("GET /test2 wrong method -> 405", func(t *testing.T) { recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodGet, "/test2", http.NoBody) if err != nil { t.Fatal(err) } group.ServeHTTP(recorder, request) if recorder.Code != http.StatusMethodNotAllowed { t.Errorf("Expected status code %d, got %d", http.StatusMethodNotAllowed, recorder.Code) } if allow := recorder.Header().Get("Allow"); allow != http.MethodPost { t.Errorf("expected Allow header to be POST, got %q", allow) } // with auto-wrapping, middleware is in a sub-group and doesn't apply to 405s if header := recorder.Header().Get("X-Test-Middleware"); header != "" { t.Errorf("Expected header X-Test-Middleware to be empty for 405 (group middleware), got '%s'", header) } }) } func TestGroupRouteAutoWrapping(t *testing.T) { // test that calling Route on root bundle auto-creates a group router := routegroup.New(http.NewServeMux()) // add middleware to router first router.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Root-Middleware", "true") next.ServeHTTP(w, r) }) }) // calling Route on root should auto-create a group router.Route(func(g *routegroup.Bundle) { g.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Group-Middleware", "true") next.ServeHTTP(w, r) }) }) g.HandleFunc("/grouped", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("grouped handler")) }) }) // add another route directly to root router.HandleFunc("/root", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("root handler")) }) t.Run("grouped route has both middlewares", func(t *testing.T) { recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodGet, "/grouped", http.NoBody) if err != nil { t.Fatal(err) } router.ServeHTTP(recorder, request) if recorder.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, recorder.Code) } if body := recorder.Body.String(); body != "grouped handler" { t.Errorf("Expected body 'grouped handler', got '%s'", body) } if header := recorder.Header().Get("X-Root-Middleware"); header != "true" { t.Errorf("Expected X-Root-Middleware to be 'true', got '%s'", header) } if header := recorder.Header().Get("X-Group-Middleware"); header != "true" { t.Errorf("Expected X-Group-Middleware to be 'true', got '%s'", header) } }) t.Run("root route only has root middleware", func(t *testing.T) { recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodGet, "/root", http.NoBody) if err != nil { t.Fatal(err) } router.ServeHTTP(recorder, request) if recorder.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, recorder.Code) } if body := recorder.Body.String(); body != "root handler" { t.Errorf("Expected body 'root handler', got '%s'", body) } if header := recorder.Header().Get("X-Root-Middleware"); header != "true" { t.Errorf("Expected X-Root-Middleware to be 'true', got '%s'", header) } if header := recorder.Header().Get("X-Group-Middleware"); header != "" { t.Errorf("Expected X-Group-Middleware to be empty, got '%s'", header) } }) } func TestGroupWithMiddleware(t *testing.T) { group := routegroup.New(http.NewServeMux()) group.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("X-Original-Middleware", "true") next.ServeHTTP(w, r) }) }) newGroup := group.With(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("X-New-Middleware", "true") next.ServeHTTP(w, r) }) }) newGroup.HandleFunc("/with-test", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) newGroup.HandleFunc("POST /with-test-post-only", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) t.Run("GET /with-test", func(t *testing.T) { recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodGet, "/with-test", http.NoBody) if err != nil { t.Fatal(err) } group.ServeHTTP(recorder, request) if recorder.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, recorder.Code) } if header := recorder.Header().Get("X-Original-Middleware"); header != "true" { t.Errorf("Expected header X-Original-Middleware to be 'true', got '%s'", header) } if header := recorder.Header().Get("X-New-Middleware"); header != "true" { t.Errorf("Expected header X-New-Middleware to be 'true', got '%s'", header) } }) t.Run("POST /with-test", func(t *testing.T) { recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodPost, "/with-test", http.NoBody) if err != nil { t.Fatal(err) } group.ServeHTTP(recorder, request) if recorder.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, recorder.Code) } if header := recorder.Header().Get("X-Original-Middleware"); header != "true" { t.Errorf("Expected header X-Original-Middleware to be 'true', got '%s'", header) } if header := recorder.Header().Get("X-New-Middleware"); header != "true" { t.Errorf("Expected header X-New-Middleware to be 'true', got '%s'", header) } }) t.Run("GET /not-found", func(t *testing.T) { recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodGet, "/not-found", http.NoBody) if err != nil { t.Fatal(err) } group.ServeHTTP(recorder, request) if recorder.Code != http.StatusNotFound { t.Errorf("Expected status code %d, got %d", http.StatusNotFound, recorder.Code) } if header := recorder.Header().Get("X-Original-Middleware"); header != "true" { t.Errorf("Expected header X-Original-Middleware to be 'true', got '%s'", header) } if header := recorder.Header().Get("X-New-Middleware"); header != "" { t.Errorf("Expected header X-New-Middleware to be not set, got '%s'", header) } }) t.Run("POST /with-test-post-only", func(t *testing.T) { recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodPost, "/with-test-post-only", http.NoBody) if err != nil { t.Fatal(err) } group.ServeHTTP(recorder, request) if recorder.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, recorder.Code) } if header := recorder.Header().Get("X-Original-Middleware"); header != "true" { t.Errorf("Expected header X-Original-Middleware to be 'true', got '%s'", header) } if header := recorder.Header().Get("X-New-Middleware"); header != "true" { t.Errorf("Expected header X-New-Middleware to be 'true', got '%s'", header) } }) t.Run("GET /with-test-post-only wrong method -> 405", func(t *testing.T) { recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodGet, "/with-test-post-only", http.NoBody) if err != nil { t.Fatal(err) } group.ServeHTTP(recorder, request) if recorder.Code != http.StatusMethodNotAllowed { t.Errorf("Expected status code %d, got %d", http.StatusMethodNotAllowed, recorder.Code) } if allow := recorder.Header().Get("Allow"); allow != http.MethodPost { t.Errorf("expected Allow header to be POST, got %q", allow) } if header := recorder.Header().Get("X-Original-Middleware"); header != "true" { t.Errorf("Expected header X-Original-Middleware to be 'true', got '%s'", header) } if header := recorder.Header().Get("X-New-Middleware"); header != "" { t.Errorf("Expected header X-New-Middleware to be not set, got '%s'", header) } }) } func TestGroupWithMiddlewareAndTopLevelAfter(t *testing.T) { group := routegroup.New(http.NewServeMux()) // create subgroup and register route sub := group.Group() sub.Use(testMiddleware) sub.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("test handler")) }) // calling Use on the same subgroup after routes should panic defer func() { if r := recover(); r == nil { t.Fatalf("expected panic on Use after routes registration on the same bundle") } }() sub.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("X-Top-Middleware", "true") next.ServeHTTP(w, r) }) }) } // Test that calling Use after routes are registered on the same bundle panics, // and that calling Use on a parent after child routes is allowed. func TestUseAfterRoutesPanicsAndParentAllowed(t *testing.T) { t.Run("root: Use after route panics", func(t *testing.T) { router := routegroup.New(http.NewServeMux()) router.HandleFunc("/r", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) defer func() { if r := recover(); r == nil { t.Fatalf("expected panic on root.Use after routes registration on the same bundle") } }() router.Use(testMiddleware) }) t.Run("root: Use after Route() with auto-wrap panics", func(t *testing.T) { router := routegroup.New(http.NewServeMux()) router.Route(func(b *routegroup.Bundle) { b.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) {}) }) defer func() { if r := recover(); r == nil { t.Fatalf("expected panic on root.Use after Route() registered routes") } }() router.Use(testMiddleware) }) t.Run("parent: Use after child routes is allowed", func(t *testing.T) { router := routegroup.New(http.NewServeMux()) child := router.Group() child.HandleFunc("/child", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("ok")) }) // parent hasn't registered any routes yet; calling Use should not panic router.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Parent", "true") next.ServeHTTP(w, r) }) }) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/child", http.NoBody) router.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("unexpected status %d", rec.Code) } if hv := rec.Header().Get("X-Parent"); hv != "true" { t.Fatalf("expected global parent middleware to apply, got %q", hv) } }) } // DisableNotFoundHandler semantics are removed; global middlewares always apply. func TestGroupWithMoreMiddleware(t *testing.T) { group := routegroup.New(http.NewServeMux()) newGroup := group.With( func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("X-New-Middleware", "true") next.ServeHTTP(w, r) }) }, func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("X-More-Middleware", "true") next.ServeHTTP(w, r) }) }, ) newGroup.HandleFunc("/with-test", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodGet, "/with-test", http.NoBody) if err != nil { t.Fatal(err) } group.ServeHTTP(recorder, request) if recorder.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, recorder.Code) } if header := recorder.Header().Get("X-New-Middleware"); header != "true" { t.Errorf("Expected header X-New-Middleware to be 'true', got '%s'", header) } if header := recorder.Header().Get("X-More-Middleware"); header != "true" { t.Errorf("Expected header X-More-Middleware to be 'true', got '%s'", header) } } func TestMiddlewareOrder(t *testing.T) { var order []string mkMiddleware := func(name string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { order = append(order, "before "+name) next.ServeHTTP(w, r) order = append(order, "after "+name) }) } } group := routegroup.New(http.NewServeMux()) group.Use(mkMiddleware("root")) api := group.Mount("/api") api.Use(mkMiddleware("api")) users := api.With(mkMiddleware("users")) users.HandleFunc("/action", func(w http.ResponseWriter, _ *http.Request) { order = append(order, "handler") _, _ = w.Write([]byte("ok")) }) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/action", http.NoBody) group.ServeHTTP(rec, req) expected := []string{ "before root", "before api", "before users", "handler", "after users", "after api", "after root", } if !reflect.DeepEqual(order, expected) { t.Errorf("wrong middleware execution order\nwant: %v\ngot: %v", expected, order) } } func TestConcurrentRequests(t *testing.T) { group := routegroup.New(http.NewServeMux()) group.HandleFunc("/concurrent", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() rec := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/concurrent", http.NoBody) group.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("got %d, want %d", rec.Code, http.StatusOK) } }() } wg.Wait() } func TestHTTPServerWrap(t *testing.T) { mw1 := func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-MW1", "1") h.ServeHTTP(w, r) }) } mw2 := func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-MW2", "2") h.ServeHTTP(w, r) }) } handlers := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("test handler")) }) ts := httptest.NewServer(routegroup.Wrap(handlers, mw1, mw2)) defer ts.Close() resp, err := http.Get(ts.URL) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) } if header := resp.Header.Get("X-MW1"); header != "1" { t.Errorf("Expected header X-MW1 to be '1', got '%s'", header) } if header := resp.Header.Get("X-MW2"); header != "2" { t.Errorf("Expected header X-MW2 to be '2', got '%s'", header) } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if string(body) != "test handler" { t.Errorf("Expected body 'test handler', got '%s'", string(body)) } } ================================================ FILE: middleware_test.go ================================================ package routegroup_test import ( "context" "net/http" "net/http/httptest" "sync" "testing" "github.com/go-pkgz/routegroup" ) func TestMiddlewareCanAccessPathValues(t *testing.T) { // test path value accessibility in middlewares // EXPECTED: root/global middlewares can't see PathValue (runs before routing) // EXPECTED: mounted group middlewares CAN see PathValue (applied at registration) tests := []struct { name string setupFunc func() *routegroup.Bundle requestPath string expectedID string expectedUser string }{ { name: "root middleware cannot access path params (expected)", setupFunc: func() *routegroup.Bundle { rtr := routegroup.New(http.NewServeMux()) // root middleware runs BEFORE mux.ServeHTTP sets path values rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // PathValue is empty here - this is EXPECTED id := r.PathValue("id") w.Header().Set("X-Root-Middleware-ID", id) // will be empty // but Pattern IS available (our fix from #24) w.Header().Set("X-Root-Pattern", r.Pattern) next.ServeHTTP(w, r) }) }) rtr.HandleFunc("GET /users/{id}", func(w http.ResponseWriter, r *http.Request) { // handler CAN access path values w.Header().Set("X-Handler-ID", r.PathValue("id")) w.WriteHeader(http.StatusOK) }) return rtr }, requestPath: "/users/123", expectedID: "", // empty in root middleware is EXPECTED }, { name: "mounted group with path params", setupFunc: func() *routegroup.Bundle { rtr := routegroup.New(http.NewServeMux()) api := rtr.Mount("/api") api.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // path values should be accessible in mounted group middleware id := r.PathValue("id") user := r.PathValue("user") if id != "" { w.Header().Set("X-Middleware-ID", id) } if user != "" { w.Header().Set("X-Middleware-User", user) } next.ServeHTTP(w, r) }) }) api.HandleFunc("GET /users/{user}/posts/{id}", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) return rtr }, requestPath: "/api/users/john/posts/456", expectedID: "456", expectedUser: "john", }, { name: "nested mounted groups with params", setupFunc: func() *routegroup.Bundle { rtr := routegroup.New(http.NewServeMux()) v1 := rtr.Mount("/v1") users := v1.Mount("/users") users.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") action := r.PathValue("action") if id != "" { w.Header().Set("X-Middleware-ID", id) } if action != "" { w.Header().Set("X-Middleware-Action", action) } next.ServeHTTP(w, r) }) }) users.HandleFunc("GET /{id}/{action}", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) return rtr }, requestPath: "/v1/users/789/edit", expectedID: "789", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bundle := tt.setupFunc() req, err := http.NewRequest(http.MethodGet, tt.requestPath, http.NoBody) if err != nil { t.Fatal(err) } rec := httptest.NewRecorder() bundle.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("expected status 200, got %d", rec.Code) } // verify path value accessibility based on middleware type if tt.name == "root middleware cannot access path params (expected)" { // verify root middleware can't see path values if got := rec.Header().Get("X-Root-Middleware-ID"); got != "" { t.Errorf("root middleware should not see path values, got %q", got) } // but handler can if got := rec.Header().Get("X-Handler-ID"); got != "123" { t.Errorf("handler should see path value, got %q", got) } // and Pattern is available (from our fix) if got := rec.Header().Get("X-Root-Pattern"); got != "GET /users/{id}" { t.Errorf("root middleware should see pattern, got %q", got) } } else { // mounted group middlewares CAN see path values if tt.expectedID != "" { if got := rec.Header().Get("X-Middleware-ID"); got != tt.expectedID { t.Errorf("middleware ID = %q, want %q", got, tt.expectedID) } } if tt.expectedUser != "" { if got := rec.Header().Get("X-Middleware-User"); got != tt.expectedUser { t.Errorf("middleware User = %q, want %q", got, tt.expectedUser) } } } }) } } func TestMiddlewareAbortChain(t *testing.T) { // test that middleware can stop the chain by not calling next.ServeHTTP() // this is critical for auth/security middleware t.Run("auth middleware aborts on unauthorized", func(t *testing.T) { handlerCalled := false middleware2Called := false rtr := routegroup.New(http.NewServeMux()) // first middleware - auth check that aborts rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Auth-Checked", "true") if r.Header.Get("Authorization") == "" { // abort chain - don't call next w.WriteHeader(http.StatusUnauthorized) _, _ = w.Write([]byte("unauthorized")) return } next.ServeHTTP(w, r) }) }) // second middleware - should not be called if first aborts rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { middleware2Called = true w.Header().Set("X-Middleware2", "called") next.ServeHTTP(w, r) }) }) rtr.HandleFunc("GET /protected", func(w http.ResponseWriter, r *http.Request) { handlerCalled = true w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("protected content")) }) // test unauthorized request req, _ := http.NewRequest(http.MethodGet, "/protected", http.NoBody) rec := httptest.NewRecorder() rtr.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", rec.Code) } if rec.Body.String() != "unauthorized" { t.Errorf("expected 'unauthorized', got %q", rec.Body.String()) } if handlerCalled { t.Error("handler should not be called when middleware aborts") } if middleware2Called { t.Error("second middleware should not be called when first aborts") } if rec.Header().Get("X-Auth-Checked") != "true" { t.Error("first middleware should have run") } if rec.Header().Get("X-Middleware2") != "" { t.Error("second middleware should not have set header") } }) t.Run("middleware abort with mounted groups", func(t *testing.T) { handlerCalled := false rtr := routegroup.New(http.NewServeMux()) // root middleware - always passes rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Root", "passed") next.ServeHTTP(w, r) }) }) api := rtr.Mount("/api") // api middleware - aborts on missing API key api.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-API-Check", "true") if r.Header.Get("X-API-Key") == "" { w.WriteHeader(http.StatusForbidden) _, _ = w.Write([]byte("API key required")) return // abort } next.ServeHTTP(w, r) }) }) api.HandleFunc("GET /data", func(w http.ResponseWriter, r *http.Request) { handlerCalled = true w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("data")) }) // test without API key req, _ := http.NewRequest(http.MethodGet, "/api/data", http.NoBody) rec := httptest.NewRecorder() rtr.ServeHTTP(rec, req) if rec.Code != http.StatusForbidden { t.Errorf("expected 403, got %d", rec.Code) } if rec.Body.String() != "API key required" { t.Errorf("expected 'API key required', got %q", rec.Body.String()) } if handlerCalled { t.Error("handler should not be called when middleware aborts") } // root middleware should have run if rec.Header().Get("X-Root") != "passed" { t.Error("root middleware should have run") } // api middleware should have run and aborted if rec.Header().Get("X-API-Check") != "true" { t.Error("API middleware should have run") } }) t.Run("middleware chain continues with authorization", func(t *testing.T) { handlerCalled := false rtr := routegroup.New(http.NewServeMux()) // auth middleware that passes with correct header rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Auth-Checked", "true") if r.Header.Get("Authorization") == "" { w.WriteHeader(http.StatusUnauthorized) return } next.ServeHTTP(w, r) }) }) // second middleware rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Middleware2", "called") next.ServeHTTP(w, r) }) }) rtr.HandleFunc("GET /protected", func(w http.ResponseWriter, r *http.Request) { handlerCalled = true w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("protected content")) }) // test authorized request req, _ := http.NewRequest(http.MethodGet, "/protected", http.NoBody) req.Header.Set("Authorization", "Bearer token") rec := httptest.NewRecorder() rtr.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("expected 200, got %d", rec.Code) } if rec.Body.String() != "protected content" { t.Errorf("expected 'protected content', got %q", rec.Body.String()) } if !handlerCalled { t.Error("handler should be called with authorization") } if rec.Header().Get("X-Auth-Checked") != "true" { t.Error("auth middleware should have run") } if rec.Header().Get("X-Middleware2") != "called" { t.Error("second middleware should have run") } }) } func TestWithMethodMiddlewareCounting(t *testing.T) { // test that With() properly tracks middleware count to avoid double execution // this is critical to prevent issues like #24 t.Run("With() creates new bundle with correct middleware count", func(t *testing.T) { callCounts := make(map[string]int) rtr := routegroup.New(http.NewServeMux()) // root middleware rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCounts["root"]++ w.Header().Set("X-MW-Order", "root") next.ServeHTTP(w, r) }) }) // create new bundle with additional middleware using With() withBundle := rtr.With(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCounts["with"]++ existing := w.Header().Get("X-MW-Order") w.Header().Set("X-MW-Order", existing+",with") next.ServeHTTP(w, r) }) }) // register route on the With bundle withBundle.HandleFunc("GET /test", func(w http.ResponseWriter, r *http.Request) { callCounts["handler"]++ w.WriteHeader(http.StatusOK) }) req, _ := http.NewRequest(http.MethodGet, "/test", http.NoBody) rec := httptest.NewRecorder() rtr.ServeHTTP(rec, req) // verify each middleware called exactly once if callCounts["root"] != 1 { t.Errorf("root middleware called %d times, expected 1", callCounts["root"]) } if callCounts["with"] != 1 { t.Errorf("with middleware called %d times, expected 1", callCounts["with"]) } if callCounts["handler"] != 1 { t.Errorf("handler called %d times, expected 1", callCounts["handler"]) } // verify order if order := rec.Header().Get("X-MW-Order"); order != "root,with" { t.Errorf("middleware order = %q, expected 'root,with'", order) } }) t.Run("multiple With() calls maintain proper count", func(t *testing.T) { callCounts := make(map[string]int) rtr := routegroup.New(http.NewServeMux()) rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCounts["root"]++ next.ServeHTTP(w, r) }) }) // first With() bundle1 := rtr.With(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCounts["with1"]++ next.ServeHTTP(w, r) }) }) // second With() on top of first bundle2 := bundle1.With(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCounts["with2"]++ next.ServeHTTP(w, r) }) }) bundle2.HandleFunc("GET /nested", func(w http.ResponseWriter, r *http.Request) { callCounts["handler"]++ w.WriteHeader(http.StatusOK) }) req, _ := http.NewRequest(http.MethodGet, "/nested", http.NoBody) rec := httptest.NewRecorder() rtr.ServeHTTP(rec, req) // all middlewares should be called exactly once if callCounts["root"] != 1 { t.Errorf("root middleware called %d times, expected 1", callCounts["root"]) } if callCounts["with1"] != 1 { t.Errorf("with1 middleware called %d times, expected 1", callCounts["with1"]) } if callCounts["with2"] != 1 { t.Errorf("with2 middleware called %d times, expected 1", callCounts["with2"]) } if callCounts["handler"] != 1 { t.Errorf("handler called %d times, expected 1", callCounts["handler"]) } }) t.Run("With() on mounted group maintains correct count", func(t *testing.T) { callCounts := make(map[string]int) rtr := routegroup.New(http.NewServeMux()) // root middleware rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCounts["root"]++ next.ServeHTTP(w, r) }) }) // mount a group api := rtr.Mount("/api") // add middleware to mounted group api.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCounts["api"]++ next.ServeHTTP(w, r) }) }) // use With() on mounted group apiWith := api.With(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCounts["api-with"]++ next.ServeHTTP(w, r) }) }) apiWith.HandleFunc("GET /data", func(w http.ResponseWriter, r *http.Request) { callCounts["handler"]++ w.WriteHeader(http.StatusOK) }) req, _ := http.NewRequest(http.MethodGet, "/api/data", http.NoBody) rec := httptest.NewRecorder() rtr.ServeHTTP(rec, req) // verify no double execution if callCounts["root"] != 1 { t.Errorf("root middleware called %d times, expected 1", callCounts["root"]) } if callCounts["api"] != 1 { t.Errorf("api middleware called %d times, expected 1", callCounts["api"]) } if callCounts["api-with"] != 1 { t.Errorf("api-with middleware called %d times, expected 1", callCounts["api-with"]) } if callCounts["handler"] != 1 { t.Errorf("handler called %d times, expected 1", callCounts["handler"]) } }) } func TestResponseWriterInterception(t *testing.T) { // test that middlewares can intercept and modify responses // this is critical for logging, metrics, response manipulation t.Run("middleware can capture status code", func(t *testing.T) { var capturedStatus int rtr := routegroup.New(http.NewServeMux()) // status capturing middleware rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // wrap response writer to capture status wrapped := &statusRecorder{ResponseWriter: w, status: http.StatusOK} next.ServeHTTP(wrapped, r) capturedStatus = wrapped.status }) }) rtr.HandleFunc("GET /success", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusCreated) _, _ = w.Write([]byte("created")) }) rtr.HandleFunc("GET /error", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) _, _ = w.Write([]byte("error")) }) // test success response req1, _ := http.NewRequest(http.MethodGet, "/success", http.NoBody) rec1 := httptest.NewRecorder() rtr.ServeHTTP(rec1, req1) if capturedStatus != http.StatusCreated { t.Errorf("captured status = %d, want %d", capturedStatus, http.StatusCreated) } if rec1.Code != http.StatusCreated { t.Errorf("response status = %d, want %d", rec1.Code, http.StatusCreated) } // test error response req2, _ := http.NewRequest(http.MethodGet, "/error", http.NoBody) rec2 := httptest.NewRecorder() rtr.ServeHTTP(rec2, req2) if capturedStatus != http.StatusInternalServerError { t.Errorf("captured status = %d, want %d", capturedStatus, http.StatusInternalServerError) } }) t.Run("middleware can modify response body", func(t *testing.T) { rtr := routegroup.New(http.NewServeMux()) // response modifying middleware rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // capture response in buffer buf := &responseBuffer{ResponseWriter: w, buffer: []byte{}} next.ServeHTTP(buf, r) // modify and write actual response modified := append([]byte("PREFIX:"), buf.buffer...) _, _ = w.Write(modified) }) }) rtr.HandleFunc("GET /test", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("original")) }) req, _ := http.NewRequest(http.MethodGet, "/test", http.NoBody) rec := httptest.NewRecorder() rtr.ServeHTTP(rec, req) if body := rec.Body.String(); body != "PREFIX:original" { t.Errorf("body = %q, want %q", body, "PREFIX:original") } }) t.Run("multiple middlewares can wrap response writer", func(t *testing.T) { var statuses []int rtr := routegroup.New(http.NewServeMux()) // first middleware - captures status rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { wrapped := &statusRecorder{ResponseWriter: w, status: http.StatusOK} next.ServeHTTP(wrapped, r) statuses = append(statuses, wrapped.status) }) }) // second middleware - also captures status rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { wrapped := &statusRecorder{ResponseWriter: w, status: http.StatusOK} next.ServeHTTP(wrapped, r) statuses = append(statuses, wrapped.status) }) }) rtr.HandleFunc("GET /test", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusAccepted) }) req, _ := http.NewRequest(http.MethodGet, "/test", http.NoBody) rec := httptest.NewRecorder() statuses = nil rtr.ServeHTTP(rec, req) // both middlewares should capture the same status if len(statuses) != 2 { t.Errorf("expected 2 status captures, got %d", len(statuses)) } if len(statuses) == 2 { if statuses[0] != http.StatusAccepted { t.Errorf("first middleware captured %d, want %d", statuses[0], http.StatusAccepted) } if statuses[1] != http.StatusAccepted { t.Errorf("second middleware captured %d, want %d", statuses[1], http.StatusAccepted) } } }) } // statusRecorder wraps ResponseWriter to capture status code type statusRecorder struct { http.ResponseWriter status int written bool } func (r *statusRecorder) WriteHeader(status int) { if !r.written { r.status = status r.written = true } r.ResponseWriter.WriteHeader(status) } func (r *statusRecorder) Write(b []byte) (int, error) { if !r.written { r.written = true // status remains default (200) if WriteHeader wasn't called } return r.ResponseWriter.Write(b) } // responseBuffer captures response body type responseBuffer struct { http.ResponseWriter buffer []byte } func (b *responseBuffer) Write(data []byte) (int, error) { b.buffer = append(b.buffer, data...) return len(data), nil } func TestContextPropagation(t *testing.T) { // test that context values propagate through middleware chain // this is critical for request IDs, user auth, tracing, etc. type contextKey string const ( requestIDKey contextKey = "requestID" userIDKey contextKey = "userID" traceKey contextKey = "trace" ) t.Run("context values propagate through chain", func(t *testing.T) { rtr := routegroup.New(http.NewServeMux()) // first middleware - adds request ID rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), requestIDKey, "req-123") next.ServeHTTP(w, r.WithContext(ctx)) }) }) // second middleware - adds user ID and verifies request ID rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // verify request ID from first middleware if reqID := r.Context().Value(requestIDKey); reqID != "req-123" { t.Errorf("middleware 2: requestID = %v, want 'req-123'", reqID) } ctx := context.WithValue(r.Context(), userIDKey, "user-456") next.ServeHTTP(w, r.WithContext(ctx)) }) }) // handler - verifies both values rtr.HandleFunc("GET /test", func(w http.ResponseWriter, r *http.Request) { reqID := r.Context().Value(requestIDKey) userID := r.Context().Value(userIDKey) if reqID != "req-123" { t.Errorf("handler: requestID = %v, want 'req-123'", reqID) } if userID != "user-456" { t.Errorf("handler: userID = %v, want 'user-456'", userID) } w.WriteHeader(http.StatusOK) }) req, _ := http.NewRequest(http.MethodGet, "/test", http.NoBody) rec := httptest.NewRecorder() rtr.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) } }) t.Run("context cancellation stops chain", func(t *testing.T) { handlerCalled := false middleware2Called := false rtr := routegroup.New(http.NewServeMux()) // first middleware - cancels context rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithCancel(r.Context()) cancel() // immediately cancel // check if context is done before calling next select { case <-ctx.Done(): w.WriteHeader(http.StatusServiceUnavailable) return default: next.ServeHTTP(w, r.WithContext(ctx)) } }) }) // second middleware - should not be called rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { middleware2Called = true next.ServeHTTP(w, r) }) }) rtr.HandleFunc("GET /test", func(w http.ResponseWriter, r *http.Request) { handlerCalled = true w.WriteHeader(http.StatusOK) }) req, _ := http.NewRequest(http.MethodGet, "/test", http.NoBody) rec := httptest.NewRecorder() rtr.ServeHTTP(rec, req) if rec.Code != http.StatusServiceUnavailable { t.Errorf("status = %d, want %d", rec.Code, http.StatusServiceUnavailable) } if middleware2Called { t.Error("middleware 2 should not be called after context cancellation") } if handlerCalled { t.Error("handler should not be called after context cancellation") } }) t.Run("context values work with mounted groups", func(t *testing.T) { rtr := routegroup.New(http.NewServeMux()) // root middleware - adds trace ID rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), traceKey, "trace-root") next.ServeHTTP(w, r.WithContext(ctx)) }) }) api := rtr.Mount("/api") // api middleware - adds request ID and checks trace api.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // verify trace from root if trace := r.Context().Value(traceKey); trace != "trace-root" { t.Errorf("api middleware: trace = %v, want 'trace-root'", trace) } ctx := context.WithValue(r.Context(), requestIDKey, "req-api") next.ServeHTTP(w, r.WithContext(ctx)) }) }) api.HandleFunc("GET /data", func(w http.ResponseWriter, r *http.Request) { trace := r.Context().Value(traceKey) reqID := r.Context().Value(requestIDKey) if trace != "trace-root" { t.Errorf("handler: trace = %v, want 'trace-root'", trace) } if reqID != "req-api" { t.Errorf("handler: requestID = %v, want 'req-api'", reqID) } w.WriteHeader(http.StatusOK) }) req, _ := http.NewRequest(http.MethodGet, "/api/data", http.NoBody) rec := httptest.NewRecorder() rtr.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) } }) } func TestMiddlewareModifiesRequest(t *testing.T) { // test that middlewares can modify request properties // this is critical for adding headers, modifying paths, etc. t.Run("middleware can add and modify headers", func(t *testing.T) { var capturedHeaders http.Header rtr := routegroup.New(http.NewServeMux()) // first middleware - adds header rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.Header.Set("X-Request-ID", "req-123") r.Header.Set("X-Custom", "value1") next.ServeHTTP(w, r) }) }) // second middleware - modifies header rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // verify first middleware's header if reqID := r.Header.Get("X-Request-ID"); reqID != "req-123" { t.Errorf("middleware 2: X-Request-ID = %q, want 'req-123'", reqID) } // modify existing header r.Header.Set("X-Custom", "value2") r.Header.Set("X-Middleware-2", "processed") next.ServeHTTP(w, r) }) }) rtr.HandleFunc("GET /test", func(w http.ResponseWriter, r *http.Request) { capturedHeaders = r.Header.Clone() w.WriteHeader(http.StatusOK) }) req, _ := http.NewRequest(http.MethodGet, "/test", http.NoBody) req.Header.Set("X-Original", "client-value") rec := httptest.NewRecorder() rtr.ServeHTTP(rec, req) // verify headers in handler if capturedHeaders.Get("X-Original") != "client-value" { t.Errorf("X-Original = %q, want 'client-value'", capturedHeaders.Get("X-Original")) } if capturedHeaders.Get("X-Request-ID") != "req-123" { t.Errorf("X-Request-ID = %q, want 'req-123'", capturedHeaders.Get("X-Request-ID")) } if capturedHeaders.Get("X-Custom") != "value2" { t.Errorf("X-Custom = %q, want 'value2' (should be modified)", capturedHeaders.Get("X-Custom")) } if capturedHeaders.Get("X-Middleware-2") != "processed" { t.Errorf("X-Middleware-2 = %q, want 'processed'", capturedHeaders.Get("X-Middleware-2")) } }) t.Run("middleware can modify URL path", func(t *testing.T) { var capturedPath string rtr := routegroup.New(http.NewServeMux()) // middleware that modifies URL rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // create a new URL with modified path newURL := *r.URL newURL.Path = "/modified" + r.URL.Path // create new request with modified URL r2 := r.Clone(r.Context()) r2.URL = &newURL next.ServeHTTP(w, r2) }) }) // register handler for modified path rtr.HandleFunc("GET /modified/original", func(w http.ResponseWriter, r *http.Request) { capturedPath = r.URL.Path w.WriteHeader(http.StatusOK) }) req, _ := http.NewRequest(http.MethodGet, "/original", http.NoBody) rec := httptest.NewRecorder() rtr.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) } if capturedPath != "/modified/original" { t.Errorf("captured path = %q, want '/modified/original'", capturedPath) } }) t.Run("middleware modifications work with mounted groups", func(t *testing.T) { var handlerHeaders http.Header rtr := routegroup.New(http.NewServeMux()) // root middleware - adds base header rtr.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.Header.Set("X-Root", "root-value") next.ServeHTTP(w, r) }) }) api := rtr.Mount("/api") // api middleware - adds API header api.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // verify root header exists if root := r.Header.Get("X-Root"); root != "root-value" { t.Errorf("api middleware: X-Root = %q, want 'root-value'", root) } r.Header.Set("X-API", "api-value") r.Header.Set("X-API-Version", "v1") next.ServeHTTP(w, r) }) }) api.HandleFunc("GET /data", func(w http.ResponseWriter, r *http.Request) { handlerHeaders = r.Header.Clone() w.WriteHeader(http.StatusOK) }) req, _ := http.NewRequest(http.MethodGet, "/api/data", http.NoBody) rec := httptest.NewRecorder() rtr.ServeHTTP(rec, req) // verify all headers made it to handler if handlerHeaders.Get("X-Root") != "root-value" { t.Errorf("X-Root = %q, want 'root-value'", handlerHeaders.Get("X-Root")) } if handlerHeaders.Get("X-API") != "api-value" { t.Errorf("X-API = %q, want 'api-value'", handlerHeaders.Get("X-API")) } if handlerHeaders.Get("X-API-Version") != "v1" { t.Errorf("X-API-Version = %q, want 'v1'", handlerHeaders.Get("X-API-Version")) } }) } func TestRequestPatternAndMiddlewareCallCount(t *testing.T) { // regression test for issue #24 - verify that: // 1. middlewares are not executed twice // 2. Request.Pattern is available in global middlewares var callCount map[string]int var mu sync.Mutex patternLogger := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mu.Lock() callCount[r.URL.Path]++ patternBefore := r.Pattern mu.Unlock() next.ServeHTTP(w, r) mu.Lock() patternAfter := r.Pattern mu.Unlock() // verify pattern is set before calling next handler if patternBefore == "" { t.Errorf("pattern should be set before ServeHTTP, got empty for path %s", r.URL.Path) } // verify pattern remains consistent if patternBefore != patternAfter { t.Errorf("pattern changed from %q to %q for path %s", patternBefore, patternAfter, r.URL.Path) } }) } handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) id := r.PathValue("id") if id != "" { _, _ = w.Write([]byte("id: " + id)) } } t.Run("root group with path params", func(t *testing.T) { callCount = make(map[string]int) rtr := routegroup.New(http.NewServeMux()) rtr.Use(patternLogger) rtr.HandleFunc("GET /a/{id}", handler) recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodGet, "/a/123", http.NoBody) if err != nil { t.Fatal(err) } rtr.ServeHTTP(recorder, request) if recorder.Code != http.StatusOK { t.Errorf("expected status 200, got %d", recorder.Code) } if body := recorder.Body.String(); body != "id: 123" { t.Errorf("expected 'id: 123', got %q", body) } // verify middleware was called exactly once if count := callCount["/a/123"]; count != 1 { t.Errorf("middleware should be called exactly once, but was called %d times", count) } }) t.Run("mounted group with path params", func(t *testing.T) { callCount = make(map[string]int) rtr := routegroup.New(http.NewServeMux()) rtr.Use(patternLogger) bGroup := rtr.Mount("/b") bGroup.HandleFunc("GET /{id}", handler) recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodGet, "/b/456", http.NoBody) if err != nil { t.Fatal(err) } rtr.ServeHTTP(recorder, request) if recorder.Code != http.StatusOK { t.Errorf("expected status 200, got %d", recorder.Code) } if body := recorder.Body.String(); body != "id: 456" { t.Errorf("expected 'id: 456', got %q", body) } // verify middleware was called exactly once if count := callCount["/b/456"]; count != 1 { t.Errorf("middleware should be called exactly once for mounted path, but was called %d times", count) } }) t.Run("multiple middlewares see pattern", func(t *testing.T) { callCount = make(map[string]int) var patterns []string var mu2 sync.Mutex middleware1 := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mu.Lock() callCount["m1"]++ mu.Unlock() mu2.Lock() patterns = append(patterns, "m1:"+r.Pattern) mu2.Unlock() next.ServeHTTP(w, r) }) } middleware2 := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mu.Lock() callCount["m2"]++ mu.Unlock() mu2.Lock() patterns = append(patterns, "m2:"+r.Pattern) mu2.Unlock() next.ServeHTTP(w, r) }) } rtr := routegroup.New(http.NewServeMux()) rtr.Use(middleware1, middleware2) rtr.HandleFunc("GET /test/{id}", handler) recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodGet, "/test/789", http.NoBody) if err != nil { t.Fatal(err) } rtr.ServeHTTP(recorder, request) // verify each middleware called once if count := callCount["m1"]; count != 1 { t.Errorf("middleware1 should be called once, got %d", count) } if count := callCount["m2"]; count != 1 { t.Errorf("middleware2 should be called once, got %d", count) } // verify both middlewares saw the pattern if len(patterns) != 2 { t.Errorf("expected 2 pattern records, got %d", len(patterns)) } if len(patterns) == 2 { if patterns[0] != "m1:GET /test/{id}" { t.Errorf("middleware1 pattern = %q, want %q", patterns[0], "m1:GET /test/{id}") } if patterns[1] != "m2:GET /test/{id}" { t.Errorf("middleware2 pattern = %q, want %q", patterns[1], "m2:GET /test/{id}") } } }) t.Run("route without path params", func(t *testing.T) { callCount = make(map[string]int) var seenPattern string checkPattern := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mu.Lock() callCount[r.URL.Path]++ seenPattern = r.Pattern mu.Unlock() next.ServeHTTP(w, r) }) } rtr := routegroup.New(http.NewServeMux()) rtr.Use(checkPattern) rtr.HandleFunc("GET /static", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("static")) }) recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodGet, "/static", http.NoBody) if err != nil { t.Fatal(err) } rtr.ServeHTTP(recorder, request) if count := callCount["/static"]; count != 1 { t.Errorf("middleware should be called once for static route, got %d", count) } if seenPattern != "GET /static" { t.Errorf("pattern = %q, want %q", seenPattern, "GET /static") } }) } // TestRequestIsolation verifies that the original request passed to ServeHTTP // is not modified, ensuring proper isolation through shallow copy. func TestRequestIsolation(t *testing.T) { t.Run("original request not modified", func(t *testing.T) { router := routegroup.New(http.NewServeMux()) var middlewareRequest *http.Request // middleware that captures the request object router.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { middlewareRequest = r next.ServeHTTP(w, r) }) }) router.HandleFunc("GET /test/{id}", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) // create the original request originalRequest, err := http.NewRequest(http.MethodGet, "/test/123", http.NoBody) if err != nil { t.Fatal(err) } // save original state originalPattern := originalRequest.Pattern // make the request recorder := httptest.NewRecorder() router.ServeHTTP(recorder, originalRequest) // verify the original request was not modified if originalRequest.Pattern != originalPattern { t.Errorf("original request was modified: Pattern changed from %q to %q", originalPattern, originalRequest.Pattern) } // verify middleware received a different request object (shallow copy) if middlewareRequest == originalRequest { t.Error("middleware received the same request object (expected a copy)") } // verify middleware's request has the pattern set if middlewareRequest.Pattern == "" { t.Error("middleware's request should have Pattern set") } }) t.Run("isolation with 404", func(t *testing.T) { router := routegroup.New(http.NewServeMux()) router.NotFoundHandler(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) }) router.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Middleware", "ran") next.ServeHTTP(w, r) }) }) originalRequest, err := http.NewRequest(http.MethodGet, "/non-existent", http.NoBody) if err != nil { t.Fatal(err) } originalPattern := originalRequest.Pattern recorder := httptest.NewRecorder() router.ServeHTTP(recorder, originalRequest) if recorder.Code != http.StatusNotFound { t.Errorf("expected 404, got %d", recorder.Code) } // original request should not be modified if originalRequest.Pattern != originalPattern { t.Errorf("original request was modified: Pattern changed from %q to %q", originalPattern, originalRequest.Pattern) } }) } ================================================ FILE: mount_test.go ================================================ package routegroup_test import ( "io" "net/http" "net/http/httptest" "reflect" "testing" "github.com/go-pkgz/routegroup" ) func TestMount(t *testing.T) { basePath := "/api" group := routegroup.Mount(http.NewServeMux(), basePath) group.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("X-Mounted-Middleware", "true") next.ServeHTTP(w, r) }) }) group.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) recorder := httptest.NewRecorder() request, err := http.NewRequest(http.MethodGet, basePath+"/test", http.NoBody) if err != nil { t.Fatal(err) } group.ServeHTTP(recorder, request) if recorder.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, recorder.Code) } if header := recorder.Header().Get("X-Mounted-Middleware"); header != "true" { t.Errorf("Expected header X-Mounted-Middleware to be 'true', got '%s'", header) } } func TestHTTPServerWithBasePathAndMiddleware(t *testing.T) { group := routegroup.Mount(http.NewServeMux(), "/api") group.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("X-Test-Middleware", "applied") next.ServeHTTP(w, r) }) }) group.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("test handler")) }) testServer := httptest.NewServer(group) defer testServer.Close() resp, err := http.Get(testServer.URL + "/api/test") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if string(body) != "test handler" { t.Errorf("Expected body 'test handler', got '%s'", string(body)) } if header := resp.Header.Get("X-Test-Middleware"); header != "applied" { t.Errorf("Expected header X-Test-Middleware to be 'applied', got '%s'", header) } } func TestHTTPServerWithBasePathNoMiddleware(t *testing.T) { group := routegroup.Mount(http.NewServeMux(), "/api") group.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("test handler")) }) testServer := httptest.NewServer(group) defer testServer.Close() resp, err := http.Get(testServer.URL + "/api/test") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if string(body) != "test handler" { t.Errorf("Expected body 'test handler', got '%s'", string(body)) } } func TestHTTPServerWithDerived(t *testing.T) { // create a new bundle with default middleware bundle := routegroup.New(http.NewServeMux()) bundle.NotFoundHandler(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusNotFound) _, _ = w.Write([]byte("not found handler")) }) bundle.Use(testMiddleware) // mount a group with additional middleware on /api group1 := bundle.Mount("/api") group1.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("X-API-Middleware", "applied") next.ServeHTTP(w, r) }) }) group1.HandleFunc("GET /test", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("GET test method handler")) }) group1.HandleFunc("POST /", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("POST api / method handler")) }) // add another group with middleware bundle.Group().Route(func(g *routegroup.Bundle) { g.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("X-Blah-Middleware", "true") next.ServeHTTP(w, r) }) }) g.HandleFunc("GET /blah/blah", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("GET blah method handler")) }) }) // mount the bundle on /auth under /api group1.Mount("/auth").Route(func(g *routegroup.Bundle) { g.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("X-Auth-Middleware", "true") next.ServeHTTP(w, r) }) }) g.HandleFunc("GET /auth-test", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("GET auth-test method handler")) }) g.HandleFunc("GET /", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("auth GET / method handler")) }) }) testServer := httptest.NewServer(bundle) defer testServer.Close() t.Run("GET /api/test", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/api/test") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if string(body) != "GET test method handler" { t.Errorf("Expected body 'GET test method handler', got '%s'", string(body)) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) t.Run("GET /blah/blah", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/blah/blah") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if string(body) != "GET blah method handler" { t.Errorf("Expected body 'GET blah method handler', got '%s'", string(body)) } if header := resp.Header.Get("X-Blah-Middleware"); header != "true" { t.Errorf("Expected header X-Blah-Middleware to be 'true', got '%s'", header) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) t.Run("GET /api/auth/auth-test", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/api/auth/auth-test") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if string(body) != "GET auth-test method handler" { t.Errorf("Expected body 'GET auth-test method handler', got '%s'", string(body)) } if header := resp.Header.Get("X-Auth-Middleware"); header != "true" { t.Errorf("Expected header X-Auth-Middleware to be 'true', got '%s'", header) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) t.Run("GET /api/auth/", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/api/auth/") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if string(body) != "auth GET / method handler" { t.Errorf("Expected body 'GET auth-test method handler', got '%s'", string(body)) } if header := resp.Header.Get("X-Auth-Middleware"); header != "true" { t.Errorf("Expected header X-Auth-Middleware to be 'true', got '%s'", header) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) t.Run("POST /api/", func(t *testing.T) { resp, err := http.Post(testServer.URL+"/api/", "application/json", http.NoBody) if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) } if string(body) != "POST api / method handler" { t.Errorf("Expected body 'GET auth-test method handler', got '%s'", string(body)) } if header := resp.Header.Get("X-Auth-Middleware"); header != "" { t.Errorf("Expected header X-Auth-Middleware to be empty, got '%s'", header) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) t.Run("POST /api/not-found", func(t *testing.T) { resp, err := http.Post(testServer.URL+"/api/not-found", "application/json", http.NoBody) if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusNotFound { t.Errorf("Expected status code %d, got %d", http.StatusNotFound, resp.StatusCode) } if string(body) != "not found handler" { t.Errorf("Expected body '404 page not found', got '%s'", string(body)) } if header := resp.Header.Get("X-Auth-Middleware"); header != "" { t.Errorf("Expected header X-Auth-Middleware to be empty, got '%s'", header) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) t.Run("GET /api/", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/api/") if err != nil { t.Fatal(err) } defer resp.Body.Close() // should return 405 Method Not Allowed since POST / is registered but not GET / if resp.StatusCode != http.StatusMethodNotAllowed { t.Errorf("Expected status code %d, got %d", http.StatusMethodNotAllowed, resp.StatusCode) } // 405 response should include Allow header if allowHeader := resp.Header.Get("Allow"); allowHeader == "" { t.Error("Expected Allow header for 405 response") } if header := resp.Header.Get("X-Auth-Middleware"); header != "" { t.Errorf("Expected header X-Auth-Middleware to be empty, got '%s'", header) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) t.Run("GET /not-found", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/not-found") if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.StatusCode != http.StatusNotFound { t.Errorf("Expected status code %d, got %d", http.StatusNotFound, resp.StatusCode) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) } func TestMountNested(t *testing.T) { bundle := routegroup.New(http.NewServeMux()) api := bundle.Mount("/api") v1 := api.Mount("/v1") v1.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) if _, err := w.Write([]byte("v1 test")); err != nil { t.Fatal(err) } }) rec := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/api/v1/test", http.NoBody) bundle.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("got %d, want %d", rec.Code, http.StatusOK) } if rec.Body.String() != "v1 test" { t.Errorf("got %q, want %q", rec.Body.String(), "v1 test") } } func TestMountPointMethodConflicts(t *testing.T) { group := routegroup.New(http.NewServeMux()) // register handler for /api directly group.HandleFunc("GET /api", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("api root")) }) // mount a group at /api api := group.Mount("/api") api.HandleFunc("/users", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("users")) }) srv := httptest.NewServer(group) defer srv.Close() t.Run("get /api root", func(t *testing.T) { resp, err := http.Get(srv.URL + "/api") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } if string(body) != "api root" { t.Errorf("expected 'api root', got %q", string(body)) } }) t.Run("get /api/users", func(t *testing.T) { resp, err := http.Get(srv.URL + "/api/users") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } if string(body) != "users" { t.Errorf("expected 'users', got %q", string(body)) } }) } func TestDeepNestedMounts(t *testing.T) { var callOrder []string mkMiddleware := func(name string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callOrder = append(callOrder, "before "+name) next.ServeHTTP(w, r) callOrder = append(callOrder, "after "+name) }) } } group := routegroup.New(http.NewServeMux()) group.Use(mkMiddleware("root")) v1 := group.Mount("/v1") v1.Use(mkMiddleware("v1")) api := v1.Mount("/api") api.Use(mkMiddleware("api")) users := api.Mount("/users") users.Use(mkMiddleware("users")) users.HandleFunc("/list", func(w http.ResponseWriter, _ *http.Request) { callOrder = append(callOrder, "handler") _, _ = w.Write([]byte("users list")) }) srv := httptest.NewServer(group) defer srv.Close() resp, err := http.Get(srv.URL + "/v1/api/users/list") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } if string(body) != "users list" { t.Errorf("expected 'users list', got %q", string(body)) } expected := []string{ "before root", "before v1", "before api", "before users", "handler", "after users", "after api", "after v1", "after root", } if !reflect.DeepEqual(callOrder, expected) { t.Errorf("middleware execution order mismatch\nwant: %v\ngot: %v", expected, callOrder) } } // TestSubgroupRootPathMatching tests that a subgroup with a root path pattern (/) // properly matches requests to the exact path without a trailing slash. func TestSubgroupRootPathMatching(t *testing.T) { mux := http.NewServeMux() router := routegroup.New(mux) // create a mounted group at /api/v1/users usersGroup := router.Mount("/api/v1/users") // add middleware to the group to test middleware invocation usersGroup.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Users-Middleware", "applied") next.ServeHTTP(w, r) }) }) // register handler for the root of the mounted group using "/" usersGroup.HandleFunc("GET /", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("users root")) }) // also add a child route for comparison usersGroup.HandleFunc("GET /list", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("users list")) }) srv := httptest.NewServer(router) defer srv.Close() t.Run("exact match without trailing slash", func(t *testing.T) { resp, err := http.Get(srv.URL + "/api/v1/users") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } if string(body) != "users root" { t.Errorf("expected 'users root', got %q", string(body)) } // check middleware was applied middlewareHeader := resp.Header.Get("X-Users-Middleware") if middlewareHeader != "applied" { t.Errorf("expected middleware header to be 'applied', got %q", middlewareHeader) } }) t.Run("with trailing slash", func(t *testing.T) { resp, err := http.Get(srv.URL + "/api/v1/users/") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } if string(body) != "users root" { t.Errorf("expected 'users root', got %q", string(body)) } // check middleware was applied middlewareHeader := resp.Header.Get("X-Users-Middleware") if middlewareHeader != "applied" { t.Errorf("expected middleware header to be 'applied', got %q", middlewareHeader) } }) t.Run("child route", func(t *testing.T) { resp, err := http.Get(srv.URL + "/api/v1/users/list") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } if string(body) != "users list" { t.Errorf("expected 'users list', got %q", string(body)) } // check middleware was applied middlewareHeader := resp.Header.Get("X-Users-Middleware") if middlewareHeader != "applied" { t.Errorf("expected middleware header to be 'applied', got %q", middlewareHeader) } }) } ================================================ FILE: notfound_test.go ================================================ package routegroup_test import ( "io" "net/http" "net/http/httptest" "testing" "github.com/go-pkgz/routegroup" ) func TestHTTPServerWithCustomNotFound(t *testing.T) { group := routegroup.New(http.NewServeMux()) group.Use(testMiddleware) group.NotFoundHandler(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "Custom 404: Page not found!", http.StatusNotFound) }) apiGroup := group.Mount("/api") apiGroup.HandleFunc("GET /test", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("test handler")) }) testServer := httptest.NewServer(group) defer testServer.Close() t.Run("GET /api/test", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/api/test") if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if string(body) != "test handler" { t.Errorf("Expected body 'test handler', got '%s'", string(body)) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) t.Run("GET /api/not-found", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/api/not-found") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } t.Logf("body: %s", body) if resp.StatusCode != http.StatusNotFound { t.Errorf("Expected status code %d, got %d", http.StatusNotFound, resp.StatusCode) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } if string(body) != "Custom 404: Page not found!\n" { t.Errorf("Expected body 'Custom 404: Page not found!', got '%s'", string(body)) } }) t.Run("GET /not-found", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/not-found") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } t.Logf("body: %s", body) if resp.StatusCode != http.StatusNotFound { t.Errorf("Expected status code %d, got %d", http.StatusNotFound, resp.StatusCode) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } if string(body) != "Custom 404: Page not found!\n" { t.Errorf("Expected body 'Custom 404: Page not found!', got '%s'", string(body)) } }) } func TestHTTPServerWithCustomNotFoundNon404Status(t *testing.T) { group := routegroup.New(http.NewServeMux()) group.Use(testMiddleware) group.NotFoundHandler(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusServiceUnavailable) _, _ = w.Write([]byte("Custom 404: Page not found!\n")) }) apiGroup := group.Mount("/api") apiGroup.HandleFunc("GET /test", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("test handler")) }) testServer := httptest.NewServer(group) defer testServer.Close() t.Run("GET /api/test", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/api/test") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) } if string(body) != "test handler" { t.Errorf("Expected body 'test handler', got '%s'", string(body)) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) t.Run("GET /api/not-found", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/api/not-found") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } t.Logf("body: %s", body) if resp.StatusCode != http.StatusServiceUnavailable { t.Errorf("Expected status code %d, got %d", http.StatusServiceUnavailable, resp.StatusCode) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } if string(body) != "Custom 404: Page not found!\n" { t.Errorf("Expected body 'Custom 404: Page not found!', got '%s'", string(body)) } }) } func TestCustomNotFoundHandlerChange(t *testing.T) { group := routegroup.New(http.NewServeMux()) group.NotFoundHandler(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "First handler", http.StatusNotFound) }) group.NotFoundHandler(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "Second handler", http.StatusNotFound) }) rec := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/not-found", http.NoBody) group.ServeHTTP(rec, req) if rec.Code != http.StatusNotFound { t.Errorf("got %d, want %d", rec.Code, http.StatusNotFound) } if rec.Body.String() != "Second handler\n" { t.Errorf("got %q, want %q", rec.Body.String(), "Second handler\n") } } func TestDisableNotFoundHandlerAfterRouteRegistration(t *testing.T) { group := routegroup.New(http.NewServeMux()) group.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) { if _, err := w.Write([]byte("test")); err != nil { t.Fatal(err) } }) group.DisableNotFoundHandler() rec := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/not-found", http.NoBody) group.ServeHTTP(rec, req) if rec.Code != http.StatusNotFound { t.Errorf("got %d, want %d", rec.Code, http.StatusNotFound) } if rec.Body.String() != "404 page not found\n" { t.Errorf("got %q, want %q", rec.Body.String(), "404 page not found\n") } } func TestNotFoundHandlerOnMountedGroup(t *testing.T) { // test that NotFoundHandler sets handler on root when called on mounted group root := routegroup.New(http.NewServeMux()) mounted := root.Mount("/api") // set NotFoundHandler on mounted group - should set it on root mounted.NotFoundHandler(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "Custom 404 from mounted", http.StatusNotFound) }) // add a route to the mounted group mounted.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("test")) }) testServer := httptest.NewServer(root) defer testServer.Close() // test that custom 404 works for non-matching routes resp, err := http.Get(testServer.URL + "/unknown") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusNotFound { t.Errorf("got status %d, want %d", resp.StatusCode, http.StatusNotFound) } if string(body) != "Custom 404 from mounted\n" { t.Errorf("got body %q, want %q", string(body), "Custom 404 from mounted\n") } } func TestStatusRecorderWith200Default(t *testing.T) { // test that statusRecorder correctly identifies handlers that return 200 without explicit WriteHeader group := routegroup.New(http.NewServeMux()) // set custom NotFound handler group.NotFoundHandler(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "Custom 404", http.StatusNotFound) }) // register a handler that returns 200 without calling WriteHeader explicitly // this is common practice - handlers often just call Write() which implicitly sets 200 group.HandleFunc("GET /implicit200", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("success")) // no WriteHeader call, should be 200 }) testServer := httptest.NewServer(group) defer testServer.Close() // test that the handler works correctly and isn't mistaken for a 404 resp, err := http.Get(testServer.URL + "/implicit200") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) // this should return 200 with "success" body if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } if string(body) != "success" { t.Errorf("expected body 'success', got %q", string(body)) } // verify it didn't trigger the custom 404 handler if string(body) == "Custom 404\n" { t.Error("custom 404 handler was incorrectly triggered for a valid 200 response") } } func TestCustomNotFoundVsMethodNotAllowed(t *testing.T) { // test demonstrates issue #27 - custom NotFound handler should not override 405 Method Not Allowed t.Run("without custom NotFound handler", func(t *testing.T) { group := routegroup.New(http.NewServeMux()) // register a route for GET method only group.HandleFunc("GET /api/resource", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("GET response")) }) testServer := httptest.NewServer(group) defer testServer.Close() // test POST to the same path - should return 405 req, _ := http.NewRequest(http.MethodPost, testServer.URL+"/api/resource", http.NoBody) client := &http.Client{} resp, err := client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) t.Logf("without custom NotFound: status=%d, body=%s", resp.StatusCode, body) // this should return 405 Method Not Allowed if resp.StatusCode != http.StatusMethodNotAllowed { t.Errorf("expected status %d (Method Not Allowed), got %d", http.StatusMethodNotAllowed, resp.StatusCode) } // test that Allow header is present allowHeader := resp.Header.Get("Allow") if allowHeader == "" { t.Error("expected Allow header to be present for 405 response") } else { t.Logf("Allow header: %s", allowHeader) } }) t.Run("with custom NotFound handler", func(t *testing.T) { group := routegroup.New(http.NewServeMux()) // set custom NotFound handler group.NotFoundHandler(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "Custom 404: Not Found", http.StatusNotFound) }) // register a route for GET method only group.HandleFunc("GET /api/resource", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("GET response")) }) testServer := httptest.NewServer(group) defer testServer.Close() // test POST to the same path - should still return 405, not 404 req, _ := http.NewRequest(http.MethodPost, testServer.URL+"/api/resource", http.NoBody) client := &http.Client{} resp, err := client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) t.Logf("with custom NotFound: status=%d, body=%s", resp.StatusCode, body) // this should return 405 Method Not Allowed, but might return 404 if the issue exists if resp.StatusCode != http.StatusMethodNotAllowed { t.Errorf("expected status %d (Method Not Allowed), got %d - custom NotFound handler incorrectly overrides 405", http.StatusMethodNotAllowed, resp.StatusCode) } // test that Allow header is present allowHeader := resp.Header.Get("Allow") if allowHeader == "" && resp.StatusCode == http.StatusMethodNotAllowed { t.Error("expected Allow header to be present for 405 response") } else if allowHeader != "" { t.Logf("Allow header: %s", allowHeader) } // verify the body is not the custom 404 message when it should be 405 if resp.StatusCode == http.StatusNotFound && string(body) == "Custom 404: Not Found\n" { t.Error("custom NotFound handler was incorrectly called for a method mismatch (should be 405)") } }) // additional test case: verify that actual 404 still uses custom handler t.Run("verify actual 404 uses custom handler", func(t *testing.T) { group := routegroup.New(http.NewServeMux()) // set custom NotFound handler group.NotFoundHandler(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "Custom 404: Not Found", http.StatusNotFound) }) // register a route for GET method group.HandleFunc("GET /api/resource", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("GET response")) }) testServer := httptest.NewServer(group) defer testServer.Close() // test a completely non-existent path - should use custom 404 resp, err := http.Get(testServer.URL + "/api/nonexistent") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) t.Logf("actual 404: status=%d, body=%s", resp.StatusCode, body) if resp.StatusCode != http.StatusNotFound { t.Errorf("expected status %d, got %d", http.StatusNotFound, resp.StatusCode) } // verify custom 404 message is used if string(body) != "Custom 404: Not Found\n" { t.Errorf("expected custom 404 body, got %q", string(body)) } }) } ================================================ FILE: pathparams_test.go ================================================ package routegroup_test import ( "context" "fmt" "net/http" "net/http/httptest" "strings" "testing" "github.com/go-pkgz/routegroup" ) func TestPathParametersWithMount(t *testing.T) { tests := []struct { name string method string setupFunc func() *routegroup.Bundle requestPath string expectedParam string expectedStatus int }{ { name: "path parameters with mount", method: "POST", setupFunc: func() *routegroup.Bundle { mux := http.NewServeMux() bundle := routegroup.Mount(mux, "/api/v0") peerGroup := bundle.Mount("/peer") peerGroup.HandleFunc("POST /iface/{iface}/multiplenew", func(w http.ResponseWriter, r *http.Request) { interfaceID := r.PathValue("iface") _, _ = w.Write([]byte("iface=" + interfaceID)) }) return bundle }, requestPath: "/api/v0/peer/iface/test123/multiplenew", expectedParam: "iface=test123", expectedStatus: http.StatusOK, }, { name: "path parameters with group", method: "POST", setupFunc: func() *routegroup.Bundle { mux := http.NewServeMux() bundle := routegroup.Mount(mux, "/api/v0") peerGroup := bundle.Group() peerGroup.HandleFunc("POST /peer/iface/{iface}/multiplenew", func(w http.ResponseWriter, r *http.Request) { interfaceID := r.PathValue("iface") _, _ = w.Write([]byte("iface=" + interfaceID)) }) return bundle }, requestPath: "/api/v0/peer/iface/test123/multiplenew", expectedParam: "iface=test123", expectedStatus: http.StatusOK, }, { name: "multiple path parameters", method: "GET", setupFunc: func() *routegroup.Bundle { mux := http.NewServeMux() bundle := routegroup.Mount(mux, "/api") bundle.HandleFunc("GET /users/{userID}/posts/{postID}", func(w http.ResponseWriter, r *http.Request) { userID := r.PathValue("userID") postID := r.PathValue("postID") _, _ = fmt.Fprintf(w, "user=%s,post=%s", userID, postID) }) return bundle }, requestPath: "/api/users/alice/posts/42", expectedParam: "user=alice,post=42", expectedStatus: http.StatusOK, }, { name: "path parameters with middleware", method: "GET", setupFunc: func() *routegroup.Bundle { mux := http.NewServeMux() bundle := routegroup.Mount(mux, "/api") bundle.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Middleware", "applied") next.ServeHTTP(w, r) }) }) bundle.HandleFunc("GET /items/{id}", func(w http.ResponseWriter, r *http.Request) { itemID := r.PathValue("id") _, _ = w.Write([]byte("item=" + itemID)) }) return bundle }, requestPath: "/api/items/xyz", expectedParam: "item=xyz", expectedStatus: http.StatusOK, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bundle := tt.setupFunc() req := httptest.NewRequest(tt.method, tt.requestPath, http.NoBody) rr := httptest.NewRecorder() bundle.ServeHTTP(rr, req) if rr.Code != tt.expectedStatus { t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) } if rr.Body.String() != tt.expectedParam { t.Errorf("expected %q, got %q", tt.expectedParam, rr.Body.String()) } }) } } // TestRemainderWildcards tests the {path...} remainder wildcard feature func TestRemainderWildcards(t *testing.T) { tests := []struct { name string pattern string requestPath string expectedParam string expectedStatus int }{ { name: "single segment remainder", pattern: "GET /files/{path...}", requestPath: "/files/document.txt", expectedParam: "document.txt", expectedStatus: http.StatusOK, }, { name: "multiple segments remainder", pattern: "GET /files/{path...}", requestPath: "/files/docs/2024/report.pdf", expectedParam: "docs/2024/report.pdf", expectedStatus: http.StatusOK, }, { name: "remainder with mount", pattern: "GET /static/{filepath...}", requestPath: "/api/static/css/style.css", expectedParam: "css/style.css", expectedStatus: http.StatusOK, }, { name: "empty remainder", pattern: "GET /files/{path...}", requestPath: "/files/", expectedParam: "", expectedStatus: http.StatusOK, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mux := http.NewServeMux() if tt.name == "remainder with mount" { bundle := routegroup.Mount(mux, "/api") bundle.HandleFunc(tt.pattern, func(w http.ResponseWriter, r *http.Request) { filepathParam := r.PathValue("filepath") _, _ = w.Write([]byte(filepathParam)) }) req := httptest.NewRequest("GET", tt.requestPath, http.NoBody) rr := httptest.NewRecorder() bundle.ServeHTTP(rr, req) if rr.Code != tt.expectedStatus { t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) } if rr.Body.String() != tt.expectedParam { t.Errorf("expected %q, got %q", tt.expectedParam, rr.Body.String()) } } else { bundle := routegroup.New(mux) bundle.HandleFunc(tt.pattern, func(w http.ResponseWriter, r *http.Request) { path := r.PathValue("path") _, _ = w.Write([]byte(path)) }) req := httptest.NewRequest("GET", tt.requestPath, http.NoBody) rr := httptest.NewRecorder() bundle.ServeHTTP(rr, req) if rr.Code != tt.expectedStatus { t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) } if rr.Body.String() != tt.expectedParam { t.Errorf("expected %q, got %q", tt.expectedParam, rr.Body.String()) } } }) } } // TestMiddlewareWithContextAndPathParams tests that path params survive middleware context changes func TestMiddlewareWithContextAndPathParams(t *testing.T) { type contextKey string const userKey contextKey = "user" type requestIDKey string const reqIDKey requestIDKey = "request-id" tests := []struct { name string method string setupFunc func() *routegroup.Bundle requestPath string expectedParam string expectedUser string }{ { name: "WithContext preserves path params", method: "GET", setupFunc: func() *routegroup.Bundle { mux := http.NewServeMux() bundle := routegroup.New(mux) // middleware that adds context value using WithContext bundle.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), userKey, "alice") next.ServeHTTP(w, r.WithContext(ctx)) }) }) bundle.HandleFunc("GET /users/{id}", func(w http.ResponseWriter, r *http.Request) { userID := r.PathValue("id") user := r.Context().Value(userKey).(string) _, _ = fmt.Fprintf(w, "id=%s,user=%s", userID, user) }) return bundle }, requestPath: "/users/123", expectedParam: "id=123,user=alice", expectedUser: "alice", }, { name: "Clone preserves path params", method: "GET", setupFunc: func() *routegroup.Bundle { mux := http.NewServeMux() bundle := routegroup.New(mux) // middleware that clones request with new context bundle.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), userKey, "bob") newReq := r.Clone(ctx) next.ServeHTTP(w, newReq) }) }) bundle.HandleFunc("GET /items/{itemID}/details", func(w http.ResponseWriter, r *http.Request) { itemID := r.PathValue("itemID") user := r.Context().Value(userKey).(string) _, _ = fmt.Fprintf(w, "item=%s,user=%s", itemID, user) }) return bundle }, requestPath: "/items/xyz/details", expectedParam: "item=xyz,user=bob", expectedUser: "bob", }, { name: "Multiple middleware with context changes", method: "POST", setupFunc: func() *routegroup.Bundle { mux := http.NewServeMux() bundle := routegroup.New(mux) // first middleware bundle.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), reqIDKey, "req-123") next.ServeHTTP(w, r.WithContext(ctx)) }) }) // second middleware bundle.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), userKey, "charlie") next.ServeHTTP(w, r.WithContext(ctx)) }) }) bundle.HandleFunc("POST /api/{version}/users/{id}", func(w http.ResponseWriter, r *http.Request) { version := r.PathValue("version") userID := r.PathValue("id") user := r.Context().Value(userKey).(string) reqID := r.Context().Value(reqIDKey).(string) _, _ = fmt.Fprintf(w, "v=%s,id=%s,user=%s,req=%s", version, userID, user, reqID) }) return bundle }, requestPath: "/api/v2/users/456", expectedParam: "v=v2,id=456,user=charlie,req=req-123", expectedUser: "charlie", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bundle := tt.setupFunc() req := httptest.NewRequest(tt.method, tt.requestPath, http.NoBody) rr := httptest.NewRecorder() bundle.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected status OK, got %d", rr.Code) } if rr.Body.String() != tt.expectedParam { t.Errorf("expected %q, got %q", tt.expectedParam, rr.Body.String()) } }) } } // TestURLEncodedPathParams tests that URL-encoded path parameters are properly decoded func TestURLEncodedPathParams(t *testing.T) { tests := []struct { name string pattern string requestPath string expectedParam string expectedStatus int }{ { name: "space encoded as %20", pattern: "GET /users/{name}", requestPath: "/users/John%20Doe", expectedParam: "John Doe", expectedStatus: http.StatusOK, }, { name: "slash encoded as %2F", pattern: "GET /files/{filename}", requestPath: "/files/folder%2Ffile.txt", expectedParam: "folder/file.txt", expectedStatus: http.StatusOK, }, { name: "special characters", pattern: "GET /search/{query}", requestPath: "/search/hello%3Dworld%26foo%3Dbar", expectedParam: "hello=world&foo=bar", expectedStatus: http.StatusOK, }, { name: "unicode characters", pattern: "GET /users/{name}", requestPath: "/users/%E4%B8%AD%E6%96%87", expectedParam: "中文", expectedStatus: http.StatusOK, }, { name: "plus sign handling", pattern: "GET /api/{version}", requestPath: "/api/v1%2B2", expectedParam: "v1+2", expectedStatus: http.StatusOK, }, { name: "percent sign itself", pattern: "GET /discount/{code}", requestPath: "/discount/SAVE%2550", expectedParam: "SAVE%50", expectedStatus: http.StatusOK, }, { name: "encoded in remainder wildcard", pattern: "GET /static/{path...}", requestPath: "/static/images%2Flogo%20v2.png", expectedParam: "images/logo v2.png", expectedStatus: http.StatusOK, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mux := http.NewServeMux() bundle := routegroup.New(mux) bundle.HandleFunc(tt.pattern, func(w http.ResponseWriter, r *http.Request) { var param string switch { case strings.Contains(tt.pattern, "{path...}"): param = r.PathValue("path") case strings.Contains(tt.pattern, "{name}"): param = r.PathValue("name") case strings.Contains(tt.pattern, "{filename}"): param = r.PathValue("filename") case strings.Contains(tt.pattern, "{query}"): param = r.PathValue("query") case strings.Contains(tt.pattern, "{version}"): param = r.PathValue("version") case strings.Contains(tt.pattern, "{code}"): param = r.PathValue("code") } _, _ = w.Write([]byte(param)) }) req := httptest.NewRequest("GET", tt.requestPath, http.NoBody) rr := httptest.NewRecorder() bundle.ServeHTTP(rr, req) if rr.Code != tt.expectedStatus { t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) } if rr.Body.String() != tt.expectedParam { t.Errorf("expected %q, got %q", tt.expectedParam, rr.Body.String()) } }) } } // TestMethodlessPathParams tests path parameters without HTTP method prefix func TestMethodlessPathParams(t *testing.T) { mux := http.NewServeMux() bundle := routegroup.New(mux) // pattern without method prefix - should work for all methods bundle.HandleFunc("/items/{id}", func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") _, _ = w.Write([]byte("item:" + id)) }) tests := []struct { method string path string want string }{ {"GET", "/items/123", "item:123"}, {"POST", "/items/456", "item:456"}, {"PUT", "/items/789", "item:789"}, {"DELETE", "/items/abc", "item:abc"}, } for _, tt := range tests { t.Run(tt.method, func(t *testing.T) { req := httptest.NewRequest(tt.method, tt.path, http.NoBody) rr := httptest.NewRecorder() bundle.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected status OK, got %d", rr.Code) } if rr.Body.String() != tt.want { t.Errorf("expected %q, got %q", tt.want, rr.Body.String()) } }) } } // TestRootBundlePathParams tests path parameters on root bundle without Mount func TestRootBundlePathParams(t *testing.T) { mux := http.NewServeMux() bundle := routegroup.New(mux) bundle.HandleFunc("GET /users/{id}", func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") _, _ = w.Write([]byte("user:" + id)) }) req := httptest.NewRequest("GET", "/users/root123", http.NoBody) rr := httptest.NewRecorder() bundle.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected status OK, got %d", rr.Code) } if got := rr.Body.String(); got != "user:root123" { t.Errorf("expected user:root123, got %q", got) } } // TestHEADWithPathParams tests HEAD requests with path parameters func TestHEADWithPathParams(t *testing.T) { mux := http.NewServeMux() bundle := routegroup.New(mux) // GET handler should also handle HEAD bundle.HandleFunc("GET /api/{version}/status", func(w http.ResponseWriter, r *http.Request) { version := r.PathValue("version") w.Header().Set("X-Version", version) w.Header().Set("Content-Length", "10") if r.Method != "HEAD" { _, _ = w.Write([]byte("status:ok")) } }) // test HEAD request req := httptest.NewRequest("HEAD", "/api/v2/status", http.NoBody) rr := httptest.NewRecorder() bundle.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected status OK, got %d", rr.Code) } if got := rr.Header().Get("X-Version"); got != "v2" { t.Errorf("expected X-Version header v2, got %q", got) } if rr.Body.Len() != 0 { t.Errorf("HEAD response should have no body, got %d bytes", rr.Body.Len()) } } ================================================ FILE: routing_test.go ================================================ package routegroup_test import ( "fmt" "io" "net/http" "net/http/httptest" "strings" "testing" "github.com/go-pkgz/routegroup" ) func TestHTTPServerWithRoot(t *testing.T) { group := routegroup.New(http.NewServeMux()) group.Use(testMiddleware) group.HandleFunc("/test", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("test handler")) }) group.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("root handler")) }) testServer := httptest.NewServer(group) defer testServer.Close() t.Run("GET /test", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/test") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if string(body) != "test handler" { t.Errorf("Expected body 'test handler', got '%s'", string(body)) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) t.Run("GET /", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if string(body) != "root handler" { t.Errorf("Expected body 'root handler', got '%s'", string(body)) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) t.Run("/", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if string(body) != "root handler" { t.Errorf("Expected body 'root handler', got '%s'", string(body)) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) t.Run("GET /unknown-path", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/unknown-path") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusNotFound { t.Errorf("Expected status code %d, got %d", http.StatusNotFound, resp.StatusCode) } if string(body) != "404 page not found\n" { t.Errorf("Expected body '404 page not found\n', got '%s'", string(body)) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) } func TestHTTPServerWithRoot122(t *testing.T) { group := routegroup.New(http.NewServeMux()) group.Use(testMiddleware) group.HandleFunc("GET /test", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("test handler")) }) group.HandleFunc("GET /", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("root handler")) }) testServer := httptest.NewServer(group) defer testServer.Close() t.Run("GET /test", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/test") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if string(body) != "test handler" { t.Errorf("Expected body 'test handler', got '%s'", string(body)) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) t.Run("GET /", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if string(body) != "root handler" { t.Errorf("Expected body 'root handler', got '%s'", string(body)) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) t.Run("POST / wrong method -> 405", func(t *testing.T) { resp, err := http.Post(testServer.URL+"/", "application/json", http.NoBody) if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusMethodNotAllowed { t.Errorf("Expected status code %d, got %d", http.StatusMethodNotAllowed, resp.StatusCode) } if string(body) != "Method Not Allowed\n" { t.Errorf("Expected body 'Method Not Allowed', got '%s'", string(body)) } if allow := resp.Header.Get("Allow"); !strings.Contains(allow, http.MethodGet) { t.Errorf("expected Allow header to contain GET, got %q", allow) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) t.Run("GET /unknown-path", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/unknown-path") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusNotFound { t.Errorf("Expected status code %d, got %d", http.StatusNotFound, resp.StatusCode) } if string(body) != "404 page not found\n" { t.Errorf("Expected body '404 page not found\n', got '%s'", string(body)) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) } func TestRootAndCatchAll(t *testing.T) { group := routegroup.New(http.NewServeMux()) group.Use(testMiddleware) group.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("root handler")) }) group.NotFoundHandler(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("custom not found handler")) }) testServer := httptest.NewServer(group) defer testServer.Close() t.Run("GET /", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) } if string(body) != "root handler" { t.Errorf("Expected body 'root handler', got '%s'", string(body)) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) t.Run("GET /unknown-path", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/unknown-path") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) } if string(body) != "custom not found handler" { t.Errorf("Expected body 'custom not found handler', got '%s'", string(body)) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) } func TestHTTPServerMethodAndPathHandling(t *testing.T) { group := routegroup.Mount(http.NewServeMux(), "/api") group.Use(testMiddleware) group.HandleFunc("GET /test", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("GET test method handler")) }) group.HandleFunc("/test2", func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("test2 method handler")) }) testServer := httptest.NewServer(group) defer testServer.Close() t.Run("handle with verb", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/api/test") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if string(body) != "GET test method handler" { t.Errorf("Expected body 'GET test method handler', got '%s'", string(body)) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) t.Run("handle without verb", func(t *testing.T) { resp, err := http.Get(testServer.URL + "/api/test2") if err != nil { t.Fatal(err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if string(body) != "test2 method handler" { t.Errorf("Expected body 'test2 method handler', got '%s'", string(body)) } if header := resp.Header.Get("X-Test-Middleware"); header != "true" { t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) } }) } func TestMethodPatternsWithDifferentMethods(t *testing.T) { group := routegroup.New(http.NewServeMux()) group.HandleFunc("GET /test", func(w http.ResponseWriter, _ *http.Request) { if _, err := io.WriteString(w, "GET handler"); err != nil { t.Fatal(err) } }) group.HandleFunc("POST /test", func(w http.ResponseWriter, _ *http.Request) { if _, err := io.WriteString(w, "POST handler"); err != nil { t.Fatal(err) } }) tests := []struct { method, path, expected string code int }{ {http.MethodGet, "/test", "GET handler", http.StatusOK}, {http.MethodPost, "/test", "POST handler", http.StatusOK}, {http.MethodPut, "/test", "Method Not Allowed\n", http.StatusMethodNotAllowed}, } for _, tt := range tests { rec := httptest.NewRecorder() req, _ := http.NewRequest(tt.method, tt.path, http.NoBody) group.ServeHTTP(rec, req) if rec.Code != tt.code { t.Errorf("got %d, want %d", rec.Code, tt.code) } if rec.Body.String() != tt.expected { t.Errorf("got %q, want %q", rec.Body.String(), tt.expected) } } } func TestHandleTrailingSlash(t *testing.T) { router := routegroup.New(http.NewServeMux()) t.Run("handler for pattern with trailing slash", func(t *testing.T) { router.Handle("/path/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("handler with trailing slash")) })) router.HandleFunc("GET /path/sub", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("sub handler")) }) srv := httptest.NewServer(router) defer srv.Close() tests := []struct { name string method string path string wantStatus int wantBody string }{ {"GET /path/", http.MethodGet, "/path/", http.StatusOK, "handler with trailing slash"}, {"POST /path/", http.MethodPost, "/path/", http.StatusOK, "handler with trailing slash"}, {"GET /path/sub", http.MethodGet, "/path/sub", http.StatusOK, "sub handler"}, // more specific route wins {"POST /path/sub", http.MethodPost, "/path/sub", http.StatusOK, "handler with trailing slash"}, // falls back to /path/ {"GET /path/anything", http.MethodGet, "/path/anything", http.StatusOK, "handler with trailing slash"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req, err := http.NewRequest(tt.method, srv.URL+tt.path, http.NoBody) if err != nil { t.Fatal(err) } resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.StatusCode != tt.wantStatus { t.Errorf("got status %d, want %d", resp.StatusCode, tt.wantStatus) } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if got := string(body); got != tt.wantBody { t.Errorf("got body %q, want %q", got, tt.wantBody) } }) } }) t.Run("mounted handler with trailing slash", func(t *testing.T) { api := router.Mount("/api") api.Handle("/v1/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("api v1")) })) api.HandleFunc("GET /v1/data", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("api data")) }) srv := httptest.NewServer(router) defer srv.Close() tests := []struct { name string method string path string wantStatus int wantBody string }{ {"GET /api/v1/", http.MethodGet, "/api/v1/", http.StatusOK, "api v1"}, {"POST /api/v1/", http.MethodPost, "/api/v1/", http.StatusOK, "api v1"}, {"GET /api/v1/data", http.MethodGet, "/api/v1/data", http.StatusOK, "api data"}, // more specific route wins {"POST /api/v1/data", http.MethodPost, "/api/v1/data", http.StatusOK, "api v1"}, // falls back to /v1/ {"GET /api/v1/anything", http.MethodGet, "/api/v1/anything", http.StatusOK, "api v1"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req, err := http.NewRequest(tt.method, srv.URL+tt.path, http.NoBody) if err != nil { t.Fatal(err) } resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.StatusCode != tt.wantStatus { t.Errorf("got status %d, want %d", resp.StatusCode, tt.wantStatus) } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if got := string(body); got != tt.wantBody { t.Errorf("got body %q, want %q", got, tt.wantBody) } }) } }) } func TestInvalidPatterns(t *testing.T) { group := routegroup.New(http.NewServeMux()) tests := []struct { name string pattern string path string // actual URL path to test wantPanic bool }{ {"empty pattern", "", "/", true}, // ServeMux panics on empty pattern {"just spaces", " ", "/", true}, // ServeMux panics on spaces-only pattern {"spaces in path", "GET /path%20with%20spaces", "/path%20with%20spaces", false}, // encoded spaces work {"only method", "GET", "/", true}, // ServeMux panics on invalid pattern {"just one slash", "/", "/", false}, // root path works {"missing slash", "GET /path", "/path", false}, // normal pattern {"double slashes", "GET //path", "/", true}, // ServeMux panics on unclean paths {"path without slash", "path", "/", true}, // must start with / {"method path without slash", "GET path", "/", true}, // must start with / {"trailing slash", "GET /path/", "/path/", false}, // trailing slash is ok } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { handler := func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("handler")) } if tt.wantPanic { defer func() { if r := recover(); r == nil { t.Error("expected panic but got none") } }() group.HandleFunc(tt.pattern, handler) return } group.HandleFunc(tt.pattern, handler) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, tt.path, http.NoBody) group.ServeHTTP(rec, req) if rec.Code == 0 { t.Error("no response code set") } }) } } func TestHandleRoot(t *testing.T) { // create client that doesn't follow redirects client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse // don't follow redirects }, } t.Run("HandleRoot with middleware", func(t *testing.T) { group := routegroup.New(http.NewServeMux()) group.Mount("/api").Route(func(apiGroup *routegroup.Bundle) { apiGroup.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Middleware", "applied") next.ServeHTTP(w, r) }) }) apiGroup.HandleRoot("GET", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, err := w.Write([]byte("api root")); err != nil { t.Fatalf("failed to write response: %v", err) } })) apiGroup.HandleFunc("GET /test", func(w http.ResponseWriter, r *http.Request) { if _, err := w.Write([]byte("test")); err != nil { t.Fatalf("failed to write response: %v", err) } }) }) group.Mount("/api-2").Route(func(apiGroup *routegroup.Bundle) { apiGroup.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Middleware", "applied") next.ServeHTTP(w, r) }) }) apiGroup.HandleRootFunc("GET", func(w http.ResponseWriter, r *http.Request) { if _, err := w.Write([]byte("api root")); err != nil { t.Fatalf("failed to write response: %v", err) } }) apiGroup.HandleFunc("GET /test", func(w http.ResponseWriter, r *http.Request) { if _, err := w.Write([]byte("test")); err != nil { t.Fatalf("failed to write response: %v", err) } }) }) ts := httptest.NewServer(group) defer ts.Close() apis := []string{"/api", "/api-2"} for _, api := range apis { // test direct access to registered root /api - should NOT redirect and middleware should be applied resp, err := client.Get(ts.URL + api) if err != nil { t.Fatalf("failed to make request: %v", err) } if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } if resp.Header.Get("X-Middleware") != "applied" { t.Errorf("middleware not applied") } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("failed to read response body: %v", err) } if string(body) != "api root" { t.Errorf("expected 'api root', got '%s'", body) } if closeErr := resp.Body.Close(); closeErr != nil { t.Errorf("failed to close response body: %v", closeErr) } // test access to /api/test resp, err = client.Get(ts.URL + api + "/test") if err != nil { t.Fatalf("failed to make request: %v", err) } body, err = io.ReadAll(resp.Body) if err != nil { t.Fatalf("failed to read response body: %v", err) } if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } if string(body) != "test" { t.Errorf("expected 'test', got '%s'", body) } if closeErr := resp.Body.Close(); closeErr != nil { t.Errorf("failed to close response body: %v", closeErr) } // test POST request to /api req, err := http.NewRequest(http.MethodPost, ts.URL+api, http.NoBody) if err != nil { t.Fatalf("failed to create request: %v", err) } resp, err = client.Do(req) if err != nil { t.Fatalf("failed to make request: %v", err) } if resp.StatusCode != http.StatusMethodNotAllowed { t.Errorf("expected status 405, got %d", resp.StatusCode) } if allow := resp.Header.Get("Allow"); !strings.Contains(allow, http.MethodGet) { t.Errorf("expected Allow header to contain GET, got %q", allow) } if closeErr := resp.Body.Close(); closeErr != nil { t.Errorf("failed to close response body: %v", closeErr) } } }) t.Run("HandleRoot without method", func(t *testing.T) { group := routegroup.New(http.NewServeMux()) group.Mount("/data").Route(func(dataGroup *routegroup.Bundle) { dataGroup.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Middleware", "applied") next.ServeHTTP(w, r) }) }) // register without specifying a method (empty string) dataGroup.HandleRoot("", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, err := w.Write([]byte("data root")); err != nil { t.Fatalf("failed to write response: %v", err) } })) }) group.Mount("/data-2").Route(func(dataGroup *routegroup.Bundle) { dataGroup.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Middleware", "applied") next.ServeHTTP(w, r) }) }) // register without specifying a method (empty string) dataGroup.HandleRootFunc("", func(w http.ResponseWriter, r *http.Request) { if _, err := w.Write([]byte("data root")); err != nil { t.Fatalf("failed to write response: %v", err) } }) }) ts := httptest.NewServer(group) defer ts.Close() paths := []string{"/data", "/data-2"} for _, path := range paths { // test GET request resp, err := client.Get(ts.URL + path) if err != nil { t.Fatalf("failed to make request: %v", err) } if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } if resp.Header.Get("X-Middleware") != "applied" { t.Errorf("middleware not applied") } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("failed to read response body: %v", err) } if string(body) != "data root" { t.Errorf("expected 'data root', got '%s'", body) } if closeErr := resp.Body.Close(); closeErr != nil { t.Errorf("failed to close response body: %v", closeErr) } // test POST request - should also work since no method was specified req, err := http.NewRequest(http.MethodPost, ts.URL+path, http.NoBody) if err != nil { t.Fatalf("failed to create request: %v", err) } resp, err = client.Do(req) if err != nil { t.Fatalf("failed to make request: %v", err) } if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } if resp.Header.Get("X-Middleware") != "applied" { t.Errorf("middleware not applied") } body, err = io.ReadAll(resp.Body) if err != nil { t.Fatalf("failed to read response body: %v", err) } if string(body) != "data root" { t.Errorf("expected 'data root', got '%s'", body) } if closeErr := resp.Body.Close(); closeErr != nil { t.Errorf("failed to close response body: %v", closeErr) } } }) t.Run("HandleRoot with empty base path", func(t *testing.T) { // create a group with empty base path group := routegroup.New(http.NewServeMux()) group.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Middleware", "applied") next.ServeHTTP(w, r) }) }) // handle the root path (empty base path) group.HandleRoot("GET", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, err := w.Write([]byte("root")); err != nil { t.Fatalf("failed to write response: %v", err) } })) ts := httptest.NewServer(group) defer ts.Close() // test GET request to root resp, err := client.Get(ts.URL + "/") if err != nil { t.Fatalf("failed to make request: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } if resp.Header.Get("X-Middleware") != "applied" { t.Errorf("middleware not applied") } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("failed to read response body: %v", err) } if string(body) != "root" { t.Errorf("expected 'root', got '%s'", body) } }) t.Run("HandleRootFunc with empty base path", func(t *testing.T) { // create a group with empty base path group := routegroup.New(http.NewServeMux()) group.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Middleware", "applied") next.ServeHTTP(w, r) }) }) // handle the root path (empty base path) group.HandleRootFunc("GET", func(w http.ResponseWriter, r *http.Request) { if _, err := w.Write([]byte("root")); err != nil { t.Fatalf("failed to write response: %v", err) } }) ts := httptest.NewServer(group) defer ts.Close() // test GET request to root resp, err := client.Get(ts.URL + "/") if err != nil { t.Fatalf("failed to make request: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } if resp.Header.Get("X-Middleware") != "applied" { t.Errorf("middleware not applied") } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("failed to read response body: %v", err) } if string(body) != "root" { t.Errorf("expected 'root', got '%s'", body) } }) t.Run("handle with trailing slash", func(t *testing.T) { group := routegroup.New(http.NewServeMux()) apiGroup := group.Mount("/api") apiGroup.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) { if _, err := w.Write([]byte("api root")); err != nil { t.Fatalf("failed to write response: %v", err) } }) ts := httptest.NewServer(group) defer ts.Close() resp, err := client.Get(ts.URL + "/api") if err != nil { t.Fatalf("failed to make request: %v", err) } defer resp.Body.Close() // verify trailing slash approach causes redirect if resp.StatusCode != http.StatusMovedPermanently { t.Errorf("expected redirect status 301, got %d", resp.StatusCode) } location := resp.Header.Get("Location") if location != "/api/" { t.Errorf("expected redirect to '/api/', got '%s'", location) } }) } func ExampleNew() { group := routegroup.New(http.NewServeMux()) // apply middleware to the group group.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("X-Mounted-Middleware", "true") next.ServeHTTP(w, r) }) }) // add test handlers group.HandleFunc("GET /test", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) group.HandleFunc("POST /test2", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) // start the server if err := http.ListenAndServe(":8080", group); err != nil { panic(err) } } func ExampleMount() { group := routegroup.Mount(http.NewServeMux(), "/api") // apply middleware to the group group.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("X-Test-Middleware", "true") next.ServeHTTP(w, r) }) }) // add test handlers group.HandleFunc("GET /test", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) group.HandleFunc("POST /test2", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) // start the server if err := http.ListenAndServe(":8080", group); err != nil { panic(err) } } func ExampleBundle_Route() { group := routegroup.New(http.NewServeMux()) // configure the group using Set group.Route(func(g *routegroup.Bundle) { // apply middleware to the group g.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("X-Test-Middleware", "true") next.ServeHTTP(w, r) }) }) // add test handlers g.HandleFunc("GET /test", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) g.HandleFunc("POST /test2", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) }) // start the server if err := http.ListenAndServe(":8080", group); err != nil { panic(err) } } // This example shows how to use HandleRoot to handle the root path of a mounted group without trailing slash func ExampleBundle_HandleRoot() { group := routegroup.New(http.NewServeMux()) // create API group apiGroup := group.Mount("/api") // apply middleware apiGroup.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-API", "true") next.ServeHTTP(w, r) }) }) // handle root path (responds to /api without redirect) apiGroup.HandleRoot("GET", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "API root") })) // regular routes apiGroup.HandleFunc("GET /users", func(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "List of users") }) } // TestPathParametersWithMount tests path parameter extraction with mounted groups (issue #22)