Repository: labstack/echo Branch: master Commit: 675712da34c1 Files: 127 Total size: 1.1 MB Directory structure: gitextract_q3zebxle/ ├── .editorconfig ├── .gitattributes ├── .github/ │ ├── FUNDING.yml │ ├── ISSUE_TEMPLATE.md │ ├── stale.yml │ └── workflows/ │ ├── checks.yml │ └── echo.yml ├── .gitignore ├── API_CHANGES_V5.md ├── CHANGELOG.md ├── CLAUDE.md ├── LICENSE ├── Makefile ├── README.md ├── SECURITY.md ├── _fixture/ │ ├── _fixture/ │ │ └── README.md │ ├── certs/ │ │ ├── README.md │ │ ├── cert.pem │ │ └── key.pem │ ├── dist/ │ │ ├── private.txt │ │ └── public/ │ │ ├── assets/ │ │ │ ├── readme.md │ │ │ └── subfolder/ │ │ │ └── subfolder.md │ │ ├── index.html │ │ └── test.txt │ ├── folder/ │ │ └── index.html │ └── index.html ├── bind.go ├── bind_test.go ├── binder.go ├── binder_external_test.go ├── binder_generic.go ├── binder_generic_test.go ├── binder_test.go ├── codecov.yml ├── context.go ├── context_generic.go ├── context_generic_test.go ├── context_test.go ├── echo.go ├── echo_test.go ├── echotest/ │ ├── context.go │ ├── context_external_test.go │ ├── context_test.go │ ├── reader.go │ ├── reader_external_test.go │ ├── reader_test.go │ └── testdata/ │ └── test.json ├── go.mod ├── go.sum ├── group.go ├── group_test.go ├── httperror.go ├── httperror_external_test.go ├── httperror_test.go ├── ip.go ├── ip_test.go ├── json.go ├── json_test.go ├── middleware/ │ ├── DEVELOPMENT.md │ ├── basic_auth.go │ ├── basic_auth_test.go │ ├── body_dump.go │ ├── body_dump_test.go │ ├── body_limit.go │ ├── body_limit_test.go │ ├── compress.go │ ├── compress_test.go │ ├── context_timeout.go │ ├── context_timeout_test.go │ ├── cors.go │ ├── cors_test.go │ ├── csrf.go │ ├── csrf_test.go │ ├── decompress.go │ ├── decompress_test.go │ ├── extractor.go │ ├── extractor_test.go │ ├── key_auth.go │ ├── key_auth_test.go │ ├── method_override.go │ ├── method_override_test.go │ ├── middleware.go │ ├── middleware_test.go │ ├── proxy.go │ ├── proxy_test.go │ ├── rate_limiter.go │ ├── rate_limiter_test.go │ ├── recover.go │ ├── recover_test.go │ ├── redirect.go │ ├── redirect_test.go │ ├── request_id.go │ ├── request_id_test.go │ ├── request_logger.go │ ├── request_logger_test.go │ ├── rewrite.go │ ├── rewrite_test.go │ ├── secure.go │ ├── secure_test.go │ ├── slash.go │ ├── slash_test.go │ ├── static.go │ ├── static_other.go │ ├── static_test.go │ ├── testdata/ │ │ ├── dist/ │ │ │ ├── private.txt │ │ │ └── public/ │ │ │ ├── assets/ │ │ │ │ ├── readme.md │ │ │ │ └── subfolder/ │ │ │ │ └── subfolder.md │ │ │ ├── index.html │ │ │ └── test.txt │ │ └── private.txt │ ├── util.go │ └── util_test.go ├── renderer.go ├── renderer_test.go ├── response.go ├── response_test.go ├── route.go ├── route_test.go ├── router.go ├── router_concurrent.go ├── router_concurrent_test.go ├── router_test.go ├── server.go ├── server_test.go ├── version.go ├── vhost.go └── vhost_test.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .editorconfig ================================================ # EditorConfig coding styles definitions. For more information about the # properties used in this file, please see the EditorConfig documentation: # http://editorconfig.org/ # indicate this is the root of the project root = true [*] charset = utf-8 end_of_line = LF insert_final_newline = true trim_trailing_whitespace = true indent_style = space indent_size = 2 [Makefile] indent_style = tab [*.md] trim_trailing_whitespace = false [*.go] indent_style = tab ================================================ FILE: .gitattributes ================================================ # Automatically normalize line endings for all text-based files # http://git-scm.com/docs/gitattributes#_end_of_line_conversion * text=auto # For the following file types, normalize line endings to LF on checking and # prevent conversion to CRLF when they are checked out (this is required in # order to prevent newline related issues) .* text eol=lf *.go text eol=lf *.yml text eol=lf *.html text eol=lf *.css text eol=lf *.js text eol=lf *.json text eol=lf LICENSE text eol=lf ================================================ FILE: .github/FUNDING.yml ================================================ # These are supported funding model platforms github: [labstack] patreon: # Replace with a single Patreon username open_collective: # Replace with a single Open Collective username ko_fi: # Replace with a single Ko-fi username tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry liberapay: # Replace with a single Liberapay username issuehunt: # Replace with a single IssueHunt username otechie: # Replace with a single Otechie username custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] ================================================ FILE: .github/ISSUE_TEMPLATE.md ================================================ ### Issue Description ### Working code to debug ```go package main import ( "github.com/labstack/echo/v5" "net/http" "net/http/httptest" "testing" ) func TestExample(t *testing.T) { e := echo.New() e.GET("/", func(c *echo.Context) error { return c.String(http.StatusOK, "Hello, World!") }) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("got %d, want %d", rec.Code, http.StatusOK) } } ``` ### Version/commit ================================================ FILE: .github/stale.yml ================================================ # Number of days of inactivity before an issue becomes stale daysUntilStale: 60 # Number of days of inactivity before a stale issue is closed daysUntilClose: 30 # Issues with these labels will never be considered stale exemptLabels: - pinned - security - bug - enhancement # Label to use when marking an issue as stale staleLabel: stale # Comment to post when marking an issue as stale. Set to `false` to disable markComment: > This issue has been automatically marked as stale because it has not had recent activity. It will be closed within a month if no further activity occurs. Thank you for your contributions. # Comment to post when closing a stale issue. Set to `false` to disable closeComment: false ================================================ FILE: .github/workflows/checks.yml ================================================ name: Run checks on: push: branches: - master pull_request: branches: - master workflow_dispatch: permissions: contents: read # to fetch code (actions/checkout) env: # run static analysis only with the latest Go version LATEST_GO_VERSION: "1.26" jobs: check: runs-on: ubuntu-latest steps: - name: Checkout Code uses: actions/checkout@v5 - name: Set up Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ env.LATEST_GO_VERSION }} check-latest: true - name: Run golint run: | go install golang.org/x/lint/golint@latest golint -set_exit_status ./... - name: Run staticcheck run: | go install honnef.co/go/tools/cmd/staticcheck@latest staticcheck ./... - name: Run govulncheck run: | go version go install golang.org/x/vuln/cmd/govulncheck@latest govulncheck ./... ================================================ FILE: .github/workflows/echo.yml ================================================ name: Run Tests on: push: branches: - master pull_request: branches: - master workflow_dispatch: permissions: contents: read # to fetch code (actions/checkout) env: # run coverage and benchmarks only with the latest Go version LATEST_GO_VERSION: "1.26" jobs: test: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy # Echo tests with last four major releases (unless there are pressing vulnerabilities) # As we depend on `golang.org/x/` libraries which only support the last 2 Go releases, we could have situations when # we derive from the last four major releases promise. go: ["1.25", "1.26"] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: - name: Checkout Code uses: actions/checkout@v5 - name: Set up Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Run Tests run: go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... - name: Upload coverage to Codecov if: success() && matrix.go == env.LATEST_GO_VERSION && matrix.os == 'ubuntu-latest' uses: codecov/codecov-action@v5 with: token: fail_ci_if_error: false benchmark: needs: test name: Benchmark comparison runs-on: ubuntu-latest steps: - name: Checkout Code (Previous) uses: actions/checkout@v5 with: ref: ${{ github.base_ref }} path: previous - name: Checkout Code (New) uses: actions/checkout@v5 with: path: new - name: Set up Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ env.LATEST_GO_VERSION }} - name: Install Dependencies run: go install golang.org/x/perf/cmd/benchstat@latest - name: Run Benchmark (Previous) run: | cd previous go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt - name: Run Benchmark (New) run: | cd new go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt - name: Run Benchstat run: | benchstat previous/benchmark.txt new/benchmark.txt ================================================ FILE: .gitignore ================================================ .DS_Store coverage.txt _test vendor .idea *.iml *.out .vscode ================================================ FILE: API_CHANGES_V5.md ================================================ # Echo v5 Public API Changes **Comparison between `master` (v4.15.0) and `v5` (v5.0.0-alpha) branches** Generated: 2026-01-01 --- ## Executive Summary (by authors) Echo `v5` is maintenance release with **major breaking changes** - `Context` is now struct instead of interface and we can add method to it in the future in minor versions. - Adds new `Router` interface for possible new routing implementations. - Drops old logging interface and uses moderm `log/slog` instead. - Rearranges alot of methods/function signatures to make them more consistent. ## Executive Summary (by LLMs) Echo v5 represents a **major breaking release** with significant architectural changes focused on: - **Updated generic helpers** to take `*Context` and rename form helpers to `FormValue*` - **Simplified API surface** by moving Context from interface to concrete struct - **Modern Go patterns** including slog.Logger integration - **Enhanced routing** with explicit RouteInfo and Routes types - **Better error handling** with simplified HTTPError - **New test helpers** via the `echotest` package ### Change Statistics - **Major Breaking Changes**: 15+ - **New Functions Added**: 30+ - **Type Signature Changes**: 20+ - **Removed APIs**: 10+ - **New Packages Added**: 1 (`echotest`) - **Version Change**: `4.15.0` → `5.0.0-alpha` --- ## Critical Breaking Changes ### 1. **Context: Interface → Concrete Struct** **v4 (master):** ```go type Context interface { Request() *http.Request // ... many methods } // Handler signature func handler(c echo.Context) error ``` **v5:** ```go type Context struct { // Has unexported fields } // Handler signature - NOW USES POINTER! func handler(c *echo.Context) error ``` **Impact:** 🔴 **CRITICAL BREAKING CHANGE** - ALL handlers must change from `echo.Context` to `*echo.Context` - Context is now a concrete struct, not an interface - This affects every single handler function in user code **Migration:** ```go // Before (v4) func MyHandler(c echo.Context) error { return c.JSON(200, map[string]string{"hello": "world"}) } // After (v5) func MyHandler(c *echo.Context) error { return c.JSON(200, map[string]string{"hello": "world"}) } ``` --- ### 2. **Logger: Custom Interface → slog.Logger** **v4:** ```go type Echo struct { Logger Logger // Custom interface with Print, Debug, Info, etc. } type Logger interface { Output() io.Writer SetOutput(w io.Writer) Prefix() string // ... many custom methods } // Context returns Logger interface func (c Context) Logger() Logger ``` **v5:** ```go type Echo struct { Logger *slog.Logger // Standard library structured logger } // Context returns slog.Logger func (c *Context) Logger() *slog.Logger func (c *Context) SetLogger(logger *slog.Logger) ``` **Impact:** 🔴 **BREAKING CHANGE** - Must use Go's standard `log/slog` package - Logger interface completely removed - All logging code needs updating --- ### 3. **Router: From Router to DefaultRouter** **v4:** ```go type Router struct { ... } func NewRouter(e *Echo) *Router func (e *Echo) Router() *Router ``` **v5:** ```go type DefaultRouter struct { ... } func NewRouter(config RouterConfig) *DefaultRouter func (e *Echo) Router() Router // Returns interface ``` **Changes:** - New `Router` interface introduced - `DefaultRouter` is the concrete implementation - `NewRouter()` now takes `RouterConfig` instead of `*Echo` - Added `NewConcurrentRouter(r Router) Router` for thread-safe routing --- ### 4. **Route Return Types Changed** **v4:** ```go func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route func (e *Echo) Any(path string, h HandlerFunc, m ...MiddlewareFunc) []*Route func (e *Echo) Routes() []*Route ``` **v5:** ```go func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo func (e *Echo) Any(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo func (e *Echo) Match(...) Routes // Returns Routes type func (e *Echo) Router() Router // Returns interface ``` **New Types:** ```go type RouteInfo struct { Name string Method string Path string Parameters []string } type Routes []RouteInfo // Collection with helper methods ``` **Impact:** 🔴 **BREAKING CHANGE** - Route registration methods return `RouteInfo` instead of `*Route` - New `Routes` collection type with filtering methods - `Route` struct still exists but used differently --- ### 5. **Response Type Changed** **v4:** ```go func (c Context) Response() *Response type Response struct { Writer http.ResponseWriter Status int Size int64 Committed bool } func NewResponse(w http.ResponseWriter, e *Echo) *Response ``` **v5:** ```go func (c *Context) Response() http.ResponseWriter type Response struct { http.ResponseWriter // Embedded Status int Size int64 Committed bool } func NewResponse(w http.ResponseWriter, logger *slog.Logger) *Response func UnwrapResponse(rw http.ResponseWriter) (*Response, error) ``` **Changes:** - Context.Response() returns `http.ResponseWriter` instead of `*Response` - Response now embeds `http.ResponseWriter` - NewResponse takes `*slog.Logger` instead of `*Echo` - New `UnwrapResponse()` helper function --- ### 6. **HTTPError Simplified** **v4:** ```go type HTTPError struct { Internal error Message interface{} // Can be any type Code int } func NewHTTPError(code int, message ...interface{}) *HTTPError ``` **v5:** ```go type HTTPError struct { Code int Message string // Now string only // Has unexported fields (Internal moved) } func NewHTTPError(code int, message string) *HTTPError func (he HTTPError) Wrap(err error) error // New method func (he *HTTPError) StatusCode() int // Implements HTTPStatusCoder ``` **Changes:** - `Message` field changed from `interface{}` to `string` - `NewHTTPError()` now takes `string` instead of `...interface{}` - Added `HTTPStatusCoder` interface and `StatusCode()` method - Added `Wrap(err error)` method for error wrapping --- ### 7. **HTTPErrorHandler Signature Changed** **v4:** ```go type HTTPErrorHandler func(err error, c Context) func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) ``` **v5:** ```go type HTTPErrorHandler func(c *Context, err error) // Parameters swapped! func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler // Now a factory ``` **Impact:** 🔴 **BREAKING CHANGE** - Parameter order reversed: `(c *Context, err error)` instead of `(err error, c Context)` - DefaultHTTPErrorHandler is now a factory function that returns HTTPErrorHandler - Takes `exposeError` bool to control error message exposure --- ## Notable API Changes in v5 ### 1. **Generic Parameter Extraction Functions (Updated Signatures)** These helpers keep the same generic API but now accept `*Context`, and the form helpers are renamed from `FormParam*` to `FormValue*`: ```go // Query Parameters func QueryParam[T any](c *Context, key string, opts ...any) (T, error) func QueryParamOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) func QueryParams[T any](c *Context, key string, opts ...any) ([]T, error) func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) // Path Parameters func PathParam[T any](c *Context, paramName string, opts ...any) (T, error) func PathParamOr[T any](c *Context, paramName string, defaultValue T, opts ...any) (T, error) // Form Values func FormValue[T any](c *Context, key string, opts ...any) (T, error) func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) func FormValues[T any](c *Context, key string, opts ...any) ([]T, error) func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) // Generic Parsing func ParseValue[T any](value string, opts ...any) (T, error) func ParseValueOr[T any](value string, defaultValue T, opts ...any) (T, error) func ParseValues[T any](values []string, opts ...any) ([]T, error) func ParseValuesOr[T any](values []string, defaultValue []T, opts ...any) ([]T, error) ``` `FormParam*` was renamed to `FormValue*`; the rest keep names but now take `*Context`. **Supported Types:** - bool, string - int, int8, int16, int32, int64 - uint, uint8, uint16, uint32, uint64 - float32, float64 - time.Time, time.Duration - BindUnmarshaler, encoding.TextUnmarshaler, json.Unmarshaler **Example Usage:** ```go // v5 - Type-safe parameter binding id, err := echo.PathParam[int](c, "id") page, err := echo.QueryParamOr[int](c, "page", 1) tags, err := echo.QueryParams[string](c, "tags") ``` --- ### 2. **Context Store Helpers Now Use `*Context`** ```go // Type-safe context value retrieval func ContextGet[T any](c *Context, key string) (T, error) func ContextGetOr[T any](c *Context, key string, defaultValue T) (T, error) // Error types var ErrNonExistentKey = errors.New("non existent key") var ErrInvalidKeyType = errors.New("invalid key type") ``` These helpers existed in v4 with `Context` and now accept `*Context`. **Example:** ```go // v5 user, err := echo.ContextGet[*User](c, "user") count, err := echo.ContextGetOr[int](c, "count", 0) ``` --- ### 3. **PathValues Type** New structured path parameter handling: ```go type PathValue struct { Name string Value string } type PathValues []PathValue func (p PathValues) Get(name string) (string, bool) func (p PathValues) GetOr(name string, defaultValue string) string // Context methods func (c *Context) PathValues() PathValues func (c *Context) SetPathValues(pathValues PathValues) ``` --- ### 4. **Time Parsing Options** ```go type TimeLayout string const ( TimeLayoutUnixTime = TimeLayout("UnixTime") TimeLayoutUnixTimeMilli = TimeLayout("UnixTimeMilli") TimeLayoutUnixTimeNano = TimeLayout("UnixTimeNano") ) type TimeOpts struct { Layout TimeLayout ParseInLocation *time.Location ToInLocation *time.Location } ``` --- ### 5. **StartConfig for Server Configuration** ```go type StartConfig struct { Address string HideBanner bool HidePort bool CertFilesystem fs.FS TLSConfig *tls.Config ListenerNetwork string ListenerAddrFunc func(addr net.Addr) GracefulTimeout time.Duration OnShutdownError func(err error) BeforeServeFunc func(s *http.Server) error } func (sc StartConfig) Start(ctx context.Context, h http.Handler) error func (sc StartConfig) StartTLS(ctx context.Context, h http.Handler, certFile, keyFile any) error ``` **Example:** ```go // v5 - More control over server startup ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer cancel() sc := echo.StartConfig{ Address: ":8080", GracefulTimeout: 10 * time.Second, } if err := sc.Start(ctx, e); err != nil { log.Fatal(err) } ``` --- ### 6. **Echo Config and Constructors** ```go type Config struct { // Configuration for Echo (logger, binder, renderer, etc.) } func NewWithConfig(config Config) *Echo ``` This adds a configuration struct for creating an `Echo` instance without mutating fields after `New()`. --- ### 7. **Enhanced Routing Features** ```go // New route methods func (e *Echo) AddRoute(route Route) (RouteInfo, error) func (e *Echo) Middlewares() []MiddlewareFunc func (e *Echo) PreMiddlewares() []MiddlewareFunc type AddRouteError struct{ ... } // Routes collection with filters type Routes []RouteInfo func (r Routes) Clone() Routes func (r Routes) FilterByMethod(method string) (Routes, error) func (r Routes) FilterByName(name string) (Routes, error) func (r Routes) FilterByPath(path string) (Routes, error) func (r Routes) FindByMethodPath(method string, path string) (RouteInfo, error) func (r Routes) Reverse(routeName string, pathValues ...any) (string, error) // RouteInfo operations func (r RouteInfo) Clone() RouteInfo func (r RouteInfo) Reverse(pathValues ...any) string ``` --- ### 8. **Middleware Configuration Interface** ```go type MiddlewareConfigurator interface { ToMiddleware() (MiddlewareFunc, error) } ``` Allows middleware configs to be converted to middleware without panicking. --- ### 9. **New Context Methods** ```go // v5 additions func (c *Context) FileFS(file string, filesystem fs.FS) error func (c *Context) FormValueOr(name, defaultValue string) string func (c *Context) InitializeRoute(ri *RouteInfo, pathValues *PathValues) func (c *Context) ParamOr(name, defaultValue string) string func (c *Context) QueryParamOr(name, defaultValue string) string func (c *Context) RouteInfo() RouteInfo ``` --- ### 10. **Virtual Host Support** ```go func NewVirtualHostHandler(vhosts map[string]*Echo) *Echo ``` Creates an Echo instance that routes requests to different Echo instances based on host. --- ### 11. **New Binder Functions** ```go func BindBody(c *Context, target any) error func BindHeaders(c *Context, target any) error func BindPathValues(c *Context, target any) error // Renamed from BindPathParams func BindQueryParams(c *Context, target any) error ``` Top-level binding functions that work with `*Context`. --- ### 12. **New echotest Package** ```go package echotest // import "github.com/labstack/echo/v5/echotest" func LoadBytes(t *testing.T, name string, opts ...loadBytesOpts) []byte func TrimNewlineEnd(bytes []byte) []byte type ContextConfig struct{ ... } type MultipartForm struct{ ... } type MultipartFormFile struct{ ... } ``` Helpers for loading fixtures and constructing test contexts. --- ## Removed APIs in v5 ### Constants ```go // v4 - Removed in v5 const CONNECT = http.MethodConnect // Use http.MethodConnect directly ``` **Reason:** Deprecated in v4, use stdlib `http.Method*` constants instead. --- ### Constants Added in v5 ```go // v5 additions const ( NotFoundRouteName = "echo_route_not_found_name" ) ``` --- ### Error Variable Changes **v4 exports:** ```go ErrBadRequest ErrInvalidKeyType ErrNonExistentKey ``` **v5 exports:** ```go ErrBadRequest // Now backed by unexported httpError type ErrValidatorNotRegistered // New ErrInvalidKeyType ErrNonExistentKey ``` **Reason:** v5 centralizes on `NewHTTPError(code, message)` rather than a broad set of predefined HTTP error variables. --- ### Functions Removed ```go // v4 - Removed in v5 func GetPath(r *http.Request) string // Use r.URL.Path or r.URL.RawPath ``` ### Variables Removed ```go // v4 - Removed in v5 var MethodNotAllowedHandler = func(c Context) error { ... } var NotFoundHandler = func(c Context) error { ... } ``` ### Functions Renamed ```go // v4 func FormParam[T any](c Context, key string, opts ...any) (T, error) func FormParamOr[T any](c Context, key string, defaultValue T, opts ...any) (T, error) func FormParams[T any](c Context, key string, opts ...any) ([]T, error) func FormParamsOr[T any](c Context, key string, defaultValue []T, opts ...any) ([]T, error) // v5 func FormValue[T any](c *Context, key string, opts ...any) (T, error) func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) func FormValues[T any](c *Context, key string, opts ...any) ([]T, error) func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) ``` --- ### Type Methods Removed/Changed **Echo struct changes:** ```go // v4 fields removed in v5 type Echo struct { StdLogger *stdLog.Logger // Removed Server *http.Server // Removed (use StartConfig) TLSServer *http.Server // Removed (use StartConfig) Listener net.Listener // Removed (use StartConfig) TLSListener net.Listener // Removed (use StartConfig) AutoTLSManager autocert.Manager // Removed ListenerNetwork string // Removed OnAddRouteHandler func(...) // Changed to OnAddRoute DisableHTTP2 bool // Removed (use StartConfig) Debug bool // Removed HideBanner bool // Removed (use StartConfig) HidePort bool // Removed (use StartConfig) } // v5 Echo struct (simplified) type Echo struct { Binder Binder Filesystem fs.FS // NEW Renderer Renderer Validator Validator JSONSerializer JSONSerializer IPExtractor IPExtractor OnAddRoute func(route Route) error // Simplified HTTPErrorHandler HTTPErrorHandler Logger *slog.Logger // Changed from Logger interface } ``` --- **Context interface → struct:** ```go // v4 type Context interface { // Had: SetResponse(*Response) Response() *Response // Had: ParamNames(), SetParamNames(), ParamValues(), SetParamValues() // These are removed in v5 (use PathValues() instead) } // v5 type Context struct { // Concrete struct with unexported fields } func (c *Context) Response() http.ResponseWriter // Changed return type func (c *Context) PathValues() PathValues // Replaces ParamNames/Values ``` --- **Types removed:** ```go // v4 type Map map[string]interface{} ``` **Group changes:** ```go // v4 func (g *Group) File(path, file string) // No return value func (g *Group) Static(pathPrefix, fsRoot string) // No return value func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) // No return value // v5 func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo func (g *Group) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo ``` Now return `RouteInfo` and accept middleware. --- ### Value Binder Factory Name Changes ```go // v4 func PathParamsBinder(c Context) *ValueBinder func QueryParamsBinder(c Context) *ValueBinder func FormFieldBinder(c Context) *ValueBinder // v5 func PathValuesBinder(c *Context) *ValueBinder // Renamed func QueryParamsBinder(c *Context) *ValueBinder func FormFieldBinder(c *Context) *ValueBinder ``` --- ## Type Signature Changes ### Binder Interface ```go // v4 type Binder interface { Bind(i interface{}, c Context) error } // v5 type Binder interface { Bind(c *Context, target any) error // Parameters swapped! } ``` --- ### DefaultBinder Methods ```go // v4 func (b *DefaultBinder) Bind(i interface{}, c Context) error func (b *DefaultBinder) BindBody(c Context, i interface{}) error func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error // v5 func (b *DefaultBinder) Bind(c *Context, target any) error // Swapped params // BindBody, BindPathParams, etc. are now top-level functions ``` --- ### JSONSerializer Interface ```go // v4 type JSONSerializer interface { Serialize(c Context, i interface{}, indent string) error Deserialize(c Context, i interface{}) error } // v5 type JSONSerializer interface { Serialize(c *Context, target any, indent string) error Deserialize(c *Context, target any) error } ``` --- ### Renderer Interface ```go // v4 type Renderer interface { Render(io.Writer, string, interface{}, Context) error } // v5 type Renderer interface { Render(c *Context, w io.Writer, templateName string, data any) error } ``` Parameters reordered with Context first. --- ### NewBindingError ```go // v4 func NewBindingError(sourceParam string, values []string, message interface{}, internalError error) error // v5 func NewBindingError(sourceParam string, values []string, message string, err error) error ``` Message parameter changed from `interface{}` to `string`. --- ### HandlerName ```go // v5 only func HandlerName(h HandlerFunc) string ``` New utility function to get handler function name. --- ## Middleware Package Changes ### Signature and Type Updates ```go // CORS now accepts optional allow-origins func CORS(allowOrigins ...string) echo.MiddlewareFunc // BodyLimit now accepts bytes func BodyLimit(limitBytes int64) echo.MiddlewareFunc // DefaultSkipper now uses *echo.Context func DefaultSkipper(c *echo.Context) bool // Trailing slash configs renamed/split func AddTrailingSlashWithConfig(config AddTrailingSlashConfig) echo.MiddlewareFunc func RemoveTrailingSlashWithConfig(config RemoveTrailingSlashConfig) echo.MiddlewareFunc type AddTrailingSlashConfig struct{ ... } type RemoveTrailingSlashConfig struct{ ... } // Auth + extractor signatures now use *echo.Context and add ExtractorSource type BasicAuthValidator func(c *echo.Context, user string, password string) (bool, error) type Extractor func(c *echo.Context) (string, error) type ExtractorSource string type KeyAuthValidator func(c *echo.Context, key string, source ExtractorSource) (bool, error) type KeyAuthErrorHandler func(c *echo.Context, err error) error // BodyDump handler now includes err type BodyDumpHandler func(c *echo.Context, reqBody []byte, resBody []byte, err error) // ValuesExtractor now returns extractor source and CreateExtractors takes a limit type ValuesExtractor func(c *echo.Context) ([]string, ExtractorSource, error) func CreateExtractors(lookups string, limit uint) ([]ValuesExtractor, error) type ValueExtractorError struct{ ... } // New constants const KB = 1024 // Rate limiter store now takes a float64 limit func NewRateLimiterMemoryStore(rateLimit float64) (store *RateLimiterMemoryStore) ``` ### Added Middleware Exports ```go var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key") var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key") var RedirectHTTPSConfig = RedirectConfig{ ... } var RedirectHTTPSWWWConfig = RedirectConfig{ ... } var RedirectNonHTTPSWWWConfig = RedirectConfig{ ... } var RedirectNonWWWConfig = RedirectConfig{ ... } var RedirectWWWConfig = RedirectConfig{ ... } ``` ### Removed/Consolidated Middleware Exports ```go // Removed in v5 func Logger() echo.MiddlewareFunc func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc func Timeout() echo.MiddlewareFunc func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc type ErrKeyAuthMissing struct{ ... } type CSRFErrorHandler func(err error, c echo.Context) error type LoggerConfig struct{ ... } type LogErrorFunc func(c echo.Context, err error, stack []byte) error type TargetProvider interface{ ... } type TrailingSlashConfig struct{ ... } type TimeoutConfig struct{ ... } ``` Also removed defaults: `DefaultBasicAuthConfig`, `DefaultBodyDumpConfig`, `DefaultBodyLimitConfig`, `DefaultCORSConfig`, `DefaultDecompressConfig`, `DefaultGzipConfig`, `DefaultLoggerConfig`, `DefaultRedirectConfig`, `DefaultRequestIDConfig`, `DefaultRewriteConfig`, `DefaultTimeoutConfig`, `DefaultTrailingSlashConfig`. --- ## Router Interface Changes ### v4 Router (Concrete Struct) ```go type Router struct { ... } func NewRouter(e *Echo) *Router func (r *Router) Add(method, path string, h HandlerFunc) func (r *Router) Find(method, path string, c Context) func (r *Router) Reverse(name string, params ...interface{}) string func (r *Router) Routes() []*Route ``` ### v5 Router (Interface + DefaultRouter) ```go type Router interface { Add(routable Route) (RouteInfo, error) Remove(method string, path string) error Routes() Routes Route(c *Context) HandlerFunc } type DefaultRouter struct { ... } func NewRouter(config RouterConfig) *DefaultRouter func NewConcurrentRouter(r Router) Router // NEW type RouterConfig struct { NotFoundHandler HandlerFunc MethodNotAllowedHandler HandlerFunc OptionsMethodHandler HandlerFunc AllowOverwritingRoute bool UnescapePathParamValues bool UseEscapedPathForMatching bool } ``` **Key Changes:** - Router is now an interface - DefaultRouter is the concrete implementation - Add() returns `(RouteInfo, error)` instead of being void - New `Remove()` method - New `Route()` method replaces `Find()` - Configuration through `RouterConfig` --- ## Echo Instance Method Changes ### Route Registration ```go // v4 func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route // v5 func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo func (e *Echo) AddRoute(route Route) (RouteInfo, error) // NEW ``` ### Static File Serving ```go // v4 func (e *Echo) Static(pathPrefix, fsRoot string) *Route func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route // v5 func (e *Echo) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo ``` Return type changed from `*Route` to `RouteInfo`. ### Server Management ```go // v4 func (e *Echo) Start(address string) error func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) error func (e *Echo) StartAutoTLS(address string) error func (e *Echo) StartH2CServer(address string, h2s *http2.Server) error func (e *Echo) StartServer(s *http.Server) error func (e *Echo) Shutdown(ctx context.Context) error func (e *Echo) Close() error func (e *Echo) ListenerAddr() net.Addr func (e *Echo) TLSListenerAddr() net.Addr func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) // v5 func (e *Echo) Start(address string) error // Simplified func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) // Removed: StartTLS, StartAutoTLS, StartH2CServer, StartServer // Use StartConfig instead for advanced server configuration // Removed: Shutdown, Close, ListenerAddr, TLSListenerAddr // Removed: DefaultHTTPErrorHandler (now a top-level factory function) ``` **v5 provides** `StartConfig` type for all advanced server configuration. ### Router Access ```go // v4 func (e *Echo) Router() *Router func (e *Echo) Routers() map[string]*Router // For multi-host func (e *Echo) Routes() []*Route func (e *Echo) Reverse(name string, params ...interface{}) string func (e *Echo) URI(handler HandlerFunc, params ...interface{}) string func (e *Echo) URL(h HandlerFunc, params ...interface{}) string func (e *Echo) Host(name string, m ...MiddlewareFunc) *Group // v5 func (e *Echo) Router() Router // Returns interface // Removed: Routers(), Reverse(), URI(), URL(), Host() // Use router.Routes() and Routes.Reverse() instead ``` --- ## NewContext Changes ```go // v4 func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context func NewResponse(w http.ResponseWriter, e *Echo) *Response // v5 func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) *Context func NewContext(r *http.Request, w http.ResponseWriter, opts ...any) *Context // Standalone func NewResponse(w http.ResponseWriter, logger *slog.Logger) *Response ``` --- ## Migration Guide Summary If you are using Linux you can migrate easier parts like that: ```bash find . -type f -name "*.go" -exec sed -i 's/ echo.Context/ *echo.Context/g' {} + find . -type f -name "*.go" -exec sed -i 's/echo\/v4/echo\/v5/g' {} + ``` or in your favorite IDE Replace all: 1. ` echo.Context` -> ` *echo.Context` 2. `echo/v4` -> `echo/v5` ### 1. Update All Handler Signatures ```go // Before func MyHandler(c echo.Context) error { ... } // After func MyHandler(c *echo.Context) error { ... } ``` ### 2. Update Logger Usage ```go // Before e.Logger.Info("Server started") c.Logger().Error("Something went wrong") // After e.Logger.Info("Server started") c.Logger().Error("Something went wrong") // Same API, different logger ``` ### 3. Use Type-Safe Parameter Extraction ```go // Before idStr := c.Param("id") id, err := strconv.Atoi(idStr) // After id, err := echo.PathParam[int](c, "id") ``` ### 4. Update Error Handler ```go // Before e.HTTPErrorHandler = func(err error, c echo.Context) { // handle error } // After e.HTTPErrorHandler = func(c *echo.Context, err error) { // Swapped! // handle error } // Or use factory e.HTTPErrorHandler = echo.DefaultHTTPErrorHandler(true) // exposeError=true ``` ### 5. Update Server Startup ```go // Before e.Start(":8080") e.StartTLS(":443", "cert.pem", "key.pem") // After // Simple e.Start(":8080") // Advanced with graceful shutdown ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) defer cancel() sc := echo.StartConfig{Address: ":8080"} sc.Start(ctx, e) ``` ### 6. Update Route Info Access ```go // Before routes := e.Routes() for _, r := range routes { fmt.Println(r.Method, r.Path) } // After routes := e.Router().Routes() for _, r := range routes { fmt.Println(r.Method, r.Path) } ``` ### 7. Update HTTPError Creation ```go // Before return echo.NewHTTPError(400, "invalid request", someDetail) // After return echo.NewHTTPError(400, "invalid request") ``` ### 8. Update Custom Binder ```go // Before type MyBinder struct{} func (b *MyBinder) Bind(i interface{}, c echo.Context) error { ... } // After type MyBinder struct{} func (b *MyBinder) Bind(c *echo.Context, target any) error { ... } // Swapped! ``` ### 9. Path Parameters ```go // Before names := c.ParamNames() values := c.ParamValues() // After pathValues := c.PathValues() for _, pv := range pathValues { fmt.Println(pv.Name, pv.Value) } ``` ### 10. Response Access ```go // Before resp := c.Response() resp.Header().Set("X-Custom", "value") // After c.Response().Header().Set("X-Custom", "value") // Returns http.ResponseWriter // To get *echo.Response resp, err := echo.UnwrapResponse(c.Response()) ``` ### Go Version Requirements - **v4**: Go 1.24.0 (per `go.mod`) - **v5**: Go 1.25.0 (per `go.mod`) --- **Generated by comparing `go doc` output from master (v4.15.0) and v5 (v5.0.0-alpha) branches** ================================================ FILE: CHANGELOG.md ================================================ # Changelog ## v5.0.4 - 2026-02-15 **Enhancements** * Remove unused import 'errors' from README example by @kumapower17 in https://github.com/labstack/echo/pull/2889 * Fix Graceful shutdown: after `http.Server.Serve` returns we need to wait for graceful shutdown goroutine to finish by @aldas in https://github.com/labstack/echo/pull/2898 * Update location of oapi-codegen in README by @mromaszewicz in https://github.com/labstack/echo/pull/2896 * Add Go 1.26 to CI flow by @aldas in https://github.com/labstack/echo/pull/2899 * Add new function `echo.StatusCode` by @suwakei in https://github.com/labstack/echo/pull/2892 * CSRF: support older token-based CSRF protection handler that want to render token into template by @aldas in https://github.com/labstack/echo/pull/2894 * Add `echo.ResolveResponseStatus` function to help middleware/handlers determine HTTP status code and echo.Response by @aldas in https://github.com/labstack/echo/pull/2900 ## v5.0.3 - 2026-02-06 **Security** * Fix directory traversal vulnerability under Windows in Static middleware when default Echo filesystem is used. Reported by @shblue21. This applies to cases when: - Windows is used as OS - `middleware.StaticConfig.Filesystem` is `nil` (default) - `echo.Filesystem` is has not been set explicitly (default) Exposure is restricted to the active process working directory and its subfolders. ## v5.0.2 - 2026-02-02 **Security** * Fix Static middleware with `config.Browse=true` lists all files/subfolders from `config.Filesystem` root and not starting from `config.Root` in https://github.com/labstack/echo/pull/2887 ## v5.0.1 - 2026-01-28 * Panic MW: will now return a custom PanicStackError with stack trace by @aldas in https://github.com/labstack/echo/pull/2871 * Docs: add missing err parameter to DenyHandler example by @cgalibern in https://github.com/labstack/echo/pull/2878 * improve: improve websocket checks in IsWebSocket() [per RFC 6455] by @raju-mechatronics in https://github.com/labstack/echo/pull/2875 * fix: Context.Json() should not send status code before serialization is complete by @aldas in https://github.com/labstack/echo/pull/2877 ## v5.0.0 - 2026-01-18 Echo `v5` is maintenance release with **major breaking changes** - `Context` is now struct instead of interface and we can add method to it in the future in minor versions. - Adds new `Router` interface for possible new routing implementations. - Drops old logging interface and uses moderm `log/slog` instead. - Rearranges alot of methods/function signatures to make them more consistent. Upgrade notes and `v4` support: - Echo `v4` is supported with **security*** updates and **bug** fixes until **2026-12-31** - If you are using Echo in a production environment, it is recommended to wait until after 2026-03-31 before upgrading. - Until 2026-03-31, any critical issues requiring breaking `v5` API changes will be addressed, even if this violates semantic versioning. See [API_CHANGES_V5.md](./API_CHANGES_V5.md) for public API changes between `v4` and `v5`, notes on **upgrading**. Upgrading TLDR: If you are using Linux you can migrate easier parts like that: ```bash find . -type f -name "*.go" -exec sed -i 's/ echo.Context/ *echo.Context/g' {} + find . -type f -name "*.go" -exec sed -i 's/echo\/v4/echo\/v5/g' {} + ``` macOS ```bash find . -type f -name "*.go" -exec sed -i '' 's/ echo.Context/ *echo.Context/g' {} + find . -type f -name "*.go" -exec sed -i '' 's/echo\/v4/echo\/v5/g' {} + ``` or in your favorite IDE Replace all: 1. ` echo.Context` -> ` *echo.Context` 2. `echo/v4` -> `echo/v5` This should solve most of the issues. Probably the hardest part is updating all the tests. ## v4.15.0 - 2026-01-01 **Security** NB: **If your application relies on cross-origin or same-site (same subdomain) requests do not blindly push this version to production** The CSRF middleware now supports the [**Sec-Fetch-Site**](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Sec-Fetch-Site) header as a modern, defense-in-depth approach to [CSRF protection](https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers), implementing the OWASP-recommended Fetch Metadata API alongside the traditional token-based mechanism. **How it works:** Modern browsers automatically send the `Sec-Fetch-Site` header with all requests, indicating the relationship between the request origin and the target. The middleware uses this to make security decisions: - **`same-origin`** or **`none`**: Requests are allowed (exact origin match or direct user navigation) - **`same-site`**: Falls back to token validation (e.g., subdomain to main domain) - **`cross-site`**: Blocked by default with 403 error for unsafe methods (POST, PUT, DELETE, PATCH) For browsers that don't send this header (older browsers), the middleware seamlessly falls back to traditional token-based CSRF protection. **New Configuration Options:** - `TrustedOrigins []string`: Allowlist specific origins for cross-site requests (useful for OAuth callbacks, webhooks) - `AllowSecFetchSiteFunc func(echo.Context) (bool, error)`: Custom logic for same-site/cross-site request validation **Example:** ```go e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{ // Allow OAuth callbacks from trusted provider TrustedOrigins: []string{"https://oauth-provider.com"}, // Custom validation for same-site requests AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { // Your custom authorization logic here return validateCustomAuth(c), nil // return true, err // blocks request with error // return true, nil // allows CSRF request through // return false, nil // falls back to legacy token logic }, })) ``` PR: https://github.com/labstack/echo/pull/2858 **Type-Safe Generic Parameter Binding** * Added generic functions for type-safe parameter extraction and context access by @aldas in https://github.com/labstack/echo/pull/2856 Echo now provides generic functions for extracting path, query, and form parameters with automatic type conversion, eliminating manual string parsing and type assertions. **New Functions:** - Path parameters: `PathParam[T]`, `PathParamOr[T]` - Query parameters: `QueryParam[T]`, `QueryParamOr[T]`, `QueryParams[T]`, `QueryParamsOr[T]` - Form values: `FormParam[T]`, `FormParamOr[T]`, `FormParams[T]`, `FormParamsOr[T]` - Context store: `ContextGet[T]`, `ContextGetOr[T]` **Supported Types:** Primitives (`bool`, `string`, `int`/`uint` variants, `float32`/`float64`), `time.Duration`, `time.Time` (with custom layouts and Unix timestamp support), and custom types implementing `BindUnmarshaler`, `TextUnmarshaler`, or `JSONUnmarshaler`. **Example:** ```go // Before: Manual parsing idStr := c.Param("id") id, err := strconv.Atoi(idStr) // After: Type-safe with automatic parsing id, err := echo.PathParam[int](c, "id") // With default values page, err := echo.QueryParamOr[int](c, "page", 1) limit, err := echo.QueryParamOr[int](c, "limit", 20) // Type-safe context access (no more panics from type assertions) user, err := echo.ContextGet[*User](c, "user") ``` PR: https://github.com/labstack/echo/pull/2856 **DEPRECATION NOTICE** Timeout Middleware Deprecated - Use ContextTimeout Instead The `middleware.Timeout` middleware has been **deprecated** due to fundamental architectural issues that cause data races. Use `middleware.ContextTimeout` or `middleware.ContextTimeoutWithConfig` instead. **Why is this being deprecated?** The Timeout middleware manipulates response writers across goroutine boundaries, which causes data races that cannot be reliably fixed without a complete architectural redesign. The middleware: - Swaps the response writer using `http.TimeoutHandler` - Must be the first middleware in the chain (fragile constraint) - Can cause races with other middleware (Logger, metrics, custom middleware) - Has been the source of multiple race condition fixes over the years **What should you use instead?** The `ContextTimeout` middleware (available since v4.12.0) provides timeout functionality using Go's standard context mechanism. It is: - Race-free by design - Can be placed anywhere in the middleware chain - Simpler and more maintainable - Compatible with all other middleware **Migration Guide:** ```go // Before (deprecated): e.Use(middleware.Timeout()) // After (recommended): e.Use(middleware.ContextTimeout(30 * time.Second)) ``` **Important Behavioral Differences:** 1. **Handler cooperation required**: With ContextTimeout, your handlers must check `context.Done()` for cooperative cancellation. The old Timeout middleware would send a 503 response regardless of handler cooperation, but had data race issues. 2. **Error handling**: ContextTimeout returns errors through the standard error handling flow. Handlers that receive `context.DeadlineExceeded` should handle it appropriately: ```go e.GET("/long-task", func(c echo.Context) error { ctx := c.Request().Context() // Example: database query with context result, err := db.QueryContext(ctx, "SELECT * FROM large_table") if err != nil { if errors.Is(err, context.DeadlineExceeded) { // Handle timeout return echo.NewHTTPError(http.StatusServiceUnavailable, "Request timeout") } return err } return c.JSON(http.StatusOK, result) }) ``` 3. **Background tasks**: For long-running background tasks, use goroutines with context: ```go e.GET("/async-task", func(c echo.Context) error { ctx := c.Request().Context() resultCh := make(chan Result, 1) errCh := make(chan error, 1) go func() { result, err := performLongTask(ctx) if err != nil { errCh <- err return } resultCh <- result }() select { case result := <-resultCh: return c.JSON(http.StatusOK, result) case err := <-errCh: return err case <-ctx.Done(): return echo.NewHTTPError(http.StatusServiceUnavailable, "Request timeout") } }) ``` **Enhancements** * Fixes by @aldas in https://github.com/labstack/echo/pull/2852 * Generic functions by @aldas in https://github.com/labstack/echo/pull/2856 * CRSF with Sec-Fetch-Site checks by @aldas in https://github.com/labstack/echo/pull/2858 ## v4.14.0 - 2025-12-11 `middleware.Logger` has been deprecated. For request logging, use `middleware.RequestLogger` or `middleware.RequestLoggerWithConfig`. `middleware.RequestLogger` replaces `middleware.Logger`, offering comparable configuration while relying on the Go standard library’s new `slog` logger. The previous default output format was JSON. The new default follows the standard `slog` logger settings. To continue emitting request logs in JSON, configure `slog` accordingly: ```go slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, nil))) e.Use(middleware.RequestLogger()) ``` **Security** * Logger middleware json string escaping and deprecation by @aldas in https://github.com/labstack/echo/pull/2849 **Enhancements** * Update deps by @aldas in https://github.com/labstack/echo/pull/2807 * refactor to use reflect.TypeFor by @cuiweixie in https://github.com/labstack/echo/pull/2812 * Use Go 1.25 in CI by @aldas in https://github.com/labstack/echo/pull/2810 * Modernize context.go by replacing interface{} with any by @vishr in https://github.com/labstack/echo/pull/2822 * Fix typo in SetParamValues comment by @vishr in https://github.com/labstack/echo/pull/2828 * Fix typo in ContextTimeout middleware comment by @vishr in https://github.com/labstack/echo/pull/2827 * Improve BasicAuth middleware: use strings.Cut and RFC compliance by @vishr in https://github.com/labstack/echo/pull/2825 * Fix duplicate plus operator in router backtracking logic by @yuya-morimoto in https://github.com/labstack/echo/pull/2832 * Replace custom private IP range check with built-in net.IP.IsPrivate by @kumapower17 in https://github.com/labstack/echo/pull/2835 * Ensure proxy connection is closed in proxyRaw function(#2837) by @kumapower17 in https://github.com/labstack/echo/pull/2838 * Update deps by @aldas in https://github.com/labstack/echo/pull/2843 * Update golang.org/x/* deps by @aldas in https://github.com/labstack/echo/pull/2850 ## v4.13.4 - 2025-05-22 **Enhancements** * chore: fix some typos in comment by @zhuhaicity in https://github.com/labstack/echo/pull/2735 * CI: test with Go 1.24 by @aldas in https://github.com/labstack/echo/pull/2748 * Add support for TLS WebSocket proxy by @t-ibayashi-safie in https://github.com/labstack/echo/pull/2762 **Security** * Update dependencies for [GO-2025-3487](https://pkg.go.dev/vuln/GO-2025-3487), [GO-2025-3503](https://pkg.go.dev/vuln/GO-2025-3503) and [GO-2025-3595](https://pkg.go.dev/vuln/GO-2025-3595) in https://github.com/labstack/echo/pull/2780 ## v4.13.3 - 2024-12-19 **Security** * Update golang.org/x/net dependency [GO-2024-3333](https://pkg.go.dev/vuln/GO-2024-3333) in https://github.com/labstack/echo/pull/2722 ## v4.13.2 - 2024-12-12 **Security** * Update dependencies (dependabot reports [GO-2024-3321](https://pkg.go.dev/vuln/GO-2024-3321)) in https://github.com/labstack/echo/pull/2721 ## v4.13.1 - 2024-12-11 **Fixes** * Fix BindBody ignoring `Transfer-Encoding: chunked` requests by @178inaba in https://github.com/labstack/echo/pull/2717 ## v4.13.0 - 2024-12-04 **BREAKING CHANGE** JWT Middleware Removed from Core use [labstack/echo-jwt](https://github.com/labstack/echo-jwt) instead The JWT middleware has been **removed from Echo core** due to another security vulnerability, [CVE-2024-51744](https://nvd.nist.gov/vuln/detail/CVE-2024-51744). For more details, refer to issue [#2699](https://github.com/labstack/echo/issues/2699). A drop-in replacement is available in the [labstack/echo-jwt](https://github.com/labstack/echo-jwt) repository. **Important**: Direct assignments like `token := c.Get("user").(*jwt.Token)` will now cause a panic due to an invalid cast. Update your code accordingly. Replace the current imports from `"github.com/golang-jwt/jwt"` in your handlers to the new middleware version using `"github.com/golang-jwt/jwt/v5"`. Background: The version of `golang-jwt/jwt` (v3.2.2) previously used in Echo core has been in an unmaintained state for some time. This is not the first vulnerability affecting this library; earlier issues were addressed in [PR #1946](https://github.com/labstack/echo/pull/1946). JWT middleware was marked as deprecated in Echo core as of [v4.10.0](https://github.com/labstack/echo/releases/tag/v4.10.0) on 2022-12-27. If you did not notice that, consider leveraging tools like [Staticcheck](https://staticcheck.dev/) to catch such deprecations earlier in you dev/CI flow. For bonus points - check out [gosec](https://github.com/securego/gosec). We sincerely apologize for any inconvenience caused by this change. While we strive to maintain backward compatibility within Echo core, recurring security issues with third-party dependencies have forced this decision. **Enhancements** * remove jwt middleware by @stevenwhitehead in https://github.com/labstack/echo/pull/2701 * optimization: struct alignment by @behnambm in https://github.com/labstack/echo/pull/2636 * bind: Maintain backwards compatibility for map[string]interface{} binding by @thesaltree in https://github.com/labstack/echo/pull/2656 * Add Go 1.23 to CI by @aldas in https://github.com/labstack/echo/pull/2675 * improve `MultipartForm` test by @martinyonatann in https://github.com/labstack/echo/pull/2682 * `bind` : add support of multipart multi files by @martinyonatann in https://github.com/labstack/echo/pull/2684 * Add TemplateRenderer struct to ease creating renderers for `html/template` and `text/template` packages. by @aldas in https://github.com/labstack/echo/pull/2690 * Refactor TestBasicAuth to utilize table-driven test format by @ErikOlson in https://github.com/labstack/echo/pull/2688 * Remove broken header by @aldas in https://github.com/labstack/echo/pull/2705 * fix(bind body): content-length can be -1 by @phamvinhdat in https://github.com/labstack/echo/pull/2710 * CORS middleware should compile allowOrigin regexp at creation by @aldas in https://github.com/labstack/echo/pull/2709 * Shorten Github issue template and add test example by @aldas in https://github.com/labstack/echo/pull/2711 ## v4.12.0 - 2024-04-15 **Security** * Update golang.org/x/net dep because of [GO-2024-2687](https://pkg.go.dev/vuln/GO-2024-2687) by @aldas in https://github.com/labstack/echo/pull/2625 **Enhancements** * binder: make binding to Map work better with string destinations by @aldas in https://github.com/labstack/echo/pull/2554 * README.md: add Encore as sponsor by @marcuskohlberg in https://github.com/labstack/echo/pull/2579 * Reorder paragraphs in README.md by @aldas in https://github.com/labstack/echo/pull/2581 * CI: upgrade actions/checkout to v4 by @aldas in https://github.com/labstack/echo/pull/2584 * Remove default charset from 'application/json' Content-Type header by @doortts in https://github.com/labstack/echo/pull/2568 * CI: Use Go 1.22 by @aldas in https://github.com/labstack/echo/pull/2588 * binder: allow binding to a nil map by @georgmu in https://github.com/labstack/echo/pull/2574 * Add Skipper Unit Test In BasicBasicAuthConfig and Add More Detail Explanation regarding BasicAuthValidator by @RyoKusnadi in https://github.com/labstack/echo/pull/2461 * fix some typos by @teslaedison in https://github.com/labstack/echo/pull/2603 * fix: some typos by @pomadev in https://github.com/labstack/echo/pull/2596 * Allow ResponseWriters to unwrap writers when flushing/hijacking by @aldas in https://github.com/labstack/echo/pull/2595 * Add SPDX licence comments to files. by @aldas in https://github.com/labstack/echo/pull/2604 * Upgrade deps by @aldas in https://github.com/labstack/echo/pull/2605 * Change type definition blocks to single declarations. This helps copy… by @aldas in https://github.com/labstack/echo/pull/2606 * Fix Real IP logic by @cl-bvl in https://github.com/labstack/echo/pull/2550 * Default binder can use `UnmarshalParams(params []string) error` inter… by @aldas in https://github.com/labstack/echo/pull/2607 * Default binder can bind pointer to slice as struct field. For example `*[]string` by @aldas in https://github.com/labstack/echo/pull/2608 * Remove maxparam dependence from Context by @aldas in https://github.com/labstack/echo/pull/2611 * When route is registered with empty path it is normalized to `/`. by @aldas in https://github.com/labstack/echo/pull/2616 * proxy middleware should use httputil.ReverseProxy for SSE requests by @aldas in https://github.com/labstack/echo/pull/2624 ## v4.11.4 - 2023-12-20 **Security** * Upgrade golang.org/x/crypto to v0.17.0 to fix vulnerability [issue](https://pkg.go.dev/vuln/GO-2023-2402) [#2562](https://github.com/labstack/echo/pull/2562) **Enhancements** * Update deps and mark Go version to 1.18 as this is what golang.org/x/* use [#2563](https://github.com/labstack/echo/pull/2563) * Request logger: add example for Slog https://pkg.go.dev/log/slog [#2543](https://github.com/labstack/echo/pull/2543) ## v4.11.3 - 2023-11-07 **Security** * 'c.Attachment' and 'c.Inline' should escape filename in 'Content-Disposition' header to avoid 'Reflect File Download' vulnerability. [#2541](https://github.com/labstack/echo/pull/2541) **Enhancements** * Tests: refactor context tests to be separate functions [#2540](https://github.com/labstack/echo/pull/2540) * Proxy middleware: reuse echo request context [#2537](https://github.com/labstack/echo/pull/2537) * Mark unmarshallable yaml struct tags as ignored [#2536](https://github.com/labstack/echo/pull/2536) ## v4.11.2 - 2023-10-11 **Security** * Bump golang.org/x/net to prevent CVE-2023-39325 / CVE-2023-44487 HTTP/2 Rapid Reset Attack [#2527](https://github.com/labstack/echo/pull/2527) * fix(sec): randomString bias introduced by #2490 [#2492](https://github.com/labstack/echo/pull/2492) * CSRF/RequestID mw: switch math/random usage to crypto/random [#2490](https://github.com/labstack/echo/pull/2490) **Enhancements** * Delete unused context in body_limit.go [#2483](https://github.com/labstack/echo/pull/2483) * Use Go 1.21 in CI [#2505](https://github.com/labstack/echo/pull/2505) * Fix some typos [#2511](https://github.com/labstack/echo/pull/2511) * Allow CORS middleware to send Access-Control-Max-Age: 0 [#2518](https://github.com/labstack/echo/pull/2518) * Bump dependancies [#2522](https://github.com/labstack/echo/pull/2522) ## v4.11.1 - 2023-07-16 **Fixes** * Fix `Gzip` middleware not sending response code for no content responses (404, 301/302 redirects etc) [#2481](https://github.com/labstack/echo/pull/2481) ## v4.11.0 - 2023-07-14 **Fixes** * Fixes the proxy middleware concurrency issue of calling the Next() proxy target on Round Robin Balancer [#2409](https://github.com/labstack/echo/pull/2409) * Fix `group.RouteNotFound` not working when group has attached middlewares [#2411](https://github.com/labstack/echo/pull/2411) * Fix global error handler return error message when message is an error [#2456](https://github.com/labstack/echo/pull/2456) * Do not use global timeNow variables [#2477](https://github.com/labstack/echo/pull/2477) **Enhancements** * Added a optional config variable to disable centralized error handler in recovery middleware [#2410](https://github.com/labstack/echo/pull/2410) * refactor: use `strings.ReplaceAll` directly [#2424](https://github.com/labstack/echo/pull/2424) * Add support for Go1.20 `http.rwUnwrapper` to Response struct [#2425](https://github.com/labstack/echo/pull/2425) * Check whether is nil before invoking centralized error handling [#2429](https://github.com/labstack/echo/pull/2429) * Proper colon support in `echo.Reverse` method [#2416](https://github.com/labstack/echo/pull/2416) * Fix misuses of a vs an in documentation comments [#2436](https://github.com/labstack/echo/pull/2436) * Add link to slog.Handler library for Echo logging into README.md [#2444](https://github.com/labstack/echo/pull/2444) * In proxy middleware Support retries of failed proxy requests [#2414](https://github.com/labstack/echo/pull/2414) * gofmt fixes to comments [#2452](https://github.com/labstack/echo/pull/2452) * gzip response only if it exceeds a minimal length [#2267](https://github.com/labstack/echo/pull/2267) * Upgrade packages [#2475](https://github.com/labstack/echo/pull/2475) ## v4.10.2 - 2023-02-22 **Security** * `filepath.Clean` behaviour has changed in Go 1.20 - adapt to it [#2406](https://github.com/labstack/echo/pull/2406) * Add `middleware.CORSConfig.UnsafeWildcardOriginWithAllowCredentials` to make UNSAFE usages of wildcard origin + allow cretentials less likely [#2405](https://github.com/labstack/echo/pull/2405) **Enhancements** * Add more HTTP error values [#2277](https://github.com/labstack/echo/pull/2277) ## v4.10.1 - 2023-02-19 **Security** * Upgrade deps due to the latest golang.org/x/net vulnerability [#2402](https://github.com/labstack/echo/pull/2402) **Enhancements** * Add new JWT repository to the README [#2377](https://github.com/labstack/echo/pull/2377) * Return an empty string for ctx.path if there is no registered path [#2385](https://github.com/labstack/echo/pull/2385) * Add context timeout middleware [#2380](https://github.com/labstack/echo/pull/2380) * Update link to jaegertracing [#2394](https://github.com/labstack/echo/pull/2394) ## v4.10.0 - 2022-12-27 **Security** * We are deprecating JWT middleware in this repository. Please use https://github.com/labstack/echo-jwt instead. JWT middleware is moved to separate repository to allow us to bump/upgrade version of JWT implementation (`github.com/golang-jwt/jwt`) we are using which we can not do in Echo core because this would break backwards compatibility guarantees we try to maintain. * This minor version bumps minimum Go version to 1.17 (from 1.16) due `golang.org/x/` packages we depend on. There are several vulnerabilities fixed in these libraries. Echo still tries to support last 4 Go versions but there are occasions we can not guarantee this promise. **Enhancements** * Bump x/text to 0.3.8 [#2305](https://github.com/labstack/echo/pull/2305) * Bump dependencies and add notes about Go releases we support [#2336](https://github.com/labstack/echo/pull/2336) * Add helper interface for ProxyBalancer interface [#2316](https://github.com/labstack/echo/pull/2316) * Expose `middleware.CreateExtractors` function so we can use it from echo-contrib repository [#2338](https://github.com/labstack/echo/pull/2338) * Refactor func(Context) error to HandlerFunc [#2315](https://github.com/labstack/echo/pull/2315) * Improve function comments [#2329](https://github.com/labstack/echo/pull/2329) * Add new method HTTPError.WithInternal [#2340](https://github.com/labstack/echo/pull/2340) * Replace io/ioutil package usages [#2342](https://github.com/labstack/echo/pull/2342) * Add staticcheck to CI flow [#2343](https://github.com/labstack/echo/pull/2343) * Replace relative path determination from proprietary to std [#2345](https://github.com/labstack/echo/pull/2345) * Remove square brackets from ipv6 addresses in XFF (X-Forwarded-For header) [#2182](https://github.com/labstack/echo/pull/2182) * Add testcases for some BodyLimit middleware configuration options [#2350](https://github.com/labstack/echo/pull/2350) * Additional configuration options for RequestLogger and Logger middleware [#2341](https://github.com/labstack/echo/pull/2341) * Add route to request log [#2162](https://github.com/labstack/echo/pull/2162) * GitHub Workflows security hardening [#2358](https://github.com/labstack/echo/pull/2358) * Add govulncheck to CI and bump dependencies [#2362](https://github.com/labstack/echo/pull/2362) * Fix rate limiter docs [#2366](https://github.com/labstack/echo/pull/2366) * Refactor how `e.Routes()` work and introduce `e.OnAddRouteHandler` callback [#2337](https://github.com/labstack/echo/pull/2337) ## v4.9.1 - 2022-10-12 **Fixes** * Fix logger panicing (when template is set to empty) by bumping dependency version [#2295](https://github.com/labstack/echo/issues/2295) **Enhancements** * Improve CORS documentation [#2272](https://github.com/labstack/echo/pull/2272) * Update readme about supported Go versions [#2291](https://github.com/labstack/echo/pull/2291) * Tests: improve error handling on closing body [#2254](https://github.com/labstack/echo/pull/2254) * Tests: refactor some of the assertions in tests [#2275](https://github.com/labstack/echo/pull/2275) * Tests: refactor assertions [#2301](https://github.com/labstack/echo/pull/2301) ## v4.9.0 - 2022-09-04 **Security** * Fix open redirect vulnerability in handlers serving static directories (e.Static, e.StaticFs, echo.StaticDirectoryHandler) [#2260](https://github.com/labstack/echo/pull/2260) **Enhancements** * Allow configuring ErrorHandler in CSRF middleware [#2257](https://github.com/labstack/echo/pull/2257) * Replace HTTP method constants in tests with stdlib constants [#2247](https://github.com/labstack/echo/pull/2247) ## v4.8.0 - 2022-08-10 **Most notable things** You can now add any arbitrary HTTP method type as a route [#2237](https://github.com/labstack/echo/pull/2237) ```go e.Add("COPY", "/*", func(c echo.Context) error return c.String(http.StatusOK, "OK COPY") }) ``` You can add custom 404 handler for specific paths [#2217](https://github.com/labstack/echo/pull/2217) ```go e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) }) g := e.Group("/images") g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) }) ``` **Enhancements** * Add new value binding methods (UnixTimeMilli,TextUnmarshaler,JSONUnmarshaler) to Valuebinder [#2127](https://github.com/labstack/echo/pull/2127) * Refactor: body_limit middleware unit test [#2145](https://github.com/labstack/echo/pull/2145) * Refactor: Timeout mw: rework how test waits for timeout. [#2187](https://github.com/labstack/echo/pull/2187) * BasicAuth middleware returns 500 InternalServerError on invalid base64 strings but should return 400 [#2191](https://github.com/labstack/echo/pull/2191) * Refactor: duplicated findStaticChild process at findChildWithLabel [#2176](https://github.com/labstack/echo/pull/2176) * Allow different param names in different methods with same path scheme [#2209](https://github.com/labstack/echo/pull/2209) * Add support for registering handlers for different 404 routes [#2217](https://github.com/labstack/echo/pull/2217) * Middlewares should use errors.As() instead of type assertion on HTTPError [#2227](https://github.com/labstack/echo/pull/2227) * Allow arbitrary HTTP method types to be added as routes [#2237](https://github.com/labstack/echo/pull/2237) ## v4.7.2 - 2022-03-16 **Fixes** * Fix nil pointer exception when calling Start again after address binding error [#2131](https://github.com/labstack/echo/pull/2131) * Fix CSRF middleware not being able to extract token from multipart/form-data form [#2136](https://github.com/labstack/echo/pull/2136) * Fix Timeout middleware write race [#2126](https://github.com/labstack/echo/pull/2126) **Enhancements** * Recover middleware should not log panic for aborted handler [#2134](https://github.com/labstack/echo/pull/2134) ## v4.7.1 - 2022-03-13 **Fixes** * Fix `e.Static`, `.File()`, `c.Attachment()` being picky with paths starting with `./`, `../` and `/` after 4.7.0 introduced echo.Filesystem support (Go1.16+) [#2123](https://github.com/labstack/echo/pull/2123) **Enhancements** * Remove some unused code [#2116](https://github.com/labstack/echo/pull/2116) ## v4.7.0 - 2022-03-01 **Enhancements** * Add JWT, KeyAuth, CSRF multivalue extractors [#2060](https://github.com/labstack/echo/pull/2060) * Add LogErrorFunc to recover middleware [#2072](https://github.com/labstack/echo/pull/2072) * Add support for HEAD method query params binding [#2027](https://github.com/labstack/echo/pull/2027) * Improve filesystem support with echo.FileFS, echo.StaticFS, group.FileFS, group.StaticFS [#2064](https://github.com/labstack/echo/pull/2064) **Fixes** * Fix X-Real-IP bug, improve tests [#2007](https://github.com/labstack/echo/pull/2007) * Minor syntax fixes [#1994](https://github.com/labstack/echo/pull/1994), [#2102](https://github.com/labstack/echo/pull/2102), [#2102](https://github.com/labstack/echo/pull/2102) **General** * Add cache-control and connection headers [#2103](https://github.com/labstack/echo/pull/2103) * Add Retry-After header constant [#2078](https://github.com/labstack/echo/pull/2078) * Upgrade `go` directive in `go.mod` to 1.17 [#2049](https://github.com/labstack/echo/pull/2049) * Add Pagoda [#2077](https://github.com/labstack/echo/pull/2077) and Souin [#2069](https://github.com/labstack/echo/pull/2069) to 3rd-party middlewares in README ## v4.6.3 - 2022-01-10 **Fixes** * Fixed Echo version number in greeting message which was not incremented to `4.6.2` [#2066](https://github.com/labstack/echo/issues/2066) ## v4.6.2 - 2022-01-08 **Fixes** * Fixed route containing escaped colon should be matchable but is not matched to request path [#2047](https://github.com/labstack/echo/pull/2047) * Fixed a problem that returned wrong content-encoding when the gzip compressed content was empty. [#1921](https://github.com/labstack/echo/pull/1921) * Update (test) dependencies [#2021](https://github.com/labstack/echo/pull/2021) **Enhancements** * Add support for configurable target header for the request_id middleware [#2040](https://github.com/labstack/echo/pull/2040) * Change decompress middleware to use stream decompression instead of buffering [#2018](https://github.com/labstack/echo/pull/2018) * Documentation updates ## v4.6.1 - 2021-09-26 **Enhancements** * Add start time to request logger middleware values [#1991](https://github.com/labstack/echo/pull/1991) ## v4.6.0 - 2021-09-20 Introduced a new [request logger](https://github.com/labstack/echo/blob/master/middleware/request_logger.go) middleware to help with cases when you want to use some other logging library in your application. **Fixes** * fix timeout middleware warning: superfluous response.WriteHeader [#1905](https://github.com/labstack/echo/issues/1905) **Enhancements** * Add Cookie to KeyAuth middleware's KeyLookup [#1929](https://github.com/labstack/echo/pull/1929) * JWT middleware should ignore case of auth scheme in request header [#1951](https://github.com/labstack/echo/pull/1951) * Refactor default error handler to return first if response is already committed [#1956](https://github.com/labstack/echo/pull/1956) * Added request logger middleware which helps to use custom logger library for logging requests. [#1980](https://github.com/labstack/echo/pull/1980) * Allow escaping of colon in route path so Google Cloud API "custom methods" could be implemented [#1988](https://github.com/labstack/echo/pull/1988) ## v4.5.0 - 2021-08-01 **Important notes** A **BREAKING CHANGE** is introduced for JWT middleware users. The JWT library used for the JWT middleware had to be changed from [github.com/dgrijalva/jwt-go](https://github.com/dgrijalva/jwt-go) to [github.com/golang-jwt/jwt](https://github.com/golang-jwt/jwt) due former library being unmaintained and affected by security issues. The [github.com/golang-jwt/jwt](https://github.com/golang-jwt/jwt) project is a drop-in replacement, but supports only the latest 2 Go versions. So for JWT middleware users Go 1.15+ is required. For detailed information please read [#1940](https://github.com/labstack/echo/discussions/) To change the library imports in all .go files in your project replace all occurrences of `dgrijalva/jwt-go` with `golang-jwt/jwt`. For Linux CLI you can use: ```bash find -type f -name "*.go" -exec sed -i "s/dgrijalva\/jwt-go/golang-jwt\/jwt/g" {} \; go mod tidy ``` **Fixes** * Change JWT library to `github.com/golang-jwt/jwt` [#1946](https://github.com/labstack/echo/pull/1946) ## v4.4.0 - 2021-07-12 **Fixes** * Split HeaderXForwardedFor header only by comma [#1878](https://github.com/labstack/echo/pull/1878) * Fix Timeout middleware Context propagation [#1910](https://github.com/labstack/echo/pull/1910) **Enhancements** * Bind data using headers as source [#1866](https://github.com/labstack/echo/pull/1866) * Adds JWTConfig.ParseTokenFunc to JWT middleware to allow different libraries implementing JWT parsing. [#1887](https://github.com/labstack/echo/pull/1887) * Adding tests for Echo#Host [#1895](https://github.com/labstack/echo/pull/1895) * Adds RequestIDHandler function to RequestID middleware [#1898](https://github.com/labstack/echo/pull/1898) * Allow for custom JSON encoding implementations [#1880](https://github.com/labstack/echo/pull/1880) ## v4.3.0 - 2021-05-08 **Important notes** * Route matching has improvements for following cases: 1. Correctly match routes with parameter part as last part of route (with trailing backslash) 2. Considering handlers when resolving routes and search for matching http method handler * Echo minimal Go version is now 1.13. **Fixes** * When url ends with slash first param route is the match [#1804](https://github.com/labstack/echo/pull/1812) * Router should check if node is suitable as matching route by path+method and if not then continue search in tree [#1808](https://github.com/labstack/echo/issues/1808) * Fix timeout middleware not writing response correctly when handler panics [#1864](https://github.com/labstack/echo/pull/1864) * Fix binder not working with embedded pointer structs [#1861](https://github.com/labstack/echo/pull/1861) * Add Go 1.16 to CI and drop 1.12 specific code [#1850](https://github.com/labstack/echo/pull/1850) **Enhancements** * Make KeyFunc public in JWT middleware [#1756](https://github.com/labstack/echo/pull/1756) * Add support for optional filesystem to the static middleware [#1797](https://github.com/labstack/echo/pull/1797) * Add a custom error handler to key-auth middleware [#1847](https://github.com/labstack/echo/pull/1847) * Allow JWT token to be looked up from multiple sources [#1845](https://github.com/labstack/echo/pull/1845) ## v4.2.2 - 2021-04-07 **Fixes** * Allow proxy middleware to use query part in rewrite (#1802) * Fix timeout middleware not sending status code when handler returns an error (#1805) * Fix Bind() when target is array/slice and path/query params complains bind target not being struct (#1835) * Fix panic in redirect middleware on short host name (#1813) * Fix timeout middleware docs (#1836) ## v4.2.1 - 2021-03-08 **Important notes** Due to a datarace the config parameters for the newly added timeout middleware required a change. See the [docs](https://echo.labstack.com/middleware/timeout). A performance regression has been fixed, even bringing better performance than before for some routing scenarios. **Fixes** * Fix performance regression caused by path escaping (#1777, #1798, #1799, aldas) * Avoid context canceled errors (#1789, clwluvw) * Improve router to use on stack backtracking (#1791, aldas, stffabi) * Fix panic in timeout middleware not being not recovered and cause application crash (#1794, aldas) * Fix Echo.Serve() not serving on HTTP port correctly when TLSListener is used (#1785, #1793, aldas) * Apply go fmt (#1788, Le0tk0k) * Uses strings.Equalfold (#1790, rkilingr) * Improve code quality (#1792, withshubh) This release was made possible by our **contributors**: aldas, clwluvw, lammel, Le0tk0k, maciej-jezierski, rkilingr, stffabi, withshubh ## v4.2.0 - 2021-02-11 **Important notes** The behaviour for binding data has been reworked for compatibility with echo before v4.1.11 by enforcing `explicit tagging` for processing parameters. This **may break** your code if you expect combined handling of query/path/form params. Please see the updated documentation for [request](https://echo.labstack.com/guide/request) and [binding](https://echo.labstack.com/guide/request) The handling for rewrite rules has been slightly adjusted to expand `*` to a non-greedy `(.*?)` capture group. This is only relevant if multiple asterisks are used in your rules. Please see [rewrite](https://echo.labstack.com/middleware/rewrite) and [proxy](https://echo.labstack.com/middleware/proxy) for details. **Security** * Fix directory traversal vulnerability for Windows (#1718, little-cui) * Fix open redirect vulnerability with trailing slash (#1771,#1775 aldas,GeoffreyFrogeye) **Enhancements** * Add Echo#ListenerNetwork as configuration (#1667, pafuent) * Add ability to change the status code using response beforeFuncs (#1706, RashadAnsari) * Echo server startup to allow data race free access to listener address * Binder: Restore pre v4.1.11 behaviour for c.Bind() to use query params only for GET or DELETE methods (#1727, aldas) * Binder: Add separate methods to bind only query params, path params or request body (#1681, aldas) * Binder: New fluent binder for query/path/form parameter binding (#1717, #1736, aldas) * Router: Performance improvements for missed routes (#1689, pafuent) * Router: Improve performance for Real-IP detection using IndexByte instead of Split (#1640, imxyb) * Middleware: Support real regex rules for rewrite and proxy middleware (#1767) * Middleware: New rate limiting middleware (#1724, iambenkay) * Middleware: New timeout middleware implementation for go1.13+ (#1743, ) * Middleware: Allow regex pattern for CORS middleware (#1623, KlotzAndrew) * Middleware: Add IgnoreBase parameter to static middleware (#1701, lnenad, iambenkay) * Middleware: Add an optional custom function to CORS middleware to validate origin (#1651, curvegrid) * Middleware: Support form fields in JWT middleware (#1704, rkfg) * Middleware: Use sync.Pool for (de)compress middleware to improve performance (#1699, #1672, pafuent) * Middleware: Add decompress middleware to support gzip compressed requests (#1687, arun0009) * Middleware: Add ErrJWTInvalid for JWT middleware (#1627, juanbelieni) * Middleware: Add SameSite mode for CSRF cookies to support iframes (#1524, pr0head) **Fixes** * Fix handling of special trailing slash case for partial prefix (#1741, stffabi) * Fix handling of static routes with trailing slash (#1747) * Fix Static files route not working (#1671, pwli0755, lammel) * Fix use of caret(^) in regex for rewrite middleware (#1588, chotow) * Fix Echo#Reverse for Any type routes (#1695, pafuent) * Fix Router#Find panic with infinite loop (#1661, pafuent) * Fix Router#Find panic fails on Param paths (#1659, pafuent) * Fix DefaultHTTPErrorHandler with Debug=true (#1477, lammel) * Fix incorrect CORS headers (#1669, ulasakdeniz) * Fix proxy middleware rewritePath to use url with updated tests (#1630, arun0009) * Fix rewritePath for proxy middleware to use escaped path in (#1628, arun0009) * Remove unless defer (#1656, imxyb) **General** * New maintainers for Echo: Roland Lammel (@lammel) and Pablo Andres Fuente (@pafuent) * Add GitHub action to compare benchmarks (#1702, pafuent) * Binding query/path params and form fields to struct only works for explicit tags (#1729,#1734, aldas) * Add support for Go 1.15 in CI (#1683, asahasrabuddhe) * Add test for request id to remain unchanged if provided (#1719, iambenkay) * Refactor echo instance listener access and startup to speed up testing (#1735, aldas) * Refactor and improve various tests for binding and routing * Run test workflow only for relevant changes (#1637, #1636, pofl) * Update .travis.yml (#1662, santosh653) * Update README.md with an recents framework benchmark (#1679, pafuent) This release was made possible by **over 100 commits** from more than **20 contributors**: asahasrabuddhe, aldas, AndrewKlotz, arun0009, chotow, curvegrid, iambenkay, imxyb, juanbelieni, lammel, little-cui, lnenad, pafuent, pofl, pr0head, pwli, RashadAnsari, rkfg, santosh653, segfiner, stffabi, ulasakdeniz ================================================ FILE: CLAUDE.md ================================================ # CLAUDE.md This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. ## About This Project Echo is a high performance, minimalist Go web framework. This is the main repository for Echo v4, which is available as a Go module at `github.com/labstack/echo/v4`. ## Development Commands The project uses a Makefile for common development tasks: - `make check` - Run linting, vetting, and race condition tests (default target) - `make init` - Install required linting tools (golint, staticcheck) - `make lint` - Run staticcheck and golint - `make vet` - Run go vet - `make test` - Run short tests - `make race` - Run tests with race detector - `make benchmark` - Run benchmarks Example commands for development: ```bash # Setup development environment make init # Run all checks (lint, vet, race) make check # Run specific tests go test ./middleware/... go test -race ./... # Run benchmarks make benchmark ``` ## Code Architecture ### Core Components **Echo Instance (`echo.go`)** - The `Echo` struct is the top-level framework instance - Contains router, middleware stacks, and server configuration - Not goroutine-safe for mutations after server start **Context (`context.go`)** - The `Context` interface represents HTTP request/response context - Provides methods for request/response handling, path parameters, data binding - Core abstraction for request processing **Router (`router.go`)** - Radix tree-based HTTP router with smart route prioritization - Supports static routes, parameterized routes (`/users/:id`), and wildcard routes (`/static/*`) - Each HTTP method has its own routing tree **Middleware (`middleware/`)** - Extensive middleware system with 50+ built-in middlewares - Middleware can be applied at Echo, Group, or individual route level - Common middleware: Logger, Recover, CORS, JWT, Rate Limiting, etc. ### Key Patterns **Middleware Chain** - Pre-middleware runs before routing - Regular middleware runs after routing but before handlers - Middleware functions have signature `func(next echo.HandlerFunc) echo.HandlerFunc` **Route Groups** - Routes can be grouped with common prefixes and middleware - Groups support nested sub-groups - Defined in `group.go` **Data Binding** - Automatic binding of request data (JSON, XML, form) to Go structs - Implemented in `binder.go` with support for custom binders **Error Handling** - Centralized error handling via `HTTPErrorHandler` - Automatic panic recovery with stack traces ## File Organization - Root directory: Core Echo functionality (echo.go, context.go, router.go, etc.) - `middleware/`: All built-in middleware implementations - `_test/`: Test fixtures and utilities - `_fixture/`: Test data files ## Code Style - Go code uses tabs for indentation (per .editorconfig) - Follows standard Go conventions and formatting - Uses gofmt, golint, and staticcheck for code quality ## Testing - Standard Go testing with `testing` package - Tests include unit tests, integration tests, and benchmarks - Race condition testing is required (`make race`) - Test files follow `*_test.go` naming convention ================================================ FILE: LICENSE ================================================ The MIT License (MIT) Copyright (c) 2022 LabStack Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Makefile ================================================ PKG := "github.com/labstack/echo" PKG_LIST := $(shell go list ${PKG}/...) .DEFAULT_GOAL := check check: lint vet race ## Check project init: @go install golang.org/x/lint/golint@latest @go install honnef.co/go/tools/cmd/staticcheck@latest lint: ## Lint the files @staticcheck ${PKG_LIST} @golint -set_exit_status ${PKG_LIST} vet: ## Vet the files @go vet ${PKG_LIST} test: ## Run tests @go test -short ${PKG_LIST} race: ## Run tests with data race detector @go test -race ${PKG_LIST} benchmark: ## Run benchmarks @go test -run="-" -benchmem -bench=".*" ${PKG_LIST} help: ## Display this help screen @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' goversion ?= "1.25" test_version: ## Run tests inside Docker with given version (defaults to 1.25 oldest supported). Example: make test_version goversion=1.25 @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" ================================================ FILE: README.md ================================================ [![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge) [![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo/v4) [![Go Report Card](https://goreportcard.com/badge/github.com/labstack/echo?style=flat-square)](https://goreportcard.com/report/github.com/labstack/echo) [![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/labstack/echo/echo.yml?style=flat-square)](https://github.com/labstack/echo/actions) [![Codecov](https://img.shields.io/codecov/c/github/labstack/echo.svg?style=flat-square)](https://codecov.io/gh/labstack/echo) [![Forum](https://img.shields.io/badge/community-forum-00afd1.svg?style=flat-square)](https://github.com/labstack/echo/discussions) [![Twitter](https://img.shields.io/badge/twitter-@labstack-55acee.svg?style=flat-square)](https://twitter.com/labstack) [![License](http://img.shields.io/badge/license-mit-blue.svg?style=flat-square)](https://raw.githubusercontent.com/labstack/echo/master/LICENSE) ## Echo High performance, extensible, minimalist Go web framework. * [Official website](https://echo.labstack.com) * [Quick start](https://echo.labstack.com/docs/quick-start) * [Middlewares](https://echo.labstack.com/docs/category/middleware) Help and questions: [Github Discussions](https://github.com/labstack/echo/discussions) ### Feature Overview - Optimized HTTP router which smartly prioritize routes - Build robust and scalable RESTful APIs - Group APIs - Extensible middleware framework - Define middleware at root, group or route level - Data binding for JSON, XML and form payload - Handy functions to send variety of HTTP responses - Centralized HTTP error handling - Template rendering with any template engine - Define your format for the logger - Highly customizable - Automatic TLS via Let’s Encrypt - HTTP/2 support ## Sponsors
encore icon Encore – the platform for building Go-based cloud backends

Click [here](https://github.com/sponsors/labstack) for more information on sponsorship. ## [Guide](https://echo.labstack.com/guide) ### Supported Echo versions - Latest major version of Echo is `v5` as of 2026-01-18. - Until 2026-03-31, any critical issues requiring breaking API changes will be addressed, even if this violates semantic versioning. - See [API_CHANGES_V5.md](./API_CHANGES_V5.md) for public API changes between `v4` and `v5`, notes on upgrading. - If you are using Echo in a production environment, it is recommended to wait until after 2026-03-31 before upgrading. - Echo `v4` is supported with **security*** updates and **bug** fixes until **2026-12-31** ### Installation ```sh // go get github.com/labstack/echo/{version} go get github.com/labstack/echo/v5 ``` Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with older versions. ### Example ```go package main import ( "github.com/labstack/echo/v5" "github.com/labstack/echo/v5/middleware" "log/slog" "net/http" ) func main() { // Echo instance e := echo.New() // Middleware e.Use(middleware.RequestLogger()) // use the RequestLogger middleware with slog logger e.Use(middleware.Recover()) // recover panics as errors for proper error handling // Routes e.GET("/", hello) // Start server if err := e.Start(":8080"); err != nil { slog.Error("failed to start server", "error", err) } } // Handler func hello(c *echo.Context) error { return c.String(http.StatusOK, "Hello, World!") } ``` # Official middleware repositories Following list of middleware is maintained by Echo team. | Repository | Description | |------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------| | [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware | | [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [pprof](https://pkg.go.dev/net/http/pprof)) middlewares | | [github.com/labstack/echo-opentelemetry](https://github.com/labstack/echo-opentelemetry) | [OpenTelemetry](https://opentelemetry.io/) middleware for tracing and metrics | | [github.com/labstack/echo-prometheus](https://github.com/labstack/echo-prometheus) | [Prometheus](https://github.com/prometheus/client_golang/) middleware for Echo | # Third-party middleware repositories Be careful when adding 3rd party middleware. Echo teams does not have time or manpower to guarantee safety and quality of middlewares in this list. | Repository | Description | |------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | [oapi-codegen/oapi-codegen](https://github.com/oapi-codegen/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator | | [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. | | [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. | | [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | | [github.com/samber/slog-echo](https://github.com/samber/slog-echo) | Go [slog](https://pkg.go.dev/golang.org/x/exp/slog) logging library wrapper for Echo logger interface. | | [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. | | [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. | | [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code | Please send a PR to add your own library here. ## Contribute **Use issues for everything** - For a small change, just send a PR. - For bigger changes open an issue for discussion before sending a PR. - PR should have: - Test case - Documentation - Example (If it makes sense) - You can also contribute by: - Reporting issues - Suggesting new features or enhancements - Improve/fix documentation ## Credits - [Vishal Rana](https://github.com/vishr) (Author) - [Nitin Rana](https://github.com/nr17) (Consultant) - [Roland Lammel](https://github.com/lammel) (Maintainer) - [Martti T.](https://github.com/aldas) (Maintainer) - [Pablo Andres Fuente](https://github.com/pafuent) (Maintainer) - [Contributors](https://github.com/labstack/echo/graphs/contributors) ## License [MIT](https://github.com/labstack/echo/blob/master/LICENSE) ================================================ FILE: SECURITY.md ================================================ # Security Policy ## Supported Versions | Version | Supported | |-----------|-------------------------------------| | 5.x.x | :white_check_mark: | | >= 4.15.x | :white_check_mark: until 2026.12.31 | | < 4.15 | :x: | ## Reporting a Vulnerability https://github.com/labstack/echo/security/advisories/new or look for maintainers email(s) in commits and email them. ================================================ FILE: _fixture/_fixture/README.md ================================================ This directory is used for the static middleware test ================================================ FILE: _fixture/certs/README.md ================================================ To generate a valid certificate and private key use the following command: ```bash # In OpenSSL ≥ 1.1.1 openssl req -x509 -newkey rsa:4096 -sha256 -days 9999 -nodes \ -keyout key.pem -out cert.pem -subj "/CN=localhost" \ -addext "subjectAltName=DNS:localhost,IP:127.0.0.1,IP:::1" ``` To check a certificate use the following command: ```bash openssl x509 -in cert.pem -text ``` ================================================ FILE: _fixture/certs/cert.pem ================================================ -----BEGIN CERTIFICATE----- MIIFODCCAyCgAwIBAgIUaTvDluaMf+VJgYHQ0HFTS3yuCHYwDQYJKoZIhvcNAQEL BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIxMDIyNzIxMzQ0MVoXDTQ4MDcx NDIxMzQ0MVowFDESMBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0BAQEF AAOCAg8AMIICCgKCAgEAnqyyAAnWFH2TH7Epj5yfZxYrBvizydZe1Wo/1WpGR2IK QT+qIul5sEKX/ERqEOXsawSrL3fw9cuSM8Z2vD/57ZZdoSR7XIdVaMDEQenJ968a HObu4D27uBQwIwrM5ELgnd+fC4gis64nIu+2GSfHumZXi7lLW7DbNm8oWkMqI6tY 2s2wx2hwGYNVJrwSn4WGnkzhQ5U5mkcsLELMx7GR0Qnv6P7sNGZVeqMU7awkcSpR crKR1OUP7XCJkEq83WLHSx50+QZv7LiyDmGnujHevRbdSHlcFfHZtaufYat+qICe S3XADwRQe/0VSsmja6u3DAHy7VmL8PNisAdkopQZrhiI9OvGrpGZffs9zn+s/jeX N1bqVDihCMiEjqXMlHx2oj3AXrZTFxb7y7Ap9C07nf70lpxQWW9SjMYRF98JBiHF eJbQkNVkmz6T8ielQbX0l46F2SGK98oyFCGNIAZBUdj5CcS1E6w/lk4t58/em0k7 3wFC5qg0g0wfIbNSmxljBNxnaBYUqyaaAJJhpaEoOebm4RYV58hQ0FbMfpnLnSh4 dYStsk6i1PumWoa7D45DTtxF3kH7TB3YOB5aWaNGAPQC1m4Qcd23YB5Rd/ABirSp ux6/cFGosjSfJ/G+G0RhNUpmcbDJvFSOhD2WCuieVhCTAzp+VPIA9bSqD+InlT0C AwEAAaOBgTB/MB0GA1UdDgQWBBQZyM//SvzYKokQZI/0MVGb6PkH+zAfBgNVHSME GDAWgBQZyM//SvzYKokQZI/0MVGb6PkH+zAPBgNVHRMBAf8EBTADAQH/MCwGA1Ud EQQlMCOCCWxvY2FsaG9zdIcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG 9w0BAQsFAAOCAgEAKGAJQmQ/KLw8iMb5QsyxxAonVjJ1eDAhNM3GWdHpM0/GFamO vVtATLQQldwDiZJvrsCQPEc8ctZ2Utvg/StLQ3+rZpsvt0+gcUlLJK61qguwYqb2 +T7VK5s7V/OyI/tsuboOW50Pka9vQHV+Z0aM06Yu+HNDAq/UTpEOb/3MQvZd6Ooy PTpZtFb/+5jIQa1dIsfFWmpBxF0+wUd9GEkX3j7nekwoZfJ8Ze4GWYERZbOFpDAQ rIHdthH5VJztnpQJmaKqzgIOF+Rurwlp5ecSC33xNNjDaYtuf/fiWnoKGhHVSBhT 61+0yxn3rTgh/Dsm95xY00rSX6lmcvI+kRNTUc8GGPz0ajBH6xyY7bNhfMjmnSW/ C/XTEDbTAhT7ndWC5vvzp7ZU0TvN+WY6A0f2kxSnnrEk6QRUvRtKkjAkmAFz8exi ttBBW0I3E5HNIC5CYRimq/9z+3clM/P1KbNblwuC65bL+PZ+nzFnn5hFaK9eLPol OwZQXv7IvAw8GfgLTrEUT7eBCQwe1IqesA7NTxF1BVwmNUb2XamvQZ7ly67QybRw 0uJq80XjpVjBWYTTQy1dsnC2OTKdqGsV9TVIDR+UGfIG9cxL70pEbiSH2AX+IDCy i3kNIvpXgBliAyOjW6Hj1fv6dNfAat/hqEfnquWkfvcs3HNrG/InwpwNAUs= -----END CERTIFICATE----- ================================================ FILE: _fixture/certs/key.pem ================================================ -----BEGIN PRIVATE KEY----- MIIJQwIBADANBgkqhkiG9w0BAQEFAASCCS0wggkpAgEAAoICAQCerLIACdYUfZMf sSmPnJ9nFisG+LPJ1l7Vaj/VakZHYgpBP6oi6XmwQpf8RGoQ5exrBKsvd/D1y5Iz xna8P/ntll2hJHtch1VowMRB6cn3rxoc5u7gPbu4FDAjCszkQuCd358LiCKzrici 77YZJ8e6ZleLuUtbsNs2byhaQyojq1jazbDHaHAZg1UmvBKfhYaeTOFDlTmaRyws QszHsZHRCe/o/uw0ZlV6oxTtrCRxKlFyspHU5Q/tcImQSrzdYsdLHnT5Bm/suLIO Yae6Md69Ft1IeVwV8dm1q59hq36ogJ5LdcAPBFB7/RVKyaNrq7cMAfLtWYvw82Kw B2SilBmuGIj068aukZl9+z3Of6z+N5c3VupUOKEIyISOpcyUfHaiPcBetlMXFvvL sCn0LTud/vSWnFBZb1KMxhEX3wkGIcV4ltCQ1WSbPpPyJ6VBtfSXjoXZIYr3yjIU IY0gBkFR2PkJxLUTrD+WTi3nz96bSTvfAULmqDSDTB8hs1KbGWME3GdoFhSrJpoA kmGloSg55ubhFhXnyFDQVsx+mcudKHh1hK2yTqLU+6ZahrsPjkNO3EXeQftMHdg4 HlpZo0YA9ALWbhBx3bdgHlF38AGKtKm7Hr9wUaiyNJ8n8b4bRGE1SmZxsMm8VI6E PZYK6J5WEJMDOn5U8gD1tKoP4ieVPQIDAQABAoICAEHF2CsH6MOpofi7GT08cR7s I33KTcxWngzc9ATk/qjMTO/rEf1Sxmx3zkR1n3nNtQhPcR5GG43nin0HwWQbKOCB OeJ4GuKp/o9jiHbCEEQpQyvD1jUBofSV+bYs3e2ogy8t6OGA1tGgWPy0XMlkoff0 QEnczw3864FO5m0z9h2/Ax//r02ZTw5kUEG0KAwT709jEuVO0AfRhM/8CKKmSola EyaDtSmrWbdyLlSuzJRUNFrVBno3UTjdM0iqkks6jN3ojBhFwNNhY/1uIXafAXNk LOnD1JYMIHCb6X809VWnqvYgozIWWb5rlA3iM2mITmId1LLqMYX5fWj2R5LUzSek H+XG+F9FIouTaL1ACoXr0zyeY5N5YJdyXYa1tThdW+axX9ZrnPgeiQrmxzKPIyb7 LLlVtNBQUg/t5tX80KyYjkNUu4j3oq/uBYPi0m//ovwMyi9bSbbyPT+cDXuXX5Bc oY7wyn3evXX0c1R7vdJLZLkLu+ctVex/9hvMjeW/mMasDjLnqY7pF3Skct1SX5N2 U8YVU9bGvFpLEwM9lmi/T7bcv+zbmGPlfTsZiFrCsixPLn7sX7y5M4L8au8O0jh0 nHm/8rWVg1Qw0Hobg3tA8FjeMa8Sr2fYmkNLVKFzhuJLxknTJLaUbX5CymNqWP4H OctvfSY0nSZ1eQpBkQaJAoIBAQDTb/NhYCfaJBLXHVMy/VYd7kWGZ+I87artcE/l 8u0pJ8XOP4kp0otFIumpHUFodysAeP6HrI79MuJB40fy91HzWZC+NrPufFFFuZ0z Ld1o3Y5nAeoZmMlf1F12Oe3OQZy7nm9eNNkfeoVtKqDv4FhAqk+aoMor86HscKsR C6HlZFdGc7kX0ylrQAXPq9KLhcvUU9oAUpbqTbhYK83IebRJgFDG45HkVo9SUHpF dmCFSb91eZpRGpdfNLCuLiSu52TebayaUCnceeAt8SyeiChJ/TwWmRRDJS0QUv6h s3Wdp+cx9ANoujA4XzAs8Fld5IZ4bcG5jjwD62/tJyWrCC5DAoIBAQDAHfHjrYCK GHBrMj+MA7cK7fCJUn/iJLSLGgo2ANYF5oq9gaCwHCtKIyB9DN/KiY0JpJ6PWg+Q 9Difq23YXiJjNEBS5EFTu9UwWAr1RhSAegrfHxm0sDbcAx31NtDYvBsADCWQYmzc KPfBshf5K4g/VCIj2VzC2CE6kNtdhqLU6AV2Pi1Tl1S82xWoAjHy91tDmlFQNWCj B2ZnZ7tY9zuwDfeBBOVCPHICgl5Q4PrY1KEWEXiNxgbtkNmOPAsY9WSqgOsP9pWK J924gdCCvovINzZtgRisxKth6Fkhra+VCsheg9SWvgR09Deo6CCoSwYxOSb0cjh2 oyX5Rb1kJ7Z/AoIBAQCX2iNVoBV/GcFeNXV3fXLH9ESCj0FwuNC1zp/TanDhyerK gd8k5k2Xzcc66gP73vpHUJ6dGlVni4/r+ivGV9HHkF/f/LGlaiuEhBZel2YY1mZb nIhg8dZOuNqW+mvMYlsKdHNPmW0GqpwBF0iWfu1jI+4gA7Kvdj6o7RIvH8eaVEJK GvqoHcP1fvmteJ2yDtmhGMfMy4QPqtnmmS8l+CJ/V2SsMuyorXIpkBsAoFAZ6ilT WY53CT4F5nWt4v39j7pl9SatfT1TV0SmOjvtb6Rf3zu0jyR6RMzkmHa/839ZRylI OxPntzDCi7qxy7yjLmlVPJ6RgZGgzwqHrEHlX+65AoIBAQCEzu6d3x5B2N02LZli eFr8MjqbI64GLiulEY5HgNJzZ8k3cjocJI0Ehj36VIEMaYRXSzbVkIO8SCgwsPiR n5mUDNX+t441jV62Odbxcc3Qdw226rABieOSupDmKEu92GOt57e8FV5939BOVYhf FunsJYQoViXbCEAIVYVgJSfBmNfVwuvgonfQyn8xErtm4/pyRGa71PqGGSKAj2Qi /16CuVUFGtZFsLV76JW8wZqHdI4bTF6TW3cEmaLbwcRGL7W0bMSS13rO8/pBh3QW PhUxhoGYt6rQHHEBkPa04nXDyZ10QRwgTSGVnBIyMK4KyTpxorm8OI2x7dzdcomX iCCPAoIBAETwfr2JKPb/AzrKhhbZgU+sLVn3WH/nb68VheNEmGOzsqXaSHCR2NOq /ow7bawjc8yUIhBRzokR4F/7jGolOmfdq0MYFb6/YokssKfv1ugxBhmvOxpZ6F6E cERJ8Ex/ffQU053gLR/0ammddVuS1GR5I/jEdP0lJVh0xapoZNUlT5dWYCgo20hY ZAmKpU+veyUn+5Li0pmm959vnLK5LJzEA5mpz3w1QPPtVwQs05dwmEV3CRAcCeeh 8sXp49WNCSW4I3BxuTZzRV845SGIFhZwgVV42PTp2LPKl2p6E7Bk8xpUCCvBpALp QmA5yIMx+u2Jpr7fUsXEXEPTEhvjff0= -----END PRIVATE KEY----- ================================================ FILE: _fixture/dist/private.txt ================================================ private file ================================================ FILE: _fixture/dist/public/assets/readme.md ================================================ readme in assets ================================================ FILE: _fixture/dist/public/assets/subfolder/subfolder.md ================================================ file inside subfolder ================================================ FILE: _fixture/dist/public/index.html ================================================

Hello from index

================================================ FILE: _fixture/dist/public/test.txt ================================================ test.txt contents ================================================ FILE: _fixture/folder/index.html ================================================ Echo ================================================ FILE: _fixture/index.html ================================================ Echo ================================================ FILE: bind.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echo import ( "encoding" "encoding/xml" "errors" "mime/multipart" "net/http" "reflect" "strconv" "strings" "time" ) // Binder is the interface that wraps the Bind method. type Binder interface { Bind(c *Context, target any) error } // DefaultBinder is the default implementation of the Binder interface. type DefaultBinder struct{} // BindUnmarshaler is the interface used to wrap the UnmarshalParam method. // Types that don't implement this, but do implement encoding.TextUnmarshaler // will use that interface instead. type BindUnmarshaler interface { // UnmarshalParam decodes and assigns a value from an form or query param. UnmarshalParam(param string) error } // bindMultipleUnmarshaler is used by binder to unmarshal multiple values from request at once to // type implementing this interface. For example request could have multiple query fields `?a=1&a=2&b=test` in that case // for `a` following slice `["1", "2"] will be passed to unmarshaller. type bindMultipleUnmarshaler interface { UnmarshalParams(params []string) error } // BindPathValues binds path parameter values to bindable object func BindPathValues(c *Context, target any) error { params := map[string][]string{} for _, param := range c.PathValues() { params[param.Name] = []string{param.Value} } if err := bindData(target, params, "param", nil); err != nil { return ErrBadRequest.Wrap(err) } return nil } // BindQueryParams binds query params to bindable object func BindQueryParams(c *Context, target any) error { if err := bindData(target, c.QueryParams(), "query", nil); err != nil { return ErrBadRequest.Wrap(err) } return nil } // BindBody binds request body contents to bindable object // NB: then binding forms take note that this implementation uses standard library form parsing // which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm // See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm // See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm func BindBody(c *Context, target any) (err error) { req := c.Request() if req.ContentLength == 0 { return } // mediatype is found like `mime.ParseMediaType()` does it base, _, _ := strings.Cut(req.Header.Get(HeaderContentType), ";") mediatype := strings.TrimSpace(base) switch mediatype { case MIMEApplicationJSON: if err = c.Echo().JSONSerializer.Deserialize(c, target); err != nil { var hErr *HTTPError if errors.As(err, &hErr) { return err } return ErrBadRequest.Wrap(err) } case MIMEApplicationXML, MIMETextXML: if err = xml.NewDecoder(req.Body).Decode(target); err != nil { return ErrBadRequest.Wrap(err) } case MIMEApplicationForm: params, err := c.FormValues() if err != nil { return ErrBadRequest.Wrap(err) } if err = bindData(target, params, "form", nil); err != nil { return ErrBadRequest.Wrap(err) } case MIMEMultipartForm: params, err := c.MultipartForm() if err != nil { return ErrBadRequest.Wrap(err) } if err = bindData(target, params.Value, "form", params.File); err != nil { return ErrBadRequest.Wrap(err) } default: return &HTTPError{Code: http.StatusUnsupportedMediaType} } return nil } // BindHeaders binds HTTP headers to a bindable object func BindHeaders(c *Context, target any) error { if err := bindData(target, c.Request().Header, "header", nil); err != nil { return ErrBadRequest.Wrap(err) } return nil } // Bind implements the `Binder#Bind` function. // Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous // step bound values. For single source binding use their own methods BindBody, BindQueryParams, BindPathValues. func (b *DefaultBinder) Bind(c *Context, target any) error { if err := BindPathValues(c, target); err != nil { return err } // Only bind query parameters for GET/DELETE/HEAD to avoid unexpected behavior with destination struct binding from body. // For example a request URL `&id=1&lang=en` with body `{"id":100,"lang":"de"}` would lead to precedence issues. // The HTTP method check restores pre-v4.1.11 behavior to avoid these problems (see issue #1670) method := c.Request().Method if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead { if err := BindQueryParams(c, target); err != nil { return err } } return BindBody(c, target) } // bindData will bind data ONLY fields in destination struct that have EXPLICIT tag func bindData(destination any, data map[string][]string, tag string, dataFiles map[string][]*multipart.FileHeader) error { if destination == nil || (len(data) == 0 && len(dataFiles) == 0) { return nil } hasFiles := len(dataFiles) > 0 typ := reflect.TypeOf(destination).Elem() val := reflect.ValueOf(destination).Elem() // Support binding to limited Map destinations: // - map[string][]string, // - map[string]string <-- (binds first value from data slice) // - map[string]any // You are better off binding to struct but there are user who want this map feature. Source of data for these cases are: // params,query,header,form as these sources produce string values, most of the time slice of strings, actually. if typ.Kind() == reflect.Map && typ.Key().Kind() == reflect.String { k := typ.Elem().Kind() isElemInterface := k == reflect.Interface isElemString := k == reflect.String isElemSliceOfStrings := k == reflect.Slice && typ.Elem().Elem().Kind() == reflect.String if !(isElemSliceOfStrings || isElemString || isElemInterface) { return nil } if val.IsNil() { val.Set(reflect.MakeMap(typ)) } for k, v := range data { if isElemString { val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0])) } else if isElemInterface { // To maintain backward compatibility, we always bind to the first string value // and not the slice of strings when dealing with map[string]any{} val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0])) } else { val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v)) } } return nil } // !struct if typ.Kind() != reflect.Struct { if tag == "param" || tag == "query" || tag == "header" { // incompatible type, data is probably to be found in the body return nil } return errors.New("binding element must be a struct") } for i := 0; i < typ.NumField(); i++ { // iterate over all destination fields typeField := typ.Field(i) structField := val.Field(i) if typeField.Anonymous { if structField.Kind() == reflect.Ptr { structField = structField.Elem() } } if !structField.CanSet() { continue } structFieldKind := structField.Kind() inputFieldName := typeField.Tag.Get(tag) if typeField.Anonymous && structFieldKind == reflect.Struct && inputFieldName != "" { // if anonymous struct with query/param/form tags, report an error return errors.New("query/param/form tags are not allowed with anonymous struct field") } if inputFieldName == "" { // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contain fields with tags). // structs that implement BindUnmarshaler are bound only when they have explicit tag if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct { if err := bindData(structField.Addr().Interface(), data, tag, dataFiles); err != nil { return err } } // does not have explicit tag and is not an ordinary struct - so move to next field continue } if hasFiles { if ok, err := isFieldMultipartFile(structField.Type()); err != nil { return err } else if ok { if ok := setMultipartFileHeaderTypes(structField, inputFieldName, dataFiles); ok { continue } } } inputValue, exists := data[inputFieldName] if !exists { // Go json.Unmarshal supports case-insensitive binding. However the // url params are bound case-sensitive which is inconsistent. To // fix this we must check all of the map values in a // case-insensitive search. for k, v := range data { if strings.EqualFold(k, inputFieldName) { inputValue = v exists = true break } } } if !exists { continue } // NOTE: algorithm here is not particularly sophisticated. It probably does not work with absurd types like `**[]*int` // but it is smart enough to handle niche cases like `*int`,`*[]string`,`[]*int` . // try unmarshalling first, in case we're dealing with an alias to an array type if ok, err := unmarshalInputsToField(typeField.Type.Kind(), inputValue, structField); ok { if err != nil { return err } continue } formatTag := typeField.Tag.Get("format") if ok, err := unmarshalInputToField(typeField.Type.Kind(), inputValue[0], structField, formatTag); ok { if err != nil { return err } continue } // we could be dealing with pointer to slice `*[]string` so dereference it. There are weird OpenAPI generators // that could create struct fields like that. if structFieldKind == reflect.Pointer { structFieldKind = structField.Elem().Kind() structField = structField.Elem() } if structFieldKind == reflect.Slice { sliceOf := structField.Type().Elem().Kind() numElems := len(inputValue) slice := reflect.MakeSlice(structField.Type(), numElems, numElems) for j := 0; j < numElems; j++ { if err := setWithProperType(sliceOf, inputValue[j], slice.Index(j)); err != nil { return err } } structField.Set(slice) continue } if err := setWithProperType(structFieldKind, inputValue[0], structField); err != nil { return err } } return nil } func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error { // But also call it here, in case we're dealing with an array of BindUnmarshalers // Note: format tag not available in this context, so empty string is passed if ok, err := unmarshalInputToField(valueKind, val, structField, ""); ok { return err } switch valueKind { case reflect.Ptr: return setWithProperType(structField.Elem().Kind(), val, structField.Elem()) case reflect.Int: return setIntField(val, 0, structField) case reflect.Int8: return setIntField(val, 8, structField) case reflect.Int16: return setIntField(val, 16, structField) case reflect.Int32: return setIntField(val, 32, structField) case reflect.Int64: return setIntField(val, 64, structField) case reflect.Uint: return setUintField(val, 0, structField) case reflect.Uint8: return setUintField(val, 8, structField) case reflect.Uint16: return setUintField(val, 16, structField) case reflect.Uint32: return setUintField(val, 32, structField) case reflect.Uint64: return setUintField(val, 64, structField) case reflect.Bool: return setBoolField(val, structField) case reflect.Float32: return setFloatField(val, 32, structField) case reflect.Float64: return setFloatField(val, 64, structField) case reflect.String: structField.SetString(val) default: return errors.New("unknown type") } return nil } func unmarshalInputsToField(valueKind reflect.Kind, values []string, field reflect.Value) (bool, error) { if valueKind == reflect.Ptr { if field.IsNil() { field.Set(reflect.New(field.Type().Elem())) } field = field.Elem() } fieldIValue := field.Addr().Interface() unmarshaler, ok := fieldIValue.(bindMultipleUnmarshaler) if !ok { return false, nil } return true, unmarshaler.UnmarshalParams(values) } func unmarshalInputToField(valueKind reflect.Kind, val string, field reflect.Value, formatTag string) (bool, error) { if valueKind == reflect.Ptr { if field.IsNil() { field.Set(reflect.New(field.Type().Elem())) } field = field.Elem() } fieldIValue := field.Addr().Interface() // Handle time.Time with custom format tag if formatTag != "" { if _, isTime := fieldIValue.(*time.Time); isTime { t, err := time.Parse(formatTag, val) if err != nil { return true, err } field.Set(reflect.ValueOf(t)) return true, nil } } switch unmarshaler := fieldIValue.(type) { case BindUnmarshaler: return true, unmarshaler.UnmarshalParam(val) case encoding.TextUnmarshaler: return true, unmarshaler.UnmarshalText([]byte(val)) } return false, nil } func setIntField(value string, bitSize int, field reflect.Value) error { if value == "" { value = "0" } intVal, err := strconv.ParseInt(value, 10, bitSize) if err == nil { field.SetInt(intVal) } return err } func setUintField(value string, bitSize int, field reflect.Value) error { if value == "" { value = "0" } uintVal, err := strconv.ParseUint(value, 10, bitSize) if err == nil { field.SetUint(uintVal) } return err } func setBoolField(value string, field reflect.Value) error { if value == "" { value = "false" } boolVal, err := strconv.ParseBool(value) if err == nil { field.SetBool(boolVal) } return err } func setFloatField(value string, bitSize int, field reflect.Value) error { if value == "" { value = "0.0" } floatVal, err := strconv.ParseFloat(value, bitSize) if err == nil { field.SetFloat(floatVal) } return err } var ( // NOT supported by bind as you can NOT check easily empty struct being actual file or not multipartFileHeaderType = reflect.TypeFor[multipart.FileHeader]() // supported by bind as you can check by nil value if file existed or not multipartFileHeaderPointerType = reflect.TypeFor[*multipart.FileHeader]() multipartFileHeaderSliceType = reflect.TypeFor[[]multipart.FileHeader]() multipartFileHeaderPointerSliceType = reflect.TypeFor[[]*multipart.FileHeader]() ) func isFieldMultipartFile(field reflect.Type) (bool, error) { switch field { case multipartFileHeaderPointerType, multipartFileHeaderSliceType, multipartFileHeaderPointerSliceType: return true, nil case multipartFileHeaderType: return true, errors.New("binding to multipart.FileHeader struct is not supported, use pointer to struct") default: return false, nil } } func setMultipartFileHeaderTypes(structField reflect.Value, inputFieldName string, files map[string][]*multipart.FileHeader) bool { fileHeaders := files[inputFieldName] if len(fileHeaders) == 0 { return false } result := true switch structField.Type() { case multipartFileHeaderPointerSliceType: structField.Set(reflect.ValueOf(fileHeaders)) case multipartFileHeaderSliceType: headers := make([]multipart.FileHeader, len(fileHeaders)) for i, fileHeader := range fileHeaders { headers[i] = *fileHeader } structField.Set(reflect.ValueOf(headers)) case multipartFileHeaderPointerType: structField.Set(reflect.ValueOf(fileHeaders[0])) default: result = false } return result } ================================================ FILE: bind_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echo import ( "bytes" "encoding/json" "encoding/xml" "errors" "fmt" "io" "mime/multipart" "net/http" "net/http/httptest" "net/http/httputil" "net/url" "reflect" "strconv" "strings" "testing" "time" "github.com/stretchr/testify/assert" ) type bindTestStruct struct { T Timestamp GoT time.Time PtrI16 *int16 PtrUI *uint Tptr *Timestamp PtrF32 *float32 PtrB *bool PtrI32 *int32 GoTptr *time.Time PtrI64 *int64 PtrI *int PtrI8 *int8 PtrF64 *float64 PtrUI8 *uint8 PtrUI64 *uint64 PtrUI16 *uint16 PtrS *string PtrUI32 *uint32 S string cantSet string DoesntExist string SA StringArray F64 float64 I int UI64 uint64 UI uint I64 int64 F32 float32 UI32 uint32 I32 int32 UI16 uint16 I16 int16 B bool UI8 uint8 I8 int8 } type bindTestStructWithTags struct { T Timestamp `json:"T" form:"T"` GoT time.Time `json:"GoT" form:"GoT"` PtrI16 *int16 `json:"PtrI16" form:"PtrI16"` PtrUI *uint `json:"PtrUI" form:"PtrUI"` Tptr *Timestamp `json:"Tptr" form:"Tptr"` PtrF32 *float32 `json:"PtrF32" form:"PtrF32"` PtrB *bool `json:"PtrB" form:"PtrB"` PtrI32 *int32 `json:"PtrI32" form:"PtrI32"` GoTptr *time.Time `json:"GoTptr" form:"GoTptr"` PtrI64 *int64 `json:"PtrI64" form:"PtrI64"` PtrI *int `json:"PtrI" form:"PtrI"` PtrI8 *int8 `json:"PtrI8" form:"PtrI8"` PtrF64 *float64 `json:"PtrF64" form:"PtrF64"` PtrUI8 *uint8 `json:"PtrUI8" form:"PtrUI8"` PtrUI64 *uint64 `json:"PtrUI64" form:"PtrUI64"` PtrUI16 *uint16 `json:"PtrUI16" form:"PtrUI16"` PtrS *string `json:"PtrS" form:"PtrS"` PtrUI32 *uint32 `json:"PtrUI32" form:"PtrUI32"` S string `json:"S" form:"S"` cantSet string DoesntExist string `json:"DoesntExist" form:"DoesntExist"` SA StringArray `json:"SA" form:"SA"` F64 float64 `json:"F64" form:"F64"` I int `json:"I" form:"I"` UI64 uint64 `json:"UI64" form:"UI64"` UI uint `json:"UI" form:"UI"` I64 int64 `json:"I64" form:"I64"` F32 float32 `json:"F32" form:"F32"` UI32 uint32 `json:"UI32" form:"UI32"` I32 int32 `json:"I32" form:"I32"` UI16 uint16 `json:"UI16" form:"UI16"` I16 int16 `json:"I16" form:"I16"` B bool `json:"B" form:"B"` UI8 uint8 `json:"UI8" form:"UI8"` I8 int8 `json:"I8" form:"I8"` } type Timestamp time.Time type TA []Timestamp type StringArray []string type Struct struct { Foo string } type Bar struct { Baz int `json:"baz" query:"baz"` } func (t *Timestamp) UnmarshalParam(src string) error { ts, err := time.Parse(time.RFC3339, src) *t = Timestamp(ts) return err } func (a *StringArray) UnmarshalParam(src string) error { *a = StringArray(strings.Split(src, ",")) return nil } func (s *Struct) UnmarshalParam(src string) error { *s = Struct{ Foo: src, } return nil } func (t bindTestStruct) GetCantSet() string { return t.cantSet } var values = map[string][]string{ "I": {"0"}, "PtrI": {"0"}, "I8": {"8"}, "PtrI8": {"8"}, "I16": {"16"}, "PtrI16": {"16"}, "I32": {"32"}, "PtrI32": {"32"}, "I64": {"64"}, "PtrI64": {"64"}, "UI": {"0"}, "PtrUI": {"0"}, "UI8": {"8"}, "PtrUI8": {"8"}, "UI16": {"16"}, "PtrUI16": {"16"}, "UI32": {"32"}, "PtrUI32": {"32"}, "UI64": {"64"}, "PtrUI64": {"64"}, "B": {"true"}, "PtrB": {"true"}, "F32": {"32.5"}, "PtrF32": {"32.5"}, "F64": {"64.5"}, "PtrF64": {"64.5"}, "S": {"test"}, "PtrS": {"test"}, "cantSet": {"test"}, "T": {"2016-12-06T19:09:05+01:00"}, "Tptr": {"2016-12-06T19:09:05+01:00"}, "GoT": {"2016-12-06T19:09:05+01:00"}, "GoTptr": {"2016-12-06T19:09:05+01:00"}, "ST": {"bar"}, } // ptr return pointer to value. This is useful as `v := []*int8{&int8(1)}` will not compile func ptr[T any](value T) *T { return &value } func TestToMultipleFields(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/?id=1&ID=2", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) type Root struct { ID int64 `query:"id"` Child2 struct { ID int64 } Child1 struct { ID int64 `query:"id"` } } u := new(Root) err := c.Bind(u) if assert.NoError(t, err) { assert.Equal(t, int64(1), u.ID) // perfectly reasonable assert.Equal(t, int64(1), u.Child1.ID) // untagged struct containing tagged field gets filled (by tag) assert.Equal(t, int64(0), u.Child2.ID) // untagged struct containing untagged field should not be bind } } func TestBindJSON(t *testing.T) { testBindOkay(t, strings.NewReader(userJSON), nil, MIMEApplicationJSON) testBindOkay(t, strings.NewReader(userJSON), dummyQuery, MIMEApplicationJSON) testBindArrayOkay(t, strings.NewReader(usersJSON), nil, MIMEApplicationJSON) testBindArrayOkay(t, strings.NewReader(usersJSON), dummyQuery, MIMEApplicationJSON) testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) testBindError(t, strings.NewReader(userJSONInvalidType), MIMEApplicationJSON, &json.UnmarshalTypeError{}) } func TestBindXML(t *testing.T) { testBindOkay(t, strings.NewReader(userXML), nil, MIMEApplicationXML) testBindOkay(t, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) testBindArrayOkay(t, strings.NewReader(userXML), nil, MIMEApplicationXML) testBindArrayOkay(t, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) testBindError(t, strings.NewReader(invalidContent), MIMEApplicationXML, errors.New("")) testBindError(t, strings.NewReader(userXMLConvertNumberError), MIMEApplicationXML, &strconv.NumError{}) testBindError(t, strings.NewReader(userXMLUnsupportedTypeError), MIMEApplicationXML, &xml.SyntaxError{}) testBindOkay(t, strings.NewReader(userXML), nil, MIMETextXML) testBindOkay(t, strings.NewReader(userXML), dummyQuery, MIMETextXML) testBindError(t, strings.NewReader(invalidContent), MIMETextXML, errors.New("")) testBindError(t, strings.NewReader(userXMLConvertNumberError), MIMETextXML, &strconv.NumError{}) testBindError(t, strings.NewReader(userXMLUnsupportedTypeError), MIMETextXML, &xml.SyntaxError{}) } func TestBindForm(t *testing.T) { testBindOkay(t, strings.NewReader(userForm), nil, MIMEApplicationForm) testBindOkay(t, strings.NewReader(userForm), dummyQuery, MIMEApplicationForm) e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userForm)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(HeaderContentType, MIMEApplicationForm) err := c.Bind(&[]struct{ Field string }{}) assert.Error(t, err) } func TestBindQueryParams(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) u := new(user) err := c.Bind(u) if assert.NoError(t, err) { assert.Equal(t, 1, u.ID) assert.Equal(t, "Jon Snow", u.Name) } } func TestBindQueryParamsCaseInsensitive(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/?ID=1&NAME=Jon+Snow", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) u := new(user) err := c.Bind(u) if assert.NoError(t, err) { assert.Equal(t, 1, u.ID) assert.Equal(t, "Jon Snow", u.Name) } } func TestBindQueryParamsCaseSensitivePrioritized(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/?id=1&ID=2&NAME=Jon+Snow&name=Jon+Doe", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) u := new(user) err := c.Bind(u) if assert.NoError(t, err) { assert.Equal(t, 1, u.ID) assert.Equal(t, "Jon Doe", u.Name) } } func TestBindHeaderParam(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Name", "Jon Doe") req.Header.Set("Id", "2") rec := httptest.NewRecorder() c := e.NewContext(req, rec) u := new(user) err := BindHeaders(c, u) if assert.NoError(t, err) { assert.Equal(t, 2, u.ID) assert.Equal(t, "Jon Doe", u.Name) } } func TestBindHeaderParamBadType(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Id", "salamander") rec := httptest.NewRecorder() c := e.NewContext(req, rec) u := new(user) err := BindHeaders(c, u) assert.Error(t, err) httpErr, ok := err.(*HTTPError) if assert.True(t, ok) { assert.Equal(t, http.StatusBadRequest, httpErr.Code) } } func TestBindUnmarshalParam(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { T Timestamp `query:"ts"` ST Struct StWithTag struct { Foo string `query:"st"` } TA []Timestamp `query:"ta"` SA StringArray `query:"sa"` }{} err := c.Bind(&result) ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)) if assert.NoError(t, err) { // assert.Equal( Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T) assert.Equal(t, ts, result.T) assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA) assert.Equal(t, []Timestamp{ts, ts}, result.TA) assert.Equal(t, Struct{""}, result.ST) // child struct does not have a field with matching tag assert.Equal(t, "baz", result.StWithTag.Foo) // child struct has field with matching tag } } func TestBindUnmarshalText(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { T time.Time `query:"ts"` ST Struct TA []time.Time `query:"ta"` SA StringArray `query:"sa"` }{} err := c.Bind(&result) ts := time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC) if assert.NoError(t, err) { // assert.Equal(t, Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T) assert.Equal(t, ts, result.T) assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA) assert.Equal(t, []time.Time{ts, ts}, result.TA) assert.Equal(t, Struct{""}, result.ST) // field in child struct does not have tag } } func TestBindUnmarshalParamPtr(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { Tptr *Timestamp `query:"ts"` }{} err := c.Bind(&result) if assert.NoError(t, err) { assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), *result.Tptr) } } func TestBindUnmarshalParamAnonymousFieldPtr(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/?baz=1", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { *Bar }{&Bar{}} err := c.Bind(&result) if assert.NoError(t, err) { assert.Equal(t, 1, result.Baz) } } func TestBindUnmarshalParamAnonymousFieldPtrNil(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/?baz=1", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { *Bar }{} err := c.Bind(&result) if assert.NoError(t, err) { assert.Nil(t, result.Bar) } } func TestBindUnmarshalParamAnonymousFieldPtrCustomTag(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, `/?bar={"baz":100}&baz=1`, nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { *Bar `json:"bar" query:"bar"` }{&Bar{}} err := c.Bind(&result) assert.Contains(t, err.Error(), "query/param/form tags are not allowed with anonymous struct field") } func TestBindUnmarshalTextPtr(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { Tptr *time.Time `query:"ts"` }{} err := c.Bind(&result) if assert.NoError(t, err) { assert.Equal(t, time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC), *result.Tptr) } } func TestBindMultipartForm(t *testing.T) { bodyBuffer := new(bytes.Buffer) mw := multipart.NewWriter(bodyBuffer) mw.WriteField("id", "1") mw.WriteField("name", "Jon Snow") mw.Close() body := bodyBuffer.Bytes() testBindOkay(t, bytes.NewReader(body), nil, mw.FormDataContentType()) testBindOkay(t, bytes.NewReader(body), dummyQuery, mw.FormDataContentType()) } func TestBindUnsupportedMediaType(t *testing.T) { testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) } func TestDefaultBinder_bindDataToMap(t *testing.T) { exampleData := map[string][]string{ "multiple": {"1", "2"}, "single": {"3"}, } t.Run("ok, bind to map[string]string", func(t *testing.T) { dest := map[string]string{} assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]string{ "multiple": "1", "single": "3", }, dest, ) }) t.Run("ok, bind to map[string]string with nil map", func(t *testing.T) { var dest map[string]string assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]string{ "multiple": "1", "single": "3", }, dest, ) }) t.Run("ok, bind to map[string][]string", func(t *testing.T) { dest := map[string][]string{} assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]string{ "multiple": {"1", "2"}, "single": {"3"}, }, dest, ) }) t.Run("ok, bind to map[string][]string with nil map", func(t *testing.T) { var dest map[string][]string assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]string{ "multiple": {"1", "2"}, "single": {"3"}, }, dest, ) }) t.Run("ok, bind to map[string]interface", func(t *testing.T) { dest := map[string]any{} assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]any{ "multiple": "1", "single": "3", }, dest, ) }) t.Run("ok, bind to map[string]interface with nil map", func(t *testing.T) { var dest map[string]any assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]any{ "multiple": "1", "single": "3", }, dest, ) }) t.Run("ok, bind to map[string]int skips", func(t *testing.T) { dest := map[string]int{} assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]int{}, dest) }) t.Run("ok, bind to map[string]int skips with nil map", func(t *testing.T) { var dest map[string]int assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]int(nil), dest) }) t.Run("ok, bind to map[string][]int skips", func(t *testing.T) { dest := map[string][]int{} assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]int{}, dest) }) t.Run("ok, bind to map[string][]int skips with nil map", func(t *testing.T) { var dest map[string][]int assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]int(nil), dest) }) } func TestBindbindData(t *testing.T) { ts := new(bindTestStruct) err := bindData(ts, values, "form", nil) assert.NoError(t, err) assert.Equal(t, 0, ts.I) assert.Equal(t, int8(0), ts.I8) assert.Equal(t, int16(0), ts.I16) assert.Equal(t, int32(0), ts.I32) assert.Equal(t, int64(0), ts.I64) assert.Equal(t, uint(0), ts.UI) assert.Equal(t, uint8(0), ts.UI8) assert.Equal(t, uint16(0), ts.UI16) assert.Equal(t, uint32(0), ts.UI32) assert.Equal(t, uint64(0), ts.UI64) assert.Equal(t, false, ts.B) assert.Equal(t, float32(0), ts.F32) assert.Equal(t, float64(0), ts.F64) assert.Equal(t, "", ts.S) assert.Equal(t, "", ts.cantSet) } func TestBindParam(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) c.InitializeRoute( &RouteInfo{Path: "/users/:id/:name"}, &PathValues{ {Name: "id", Value: "1"}, {Name: "name", Value: "Jon Snow"}, }, ) u := new(user) err := c.Bind(u) if assert.NoError(t, err) { assert.Equal(t, 1, u.ID) assert.Equal(t, "Jon Snow", u.Name) } // Second test for the absence of a param c2 := e.NewContext(req, rec) c2.InitializeRoute( &RouteInfo{Path: "/users/:id"}, &PathValues{ {Name: "id", Value: "1"}, }, ) u = new(user) err = c2.Bind(u) if assert.NoError(t, err) { assert.Equal(t, 1, u.ID) assert.Equal(t, "", u.Name) } // Bind something with param and post data payload body := bytes.NewBufferString(`{ "name": "Jon Snow" }`) e2 := New() req2 := httptest.NewRequest(http.MethodPost, "/", body) req2.Header.Set(HeaderContentType, MIMEApplicationJSON) rec2 := httptest.NewRecorder() c3 := e2.NewContext(req2, rec2) c3.InitializeRoute( &RouteInfo{Path: "/users/:id"}, &PathValues{ {Name: "id", Value: "1"}, }, ) u = new(user) err = c3.Bind(u) if assert.NoError(t, err) { assert.Equal(t, 1, u.ID) assert.Equal(t, "Jon Snow", u.Name) } } func TestBindUnmarshalTypeError(t *testing.T) { body := bytes.NewBufferString(`{ "id": "text" }`) e := New() req := httptest.NewRequest(http.MethodPost, "/", body) req.Header.Set(HeaderContentType, MIMEApplicationJSON) rec := httptest.NewRecorder() c := e.NewContext(req, rec) u := new(user) err := c.Bind(u) assert.EqualError(t, err, `code=400, message=Bad Request, err=json: cannot unmarshal string into Go struct field user.id of type int`) } func TestBindSetWithProperType(t *testing.T) { ts := new(bindTestStruct) typ := reflect.TypeOf(ts).Elem() val := reflect.ValueOf(ts).Elem() for i := 0; i < typ.NumField(); i++ { typeField := typ.Field(i) structField := val.Field(i) if !structField.CanSet() { continue } if len(values[typeField.Name]) == 0 { continue } val := values[typeField.Name][0] err := setWithProperType(typeField.Type.Kind(), val, structField) assert.NoError(t, err) } assertBindTestStruct(t, ts) type foo struct { Bar bytes.Buffer } v := &foo{} typ = reflect.TypeOf(v).Elem() val = reflect.ValueOf(v).Elem() assert.Error(t, setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0))) } func BenchmarkBindbindDataWithTags(b *testing.B) { b.ReportAllocs() ts := new(bindTestStructWithTags) var err error b.ResetTimer() for i := 0; i < b.N; i++ { err = bindData(ts, values, "form", nil) } assert.NoError(b, err) assertBindTestStruct(b, (*bindTestStruct)(ts)) } func assertBindTestStruct(tb testing.TB, ts *bindTestStruct) { assert.Equal(tb, 0, ts.I) assert.Equal(tb, int8(8), ts.I8) assert.Equal(tb, int16(16), ts.I16) assert.Equal(tb, int32(32), ts.I32) assert.Equal(tb, int64(64), ts.I64) assert.Equal(tb, uint(0), ts.UI) assert.Equal(tb, uint8(8), ts.UI8) assert.Equal(tb, uint16(16), ts.UI16) assert.Equal(tb, uint32(32), ts.UI32) assert.Equal(tb, uint64(64), ts.UI64) assert.Equal(tb, true, ts.B) assert.Equal(tb, float32(32.5), ts.F32) assert.Equal(tb, float64(64.5), ts.F64) assert.Equal(tb, "test", ts.S) assert.Equal(tb, "", ts.GetCantSet()) } func testBindOkay(t *testing.T, r io.Reader, query url.Values, ctype string) { e := New() path := "/" if len(query) > 0 { path += "?" + query.Encode() } req := httptest.NewRequest(http.MethodPost, path, r) rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(HeaderContentType, ctype) u := new(user) err := c.Bind(u) if assert.Equal(t, nil, err) { assert.Equal(t, 1, u.ID) assert.Equal(t, "Jon Snow", u.Name) } } func testBindArrayOkay(t *testing.T, r io.Reader, query url.Values, ctype string) { e := New() path := "/" if len(query) > 0 { path += "?" + query.Encode() } req := httptest.NewRequest(http.MethodPost, path, r) rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(HeaderContentType, ctype) u := []user{} err := c.Bind(&u) if assert.NoError(t, err) { assert.Equal(t, 1, len(u)) assert.Equal(t, 1, u[0].ID) assert.Equal(t, "Jon Snow", u[0].Name) } } func testBindError(t *testing.T, r io.Reader, ctype string, expectedInternal error) { e := New() req := httptest.NewRequest(http.MethodPost, "/", r) rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(HeaderContentType, ctype) u := new(user) err := c.Bind(u) switch { case strings.HasPrefix(ctype, MIMEApplicationJSON), strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML), strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): if assert.IsType(t, new(HTTPError), err) { assert.Equal(t, http.StatusBadRequest, err.(*HTTPError).Code) assert.IsType(t, expectedInternal, err.(*HTTPError).Unwrap()) } default: if assert.IsType(t, new(HTTPError), err) { assert.Equal(t, ErrUnsupportedMediaType, err) assert.IsType(t, expectedInternal, err.(*HTTPError).Unwrap()) } } } func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { // tests to check binding behaviour when multiple sources (path params, query params and request body) are in use // binding is done in steps and one source could overwrite previous source bound data // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed type Opts struct { Node string `json:"node" form:"node" query:"node" param:"node"` Lang string ID int `json:"id" form:"id" query:"id"` } var testCases = []struct { givenContent io.Reader whenBindTarget any expect any name string givenURL string givenMethod string expectError string whenNoPathValues bool }{ { name: "ok, POST bind to struct with: path param + query param + body", givenMethod: http.MethodPost, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`{"id": 1}`), expect: &Opts{ID: 1, Node: "node_from_path"}, // query params are not used, node is filled from path }, { name: "ok, PUT bind to struct with: path param + query param + body", givenMethod: http.MethodPut, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`{"id": 1}`), expect: &Opts{ID: 1, Node: "node_from_path"}, // query params are not used }, { name: "ok, GET bind to struct with: path param + query param + body", givenMethod: http.MethodGet, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`{"id": 1}`), expect: &Opts{ID: 1, Node: "xxx"}, // query overwrites previous path value }, { name: "ok, GET bind to struct with: path param + query param + body", givenMethod: http.MethodGet, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), expect: &Opts{ID: 1, Node: "zzz"}, // body is bound last and overwrites previous (path,query) values }, { name: "ok, DELETE bind to struct with: path param + query param + body", givenMethod: http.MethodDelete, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), expect: &Opts{ID: 1, Node: "zzz"}, // for DELETE body is bound after query params }, { name: "ok, POST bind to struct with: path param + body", givenMethod: http.MethodPost, givenURL: "/api/real_node/endpoint", givenContent: strings.NewReader(`{"id": 1}`), expect: &Opts{ID: 1, Node: "node_from_path"}, }, { name: "ok, POST bind to struct with path + query + body = body has priority", givenMethod: http.MethodPost, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), expect: &Opts{ID: 1, Node: "zzz"}, // field value from content has higher priority }, { name: "nok, POST body bind failure", givenMethod: http.MethodPost, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`{`), expect: &Opts{ID: 0, Node: "node_from_path"}, // query binding has already modified bind target expectError: "code=400, message=Bad Request, err=unexpected EOF", }, { name: "nok, GET with body bind failure when types are not convertible", givenMethod: http.MethodGet, givenURL: "/api/real_node/endpoint?id=nope", givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), expect: &Opts{ID: 0, Node: "node_from_path"}, // path params binding has already modified bind target expectError: `code=400, message=Bad Request, err=strconv.ParseInt: parsing "nope": invalid syntax`, }, { name: "nok, GET body bind failure - trying to bind json array to struct", givenMethod: http.MethodGet, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`[{"id": 1}]`), expect: &Opts{ID: 0, Node: "xxx"}, // query binding has already modified bind target expectError: `code=400, message=Bad Request, err=json: cannot unmarshal array into Go value of type echo.Opts`, }, { // query param is ignored as we do not know where exactly to bind it in slice name: "ok, GET bind to struct slice, ignore query param", givenMethod: http.MethodGet, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`[{"id": 1}]`), whenNoPathValues: true, whenBindTarget: &[]Opts{}, expect: &[]Opts{ {ID: 1, Node: ""}, }, }, { // binding query params interferes with body. b.BindBody() should be used to bind only body to slice name: "ok, POST binding to slice should not be affected query params types", givenMethod: http.MethodPost, givenURL: "/api/real_node/endpoint?id=nope&node=xxx", givenContent: strings.NewReader(`[{"id": 1}]`), whenNoPathValues: true, whenBindTarget: &[]Opts{}, expect: &[]Opts{{ID: 1}}, expectError: "", }, { // path param is ignored as we do not know where exactly to bind it in slice name: "ok, GET bind to struct slice, ignore path param", givenMethod: http.MethodGet, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`[{"id": 1}]`), whenBindTarget: &[]Opts{}, expect: &[]Opts{ {ID: 1, Node: ""}, }, }, { name: "ok, GET body bind json array to slice", givenMethod: http.MethodGet, givenURL: "/api/real_node/endpoint", givenContent: strings.NewReader(`[{"id": 1}]`), whenNoPathValues: true, whenBindTarget: &[]Opts{}, expect: &[]Opts{{ID: 1, Node: ""}}, expectError: "", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() // assume route we are testing is "/api/:node/endpoint?some_query_params=here" req := httptest.NewRequest(tc.givenMethod, tc.givenURL, tc.givenContent) req.Header.Set(HeaderContentType, MIMEApplicationJSON) rec := httptest.NewRecorder() c := e.NewContext(req, rec) if !tc.whenNoPathValues { c.SetPathValues(PathValues{ {Name: "node", Value: "node_from_path"}, }) } var bindTarget any if tc.whenBindTarget != nil { bindTarget = tc.whenBindTarget } else { bindTarget = &Opts{} } b := new(DefaultBinder) err := b.Bind(c, bindTarget) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, bindTarget) }) } } func TestDefaultBinder_BindBody(t *testing.T) { // tests to check binding behaviour when multiple sources (path params, query params and request body) are in use // generally when binding from request body - URL and path params are ignored - unless form is being bound. // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed type Node struct { Node string `json:"node" xml:"node" form:"node" query:"node" param:"node"` ID int `json:"id" xml:"id" form:"id" query:"id"` } type Nodes struct { Nodes []Node `xml:"node" form:"node"` } var testCases = []struct { givenContent io.Reader whenBindTarget any expect any name string givenURL string givenMethod string givenContentType string expectError string whenNoPathValues bool whenChunkedBody bool }{ { name: "ok, JSON POST bind to struct with: path + query + empty field in body", givenURL: "/api/real_node/endpoint?node=xxx", givenMethod: http.MethodPost, givenContentType: MIMEApplicationJSON, givenContent: strings.NewReader(`{"id": 1}`), expect: &Node{ID: 1, Node: ""}, // path params or query params should not interfere with body }, { name: "ok, JSON POST bind to struct with: path + query + body", givenURL: "/api/real_node/endpoint?node=xxx", givenMethod: http.MethodPost, givenContentType: MIMEApplicationJSON, givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), expect: &Node{ID: 1, Node: "zzz"}, // field value from content has higher priority }, { name: "ok, JSON POST body bind json array to slice (has matching path/query params)", givenURL: "/api/real_node/endpoint?node=xxx", givenMethod: http.MethodPost, givenContentType: MIMEApplicationJSON, givenContent: strings.NewReader(`[{"id": 1}]`), whenNoPathValues: true, whenBindTarget: &[]Node{}, expect: &[]Node{{ID: 1, Node: ""}}, expectError: "", }, { // rare case as GET is not usually used to send request body name: "ok, JSON GET bind to struct with: path + query + empty field in body", givenURL: "/api/real_node/endpoint?node=xxx", givenMethod: http.MethodGet, givenContentType: MIMEApplicationJSON, givenContent: strings.NewReader(`{"id": 1}`), expect: &Node{ID: 1, Node: ""}, // path params or query params should not interfere with body }, { // rare case as GET is not usually used to send request body name: "ok, JSON GET bind to struct with: path + query + body", givenURL: "/api/real_node/endpoint?node=xxx", givenMethod: http.MethodGet, givenContentType: MIMEApplicationJSON, givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), expect: &Node{ID: 1, Node: "zzz"}, // field value from content has higher priority }, { name: "nok, JSON POST body bind failure", givenURL: "/api/real_node/endpoint?node=xxx", givenMethod: http.MethodPost, givenContentType: MIMEApplicationJSON, givenContent: strings.NewReader(`{`), expect: &Node{ID: 0, Node: ""}, expectError: "code=400, message=Bad Request, err=unexpected EOF", }, { name: "ok, XML POST bind to struct with: path + query + empty body", givenURL: "/api/real_node/endpoint?node=xxx", givenMethod: http.MethodPost, givenContentType: MIMEApplicationXML, givenContent: strings.NewReader(`1yyy`), expect: &Node{ID: 1, Node: "yyy"}, }, { name: "ok, XML POST bind array to slice with: path + query + body", givenURL: "/api/real_node/endpoint?node=xxx", givenMethod: http.MethodPost, givenContentType: MIMEApplicationXML, givenContent: strings.NewReader(`1yyy`), whenBindTarget: &Nodes{}, expect: &Nodes{Nodes: []Node{{ID: 1, Node: "yyy"}}}, }, { name: "nok, XML POST bind failure", givenURL: "/api/real_node/endpoint?node=xxx", givenMethod: http.MethodPost, givenContentType: MIMEApplicationXML, givenContent: strings.NewReader(`<`), expect: &Node{ID: 0, Node: ""}, expectError: "code=400, message=Bad Request, err=XML syntax error on line 1: unexpected EOF", }, { name: "ok, FORM POST bind to struct with: path + query + body", givenURL: "/api/real_node/endpoint?node=xxx", givenMethod: http.MethodPost, givenContentType: MIMEApplicationForm, givenContent: strings.NewReader(`id=1&node=yyy`), expect: &Node{ID: 1, Node: "yyy"}, }, { // NB: form values are taken from BOTH body and query for POST/PUT/PATCH by standard library implementation // See: https://golang.org/pkg/net/http/#Request.ParseForm name: "ok, FORM POST bind to struct with: path + query + empty field in body", givenURL: "/api/real_node/endpoint?node=xxx", givenMethod: http.MethodPost, givenContentType: MIMEApplicationForm, givenContent: strings.NewReader(`id=1`), expect: &Node{ID: 1, Node: "xxx"}, }, { // NB: form values are taken from query by standard library implementation // See: https://golang.org/pkg/net/http/#Request.ParseForm name: "ok, FORM GET bind to struct with: path + query + empty field in body", givenURL: "/api/real_node/endpoint?node=xxx", givenMethod: http.MethodGet, givenContentType: MIMEApplicationForm, givenContent: strings.NewReader(`id=1`), expect: &Node{ID: 0, Node: "xxx"}, // 'xxx' is taken from URL and body is not used with GET by implementation }, { name: "nok, unsupported content type", givenURL: "/api/real_node/endpoint?node=xxx", givenMethod: http.MethodPost, givenContentType: MIMETextPlain, givenContent: strings.NewReader(``), expect: &Node{ID: 0, Node: ""}, expectError: "code=415, message=Unsupported Media Type", }, // FIXME: REASON in Go 1.24 and earlier http.NoBody would result ContentLength=-1 // but as of Go 1.25 http.NoBody would result ContentLength=0 // I am too lazy to bother documenting this as 2 version specific tests. //{ // name: "nok, JSON POST with http.NoBody", // givenURL: "/api/real_node/endpoint?node=xxx", // givenMethod: http.MethodPost, // givenContentType: MIMEApplicationJSON, // givenContent: http.NoBody, // expect: &Node{ID: 0, Node: ""}, // expectError: "code=400, message=EOF, internal=EOF", //}, { name: "ok, JSON POST with empty body", givenURL: "/api/real_node/endpoint?node=xxx", givenMethod: http.MethodPost, givenContentType: MIMEApplicationJSON, givenContent: strings.NewReader(""), expect: &Node{ID: 0, Node: ""}, }, { name: "ok, JSON POST bind to struct with: path + query + chunked body", givenURL: "/api/real_node/endpoint?node=xxx", givenMethod: http.MethodPost, givenContentType: MIMEApplicationJSON, givenContent: httputil.NewChunkedReader(strings.NewReader("18\r\n" + `{"id": 1, "node": "zzz"}` + "\r\n0\r\n")), whenChunkedBody: true, expect: &Node{ID: 1, Node: "zzz"}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() // assume route we are testing is "/api/:node/endpoint?some_query_params=here" req := httptest.NewRequest(tc.givenMethod, tc.givenURL, tc.givenContent) switch tc.givenContentType { case MIMEApplicationXML: req.Header.Set(HeaderContentType, MIMEApplicationXML) case MIMEApplicationForm: req.Header.Set(HeaderContentType, MIMEApplicationForm) case MIMEApplicationJSON: req.Header.Set(HeaderContentType, MIMEApplicationJSON) } if tc.whenChunkedBody { req.ContentLength = -1 req.TransferEncoding = append(req.TransferEncoding, "chunked") } rec := httptest.NewRecorder() c := e.NewContext(req, rec) if !tc.whenNoPathValues { c.SetPathValues(PathValues{ {Name: "node", Value: "real_node"}, }) } var bindTarget any if tc.whenBindTarget != nil { bindTarget = tc.whenBindTarget } else { bindTarget = &Node{} } err := BindBody(c, bindTarget) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, bindTarget) }) } } func testBindURL(queryString string, target any) error { e := New() req := httptest.NewRequest(http.MethodGet, queryString, nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) return c.Bind(target) } type unixTimestamp struct { Time time.Time } func (t *unixTimestamp) UnmarshalParam(param string) error { n, err := strconv.ParseInt(param, 10, 64) if err != nil { return fmt.Errorf("'%s' is not an integer", param) } *t = unixTimestamp{Time: time.Unix(n, 0)} return err } type IntArrayA []int // UnmarshalParam converts value to *Int64Slice. This allows the API to accept // a comma-separated list of integers as a query parameter. func (i *IntArrayA) UnmarshalParam(value string) error { var values = strings.Split(value, ",") var numbers = make([]int, 0, len(values)) for _, v := range values { n, err := strconv.ParseInt(v, 10, 64) if err != nil { return fmt.Errorf("'%s' is not an integer", v) } numbers = append(numbers, int(n)) } *i = append(*i, numbers...) return nil } func TestBindUnmarshalParamExtras(t *testing.T) { // this test documents how bind handles `BindUnmarshaler` interface: // NOTE: BindUnmarshaler chooses first input value to be bound. t.Run("nok, unmarshalling fails", func(t *testing.T) { result := struct { V unixTimestamp `query:"t"` }{} err := testBindURL("/?t=xxxx", &result) assert.EqualError(t, err, `code=400, message=Bad Request, err='xxxx' is not an integer`) }) t.Run("ok, target is struct", func(t *testing.T) { result := struct { V unixTimestamp `query:"t"` }{} err := testBindURL("/?t=1710095540&t=1710095541", &result) assert.NoError(t, err) expect := unixTimestamp{ Time: time.Unix(1710095540, 0), } assert.Equal(t, expect, result.V) }) t.Run("ok, target is an alias to slice and is nil, append only values from first", func(t *testing.T) { result := struct { V IntArrayA `query:"a"` }{} err := testBindURL("/?a=1,2,3&a=4,5,6", &result) assert.NoError(t, err) assert.Equal(t, IntArrayA([]int{1, 2, 3}), result.V) }) t.Run("ok, target is an alias to slice and is nil, single input", func(t *testing.T) { result := struct { V IntArrayA `query:"a"` }{} err := testBindURL("/?a=1,2", &result) assert.NoError(t, err) assert.Equal(t, IntArrayA([]int{1, 2}), result.V) }) t.Run("ok, target is pointer an alias to slice and is nil", func(t *testing.T) { result := struct { V *IntArrayA `query:"a"` }{} err := testBindURL("/?a=1&a=4,5,6", &result) assert.NoError(t, err) var expected = IntArrayA([]int{1}) assert.Equal(t, &expected, result.V) }) t.Run("ok, target is pointer an alias to slice and is NOT nil", func(t *testing.T) { result := struct { V *IntArrayA `query:"a"` }{} result.V = new(IntArrayA) // NOT nil err := testBindURL("/?a=1&a=4,5,6", &result) assert.NoError(t, err) var expected = IntArrayA([]int{1}) assert.Equal(t, &expected, result.V) }) } type unixTimestampLast struct { Time time.Time } // this is silly example for `bindMultipleUnmarshaler` for type that uses last input value for unmarshalling func (t *unixTimestampLast) UnmarshalParams(params []string) error { lastInput := params[len(params)-1] n, err := strconv.ParseInt(lastInput, 10, 64) if err != nil { return fmt.Errorf("'%s' is not an integer", lastInput) } *t = unixTimestampLast{Time: time.Unix(n, 0)} return err } type IntArrayB []int func (i *IntArrayB) UnmarshalParams(params []string) error { var numbers = make([]int, 0, len(params)) for _, param := range params { var values = strings.Split(param, ",") for _, v := range values { n, err := strconv.ParseInt(v, 10, 64) if err != nil { return fmt.Errorf("'%s' is not an integer", v) } numbers = append(numbers, int(n)) } } *i = append(*i, numbers...) return nil } func TestBindUnmarshalParams(t *testing.T) { // this test documents how bind handles `bindMultipleUnmarshaler` interface: t.Run("nok, unmarshalling fails", func(t *testing.T) { result := struct { V unixTimestampLast `query:"t"` }{} err := testBindURL("/?t=xxxx", &result) assert.EqualError(t, err, "code=400, message=Bad Request, err='xxxx' is not an integer") }) t.Run("ok, target is struct", func(t *testing.T) { result := struct { V unixTimestampLast `query:"t"` }{} err := testBindURL("/?t=1710095540&t=1710095541", &result) assert.NoError(t, err) expect := unixTimestampLast{ Time: time.Unix(1710095541, 0), } assert.Equal(t, expect, result.V) }) t.Run("ok, target is an alias to slice and is nil, append multiple inputs", func(t *testing.T) { result := struct { V IntArrayB `query:"a"` }{} err := testBindURL("/?a=1,2,3&a=4,5,6", &result) assert.NoError(t, err) assert.Equal(t, IntArrayB([]int{1, 2, 3, 4, 5, 6}), result.V) }) t.Run("ok, target is an alias to slice and is nil, single input", func(t *testing.T) { result := struct { V IntArrayB `query:"a"` }{} err := testBindURL("/?a=1,2", &result) assert.NoError(t, err) assert.Equal(t, IntArrayB([]int{1, 2}), result.V) }) t.Run("ok, target is pointer an alias to slice and is nil", func(t *testing.T) { result := struct { V *IntArrayB `query:"a"` }{} err := testBindURL("/?a=1&a=4,5,6", &result) assert.NoError(t, err) var expected = IntArrayB([]int{1, 4, 5, 6}) assert.Equal(t, &expected, result.V) }) t.Run("ok, target is pointer an alias to slice and is NOT nil", func(t *testing.T) { result := struct { V *IntArrayB `query:"a"` }{} result.V = new(IntArrayB) // NOT nil err := testBindURL("/?a=1&a=4,5,6", &result) assert.NoError(t, err) var expected = IntArrayB([]int{1, 4, 5, 6}) assert.Equal(t, &expected, result.V) }) } func TestBindInt8(t *testing.T) { t.Run("nok, binding fails", func(t *testing.T) { type target struct { V int8 `query:"v"` } p := target{} err := testBindURL("/?v=x&v=2", &p) assert.EqualError(t, err, `code=400, message=Bad Request, err=strconv.ParseInt: parsing "x": invalid syntax`) }) t.Run("nok, int8 embedded in struct", func(t *testing.T) { type target struct { int8 `query:"v"` // embedded field is `Anonymous`. We can only set public fields } p := target{} err := testBindURL("/?v=1&v=2", &p) assert.NoError(t, err) assert.Equal(t, target{0}, p) }) t.Run("nok, pointer to int8 embedded in struct", func(t *testing.T) { type target struct { *int8 `query:"v"` // embedded field is `Anonymous`. We can only set public fields } p := target{} err := testBindURL("/?v=1&v=2", &p) assert.NoError(t, err) assert.Equal(t, target{int8: nil}, p) }) t.Run("ok, bind int8 as struct field", func(t *testing.T) { type target struct { V int8 `query:"v"` } p := target{V: 127} err := testBindURL("/?v=1&v=2", &p) assert.NoError(t, err) assert.Equal(t, target{V: 1}, p) }) t.Run("ok, bind pointer to int8 as struct field, value is nil", func(t *testing.T) { type target struct { V *int8 `query:"v"` } p := target{} err := testBindURL("/?v=1&v=2", &p) assert.NoError(t, err) assert.Equal(t, target{V: ptr(int8(1))}, p) }) t.Run("ok, bind pointer to int8 as struct field, value is set", func(t *testing.T) { type target struct { V *int8 `query:"v"` } p := target{V: ptr(int8(127))} err := testBindURL("/?v=1&v=2", &p) assert.NoError(t, err) assert.Equal(t, target{V: ptr(int8(1))}, p) }) t.Run("ok, bind int8 slice as struct field, value is nil", func(t *testing.T) { type target struct { V []int8 `query:"v"` } p := target{} err := testBindURL("/?v=1&v=2", &p) assert.NoError(t, err) assert.Equal(t, target{V: []int8{1, 2}}, p) }) t.Run("ok, bind slice of int8 as struct field, value is set", func(t *testing.T) { type target struct { V []int8 `query:"v"` } p := target{V: []int8{111}} err := testBindURL("/?v=1&v=2", &p) assert.NoError(t, err) assert.Equal(t, target{V: []int8{1, 2}}, p) }) t.Run("ok, bind slice of pointer to int8 as struct field, value is set", func(t *testing.T) { type target struct { V []*int8 `query:"v"` } p := target{V: []*int8{ptr(int8(127))}} err := testBindURL("/?v=1&v=2", &p) assert.NoError(t, err) assert.Equal(t, target{V: []*int8{ptr(int8(1)), ptr(int8(2))}}, p) }) t.Run("ok, bind pointer to slice of int8 as struct field, value is set", func(t *testing.T) { type target struct { V *[]int8 `query:"v"` } p := target{V: &[]int8{111}} err := testBindURL("/?v=1&v=2", &p) assert.NoError(t, err) assert.Equal(t, target{V: &[]int8{1, 2}}, p) }) } func TestBindMultipartFormFiles(t *testing.T) { file1 := createTestFormFile("file", "file1.txt") file11 := createTestFormFile("file", "file11.txt") file2 := createTestFormFile("file2", "file2.txt") filesA := createTestFormFile("files", "filesA.txt") filesB := createTestFormFile("files", "filesB.txt") t.Run("nok, can not bind to multipart file struct", func(t *testing.T) { var target struct { File multipart.FileHeader `form:"file"` } err := bindMultipartFiles(t, &target, file1, file2) // file2 should be ignored assert.EqualError(t, err, `code=400, message=Bad Request, err=binding to multipart.FileHeader struct is not supported, use pointer to struct`) }) t.Run("ok, bind single multipart file to pointer to multipart file", func(t *testing.T) { var target struct { File *multipart.FileHeader `form:"file"` } err := bindMultipartFiles(t, &target, file1, file2) // file2 should be ignored assert.NoError(t, err) assertMultipartFileHeader(t, target.File, file1) }) t.Run("ok, bind multiple multipart files to pointer to multipart file", func(t *testing.T) { var target struct { File *multipart.FileHeader `form:"file"` } err := bindMultipartFiles(t, &target, file1, file11) assert.NoError(t, err) assertMultipartFileHeader(t, target.File, file1) // should choose first one }) t.Run("ok, bind multiple multipart files to slice of multipart file", func(t *testing.T) { var target struct { Files []multipart.FileHeader `form:"files"` } err := bindMultipartFiles(t, &target, filesA, filesB, file1) assert.NoError(t, err) assert.Len(t, target.Files, 2) assertMultipartFileHeader(t, &target.Files[0], filesA) assertMultipartFileHeader(t, &target.Files[1], filesB) }) t.Run("ok, bind multiple multipart files to slice of pointer to multipart file", func(t *testing.T) { var target struct { Files []*multipart.FileHeader `form:"files"` } err := bindMultipartFiles(t, &target, filesA, filesB, file1) assert.NoError(t, err) assert.Len(t, target.Files, 2) assertMultipartFileHeader(t, target.Files[0], filesA) assertMultipartFileHeader(t, target.Files[1], filesB) }) } type testFormFile struct { Fieldname string Filename string Content []byte } func createTestFormFile(formFieldName string, filename string) testFormFile { return testFormFile{ Fieldname: formFieldName, Filename: filename, Content: []byte(strings.Repeat(filename, 10)), } } func bindMultipartFiles(t *testing.T, target any, files ...testFormFile) error { var body bytes.Buffer mw := multipart.NewWriter(&body) for _, file := range files { fw, err := mw.CreateFormFile(file.Fieldname, file.Filename) assert.NoError(t, err) n, err := fw.Write(file.Content) assert.NoError(t, err) assert.Equal(t, len(file.Content), n) } err := mw.Close() assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, "/", &body) assert.NoError(t, err) req.Header.Set("Content-Type", mw.FormDataContentType()) rec := httptest.NewRecorder() e := New() c := e.NewContext(req, rec) return c.Bind(target) } func assertMultipartFileHeader(t *testing.T, fh *multipart.FileHeader, file testFormFile) { assert.Equal(t, file.Filename, fh.Filename) assert.Equal(t, int64(len(file.Content)), fh.Size) fl, err := fh.Open() assert.NoError(t, err) body, err := io.ReadAll(fl) assert.NoError(t, err) assert.Equal(t, string(file.Content), string(body)) err = fl.Close() assert.NoError(t, err) } func TestTimeFormatBinding(t *testing.T) { type TestStruct struct { DateTimeLocal time.Time `form:"datetime_local" format:"2006-01-02T15:04"` Date time.Time `query:"date" format:"2006-01-02"` CustomFormat time.Time `form:"custom" format:"01/02/2006 15:04:05"` DefaultTime time.Time `form:"default_time"` // No format tag - should use default parsing PtrTime *time.Time `query:"ptr_time" format:"2006-01-02"` } testCases := []struct { name string contentType string data string queryParams string expect TestStruct expectError bool }{ { name: "ok, datetime-local format binding", contentType: MIMEApplicationForm, data: "datetime_local=2023-12-25T14:30&default_time=2023-12-25T14:30:45Z", expect: TestStruct{ DateTimeLocal: time.Date(2023, 12, 25, 14, 30, 0, 0, time.UTC), DefaultTime: time.Date(2023, 12, 25, 14, 30, 45, 0, time.UTC), }, }, { name: "ok, date format binding via query params", queryParams: "?date=2023-01-15&ptr_time=2023-02-20", expect: TestStruct{ Date: time.Date(2023, 1, 15, 0, 0, 0, 0, time.UTC), PtrTime: &time.Time{}, }, }, { name: "ok, custom format via form data", contentType: MIMEApplicationForm, data: "custom=12/25/2023 14:30:45", expect: TestStruct{ CustomFormat: time.Date(2023, 12, 25, 14, 30, 45, 0, time.UTC), }, }, { name: "nok, invalid format should fail", contentType: MIMEApplicationForm, data: "datetime_local=invalid-date", expectError: true, }, { name: "nok, wrong format should fail", contentType: MIMEApplicationForm, data: "datetime_local=2023-12-25", // Missing time part expectError: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() var req *http.Request if tc.contentType == MIMEApplicationJSON { req = httptest.NewRequest(http.MethodPost, "/"+tc.queryParams, strings.NewReader(tc.data)) req.Header.Set(HeaderContentType, tc.contentType) } else if tc.contentType == MIMEApplicationForm { req = httptest.NewRequest(http.MethodPost, "/"+tc.queryParams, strings.NewReader(tc.data)) req.Header.Set(HeaderContentType, tc.contentType) } else { req = httptest.NewRequest(http.MethodGet, "/"+tc.queryParams, nil) } rec := httptest.NewRecorder() c := e.NewContext(req, rec) var result TestStruct err := c.Bind(&result) if tc.expectError { assert.Error(t, err) return } assert.NoError(t, err) // Check individual fields since time comparison can be tricky if !tc.expect.DateTimeLocal.IsZero() { assert.True(t, tc.expect.DateTimeLocal.Equal(result.DateTimeLocal), "DateTimeLocal: expected %v, got %v", tc.expect.DateTimeLocal, result.DateTimeLocal) } if !tc.expect.Date.IsZero() { assert.True(t, tc.expect.Date.Equal(result.Date), "Date: expected %v, got %v", tc.expect.Date, result.Date) } if !tc.expect.CustomFormat.IsZero() { assert.True(t, tc.expect.CustomFormat.Equal(result.CustomFormat), "CustomFormat: expected %v, got %v", tc.expect.CustomFormat, result.CustomFormat) } if !tc.expect.DefaultTime.IsZero() { assert.True(t, tc.expect.DefaultTime.Equal(result.DefaultTime), "DefaultTime: expected %v, got %v", tc.expect.DefaultTime, result.DefaultTime) } if tc.expect.PtrTime != nil { assert.NotNil(t, result.PtrTime) if result.PtrTime != nil { expectedPtr := time.Date(2023, 2, 20, 0, 0, 0, 0, time.UTC) assert.True(t, expectedPtr.Equal(*result.PtrTime), "PtrTime: expected %v, got %v", expectedPtr, *result.PtrTime) } } }) } } ================================================ FILE: binder.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echo import ( "encoding" "encoding/json" "fmt" "net/http" "strconv" "strings" "time" ) /** Following functions provide handful of methods for binding to Go native types from request query or path parameters. * QueryParamsBinder(c) - binds query parameters (source URL) * PathValuesBinder(c) - binds path parameters (source URL) * FormFieldBinder(c) - binds form fields (source URL + body) Example: ```go var length int64 err := echo.QueryParamsBinder(c).Int64("length", &length).BindError() ``` For every supported type there are following methods: * ("param", &destination) - if parameter value exists then binds it to given destination of that type i.e Int64(...). * Must("param", &destination) - parameter value is required to exist, binds it to given destination of that type i.e MustInt64(...). * s("param", &destination) - (for slices) if parameter values exists then binds it to given destination of that type i.e Int64s(...). * Musts("param", &destination) - (for slices) parameter value is required to exist, binds it to given destination of that type i.e MustInt64s(...). for some slice types `BindWithDelimiter("param", &dest, ",")` supports splitting parameter values before type conversion is done i.e. URL `/api/search?id=1,2,3&id=1` can be bind to `[]int64{1,2,3,1}` `FailFast` flags binder to stop binding after first bind error during binder call chain. Enabled by default. `BindError()` returns first bind error from binder and resets errors in binder. Useful along with `FailFast()` method to do binding and returns on first problem `BindErrors()` returns all bind errors from binder and resets errors in binder. Types that are supported: * bool * float32 * float64 * int * int8 * int16 * int32 * int64 * uint * uint8/byte (does not support `bytes()`. Use BindUnmarshaler/CustomFunc to convert value from base64 etc to []byte{}) * uint16 * uint32 * uint64 * string * time * duration * BindUnmarshaler() interface * TextUnmarshaler() interface * JSONUnmarshaler() interface * UnixTime() - converts unix time (integer) to time.Time * UnixTimeMilli() - converts unix time with millisecond precision (integer) to time.Time * UnixTimeNano() - converts unix time with nanosecond precision (integer) to time.Time * CustomFunc() - callback function for your custom conversion logic. Signature `func(values []string) []error` */ // BindingError represents an error that occurred while binding request data. type BindingError struct { // Field is the field name where value binding failed Field string `json:"field"` *HTTPError // Values of parameter that failed to bind. Values []string `json:"-"` } // NewBindingError creates new instance of binding error func NewBindingError(sourceParam string, values []string, message string, err error) error { return &BindingError{ Field: sourceParam, Values: values, HTTPError: &HTTPError{Code: http.StatusBadRequest, Message: message, err: err}, } } // Error returns error message func (be *BindingError) Error() string { return fmt.Sprintf("%s, field=%s", be.HTTPError.Error(), be.Field) } // ValueBinder provides utility methods for binding query or path parameter to various Go built-in types type ValueBinder struct { // ValueFunc is used to get single parameter (first) value from request ValueFunc func(sourceParam string) string // ValuesFunc is used to get all values for parameter from request. i.e. `/api/search?ids=1&ids=2` ValuesFunc func(sourceParam string) []string // ErrorFunc is used to create errors. Allows you to use your own error type, that for example marshals to your specific json response ErrorFunc func(sourceParam string, values []string, message string, internalError error) error errors []error // failFast is flag for binding methods to return without attempting to bind when previous binding already failed failFast bool } // QueryParamsBinder creates query parameter value binder func QueryParamsBinder(c *Context) *ValueBinder { return &ValueBinder{ failFast: true, ValueFunc: c.QueryParam, ValuesFunc: func(sourceParam string) []string { values, ok := c.QueryParams()[sourceParam] if !ok { return nil } return values }, ErrorFunc: NewBindingError, } } // PathValuesBinder creates path parameter value binder func PathValuesBinder(c *Context) *ValueBinder { return &ValueBinder{ failFast: true, ValueFunc: c.Param, ValuesFunc: func(sourceParam string) []string { // path parameter should not have multiple values so getting values does not make sense but lets not error out here value := c.Param(sourceParam) if value == "" { return nil } return []string{value} }, ErrorFunc: NewBindingError, } } // FormFieldBinder creates form field value binder // For all requests, FormFieldBinder parses the raw query from the URL and uses query params as form fields // // For POST, PUT, and PATCH requests, it also reads the request body, parses it // as a form and uses query params as form fields. Request body parameters take precedence over URL query // string values in r.Form. // // NB: when binding forms take note that this implementation uses standard library form parsing // which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm // See https://golang.org/pkg/net/http/#Request.ParseForm func FormFieldBinder(c *Context) *ValueBinder { vb := &ValueBinder{ failFast: true, ValueFunc: func(sourceParam string) string { return c.Request().FormValue(sourceParam) }, ErrorFunc: NewBindingError, } vb.ValuesFunc = func(sourceParam string) []string { if c.Request().Form == nil { // this is same as `Request().FormValue()` does internally _, _ = c.MultipartForm() // we want to trigger c.request.ParseMultipartForm(c.formParseMaxMemory) } values, ok := c.Request().Form[sourceParam] if !ok { return nil } return values } return vb } // FailFast set internal flag to indicate if binding methods will return early (without binding) when previous bind failed // NB: call this method before any other binding methods as it modifies binding methods behaviour func (b *ValueBinder) FailFast(value bool) *ValueBinder { b.failFast = value return b } func (b *ValueBinder) setError(err error) { if b.errors == nil { b.errors = []error{err} return } b.errors = append(b.errors, err) } // BindError returns first seen bind error and resets/empties binder errors for further calls func (b *ValueBinder) BindError() error { if b.errors == nil { return nil } err := b.errors[0] b.errors = nil // reset errors so next chain will start from zero return err } // BindErrors returns all bind errors and resets/empties binder errors for further calls func (b *ValueBinder) BindErrors() []error { if b.errors == nil { return nil } errors := b.errors b.errors = nil // reset errors so next chain will start from zero return errors } // CustomFunc binds parameter values with Func. Func is called only when parameter values exist. func (b *ValueBinder) CustomFunc(sourceParam string, customFunc func(values []string) []error) *ValueBinder { return b.customFunc(sourceParam, customFunc, false) } // MustCustomFunc requires parameter values to exist to bind with Func. Returns error when value does not exist. func (b *ValueBinder) MustCustomFunc(sourceParam string, customFunc func(values []string) []error) *ValueBinder { return b.customFunc(sourceParam, customFunc, true) } func (b *ValueBinder) customFunc(sourceParam string, customFunc func(values []string) []error, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } values := b.ValuesFunc(sourceParam) if len(values) == 0 { if valueMustExist { b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) } return b } if errs := customFunc(values); errs != nil { b.errors = append(b.errors, errs...) } return b } // String binds parameter to string variable func (b *ValueBinder) String(sourceParam string, dest *string) *ValueBinder { if b.failFast && b.errors != nil { return b } value := b.ValueFunc(sourceParam) if value == "" { return b } *dest = value return b } // MustString requires parameter value to exist to bind to string variable. Returns error when value does not exist func (b *ValueBinder) MustString(sourceParam string, dest *string) *ValueBinder { if b.failFast && b.errors != nil { return b } value := b.ValueFunc(sourceParam) if value == "" { b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil)) return b } *dest = value return b } // Strings binds parameter values to slice of string func (b *ValueBinder) Strings(sourceParam string, dest *[]string) *ValueBinder { if b.failFast && b.errors != nil { return b } value := b.ValuesFunc(sourceParam) if value == nil { return b } *dest = value return b } // MustStrings requires parameter values to exist to bind to slice of string variables. Returns error when value does not exist func (b *ValueBinder) MustStrings(sourceParam string, dest *[]string) *ValueBinder { if b.failFast && b.errors != nil { return b } value := b.ValuesFunc(sourceParam) if value == nil { b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) return b } *dest = value return b } // BindUnmarshaler binds parameter to destination implementing BindUnmarshaler interface func (b *ValueBinder) BindUnmarshaler(sourceParam string, dest BindUnmarshaler) *ValueBinder { if b.failFast && b.errors != nil { return b } tmp := b.ValueFunc(sourceParam) if tmp == "" { return b } if err := dest.UnmarshalParam(tmp); err != nil { b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to BindUnmarshaler interface", err)) } return b } // MustBindUnmarshaler requires parameter value to exist to bind to destination implementing BindUnmarshaler interface. // Returns error when value does not exist func (b *ValueBinder) MustBindUnmarshaler(sourceParam string, dest BindUnmarshaler) *ValueBinder { if b.failFast && b.errors != nil { return b } value := b.ValueFunc(sourceParam) if value == "" { b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil)) return b } if err := dest.UnmarshalParam(value); err != nil { b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to BindUnmarshaler interface", err)) } return b } // JSONUnmarshaler binds parameter to destination implementing json.Unmarshaler interface func (b *ValueBinder) JSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder { if b.failFast && b.errors != nil { return b } tmp := b.ValueFunc(sourceParam) if tmp == "" { return b } if err := dest.UnmarshalJSON([]byte(tmp)); err != nil { b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err)) } return b } // MustJSONUnmarshaler requires parameter value to exist to bind to destination implementing json.Unmarshaler interface. // Returns error when value does not exist func (b *ValueBinder) MustJSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder { if b.failFast && b.errors != nil { return b } tmp := b.ValueFunc(sourceParam) if tmp == "" { b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil)) return b } if err := dest.UnmarshalJSON([]byte(tmp)); err != nil { b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err)) } return b } // TextUnmarshaler binds parameter to destination implementing encoding.TextUnmarshaler interface func (b *ValueBinder) TextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder { if b.failFast && b.errors != nil { return b } tmp := b.ValueFunc(sourceParam) if tmp == "" { return b } if err := dest.UnmarshalText([]byte(tmp)); err != nil { b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err)) } return b } // MustTextUnmarshaler requires parameter value to exist to bind to destination implementing encoding.TextUnmarshaler interface. // Returns error when value does not exist func (b *ValueBinder) MustTextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder { if b.failFast && b.errors != nil { return b } tmp := b.ValueFunc(sourceParam) if tmp == "" { b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil)) return b } if err := dest.UnmarshalText([]byte(tmp)); err != nil { b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err)) } return b } // BindWithDelimiter binds parameter to destination by suitable conversion function. // Delimiter is used before conversion to split parameter value to separate values func (b *ValueBinder) BindWithDelimiter(sourceParam string, dest any, delimiter string) *ValueBinder { return b.bindWithDelimiter(sourceParam, dest, delimiter, false) } // MustBindWithDelimiter requires parameter value to exist to bind destination by suitable conversion function. // Delimiter is used before conversion to split parameter value to separate values func (b *ValueBinder) MustBindWithDelimiter(sourceParam string, dest any, delimiter string) *ValueBinder { return b.bindWithDelimiter(sourceParam, dest, delimiter, true) } func (b *ValueBinder) bindWithDelimiter(sourceParam string, dest any, delimiter string, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } values := b.ValuesFunc(sourceParam) if len(values) == 0 { if valueMustExist { b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) } return b } tmpValues := make([]string, 0, len(values)) for _, v := range values { tmpValues = append(tmpValues, strings.Split(v, delimiter)...) } switch d := dest.(type) { case *[]string: *d = tmpValues return b case *[]bool: return b.bools(sourceParam, tmpValues, d) case *[]int64, *[]int32, *[]int16, *[]int8, *[]int: return b.ints(sourceParam, tmpValues, d) case *[]uint64, *[]uint32, *[]uint16, *[]uint8, *[]uint: // *[]byte is same as *[]uint8 return b.uints(sourceParam, tmpValues, d) case *[]float64, *[]float32: return b.floats(sourceParam, tmpValues, d) case *[]time.Duration: return b.durations(sourceParam, tmpValues, d) default: // support only cases when destination is slice // does not support time.Time as it needs argument (layout) for parsing or BindUnmarshaler b.setError(b.ErrorFunc(sourceParam, []string{}, "unsupported bind type", nil)) return b } } // Int64 binds parameter to int64 variable func (b *ValueBinder) Int64(sourceParam string, dest *int64) *ValueBinder { return b.intValue(sourceParam, dest, 64, false) } // MustInt64 requires parameter value to exist to bind to int64 variable. Returns error when value does not exist func (b *ValueBinder) MustInt64(sourceParam string, dest *int64) *ValueBinder { return b.intValue(sourceParam, dest, 64, true) } // Int32 binds parameter to int32 variable func (b *ValueBinder) Int32(sourceParam string, dest *int32) *ValueBinder { return b.intValue(sourceParam, dest, 32, false) } // MustInt32 requires parameter value to exist to bind to int32 variable. Returns error when value does not exist func (b *ValueBinder) MustInt32(sourceParam string, dest *int32) *ValueBinder { return b.intValue(sourceParam, dest, 32, true) } // Int16 binds parameter to int16 variable func (b *ValueBinder) Int16(sourceParam string, dest *int16) *ValueBinder { return b.intValue(sourceParam, dest, 16, false) } // MustInt16 requires parameter value to exist to bind to int16 variable. Returns error when value does not exist func (b *ValueBinder) MustInt16(sourceParam string, dest *int16) *ValueBinder { return b.intValue(sourceParam, dest, 16, true) } // Int8 binds parameter to int8 variable func (b *ValueBinder) Int8(sourceParam string, dest *int8) *ValueBinder { return b.intValue(sourceParam, dest, 8, false) } // MustInt8 requires parameter value to exist to bind to int8 variable. Returns error when value does not exist func (b *ValueBinder) MustInt8(sourceParam string, dest *int8) *ValueBinder { return b.intValue(sourceParam, dest, 8, true) } // Int binds parameter to int variable func (b *ValueBinder) Int(sourceParam string, dest *int) *ValueBinder { return b.intValue(sourceParam, dest, 0, false) } // MustInt requires parameter value to exist to bind to int variable. Returns error when value does not exist func (b *ValueBinder) MustInt(sourceParam string, dest *int) *ValueBinder { return b.intValue(sourceParam, dest, 0, true) } func (b *ValueBinder) intValue(sourceParam string, dest any, bitSize int, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } value := b.ValueFunc(sourceParam) if value == "" { if valueMustExist { b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) } return b } return b.int(sourceParam, value, dest, bitSize) } func (b *ValueBinder) int(sourceParam string, value string, dest any, bitSize int) *ValueBinder { n, err := strconv.ParseInt(value, 10, bitSize) if err != nil { if bitSize == 0 { b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to int", err)) } else { b.setError(b.ErrorFunc(sourceParam, []string{value}, fmt.Sprintf("failed to bind field value to int%v", bitSize), err)) } return b } switch d := dest.(type) { case *int64: *d = n case *int32: *d = int32(n) // #nosec G115 case *int16: *d = int16(n) // #nosec G115 case *int8: *d = int8(n) // #nosec G115 case *int: *d = int(n) } return b } func (b *ValueBinder) intsValue(sourceParam string, dest any, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } values := b.ValuesFunc(sourceParam) if len(values) == 0 { if valueMustExist { b.setError(b.ErrorFunc(sourceParam, values, "required field value is empty", nil)) } return b } return b.ints(sourceParam, values, dest) } func (b *ValueBinder) ints(sourceParam string, values []string, dest any) *ValueBinder { switch d := dest.(type) { case *[]int64: tmp := make([]int64, len(values)) for i, v := range values { b.int(sourceParam, v, &tmp[i], 64) if b.failFast && b.errors != nil { return b } } if b.errors == nil { *d = tmp } case *[]int32: tmp := make([]int32, len(values)) for i, v := range values { b.int(sourceParam, v, &tmp[i], 32) if b.failFast && b.errors != nil { return b } } if b.errors == nil { *d = tmp } case *[]int16: tmp := make([]int16, len(values)) for i, v := range values { b.int(sourceParam, v, &tmp[i], 16) if b.failFast && b.errors != nil { return b } } if b.errors == nil { *d = tmp } case *[]int8: tmp := make([]int8, len(values)) for i, v := range values { b.int(sourceParam, v, &tmp[i], 8) if b.failFast && b.errors != nil { return b } } if b.errors == nil { *d = tmp } case *[]int: tmp := make([]int, len(values)) for i, v := range values { b.int(sourceParam, v, &tmp[i], 0) if b.failFast && b.errors != nil { return b } } if b.errors == nil { *d = tmp } } return b } // Int64s binds parameter to slice of int64 func (b *ValueBinder) Int64s(sourceParam string, dest *[]int64) *ValueBinder { return b.intsValue(sourceParam, dest, false) } // MustInt64s requires parameter value to exist to bind to int64 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt64s(sourceParam string, dest *[]int64) *ValueBinder { return b.intsValue(sourceParam, dest, true) } // Int32s binds parameter to slice of int32 func (b *ValueBinder) Int32s(sourceParam string, dest *[]int32) *ValueBinder { return b.intsValue(sourceParam, dest, false) } // MustInt32s requires parameter value to exist to bind to int32 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt32s(sourceParam string, dest *[]int32) *ValueBinder { return b.intsValue(sourceParam, dest, true) } // Int16s binds parameter to slice of int16 func (b *ValueBinder) Int16s(sourceParam string, dest *[]int16) *ValueBinder { return b.intsValue(sourceParam, dest, false) } // MustInt16s requires parameter value to exist to bind to int16 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt16s(sourceParam string, dest *[]int16) *ValueBinder { return b.intsValue(sourceParam, dest, true) } // Int8s binds parameter to slice of int8 func (b *ValueBinder) Int8s(sourceParam string, dest *[]int8) *ValueBinder { return b.intsValue(sourceParam, dest, false) } // MustInt8s requires parameter value to exist to bind to int8 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt8s(sourceParam string, dest *[]int8) *ValueBinder { return b.intsValue(sourceParam, dest, true) } // Ints binds parameter to slice of int func (b *ValueBinder) Ints(sourceParam string, dest *[]int) *ValueBinder { return b.intsValue(sourceParam, dest, false) } // MustInts requires parameter value to exist to bind to int slice variable. Returns error when value does not exist func (b *ValueBinder) MustInts(sourceParam string, dest *[]int) *ValueBinder { return b.intsValue(sourceParam, dest, true) } // Uint64 binds parameter to uint64 variable func (b *ValueBinder) Uint64(sourceParam string, dest *uint64) *ValueBinder { return b.uintValue(sourceParam, dest, 64, false) } // MustUint64 requires parameter value to exist to bind to uint64 variable. Returns error when value does not exist func (b *ValueBinder) MustUint64(sourceParam string, dest *uint64) *ValueBinder { return b.uintValue(sourceParam, dest, 64, true) } // Uint32 binds parameter to uint32 variable func (b *ValueBinder) Uint32(sourceParam string, dest *uint32) *ValueBinder { return b.uintValue(sourceParam, dest, 32, false) } // MustUint32 requires parameter value to exist to bind to uint32 variable. Returns error when value does not exist func (b *ValueBinder) MustUint32(sourceParam string, dest *uint32) *ValueBinder { return b.uintValue(sourceParam, dest, 32, true) } // Uint16 binds parameter to uint16 variable func (b *ValueBinder) Uint16(sourceParam string, dest *uint16) *ValueBinder { return b.uintValue(sourceParam, dest, 16, false) } // MustUint16 requires parameter value to exist to bind to uint16 variable. Returns error when value does not exist func (b *ValueBinder) MustUint16(sourceParam string, dest *uint16) *ValueBinder { return b.uintValue(sourceParam, dest, 16, true) } // Uint8 binds parameter to uint8 variable func (b *ValueBinder) Uint8(sourceParam string, dest *uint8) *ValueBinder { return b.uintValue(sourceParam, dest, 8, false) } // MustUint8 requires parameter value to exist to bind to uint8 variable. Returns error when value does not exist func (b *ValueBinder) MustUint8(sourceParam string, dest *uint8) *ValueBinder { return b.uintValue(sourceParam, dest, 8, true) } // Byte binds parameter to byte variable func (b *ValueBinder) Byte(sourceParam string, dest *byte) *ValueBinder { return b.uintValue(sourceParam, dest, 8, false) } // MustByte requires parameter value to exist to bind to byte variable. Returns error when value does not exist func (b *ValueBinder) MustByte(sourceParam string, dest *byte) *ValueBinder { return b.uintValue(sourceParam, dest, 8, true) } // Uint binds parameter to uint variable func (b *ValueBinder) Uint(sourceParam string, dest *uint) *ValueBinder { return b.uintValue(sourceParam, dest, 0, false) } // MustUint requires parameter value to exist to bind to uint variable. Returns error when value does not exist func (b *ValueBinder) MustUint(sourceParam string, dest *uint) *ValueBinder { return b.uintValue(sourceParam, dest, 0, true) } func (b *ValueBinder) uintValue(sourceParam string, dest any, bitSize int, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } value := b.ValueFunc(sourceParam) if value == "" { if valueMustExist { b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) } return b } return b.uint(sourceParam, value, dest, bitSize) } func (b *ValueBinder) uint(sourceParam string, value string, dest any, bitSize int) *ValueBinder { n, err := strconv.ParseUint(value, 10, bitSize) if err != nil { if bitSize == 0 { b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to uint", err)) } else { b.setError(b.ErrorFunc(sourceParam, []string{value}, fmt.Sprintf("failed to bind field value to uint%v", bitSize), err)) } return b } switch d := dest.(type) { case *uint64: *d = n case *uint32: *d = uint32(n) // #nosec G115 case *uint16: *d = uint16(n) // #nosec G115 case *uint8: // byte is alias to uint8 *d = uint8(n) // #nosec G115 case *uint: *d = uint(n) // #nosec G115 } return b } func (b *ValueBinder) uintsValue(sourceParam string, dest any, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } values := b.ValuesFunc(sourceParam) if len(values) == 0 { if valueMustExist { b.setError(b.ErrorFunc(sourceParam, values, "required field value is empty", nil)) } return b } return b.uints(sourceParam, values, dest) } func (b *ValueBinder) uints(sourceParam string, values []string, dest any) *ValueBinder { switch d := dest.(type) { case *[]uint64: tmp := make([]uint64, len(values)) for i, v := range values { b.uint(sourceParam, v, &tmp[i], 64) if b.failFast && b.errors != nil { return b } } if b.errors == nil { *d = tmp } case *[]uint32: tmp := make([]uint32, len(values)) for i, v := range values { b.uint(sourceParam, v, &tmp[i], 32) if b.failFast && b.errors != nil { return b } } if b.errors == nil { *d = tmp } case *[]uint16: tmp := make([]uint16, len(values)) for i, v := range values { b.uint(sourceParam, v, &tmp[i], 16) if b.failFast && b.errors != nil { return b } } if b.errors == nil { *d = tmp } case *[]uint8: // byte is alias to uint8 tmp := make([]uint8, len(values)) for i, v := range values { b.uint(sourceParam, v, &tmp[i], 8) if b.failFast && b.errors != nil { return b } } if b.errors == nil { *d = tmp } case *[]uint: tmp := make([]uint, len(values)) for i, v := range values { b.uint(sourceParam, v, &tmp[i], 0) if b.failFast && b.errors != nil { return b } } if b.errors == nil { *d = tmp } } return b } // Uint64s binds parameter to slice of uint64 func (b *ValueBinder) Uint64s(sourceParam string, dest *[]uint64) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } // MustUint64s requires parameter value to exist to bind to uint64 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint64s(sourceParam string, dest *[]uint64) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } // Uint32s binds parameter to slice of uint32 func (b *ValueBinder) Uint32s(sourceParam string, dest *[]uint32) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } // MustUint32s requires parameter value to exist to bind to uint32 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint32s(sourceParam string, dest *[]uint32) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } // Uint16s binds parameter to slice of uint16 func (b *ValueBinder) Uint16s(sourceParam string, dest *[]uint16) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } // MustUint16s requires parameter value to exist to bind to uint16 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint16s(sourceParam string, dest *[]uint16) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } // Uint8s binds parameter to slice of uint8 func (b *ValueBinder) Uint8s(sourceParam string, dest *[]uint8) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } // MustUint8s requires parameter value to exist to bind to uint8 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint8s(sourceParam string, dest *[]uint8) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } // Uints binds parameter to slice of uint func (b *ValueBinder) Uints(sourceParam string, dest *[]uint) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } // MustUints requires parameter value to exist to bind to uint slice variable. Returns error when value does not exist func (b *ValueBinder) MustUints(sourceParam string, dest *[]uint) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } // Bool binds parameter to bool variable func (b *ValueBinder) Bool(sourceParam string, dest *bool) *ValueBinder { return b.boolValue(sourceParam, dest, false) } // MustBool requires parameter value to exist to bind to bool variable. Returns error when value does not exist func (b *ValueBinder) MustBool(sourceParam string, dest *bool) *ValueBinder { return b.boolValue(sourceParam, dest, true) } func (b *ValueBinder) boolValue(sourceParam string, dest *bool, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } value := b.ValueFunc(sourceParam) if value == "" { if valueMustExist { b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) } return b } return b.bool(sourceParam, value, dest) } func (b *ValueBinder) bool(sourceParam string, value string, dest *bool) *ValueBinder { n, err := strconv.ParseBool(value) if err != nil { b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to bool", err)) return b } *dest = n return b } func (b *ValueBinder) boolsValue(sourceParam string, dest *[]bool, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } values := b.ValuesFunc(sourceParam) if len(values) == 0 { if valueMustExist { b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) } return b } return b.bools(sourceParam, values, dest) } func (b *ValueBinder) bools(sourceParam string, values []string, dest *[]bool) *ValueBinder { tmp := make([]bool, len(values)) for i, v := range values { b.bool(sourceParam, v, &tmp[i]) if b.failFast && b.errors != nil { return b } } if b.errors == nil { *dest = tmp } return b } // Bools binds parameter values to slice of bool variables func (b *ValueBinder) Bools(sourceParam string, dest *[]bool) *ValueBinder { return b.boolsValue(sourceParam, dest, false) } // MustBools requires parameter values to exist to bind to slice of bool variables. Returns error when values does not exist func (b *ValueBinder) MustBools(sourceParam string, dest *[]bool) *ValueBinder { return b.boolsValue(sourceParam, dest, true) } // Float64 binds parameter to float64 variable func (b *ValueBinder) Float64(sourceParam string, dest *float64) *ValueBinder { return b.floatValue(sourceParam, dest, 64, false) } // MustFloat64 requires parameter value to exist to bind to float64 variable. Returns error when value does not exist func (b *ValueBinder) MustFloat64(sourceParam string, dest *float64) *ValueBinder { return b.floatValue(sourceParam, dest, 64, true) } // Float32 binds parameter to float32 variable func (b *ValueBinder) Float32(sourceParam string, dest *float32) *ValueBinder { return b.floatValue(sourceParam, dest, 32, false) } // MustFloat32 requires parameter value to exist to bind to float32 variable. Returns error when value does not exist func (b *ValueBinder) MustFloat32(sourceParam string, dest *float32) *ValueBinder { return b.floatValue(sourceParam, dest, 32, true) } func (b *ValueBinder) floatValue(sourceParam string, dest any, bitSize int, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } value := b.ValueFunc(sourceParam) if value == "" { if valueMustExist { b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) } return b } return b.float(sourceParam, value, dest, bitSize) } func (b *ValueBinder) float(sourceParam string, value string, dest any, bitSize int) *ValueBinder { n, err := strconv.ParseFloat(value, bitSize) if err != nil { b.setError(b.ErrorFunc(sourceParam, []string{value}, fmt.Sprintf("failed to bind field value to float%v", bitSize), err)) return b } switch d := dest.(type) { case *float64: *d = n case *float32: *d = float32(n) } return b } func (b *ValueBinder) floatsValue(sourceParam string, dest any, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } values := b.ValuesFunc(sourceParam) if len(values) == 0 { if valueMustExist { b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) } return b } return b.floats(sourceParam, values, dest) } func (b *ValueBinder) floats(sourceParam string, values []string, dest any) *ValueBinder { switch d := dest.(type) { case *[]float64: tmp := make([]float64, len(values)) for i, v := range values { b.float(sourceParam, v, &tmp[i], 64) if b.failFast && b.errors != nil { return b } } if b.errors == nil { *d = tmp } case *[]float32: tmp := make([]float32, len(values)) for i, v := range values { b.float(sourceParam, v, &tmp[i], 32) if b.failFast && b.errors != nil { return b } } if b.errors == nil { *d = tmp } } return b } // Float64s binds parameter values to slice of float64 variables func (b *ValueBinder) Float64s(sourceParam string, dest *[]float64) *ValueBinder { return b.floatsValue(sourceParam, dest, false) } // MustFloat64s requires parameter values to exist to bind to slice of float64 variables. Returns error when values does not exist func (b *ValueBinder) MustFloat64s(sourceParam string, dest *[]float64) *ValueBinder { return b.floatsValue(sourceParam, dest, true) } // Float32s binds parameter values to slice of float32 variables func (b *ValueBinder) Float32s(sourceParam string, dest *[]float32) *ValueBinder { return b.floatsValue(sourceParam, dest, false) } // MustFloat32s requires parameter values to exist to bind to slice of float32 variables. Returns error when values does not exist func (b *ValueBinder) MustFloat32s(sourceParam string, dest *[]float32) *ValueBinder { return b.floatsValue(sourceParam, dest, true) } // Time binds parameter to time.Time variable func (b *ValueBinder) Time(sourceParam string, dest *time.Time, layout string) *ValueBinder { return b.time(sourceParam, dest, layout, false) } // MustTime requires parameter value to exist to bind to time.Time variable. Returns error when value does not exist func (b *ValueBinder) MustTime(sourceParam string, dest *time.Time, layout string) *ValueBinder { return b.time(sourceParam, dest, layout, true) } func (b *ValueBinder) time(sourceParam string, dest *time.Time, layout string, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } value := b.ValueFunc(sourceParam) if value == "" { if valueMustExist { b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil)) } return b } t, err := time.Parse(layout, value) if err != nil { b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to Time", err)) return b } *dest = t return b } // Times binds parameter values to slice of time.Time variables func (b *ValueBinder) Times(sourceParam string, dest *[]time.Time, layout string) *ValueBinder { return b.times(sourceParam, dest, layout, false) } // MustTimes requires parameter values to exist to bind to slice of time.Time variables. Returns error when values does not exist func (b *ValueBinder) MustTimes(sourceParam string, dest *[]time.Time, layout string) *ValueBinder { return b.times(sourceParam, dest, layout, true) } func (b *ValueBinder) times(sourceParam string, dest *[]time.Time, layout string, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } values := b.ValuesFunc(sourceParam) if len(values) == 0 { if valueMustExist { b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) } return b } tmp := make([]time.Time, len(values)) for i, v := range values { t, err := time.Parse(layout, v) if err != nil { b.setError(b.ErrorFunc(sourceParam, []string{v}, "failed to bind field value to Time", err)) if b.failFast { return b } continue } tmp[i] = t } if b.errors == nil { *dest = tmp } return b } // Duration binds parameter to time.Duration variable func (b *ValueBinder) Duration(sourceParam string, dest *time.Duration) *ValueBinder { return b.duration(sourceParam, dest, false) } // MustDuration requires parameter value to exist to bind to time.Duration variable. Returns error when value does not exist func (b *ValueBinder) MustDuration(sourceParam string, dest *time.Duration) *ValueBinder { return b.duration(sourceParam, dest, true) } func (b *ValueBinder) duration(sourceParam string, dest *time.Duration, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } value := b.ValueFunc(sourceParam) if value == "" { if valueMustExist { b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil)) } return b } t, err := time.ParseDuration(value) if err != nil { b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to Duration", err)) return b } *dest = t return b } // Durations binds parameter values to slice of time.Duration variables func (b *ValueBinder) Durations(sourceParam string, dest *[]time.Duration) *ValueBinder { return b.durationsValue(sourceParam, dest, false) } // MustDurations requires parameter values to exist to bind to slice of time.Duration variables. Returns error when values does not exist func (b *ValueBinder) MustDurations(sourceParam string, dest *[]time.Duration) *ValueBinder { return b.durationsValue(sourceParam, dest, true) } func (b *ValueBinder) durationsValue(sourceParam string, dest *[]time.Duration, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } values := b.ValuesFunc(sourceParam) if len(values) == 0 { if valueMustExist { b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) } return b } return b.durations(sourceParam, values, dest) } func (b *ValueBinder) durations(sourceParam string, values []string, dest *[]time.Duration) *ValueBinder { tmp := make([]time.Duration, len(values)) for i, v := range values { t, err := time.ParseDuration(v) if err != nil { b.setError(b.ErrorFunc(sourceParam, []string{v}, "failed to bind field value to Duration", err)) if b.failFast { return b } continue } tmp[i] = t } if b.errors == nil { *dest = tmp } return b } // UnixTime binds parameter to time.Time variable (in local Time corresponding to the given Unix time). // // Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00 // // Note: // - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, false, time.Second) } // MustUnixTime requires parameter value to exist to bind to time.Duration variable (in local time corresponding // to the given Unix time). Returns error when value does not exist. // // Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00 // // Note: // - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, true, time.Second) } // UnixTimeMilli binds parameter to time.Time variable (in local time corresponding to the given Unix time in millisecond precision). // // Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 // // Note: // - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, false, time.Millisecond) } // MustUnixTimeMilli requires parameter value to exist to bind to time.Duration variable (in local time corresponding // to the given Unix time in millisecond precision). Returns error when value does not exist. // // Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 // // Note: // - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, true, time.Millisecond) } // UnixTimeNano binds parameter to time.Time variable (in local time corresponding to the given Unix time in nanosecond precision). // // Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00 // Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00 // Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00 // // Note: // - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal // - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, false, time.Nanosecond) } // MustUnixTimeNano requires parameter value to exist to bind to time.Duration variable (in local Time corresponding // to the given Unix time value in nano second precision). Returns error when value does not exist. // // Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00 // Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00 // Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00 // // Note: // - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal // - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. func (b *ValueBinder) MustUnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, true, time.Nanosecond) } func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExist bool, precision time.Duration) *ValueBinder { if b.failFast && b.errors != nil { return b } value := b.ValueFunc(sourceParam) if value == "" { if valueMustExist { b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil)) } return b } n, err := strconv.ParseInt(value, 10, 64) if err != nil { b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to Time", err)) return b } switch precision { case time.Second: *dest = time.Unix(n, 0) case time.Millisecond: *dest = time.UnixMilli(n) case time.Nanosecond: *dest = time.Unix(0, n) } return b } ================================================ FILE: binder_external_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors // run tests as external package to get real feel for API package echo_test import ( "encoding/base64" "fmt" "log" "net/http" "net/http/httptest" "github.com/labstack/echo/v5" ) func ExampleValueBinder_BindErrors() { // example route function that binds query params to different destinations and returns all bind errors in one go routeFunc := func(c *echo.Context) error { var opts struct { IDs []int64 Active bool } length := int64(50) // default length is 50 b := echo.QueryParamsBinder(c) errs := b.Int64("length", &length). Int64s("ids", &opts.IDs). Bool("active", &opts.Active). BindErrors() // returns all errors if errs != nil { for _, err := range errs { bErr := err.(*echo.BindingError) log.Printf("in case you want to access what field: %s values: %v failed", bErr.Field, bErr.Values) } return fmt.Errorf("%v fields failed to bind", len(errs)) } fmt.Printf("active = %v, length = %v, ids = %v", opts.Active, length, opts.IDs) return c.JSON(http.StatusOK, opts) } e := echo.New() c := e.NewContext( httptest.NewRequest(http.MethodGet, "/api/endpoint?active=true&length=25&ids=1&ids=2&ids=3", nil), httptest.NewRecorder(), ) _ = routeFunc(c) // Output: active = true, length = 25, ids = [1 2 3] } func ExampleValueBinder_BindError() { // example route function that binds query params to different destinations and stops binding on first bind error failFastRouteFunc := func(c *echo.Context) error { var opts struct { IDs []int64 Active bool } length := int64(50) // default length is 50 // create binder that stops binding at first error b := echo.QueryParamsBinder(c) err := b.Int64("length", &length). Int64s("ids", &opts.IDs). Bool("active", &opts.Active). BindError() // returns first binding error if err != nil { bErr := err.(*echo.BindingError) return fmt.Errorf("my own custom error for field: %s values: %v", bErr.Field, bErr.Values) } fmt.Printf("active = %v, length = %v, ids = %v\n", opts.Active, length, opts.IDs) return c.JSON(http.StatusOK, opts) } e := echo.New() c := e.NewContext( httptest.NewRequest(http.MethodGet, "/api/endpoint?active=true&length=25&ids=1&ids=2&ids=3", nil), httptest.NewRecorder(), ) _ = failFastRouteFunc(c) // Output: active = true, length = 25, ids = [1 2 3] } func ExampleValueBinder_CustomFunc() { // example route function that binds query params using custom function closure routeFunc := func(c *echo.Context) error { length := int64(50) // default length is 50 var binary []byte b := echo.QueryParamsBinder(c) errs := b.Int64("length", &length). CustomFunc("base64", func(values []string) []error { if len(values) == 0 { return nil } decoded, err := base64.URLEncoding.DecodeString(values[0]) if err != nil { // in this example we use only first param value but url could contain multiple params in reality and // therefore in theory produce multiple binding errors return []error{echo.NewBindingError("base64", values[0:1], "failed to decode base64", err)} } binary = decoded return nil }). BindErrors() // returns all errors if errs != nil { for _, err := range errs { bErr := err.(*echo.BindingError) log.Printf("in case you want to access what field: %s values: %v failed", bErr.Field, bErr.Values) } return fmt.Errorf("%v fields failed to bind", len(errs)) } fmt.Printf("length = %v, base64 = %s", length, binary) return c.JSON(http.StatusOK, "ok") } e := echo.New() c := e.NewContext( httptest.NewRequest(http.MethodGet, "/api/endpoint?length=25&base64=SGVsbG8gV29ybGQ%3D", nil), httptest.NewRecorder(), ) _ = routeFunc(c) // Output: length = 25, base64 = Hello World } ================================================ FILE: binder_generic.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echo import ( "encoding" "encoding/json" "fmt" "strconv" "time" ) // TimeLayout specifies the format for parsing time values in request parameters. // It can be a standard Go time layout string or one of the special Unix time layouts. type TimeLayout string // TimeOpts is options for parsing time.Time values type TimeOpts struct { // Layout specifies the format for parsing time values in request parameters. // It can be a standard Go time layout string or one of the special Unix time layouts. // // Parsing layout defaults to: echo.TimeLayout(time.RFC3339Nano) // - To convert to custom layout use `echo.TimeLayout("2006-01-02")` // - To convert unix timestamp (integer) to time.Time use `echo.TimeLayoutUnixTime` // - To convert unix timestamp in milliseconds to time.Time use `echo.TimeLayoutUnixTimeMilli` // - To convert unix timestamp in nanoseconds to time.Time use `echo.TimeLayoutUnixTimeNano` Layout TimeLayout // ParseInLocation is location used with time.ParseInLocation for layout that do not contain // timezone information to set output time in given location. // Defaults to time.UTC ParseInLocation *time.Location // ToInLocation is location to which parsed time is converted to after parsing. // The parsed time will be converted using time.In(ToInLocation). // Defaults to time.UTC ToInLocation *time.Location } // TimeLayout constants for parsing Unix timestamps in different precisions. const ( TimeLayoutUnixTime = TimeLayout("UnixTime") // Unix timestamp in seconds TimeLayoutUnixTimeMilli = TimeLayout("UnixTimeMilli") // Unix timestamp in milliseconds TimeLayoutUnixTimeNano = TimeLayout("UnixTimeNano") // Unix timestamp in nanoseconds ) // PathParam extracts and parses a path parameter from the context by name. // It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found. // // Empty String Handling: // If the parameter exists but has an empty value, the zero value of type T is returned // with no error. For example, a path parameter with value "" returns (0, nil) for int types. // This differs from standard library behavior where parsing empty strings returns errors. // To treat empty values as errors, validate the result separately or check the raw value. // // See ParseValue for supported types and options func PathParam[T any](c *Context, paramName string, opts ...any) (T, error) { for _, pv := range c.PathValues() { if pv.Name == paramName { v, err := ParseValue[T](pv.Value, opts...) if err != nil { return v, NewBindingError(paramName, []string{pv.Value}, "path value", err) } return v, nil } } var zero T return zero, ErrNonExistentKey } // PathParamOr extracts and parses a path parameter from the context by name. // Returns defaultValue if the parameter is not found or has an empty value. // Returns an error only if parsing fails (e.g., "abc" for int type). // // Example: // id, err := echo.PathParamOr[int](c, "id", 0) // // If "id" is missing: returns (0, nil) // // If "id" is "123": returns (123, nil) // // If "id" is "abc": returns (0, BindingError) // // See ParseValue for supported types and options func PathParamOr[T any](c *Context, paramName string, defaultValue T, opts ...any) (T, error) { for _, pv := range c.PathValues() { if pv.Name == paramName { v, err := ParseValueOr[T](pv.Value, defaultValue, opts...) if err != nil { return v, NewBindingError(paramName, []string{pv.Value}, "path value", err) } return v, nil } } return defaultValue, nil } // QueryParam extracts and parses a single query parameter from the request by key. // It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found. // // Empty String Handling: // If the parameter exists but has an empty value (?key=), the zero value of type T is returned // with no error. For example, "?count=" returns (0, nil) for int types. // This differs from standard library behavior where parsing empty strings returns errors. // To treat empty values as errors, validate the result separately or check the raw value. // // Behavior Summary: // - Missing key (?other=value): returns (zero, ErrNonExistentKey) // - Empty value (?key=): returns (zero, nil) // - Invalid value (?key=abc for int): returns (zero, BindingError) // // See ParseValue for supported types and options func QueryParam[T any](c *Context, key string, opts ...any) (T, error) { values, ok := c.QueryParams()[key] if !ok { var zero T return zero, ErrNonExistentKey } if len(values) == 0 { var zero T return zero, nil } value := values[0] v, err := ParseValue[T](value, opts...) if err != nil { return v, NewBindingError(key, []string{value}, "query param", err) } return v, nil } // QueryParamOr extracts and parses a single query parameter from the request by key. // Returns defaultValue if the parameter is not found or has an empty value. // Returns an error only if parsing fails (e.g., "abc" for int type). // // Example: // page, err := echo.QueryParamOr[int](c, "page", 1) // // If "page" is missing: returns (1, nil) // // If "page" is "5": returns (5, nil) // // If "page" is "abc": returns (1, BindingError) // // See ParseValue for supported types and options func QueryParamOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) { values, ok := c.QueryParams()[key] if !ok { return defaultValue, nil } if len(values) == 0 { return defaultValue, nil } value := values[0] v, err := ParseValueOr[T](value, defaultValue, opts...) if err != nil { return v, NewBindingError(key, []string{value}, "query param", err) } return v, nil } // QueryParams extracts and parses all values for a query parameter key as a slice. // It returns the typed slice and an error if binding any value fails. Returns ErrNonExistentKey if parameter not found. // // See ParseValues for supported types and options func QueryParams[T any](c *Context, key string, opts ...any) ([]T, error) { values, ok := c.QueryParams()[key] if !ok { return nil, ErrNonExistentKey } result, err := ParseValues[T](values, opts...) if err != nil { return nil, NewBindingError(key, values, "query params", err) } return result, nil } // QueryParamsOr extracts and parses all values for a query parameter key as a slice. // Returns defaultValue if the parameter is not found. // Returns an error only if parsing any value fails. // // Example: // ids, err := echo.QueryParamsOr[int](c, "ids", []int{}) // // If "ids" is missing: returns ([], nil) // // If "ids" is "1&ids=2": returns ([1, 2], nil) // // If "ids" contains "abc": returns ([], BindingError) // // See ParseValues for supported types and options func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) { values, ok := c.QueryParams()[key] if !ok { return defaultValue, nil } result, err := ParseValuesOr[T](values, defaultValue, opts...) if err != nil { return nil, NewBindingError(key, values, "query params", err) } return result, nil } // FormValue extracts and parses a single form value from the request by key. // It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found. // // Empty String Handling: // If the form field exists but has an empty value, the zero value of type T is returned // with no error. For example, an empty form field returns (0, nil) for int types. // This differs from standard library behavior where parsing empty strings returns errors. // To treat empty values as errors, validate the result separately or check the raw value. // // See ParseValue for supported types and options func FormValue[T any](c *Context, key string, opts ...any) (T, error) { formValues, err := c.FormValues() if err != nil { var zero T return zero, fmt.Errorf("failed to parse form value, key: %s, err: %w", key, err) } values, ok := formValues[key] if !ok { var zero T return zero, ErrNonExistentKey } if len(values) == 0 { var zero T return zero, nil } value := values[0] v, err := ParseValue[T](value, opts...) if err != nil { return v, NewBindingError(key, []string{value}, "form value", err) } return v, nil } // FormValueOr extracts and parses a single form value from the request by key. // Returns defaultValue if the parameter is not found or has an empty value. // Returns an error only if parsing fails or form parsing errors occur. // // Example: // limit, err := echo.FormValueOr[int](c, "limit", 100) // // If "limit" is missing: returns (100, nil) // // If "limit" is "50": returns (50, nil) // // If "limit" is "abc": returns (100, BindingError) // // See ParseValue for supported types and options func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) { formValues, err := c.FormValues() if err != nil { var zero T return zero, fmt.Errorf("failed to parse form value, key: %s, err: %w", key, err) } values, ok := formValues[key] if !ok { return defaultValue, nil } if len(values) == 0 { return defaultValue, nil } value := values[0] v, err := ParseValueOr[T](value, defaultValue, opts...) if err != nil { return v, NewBindingError(key, []string{value}, "form value", err) } return v, nil } // FormValues extracts and parses all values for a form values key as a slice. // It returns the typed slice and an error if binding any value fails. Returns ErrNonExistentKey if parameter not found. // // See ParseValues for supported types and options func FormValues[T any](c *Context, key string, opts ...any) ([]T, error) { formValues, err := c.FormValues() if err != nil { return nil, fmt.Errorf("failed to parse form values, key: %s, err: %w", key, err) } values, ok := formValues[key] if !ok { return nil, ErrNonExistentKey } result, err := ParseValues[T](values, opts...) if err != nil { return nil, NewBindingError(key, values, "form values", err) } return result, nil } // FormValuesOr extracts and parses all values for a form values key as a slice. // Returns defaultValue if the parameter is not found. // Returns an error only if parsing any value fails or form parsing errors occur. // // Example: // tags, err := echo.FormValuesOr[string](c, "tags", []string{}) // // If "tags" is missing: returns ([], nil) // // If form parsing fails: returns (nil, error) // // See ParseValues for supported types and options func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) { formValues, err := c.FormValues() if err != nil { return nil, fmt.Errorf("failed to parse form values, key: %s, err: %w", key, err) } values, ok := formValues[key] if !ok { return defaultValue, nil } result, err := ParseValuesOr[T](values, defaultValue, opts...) if err != nil { return nil, NewBindingError(key, values, "form values", err) } return result, nil } // ParseValues parses value to generic type slice. Same types are supported as ParseValue // function but the result type is slice instead of scalar value. // // See ParseValue for supported types and options func ParseValues[T any](values []string, opts ...any) ([]T, error) { var zero []T return ParseValuesOr(values, zero, opts...) } // ParseValuesOr parses value to generic type slice, when value is empty defaultValue is returned. // Same types are supported as ParseValue function but the result type is slice instead of scalar value. // // See ParseValue for supported types and options func ParseValuesOr[T any](values []string, defaultValue []T, opts ...any) ([]T, error) { if len(values) == 0 { return defaultValue, nil } result := make([]T, 0, len(values)) for _, v := range values { tmp, err := ParseValue[T](v, opts...) if err != nil { return nil, err } result = append(result, tmp) } return result, nil } // ParseValue parses value to generic type // // Types that are supported: // - bool // - float32 // - float64 // - int // - int8 // - int16 // - int32 // - int64 // - uint // - uint8/byte // - uint16 // - uint32 // - uint64 // - string // - echo.BindUnmarshaler interface // - encoding.TextUnmarshaler interface // - json.Unmarshaler interface // - time.Duration // - time.Time use echo.TimeOpts or echo.TimeLayout to set time parsing configuration func ParseValue[T any](value string, opts ...any) (T, error) { var zero T return ParseValueOr(value, zero, opts...) } // ParseValueOr parses value to generic type, when value is empty defaultValue is returned. // // Types that are supported: // - bool // - float32 // - float64 // - int // - int8 // - int16 // - int32 // - int64 // - uint // - uint8/byte // - uint16 // - uint32 // - uint64 // - string // - echo.BindUnmarshaler interface // - encoding.TextUnmarshaler interface // - json.Unmarshaler interface // - time.Duration // - time.Time use echo.TimeOpts or echo.TimeLayout to set time parsing configuration func ParseValueOr[T any](value string, defaultValue T, opts ...any) (T, error) { if len(value) == 0 { return defaultValue, nil } var tmp T if err := bindValue(value, &tmp, opts...); err != nil { var zero T return zero, fmt.Errorf("failed to parse value, err: %w", err) } return tmp, nil } func bindValue(value string, dest any, opts ...any) error { // NOTE: if this function is ever made public the dest should be checked for nil // values when dealing with interfaces if len(opts) > 0 { if _, isTime := dest.(*time.Time); !isTime { return fmt.Errorf("options are only supported for time.Time, got %T", dest) } } switch d := dest.(type) { case *bool: n, err := strconv.ParseBool(value) if err != nil { return err } *d = n case *float32: n, err := strconv.ParseFloat(value, 32) if err != nil { return err } *d = float32(n) case *float64: n, err := strconv.ParseFloat(value, 64) if err != nil { return err } *d = n case *int: n, err := strconv.ParseInt(value, 10, 0) if err != nil { return err } *d = int(n) case *int8: n, err := strconv.ParseInt(value, 10, 8) if err != nil { return err } *d = int8(n) case *int16: n, err := strconv.ParseInt(value, 10, 16) if err != nil { return err } *d = int16(n) case *int32: n, err := strconv.ParseInt(value, 10, 32) if err != nil { return err } *d = int32(n) case *int64: n, err := strconv.ParseInt(value, 10, 64) if err != nil { return err } *d = n case *uint: n, err := strconv.ParseUint(value, 10, 0) if err != nil { return err } *d = uint(n) case *uint8: n, err := strconv.ParseUint(value, 10, 8) if err != nil { return err } *d = uint8(n) case *uint16: n, err := strconv.ParseUint(value, 10, 16) if err != nil { return err } *d = uint16(n) case *uint32: n, err := strconv.ParseUint(value, 10, 32) if err != nil { return err } *d = uint32(n) case *uint64: n, err := strconv.ParseUint(value, 10, 64) if err != nil { return err } *d = n case *string: *d = value case *time.Duration: t, err := time.ParseDuration(value) if err != nil { return err } *d = t case *time.Time: to := TimeOpts{ Layout: TimeLayout(time.RFC3339Nano), ParseInLocation: time.UTC, ToInLocation: time.UTC, } for _, o := range opts { switch v := o.(type) { case TimeOpts: if v.Layout != "" { to.Layout = v.Layout } if v.ParseInLocation != nil { to.ParseInLocation = v.ParseInLocation } if v.ToInLocation != nil { to.ToInLocation = v.ToInLocation } case TimeLayout: to.Layout = v } } var t time.Time var err error switch to.Layout { case TimeLayoutUnixTime: n, err := strconv.ParseInt(value, 10, 64) if err != nil { return err } t = time.Unix(n, 0) case TimeLayoutUnixTimeMilli: n, err := strconv.ParseInt(value, 10, 64) if err != nil { return err } t = time.UnixMilli(n) case TimeLayoutUnixTimeNano: n, err := strconv.ParseInt(value, 10, 64) if err != nil { return err } t = time.Unix(0, n) default: if to.ParseInLocation != nil { t, err = time.ParseInLocation(string(to.Layout), value, to.ParseInLocation) } else { t, err = time.Parse(string(to.Layout), value) } if err != nil { return err } } *d = t.In(to.ToInLocation) case BindUnmarshaler: if err := d.UnmarshalParam(value); err != nil { return err } case encoding.TextUnmarshaler: if err := d.UnmarshalText([]byte(value)); err != nil { return err } case json.Unmarshaler: if err := d.UnmarshalJSON([]byte(value)); err != nil { return err } default: return fmt.Errorf("unsupported value type: %T", dest) } return nil } ================================================ FILE: binder_generic_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echo import ( "cmp" "encoding/json" "fmt" "math" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/stretchr/testify/assert" ) // TextUnmarshalerType implements encoding.TextUnmarshaler but NOT BindUnmarshaler type TextUnmarshalerType struct { Value string } func (t *TextUnmarshalerType) UnmarshalText(data []byte) error { s := string(data) if s == "invalid" { return fmt.Errorf("invalid value: %s", s) } t.Value = strings.ToUpper(s) return nil } // JSONUnmarshalerType implements json.Unmarshaler but NOT BindUnmarshaler or TextUnmarshaler type JSONUnmarshalerType struct { Value string } func (j *JSONUnmarshalerType) UnmarshalJSON(data []byte) error { return json.Unmarshal(data, &j.Value) } func TestPathParam(t *testing.T) { var testCases = []struct { name string givenKey string givenValue string expect bool expectErr string }{ { name: "ok", givenValue: "true", expect: true, }, { name: "nok, non existent key", givenKey: "missing", givenValue: "true", expect: false, expectErr: ErrNonExistentKey.Error(), }, { name: "nok, invalid value", givenValue: "can_parse_me", expect: false, expectErr: `code=400, message=path value, err=failed to parse value, err: strconv.ParseBool: parsing "can_parse_me": invalid syntax, field=key`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := NewContext(nil, nil) c.SetPathValues(PathValues{{ Name: cmp.Or(tc.givenKey, "key"), Value: tc.givenValue, }}) v, err := PathParam[bool](c, "key") if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestPathParam_UnsupportedType(t *testing.T) { c := NewContext(nil, nil) c.SetPathValues(PathValues{{Name: "key", Value: "true"}}) v, err := PathParam[[]bool](c, "key") expectErr := "code=400, message=path value, err=failed to parse value, err: unsupported value type: *[]bool, field=key" assert.EqualError(t, err, expectErr) assert.Equal(t, []bool(nil), v) } func TestQueryParam(t *testing.T) { var testCases = []struct { name string givenURL string expect bool expectErr string }{ { name: "ok", givenURL: "/?key=true", expect: true, }, { name: "nok, non existent key", givenURL: "/?different=true", expect: false, expectErr: ErrNonExistentKey.Error(), }, { name: "nok, invalid value", givenURL: "/?key=invalidbool", expect: false, expectErr: `code=400, message=query param, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) c := NewContext(req, nil) v, err := QueryParam[bool](c, "key") if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestQueryParam_UnsupportedType(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil) c := NewContext(req, nil) v, err := QueryParam[[]bool](c, "key") expectErr := "code=400, message=query param, err=failed to parse value, err: unsupported value type: *[]bool, field=key" assert.EqualError(t, err, expectErr) assert.Equal(t, []bool(nil), v) } func TestQueryParams(t *testing.T) { var testCases = []struct { name string givenURL string expect []bool expectErr string }{ { name: "ok", givenURL: "/?key=true&key=false", expect: []bool{true, false}, }, { name: "nok, non existent key", givenURL: "/?different=true", expect: []bool(nil), expectErr: ErrNonExistentKey.Error(), }, { name: "nok, invalid value", givenURL: "/?key=true&key=invalidbool", expect: []bool(nil), expectErr: `code=400, message=query params, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) c := NewContext(req, nil) v, err := QueryParams[bool](c, "key") if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestQueryParams_UnsupportedType(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil) c := NewContext(req, nil) v, err := QueryParams[[]bool](c, "key") expectErr := "code=400, message=query params, err=failed to parse value, err: unsupported value type: *[]bool, field=key" assert.EqualError(t, err, expectErr) assert.Equal(t, [][]bool(nil), v) } func TestFormValue(t *testing.T) { var testCases = []struct { name string givenURL string expect bool expectErr string }{ { name: "ok", givenURL: "/?key=true", expect: true, }, { name: "nok, non existent key", givenURL: "/?different=true", expect: false, expectErr: ErrNonExistentKey.Error(), }, { name: "nok, invalid value", givenURL: "/?key=invalidbool", expect: false, expectErr: `code=400, message=form value, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) c := NewContext(req, nil) v, err := FormValue[bool](c, "key") if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestFormValue_UnsupportedType(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil) c := NewContext(req, nil) v, err := FormValue[[]bool](c, "key") expectErr := "code=400, message=form value, err=failed to parse value, err: unsupported value type: *[]bool, field=key" assert.EqualError(t, err, expectErr) assert.Equal(t, []bool(nil), v) } func TestFormValues(t *testing.T) { var testCases = []struct { name string givenURL string expect []bool expectErr string }{ { name: "ok", givenURL: "/?key=true&key=false", expect: []bool{true, false}, }, { name: "nok, non existent key", givenURL: "/?different=true", expect: []bool(nil), expectErr: ErrNonExistentKey.Error(), }, { name: "nok, invalid value", givenURL: "/?key=true&key=invalidbool", expect: []bool(nil), expectErr: `code=400, message=form values, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) c := NewContext(req, nil) v, err := FormValues[bool](c, "key") if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestFormValues_UnsupportedType(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil) c := NewContext(req, nil) v, err := FormValues[[]bool](c, "key") expectErr := "code=400, message=form values, err=failed to parse value, err: unsupported value type: *[]bool, field=key" assert.EqualError(t, err, expectErr) assert.Equal(t, [][]bool(nil), v) } func TestParseValue_bool(t *testing.T) { var testCases = []struct { name string when string expect bool expectErr error }{ { name: "ok, true", when: "true", expect: true, }, { name: "ok, false", when: "false", expect: false, }, { name: "ok, 1", when: "1", expect: true, }, { name: "ok, 0", when: "0", expect: false, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { v, err := ParseValue[bool](tc.when) if tc.expectErr != nil { assert.ErrorIs(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestParseValue_float32(t *testing.T) { var testCases = []struct { name string when string expect float32 expectErr string }{ { name: "ok, 123.345", when: "123.345", expect: 123.345, }, { name: "ok, 0", when: "0", expect: 0, }, { name: "ok, Inf", when: "+Inf", expect: float32(math.Inf(1)), }, { name: "ok, Inf", when: "-Inf", expect: float32(math.Inf(-1)), }, { name: "ok, NaN", when: "NaN", expect: float32(math.NaN()), }, { name: "ok, invalid value", when: "X", expect: 0, expectErr: `failed to parse value, err: strconv.ParseFloat: parsing "X": invalid syntax`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { v, err := ParseValue[float32](tc.when) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } if math.IsNaN(float64(tc.expect)) { if !math.IsNaN(float64(v)) { t.Fatal("expected NaN but got non NaN") } } else { assert.Equal(t, tc.expect, v) } }) } } func TestParseValue_float64(t *testing.T) { var testCases = []struct { name string when string expect float64 expectErr string }{ { name: "ok, 123.345", when: "123.345", expect: 123.345, }, { name: "ok, 0", when: "0", expect: 0, }, { name: "ok, Inf", when: "+Inf", expect: math.Inf(1), }, { name: "ok, Inf", when: "-Inf", expect: math.Inf(-1), }, { name: "ok, NaN", when: "NaN", expect: math.NaN(), }, { name: "ok, invalid value", when: "X", expect: 0, expectErr: `failed to parse value, err: strconv.ParseFloat: parsing "X": invalid syntax`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { v, err := ParseValue[float64](tc.when) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } if math.IsNaN(tc.expect) { if !math.IsNaN(v) { t.Fatal("expected NaN but got non NaN") } } else { assert.Equal(t, tc.expect, v) } }) } } func TestParseValue_int(t *testing.T) { var testCases = []struct { name string when string expect int expectErr string }{ { name: "ok, 0", when: "0", expect: 0, }, { name: "ok, 1", when: "1", expect: 1, }, { name: "ok, -1", when: "-1", expect: -1, }, { name: "ok, max int (64bit)", when: "9223372036854775807", expect: 9223372036854775807, }, { name: "ok, min int (64bit)", when: "-9223372036854775808", expect: -9223372036854775808, }, { name: "ok, overflow max int (64bit)", when: "9223372036854775808", expect: 0, expectErr: `failed to parse value, err: strconv.ParseInt: parsing "9223372036854775808": value out of range`, }, { name: "ok, underflow min int (64bit)", when: "-9223372036854775809", expect: 0, expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-9223372036854775809": value out of range`, }, { name: "ok, invalid value", when: "X", expect: 0, expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { v, err := ParseValue[int](tc.when) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestParseValue_uint(t *testing.T) { var testCases = []struct { name string when string expect uint expectErr string }{ { name: "ok, 0", when: "0", expect: 0, }, { name: "ok, 1", when: "1", expect: 1, }, { name: "ok, max uint (64bit)", when: "18446744073709551615", expect: 18446744073709551615, }, { name: "nok, overflow max uint (64bit)", when: "18446744073709551616", expect: 0, expectErr: `failed to parse value, err: strconv.ParseUint: parsing "18446744073709551616": value out of range`, }, { name: "nok, negative value", when: "-1", expect: 0, expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`, }, { name: "nok, invalid value", when: "X", expect: 0, expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { v, err := ParseValue[uint](tc.when) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestParseValue_int8(t *testing.T) { var testCases = []struct { name string when string expect int8 expectErr string }{ { name: "ok, 0", when: "0", expect: 0, }, { name: "ok, 1", when: "1", expect: 1, }, { name: "ok, -1", when: "-1", expect: -1, }, { name: "ok, max int8", when: "127", expect: 127, }, { name: "ok, min int8", when: "-128", expect: -128, }, { name: "nok, overflow max int8", when: "128", expect: 0, expectErr: `failed to parse value, err: strconv.ParseInt: parsing "128": value out of range`, }, { name: "nok, underflow min int8", when: "-129", expect: 0, expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-129": value out of range`, }, { name: "nok, invalid value", when: "X", expect: 0, expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { v, err := ParseValue[int8](tc.when) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestParseValue_int16(t *testing.T) { var testCases = []struct { name string when string expect int16 expectErr string }{ { name: "ok, 0", when: "0", expect: 0, }, { name: "ok, 1", when: "1", expect: 1, }, { name: "ok, -1", when: "-1", expect: -1, }, { name: "ok, max int16", when: "32767", expect: 32767, }, { name: "ok, min int16", when: "-32768", expect: -32768, }, { name: "nok, overflow max int16", when: "32768", expect: 0, expectErr: `failed to parse value, err: strconv.ParseInt: parsing "32768": value out of range`, }, { name: "nok, underflow min int16", when: "-32769", expect: 0, expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-32769": value out of range`, }, { name: "nok, invalid value", when: "X", expect: 0, expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { v, err := ParseValue[int16](tc.when) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestParseValue_int32(t *testing.T) { var testCases = []struct { name string when string expect int32 expectErr string }{ { name: "ok, 0", when: "0", expect: 0, }, { name: "ok, 1", when: "1", expect: 1, }, { name: "ok, -1", when: "-1", expect: -1, }, { name: "ok, max int32", when: "2147483647", expect: 2147483647, }, { name: "ok, min int32", when: "-2147483648", expect: -2147483648, }, { name: "nok, overflow max int32", when: "2147483648", expect: 0, expectErr: `failed to parse value, err: strconv.ParseInt: parsing "2147483648": value out of range`, }, { name: "nok, underflow min int32", when: "-2147483649", expect: 0, expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-2147483649": value out of range`, }, { name: "nok, invalid value", when: "X", expect: 0, expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { v, err := ParseValue[int32](tc.when) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestParseValue_int64(t *testing.T) { var testCases = []struct { name string when string expect int64 expectErr string }{ { name: "ok, 0", when: "0", expect: 0, }, { name: "ok, 1", when: "1", expect: 1, }, { name: "ok, -1", when: "-1", expect: -1, }, { name: "ok, max int64", when: "9223372036854775807", expect: 9223372036854775807, }, { name: "ok, min int64", when: "-9223372036854775808", expect: -9223372036854775808, }, { name: "nok, overflow max int64", when: "9223372036854775808", expect: 0, expectErr: `failed to parse value, err: strconv.ParseInt: parsing "9223372036854775808": value out of range`, }, { name: "nok, underflow min int64", when: "-9223372036854775809", expect: 0, expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-9223372036854775809": value out of range`, }, { name: "nok, invalid value", when: "X", expect: 0, expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { v, err := ParseValue[int64](tc.when) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestParseValue_uint8(t *testing.T) { var testCases = []struct { name string when string expect uint8 expectErr string }{ { name: "ok, 0", when: "0", expect: 0, }, { name: "ok, 1", when: "1", expect: 1, }, { name: "ok, max uint8", when: "255", expect: 255, }, { name: "nok, overflow max uint8", when: "256", expect: 0, expectErr: `failed to parse value, err: strconv.ParseUint: parsing "256": value out of range`, }, { name: "nok, negative value", when: "-1", expect: 0, expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`, }, { name: "nok, invalid value", when: "X", expect: 0, expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { v, err := ParseValue[uint8](tc.when) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestParseValue_uint16(t *testing.T) { var testCases = []struct { name string when string expect uint16 expectErr string }{ { name: "ok, 0", when: "0", expect: 0, }, { name: "ok, 1", when: "1", expect: 1, }, { name: "ok, max uint16", when: "65535", expect: 65535, }, { name: "nok, overflow max uint16", when: "65536", expect: 0, expectErr: `failed to parse value, err: strconv.ParseUint: parsing "65536": value out of range`, }, { name: "nok, negative value", when: "-1", expect: 0, expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`, }, { name: "nok, invalid value", when: "X", expect: 0, expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { v, err := ParseValue[uint16](tc.when) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestParseValue_uint32(t *testing.T) { var testCases = []struct { name string when string expect uint32 expectErr string }{ { name: "ok, 0", when: "0", expect: 0, }, { name: "ok, 1", when: "1", expect: 1, }, { name: "ok, max uint32", when: "4294967295", expect: 4294967295, }, { name: "nok, overflow max uint32", when: "4294967296", expect: 0, expectErr: `failed to parse value, err: strconv.ParseUint: parsing "4294967296": value out of range`, }, { name: "nok, negative value", when: "-1", expect: 0, expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`, }, { name: "nok, invalid value", when: "X", expect: 0, expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { v, err := ParseValue[uint32](tc.when) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestParseValue_uint64(t *testing.T) { var testCases = []struct { name string when string expect uint64 expectErr string }{ { name: "ok, 0", when: "0", expect: 0, }, { name: "ok, 1", when: "1", expect: 1, }, { name: "ok, max uint64", when: "18446744073709551615", expect: 18446744073709551615, }, { name: "nok, overflow max uint64", when: "18446744073709551616", expect: 0, expectErr: `failed to parse value, err: strconv.ParseUint: parsing "18446744073709551616": value out of range`, }, { name: "nok, negative value", when: "-1", expect: 0, expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`, }, { name: "nok, invalid value", when: "X", expect: 0, expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { v, err := ParseValue[uint64](tc.when) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestParseValue_string(t *testing.T) { var testCases = []struct { name string when string expect string expectErr string }{ { name: "ok, my", when: "my", expect: "my", }, { name: "ok, empty", when: "", expect: "", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { v, err := ParseValue[string](tc.when) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestParseValue_Duration(t *testing.T) { var testCases = []struct { name string when string expect time.Duration expectErr string }{ { name: "ok, 10h11m01s", when: "10h11m01s", expect: 10*time.Hour + 11*time.Minute + 1*time.Second, }, { name: "ok, empty", when: "", expect: 0, }, { name: "ok, invalid", when: "0x0", expect: 0, expectErr: `failed to parse value, err: time: unknown unit "x" in duration "0x0"`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { v, err := ParseValue[time.Duration](tc.when) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestParseValue_Time(t *testing.T) { tallinn, err := time.LoadLocation("Europe/Tallinn") if err != nil { t.Fatal(err) } berlin, err := time.LoadLocation("Europe/Berlin") if err != nil { t.Fatal(err) } parse := func(t *testing.T, layout string, s string) time.Time { result, err := time.Parse(layout, s) if err != nil { t.Fatal(err) } return result } parseInLoc := func(t *testing.T, layout string, s string, loc *time.Location) time.Time { result, err := time.ParseInLocation(layout, s, loc) if err != nil { t.Fatal(err) } return result } var testCases = []struct { name string when string whenLayout TimeLayout whenTimeOpts *TimeOpts expect time.Time expectErr string }{ { name: "ok, defaults to RFC3339Nano", when: "2006-01-02T15:04:05.999999999Z", expect: parse(t, time.RFC3339Nano, "2006-01-02T15:04:05.999999999Z"), }, { name: "ok, custom TimeOpt", when: "2006-01-02", whenTimeOpts: &TimeOpts{ Layout: time.DateOnly, ParseInLocation: tallinn, ToInLocation: berlin, }, expect: parseInLoc(t, time.DateTime, "2006-01-01 23:00:00", berlin), }, { name: "ok, custom layout", when: "2006-01-02", whenLayout: TimeLayout(time.DateOnly), expect: parse(t, time.DateOnly, "2006-01-02"), }, { name: "ok, TimeLayoutUnixTime", when: "1766604665", whenLayout: TimeLayoutUnixTime, expect: parse(t, time.RFC3339Nano, "2025-12-24T19:31:05Z"), }, { name: "nok, TimeLayoutUnixTime, invalid value", when: "176x6604665", whenLayout: TimeLayoutUnixTime, expectErr: `failed to parse value, err: strconv.ParseInt: parsing "176x6604665": invalid syntax`, }, { name: "ok, TimeLayoutUnixTimeMilli", when: "1766604665123", whenLayout: TimeLayoutUnixTimeMilli, expect: parse(t, time.RFC3339Nano, "2025-12-24T19:31:05.123Z"), }, { name: "nok, TimeLayoutUnixTimeMilli, invalid value", when: "1x766604665123", whenLayout: TimeLayoutUnixTimeMilli, expectErr: `failed to parse value, err: strconv.ParseInt: parsing "1x766604665123": invalid syntax`, }, { name: "ok, TimeLayoutUnixTimeMilli", when: "1766604665999999999", whenLayout: TimeLayoutUnixTimeNano, expect: parse(t, time.RFC3339Nano, "2025-12-24T19:31:05.999999999Z"), }, { name: "nok, TimeLayoutUnixTimeMilli, invalid value", when: "1x766604665999999999", whenLayout: TimeLayoutUnixTimeNano, expectErr: `failed to parse value, err: strconv.ParseInt: parsing "1x766604665999999999": invalid syntax`, }, { name: "ok, invalid", when: "xx", expect: time.Time{}, expectErr: `failed to parse value, err: parsing time "xx" as "2006-01-02T15:04:05.999999999Z07:00": cannot parse "xx" as "2006"`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var opts []any if tc.whenLayout != "" { opts = append(opts, tc.whenLayout) } if tc.whenTimeOpts != nil { opts = append(opts, *tc.whenTimeOpts) } v, err := ParseValue[time.Time](tc.when, opts...) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestParseValue_OptionsOnlyForTime(t *testing.T) { _, err := ParseValue[int]("test", TimeLayoutUnixTime) assert.EqualError(t, err, `failed to parse value, err: options are only supported for time.Time, got *int`) } func TestParseValue_BindUnmarshaler(t *testing.T) { exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00") var testCases = []struct { name string when string expect Timestamp expectErr string }{ { name: "ok", when: "2020-12-23T09:45:31+02:00", expect: Timestamp(exampleTime), }, { name: "nok, invalid value", when: "2020-12-23T09:45:3102:00", expect: Timestamp{}, expectErr: `failed to parse value, err: parsing time "2020-12-23T09:45:3102:00" as "2006-01-02T15:04:05Z07:00": cannot parse "02:00" as "Z07:00"`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { v, err := ParseValue[Timestamp](tc.when) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestParseValue_TextUnmarshaler(t *testing.T) { var testCases = []struct { name string when string expect TextUnmarshalerType expectErr string }{ { name: "ok, converts to uppercase", when: "hello", expect: TextUnmarshalerType{Value: "HELLO"}, }, { name: "ok, empty string", when: "", expect: TextUnmarshalerType{Value: ""}, }, { name: "nok, invalid value", when: "invalid", expect: TextUnmarshalerType{}, expectErr: "failed to parse value, err: invalid value: invalid", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { v, err := ParseValue[TextUnmarshalerType](tc.when) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestParseValue_JSONUnmarshaler(t *testing.T) { var testCases = []struct { name string when string expect JSONUnmarshalerType expectErr string }{ { name: "ok, valid JSON string", when: `"hello"`, expect: JSONUnmarshalerType{Value: "hello"}, }, { name: "ok, empty JSON string", when: `""`, expect: JSONUnmarshalerType{Value: ""}, }, { name: "nok, invalid JSON", when: "not-json", expect: JSONUnmarshalerType{}, expectErr: "failed to parse value, err: invalid character 'o' in literal null (expecting 'u')", }, { name: "nok, unquoted string", when: "hello", expect: JSONUnmarshalerType{}, expectErr: "failed to parse value, err: invalid character 'h' looking for beginning of value", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { v, err := ParseValue[JSONUnmarshalerType](tc.when) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestParseValues_bools(t *testing.T) { var testCases = []struct { name string when []string expect []bool expectErr string }{ { name: "ok", when: []string{"true", "0", "false", "1"}, expect: []bool{true, false, false, true}, }, { name: "nok", when: []string{"true", "10"}, expect: nil, expectErr: `failed to parse value, err: strconv.ParseBool: parsing "10": invalid syntax`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { v, err := ParseValues[bool](tc.when) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestPathParamOr(t *testing.T) { var testCases = []struct { name string givenKey string givenValue string defaultValue int expect int expectErr string }{ { name: "ok, param exists", givenKey: "id", givenValue: "123", defaultValue: 999, expect: 123, }, { name: "ok, param missing - returns default", givenKey: "other", givenValue: "123", defaultValue: 999, expect: 999, }, { name: "ok, param exists but empty - returns default", givenKey: "id", givenValue: "", defaultValue: 999, expect: 999, }, { name: "nok, invalid value", givenKey: "id", givenValue: "invalid", defaultValue: 999, expectErr: "code=400, message=path value, err=failed to parse value", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := NewContext(nil, nil) c.SetPathValues(PathValues{{Name: tc.givenKey, Value: tc.givenValue}}) v, err := PathParamOr[int](c, "id", tc.defaultValue) if tc.expectErr != "" { assert.ErrorContains(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestQueryParamOr(t *testing.T) { var testCases = []struct { name string givenURL string defaultValue int expect int expectErr string }{ { name: "ok, param exists", givenURL: "/?key=42", defaultValue: 999, expect: 42, }, { name: "ok, param missing - returns default", givenURL: "/?other=42", defaultValue: 999, expect: 999, }, { name: "ok, param exists but empty - returns default", givenURL: "/?key=", defaultValue: 999, expect: 999, }, { name: "nok, invalid value", givenURL: "/?key=invalid", defaultValue: 999, expectErr: "code=400, message=query param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) c := NewContext(req, nil) v, err := QueryParamOr[int](c, "key", tc.defaultValue) if tc.expectErr != "" { assert.ErrorContains(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestQueryParamsOr(t *testing.T) { var testCases = []struct { name string givenURL string defaultValue []int expect []int expectErr string }{ { name: "ok, params exist", givenURL: "/?key=1&key=2&key=3", defaultValue: []int{999}, expect: []int{1, 2, 3}, }, { name: "ok, params missing - returns default", givenURL: "/?other=1", defaultValue: []int{7, 8, 9}, expect: []int{7, 8, 9}, }, { name: "nok, invalid value", givenURL: "/?key=1&key=invalid", defaultValue: []int{999}, expectErr: "code=400, message=query params", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) c := NewContext(req, nil) v, err := QueryParamsOr[int](c, "key", tc.defaultValue) if tc.expectErr != "" { assert.ErrorContains(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestFormValueOr(t *testing.T) { var testCases = []struct { name string givenURL string defaultValue string expect string expectErr string }{ { name: "ok, value exists", givenURL: "/?name=john", defaultValue: "default", expect: "john", }, { name: "ok, value missing - returns default", givenURL: "/?other=john", defaultValue: "default", expect: "default", }, { name: "ok, value exists but empty - returns default", givenURL: "/?name=", defaultValue: "default", expect: "default", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) c := NewContext(req, nil) v, err := FormValueOr[string](c, "name", tc.defaultValue) if tc.expectErr != "" { assert.ErrorContains(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } func TestFormValuesOr(t *testing.T) { var testCases = []struct { name string givenURL string defaultValue []string expect []string expectErr string }{ { name: "ok, values exist", givenURL: "/?tags=go&tags=rust&tags=python", defaultValue: []string{"default"}, expect: []string{"go", "rust", "python"}, }, { name: "ok, values missing - returns default", givenURL: "/?other=value", defaultValue: []string{"a", "b"}, expect: []string{"a", "b"}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) c := NewContext(req, nil) v, err := FormValuesOr[string](c, "tags", tc.defaultValue) if tc.expectErr != "" { assert.ErrorContains(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expect, v) }) } } ================================================ FILE: binder_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echo import ( "encoding/json" "errors" "fmt" "github.com/stretchr/testify/assert" "io" "math/big" "net/http" "net/http/httptest" "strconv" "strings" "testing" "time" ) func createTestContext(URL string, body io.Reader, pathValues map[string]string) *Context { e := New() req := httptest.NewRequest(http.MethodGet, URL, body) if body != nil { req.Header.Set(HeaderContentType, MIMEApplicationJSON) } rec := httptest.NewRecorder() c := e.NewContext(req, rec) if len(pathValues) > 0 { params := make(PathValues, 0) for name, value := range pathValues { params = append(params, PathValue{ Name: name, Value: value, }) } c.SetPathValues(params) } return c } func TestBindingError_Error(t *testing.T) { err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error")) assert.EqualError(t, err, `code=400, message=bind failed, err=internal error, field=id`) bErr := err.(*BindingError) assert.Equal(t, 400, bErr.Code) assert.Equal(t, "bind failed", bErr.Message) assert.Equal(t, errors.New("internal error"), bErr.err) assert.Equal(t, "id", bErr.Field) assert.Equal(t, []string{"1", "nope"}, bErr.Values) } func TestBindingError_ErrorJSON(t *testing.T) { err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error")) resp, _ := json.Marshal(err) assert.Equal(t, `{"field":"id","message":"bind failed"}`, string(resp)) } func TestPathValuesBinder(t *testing.T) { c := createTestContext("/api/user/999", nil, map[string]string{ "id": "1", "nr": "2", "slice": "3", }) b := PathValuesBinder(c) id := int64(99) nr := int64(88) var slice = make([]int64, 0) var notExisting = make([]int64, 0) err := b.Int64("id", &id). Int64("nr", &nr). Int64s("slice", &slice). Int64s("not_existing", ¬Existing). BindError() assert.NoError(t, err) assert.Equal(t, int64(1), id) assert.Equal(t, int64(2), nr) assert.Equal(t, []int64{3}, slice) // binding params to slice does not make sense but it should not panic either assert.Equal(t, []int64{}, notExisting) // binding params to slice does not make sense but it should not panic either } func TestQueryParamsBinder_FailFast(t *testing.T) { var testCases = []struct { name string whenURL string expectError []string givenFailFast bool }{ { name: "ok, FailFast=true stops at first error", whenURL: "/api/user/999?nr=en&id=nope", givenFailFast: true, expectError: []string{ `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "nope": invalid syntax, field=id`, }, }, { name: "ok, FailFast=false encounters all errors", whenURL: "/api/user/999?nr=en&id=nope", givenFailFast: false, expectError: []string{ `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "nope": invalid syntax, field=id`, `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "en": invalid syntax, field=nr`, }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, map[string]string{"id": "999"}) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) id := int64(99) nr := int64(88) errs := b.Int64("id", &id). Int64("nr", &nr). BindErrors() assert.Len(t, errs, len(tc.expectError)) for _, err := range errs { assert.Contains(t, tc.expectError, err.Error()) } }) } } func TestFormFieldBinder(t *testing.T) { e := New() body := `texta=foo&slice=5` req := httptest.NewRequest(http.MethodPost, "/api/search?id=1&nr=2&slice=3&slice=4", strings.NewReader(body)) req.Header.Set(HeaderContentLength, strconv.Itoa(len(body))) req.Header.Set(HeaderContentType, MIMEApplicationForm) rec := httptest.NewRecorder() c := e.NewContext(req, rec) b := FormFieldBinder(c) var texta string id := int64(99) nr := int64(88) var slice = make([]int64, 0) var notExisting = make([]int64, 0) err := b. Int64s("slice", &slice). Int64("id", &id). Int64("nr", &nr). String("texta", &texta). Int64s("notExisting", ¬Existing). BindError() assert.NoError(t, err) assert.Equal(t, "foo", texta) assert.Equal(t, int64(1), id) assert.Equal(t, int64(2), nr) assert.Equal(t, []int64{5, 3, 4}, slice) assert.Equal(t, []int64{}, notExisting) } func TestValueBinder_errorStopsBinding(t *testing.T) { // this test documents "feature" that binding multiple params can change destination if it was bound before // failing parameter binding c := createTestContext("/api/user/999?id=1&nr=nope", nil, nil) b := QueryParamsBinder(c) id := int64(99) // will be changed before nr binding fails nr := int64(88) // will not be changed err := b.Int64("id", &id). Int64("nr", &nr). BindError() assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=nr") assert.Equal(t, int64(1), id) assert.Equal(t, int64(88), nr) } func TestValueBinder_BindError(t *testing.T) { c := createTestContext("/api/user/999?nr=en&id=nope", nil, nil) b := QueryParamsBinder(c) id := int64(99) nr := int64(88) err := b.Int64("id", &id). Int64("nr", &nr). BindError() assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=id") assert.Nil(t, b.errors) assert.Nil(t, b.BindError()) } func TestValueBinder_GetValues(t *testing.T) { var testCases = []struct { whenValuesFunc func(sourceParam string) []string name string expectError string expect []int64 }{ { name: "ok, default implementation", expect: []int64{1, 101}, }, { name: "ok, values returns nil", whenValuesFunc: func(sourceParam string) []string { return nil }, expect: []int64(nil), }, { name: "ok, values returns empty slice", whenValuesFunc: func(sourceParam string) []string { return []string{} }, expect: []int64(nil), }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext("/search?nr=en&id=1&id=101", nil, nil) b := QueryParamsBinder(c) if tc.whenValuesFunc != nil { b.ValuesFunc = tc.whenValuesFunc } var IDs []int64 err := b.Int64s("id", &IDs).BindError() assert.Equal(t, tc.expect, IDs) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_CustomFuncWithError(t *testing.T) { c := createTestContext("/search?nr=en&id=1&id=101", nil, nil) b := QueryParamsBinder(c) id := int64(99) givenCustomFunc := func(values []string) []error { assert.Equal(t, []string{"1", "101"}, values) return []error{ errors.New("first error"), errors.New("second error"), } } err := b.CustomFunc("id", givenCustomFunc).BindError() assert.Equal(t, int64(99), id) assert.EqualError(t, err, "first error") } func TestValueBinder_CustomFunc(t *testing.T) { var testCases = []struct { expectValue any name string whenURL string givenFuncErrors []error expectParamValues []string expectErrors []string givenFailFast bool }{ { name: "ok, binds value", whenURL: "/search?nr=en&id=1&id=100", expectParamValues: []string{"1", "100"}, expectValue: int64(1000), }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nr=en", expectParamValues: []string{}, expectValue: int64(99), }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, whenURL: "/search?nr=en&id=1&id=100", expectParamValues: []string{"1", "100"}, expectValue: int64(99), expectErrors: []string{"previous error"}, }, { name: "nok, func returns errors", givenFuncErrors: []error{ errors.New("first error"), errors.New("second error"), }, whenURL: "/search?nr=en&id=1&id=100", expectParamValues: []string{"1", "100"}, expectValue: int64(99), expectErrors: []string{"first error", "second error"}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } id := int64(99) givenCustomFunc := func(values []string) []error { assert.Equal(t, tc.expectParamValues, values) if tc.givenFuncErrors == nil { id = 1000 // emulated conversion and setting value return nil } return tc.givenFuncErrors } errs := b.CustomFunc("id", givenCustomFunc).BindErrors() assert.Equal(t, tc.expectValue, id) if tc.expectErrors != nil { assert.Len(t, errs, len(tc.expectErrors)) for _, err := range errs { assert.Contains(t, tc.expectErrors, err.Error()) } } else { assert.Nil(t, errs) } }) } } func TestValueBinder_MustCustomFunc(t *testing.T) { var testCases = []struct { expectValue any name string whenURL string givenFuncErrors []error expectParamValues []string expectErrors []string givenFailFast bool }{ { name: "ok, binds value", whenURL: "/search?nr=en&id=1&id=100", expectParamValues: []string{"1", "100"}, expectValue: int64(1000), }, { name: "nok, params values empty, returns error, value is not changed", whenURL: "/search?nr=en", expectParamValues: []string{}, expectValue: int64(99), expectErrors: []string{"code=400, message=required field value is empty, field=id"}, }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, whenURL: "/search?nr=en&id=1&id=100", expectParamValues: []string{"1", "100"}, expectValue: int64(99), expectErrors: []string{"previous error"}, }, { name: "nok, func returns errors", givenFuncErrors: []error{ errors.New("first error"), errors.New("second error"), }, whenURL: "/search?nr=en&id=1&id=100", expectParamValues: []string{"1", "100"}, expectValue: int64(99), expectErrors: []string{"first error", "second error"}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } id := int64(99) givenCustomFunc := func(values []string) []error { assert.Equal(t, tc.expectParamValues, values) if tc.givenFuncErrors == nil { id = 1000 // emulated conversion and setting value return nil } return tc.givenFuncErrors } errs := b.MustCustomFunc("id", givenCustomFunc).BindErrors() assert.Equal(t, tc.expectValue, id) if tc.expectErrors != nil { assert.Len(t, errs, len(tc.expectErrors)) for _, err := range errs { assert.Contains(t, tc.expectErrors, err.Error()) } } else { assert.Nil(t, errs) } }) } } func TestValueBinder_String(t *testing.T) { var testCases = []struct { name string whenURL string expectValue string expectError string givenBindErrors []error givenFailFast bool whenMust bool }{ { name: "ok, binds value", whenURL: "/search?param=en¶m=de", expectValue: "en", }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nr=en", expectValue: "default", }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, whenURL: "/search?nr=en&id=1&id=100", expectValue: "default", expectError: "previous error", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=en¶m=de", expectValue: "en", }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nr=en", expectValue: "default", expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, whenMust: true, whenURL: "/search?nr=en&id=1&id=100", expectValue: "default", expectError: "previous error", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } dest := "default" var err error if tc.whenMust { err = b.MustString("param", &dest).BindError() } else { err = b.String("param", &dest).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_Strings(t *testing.T) { var testCases = []struct { name string whenURL string expectError string givenBindErrors []error expectValue []string givenFailFast bool whenMust bool }{ { name: "ok, binds value", whenURL: "/search?param=en¶m=de", expectValue: []string{"en", "de"}, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nr=en", expectValue: []string{"default"}, }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, whenURL: "/search?nr=en&id=1&id=100", expectValue: []string{"default"}, expectError: "previous error", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=en¶m=de", expectValue: []string{"en", "de"}, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nr=en", expectValue: []string{"default"}, expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, whenMust: true, whenURL: "/search?nr=en&id=1&id=100", expectValue: []string{"default"}, expectError: "previous error", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } dest := []string{"default"} var err error if tc.whenMust { err = b.MustStrings("param", &dest).BindError() } else { err = b.Strings("param", &dest).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_Int64_intValue(t *testing.T) { var testCases = []struct { name string whenURL string expectError string givenBindErrors []error expectValue int64 givenFailFast bool whenMust bool }{ { name: "ok, binds value", whenURL: "/search?param=1¶m=100", expectValue: 1, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: 99, }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: 99, expectError: "previous error", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: 99, expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: 1, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: 99, expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: 99, expectError: "previous error", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: 99, expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } dest := int64(99) var err error if tc.whenMust { err = b.MustInt64("param", &dest).BindError() } else { err = b.Int64("param", &dest).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_Int_errorMessage(t *testing.T) { // int/uint (without byte size) has a little bit different error message so test these separately c := createTestContext("/search?param=nope", nil, nil) b := QueryParamsBinder(c).FailFast(false) destInt := 99 destUint := uint(98) errs := b.Int("param", &destInt).Uint("param", &destUint).BindErrors() assert.Equal(t, 99, destInt) assert.Equal(t, uint(98), destUint) assert.EqualError(t, errs[0], `code=400, message=failed to bind field value to int, err=strconv.ParseInt: parsing "nope": invalid syntax, field=param`) assert.EqualError(t, errs[1], `code=400, message=failed to bind field value to uint, err=strconv.ParseUint: parsing "nope": invalid syntax, field=param`) } func TestValueBinder_Uint64_uintValue(t *testing.T) { var testCases = []struct { name string whenURL string expectError string givenBindErrors []error expectValue uint64 givenFailFast bool whenMust bool }{ { name: "ok, binds value", whenURL: "/search?param=1¶m=100", expectValue: 1, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: 99, }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: 99, expectError: "previous error", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: 99, expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: 1, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: 99, expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: 99, expectError: "previous error", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: 99, expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } dest := uint64(99) var err error if tc.whenMust { err = b.MustUint64("param", &dest).BindError() } else { err = b.Uint64("param", &dest).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_Int_Types(t *testing.T) { type target struct { int64 int64 mustInt64 int64 uint64 uint64 mustUint64 uint64 int32 int32 mustInt32 int32 uint32 uint32 mustUint32 uint32 int16 int16 mustInt16 int16 uint16 uint16 mustUint16 uint16 int8 int8 mustInt8 int8 uint8 uint8 mustUint8 uint8 byte byte mustByte byte int int mustInt int uint uint mustUint uint } types := []string{ "int64=1", "mustInt64=2", "uint64=3", "mustUint64=4", "int32=5", "mustInt32=6", "uint32=7", "mustUint32=8", "int16=9", "mustInt16=10", "uint16=11", "mustUint16=12", "int8=13", "mustInt8=14", "uint8=15", "mustUint8=16", "byte=17", "mustByte=18", "int=19", "mustInt=20", "uint=21", "mustUint=22", } c := createTestContext("/search?"+strings.Join(types, "&"), nil, nil) b := QueryParamsBinder(c) dest := target{} err := b. Int64("int64", &dest.int64). MustInt64("mustInt64", &dest.mustInt64). Uint64("uint64", &dest.uint64). MustUint64("mustUint64", &dest.mustUint64). Int32("int32", &dest.int32). MustInt32("mustInt32", &dest.mustInt32). Uint32("uint32", &dest.uint32). MustUint32("mustUint32", &dest.mustUint32). Int16("int16", &dest.int16). MustInt16("mustInt16", &dest.mustInt16). Uint16("uint16", &dest.uint16). MustUint16("mustUint16", &dest.mustUint16). Int8("int8", &dest.int8). MustInt8("mustInt8", &dest.mustInt8). Uint8("uint8", &dest.uint8). MustUint8("mustUint8", &dest.mustUint8). Byte("byte", &dest.byte). MustByte("mustByte", &dest.mustByte). Int("int", &dest.int). MustInt("mustInt", &dest.mustInt). Uint("uint", &dest.uint). MustUint("mustUint", &dest.mustUint). BindError() assert.NoError(t, err) assert.Equal(t, int64(1), dest.int64) assert.Equal(t, int64(2), dest.mustInt64) assert.Equal(t, uint64(3), dest.uint64) assert.Equal(t, uint64(4), dest.mustUint64) assert.Equal(t, int32(5), dest.int32) assert.Equal(t, int32(6), dest.mustInt32) assert.Equal(t, uint32(7), dest.uint32) assert.Equal(t, uint32(8), dest.mustUint32) assert.Equal(t, int16(9), dest.int16) assert.Equal(t, int16(10), dest.mustInt16) assert.Equal(t, uint16(11), dest.uint16) assert.Equal(t, uint16(12), dest.mustUint16) assert.Equal(t, int8(13), dest.int8) assert.Equal(t, int8(14), dest.mustInt8) assert.Equal(t, uint8(15), dest.uint8) assert.Equal(t, uint8(16), dest.mustUint8) assert.Equal(t, uint8(17), dest.byte) assert.Equal(t, uint8(18), dest.mustByte) assert.Equal(t, 19, dest.int) assert.Equal(t, 20, dest.mustInt) assert.Equal(t, uint(21), dest.uint) assert.Equal(t, uint(22), dest.mustUint) } func TestValueBinder_Int64s_intsValue(t *testing.T) { var testCases = []struct { name string whenURL string expectError string givenBindErrors []error expectValue []int64 givenFailFast bool whenMust bool }{ { name: "ok, binds value", whenURL: "/search?param=1¶m=2¶m=1", expectValue: []int64{1, 2, 1}, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: []int64{99}, }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: []int64{99}, expectError: "previous error", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []int64{99}, expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=1¶m=2¶m=1", expectValue: []int64{1, 2, 1}, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: []int64{99}, expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: []int64{99}, expectError: "previous error", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []int64{99}, expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } dest := []int64{99} // when values are set with bind - contents before bind is gone var err error if tc.whenMust { err = b.MustInt64s("param", &dest).BindError() } else { err = b.Int64s("param", &dest).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_Uint64s_uintsValue(t *testing.T) { var testCases = []struct { name string whenURL string expectError string givenBindErrors []error expectValue []uint64 givenFailFast bool whenMust bool }{ { name: "ok, binds value", whenURL: "/search?param=1¶m=2¶m=1", expectValue: []uint64{1, 2, 1}, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: []uint64{99}, }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: []uint64{99}, expectError: "previous error", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []uint64{99}, expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=1¶m=2¶m=1", expectValue: []uint64{1, 2, 1}, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: []uint64{99}, expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: []uint64{99}, expectError: "previous error", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []uint64{99}, expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } dest := []uint64{99} // when values are set with bind - contents before bind is gone var err error if tc.whenMust { err = b.MustUint64s("param", &dest).BindError() } else { err = b.Uint64s("param", &dest).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_Ints_Types(t *testing.T) { type target struct { int64 []int64 mustInt64 []int64 uint64 []uint64 mustUint64 []uint64 int32 []int32 mustInt32 []int32 uint32 []uint32 mustUint32 []uint32 int16 []int16 mustInt16 []int16 uint16 []uint16 mustUint16 []uint16 int8 []int8 mustInt8 []int8 uint8 []uint8 mustUint8 []uint8 int []int mustInt []int uint []uint mustUint []uint } types := []string{ "int64=1", "mustInt64=2", "uint64=3", "mustUint64=4", "int32=5", "mustInt32=6", "uint32=7", "mustUint32=8", "int16=9", "mustInt16=10", "uint16=11", "mustUint16=12", "int8=13", "mustInt8=14", "uint8=15", "mustUint8=16", "int=19", "mustInt=20", "uint=21", "mustUint=22", } url := "/search?" for _, v := range types { url = url + "&" + v + "&" + v } c := createTestContext(url, nil, nil) b := QueryParamsBinder(c) dest := target{} err := b. Int64s("int64", &dest.int64). MustInt64s("mustInt64", &dest.mustInt64). Uint64s("uint64", &dest.uint64). MustUint64s("mustUint64", &dest.mustUint64). Int32s("int32", &dest.int32). MustInt32s("mustInt32", &dest.mustInt32). Uint32s("uint32", &dest.uint32). MustUint32s("mustUint32", &dest.mustUint32). Int16s("int16", &dest.int16). MustInt16s("mustInt16", &dest.mustInt16). Uint16s("uint16", &dest.uint16). MustUint16s("mustUint16", &dest.mustUint16). Int8s("int8", &dest.int8). MustInt8s("mustInt8", &dest.mustInt8). Uint8s("uint8", &dest.uint8). MustUint8s("mustUint8", &dest.mustUint8). Ints("int", &dest.int). MustInts("mustInt", &dest.mustInt). Uints("uint", &dest.uint). MustUints("mustUint", &dest.mustUint). BindError() assert.NoError(t, err) assert.Equal(t, []int64{1, 1}, dest.int64) assert.Equal(t, []int64{2, 2}, dest.mustInt64) assert.Equal(t, []uint64{3, 3}, dest.uint64) assert.Equal(t, []uint64{4, 4}, dest.mustUint64) assert.Equal(t, []int32{5, 5}, dest.int32) assert.Equal(t, []int32{6, 6}, dest.mustInt32) assert.Equal(t, []uint32{7, 7}, dest.uint32) assert.Equal(t, []uint32{8, 8}, dest.mustUint32) assert.Equal(t, []int16{9, 9}, dest.int16) assert.Equal(t, []int16{10, 10}, dest.mustInt16) assert.Equal(t, []uint16{11, 11}, dest.uint16) assert.Equal(t, []uint16{12, 12}, dest.mustUint16) assert.Equal(t, []int8{13, 13}, dest.int8) assert.Equal(t, []int8{14, 14}, dest.mustInt8) assert.Equal(t, []uint8{15, 15}, dest.uint8) assert.Equal(t, []uint8{16, 16}, dest.mustUint8) assert.Equal(t, []int{19, 19}, dest.int) assert.Equal(t, []int{20, 20}, dest.mustInt) assert.Equal(t, []uint{21, 21}, dest.uint) assert.Equal(t, []uint{22, 22}, dest.mustUint) } func TestValueBinder_Ints_Types_FailFast(t *testing.T) { // FailFast() should stop parsing and return early errTmpl := "code=400, message=failed to bind field value to %v, err=strconv.Parse%v: parsing \"nope\": invalid syntax, field=param" c := createTestContext("/search?param=1¶m=nope¶m=2", nil, nil) var dest64 []int64 err := QueryParamsBinder(c).FailFast(true).Int64s("param", &dest64).BindError() assert.Equal(t, []int64(nil), dest64) assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int64", "Int")) var dest32 []int32 err = QueryParamsBinder(c).FailFast(true).Int32s("param", &dest32).BindError() assert.Equal(t, []int32(nil), dest32) assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int32", "Int")) var dest16 []int16 err = QueryParamsBinder(c).FailFast(true).Int16s("param", &dest16).BindError() assert.Equal(t, []int16(nil), dest16) assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int16", "Int")) var dest8 []int8 err = QueryParamsBinder(c).FailFast(true).Int8s("param", &dest8).BindError() assert.Equal(t, []int8(nil), dest8) assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int8", "Int")) var dest []int err = QueryParamsBinder(c).FailFast(true).Ints("param", &dest).BindError() assert.Equal(t, []int(nil), dest) assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int", "Int")) var destu64 []uint64 err = QueryParamsBinder(c).FailFast(true).Uint64s("param", &destu64).BindError() assert.Equal(t, []uint64(nil), destu64) assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint64", "Uint")) var destu32 []uint32 err = QueryParamsBinder(c).FailFast(true).Uint32s("param", &destu32).BindError() assert.Equal(t, []uint32(nil), destu32) assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint32", "Uint")) var destu16 []uint16 err = QueryParamsBinder(c).FailFast(true).Uint16s("param", &destu16).BindError() assert.Equal(t, []uint16(nil), destu16) assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint16", "Uint")) var destu8 []uint8 err = QueryParamsBinder(c).FailFast(true).Uint8s("param", &destu8).BindError() assert.Equal(t, []uint8(nil), destu8) assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint8", "Uint")) var destu []uint err = QueryParamsBinder(c).FailFast(true).Uints("param", &destu).BindError() assert.Equal(t, []uint(nil), destu) assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint", "Uint")) } func TestValueBinder_Bool(t *testing.T) { var testCases = []struct { name string whenURL string expectError string givenBindErrors []error givenFailFast bool whenMust bool expectValue bool }{ { name: "ok, binds value", whenURL: "/search?param=true¶m=1", expectValue: true, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: false, }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: false, expectError: "previous error", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: false, expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: true, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: false, expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: false, expectError: "previous error", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: false, expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } dest := false var err error if tc.whenMust { err = b.MustBool("param", &dest).BindError() } else { err = b.Bool("param", &dest).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_Bools(t *testing.T) { var testCases = []struct { name string whenURL string expectError string givenBindErrors []error expectValue []bool givenFailFast bool whenMust bool }{ { name: "ok, binds value", whenURL: "/search?param=true¶m=false¶m=1¶m=0", expectValue: []bool{true, false, true, false}, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: []bool(nil), }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, givenBindErrors: []error{errors.New("previous error")}, whenURL: "/search?param=1¶m=100", expectValue: []bool(nil), expectError: "previous error", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=true¶m=nope¶m=100", expectValue: []bool(nil), expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", }, { name: "nok, conversion fails fast, value is not changed", givenFailFast: true, whenURL: "/search?param=true¶m=nope¶m=100", expectValue: []bool(nil), expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=true¶m=false¶m=1¶m=0", expectValue: []bool{true, false, true, false}, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: []bool(nil), expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, givenBindErrors: []error{errors.New("previous error")}, whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: []bool(nil), expectError: "previous error", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []bool(nil), expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) b.errors = tc.givenBindErrors var dest []bool var err error if tc.whenMust { err = b.MustBools("param", &dest).BindError() } else { err = b.Bools("param", &dest).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_Float64(t *testing.T) { var testCases = []struct { name string whenURL string expectError string givenBindErrors []error expectValue float64 givenFailFast bool whenMust bool }{ { name: "ok, binds value", whenURL: "/search?param=4.3¶m=1", expectValue: 4.3, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: 1.123, }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: 1.123, expectError: "previous error", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: 1.123, expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=4.3¶m=100", expectValue: 4.3, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: 1.123, expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: 1.123, expectError: "previous error", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: 1.123, expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } dest := 1.123 var err error if tc.whenMust { err = b.MustFloat64("param", &dest).BindError() } else { err = b.Float64("param", &dest).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_Float64s(t *testing.T) { var testCases = []struct { name string whenURL string expectError string givenBindErrors []error expectValue []float64 givenFailFast bool whenMust bool }{ { name: "ok, binds value", whenURL: "/search?param=4.3¶m=0", expectValue: []float64{4.3, 0}, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: []float64(nil), }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, givenBindErrors: []error{errors.New("previous error")}, whenURL: "/search?param=1¶m=100", expectValue: []float64(nil), expectError: "previous error", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []float64(nil), expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, { name: "nok, conversion fails fast, value is not changed", givenFailFast: true, whenURL: "/search?param=0¶m=nope¶m=100", expectValue: []float64(nil), expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=4.3¶m=0", expectValue: []float64{4.3, 0}, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: []float64(nil), expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, givenBindErrors: []error{errors.New("previous error")}, whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: []float64(nil), expectError: "previous error", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []float64(nil), expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) b.errors = tc.givenBindErrors var dest []float64 var err error if tc.whenMust { err = b.MustFloat64s("param", &dest).BindError() } else { err = b.Float64s("param", &dest).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_Float32(t *testing.T) { var testCases = []struct { name string whenURL string expectError string givenBindErrors []error expectValue float32 givenNoFailFast bool whenMust bool }{ { name: "ok, binds value", whenURL: "/search?param=4.3¶m=1", expectValue: 4.3, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: 1.123, }, { name: "nok, previous errors fail fast without binding value", givenNoFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: 1.123, expectError: "previous error", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: 1.123, expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=4.3¶m=100", expectValue: 4.3, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: 1.123, expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenNoFailFast: true, whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: 1.123, expectError: "previous error", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: 1.123, expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenNoFailFast) if tc.givenNoFailFast { b.errors = []error{errors.New("previous error")} } dest := float32(1.123) var err error if tc.whenMust { err = b.MustFloat32("param", &dest).BindError() } else { err = b.Float32("param", &dest).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_Float32s(t *testing.T) { var testCases = []struct { name string whenURL string expectError string givenBindErrors []error expectValue []float32 givenFailFast bool whenMust bool }{ { name: "ok, binds value", whenURL: "/search?param=4.3¶m=0", expectValue: []float32{4.3, 0}, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: []float32(nil), }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, givenBindErrors: []error{errors.New("previous error")}, whenURL: "/search?param=1¶m=100", expectValue: []float32(nil), expectError: "previous error", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []float32(nil), expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, { name: "nok, conversion fails fast, value is not changed", givenFailFast: true, whenURL: "/search?param=0¶m=nope¶m=100", expectValue: []float32(nil), expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=4.3¶m=0", expectValue: []float32{4.3, 0}, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: []float32(nil), expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, givenBindErrors: []error{errors.New("previous error")}, whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: []float32(nil), expectError: "previous error", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []float32(nil), expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) b.errors = tc.givenBindErrors var dest []float32 var err error if tc.whenMust { err = b.MustFloat32s("param", &dest).BindError() } else { err = b.Float32s("param", &dest).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_Time(t *testing.T) { exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00") var testCases = []struct { expectValue time.Time name string whenURL string whenLayout string expectError string givenBindErrors []error givenFailFast bool whenMust bool }{ { name: "ok, binds value", whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00", whenLayout: time.RFC3339, expectValue: exampleTime, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: time.Time{}, }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: time.Time{}, expectError: "previous error", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00", whenLayout: time.RFC3339, expectValue: exampleTime, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: time.Time{}, expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: time.Time{}, expectError: "previous error", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } dest := time.Time{} var err error if tc.whenMust { err = b.MustTime("param", &dest, tc.whenLayout).BindError() } else { err = b.Time("param", &dest, tc.whenLayout).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_Times(t *testing.T) { exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00") exampleTime2, _ := time.Parse(time.RFC3339, "2000-01-02T09:45:31+00:00") var testCases = []struct { name string whenURL string whenLayout string expectError string givenBindErrors []error expectValue []time.Time givenFailFast bool whenMust bool }{ { name: "ok, binds value", whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00", whenLayout: time.RFC3339, expectValue: []time.Time{exampleTime, exampleTime2}, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: []time.Time(nil), }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, givenBindErrors: []error{errors.New("previous error")}, whenURL: "/search?param=1¶m=100", expectValue: []time.Time(nil), expectError: "previous error", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00", whenLayout: time.RFC3339, expectValue: []time.Time{exampleTime, exampleTime2}, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: []time.Time(nil), expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, givenBindErrors: []error{errors.New("previous error")}, whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: []time.Time(nil), expectError: "previous error", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) b.errors = tc.givenBindErrors layout := time.RFC3339 if tc.whenLayout != "" { layout = tc.whenLayout } var dest []time.Time var err error if tc.whenMust { err = b.MustTimes("param", &dest, layout).BindError() } else { err = b.Times("param", &dest, layout).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_Duration(t *testing.T) { example := 42 * time.Second var testCases = []struct { name string whenURL string expectError string givenBindErrors []error expectValue time.Duration givenFailFast bool whenMust bool }{ { name: "ok, binds value", whenURL: "/search?param=42s¶m=1ms", expectValue: example, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: 0, }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: 0, expectError: "previous error", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=42s¶m=1ms", expectValue: example, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: 0, expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: 0, expectError: "previous error", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } var dest time.Duration var err error if tc.whenMust { err = b.MustDuration("param", &dest).BindError() } else { err = b.Duration("param", &dest).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_Durations(t *testing.T) { exampleDuration := 42 * time.Second exampleDuration2 := 1 * time.Millisecond var testCases = []struct { name string whenURL string expectError string givenBindErrors []error expectValue []time.Duration givenFailFast bool whenMust bool }{ { name: "ok, binds value", whenURL: "/search?param=42s¶m=1ms", expectValue: []time.Duration{exampleDuration, exampleDuration2}, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: []time.Duration(nil), }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, givenBindErrors: []error{errors.New("previous error")}, whenURL: "/search?param=1¶m=100", expectValue: []time.Duration(nil), expectError: "previous error", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=42s¶m=1ms", expectValue: []time.Duration{exampleDuration, exampleDuration2}, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: []time.Duration(nil), expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, givenBindErrors: []error{errors.New("previous error")}, whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: []time.Duration(nil), expectError: "previous error", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) b.errors = tc.givenBindErrors var dest []time.Duration var err error if tc.whenMust { err = b.MustDurations("param", &dest).BindError() } else { err = b.Durations("param", &dest).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_BindUnmarshaler(t *testing.T) { exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00") var testCases = []struct { expectValue Timestamp name string whenURL string expectError string givenBindErrors []error givenFailFast bool whenMust bool }{ { name: "ok, binds value", whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00", expectValue: Timestamp(exampleTime), }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: Timestamp{}, }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: Timestamp{}, expectError: "previous error", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: Timestamp{}, expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00", expectValue: Timestamp(exampleTime), }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: Timestamp{}, expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: Timestamp{}, expectError: "previous error", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: Timestamp{}, expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } var dest Timestamp var err error if tc.whenMust { err = b.MustBindUnmarshaler("param", &dest).BindError() } else { err = b.BindUnmarshaler("param", &dest).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_JSONUnmarshaler(t *testing.T) { example := big.NewInt(999) var testCases = []struct { name string whenURL string expectError string expectValue big.Int givenBindErrors []error givenFailFast bool whenMust bool }{ { name: "ok, binds value", whenURL: "/search?param=999¶m=998", expectValue: *example, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: big.Int{}, }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: big.Int{}, expectError: "previous error", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=xxx", expectValue: big.Int{}, expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=999¶m=998", expectValue: *example, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: big.Int{}, expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, whenMust: true, whenURL: "/search?param=1¶m=xxx", expectValue: big.Int{}, expectError: "previous error", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=xxx", expectValue: big.Int{}, expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } var dest big.Int var err error if tc.whenMust { err = b.MustJSONUnmarshaler("param", &dest).BindError() } else { err = b.JSONUnmarshaler("param", &dest).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_TextUnmarshaler(t *testing.T) { example := big.NewInt(999) var testCases = []struct { name string whenURL string expectError string expectValue big.Int givenBindErrors []error givenFailFast bool whenMust bool }{ { name: "ok, binds value", whenURL: "/search?param=999¶m=998", expectValue: *example, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: big.Int{}, }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: big.Int{}, expectError: "previous error", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=xxx", expectValue: big.Int{}, expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=999¶m=998", expectValue: *example, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: big.Int{}, expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, whenMust: true, whenURL: "/search?param=1¶m=xxx", expectValue: big.Int{}, expectError: "previous error", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=xxx", expectValue: big.Int{}, expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } var dest big.Int var err error if tc.whenMust { err = b.MustTextUnmarshaler("param", &dest).BindError() } else { err = b.TextUnmarshaler("param", &dest).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_BindWithDelimiter_types(t *testing.T) { var testCases = []struct { expect any name string whenURL string }{ { name: "ok, strings", expect: []string{"1", "2", "1"}, }, { name: "ok, int64", expect: []int64{1, 2, 1}, }, { name: "ok, int32", expect: []int32{1, 2, 1}, }, { name: "ok, int16", expect: []int16{1, 2, 1}, }, { name: "ok, int8", expect: []int8{1, 2, 1}, }, { name: "ok, int", expect: []int{1, 2, 1}, }, { name: "ok, uint64", expect: []uint64{1, 2, 1}, }, { name: "ok, uint32", expect: []uint32{1, 2, 1}, }, { name: "ok, uint16", expect: []uint16{1, 2, 1}, }, { name: "ok, uint8", expect: []uint8{1, 2, 1}, }, { name: "ok, uint", expect: []uint{1, 2, 1}, }, { name: "ok, float64", expect: []float64{1, 2, 1}, }, { name: "ok, float32", expect: []float32{1, 2, 1}, }, { name: "ok, bool", whenURL: "/search?param=1,false¶m=true", expect: []bool{true, false, true}, }, { name: "ok, Duration", whenURL: "/search?param=1s,42s¶m=1ms", expect: []time.Duration{1 * time.Second, 42 * time.Second, 1 * time.Millisecond}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { URL := "/search?param=1,2¶m=1" if tc.whenURL != "" { URL = tc.whenURL } c := createTestContext(URL, nil, nil) b := QueryParamsBinder(c) switch tc.expect.(type) { case []string: var dest []string assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) assert.Equal(t, tc.expect, dest) case []int64: var dest []int64 assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) assert.Equal(t, tc.expect, dest) case []int32: var dest []int32 assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) assert.Equal(t, tc.expect, dest) case []int16: var dest []int16 assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) assert.Equal(t, tc.expect, dest) case []int8: var dest []int8 assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) assert.Equal(t, tc.expect, dest) case []int: var dest []int assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) assert.Equal(t, tc.expect, dest) case []uint64: var dest []uint64 assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) assert.Equal(t, tc.expect, dest) case []uint32: var dest []uint32 assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) assert.Equal(t, tc.expect, dest) case []uint16: var dest []uint16 assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) assert.Equal(t, tc.expect, dest) case []uint8: var dest []uint8 assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) assert.Equal(t, tc.expect, dest) case []uint: var dest []uint assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) assert.Equal(t, tc.expect, dest) case []float64: var dest []float64 assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) assert.Equal(t, tc.expect, dest) case []float32: var dest []float32 assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) assert.Equal(t, tc.expect, dest) case []bool: var dest []bool assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) assert.Equal(t, tc.expect, dest) case []time.Duration: var dest []time.Duration assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) assert.Equal(t, tc.expect, dest) default: assert.Fail(t, "invalid type") } }) } } func TestValueBinder_BindWithDelimiter(t *testing.T) { var testCases = []struct { name string whenURL string expectError string givenBindErrors []error expectValue []int64 givenFailFast bool whenMust bool }{ { name: "ok, binds value", whenURL: "/search?param=1,2¶m=1", expectValue: []int64{1, 2, 1}, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: []int64(nil), }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: []int64(nil), expectError: "previous error", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []int64(nil), expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=1,2¶m=1", expectValue: []int64{1, 2, 1}, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: []int64(nil), expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: []int64(nil), expectError: "previous error", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []int64(nil), expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } var dest []int64 var err error if tc.whenMust { err = b.MustBindWithDelimiter("param", &dest, ",").BindError() } else { err = b.BindWithDelimiter("param", &dest, ",").BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestBindWithDelimiter_invalidType(t *testing.T) { c := createTestContext("/search?param=1¶m=100", nil, nil) b := QueryParamsBinder(c) var dest []BindUnmarshaler err := b.BindWithDelimiter("param", &dest, ",").BindError() assert.Equal(t, []BindUnmarshaler(nil), dest) assert.EqualError(t, err, "code=400, message=unsupported bind type, field=param") } func TestValueBinder_UnixTime(t *testing.T) { exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43+00:00") // => 1609180603 var testCases = []struct { expectValue time.Time name string whenURL string expectError string givenBindErrors []error givenFailFast bool whenMust bool }{ { name: "ok, binds value, unix time in seconds", whenURL: "/search?param=1609180603¶m=1609180604", expectValue: exampleTime, }, { name: "ok, binds value, unix time over int32 value", whenURL: "/search?param=2147483648¶m=1609180604", expectValue: time.Unix(2147483648, 0), }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: time.Time{}, }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: time.Time{}, expectError: "previous error", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=1609180603¶m=1609180604", expectValue: exampleTime, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: time.Time{}, expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: time.Time{}, expectError: "previous error", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } dest := time.Time{} var err error if tc.whenMust { err = b.MustUnixTime("param", &dest).BindError() } else { err = b.UnixTime("param", &dest).BindError() } assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano()) assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC)) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_UnixTimeMilli(t *testing.T) { exampleTime, _ := time.Parse(time.RFC3339Nano, "2022-03-13T15:13:30.140000000+00:00") // => 1647184410140 var testCases = []struct { expectValue time.Time name string whenURL string expectError string givenBindErrors []error givenFailFast bool whenMust bool }{ { name: "ok, binds value, unix time in milliseconds", whenURL: "/search?param=1647184410140¶m=1647184410199", expectValue: exampleTime, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: time.Time{}, }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: time.Time{}, expectError: "previous error", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=1647184410140¶m=1647184410199", expectValue: exampleTime, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: time.Time{}, expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: time.Time{}, expectError: "previous error", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } dest := time.Time{} var err error if tc.whenMust { err = b.MustUnixTimeMilli("param", &dest).BindError() } else { err = b.UnixTimeMilli("param", &dest).BindError() } assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano()) assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC)) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_UnixTimeNano(t *testing.T) { exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43.000000000+00:00") // => 1609180603 exampleTimeNano, _ := time.Parse(time.RFC3339Nano, "2020-12-28T18:36:43.123456789+00:00") // => 1609180603123456789 exampleTimeNanoBelowSec, _ := time.Parse(time.RFC3339Nano, "1970-01-01T00:00:00.999999999+00:00") var testCases = []struct { expectValue time.Time name string whenURL string expectError string givenBindErrors []error givenFailFast bool whenMust bool }{ { name: "ok, binds value, unix time in nano seconds (sec precision)", whenURL: "/search?param=1609180603000000000¶m=1609180604", expectValue: exampleTime, }, { name: "ok, binds value, unix time in nano seconds", whenURL: "/search?param=1609180603123456789¶m=1609180604", expectValue: exampleTimeNano, }, { name: "ok, binds value, unix time in nano seconds (below 1 sec)", whenURL: "/search?param=999999999¶m=1609180604", expectValue: exampleTimeNanoBelowSec, }, { name: "ok, params values empty, value is not changed", whenURL: "/search?nope=1", expectValue: time.Time{}, }, { name: "nok, previous errors fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: time.Time{}, expectError: "previous error", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", whenMust: true, whenURL: "/search?param=1609180603000000000¶m=1609180604", expectValue: exampleTime, }, { name: "ok (must), params values empty, returns error, value is not changed", whenMust: true, whenURL: "/search?nope=1", expectValue: time.Time{}, expectError: "code=400, message=required field value is empty, field=param", }, { name: "nok (must), previous errors fail fast without binding value", givenFailFast: true, whenMust: true, whenURL: "/search?param=1¶m=100", expectValue: time.Time{}, expectError: "previous error", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } dest := time.Time{} var err error if tc.whenMust { err = b.MustUnixTimeNano("param", &dest).BindError() } else { err = b.UnixTimeNano("param", &dest).BindError() } assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano()) assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC)) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func BenchmarkDefaultBinder_BindInt64_single(b *testing.B) { type Opts struct { Param int64 `query:"param"` } c := createTestContext("/search?param=1¶m=100", nil, nil) b.ReportAllocs() b.ResetTimer() binder := new(DefaultBinder) for i := 0; i < b.N; i++ { var dest Opts _ = binder.Bind(c, &dest) } } func BenchmarkValueBinder_BindInt64_single(b *testing.B) { c := createTestContext("/search?param=1¶m=100", nil, nil) b.ReportAllocs() b.ResetTimer() type Opts struct { Param int64 } binder := QueryParamsBinder(c) for i := 0; i < b.N; i++ { var dest Opts _ = binder.Int64("param", &dest.Param).BindError() } } func BenchmarkRawFunc_Int64_single(b *testing.B) { c := createTestContext("/search?param=1¶m=100", nil, nil) rawFunc := func(input string, defaultValue int64) (int64, bool) { if input == "" { return defaultValue, true } n, err := strconv.Atoi(input) if err != nil { return 0, false } return int64(n), true } b.ReportAllocs() b.ResetTimer() type Opts struct { Param int64 } for i := 0; i < b.N; i++ { var dest Opts if n, ok := rawFunc(c.QueryParam("param"), 1); ok { dest.Param = n } } } func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) { type Opts struct { String string `query:"string"` Strings []string `query:"strings"` Int64 int64 `query:"int64"` Uint64 uint64 `query:"uint64"` Int32 int32 `query:"int32"` Uint32 uint32 `query:"uint32"` Int16 int16 `query:"int16"` Uint16 uint16 `query:"uint16"` Int8 int8 `query:"int8"` Uint8 uint8 `query:"uint8"` } c := createTestContext("/search?int64=1&int32=2&int16=3&int8=4&string=test&uint64=5&uint32=6&uint16=7&uint8=8&strings=first&strings=second", nil, nil) b.ReportAllocs() b.ResetTimer() binder := new(DefaultBinder) for i := 0; i < b.N; i++ { var dest Opts _ = binder.Bind(c, &dest) if dest.Int64 != 1 { b.Fatalf("int64!=1") } } } func BenchmarkValueBinder_BindInt64_10_fields(b *testing.B) { type Opts struct { String string `query:"string"` Strings []string `query:"strings"` Int64 int64 `query:"int64"` Uint64 uint64 `query:"uint64"` Int32 int32 `query:"int32"` Uint32 uint32 `query:"uint32"` Int16 int16 `query:"int16"` Uint16 uint16 `query:"uint16"` Int8 int8 `query:"int8"` Uint8 uint8 `query:"uint8"` } c := createTestContext("/search?int64=1&int32=2&int16=3&int8=4&string=test&uint64=5&uint32=6&uint16=7&uint8=8&strings=first&strings=second", nil, nil) b.ReportAllocs() b.ResetTimer() binder := QueryParamsBinder(c) for i := 0; i < b.N; i++ { var dest Opts _ = binder. Int64("int64", &dest.Int64). Int32("int32", &dest.Int32). Int16("int16", &dest.Int16). Int8("int8", &dest.Int8). String("string", &dest.String). Uint64("int64", &dest.Uint64). Uint32("int32", &dest.Uint32). Uint16("int16", &dest.Uint16). Uint8("int8", &dest.Uint8). Strings("strings", &dest.Strings). BindError() if dest.Int64 != 1 { b.Fatalf("int64!=1") } } } func TestValueBinder_TimeError(t *testing.T) { var testCases = []struct { expectValue time.Time name string whenURL string whenLayout string expectError string givenBindErrors []error givenFailFast bool whenMust bool }{ { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\": extra text: \"nope\", field=param", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\": extra text: \"nope\", field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } dest := time.Time{} var err error if tc.whenMust { err = b.MustTime("param", &dest, tc.whenLayout).BindError() } else { err = b.Time("param", &dest, tc.whenLayout).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_TimesError(t *testing.T) { var testCases = []struct { name string whenURL string whenLayout string expectError string givenBindErrors []error expectValue []time.Time givenFailFast bool whenMust bool }{ { name: "nok, fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: []time.Time(nil), expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []time.Time(nil), expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []time.Time(nil), expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) b.errors = tc.givenBindErrors layout := time.RFC3339 if tc.whenLayout != "" { layout = tc.whenLayout } var dest []time.Time var err error if tc.whenMust { err = b.MustTimes("param", &dest, layout).BindError() } else { err = b.Times("param", &dest, layout).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_DurationError(t *testing.T) { var testCases = []struct { name string whenURL string expectError string givenBindErrors []error expectValue time.Duration givenFailFast bool whenMust bool }{ { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: 0, expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: 0, expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) if tc.givenFailFast { b.errors = []error{errors.New("previous error")} } var dest time.Duration var err error if tc.whenMust { err = b.MustDuration("param", &dest).BindError() } else { err = b.Duration("param", &dest).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValueBinder_DurationsError(t *testing.T) { var testCases = []struct { name string whenURL string expectError string givenBindErrors []error expectValue []time.Duration givenFailFast bool whenMust bool }{ { name: "nok, fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: []time.Duration(nil), expectError: "code=400, message=failed to bind field value to Duration, err=time: missing unit in duration \"1\", field=param", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []time.Duration(nil), expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []time.Duration(nil), expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := createTestContext(tc.whenURL, nil, nil) b := QueryParamsBinder(c).FailFast(tc.givenFailFast) b.errors = tc.givenBindErrors var dest []time.Duration var err error if tc.whenMust { err = b.MustDurations("param", &dest).BindError() } else { err = b.Durations("param", &dest).BindError() } assert.Equal(t, tc.expectValue, dest) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } ================================================ FILE: codecov.yml ================================================ coverage: status: project: default: threshold: 1% patch: default: threshold: 1% comment: require_changes: true ================================================ FILE: context.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echo import ( "bytes" "encoding/xml" "errors" "fmt" "io" "io/fs" "log/slog" "mime/multipart" "net" "net/http" "net/url" "path" "path/filepath" "strings" "sync" ) const ( // ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain. // Allow header is mandatory for status 405 (method not found) and useful for OPTIONS method requests. // It is added to context only when Router does not find matching method handler for request. ContextKeyHeaderAllow = "echo_header_allow" ) const ( // defaultMemory is default value for memory limit that is used when // parsing multipart forms (See (*http.Request).ParseMultipartForm) defaultMemory int64 = 32 << 20 // 32 MB indexPage = "index.html" ) // Context represents the context of the current HTTP request. It holds request and // response objects, path, path parameters, data and registered handler. type Context struct { request *http.Request orgResponse *Response response http.ResponseWriter query url.Values // formParseMaxMemory is used for http.Request.ParseMultipartForm formParseMaxMemory int64 route *RouteInfo pathValues *PathValues store map[string]any echo *Echo logger *slog.Logger path string lock sync.RWMutex } // NewContext returns a new Context instance. // // Note: request,response and e can be left to nil as Echo.ServeHTTP will call c.Reset(req,resp) anyway // these arguments are useful when creating context for tests and cases like that. func NewContext(r *http.Request, w http.ResponseWriter, opts ...any) *Context { var e *Echo for _, opt := range opts { switch v := opt.(type) { case *Echo: e = v } } return newContext(r, w, e) } func newContext(r *http.Request, w http.ResponseWriter, e *Echo) *Context { c := &Context{ pathValues: nil, store: make(map[string]any), echo: e, logger: nil, } var logger *slog.Logger paramLen := int32(0) formParseMaxMemory := defaultMemory if e != nil { paramLen = e.contextPathParamAllocSize.Load() logger = e.Logger formParseMaxMemory = e.formParseMaxMemory } if logger == nil { logger = slog.Default() } c.logger = logger p := make(PathValues, 0, paramLen) c.pathValues = &p c.SetRequest(r) c.orgResponse = NewResponse(w, logger) c.response = c.orgResponse c.formParseMaxMemory = formParseMaxMemory return c } // Reset resets the context after request completes. It must be called along // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. // See `Echo#ServeHTTP()` func (c *Context) Reset(r *http.Request, w http.ResponseWriter) { c.request = r c.orgResponse.reset(w) c.response = c.orgResponse c.query = nil c.store = nil c.logger = c.echo.Logger c.route = nil c.path = "" // NOTE: empty by setting length to 0. PathValues has to have capacity of c.echo.contextPathParamAllocSize at all times *c.pathValues = (*c.pathValues)[:0] } func (c *Context) writeContentType(value string) { header := c.response.Header() if header.Get(HeaderContentType) == "" { header.Set(HeaderContentType, value) } } // Request returns `*http.Request`. func (c *Context) Request() *http.Request { return c.request } // SetRequest sets `*http.Request`. func (c *Context) SetRequest(r *http.Request) { c.request = r } // Response returns `*Response`. func (c *Context) Response() http.ResponseWriter { return c.response } // SetResponse sets `*http.ResponseWriter`. Some context methods and/or middleware require that given ResponseWriter implements following // method `Unwrap() http.ResponseWriter` which eventually should return *echo.Response instance. func (c *Context) SetResponse(r http.ResponseWriter) { c.response = r } // IsTLS returns true if HTTP connection is TLS otherwise false. func (c *Context) IsTLS() bool { return c.request.TLS != nil } // IsWebSocket returns true if HTTP connection is WebSocket otherwise false. func (c *Context) IsWebSocket() bool { upgrade := c.request.Header.Get(HeaderUpgrade) connection := c.request.Header.Get(HeaderConnection) return strings.EqualFold(upgrade, "websocket") && strings.Contains(strings.ToLower(connection), "upgrade") } // Scheme returns the HTTP protocol scheme, `http` or `https`. func (c *Context) Scheme() string { // Can't use `r.Request.URL.Scheme` // See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0 if c.IsTLS() { return "https" } if scheme := c.request.Header.Get(HeaderXForwardedProto); scheme != "" { return scheme } if scheme := c.request.Header.Get(HeaderXForwardedProtocol); scheme != "" { return scheme } if ssl := c.request.Header.Get(HeaderXForwardedSsl); ssl == "on" { return "https" } if scheme := c.request.Header.Get(HeaderXUrlScheme); scheme != "" { return scheme } return "http" } // RealIP returns the client's network address based on `X-Forwarded-For` // or `X-Real-IP` request header. // The behavior can be configured using `Echo#IPExtractor`. func (c *Context) RealIP() string { if c.echo != nil && c.echo.IPExtractor != nil { return c.echo.IPExtractor(c.request) } // Fall back to legacy behavior if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" { i := strings.IndexAny(ip, ",") if i > 0 { xffip := strings.TrimSpace(ip[:i]) xffip = strings.TrimPrefix(xffip, "[") xffip = strings.TrimSuffix(xffip, "]") return xffip } return ip } if ip := c.request.Header.Get(HeaderXRealIP); ip != "" { ip = strings.TrimPrefix(ip, "[") ip = strings.TrimSuffix(ip, "]") return ip } ra, _, _ := net.SplitHostPort(c.request.RemoteAddr) return ra } // Path returns the registered path for the handler. func (c *Context) Path() string { return c.path } // SetPath sets the registered path for the handler. func (c *Context) SetPath(p string) { c.path = p } // RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route. // // RouteInfo returns generic "empty" struct for these cases: // * Context is accessed before Routing is done. For example inside Pre middlewares (`e.Pre()`) // * Router did not find matching route - 404 (route not found) // * Router did not find matching route with same method - 405 (method not allowed) func (c *Context) RouteInfo() RouteInfo { if c.route != nil { return c.route.Clone() } return RouteInfo{} } // Param returns path parameter by name. func (c *Context) Param(name string) string { return c.pathValues.GetOr(name, "") } // ParamOr returns the path parameter or default value for the provided name. // // Notes for DefaultRouter implementation: // Path parameter could be empty for cases like that: // * route `/release-:version/bin` and request URL is `/release-/bin` // * route `/api/:version/image.jpg` and request URL is `/api//image.jpg` // but not when path parameter is last part of route path // * route `/download/file.:ext` will not match request `/download/file.` func (c *Context) ParamOr(name, defaultValue string) string { return c.pathValues.GetOr(name, defaultValue) } // PathValues returns path parameter values. func (c *Context) PathValues() PathValues { return *c.pathValues } // SetPathValues sets path parameters for current request. func (c *Context) SetPathValues(pathValues PathValues) { if pathValues == nil { panic("context SetPathValues called with nil PathValues") } c.setPathValues(&pathValues) } // InitializeRoute sets the route related variables of this request to the context. func (c *Context) InitializeRoute(ri *RouteInfo, pathValues *PathValues) { c.route = ri c.path = ri.Path c.setPathValues(pathValues) } func (c *Context) setPathValues(pv *PathValues) { // Router accesses c.pathValues by index and may resize it to full capacity during routing // for that to work without going out-of-bounds we must make sure that c.pathValues slice is not replaced with smaller // slice than Router can set when routing Route with maximum amount of parameters. pathValues := c.pathValues if cap(*c.pathValues) < len(*pv) { // normally we should not end up here. pathValues is normally sized to Echo.contextPathParamAllocSize which should not // be smaller than anything router knows as maximum path parameter count to be. tmp := make(PathValues, len(*pv)) c.pathValues = &tmp pathValues = c.pathValues } else if len(*c.pathValues) != len(*pv) { *pathValues = (*pathValues)[0:len(*pv)] // resize slice to given params length for copy to work } copy(*pathValues, *pv) } // QueryParam returns the query param for the provided name. func (c *Context) QueryParam(name string) string { if c.query == nil { c.query = c.request.URL.Query() } return c.query.Get(name) } // QueryParamOr returns the query param or default value for the provided name. // Note: QueryParamOr does not distinguish if query had no value by that name or value was empty string // This means URLs `/test?search=` and `/test` would both return `1` for `c.QueryParamOr("search", "1")` func (c *Context) QueryParamOr(name, defaultValue string) string { value := c.QueryParam(name) if value == "" { value = defaultValue } return value } // QueryParams returns the query parameters as `url.Values`. func (c *Context) QueryParams() url.Values { if c.query == nil { c.query = c.request.URL.Query() } return c.query } // QueryString returns the URL query string. func (c *Context) QueryString() string { return c.request.URL.RawQuery } // FormValue returns the form field value for the provided name. func (c *Context) FormValue(name string) string { return c.request.FormValue(name) } // FormValueOr returns the form field value or default value for the provided name. // Note: FormValueOr does not distinguish if form had no value by that name or value was empty string func (c *Context) FormValueOr(name, defaultValue string) string { value := c.FormValue(name) if value == "" { value = defaultValue } return value } // FormValues returns the form field values as `url.Values`. func (c *Context) FormValues() (url.Values, error) { if strings.HasPrefix(c.request.Header.Get(HeaderContentType), MIMEMultipartForm) { if err := c.request.ParseMultipartForm(c.formParseMaxMemory); err != nil { return nil, err } } else { if err := c.request.ParseForm(); err != nil { return nil, err } } return c.request.Form, nil } // FormFile returns the multipart form file for the provided name. func (c *Context) FormFile(name string) (*multipart.FileHeader, error) { f, fh, err := c.request.FormFile(name) if err != nil { return nil, err } _ = f.Close() return fh, nil } // MultipartForm returns the multipart form. func (c *Context) MultipartForm() (*multipart.Form, error) { err := c.request.ParseMultipartForm(c.formParseMaxMemory) return c.request.MultipartForm, err } // Cookie returns the named cookie provided in the request. func (c *Context) Cookie(name string) (*http.Cookie, error) { return c.request.Cookie(name) } // SetCookie adds a `Set-Cookie` header in HTTP response. func (c *Context) SetCookie(cookie *http.Cookie) { http.SetCookie(c.Response(), cookie) } // Cookies returns the HTTP cookies sent with the request. func (c *Context) Cookies() []*http.Cookie { return c.request.Cookies() } // Get retrieves data from the context. // Method returns any(nil) when key does not exist which is different from typed nil (eg. []byte(nil)). func (c *Context) Get(key string) any { c.lock.RLock() defer c.lock.RUnlock() return c.store[key] } // Set saves data in the context. func (c *Context) Set(key string, val any) { c.lock.Lock() defer c.lock.Unlock() if c.store == nil { c.store = make(map[string]any) } c.store[key] = val } // Bind binds path params, query params and the request body into provided type `i`. The default binder // binds body based on Content-Type header. func (c *Context) Bind(i any) error { return c.echo.Binder.Bind(c, i) } // Validate validates provided `i`. It is usually called after `Context#Bind()`. // Validator must be registered using `Echo#Validator`. func (c *Context) Validate(i any) error { if c.echo.Validator == nil { return ErrValidatorNotRegistered } return c.echo.Validator.Validate(i) } // Render renders a template with data and sends a text/html response with status // code. Renderer must be registered using `Echo.Renderer`. func (c *Context) Render(code int, name string, data any) (err error) { if c.echo.Renderer == nil { return ErrRendererNotRegistered } // as Renderer.Render can fail, and in that case we need to delay sending status code to the client until // (global) error handler decides the correct status code for the error to be sent to the client, so we need to write // the rendered template to the buffer first. // // html.Template.ExecuteTemplate() documentations writes: // > If an error occurs executing the template or writing its output, // > execution stops, but partial results may already have been written to // > the output writer. buf := new(bytes.Buffer) if err = c.echo.Renderer.Render(c, buf, name, data); err != nil { return } return c.HTMLBlob(code, buf.Bytes()) } // HTML sends an HTTP response with status code. func (c *Context) HTML(code int, html string) (err error) { return c.HTMLBlob(code, []byte(html)) } // HTMLBlob sends an HTTP blob response with status code. func (c *Context) HTMLBlob(code int, b []byte) (err error) { return c.Blob(code, MIMETextHTMLCharsetUTF8, b) } // String sends a string response with status code. func (c *Context) String(code int, s string) (err error) { return c.Blob(code, MIMETextPlainCharsetUTF8, []byte(s)) } func (c *Context) jsonPBlob(code int, callback string, i any) (err error) { c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8) c.response.WriteHeader(code) if _, err = c.response.Write([]byte(callback + "(")); err != nil { return } if err = c.echo.JSONSerializer.Serialize(c, i, ""); err != nil { return } if _, err = c.response.Write([]byte(");")); err != nil { return } return } func (c *Context) json(code int, i any, indent string) error { c.writeContentType(MIMEApplicationJSON) // as JSONSerializer.Serialize can fail, and in that case we need to delay sending status code to the client until // (global) error handler decides correct status code for the error to be sent to the client. // For that we need to use writer that can store the proposed status code until the first Write is called. if r, err := UnwrapResponse(c.response); err == nil { r.Status = code } else { resp := c.Response() c.SetResponse(&delayedStatusWriter{ResponseWriter: resp, status: code}) defer c.SetResponse(resp) } return c.echo.JSONSerializer.Serialize(c, i, indent) } // JSON sends a JSON response with status code. func (c *Context) JSON(code int, i any) (err error) { return c.json(code, i, "") } // JSONPretty sends a pretty-print JSON with status code. func (c *Context) JSONPretty(code int, i any, indent string) (err error) { return c.json(code, i, indent) } // JSONBlob sends a JSON blob response with status code. func (c *Context) JSONBlob(code int, b []byte) (err error) { return c.Blob(code, MIMEApplicationJSON, b) } // JSONP sends a JSONP response with status code. It uses `callback` to construct // the JSONP payload. func (c *Context) JSONP(code int, callback string, i any) (err error) { return c.jsonPBlob(code, callback, i) } // JSONPBlob sends a JSONP blob response with status code. It uses `callback` // to construct the JSONP payload. func (c *Context) JSONPBlob(code int, callback string, b []byte) (err error) { c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8) c.response.WriteHeader(code) if _, err = c.response.Write([]byte(callback + "(")); err != nil { return } if _, err = c.response.Write(b); err != nil { return } _, err = c.response.Write([]byte(");")) return } func (c *Context) xml(code int, i any, indent string) (err error) { c.writeContentType(MIMEApplicationXMLCharsetUTF8) c.response.WriteHeader(code) enc := xml.NewEncoder(c.response) if indent != "" { enc.Indent("", indent) } if _, err = c.response.Write([]byte(xml.Header)); err != nil { return } return enc.Encode(i) } // XML sends an XML response with status code. func (c *Context) XML(code int, i any) (err error) { return c.xml(code, i, "") } // XMLPretty sends a pretty-print XML with status code. func (c *Context) XMLPretty(code int, i any, indent string) (err error) { return c.xml(code, i, indent) } // XMLBlob sends an XML blob response with status code. func (c *Context) XMLBlob(code int, b []byte) (err error) { c.writeContentType(MIMEApplicationXMLCharsetUTF8) c.response.WriteHeader(code) if _, err = c.response.Write([]byte(xml.Header)); err != nil { return } _, err = c.response.Write(b) return } // Blob sends a blob response with status code and content type. func (c *Context) Blob(code int, contentType string, b []byte) (err error) { c.writeContentType(contentType) c.response.WriteHeader(code) _, err = c.response.Write(b) return } // Stream sends a streaming response with status code and content type. func (c *Context) Stream(code int, contentType string, r io.Reader) (err error) { c.writeContentType(contentType) c.response.WriteHeader(code) _, err = io.Copy(c.response, r) return } // File sends a response with the content of the file. func (c *Context) File(file string) error { return fsFile(c, file, c.echo.Filesystem) } // FileFS serves file from given file system. // // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths // including `assets/images` as their prefix. func (c *Context) FileFS(file string, filesystem fs.FS) error { return fsFile(c, file, filesystem) } func fsFile(c *Context, file string, filesystem fs.FS) error { file = path.Clean(file) // `os.Open` and `os.DirFs.Open()` behave differently, later does not like ``, `.`, `..` at all, but we allowed those now need to clean f, err := filesystem.Open(file) if err != nil { return ErrNotFound } defer f.Close() fi, _ := f.Stat() if fi.IsDir() { file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect. f, err = filesystem.Open(file) if err != nil { return ErrNotFound } defer f.Close() if fi, err = f.Stat(); err != nil { return err } } ff, ok := f.(io.ReadSeeker) if !ok { return errors.New("file does not implement io.ReadSeeker") } http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff) return nil } // Attachment sends a response as attachment, prompting client to save the file. func (c *Context) Attachment(file, name string) error { return c.contentDisposition(file, name, "attachment") } // Inline sends a response as inline, opening the file in the browser. func (c *Context) Inline(file, name string) error { return c.contentDisposition(file, name, "inline") } var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") func (c *Context) contentDisposition(file, name, dispositionType string) error { c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf(`%s; filename="%s"`, dispositionType, quoteEscaper.Replace(name))) return c.File(file) } // NoContent sends a response with no body and a status code. func (c *Context) NoContent(code int) error { c.response.WriteHeader(code) return nil } // Redirect redirects the request to a provided URL with status code. func (c *Context) Redirect(code int, url string) error { if code < 300 || code > 308 { return ErrInvalidRedirectCode } c.response.Header().Set(HeaderLocation, url) c.response.WriteHeader(code) return nil } // Logger returns logger in Context func (c *Context) Logger() *slog.Logger { if c.logger != nil { return c.logger } return c.echo.Logger } // SetLogger sets logger in Context func (c *Context) SetLogger(logger *slog.Logger) { c.logger = logger } // Echo returns the `Echo` instance. func (c *Context) Echo() *Echo { return c.echo } ================================================ FILE: context_generic.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echo import "errors" // ErrNonExistentKey is error that is returned when key does not exist var ErrNonExistentKey = errors.New("non existent key") // ErrInvalidKeyType is error that is returned when the value is not castable to expected type. var ErrInvalidKeyType = errors.New("invalid key type") // ContextGet retrieves a value from the context store or ErrNonExistentKey error the key is missing. // Returns ErrInvalidKeyType error if the value is not castable to type T. func ContextGet[T any](c *Context, key string) (T, error) { c.lock.RLock() defer c.lock.RUnlock() val, ok := c.store[key] if !ok { var zero T return zero, ErrNonExistentKey } typed, ok := val.(T) if !ok { var zero T return zero, ErrInvalidKeyType } return typed, nil } // ContextGetOr retrieves a value from the context store or returns a default value when the key // is missing. Returns ErrInvalidKeyType error if the value is not castable to type T. func ContextGetOr[T any](c *Context, key string, defaultValue T) (T, error) { typed, err := ContextGet[T](c, key) if err == ErrNonExistentKey { return defaultValue, nil } return typed, err } ================================================ FILE: context_generic_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echo import ( "testing" "github.com/stretchr/testify/assert" ) func TestContextGetOK(t *testing.T) { c := NewContext(nil, nil) c.Set("key", int64(123)) v, err := ContextGet[int64](c, "key") assert.NoError(t, err) assert.Equal(t, int64(123), v) } func TestContextGetNonExistentKey(t *testing.T) { c := NewContext(nil, nil) c.Set("key", int64(123)) v, err := ContextGet[int64](c, "nope") assert.ErrorIs(t, err, ErrNonExistentKey) assert.Equal(t, int64(0), v) } func TestContextGetInvalidCast(t *testing.T) { c := NewContext(nil, nil) c.Set("key", int64(123)) v, err := ContextGet[bool](c, "key") assert.ErrorIs(t, err, ErrInvalidKeyType) assert.Equal(t, false, v) } func TestContextGetOrOK(t *testing.T) { c := NewContext(nil, nil) c.Set("key", int64(123)) v, err := ContextGetOr[int64](c, "key", 999) assert.NoError(t, err) assert.Equal(t, int64(123), v) } func TestContextGetOrNonExistentKey(t *testing.T) { c := NewContext(nil, nil) c.Set("key", int64(123)) v, err := ContextGetOr[int64](c, "nope", 999) assert.NoError(t, err) assert.Equal(t, int64(999), v) } func TestContextGetOrInvalidCast(t *testing.T) { c := NewContext(nil, nil) c.Set("key", int64(123)) v, err := ContextGetOr[float32](c, "key", float32(999)) assert.ErrorIs(t, err, ErrInvalidKeyType) assert.Equal(t, float32(0), v) } ================================================ FILE: context_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echo import ( "bytes" "crypto/tls" "encoding/json" "encoding/xml" "fmt" "io" "io/fs" "log/slog" "math" "mime/multipart" "net/http" "net/http/httptest" "net/url" "os" "strings" "testing" "text/template" "time" "github.com/stretchr/testify/assert" ) type Template struct { templates *template.Template } var testUser = user{ID: 1, Name: "Jon Snow"} func BenchmarkAllocJSONP(b *testing.B) { e := New() e.Logger = slog.New(slog.DiscardHandler) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { c.JSONP(http.StatusOK, "callback", testUser) } } func BenchmarkAllocJSON(b *testing.B) { e := New() e.Logger = slog.New(slog.DiscardHandler) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { c.JSON(http.StatusOK, testUser) } } func BenchmarkAllocXML(b *testing.B) { e := New() e.Logger = slog.New(slog.DiscardHandler) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { c.XML(http.StatusOK, testUser) } } func BenchmarkRealIPForHeaderXForwardFor(b *testing.B) { c := Context{request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}}, }} for i := 0; i < b.N; i++ { c.RealIP() } } func (t *Template) Render(c *Context, w io.Writer, name string, data any) error { return t.templates.ExecuteTemplate(w, name, data) } func TestContextEcho(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) assert.Equal(t, e, c.Echo()) } func TestContextRequest(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) assert.NotNil(t, c.Request()) assert.Equal(t, req, c.Request()) } func TestContextResponse(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) assert.NotNil(t, c.Response()) } func TestContextRenderTemplate(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) tmpl := &Template{ templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")), } c.Echo().Renderer = tmpl err := c.Render(http.StatusOK, "hello", "Jon Snow") if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "Hello, Jon Snow!", rec.Body.String()) } } func TestContextRenderTemplateError(t *testing.T) { // we test that when template rendering fails, no response is sent to the client yet, so the global error handler can decide what to do e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) tmpl := &Template{ templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")), } c.Echo().Renderer = tmpl err := c.Render(http.StatusOK, "not_existing", "Jon Snow") assert.EqualError(t, err, `template: no template "not_existing" associated with template "hello"`) assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client assert.Empty(t, rec.Body.String()) // body must not be sent to the client } func TestContextRenderErrorsOnNoRenderer(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) c.Echo().Renderer = nil assert.Error(t, c.Render(http.StatusOK, "hello", "Jon Snow")) } func TestContextStream(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) c := e.NewContext(req, rec) r, w := io.Pipe() go func() { defer w.Close() for i := 0; i < 3; i++ { fmt.Fprintf(w, "data: index %v\n\n", i) time.Sleep(5 * time.Millisecond) } }() err := c.Stream(http.StatusOK, "text/event-stream", r) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "text/event-stream", rec.Header().Get(HeaderContentType)) assert.Equal(t, "data: index 0\n\ndata: index 1\n\ndata: index 2\n\n", rec.Body.String()) } } func TestContextHTML(t *testing.T) { rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) c := NewContext(req, rec) err := c.HTML(http.StatusOK, "Hi, Jon Snow") if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, "Hi, Jon Snow", rec.Body.String()) } } func TestContextHTMLBlob(t *testing.T) { rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) c := NewContext(req, rec) err := c.HTMLBlob(http.StatusOK, []byte("Hi, Jon Snow")) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, "Hi, Jon Snow", rec.Body.String()) } } func TestContextJSON(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) c := e.NewContext(req, rec) err := c.JSON(http.StatusOK, user{ID: 1, Name: "Jon Snow"}) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) assert.Equal(t, userJSON+"\n", rec.Body.String()) } } func TestContextJSONErrorsOut(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) c := e.NewContext(req, rec) err := c.JSON(http.StatusOK, make(chan bool)) assert.EqualError(t, err, "json: unsupported type: chan bool") assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client assert.Empty(t, rec.Body.String()) // body must not be sent to the client } func TestContextJSONWithNotEchoResponse(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) c := e.NewContext(req, rec) c.SetResponse(rec) err := c.JSON(http.StatusCreated, map[string]float64{"foo": math.NaN()}) assert.EqualError(t, err, "json: unsupported value: NaN") assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client assert.Empty(t, rec.Body.String()) // body must not be sent to the client } func TestContextJSONPretty(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) c := e.NewContext(req, rec) err := c.JSONPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, " ") if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) } } func TestContextJSONWithEmptyIntent(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) c := e.NewContext(req, rec) u := user{ID: 1, Name: "Jon Snow"} emptyIndent := "" buf := new(bytes.Buffer) enc := json.NewEncoder(buf) enc.SetIndent(emptyIndent, emptyIndent) _ = enc.Encode(u) err := c.JSONPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, emptyIndent) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) assert.Equal(t, buf.String(), rec.Body.String()) } } func TestContextJSONP(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) c := e.NewContext(req, rec) callback := "callback" err := c.JSONP(http.StatusOK, callback, user{ID: 1, Name: "Jon Snow"}) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, callback+"("+userJSON+"\n);", rec.Body.String()) } } func TestContextJSONBlob(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) c := e.NewContext(req, rec) data, err := json.Marshal(user{ID: 1, Name: "Jon Snow"}) assert.NoError(t, err) err = c.JSONBlob(http.StatusOK, data) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) assert.Equal(t, userJSON, rec.Body.String()) } } func TestContextJSONPBlob(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) c := e.NewContext(req, rec) callback := "callback" data, err := json.Marshal(user{ID: 1, Name: "Jon Snow"}) assert.NoError(t, err) err = c.JSONPBlob(http.StatusOK, callback, data) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, callback+"("+userJSON+");", rec.Body.String()) } } func TestContextXML(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) c := e.NewContext(req, rec) err := c.XML(http.StatusOK, user{ID: 1, Name: "Jon Snow"}) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, xml.Header+userXML, rec.Body.String()) } } func TestContextXMLPretty(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) c := e.NewContext(req, rec) err := c.XMLPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, " ") if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String()) } } func TestContextXMLBlob(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) c := e.NewContext(req, rec) data, err := xml.Marshal(user{ID: 1, Name: "Jon Snow"}) assert.NoError(t, err) err = c.XMLBlob(http.StatusOK, data) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, xml.Header+userXML, rec.Body.String()) } } func TestContextXMLWithEmptyIntent(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) c := e.NewContext(req, rec) u := user{ID: 1, Name: "Jon Snow"} emptyIndent := "" buf := new(bytes.Buffer) enc := xml.NewEncoder(buf) enc.Indent(emptyIndent, emptyIndent) _ = enc.Encode(u) err := c.XMLPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, emptyIndent) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, xml.Header+buf.String(), rec.Body.String()) } } func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := c.JSON(http.StatusCreated, user{ID: 1, Name: "Jon Snow"}) if assert.NoError(t, err) { assert.Equal(t, http.StatusCreated, rec.Code) assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) assert.Equal(t, userJSON+"\n", rec.Body.String()) } } func TestContextAttachment(t *testing.T) { var testCases = []struct { name string whenName string expectHeader string }{ { name: "ok", whenName: "walle.png", expectHeader: `attachment; filename="walle.png"`, }, { name: "ok, escape quotes in malicious filename", whenName: `malicious.sh"; \"; dummy=.txt`, expectHeader: `attachment; filename="malicious.sh\"; \\\"; dummy=.txt"`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) c := e.NewContext(req, rec) err := c.Attachment("_fixture/images/walle.png", tc.whenName) if assert.NoError(t, err) { assert.Equal(t, tc.expectHeader, rec.Header().Get(HeaderContentDisposition)) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, 219885, rec.Body.Len()) } }) } } func TestContextInline(t *testing.T) { var testCases = []struct { name string whenName string expectHeader string }{ { name: "ok", whenName: "walle.png", expectHeader: `inline; filename="walle.png"`, }, { name: "ok, escape quotes in malicious filename", whenName: `malicious.sh"; \"; dummy=.txt`, expectHeader: `inline; filename="malicious.sh\"; \\\"; dummy=.txt"`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) c := e.NewContext(req, rec) err := c.Inline("_fixture/images/walle.png", tc.whenName) if assert.NoError(t, err) { assert.Equal(t, tc.expectHeader, rec.Header().Get(HeaderContentDisposition)) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, 219885, rec.Body.Len()) } }) } } func TestContextNoContent(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) c := e.NewContext(req, rec) c.NoContent(http.StatusOK) assert.Equal(t, http.StatusOK, rec.Code) } func TestContextCookie(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) theme := "theme=light" user := "user=Jon Snow" req.Header.Add(HeaderCookie, theme) req.Header.Add(HeaderCookie, user) rec := httptest.NewRecorder() c := e.NewContext(req, rec) // Read single cookie, err := c.Cookie("theme") if assert.NoError(t, err) { assert.Equal(t, "theme", cookie.Name) assert.Equal(t, "light", cookie.Value) } // Read multiple for _, cookie := range c.Cookies() { switch cookie.Name { case "theme": assert.Equal(t, "light", cookie.Value) case "user": assert.Equal(t, "Jon Snow", cookie.Value) } } // Write cookie = &http.Cookie{ Name: "SSID", Value: "Ap4PGTEq", Domain: "labstack.com", Path: "/", Expires: time.Now(), Secure: true, HttpOnly: true, } c.SetCookie(cookie) assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID") assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq") assert.Contains(t, rec.Header().Get(HeaderSetCookie), "labstack.com") assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Secure") assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly") } func TestContext_PathValues(t *testing.T) { var testCases = []struct { name string given PathValues expect PathValues }{ { name: "param exists", given: PathValues{ {Name: "uid", Value: "101"}, {Name: "fid", Value: "501"}, }, expect: PathValues{ {Name: "uid", Value: "101"}, {Name: "fid", Value: "501"}, }, }, { name: "params is empty", given: PathValues{}, expect: PathValues{}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) c := e.NewContext(req, nil) c.SetPathValues(tc.given) assert.EqualValues(t, tc.expect, c.PathValues()) }) } } func TestContext_PathParam(t *testing.T) { var testCases = []struct { name string given PathValues whenParamName string expect string }{ { name: "param exists", given: PathValues{ {Name: "uid", Value: "101"}, {Name: "fid", Value: "501"}, }, whenParamName: "uid", expect: "101", }, { name: "multiple same param values exists - return first", given: PathValues{ {Name: "uid", Value: "101"}, {Name: "uid", Value: "202"}, {Name: "fid", Value: "501"}, }, whenParamName: "uid", expect: "101", }, { name: "param does not exists", given: PathValues{ {Name: "uid", Value: "101"}, }, whenParamName: "nope", expect: "", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) c := e.NewContext(req, nil) c.SetPathValues(tc.given) assert.EqualValues(t, tc.expect, c.Param(tc.whenParamName)) }) } } func TestContext_PathParamDefault(t *testing.T) { var testCases = []struct { name string given PathValues whenParamName string whenDefaultValue string expect string }{ { name: "param exists", given: PathValues{ {Name: "uid", Value: "101"}, {Name: "fid", Value: "501"}, }, whenParamName: "uid", whenDefaultValue: "999", expect: "101", }, { name: "param exists and is empty", given: PathValues{ {Name: "uid", Value: ""}, {Name: "fid", Value: "501"}, }, whenParamName: "uid", whenDefaultValue: "999", expect: "", // <-- this is different from QueryParamOr behaviour }, { name: "param does not exists", given: PathValues{ {Name: "uid", Value: "101"}, }, whenParamName: "nope", whenDefaultValue: "999", expect: "999", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) c := e.NewContext(req, nil) c.SetPathValues(tc.given) assert.EqualValues(t, tc.expect, c.ParamOr(tc.whenParamName, tc.whenDefaultValue)) }) } } func TestContextGetAndSetPathValuesMutability(t *testing.T) { t.Run("c.PathValues() does not return copy and modifying raw slice mutates value in context", func(t *testing.T) { e := New() e.contextPathParamAllocSize.Store(1) req := httptest.NewRequest(http.MethodGet, "/:foo", nil) c := e.NewContext(req, nil) params := PathValues{{Name: "foo", Value: "101"}} c.SetPathValues(params) // round-trip param values with modification paramVals := c.PathValues() assert.Equal(t, params, c.PathValues()) // PathValues() does not return copy and modifying raw slice mutates value in context paramVals[0] = PathValue{Name: "xxx", Value: "yyy"} assert.Equal(t, PathValues{PathValue{Name: "xxx", Value: "yyy"}}, c.PathValues()) }) t.Run("calling SetPathValues with bigger size changes capacity in context", func(t *testing.T) { e := New() e.contextPathParamAllocSize.Store(1) req := httptest.NewRequest(http.MethodGet, "/:foo", nil) c := e.NewContext(req, nil) // increase path param capacity in context pathValues := PathValues{ {Name: "aaa", Value: "bbb"}, {Name: "ccc", Value: "ddd"}, } c.SetPathValues(pathValues) assert.Equal(t, pathValues, c.PathValues()) // shouldn't explode during Reset() afterwards! assert.NotPanics(t, func() { c.Reset(nil, nil) }) assert.Equal(t, PathValues{}, c.PathValues()) assert.Len(t, *c.pathValues, 0) assert.Equal(t, 2, cap(*c.pathValues)) }) t.Run("calling SetPathValues with smaller size slice does not change capacity in context", func(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/:foo", nil) c := e.NewContext(req, nil) c.pathValues = &PathValues{ {Name: "aaa", Value: "bbb"}, {Name: "ccc", Value: "ddd"}, } pathValues := PathValues{ {Name: "aaa", Value: "bbb"}, } // given pathValues slice is smaller. this should not decrease c.pathValues capacity c.SetPathValues(pathValues) assert.Equal(t, pathValues, c.PathValues()) // shouldn't explode during Reset() afterwards! assert.NotPanics(t, func() { c.Reset(nil, nil) }) assert.Equal(t, PathValues{}, c.PathValues()) assert.Len(t, *c.pathValues, 0) assert.Equal(t, 2, cap(*c.pathValues)) }) } // Issue #1655 func TestContext_SetParamNamesShouldNotModifyPathValuesCapacity(t *testing.T) { e := New() c := e.NewContext(nil, nil) assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load()) expectedTwoParams := PathValues{ {Name: "1", Value: "one"}, {Name: "2", Value: "two"}, } c.SetPathValues(expectedTwoParams) assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load()) assert.Equal(t, expectedTwoParams, c.PathValues()) expectedThreeParams := PathValues{ {Name: "1", Value: "one"}, {Name: "2", Value: "two"}, {Name: "3", Value: "three"}, } c.SetPathValues(expectedThreeParams) assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load()) assert.Equal(t, expectedThreeParams, c.PathValues()) } func TestContextFormValue(t *testing.T) { f := make(url.Values) f.Set("name", "Jon Snow") f.Set("email", "jon@labstack.com") e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) req.Header.Add(HeaderContentType, MIMEApplicationForm) c := e.NewContext(req, nil) // FormValue assert.Equal(t, "Jon Snow", c.FormValue("name")) assert.Equal(t, "jon@labstack.com", c.FormValue("email")) // FormValueOr assert.Equal(t, "Jon Snow", c.FormValueOr("name", "nope")) assert.Equal(t, "default", c.FormValueOr("missing", "default")) // FormValues values, err := c.FormValues() if assert.NoError(t, err) { assert.Equal(t, url.Values{ "name": []string{"Jon Snow"}, "email": []string{"jon@labstack.com"}, }, values) } // Multipart FormParams error req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) req.Header.Add(HeaderContentType, MIMEMultipartForm) c = e.NewContext(req, nil) values, err = c.FormValues() assert.Nil(t, values) assert.Error(t, err) } func TestContext_QueryParams(t *testing.T) { var testCases = []struct { expect url.Values name string givenURL string }{ { name: "multiple values in url", givenURL: "/?test=1&test=2&email=jon%40labstack.com", expect: url.Values{ "test": []string{"1", "2"}, "email": []string{"jon@labstack.com"}, }, }, { name: "single value in url", givenURL: "/?nope=1", expect: url.Values{ "nope": []string{"1"}, }, }, { name: "no query params in url", givenURL: "/?", expect: url.Values{}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) e := New() c := e.NewContext(req, nil) assert.Equal(t, tc.expect, c.QueryParams()) }) } } func TestContext_QueryParam(t *testing.T) { var testCases = []struct { name string givenURL string whenParamName string expect string }{ { name: "value exists in url", givenURL: "/?test=1", whenParamName: "test", expect: "1", }, { name: "multiple values exists in url", givenURL: "/?test=9&test=8", whenParamName: "test", expect: "9", // <-- first value in returned }, { name: "value does not exists in url", givenURL: "/?nope=1", whenParamName: "test", expect: "", }, { name: "value is empty in url", givenURL: "/?test=", whenParamName: "test", expect: "", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) e := New() c := e.NewContext(req, nil) assert.Equal(t, tc.expect, c.QueryParam(tc.whenParamName)) }) } } func TestContext_QueryParamDefault(t *testing.T) { var testCases = []struct { name string givenURL string whenParamName string whenDefaultValue string expect string }{ { name: "value exists in url", givenURL: "/?test=1", whenParamName: "test", whenDefaultValue: "999", expect: "1", }, { name: "value does not exists in url", givenURL: "/?nope=1", whenParamName: "test", whenDefaultValue: "999", expect: "999", }, { name: "value is empty in url", givenURL: "/?test=", whenParamName: "test", whenDefaultValue: "999", expect: "999", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) e := New() c := e.NewContext(req, nil) assert.Equal(t, tc.expect, c.QueryParamOr(tc.whenParamName, tc.whenDefaultValue)) }) } } func TestContextFormFile(t *testing.T) { e := New() buf := new(bytes.Buffer) mr := multipart.NewWriter(buf) w, err := mr.CreateFormFile("file", "test") if assert.NoError(t, err) { w.Write([]byte("test")) } mr.Close() req := httptest.NewRequest(http.MethodPost, "/", buf) req.Header.Set(HeaderContentType, mr.FormDataContentType()) rec := httptest.NewRecorder() c := e.NewContext(req, rec) f, err := c.FormFile("file") if assert.NoError(t, err) { assert.Equal(t, "test", f.Filename) } } func TestContextMultipartForm(t *testing.T) { e := New() buf := new(bytes.Buffer) mw := multipart.NewWriter(buf) mw.WriteField("name", "Jon Snow") fileContent := "This is a test file" w, err := mw.CreateFormFile("file", "test.txt") if assert.NoError(t, err) { w.Write([]byte(fileContent)) } mw.Close() req := httptest.NewRequest(http.MethodPost, "/", buf) req.Header.Set(HeaderContentType, mw.FormDataContentType()) rec := httptest.NewRecorder() c := e.NewContext(req, rec) f, err := c.MultipartForm() if assert.NoError(t, err) { assert.NotNil(t, f) files := f.File["file"] if assert.Len(t, files, 1) { file := files[0] assert.Equal(t, "test.txt", file.Filename) assert.Equal(t, int64(len(fileContent)), file.Size) } } } func TestContextRedirect(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) assert.Equal(t, http.StatusMovedPermanently, rec.Code) assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) } func TestContextGet(t *testing.T) { var testCases = []struct { name string given any whenKey string expect any }{ { name: "ok, value exist", given: "Jon Snow", whenKey: "key", expect: "Jon Snow", }, { name: "ok, value does not exist", given: "Jon Snow", whenKey: "nope", expect: nil, }, { name: "ok, value is nil value", given: []byte(nil), whenKey: "key", expect: []byte(nil), }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var c = new(Context) c.Set("key", tc.given) v := c.Get(tc.whenKey) assert.Equal(t, tc.expect, v) }) } } func BenchmarkContext_Store(b *testing.B) { e := &Echo{} c := &Context{ echo: e, } for n := 0; n < b.N; n++ { c.Set("name", "Jon Snow") if c.Get("name") != "Jon Snow" { b.Fail() } } } type validator struct{} func (*validator) Validate(i any) error { return nil } func TestContext_Validate(t *testing.T) { e := New() c := e.NewContext(nil, nil) assert.Error(t, c.Validate(struct{}{})) e.Validator = &validator{} assert.NoError(t, c.Validate(struct{}{})) } func TestContext_QueryString(t *testing.T) { e := New() queryString := "query=string&var=val" req := httptest.NewRequest(http.MethodGet, "/?"+queryString, nil) c := e.NewContext(req, nil) assert.Equal(t, queryString, c.QueryString()) } func TestContext_Request(t *testing.T) { var c = new(Context) assert.Nil(t, c.Request()) req := httptest.NewRequest(http.MethodGet, "/path", nil) c.SetRequest(req) assert.Equal(t, req, c.Request()) } func TestContext_Scheme(t *testing.T) { tests := []struct { c *Context s string }{ { &Context{ request: &http.Request{ TLS: &tls.ConnectionState{}, }, }, "https", }, { &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedProto: []string{"https"}}, }, }, "https", }, { &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedProtocol: []string{"http"}}, }, }, "http", }, { &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedSsl: []string{"on"}}, }, }, "https", }, { &Context{ request: &http.Request{ Header: http.Header{HeaderXUrlScheme: []string{"https"}}, }, }, "https", }, { &Context{ request: &http.Request{}, }, "http", }, } for _, tt := range tests { assert.Equal(t, tt.s, tt.c.Scheme()) } } func TestContext_IsWebSocket(t *testing.T) { tests := []struct { c *Context ws assert.BoolAssertionFunc }{ { &Context{ request: &http.Request{ Header: http.Header{ HeaderUpgrade: []string{"websocket"}, HeaderConnection: []string{"upgrade"}, }, }, }, assert.True, }, { &Context{ request: &http.Request{ Header: http.Header{ HeaderUpgrade: []string{"Websocket"}, HeaderConnection: []string{"Upgrade"}, }, }, }, assert.True, }, { &Context{ request: &http.Request{}, }, assert.False, }, { &Context{ request: &http.Request{ Header: http.Header{HeaderUpgrade: []string{"other"}}, }, }, assert.False, }, { &Context{ request: &http.Request{ Header: http.Header{ HeaderUpgrade: []string{"websocket"}, HeaderConnection: []string{"close"}, }, }, }, assert.False, }, } for i, tt := range tests { t.Run(fmt.Sprintf("test %d", i+1), func(t *testing.T) { tt.ws(t, tt.c.IsWebSocket()) }) } } func TestContext_Bind(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) c := e.NewContext(req, nil) u := new(user) req.Header.Add(HeaderContentType, MIMEApplicationJSON) err := c.Bind(u) assert.NoError(t, err) assert.Equal(t, &user{ID: 1, Name: "Jon Snow"}, u) } func TestContext_RealIP(t *testing.T) { tests := []struct { c *Context s string }{ { &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}}, }, }, "127.0.0.1", }, { &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1,127.0.1.1"}}, }, }, "127.0.0.1", }, { &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1"}}, }, }, "127.0.0.1", }, { &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348], 2001:db8::1, "}}, }, }, "2001:db8:85a3:8d3:1319:8a2e:370:7348", }, { &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348],[2001:db8::1]"}}, }, }, "2001:db8:85a3:8d3:1319:8a2e:370:7348", }, { &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"2001:db8:85a3:8d3:1319:8a2e:370:7348"}}, }, }, "2001:db8:85a3:8d3:1319:8a2e:370:7348", }, { &Context{ request: &http.Request{ Header: http.Header{ "X-Real-Ip": []string{"192.168.0.1"}, }, }, }, "192.168.0.1", }, { &Context{ request: &http.Request{ Header: http.Header{ "X-Real-Ip": []string{"[2001:db8::1]"}, }, }, }, "2001:db8::1", }, { &Context{ request: &http.Request{ RemoteAddr: "89.89.89.89:1654", }, }, "89.89.89.89", }, } for _, tt := range tests { assert.Equal(t, tt.s, tt.c.RealIP()) } } func TestContext_File(t *testing.T) { var testCases = []struct { whenFS fs.FS name string whenFile string expectError string expectStartsWith []byte expectStatus int }{ { name: "ok, from default file system", whenFile: "_fixture/images/walle.png", whenFS: nil, expectStatus: http.StatusOK, expectStartsWith: []byte{0x89, 0x50, 0x4e}, }, { name: "ok, from custom file system", whenFile: "walle.png", whenFS: os.DirFS("_fixture/images"), expectStatus: http.StatusOK, expectStartsWith: []byte{0x89, 0x50, 0x4e}, }, { name: "nok, not existent file", whenFile: "not.png", whenFS: os.DirFS("_fixture/images"), expectStatus: http.StatusOK, expectStartsWith: nil, expectError: "Not Found", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() if tc.whenFS != nil { e.Filesystem = tc.whenFS } handler := func(ec *Context) error { return ec.File(tc.whenFile) } req := httptest.NewRequest(http.MethodGet, "/match.png", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := handler(c) assert.Equal(t, tc.expectStatus, rec.Code) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } body := rec.Body.Bytes() if len(body) > len(tc.expectStartsWith) { body = body[:len(tc.expectStartsWith)] } assert.Equal(t, tc.expectStartsWith, body) }) } } func TestContext_FileFS(t *testing.T) { var testCases = []struct { whenFS fs.FS name string whenFile string expectError string expectStartsWith []byte expectStatus int }{ { name: "ok", whenFile: "walle.png", whenFS: os.DirFS("_fixture/images"), expectStatus: http.StatusOK, expectStartsWith: []byte{0x89, 0x50, 0x4e}, }, { name: "nok, not existent file", whenFile: "not.png", whenFS: os.DirFS("_fixture/images"), expectStatus: http.StatusOK, expectStartsWith: nil, expectError: "Not Found", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() handler := func(ec *Context) error { return ec.FileFS(tc.whenFile, tc.whenFS) } req := httptest.NewRequest(http.MethodGet, "/match.png", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := handler(c) assert.Equal(t, tc.expectStatus, rec.Code) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } body := rec.Body.Bytes() if len(body) > len(tc.expectStartsWith) { body = body[:len(tc.expectStartsWith)] } assert.Equal(t, tc.expectStartsWith, body) }) } } func TestLogger(t *testing.T) { e := New() c := e.NewContext(nil, nil) log1 := c.Logger() assert.NotNil(t, log1) assert.Equal(t, e.Logger, log1) customLogger := slog.New(slog.NewTextHandler(os.Stdout, nil)) c.SetLogger(customLogger) assert.Equal(t, customLogger, c.Logger()) // Resetting the context returns the initial Echo logger c.Reset(nil, nil) assert.Equal(t, e.Logger, c.Logger()) } func TestRouteInfo(t *testing.T) { e := New() c := e.NewContext(nil, nil) orgRI := RouteInfo{ Name: "root", Method: http.MethodGet, Path: "/*", Parameters: []string{"*"}, } c.route = &orgRI ri := c.RouteInfo() assert.Equal(t, orgRI, ri) // Test mutability when middlewares start to change things // RouteInfo inside context will not be affected when returned instance is changed expect := orgRI.Clone() ri.Path = "changed" ri.Parameters[0] = "changed" assert.Equal(t, expect, c.RouteInfo()) // RouteInfo inside context will not be affected when returned instance is changed expect = c.RouteInfo() orgRI.Name = "changed" assert.NotEqual(t, expect, c.RouteInfo()) } ================================================ FILE: echo.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors /* Package echo implements high performance, minimalist Go web framework. Example: package main import ( "log/slog" "net/http" "github.com/labstack/echo/v5" "github.com/labstack/echo/v5/middleware" ) // Handler func hello(c *echo.Context) error { return c.String(http.StatusOK, "Hello, World!") } func main() { // Echo instance e := echo.New() // Middleware e.Use(middleware.RequestLogger()) e.Use(middleware.Recover()) // Routes e.GET("/", hello) // Start server if err := e.Start(":8080"); err != nil { slog.Error("failed to start server", "error", err) } } Learn more at https://echo.labstack.com */ package echo import ( stdContext "context" "encoding/json" "errors" "fmt" "io/fs" "log/slog" "net/http" "net/url" "os" "os/signal" "path/filepath" "strings" "sync" "sync/atomic" "syscall" ) // Echo is the top-level framework instance. // // Goroutine safety: Do not mutate Echo instance fields after server has started. Accessing these // fields from handlers/middlewares and changing field values at the same time leads to data-races. // Same rule applies to adding new routes after server has been started - Adding a route is not Goroutine safe action. type Echo struct { serveHTTPFunc func(http.ResponseWriter, *http.Request) Binder Binder Filesystem fs.FS Renderer Renderer Validator Validator JSONSerializer JSONSerializer IPExtractor IPExtractor OnAddRoute func(route Route) error HTTPErrorHandler HTTPErrorHandler Logger *slog.Logger contextPool sync.Pool router Router // premiddleware are middlewares that are called before routing is done premiddleware []MiddlewareFunc // middleware are middlewares that are called after routing is done and before handler is called middleware []MiddlewareFunc contextPathParamAllocSize atomic.Int32 // formParseMaxMemory is passed to Context for multipart form parsing (See http.Request.ParseMultipartForm) formParseMaxMemory int64 } // JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. type JSONSerializer interface { Serialize(c *Context, target any, indent string) error Deserialize(c *Context, target any) error } // HTTPErrorHandler is a centralized HTTP error handler. type HTTPErrorHandler func(c *Context, err error) // HandlerFunc defines a function to serve HTTP requests. type HandlerFunc func(c *Context) error // MiddlewareFunc defines a function to process middleware. type MiddlewareFunc func(next HandlerFunc) HandlerFunc // MiddlewareConfigurator defines interface for creating middleware handlers with possibility to return configuration errors instead of panicking. type MiddlewareConfigurator interface { ToMiddleware() (MiddlewareFunc, error) } // Validator is the interface that wraps the Validate function. type Validator interface { Validate(i any) error } // MIME types const ( // MIMEApplicationJSON JavaScript Object Notation (JSON) https://www.rfc-editor.org/rfc/rfc8259 MIMEApplicationJSON = "application/json" // Deprecated: Please use MIMEApplicationJSON instead. JSON should be encoded using UTF-8 by default. // No "charset" parameter is defined for this registration. // Adding one really has no effect on compliant recipients. // See RFC 8259, section 8.1. https://datatracker.ietf.org/doc/html/rfc8259#section-8.1n" MIMEApplicationJSONCharsetUTF8 = MIMEApplicationJSON + "; " + charsetUTF8 MIMEApplicationJavaScript = "application/javascript" MIMEApplicationJavaScriptCharsetUTF8 = MIMEApplicationJavaScript + "; " + charsetUTF8 MIMEApplicationXML = "application/xml" MIMEApplicationXMLCharsetUTF8 = MIMEApplicationXML + "; " + charsetUTF8 MIMETextXML = "text/xml" MIMETextXMLCharsetUTF8 = MIMETextXML + "; " + charsetUTF8 MIMEApplicationForm = "application/x-www-form-urlencoded" MIMEApplicationProtobuf = "application/protobuf" MIMEApplicationMsgpack = "application/msgpack" MIMETextHTML = "text/html" MIMETextHTMLCharsetUTF8 = MIMETextHTML + "; " + charsetUTF8 MIMETextPlain = "text/plain" MIMETextPlainCharsetUTF8 = MIMETextPlain + "; " + charsetUTF8 MIMEMultipartForm = "multipart/form-data" MIMEOctetStream = "application/octet-stream" ) const ( charsetUTF8 = "charset=UTF-8" // PROPFIND Method can be used on collection and property resources. PROPFIND = "PROPFIND" // REPORT Method can be used to get information about a resource, see rfc 3253 REPORT = "REPORT" // RouteNotFound is special method type for routes handling "route not found" (404) cases RouteNotFound = "echo_route_not_found" // RouteAny is special method type that matches any HTTP method in request. Any has lower // priority that other methods that have been registered with Router to that path. RouteAny = "echo_route_any" ) // Headers const ( HeaderAccept = "Accept" HeaderAcceptEncoding = "Accept-Encoding" // HeaderAllow is the name of the "Allow" header field used to list the set of methods // advertised as supported by the target resource. Returning an Allow header is mandatory // for status 405 (method not found) and useful for the OPTIONS method in responses. // See RFC 7231: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1 HeaderAllow = "Allow" HeaderAuthorization = "Authorization" HeaderContentDisposition = "Content-Disposition" HeaderContentEncoding = "Content-Encoding" HeaderContentLength = "Content-Length" HeaderContentType = "Content-Type" HeaderCookie = "Cookie" HeaderSetCookie = "Set-Cookie" HeaderIfModifiedSince = "If-Modified-Since" HeaderLastModified = "Last-Modified" HeaderLocation = "Location" HeaderRetryAfter = "Retry-After" HeaderUpgrade = "Upgrade" HeaderVary = "Vary" HeaderWWWAuthenticate = "WWW-Authenticate" HeaderXForwardedFor = "X-Forwarded-For" HeaderXForwardedProto = "X-Forwarded-Proto" HeaderXForwardedProtocol = "X-Forwarded-Protocol" HeaderXForwardedSsl = "X-Forwarded-Ssl" HeaderXUrlScheme = "X-Url-Scheme" HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" HeaderXRealIP = "X-Real-Ip" HeaderXRequestID = "X-Request-Id" HeaderXCorrelationID = "X-Correlation-Id" HeaderXRequestedWith = "X-Requested-With" HeaderServer = "Server" // HeaderOrigin request header indicates the origin (scheme, hostname, and port) that caused the request. // See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin HeaderOrigin = "Origin" HeaderCacheControl = "Cache-Control" HeaderConnection = "Connection" // Access control HeaderAccessControlRequestMethod = "Access-Control-Request-Method" HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers" HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin" HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods" HeaderAccessControlAllowHeaders = "Access-Control-Allow-Headers" HeaderAccessControlAllowCredentials = "Access-Control-Allow-Credentials" HeaderAccessControlExposeHeaders = "Access-Control-Expose-Headers" HeaderAccessControlMaxAge = "Access-Control-Max-Age" // Security HeaderStrictTransportSecurity = "Strict-Transport-Security" HeaderXContentTypeOptions = "X-Content-Type-Options" HeaderXXSSProtection = "X-XSS-Protection" HeaderXFrameOptions = "X-Frame-Options" HeaderContentSecurityPolicy = "Content-Security-Policy" HeaderContentSecurityPolicyReportOnly = "Content-Security-Policy-Report-Only" HeaderXCSRFToken = "X-CSRF-Token" // #nosec G101 HeaderReferrerPolicy = "Referrer-Policy" // HeaderSecFetchSite fetch metadata request header indicates the relationship between a request initiator's // origin and the origin of the requested resource. // See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Sec-Fetch-Site HeaderSecFetchSite = "Sec-Fetch-Site" ) // Config is configuration for NewWithConfig function type Config struct { // Logger is the slog logger instance used for application-wide structured logging. // If not set, a default TextHandler writing to stdout is created. Logger *slog.Logger // HTTPErrorHandler is the centralized error handler that processes errors returned // by handlers and middleware, converting them to appropriate HTTP responses. // If not set, DefaultHTTPErrorHandler(false) is used. HTTPErrorHandler HTTPErrorHandler // Router is the HTTP request router responsible for matching URLs to handlers // using a radix tree-based algorithm. // If not set, NewRouter(RouterConfig{}) is used. Router Router // OnAddRoute is an optional callback hook executed when routes are registered. // Useful for route validation, logging, or custom route processing. // If not set, no callback is executed. OnAddRoute func(route Route) error // Filesystem is the fs.FS implementation used for serving static files. // Supports os.DirFS, embed.FS, and custom implementations. // If not set, defaults to current working directory. Filesystem fs.FS // Binder handles automatic data binding from HTTP requests to Go structs. // Supports JSON, XML, form data, query parameters, and path parameters. // If not set, DefaultBinder is used. Binder Binder // Validator provides optional struct validation after data binding. // Commonly used with third-party validation libraries. // If not set, Context.Validate() returns ErrValidatorNotRegistered. Validator Validator // Renderer provides template rendering for generating HTML responses. // Requires integration with a template engine like html/template. // If not set, Context.Render() returns ErrRendererNotRegistered. Renderer Renderer // JSONSerializer handles JSON encoding and decoding for HTTP requests/responses. // Can be replaced with faster alternatives like jsoniter or sonic. // If not set, DefaultJSONSerializer using encoding/json is used. JSONSerializer JSONSerializer // IPExtractor defines the strategy for extracting the real client IP address // from requests, particularly important when behind proxies or load balancers. // Used for rate limiting, access control, and logging. // If not set, falls back to checking X-Forwarded-For and X-Real-IP headers. IPExtractor IPExtractor // FormParseMaxMemory is default value for memory limit that is used // when parsing multipart forms (See (*http.Request).ParseMultipartForm) FormParseMaxMemory int64 } // NewWithConfig creates an instance of Echo with given configuration. func NewWithConfig(config Config) *Echo { e := New() if config.Logger != nil { e.Logger = config.Logger } if config.HTTPErrorHandler != nil { e.HTTPErrorHandler = config.HTTPErrorHandler } if config.Router != nil { e.router = config.Router } if config.OnAddRoute != nil { e.OnAddRoute = config.OnAddRoute } if config.Filesystem != nil { e.Filesystem = config.Filesystem } if config.Binder != nil { e.Binder = config.Binder } if config.Validator != nil { e.Validator = config.Validator } if config.Renderer != nil { e.Renderer = config.Renderer } if config.JSONSerializer != nil { e.JSONSerializer = config.JSONSerializer } if config.IPExtractor != nil { e.IPExtractor = config.IPExtractor } if config.FormParseMaxMemory > 0 { e.formParseMaxMemory = config.FormParseMaxMemory } return e } // New creates an instance of Echo. func New() *Echo { logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) e := &Echo{ Logger: logger, Filesystem: newDefaultFS(), Binder: &DefaultBinder{}, JSONSerializer: &DefaultJSONSerializer{}, formParseMaxMemory: defaultMemory, } e.serveHTTPFunc = e.serveHTTP e.router = NewRouter(RouterConfig{}) e.HTTPErrorHandler = DefaultHTTPErrorHandler(false) e.contextPool.New = func() any { return newContext(nil, nil, e) } return e } // NewContext returns a new Context instance. // // Note: both request and response can be left to nil as Echo.ServeHTTP will call c.Reset(req,resp) anyway // these arguments are useful when creating context for tests and cases like that. func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) *Context { return newContext(r, w, e) } // Router returns the default router. func (e *Echo) Router() Router { return e.router } // DefaultHTTPErrorHandler creates new default HTTP error handler implementation. It sends a JSON response // with status code. `exposeError` parameter decides if returned message will contain also error message or not // // Note: DefaultHTTPErrorHandler does not log errors. Use middleware for it if errors need to be logged (separately) // Note: In case errors happens in middleware call-chain that is returning from handler (which did not return an error). // When handler has already sent response (ala c.JSON()) and there is error in middleware that is returning from // handler. Then the error that global error handler received will be ignored because we have already "committed" the // response and status code header has been sent to the client. func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler { return func(c *Context, err error) { if r, _ := UnwrapResponse(c.response); r != nil && r.Committed { return } code := http.StatusInternalServerError var sc HTTPStatusCoder if errors.As(err, &sc) { if tmp := sc.StatusCode(); tmp != 0 { code = tmp } } var result any switch m := sc.(type) { case json.Marshaler: // this type knows how to format itself to JSON result = m case *HTTPError: sText := m.Message if sText == "" { sText = http.StatusText(code) } msg := map[string]any{"message": sText} if exposeError { if wrappedErr := m.Unwrap(); wrappedErr != nil { msg["error"] = wrappedErr.Error() } } result = msg default: msg := map[string]any{"message": http.StatusText(code)} if exposeError { msg["error"] = err.Error() } result = msg } var cErr error if c.Request().Method == http.MethodHead { // Issue #608 cErr = c.NoContent(code) } else { cErr = c.JSON(code, result) } if cErr != nil { c.Logger().Error("echo default error handler failed to send error to client", "error", cErr) // truly rare case. ala client already disconnected } } } // Pre adds middleware to the chain which is run before router tries to find matching route. // Meaning middleware is executed even for 404 (not found) cases. func (e *Echo) Pre(middleware ...MiddlewareFunc) { e.premiddleware = append(e.premiddleware, middleware...) } // Use adds middleware to the chain which is run after router has found matching route and before route/request handler method is executed. func (e *Echo) Use(middleware ...MiddlewareFunc) { e.middleware = append(e.middleware, middleware...) } // CONNECT registers a new CONNECT route for a path with matching handler in the // router with optional route-level middleware. Panics on error. func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodConnect, path, h, m...) } // DELETE registers a new DELETE route for a path with matching handler in the router // with optional route-level middleware. Panics on error. func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodDelete, path, h, m...) } // GET registers a new GET route for a path with matching handler in the router // with optional route-level middleware. Panics on error. func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodGet, path, h, m...) } // HEAD registers a new HEAD route for a path with matching handler in the // router with optional route-level middleware. Panics on error. func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodHead, path, h, m...) } // OPTIONS registers a new OPTIONS route for a path with matching handler in the // router with optional route-level middleware. Panics on error. func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodOptions, path, h, m...) } // PATCH registers a new PATCH route for a path with matching handler in the // router with optional route-level middleware. Panics on error. func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPatch, path, h, m...) } // POST registers a new POST route for a path with matching handler in the // router with optional route-level middleware. Panics on error. func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPost, path, h, m...) } // PUT registers a new PUT route for a path with matching handler in the // router with optional route-level middleware. Panics on error. func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPut, path, h, m...) } // TRACE registers a new TRACE route for a path with matching handler in the // router with optional route-level middleware. Panics on error. func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodTrace, path, h, m...) } // RouteNotFound registers a special-case route which is executed when no other route is found (i.e. HTTP 404 cases) // for current request URL. // Path supports static and named/any parameters just like other http method is defined. Generally path is ended with // wildcard/match-any character (`/*`, `/download/*` etc). // // Example: `e.RouteNotFound("/*", func(c *echo.Context) error { return c.NoContent(http.StatusNotFound) })` func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(RouteNotFound, path, h, m...) } // Any registers a new route for all HTTP methods (supported by Echo) and path with matching handler // in the router with optional route-level middleware. // // Note: this method only adds specific set of supported HTTP methods as handler and is not true // "catch-any-arbitrary-method" way of matching requests. func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { return e.Add(RouteAny, path, handler, middleware...) } // Match registers a new route for multiple HTTP methods and path with matching // handler in the router with optional route-level middleware. Panics on error. func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { errs := make([]error, 0) ris := make(Routes, 0) for _, m := range methods { ri, err := e.AddRoute(Route{ Method: m, Path: path, Handler: handler, Middlewares: middleware, }) if err != nil { errs = append(errs, err) continue } ris = append(ris, ri) } if len(errs) > 0 { panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage } return ris } // Static registers a new route with path prefix to serve static files from the provided root directory. func (e *Echo) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo { subFs := MustSubFS(e.Filesystem, fsRoot) return e.Add( http.MethodGet, pathPrefix+"*", StaticDirectoryHandler(subFs, false), middleware..., ) } // StaticFS registers a new route with path prefix to serve static files from the provided file system. // // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths // including `assets/images` as their prefix. func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo { return e.Add( http.MethodGet, pathPrefix+"*", StaticDirectoryHandler(filesystem, false), middleware..., ) } // StaticDirectoryHandler creates handler function to serve files from provided file system // When disablePathUnescaping is set then file name from path is not unescaped and is served as is. func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc { return func(c *Context) error { p := c.Param("*") if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice tmpPath, err := url.PathUnescape(p) if err != nil { return fmt.Errorf("failed to unescape path variable: %w", err) } p = tmpPath } // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/"))) fi, err := fs.Stat(fileSystem, name) if err != nil { return ErrNotFound } // If the request is for a directory and does not end with "/" p = c.Request().URL.Path // path must not be empty. if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' { // Redirect to ends with "/" return c.Redirect(http.StatusMovedPermanently, sanitizeURI(p+"/")) } return fsFile(c, name, fileSystem) } } // FileFS registers a new route with path to serve file from the provided file system. func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo { return e.GET(path, StaticFileHandler(file, filesystem), m...) } // StaticFileHandler creates handler function to serve file from provided file system func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc { return func(c *Context) error { return fsFile(c, file, filesystem) } } // File registers a new route with path to serve a static file with optional route-level middleware. Panics on error. func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo { handler := func(c *Context) error { return c.File(file) } return e.Add(http.MethodGet, path, handler, middleware...) } // AddRoute registers a new Route with default host Router func (e *Echo) AddRoute(route Route) (RouteInfo, error) { return e.add(route) } func (e *Echo) add(route Route) (RouteInfo, error) { if e.OnAddRoute != nil { if err := e.OnAddRoute(route); err != nil { return RouteInfo{}, err } } ri, err := e.router.Add(route) if err != nil { return RouteInfo{}, err } paramsCount := int32(len(ri.Parameters)) // #nosec G115 if paramsCount > e.contextPathParamAllocSize.Load() { e.contextPathParamAllocSize.Store(paramsCount) } return ri, nil } // Add registers a new route for an HTTP method and path with matching handler // in the router with optional route-level middleware. func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { ri, err := e.add( Route{ Method: method, Path: path, Handler: handler, Middlewares: middleware, Name: "", }, ) if err != nil { panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage } return ri } // Group creates a new router group with prefix and optional group-level middleware. func (e *Echo) Group(prefix string, m ...MiddlewareFunc) (g *Group) { g = &Group{prefix: prefix, echo: e} g.Use(m...) return } // PreMiddlewares returns registered pre middlewares. These are middleware to the chain // which are run before router tries to find matching route. // Use this method to build your own ServeHTTP method. // // NOTE: returned slice is not a copy. Do not mutate. func (e *Echo) PreMiddlewares() []MiddlewareFunc { return e.premiddleware } // Middlewares returns registered route level middlewares. Does not contain any group level // middlewares. Use this method to build your own ServeHTTP method. // // NOTE: returned slice is not a copy. Do not mutate. func (e *Echo) Middlewares() []MiddlewareFunc { return e.middleware } // AcquireContext returns an empty `Context` instance from the pool. // You must return the context by calling `ReleaseContext()`. func (e *Echo) AcquireContext() *Context { return e.contextPool.Get().(*Context) } // ReleaseContext returns the `Context` instance back to the pool. // You must call it after `AcquireContext()`. func (e *Echo) ReleaseContext(c *Context) { e.contextPool.Put(c) } // ServeHTTP implements `http.Handler` interface, which serves HTTP requests. func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { e.serveHTTPFunc(w, r) } // serveHTTP implements `http.Handler` interface, which serves HTTP requests. func (e *Echo) serveHTTP(w http.ResponseWriter, r *http.Request) { c := e.contextPool.Get().(*Context) defer e.contextPool.Put(c) c.Reset(r, w) var h HandlerFunc if e.premiddleware == nil { h = applyMiddleware(e.router.Route(c), e.middleware...) } else { h = func(cc *Context) error { h1 := applyMiddleware(e.router.Route(cc), e.middleware...) return h1(cc) } h = applyMiddleware(h, e.premiddleware...) } // Execute chain if err := h(c); err != nil { e.HTTPErrorHandler(c, err) } } // Start stars HTTP server on given address with Echo as a handler serving requests. The server can be shutdown by // sending os.Interrupt signal with `ctrl+c`. Method returns only errors that are not http.ErrServerClosed. // // Note: this method is created for use in examples/demos and is deliberately simple without providing configuration // options. // // In need of customization use: // // ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) // defer cancel() // sc := echo.StartConfig{Address: ":8080"} // if err := sc.Start(ctx, e); err != nil && !errors.Is(err, http.ErrServerClosed) { // slog.Error(err.Error()) // } // // // or standard library `http.Server` // // s := http.Server{Addr: ":8080", Handler: e} // if err := s.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { // slog.Error(err.Error()) // } func (e *Echo) Start(address string) error { sc := StartConfig{Address: address} ctx, cancel := signal.NotifyContext(stdContext.Background(), os.Interrupt, syscall.SIGTERM) // start shutdown process on ctrl+c defer cancel() return sc.Start(ctx, e) } // WrapHandler wraps `http.Handler` into `echo.HandlerFunc`. func WrapHandler(h http.Handler) HandlerFunc { return func(c *Context) error { req := c.Request() req.Pattern = c.Path() for _, p := range c.PathValues() { req.SetPathValue(p.Name, p.Value) } h.ServeHTTP(c.Response(), req) return nil } } // WrapMiddleware wraps `func(http.Handler) http.Handler` into `echo.MiddlewareFunc` func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc { return func(next HandlerFunc) HandlerFunc { return func(c *Context) (err error) { req := c.Request() req.Pattern = c.Path() for _, p := range c.PathValues() { req.SetPathValue(p.Name, p.Value) } m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c.SetRequest(r) c.SetResponse(NewResponse(w, c.echo.Logger)) err = next(c) })).ServeHTTP(c.Response(), req) return } } } func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc { for i := len(middleware) - 1; i >= 0; i-- { h = middleware[i](h) } return h } // defaultFS emulates os.Open behaviour with filesystem opened by `os.DirFs`. Difference between `os.Open` and `fs.Open` // is that FS does not allow to open path that start with `..` or `/` etc. For example previously you could have `../images` // in your application but `fs := os.DirFS("./")` would not allow you to use `fs.Open("../images")` and this would break // all old applications that rely on being able to traverse up from current executable run path. // NB: private because you really should use fs.FS implementation instances type defaultFS struct { fs fs.FS prefix string } func newDefaultFS() *defaultFS { dir, _ := os.Getwd() return &defaultFS{ prefix: dir, fs: os.DirFS(dir), } } func (fs defaultFS) Open(name string) (fs.File, error) { return fs.fs.Open(name) } func subFS(currentFs fs.FS, root string) (fs.FS, error) { root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows if dFS, ok := currentFs.(*defaultFS); ok { // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS. // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs if !filepath.IsAbs(root) { root = filepath.Join(dFS.prefix, root) } return &defaultFS{ prefix: root, fs: os.DirFS(root), }, nil } return fs.Sub(currentFs, root) } // MustSubFS creates sub FS from current filesystem or panic on failure. // Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. // // MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with // paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to // create sub fs which uses necessary prefix for directory path. func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS { subFs, err := subFS(currentFs, fsRoot) if err != nil { panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err)) } return subFs } func sanitizeURI(uri string) string { // double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri // we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') { uri = "/" + strings.TrimLeft(uri, `/\`) } return uri } ================================================ FILE: echo_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echo import ( "bytes" stdContext "context" "errors" "fmt" "io/fs" "log/slog" "net" "net/http" "net/http/httptest" "net/url" "os" "runtime" "strings" "testing" "time" "github.com/stretchr/testify/assert" ) type user struct { ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"` Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"` } const ( userJSON = `{"id":1,"name":"Jon Snow"}` usersJSON = `[{"id":1,"name":"Jon Snow"}]` userXML = `1Jon Snow` userForm = `id=1&name=Jon Snow` invalidContent = "invalid content" userJSONInvalidType = `{"id":"1","name":"Jon Snow"}` userXMLConvertNumberError = `Number oneJon Snow` userXMLUnsupportedTypeError = `<>Number oneJon Snow` ) const userJSONPretty = `{ "id": 1, "name": "Jon Snow" }` const userXMLPretty = ` 1 Jon Snow ` var dummyQuery = url.Values{"dummy": []string{"useless"}} func TestEcho(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) // Router assert.NotNil(t, e.Router()) e.HTTPErrorHandler(c, errors.New("error")) assert.Equal(t, http.StatusInternalServerError, rec.Code) } func TestNewWithConfig(t *testing.T) { e := NewWithConfig(Config{}) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e.GET("/", func(c *Context) error { return c.String(http.StatusTeapot, "Hello, World!") }) e.ServeHTTP(rec, req) assert.Equal(t, http.StatusTeapot, rec.Code) assert.Equal(t, `Hello, World!`, rec.Body.String()) } func TestEcho_StaticFS(t *testing.T) { var testCases = []struct { givenFs fs.FS name string givenPrefix string givenFsRoot string whenURL string expectHeaderLocation string expectBodyStartsWith string expectStatus int }{ { name: "ok", givenPrefix: "/images", givenFs: os.DirFS("./_fixture/images"), whenURL: "/images/walle.png", expectStatus: http.StatusOK, expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), }, { name: "ok, from sub fs", givenPrefix: "/images", givenFs: MustSubFS(os.DirFS("./_fixture/"), "images"), whenURL: "/images/walle.png", expectStatus: http.StatusOK, expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), }, { name: "No file", givenPrefix: "/images", givenFs: os.DirFS("_fixture/scripts"), whenURL: "/images/bolt.png", expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", }, { name: "Directory", givenPrefix: "/images", givenFs: os.DirFS("_fixture/images"), whenURL: "/images/", expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", }, { name: "Directory Redirect", givenPrefix: "/", givenFs: os.DirFS("_fixture/"), whenURL: "/folder", expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/folder/", expectBodyStartsWith: "", }, { name: "Directory Redirect with non-root path", givenPrefix: "/static", givenFs: os.DirFS("_fixture"), whenURL: "/static", expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/static/", expectBodyStartsWith: "", }, { name: "Prefixed directory 404 (request URL without slash)", givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" givenFs: os.DirFS("_fixture"), whenURL: "/folder", // no trailing slash expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", }, { name: "Prefixed directory redirect (without slash redirect to slash)", givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* givenFs: os.DirFS("_fixture"), whenURL: "/folder", // no trailing slash expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/folder/", expectBodyStartsWith: "", }, { name: "Directory with index.html", givenPrefix: "/", givenFs: os.DirFS("_fixture"), whenURL: "/", expectStatus: http.StatusOK, expectBodyStartsWith: "", }, { name: "Prefixed directory with index.html (prefix ending with slash)", givenPrefix: "/assets/", givenFs: os.DirFS("_fixture"), whenURL: "/assets/", expectStatus: http.StatusOK, expectBodyStartsWith: "", }, { name: "Prefixed directory with index.html (prefix ending without slash)", givenPrefix: "/assets", givenFs: os.DirFS("_fixture"), whenURL: "/assets/", expectStatus: http.StatusOK, expectBodyStartsWith: "", }, { name: "Sub-directory with index.html", givenPrefix: "/", givenFs: os.DirFS("_fixture"), whenURL: "/folder/", expectStatus: http.StatusOK, expectBodyStartsWith: "", }, { name: "do not allow directory traversal (backslash - windows separator)", givenPrefix: "/", givenFs: os.DirFS("_fixture/"), whenURL: `/..\\middleware/basic_auth.go`, expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", }, { name: "do not allow directory traversal (slash - unix separator)", givenPrefix: "/", givenFs: os.DirFS("_fixture/"), whenURL: `/../middleware/basic_auth.go`, expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", }, { name: "open redirect vulnerability", givenPrefix: "/", givenFs: os.DirFS("_fixture/"), whenURL: "/open.redirect.hackercom%2f..", expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/open.redirect.hackercom/../", // location starting with `//open` would be very bad expectBodyStartsWith: "", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() tmpFs := tc.givenFs if tc.givenFsRoot != "" { tmpFs = MustSubFS(tmpFs, tc.givenFsRoot) } e.StaticFS(tc.givenPrefix, tmpFs) req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, tc.expectStatus, rec.Code) body := rec.Body.String() if tc.expectBodyStartsWith != "" { assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) } else { assert.Equal(t, "", body) } if tc.expectHeaderLocation != "" { assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0]) } else { _, ok := rec.Result().Header["Location"] assert.False(t, ok) } }) } } func TestEcho_FileFS(t *testing.T) { var testCases = []struct { whenFS fs.FS name string whenPath string whenFile string givenURL string expectStartsWith []byte expectCode int }{ { name: "ok", whenPath: "/walle", whenFS: os.DirFS("_fixture/images"), whenFile: "walle.png", givenURL: "/walle", expectCode: http.StatusOK, expectStartsWith: []byte{0x89, 0x50, 0x4e}, }, { name: "nok, requesting invalid path", whenPath: "/walle", whenFS: os.DirFS("_fixture/images"), whenFile: "walle.png", givenURL: "/walle.png", expectCode: http.StatusNotFound, expectStartsWith: []byte(`{"message":"Not Found"}`), }, { name: "nok, serving not existent file from filesystem", whenPath: "/walle", whenFS: os.DirFS("_fixture/images"), whenFile: "not-existent.png", givenURL: "/walle", expectCode: http.StatusNotFound, expectStartsWith: []byte(`{"message":"Not Found"}`), }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() e.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, tc.expectCode, rec.Code) body := rec.Body.Bytes() if len(body) > len(tc.expectStartsWith) { body = body[:len(tc.expectStartsWith)] } assert.Equal(t, tc.expectStartsWith, body) }) } } func TestEcho_StaticPanic(t *testing.T) { var testCases = []struct { name string givenRoot string }{ { name: "panics for ../", givenRoot: "../assets", }, { name: "panics for /", givenRoot: "/assets", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() e.Filesystem = os.DirFS("./") assert.Panics(t, func() { e.Static("../assets", tc.givenRoot) }) }) } } func TestEchoStaticRedirectIndex(t *testing.T) { e := New() // HandlerFunc ri := e.Static("/static", "_fixture") assert.Equal(t, http.MethodGet, ri.Method) assert.Equal(t, "/static*", ri.Path) assert.Equal(t, "GET:/static*", ri.Name) assert.Equal(t, []string{"*"}, ri.Parameters) ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond) defer cancel() addr, err := startOnRandomPort(ctx, e) if err != nil { assert.Fail(t, err.Error()) } code, body, err := doGet(fmt.Sprintf("http://%v/static", addr)) assert.NoError(t, err) assert.True(t, strings.HasPrefix(body, "")) assert.Equal(t, http.StatusOK, code) } func TestEchoFile(t *testing.T) { var testCases = []struct { name string givenPath string givenFile string whenPath string expectStartsWith string expectCode int }{ { name: "ok", givenPath: "/walle", givenFile: "_fixture/images/walle.png", whenPath: "/walle", expectCode: http.StatusOK, expectStartsWith: string([]byte{0x89, 0x50, 0x4e}), }, { name: "ok with relative path", givenPath: "/", givenFile: "./go.mod", whenPath: "/", expectCode: http.StatusOK, expectStartsWith: "module github.com/labstack/echo/v", }, { name: "nok file does not exist", givenPath: "/", givenFile: "./this-file-does-not-exist", whenPath: "/", expectCode: http.StatusNotFound, expectStartsWith: "{\"message\":\"Not Found\"}\n", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() // we are using echo.defaultFS instance e.File(tc.givenPath, tc.givenFile) c, b := request(http.MethodGet, tc.whenPath, e) assert.Equal(t, tc.expectCode, c) if len(b) > len(tc.expectStartsWith) { b = b[:len(tc.expectStartsWith)] } assert.Equal(t, tc.expectStartsWith, b) }) } } func TestEchoMiddleware(t *testing.T) { e := New() buf := new(bytes.Buffer) e.Pre(func(next HandlerFunc) HandlerFunc { return func(c *Context) error { // before route match is found RouteInfo does not exist assert.Equal(t, RouteInfo{}, c.RouteInfo()) buf.WriteString("-1") return next(c) } }) e.Use(func(next HandlerFunc) HandlerFunc { return func(c *Context) error { buf.WriteString("1") return next(c) } }) e.Use(func(next HandlerFunc) HandlerFunc { return func(c *Context) error { buf.WriteString("2") return next(c) } }) e.Use(func(next HandlerFunc) HandlerFunc { return func(c *Context) error { buf.WriteString("3") return next(c) } }) // Route e.GET("/", func(c *Context) error { return c.String(http.StatusOK, "OK") }) c, b := request(http.MethodGet, "/", e) assert.Equal(t, "-1123", buf.String()) assert.Equal(t, http.StatusOK, c) assert.Equal(t, "OK", b) } func TestEchoMiddlewareError(t *testing.T) { e := New() e.Use(func(next HandlerFunc) HandlerFunc { return func(c *Context) error { return errors.New("error") } }) e.GET("/", notFoundHandler) c, _ := request(http.MethodGet, "/", e) assert.Equal(t, http.StatusInternalServerError, c) } func TestEchoHandler(t *testing.T) { e := New() // HandlerFunc e.GET("/ok", func(c *Context) error { return c.String(http.StatusOK, "OK") }) c, b := request(http.MethodGet, "/ok", e) assert.Equal(t, http.StatusOK, c) assert.Equal(t, "OK", b) } func TestEchoWrapHandler(t *testing.T) { e := New() var actualID string var actualPattern string e.GET("/:id", WrapHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("test")) actualID = r.PathValue("id") actualPattern = r.Pattern }))) req := httptest.NewRequest(http.MethodGet, "/123", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "test", rec.Body.String()) assert.Equal(t, "123", actualID) assert.Equal(t, "/:id", actualPattern) } func TestEchoWrapMiddleware(t *testing.T) { e := New() var actualID string var actualPattern string e.Use(WrapMiddleware(func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { actualID = r.PathValue("id") actualPattern = r.Pattern h.ServeHTTP(w, r) }) })) e.GET("/:id", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) req := httptest.NewRequest(http.MethodGet, "/123", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusTeapot, rec.Code) assert.Equal(t, "OK", rec.Body.String()) assert.Equal(t, "123", actualID) assert.Equal(t, "/:id", actualPattern) } func TestEchoConnect(t *testing.T) { e := New() ri := e.CONNECT("/", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) assert.Equal(t, http.MethodConnect, ri.Method) assert.Equal(t, "/", ri.Path) assert.Equal(t, http.MethodConnect+":/", ri.Name) assert.Nil(t, ri.Parameters) status, body := request(http.MethodConnect, "/", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, "OK", body) } func TestEchoDelete(t *testing.T) { e := New() ri := e.DELETE("/", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) assert.Equal(t, http.MethodDelete, ri.Method) assert.Equal(t, "/", ri.Path) assert.Equal(t, http.MethodDelete+":/", ri.Name) assert.Nil(t, ri.Parameters) status, body := request(http.MethodDelete, "/", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, "OK", body) } func TestEchoGet(t *testing.T) { e := New() ri := e.GET("/", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) assert.Equal(t, http.MethodGet, ri.Method) assert.Equal(t, "/", ri.Path) assert.Equal(t, http.MethodGet+":/", ri.Name) assert.Nil(t, ri.Parameters) status, body := request(http.MethodGet, "/", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, "OK", body) } func TestEchoHead(t *testing.T) { e := New() ri := e.HEAD("/", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) assert.Equal(t, http.MethodHead, ri.Method) assert.Equal(t, "/", ri.Path) assert.Equal(t, http.MethodHead+":/", ri.Name) assert.Nil(t, ri.Parameters) status, body := request(http.MethodHead, "/", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, "OK", body) } func TestEchoOptions(t *testing.T) { e := New() ri := e.OPTIONS("/", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) assert.Equal(t, http.MethodOptions, ri.Method) assert.Equal(t, "/", ri.Path) assert.Equal(t, http.MethodOptions+":/", ri.Name) assert.Nil(t, ri.Parameters) status, body := request(http.MethodOptions, "/", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, "OK", body) } func TestEchoPatch(t *testing.T) { e := New() ri := e.PATCH("/", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) assert.Equal(t, http.MethodPatch, ri.Method) assert.Equal(t, "/", ri.Path) assert.Equal(t, http.MethodPatch+":/", ri.Name) assert.Nil(t, ri.Parameters) status, body := request(http.MethodPatch, "/", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, "OK", body) } func TestEchoPost(t *testing.T) { e := New() ri := e.POST("/", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) assert.Equal(t, http.MethodPost, ri.Method) assert.Equal(t, "/", ri.Path) assert.Equal(t, http.MethodPost+":/", ri.Name) assert.Nil(t, ri.Parameters) status, body := request(http.MethodPost, "/", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, "OK", body) } func TestEchoPut(t *testing.T) { e := New() ri := e.PUT("/", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) assert.Equal(t, http.MethodPut, ri.Method) assert.Equal(t, "/", ri.Path) assert.Equal(t, http.MethodPut+":/", ri.Name) assert.Nil(t, ri.Parameters) status, body := request(http.MethodPut, "/", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, "OK", body) } func TestEchoTrace(t *testing.T) { e := New() ri := e.TRACE("/", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) assert.Equal(t, http.MethodTrace, ri.Method) assert.Equal(t, "/", ri.Path) assert.Equal(t, http.MethodTrace+":/", ri.Name) assert.Nil(t, ri.Parameters) status, body := request(http.MethodTrace, "/", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, "OK", body) } func TestEcho_Any(t *testing.T) { e := New() ri := e.Any("/activate", func(c *Context) error { return c.String(http.StatusTeapot, "OK from ANY") }) assert.Equal(t, RouteAny, ri.Method) assert.Equal(t, "/activate", ri.Path) assert.Equal(t, RouteAny+":/activate", ri.Name) assert.Nil(t, ri.Parameters) status, body := request(http.MethodTrace, "/activate", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, `OK from ANY`, body) } func TestEcho_Any_hasLowerPriority(t *testing.T) { e := New() e.Any("/activate", func(c *Context) error { return c.String(http.StatusTeapot, "ANY") }) e.GET("/activate", func(c *Context) error { return c.String(http.StatusLocked, "GET") }) status, body := request(http.MethodTrace, "/activate", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, `ANY`, body) status, body = request(http.MethodGet, "/activate", e) assert.Equal(t, http.StatusLocked, status) assert.Equal(t, `GET`, body) } func TestEchoMatch(t *testing.T) { // JFC e := New() ris := e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c *Context) error { return c.String(http.StatusOK, "Match") }) assert.Len(t, ris, 2) } func TestEchoServeHTTPPathEncoding(t *testing.T) { e := New() e.GET("/with/slash", func(c *Context) error { return c.String(http.StatusOK, "/with/slash") }) e.GET("/:id", func(c *Context) error { return c.String(http.StatusOK, c.Param("id")) }) var testCases = []struct { name string whenURL string expectURL string expectStatus int }{ { name: "url with encoding is not decoded for routing", whenURL: "/with%2Fslash", expectURL: "with%2Fslash", // `%2F` is not decoded to `/` for routing expectStatus: http.StatusOK, }, { name: "url without encoding is used as is", whenURL: "/with/slash", expectURL: "/with/slash", expectStatus: http.StatusOK, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, tc.expectStatus, rec.Code) assert.Equal(t, tc.expectURL, rec.Body.String()) }) } } func TestEchoGroup(t *testing.T) { e := New() buf := new(bytes.Buffer) e.Use(MiddlewareFunc(func(next HandlerFunc) HandlerFunc { return func(c *Context) error { buf.WriteString("0") return next(c) } })) h := func(c *Context) error { return c.NoContent(http.StatusOK) } //-------- // Routes //-------- e.GET("/users", h) // Group g1 := e.Group("/group1") g1.Use(func(next HandlerFunc) HandlerFunc { return func(c *Context) error { buf.WriteString("1") return next(c) } }) g1.GET("", h) // Nested groups with middleware g2 := e.Group("/group2") g2.Use(func(next HandlerFunc) HandlerFunc { return func(c *Context) error { buf.WriteString("2") return next(c) } }) g3 := g2.Group("/group3") g3.Use(func(next HandlerFunc) HandlerFunc { return func(c *Context) error { buf.WriteString("3") return next(c) } }) g3.GET("", h) request(http.MethodGet, "/users", e) assert.Equal(t, "0", buf.String()) buf.Reset() request(http.MethodGet, "/group1", e) assert.Equal(t, "01", buf.String()) buf.Reset() request(http.MethodGet, "/group2/group3", e) assert.Equal(t, "023", buf.String()) } func TestEcho_RouteNotFound(t *testing.T) { var testCases = []struct { expectRoute any name string whenURL string expectCode int }{ { name: "404, route to static not found handler /a/c/xx", whenURL: "/a/c/xx", expectRoute: "GET /a/c/xx", expectCode: http.StatusNotFound, }, { name: "404, route to path param not found handler /a/:file", whenURL: "/a/echo.exe", expectRoute: "GET /a/:file", expectCode: http.StatusNotFound, }, { name: "404, route to any not found handler /*", whenURL: "/b/echo.exe", expectRoute: "GET /*", expectCode: http.StatusNotFound, }, { name: "200, route /a/c/df to /a/c/df", whenURL: "/a/c/df", expectRoute: "GET /a/c/df", expectCode: http.StatusOK, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() okHandler := func(c *Context) error { return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) } notFoundHandler := func(c *Context) error { return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) } e.GET("/", okHandler) e.GET("/a/c/df", okHandler) e.GET("/a/b*", okHandler) e.PUT("/*", okHandler) e.RouteNotFound("/a/c/xx", notFoundHandler) // static e.RouteNotFound("/a/:file", notFoundHandler) // param e.RouteNotFound("/*", notFoundHandler) // any req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, tc.expectCode, rec.Code) assert.Equal(t, tc.expectRoute, rec.Body.String()) }) } } func TestEchoNotFound(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/files", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusNotFound, rec.Code) } func TestEchoMethodNotAllowed(t *testing.T) { e := New() e.GET("/", func(c *Context) error { return c.String(http.StatusOK, "Echo!") }) req := httptest.NewRequest(http.MethodPost, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusMethodNotAllowed, rec.Code) assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow)) } func TestEcho_OnAddRoute(t *testing.T) { exampleRoute := Route{ Method: http.MethodGet, Path: "/api/files/:id", Handler: notFoundHandler, Middlewares: nil, Name: "x", } var testCases = []struct { whenRoute Route whenError error name string expectError string expectAdded []string expectLen int }{ { name: "ok", whenRoute: exampleRoute, whenError: nil, expectAdded: []string{"/static", "/api/files/:id"}, expectError: "", expectLen: 2, }, { name: "nok, error is returned", whenRoute: exampleRoute, whenError: errors.New("nope"), expectAdded: []string{"/static"}, expectError: "nope", expectLen: 1, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() added := make([]string, 0) cnt := 0 e.OnAddRoute = func(route Route) error { if cnt > 0 && tc.whenError != nil { // we want to GET /static to succeed for nok tests return tc.whenError } cnt++ added = append(added, route.Path) return nil } e.GET("/static", notFoundHandler) var err error _, err = e.AddRoute(tc.whenRoute) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } assert.Len(t, e.Router().Routes(), tc.expectLen) assert.Equal(t, tc.expectAdded, added) }) } } func TestEchoContext(t *testing.T) { e := New() c := e.AcquireContext() assert.IsType(t, new(Context), c) e.ReleaseContext(c) } func TestPreMiddlewares(t *testing.T) { e := New() assert.Equal(t, 0, len(e.PreMiddlewares())) e.Pre(func(next HandlerFunc) HandlerFunc { return func(c *Context) error { return next(c) } }) assert.Equal(t, 1, len(e.PreMiddlewares())) } func TestMiddlewares(t *testing.T) { e := New() assert.Equal(t, 0, len(e.Middlewares())) e.Use(func(next HandlerFunc) HandlerFunc { return func(c *Context) error { return next(c) } }) assert.Equal(t, 1, len(e.Middlewares())) } func TestEcho_Start(t *testing.T) { e := New() e.GET("/", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) rndPort, err := net.Listen("tcp", ":0") if err != nil { t.Fatal(err) } defer rndPort.Close() errChan := make(chan error, 1) go func() { errChan <- e.Start(rndPort.Addr().String()) }() select { case <-time.After(250 * time.Millisecond): t.Fatal("start did not error out") case err := <-errChan: expectContains := "bind: address already in use" if runtime.GOOS == "windows" { expectContains = "bind: Only one usage of each socket address" } assert.Contains(t, err.Error(), expectContains) } } func request(method, path string, e *Echo) (int, string) { req := httptest.NewRequest(method, path, nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) return rec.Code, rec.Body.String() } type customError struct { Code int Message string } func (ce *customError) StatusCode() int { return ce.Code } func (ce *customError) MarshalJSON() ([]byte, error) { return []byte(fmt.Sprintf(`{"x":"%v"}`, ce.Message)), nil } func (ce *customError) Error() string { return ce.Message } func TestDefaultHTTPErrorHandler(t *testing.T) { var testCases = []struct { whenError error name string whenMethod string expectBody string expectLogged string expectStatus int givenExposeError bool givenLoggerFunc bool }{ { name: "ok, expose error = true, HTTPError, no wrapped err", givenExposeError: true, whenError: &HTTPError{Code: http.StatusTeapot, Message: "my_error"}, expectStatus: http.StatusTeapot, expectBody: `{"message":"my_error"}` + "\n", }, { name: "ok, expose error = true, HTTPError + wrapped error", givenExposeError: true, whenError: HTTPError{Code: http.StatusTeapot, Message: "my_error"}.Wrap(errors.New("internal_error")), expectStatus: http.StatusTeapot, expectBody: `{"error":"internal_error","message":"my_error"}` + "\n", }, { name: "ok, expose error = true, HTTPError + wrapped HTTPError", givenExposeError: true, whenError: HTTPError{Code: http.StatusTeapot, Message: "my_error"}.Wrap(&HTTPError{Code: http.StatusTeapot, Message: "early_error"}), expectStatus: http.StatusTeapot, expectBody: `{"error":"code=418, message=early_error","message":"my_error"}` + "\n", }, { name: "ok, expose error = false, HTTPError", whenError: &HTTPError{Code: http.StatusTeapot, Message: "my_error"}, expectStatus: http.StatusTeapot, expectBody: `{"message":"my_error"}` + "\n", }, { name: "ok, expose error = false, HTTPError, no message", whenError: &HTTPError{Code: http.StatusTeapot, Message: ""}, expectStatus: http.StatusTeapot, expectBody: `{"message":"I'm a teapot"}` + "\n", }, { name: "ok, expose error = false, HTTPError + internal HTTPError", whenError: HTTPError{Code: http.StatusTooEarly, Message: "my_error"}.Wrap(&HTTPError{Code: http.StatusTeapot, Message: "early_error"}), expectStatus: http.StatusTooEarly, expectBody: `{"message":"my_error"}` + "\n", }, { name: "ok, expose error = true, Error", givenExposeError: true, whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), expectStatus: http.StatusInternalServerError, expectBody: `{"error":"my errors wraps: internal_error","message":"Internal Server Error"}` + "\n", }, { name: "ok, expose error = false, Error", whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), expectStatus: http.StatusInternalServerError, expectBody: `{"message":"Internal Server Error"}` + "\n", }, { name: "ok, http.HEAD, expose error = true, Error", givenExposeError: true, whenMethod: http.MethodHead, whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), expectStatus: http.StatusInternalServerError, expectBody: ``, }, { name: "ok, custom error implement MarshalJSON + HTTPStatusCoder", whenMethod: http.MethodGet, whenError: &customError{Code: http.StatusTeapot, Message: "custom error msg"}, expectStatus: http.StatusTeapot, expectBody: `{"x":"custom error msg"}` + "\n", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { buf := new(bytes.Buffer) e := New() e.Logger = slog.New(slog.DiscardHandler) e.Any("/path", func(c *Context) error { return tc.whenError }) e.HTTPErrorHandler = DefaultHTTPErrorHandler(tc.givenExposeError) method := http.MethodGet if tc.whenMethod != "" { method = tc.whenMethod } c, b := request(method, "/path", e) assert.Equal(t, tc.expectStatus, c) assert.Equal(t, tc.expectBody, b) assert.Equal(t, tc.expectLogged, buf.String()) }) } } func TestDefaultHTTPErrorHandler_CommitedResponse(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) resp := httptest.NewRecorder() c := e.NewContext(req, resp) c.orgResponse.Committed = true errHandler := DefaultHTTPErrorHandler(false) errHandler(c, errors.New("my_error")) assert.Equal(t, http.StatusOK, resp.Code) } func benchmarkEchoRoutes(b *testing.B, routes []testRoute) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) u := req.URL w := httptest.NewRecorder() b.ReportAllocs() // Add routes for _, route := range routes { e.Add(route.Method, route.Path, func(c *Context) error { return nil }) } // Find routes b.ResetTimer() for i := 0; i < b.N; i++ { for _, route := range routes { req.Method = route.Method u.Path = route.Path e.ServeHTTP(w, req) } } } func BenchmarkEchoStaticRoutes(b *testing.B) { benchmarkEchoRoutes(b, staticRoutes) } func BenchmarkEchoStaticRoutesMisses(b *testing.B) { benchmarkEchoRoutes(b, staticRoutes) } func BenchmarkEchoGitHubAPI(b *testing.B) { benchmarkEchoRoutes(b, gitHubAPI) } func BenchmarkEchoGitHubAPIMisses(b *testing.B) { benchmarkEchoRoutes(b, gitHubAPI) } func BenchmarkEchoParseAPI(b *testing.B) { benchmarkEchoRoutes(b, parseAPI) } ================================================ FILE: echotest/context.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echotest import ( "bytes" "io" "mime/multipart" "net/http" "net/http/httptest" "net/url" "strings" "testing" "github.com/labstack/echo/v5" ) // ContextConfig is configuration for creating echo.Context for testing purposes. type ContextConfig struct { // Request will be used instead of default `httptest.NewRequest(http.MethodGet, "/", nil)` Request *http.Request // Response will be used instead of default `httptest.NewRecorder()` Response *httptest.ResponseRecorder // QueryValues will be set as Request.URL.RawQuery value QueryValues url.Values // Headers will be set as Request.Header value Headers http.Header // PathValues initializes context.PathValues with given value. PathValues echo.PathValues // RouteInfo initializes context.RouteInfo() with given value RouteInfo *echo.RouteInfo // FormValues creates form-urlencoded form out of given values. If there is no // `content-type` header it will be set to `application/x-www-form-urlencoded` // In case Request was not set the Request.Method is set to `POST` // // FormValues, MultipartForm and JSONBody are mutually exclusive. FormValues url.Values // MultipartForm creates multipart form out of given value. If there is no // `content-type` header it will be set to `multipart/form-data` // In case Request was not set the Request.Method is set to `POST` // // FormValues, MultipartForm and JSONBody are mutually exclusive. MultipartForm *MultipartForm // JSONBody creates JSON body out of given bytes. If there is no // `content-type` header it will be set to `application/json` // In case Request was not set the Request.Method is set to `POST` // // FormValues, MultipartForm and JSONBody are mutually exclusive. JSONBody []byte } // MultipartForm is used to create multipart form out of given value type MultipartForm struct { Fields map[string]string Files []MultipartFormFile } // MultipartFormFile is used to create file in multipart form out of given value type MultipartFormFile struct { Fieldname string Filename string Content []byte } // ToContext converts ContextConfig to echo.Context func (conf ContextConfig) ToContext(t *testing.T) *echo.Context { c, _ := conf.ToContextRecorder(t) return c } // ToContextRecorder converts ContextConfig to echo.Context and httptest.ResponseRecorder func (conf ContextConfig) ToContextRecorder(t *testing.T) (*echo.Context, *httptest.ResponseRecorder) { if conf.Response == nil { conf.Response = httptest.NewRecorder() } isDefaultRequest := false if conf.Request == nil { isDefaultRequest = true conf.Request = httptest.NewRequest(http.MethodGet, "/", nil) } if len(conf.QueryValues) > 0 { conf.Request.URL.RawQuery = conf.QueryValues.Encode() } if len(conf.Headers) > 0 { conf.Request.Header = conf.Headers } if len(conf.FormValues) > 0 { body := strings.NewReader(url.Values(conf.FormValues).Encode()) conf.Request.Body = io.NopCloser(body) conf.Request.ContentLength = int64(body.Len()) if conf.Request.Header.Get(echo.HeaderContentType) == "" { conf.Request.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) } if isDefaultRequest { conf.Request.Method = http.MethodPost } } else if conf.MultipartForm != nil { var body bytes.Buffer mw := multipart.NewWriter(&body) for field, value := range conf.MultipartForm.Fields { if err := mw.WriteField(field, value); err != nil { t.Fatal(err) } } for _, file := range conf.MultipartForm.Files { fw, err := mw.CreateFormFile(file.Fieldname, file.Filename) if err != nil { t.Fatal(err) } if _, err = fw.Write(file.Content); err != nil { t.Fatal(err) } } if err := mw.Close(); err != nil { t.Fatal(err) } conf.Request.Body = io.NopCloser(&body) conf.Request.ContentLength = int64(body.Len()) if conf.Request.Header.Get(echo.HeaderContentType) == "" { conf.Request.Header.Set(echo.HeaderContentType, mw.FormDataContentType()) } if isDefaultRequest { conf.Request.Method = http.MethodPost } } else if conf.JSONBody != nil { body := bytes.NewReader(conf.JSONBody) conf.Request.Body = io.NopCloser(body) conf.Request.ContentLength = int64(body.Len()) if conf.Request.Header.Get(echo.HeaderContentType) == "" { conf.Request.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) } if isDefaultRequest { conf.Request.Method = http.MethodPost } } ec := echo.NewContext(conf.Request, conf.Response, echo.New()) if conf.RouteInfo == nil { conf.RouteInfo = &echo.RouteInfo{ Name: "", Method: conf.Request.Method, Path: "/test", Parameters: []string{}, } for _, p := range conf.PathValues { conf.RouteInfo.Parameters = append(conf.RouteInfo.Parameters, p.Name) } } ec.InitializeRoute(conf.RouteInfo, &conf.PathValues) return ec, conf.Response } // ServeWithHandler serves ContextConfig with given handler and returns httptest.ResponseRecorder for response checking func (conf ContextConfig) ServeWithHandler(t *testing.T, handler echo.HandlerFunc, opts ...any) *httptest.ResponseRecorder { c, rec := conf.ToContextRecorder(t) errHandler := echo.DefaultHTTPErrorHandler(false) for _, opt := range opts { switch o := opt.(type) { case echo.HTTPErrorHandler: errHandler = o } } err := handler(c) if err != nil { errHandler(c, err) } return rec } ================================================ FILE: echotest/context_external_test.go ================================================ package echotest_test import ( "net/http" "testing" "github.com/labstack/echo/v5" "github.com/labstack/echo/v5/echotest" "github.com/stretchr/testify/assert" ) func TestToContext_JSONBody(t *testing.T) { c := echotest.ContextConfig{ JSONBody: echotest.LoadBytes(t, "testdata/test.json"), }.ToContext(t) payload := struct { Field string `json:"field"` }{} if err := c.Bind(&payload); err != nil { t.Fatal(err) } assert.Equal(t, "value", payload.Field) assert.Equal(t, http.MethodPost, c.Request().Method) assert.Equal(t, echo.MIMEApplicationJSON, c.Request().Header.Get(echo.HeaderContentType)) } ================================================ FILE: echotest/context_test.go ================================================ package echotest import ( "net/http" "net/url" "strings" "testing" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestServeWithHandler(t *testing.T) { handler := func(c *echo.Context) error { return c.String(http.StatusOK, c.QueryParam("key")) } testConf := ContextConfig{ QueryValues: url.Values{"key": []string{"value"}}, } resp := testConf.ServeWithHandler(t, handler) assert.Equal(t, http.StatusOK, resp.Code) assert.Equal(t, "value", resp.Body.String()) } func TestServeWithHandler_error(t *testing.T) { handler := func(c *echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, "something went wrong") } testConf := ContextConfig{ QueryValues: url.Values{"key": []string{"value"}}, } customErrHandler := echo.DefaultHTTPErrorHandler(true) resp := testConf.ServeWithHandler(t, handler, customErrHandler) assert.Equal(t, http.StatusBadRequest, resp.Code) assert.Equal(t, `{"message":"something went wrong"}`+"\n", resp.Body.String()) } func TestToContext_QueryValues(t *testing.T) { testConf := ContextConfig{ QueryValues: url.Values{"t": []string{"2006-01-02"}}, } c := testConf.ToContext(t) v, err := echo.QueryParam[string](c, "t") assert.NoError(t, err) assert.Equal(t, "2006-01-02", v) } func TestToContext_Headers(t *testing.T) { testConf := ContextConfig{ Headers: http.Header{echo.HeaderXRequestID: []string{"ABC"}}, } c := testConf.ToContext(t) id := c.Request().Header.Get(echo.HeaderXRequestID) assert.Equal(t, "ABC", id) } func TestToContext_PathValues(t *testing.T) { testConf := ContextConfig{ PathValues: echo.PathValues{{ Name: "key", Value: "value", }}, } c := testConf.ToContext(t) key := c.Param("key") assert.Equal(t, "value", key) } func TestToContext_RouteInfo(t *testing.T) { testConf := ContextConfig{ RouteInfo: &echo.RouteInfo{ Name: "my_route", Method: http.MethodGet, Path: "/:id", Parameters: []string{"id"}, }, } c := testConf.ToContext(t) ri := c.RouteInfo() assert.Equal(t, echo.RouteInfo{ Name: "my_route", Method: http.MethodGet, Path: "/:id", Parameters: []string{"id"}, }, ri) } func TestToContext_FormValues(t *testing.T) { testConf := ContextConfig{ FormValues: url.Values{"key": []string{"value"}}, } c := testConf.ToContext(t) assert.Equal(t, "value", c.FormValue("key")) assert.Equal(t, http.MethodPost, c.Request().Method) assert.Equal(t, echo.MIMEApplicationForm, c.Request().Header.Get(echo.HeaderContentType)) } func TestToContext_MultipartForm(t *testing.T) { testConf := ContextConfig{ MultipartForm: &MultipartForm{ Fields: map[string]string{ "key": "value", }, Files: []MultipartFormFile{ { Fieldname: "file", Filename: "test.json", Content: LoadBytes(t, "testdata/test.json"), }, }, }, } c := testConf.ToContext(t) assert.Equal(t, "value", c.FormValue("key")) assert.Equal(t, http.MethodPost, c.Request().Method) assert.Equal(t, true, strings.HasPrefix(c.Request().Header.Get(echo.HeaderContentType), "multipart/form-data; boundary=")) fv, err := c.FormFile("file") if err != nil { t.Fatal(err) } assert.Equal(t, "test.json", fv.Filename) assert.Equal(t, int64(23), fv.Size) } func TestToContext_JSONBody(t *testing.T) { testConf := ContextConfig{ JSONBody: LoadBytes(t, "testdata/test.json"), } c := testConf.ToContext(t) payload := struct { Field string `json:"field"` }{} if err := c.Bind(&payload); err != nil { t.Fatal(err) } assert.Equal(t, "value", payload.Field) assert.Equal(t, http.MethodPost, c.Request().Method) assert.Equal(t, echo.MIMEApplicationJSON, c.Request().Header.Get(echo.HeaderContentType)) } ================================================ FILE: echotest/reader.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echotest import ( "os" "path/filepath" "runtime" "testing" ) type loadBytesOpts func([]byte) []byte // TrimNewlineEnd instructs LoadBytes to remove `\n` from the end of loaded file. func TrimNewlineEnd(bytes []byte) []byte { bLen := len(bytes) if bLen > 1 && bytes[bLen-1] == '\n' { bytes = bytes[:bLen-1] } return bytes } // LoadBytes is helper to load file contents relative to current (where test file is) package // directory. func LoadBytes(t *testing.T, name string, opts ...loadBytesOpts) []byte { bytes := loadBytes(t, name, 2) for _, f := range opts { bytes = f(bytes) } return bytes } func loadBytes(t *testing.T, name string, callDepth int) []byte { _, b, _, _ := runtime.Caller(callDepth) basepath := filepath.Dir(b) path := filepath.Join(basepath, name) // relative path bytes, err := os.ReadFile(path) if err != nil { t.Fatal(err) } return bytes[:] } ================================================ FILE: echotest/reader_external_test.go ================================================ package echotest_test import ( "strings" "testing" "github.com/labstack/echo/v5/echotest" "github.com/stretchr/testify/assert" ) const testJSONContent = `{ "field": "value" }` func TestLoadBytesOK(t *testing.T) { data := echotest.LoadBytes(t, "testdata/test.json") assert.Equal(t, []byte(testJSONContent+"\n"), data) } func TestLoadBytes_custom(t *testing.T) { data := echotest.LoadBytes(t, "testdata/test.json", func(bytes []byte) []byte { return []byte(strings.ToUpper(string(bytes))) }) assert.Equal(t, []byte(strings.ToUpper(testJSONContent)+"\n"), data) } ================================================ FILE: echotest/reader_test.go ================================================ package echotest import ( "testing" "github.com/stretchr/testify/assert" ) const testJSONContent = `{ "field": "value" }` func TestLoadBytesOK(t *testing.T) { data := LoadBytes(t, "testdata/test.json") assert.Equal(t, []byte(testJSONContent+"\n"), data) } func TestLoadBytesOK_TrimNewlineEnd(t *testing.T) { data := LoadBytes(t, "testdata/test.json", TrimNewlineEnd) assert.Equal(t, []byte(testJSONContent), data) } ================================================ FILE: echotest/testdata/test.json ================================================ { "field": "value" } ================================================ FILE: go.mod ================================================ module github.com/labstack/echo/v5 go 1.25.0 require ( github.com/stretchr/testify v1.11.1 golang.org/x/net v0.49.0 golang.org/x/time v0.14.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/text v0.33.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) ================================================ FILE: go.sum ================================================ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= ================================================ FILE: group.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echo import ( "io/fs" "net/http" ) // Group is a set of sub-routes for a specified route. It can be used for inner // routes that share a common middleware or functionality that should be separate // from the parent echo instance while still inheriting from it. type Group struct { echo *Echo prefix string middleware []MiddlewareFunc } // Use implements `Echo#Use()` for sub-routes within the Group. // Group middlewares are not executed on request when there is no matching route found. func (g *Group) Use(middleware ...MiddlewareFunc) { g.middleware = append(g.middleware, middleware...) } // CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. Panics on error. func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodConnect, path, h, m...) } // DELETE implements `Echo#DELETE()` for sub-routes within the Group. Panics on error. func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodDelete, path, h, m...) } // GET implements `Echo#GET()` for sub-routes within the Group. Panics on error. func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodGet, path, h, m...) } // HEAD implements `Echo#HEAD()` for sub-routes within the Group. Panics on error. func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodHead, path, h, m...) } // OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. Panics on error. func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodOptions, path, h, m...) } // PATCH implements `Echo#PATCH()` for sub-routes within the Group. Panics on error. func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPatch, path, h, m...) } // POST implements `Echo#POST()` for sub-routes within the Group. Panics on error. func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPost, path, h, m...) } // PUT implements `Echo#PUT()` for sub-routes within the Group. Panics on error. func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPut, path, h, m...) } // TRACE implements `Echo#TRACE()` for sub-routes within the Group. Panics on error. func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodTrace, path, h, m...) } // Any implements `Echo#Any()` for sub-routes within the Group. Panics on error. func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { return g.Add(RouteAny, path, handler, middleware...) } // Match implements `Echo#Match()` for sub-routes within the Group. Panics on error. func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { errs := make([]error, 0) ris := make(Routes, 0) for _, m := range methods { ri, err := g.AddRoute(Route{ Method: m, Path: path, Handler: handler, Middlewares: middleware, }) if err != nil { errs = append(errs, err) continue } ris = append(ris, ri) } if len(errs) > 0 { panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage } return ris } // Group creates a new sub-group with prefix and optional sub-group-level middleware. // Important! Group middlewares are only executed in case there was exact route match and not // for 404 (not found) or 405 (method not allowed) cases. If this kind of behaviour is needed then add // a catch-all route `/*` for the group which handler returns always 404 func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) { m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) m = append(m, g.middleware...) m = append(m, middleware...) sg = g.echo.Group(g.prefix+prefix, m...) return } // Static implements `Echo#Static()` for sub-routes within the Group. func (g *Group) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo { subFs := MustSubFS(g.echo.Filesystem, fsRoot) return g.StaticFS(pathPrefix, subFs, middleware...) } // StaticFS implements `Echo#StaticFS()` for sub-routes within the Group. // // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths // including `assets/images` as their prefix. func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo { return g.Add( http.MethodGet, pathPrefix+"*", StaticDirectoryHandler(filesystem, false), middleware..., ) } // FileFS implements `Echo#FileFS()` for sub-routes within the Group. func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo { return g.GET(path, StaticFileHandler(file, filesystem), m...) } // File implements `Echo#File()` for sub-routes within the Group. Panics on error. func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo { handler := func(c *Context) error { return c.File(file) } return g.Add(http.MethodGet, path, handler, middleware...) } // RouteNotFound implements `Echo#RouteNotFound()` for sub-routes within the Group. // // Example: `g.RouteNotFound("/*", func(c *echo.Context) error { return c.NoContent(http.StatusNotFound) })` func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(RouteNotFound, path, h, m...) } // Add implements `Echo#Add()` for sub-routes within the Group. Panics on error. func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { ri, err := g.AddRoute(Route{ Method: method, Path: path, Handler: handler, Middlewares: middleware, }) if err != nil { panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage } return ri } // AddRoute registers a new Routable with Router func (g *Group) AddRoute(route Route) (RouteInfo, error) { // Combine middleware into a new slice to avoid accidentally passing the same slice for // multiple routes, which would lead to later add() calls overwriting the // middleware from earlier calls. groupRoute := route.WithPrefix(g.prefix, append([]MiddlewareFunc{}, g.middleware...)) return g.echo.add(groupRoute) } ================================================ FILE: group_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echo import ( "io/fs" "net/http" "net/http/httptest" "os" "strings" "testing" "github.com/stretchr/testify/assert" ) func TestGroup_withoutRouteWillNotExecuteMiddleware(t *testing.T) { e := New() called := false mw := func(next HandlerFunc) HandlerFunc { return func(c *Context) error { called = true return c.NoContent(http.StatusTeapot) } } // even though group has middleware it will not be executed when there are no routes under that group _ = e.Group("/group", mw) status, body := request(http.MethodGet, "/group/nope", e) assert.Equal(t, http.StatusNotFound, status) assert.Equal(t, `{"message":"Not Found"}`+"\n", body) assert.False(t, called) } func TestGroup_withRoutesWillNotExecuteMiddlewareFor404(t *testing.T) { e := New() called := false mw := func(next HandlerFunc) HandlerFunc { return func(c *Context) error { called = true return c.NoContent(http.StatusTeapot) } } // even though group has middleware and routes when we have no match on some route the middlewares for that // group will not be executed g := e.Group("/group", mw) g.GET("/yes", handlerFunc) status, body := request(http.MethodGet, "/group/nope", e) assert.Equal(t, http.StatusNotFound, status) assert.Equal(t, `{"message":"Not Found"}`+"\n", body) assert.False(t, called) } func TestGroup_multiLevelGroup(t *testing.T) { e := New() api := e.Group("/api") users := api.Group("/users") users.GET("/activate", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) status, body := request(http.MethodGet, "/api/users/activate", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, `OK`, body) } func TestGroupFile(t *testing.T) { e := New() g := e.Group("/group") g.File("/walle", "_fixture/images/walle.png") expectedData, err := os.ReadFile("_fixture/images/walle.png") assert.Nil(t, err) req := httptest.NewRequest(http.MethodGet, "/group/walle", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, expectedData, rec.Body.Bytes()) } func TestGroupRouteMiddleware(t *testing.T) { // Ensure middleware slices are not re-used e := New() g := e.Group("/group") h := func(*Context) error { return nil } m1 := func(next HandlerFunc) HandlerFunc { return func(c *Context) error { return next(c) } } m2 := func(next HandlerFunc) HandlerFunc { return func(c *Context) error { return next(c) } } m3 := func(next HandlerFunc) HandlerFunc { return func(c *Context) error { return next(c) } } m4 := func(next HandlerFunc) HandlerFunc { return func(c *Context) error { return c.NoContent(404) } } m5 := func(next HandlerFunc) HandlerFunc { return func(c *Context) error { return c.NoContent(405) } } g.Use(m1, m2, m3) g.GET("/404", h, m4) g.GET("/405", h, m5) c, _ := request(http.MethodGet, "/group/404", e) assert.Equal(t, 404, c) c, _ = request(http.MethodGet, "/group/405", e) assert.Equal(t, 405, c) } func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { // Ensure middleware and match any routes do not conflict e := New() g := e.Group("/group") m1 := func(next HandlerFunc) HandlerFunc { return func(c *Context) error { return next(c) } } m2 := func(next HandlerFunc) HandlerFunc { return func(c *Context) error { return c.String(http.StatusOK, c.RouteInfo().Path) } } h := func(c *Context) error { return c.String(http.StatusOK, c.RouteInfo().Path) } g.Use(m1) g.GET("/help", h, m2) g.GET("/*", h, m2) g.GET("", h, m2) e.GET("unrelated", h, m2) e.GET("*", h, m2) _, m := request(http.MethodGet, "/group/help", e) assert.Equal(t, "/group/help", m) _, m = request(http.MethodGet, "/group/help/other", e) assert.Equal(t, "/group/*", m) _, m = request(http.MethodGet, "/group/404", e) assert.Equal(t, "/group/*", m) _, m = request(http.MethodGet, "/group", e) assert.Equal(t, "/group", m) _, m = request(http.MethodGet, "/other", e) assert.Equal(t, "/*", m) _, m = request(http.MethodGet, "/", e) assert.Equal(t, "/*", m) } func TestGroup_CONNECT(t *testing.T) { e := New() users := e.Group("/users") ri := users.CONNECT("/activate", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) assert.Equal(t, http.MethodConnect, ri.Method) assert.Equal(t, "/users/activate", ri.Path) assert.Equal(t, http.MethodConnect+":/users/activate", ri.Name) assert.Nil(t, ri.Parameters) status, body := request(http.MethodConnect, "/users/activate", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, `OK`, body) } func TestGroup_DELETE(t *testing.T) { e := New() users := e.Group("/users") ri := users.DELETE("/activate", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) assert.Equal(t, http.MethodDelete, ri.Method) assert.Equal(t, "/users/activate", ri.Path) assert.Equal(t, http.MethodDelete+":/users/activate", ri.Name) assert.Nil(t, ri.Parameters) status, body := request(http.MethodDelete, "/users/activate", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, `OK`, body) } func TestGroup_HEAD(t *testing.T) { e := New() users := e.Group("/users") ri := users.HEAD("/activate", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) assert.Equal(t, http.MethodHead, ri.Method) assert.Equal(t, "/users/activate", ri.Path) assert.Equal(t, http.MethodHead+":/users/activate", ri.Name) assert.Nil(t, ri.Parameters) status, body := request(http.MethodHead, "/users/activate", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, `OK`, body) } func TestGroup_OPTIONS(t *testing.T) { e := New() users := e.Group("/users") ri := users.OPTIONS("/activate", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) assert.Equal(t, http.MethodOptions, ri.Method) assert.Equal(t, "/users/activate", ri.Path) assert.Equal(t, http.MethodOptions+":/users/activate", ri.Name) assert.Nil(t, ri.Parameters) status, body := request(http.MethodOptions, "/users/activate", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, `OK`, body) } func TestGroup_PATCH(t *testing.T) { e := New() users := e.Group("/users") ri := users.PATCH("/activate", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) assert.Equal(t, http.MethodPatch, ri.Method) assert.Equal(t, "/users/activate", ri.Path) assert.Equal(t, http.MethodPatch+":/users/activate", ri.Name) assert.Nil(t, ri.Parameters) status, body := request(http.MethodPatch, "/users/activate", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, `OK`, body) } func TestGroup_POST(t *testing.T) { e := New() users := e.Group("/users") ri := users.POST("/activate", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) assert.Equal(t, http.MethodPost, ri.Method) assert.Equal(t, "/users/activate", ri.Path) assert.Equal(t, http.MethodPost+":/users/activate", ri.Name) assert.Nil(t, ri.Parameters) status, body := request(http.MethodPost, "/users/activate", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, `OK`, body) } func TestGroup_PUT(t *testing.T) { e := New() users := e.Group("/users") ri := users.PUT("/activate", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) assert.Equal(t, http.MethodPut, ri.Method) assert.Equal(t, "/users/activate", ri.Path) assert.Equal(t, http.MethodPut+":/users/activate", ri.Name) assert.Nil(t, ri.Parameters) status, body := request(http.MethodPut, "/users/activate", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, `OK`, body) } func TestGroup_TRACE(t *testing.T) { e := New() users := e.Group("/users") ri := users.TRACE("/activate", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) assert.Equal(t, http.MethodTrace, ri.Method) assert.Equal(t, "/users/activate", ri.Path) assert.Equal(t, http.MethodTrace+":/users/activate", ri.Name) assert.Nil(t, ri.Parameters) status, body := request(http.MethodTrace, "/users/activate", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, `OK`, body) } func TestGroup_RouteNotFound(t *testing.T) { var testCases = []struct { expectRoute any name string whenURL string expectCode int }{ { name: "404, route to static not found handler /group/a/c/xx", whenURL: "/group/a/c/xx", expectRoute: "GET /group/a/c/xx", expectCode: http.StatusNotFound, }, { name: "404, route to path param not found handler /group/a/:file", whenURL: "/group/a/echo.exe", expectRoute: "GET /group/a/:file", expectCode: http.StatusNotFound, }, { name: "404, route to any not found handler /group/*", whenURL: "/group/b/echo.exe", expectRoute: "GET /group/*", expectCode: http.StatusNotFound, }, { name: "200, route /group/a/c/df to /group/a/c/df", whenURL: "/group/a/c/df", expectRoute: "GET /group/a/c/df", expectCode: http.StatusOK, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() g := e.Group("/group") okHandler := func(c *Context) error { return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) } notFoundHandler := func(c *Context) error { return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) } g.GET("/", okHandler) g.GET("/a/c/df", okHandler) g.GET("/a/b*", okHandler) g.PUT("/*", okHandler) g.RouteNotFound("/a/c/xx", notFoundHandler) // static g.RouteNotFound("/a/:file", notFoundHandler) // param g.RouteNotFound("/*", notFoundHandler) // any req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, tc.expectCode, rec.Code) assert.Equal(t, tc.expectRoute, rec.Body.String()) }) } } func TestGroup_Any(t *testing.T) { e := New() users := e.Group("/users") ri := users.Any("/activate", func(c *Context) error { return c.String(http.StatusTeapot, "OK from ANY") }) assert.Equal(t, RouteAny, ri.Method) assert.Equal(t, "/users/activate", ri.Path) assert.Equal(t, RouteAny+":/users/activate", ri.Name) assert.Nil(t, ri.Parameters) status, body := request(http.MethodTrace, "/users/activate", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, `OK from ANY`, body) } func TestGroup_Match(t *testing.T) { e := New() myMethods := []string{http.MethodGet, http.MethodPost} users := e.Group("/users") ris := users.Match(myMethods, "/activate", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) assert.Len(t, ris, 2) for _, m := range myMethods { status, body := request(m, "/users/activate", e) assert.Equal(t, http.StatusTeapot, status) assert.Equal(t, `OK`, body) } } func TestGroup_MatchWithErrors(t *testing.T) { e := New() users := e.Group("/users") users.GET("/activate", func(c *Context) error { return c.String(http.StatusOK, "OK") }) myMethods := []string{http.MethodGet, http.MethodPost} errs := func() (errs []error) { defer func() { if r := recover(); r != nil { if tmpErr, ok := r.([]error); ok { errs = tmpErr return } panic(r) } }() users.Match(myMethods, "/activate", func(c *Context) error { return c.String(http.StatusTeapot, "OK") }) return nil }() assert.Len(t, errs, 1) assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed") for _, m := range myMethods { status, body := request(m, "/users/activate", e) expect := http.StatusTeapot if m == http.MethodGet { expect = http.StatusOK } assert.Equal(t, expect, status) assert.Equal(t, `OK`, body) } } func TestGroup_Static(t *testing.T) { e := New() g := e.Group("/books") ri := g.Static("/download", "_fixture") assert.Equal(t, http.MethodGet, ri.Method) assert.Equal(t, "/books/download*", ri.Path) assert.Equal(t, "GET:/books/download*", ri.Name) assert.Equal(t, []string{"*"}, ri.Parameters) req := httptest.NewRequest(http.MethodGet, "/books/download/index.html", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) body := rec.Body.String() assert.True(t, strings.HasPrefix(body, "")) } func TestGroup_StaticMultiTest(t *testing.T) { var testCases = []struct { name string givenPrefix string givenRoot string whenURL string expectHeaderLocation string expectBodyStartsWith string expectBodyNotContains string expectStatus int }{ { name: "ok", givenPrefix: "/images", givenRoot: "_fixture/images", whenURL: "/test/images/walle.png", expectStatus: http.StatusOK, expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), }, { name: "ok, without prefix", givenPrefix: "", givenRoot: "_fixture/images", whenURL: "/testwalle.png", // `/test` + `*` creates route `/test*` witch matches `/testwalle.png` expectStatus: http.StatusOK, expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), }, { name: "nok, without prefix does not serve dir index", givenPrefix: "", givenRoot: "_fixture/images", whenURL: "/test/", // `/test` + `*` creates route `/test*` expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", }, { name: "No file", givenPrefix: "/images", givenRoot: "_fixture/scripts", whenURL: "/test/images/bolt.png", expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", }, { name: "Directory", givenPrefix: "/images", givenRoot: "_fixture/images", whenURL: "/test/images/", expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", }, { name: "Directory Redirect", givenPrefix: "/", givenRoot: "_fixture", whenURL: "/test/folder", expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/test/folder/", expectBodyStartsWith: "", }, { name: "Directory Redirect with non-root path", givenPrefix: "/static", givenRoot: "_fixture", whenURL: "/test/static", expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/test/static/", expectBodyStartsWith: "", }, { name: "Prefixed directory 404 (request URL without slash)", givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" givenRoot: "_fixture", whenURL: "/test/folder", // no trailing slash expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", }, { name: "Prefixed directory redirect (without slash redirect to slash)", givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* givenRoot: "_fixture", whenURL: "/test/folder", // no trailing slash expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/test/folder/", expectBodyStartsWith: "", }, { name: "Directory with index.html", givenPrefix: "/", givenRoot: "_fixture", whenURL: "/test/", expectStatus: http.StatusOK, expectBodyStartsWith: "", }, { name: "Prefixed directory with index.html (prefix ending with slash)", givenPrefix: "/assets/", givenRoot: "_fixture", whenURL: "/test/assets/", expectStatus: http.StatusOK, expectBodyStartsWith: "", }, { name: "Prefixed directory with index.html (prefix ending without slash)", givenPrefix: "/assets", givenRoot: "_fixture", whenURL: "/test/assets/", expectStatus: http.StatusOK, expectBodyStartsWith: "", }, { name: "Sub-directory with index.html", givenPrefix: "/", givenRoot: "_fixture", whenURL: "/test/folder/", expectStatus: http.StatusOK, expectBodyStartsWith: "", }, { name: "nok, URL encoded path traversal (single encoding, slash - unix separator)", givenRoot: "_fixture/dist/public", whenURL: "/%2e%2e%2fprivate.txt", expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", expectBodyNotContains: `private file`, }, { name: "nok, URL encoded path traversal (single encoding, backslash - windows separator)", givenRoot: "_fixture/dist/public", whenURL: "/%2e%2e%5cprivate.txt", expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", expectBodyNotContains: `private file`, }, { name: "do not allow directory traversal (backslash - windows separator)", givenPrefix: "/", givenRoot: "_fixture/", whenURL: `/test/..\\middleware/basic_auth.go`, expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", }, { name: "do not allow directory traversal (slash - unix separator)", givenPrefix: "/", givenRoot: "_fixture/", whenURL: `/test/../middleware/basic_auth.go`, expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() g := e.Group("/test") g.Static(tc.givenPrefix, tc.givenRoot) req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, tc.expectStatus, rec.Code) body := rec.Body.String() if tc.expectBodyStartsWith != "" { assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) } else { assert.Equal(t, "", body) } if tc.expectBodyNotContains != "" { assert.NotContains(t, body, tc.expectBodyNotContains) } if tc.expectHeaderLocation != "" { assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0]) } else { _, ok := rec.Result().Header["Location"] assert.False(t, ok) } }) } } func TestGroup_FileFS(t *testing.T) { var testCases = []struct { whenFS fs.FS name string whenPath string whenFile string givenURL string expectStartsWith []byte expectCode int }{ { name: "ok", whenPath: "/walle", whenFS: os.DirFS("_fixture/images"), whenFile: "walle.png", givenURL: "/assets/walle", expectCode: http.StatusOK, expectStartsWith: []byte{0x89, 0x50, 0x4e}, }, { name: "nok, requesting invalid path", whenPath: "/walle", whenFS: os.DirFS("_fixture/images"), whenFile: "walle.png", givenURL: "/assets/walle.png", expectCode: http.StatusNotFound, expectStartsWith: []byte(`{"message":"Not Found"}`), }, { name: "nok, serving not existent file from filesystem", whenPath: "/walle", whenFS: os.DirFS("_fixture/images"), whenFile: "not-existent.png", givenURL: "/assets/walle", expectCode: http.StatusNotFound, expectStartsWith: []byte(`{"message":"Not Found"}`), }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() g := e.Group("/assets") g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, tc.expectCode, rec.Code) body := rec.Body.Bytes() if len(body) > len(tc.expectStartsWith) { body = body[:len(tc.expectStartsWith)] } assert.Equal(t, tc.expectStartsWith, body) }) } } func TestGroup_StaticPanic(t *testing.T) { var testCases = []struct { name string givenRoot string }{ { name: "panics for ../", givenRoot: "../images", }, { name: "panics for /", givenRoot: "/images", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() e.Filesystem = os.DirFS("./") g := e.Group("/assets") assert.Panics(t, func() { g.Static("/images", tc.givenRoot) }) }) } } func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) { var testCases = []struct { expectBody any name string whenURL string expectCode int givenCustom404 bool expectMiddlewareCalled bool }{ { name: "ok, custom 404 handler is called with middleware", givenCustom404: true, whenURL: "/group/test3", expectBody: "404 GET /group/*", expectCode: http.StatusNotFound, expectMiddlewareCalled: true, // because RouteNotFound is added after middleware is added }, { name: "ok, default group 404 handler is not called with middleware", givenCustom404: false, whenURL: "/group/test3", expectBody: "404 GET /*", expectCode: http.StatusNotFound, expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added }, { name: "ok, (no slash) default group 404 handler is called with middleware", givenCustom404: false, whenURL: "/group", expectBody: "404 GET /*", expectCode: http.StatusNotFound, expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { okHandler := func(c *Context) error { return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) } notFoundHandler := func(c *Context) error { return c.String(http.StatusNotFound, "404 "+c.Request().Method+" "+c.Path()) } e := New() e.GET("/test1", okHandler) e.RouteNotFound("/*", notFoundHandler) g := e.Group("/group") g.GET("/test1", okHandler) middlewareCalled := false g.Use(func(next HandlerFunc) HandlerFunc { return func(c *Context) error { middlewareCalled = true return next(c) } }) if tc.givenCustom404 { g.RouteNotFound("/*", notFoundHandler) } req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, tc.expectMiddlewareCalled, middlewareCalled) assert.Equal(t, tc.expectCode, rec.Code) assert.Equal(t, tc.expectBody, rec.Body.String()) }) } } ================================================ FILE: httperror.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echo import ( "errors" "fmt" "net/http" ) // Following errors can produce HTTP status code by implementing HTTPStatusCoder interface var ( ErrBadRequest = &httpError{http.StatusBadRequest} // 400 ErrUnauthorized = &httpError{http.StatusUnauthorized} // 401 ErrForbidden = &httpError{http.StatusForbidden} // 403 ErrNotFound = &httpError{http.StatusNotFound} // 404 ErrMethodNotAllowed = &httpError{http.StatusMethodNotAllowed} // 405 ErrRequestTimeout = &httpError{http.StatusRequestTimeout} // 408 ErrStatusRequestEntityTooLarge = &httpError{http.StatusRequestEntityTooLarge} // 413 ErrUnsupportedMediaType = &httpError{http.StatusUnsupportedMediaType} // 415 ErrTooManyRequests = &httpError{http.StatusTooManyRequests} // 429 ErrInternalServerError = &httpError{http.StatusInternalServerError} // 500 ErrBadGateway = &httpError{http.StatusBadGateway} // 502 ErrServiceUnavailable = &httpError{http.StatusServiceUnavailable} // 503 ) // Following errors fall into 500 (InternalServerError) category var ( ErrValidatorNotRegistered = errors.New("validator not registered") ErrRendererNotRegistered = errors.New("renderer not registered") ErrInvalidRedirectCode = errors.New("invalid redirect status code") ErrCookieNotFound = errors.New("cookie not found") ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") ErrInvalidListenerNetwork = errors.New("invalid listener network") ) // HTTPStatusCoder is interface that errors can implement to produce status code for HTTP response type HTTPStatusCoder interface { StatusCode() int } // StatusCode returns status code from error if it implements HTTPStatusCoder interface. // If error does not implement the interface it returns 0. func StatusCode(err error) int { var sc HTTPStatusCoder if errors.As(err, &sc) { return sc.StatusCode() } return 0 } // ResolveResponseStatus returns the Response and HTTP status code that should be (or has been) sent for rw, // given an optional error. // // This function is useful for middleware and handlers that need to figure out the HTTP status // code to return based on the error that occurred or what was set in the response. // // Precedence rules: // 1. If the response has already been committed, the committed status wins (err is ignored). // 2. Otherwise, start with 200 OK (net/http default if WriteHeader is never called). // 3. If the response has a non-zero suggested status, use it. // 4. If err != nil, it overrides the suggested status: // - StatusCode(err) if non-zero // - otherwise 500 Internal Server Error. func ResolveResponseStatus(rw http.ResponseWriter, err error) (resp *Response, status int) { resp, _ = UnwrapResponse(rw) // once committed (sent to the client), the wire status is fixed; err cannot change it. if resp != nil && resp.Committed { if resp.Status == 0 { // unlikely path, but fall back to net/http implicit default if handler never calls WriteHeader return resp, http.StatusOK } return resp, resp.Status } // net/http implicit default if handler never calls WriteHeader. status = http.StatusOK // suggested status written from middleware/handlers, if present. if resp != nil && resp.Status != 0 { status = resp.Status } // error overrides suggested status (matches typical Echo error-handler semantics). if err != nil { if s := StatusCode(err); s != 0 { status = s } else { status = http.StatusInternalServerError } } return resp, status } // NewHTTPError creates new instance of HTTPError func NewHTTPError(code int, message string) *HTTPError { return &HTTPError{ Code: code, Message: message, } } // HTTPError represents an error that occurred while handling a request. type HTTPError struct { // Code is status code for HTTP response Code int `json:"-"` Message string `json:"message"` err error } // StatusCode returns status code for HTTP response func (he *HTTPError) StatusCode() int { return he.Code } // Error makes it compatible with `error` interface. func (he *HTTPError) Error() string { msg := he.Message if msg == "" { msg = http.StatusText(he.Code) } if he.err == nil { return fmt.Sprintf("code=%d, message=%v", he.Code, msg) } return fmt.Sprintf("code=%d, message=%v, err=%v", he.Code, msg, he.err.Error()) } // Wrap eturns new HTTPError with given errors wrapped inside func (he HTTPError) Wrap(err error) error { return &HTTPError{ Code: he.Code, Message: he.Message, err: err, } } func (he *HTTPError) Unwrap() error { return he.err } type httpError struct { code int } func (he httpError) StatusCode() int { return he.code } func (he httpError) Error() string { return http.StatusText(he.code) // does not include status code } func (he httpError) Wrap(err error) error { return &HTTPError{ Code: he.code, Message: http.StatusText(he.code), err: err, } } ================================================ FILE: httperror_external_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors // run tests as external package to get real feel for API package echo_test import ( "encoding/json" "fmt" "github.com/labstack/echo/v5" "net/http" "net/http/httptest" ) func ExampleDefaultHTTPErrorHandler() { e := echo.New() e.GET("/api/endpoint", func(c *echo.Context) error { return &apiError{ Code: http.StatusBadRequest, Body: map[string]any{"message": "custom error"}, } }) req := httptest.NewRequest(http.MethodGet, "/api/endpoint?err=1", nil) resp := httptest.NewRecorder() e.ServeHTTP(resp, req) fmt.Printf("%d %s", resp.Code, resp.Body.String()) // Output: 400 {"error":{"message":"custom error"}} } type apiError struct { Code int Body any } func (e *apiError) StatusCode() int { return e.Code } func (e *apiError) MarshalJSON() ([]byte, error) { type body struct { Error any `json:"error"` } return json.Marshal(body{Error: e.Body}) } func (e *apiError) Error() string { return http.StatusText(e.Code) } ================================================ FILE: httperror_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echo import ( "errors" "fmt" "net/http" "testing" "github.com/stretchr/testify/assert" ) func TestHTTPError_StatusCode(t *testing.T) { var err error = &HTTPError{Code: http.StatusBadRequest, Message: "my error message"} code := 0 var sc HTTPStatusCoder if errors.As(err, &sc) { code = sc.StatusCode() } assert.Equal(t, http.StatusBadRequest, code) } func TestHTTPError_Error(t *testing.T) { var testCases = []struct { name string error error expect string }{ { name: "ok, without message", error: &HTTPError{Code: http.StatusBadRequest}, expect: "code=400, message=Bad Request", }, { name: "ok, with message", error: &HTTPError{Code: http.StatusBadRequest, Message: "my error message"}, expect: "code=400, message=my error message", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { assert.Equal(t, tc.expect, tc.error.Error()) }) } } func TestHTTPError_WrapUnwrap(t *testing.T) { err := &HTTPError{Code: http.StatusBadRequest, Message: "bad"} wrapped := err.Wrap(errors.New("my_error")).(*HTTPError) err.Code = http.StatusOK err.Message = "changed" assert.Equal(t, http.StatusBadRequest, wrapped.Code) assert.Equal(t, "bad", wrapped.Message) assert.Equal(t, errors.New("my_error"), wrapped.Unwrap()) assert.Equal(t, "code=400, message=bad, err=my_error", wrapped.Error()) } func TestNewHTTPError(t *testing.T) { err := NewHTTPError(http.StatusBadRequest, "bad") err2 := &HTTPError{Code: http.StatusBadRequest, Message: "bad"} assert.Equal(t, err2, err) } func TestStatusCode(t *testing.T) { var testCases = []struct { name string err error expect int }{ { name: "ok, HTTPError", err: &HTTPError{Code: http.StatusNotFound}, expect: http.StatusNotFound, }, { name: "ok, sentinel error", err: ErrNotFound, expect: http.StatusNotFound, }, { name: "ok, wrapped HTTPError", err: fmt.Errorf("wrapped: %w", &HTTPError{Code: http.StatusTeapot}), expect: http.StatusTeapot, }, { name: "nok, normal error", err: errors.New("error"), expect: 0, }, { name: "nok, nil", err: nil, expect: 0, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { assert.Equal(t, tc.expect, StatusCode(tc.err)) }) } } func TestResolveResponseStatus(t *testing.T) { someErr := errors.New("some error") var testCases = []struct { name string whenResp http.ResponseWriter whenErr error expectStatus int expectResp bool }{ { name: "nil resp, nil err -> 200", whenResp: nil, whenErr: nil, expectStatus: http.StatusOK, expectResp: false, }, { name: "resp suggested status used when no error", whenResp: &Response{Status: http.StatusCreated}, whenErr: nil, expectStatus: http.StatusCreated, expectResp: true, }, { name: "error overrides suggested status with StatusCode(err)", whenResp: &Response{Status: http.StatusAccepted}, whenErr: ErrBadRequest, expectStatus: http.StatusBadRequest, expectResp: true, }, { name: "error overrides suggested status with 500 when StatusCode(err)==0", whenResp: &Response{Status: http.StatusAccepted}, whenErr: ErrInternalServerError, expectStatus: http.StatusInternalServerError, expectResp: true, }, { name: "nil resp, error -> 500 when StatusCode(err)==0", whenResp: nil, whenErr: someErr, expectStatus: http.StatusInternalServerError, expectResp: false, }, { name: "committed response wins over error", whenResp: &Response{Committed: true, Status: http.StatusNoContent}, whenErr: someErr, expectStatus: http.StatusNoContent, expectResp: true, }, { name: "committed response with status 0 falls back to 200 (defensive)", whenResp: &Response{Committed: true, Status: 0}, whenErr: someErr, expectStatus: http.StatusOK, expectResp: true, }, { name: "resp with status 0 and no error -> 200", whenResp: &Response{Status: 0}, whenErr: nil, expectStatus: http.StatusOK, expectResp: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { resp, status := ResolveResponseStatus(tc.whenResp, tc.whenErr) assert.Equal(t, tc.expectResp, resp != nil) assert.Equal(t, tc.expectStatus, status) }) } } ================================================ FILE: ip.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echo import ( "net" "net/http" "strings" ) /** By: https://github.com/tmshn (See: https://github.com/labstack/echo/pull/1478 , https://github.com/labstack/echox/pull/134 ) Source: https://echo.labstack.com/guide/ip-address/ IP address plays fundamental role in HTTP; it's used for access control, auditing, geo-based access analysis and more. Echo provides handy method [`Context#RealIP()`](https://godoc.org/github.com/labstack/echo#Context) for that. However, it is not trivial to retrieve the _real_ IP address from requests especially when you put L7 proxies before the application. In such situation, _real_ IP needs to be relayed on HTTP layer from proxies to your app, but you must not trust HTTP headers unconditionally. Otherwise, you might give someone a chance of deceiving you. **A security risk!** To retrieve IP address reliably/securely, you must let your application be aware of the entire architecture of your infrastructure. In Echo, this can be done by configuring `Echo#IPExtractor` appropriately. This guides show you why and how. > Note: if you don't set `Echo#IPExtractor` explicitly, Echo fallback to legacy behavior, which is not a good choice. Let's start from two questions to know the right direction: 1. Do you put any HTTP (L7) proxy in front of the application? - It includes both cloud solutions (such as AWS ALB or GCP HTTP LB) and OSS ones (such as Nginx, Envoy or Istio ingress gateway). 2. If yes, what HTTP header do your proxies use to pass client IP to the application? ## Case 1. With no proxy If you put no proxy (e.g.: directory facing to the internet), all you need to (and have to) see is IP address from network layer. Any HTTP header is untrustable because the clients have full control what headers to be set. In this case, use `echo.ExtractIPDirect()`. ```go e.IPExtractor = echo.ExtractIPDirect() ``` ## Case 2. With proxies using `X-Forwarded-For` header [`X-Forwared-For` (XFF)](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For) is the popular header to relay clients' IP addresses. At each hop on the proxies, they append the request IP address at the end of the header. Following example diagram illustrates this behavior. ```text ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ "Origin" │───────────>│ Proxy 1 │───────────>│ Proxy 2 │───────────>│ Your app │ │ (IP: a) │ │ (IP: b) │ │ (IP: c) │ │ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ Case 1. XFF: "" "a" "a, b" ~~~~~~ Case 2. XFF: "x" "x, a" "x, a, b" ~~~~~~~~~ ↑ What your app will see ``` In this case, use **first _untrustable_ IP reading from right**. Never use first one reading from left, as it is configurable by client. Here "trustable" means "you are sure the IP address belongs to your infrastructure". In above example, if `b` and `c` are trustable, the IP address of the client is `a` for both cases, never be `x`. In Echo, use `ExtractIPFromXFFHeader(...TrustOption)`. ```go e.IPExtractor = echo.ExtractIPFromXFFHeader() ``` By default, it trusts internal IP addresses (loopback, link-local unicast, private-use and unique local address from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and [RFC4193](https://tools.ietf.org/html/rfc4193)). To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s. E.g.: ```go e.IPExtractor = echo.ExtractIPFromXFFHeader( TrustLinkLocal(false), TrustIPRanges(lbIPRange), ) ``` - Ref: https://godoc.org/github.com/labstack/echo#TrustOption ## Case 3. With proxies using `X-Real-IP` header `X-Real-IP` is another HTTP header to relay clients' IP addresses, but it carries only one address unlike XFF. If your proxies set this header, use `ExtractIPFromRealIPHeader(...TrustOption)`. ```go e.IPExtractor = echo.ExtractIPFromRealIPHeader() ``` Again, it trusts internal IP addresses by default (loopback, link-local unicast, private-use and unique local address from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and [RFC4193](https://tools.ietf.org/html/rfc4193)). To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s. - Ref: https://godoc.org/github.com/labstack/echo#TrustOption > **Never forget** to configure the outermost proxy (i.e.; at the edge of your infrastructure) **not to pass through incoming headers**. > Otherwise there is a chance of fraud, as it is what clients can control. ## About default behavior In default behavior, Echo sees all of first XFF header, X-Real-IP header and IP from network layer. As you might already notice, after reading this article, this is not good. Sole reason this is default is just backward compatibility. ## Private IP ranges See: https://en.wikipedia.org/wiki/Private_network Private IPv4 address ranges (RFC 1918): * 10.0.0.0 – 10.255.255.255 (24-bit block) * 172.16.0.0 – 172.31.255.255 (20-bit block) * 192.168.0.0 – 192.168.255.255 (16-bit block) Private IPv6 address ranges: * fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA) */ type ipChecker struct { trustExtraRanges []*net.IPNet trustLoopback bool trustLinkLocal bool trustPrivateNet bool } // TrustOption is config for which IP address to trust type TrustOption func(*ipChecker) // TrustLoopback configures if you trust loopback address (default: true). func TrustLoopback(v bool) TrustOption { return func(c *ipChecker) { c.trustLoopback = v } } // TrustLinkLocal configures if you trust link-local address (default: true). func TrustLinkLocal(v bool) TrustOption { return func(c *ipChecker) { c.trustLinkLocal = v } } // TrustPrivateNet configures if you trust private network address (default: true). func TrustPrivateNet(v bool) TrustOption { return func(c *ipChecker) { c.trustPrivateNet = v } } // TrustIPRange add trustable IP ranges using CIDR notation. func TrustIPRange(ipRange *net.IPNet) TrustOption { return func(c *ipChecker) { c.trustExtraRanges = append(c.trustExtraRanges, ipRange) } } func newIPChecker(configs []TrustOption) *ipChecker { checker := &ipChecker{trustLoopback: true, trustLinkLocal: true, trustPrivateNet: true} for _, configure := range configs { configure(checker) } return checker } func (c *ipChecker) trust(ip net.IP) bool { if c.trustLoopback && ip.IsLoopback() { return true } if c.trustLinkLocal && ip.IsLinkLocalUnicast() { return true } if c.trustPrivateNet && ip.IsPrivate() { return true } for _, trustedRange := range c.trustExtraRanges { if trustedRange.Contains(ip) { return true } } return false } // IPExtractor is a function to extract IP addr from http.Request. // Set appropriate one to Echo#IPExtractor. // See https://echo.labstack.com/guide/ip-address for more details. type IPExtractor func(*http.Request) string // ExtractIPDirect extracts IP address using actual IP address. // Use this if your server faces to internet directory (i.e.: uses no proxy). func ExtractIPDirect() IPExtractor { return extractIP } func extractIP(req *http.Request) string { host, _, err := net.SplitHostPort(req.RemoteAddr) if err != nil { if net.ParseIP(req.RemoteAddr) != nil { return req.RemoteAddr } return "" } return host } // ExtractIPFromRealIPHeader extracts IP address using x-real-ip header. // Use this if you put proxy which uses this header. func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor { checker := newIPChecker(options) return func(req *http.Request) string { realIP := req.Header.Get(HeaderXRealIP) if realIP != "" { realIP = strings.TrimPrefix(realIP, "[") realIP = strings.TrimSuffix(realIP, "]") if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) { return realIP } } return extractIP(req) } } // ExtractIPFromXFFHeader extracts IP address using x-forwarded-for header. // Use this if you put proxy which uses this header. // This returns nearest untrustable IP. If all IPs are trustable, returns furthest one (i.e.: XFF[0]). func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor { checker := newIPChecker(options) return func(req *http.Request) string { directIP := extractIP(req) xffs := req.Header[HeaderXForwardedFor] if len(xffs) == 0 { return directIP } ips := append(strings.Split(strings.Join(xffs, ","), ","), directIP) for i := len(ips) - 1; i >= 0; i-- { ips[i] = strings.TrimSpace(ips[i]) ips[i] = strings.TrimPrefix(ips[i], "[") ips[i] = strings.TrimSuffix(ips[i], "]") ip := net.ParseIP(ips[i]) if ip == nil { // Unable to parse IP; cannot trust entire records return directIP } if !checker.trust(ip) { return ip.String() } } // All of the IPs are trusted; return first element because it is furthest from server (best effort strategy). return strings.TrimSpace(ips[0]) } } ================================================ FILE: ip_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echo import ( "net" "net/http" "testing" "github.com/stretchr/testify/assert" ) func mustParseCIDR(s string) *net.IPNet { _, IPNet, err := net.ParseCIDR(s) if err != nil { panic(err) } return IPNet } func TestIPChecker_TrustOption(t *testing.T) { var testCases = []struct { name string whenIP string givenOptions []TrustOption expect bool }{ { name: "ip is within trust range, trusts additional private IPV6 network", givenOptions: []TrustOption{ TrustLoopback(false), TrustLinkLocal(false), TrustPrivateNet(false), // this is private IPv6 ip // CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48 // Address: 2001:0db8:0000:0000:0000:0000:0000:0103 // Range start: 2001:0db8:0000:0000:0000:0000:0000:0000 // Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff TrustIPRange(mustParseCIDR("2001:db8::103/48")), }, whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103", expect: true, }, { name: "ip is within trust range, trusts additional private IPV6 network", givenOptions: []TrustOption{ TrustIPRange(mustParseCIDR("2001:db8::103/48")), }, whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103", expect: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { checker := newIPChecker(tc.givenOptions) result := checker.trust(net.ParseIP(tc.whenIP)) assert.Equal(t, tc.expect, result) }) } } func TestTrustIPRange(t *testing.T) { var testCases = []struct { name string givenRange string whenIP string expect bool }{ { name: "ip is within trust range, IPV6 network range", // CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48 // Address: 2001:0db8:0000:0000:0000:0000:0000:0103 // Range start: 2001:0db8:0000:0000:0000:0000:0000:0000 // Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff givenRange: "2001:db8::103/48", whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103", expect: true, }, { name: "ip is outside (upper bounds) of trust range, IPV6 network range", givenRange: "2001:db8::103/48", whenIP: "2001:0db8:0001:0000:0000:0000:0000:0000", expect: false, }, { name: "ip is outside (lower bounds) of trust range, IPV6 network range", givenRange: "2001:db8::103/48", whenIP: "2001:0db7:ffff:ffff:ffff:ffff:ffff:ffff", expect: false, }, { name: "ip is within trust range, IPV4 network range", // CIDR Notation: 8.8.8.8/24 // Address: 8.8.8.8 // Range start: 8.8.8.0 // Range end: 8.8.8.255 givenRange: "8.8.8.0/24", whenIP: "8.8.8.8", expect: true, }, { name: "ip is within trust range, IPV4 network range", // CIDR Notation: 8.8.8.8/24 // Address: 8.8.8.8 // Range start: 8.8.8.0 // Range end: 8.8.8.255 givenRange: "8.8.8.0/24", whenIP: "8.8.8.8", expect: true, }, { name: "ip is outside (upper bounds) of trust range, IPV4 network range", givenRange: "8.8.8.0/24", whenIP: "8.8.9.0", expect: false, }, { name: "ip is outside (lower bounds) of trust range, IPV4 network range", givenRange: "8.8.8.0/24", whenIP: "8.8.7.255", expect: false, }, { name: "public ip, trust everything in IPV4 network range", givenRange: "0.0.0.0/0", whenIP: "8.8.8.8", expect: true, }, { name: "internal ip, trust everything in IPV4 network range", givenRange: "0.0.0.0/0", whenIP: "127.0.10.1", expect: true, }, { name: "public ip, trust everything in IPV6 network range", givenRange: "::/0", whenIP: "2a00:1450:4026:805::200e", expect: true, }, { name: "internal ip, trust everything in IPV6 network range", givenRange: "::/0", whenIP: "0:0:0:0:0:0:0:1", expect: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { cidr := mustParseCIDR(tc.givenRange) checker := newIPChecker([]TrustOption{ TrustLoopback(false), // disable to avoid interference TrustLinkLocal(false), // disable to avoid interference TrustPrivateNet(false), // disable to avoid interference TrustIPRange(cidr), }) result := checker.trust(net.ParseIP(tc.whenIP)) assert.Equal(t, tc.expect, result) }) } } func TestTrustPrivateNet(t *testing.T) { var testCases = []struct { name string whenIP string expect bool }{ { name: "do not trust public IPv4 address", whenIP: "8.8.8.8", expect: false, }, { name: "do not trust public IPv6 address", whenIP: "2a00:1450:4026:805::200e", expect: false, }, { // Class A: 10.0.0.0 — 10.255.255.255 name: "do not trust IPv4 just outside of class A (lower bounds)", whenIP: "9.255.255.255", expect: false, }, { name: "do not trust IPv4 just outside of class A (upper bounds)", whenIP: "11.0.0.0", expect: false, }, { name: "trust IPv4 of class A (lower bounds)", whenIP: "10.0.0.0", expect: true, }, { name: "trust IPv4 of class A (upper bounds)", whenIP: "10.255.255.255", expect: true, }, { // Class B: 172.16.0.0 — 172.31.255.255 name: "do not trust IPv4 just outside of class B (lower bounds)", whenIP: "172.15.255.255", expect: false, }, { name: "do not trust IPv4 just outside of class B (upper bounds)", whenIP: "172.32.0.0", expect: false, }, { name: "trust IPv4 of class B (lower bounds)", whenIP: "172.16.0.0", expect: true, }, { name: "trust IPv4 of class B (upper bounds)", whenIP: "172.31.255.255", expect: true, }, { // Class C: 192.168.0.0 — 192.168.255.255 name: "do not trust IPv4 just outside of class C (lower bounds)", whenIP: "192.167.255.255", expect: false, }, { name: "do not trust IPv4 just outside of class C (upper bounds)", whenIP: "192.169.0.0", expect: false, }, { name: "trust IPv4 of class C (lower bounds)", whenIP: "192.168.0.0", expect: true, }, { name: "trust IPv4 of class C (upper bounds)", whenIP: "192.168.255.255", expect: true, }, { // fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA) // splits the address block in two equally sized halves, fc00::/8 and fd00::/8. // https://en.wikipedia.org/wiki/Unique_local_address name: "trust IPv6 private address", whenIP: "fdfc:3514:2cb3:4bd5::", expect: true, }, { name: "do not trust IPv6 just out of /fd (upper bounds)", whenIP: "/fe00:0000:0000:0000:0000", expect: false, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { checker := newIPChecker([]TrustOption{ TrustLoopback(false), // disable to avoid interference TrustLinkLocal(false), // disable to avoid interference TrustPrivateNet(true), }) result := checker.trust(net.ParseIP(tc.whenIP)) assert.Equal(t, tc.expect, result) }) } } func TestTrustLinkLocal(t *testing.T) { var testCases = []struct { name string whenIP string expect bool }{ { name: "trust link local IPv4 address (lower bounds)", whenIP: "169.254.0.0", expect: true, }, { name: "trust link local IPv4 address (upper bounds)", whenIP: "169.254.255.255", expect: true, }, { name: "do not trust link local IPv4 address (outside of lower bounds)", whenIP: "169.253.255.255", expect: false, }, { name: "do not trust link local IPv4 address (outside of upper bounds)", whenIP: "169.255.0.0", expect: false, }, { name: "trust link local IPv6 address ", whenIP: "fe80::1", expect: true, }, { name: "do not trust link local IPv6 address ", whenIP: "fec0::1", expect: false, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { checker := newIPChecker([]TrustOption{ TrustLoopback(false), // disable to avoid interference TrustPrivateNet(false), // disable to avoid interference TrustLinkLocal(true), }) result := checker.trust(net.ParseIP(tc.whenIP)) assert.Equal(t, tc.expect, result) }) } } func TestTrustLoopback(t *testing.T) { var testCases = []struct { name string whenIP string expect bool }{ { name: "trust IPv4 as localhost", whenIP: "127.0.0.1", expect: true, }, { name: "trust IPv6 as localhost", whenIP: "::1", expect: true, }, { name: "do not trust public ip as localhost", whenIP: "8.8.8.8", expect: false, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { checker := newIPChecker([]TrustOption{ TrustLinkLocal(false), // disable to avoid interference TrustPrivateNet(false), // disable to avoid interference TrustLoopback(true), }) result := checker.trust(net.ParseIP(tc.whenIP)) assert.Equal(t, tc.expect, result) }) } } func TestExtractIPDirect(t *testing.T) { var testCases = []struct { name string whenRequest http.Request expectIP string }{ { name: "request has no headers, extracts IP from request remote addr", whenRequest: http.Request{ RemoteAddr: "203.0.113.1:8080", }, expectIP: "203.0.113.1", }, { name: "remote addr is IP without port, extracts IP directly", whenRequest: http.Request{ RemoteAddr: "203.0.113.1", }, expectIP: "203.0.113.1", }, { name: "remote addr is IPv6 without port, extracts IP directly", whenRequest: http.Request{ RemoteAddr: "2001:db8::1", }, expectIP: "2001:db8::1", }, { name: "remote addr is IPv6 with port", whenRequest: http.Request{ RemoteAddr: "[2001:db8::1]:8080", }, expectIP: "2001:db8::1", }, { name: "remote addr is invalid, returns empty string", whenRequest: http.Request{ RemoteAddr: "invalid-ip-format", }, expectIP: "", }, { name: "request is from external IP has X-Real-Ip header, extractor still extracts IP from request remote addr", whenRequest: http.Request{ Header: http.Header{ HeaderXRealIP: []string{"203.0.113.10"}, }, RemoteAddr: "203.0.113.1:8080", }, expectIP: "203.0.113.1", }, { name: "request is from internal IP and has Real-IP header, extractor still extracts internal IP from request remote addr", whenRequest: http.Request{ Header: http.Header{ HeaderXRealIP: []string{"203.0.113.10"}, }, RemoteAddr: "127.0.0.1:8080", }, expectIP: "127.0.0.1", }, { name: "request is from external IP and has XFF + Real-IP header, extractor still extracts external IP from request remote addr", whenRequest: http.Request{ Header: http.Header{ HeaderXRealIP: []string{"203.0.113.10"}, HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, }, RemoteAddr: "203.0.113.1:8080", }, expectIP: "203.0.113.1", }, { name: "request is from internal IP and has XFF + Real-IP header, extractor still extracts internal IP from request remote addr", whenRequest: http.Request{ Header: http.Header{ HeaderXRealIP: []string{"127.0.0.1"}, HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, }, RemoteAddr: "127.0.0.1:8080", }, expectIP: "127.0.0.1", }, { name: "request is from external IP and has XFF header, extractor still extracts external IP from request remote addr", whenRequest: http.Request{ Header: http.Header{ HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, }, RemoteAddr: "203.0.113.1:8080", }, expectIP: "203.0.113.1", }, { name: "request is from internal IP and has XFF header, extractor still extracts internal IP from request remote addr", whenRequest: http.Request{ Header: http.Header{ HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, }, RemoteAddr: "127.0.0.1:8080", }, expectIP: "127.0.0.1", }, { name: "request is from internal IP and has INVALID XFF header, extractor still extracts internal IP from request remote addr", whenRequest: http.Request{ Header: http.Header{ HeaderXForwardedFor: []string{"this.is.broken.lol, 169.254.0.101"}, }, RemoteAddr: "127.0.0.1:8080", }, expectIP: "127.0.0.1", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { extractedIP := ExtractIPDirect()(&tc.whenRequest) assert.Equal(t, tc.expectIP, extractedIP) }) } } func TestExtractIPFromRealIPHeader(t *testing.T) { _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64") var testCases = []struct { whenRequest http.Request name string expectIP string givenTrustOptions []TrustOption }{ { name: "request has no headers, extracts IP from request remote addr", whenRequest: http.Request{ RemoteAddr: "203.0.113.1:8080", }, expectIP: "203.0.113.1", }, { name: "request is from external IP has INVALID external X-Real-Ip header, extract IP from remote addr", whenRequest: http.Request{ Header: http.Header{ HeaderXRealIP: []string{"xxx.yyy.zzz.ccc"}, // <-- this is invalid }, RemoteAddr: "203.0.113.1:8080", }, expectIP: "203.0.113.1", }, { name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr", whenRequest: http.Request{ Header: http.Header{ HeaderXRealIP: []string{"203.0.113.199"}, // <-- this is untrusted }, RemoteAddr: "203.0.113.1:8080", }, expectIP: "203.0.113.1", }, { name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr", whenRequest: http.Request{ Header: http.Header{ HeaderXRealIP: []string{"[2001:db8::113:199]"}, // <-- this is untrusted }, RemoteAddr: "[2001:db8::113:1]:8080", }, expectIP: "2001:db8::113:1", }, { name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" }, whenRequest: http.Request{ Header: http.Header{ HeaderXRealIP: []string{"203.0.113.199"}, }, RemoteAddr: "203.0.113.1:8080", }, expectIP: "203.0.113.199", }, { name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64" }, whenRequest: http.Request{ Header: http.Header{ HeaderXRealIP: []string{"[2001:db8::113:199]"}, }, RemoteAddr: "[2001:db8::113:1]:8080", }, expectIP: "2001:db8::113:199", }, { name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" }, whenRequest: http.Request{ Header: http.Header{ HeaderXRealIP: []string{"203.0.113.199"}, HeaderXForwardedFor: []string{"203.0.113.198, 203.0.113.197"}, // <-- should not affect anything }, RemoteAddr: "203.0.113.1:8080", }, expectIP: "203.0.113.199", }, { name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64" }, whenRequest: http.Request{ Header: http.Header{ HeaderXRealIP: []string{"[2001:db8::113:199]"}, HeaderXForwardedFor: []string{"[2001:db8::113:198], [2001:db8::113:197]"}, // <-- should not affect anything }, RemoteAddr: "[2001:db8::113:1]:8080", }, expectIP: "2001:db8::113:199", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { extractedIP := ExtractIPFromRealIPHeader(tc.givenTrustOptions...)(&tc.whenRequest) assert.Equal(t, tc.expectIP, extractedIP) }) } } func TestExtractIPFromXFFHeader(t *testing.T) { _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64") var testCases = []struct { whenRequest http.Request name string expectIP string givenTrustOptions []TrustOption }{ { name: "request has no headers, extracts IP from request remote addr", whenRequest: http.Request{ RemoteAddr: "203.0.113.1:8080", }, expectIP: "203.0.113.1", }, { name: "request has INVALID external XFF header, extract IP from remote addr", whenRequest: http.Request{ Header: http.Header{ HeaderXForwardedFor: []string{"xxx.yyy.zzz.ccc, 127.0.0.2"}, // <-- this is invalid }, RemoteAddr: "127.0.0.1:8080", }, expectIP: "127.0.0.1", }, { name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain", whenRequest: http.Request{ Header: http.Header{ HeaderXForwardedFor: []string{"127.0.0.3, 127.0.0.2, 127.0.0.1"}, }, RemoteAddr: "127.0.0.1:8080", }, expectIP: "127.0.0.3", }, { name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain", whenRequest: http.Request{ Header: http.Header{ HeaderXForwardedFor: []string{"[fe80::3], [fe80::2], [fe80::1]"}, }, RemoteAddr: "[fe80::1]:8080", }, expectIP: "fe80::3", }, { name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr", whenRequest: http.Request{ Header: http.Header{ HeaderXForwardedFor: []string{"203.0.113.199"}, // <-- this is untrusted }, RemoteAddr: "203.0.113.1:8080", }, expectIP: "203.0.113.1", }, { name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr", whenRequest: http.Request{ Header: http.Header{ HeaderXForwardedFor: []string{"[2001:db8::1]"}, // <-- this is untrusted }, RemoteAddr: "[2001:db8::2]:8080", }, expectIP: "2001:db8::2", }, { name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header", givenTrustOptions: []TrustOption{ TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" }, // from request its seems that request has been proxied through 6 servers. // 1) 203.0.1.100 (this is external IP set by 203.0.100.100 which we do not trust - could be spoofed) // 2) 203.0.100.100 (this is outside of our network but set by 203.0.113.199 which we trust to set correct IPs) // 3) 203.0.113.199 (we trust, for example maybe our proxy from some other office) // 4) 192.168.1.100 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products) // 5) 127.0.0.1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing) whenRequest: http.Request{ Header: http.Header{ HeaderXForwardedFor: []string{"203.0.1.100, 203.0.100.100, 203.0.113.199, 192.168.1.100"}, }, RemoteAddr: "127.0.0.1:8080", // IP of proxy upstream of our APP }, expectIP: "203.0.100.100", // this is first trusted IP in XFF chain }, { name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header", givenTrustOptions: []TrustOption{ TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64" }, // from request its seems that request has been proxied through 6 servers. // 1) 2001:db8:1::1:100 (this is external IP set by 2001:db8:2::100:100 which we do not trust - could be spoofed) // 2) 2001:db8:2::100:100 (this is outside of our network but set by 2001:db8::113:199 which we trust to set correct IPs) // 3) 2001:db8::113:199 (we trust, for example maybe our proxy from some other office) // 4) fd12:3456:789a:1::1 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products) // 5) fe80::1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing) whenRequest: http.Request{ Header: http.Header{ HeaderXForwardedFor: []string{"[2001:db8:1::1:100], [2001:db8:2::100:100], [2001:db8::113:199], [fd12:3456:789a:1::1]"}, }, RemoteAddr: "[fe80::1]:8080", // IP of proxy upstream of our APP }, expectIP: "2001:db8:2::100:100", // this is first trusted IP in XFF chain }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { extractedIP := ExtractIPFromXFFHeader(tc.givenTrustOptions...)(&tc.whenRequest) assert.Equal(t, tc.expectIP, extractedIP) }) } } ================================================ FILE: json.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echo import ( "encoding/json" ) // DefaultJSONSerializer implements JSON encoding using encoding/json. type DefaultJSONSerializer struct{} // Serialize converts an interface into a json and writes it to the response. // You can optionally use the indent parameter to produce pretty JSONs. func (d DefaultJSONSerializer) Serialize(c *Context, target any, indent string) error { enc := json.NewEncoder(c.Response()) if indent != "" { enc.SetIndent("", indent) } return enc.Encode(target) } // Deserialize reads a JSON from a request body and converts it into an interface. func (d DefaultJSONSerializer) Deserialize(c *Context, target any) error { if err := json.NewDecoder(c.Request().Body).Decode(target); err != nil { return ErrBadRequest.Wrap(err) } return nil } ================================================ FILE: json_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package echo import ( "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" "strings" "testing" ) // Note this test is deliberately simple as there's not a lot to test. // Just need to ensure it writes JSONs. The heavy work is done by the context methods. func TestDefaultJSONCodec_Encode(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) // Echo assert.Equal(t, e, c.Echo()) // Request assert.NotNil(t, c.Request()) // Response assert.NotNil(t, c.Response()) //-------- // Default JSON encoder //-------- enc := new(DefaultJSONSerializer) err := enc.Serialize(c, user{ID: 1, Name: "Jon Snow"}, "") if assert.NoError(t, err) { assert.Equal(t, userJSON+"\n", rec.Body.String()) } req = httptest.NewRequest(http.MethodPost, "/", nil) rec = httptest.NewRecorder() c = e.NewContext(req, rec) err = enc.Serialize(c, user{ID: 1, Name: "Jon Snow"}, " ") if assert.NoError(t, err) { assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) } } // Note this test is deliberately simple as there's not a lot to test. // Just need to ensure it writes JSONs. The heavy work is done by the context methods. func TestDefaultJSONCodec_Decode(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) // Echo assert.Equal(t, e, c.Echo()) // Request assert.NotNil(t, c.Request()) // Response assert.NotNil(t, c.Response()) //-------- // Default JSON encoder //-------- enc := new(DefaultJSONSerializer) var u = user{} err := enc.Deserialize(c, &u) if assert.NoError(t, err) { assert.Equal(t, u, user{ID: 1, Name: "Jon Snow"}) } var userUnmarshalSyntaxError = user{} req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(invalidContent)) rec = httptest.NewRecorder() c = e.NewContext(req, rec) err = enc.Deserialize(c, &userUnmarshalSyntaxError) assert.IsType(t, &HTTPError{}, err) assert.EqualError(t, err, "code=400, message=Bad Request, err=invalid character 'i' looking for beginning of value") var userUnmarshalTypeError = struct { ID string `json:"id"` Name string `json:"name"` }{} req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec = httptest.NewRecorder() c = e.NewContext(req, rec) err = enc.Deserialize(c, &userUnmarshalTypeError) assert.IsType(t, &HTTPError{}, err) assert.EqualError(t, err, "code=400, message=Bad Request, err=json: cannot unmarshal number into Go struct field .id of type string") } ================================================ FILE: middleware/DEVELOPMENT.md ================================================ # Development Guidelines for middlewares ## Best practices: * Do not use `panic` in middleware creator functions in case of invalid configuration. * In case of an error in middleware function handling request avoid using `c.Error()` and returning no error instead because previous middlewares up in call chain could have logic for dealing with returned errors. * Create middleware configuration structs that implement `MiddlewareConfigurator` interface so can decide if they want to create middleware with panics or with returning errors on configuration errors. * When adding `echo.Context` to function type or fields make it first parameter so all functions with Context looks same. ================================================ FILE: middleware/basic_auth.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "bytes" "cmp" "encoding/base64" "errors" "strconv" "strings" "github.com/labstack/echo/v5" ) // BasicAuthConfig defines the config for BasicAuthWithConfig middleware. // // SECURITY: The Validator function is responsible for securely comparing credentials. // See BasicAuthValidator documentation for guidance on preventing timing attacks. type BasicAuthConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic auth headers // this function would be called once for each header until first valid result is returned // Required. Validator BasicAuthValidator // Realm is a string to define realm attribute of BasicAuthWithConfig. // Default value "Restricted". Realm string // AllowedCheckLimit set how many headers are allowed to be checked. This is useful // environments like corporate test environments with application proxies restricting // access to environment with their own auth scheme. // Defaults to 1. AllowedCheckLimit uint } // BasicAuthValidator defines a function to validate BasicAuthWithConfig credentials. // // SECURITY WARNING: To prevent timing attacks that could allow attackers to enumerate // valid usernames or passwords, validator implementations MUST use constant-time // comparison for credential checking. Use crypto/subtle.ConstantTimeCompare instead // of standard string equality (==) or switch statements. // // Example of SECURE implementation: // // import "crypto/subtle" // // validator := func(c *echo.Context, username, password string) (bool, error) { // // Fetch expected credentials from database/config // expectedUser := "admin" // expectedPass := "secretpassword" // // // Use constant-time comparison to prevent timing attacks // userMatch := subtle.ConstantTimeCompare([]byte(username), []byte(expectedUser)) == 1 // passMatch := subtle.ConstantTimeCompare([]byte(password), []byte(expectedPass)) == 1 // // if userMatch && passMatch { // return true, nil // } // return false, nil // } // // Example of INSECURE implementation (DO NOT USE): // // // VULNERABLE TO TIMING ATTACKS - DO NOT USE // validator := func(c *echo.Context, username, password string) (bool, error) { // if username == "admin" && password == "secret" { // Timing leak! // return true, nil // } // return false, nil // } type BasicAuthValidator func(c *echo.Context, user string, password string) (bool, error) const ( basic = "basic" defaultRealm = "Restricted" ) // BasicAuth returns an BasicAuth middleware. // // For valid credentials it calls the next handler. // For missing or invalid credentials, it sends "401 - Unauthorized" response. func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc { return BasicAuthWithConfig(BasicAuthConfig{Validator: fn}) } // BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config. func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } // ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Validator == nil { return nil, errors.New("echo basic-auth middleware requires a validator function") } if config.Skipper == nil { config.Skipper = DefaultSkipper } realm := defaultRealm if config.Realm != "" { realm = config.Realm } realm = strconv.Quote(realm) limit := cmp.Or(config.AllowedCheckLimit, 1) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } var lastError error l := len(basic) i := uint(0) for _, auth := range c.Request().Header[echo.HeaderAuthorization] { if i >= limit { break } if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) { continue } i++ // Invalid base64 shouldn't be treated as error // instead should be treated as invalid client input b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:]) if errDecode != nil { lastError = echo.ErrBadRequest.Wrap(errDecode) continue } idx := bytes.IndexByte(b, ':') if idx >= 0 { valid, errValidate := config.Validator(c, string(b[:idx]), string(b[idx+1:])) if errValidate != nil { lastError = errValidate } else if valid { return next(c) } } } if lastError != nil { return lastError } // Need to return `401` for browsers to pop-up login box. c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm) return echo.ErrUnauthorized } }, nil } ================================================ FILE: middleware/basic_auth_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "crypto/subtle" "encoding/base64" "errors" "net/http" "net/http/httptest" "strings" "testing" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestBasicAuth(t *testing.T) { validatorFunc := func(c *echo.Context, u, p string) (bool, error) { // Use constant-time comparison to prevent timing attacks userMatch := subtle.ConstantTimeCompare([]byte(u), []byte("joe")) == 1 passMatch := subtle.ConstantTimeCompare([]byte(p), []byte("secret")) == 1 if userMatch && passMatch { return true, nil } // Special case for testing error handling if u == "error" { return false, errors.New(p) } return false, nil } defaultConfig := BasicAuthConfig{Validator: validatorFunc} var testCases = []struct { name string givenConfig BasicAuthConfig whenAuth []string expectHeader string expectErr string }{ { name: "ok", givenConfig: defaultConfig, whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, }, { name: "ok, multiple", givenConfig: BasicAuthConfig{Validator: validatorFunc, AllowedCheckLimit: 2}, whenAuth: []string{ "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")), basic + " NOT_BASE64", basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), }, }, { name: "nok, multiple, valid out of limit", givenConfig: BasicAuthConfig{Validator: validatorFunc, AllowedCheckLimit: 1}, whenAuth: []string{ "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")), basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid_password")), // limit only check first and should not check auth below basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), }, expectHeader: basic + ` realm="Restricted"`, expectErr: "Unauthorized", }, { name: "nok, invalid Authorization header", givenConfig: defaultConfig, whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, expectHeader: basic + ` realm="Restricted"`, expectErr: "Unauthorized", }, { name: "nok, not base64 Authorization header", givenConfig: defaultConfig, whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"}, expectErr: "code=400, message=Bad Request, err=illegal base64 data at input byte 3", }, { name: "nok, missing Authorization header", givenConfig: defaultConfig, expectHeader: basic + ` realm="Restricted"`, expectErr: "Unauthorized", }, { name: "ok, realm", givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, }, { name: "ok, realm, case-insensitive header scheme", givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, }, { name: "nok, realm, invalid Authorization header", givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, expectHeader: basic + ` realm="someRealm"`, expectErr: "Unauthorized", }, { name: "nok, validator func returns an error", givenConfig: defaultConfig, whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))}, expectErr: "my_error", }, { name: "ok, skipped", givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c *echo.Context) bool { return true }}, whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) res := httptest.NewRecorder() c := e.NewContext(req, res) config := tc.givenConfig mw, err := config.ToMiddleware() assert.NoError(t, err) h := mw(func(c *echo.Context) error { return c.String(http.StatusTeapot, "test") }) if len(tc.whenAuth) != 0 { for _, a := range tc.whenAuth { req.Header.Add(echo.HeaderAuthorization, a) } } err = h(c) if tc.expectErr != "" { assert.Equal(t, http.StatusOK, res.Code) assert.EqualError(t, err, tc.expectErr) } else { assert.Equal(t, http.StatusTeapot, res.Code) assert.NoError(t, err) } if tc.expectHeader != "" { assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate)) } }) } } func TestBasicAuth_panic(t *testing.T) { assert.Panics(t, func() { mw := BasicAuth(nil) assert.NotNil(t, mw) }) mw := BasicAuth(func(c *echo.Context, user string, password string) (bool, error) { return true, nil }) assert.NotNil(t, mw) } func TestBasicAuthWithConfig_panic(t *testing.T) { assert.Panics(t, func() { mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil}) assert.NotNil(t, mw) }) mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c *echo.Context, user string, password string) (bool, error) { return true, nil }}) assert.NotNil(t, mw) } func TestBasicAuthRealm(t *testing.T) { e := echo.New() mockValidator := func(c *echo.Context, u, p string) (bool, error) { return false, nil // Always fail to trigger WWW-Authenticate header } tests := []struct { name string realm string expectedAuth string }{ { name: "Default realm", realm: "Restricted", expectedAuth: `basic realm="Restricted"`, }, { name: "Custom realm", realm: "My API", expectedAuth: `basic realm="My API"`, }, { name: "Realm with special characters", realm: `Realm with "quotes" and \backslashes`, expectedAuth: `basic realm="Realm with \"quotes\" and \\backslashes"`, }, { name: "Empty realm (falls back to default)", realm: "", expectedAuth: `basic realm="Restricted"`, }, { name: "Realm with unicode", realm: "测试领域", expectedAuth: `basic realm="测试领域"`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) res := httptest.NewRecorder() c := e.NewContext(req, res) h := BasicAuthWithConfig(BasicAuthConfig{ Validator: mockValidator, Realm: tt.realm, })(func(c *echo.Context) error { return c.String(http.StatusOK, "test") }) err := h(c) assert.Equal(t, echo.ErrUnauthorized, err) assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate)) }) } } ================================================ FILE: middleware/body_dump.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "bufio" "bytes" "errors" "io" "net" "net/http" "sync" "github.com/labstack/echo/v5" ) // BodyDumpConfig defines the config for BodyDump middleware. type BodyDumpConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // Handler receives request, response payloads and handler error if there are any. // Required. Handler BodyDumpHandler // MaxRequestBytes limits how much of the request body to dump. // If the request body exceeds this limit, only the first MaxRequestBytes // are dumped. The handler callback receives truncated data. // Default: 5 * MB (5,242,880 bytes) // Set to -1 to disable limits (not recommended in production). MaxRequestBytes int64 // MaxResponseBytes limits how much of the response body to dump. // If the response body exceeds this limit, only the first MaxResponseBytes // are dumped. The handler callback receives truncated data. // Default: 5 * MB (5,242,880 bytes) // Set to -1 to disable limits (not recommended in production). MaxResponseBytes int64 } // BodyDumpHandler receives the request and response payload. type BodyDumpHandler func(c *echo.Context, reqBody []byte, resBody []byte, err error) type bodyDumpResponseWriter struct { io.Writer http.ResponseWriter } // BodyDump returns a BodyDump middleware. // // BodyDump middleware captures the request and response payload and calls the // registered handler. // // SECURITY: By default, this limits dumped bodies to 5MB to prevent memory exhaustion // attacks. To customize limits, use BodyDumpWithConfig. To disable limits (not recommended // in production), explicitly set MaxRequestBytes and MaxResponseBytes to -1. func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc { return BodyDumpWithConfig(BodyDumpConfig{Handler: handler}) } // BodyDumpWithConfig returns a BodyDump middleware with config. // See: `BodyDump()`. // // SECURITY: If MaxRequestBytes and MaxResponseBytes are not set (zero values), they default // to 5MB each to prevent DoS attacks via large payloads. Set them explicitly to -1 to disable // limits if needed for your use case. func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } // ToMiddleware converts BodyDumpConfig to middleware or returns an error for invalid configuration func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Handler == nil { return nil, errors.New("echo body-dump middleware requires a handler function") } if config.Skipper == nil { config.Skipper = DefaultSkipper } if config.MaxRequestBytes == 0 { config.MaxRequestBytes = 5 * MB } if config.MaxResponseBytes == 0 { config.MaxResponseBytes = 5 * MB } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } reqBuf := bodyDumpBufferPool.Get().(*bytes.Buffer) reqBuf.Reset() defer bodyDumpBufferPool.Put(reqBuf) var bodyReader io.Reader = c.Request().Body if config.MaxRequestBytes > 0 { bodyReader = io.LimitReader(c.Request().Body, config.MaxRequestBytes) } _, readErr := io.Copy(reqBuf, bodyReader) if readErr != nil && readErr != io.EOF { return readErr } if config.MaxRequestBytes > 0 { // Drain any remaining body data to prevent connection issues _, _ = io.Copy(io.Discard, c.Request().Body) _ = c.Request().Body.Close() } reqBody := make([]byte, reqBuf.Len()) copy(reqBody, reqBuf.Bytes()) c.Request().Body = io.NopCloser(bytes.NewReader(reqBody)) // response part resBuf := bodyDumpBufferPool.Get().(*bytes.Buffer) resBuf.Reset() defer bodyDumpBufferPool.Put(resBuf) var respWriter io.Writer if config.MaxResponseBytes > 0 { respWriter = &limitedWriter{ response: c.Response(), dumpBuf: resBuf, limit: config.MaxResponseBytes, } } else { respWriter = io.MultiWriter(c.Response(), resBuf) } writer := &bodyDumpResponseWriter{ Writer: respWriter, ResponseWriter: c.Response(), } c.SetResponse(writer) err := next(c) // Callback config.Handler(c, reqBody, resBuf.Bytes(), err) return err } }, nil } func (w *bodyDumpResponseWriter) WriteHeader(code int) { w.ResponseWriter.WriteHeader(code) } func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) { return w.Writer.Write(b) } func (w *bodyDumpResponseWriter) Flush() { err := http.NewResponseController(w.ResponseWriter).Flush() if err != nil && errors.Is(err, http.ErrNotSupported) { panic(errors.New("response writer flushing is not supported")) } } func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return http.NewResponseController(w.ResponseWriter).Hijack() } func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter { return w.ResponseWriter } var bodyDumpBufferPool = sync.Pool{ New: func() any { return new(bytes.Buffer) }, } type limitedWriter struct { response http.ResponseWriter dumpBuf *bytes.Buffer dumped int64 limit int64 } func (w *limitedWriter) Write(b []byte) (n int, err error) { // Always write full data to actual response (don't truncate client response) n, err = w.response.Write(b) if err != nil { return n, err } // Write to dump buffer only up to limit if w.dumped < w.limit { remaining := w.limit - w.dumped toDump := int64(n) if toDump > remaining { toDump = remaining } w.dumpBuf.Write(b[:toDump]) w.dumped += toDump } return n, nil } ================================================ FILE: middleware/body_dump_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "errors" "io" "net/http" "net/http/httptest" "strings" "testing" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestBodyDump(t *testing.T) { e := echo.New() hw := "Hello, World!" req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c *echo.Context) error { body, err := io.ReadAll(c.Request().Body) if err != nil { return err } return c.String(http.StatusOK, string(body)) } requestBody := "" responseBody := "" mw, err := BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { requestBody = string(reqBody) responseBody = string(resBody) }}.ToMiddleware() assert.NoError(t, err) if assert.NoError(t, mw(h)(c)) { assert.Equal(t, requestBody, hw) assert.Equal(t, responseBody, hw) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, hw, rec.Body.String()) } } func TestBodyDump_skipper(t *testing.T) { e := echo.New() isCalled := false mw, err := BodyDumpConfig{ Skipper: func(c *echo.Context) bool { return true }, Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { isCalled = true }, }.ToMiddleware() assert.NoError(t, err) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("{}")) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c *echo.Context) error { return errors.New("some error") } err = mw(h)(c) assert.EqualError(t, err, "some error") assert.False(t, isCalled) } func TestBodyDump_fails(t *testing.T) { e := echo.New() hw := "Hello, World!" req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c *echo.Context) error { return errors.New("some error") } mw, err := BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}}.ToMiddleware() assert.NoError(t, err) err = mw(h)(c) assert.EqualError(t, err, "some error") assert.Equal(t, http.StatusOK, rec.Code) } func TestBodyDumpWithConfig_panic(t *testing.T) { assert.Panics(t, func() { mw := BodyDumpWithConfig(BodyDumpConfig{ Skipper: nil, Handler: nil, }) assert.NotNil(t, mw) }) assert.NotPanics(t, func() { mw := BodyDumpWithConfig(BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}}) assert.NotNil(t, mw) }) } func TestBodyDump_panic(t *testing.T) { assert.Panics(t, func() { mw := BodyDump(nil) assert.NotNil(t, mw) }) assert.NotPanics(t, func() { BodyDump(func(c *echo.Context, reqBody, resBody []byte, err error) {}) }) } func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) { bdrw := bodyDumpResponseWriter{ ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush } assert.PanicsWithError(t, "response writer flushing is not supported", func() { bdrw.Flush() }) } func TestBodyDumpResponseWriter_CanFlush(t *testing.T) { trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}} bdrw := bodyDumpResponseWriter{ ResponseWriter: &trwu, } bdrw.Flush() assert.Equal(t, 1, trwu.unwrapCalled) } func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) { trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()} bdrw := bodyDumpResponseWriter{ ResponseWriter: trwu, } result := bdrw.Unwrap() assert.Equal(t, trwu, result) } func TestBodyDumpResponseWriter_CanHijack(t *testing.T) { trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}} bdrw := bodyDumpResponseWriter{ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping } _, _, err := bdrw.Hijack() assert.EqualError(t, err, "can hijack") } func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) { trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()} bdrw := bodyDumpResponseWriter{ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping } _, _, err := bdrw.Hijack() assert.EqualError(t, err, "feature not supported") } func TestBodyDump_ReadError(t *testing.T) { e := echo.New() // Create a reader that fails during read failingReader := &failingReadCloser{ data: []byte("partial data"), failAt: 7, // Fail after 7 bytes failWith: errors.New("connection reset"), } req := httptest.NewRequest(http.MethodPost, "/", failingReader) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c *echo.Context) error { // This handler should not be reached if body read fails body, _ := io.ReadAll(c.Request().Body) return c.String(http.StatusOK, string(body)) } requestBodyReceived := "" mw := BodyDump(func(c *echo.Context, reqBody, resBody []byte, err error) { requestBodyReceived = string(reqBody) }) err := mw(h)(c) // Verify error is propagated assert.Error(t, err) assert.Contains(t, err.Error(), "connection reset") // Verify handler was not executed (callback wouldn't have received data) assert.Empty(t, requestBodyReceived) } // failingReadCloser is a helper type for testing read errors type failingReadCloser struct { data []byte pos int failAt int failWith error } func (f *failingReadCloser) Read(p []byte) (n int, err error) { if f.pos >= f.failAt { return 0, f.failWith } n = copy(p, f.data[f.pos:]) f.pos += n if f.pos >= f.failAt { return n, f.failWith } return n, nil } func (f *failingReadCloser) Close() error { return nil } func TestBodyDump_RequestWithinLimit(t *testing.T) { e := echo.New() requestData := "Hello, World!" req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c *echo.Context) error { body, _ := io.ReadAll(c.Request().Body) return c.String(http.StatusOK, string(body)) } requestBodyDumped := "" mw, err := BodyDumpConfig{ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { requestBodyDumped = string(reqBody) }, MaxRequestBytes: 1 * MB, // 1MB limit MaxResponseBytes: 1 * MB, }.ToMiddleware() assert.NoError(t, err) err = mw(h)(c) assert.NoError(t, err) assert.Equal(t, requestData, requestBodyDumped, "Small request should be fully dumped") assert.Equal(t, requestData, rec.Body.String(), "Handler should receive full request") } func TestBodyDump_RequestExceedsLimit(t *testing.T) { e := echo.New() // Create 2KB of data but limit to 1KB largeData := strings.Repeat("A", 2*1024) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeData)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c *echo.Context) error { body, _ := io.ReadAll(c.Request().Body) return c.String(http.StatusOK, string(body)) } requestBodyDumped := "" limit := int64(1024) // 1KB limit mw, err := BodyDumpConfig{ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { requestBodyDumped = string(reqBody) }, MaxRequestBytes: limit, MaxResponseBytes: 1 * MB, }.ToMiddleware() assert.NoError(t, err) err = mw(h)(c) assert.NoError(t, err) assert.Equal(t, int(limit), len(requestBodyDumped), "Dumped request should be truncated to limit") assert.Equal(t, strings.Repeat("A", 1024), requestBodyDumped, "Dumped data should match first N bytes") // Handler should receive truncated data (what was dumped) assert.Equal(t, strings.Repeat("A", 1024), rec.Body.String()) } func TestBodyDump_RequestAtExactLimit(t *testing.T) { e := echo.New() exactData := strings.Repeat("B", 1024) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(exactData)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c *echo.Context) error { body, _ := io.ReadAll(c.Request().Body) return c.String(http.StatusOK, string(body)) } requestBodyDumped := "" limit := int64(1024) mw, err := BodyDumpConfig{ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { requestBodyDumped = string(reqBody) }, MaxRequestBytes: limit, MaxResponseBytes: 1 * MB, }.ToMiddleware() assert.NoError(t, err) err = mw(h)(c) assert.NoError(t, err) assert.Equal(t, int(limit), len(requestBodyDumped), "Exact limit should dump full data") assert.Equal(t, exactData, requestBodyDumped) } func TestBodyDump_ResponseWithinLimit(t *testing.T) { e := echo.New() responseData := "Response data" req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c *echo.Context) error { return c.String(http.StatusOK, responseData) } responseBodyDumped := "" mw, err := BodyDumpConfig{ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { responseBodyDumped = string(resBody) }, MaxRequestBytes: 1 * MB, MaxResponseBytes: 1 * MB, }.ToMiddleware() assert.NoError(t, err) err = mw(h)(c) assert.NoError(t, err) assert.Equal(t, responseData, responseBodyDumped, "Small response should be fully dumped") assert.Equal(t, responseData, rec.Body.String(), "Client should receive full response") } func TestBodyDump_ResponseExceedsLimit(t *testing.T) { e := echo.New() largeResponse := strings.Repeat("X", 2*1024) // 2KB req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c *echo.Context) error { return c.String(http.StatusOK, largeResponse) } responseBodyDumped := "" limit := int64(1024) // 1KB limit mw, err := BodyDumpConfig{ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { responseBodyDumped = string(resBody) }, MaxRequestBytes: 1 * MB, MaxResponseBytes: limit, }.ToMiddleware() assert.NoError(t, err) err = mw(h)(c) assert.NoError(t, err) // Dump should be truncated assert.Equal(t, int(limit), len(responseBodyDumped), "Dumped response should be truncated to limit") assert.Equal(t, strings.Repeat("X", 1024), responseBodyDumped) // Client should still receive full response! assert.Equal(t, largeResponse, rec.Body.String(), "Client must receive full response despite dump truncation") } func TestBodyDump_ClientGetsFullResponse(t *testing.T) { e := echo.New() // This is critical - even when dump is limited, client gets everything largeResponse := strings.Repeat("DATA", 500) // 2KB req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c *echo.Context) error { // Write response in chunks to test incremental writes for i := 0; i < 4; i++ { c.Response().Write([]byte(strings.Repeat("DATA", 125))) } return nil } responseBodyDumped := "" mw, err := BodyDumpConfig{ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { responseBodyDumped = string(resBody) }, MaxRequestBytes: 1 * MB, MaxResponseBytes: 512, // Very small limit }.ToMiddleware() assert.NoError(t, err) err = mw(h)(c) assert.NoError(t, err) assert.Equal(t, 512, len(responseBodyDumped), "Dump should be limited") assert.Equal(t, largeResponse, rec.Body.String(), "Client must get full response") } func TestBodyDump_BothLimitsSimultaneous(t *testing.T) { e := echo.New() largeRequest := strings.Repeat("Q", 2*1024) largeResponse := strings.Repeat("R", 2*1024) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeRequest)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c *echo.Context) error { io.ReadAll(c.Request().Body) // Consume request return c.String(http.StatusOK, largeResponse) } requestBodyDumped := "" responseBodyDumped := "" limit := int64(1024) mw, err := BodyDumpConfig{ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { requestBodyDumped = string(reqBody) responseBodyDumped = string(resBody) }, MaxRequestBytes: limit, MaxResponseBytes: limit, }.ToMiddleware() assert.NoError(t, err) err = mw(h)(c) assert.NoError(t, err) assert.Equal(t, int(limit), len(requestBodyDumped), "Request dump should be limited") assert.Equal(t, int(limit), len(responseBodyDumped), "Response dump should be limited") } func TestBodyDump_DefaultConfig(t *testing.T) { e := echo.New() smallData := "test" req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(smallData)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c *echo.Context) error { body, _ := io.ReadAll(c.Request().Body) return c.String(http.StatusOK, string(body)) } requestBodyDumped := "" // Use default config which should have 1MB limits config := BodyDumpConfig{} config.Handler = func(c *echo.Context, reqBody, resBody []byte, err error) { requestBodyDumped = string(reqBody) } mw, err := config.ToMiddleware() assert.NoError(t, err) err = mw(h)(c) assert.NoError(t, err) assert.Equal(t, smallData, requestBodyDumped) } func TestBodyDump_LargeRequestDosPrevention(t *testing.T) { e := echo.New() // Simulate a very large request (10MB) that could cause OOM largeSize := 10 * 1024 * 1024 // 10MB largeData := strings.Repeat("M", largeSize) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeData)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c *echo.Context) error { body, _ := io.ReadAll(c.Request().Body) return c.String(http.StatusOK, string(body)) } requestBodyDumped := "" limit := int64(1 * MB) // Only dump 1MB max mw, err := BodyDumpConfig{ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { requestBodyDumped = string(reqBody) }, MaxRequestBytes: limit, MaxResponseBytes: 1 * MB, }.ToMiddleware() assert.NoError(t, err) err = mw(h)(c) assert.NoError(t, err) // Verify only 1MB was dumped, not 10MB assert.Equal(t, int(limit), len(requestBodyDumped), "Should only dump up to limit, preventing DoS") assert.Less(t, len(requestBodyDumped), largeSize, "Dump should be much smaller than full request") } func TestBodyDump_LargeResponseDosPrevention(t *testing.T) { e := echo.New() // Simulate a very large response (10MB) largeSize := 10 * 1024 * 1024 // 10MB largeResponse := strings.Repeat("R", largeSize) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c *echo.Context) error { return c.String(http.StatusOK, largeResponse) } responseBodyDumped := "" limit := int64(1 * MB) // Only dump 1MB max mw, err := BodyDumpConfig{ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { responseBodyDumped = string(resBody) }, MaxRequestBytes: 1 * MB, MaxResponseBytes: limit, }.ToMiddleware() assert.NoError(t, err) err = mw(h)(c) assert.NoError(t, err) // Verify only 1MB was dumped, not 10MB assert.Equal(t, int(limit), len(responseBodyDumped), "Should only dump up to limit, preventing DoS") assert.Less(t, len(responseBodyDumped), largeSize, "Dump should be much smaller than full response") // Client still gets full response assert.Equal(t, largeSize, rec.Body.Len(), "Client must receive full response") } func BenchmarkBodyDump_WithLimit(b *testing.B) { e := echo.New() requestData := strings.Repeat("data", 256) // 1KB responseData := strings.Repeat("resp", 256) // 1KB h := func(c *echo.Context) error { io.ReadAll(c.Request().Body) return c.String(http.StatusOK, responseData) } mw, _ := BodyDumpConfig{ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { // Simulate logging _ = len(reqBody) + len(resBody) }, MaxRequestBytes: 1 * MB, MaxResponseBytes: 1 * MB, }.ToMiddleware() b.ResetTimer() for i := 0; i < b.N; i++ { req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) mw(h)(c) } } func BenchmarkBodyDump_BufferPooling(b *testing.B) { e := echo.New() requestData := strings.Repeat("x", 1024) responseData := "response" h := func(c *echo.Context) error { io.ReadAll(c.Request().Body) return c.String(http.StatusOK, responseData) } mw, _ := BodyDumpConfig{ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}, MaxRequestBytes: 1 * MB, MaxResponseBytes: 1 * MB, }.ToMiddleware() b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) mw(h)(c) } } ================================================ FILE: middleware/body_limit.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "io" "net/http" "sync" "github.com/labstack/echo/v5" ) // BodyLimitConfig defines the config for BodyLimitWithConfig middleware. type BodyLimitConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // LimitBytes is maximum allowed size in bytes for a request body LimitBytes int64 } type limitedReader struct { BodyLimitConfig reader io.ReadCloser read int64 } // BodyLimit returns a BodyLimit middleware. // // BodyLimit middleware sets the maximum allowed size for a request body, if the size exceeds the configured limit, it // sends "413 - Request Entity Too Large" response. The BodyLimit is determined based on both `Content-Length` request // header and actual content read, which makes it super secure. func BodyLimit(limitBytes int64) echo.MiddlewareFunc { return BodyLimitWithConfig(BodyLimitConfig{LimitBytes: limitBytes}) } // BodyLimitWithConfig returns a BodyLimitWithConfig middleware. Middleware sets the maximum allowed size in bytes for // a request body, if the size exceeds the configured limit, it sends "413 - Request Entity Too Large" response. // The BodyLimitWithConfig is determined based on both `Content-Length` request header and actual content read, which // makes it super secure. func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } // ToMiddleware converts BodyLimitConfig to middleware or returns an error for invalid configuration func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultSkipper } pool := sync.Pool{ New: func() any { return &limitedReader{BodyLimitConfig: config} }, } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } req := c.Request() // Based on content length if req.ContentLength > config.LimitBytes { return echo.ErrStatusRequestEntityTooLarge } // Based on content read r, ok := pool.Get().(*limitedReader) if !ok { return echo.NewHTTPError(http.StatusInternalServerError, "invalid pool object") } r.Reset(req.Body) defer pool.Put(r) req.Body = r return next(c) } }, nil } func (r *limitedReader) Read(b []byte) (n int, err error) { n, err = r.reader.Read(b) r.read += int64(n) if r.read > r.LimitBytes { return n, echo.ErrStatusRequestEntityTooLarge } return } func (r *limitedReader) Close() error { return r.reader.Close() } func (r *limitedReader) Reset(reader io.ReadCloser) { r.reader = reader r.read = 0 } ================================================ FILE: middleware/body_limit_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "bytes" "io" "net/http" "net/http/httptest" "testing" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestBodyLimitConfig_ToMiddleware(t *testing.T) { e := echo.New() hw := []byte("Hello, World!") req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c *echo.Context) error { body, err := io.ReadAll(c.Request().Body) if err != nil { return err } return c.String(http.StatusOK, string(body)) } // Based on content length (within limit) mw, err := BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware() assert.NoError(t, err) err = mw(h)(c) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, hw, rec.Body.Bytes()) } // Based on content read (overlimit) mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware() assert.NoError(t, err) he := mw(h)(c).(echo.HTTPStatusCoder) assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) // Based on content read (within limit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) mw, err = BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware() assert.NoError(t, err) err = mw(h)(c) assert.NoError(t, err) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "Hello, World!", rec.Body.String()) // Based on content read (overlimit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware() assert.NoError(t, err) he = mw(h)(c).(echo.HTTPStatusCoder) assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) } func TestBodyLimitReader(t *testing.T) { hw := []byte("Hello, World!") config := BodyLimitConfig{ Skipper: DefaultSkipper, LimitBytes: 2, } reader := &limitedReader{ BodyLimitConfig: config, reader: io.NopCloser(bytes.NewReader(hw)), } // read all should return ErrStatusRequestEntityTooLarge _, err := io.ReadAll(reader) he := err.(echo.HTTPStatusCoder) assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) // reset reader and read two bytes must succeed bt := make([]byte, 2) reader.Reset(io.NopCloser(bytes.NewReader(hw))) n, err := reader.Read(bt) assert.Equal(t, 2, n) assert.Equal(t, nil, err) } func TestBodyLimit_skipper(t *testing.T) { e := echo.New() h := func(c *echo.Context) error { body, err := io.ReadAll(c.Request().Body) if err != nil { return err } return c.String(http.StatusOK, string(body)) } mw, err := BodyLimitConfig{ Skipper: func(c *echo.Context) bool { return true }, LimitBytes: 2, }.ToMiddleware() assert.NoError(t, err) hw := []byte("Hello, World!") req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) err = mw(h)(c) assert.NoError(t, err) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, hw, rec.Body.Bytes()) } func TestBodyLimitWithConfig(t *testing.T) { e := echo.New() hw := []byte("Hello, World!") req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c *echo.Context) error { body, err := io.ReadAll(c.Request().Body) if err != nil { return err } return c.String(http.StatusOK, string(body)) } mw := BodyLimitWithConfig(BodyLimitConfig{LimitBytes: 2 * MB}) err := mw(h)(c) assert.NoError(t, err) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, hw, rec.Body.Bytes()) } func TestBodyLimit(t *testing.T) { e := echo.New() hw := []byte("Hello, World!") req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c *echo.Context) error { body, err := io.ReadAll(c.Request().Body) if err != nil { return err } return c.String(http.StatusOK, string(body)) } mw := BodyLimit(2 * MB) err := mw(h)(c) assert.NoError(t, err) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, hw, rec.Body.Bytes()) } ================================================ FILE: middleware/compress.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "bufio" "bytes" "compress/gzip" "errors" "io" "net" "net/http" "strings" "sync" "github.com/labstack/echo/v5" ) const ( gzipScheme = "gzip" ) // GzipConfig defines the config for Gzip middleware. type GzipConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // Gzip compression level. // Optional. Default value -1. Level int // Length threshold before gzip compression is applied. // Optional. Default value 0. // // Most of the time you will not need to change the default. Compressing // a short response might increase the transmitted data because of the // gzip format overhead. Compressing the response will also consume CPU // and time on the server and the client (for decompressing). Depending on // your use case such a threshold might be useful. // // See also: // https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits MinLength int } type gzipResponseWriter struct { io.Writer http.ResponseWriter wroteHeader bool wroteBody bool minLength int minLengthExceeded bool buffer *bytes.Buffer code int } // Gzip returns a middleware which compresses HTTP response using gzip compression scheme. func Gzip() echo.MiddlewareFunc { return GzipWithConfig(GzipConfig{}) } // GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme. func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } // ToMiddleware converts GzipConfig to middleware or returns an error for invalid configuration func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultSkipper } if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression return nil, errors.New("invalid gzip level") } if config.Level == 0 { config.Level = -1 } if config.MinLength < 0 { config.MinLength = 0 } pool := gzipCompressPool(config) bpool := bufferPool() return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } res := c.Response() res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding) if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) { i := pool.Get() w, ok := i.(*gzip.Writer) if !ok { return echo.NewHTTPError(http.StatusInternalServerError, "invalid pool object") } rw := res w.Reset(rw) buf := bpool.Get().(*bytes.Buffer) buf.Reset() grw := &gzipResponseWriter{ Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf, } c.SetResponse(grw) defer func() { // There are different reasons for cases when we have not yet written response to the client and now need to do so. // a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now. // b) body is shorter than our minimum length threshold and being buffered currently and needs to be written if !grw.wroteBody { if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme { res.Header().Del(echo.HeaderContentEncoding) } if grw.wroteHeader { rw.WriteHeader(grw.code) } // We have to reset response to it's pristine state when // nothing is written to body or error is returned. // See issue #424, #407. c.SetResponse(rw) w.Reset(io.Discard) } else if !grw.minLengthExceeded { // Write uncompressed response c.SetResponse(rw) if grw.wroteHeader { grw.ResponseWriter.WriteHeader(grw.code) } _, _ = grw.buffer.WriteTo(rw) w.Reset(io.Discard) } _ = w.Close() bpool.Put(buf) pool.Put(w) }() } return next(c) } }, nil } func (w *gzipResponseWriter) WriteHeader(code int) { w.Header().Del(echo.HeaderContentLength) // Issue #444 w.wroteHeader = true // Delay writing of the header until we know if we'll actually compress the response w.code = code } func (w *gzipResponseWriter) Write(b []byte) (int, error) { if w.Header().Get(echo.HeaderContentType) == "" { w.Header().Set(echo.HeaderContentType, http.DetectContentType(b)) } w.wroteBody = true if !w.minLengthExceeded { n, err := w.buffer.Write(b) if w.buffer.Len() >= w.minLength { w.minLengthExceeded = true // The minimum length is exceeded, add Content-Encoding header and write the header w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 if w.wroteHeader { w.ResponseWriter.WriteHeader(w.code) } return w.Writer.Write(w.buffer.Bytes()) } return n, err } return w.Writer.Write(b) } func (w *gzipResponseWriter) Flush() { if !w.minLengthExceeded { // Enforce compression because we will not know how much more data will come w.minLengthExceeded = true w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 if w.wroteHeader { w.ResponseWriter.WriteHeader(w.code) } _, _ = w.Writer.Write(w.buffer.Bytes()) } if gw, ok := w.Writer.(*gzip.Writer); ok { gw.Flush() } _ = http.NewResponseController(w.ResponseWriter).Flush() } func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return http.NewResponseController(w.ResponseWriter).Hijack() } func (w *gzipResponseWriter) Unwrap() http.ResponseWriter { return w.ResponseWriter } func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { if p, ok := w.ResponseWriter.(http.Pusher); ok { return p.Push(target, opts) } return http.ErrNotSupported } func gzipCompressPool(config GzipConfig) sync.Pool { return sync.Pool{ New: func() any { w, err := gzip.NewWriterLevel(io.Discard, config.Level) if err != nil { return err } return w }, } } func bufferPool() sync.Pool { return sync.Pool{ New: func() any { b := &bytes.Buffer{} return b }, } } ================================================ FILE: middleware/compress_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "bytes" "compress/gzip" "io" "net/http" "net/http/httptest" "os" "testing" "time" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestGzip_NoAcceptEncodingHeader(t *testing.T) { // Skip if no Accept-Encoding header h := Gzip()(func(c *echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := h(c) assert.NoError(t, err) assert.Equal(t, "test", rec.Body.String()) } func TestMustGzipWithConfig_panics(t *testing.T) { assert.Panics(t, func() { GzipWithConfig(GzipConfig{Level: 999}) }) } func TestGzip_AcceptEncodingHeader(t *testing.T) { h := Gzip()(func(c *echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := h(c) assert.NoError(t, err) assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) r, err := gzip.NewReader(rec.Body) assert.NoError(t, err) buf := new(bytes.Buffer) defer r.Close() buf.ReadFrom(r) assert.Equal(t, "test", buf.String()) } func TestGzip_chunked(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() c := e.NewContext(req, rec) chunkChan := make(chan struct{}) waitChan := make(chan struct{}) h := Gzip()(func(c *echo.Context) error { rc := http.NewResponseController(c.Response()) c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Transfer-Encoding", "chunked") // Write and flush the first part of the data c.Response().Write([]byte("first\n")) rc.Flush() chunkChan <- struct{}{} <-waitChan // Write and flush the second part of the data c.Response().Write([]byte("second\n")) rc.Flush() chunkChan <- struct{}{} <-waitChan // Write the final part of the data and return c.Response().Write([]byte("third")) chunkChan <- struct{}{} return nil }) go func() { err := h(c) chunkChan <- struct{}{} assert.NoError(t, err) }() <-chunkChan // wait for first write waitChan <- struct{}{} <-chunkChan // wait for second write waitChan <- struct{}{} <-chunkChan // wait for final write in handler <-chunkChan // wait for return from handler time.Sleep(5 * time.Millisecond) // to have time for flushing assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) r, err := gzip.NewReader(rec.Body) assert.NoError(t, err) buf := new(bytes.Buffer) buf.ReadFrom(r) assert.Equal(t, "first\nsecond\nthird", buf.String()) } func TestGzip_NoContent(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := Gzip()(func(c *echo.Context) error { return c.NoContent(http.StatusNoContent) }) if assert.NoError(t, h(c)) { assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) assert.Equal(t, 0, len(rec.Body.Bytes())) } } func TestGzip_Empty(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := Gzip()(func(c *echo.Context) error { return c.String(http.StatusOK, "") }) if assert.NoError(t, h(c)) { assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) assert.Equal(t, "text/plain; charset=UTF-8", rec.Header().Get(echo.HeaderContentType)) r, err := gzip.NewReader(rec.Body) if assert.NoError(t, err) { var buf bytes.Buffer buf.ReadFrom(r) assert.Equal(t, "", buf.String()) } } } func TestGzip_ErrorReturned(t *testing.T) { e := echo.New() e.Use(Gzip()) e.GET("/", func(c *echo.Context) error { return echo.ErrNotFound }) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusNotFound, rec.Code) assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) } func TestGzipWithConfig_invalidLevel(t *testing.T) { mw, err := GzipConfig{Level: 12}.ToMiddleware() assert.EqualError(t, err, "invalid gzip level") assert.Nil(t, mw) } // Issue #806 func TestGzipWithStatic(t *testing.T) { e := echo.New() e.Filesystem = os.DirFS("../") e.Use(Gzip()) e.Static("/test", "_fixture/images") req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) // Data is written out in chunks when Content-Length == "", so only // validate the content length if it's not set. if cl := rec.Header().Get("Content-Length"); cl != "" { assert.Equal(t, cl, rec.Body.Len()) } r, err := gzip.NewReader(rec.Body) if assert.NoError(t, err) { defer r.Close() want, err := os.ReadFile("../_fixture/images/walle.png") if assert.NoError(t, err) { buf := new(bytes.Buffer) buf.ReadFrom(r) assert.Equal(t, want, buf.Bytes()) } } } func TestGzipWithMinLength(t *testing.T) { e := echo.New() // Minimal response length e.Use(GzipWithConfig(GzipConfig{MinLength: 10})) e.GET("/", func(c *echo.Context) error { c.Response().Write([]byte("foobarfoobar")) return nil }) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) r, err := gzip.NewReader(rec.Body) if assert.NoError(t, err) { buf := new(bytes.Buffer) defer r.Close() buf.ReadFrom(r) assert.Equal(t, "foobarfoobar", buf.String()) } } func TestGzipWithMinLengthTooShort(t *testing.T) { e := echo.New() // Minimal response length e.Use(GzipWithConfig(GzipConfig{MinLength: 10})) e.GET("/", func(c *echo.Context) error { c.Response().Write([]byte("test")) return nil }) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding)) assert.Contains(t, rec.Body.String(), "test") } func TestGzipWithResponseWithoutBody(t *testing.T) { e := echo.New() e.Use(Gzip()) e.GET("/", func(c *echo.Context) error { return c.Redirect(http.StatusMovedPermanently, "http://localhost") }) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusMovedPermanently, rec.Code) assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding)) } func TestGzipWithMinLengthChunked(t *testing.T) { e := echo.New() // Gzip chunked chunkBuf := make([]byte, 5) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() var r *gzip.Reader = nil c := e.NewContext(req, rec) next := func(c *echo.Context) error { rc := http.NewResponseController(c.Response()) c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Transfer-Encoding", "chunked") // Write and flush the first part of the data c.Response().Write([]byte("test\n")) rc.Flush() // Read the first part of the data assert.True(t, rec.Flushed) assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) var err error r, err = gzip.NewReader(rec.Body) assert.NoError(t, err) _, err = io.ReadFull(r, chunkBuf) assert.NoError(t, err) assert.Equal(t, "test\n", string(chunkBuf)) // Write and flush the second part of the data c.Response().Write([]byte("test\n")) rc.Flush() _, err = io.ReadFull(r, chunkBuf) assert.NoError(t, err) assert.Equal(t, "test\n", string(chunkBuf)) // Write the final part of the data and return c.Response().Write([]byte("test")) return nil } err := GzipWithConfig(GzipConfig{MinLength: 10})(next)(c) assert.NoError(t, err) assert.NotNil(t, r) buf := new(bytes.Buffer) buf.ReadFrom(r) assert.Equal(t, "test", buf.String()) r.Close() } func TestGzipWithMinLengthNoContent(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := GzipWithConfig(GzipConfig{MinLength: 10})(func(c *echo.Context) error { return c.NoContent(http.StatusNoContent) }) if assert.NoError(t, h(c)) { assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) assert.Equal(t, 0, len(rec.Body.Bytes())) } } func TestGzipResponseWriter_CanUnwrap(t *testing.T) { trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()} bdrw := gzipResponseWriter{ ResponseWriter: trwu, } result := bdrw.Unwrap() assert.Equal(t, trwu, result) } func TestGzipResponseWriter_CanHijack(t *testing.T) { trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}} bdrw := gzipResponseWriter{ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping } _, _, err := bdrw.Hijack() assert.EqualError(t, err, "can hijack") } func TestGzipResponseWriter_CanNotHijack(t *testing.T) { trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()} bdrw := gzipResponseWriter{ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping } _, _, err := bdrw.Hijack() assert.EqualError(t, err, "feature not supported") } func BenchmarkGzip(b *testing.B) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) h := Gzip()(func(c *echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { // Gzip rec := httptest.NewRecorder() c := e.NewContext(req, rec) h(c) } } ================================================ FILE: middleware/context_timeout.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "context" "errors" "time" "github.com/labstack/echo/v5" ) // ContextTimeoutConfig defines the config for ContextTimeout middleware. type ContextTimeoutConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // ErrorHandler is a function when error arises in middeware execution. ErrorHandler func(c *echo.Context, err error) error // Timeout configures a timeout for the middleware Timeout time.Duration } // ContextTimeout returns a middleware which returns error (503 Service Unavailable error) to client // when underlying method returns context.DeadlineExceeded error. func ContextTimeout(timeout time.Duration) echo.MiddlewareFunc { return ContextTimeoutWithConfig(ContextTimeoutConfig{Timeout: timeout}) } // ContextTimeoutWithConfig returns a Timeout middleware with config. func ContextTimeoutWithConfig(config ContextTimeoutConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } // ToMiddleware converts Config to middleware. func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Timeout == 0 { return nil, errors.New("timeout must be set") } if config.Skipper == nil { config.Skipper = DefaultSkipper } if config.ErrorHandler == nil { config.ErrorHandler = func(c *echo.Context, err error) error { if err != nil && errors.Is(err, context.DeadlineExceeded) { return echo.ErrServiceUnavailable.Wrap(err) } return err } } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } timeoutContext, cancel := context.WithTimeout(c.Request().Context(), config.Timeout) defer cancel() c.SetRequest(c.Request().WithContext(timeoutContext)) if err := next(c); err != nil { return config.ErrorHandler(c, err) } return nil } }, nil } ================================================ FILE: middleware/context_timeout_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "context" "errors" "github.com/labstack/echo/v5" "net/http" "net/http/httptest" "net/url" "strings" "testing" "time" "github.com/stretchr/testify/assert" ) func TestContextTimeoutSkipper(t *testing.T) { t.Parallel() m := ContextTimeoutWithConfig(ContextTimeoutConfig{ Skipper: func(context *echo.Context) bool { return true }, Timeout: 10 * time.Millisecond, }) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e := echo.New() c := e.NewContext(req, rec) err := m(func(c *echo.Context) error { if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil { return err } return errors.New("response from handler") })(c) // if not skipped we would have not returned error due context timeout logic assert.EqualError(t, err, "response from handler") } func TestContextTimeoutWithTimeout0(t *testing.T) { t.Parallel() assert.Panics(t, func() { ContextTimeout(time.Duration(0)) }) } func TestContextTimeoutErrorOutInHandler(t *testing.T) { t.Parallel() m := ContextTimeoutWithConfig(ContextTimeoutConfig{ // Timeout has to be defined or the whole flow for timeout middleware will be skipped Timeout: 10 * time.Millisecond, }) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e := echo.New() c := e.NewContext(req, rec) rec.Code = 1 // we want to be sure that even 200 will not be sent err := m(func(c *echo.Context) error { // this error must not be written to the client response. Middlewares upstream of timeout middleware must be able // to handle returned error and this can be done only then handler has not yet committed (written status code) // the response. return echo.NewHTTPError(http.StatusTeapot, "err") })(c) assert.Error(t, err) assert.EqualError(t, err, "code=418, message=err") assert.Equal(t, 1, rec.Code) assert.Equal(t, "", rec.Body.String()) } func TestContextTimeoutSuccessfulRequest(t *testing.T) { t.Parallel() m := ContextTimeoutWithConfig(ContextTimeoutConfig{ // Timeout has to be defined or the whole flow for timeout middleware will be skipped Timeout: 10 * time.Millisecond, }) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e := echo.New() c := e.NewContext(req, rec) err := m(func(c *echo.Context) error { return c.JSON(http.StatusCreated, map[string]string{"data": "ok"}) })(c) assert.NoError(t, err) assert.Equal(t, http.StatusCreated, rec.Code) assert.Equal(t, "{\"data\":\"ok\"}\n", rec.Body.String()) } func TestContextTimeoutTestRequestClone(t *testing.T) { t.Parallel() req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode())) req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"}) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rec := httptest.NewRecorder() m := ContextTimeoutWithConfig(ContextTimeoutConfig{ // Timeout has to be defined or the whole flow for timeout middleware will be skipped Timeout: 1 * time.Second, }) e := echo.New() c := e.NewContext(req, rec) err := m(func(c *echo.Context) error { // Cookie test cookie, err := c.Request().Cookie("cookie") if assert.NoError(t, err) { assert.EqualValues(t, "cookie", cookie.Name) assert.EqualValues(t, "value", cookie.Value) } // Form values if assert.NoError(t, c.Request().ParseForm()) { assert.EqualValues(t, "value", c.Request().FormValue("form")) } // Query string assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0]) return nil })(c) assert.NoError(t, err) } func TestContextTimeoutWithDefaultErrorMessage(t *testing.T) { t.Parallel() timeout := 10 * time.Millisecond m := ContextTimeoutWithConfig(ContextTimeoutConfig{ Timeout: timeout, }) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e := echo.New() c := e.NewContext(req, rec) err := m(func(c *echo.Context) error { if err := sleepWithContext(c.Request().Context(), time.Duration(80*time.Millisecond)); err != nil { return err } return c.String(http.StatusOK, "Hello, World!") })(c) assert.Error(t, err) if assert.IsType(t, &echo.HTTPError{}, err) { assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code) assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message) } } func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) { t.Parallel() timeoutErrorHandler := func(c *echo.Context, err error) error { if err != nil { if errors.Is(err, context.DeadlineExceeded) { return &echo.HTTPError{ Code: http.StatusServiceUnavailable, Message: "Timeout! change me", } } return err } return nil } timeout := 50 * time.Millisecond m := ContextTimeoutWithConfig(ContextTimeoutConfig{ Timeout: timeout, ErrorHandler: timeoutErrorHandler, }) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e := echo.New() c := e.NewContext(req, rec) err := m(func(c *echo.Context) error { // NOTE: Very short periods are not reliable for tests due to Go routine scheduling and the unpredictable order // for 1) request and 2) time goroutine. For most OS this works as expected, but MacOS seems most flaky. if err := sleepWithContext(c.Request().Context(), 100*time.Millisecond); err != nil { return err } // The Request Context should have a Deadline set by http.ContextTimeoutHandler if _, ok := c.Request().Context().Deadline(); !ok { assert.Fail(t, "No timeout set on Request Context") } return c.String(http.StatusOK, "Hello, World!") })(c) assert.IsType(t, &echo.HTTPError{}, err) assert.Error(t, err) assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code) assert.Equal(t, "Timeout! change me", err.(*echo.HTTPError).Message) } func sleepWithContext(ctx context.Context, d time.Duration) error { timer := time.NewTimer(d) defer func() { _ = timer.Stop() }() select { case <-ctx.Done(): return context.DeadlineExceeded case <-timer.C: return nil } } ================================================ FILE: middleware/cors.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "errors" "fmt" "net/http" "strconv" "strings" "github.com/labstack/echo/v5" ) // CORSConfig defines the config for CORS middleware. type CORSConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // AllowOrigins determines the value of the Access-Control-Allow-Origin // response header. This header defines a list of origins that may access the // resource. // // Origin consist of following parts: `scheme + "://" + host + optional ":" + port` // Wildcard can be used, but has to be set explicitly []string{"*"} // Example: `https://example.com`, `http://example.com:8080`, `*` // // Security: use extreme caution when handling the origin, and carefully // validate any logic. Remember that attackers may register hostile domain names. // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin // // Mandatory. AllowOrigins []string // UnsafeAllowOriginFunc is an optional custom function to validate the origin. It takes the // origin as an argument and returns // - string, allowed origin // - bool, true if allowed or false otherwise. // - error, if an error is returned, it is returned immediately by the handler. // If this option is set, AllowOrigins is ignored. // // Security: use extreme caution when handling the origin, and carefully // validate any logic. Remember that attackers may register hostile (sub)domain names. // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html // // Sub-domain checks example: // UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (string, bool, error) { // if strings.HasSuffix(origin, ".example.com") { // return origin, true, nil // } // return "", false, nil // }, // // Optional. UnsafeAllowOriginFunc func(c *echo.Context, origin string) (allowedOrigin string, allowed bool, err error) // AllowMethods determines the value of the Access-Control-Allow-Methods // response header. This header specified the list of methods allowed when // accessing the resource. This is used in response to a preflight request. // // Optional. Default value DefaultCORSConfig.AllowMethods. // If `allowMethods` is left empty, this middleware will fill for preflight // request `Access-Control-Allow-Methods` header value // from `Allow` header that echo.Router set into context. // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods AllowMethods []string // AllowHeaders determines the value of the Access-Control-Allow-Headers // response header. This header is used in response to a preflight request to // indicate which HTTP headers can be used when making the actual request. // // Optional. Defaults to empty list. No domains allowed for CORS. // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers AllowHeaders []string // AllowCredentials determines the value of the // Access-Control-Allow-Credentials response header. This header indicates // whether or not the response to the request can be exposed when the // credentials mode (Request.credentials) is true. When used as part of a // response to a preflight request, this indicates whether or not the actual // request can be made using credentials. See also // [MDN: Access-Control-Allow-Credentials]. // // Optional. Default value false, in which case the header is not set. // // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. // See "Exploiting CORS misconfigurations for Bitcoins and bounties", // https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials AllowCredentials bool // ExposeHeaders determines the value of Access-Control-Expose-Headers, which // defines a list of headers that clients are allowed to access. // // Optional. Default value []string{}, in which case the header is not set. // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header ExposeHeaders []string // MaxAge determines the value of the Access-Control-Max-Age response header. // This header indicates how long (in seconds) the results of a preflight // request can be cached. // The header is set only if MaxAge != 0, negative value sends "0" which instructs browsers not to cache that response. // // Optional. Default value 0 - meaning header is not sent. // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age MaxAge int } // CORS returns a Cross-Origin Resource Sharing (CORS) middleware. // See also [MDN: Cross-Origin Resource Sharing (CORS)]. // // Origin consist of following parts: `scheme + "://" + host + optional ":" + port` // Wildcard `*` can be used, but has to be set explicitly. // Example: `https://example.com`, `http://example.com:8080`, `*` // // Security: Poorly configured CORS can compromise security because it allows // relaxation of the browser's Same-Origin policy. See [Exploiting CORS // misconfigurations for Bitcoins and bounties] and [Portswigger: Cross-origin // resource sharing (CORS)] for more details. // // [MDN: Cross-Origin Resource Sharing (CORS)]: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS // [Exploiting CORS misconfigurations for Bitcoins and bounties]: https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html // [Portswigger: Cross-origin resource sharing (CORS)]: https://portswigger.net/web-security/cors func CORS(allowOrigins ...string) echo.MiddlewareFunc { c := CORSConfig{ AllowOrigins: allowOrigins, } return CORSWithConfig(c) } // CORSWithConfig returns a CORS middleware with config or panics on invalid configuration. // See: [CORS]. func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } // ToMiddleware converts CORSConfig to middleware or returns an error for invalid configuration func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultSkipper } hasCustomAllowMethods := true if len(config.AllowMethods) == 0 { hasCustomAllowMethods = false config.AllowMethods = []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete} } allowMethods := strings.Join(config.AllowMethods, ",") allowHeaders := strings.Join(config.AllowHeaders, ",") exposeHeaders := strings.Join(config.ExposeHeaders, ",") maxAge := "0" if config.MaxAge > 0 { maxAge = strconv.Itoa(config.MaxAge) } allowOriginFunc := config.UnsafeAllowOriginFunc if config.UnsafeAllowOriginFunc == nil { if len(config.AllowOrigins) == 0 { return nil, errors.New("at least one AllowOrigins is required or UnsafeAllowOriginFunc must be provided") } allowOriginFunc = config.defaultAllowOriginFunc for _, origin := range config.AllowOrigins { if origin == "*" { if config.AllowCredentials { return nil, fmt.Errorf("* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc") } allowOriginFunc = config.starAllowOriginFunc break } if err := validateOrigin(origin, "allow origin"); err != nil { return nil, err } } config.AllowOrigins = append([]string(nil), config.AllowOrigins...) } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } req := c.Request() res := c.Response() origin := req.Header.Get(echo.HeaderOrigin) res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) // Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method, // Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request // For simplicity we just consider method type and later `Origin` header. preflight := req.Method == http.MethodOptions // Although router adds special handler in case of OPTIONS method we avoid calling next for OPTIONS in this middleware // as CORS requests do not have cookies / authentication headers by default, so we could get stuck in auth // middlewares by calling next(c). // But we still want to send `Allow` header as response in case of Non-CORS OPTIONS request as router default // handler does. routerAllowMethods := "" if preflight { tmpAllowMethods, ok := c.Get(echo.ContextKeyHeaderAllow).(string) if ok && tmpAllowMethods != "" { routerAllowMethods = tmpAllowMethods c.Response().Header().Set(echo.HeaderAllow, routerAllowMethods) } } // No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain if origin == "" { if preflight { // req.Method=OPTIONS return c.NoContent(http.StatusNoContent) } return next(c) // let non-browser calls through } allowedOrigin, allowed, err := allowOriginFunc(c, origin) if err != nil { return err } if !allowed { // Origin existed and was NOT allowed if preflight { // From: https://github.com/labstack/echo/issues/2767 // If the request's origin isn't allowed by the CORS configuration, // the middleware should simply omit the relevant CORS headers from the response // and let the browser fail the CORS check (if any). return c.NoContent(http.StatusNoContent) } // From: https://github.com/labstack/echo/issues/2767 // no CORS middleware should block non-preflight requests; // such requests should be let through. One reason is that not all requests that // carry an Origin header participate in the CORS protocol. return next(c) } // Origin existed and was allowed res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin) if config.AllowCredentials { res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") } // Simple request will be let though if !preflight { if exposeHeaders != "" { res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders) } return next(c) } // Below code is for Preflight (OPTIONS) request // // Preflight will end with c.NoContent(http.StatusNoContent) as we do not know if // at the end of handler chain is actual OPTIONS route or 404/405 route which // response code will confuse browsers res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) if !hasCustomAllowMethods && routerAllowMethods != "" { res.Header().Set(echo.HeaderAccessControlAllowMethods, routerAllowMethods) } else { res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods) } if allowHeaders != "" { res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders) } else { h := req.Header.Get(echo.HeaderAccessControlRequestHeaders) if h != "" { res.Header().Set(echo.HeaderAccessControlAllowHeaders, h) } } if config.MaxAge != 0 { res.Header().Set(echo.HeaderAccessControlMaxAge, maxAge) } return c.NoContent(http.StatusNoContent) } }, nil } func (config CORSConfig) starAllowOriginFunc(c *echo.Context, origin string) (string, bool, error) { return "*", true, nil } func (config CORSConfig) defaultAllowOriginFunc(c *echo.Context, origin string) (string, bool, error) { for _, allowedOrigin := range config.AllowOrigins { if strings.EqualFold(allowedOrigin, origin) { return allowedOrigin, true, nil } } return "", false, nil } ================================================ FILE: middleware/cors_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "cmp" "errors" "net/http" "net/http/httptest" "strings" "testing" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestCORS(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodOptions, "/", nil) // Preflight request req.Header.Set(echo.HeaderOrigin, "http://example.com") rec := httptest.NewRecorder() c := e.NewContext(req, rec) mw := CORS("*") handler := mw(func(c *echo.Context) error { return nil }) err := handler(c) assert.NoError(t, err) assert.Equal(t, http.StatusNoContent, rec.Code) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) } func TestCORSConfig(t *testing.T) { var testCases = []struct { name string givenConfig *CORSConfig whenMethod string whenHeaders map[string]string expectHeaders map[string]string notExpectHeaders map[string]string expectErr string }{ { name: "ok, wildcard origin", givenConfig: &CORSConfig{ AllowOrigins: []string{"*"}, }, whenHeaders: map[string]string{echo.HeaderOrigin: "localhost"}, expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "*"}, }, { name: "ok, wildcard AllowedOrigin with no Origin header in request", givenConfig: &CORSConfig{ AllowOrigins: []string{"*"}, }, notExpectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: ""}, }, { name: "ok, specific AllowOrigins and AllowCredentials", givenConfig: &CORSConfig{ AllowOrigins: []string{"http://localhost", "http://localhost:8080"}, AllowCredentials: true, MaxAge: 3600, }, whenHeaders: map[string]string{echo.HeaderOrigin: "http://localhost"}, expectHeaders: map[string]string{ echo.HeaderAccessControlAllowOrigin: "http://localhost", echo.HeaderAccessControlAllowCredentials: "true", }, }, { name: "ok, preflight request with matching origin for `AllowOrigins`", givenConfig: &CORSConfig{ AllowOrigins: []string{"http://localhost"}, AllowCredentials: true, MaxAge: 3600, }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{ echo.HeaderOrigin: "http://localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, }, expectHeaders: map[string]string{ echo.HeaderAccessControlAllowOrigin: "http://localhost", echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", echo.HeaderAccessControlAllowCredentials: "true", echo.HeaderAccessControlMaxAge: "3600", }, }, { name: "ok, preflight request when `Access-Control-Max-Age` is set", givenConfig: &CORSConfig{ AllowOrigins: []string{"http://localhost"}, AllowCredentials: true, MaxAge: 1, }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{ echo.HeaderOrigin: "http://localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, }, expectHeaders: map[string]string{ echo.HeaderAccessControlMaxAge: "1", }, }, { name: "ok, preflight request when `Access-Control-Max-Age` is set to 0 - not to cache response", givenConfig: &CORSConfig{ AllowOrigins: []string{"http://localhost"}, AllowCredentials: true, MaxAge: -1, // forces `Access-Control-Max-Age: 0` }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{ echo.HeaderOrigin: "http://localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, }, expectHeaders: map[string]string{ echo.HeaderAccessControlMaxAge: "0", }, }, { name: "ok, CORS check are skipped", givenConfig: &CORSConfig{ AllowOrigins: []string{"http://localhost"}, AllowCredentials: true, Skipper: func(c *echo.Context) bool { return true }, }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{ echo.HeaderOrigin: "http://localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, }, notExpectHeaders: map[string]string{ echo.HeaderAccessControlAllowOrigin: "localhost", echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", echo.HeaderAccessControlAllowCredentials: "true", echo.HeaderAccessControlMaxAge: "3600", }, }, { name: "nok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` true", givenConfig: &CORSConfig{ AllowOrigins: []string{"*"}, AllowCredentials: true, MaxAge: 3600, }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{ echo.HeaderOrigin: "localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, }, expectErr: `* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc`, }, { name: "nok, preflight request with invalid `AllowOrigins` value", givenConfig: &CORSConfig{ AllowOrigins: []string{"http://server", "missing-scheme"}, }, expectErr: `allow origin is missing scheme or host: missing-scheme`, }, { name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` false", givenConfig: &CORSConfig{ AllowOrigins: []string{"*"}, AllowCredentials: false, // important for this testcase MaxAge: 3600, }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{ echo.HeaderOrigin: "localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, }, expectHeaders: map[string]string{ echo.HeaderAccessControlAllowOrigin: "*", echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", echo.HeaderAccessControlMaxAge: "3600", }, notExpectHeaders: map[string]string{ echo.HeaderAccessControlAllowCredentials: "", }, }, { name: "ok, INSECURE preflight request with wildcard `AllowOrigins` and `AllowCredentials` true", givenConfig: &CORSConfig{ AllowOrigins: []string{"*"}, AllowCredentials: true, MaxAge: 3600, }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{ echo.HeaderOrigin: "localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, }, expectErr: `* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc`, }, { name: "ok, preflight request with Access-Control-Request-Headers", givenConfig: &CORSConfig{ AllowOrigins: []string{"*"}, }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{ echo.HeaderOrigin: "localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, echo.HeaderAccessControlRequestHeaders: "Special-Request-Header", }, expectHeaders: map[string]string{ echo.HeaderAccessControlAllowOrigin: "*", echo.HeaderAccessControlAllowHeaders: "Special-Request-Header", echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", }, }, { name: "ok, preflight request with `AllowOrigins` which allow all subdomains aaa with *", givenConfig: &CORSConfig{ UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (allowedOrigin string, allowed bool, err error) { if strings.HasSuffix(origin, ".example.com") { allowed = true } return origin, allowed, nil }, }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{echo.HeaderOrigin: "http://aaa.example.com"}, expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://aaa.example.com"}, }, { name: "ok, preflight request with `AllowOrigins` which allow all subdomains bbb with *", givenConfig: &CORSConfig{ UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (string, bool, error) { if strings.HasSuffix(origin, ".example.com") { return origin, true, nil } return "", false, nil }, }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{echo.HeaderOrigin: "http://bbb.example.com"}, expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://bbb.example.com"}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() var mw echo.MiddlewareFunc var err error if tc.givenConfig != nil { mw, err = tc.givenConfig.ToMiddleware() } else { mw, err = CORSConfig{}.ToMiddleware() } if err != nil { if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) return } t.Fatal(err) } h := mw(func(c *echo.Context) error { return nil }) method := cmp.Or(tc.whenMethod, http.MethodGet) req := httptest.NewRequest(method, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) for k, v := range tc.whenHeaders { req.Header.Set(k, v) } err = h(c) assert.NoError(t, err) header := rec.Header() for k, v := range tc.expectHeaders { assert.Equal(t, v, header.Get(k), "header: `%v` should be `%v`", k, v) } for k, v := range tc.notExpectHeaders { if v == "" { assert.Len(t, header.Values(k), 0, "header: `%v` should not be set", k) } else { assert.NotEqual(t, v, header.Get(k), "header: `%v` should not be `%v`", k, v) } } }) } } func Test_allowOriginScheme(t *testing.T) { tests := []struct { domain, pattern string expected bool }{ { domain: "http://example.com", pattern: "http://example.com", expected: true, }, { domain: "https://example.com", pattern: "https://example.com", expected: true, }, { domain: "http://example.com", pattern: "https://example.com", expected: false, }, { domain: "https://example.com", pattern: "http://example.com", expected: false, }, } e := echo.New() for _, tt := range tests { req := httptest.NewRequest(http.MethodOptions, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(echo.HeaderOrigin, tt.domain) cors := CORSWithConfig(CORSConfig{ AllowOrigins: []string{tt.pattern}, }) h := cors(func(c *echo.Context) error { return echo.ErrNotFound }) h(c) if tt.expected { assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) } else { assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) } } } func TestCORSWithConfig_AllowMethods(t *testing.T) { var testCases = []struct { name string givenAllowOrigins []string givenAllowMethods []string whenAllowContextKey string whenOrigin string expectAllow string expectAccessControlAllowMethods string }{ { name: "custom AllowMethods, preflight, no origin, sets only allow header from context key", givenAllowOrigins: []string{"*"}, givenAllowMethods: []string{http.MethodGet, http.MethodHead}, whenAllowContextKey: "OPTIONS, GET", whenOrigin: "", expectAllow: "OPTIONS, GET", }, { name: "default AllowMethods, preflight, no origin, no allow header in context key and in response", givenAllowOrigins: []string{"*"}, givenAllowMethods: nil, whenAllowContextKey: "", whenOrigin: "", expectAllow: "", }, { name: "custom AllowMethods, preflight, existing origin, sets both headers different values", givenAllowOrigins: []string{"*"}, givenAllowMethods: []string{http.MethodGet, http.MethodHead}, whenAllowContextKey: "OPTIONS, GET", whenOrigin: "http://google.com", expectAllow: "OPTIONS, GET", expectAccessControlAllowMethods: "GET,HEAD", }, { name: "default AllowMethods, preflight, existing origin, sets both headers", givenAllowOrigins: []string{"*"}, givenAllowMethods: nil, whenAllowContextKey: "OPTIONS, GET", whenOrigin: "http://google.com", expectAllow: "OPTIONS, GET", expectAccessControlAllowMethods: "OPTIONS, GET", }, { name: "default AllowMethods, preflight, existing origin, no allows, sets only CORS allow methods", givenAllowOrigins: []string{"*"}, givenAllowMethods: nil, whenAllowContextKey: "", whenOrigin: "http://google.com", expectAllow: "", expectAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() e.GET("/test", func(c *echo.Context) error { return c.String(http.StatusOK, "OK") }) cors := CORSWithConfig(CORSConfig{ AllowOrigins: tc.givenAllowOrigins, AllowMethods: tc.givenAllowMethods, }) req := httptest.NewRequest(http.MethodOptions, "/test", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(echo.HeaderOrigin, tc.whenOrigin) if tc.whenAllowContextKey != "" { c.Set(echo.ContextKeyHeaderAllow, tc.whenAllowContextKey) } h := cors(func(c *echo.Context) error { return c.String(http.StatusOK, "OK") }) h(c) assert.Equal(t, tc.expectAllow, rec.Header().Get(echo.HeaderAllow)) assert.Equal(t, tc.expectAccessControlAllowMethods, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) }) } } func TestCorsHeaders(t *testing.T) { tests := []struct { name string originDomain string method string allowedOrigin string expected bool expectStatus int expectAllowHeader string }{ { name: "non-preflight request, allow any origin, missing origin header = no CORS logic done", originDomain: "", allowedOrigin: "*", method: http.MethodGet, expected: false, expectStatus: http.StatusOK, }, { name: "non-preflight request, allow any origin, specific origin domain", originDomain: "http://example.com", allowedOrigin: "*", method: http.MethodGet, expected: true, expectStatus: http.StatusOK, }, { name: "non-preflight request, allow specific origin, missing origin header = no CORS logic done", originDomain: "", // Request does not have Origin header allowedOrigin: "http://example.com", method: http.MethodGet, expected: false, expectStatus: http.StatusOK, }, { name: "non-preflight request, allow specific origin, different origin header = CORS logic failure", originDomain: "http://bar.com", allowedOrigin: "http://example.com", method: http.MethodGet, expected: false, expectStatus: http.StatusOK, }, { name: "non-preflight request, allow specific origin, matching origin header = CORS logic done", originDomain: "http://example.com", allowedOrigin: "http://example.com", method: http.MethodGet, expected: true, expectStatus: http.StatusOK, }, { name: "preflight, allow any origin, missing origin header = no CORS logic done", originDomain: "", // Request does not have Origin header allowedOrigin: "*", method: http.MethodOptions, expected: false, expectStatus: http.StatusNoContent, expectAllowHeader: "OPTIONS, GET, POST", }, { name: "preflight, allow any origin, existing origin header = CORS logic done", originDomain: "http://example.com", allowedOrigin: "*", method: http.MethodOptions, expected: true, expectStatus: http.StatusNoContent, expectAllowHeader: "OPTIONS, GET, POST", }, { name: "preflight, allow any origin, missing origin header = no CORS logic done", originDomain: "", // Request does not have Origin header allowedOrigin: "http://example.com", method: http.MethodOptions, expected: false, expectStatus: http.StatusNoContent, expectAllowHeader: "OPTIONS, GET, POST", }, { name: "preflight, allow specific origin, different origin header = no CORS logic done", originDomain: "http://bar.com", allowedOrigin: "http://example.com", method: http.MethodOptions, expected: false, expectStatus: http.StatusNoContent, expectAllowHeader: "OPTIONS, GET, POST", }, { name: "preflight, allow specific origin, matching origin header = CORS logic done", originDomain: "http://example.com", allowedOrigin: "http://example.com", method: http.MethodOptions, expected: true, expectStatus: http.StatusNoContent, expectAllowHeader: "OPTIONS, GET, POST", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { e := echo.New() e.Use(CORSWithConfig(CORSConfig{ AllowOrigins: []string{tc.allowedOrigin}, //AllowCredentials: true, //MaxAge: 3600, })) e.GET("/", func(c *echo.Context) error { return c.String(http.StatusOK, "OK") }) e.POST("/", func(c *echo.Context) error { return c.String(http.StatusCreated, "OK") }) req := httptest.NewRequest(tc.method, "/", nil) rec := httptest.NewRecorder() if tc.originDomain != "" { req.Header.Set(echo.HeaderOrigin, tc.originDomain) } // we run through whole Echo handler chain to see how CORS works with Router OPTIONS handler e.ServeHTTP(rec, req) assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary)) assert.Equal(t, tc.expectAllowHeader, rec.Header().Get(echo.HeaderAllow)) assert.Equal(t, tc.expectStatus, rec.Code) expectedAllowOrigin := "" if tc.allowedOrigin == "*" { expectedAllowOrigin = "*" } else { expectedAllowOrigin = tc.originDomain } switch { case tc.expected && tc.method == http.MethodOptions: assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods) assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary])) case tc.expected && tc.method == http.MethodGet: assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin default: assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin } }) } } func Test_allowOriginFunc(t *testing.T) { returnTrue := func(c *echo.Context, origin string) (string, bool, error) { return origin, true, nil } returnFalse := func(c *echo.Context, origin string) (string, bool, error) { return origin, false, nil } returnError := func(c *echo.Context, origin string) (string, bool, error) { return origin, true, errors.New("this is a test error") } allowOriginFuncs := []func(c *echo.Context, origin string) (string, bool, error){ returnTrue, returnFalse, returnError, } const origin = "http://example.com" e := echo.New() for _, allowOriginFunc := range allowOriginFuncs { req := httptest.NewRequest(http.MethodOptions, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(echo.HeaderOrigin, origin) cors, err := CORSConfig{UnsafeAllowOriginFunc: allowOriginFunc}.ToMiddleware() assert.NoError(t, err) h := cors(func(c *echo.Context) error { return echo.ErrNotFound }) err = h(c) allowedOrigin, allowed, expectedErr := allowOriginFunc(c, origin) if expectedErr != nil { assert.Equal(t, expectedErr, err) assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) continue } if allowed { assert.Equal(t, allowedOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) } else { assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) } } } ================================================ FILE: middleware/csrf.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "crypto/subtle" "net/http" "slices" "strings" "time" "github.com/labstack/echo/v5" ) // CSRFUsingSecFetchSite is a context key for CSRF middleware what is set when the client browser is using Sec-Fetch-Site // header and the request is deemed safe. // It is a dummy token value that can be used to render CSRF token for form by handlers. // // We know that the client is using a browser that supports Sec-Fetch-Site header, so when the form is submitted in // the future with this dummy token value it is OK. Although the request is safe, the template rendered by the // handler may need this value to render CSRF token for form. const CSRFUsingSecFetchSite = "_echo_csrf_using_sec_fetch_site_" // CSRFConfig defines the config for CSRF middleware. type CSRFConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // TrustedOrigin permits any request with `Sec-Fetch-Site` header whose `Origin` header // exactly matches the specified value. // Values should be formatted as Origin header "scheme://host[:port]". // // See [Origin]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin // See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers TrustedOrigins []string // AllowSecFetchSameSite allows custom behaviour for `Sec-Fetch-Site` requests that are about to // fail with CRSF error, to be allowed or replaced with custom error. // This function applies to `Sec-Fetch-Site` values: // - `same-site` same registrable domain (subdomain and/or different port) // - `cross-site` request originates from different site // See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers AllowSecFetchSiteFunc func(c *echo.Context) (bool, error) // TokenLength is the length of the generated token. TokenLength uint8 // Optional. Default value 32. // TokenLookup is a string in the form of ":" or ":,:" that is used // to extract token from the request. // Optional. Default value "header:X-CSRF-Token". // Possible values: // - "header:" or "header::" // - "query:" // - "form:" // Multiple sources example: // - "header:X-CSRF-Token,query:csrf" TokenLookup string `yaml:"token_lookup"` // Generator defines a function to generate token. // Optional. Defaults tp randomString(TokenLength). Generator func() string // Context key to store generated CSRF token into context. // Optional. Default value "csrf". ContextKey string // Name of the CSRF cookie. This cookie will store CSRF token. // Optional. Default value "csrf". CookieName string // Domain of the CSRF cookie. // Optional. Default value none. CookieDomain string // Path of the CSRF cookie. // Optional. Default value none. CookiePath string // Max age (in seconds) of the CSRF cookie. // Optional. Default value 86400 (24hr). CookieMaxAge int // Indicates if CSRF cookie is secure. // Optional. Default value false. CookieSecure bool // Indicates if CSRF cookie is HTTP only. // Optional. Default value false. CookieHTTPOnly bool // Indicates SameSite mode of the CSRF cookie. // Optional. Default value SameSiteDefaultMode. CookieSameSite http.SameSite // ErrorHandler defines a function which is executed for returning custom errors. ErrorHandler func(c *echo.Context, err error) error } // ErrCSRFInvalid is returned when CSRF check fails var ErrCSRFInvalid = &echo.HTTPError{Code: http.StatusForbidden, Message: "invalid csrf token"} // DefaultCSRFConfig is the default CSRF middleware config. var DefaultCSRFConfig = CSRFConfig{ Skipper: DefaultSkipper, TokenLength: 32, TokenLookup: "header:" + echo.HeaderXCSRFToken, ContextKey: "csrf", CookieName: "_csrf", CookieMaxAge: 86400, CookieSameSite: http.SameSiteDefaultMode, } // CSRF returns a Cross-Site Request Forgery (CSRF) middleware. // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery func CSRF() echo.MiddlewareFunc { return CSRFWithConfig(DefaultCSRFConfig) } // CSRFWithConfig returns a CSRF middleware with config or panics on invalid configuration. func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } // ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultCSRFConfig.Skipper } if config.TokenLength == 0 { config.TokenLength = DefaultCSRFConfig.TokenLength } if config.Generator == nil { config.Generator = createRandomStringGenerator(config.TokenLength) } if config.TokenLookup == "" { config.TokenLookup = DefaultCSRFConfig.TokenLookup } if config.ContextKey == "" { config.ContextKey = DefaultCSRFConfig.ContextKey } if config.CookieName == "" { config.CookieName = DefaultCSRFConfig.CookieName } if config.CookieMaxAge == 0 { config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge } if config.CookieSameSite == http.SameSiteNoneMode { config.CookieSecure = true } if len(config.TrustedOrigins) > 0 { if err := validateOrigins(config.TrustedOrigins, "trusted origin"); err != nil { return nil, err } config.TrustedOrigins = append([]string(nil), config.TrustedOrigins...) } extractors, cErr := createExtractors(config.TokenLookup, 1) if cErr != nil { return nil, cErr } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } // use the `Sec-Fetch-Site` header as part of a modern approach to CSRF protection allow, err := config.checkSecFetchSiteRequest(c) if err != nil { return err } if allow { return next(c) } // Fallback to legacy token based CSRF protection token := "" if k, err := c.Cookie(config.CookieName); err != nil { token = config.Generator() // Generate token } else { token = k.Value // Reuse token } switch c.Request().Method { case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: default: // Validate token only for requests which are not defined as 'safe' by RFC7231 var lastExtractorErr error var lastTokenErr error outer: for _, extractor := range extractors { clientTokens, _, err := extractor(c) if err != nil { lastExtractorErr = err continue } for _, clientToken := range clientTokens { if validateCSRFToken(token, clientToken) { lastTokenErr = nil lastExtractorErr = nil break outer } lastTokenErr = ErrCSRFInvalid } } var finalErr error if lastTokenErr != nil { finalErr = lastTokenErr } else if lastExtractorErr != nil { finalErr = echo.ErrBadRequest.Wrap(lastExtractorErr) } if finalErr != nil { if config.ErrorHandler != nil { return config.ErrorHandler(c, finalErr) } return finalErr } } // Set CSRF cookie cookie := new(http.Cookie) cookie.Name = config.CookieName cookie.Value = token if config.CookiePath != "" { cookie.Path = config.CookiePath } if config.CookieDomain != "" { cookie.Domain = config.CookieDomain } if config.CookieSameSite != http.SameSiteDefaultMode { cookie.SameSite = config.CookieSameSite } cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second) cookie.Secure = config.CookieSecure cookie.HttpOnly = config.CookieHTTPOnly c.SetCookie(cookie) // Store token in the context c.Set(config.ContextKey, token) // Protect clients from caching the response c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie) return next(c) } }, nil } func validateCSRFToken(token, clientToken string) bool { return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1 } var safeMethods = []string{http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace} func (config CSRFConfig) checkSecFetchSiteRequest(c *echo.Context) (bool, error) { // https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers // Sec-Fetch-Site values are: // - `same-origin` exact origin match - allow always // - `same-site` same registrable domain (subdomain and/or different port) - block, unless explicitly trusted // - `cross-site` request originates from different site - block, unless explicitly trusted // - `none` direct navigation (URL bar, bookmark) - allow always secFetchSite := c.Request().Header.Get(echo.HeaderSecFetchSite) if secFetchSite == "" { return false, nil } if len(config.TrustedOrigins) > 0 { // trusted sites ala OAuth callbacks etc. should be let through origin := c.Request().Header.Get(echo.HeaderOrigin) if origin != "" { for _, trustedOrigin := range config.TrustedOrigins { if strings.EqualFold(origin, trustedOrigin) { return true, nil } } } } isSafe := slices.Contains(safeMethods, c.Request().Method) if !isSafe { // for state-changing request check SecFetchSite value isSafe = secFetchSite == "same-origin" || secFetchSite == "none" } if isSafe { // This helps handlers that support older token-based CSRF protection. // We know that the client is using a browser that supports Sec-Fetch-Site header, so when the form is submitted in // the future with this dummy token value it is OK. Although the request is safe, the template rendered by the // handler may need this value to render CSRF token for form. c.Set(config.ContextKey, CSRFUsingSecFetchSite) return true, nil } // we are here when request is state-changing and `cross-site` or `same-site` // Note: if you want to allow `same-site` use config.TrustedOrigins or `config.AllowSecFetchSiteFunc` if config.AllowSecFetchSiteFunc != nil { return config.AllowSecFetchSiteFunc(c) } if secFetchSite == "same-site" { return false, echo.NewHTTPError(http.StatusForbidden, "same-site request blocked by CSRF") } return false, echo.NewHTTPError(http.StatusForbidden, "cross-site request blocked by CSRF") } ================================================ FILE: middleware/csrf_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "cmp" "net/http" "net/http/httptest" "net/url" "strings" "testing" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestCSRF_tokenExtractors(t *testing.T) { var testCases = []struct { name string whenTokenLookup string whenCookieName string givenCSRFCookie string givenMethod string givenQueryTokens map[string][]string givenFormTokens map[string][]string givenHeaderTokens map[string][]string expectError string expectToMiddlewareError string }{ { name: "ok, multiple token lookups sources, succeeds on last one", whenTokenLookup: "header:X-CSRF-Token,form:csrf", givenCSRFCookie: "token", givenMethod: http.MethodPost, givenHeaderTokens: map[string][]string{ echo.HeaderXCSRFToken: {"invalid_token"}, }, givenFormTokens: map[string][]string{ "csrf": {"token"}, }, }, { name: "ok, token from POST form", whenTokenLookup: "form:csrf", givenCSRFCookie: "token", givenMethod: http.MethodPost, givenFormTokens: map[string][]string{ "csrf": {"token"}, }, }, { name: "ok, token from POST form, second token passes", whenTokenLookup: "form:csrf", givenCSRFCookie: "token", givenMethod: http.MethodPost, givenFormTokens: map[string][]string{ "csrf": {"invalid", "token"}, }, expectError: "code=403, message=invalid csrf token", }, { name: "nok, invalid token from POST form", whenTokenLookup: "form:csrf", givenCSRFCookie: "token", givenMethod: http.MethodPost, givenFormTokens: map[string][]string{ "csrf": {"invalid_token"}, }, expectError: "code=403, message=invalid csrf token", }, { name: "nok, missing token from POST form", whenTokenLookup: "form:csrf", givenCSRFCookie: "token", givenMethod: http.MethodPost, givenFormTokens: map[string][]string{}, expectError: "code=400, message=Bad Request, err=missing value in the form", }, { name: "ok, token from POST header", whenTokenLookup: "", // will use defaults givenCSRFCookie: "token", givenMethod: http.MethodPost, givenHeaderTokens: map[string][]string{ echo.HeaderXCSRFToken: {"token"}, }, }, { name: "nok, token from POST header, tokens limited to 1, second token would pass", whenTokenLookup: "header:" + echo.HeaderXCSRFToken, givenCSRFCookie: "token", givenMethod: http.MethodPost, givenHeaderTokens: map[string][]string{ echo.HeaderXCSRFToken: {"invalid", "token"}, }, expectError: "code=403, message=invalid csrf token", }, { name: "nok, invalid token from POST header", whenTokenLookup: "header:" + echo.HeaderXCSRFToken, givenCSRFCookie: "token", givenMethod: http.MethodPost, givenHeaderTokens: map[string][]string{ echo.HeaderXCSRFToken: {"invalid_token"}, }, expectError: "code=403, message=invalid csrf token", }, { name: "nok, missing token from POST header", whenTokenLookup: "header:" + echo.HeaderXCSRFToken, givenCSRFCookie: "token", givenMethod: http.MethodPost, givenHeaderTokens: map[string][]string{}, expectError: "code=400, message=Bad Request, err=missing value in request header", }, { name: "ok, token from PUT query param", whenTokenLookup: "query:csrf-param", givenCSRFCookie: "token", givenMethod: http.MethodPut, givenQueryTokens: map[string][]string{ "csrf-param": {"token"}, }, }, { name: "nok, token from PUT query form, second token would pass", whenTokenLookup: "query:csrf", givenCSRFCookie: "token", givenMethod: http.MethodPut, givenQueryTokens: map[string][]string{ "csrf": {"invalid", "token"}, }, expectError: "code=403, message=invalid csrf token", }, { name: "nok, invalid token from PUT query form", whenTokenLookup: "query:csrf", givenCSRFCookie: "token", givenMethod: http.MethodPut, givenQueryTokens: map[string][]string{ "csrf": {"invalid_token"}, }, expectError: "code=403, message=invalid csrf token", }, { name: "nok, missing token from PUT query form", whenTokenLookup: "query:csrf", givenCSRFCookie: "token", givenMethod: http.MethodPut, givenQueryTokens: map[string][]string{}, expectError: "code=400, message=Bad Request, err=missing value in the query string", }, { name: "nok, invalid TokenLookup", whenTokenLookup: "q", givenCSRFCookie: "token", givenMethod: http.MethodPut, givenQueryTokens: map[string][]string{}, expectToMiddlewareError: "extractor source for lookup could not be split into needed parts: q", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() q := make(url.Values) for queryParam, values := range tc.givenQueryTokens { for _, v := range values { q.Add(queryParam, v) } } f := make(url.Values) for formKey, values := range tc.givenFormTokens { for _, v := range values { f.Add(formKey, v) } } var req *http.Request switch tc.givenMethod { case http.MethodGet: req = httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil) case http.MethodPost, http.MethodPut: req = httptest.NewRequest(http.MethodPost, "/?"+q.Encode(), strings.NewReader(f.Encode())) req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) } for header, values := range tc.givenHeaderTokens { for _, v := range values { req.Header.Add(header, v) } } if tc.givenCSRFCookie != "" { req.Header.Set(echo.HeaderCookie, "_csrf="+tc.givenCSRFCookie) } rec := httptest.NewRecorder() c := e.NewContext(req, rec) config := CSRFConfig{ TokenLookup: tc.whenTokenLookup, CookieName: tc.whenCookieName, } csrf, err := config.ToMiddleware() if tc.expectToMiddlewareError != "" { assert.EqualError(t, err, tc.expectToMiddlewareError) return } else if err != nil { assert.NoError(t, err) } h := csrf(func(c *echo.Context) error { return c.String(http.StatusOK, "test") }) err = h(c) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestCSRFWithConfig(t *testing.T) { token := randomString(16) var testCases = []struct { name string givenConfig *CSRFConfig whenMethod string whenHeaders map[string]string expectEmptyBody bool expectMWError string expectCookieContains string expectTokenInContext string expectErr string }{ { name: "ok, GET", whenMethod: http.MethodGet, expectCookieContains: "_csrf", expectTokenInContext: "TESTTOKEN", }, { name: "ok, POST valid token", whenHeaders: map[string]string{ echo.HeaderCookie: "_csrf=" + token, echo.HeaderXCSRFToken: token, }, whenMethod: http.MethodPost, expectCookieContains: "_csrf", expectTokenInContext: token, }, { name: "nok, POST without token", whenMethod: http.MethodPost, expectEmptyBody: true, expectErr: `code=400, message=Bad Request, err=missing value in request header`, }, { name: "nok, POST empty token", whenHeaders: map[string]string{echo.HeaderXCSRFToken: ""}, whenMethod: http.MethodPost, expectEmptyBody: true, expectErr: `code=403, message=invalid csrf token`, }, { name: "nok, invalid trusted origin in Config", givenConfig: &CSRFConfig{ TrustedOrigins: []string{"http://example.com", "invalid"}, }, expectMWError: `trusted origin is missing scheme or host: invalid`, }, { name: "ok, TokenLength", givenConfig: &CSRFConfig{ TokenLength: 16, }, whenMethod: http.MethodGet, expectCookieContains: "_csrf", expectTokenInContext: "TESTTOKEN", }, { name: "ok, unsafe method + SecFetchSite=same-origin passes", whenHeaders: map[string]string{ echo.HeaderSecFetchSite: "same-origin", }, whenMethod: http.MethodPost, expectTokenInContext: "_echo_csrf_using_sec_fetch_site_", }, { name: "ok, safe method + SecFetchSite=same-origin passes", whenHeaders: map[string]string{ echo.HeaderSecFetchSite: "same-origin", }, whenMethod: http.MethodGet, expectTokenInContext: "_echo_csrf_using_sec_fetch_site_", }, { name: "nok, unsafe method + SecFetchSite=same-cross blocked", whenHeaders: map[string]string{ echo.HeaderSecFetchSite: "same-cross", }, whenMethod: http.MethodPost, expectEmptyBody: true, expectErr: `code=403, message=cross-site request blocked by CSRF`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() req := httptest.NewRequest(cmp.Or(tc.whenMethod, http.MethodPost), "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) for key, value := range tc.whenHeaders { req.Header.Set(key, value) } config := CSRFConfig{} if tc.givenConfig != nil { config = *tc.givenConfig } if config.Generator == nil { config.Generator = func() string { return "TESTTOKEN" } } mw, err := config.ToMiddleware() if tc.expectMWError != "" { assert.EqualError(t, err, tc.expectMWError) return } assert.NoError(t, err) h := mw(func(c *echo.Context) error { cToken := c.Get(cmp.Or(config.ContextKey, DefaultCSRFConfig.ContextKey)) assert.Equal(t, tc.expectTokenInContext, cToken) return c.String(http.StatusOK, "test") }) err = h(c) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } expect := "test" if tc.expectEmptyBody { expect = "" } assert.Equal(t, expect, rec.Body.String()) if tc.expectCookieContains != "" { assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), tc.expectCookieContains) } }) } } func TestCSRF(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) csrf := CSRF() h := csrf(func(c *echo.Context) error { return c.String(http.StatusOK, "test") }) // Generate CSRF token h(c) assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf") } func TestCSRFSetSameSiteMode(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) csrf := CSRFWithConfig(CSRFConfig{ CookieSameSite: http.SameSiteStrictMode, }) h := csrf(func(c *echo.Context) error { return c.String(http.StatusOK, "test") }) r := h(c) assert.NoError(t, r) assert.Regexp(t, "SameSite=Strict", rec.Header()["Set-Cookie"]) } func TestCSRFWithoutSameSiteMode(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) csrf := CSRFWithConfig(CSRFConfig{}) h := csrf(func(c *echo.Context) error { return c.String(http.StatusOK, "test") }) r := h(c) assert.NoError(t, r) assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"]) } func TestCSRFWithSameSiteDefaultMode(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) csrf := CSRFWithConfig(CSRFConfig{ CookieSameSite: http.SameSiteDefaultMode, }) h := csrf(func(c *echo.Context) error { return c.String(http.StatusOK, "test") }) r := h(c) assert.NoError(t, r) assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"]) } func TestCSRFWithSameSiteModeNone(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) csrf, err := CSRFConfig{ CookieSameSite: http.SameSiteNoneMode, }.ToMiddleware() assert.NoError(t, err) h := csrf(func(c *echo.Context) error { return c.String(http.StatusOK, "test") }) r := h(c) assert.NoError(t, r) assert.Regexp(t, "SameSite=None", rec.Header()["Set-Cookie"]) assert.Regexp(t, "Secure", rec.Header()["Set-Cookie"]) } func TestCSRFConfig_skipper(t *testing.T) { var testCases = []struct { name string whenSkip bool expectCookies int }{ { name: "do skip", whenSkip: true, expectCookies: 0, }, { name: "do not skip", whenSkip: false, expectCookies: 1, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) csrf := CSRFWithConfig(CSRFConfig{ Skipper: func(c *echo.Context) bool { return tc.whenSkip }, }) h := csrf(func(c *echo.Context) error { return c.String(http.StatusOK, "test") }) r := h(c) assert.NoError(t, r) cookie := rec.Header()["Set-Cookie"] assert.Len(t, cookie, tc.expectCookies) }) } } func TestCSRFErrorHandling(t *testing.T) { cfg := CSRFConfig{ ErrorHandler: func(c *echo.Context, err error) error { return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed") }, } e := echo.New() e.POST("/", func(c *echo.Context) error { return c.String(http.StatusNotImplemented, "should not end up here") }) e.Use(CSRFWithConfig(cfg)) req := httptest.NewRequest(http.MethodPost, "/", nil) res := httptest.NewRecorder() e.ServeHTTP(res, req) assert.Equal(t, http.StatusTeapot, res.Code) assert.Equal(t, "{\"message\":\"error_handler_executed\"}\n", res.Body.String()) } func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { var testCases = []struct { name string givenConfig CSRFConfig whenMethod string whenSecFetchSite string whenOrigin string expectAllow bool expectErr string }{ { name: "ok, unsafe POST, no SecFetchSite is not blocked", givenConfig: CSRFConfig{}, whenMethod: http.MethodPost, whenSecFetchSite: "", expectAllow: false, // should fall back to token CSRF }, { name: "ok, safe GET + same-origin passes", givenConfig: CSRFConfig{}, whenMethod: http.MethodGet, whenSecFetchSite: "same-origin", expectAllow: true, }, { name: "ok, safe GET + none passes", givenConfig: CSRFConfig{}, whenMethod: http.MethodGet, whenSecFetchSite: "none", expectAllow: true, }, { name: "ok, safe GET + same-site passes", givenConfig: CSRFConfig{}, whenMethod: http.MethodGet, whenSecFetchSite: "same-site", expectAllow: true, }, { name: "ok, safe GET + cross-site passes", givenConfig: CSRFConfig{}, whenMethod: http.MethodGet, whenSecFetchSite: "cross-site", expectAllow: true, }, { name: "nok, unsafe POST + cross-site is blocked", givenConfig: CSRFConfig{}, whenMethod: http.MethodPost, whenSecFetchSite: "cross-site", expectAllow: false, expectErr: `code=403, message=cross-site request blocked by CSRF`, }, { name: "nok, unsafe POST + same-site is blocked", givenConfig: CSRFConfig{}, whenMethod: http.MethodPost, whenSecFetchSite: "same-site", expectAllow: false, expectErr: `code=403, message=same-site request blocked by CSRF`, }, { name: "ok, unsafe POST + same-origin passes", givenConfig: CSRFConfig{}, whenMethod: http.MethodPost, whenSecFetchSite: "same-origin", expectAllow: true, }, { name: "ok, unsafe POST + none passes", givenConfig: CSRFConfig{}, whenMethod: http.MethodPost, whenSecFetchSite: "none", expectAllow: true, }, { name: "ok, unsafe PUT + same-origin passes", givenConfig: CSRFConfig{}, whenMethod: http.MethodPut, whenSecFetchSite: "same-origin", expectAllow: true, }, { name: "ok, unsafe PUT + none passes", givenConfig: CSRFConfig{}, whenMethod: http.MethodPut, whenSecFetchSite: "none", expectAllow: true, }, { name: "ok, unsafe DELETE + same-origin passes", givenConfig: CSRFConfig{}, whenMethod: http.MethodDelete, whenSecFetchSite: "same-origin", expectAllow: true, }, { name: "ok, unsafe PATCH + same-origin passes", givenConfig: CSRFConfig{}, whenMethod: http.MethodPatch, whenSecFetchSite: "same-origin", expectAllow: true, }, { name: "nok, unsafe PUT + cross-site is blocked", givenConfig: CSRFConfig{}, whenMethod: http.MethodPut, whenSecFetchSite: "cross-site", expectAllow: false, expectErr: `code=403, message=cross-site request blocked by CSRF`, }, { name: "nok, unsafe PUT + same-site is blocked", givenConfig: CSRFConfig{}, whenMethod: http.MethodPut, whenSecFetchSite: "same-site", expectAllow: false, expectErr: `code=403, message=same-site request blocked by CSRF`, }, { name: "nok, unsafe DELETE + cross-site is blocked", givenConfig: CSRFConfig{}, whenMethod: http.MethodDelete, whenSecFetchSite: "cross-site", expectAllow: false, expectErr: `code=403, message=cross-site request blocked by CSRF`, }, { name: "nok, unsafe DELETE + same-site is blocked", givenConfig: CSRFConfig{}, whenMethod: http.MethodDelete, whenSecFetchSite: "same-site", expectAllow: false, expectErr: `code=403, message=same-site request blocked by CSRF`, }, { name: "nok, unsafe PATCH + cross-site is blocked", givenConfig: CSRFConfig{}, whenMethod: http.MethodPatch, whenSecFetchSite: "cross-site", expectAllow: false, expectErr: `code=403, message=cross-site request blocked by CSRF`, }, { name: "ok, safe HEAD + same-origin passes", givenConfig: CSRFConfig{}, whenMethod: http.MethodHead, whenSecFetchSite: "same-origin", expectAllow: true, }, { name: "ok, safe HEAD + cross-site passes", givenConfig: CSRFConfig{}, whenMethod: http.MethodHead, whenSecFetchSite: "cross-site", expectAllow: true, }, { name: "ok, safe OPTIONS + cross-site passes", givenConfig: CSRFConfig{}, whenMethod: http.MethodOptions, whenSecFetchSite: "cross-site", expectAllow: true, }, { name: "ok, safe TRACE + cross-site passes", givenConfig: CSRFConfig{}, whenMethod: http.MethodTrace, whenSecFetchSite: "cross-site", expectAllow: true, }, { name: "ok, unsafe POST + cross-site + matching trusted origin passes", givenConfig: CSRFConfig{ TrustedOrigins: []string{"https://trusted.example.com"}, }, whenMethod: http.MethodPost, whenSecFetchSite: "cross-site", whenOrigin: "https://trusted.example.com", expectAllow: true, }, { name: "ok, unsafe POST + same-site + matching trusted origin passes", givenConfig: CSRFConfig{ TrustedOrigins: []string{"https://trusted.example.com"}, }, whenMethod: http.MethodPost, whenSecFetchSite: "same-site", whenOrigin: "https://trusted.example.com", expectAllow: true, }, { name: "nok, unsafe POST + cross-site + non-matching origin is blocked", givenConfig: CSRFConfig{ TrustedOrigins: []string{"https://trusted.example.com"}, }, whenMethod: http.MethodPost, whenSecFetchSite: "cross-site", whenOrigin: "https://evil.example.com", expectAllow: false, expectErr: `code=403, message=cross-site request blocked by CSRF`, }, { name: "ok, unsafe POST + cross-site + case-insensitive trusted origin match passes", givenConfig: CSRFConfig{ TrustedOrigins: []string{"https://trusted.example.com"}, }, whenMethod: http.MethodPost, whenSecFetchSite: "cross-site", whenOrigin: "https://TRUSTED.example.com", expectAllow: true, }, { name: "ok, unsafe POST + same-origin + trusted origins configured but not matched passes", givenConfig: CSRFConfig{ TrustedOrigins: []string{"https://trusted.example.com"}, }, whenMethod: http.MethodPost, whenSecFetchSite: "same-origin", whenOrigin: "https://different.example.com", expectAllow: true, }, { name: "nok, unsafe POST + cross-site + empty origin + trusted origins configured is blocked", givenConfig: CSRFConfig{ TrustedOrigins: []string{"https://trusted.example.com"}, }, whenMethod: http.MethodPost, whenSecFetchSite: "cross-site", whenOrigin: "", expectAllow: false, expectErr: `code=403, message=cross-site request blocked by CSRF`, }, { name: "ok, unsafe POST + cross-site + multiple trusted origins, second one matches", givenConfig: CSRFConfig{ TrustedOrigins: []string{"https://first.example.com", "https://second.example.com"}, }, whenMethod: http.MethodPost, whenSecFetchSite: "cross-site", whenOrigin: "https://second.example.com", expectAllow: true, }, { name: "ok, unsafe POST + same-site + custom func allows", givenConfig: CSRFConfig{ AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { return true, nil }, }, whenMethod: http.MethodPost, whenSecFetchSite: "same-site", expectAllow: true, }, { name: "ok, unsafe POST + cross-site + custom func allows", givenConfig: CSRFConfig{ AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { return true, nil }, }, whenMethod: http.MethodPost, whenSecFetchSite: "cross-site", expectAllow: true, }, { name: "nok, unsafe POST + same-site + custom func returns custom error", givenConfig: CSRFConfig{ AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { return false, echo.NewHTTPError(http.StatusTeapot, "custom error from func") }, }, whenMethod: http.MethodPost, whenSecFetchSite: "same-site", expectAllow: false, expectErr: `code=418, message=custom error from func`, }, { name: "nok, unsafe POST + cross-site + custom func returns false with nil error", givenConfig: CSRFConfig{ AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { return false, nil }, }, whenMethod: http.MethodPost, whenSecFetchSite: "cross-site", expectAllow: false, expectErr: "", // custom func returns nil error, so no error expected }, { name: "nok, unsafe POST + invalid Sec-Fetch-Site value treated as cross-site", givenConfig: CSRFConfig{}, whenMethod: http.MethodPost, whenSecFetchSite: "invalid-value", expectAllow: false, expectErr: `code=403, message=cross-site request blocked by CSRF`, }, { name: "ok, unsafe POST + cross-site + trusted origin takes precedence over custom func", givenConfig: CSRFConfig{ TrustedOrigins: []string{"https://trusted.example.com"}, AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { return false, echo.NewHTTPError(http.StatusTeapot, "should not be called") }, }, whenMethod: http.MethodPost, whenSecFetchSite: "cross-site", whenOrigin: "https://trusted.example.com", expectAllow: true, }, { name: "nok, unsafe POST + cross-site + trusted origin not matched, custom func blocks", givenConfig: CSRFConfig{ TrustedOrigins: []string{"https://trusted.example.com"}, AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { return false, echo.NewHTTPError(http.StatusTeapot, "custom block") }, }, whenMethod: http.MethodPost, whenSecFetchSite: "cross-site", whenOrigin: "https://evil.example.com", expectAllow: false, expectErr: `code=418, message=custom block`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(tc.whenMethod, "/", nil) if tc.whenSecFetchSite != "" { req.Header.Set(echo.HeaderSecFetchSite, tc.whenSecFetchSite) } if tc.whenOrigin != "" { req.Header.Set(echo.HeaderOrigin, tc.whenOrigin) } res := httptest.NewRecorder() c := echo.NewContext(req, res) allow, err := tc.givenConfig.checkSecFetchSiteRequest(c) assert.Equal(t, tc.expectAllow, allow) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } }) } } ================================================ FILE: middleware/decompress.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "compress/gzip" "io" "net/http" "sync" "github.com/labstack/echo/v5" ) // DecompressConfig defines the config for Decompress middleware. type DecompressConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers GzipDecompressPool Decompressor // MaxDecompressedSize limits the maximum size of decompressed request body in bytes. // If the decompressed body exceeds this limit, the middleware returns HTTP 413 error. // This prevents zip bomb attacks where small compressed payloads decompress to huge sizes. // Default: 100 * MB (104,857,600 bytes) // Set to -1 to disable limits (not recommended in production). MaxDecompressedSize int64 } // GZIPEncoding content-encoding header if set to "gzip", decompress body contents. const GZIPEncoding string = "gzip" // Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers type Decompressor interface { gzipDecompressPool() sync.Pool } // DefaultGzipDecompressPool is the default implementation of Decompressor interface type DefaultGzipDecompressPool struct { } func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool { return sync.Pool{New: func() any { return new(gzip.Reader) }} } // Decompress decompresses request body based if content encoding type is set to "gzip" with default config // // SECURITY: By default, this limits decompressed data to 100MB to prevent zip bomb attacks. // To customize the limit, use DecompressWithConfig. To disable limits (not recommended in production), // set MaxDecompressedSize to -1. func Decompress() echo.MiddlewareFunc { return DecompressWithConfig(DecompressConfig{}) } // DecompressWithConfig returns a decompress middleware with config or panics on invalid configuration. // // SECURITY: If MaxDecompressedSize is not set (zero value), it defaults to 100MB to prevent // DoS attacks via zip bombs. Set to -1 to explicitly disable limits if needed for your use case. func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } // ToMiddleware converts DecompressConfig to middleware or returns an error for invalid configuration func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultSkipper } if config.GzipDecompressPool == nil { config.GzipDecompressPool = &DefaultGzipDecompressPool{} } // Apply secure default for decompression limit if config.MaxDecompressedSize == 0 { config.MaxDecompressedSize = 100 * MB } return func(next echo.HandlerFunc) echo.HandlerFunc { pool := config.GzipDecompressPool.gzipDecompressPool() return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } if c.Request().Header.Get(echo.HeaderContentEncoding) != GZIPEncoding { return next(c) } i := pool.Get() gr, ok := i.(*gzip.Reader) if !ok || gr == nil { if err, isErr := i.(error); isErr { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return echo.NewHTTPError(http.StatusInternalServerError, "unexpected type from gzip decompression pool") } defer pool.Put(gr) b := c.Request().Body defer b.Close() if err := gr.Reset(b); err != nil { if err == io.EOF { //ignore if body is empty return next(c) } return err } // only Close gzip reader if it was set to a proper gzip source otherwise it will panic on close. defer gr.Close() // Apply decompression size limit to prevent zip bombs if config.MaxDecompressedSize > 0 { c.Request().Body = &limitedGzipReader{ Reader: gr, remaining: config.MaxDecompressedSize, limit: config.MaxDecompressedSize, } } else { // -1 means explicitly unlimited (not recommended) c.Request().Body = gr } return next(c) } }, nil } // limitedGzipReader wraps a gzip reader with size limiting to prevent zip bombs type limitedGzipReader struct { *gzip.Reader remaining int64 limit int64 } func (r *limitedGzipReader) Read(p []byte) (n int, err error) { if r.remaining <= 0 { // Limit exceeded - return 413 error return 0, echo.ErrStatusRequestEntityTooLarge } // Limit the read to remaining bytes if int64(len(p)) > r.remaining { p = p[:r.remaining] } n, err = r.Reader.Read(p) r.remaining -= int64(n) return n, err } func (r *limitedGzipReader) Close() error { return r.Reader.Close() } ================================================ FILE: middleware/decompress_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "bytes" "compress/gzip" "errors" "io" "net/http" "net/http/httptest" "strings" "sync" "testing" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestDecompress(t *testing.T) { e := echo.New() h := Decompress()(func(c *echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) // Decompress request body body := `{"name": "echo"}` gz, _ := gzipString(body) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := h(c) assert.NoError(t, err) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) b, err := io.ReadAll(req.Body) assert.NoError(t, err) assert.Equal(t, body, string(b)) } func TestDecompress_skippedIfNoHeader(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) rec := httptest.NewRecorder() c := e.NewContext(req, rec) // Skip if no Content-Encoding header h := Decompress()(func(c *echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) err := h(c) assert.NoError(t, err) assert.Equal(t, "test", rec.Body.String()) } func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h, err := DecompressConfig{}.ToMiddleware() assert.NoError(t, err) err = h(func(c *echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil })(c) assert.NoError(t, err) assert.Equal(t, "test", rec.Body.String()) } func TestDecompressWithConfig_DefaultConfig(t *testing.T) { e := echo.New() h := Decompress()(func(c *echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) // Decompress body := `{"name": "echo"}` gz, _ := gzipString(body) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := h(c) assert.NoError(t, err) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) b, err := io.ReadAll(req.Body) assert.NoError(t, err) assert.Equal(t, body, string(b)) } func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { e := echo.New() body := `{"name":"echo"}` gz, _ := gzipString(body) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() e.NewContext(req, rec) e.ServeHTTP(rec, req) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) b, err := io.ReadAll(req.Body) assert.NoError(t, err) assert.NotEqual(t, b, body) assert.Equal(t, b, gz) } func TestDecompressNoContent(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := Decompress()(func(c *echo.Context) error { return c.NoContent(http.StatusNoContent) }) err := h(c) if assert.NoError(t, err) { assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) assert.Equal(t, 0, len(rec.Body.Bytes())) } } func TestDecompressErrorReturned(t *testing.T) { e := echo.New() e.Use(Decompress()) e.GET("/", func(c *echo.Context) error { return echo.ErrNotFound }) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusNotFound, rec.Code) assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) } func TestDecompressSkipper(t *testing.T) { e := echo.New() e.Use(DecompressWithConfig(DecompressConfig{ Skipper: func(c *echo.Context) bool { return c.Request().URL.Path == "/skip" }, })) body := `{"name": "echo"}` req := httptest.NewRequest(http.MethodPost, "/skip", strings.NewReader(body)) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) e.ServeHTTP(rec, req) assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSON) reqBody, err := io.ReadAll(c.Request().Body) assert.NoError(t, err) assert.Equal(t, body, string(reqBody)) } type TestDecompressPoolWithError struct { } func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool { return sync.Pool{ New: func() any { return errors.New("pool error") }, } } func TestDecompressPoolError(t *testing.T) { e := echo.New() e.Use(DecompressWithConfig(DecompressConfig{ Skipper: DefaultSkipper, GzipDecompressPool: &TestDecompressPoolWithError{}, })) body := `{"name": "echo"}` req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body)) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) e.ServeHTTP(rec, req) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) reqBody, err := io.ReadAll(c.Request().Body) assert.NoError(t, err) assert.Equal(t, body, string(reqBody)) assert.Equal(t, rec.Code, http.StatusInternalServerError) } func BenchmarkDecompress(b *testing.B) { e := echo.New() body := `{"name": "echo"}` gz, _ := gzipString(body) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) h := Decompress()(func(c *echo.Context) error { c.Response().Write([]byte(body)) // For Content-Type sniffing return nil }) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { // Decompress rec := httptest.NewRecorder() c := e.NewContext(req, rec) h(c) } } func gzipString(body string) ([]byte, error) { var buf bytes.Buffer gz := gzip.NewWriter(&buf) _, err := gz.Write([]byte(body)) if err != nil { return nil, err } if err := gz.Close(); err != nil { return nil, err } return buf.Bytes(), nil } func TestDecompress_WithinLimit(t *testing.T) { e := echo.New() body := strings.Repeat("test data ", 100) // Small payload ~1KB gz, _ := gzipString(body) req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h, err := DecompressConfig{MaxDecompressedSize: 100 * MB}.ToMiddleware() assert.NoError(t, err) err = h(func(c *echo.Context) error { b, _ := io.ReadAll(c.Request().Body) return c.String(http.StatusOK, string(b)) })(c) assert.NoError(t, err) assert.Equal(t, body, rec.Body.String()) } func TestDecompress_ExceedsLimit(t *testing.T) { e := echo.New() // Create 2KB of data but limit to 1KB largeBody := strings.Repeat("A", 2*1024) gz, _ := gzipString(largeBody) req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit assert.NoError(t, err) err = h(func(c *echo.Context) error { _, readErr := io.ReadAll(c.Request().Body) return readErr })(c) // Should return 413 error assert.Error(t, err) he, ok := err.(echo.HTTPStatusCoder) assert.True(t, ok) assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) } func TestDecompress_AtExactLimit(t *testing.T) { e := echo.New() exactBody := strings.Repeat("B", 1024) // Exactly 1KB gz, _ := gzipString(exactBody) req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() assert.NoError(t, err) err = h(func(c *echo.Context) error { b, _ := io.ReadAll(c.Request().Body) return c.String(http.StatusOK, string(b)) })(c) assert.NoError(t, err) assert.Equal(t, exactBody, rec.Body.String()) } func TestDecompress_ZipBomb(t *testing.T) { e := echo.New() // Create highly compressed data that expands to 2MB // but limit is 1MB largeBody := bytes.Repeat([]byte("A"), 2*1024*1024) // 2MB var buf bytes.Buffer gzWriter := gzip.NewWriter(&buf) gzWriter.Write(largeBody) gzWriter.Close() req := httptest.NewRequest(http.MethodPost, "/", &buf) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h, err := DecompressConfig{MaxDecompressedSize: 1 * MB}.ToMiddleware() assert.NoError(t, err) err = h(func(c *echo.Context) error { _, readErr := io.ReadAll(c.Request().Body) return readErr })(c) // Should return 413 error assert.Error(t, err) he, ok := err.(echo.HTTPStatusCoder) assert.True(t, ok) assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) } func TestDecompress_UnlimitedExplicit(t *testing.T) { e := echo.New() largeBody := strings.Repeat("X", 10*1024) // 10KB gz, _ := gzipString(largeBody) req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h, err := DecompressConfig{MaxDecompressedSize: -1}.ToMiddleware() // Unlimited assert.NoError(t, err) err = h(func(c *echo.Context) error { b, _ := io.ReadAll(c.Request().Body) return c.String(http.StatusOK, string(b)) })(c) assert.NoError(t, err) assert.Equal(t, largeBody, rec.Body.String()) } func TestDecompress_DefaultLimit(t *testing.T) { e := echo.New() smallBody := "test" gz, _ := gzipString(smallBody) req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) // Use zero value which should apply 100MB default h, err := DecompressConfig{}.ToMiddleware() assert.NoError(t, err) err = h(func(c *echo.Context) error { b, _ := io.ReadAll(c.Request().Body) return c.String(http.StatusOK, string(b)) })(c) assert.NoError(t, err) assert.Equal(t, smallBody, rec.Body.String()) } func TestDecompress_SmallCustomLimit(t *testing.T) { e := echo.New() body := strings.Repeat("D", 512) // 512 bytes gz, _ := gzipString(body) req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit assert.NoError(t, err) err = h(func(c *echo.Context) error { b, _ := io.ReadAll(c.Request().Body) return c.String(http.StatusOK, string(b)) })(c) assert.NoError(t, err) assert.Equal(t, body, rec.Body.String()) } func TestDecompress_MultipleReads(t *testing.T) { e := echo.New() // Test that limit is enforced across multiple Read() calls largeBody := strings.Repeat("M", 2*1024) // 2KB gz, _ := gzipString(largeBody) req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit assert.NoError(t, err) err = h(func(c *echo.Context) error { // Read in small chunks buf := make([]byte, 256) total := 0 for { n, readErr := c.Request().Body.Read(buf) total += n if readErr != nil { if readErr == io.EOF { return nil } return readErr } } })(c) // Should return 413 error from cumulative reads assert.Error(t, err) he, ok := err.(echo.HTTPStatusCoder) assert.True(t, ok) assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) } func TestDecompress_LargePayloadDosPrevention(t *testing.T) { e := echo.New() // Simulate a DoS attack with highly compressed large payload largeSize := 10 * 1024 * 1024 // 10MB decompressed largeBody := bytes.Repeat([]byte("Z"), largeSize) var buf bytes.Buffer gzWriter := gzip.NewWriter(&buf) gzWriter.Write(largeBody) gzWriter.Close() req := httptest.NewRequest(http.MethodPost, "/", &buf) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h, err := DecompressConfig{MaxDecompressedSize: 1 * MB}.ToMiddleware() assert.NoError(t, err) err = h(func(c *echo.Context) error { _, readErr := io.ReadAll(c.Request().Body) return readErr })(c) // Should prevent DoS by returning 413 assert.Error(t, err) he, ok := err.(echo.HTTPStatusCoder) assert.True(t, ok) assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) } func BenchmarkDecompress_WithLimit(b *testing.B) { e := echo.New() body := strings.Repeat("benchmark data ", 1000) // ~15KB gz, _ := gzipString(body) h, _ := DecompressConfig{MaxDecompressedSize: 100 * MB}.ToMiddleware() b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h(func(c *echo.Context) error { io.ReadAll(c.Request().Body) return nil })(c) } } ================================================ FILE: middleware/extractor.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "fmt" "net/textproto" "strings" "github.com/labstack/echo/v5" ) const ( // extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource exhaustion // attack vector extractorLimit = 20 ) // ExtractorSource is type to indicate source for extracted value type ExtractorSource string const ( // ExtractorSourceHeader means value was extracted from request header ExtractorSourceHeader ExtractorSource = "header" // ExtractorSourceQuery means value was extracted from request query parameters ExtractorSourceQuery ExtractorSource = "query" // ExtractorSourcePathParam means value was extracted from route path parameters ExtractorSourcePathParam ExtractorSource = "param" // ExtractorSourceCookie means value was extracted from request cookies ExtractorSourceCookie ExtractorSource = "cookie" // ExtractorSourceForm means value was extracted from request form values ExtractorSourceForm ExtractorSource = "form" ) // ValueExtractorError is error type when middleware extractor is unable to extract value from lookups type ValueExtractorError struct { message string } // Error returns errors text func (e *ValueExtractorError) Error() string { return e.message } var errHeaderExtractorValueMissing = &ValueExtractorError{message: "missing value in request header"} var errHeaderExtractorValueInvalid = &ValueExtractorError{message: "invalid value in request header"} var errQueryExtractorValueMissing = &ValueExtractorError{message: "missing value in the query string"} var errParamExtractorValueMissing = &ValueExtractorError{message: "missing value in path params"} var errCookieExtractorValueMissing = &ValueExtractorError{message: "missing value in cookies"} var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value in the form"} // ValuesExtractor defines a function for extracting values (keys/tokens) from the given context. type ValuesExtractor func(c *echo.Context) ([]string, ExtractorSource, error) // CreateExtractors creates ValuesExtractors from given lookups. // lookups is a string in the form of ":" or ":,:" that is used // to extract key from the request. // Possible values: // - "header:" or "header::" // `` is argument value to cut/trim prefix of the extracted value. This is useful if header // value has static prefix like `Authorization: ` where part that we // want to cut is ` ` note the space at the end. // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. // - "query:" // - "param:" // - "form:" // - "cookie:" // // Multiple sources example: // - "header:Authorization,header:X-Api-Key" // // limit sets the maximum amount how many lookups can be returned. func CreateExtractors(lookups string, limit uint) ([]ValuesExtractor, error) { return createExtractors(lookups, limit) } func createExtractors(lookups string, limit uint) ([]ValuesExtractor, error) { if lookups == "" { return nil, nil } if limit == 0 { limit = 1 } else if limit > extractorLimit { limit = extractorLimit } sources := strings.Split(lookups, ",") var extractors = make([]ValuesExtractor, 0) for _, source := range sources { parts := strings.Split(source, ":") if len(parts) < 2 { return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source) } switch parts[0] { case "query": extractors = append(extractors, valuesFromQuery(parts[1], limit)) case "param": extractors = append(extractors, valuesFromParam(parts[1], limit)) case "cookie": extractors = append(extractors, valuesFromCookie(parts[1], limit)) case "form": extractors = append(extractors, valuesFromForm(parts[1], limit)) case "header": prefix := "" if len(parts) > 2 { prefix = parts[2] } extractors = append(extractors, valuesFromHeader(parts[1], prefix, limit)) } } return extractors, nil } // valuesFromHeader returns a functions that extracts values from the request header. // valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has static // prefix like `Authorization: ` where part that we want to remove is ` ` // note the space at the end. In case of basic authentication `Authorization: Basic ` prefix we want to remove // is `Basic `. In case of JWT tokens `Authorization: Bearer ` prefix is `Bearer `. // If prefix is left empty the whole value is returned. func valuesFromHeader(header string, valuePrefix string, limit uint) ValuesExtractor { prefixLen := len(valuePrefix) // standard library parses http.Request header keys in canonical form but we may provide something else so fix this header = textproto.CanonicalMIMEHeaderKey(header) if limit == 0 { limit = 1 } return func(c *echo.Context) ([]string, ExtractorSource, error) { values := c.Request().Header.Values(header) if len(values) == 0 { return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing } i := uint(0) result := make([]string, 0) for _, value := range values { if prefixLen == 0 { result = append(result, value) i++ if i >= limit { break } } else if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) { result = append(result, value[prefixLen:]) i++ if i >= limit { break } } } if len(result) == 0 { if prefixLen > 0 { return nil, ExtractorSourceHeader, errHeaderExtractorValueInvalid } return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing } return result, ExtractorSourceHeader, nil } } // valuesFromQuery returns a function that extracts values from the query string. func valuesFromQuery(param string, limit uint) ValuesExtractor { if limit == 0 { limit = 1 } return func(c *echo.Context) ([]string, ExtractorSource, error) { result := c.QueryParams()[param] if len(result) == 0 { return nil, ExtractorSourceQuery, errQueryExtractorValueMissing } else if len(result) > int(limit)-1 { result = result[:limit] } return result, ExtractorSourceQuery, nil } } // valuesFromParam returns a function that extracts values from the url param string. func valuesFromParam(param string, limit uint) ValuesExtractor { if limit == 0 { limit = 1 } return func(c *echo.Context) ([]string, ExtractorSource, error) { result := make([]string, 0) i := uint(0) for _, p := range c.PathValues() { if param != p.Name { continue } result = append(result, p.Value) i++ if i >= limit { break } } if len(result) == 0 { return nil, ExtractorSourcePathParam, errParamExtractorValueMissing } return result, ExtractorSourcePathParam, nil } } // valuesFromCookie returns a function that extracts values from the named cookie. func valuesFromCookie(name string, limit uint) ValuesExtractor { if limit == 0 { limit = 1 } return func(c *echo.Context) ([]string, ExtractorSource, error) { cookies := c.Cookies() if len(cookies) == 0 { return nil, ExtractorSourceCookie, errCookieExtractorValueMissing } i := uint(0) result := make([]string, 0) for _, cookie := range cookies { if name != cookie.Name { continue } result = append(result, cookie.Value) i++ if i >= limit { break } } if len(result) == 0 { return nil, ExtractorSourceCookie, errCookieExtractorValueMissing } return result, ExtractorSourceCookie, nil } } // valuesFromForm returns a function that extracts values from the form field. func valuesFromForm(name string, limit uint) ValuesExtractor { if limit == 0 { limit = 1 } return func(c *echo.Context) ([]string, ExtractorSource, error) { if c.Request().Form == nil { _, _ = c.MultipartForm() // we want to trigger c.request.ParseMultipartForm(c.formParseMaxMemory) } values := c.Request().Form[name] if len(values) == 0 { return nil, ExtractorSourceForm, errFormExtractorValueMissing } if len(values) > int(limit)-1 { values = values[:limit] } result := append([]string{}, values...) return result, ExtractorSourceForm, nil } } ================================================ FILE: middleware/extractor_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "bytes" "fmt" "mime/multipart" "net/http" "net/http/httptest" "net/url" "strings" "testing" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestCreateExtractors(t *testing.T) { var testCases = []struct { name string givenRequest func() *http.Request givenPathValues echo.PathValues whenLookups string whenLimit uint expectValues []string expectSource ExtractorSource expectCreateError string expectError string }{ { name: "ok, header", givenRequest: func() *http.Request { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAuthorization, "Bearer token") return req }, whenLookups: "header:Authorization:Bearer ", expectValues: []string{"token"}, expectSource: ExtractorSourceHeader, }, { name: "ok, form", givenRequest: func() *http.Request { f := make(url.Values) f.Set("name", "Jon Snow") req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) return req }, whenLookups: "form:name", expectValues: []string{"Jon Snow"}, expectSource: ExtractorSourceForm, }, { name: "ok, cookie", givenRequest: func() *http.Request { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderCookie, "_csrf=token") return req }, whenLookups: "cookie:_csrf", expectValues: []string{"token"}, expectSource: ExtractorSourceCookie, }, { name: "ok, param", givenPathValues: echo.PathValues{ {Name: "id", Value: "123"}, }, whenLookups: "param:id", expectValues: []string{"123"}, expectSource: ExtractorSourcePathParam, }, { name: "ok, query", givenRequest: func() *http.Request { req := httptest.NewRequest(http.MethodGet, "/?id=999", nil) return req }, whenLookups: "query:id", expectValues: []string{"999"}, expectSource: ExtractorSourceQuery, }, { name: "nok, invalid lookup", whenLookups: "query", expectCreateError: "extractor source for lookup could not be split into needed parts: query", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) if tc.givenRequest != nil { req = tc.givenRequest() } rec := httptest.NewRecorder() c := e.NewContext(req, rec) if tc.givenPathValues != nil { c.SetPathValues(tc.givenPathValues) } extractors, err := CreateExtractors(tc.whenLookups, tc.whenLimit) if tc.expectCreateError != "" { assert.EqualError(t, err, tc.expectCreateError) return } assert.NoError(t, err) for _, e := range extractors { values, source, eErr := e(c) assert.Equal(t, tc.expectValues, values) assert.Equal(t, tc.expectSource, source) if tc.expectError != "" { assert.EqualError(t, eErr, tc.expectError) return } assert.NoError(t, eErr) } }) } } func TestValuesFromHeader(t *testing.T) { exampleRequest := func(req *http.Request) { req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") } var testCases = []struct { name string givenRequest func(req *http.Request) whenName string whenValuePrefix string whenLimit uint expectValues []string expectError string }{ { name: "ok, single value", givenRequest: exampleRequest, whenName: echo.HeaderAuthorization, whenValuePrefix: "basic ", expectValues: []string{"dXNlcjpwYXNzd29yZA=="}, }, { name: "ok, single value, case insensitive", givenRequest: exampleRequest, whenName: echo.HeaderAuthorization, whenValuePrefix: "Basic ", expectValues: []string{"dXNlcjpwYXNzd29yZA=="}, }, { name: "ok, multiple value", givenRequest: func(req *http.Request) { req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") }, whenName: echo.HeaderAuthorization, whenValuePrefix: "basic ", whenLimit: 2, expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"}, }, { name: "ok, empty prefix", givenRequest: exampleRequest, whenName: echo.HeaderAuthorization, whenValuePrefix: "", expectValues: []string{"basic dXNlcjpwYXNzd29yZA=="}, }, { name: "nok, no matching due different prefix", givenRequest: func(req *http.Request) { req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") }, whenName: echo.HeaderAuthorization, whenValuePrefix: "Bearer ", expectError: errHeaderExtractorValueInvalid.Error(), }, { name: "nok, no matching due different prefix", givenRequest: func(req *http.Request) { req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") }, whenName: echo.HeaderWWWAuthenticate, whenValuePrefix: "", expectError: errHeaderExtractorValueMissing.Error(), }, { name: "nok, no headers", givenRequest: nil, whenName: echo.HeaderAuthorization, whenValuePrefix: "basic ", expectError: errHeaderExtractorValueMissing.Error(), }, { name: "ok, prefix, cut values over extractorLimit", givenRequest: func(req *http.Request) { for i := 1; i <= 25; i++ { req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("basic %v", i)) } }, whenName: echo.HeaderAuthorization, whenValuePrefix: "basic ", whenLimit: extractorLimit, expectValues: []string{ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", }, }, { name: "ok, cut values over extractorLimit", givenRequest: func(req *http.Request) { for i := 1; i <= 25; i++ { req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("%v", i)) } }, whenName: echo.HeaderAuthorization, whenValuePrefix: "", whenLimit: extractorLimit, expectValues: []string{ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) if tc.givenRequest != nil { tc.givenRequest(req) } rec := httptest.NewRecorder() c := e.NewContext(req, rec) extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix, tc.whenLimit) values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) assert.Equal(t, ExtractorSourceHeader, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValuesFromQuery(t *testing.T) { var testCases = []struct { name string givenQueryPart string whenName string whenLimit uint expectValues []string expectError string }{ { name: "ok, single value", givenQueryPart: "?id=123&name=test", whenName: "id", expectValues: []string{"123"}, }, { name: "ok, multiple value", givenQueryPart: "?id=123&id=456&name=test", whenName: "id", whenLimit: 2, expectValues: []string{"123", "456"}, }, { name: "nok, missing value", givenQueryPart: "?id=123&name=test", whenName: "nope", expectError: errQueryExtractorValueMissing.Error(), }, { name: "ok, cut values over extractorLimit", givenQueryPart: "?name=test" + "&id=1&id=2&id=3&id=4&id=5&id=6&id=7&id=8&id=9&id=10" + "&id=11&id=12&id=13&id=14&id=15&id=16&id=17&id=18&id=19&id=20" + "&id=21&id=22&id=23&id=24&id=25", whenName: "id", whenLimit: extractorLimit, expectValues: []string{ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/"+tc.givenQueryPart, nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) extractor := valuesFromQuery(tc.whenName, tc.whenLimit) values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) assert.Equal(t, ExtractorSourceQuery, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValuesFromParam(t *testing.T) { examplePathValues := echo.PathValues{ {Name: "id", Value: "123"}, {Name: "gid", Value: "456"}, {Name: "gid", Value: "789"}, } examplePathValues20 := make(echo.PathValues, 0) for i := 1; i < 25; i++ { examplePathValues20 = append(examplePathValues20, echo.PathValue{Name: "id", Value: fmt.Sprintf("%v", i)}) } var testCases = []struct { name string givenPathValues echo.PathValues whenName string whenLimit uint expectValues []string expectError string }{ { name: "ok, single value", givenPathValues: examplePathValues, whenName: "id", expectValues: []string{"123"}, }, { name: "ok, multiple value", givenPathValues: examplePathValues, whenName: "gid", whenLimit: 2, expectValues: []string{"456", "789"}, }, { name: "nok, no values", givenPathValues: nil, whenName: "nope", expectValues: nil, expectError: errParamExtractorValueMissing.Error(), }, { name: "nok, no matching value", givenPathValues: examplePathValues, whenName: "nope", expectValues: nil, expectError: errParamExtractorValueMissing.Error(), }, { name: "ok, cut values over extractorLimit", givenPathValues: examplePathValues20, whenName: "id", whenLimit: extractorLimit, expectValues: []string{ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) if tc.givenPathValues != nil { c.SetPathValues(tc.givenPathValues) } extractor := valuesFromParam(tc.whenName, tc.whenLimit) values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) assert.Equal(t, ExtractorSourcePathParam, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValuesFromCookie(t *testing.T) { exampleRequest := func(req *http.Request) { req.Header.Set(echo.HeaderCookie, "_csrf=token") } var testCases = []struct { name string givenRequest func(req *http.Request) whenName string whenLimit uint expectValues []string expectError string }{ { name: "ok, single value", givenRequest: exampleRequest, whenName: "_csrf", expectValues: []string{"token"}, }, { name: "ok, multiple value", givenRequest: func(req *http.Request) { req.Header.Add(echo.HeaderCookie, "_csrf=token") req.Header.Add(echo.HeaderCookie, "_csrf=token2") }, whenName: "_csrf", whenLimit: 2, expectValues: []string{"token", "token2"}, }, { name: "nok, no matching cookie", givenRequest: exampleRequest, whenName: "xxx", expectValues: nil, expectError: errCookieExtractorValueMissing.Error(), }, { name: "nok, no cookies at all", givenRequest: nil, whenName: "xxx", expectValues: nil, expectError: errCookieExtractorValueMissing.Error(), }, { name: "ok, cut values over extractorLimit", givenRequest: func(req *http.Request) { for i := 1; i < 25; i++ { req.Header.Add(echo.HeaderCookie, fmt.Sprintf("_csrf=%v", i)) } }, whenName: "_csrf", whenLimit: extractorLimit, expectValues: []string{ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) if tc.givenRequest != nil { tc.givenRequest(req) } rec := httptest.NewRecorder() c := e.NewContext(req, rec) extractor := valuesFromCookie(tc.whenName, tc.whenLimit) values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) assert.Equal(t, ExtractorSourceCookie, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestValuesFromForm(t *testing.T) { examplePostFormRequest := func(mod func(v *url.Values)) *http.Request { f := make(url.Values) f.Set("name", "Jon Snow") f.Set("emails[]", "jon@labstack.com") if mod != nil { mod(&f) } req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) return req } exampleGetFormRequest := func(mod func(v *url.Values)) *http.Request { f := make(url.Values) f.Set("name", "Jon Snow") f.Set("emails[]", "jon@labstack.com") if mod != nil { mod(&f) } req := httptest.NewRequest(http.MethodGet, "/?"+f.Encode(), nil) return req } exampleMultiPartFormRequest := func(mod func(w *multipart.Writer)) *http.Request { var b bytes.Buffer w := multipart.NewWriter(&b) w.WriteField("name", "Jon Snow") w.WriteField("emails[]", "jon@labstack.com") if mod != nil { mod(w) } fw, _ := w.CreateFormFile("upload", "my.file") fw.Write([]byte(`
hi
`)) w.Close() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(b.String())) req.Header.Add(echo.HeaderContentType, w.FormDataContentType()) return req } var testCases = []struct { name string givenRequest *http.Request whenName string whenLimit uint expectValues []string expectError string }{ { name: "ok, POST form, single value", givenRequest: examplePostFormRequest(nil), whenName: "emails[]", expectValues: []string{"jon@labstack.com"}, }, { name: "ok, POST form, multiple value", givenRequest: examplePostFormRequest(func(v *url.Values) { v.Add("emails[]", "snow@labstack.com") }), whenName: "emails[]", whenLimit: 2, expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, }, { name: "ok, POST multipart/form, multiple value", givenRequest: exampleMultiPartFormRequest(func(w *multipart.Writer) { w.WriteField("emails[]", "snow@labstack.com") }), whenName: "emails[]", whenLimit: 2, expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, }, { name: "ok, GET form, single value", givenRequest: exampleGetFormRequest(nil), whenName: "emails[]", expectValues: []string{"jon@labstack.com"}, }, { name: "ok, GET form, multiple value", givenRequest: examplePostFormRequest(func(v *url.Values) { v.Add("emails[]", "snow@labstack.com") }), whenName: "emails[]", whenLimit: 2, expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, }, { name: "nok, POST form, value missing", givenRequest: examplePostFormRequest(nil), whenName: "nope", expectError: errFormExtractorValueMissing.Error(), }, { name: "ok, cut values over extractorLimit", givenRequest: examplePostFormRequest(func(v *url.Values) { for i := 1; i < 25; i++ { v.Add("id[]", fmt.Sprintf("%v", i)) } }), whenName: "id[]", whenLimit: extractorLimit, expectValues: []string{ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() req := tc.givenRequest rec := httptest.NewRecorder() c := e.NewContext(req, rec) extractor := valuesFromForm(tc.whenName, tc.whenLimit) values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) assert.Equal(t, ExtractorSourceForm, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } ================================================ FILE: middleware/key_auth.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "cmp" "errors" "fmt" "net/http" "github.com/labstack/echo/v5" ) // KeyAuthConfig defines the config for KeyAuth middleware. // // SECURITY: The Validator function is responsible for securely comparing API keys. // See KeyAuthValidator documentation for guidance on preventing timing attacks. type KeyAuthConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // KeyLookup is a string in the form of ":" or ":,:" that is used // to extract key from the request. // Optional. Default value "header:Authorization". // Possible values: // - "header:" or "header::" // `` is argument value to cut/trim prefix of the extracted value. This is useful if header // value has static prefix like `Authorization: ` where part that we // want to cut is ` ` note the space at the end. // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. // - "query:" // - "form:" // - "cookie:" // Multiple sources example: // - "header:Authorization,header:X-Api-Key" KeyLookup string // AllowedCheckLimit set how many KeyLookup values are allowed to be checked. This is // useful environments like corporate test environments with application proxies restricting // access to environment with their own auth scheme. AllowedCheckLimit uint // Validator is a function to validate key. // Required. Validator KeyAuthValidator // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator // function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key. // It may be used to define a custom error. // // Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler. // This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users // In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain. ErrorHandler KeyAuthErrorHandler // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to // ignore the error (by returning `nil`). // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. // In that case you can use ErrorHandler to set a default public key auth value in the request context // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then. ContinueOnIgnoredError bool } // KeyAuthValidator defines a function to validate KeyAuth credentials. // // SECURITY WARNING: To prevent timing attacks that could allow attackers to enumerate // valid API keys, validator implementations MUST use constant-time comparison. // Use crypto/subtle.ConstantTimeCompare instead of standard string equality (==) // or switch statements. // // Example of SECURE implementation: // // import "crypto/subtle" // // validator := func(c *echo.Context, key string, source ExtractorSource) (bool, error) { // // Fetch valid keys from database/config // validKeys := []string{"key1", "key2", "key3"} // // for _, validKey := range validKeys { // // Use constant-time comparison to prevent timing attacks // if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 { // return true, nil // } // } // return false, nil // } // // Example of INSECURE implementation (DO NOT USE): // // // VULNERABLE TO TIMING ATTACKS - DO NOT USE // validator := func(c *echo.Context, key string, source ExtractorSource) (bool, error) { // switch key { // Timing leak! // case "valid-key": // return true, nil // default: // return false, nil // } // } type KeyAuthValidator func(c *echo.Context, key string, source ExtractorSource) (bool, error) // KeyAuthErrorHandler defines a function which is executed for an invalid key. type KeyAuthErrorHandler func(c *echo.Context, err error) error // ErrKeyMissing denotes an error raised when key value could not be extracted from request var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key") // ErrInvalidKey denotes an error raised when key value is invalid by validator var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key") // DefaultKeyAuthConfig is the default KeyAuth middleware config. var DefaultKeyAuthConfig = KeyAuthConfig{ Skipper: DefaultSkipper, KeyLookup: "header:" + echo.HeaderAuthorization + ":Bearer ", } // KeyAuth returns an KeyAuth middleware. // // For valid key it calls the next handler. // For invalid key, it sends "401 - Unauthorized" response. // For missing key, it sends "400 - Bad Request" response. func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc { c := DefaultKeyAuthConfig c.Validator = fn return KeyAuthWithConfig(c) } // KeyAuthWithConfig returns an KeyAuth middleware or panics if configuration is invalid. // // For first valid key it calls the next handler. // For invalid key, it sends "401 - Unauthorized" response. // For missing key, it sends "400 - Bad Request" response. func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } // ToMiddleware converts KeyAuthConfig to middleware or returns an error for invalid configuration func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultKeyAuthConfig.Skipper } if config.KeyLookup == "" { config.KeyLookup = DefaultKeyAuthConfig.KeyLookup } if config.Validator == nil { return nil, errors.New("echo key-auth middleware requires a validator function") } limit := cmp.Or(config.AllowedCheckLimit, 1) extractors, cErr := createExtractors(config.KeyLookup, limit) if cErr != nil { return nil, fmt.Errorf("echo key-auth middleware could not create key extractor: %w", cErr) } if len(extractors) == 0 { return nil, errors.New("echo key-auth middleware could not create extractors from KeyLookup string") } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } var lastExtractorErr error var lastValidatorErr error for _, extractor := range extractors { keys, source, extrErr := extractor(c) if extrErr != nil { lastExtractorErr = extrErr continue } for _, key := range keys { valid, err := config.Validator(c, key, source) if err != nil { lastValidatorErr = err continue } if !valid { lastValidatorErr = ErrInvalidKey continue } return next(c) } } // prioritize validator errors over extracting errors err := lastValidatorErr if err == nil { err = lastExtractorErr } if config.ErrorHandler != nil { tmpErr := config.ErrorHandler(c, err) if config.ContinueOnIgnoredError && tmpErr == nil { return next(c) } return tmpErr } if lastValidatorErr == nil { return ErrKeyMissing.Wrap(err) } return echo.ErrUnauthorized.Wrap(err) } }, nil } ================================================ FILE: middleware/key_auth_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "crypto/subtle" "errors" "net/http" "net/http/httptest" "strings" "testing" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func testKeyValidator(c *echo.Context, key string, source ExtractorSource) (bool, error) { // Use constant-time comparison to prevent timing attacks if subtle.ConstantTimeCompare([]byte(key), []byte("valid-key")) == 1 { return true, nil } // Special case for testing error handling if key == "error-key" { // Error path doesn't need constant-time return false, errors.New("some user defined error") } return false, nil } func TestKeyAuth(t *testing.T) { handlerCalled := false handler := func(c *echo.Context) error { handlerCalled = true return c.String(http.StatusOK, "test") } middlewareChain := KeyAuth(testKeyValidator)(handler) e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key") rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := middlewareChain(c) assert.NoError(t, err) assert.True(t, handlerCalled) } func TestKeyAuthWithConfig(t *testing.T) { var testCases = []struct { name string givenRequestFunc func() *http.Request givenRequest func(req *http.Request) whenConfig func(conf *KeyAuthConfig) expectHandlerCalled bool expectError string }{ { name: "ok, defaults, key from header", givenRequest: func(req *http.Request) { req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key") }, expectHandlerCalled: true, }, { name: "ok, custom skipper", givenRequest: func(req *http.Request) { req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") }, whenConfig: func(conf *KeyAuthConfig) { conf.Skipper = func(context *echo.Context) bool { return true } }, expectHandlerCalled: true, }, { name: "nok, defaults, invalid key from header, Authorization: Bearer", givenRequest: func(req *http.Request) { req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key") }, expectHandlerCalled: false, expectError: "code=401, message=Unauthorized, err=code=401, message=invalid key", }, { name: "nok, defaults, invalid scheme in header", givenRequest: func(req *http.Request) { req.Header.Set(echo.HeaderAuthorization, "Bear valid-key") }, expectHandlerCalled: false, expectError: "code=401, message=missing key, err=invalid value in request header", }, { name: "nok, defaults, missing header", givenRequest: func(req *http.Request) {}, expectHandlerCalled: false, expectError: "code=401, message=missing key, err=missing value in request header", }, { name: "ok, custom key lookup, header", givenRequest: func(req *http.Request) { req.Header.Set("API-Key", "valid-key") }, whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "header:API-Key" }, expectHandlerCalled: true, }, { name: "nok, custom key lookup, missing header", givenRequest: func(req *http.Request) { }, whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "header:API-Key" }, expectHandlerCalled: false, expectError: "code=401, message=missing key, err=missing value in request header", }, { name: "ok, custom key lookup, query", givenRequest: func(req *http.Request) { q := req.URL.Query() q.Add("key", "valid-key") req.URL.RawQuery = q.Encode() }, whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "query:key" }, expectHandlerCalled: true, }, { name: "nok, custom key lookup, missing query param", whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "query:key" }, expectHandlerCalled: false, expectError: "code=401, message=missing key, err=missing value in the query string", }, { name: "ok, custom key lookup, form", givenRequestFunc: func() *http.Request { req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("key=valid-key")) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) return req }, whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "form:key" }, expectHandlerCalled: true, }, { name: "nok, custom key lookup, missing key in form", givenRequestFunc: func() *http.Request { req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("xxx=valid-key")) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) return req }, whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "form:key" }, expectHandlerCalled: false, expectError: "code=401, message=missing key, err=missing value in the form", }, { name: "ok, custom key lookup, cookie", givenRequest: func(req *http.Request) { req.AddCookie(&http.Cookie{ Name: "key", Value: "valid-key", }) q := req.URL.Query() q.Add("key", "valid-key") req.URL.RawQuery = q.Encode() }, whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "cookie:key" }, expectHandlerCalled: true, }, { name: "nok, custom key lookup, missing cookie param", whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "cookie:key" }, expectHandlerCalled: false, expectError: "code=401, message=missing key, err=missing value in cookies", }, { name: "nok, custom errorHandler, error from extractor", whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "header:token" conf.ErrorHandler = func(c *echo.Context, err error) error { return echo.NewHTTPError(http.StatusTeapot, "custom").Wrap(err) } }, expectHandlerCalled: false, expectError: "code=418, message=custom, err=missing value in request header", }, { name: "nok, custom errorHandler, error from validator", givenRequest: func(req *http.Request) { req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") }, whenConfig: func(conf *KeyAuthConfig) { conf.ErrorHandler = func(c *echo.Context, err error) error { return echo.NewHTTPError(http.StatusTeapot, "custom").Wrap(err) } }, expectHandlerCalled: false, expectError: "code=418, message=custom, err=some user defined error", }, { name: "nok, defaults, error from validator", givenRequest: func(req *http.Request) { req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") }, whenConfig: func(conf *KeyAuthConfig) {}, expectHandlerCalled: false, expectError: "code=401, message=Unauthorized, err=some user defined error", }, { name: "ok, custom validator checks source", givenRequest: func(req *http.Request) { q := req.URL.Query() q.Add("key", "valid-key") req.URL.RawQuery = q.Encode() }, whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "query:key" conf.Validator = func(c *echo.Context, key string, source ExtractorSource) (bool, error) { if source == ExtractorSourceQuery { return true, nil } return false, errors.New("invalid source") } }, expectHandlerCalled: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { handlerCalled := false handler := func(c *echo.Context) error { handlerCalled = true return c.String(http.StatusOK, "test") } config := KeyAuthConfig{ Validator: testKeyValidator, } if tc.whenConfig != nil { tc.whenConfig(&config) } middlewareChain := KeyAuthWithConfig(config)(handler) e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) if tc.givenRequestFunc != nil { req = tc.givenRequestFunc() } if tc.givenRequest != nil { tc.givenRequest(req) } rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := middlewareChain(c) assert.Equal(t, tc.expectHandlerCalled, handlerCalled) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestKeyAuthWithConfig_errors(t *testing.T) { var testCases = []struct { name string whenConfig KeyAuthConfig expectError string }{ { name: "ok, no error", whenConfig: KeyAuthConfig{ Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) { return false, nil }, }, }, { name: "ok, missing validator func", whenConfig: KeyAuthConfig{ Validator: nil, }, expectError: "echo key-auth middleware requires a validator function", }, { name: "ok, extractor source can not be split", whenConfig: KeyAuthConfig{ KeyLookup: "nope", Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) { return false, nil }, }, expectError: "echo key-auth middleware could not create key extractor: extractor source for lookup could not be split into needed parts: nope", }, { name: "ok, no extractors", whenConfig: KeyAuthConfig{ KeyLookup: "nope:nope", Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) { return false, nil }, }, expectError: "echo key-auth middleware could not create extractors from KeyLookup string", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { mw, err := tc.whenConfig.ToMiddleware() if tc.expectError != "" { assert.Nil(t, mw) assert.EqualError(t, err, tc.expectError) } else { assert.NotNil(t, mw) assert.NoError(t, err) } }) } } func TestMustKeyAuthWithConfig_panic(t *testing.T) { assert.Panics(t, func() { KeyAuthWithConfig(KeyAuthConfig{}) }) } func TestKeyAuth_errorHandlerSwallowsError(t *testing.T) { handlerCalled := false var authValue string handler := func(c *echo.Context) error { handlerCalled = true authValue = c.Get("auth").(string) return c.String(http.StatusOK, "test") } middlewareChain := KeyAuthWithConfig(KeyAuthConfig{ Validator: testKeyValidator, ErrorHandler: func(c *echo.Context, err error) error { // could check error to decide if we can swallow the error c.Set("auth", "public") return nil }, ContinueOnIgnoredError: true, })(handler) e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) // no auth header this time rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := middlewareChain(c) assert.NoError(t, err) assert.True(t, handlerCalled) assert.Equal(t, "public", authValue) } ================================================ FILE: middleware/method_override.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "net/http" "github.com/labstack/echo/v5" ) // MethodOverrideConfig defines the config for MethodOverride middleware. type MethodOverrideConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // Getter is a function that gets overridden method from the request. // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride). Getter MethodOverrideGetter } // MethodOverrideGetter is a function that gets overridden method from the request type MethodOverrideGetter func(c *echo.Context) string // DefaultMethodOverrideConfig is the default MethodOverride middleware config. var DefaultMethodOverrideConfig = MethodOverrideConfig{ Skipper: DefaultSkipper, Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), } // MethodOverride returns a MethodOverride middleware. // MethodOverride middleware checks for the overridden method from the request and // uses it instead of the original method. // // For security reasons, only `POST` method can be overridden. func MethodOverride() echo.MiddlewareFunc { return MethodOverrideWithConfig(DefaultMethodOverrideConfig) } // MethodOverrideWithConfig returns a Method Override middleware with config or panics on invalid configuration. func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } // ToMiddleware converts MethodOverrideConfig to middleware or returns an error for invalid configuration func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultMethodOverrideConfig.Skipper } if config.Getter == nil { config.Getter = DefaultMethodOverrideConfig.Getter } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } req := c.Request() if req.Method == http.MethodPost { m := config.Getter(c) if m != "" { req.Method = m } } return next(c) } }, nil } // MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from // the request header. func MethodFromHeader(header string) MethodOverrideGetter { return func(c *echo.Context) string { return c.Request().Header.Get(header) } } // MethodFromForm is a `MethodOverrideGetter` that gets overridden method from the // form parameter. func MethodFromForm(param string) MethodOverrideGetter { return func(c *echo.Context) string { return c.FormValue(param) } } // MethodFromQuery is a `MethodOverrideGetter` that gets overridden method from // the query parameter. func MethodFromQuery(param string) MethodOverrideGetter { return func(c *echo.Context) string { return c.QueryParam(param) } } ================================================ FILE: middleware/method_override_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "bytes" "net/http" "net/http/httptest" "testing" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestMethodOverride(t *testing.T) { e := echo.New() m := MethodOverride() h := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } // Override with http header req := httptest.NewRequest(http.MethodPost, "/", nil) rec := httptest.NewRecorder() req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) c := e.NewContext(req, rec) err := m(h)(c) assert.NoError(t, err) assert.Equal(t, http.MethodDelete, req.Method) } func TestMethodOverride_formParam(t *testing.T) { e := echo.New() h := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } // Override with form parameter m, err := MethodOverrideConfig{Getter: MethodFromForm("_method")}.ToMiddleware() assert.NoError(t, err) req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete))) rec := httptest.NewRecorder() req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) c := e.NewContext(req, rec) err = m(h)(c) assert.NoError(t, err) assert.Equal(t, http.MethodDelete, req.Method) } func TestMethodOverride_queryParam(t *testing.T) { e := echo.New() h := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } // Override with query parameter m, err := MethodOverrideConfig{Getter: MethodFromQuery("_method")}.ToMiddleware() assert.NoError(t, err) req := httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) err = m(h)(c) assert.NoError(t, err) assert.Equal(t, http.MethodDelete, req.Method) } func TestMethodOverride_ignoreGet(t *testing.T) { e := echo.New() m := MethodOverride() h := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } // Ignore `GET` req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := m(h)(c) assert.NoError(t, err) assert.Equal(t, http.MethodGet, req.Method) } ================================================ FILE: middleware/middleware.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "net/http" "regexp" "strconv" "strings" "github.com/labstack/echo/v5" ) // Skipper defines a function to skip middleware. Returning true skips processing the middleware. type Skipper func(c *echo.Context) bool // BeforeFunc defines a function which is executed just before the middleware. type BeforeFunc func(c *echo.Context) func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { groups := pattern.FindAllStringSubmatch(input, -1) if groups == nil { return nil } values := groups[0][1:] replace := make([]string, 2*len(values)) for i, v := range values { j := 2 * i replace[j] = "$" + strconv.Itoa(i+1) replace[j+1] = v } return strings.NewReplacer(replace...) } func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string { // Initialize rulesRegex := map[*regexp.Regexp]string{} for k, v := range rewrite { k = regexp.QuoteMeta(k) k = strings.ReplaceAll(k, `\*`, "(.*?)") if strings.HasPrefix(k, `\^`) { k = strings.ReplaceAll(k, `\^`, "^") } k = k + "$" rulesRegex[regexp.MustCompile(k)] = v } return rulesRegex } func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error { if len(rewriteRegex) == 0 { return nil } // Depending how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path. // We only want to use path part for rewriting and therefore trim prefix if it exists rawURI := req.RequestURI if rawURI != "" && rawURI[0] != '/' { prefix := "" if req.URL.Scheme != "" { prefix = req.URL.Scheme + "://" } if req.URL.Host != "" { prefix += req.URL.Host // host or host:port } if prefix != "" { rawURI = strings.TrimPrefix(rawURI, prefix) } } for k, v := range rewriteRegex { if replacer := captureTokens(k, rawURI); replacer != nil { url, err := req.URL.Parse(replacer.Replace(v)) if err != nil { return err } req.URL = url return nil // rewrite only once } } return nil } // DefaultSkipper returns false which processes the middleware. func DefaultSkipper(c *echo.Context) bool { return false } func toMiddlewareOrPanic(config echo.MiddlewareConfigurator) echo.MiddlewareFunc { mw, err := config.ToMiddleware() if err != nil { panic(err) } return mw } ================================================ FILE: middleware/middleware_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "bufio" "errors" "github.com/stretchr/testify/assert" "net" "net/http" "net/http/httptest" "regexp" "testing" ) func TestRewriteURL(t *testing.T) { var testCases = []struct { whenURL string expectPath string expectRawPath string expectQuery string expectErr string }{ { whenURL: "http://localhost:8080/old", expectPath: "/new", expectRawPath: "", }, { // encoded `ol%64` (decoded `old`) should not be rewritten to `/new` whenURL: "/ol%64", // `%64` is decoded `d` expectPath: "/old", expectRawPath: "/ol%64", }, { whenURL: "http://localhost:8080/users/+_+/orders/___++++?test=1", expectPath: "/user/+_+/order/___++++", expectRawPath: "", expectQuery: "test=1", }, { whenURL: "http://localhost:8080/users/%20a/orders/%20aa", expectPath: "/user/ a/order/ aa", expectRawPath: "", }, { whenURL: "http://localhost:8080/%47%6f%2f?test=1", expectPath: "/Go/", expectRawPath: "/%47%6f%2f", expectQuery: "test=1", }, { whenURL: "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F", expectPath: "/user/jill/order/T/cO4lW/t/Vp/", expectRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", }, { // do nothing, replace nothing whenURL: "http://localhost:8080/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", expectPath: "/user/jill/order/T/cO4lW/t/Vp/", expectRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", }, { whenURL: "http://localhost:8080/static", expectPath: "/static/path", expectRawPath: "", expectQuery: "role=AUTHOR&limit=1000", }, { whenURL: "/static", expectPath: "/static/path", expectRawPath: "", expectQuery: "role=AUTHOR&limit=1000", }, } rules := map[*regexp.Regexp]string{ regexp.MustCompile("^/old$"): "/new", regexp.MustCompile("^/users/(.*?)/orders/(.*?)$"): "/user/$1/order/$2", regexp.MustCompile("^/static$"): "/static/path?role=AUTHOR&limit=1000", } for _, tc := range testCases { t.Run(tc.whenURL, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) err := rewriteURL(rules, req) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, tc.expectPath, req.URL.Path) // Path field is stored in decoded form: /%47%6f%2f becomes /Go/. assert.Equal(t, tc.expectRawPath, req.URL.RawPath) // RawPath, an optional field which only gets set if the default encoding is different from Path. assert.Equal(t, tc.expectQuery, req.URL.RawQuery) }) } } type testResponseWriterNoFlushHijack struct { } func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) { } func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) { return 0, nil } func (w *testResponseWriterNoFlushHijack) Header() http.Header { return nil } type testResponseWriterUnwrapper struct { unwrapCalled int rw http.ResponseWriter } func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) { } func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) { return 0, nil } func (w *testResponseWriterUnwrapper) Header() http.Header { return nil } func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter { w.unwrapCalled++ return w.rw } type testResponseWriterUnwrapperHijack struct { testResponseWriterUnwrapper } func (w *testResponseWriterUnwrapperHijack) Hijack() (net.Conn, *bufio.ReadWriter, error) { return nil, nil, errors.New("can hijack") } ================================================ FILE: middleware/proxy.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "context" "crypto/tls" "errors" "fmt" "io" "math/rand" "net" "net/http" "net/http/httputil" "net/url" "regexp" "strings" "sync" "time" "github.com/labstack/echo/v5" ) // TODO: Handle TLS proxy // ProxyConfig defines the config for Proxy middleware. type ProxyConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // Balancer defines a load balancing technique. // Required. Balancer ProxyBalancer // RetryCount defines the number of times a failed proxied request should be retried // using the next available ProxyTarget. Defaults to 0, meaning requests are never retried. RetryCount int // RetryFilter defines a function used to determine if a failed request to a // ProxyTarget should be retried. The RetryFilter will only be called when the number // of previous retries is less than RetryCount. If the function returns true, the // request will be retried. The provided error indicates the reason for the request // failure. When the ProxyTarget is unavailable, the error will be an instance of // echo.HTTPError with a code of http.StatusBadGateway. In all other cases, the error // will indicate an internal error in the Proxy middleware. When a RetryFilter is not // specified, all requests that fail with http.StatusBadGateway will be retried. A custom // RetryFilter can be provided to only retry specific requests. Note that RetryFilter is // only called when the request to the target fails, or an internal error in the Proxy // middleware has occurred. Successful requests that return a non-200 response code cannot // be retried. RetryFilter func(c *echo.Context, e error) bool // ErrorHandler defines a function which can be used to return custom errors from // the Proxy middleware. ErrorHandler is only invoked when there has been // either an internal error in the Proxy middleware or the ProxyTarget is // unavailable. Due to the way requests are proxied, ErrorHandler is not invoked // when a ProxyTarget returns a non-200 response. In these cases, the response // is already written so errors cannot be modified. ErrorHandler is only // invoked after all retry attempts have been exhausted. ErrorHandler func(c *echo.Context, err error) error // Rewrite defines URL path rewrite rules. The values captured in asterisk can be // retrieved by index e.g. $1, $2 and so on. // Examples: // "/old": "/new", // "/api/*": "/$1", // "/js/*": "/public/javascripts/$1", // "/users/*/orders/*": "/user/$1/order/$2", Rewrite map[string]string // RegexRewrite defines rewrite rules using regexp.Rexexp with captures // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. // Example: // "^/old/[0.9]+/": "/new", // "^/api/.+?/(.*)": "/v2/$1", RegexRewrite map[*regexp.Regexp]string // Context key to store selected ProxyTarget into context. // Optional. Default value "target". ContextKey string // To customize the transport to remote. // Examples: If custom TLS certificates are required. Transport http.RoundTripper // ModifyResponse defines function to modify response from ProxyTarget. ModifyResponse func(*http.Response) error } // ProxyTarget defines the upstream target. type ProxyTarget struct { Name string URL *url.URL Meta map[string]any } // ProxyBalancer defines an interface to implement a load balancing technique. type ProxyBalancer interface { AddTarget(target *ProxyTarget) bool RemoveTarget(targetName string) bool Next(c *echo.Context) (*ProxyTarget, error) } type commonBalancer struct { targets []*ProxyTarget mutex sync.Mutex } // RandomBalancer implements a random load balancing technique. type randomBalancer struct { commonBalancer random *rand.Rand } // RoundRobinBalancer implements a round-robin load balancing technique. type roundRobinBalancer struct { commonBalancer // tracking the index on `targets` slice for the next `*ProxyTarget` to be used i int } // DefaultProxyConfig is the default Proxy middleware config. var DefaultProxyConfig = ProxyConfig{ Skipper: DefaultSkipper, ContextKey: "target", } func proxyRaw(c *echo.Context, t *ProxyTarget, config ProxyConfig) http.Handler { var dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) if transport, ok := config.Transport.(*http.Transport); ok { if transport.TLSClientConfig != nil { d := tls.Dialer{ Config: transport.TLSClientConfig, } dialFunc = d.DialContext } } if dialFunc == nil { var d net.Dialer dialFunc = d.DialContext } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { in, _, err := http.NewResponseController(w).Hijack() if err != nil { c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL)) return } defer in.Close() out, err := dialFunc(c.Request().Context(), "tcp", t.URL.Host) if err != nil { c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL))) return } defer out.Close() // Write header err = r.Write(out) if err != nil { c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", err, t.URL))) return } errCh := make(chan error, 2) cp := func(dst io.Writer, src io.Reader) { _, copyErr := io.Copy(dst, src) errCh <- copyErr } go cp(out, in) go cp(in, out) // Wait for BOTH goroutines to complete err1 := <-errCh err2 := <-errCh if err1 != nil && err1 != io.EOF { c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err1, t.URL)) } else if err2 != nil && err2 != io.EOF { c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err2, t.URL)) } }) } // NewRandomBalancer returns a random proxy balancer. func NewRandomBalancer(targets []*ProxyTarget) ProxyBalancer { b := randomBalancer{} b.targets = targets // G404 (CWE-338): Use of weak random number generator (math/rand or math/rand/v2 instead of crypto/rand) // this random is used to select next target. I can not think of reason this must be cryptographically safe. If you can - please open PR. b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) // #nosec G404 return &b } // NewRoundRobinBalancer returns a round-robin proxy balancer. func NewRoundRobinBalancer(targets []*ProxyTarget) ProxyBalancer { b := roundRobinBalancer{} b.targets = targets return &b } // AddTarget adds an upstream target to the list and returns `true`. // // However, if a target with the same name already exists then the operation is aborted returning `false`. func (b *commonBalancer) AddTarget(target *ProxyTarget) bool { b.mutex.Lock() defer b.mutex.Unlock() for _, t := range b.targets { if t.Name == target.Name { return false } } b.targets = append(b.targets, target) return true } // RemoveTarget removes an upstream target from the list by name. // // Returns `true` on success, `false` if no target with the name is found. func (b *commonBalancer) RemoveTarget(name string) bool { b.mutex.Lock() defer b.mutex.Unlock() for i, t := range b.targets { if t.Name == name { b.targets = append(b.targets[:i], b.targets[i+1:]...) return true } } return false } // Next randomly returns an upstream target. // // Note: `nil` is returned in case upstream target list is empty. func (b *randomBalancer) Next(c *echo.Context) (*ProxyTarget, error) { b.mutex.Lock() defer b.mutex.Unlock() if len(b.targets) == 0 { return nil, nil } else if len(b.targets) == 1 { return b.targets[0], nil } return b.targets[b.random.Intn(len(b.targets))], nil } // Next returns an upstream target using round-robin technique. In the case // where a previously failed request is being retried, the round-robin // balancer will attempt to use the next target relative to the original // request. If the list of targets held by the balancer is modified while a // failed request is being retried, it is possible that the balancer will // return the original failed target. // // Note: `nil` is returned in case upstream target list is empty. func (b *roundRobinBalancer) Next(c *echo.Context) (*ProxyTarget, error) { b.mutex.Lock() defer b.mutex.Unlock() if len(b.targets) == 0 { return nil, nil } else if len(b.targets) == 1 { return b.targets[0], nil } var i int const lastIdxKey = "_round_robin_last_index" // This request is a retry, start from the index of the previous // target to ensure we don't attempt to retry the request with // the same failed target if c.Get(lastIdxKey) != nil { i = c.Get(lastIdxKey).(int) i++ if i >= len(b.targets) { i = 0 } } else { // This is a first time request, use the global index if b.i >= len(b.targets) { b.i = 0 } i = b.i b.i++ } c.Set(lastIdxKey, i) return b.targets[i], nil } // Proxy returns a Proxy middleware. // // Proxy middleware forwards the request to upstream server using a configured load balancing technique. func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc { c := DefaultProxyConfig c.Balancer = balancer return ProxyWithConfig(c) } // ProxyWithConfig returns a Proxy middleware or panics if configuration is invalid. // // Proxy middleware forwards the request to upstream server using a configured load balancing technique. func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } // ToMiddleware converts ProxyConfig to middleware or returns an error for invalid configuration func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultProxyConfig.Skipper } if config.ContextKey == "" { config.ContextKey = DefaultProxyConfig.ContextKey } if config.Balancer == nil { return nil, errors.New("echo proxy middleware requires balancer") } if config.RetryFilter == nil { config.RetryFilter = func(c *echo.Context, e error) bool { if httpErr, ok := e.(*echo.HTTPError); ok { return httpErr.Code == http.StatusBadGateway } return false } } if config.ErrorHandler == nil { config.ErrorHandler = func(c *echo.Context, err error) error { return err } } if config.Rewrite != nil { if config.RegexRewrite == nil { config.RegexRewrite = make(map[*regexp.Regexp]string) } for k, v := range rewriteRulesRegex(config.Rewrite) { config.RegexRewrite[k] = v } } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) (err error) { if config.Skipper(c) { return next(c) } req := c.Request() res := c.Response() if err := rewriteURL(config.RegexRewrite, req); err != nil { return config.ErrorHandler(c, err) } // Fix header // Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream. // However, for backward compatibility, legacy behavior is preserved unless you configure Echo#IPExtractor. if req.Header.Get(echo.HeaderXRealIP) == "" || c.Echo().IPExtractor != nil { req.Header.Set(echo.HeaderXRealIP, c.RealIP()) } if req.Header.Get(echo.HeaderXForwardedProto) == "" { req.Header.Set(echo.HeaderXForwardedProto, c.Scheme()) } if c.IsWebSocket() && req.Header.Get(echo.HeaderXForwardedFor) == "" { // For HTTP, it is automatically set by Go HTTP reverse proxy. req.Header.Set(echo.HeaderXForwardedFor, c.RealIP()) } retries := config.RetryCount for { tgt, err := config.Balancer.Next(c) if err != nil { return config.ErrorHandler(c, err) } c.Set(config.ContextKey, tgt) //If retrying a failed request, clear any previous errors from //context here so that balancers have the option to check for //errors that occurred using previous target if retries < config.RetryCount { c.Set("_error", nil) } // This is needed for ProxyConfig.ModifyResponse and/or ProxyConfig.Transport to be able to process the Request // that Balancer may have replaced with c.SetRequest. req = c.Request() // Proxy switch { case c.IsWebSocket(): proxyRaw(c, tgt, config).ServeHTTP(res, req) default: // even SSE requests proxyHTTP(c, tgt, config).ServeHTTP(res, req) } err, hasError := c.Get("_error").(error) if !hasError { return nil } retry := retries > 0 && config.RetryFilter(c, err) if !retry { return config.ErrorHandler(c, err) } retries-- } } }, nil } // StatusCodeContextCanceled is a custom HTTP status code for situations // where a client unexpectedly closed the connection to the server. // As there is no standard error code for "client closed connection", but // various well-known HTTP clients and server implement this HTTP code we use // 499 too instead of the more problematic 5xx, which does not allow to detect this situation const StatusCodeContextCanceled = 499 func proxyHTTP(c *echo.Context, tgt *ProxyTarget, config ProxyConfig) http.Handler { proxy := httputil.NewSingleHostReverseProxy(tgt.URL) proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) { desc := tgt.URL.String() if tgt.Name != "" { desc = fmt.Sprintf("%s(%s)", tgt.Name, tgt.URL.String()) } // If the client canceled the request (usually by closing the connection), we can report a // client error (4xx) instead of a server error (5xx) to correctly identify the situation. // The Go standard library (at of late 2020) wraps the exported, standard // context. Canceled error with unexported garbage value requiring a substring check, see // https://github.com/golang/go/blob/6965b01ea248cabb70c3749fd218b36089a21efb/src/net/net.go#L416-L430 // From Caddy https://github.com/caddyserver/caddy/blob/afa778ae05503f563af0d1015cdf7e5e78b1eeec/modules/caddyhttp/reverseproxy/reverseproxy.go#L1352 if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "operation was canceled") { httpError := echo.NewHTTPError(StatusCodeContextCanceled, "client closed connection").Wrap(err) c.Set("_error", httpError) } else { httpError := echo.NewHTTPError( http.StatusBadGateway, "remote server unreachable, could not proxy request", ).Wrap(fmt.Errorf("server: %s, err: %w", desc, err)) c.Set("_error", httpError) } } proxy.Transport = config.Transport proxy.ModifyResponse = config.ModifyResponse return proxy } ================================================ FILE: middleware/proxy_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "bytes" "context" "crypto/tls" "errors" "fmt" "io" "net" "net/http" "net/http/httptest" "net/url" "regexp" "sync" "testing" "time" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" "golang.org/x/net/websocket" ) // Assert expected with url.EscapedPath method to obtain the path. func TestProxy(t *testing.T) { // Setup t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "target 1") })) defer t1.Close() url1, _ := url.Parse(t1.URL) t2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "target 2") })) defer t2.Close() url2, _ := url.Parse(t2.URL) targets := []*ProxyTarget{ { Name: "target 1", URL: url1, }, { Name: "target 2", URL: url2, }, } rb := NewRandomBalancer(nil) // must add targets: for _, target := range targets { assert.True(t, rb.AddTarget(target)) } // must ignore duplicates: for _, target := range targets { assert.False(t, rb.AddTarget(target)) } // Random e := echo.New() e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) body := rec.Body.String() expected := map[string]bool{ "target 1": true, "target 2": true, } assert.Condition(t, func() bool { return expected[body] }) for _, target := range targets { assert.True(t, rb.RemoveTarget(target.Name)) } assert.False(t, rb.RemoveTarget("unknown target")) // Round-robin rrb := NewRoundRobinBalancer(targets) e = echo.New() e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb})) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) body = rec.Body.String() assert.Equal(t, "target 1", body) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) body = rec.Body.String() assert.Equal(t, "target 2", body) // ModifyResponse e = echo.New() e.Use(ProxyWithConfig(ProxyConfig{ Balancer: rrb, ModifyResponse: func(res *http.Response) error { res.Body = io.NopCloser(bytes.NewBuffer([]byte("modified"))) res.Header.Set("X-Modified", "1") return nil }, })) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "modified", rec.Body.String()) assert.Equal(t, "1", rec.Header().Get("X-Modified")) // ProxyTarget is set in context contextObserver := func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) (err error) { next(c) assert.Contains(t, targets, c.Get("target"), "target is not set in context") return nil } } e = echo.New() e.Use(contextObserver) e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRoundRobinBalancer(targets)})) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) } func TestMustProxyWithConfig_emptyBalancerPanics(t *testing.T) { assert.Panics(t, func() { ProxyWithConfig(ProxyConfig{Balancer: nil}) }) } func TestProxyRealIPHeader(t *testing.T) { // Setup upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) defer upstream.Close() url, _ := url.Parse(upstream.URL) rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}}) e := echo.New() e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb})) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() remoteAddrIP, _, _ := net.SplitHostPort(req.RemoteAddr) realIPHeaderIP := "203.0.113.1" extractedRealIP := "203.0.113.10" tests := []*struct { hasRealIPheader bool hasIPExtractor bool expectedXRealIP string }{ {false, false, remoteAddrIP}, {false, true, extractedRealIP}, {true, false, realIPHeaderIP}, {true, true, extractedRealIP}, } for _, tt := range tests { if tt.hasRealIPheader { req.Header.Set(echo.HeaderXRealIP, realIPHeaderIP) } else { req.Header.Del(echo.HeaderXRealIP) } if tt.hasIPExtractor { e.IPExtractor = func(*http.Request) string { return extractedRealIP } } else { e.IPExtractor = nil } e.ServeHTTP(rec, req) assert.Equal(t, tt.expectedXRealIP, req.Header.Get(echo.HeaderXRealIP), "hasRealIPheader: %t / hasIPExtractor: %t", tt.hasRealIPheader, tt.hasIPExtractor) } } func TestProxyRewrite(t *testing.T) { var testCases = []struct { whenPath string expectProxiedURI string expectStatus int }{ { whenPath: "/api/users", expectProxiedURI: "/users", expectStatus: http.StatusOK, }, { whenPath: "/js/main.js", expectProxiedURI: "/public/javascripts/main.js", expectStatus: http.StatusOK, }, { whenPath: "/old", expectProxiedURI: "/new", expectStatus: http.StatusOK, }, { whenPath: "/users/jack/orders/1", expectProxiedURI: "/user/jack/order/1", expectStatus: http.StatusOK, }, { whenPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", expectProxiedURI: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", expectStatus: http.StatusOK, }, { // ` ` (space) is encoded by httpClient to `%20` when doing request to Echo. `%20` should not be double escaped when proxying request whenPath: "/api/new users", expectProxiedURI: "/new%20users", expectStatus: http.StatusOK, }, { // query params should be proxied and not be modified whenPath: "/api/users?limit=10", expectProxiedURI: "/users?limit=10", expectStatus: http.StatusOK, }, } for _, tc := range testCases { t.Run(tc.whenPath, func(t *testing.T) { receivedRequestURI := make(chan string, 1) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // RequestURI is the unmodified request-target of the Request-Line (RFC 7230, Section 3.1.1) as sent by the client to a server // we need unmodified target to see if we are encoding/decoding the url in addition to rewrite/replace logic // if original request had `%2F` we should not magically decode it to `/` as it would change what was requested receivedRequestURI <- r.RequestURI })) defer upstream.Close() serverURL, _ := url.Parse(upstream.URL) rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: serverURL}}) // Rewrite e := echo.New() e.Use(ProxyWithConfig(ProxyConfig{ Balancer: rrb, Rewrite: map[string]string{ "/old": "/new", "/api/*": "/$1", "/js/*": "/public/javascripts/$1", "/users/*/orders/*": "/user/$1/order/$2", }, })) targetURL, _ := serverURL.Parse(tc.whenPath) req := httptest.NewRequest(http.MethodGet, targetURL.String(), nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, tc.expectStatus, rec.Code) actualRequestURI := <-receivedRequestURI assert.Equal(t, tc.expectProxiedURI, actualRequestURI) }) } } func TestProxyRewriteRegex(t *testing.T) { // Setup receivedRequestURI := make(chan string, 1) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // RequestURI is the unmodified request-target of the Request-Line (RFC 7230, Section 3.1.1) as sent by the client to a server // we need unmodified target to see if we are encoding/decoding the url in addition to rewrite/replace logic // if original request had `%2F` we should not magically decode it to `/` as it would change what was requested receivedRequestURI <- r.RequestURI })) defer upstream.Close() tmpUrL, _ := url.Parse(upstream.URL) rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: tmpUrL}}) // Rewrite e := echo.New() e.Use(ProxyWithConfig(ProxyConfig{ Balancer: rrb, Rewrite: map[string]string{ "^/a/*": "/v1/$1", "^/b/*/c/*": "/v2/$2/$1", "^/c/*/*": "/v3/$2", }, RegexRewrite: map[*regexp.Regexp]string{ regexp.MustCompile("^/x/.+?/(.*)"): "/v4/$1", regexp.MustCompile("^/y/(.+?)/(.*)"): "/v5/$2/$1", }, })) testCases := []struct { requestPath string statusCode int expectPath string }{ {"/unmatched", http.StatusOK, "/unmatched"}, {"/a/test", http.StatusOK, "/v1/test"}, {"/b/foo/c/bar/baz", http.StatusOK, "/v2/bar/baz/foo"}, {"/c/ignore/test", http.StatusOK, "/v3/test"}, {"/c/ignore1/test/this", http.StatusOK, "/v3/test/this"}, {"/x/ignore/test", http.StatusOK, "/v4/test"}, {"/y/foo/bar", http.StatusOK, "/v5/bar/foo"}, // NB: fragment is not added by golang httputil.NewSingleHostReverseProxy implementation // $2 = `bar?q=1#frag`, $1 = `foo`. replaced uri = `/v5/bar?q=1#frag/foo` but httputil.NewSingleHostReverseProxy does not send `#frag/foo` (currently) {"/y/foo/bar?q=1#frag", http.StatusOK, "/v5/bar?q=1"}, } for _, tc := range testCases { t.Run(tc.requestPath, func(t *testing.T) { targetURL, _ := url.Parse(tc.requestPath) req := httptest.NewRequest(http.MethodGet, targetURL.String(), nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) actualRequestURI := <-receivedRequestURI assert.Equal(t, tc.expectPath, actualRequestURI) assert.Equal(t, tc.statusCode, rec.Code) }) } } func TestProxyError(t *testing.T) { // Setup url1, _ := url.Parse("http://127.0.0.1:27121") url2, _ := url.Parse("http://127.0.0.1:27122") targets := []*ProxyTarget{ { Name: "target 1", URL: url1, }, { Name: "target 2", URL: url2, }, } rb := NewRandomBalancer(nil) // must add targets: for _, target := range targets { assert.True(t, rb.AddTarget(target)) } // must ignore duplicates: for _, target := range targets { assert.False(t, rb.AddTarget(target)) } // Random e := echo.New() e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) req := httptest.NewRequest(http.MethodGet, "/", nil) // Remote unreachable rec := httptest.NewRecorder() req.URL.Path = "/api/users" e.ServeHTTP(rec, req) assert.Equal(t, "/api/users", req.URL.Path) assert.Equal(t, http.StatusBadGateway, rec.Code) } func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) { var timeoutStop sync.WaitGroup timeoutStop.Add(1) HTTPTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { timeoutStop.Wait() // wait until we have canceled the request w.WriteHeader(http.StatusOK) })) defer HTTPTarget.Close() targetURL, _ := url.Parse(HTTPTarget.URL) target := &ProxyTarget{ Name: "target", URL: targetURL, } rb := NewRandomBalancer(nil) assert.True(t, rb.AddTarget(target)) e := echo.New() e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) ctx, cancel := context.WithCancel(req.Context()) req = req.WithContext(ctx) go func() { time.Sleep(10 * time.Millisecond) cancel() }() e.ServeHTTP(rec, req) timeoutStop.Done() assert.Equal(t, 499, rec.Code) } type testProvider struct { commonBalancer target *ProxyTarget err error } func (p *testProvider) Next(c *echo.Context) (*ProxyTarget, error) { return p.target, p.err } func TestTargetProvider(t *testing.T) { t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "target 1") })) defer t1.Close() url1, _ := url.Parse(t1.URL) e := echo.New() tp := &testProvider{} tp.target = &ProxyTarget{Name: "target 1", URL: url1} e.Use(Proxy(tp)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(rec, req) body := rec.Body.String() assert.Equal(t, "target 1", body) } func TestFailNextTarget(t *testing.T) { url1, err := url.Parse("http://dummy:8080") assert.Nil(t, err) e := echo.New() tp := &testProvider{} tp.target = &ProxyTarget{Name: "target 1", URL: url1} tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target") e.Use(Proxy(tp)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(rec, req) body := rec.Body.String() assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body) } func TestRandomBalancerWithNoTargets(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) // Assert balancer with empty targets does return `nil` on `Next()` rb := NewRandomBalancer(nil) target, err := rb.Next(c) assert.Nil(t, target) assert.NoError(t, err) } func TestRoundRobinBalancerWithNoTargets(t *testing.T) { // Assert balancer with empty targets does return `nil` on `Next()` rrb := NewRoundRobinBalancer([]*ProxyTarget{}) e := echo.New() req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) target, err := rrb.Next(c) assert.Nil(t, target) assert.NoError(t, err) } func TestProxyRetries(t *testing.T) { newServer := func(res int) (*url.URL, *httptest.Server) { server := httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(res) }), ) targetURL, _ := url.Parse(server.URL) return targetURL, server } targetURL, server := newServer(http.StatusOK) defer server.Close() goodTarget := &ProxyTarget{ Name: "Good", URL: targetURL, } targetURL, server = newServer(http.StatusBadRequest) defer server.Close() goodTargetWith40X := &ProxyTarget{ Name: "Good with 40X", URL: targetURL, } targetURL, _ = url.Parse("http://127.0.0.1:27121") badTarget := &ProxyTarget{ Name: "Bad", URL: targetURL, } alwaysRetryFilter := func(c *echo.Context, e error) bool { return true } neverRetryFilter := func(c *echo.Context, e error) bool { return false } testCases := []struct { name string retryCount int retryFilters []func(c *echo.Context, e error) bool targets []*ProxyTarget expectedResponse int }{ { name: "retry count 0 does not attempt retry on fail", targets: []*ProxyTarget{ badTarget, goodTarget, }, expectedResponse: http.StatusBadGateway, }, { name: "retry count 1 does not attempt retry on success", retryCount: 1, targets: []*ProxyTarget{ goodTarget, }, expectedResponse: http.StatusOK, }, { name: "retry count 1 does retry on handler return true", retryCount: 1, retryFilters: []func(c *echo.Context, e error) bool{ alwaysRetryFilter, }, targets: []*ProxyTarget{ badTarget, goodTarget, }, expectedResponse: http.StatusOK, }, { name: "retry count 1 does not retry on handler return false", retryCount: 1, retryFilters: []func(c *echo.Context, e error) bool{ neverRetryFilter, }, targets: []*ProxyTarget{ badTarget, goodTarget, }, expectedResponse: http.StatusBadGateway, }, { name: "retry count 2 returns error when no more retries left", retryCount: 2, retryFilters: []func(c *echo.Context, e error) bool{ alwaysRetryFilter, alwaysRetryFilter, }, targets: []*ProxyTarget{ badTarget, badTarget, badTarget, goodTarget, //Should never be reached as only 2 retries }, expectedResponse: http.StatusBadGateway, }, { name: "retry count 2 returns error when retries left but handler returns false", retryCount: 3, retryFilters: []func(c *echo.Context, e error) bool{ alwaysRetryFilter, alwaysRetryFilter, neverRetryFilter, }, targets: []*ProxyTarget{ badTarget, badTarget, badTarget, goodTarget, //Should never be reached as retry handler returns false on 2nd check }, expectedResponse: http.StatusBadGateway, }, { name: "retry count 3 succeeds", retryCount: 3, retryFilters: []func(c *echo.Context, e error) bool{ alwaysRetryFilter, alwaysRetryFilter, alwaysRetryFilter, }, targets: []*ProxyTarget{ badTarget, badTarget, badTarget, goodTarget, }, expectedResponse: http.StatusOK, }, { name: "40x responses are not retried", retryCount: 1, targets: []*ProxyTarget{ goodTargetWith40X, goodTarget, }, expectedResponse: http.StatusBadRequest, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { retryFilterCall := 0 retryFilter := func(c *echo.Context, e error) bool { if len(tc.retryFilters) == 0 { assert.FailNow(t, fmt.Sprintf("unexpected calls, %d, to retry handler", retryFilterCall)) } retryFilterCall++ nextRetryFilter := tc.retryFilters[0] tc.retryFilters = tc.retryFilters[1:] return nextRetryFilter(c, e) } e := echo.New() e.Use(ProxyWithConfig( ProxyConfig{ Balancer: NewRoundRobinBalancer(tc.targets), RetryCount: tc.retryCount, RetryFilter: retryFilter, }, )) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, tc.expectedResponse, rec.Code) if len(tc.retryFilters) > 0 { assert.FailNow(t, fmt.Sprintf("expected %d more retry handler calls", len(tc.retryFilters))) } }) } } func TestProxyRetryWithBackendTimeout(t *testing.T) { transport := http.DefaultTransport.(*http.Transport).Clone() transport.ResponseHeaderTimeout = time.Millisecond * 500 timeoutBackend := httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(1 * time.Second) w.WriteHeader(404) }), ) defer timeoutBackend.Close() timeoutTargetURL, _ := url.Parse(timeoutBackend.URL) goodBackend := httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) }), ) defer goodBackend.Close() goodTargetURL, _ := url.Parse(goodBackend.URL) e := echo.New() e.Use(ProxyWithConfig( ProxyConfig{ Transport: transport, Balancer: NewRoundRobinBalancer([]*ProxyTarget{ { Name: "Timeout", URL: timeoutTargetURL, }, { Name: "Good", URL: goodTargetURL, }, }), RetryCount: 1, }, )) var wg sync.WaitGroup for i := 0; i < 20; i++ { wg.Add(1) go func() { defer wg.Done() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, 200, rec.Code) }() } wg.Wait() } func TestProxyErrorHandler(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) goodURL, _ := url.Parse(server.URL) defer server.Close() goodTarget := &ProxyTarget{ Name: "Good", URL: goodURL, } badURL, _ := url.Parse("http://127.0.0.1:27121") badTarget := &ProxyTarget{ Name: "Bad", URL: badURL, } transformedError := errors.New("a new error") testCases := []struct { name string target *ProxyTarget errorHandler func(c *echo.Context, e error) error expectFinalError func(t *testing.T, err error) }{ { name: "Error handler not invoked when request success", target: goodTarget, errorHandler: func(c *echo.Context, e error) error { assert.FailNow(t, "error handler should not be invoked") return e }, }, { name: "Error handler invoked when request fails", target: badTarget, errorHandler: func(c *echo.Context, e error) error { httpErr, ok := e.(*echo.HTTPError) assert.True(t, ok, "expected http error to be passed to handler") assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler") return transformedError }, expectFinalError: func(t *testing.T, err error) { assert.Equal(t, transformedError, err, "transformed error not returned from proxy") }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() e.Use(ProxyWithConfig( ProxyConfig{ Balancer: NewRoundRobinBalancer([]*ProxyTarget{tc.target}), ErrorHandler: tc.errorHandler, }, )) errorHandlerCalled := false dheh := echo.DefaultHTTPErrorHandler(false) e.HTTPErrorHandler = func(c *echo.Context, err error) { errorHandlerCalled = true tc.expectFinalError(t, err) dheh(c, err) } req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) if !errorHandlerCalled && tc.expectFinalError != nil { t.Fatalf("error handler was not called") } }) } } type testContextKey string type customBalancer struct { target *ProxyTarget } func (b *customBalancer) AddTarget(target *ProxyTarget) bool { return false } func (b *customBalancer) RemoveTarget(name string) bool { return false } func (b *customBalancer) Next(c *echo.Context) (*ProxyTarget, error) { ctx := context.WithValue(c.Request().Context(), testContextKey("FROM_BALANCER"), "CUSTOM_BALANCER") c.SetRequest(c.Request().WithContext(ctx)) return b.target, nil } func TestModifyResponseUseContext(t *testing.T) { server := httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) }), ) defer server.Close() targetURL, _ := url.Parse(server.URL) e := echo.New() e.Use(ProxyWithConfig( ProxyConfig{ Balancer: &customBalancer{ target: &ProxyTarget{ Name: "tst", URL: targetURL, }, }, RetryCount: 1, ModifyResponse: func(res *http.Response) error { val := res.Request.Context().Value(testContextKey("FROM_BALANCER")) if valStr, ok := val.(string); ok { res.Header.Set("FROM_BALANCER", valStr) } return nil }, }, )) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "OK", rec.Body.String()) assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER")) } func createSimpleWebSocketServer(serveTLS bool) *httptest.Server { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { wsHandler := func(conn *websocket.Conn) { defer conn.Close() for { var msg string err := websocket.Message.Receive(conn, &msg) if err != nil { return } // message back to the client websocket.Message.Send(conn, msg) } } websocket.Server{Handler: wsHandler}.ServeHTTP(w, r) }) if serveTLS { return httptest.NewTLSServer(handler) } return httptest.NewServer(handler) } func createSimpleProxyServer(t *testing.T, srv *httptest.Server, serveTLS bool, toTLS bool) *httptest.Server { e := echo.New() if toTLS { // proxy to tls target tgtURL, _ := url.Parse(srv.URL) tgtURL.Scheme = "wss" balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}}) defaultTransport, ok := http.DefaultTransport.(*http.Transport) if !ok { t.Fatal("Default transport is not of type *http.Transport") } transport := defaultTransport.Clone() transport.TLSClientConfig = &tls.Config{ InsecureSkipVerify: true, } e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer, Transport: transport})) } else { // proxy to non-TLS target tgtURL, _ := url.Parse(srv.URL) balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}}) e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer})) } if serveTLS { // serve proxy server with TLS ts := httptest.NewTLSServer(e) return ts } // serve proxy server without TLS ts := httptest.NewServer(e) return ts } // TestProxyWithConfigWebSocketNonTLS2NonTLS tests the proxy with non-TLS to non-TLS WebSocket connection. func TestProxyWithConfigWebSocketNonTLS2NonTLS(t *testing.T) { /* Arrange */ // Create a WebSocket test server (non-TLS) srv := createSimpleWebSocketServer(false) defer srv.Close() // create proxy server (non-TLS to non-TLS) ts := createSimpleProxyServer(t, srv, false, false) defer ts.Close() tsURL, _ := url.Parse(ts.URL) tsURL.Scheme = "ws" tsURL.Path = "/" /* Act */ // Connect to the proxy WebSocket wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/") assert.NoError(t, err) defer wsConn.Close() // Send message sendMsg := "Hello, Non TLS WebSocket!" err = websocket.Message.Send(wsConn, sendMsg) assert.NoError(t, err) /* Assert */ // Read response var recvMsg string err = websocket.Message.Receive(wsConn, &recvMsg) assert.NoError(t, err) assert.Equal(t, sendMsg, recvMsg) } // TestProxyWithConfigWebSocketTLS2TLS tests the proxy with TLS to TLS WebSocket connection. func TestProxyWithConfigWebSocketTLS2TLS(t *testing.T) { /* Arrange */ // Create a WebSocket test server (TLS) srv := createSimpleWebSocketServer(true) defer srv.Close() // create proxy server (TLS to TLS) ts := createSimpleProxyServer(t, srv, true, true) defer ts.Close() tsURL, _ := url.Parse(ts.URL) tsURL.Scheme = "wss" tsURL.Path = "/" /* Act */ origin, err := url.Parse(ts.URL) assert.NoError(t, err) config := &websocket.Config{ Location: tsURL, Origin: origin, TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing Version: websocket.ProtocolVersionHybi13, } wsConn, err := websocket.DialConfig(config) assert.NoError(t, err) defer wsConn.Close() // Send message sendMsg := "Hello, TLS to TLS WebSocket!" err = websocket.Message.Send(wsConn, sendMsg) assert.NoError(t, err) // Read response var recvMsg string err = websocket.Message.Receive(wsConn, &recvMsg) assert.NoError(t, err) assert.Equal(t, sendMsg, recvMsg) } // TestProxyWithConfigWebSocketNonTLS2TLS tests the proxy with non-TLS to TLS WebSocket connection. func TestProxyWithConfigWebSocketNonTLS2TLS(t *testing.T) { /* Arrange */ // Create a WebSocket test server (TLS) srv := createSimpleWebSocketServer(true) defer srv.Close() // create proxy server (Non-TLS to TLS) ts := createSimpleProxyServer(t, srv, false, true) defer ts.Close() tsURL, _ := url.Parse(ts.URL) tsURL.Scheme = "ws" tsURL.Path = "/" /* Act */ // Connect to the proxy WebSocket wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/") assert.NoError(t, err) defer wsConn.Close() // Send message sendMsg := "Hello, Non TLS to TLS WebSocket!" err = websocket.Message.Send(wsConn, sendMsg) assert.NoError(t, err) /* Assert */ // Read response var recvMsg string err = websocket.Message.Receive(wsConn, &recvMsg) assert.NoError(t, err) assert.Equal(t, sendMsg, recvMsg) } // TestProxyWithConfigWebSocketTLSToNoneTLS tests the proxy with TLS to non-TLS WebSocket connection. (TLS termination) func TestProxyWithConfigWebSocketTLS2NonTLS(t *testing.T) { /* Arrange */ // Create a WebSocket test server (non-TLS) srv := createSimpleWebSocketServer(false) defer srv.Close() // create proxy server (TLS to non-TLS) ts := createSimpleProxyServer(t, srv, true, false) defer ts.Close() tsURL, _ := url.Parse(ts.URL) tsURL.Scheme = "wss" tsURL.Path = "/" /* Act */ origin, err := url.Parse(ts.URL) assert.NoError(t, err) config := &websocket.Config{ Location: tsURL, Origin: origin, TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing Version: websocket.ProtocolVersionHybi13, } wsConn, err := websocket.DialConfig(config) assert.NoError(t, err) defer wsConn.Close() // Send message sendMsg := "Hello, TLS to NoneTLS WebSocket!" err = websocket.Message.Send(wsConn, sendMsg) assert.NoError(t, err) // Read response var recvMsg string err = websocket.Message.Receive(wsConn, &recvMsg) assert.NoError(t, err) assert.Equal(t, sendMsg, recvMsg) } ================================================ FILE: middleware/rate_limiter.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "errors" "math" "net/http" "sync" "time" "github.com/labstack/echo/v5" "golang.org/x/time/rate" ) // RateLimiterStore is the interface to be implemented by custom stores. type RateLimiterStore interface { Allow(identifier string) (bool, error) } // RateLimiterConfig defines the configuration for the rate limiter type RateLimiterConfig struct { Skipper Skipper BeforeFunc BeforeFunc // IdentifierExtractor uses *echo.Context to extract the identifier for a visitor IdentifierExtractor Extractor // Store defines a store for the rate limiter Store RateLimiterStore // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error ErrorHandler func(c *echo.Context, err error) error // DenyHandler provides a handler to be called when RateLimiter denies access DenyHandler func(c *echo.Context, identifier string, err error) error } // Extractor is used to extract data from *echo.Context type Extractor func(c *echo.Context) (string, error) // ErrRateLimitExceeded denotes an error raised when rate limit is exceeded var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") // ErrExtractorError denotes an error raised when extractor function is unsuccessful var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") // DefaultRateLimiterConfig defines default values for RateLimiterConfig var DefaultRateLimiterConfig = RateLimiterConfig{ Skipper: DefaultSkipper, IdentifierExtractor: func(ctx *echo.Context) (string, error) { id := ctx.RealIP() return id, nil }, ErrorHandler: func(c *echo.Context, err error) error { return ErrExtractorError.Wrap(err) }, DenyHandler: func(c *echo.Context, identifier string, err error) error { return ErrRateLimitExceeded.Wrap(err) }, } /* RateLimiter returns a rate limiting middleware e := echo.New() limiterStore := middleware.NewRateLimiterMemoryStore(20) e.GET("/rate-limited", func(c *echo.Context) error { return c.String(http.StatusOK, "test") }, RateLimiter(limiterStore)) */ func RateLimiter(store RateLimiterStore) echo.MiddlewareFunc { config := DefaultRateLimiterConfig config.Store = store return RateLimiterWithConfig(config) } /* RateLimiterWithConfig returns a rate limiting middleware e := echo.New() config := middleware.RateLimiterConfig{ Skipper: DefaultSkipper, Store: middleware.NewRateLimiterMemoryStore( middleware.RateLimiterMemoryStoreConfig{Rate: 10, Burst: 30, ExpiresIn: 3 * time.Minute} ) IdentifierExtractor: func(ctx *echo.Context) (string, error) { id := ctx.RealIP() return id, nil }, ErrorHandler: func(ctx *echo.Context, err error) error { return context.JSON(http.StatusTooManyRequests, nil) }, DenyHandler: func(ctx *echo.Context, identifier string, err error) error { return context.JSON(http.StatusForbidden, nil) }, } e.GET("/rate-limited", func(c *echo.Context) error { return c.String(http.StatusOK, "test") }, middleware.RateLimiterWithConfig(config)) */ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } // ToMiddleware converts RateLimiterConfig to middleware or returns an error for invalid configuration func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultRateLimiterConfig.Skipper } if config.IdentifierExtractor == nil { config.IdentifierExtractor = DefaultRateLimiterConfig.IdentifierExtractor } if config.ErrorHandler == nil { config.ErrorHandler = DefaultRateLimiterConfig.ErrorHandler } if config.DenyHandler == nil { config.DenyHandler = DefaultRateLimiterConfig.DenyHandler } if config.Store == nil { return nil, errors.New("echo rate limiter store configuration must be provided") } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } if config.BeforeFunc != nil { config.BeforeFunc(c) } identifier, err := config.IdentifierExtractor(c) if err != nil { return config.ErrorHandler(c, err) } if allow, allowErr := config.Store.Allow(identifier); !allow { return config.DenyHandler(c, identifier, allowErr) } return next(c) } }, nil } // RateLimiterMemoryStore is the built-in store implementation for RateLimiter type RateLimiterMemoryStore struct { visitors map[string]*Visitor mutex sync.Mutex rate float64 // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit burst int expiresIn time.Duration lastCleanup time.Time timeNow func() time.Time } // Visitor signifies a unique user's limiter details type Visitor struct { *rate.Limiter lastSeen time.Time } /* NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with the provided rate (as req/s). for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. Burst and ExpiresIn will be set to default values. Note that if the provided rate is a float number and Burst is zero, Burst will be treated as the rounded down value of the rate. Example (with 20 requests/sec): limiterStore := middleware.NewRateLimiterMemoryStore(20) */ func NewRateLimiterMemoryStore(rateLimit float64) (store *RateLimiterMemoryStore) { return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ Rate: rateLimit, }) } /* NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore with the provided configuration. Rate must be provided. Burst will be set to the rounded down value of the configured rate if not provided or set to 0. The built-in memory store is usually capable for modest loads. For higher loads other store implementations should be considered. Characteristics: * Concurrency above 100 parallel requests may causes measurable lock contention * A high number of different IP addresses (above 16000) may be impacted by the internally used Go map * A high number of requests from a single IP address may cause lock contention Example: limiterStore := middleware.NewRateLimiterMemoryStoreWithConfig( middleware.RateLimiterMemoryStoreConfig{Rate: 50, Burst: 200, ExpiresIn: 5 * time.Minute}, ) */ func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (store *RateLimiterMemoryStore) { store = &RateLimiterMemoryStore{} store.rate = config.Rate store.burst = config.Burst store.expiresIn = config.ExpiresIn if config.ExpiresIn == 0 { store.expiresIn = DefaultRateLimiterMemoryStoreConfig.ExpiresIn } if config.Burst == 0 { store.burst = int(math.Max(1, math.Ceil(float64(config.Rate)))) } store.visitors = make(map[string]*Visitor) store.timeNow = time.Now store.lastCleanup = store.timeNow() return } // RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore type RateLimiterMemoryStoreConfig struct { Rate float64 // Rate of requests allowed to pass as req/s. For more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. Burst int // Burst is maximum number of requests to pass at the same moment. It additionally allows a number of requests to pass when rate limit is reached. ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up } // DefaultRateLimiterMemoryStoreConfig provides default configuration values for RateLimiterMemoryStore var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{ ExpiresIn: 3 * time.Minute, } // Allow implements RateLimiterStore.Allow func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) { store.mutex.Lock() limiter, exists := store.visitors[identifier] if !exists { limiter = new(Visitor) limiter.Limiter = rate.NewLimiter(rate.Limit(store.rate), store.burst) store.visitors[identifier] = limiter } now := store.timeNow() limiter.lastSeen = now if now.Sub(store.lastCleanup) > store.expiresIn { store.cleanupStaleVisitors(now) } allowed := limiter.AllowN(now, 1) store.mutex.Unlock() return allowed, nil } /* cleanupStaleVisitors helps manage the size of the visitors map by removing stale records of users who haven't visited again after the configured expiry time has elapsed */ func (store *RateLimiterMemoryStore) cleanupStaleVisitors(now time.Time) { for id, visitor := range store.visitors { if now.Sub(visitor.lastSeen) > store.expiresIn { delete(store.visitors, id) } } store.lastCleanup = now } ================================================ FILE: middleware/rate_limiter_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "errors" "math/rand" "net/http" "net/http/httptest" "sync" "sync/atomic" "testing" "time" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" "golang.org/x/time/rate" ) func TestRateLimiter(t *testing.T) { e := echo.New() handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) mw := RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore}) testCases := []struct { id string expectErr string }{ {id: "127.0.0.1"}, {id: "127.0.0.1"}, {id: "127.0.0.1"}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, } for _, tc := range testCases { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Add(echo.HeaderXRealIP, tc.id) rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := mw(handler)(c) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, http.StatusOK, rec.Code) } } func TestMustRateLimiterWithConfig_panicBehaviour(t *testing.T) { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) assert.Panics(t, func() { RateLimiterWithConfig(RateLimiterConfig{}) }) assert.NotPanics(t, func() { RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore}) }) } func TestRateLimiterWithConfig(t *testing.T) { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) e := echo.New() handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } mw, err := RateLimiterConfig{ IdentifierExtractor: func(c *echo.Context) (string, error) { id := c.Request().Header.Get(echo.HeaderXRealIP) if id == "" { return "", errors.New("invalid identifier") } return id, nil }, DenyHandler: func(ctx *echo.Context, identifier string, err error) error { return ctx.JSON(http.StatusForbidden, nil) }, ErrorHandler: func(ctx *echo.Context, err error) error { return ctx.JSON(http.StatusBadRequest, nil) }, Store: inMemoryStore, }.ToMiddleware() assert.NoError(t, err) testCases := []struct { id string code int }{ {"127.0.0.1", http.StatusOK}, {"127.0.0.1", http.StatusOK}, {"127.0.0.1", http.StatusOK}, {"127.0.0.1", http.StatusForbidden}, {"", http.StatusBadRequest}, {"127.0.0.1", http.StatusForbidden}, {"127.0.0.1", http.StatusForbidden}, } for _, tc := range testCases { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Add(echo.HeaderXRealIP, tc.id) rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := mw(handler)(c) assert.NoError(t, err) assert.Equal(t, tc.code, rec.Code) } } func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) e := echo.New() handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } mw, err := RateLimiterConfig{ IdentifierExtractor: func(c *echo.Context) (string, error) { id := c.Request().Header.Get(echo.HeaderXRealIP) if id == "" { return "", errors.New("invalid identifier") } return id, nil }, Store: inMemoryStore, }.ToMiddleware() assert.NoError(t, err) testCases := []struct { id string expectErr string }{ {id: "127.0.0.1"}, {id: "127.0.0.1"}, {id: "127.0.0.1"}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, {expectErr: "code=403, message=error while extracting identifier, err=invalid identifier"}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, } for _, tc := range testCases { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Add(echo.HeaderXRealIP, tc.id) rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := mw(handler)(c) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, http.StatusOK, rec.Code) } } func TestRateLimiterWithConfig_defaultConfig(t *testing.T) { { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) e := echo.New() handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } mw, err := RateLimiterConfig{ Store: inMemoryStore, }.ToMiddleware() assert.NoError(t, err) testCases := []struct { id string expectErr string }{ {id: "127.0.0.1"}, {id: "127.0.0.1"}, {id: "127.0.0.1"}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, } for _, tc := range testCases { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Add(echo.HeaderXRealIP, tc.id) rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := mw(handler)(c) if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, http.StatusOK, rec.Code) } } } func TestRateLimiterWithConfig_skipper(t *testing.T) { e := echo.New() var beforeFuncRan bool handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } var inMemoryStore = NewRateLimiterMemoryStore(5) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") rec := httptest.NewRecorder() c := e.NewContext(req, rec) mw, err := RateLimiterConfig{ Skipper: func(c *echo.Context) bool { return true }, BeforeFunc: func(c *echo.Context) { beforeFuncRan = true }, Store: inMemoryStore, IdentifierExtractor: func(ctx *echo.Context) (string, error) { return "127.0.0.1", nil }, }.ToMiddleware() assert.NoError(t, err) err = mw(handler)(c) assert.NoError(t, err) assert.Equal(t, false, beforeFuncRan) } func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) { e := echo.New() var beforeFuncRan bool handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } var inMemoryStore = NewRateLimiterMemoryStore(5) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") rec := httptest.NewRecorder() c := e.NewContext(req, rec) mw, err := RateLimiterConfig{ Skipper: func(c *echo.Context) bool { return false }, BeforeFunc: func(c *echo.Context) { beforeFuncRan = true }, Store: inMemoryStore, IdentifierExtractor: func(ctx *echo.Context) (string, error) { return "127.0.0.1", nil }, }.ToMiddleware() assert.NoError(t, err) _ = mw(handler)(c) assert.Equal(t, true, beforeFuncRan) } func TestRateLimiterWithConfig_beforeFunc(t *testing.T) { e := echo.New() handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } var beforeRan bool var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") rec := httptest.NewRecorder() c := e.NewContext(req, rec) mw, err := RateLimiterConfig{ BeforeFunc: func(c *echo.Context) { beforeRan = true }, Store: inMemoryStore, IdentifierExtractor: func(ctx *echo.Context) (string, error) { return "127.0.0.1", nil }, }.ToMiddleware() assert.NoError(t, err) err = mw(handler)(c) assert.NoError(t, err) assert.Equal(t, true, beforeRan) } func TestRateLimiterMemoryStore_Allow(t *testing.T) { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3, ExpiresIn: 2 * time.Second}) testCases := []struct { id string allowed bool }{ {"127.0.0.1", true}, // 0 ms {"127.0.0.1", true}, // 220 ms burst #2 {"127.0.0.1", true}, // 440 ms burst #3 {"127.0.0.1", false}, // 660 ms block {"127.0.0.1", false}, // 880 ms block {"127.0.0.1", true}, // 1100 ms next second #1 {"127.0.0.2", true}, // 1320 ms allow other ip {"127.0.0.1", false}, // 1540 ms no burst {"127.0.0.1", false}, // 1760 ms no burst {"127.0.0.1", false}, // 1980 ms no burst {"127.0.0.1", true}, // 2200 ms no burst {"127.0.0.1", false}, // 2420 ms no burst {"127.0.0.1", false}, // 2640 ms no burst {"127.0.0.1", false}, // 2860 ms no burst {"127.0.0.1", true}, // 3080 ms no burst {"127.0.0.1", false}, // 3300 ms no burst {"127.0.0.1", false}, // 3520 ms no burst {"127.0.0.1", false}, // 3740 ms no burst {"127.0.0.1", false}, // 3960 ms no burst {"127.0.0.1", true}, // 4180 ms no burst {"127.0.0.1", false}, // 4400 ms no burst {"127.0.0.1", false}, // 4620 ms no burst {"127.0.0.1", false}, // 4840 ms no burst {"127.0.0.1", true}, // 5060 ms no burst } for i, tc := range testCases { t.Logf("Running testcase #%d => %v", i, time.Duration(i)*220*time.Millisecond) inMemoryStore.timeNow = func() time.Time { return time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).Add(time.Duration(i) * 220 * time.Millisecond) } allowed, _ := inMemoryStore.Allow(tc.id) assert.Equal(t, tc.allowed, allowed) } } func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) inMemoryStore.visitors = map[string]*Visitor{ "A": { Limiter: rate.NewLimiter(1, 3), lastSeen: time.Now(), }, "B": { Limiter: rate.NewLimiter(1, 3), lastSeen: time.Now().Add(-1 * time.Minute), }, "C": { Limiter: rate.NewLimiter(1, 3), lastSeen: time.Now().Add(-5 * time.Minute), }, "D": { Limiter: rate.NewLimiter(1, 3), lastSeen: time.Now().Add(-10 * time.Minute), }, } inMemoryStore.Allow("D") inMemoryStore.cleanupStaleVisitors(time.Now()) var exists bool _, exists = inMemoryStore.visitors["A"] assert.Equal(t, true, exists) _, exists = inMemoryStore.visitors["B"] assert.Equal(t, true, exists) _, exists = inMemoryStore.visitors["C"] assert.Equal(t, false, exists) _, exists = inMemoryStore.visitors["D"] assert.Equal(t, true, exists) } func TestNewRateLimiterMemoryStore(t *testing.T) { testCases := []struct { rate float64 burst int expiresIn time.Duration expectedExpiresIn time.Duration }{ {1, 3, 5 * time.Second, 5 * time.Second}, {2, 4, 0, 3 * time.Minute}, {1, 5, 10 * time.Minute, 10 * time.Minute}, {3, 7, 0, 3 * time.Minute}, } for _, tc := range testCases { store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: tc.rate, Burst: tc.burst, ExpiresIn: tc.expiresIn}) assert.Equal(t, tc.rate, store.rate) assert.Equal(t, tc.burst, store.burst) assert.Equal(t, tc.expectedExpiresIn, store.expiresIn) } } func TestRateLimiterMemoryStore_FractionalRateDefaultBurst(t *testing.T) { store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ Rate: 0.5, // fractional rate should get a burst of at least 1 }) base := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) store.timeNow = func() time.Time { return base } allowed, err := store.Allow("user") assert.NoError(t, err) assert.True(t, allowed, "first request should not be blocked") allowed, err = store.Allow("user") assert.NoError(t, err) assert.False(t, allowed, "burst token should be consumed immediately") store.timeNow = func() time.Time { return base.Add(2 * time.Second) } allowed, err = store.Allow("user") assert.NoError(t, err) assert.True(t, allowed, "token should refill for fractional rate after time passes") } func generateAddressList(count int) []string { addrs := make([]string, count) for i := 0; i < count; i++ { addrs[i] = randomString(15) } return addrs } func run(wg *sync.WaitGroup, store RateLimiterStore, addrs []string, max int, b *testing.B) { for i := 0; i < b.N; i++ { store.Allow(addrs[rand.Intn(max)]) } wg.Done() } func benchmarkStore(store RateLimiterStore, parallel int, max int, b *testing.B) { addrs := generateAddressList(max) wg := &sync.WaitGroup{} for i := 0; i < parallel; i++ { wg.Add(1) go run(wg, store, addrs, max, b) } wg.Wait() } const ( testExpiresIn = 1000 * time.Millisecond ) func BenchmarkRateLimiterMemoryStore_1000(b *testing.B) { var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) benchmarkStore(store, 10, 1000, b) } func BenchmarkRateLimiterMemoryStore_10000(b *testing.B) { var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) benchmarkStore(store, 10, 10000, b) } func BenchmarkRateLimiterMemoryStore_100000(b *testing.B) { var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) benchmarkStore(store, 10, 100000, b) } func BenchmarkRateLimiterMemoryStore_conc100_10000(b *testing.B) { var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) benchmarkStore(store, 100, 10000, b) } // TestRateLimiterMemoryStore_TOCTOUFix verifies that the TOCTOU race condition is fixed // by ensuring timeNow() is only called once per Allow() call func TestRateLimiterMemoryStore_TOCTOUFix(t *testing.T) { t.Parallel() store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ Rate: 1, Burst: 1, ExpiresIn: 2 * time.Second, }) // Track time calls to verify we use the same time value timeCallCount := 0 baseTime := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) store.timeNow = func() time.Time { timeCallCount++ return baseTime } // First request - should succeed allowed, err := store.Allow("127.0.0.1") assert.NoError(t, err) assert.True(t, allowed, "First request should be allowed") // Verify timeNow() was only called once assert.Equal(t, 1, timeCallCount, "timeNow() should only be called once per Allow()") } // TestRateLimiterMemoryStore_ConcurrentAccess verifies rate limiting correctness under concurrent load func TestRateLimiterMemoryStore_ConcurrentAccess(t *testing.T) { t.Parallel() store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ Rate: 10, Burst: 5, ExpiresIn: 5 * time.Second, }) const goroutines = 50 const requestsPerGoroutine = 20 var wg sync.WaitGroup var allowedCount, deniedCount int32 for i := 0; i < goroutines; i++ { wg.Add(1) go func() { defer wg.Done() for j := 0; j < requestsPerGoroutine; j++ { allowed, err := store.Allow("test-user") assert.NoError(t, err) if allowed { atomic.AddInt32(&allowedCount, 1) } else { atomic.AddInt32(&deniedCount, 1) } time.Sleep(time.Millisecond) } }() } wg.Wait() totalRequests := goroutines * requestsPerGoroutine allowed := int(allowedCount) denied := int(deniedCount) assert.Equal(t, totalRequests, allowed+denied, "All requests should be processed") assert.Greater(t, denied, 0, "Some requests should be denied due to rate limiting") assert.Greater(t, allowed, 0, "Some requests should be allowed") } // TestRateLimiterMemoryStore_RaceDetection verifies no data races with high concurrency // Run with: go test -race ./middleware -run TestRateLimiterMemoryStore_RaceDetection func TestRateLimiterMemoryStore_RaceDetection(t *testing.T) { t.Parallel() store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ Rate: 100, Burst: 200, ExpiresIn: 1 * time.Second, }) const goroutines = 100 const requestsPerGoroutine = 100 var wg sync.WaitGroup identifiers := []string{"user1", "user2", "user3", "user4", "user5"} for i := 0; i < goroutines; i++ { wg.Add(1) go func(routineID int) { defer wg.Done() for j := 0; j < requestsPerGoroutine; j++ { identifier := identifiers[routineID%len(identifiers)] _, err := store.Allow(identifier) assert.NoError(t, err) } }(i) } wg.Wait() } // TestRateLimiterMemoryStore_TimeOrdering verifies time ordering consistency in rate limiting decisions func TestRateLimiterMemoryStore_TimeOrdering(t *testing.T) { t.Parallel() store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ Rate: 1, Burst: 2, ExpiresIn: 5 * time.Second, }) currentTime := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) store.timeNow = func() time.Time { return currentTime } // First two requests should succeed (burst=2) allowed1, _ := store.Allow("user1") assert.True(t, allowed1, "Request 1 should be allowed (burst)") allowed2, _ := store.Allow("user1") assert.True(t, allowed2, "Request 2 should be allowed (burst)") // Third request should be denied allowed3, _ := store.Allow("user1") assert.False(t, allowed3, "Request 3 should be denied (burst exhausted)") // Advance time by 1 second currentTime = currentTime.Add(1 * time.Second) // Fourth request should succeed allowed4, _ := store.Allow("user1") assert.True(t, allowed4, "Request 4 should be allowed (1 token available)") } ================================================ FILE: middleware/recover.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "fmt" "net/http" "runtime" "github.com/labstack/echo/v5" ) // RecoverConfig defines the config for Recover middleware. type RecoverConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // Size of the stack to be printed. // Optional. Default value 4KB. StackSize int // DisableStackAll disables formatting stack traces of all other goroutines // into buffer after the trace for the current goroutine. // Optional. Default value false. DisableStackAll bool // DisablePrintStack disables printing stack trace. // Optional. Default value as false. DisablePrintStack bool } // DefaultRecoverConfig is the default Recover middleware config. var DefaultRecoverConfig = RecoverConfig{ Skipper: DefaultSkipper, StackSize: 4 << 10, // 4 KB DisableStackAll: false, DisablePrintStack: false, } // Recover returns a middleware which recovers from panics anywhere in the chain // and handles the control to the centralized HTTPErrorHandler. func Recover() echo.MiddlewareFunc { return RecoverWithConfig(DefaultRecoverConfig) } // RecoverWithConfig returns a Recovery middleware with config or panics on invalid configuration. func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } // ToMiddleware converts RecoverConfig to middleware or returns an error for invalid configuration func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultRecoverConfig.Skipper } if config.StackSize == 0 { config.StackSize = DefaultRecoverConfig.StackSize } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) (err error) { if config.Skipper(c) { return next(c) } defer func() { if r := recover(); r != nil { if r == http.ErrAbortHandler { panic(r) } tmpErr, ok := r.(error) if !ok { tmpErr = fmt.Errorf("%v", r) } if !config.DisablePrintStack { stack := make([]byte, config.StackSize) length := runtime.Stack(stack, !config.DisableStackAll) tmpErr = &PanicStackError{Stack: stack[:length], Err: tmpErr} } err = tmpErr } }() return next(c) } }, nil } // PanicStackError is an error type that wraps an error along with its stack trace. // It is returned when config.DisablePrintStack is set to false. type PanicStackError struct { Stack []byte Err error } func (e *PanicStackError) Error() string { return fmt.Sprintf("[PANIC RECOVER] %s %s", e.Err.Error(), e.Stack) } func (e *PanicStackError) Unwrap() error { return e.Err } ================================================ FILE: middleware/recover_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "bytes" "errors" "log/slog" "net/http" "net/http/httptest" "testing" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestRecover(t *testing.T) { e := echo.New() buf := new(bytes.Buffer) e.Logger = slog.New(&discardHandler{}) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := Recover()(func(c *echo.Context) error { panic("test") }) err := h(c) assert.Contains(t, err.Error(), "[PANIC RECOVER] test goroutine") var pse *PanicStackError if errors.As(err, &pse) { assert.Contains(t, string(pse.Stack), "middleware/recover.go") } else { assert.Fail(t, "not of type PanicStackError") } assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain assert.Contains(t, buf.String(), "") // nothing is logged } func TestRecover_skipper(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) config := RecoverConfig{ Skipper: func(c *echo.Context) bool { return true }, } h := RecoverWithConfig(config)(func(c *echo.Context) error { panic("testPANIC") }) var err error assert.Panics(t, func() { err = h(c) }) assert.NoError(t, err) assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain } func TestRecoverErrAbortHandler(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := Recover()(func(c *echo.Context) error { panic(http.ErrAbortHandler) }) defer func() { r := recover() if r == nil { assert.Fail(t, "expecting `http.ErrAbortHandler`, got `nil`") } else { if err, ok := r.(error); ok { assert.ErrorIs(t, err, http.ErrAbortHandler) } else { assert.Fail(t, "not of error type") } } }() hErr := h(c) assert.Equal(t, http.StatusInternalServerError, rec.Code) assert.NotContains(t, hErr.Error(), "PANIC RECOVER") } func TestRecoverWithConfig(t *testing.T) { var testCases = []struct { name string givenNoPanic bool whenConfig RecoverConfig expectErrContain string expectErr string }{ { name: "ok, default config", whenConfig: DefaultRecoverConfig, expectErrContain: "[PANIC RECOVER] testPANIC goroutine", }, { name: "ok, no panic", givenNoPanic: true, whenConfig: DefaultRecoverConfig, expectErrContain: "", }, { name: "ok, DisablePrintStack", whenConfig: RecoverConfig{ DisablePrintStack: true, }, expectErr: "testPANIC", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) config := tc.whenConfig h := RecoverWithConfig(config)(func(c *echo.Context) error { if tc.givenNoPanic { return nil } panic("testPANIC") }) err := h(c) if tc.expectErrContain != "" { assert.Contains(t, err.Error(), tc.expectErrContain) } else if tc.expectErr != "" { assert.Contains(t, err.Error(), tc.expectErr) } else { assert.NoError(t, err) } assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain }) } } ================================================ FILE: middleware/redirect.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "errors" "net/http" "strings" "github.com/labstack/echo/v5" ) // RedirectConfig defines the config for Redirect middleware. type RedirectConfig struct { // Skipper defines a function to skip middleware. Skipper // Status code to be used when redirecting the request. // Optional. Default value http.StatusMovedPermanently. Code int redirect redirectLogic } // redirectLogic represents a function that given a scheme, host and uri // can both: 1) determine if redirect is needed (will set ok accordingly) and // 2) return the appropriate redirect url. type redirectLogic func(scheme, host, uri string) (ok bool, url string) const www = "www." // RedirectHTTPSConfig is the HTTPS Redirect middleware config. var RedirectHTTPSConfig = RedirectConfig{redirect: redirectHTTPS} // RedirectHTTPSWWWConfig is the HTTPS WWW Redirect middleware config. var RedirectHTTPSWWWConfig = RedirectConfig{redirect: redirectHTTPSWWW} // RedirectNonHTTPSWWWConfig is the non HTTPS WWW Redirect middleware config. var RedirectNonHTTPSWWWConfig = RedirectConfig{redirect: redirectNonHTTPSWWW} // RedirectWWWConfig is the WWW Redirect middleware config. var RedirectWWWConfig = RedirectConfig{redirect: redirectWWW} // RedirectNonWWWConfig is the non WWW Redirect middleware config. var RedirectNonWWWConfig = RedirectConfig{redirect: redirectNonWWW} // HTTPSRedirect redirects http requests to https. // For example, http://labstack.com will be redirect to https://labstack.com. // // Usage `Echo#Pre(HTTPSRedirect())` func HTTPSRedirect() echo.MiddlewareFunc { return HTTPSRedirectWithConfig(RedirectHTTPSConfig) } // HTTPSRedirectWithConfig returns a HTTPS redirect middleware with config or panics on invalid configuration. func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { config.redirect = redirectHTTPS return toMiddlewareOrPanic(config) } // HTTPSWWWRedirect redirects http requests to https www. // For example, http://labstack.com will be redirect to https://www.labstack.com. // // Usage `Echo#Pre(HTTPSWWWRedirect())` func HTTPSWWWRedirect() echo.MiddlewareFunc { return HTTPSWWWRedirectWithConfig(RedirectHTTPSWWWConfig) } // HTTPSWWWRedirectWithConfig returns a HTTPS WWW redirect middleware with config or panics on invalid configuration. func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { config.redirect = redirectHTTPSWWW return toMiddlewareOrPanic(config) } // HTTPSNonWWWRedirect redirects http requests to https non www. // For example, http://www.labstack.com will be redirect to https://labstack.com. // // Usage `Echo#Pre(HTTPSNonWWWRedirect())` func HTTPSNonWWWRedirect() echo.MiddlewareFunc { return HTTPSNonWWWRedirectWithConfig(RedirectNonHTTPSWWWConfig) } // HTTPSNonWWWRedirectWithConfig returns a HTTPS Non-WWW redirect middleware with config or panics on invalid configuration. func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { config.redirect = redirectNonHTTPSWWW return toMiddlewareOrPanic(config) } // WWWRedirect redirects non www requests to www. // For example, http://labstack.com will be redirect to http://www.labstack.com. // // Usage `Echo#Pre(WWWRedirect())` func WWWRedirect() echo.MiddlewareFunc { return WWWRedirectWithConfig(RedirectWWWConfig) } // WWWRedirectWithConfig returns a WWW redirect middleware with config or panics on invalid configuration. func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { config.redirect = redirectWWW return toMiddlewareOrPanic(config) } // NonWWWRedirect redirects www requests to non www. // For example, http://www.labstack.com will be redirect to http://labstack.com. // // Usage `Echo#Pre(NonWWWRedirect())` func NonWWWRedirect() echo.MiddlewareFunc { return NonWWWRedirectWithConfig(RedirectNonWWWConfig) } // NonWWWRedirectWithConfig returns a Non-WWW redirect middleware with config or panics on invalid configuration. func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { config.redirect = redirectNonWWW return toMiddlewareOrPanic(config) } // ToMiddleware converts RedirectConfig to middleware or returns an error for invalid configuration func (config RedirectConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultSkipper } if config.Code == 0 { config.Code = http.StatusMovedPermanently } if config.redirect == nil { return nil, errors.New("redirectConfig is missing redirect function") } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } req, scheme := c.Request(), c.Scheme() host := req.Host if ok, url := config.redirect(scheme, host, req.RequestURI); ok { return c.Redirect(config.Code, url) } return next(c) } }, nil } var redirectHTTPS = func(scheme, host, uri string) (bool, string) { if scheme != "https" { return true, "https://" + host + uri } return false, "" } var redirectHTTPSWWW = func(scheme, host, uri string) (bool, string) { // Redirect if not HTTPS OR missing www prefix (needs either fix) if scheme != "https" || !strings.HasPrefix(host, www) { host = strings.TrimPrefix(host, www) // Remove www if present to avoid duplication return true, "https://www." + host + uri } return false, "" } var redirectNonHTTPSWWW = func(scheme, host, uri string) (ok bool, url string) { // Redirect if not HTTPS OR has www prefix (needs either fix) if scheme != "https" || strings.HasPrefix(host, www) { host = strings.TrimPrefix(host, www) return true, "https://" + host + uri } return false, "" } var redirectWWW = func(scheme, host, uri string) (bool, string) { if !strings.HasPrefix(host, www) { return true, scheme + "://www." + host + uri } return false, "" } var redirectNonWWW = func(scheme, host, uri string) (bool, string) { if strings.HasPrefix(host, www) { return true, scheme + "://" + host[4:] + uri } return false, "" } ================================================ FILE: middleware/redirect_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "net/http" "net/http/httptest" "testing" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) type middlewareGenerator func() echo.MiddlewareFunc func TestRedirectHTTPSRedirect(t *testing.T) { var testCases = []struct { whenHost string whenHeader http.Header expectLocation string expectStatusCode int }{ { whenHost: "labstack.com", expectLocation: "https://labstack.com/", expectStatusCode: http.StatusMovedPermanently, }, { whenHost: "labstack.com", whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, expectLocation: "", expectStatusCode: http.StatusOK, }, } for _, tc := range testCases { t.Run(tc.whenHost, func(t *testing.T) { res := redirectTest(HTTPSRedirect, tc.whenHost, tc.whenHeader) assert.Equal(t, tc.expectStatusCode, res.Code) assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) }) } } func TestRedirectHTTPSWWWRedirect(t *testing.T) { var testCases = []struct { whenHost string whenHeader http.Header expectLocation string expectStatusCode int }{ { whenHost: "labstack.com", expectLocation: "https://www.labstack.com/", expectStatusCode: http.StatusMovedPermanently, }, { whenHost: "www.labstack.com", expectLocation: "https://www.labstack.com/", expectStatusCode: http.StatusMovedPermanently, }, { whenHost: "a.com", expectLocation: "https://www.a.com/", expectStatusCode: http.StatusMovedPermanently, }, { whenHost: "ip", expectLocation: "https://www.ip/", expectStatusCode: http.StatusMovedPermanently, }, { whenHost: "labstack.com", whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, expectLocation: "https://www.labstack.com/", expectStatusCode: http.StatusMovedPermanently, }, { whenHost: "www.labstack.com", whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, expectLocation: "", expectStatusCode: http.StatusOK, }, } for _, tc := range testCases { t.Run(tc.whenHost, func(t *testing.T) { res := redirectTest(HTTPSWWWRedirect, tc.whenHost, tc.whenHeader) assert.Equal(t, tc.expectStatusCode, res.Code) assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) }) } } func TestRedirectHTTPSNonWWWRedirect(t *testing.T) { var testCases = []struct { whenHost string whenHeader http.Header expectLocation string expectStatusCode int }{ { whenHost: "www.labstack.com", expectLocation: "https://labstack.com/", expectStatusCode: http.StatusMovedPermanently, }, { whenHost: "a.com", expectLocation: "https://a.com/", expectStatusCode: http.StatusMovedPermanently, }, { whenHost: "ip", expectLocation: "https://ip/", expectStatusCode: http.StatusMovedPermanently, }, { whenHost: "www.labstack.com", whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, expectLocation: "https://labstack.com/", expectStatusCode: http.StatusMovedPermanently, }, { whenHost: "labstack.com", whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, expectLocation: "", expectStatusCode: http.StatusOK, }, } for _, tc := range testCases { t.Run(tc.whenHost, func(t *testing.T) { res := redirectTest(HTTPSNonWWWRedirect, tc.whenHost, tc.whenHeader) assert.Equal(t, tc.expectStatusCode, res.Code) assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) }) } } func TestRedirectWWWRedirect(t *testing.T) { var testCases = []struct { whenHost string whenHeader http.Header expectLocation string expectStatusCode int }{ { whenHost: "labstack.com", expectLocation: "http://www.labstack.com/", expectStatusCode: http.StatusMovedPermanently, }, { whenHost: "a.com", expectLocation: "http://www.a.com/", expectStatusCode: http.StatusMovedPermanently, }, { whenHost: "ip", expectLocation: "http://www.ip/", expectStatusCode: http.StatusMovedPermanently, }, { whenHost: "a.com", whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, expectLocation: "https://www.a.com/", expectStatusCode: http.StatusMovedPermanently, }, { whenHost: "www.ip", expectLocation: "", expectStatusCode: http.StatusOK, }, } for _, tc := range testCases { t.Run(tc.whenHost, func(t *testing.T) { res := redirectTest(WWWRedirect, tc.whenHost, tc.whenHeader) assert.Equal(t, tc.expectStatusCode, res.Code) assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) }) } } func TestRedirectNonWWWRedirect(t *testing.T) { var testCases = []struct { whenHost string whenHeader http.Header expectLocation string expectStatusCode int }{ { whenHost: "www.labstack.com", expectLocation: "http://labstack.com/", expectStatusCode: http.StatusMovedPermanently, }, { whenHost: "www.a.com", expectLocation: "http://a.com/", expectStatusCode: http.StatusMovedPermanently, }, { whenHost: "www.a.com", whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, expectLocation: "https://a.com/", expectStatusCode: http.StatusMovedPermanently, }, { whenHost: "ip", expectLocation: "", expectStatusCode: http.StatusOK, }, } for _, tc := range testCases { t.Run(tc.whenHost, func(t *testing.T) { res := redirectTest(NonWWWRedirect, tc.whenHost, tc.whenHeader) assert.Equal(t, tc.expectStatusCode, res.Code) assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) }) } } func TestNonWWWRedirectWithConfig(t *testing.T) { var testCases = []struct { name string givenCode int givenSkipFunc func(c *echo.Context) bool whenHost string whenHeader http.Header expectLocation string expectStatusCode int }{ { name: "usual redirect", whenHost: "www.labstack.com", expectLocation: "http://labstack.com/", expectStatusCode: http.StatusMovedPermanently, }, { name: "redirect is skipped", givenSkipFunc: func(c *echo.Context) bool { return true // skip always }, whenHost: "www.labstack.com", expectLocation: "", expectStatusCode: http.StatusOK, }, { name: "redirect with custom status code", givenCode: http.StatusSeeOther, whenHost: "www.labstack.com", expectLocation: "http://labstack.com/", expectStatusCode: http.StatusSeeOther, }, } for _, tc := range testCases { t.Run(tc.whenHost, func(t *testing.T) { middleware := func() echo.MiddlewareFunc { return NonWWWRedirectWithConfig(RedirectConfig{ Skipper: tc.givenSkipFunc, Code: tc.givenCode, }) } res := redirectTest(middleware, tc.whenHost, tc.whenHeader) assert.Equal(t, tc.expectStatusCode, res.Code) assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) }) } } func redirectTest(fn middlewareGenerator, host string, header http.Header) *httptest.ResponseRecorder { e := echo.New() next := func(c *echo.Context) (err error) { return c.NoContent(http.StatusOK) } req := httptest.NewRequest(http.MethodGet, "/", nil) req.Host = host if header != nil { req.Header = header } res := httptest.NewRecorder() c := e.NewContext(req, res) fn()(next)(c) return res } ================================================ FILE: middleware/request_id.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "github.com/labstack/echo/v5" ) // RequestIDConfig defines the config for RequestID middleware. type RequestIDConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // Generator defines a function to generate an ID. // Optional. Default value random.String(32). Generator func() string // RequestIDHandler defines a function which is executed for a request id. RequestIDHandler func(c *echo.Context, requestID string) // TargetHeader defines what header to look for to populate the id. // Optional. Default value is `X-Request-Id` TargetHeader string } // RequestID returns a middleware that reads RequestIDConfig.TargetHeader (`X-Request-ID`) header value or when // the header value is empty, generates that value and sets request ID to response // as RequestIDConfig.TargetHeader (`X-Request-Id`) value. func RequestID() echo.MiddlewareFunc { return RequestIDWithConfig(RequestIDConfig{}) } // RequestIDWithConfig returns a middleware with given valid config or panics on invalid configuration. // The middleware reads RequestIDConfig.TargetHeader (`X-Request-ID`) header value or when the header value is empty, // generates that value and sets request ID to response as RequestIDConfig.TargetHeader (`X-Request-Id`) value. func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } // ToMiddleware converts RequestIDConfig to middleware or returns an error for invalid configuration func (config RequestIDConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultSkipper } if config.Generator == nil { config.Generator = createRandomStringGenerator(32) } if config.TargetHeader == "" { config.TargetHeader = echo.HeaderXRequestID } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } req := c.Request() res := c.Response() rid := req.Header.Get(config.TargetHeader) if rid == "" { rid = config.Generator() } res.Header().Set(config.TargetHeader, rid) if config.RequestIDHandler != nil { config.RequestIDHandler(c, rid) } return next(c) } }, nil } ================================================ FILE: middleware/request_id_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "net/http" "net/http/httptest" "testing" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestRequestID(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } rid := RequestID() h := rid(handler) err := h(c) assert.NoError(t, err) assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) } func TestMustRequestIDWithConfig_skipper(t *testing.T) { e := echo.New() e.GET("/", func(c *echo.Context) error { return c.String(http.StatusTeapot, "test") }) generatorCalled := false e.Use(RequestIDWithConfig(RequestIDConfig{ Skipper: func(c *echo.Context) bool { return true }, Generator: func() string { generatorCalled = true return "customGenerator" }, })) req := httptest.NewRequest(http.MethodGet, "/", nil) res := httptest.NewRecorder() e.ServeHTTP(res, req) assert.Equal(t, http.StatusTeapot, res.Code) assert.Equal(t, "test", res.Body.String()) assert.Equal(t, res.Header().Get(echo.HeaderXRequestID), "") assert.False(t, generatorCalled) } func TestMustRequestIDWithConfig_customGenerator(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } rid := RequestIDWithConfig(RequestIDConfig{ Generator: func() string { return "customGenerator" }, }) h := rid(handler) err := h(c) assert.NoError(t, err) assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") } func TestMustRequestIDWithConfig_RequestIDHandler(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } called := false rid := RequestIDWithConfig(RequestIDConfig{ Generator: func() string { return "customGenerator" }, RequestIDHandler: func(c *echo.Context, s string) { called = true }, }) h := rid(handler) err := h(c) assert.NoError(t, err) assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") assert.True(t, called) } func TestRequestIDWithConfig(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } rid, err := RequestIDConfig{}.ToMiddleware() assert.NoError(t, err) h := rid(handler) h(c) assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) // Custom generator rid = RequestIDWithConfig(RequestIDConfig{ Generator: func() string { return "customGenerator" }, }) h = rid(handler) h(c) assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") } func TestRequestID_IDNotAltered(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Add(echo.HeaderXRequestID, "") rec := httptest.NewRecorder() c := e.NewContext(req, rec) handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } rid := RequestIDWithConfig(RequestIDConfig{}) h := rid(handler) _ = h(c) assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "") } func TestRequestIDConfigDifferentHeader(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } rid := RequestIDWithConfig(RequestIDConfig{TargetHeader: echo.HeaderXCorrelationID}) h := rid(handler) h(c) assert.Len(t, rec.Header().Get(echo.HeaderXCorrelationID), 32) // Custom generator and handler customID := "customGenerator" calledHandler := false rid = RequestIDWithConfig(RequestIDConfig{ Generator: func() string { return customID }, TargetHeader: echo.HeaderXCorrelationID, RequestIDHandler: func(_ *echo.Context, id string) { calledHandler = true assert.Equal(t, customID, id) }, }) h = rid(handler) h(c) assert.Equal(t, rec.Header().Get(echo.HeaderXCorrelationID), "customGenerator") assert.True(t, calledHandler) } ================================================ FILE: middleware/request_logger.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "context" "errors" "log/slog" "net/http" "time" "github.com/labstack/echo/v5" ) // Example for `slog` https://pkg.go.dev/log/slog // logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // LogStatus: true, // LogURI: true, // HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code // LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { // if v.Error == nil { // logger.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST", // slog.String("uri", v.URI), // slog.Int("status", v.Status), // ) // } else { // logger.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR", // slog.String("uri", v.URI), // slog.Int("status", v.Status), // slog.String("err", v.Error.Error()), // ) // } // return nil // }, // })) // // Example for `fmt.Printf` // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // LogStatus: true, // LogURI: true, // HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code // LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { // if v.Error == nil { // fmt.Printf("REQUEST: uri: %v, status: %v\n", v.URI, v.Status) // } else { // fmt.Printf("REQUEST_ERROR: uri: %v, status: %v, err: %v\n", v.URI, v.Status, v.Error) // } // return nil // }, // })) // // Example for Zerolog (https://github.com/rs/zerolog) // logger := zerolog.New(os.Stdout) // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // LogURI: true, // LogStatus: true, // HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code // LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { // if v.Error == nil { // logger.Info(). // Str("URI", v.URI). // Int("status", v.Status). // Msg("request") // } else { // logger.Error(). // Err(v.Error). // Str("URI", v.URI). // Int("status", v.Status). // Msg("request error") // } // return nil // }, // })) // // Example for Zap (https://github.com/uber-go/zap) // logger, _ := zap.NewProduction() // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // LogURI: true, // LogStatus: true, // HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code // LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { // if v.Error == nil { // logger.Info("request", // zap.String("URI", v.URI), // zap.Int("status", v.Status), // ) // } else { // logger.Error("request error", // zap.String("URI", v.URI), // zap.Int("status", v.Status), // zap.Error(v.Error), // ) // } // return nil // }, // })) // // Example for Logrus (https://github.com/sirupsen/logrus) // log := logrus.New() // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // LogURI: true, // LogStatus: true, // HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code // LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { // if v.Error == nil { // log.WithFields(logrus.Fields{ // "URI": v.URI, // "status": v.Status, // }).Info("request") // } else { // log.WithFields(logrus.Fields{ // "URI": v.URI, // "status": v.Status, // "error": v.Error, // }).Error("request error") // } // return nil // }, // })) // RequestLoggerConfig is configuration for Request Logger middleware. type RequestLoggerConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // BeforeNextFunc defines a function that is called before next middleware or handler is called in chain. BeforeNextFunc func(c *echo.Context) // LogValuesFunc defines a function that is called with values extracted by logger from request/response. // Mandatory. LogValuesFunc func(c *echo.Context, v RequestLoggerValues) error // HandleError instructs logger to call global error handler when next middleware/handler returns an error. // This is useful when you have custom error handler that can decide to use different status codes. // // A side-effect of calling global error handler is that now Response has been committed and sent to the client // and middlewares up in chain can not change Response status code or response body. HandleError bool // LogLatency instructs logger to record duration it took to execute rest of the handler chain (next(c) call). LogLatency bool // LogProtocol instructs logger to extract request protocol (i.e. `HTTP/1.1` or `HTTP/2`) LogProtocol bool // LogRemoteIP instructs logger to extract request remote IP. See `echo.Context.RealIP()` for implementation details. LogRemoteIP bool // LogHost instructs logger to extract request host value (i.e. `example.com`) LogHost bool // LogMethod instructs logger to extract request method value (i.e. `GET` etc) LogMethod bool // LogURI instructs logger to extract request URI (i.e. `/list?lang=en&page=1`) LogURI bool // LogURIPath instructs logger to extract request URI path part (i.e. `/list`) LogURIPath bool // LogRoutePath instructs logger to extract route path part to which request was matched to (i.e. `/user/:id`) LogRoutePath bool // LogRequestID instructs logger to extract request ID from request `X-Request-ID` header or response if request did not have value. LogRequestID bool // LogReferer instructs logger to extract request referer values. LogReferer bool // LogUserAgent instructs logger to extract request user agent values. LogUserAgent bool // LogStatus instructs logger to extract response status code. If handler chain returns an error, // the status code is extracted from the error satisfying echo.StatusCoder interface. LogStatus bool // LogContentLength instructs logger to extract content length header value. Note: this value could be different from // actual request body size as it could be spoofed etc. LogContentLength bool // LogResponseSize instructs logger to extract response content length value. Note: when used with Gzip middleware // this value may not be always correct. LogResponseSize bool // LogHeaders instructs logger to extract given list of headers from request. Note: request can contain more than // one header with same value so slice of values is been logger for each given header. // // Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header // names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding". LogHeaders []string // LogQueryParams instructs logger to extract given list of query parameters from request URI. Note: request can // contain more than one query parameter with same name so slice of values is been logger for each given query param name. LogQueryParams []string // LogFormValues instructs logger to extract given list of form values from request body+URI. Note: request can // contain more than one form value with same name so slice of values is been logger for each given form value name. LogFormValues []string timeNow func() time.Time } // RequestLoggerValues contains extracted values from logger. type RequestLoggerValues struct { // StartTime is time recorded before next middleware/handler is executed. StartTime time.Time // Latency is duration it took to execute rest of the handler chain (next(c) call). Latency time.Duration // Protocol is request protocol (i.e. `HTTP/1.1` or `HTTP/2`) Protocol string // RemoteIP is request remote IP. See `echo.Context.RealIP()` for implementation details. RemoteIP string // Host is request host value (i.e. `example.com`) Host string // Method is request method value (i.e. `GET` etc) Method string // URI is request URI (i.e. `/list?lang=en&page=1`) URI string // URIPath is request URI path part (i.e. `/list`) URIPath string // RoutePath is route path part to which request was matched to (i.e. `/user/:id`) RoutePath string // RequestID is request ID from request `X-Request-ID` header or response if request did not have value. RequestID string // Referer is request referer values. Referer string // UserAgent is request user agent values. UserAgent string // Status is a response status code. When the handler returns an error satisfying echo.StatusCoder interface, then code from it. Status int // Error is error returned from executed handler chain. Error error // ContentLength is content length header value. Note: this value could be different from actual request body size // as it could be spoofed etc. ContentLength string // ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct. ResponseSize int64 // Headers are list of headers from request. Note: request can contain more than one header with same value so slice // of values is what will be returned/logged for each given header. // Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header // names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding". Headers map[string][]string // QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter // with same name so slice of values is what will be returned/logged for each given query param name. QueryParams map[string][]string // FormValues are list of form values from request body+URI. Note: request can contain more than one form value with // same name so slice of values is what will be returned/logged for each given form value name. FormValues map[string][]string } // RequestLoggerWithConfig returns a RequestLogger middleware with config. func RequestLoggerWithConfig(config RequestLoggerConfig) echo.MiddlewareFunc { mw, err := config.ToMiddleware() if err != nil { panic(err) } return mw } // ToMiddleware converts RequestLoggerConfig into middleware or returns an error for invalid configuration. func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultSkipper } now := time.Now if config.timeNow != nil { now = config.timeNow } if config.LogValuesFunc == nil { return nil, errors.New("missing LogValuesFunc callback function for request logger middleware") } logHeaders := len(config.LogHeaders) > 0 headers := append([]string(nil), config.LogHeaders...) for i, v := range headers { headers[i] = http.CanonicalHeaderKey(v) } logQueryParams := len(config.LogQueryParams) > 0 logFormValues := len(config.LogFormValues) > 0 return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } req := c.Request() start := now() if config.BeforeNextFunc != nil { config.BeforeNextFunc(c) } err := next(c) if err != nil && config.HandleError { // When global error handler writes the error to the client the Response gets "committed". This state can be // checked with `c.Response().Committed` field. c.Echo().HTTPErrorHandler(c, err) } res := c.Response() v := RequestLoggerValues{ StartTime: start, } if config.LogLatency { v.Latency = now().Sub(start) } if config.LogProtocol { v.Protocol = req.Proto } if config.LogRemoteIP { v.RemoteIP = c.RealIP() } if config.LogHost { v.Host = req.Host } if config.LogMethod { v.Method = req.Method } if config.LogURI { v.URI = req.RequestURI } if config.LogURIPath { p := req.URL.Path if p == "" { p = "/" } v.URIPath = p } if config.LogRoutePath { v.RoutePath = c.Path() } if config.LogRequestID { id := req.Header.Get(echo.HeaderXRequestID) if id == "" { id = res.Header().Get(echo.HeaderXRequestID) } v.RequestID = id } if config.LogReferer { v.Referer = req.Referer() } if config.LogUserAgent { v.UserAgent = req.UserAgent() } if config.LogStatus || config.LogResponseSize { resp, status := echo.ResolveResponseStatus(res, err) if config.LogStatus { v.Status = status } if config.LogResponseSize { v.ResponseSize = -1 if resp != nil { v.ResponseSize = resp.Size } } } if err != nil { v.Error = err } if config.LogContentLength { v.ContentLength = req.Header.Get(echo.HeaderContentLength) } if logHeaders { v.Headers = map[string][]string{} for _, header := range headers { if values, ok := req.Header[header]; ok { v.Headers[header] = values } } } if logQueryParams { queryParams := c.QueryParams() v.QueryParams = map[string][]string{} for _, param := range config.LogQueryParams { if values, ok := queryParams[param]; ok { v.QueryParams[param] = values } } } if logFormValues { v.FormValues = map[string][]string{} for _, formValue := range config.LogFormValues { if values, ok := req.Form[formValue]; ok { v.FormValues[formValue] = values } } } if errOnLog := config.LogValuesFunc(c, v); errOnLog != nil { return errOnLog } // in case of HandleError=true we are returning the error that we already have handled with global error handler // this is deliberate as this error could be useful for upstream middlewares and default global error handler // will ignore that error when it bubbles up in middleware chain. // Committed response can be checked in custom error handler with following logic // // if r, _ := echo.UnwrapResponse(c.Response()); r != nil && r.Committed { // return // } return err } }, nil } // RequestLogger creates Request Logger middleware with Echo default settings that uses Context.Logger() as logger. func RequestLogger() echo.MiddlewareFunc { return RequestLoggerWithConfig(RequestLoggerConfig{ LogLatency: true, LogRemoteIP: true, LogHost: true, LogMethod: true, LogURI: true, LogRequestID: true, LogUserAgent: true, LogStatus: true, LogContentLength: true, LogResponseSize: true, // forwards error to the global error handler, so it can decide appropriate status code. // NB: side-effect of that is - request is now "committed" written to the client. Middlewares up in chain can not // change Response status code or response body. HandleError: true, LogValuesFunc: func(c *echo.Context, v RequestLoggerValues) error { logger := c.Logger() if v.Error == nil { logger.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST", slog.String("method", v.Method), slog.String("uri", v.URI), slog.Int("status", v.Status), slog.Duration("latency", v.Latency), slog.String("host", v.Host), slog.String("bytes_in", v.ContentLength), slog.Int64("bytes_out", v.ResponseSize), slog.String("user_agent", v.UserAgent), slog.String("remote_ip", v.RemoteIP), slog.String("request_id", v.RequestID), ) return nil } logger.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR", slog.String("method", v.Method), slog.String("uri", v.URI), slog.Int("status", v.Status), slog.Duration("latency", v.Latency), slog.String("host", v.Host), slog.String("bytes_in", v.ContentLength), slog.Int64("bytes_out", v.ResponseSize), slog.String("user_agent", v.UserAgent), slog.String("remote_ip", v.RemoteIP), slog.String("request_id", v.RequestID), slog.String("error", v.Error.Error()), ) return nil }, }) } ================================================ FILE: middleware/request_logger_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "bytes" "encoding/json" "errors" "log/slog" "net/http" "net/http/httptest" "net/url" "strconv" "strings" "testing" "time" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestRequestLoggerOK(t *testing.T) { old := slog.Default() t.Cleanup(func() { slog.SetDefault(old) }) e := echo.New() buf := new(bytes.Buffer) e.Logger = slog.New(slog.NewJSONHandler(buf, nil)) e.Use(RequestLogger()) e.POST("/test", func(c *echo.Context) error { return c.String(http.StatusTeapot, "OK") }) reader := strings.NewReader(`{"foo":"bar"}`) req := httptest.NewRequest(http.MethodPost, "/test", reader) req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size()))) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) req.Header.Set(echo.HeaderXRealIP, "8.8.8.8") req.Header.Set("User-Agent", "curl/7.68.0") rec := httptest.NewRecorder() e.ServeHTTP(rec, req) logAttrs := map[string]interface{}{} assert.NoError(t, json.Unmarshal(buf.Bytes(), &logAttrs)) logAttrs["latency"] = 123 logAttrs["time"] = "x" expect := map[string]interface{}{ "level": "INFO", "msg": "REQUEST", "method": "POST", "uri": "/test", "status": float64(418), "bytes_in": "13", "host": "example.com", "bytes_out": float64(2), "user_agent": "curl/7.68.0", "remote_ip": "8.8.8.8", "request_id": "", "time": "x", "latency": 123, } assert.Equal(t, expect, logAttrs) } func TestRequestLoggerError(t *testing.T) { old := slog.Default() t.Cleanup(func() { slog.SetDefault(old) }) e := echo.New() buf := new(bytes.Buffer) e.Logger = slog.New(slog.NewJSONHandler(buf, nil)) e.Use(RequestLogger()) e.GET("/test", func(c *echo.Context) error { return errors.New("nope") }) req := httptest.NewRequest(http.MethodGet, "/test", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) logAttrs := map[string]interface{}{} assert.NoError(t, json.Unmarshal(buf.Bytes(), &logAttrs)) logAttrs["latency"] = 123 logAttrs["time"] = "x" expect := map[string]interface{}{ "level": "ERROR", "msg": "REQUEST_ERROR", "method": "GET", "uri": "/test", "status": float64(500), "bytes_in": "", "host": "example.com", "bytes_out": float64(36.0), "user_agent": "", "remote_ip": "192.0.2.1", "request_id": "", "error": "nope", "latency": 123, "time": "x", } assert.Equal(t, expect, logAttrs) } func TestRequestLoggerWithConfig(t *testing.T) { e := echo.New() var expect RequestLoggerValues e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ LogRoutePath: true, LogURI: true, LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { expect = values return nil }, })) e.GET("/test", func(c *echo.Context) error { return c.String(http.StatusTeapot, "OK") }) req := httptest.NewRequest(http.MethodGet, "/test", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusTeapot, rec.Code) assert.Equal(t, "/test", expect.RoutePath) } func TestRequestLoggerWithConfig_missingOnLogValuesPanics(t *testing.T) { assert.Panics(t, func() { RequestLoggerWithConfig(RequestLoggerConfig{ LogValuesFunc: nil, }) }) } func TestRequestLogger_skipper(t *testing.T) { e := echo.New() loggerCalled := false e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ Skipper: func(c *echo.Context) bool { return true }, LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { loggerCalled = true return nil }, })) e.GET("/test", func(c *echo.Context) error { return c.String(http.StatusTeapot, "OK") }) req := httptest.NewRequest(http.MethodGet, "/test", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusTeapot, rec.Code) assert.False(t, loggerCalled) } func TestRequestLogger_beforeNextFunc(t *testing.T) { e := echo.New() var myLoggerInstance int e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ BeforeNextFunc: func(c *echo.Context) { c.Set("myLoggerInstance", 42) }, LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { myLoggerInstance = c.Get("myLoggerInstance").(int) return nil }, })) e.GET("/test", func(c *echo.Context) error { return c.String(http.StatusTeapot, "OK") }) req := httptest.NewRequest(http.MethodGet, "/test", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusTeapot, rec.Code) assert.Equal(t, 42, myLoggerInstance) } func TestRequestLogger_logError(t *testing.T) { e := echo.New() var actual RequestLoggerValues e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ LogStatus: true, LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { actual = values return nil }, })) e.GET("/test", func(c *echo.Context) error { return echo.NewHTTPError(http.StatusNotAcceptable, "nope") }) req := httptest.NewRequest(http.MethodGet, "/test", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusNotAcceptable, rec.Code) assert.Equal(t, http.StatusNotAcceptable, actual.Status) assert.EqualError(t, actual.Error, "code=406, message=nope") } func TestRequestLogger_HandleError(t *testing.T) { e := echo.New() var actual RequestLoggerValues e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ timeNow: func() time.Time { return time.Unix(1631045377, 0).UTC() }, HandleError: true, LogStatus: true, LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { actual = values return nil }, })) // to see if "HandleError" works we create custom error handler that uses its own status codes e.HTTPErrorHandler = func(c *echo.Context, err error) { if r, _ := echo.UnwrapResponse(c.Response()); r != nil && r.Committed { return } c.JSON(http.StatusTeapot, "custom error handler") } e.GET("/test", func(c *echo.Context) error { return echo.NewHTTPError(http.StatusForbidden, "nope") }) req := httptest.NewRequest(http.MethodGet, "/test", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusTeapot, rec.Code) expect := RequestLoggerValues{ StartTime: time.Unix(1631045377, 0).UTC(), Status: http.StatusTeapot, Error: echo.NewHTTPError(http.StatusForbidden, "nope"), } assert.Equal(t, expect, actual) } func TestRequestLogger_LogValuesFuncError(t *testing.T) { e := echo.New() var expect RequestLoggerValues e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ LogStatus: true, LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { expect = values return echo.NewHTTPError(http.StatusNotAcceptable, "LogValuesFuncError") }, })) e.GET("/test", func(c *echo.Context) error { return c.String(http.StatusTeapot, "OK") }) req := httptest.NewRequest(http.MethodGet, "/test", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) // NOTE: when global error handler received error returned from middleware the status has already // been written to the client and response has been "committed" therefore global error handler does not do anything // and error that bubbled up in middleware chain will not be reflected in response code. assert.Equal(t, http.StatusTeapot, rec.Code) assert.Equal(t, http.StatusTeapot, expect.Status) } func TestRequestLogger_ID(t *testing.T) { var testCases = []struct { name string whenFromRequest bool expect string }{ { name: "ok, ID is provided from request headers", whenFromRequest: true, expect: "123", }, { name: "ok, ID is from response headers", whenFromRequest: false, expect: "321", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() var expect RequestLoggerValues e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ LogRequestID: true, LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { expect = values return nil }, })) e.GET("/test", func(c *echo.Context) error { c.Response().Header().Set(echo.HeaderXRequestID, "321") return c.String(http.StatusTeapot, "OK") }) req := httptest.NewRequest(http.MethodGet, "/test", nil) if tc.whenFromRequest { req.Header.Set(echo.HeaderXRequestID, "123") } rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusTeapot, rec.Code) assert.Equal(t, tc.expect, expect.RequestID) }) } } func TestRequestLogger_headerIsCaseInsensitive(t *testing.T) { e := echo.New() var expect RequestLoggerValues mw := RequestLoggerWithConfig(RequestLoggerConfig{ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { expect = values return nil }, LogHeaders: []string{"referer", "User-Agent"}, })(func(c *echo.Context) error { c.Request().Header.Set(echo.HeaderXRequestID, "123") c.FormValue("to force parse form") return c.String(http.StatusTeapot, "OK") }) req := httptest.NewRequest(http.MethodGet, "/test?lang=en&checked=1&checked=2", nil) req.Header.Set("referer", "https://echo.labstack.com/") rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := mw(c) assert.NoError(t, err) assert.Len(t, expect.Headers, 1) assert.Equal(t, []string{"https://echo.labstack.com/"}, expect.Headers["Referer"]) } func TestRequestLogger_allFields(t *testing.T) { e := echo.New() isFirstNowCall := true var expect RequestLoggerValues mw := RequestLoggerWithConfig(RequestLoggerConfig{ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { expect = values return nil }, LogLatency: true, LogProtocol: true, LogRemoteIP: true, LogHost: true, LogMethod: true, LogURI: true, LogURIPath: true, LogRoutePath: true, LogRequestID: true, LogReferer: true, LogUserAgent: true, LogStatus: true, LogContentLength: true, LogResponseSize: true, LogHeaders: []string{"accept-encoding", "User-Agent"}, LogQueryParams: []string{"lang", "checked"}, LogFormValues: []string{"csrf", "multiple"}, timeNow: func() time.Time { if isFirstNowCall { isFirstNowCall = false return time.Unix(1631045377, 0) } return time.Unix(1631045377+10, 0) }, })(func(c *echo.Context) error { c.Request().Header.Set(echo.HeaderXRequestID, "123") c.FormValue("to force parse form") return c.String(http.StatusTeapot, "OK") }) f := make(url.Values) f.Set("csrf", "token") f.Set("multiple", "1") f.Add("multiple", "2") reader := strings.NewReader(f.Encode()) req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", reader) req.Header.Set("Referer", "https://echo.labstack.com/") req.Header.Set("User-Agent", "curl/7.68.0") req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size()))) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) req.Header.Set(echo.HeaderXRealIP, "8.8.8.8") rec := httptest.NewRecorder() c := e.NewContext(req, rec) c.SetPath("/test*") err := mw(c) assert.NoError(t, err) assert.Equal(t, time.Unix(1631045377, 0), expect.StartTime) assert.Equal(t, 10*time.Second, expect.Latency) assert.Equal(t, "HTTP/1.1", expect.Protocol) assert.Equal(t, "8.8.8.8", expect.RemoteIP) assert.Equal(t, "example.com", expect.Host) assert.Equal(t, http.MethodPost, expect.Method) assert.Equal(t, "/test?lang=en&checked=1&checked=2", expect.URI) assert.Equal(t, "/test", expect.URIPath) assert.Equal(t, "/test*", expect.RoutePath) assert.Equal(t, "123", expect.RequestID) assert.Equal(t, "https://echo.labstack.com/", expect.Referer) assert.Equal(t, "curl/7.68.0", expect.UserAgent) assert.Equal(t, 418, expect.Status) assert.Equal(t, nil, expect.Error) assert.Equal(t, "32", expect.ContentLength) assert.Equal(t, int64(2), expect.ResponseSize) assert.Len(t, expect.Headers, 1) assert.Equal(t, []string{"curl/7.68.0"}, expect.Headers["User-Agent"]) assert.Len(t, expect.QueryParams, 2) assert.Equal(t, []string{"en"}, expect.QueryParams["lang"]) assert.Equal(t, []string{"1", "2"}, expect.QueryParams["checked"]) assert.Len(t, expect.FormValues, 2) assert.Equal(t, []string{"token"}, expect.FormValues["csrf"]) assert.Equal(t, []string{"1", "2"}, expect.FormValues["multiple"]) } func TestTestRequestLogger(t *testing.T) { var testCases = []struct { name string whenStatus int whenError error expectStatus string expectError string }{ { name: "ok", whenStatus: http.StatusTeapot, expectStatus: "418", }, { name: "error", whenError: echo.NewHTTPError(http.StatusBadGateway, "bad gw"), expectStatus: "502", expectError: `"error":"code=502, message=bad gw"`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() buf := new(bytes.Buffer) e.Logger = slog.New(slog.NewJSONHandler(buf, nil)) e.Use(RequestLogger()) e.POST("/test", func(c *echo.Context) error { if tc.whenError != nil { return tc.whenError } return c.String(tc.whenStatus, "OK") }) f := make(url.Values) f.Set("csrf", "token") f.Set("multiple", "1") f.Add("multiple", "2") reader := strings.NewReader(f.Encode()) req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", reader) req.Header.Set("Referer", "https://echo.labstack.com/") req.Header.Set("User-Agent", "curl/7.68.0") req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size()))) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) req.Header.Set(echo.HeaderXRealIP, "8.8.8.8") req.Header.Set(echo.HeaderXRequestID, "MY_ID") rec := httptest.NewRecorder() e.ServeHTTP(rec, req) rawlog := buf.Bytes() if tc.expectError != "" { assert.Contains(t, string(rawlog), `"level":"ERROR"`) assert.Contains(t, string(rawlog), `"msg":"REQUEST_ERROR"`) assert.Contains(t, string(rawlog), tc.expectError) } else { assert.Contains(t, string(rawlog), `"level":"INFO"`) assert.Contains(t, string(rawlog), `"msg":"REQUEST"`) } assert.Contains(t, string(rawlog), `"status":`+tc.expectStatus) assert.Contains(t, string(rawlog), `"method":"POST"`) assert.Contains(t, string(rawlog), `"uri":"/test?lang=en&checked=1&checked=2"`) assert.Contains(t, string(rawlog), `"latency":`) // this value varies assert.Contains(t, string(rawlog), `"request_id":"MY_ID"`) assert.Contains(t, string(rawlog), `"remote_ip":"8.8.8.8"`) assert.Contains(t, string(rawlog), `"host":"example.com"`) assert.Contains(t, string(rawlog), `"user_agent":"curl/7.68.0"`) assert.Contains(t, string(rawlog), `"bytes_in":"32"`) assert.Contains(t, string(rawlog), `"bytes_out":2`) }) } } func BenchmarkRequestLogger_withoutMapFields(b *testing.B) { e := echo.New() mw := RequestLoggerWithConfig(RequestLoggerConfig{ Skipper: nil, LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { return nil }, LogLatency: true, LogProtocol: true, LogRemoteIP: true, LogHost: true, LogMethod: true, LogURI: true, LogURIPath: true, LogRoutePath: true, LogRequestID: true, LogReferer: true, LogUserAgent: true, LogStatus: true, LogContentLength: true, LogResponseSize: true, })(func(c *echo.Context) error { c.Request().Header.Set(echo.HeaderXRequestID, "123") return c.String(http.StatusTeapot, "OK") }) req := httptest.NewRequest(http.MethodGet, "/test?lang=en", nil) req.Header.Set("Referer", "https://echo.labstack.com/") req.Header.Set("User-Agent", "curl/7.68.0") b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { rec := httptest.NewRecorder() c := e.NewContext(req, rec) mw(c) } } func BenchmarkRequestLogger_withMapFields(b *testing.B) { e := echo.New() mw := RequestLoggerWithConfig(RequestLoggerConfig{ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { return nil }, LogLatency: true, LogProtocol: true, LogRemoteIP: true, LogHost: true, LogMethod: true, LogURI: true, LogURIPath: true, LogRoutePath: true, LogRequestID: true, LogReferer: true, LogUserAgent: true, LogStatus: true, LogContentLength: true, LogResponseSize: true, LogHeaders: []string{"accept-encoding", "User-Agent"}, LogQueryParams: []string{"lang", "checked"}, LogFormValues: []string{"csrf", "multiple"}, })(func(c *echo.Context) error { c.Request().Header.Set(echo.HeaderXRequestID, "123") c.FormValue("to force parse form") return c.String(http.StatusTeapot, "OK") }) f := make(url.Values) f.Set("csrf", "token") f.Add("multiple", "1") f.Add("multiple", "2") req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", strings.NewReader(f.Encode())) req.Header.Set("Referer", "https://echo.labstack.com/") req.Header.Set("User-Agent", "curl/7.68.0") req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { rec := httptest.NewRecorder() c := e.NewContext(req, rec) mw(c) } } ================================================ FILE: middleware/rewrite.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "errors" "regexp" "github.com/labstack/echo/v5" ) // RewriteConfig defines the config for Rewrite middleware. type RewriteConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // Rules defines the URL path rewrite rules. The values captured in asterisk can be // retrieved by index e.g. $1, $2 and so on. // Example: // "/old": "/new", // "/api/*": "/$1", // "/js/*": "/public/javascripts/$1", // "/users/*/orders/*": "/user/$1/order/$2", // Required. Rules map[string]string // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. // Example: // "^/old/[0.9]+/": "/new", // "^/api/.+?/(.*)": "/v2/$1", RegexRules map[*regexp.Regexp]string } // Rewrite returns a Rewrite middleware. // // Rewrite middleware rewrites the URL path based on the provided rules. func Rewrite(rules map[string]string) echo.MiddlewareFunc { c := RewriteConfig{} c.Rules = rules return RewriteWithConfig(c) } // RewriteWithConfig returns a Rewrite middleware or panics on invalid configuration. // // Rewrite middleware rewrites the URL path based on the provided rules. func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } // ToMiddleware converts RewriteConfig to middleware or returns an error for invalid configuration func (config RewriteConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultSkipper } if config.Rules == nil && config.RegexRules == nil { return nil, errors.New("echo rewrite middleware requires url path rewrite rules or regex rules") } if config.RegexRules == nil { config.RegexRules = make(map[*regexp.Regexp]string) } for k, v := range rewriteRulesRegex(config.Rules) { config.RegexRules[k] = v } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) (err error) { if config.Skipper(c) { return next(c) } if err := rewriteURL(config.RegexRules, c.Request()); err != nil { return err } return next(c) } }, nil } ================================================ FILE: middleware/rewrite_test.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "io" "net/http" "net/http/httptest" "net/url" "regexp" "testing" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestRewriteAfterRouting(t *testing.T) { e := echo.New() // middlewares added with `Use()` are executed after routing is done and do not affect which route handler is matched e.Use(RewriteWithConfig(RewriteConfig{ Rules: map[string]string{ "/old": "/new", "/api/*": "/$1", "/js/*": "/public/javascripts/$1", "/users/*/orders/*": "/user/$1/order/$2", }, })) e.GET("/public/*", func(c *echo.Context) error { return c.String(http.StatusOK, c.Param("*")) }) e.GET("/*", func(c *echo.Context) error { return c.String(http.StatusOK, c.Param("*")) }) var testCases = []struct { whenPath string expectRoutePath string expectRequestPath string expectRequestRawPath string }{ { whenPath: "/api/users", expectRoutePath: "api/users", expectRequestPath: "/users", expectRequestRawPath: "", }, { whenPath: "/js/main.js", expectRoutePath: "js/main.js", expectRequestPath: "/public/javascripts/main.js", expectRequestRawPath: "", }, { whenPath: "/users/jack/orders/1", expectRoutePath: "users/jack/orders/1", expectRequestPath: "/user/jack/order/1", expectRequestRawPath: "", }, { // no rewrite rule matched. already encoded URL should not be double encoded or changed in any way whenPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", expectRoutePath: "user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", expectRequestPath: "/user/jill/order/T/cO4lW/t/Vp/", // this is equal to `url.Parse(tc.whenPath)` result expectRequestRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", }, { // just rewrite but do not touch encoding. already encoded URL should not be double encoded whenPath: "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F", expectRoutePath: "users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F", expectRequestPath: "/user/jill/order/T/cO4lW/t/Vp/", // this is equal to `url.Parse(tc.whenPath)` result expectRequestRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", }, { // ` ` (space) is encoded by httpClient to `%20` when doing request to Echo. `%20` should not be double escaped or changed in any way when rewriting request whenPath: "/api/new users", expectRoutePath: "api/new users", expectRequestPath: "/new users", expectRequestRawPath: "", }, } for _, tc := range testCases { t.Run(tc.whenPath, func(t *testing.T) { target, _ := url.Parse(tc.whenPath) req := httptest.NewRequest(http.MethodGet, target.String(), nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, tc.expectRoutePath, rec.Body.String()) assert.Equal(t, tc.expectRequestPath, req.URL.Path) assert.Equal(t, tc.expectRequestRawPath, req.URL.RawPath) }) } } func TestMustRewriteWithConfig_emptyRulesPanics(t *testing.T) { assert.Panics(t, func() { RewriteWithConfig(RewriteConfig{}) }) } func TestMustRewriteWithConfig_skipper(t *testing.T) { var testCases = []struct { name string givenSkipper func(c *echo.Context) bool whenURL string expectURL string expectStatus int }{ { name: "not skipped", whenURL: "/old", expectURL: "/new", expectStatus: http.StatusOK, }, { name: "skipped", givenSkipper: func(c *echo.Context) bool { return true }, whenURL: "/old", expectURL: "/old", expectStatus: http.StatusNotFound, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() e.Pre(RewriteWithConfig( RewriteConfig{ Skipper: tc.givenSkipper, Rules: map[string]string{"/old": "/new"}}, )) e.GET("/new", func(c *echo.Context) error { return c.NoContent(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, tc.expectURL, req.URL.EscapedPath()) assert.Equal(t, tc.expectStatus, rec.Code) }) } } // Issue #1086 func TestEchoRewritePreMiddleware(t *testing.T) { e := echo.New() // Rewrite old url to new one // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches e.Pre(RewriteWithConfig(RewriteConfig{ Rules: map[string]string{"/old": "/new"}}), ) // Route e.Add(http.MethodGet, "/new", func(c *echo.Context) error { return c.NoContent(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/old", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/new", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) } // Issue #1143 func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { e := echo.New() // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches e.Pre(RewriteWithConfig(RewriteConfig{ Rules: map[string]string{ "/api/*/mgmt/proj/*/agt": "/api/$1/hosts/$2", "/api/*/mgmt/proj": "/api/$1/eng", }, })) e.Add(http.MethodGet, "/api/:version/hosts/:name", func(c *echo.Context) error { return c.String(http.StatusOK, "hosts") }) e.Add(http.MethodGet, "/api/:version/eng", func(c *echo.Context) error { return c.String(http.StatusOK, "eng") }) for i := 0; i < 100; i++ { req := httptest.NewRequest(http.MethodGet, "/api/v1/mgmt/proj/test/agt", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/api/v1/hosts/test", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) defer rec.Result().Body.Close() bodyBytes, _ := io.ReadAll(rec.Result().Body) assert.Equal(t, "hosts", string(bodyBytes)) } } // Issue #1573 func TestEchoRewriteWithCaret(t *testing.T) { e := echo.New() e.Pre(RewriteWithConfig(RewriteConfig{ Rules: map[string]string{ "^/abc/*": "/v1/abc/$1", }, })) rec := httptest.NewRecorder() var req *http.Request req = httptest.NewRequest(http.MethodGet, "/abc/test", nil) e.ServeHTTP(rec, req) assert.Equal(t, "/v1/abc/test", req.URL.Path) req = httptest.NewRequest(http.MethodGet, "/v1/abc/test", nil) e.ServeHTTP(rec, req) assert.Equal(t, "/v1/abc/test", req.URL.Path) req = httptest.NewRequest(http.MethodGet, "/v2/abc/test", nil) e.ServeHTTP(rec, req) assert.Equal(t, "/v2/abc/test", req.URL.Path) } // Verify regex used with rewrite func TestEchoRewriteWithRegexRules(t *testing.T) { e := echo.New() e.Pre(RewriteWithConfig(RewriteConfig{ Rules: map[string]string{ "^/a/*": "/v1/$1", "^/b/*/c/*": "/v2/$2/$1", "^/c/*/*": "/v3/$2", }, RegexRules: map[*regexp.Regexp]string{ regexp.MustCompile("^/x/.+?/(.*)"): "/v4/$1", regexp.MustCompile("^/y/(.+?)/(.*)"): "/v5/$2/$1", }, })) var rec *httptest.ResponseRecorder var req *http.Request testCases := []struct { requestPath string expectPath string }{ {"/unmatched", "/unmatched"}, {"/a/test", "/v1/test"}, {"/b/foo/c/bar/baz", "/v2/bar/baz/foo"}, {"/c/ignore/test", "/v3/test"}, {"/c/ignore1/test/this", "/v3/test/this"}, {"/x/ignore/test", "/v4/test"}, {"/y/foo/bar", "/v5/bar/foo"}, } for _, tc := range testCases { t.Run(tc.requestPath, func(t *testing.T) { req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, tc.expectPath, req.URL.EscapedPath()) }) } } // Ensure correct escaping as defined in replacement (issue #1798) func TestEchoRewriteReplacementEscaping(t *testing.T) { e := echo.New() // NOTE: these are incorrect regexps as they do not factor in that URI we are replacing could contain ? (query) and # (fragment) parts // so in reality they append query and fragment part as `$1` matches everything after that prefix e.Pre(RewriteWithConfig(RewriteConfig{ Rules: map[string]string{ "^/a/*": "/$1?query=param", "^/b/*": "/$1;part#one", }, RegexRules: map[*regexp.Regexp]string{ regexp.MustCompile("^/x/(.*)"): "/$1?query=param", regexp.MustCompile("^/y/(.*)"): "/$1;part#one", regexp.MustCompile("^/z/(.*)"): "/$1?test=1#escaped%20test", }, })) var rec *httptest.ResponseRecorder var req *http.Request testCases := []struct { requestPath string expect string }{ {"/unmatched", "/unmatched"}, {"/a/test", "/test?query=param"}, {"/b/foo/bar", "/foo/bar;part#one"}, {"/x/test", "/test?query=param"}, {"/y/foo/bar", "/foo/bar;part#one"}, {"/z/foo/b%20ar", "/foo/b%20ar?test=1#escaped%20test"}, {"/z/foo/b%20ar?nope=1#yes", "/foo/b%20ar?nope=1#yes?test=1%23escaped%20test"}, // example of appending } for _, tc := range testCases { t.Run(tc.requestPath, func(t *testing.T) { req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, tc.expect, req.URL.String()) }) } } ================================================ FILE: middleware/secure.go ================================================ // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors package middleware import ( "fmt" "github.com/labstack/echo/v5" ) // SecureConfig defines the config for Secure middleware. type SecureConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // XSSProtection provides protection against cross-site scripting attack (XSS) // by setting the `X-XSS-Protection` header. // Optional. Default value "1; mode=block". XSSProtection string // ContentTypeNosniff provides protection against overriding Content-Type // header by setting the `X-Content-Type-Options` header. // Optional. Default value "nosniff". ContentTypeNosniff string // XFrameOptions can be used to indicate whether or not a browser should // be allowed to render a page in a ,